Friday, December 14, 2018

[Spoj] COT - Count on a Tree

Author            : Dipu Kumar Mohanto 
                    CSE, Batch - 6
                    BRUR.
Problem Statement : COT - Count on a Tree
Source            : Spoj
Category          : Data Structure
Algorithm         : Persistent Segment Tree
Verdict           : Accepted 
Persistent Segment Tree + Binary Search + Fast IO = O(n logn logn) - TLE
Persistent Segment Tree + Query + Fast IO         = O(n logn) - AC
Works for duplicate values
 
#include <bits/stdc++.h>
 
using namespace std;
 
static const int MAXN = 1e5 + 5;
static const int LOGN = 19;
 
map <int, int> valTOind;
int indTOval[MAXN];
 
int tNode, tQuery, n;
int nodeVal[MAXN];// temp[MAXN];
vector <int> graph[MAXN];
 
int readInt() 
{
	bool minus = false;
	int result = 0;
	char ch;
	ch = getchar();
	while (true) 
        {
		if (ch == '-') break;
		if (ch >= '0' && ch <= '9') break;
		ch = getchar();
	}
	if (ch == '-') minus = true; else result = ch-'0';
	while (true) 
        {
		ch = getchar();
		if (ch < '0' || ch > '9') break;
		result = result*10 + (ch - '0');
	}
	if (minus) return -result;
	else return result;
}
 
struct node
{
    int val;
    node *left, *right;
    node(int val)
    {
        this->val = val;
        this->left = NULL;
        this->right = NULL;
    }
} *version[MAXN<<1];
 
void build(node *root, int a, int b)
{
    if (a > b) return;
    if (a == b)
    {
        root->val = 0;
        return;
    }
    int mid = (a+b) >> 1;
    root->left = new node(0);
    root->right = new node(0);
    build(root->left, a, mid);
    build(root->right, mid+1, b);
    root->val = root->left->val + root->right->val;
}
 
void update(node *proot, node *root, int a, int b, int pos)
{
    if (a > b || a > pos || b < pos) return;
    if (a >= pos && b <= pos)
    {
        root->val += 1;
        return;
    }
    int mid = (a+b) >> 1;
    if (pos <= mid)
    {
        root->left = new node(0);
        root->right = proot->right;
        update(proot->left, root->left, a, mid, pos);
    }
    else
    {
        root->left = proot->left;
        root->right = new node(0);
        update(proot->right, root->right, mid+1, b, pos);
    }
    root->val = root->left->val + root->right->val;
}
 
/**
int query(node *root, int a, int b, int i, int j)
{
    if (a > b || a > j || b < i)
        return 0;
 
    if (a >= i && b <= j)
        return root->val;
 
    int mid = (a+b)>>1;
 
    int p1 = query(root->left, a, mid, i, j);
    int p2 = query(root->right, mid+1, b, i, j);
 
    return p1+p2;
}
 
int Count(node *root_u, node *root_v, node *root_lca, node *root_plca, int k)
{
    int sum = 0;
        sum += query(root_u, 1, n, 1, k);
        sum += query(root_v, 1, n, 1, k);
        sum -= query(root_lca, 1, n, 1, k);
        sum -= query(root_plca, 1, n, 1, k);
 
    return sum;
}
 
int binarySearch(node *root_u, node *root_v, node *root_lca, node *root_plca, int k)  // l : version[u], r : version[v]
{
    int low = 1;
    int high = n;
    int ans;
 
    while (low <= high)
    {
        int mid = (low+high)>>1;
 
        int cnt = Count(root_u, root_v, root_lca, root_plca, mid);
 
        if (cnt >= k)
        {
            ans = mid;
            high = mid-1;
        }
        else
        {
            low = mid+1;
        }
    }
    return ans;
}
**/
 
int query1(node *root_u, node *root_v, node *root_lca, node *root_plca, int a, int b, int k)
{
    if (a == b) return a;
    int sum = 0;
        sum += root_u->left->val;
        sum += root_v->left->val;
        sum -= root_lca->left->val;
        sum -= root_plca->left->val;
    int mid = (a+b) >> 1;
    if (sum >= k) return query1(root_u->left, root_v->left, root_lca->left, root_plca->left, a, mid, k);
    else return query1(root_u->right, root_v->right, root_lca->right, root_plca->right, mid+1, b, k - sum);
}
 
int father[MAXN][LOGN];
int depth[MAXN];
 
void dfs(int u, int p = -1)
{
    for (int i = 1; i < LOGN; i++) father[u][i] = father[father[u][i-1]][i-1];
    for (int v : graph[u])
    {
        if (v == p) continue;
        father[v][0] = u;
        depth[v] = depth[u] + 1;
        dfs(v, u);
    }
}
 
int LCA(int u, int v)
{
    if (depth[u] < depth[v]) swap(u, v);
    for (int i = LOGN-1; i >= 0; i--)
    {
        if (depth[father[u][i]] >= depth[v])
        {
            u = father[u][i];
        }
    }
    if (u == v) return u;
    for (int i = LOGN-1; i >= 0; i--)
    {
        if (father[u][i] != father[v][i])
        {
            u = father[u][i];
            v = father[v][i];
        }
    }
    return father[u][0];
}
 
int ver[MAXN];
bool vis[MAXN];
 
void bfs(int src)
{
    memset(vis, 0, sizeof vis);
    queue <int> PQ;
    PQ.push(src);
    int vr = 1;
    ver[src] = vr;
    vis[src] = 1;
    version[ver[src]] = new node(0);
    update(version[0], version[ver[src]], 1, n, nodeVal[src]);
    while (!PQ.empty())
    {
        int u = PQ.front(); PQ.pop();
        for (int v : graph[u])
        {
            if (vis[v]) continue;
            vis[v] = 1;
            vr++;
            ver[v] = vr;
            version[ver[v]] = new node(0);
            update(version[ver[u]], version[ver[v]], 1, n, nodeVal[v]);
            PQ.push(v);
        }
    }
 
}
 
struct structure
{
    int val, ind;
    structure() {}
    structure(int val, int ind)
    {
        this->val = val;
        this->ind = ind;
    }
    friend bool operator<(structure A, structure B)
    {
        if (A.val == B.val) return A.ind < B.ind;
        return A.val < B.val;
    }
} temp[MAXN];
 
 
int main()
{
    //freopen("in.txt", "r", stdin);
 
    //scanf("%d %d", &tNode, &tQuery);
    tNode = readInt();
    tQuery = readInt();
    for (int i = 1; i <= tNode; i++)
    {
        //scanf("%d", &nodeVal[i]);
        nodeVal[i] = readInt();
        temp[i] = {nodeVal[i], i};
    }
    for (int e = 1; e < tNode; e++)
    {
        int u, v;
        //scanf("%d %d", &u, &v);
        u = readInt();
        v = readInt();
        graph[u].push_back(v);
        graph[v].push_back(u);
    }
    sort(temp+1, temp+tNode+1);
    int id = 0;
    for (int i = 1; i <= tNode; i++)
    {
        int num = temp[i].val;
        int ind = temp[i].ind;
        id++;
        nodeVal[ind] = id;
        indTOval[id] = num;
    }
    n = id;
    version[0] = new node(0);
    build(version[0], 1, n);
    bfs(1);
    depth[1] = 1;
    dfs(1);
    while (tQuery--)
    {
        int u, v, k;
        //scanf("%d %d %d", &u, &v, &k);
        u = readInt();
        v = readInt();
        k = readInt();

        int lca = LCA(u, v);
        int plca = father[lca][0];

        //int ansb = binarySearch(version[ver[u]], version[ver[v]], version[ver[lca]], version[ver[plca]], k);
        int ans = query1(version[ver[u]], version[ver[v]], version[ver[lca]], version[ver[plca]], 1, n, k);
 
        printf("%d\n", indTOval[ans]);
    }
}
 

No comments:

Post a Comment

Note: Only a member of this blog may post a comment.