Author : Dipu Kumar Mohanto
CSE, Batch - 6
BRUR.
Problem Statement : G. Master Cat Viper and Programming
Source : Codemarshal
City University IUPC 2017
Category : Data Structure, Tree
Algorithm : MO's Algorithm on Tree
Verdict : Accepted
- #include "bits/stdc++.h"
-
- using namespace std;
-
- static const int maxn = 50000 * 2 + 5;
- static const int block = 320;
- static const int logn = 18;
- static const long long mod = 1000000007;
-
-
- vector < vector <int> > graph;
-
- int nodeList[maxn];
- int in[maxn];
- int out[maxn];
- int dfsTime;
- int pointer;
-
- int father[maxn][logn];
- int depth[maxn];
-
- void dfs(int u, int p = -1)
- {
- in[u] = ++dfsTime;
- ++pointer;
- nodeList[pointer] = u;
- for (int i = 1; i < logn; i++) father[u][i] = father[ father[u][i - 1] ][i - 1];
- for (int v : graph[u])
- {
- if (v == p) continue;
- depth[v] = depth[u] + 1;
- father[v][0] = u;
- dfs(v, u);
- }
- ++dfsTime;
- ++pointer;
- out[u] = dfsTime;
- nodeList[pointer] = u;
- }
-
- int findLca(int u, int v)
- {
- if (depth[u] < depth[v]) swap(u, v);
- for (int i = logn - 1; i >= 0; i--)
- {
- if (depth[ father[u][i] ] >= depth[v]) u = father[u][i];
- }
- if (u == v) return u;
- for (int i = logn - 1; i >= 0; i--)
- {
- if (father[u][i] != father[v][i])
- {
- u = father[u][i];
- v = father[v][i];
- }
- }
- return father[u][0];
- }
-
- struct Mo
- {
- int l;
- int r;
- int id;
- int lca;
- int d;
- Mo(int l = 0, int r = 0, int id = 0, int lca = 0, int d = 0)
- : l(l)
- , r(r)
- , id(id)
- , lca(lca)
- , d(d) {}
-
- friend bool operator < (Mo p, Mo q)
- {
- int pb = p.l / block;
- int qb = q.l / block;
- if (pb != qb) return p.l < q.l;
- return (p.r < q.r) ^ (p.l / block % 2);
- }
- } queries[maxn];
-
- long long arr[maxn];
- int counter[maxn];
- long long sum;
- long long mul;
- long long sumsq;
- long long ans[maxn];
-
- int l;
- int r;
-
- void add(int pos)
- {
- int node = nodeList[pos];
- long long x = arr[node];
- counter[node]++;
- if (counter[node] == 1)
- {
- mul = (mul + (x * sum) % mod) % mod;
- sum = (sum + x) % mod;
- sumsq = (sumsq + (x * x) % mod) % mod;
- }
- else if (counter[node] == 2)
- {
- sum = (sum - x + mod) % mod;
- mul = (mul - (sum * x) % mod + mod) % mod;
- sumsq = (sumsq - (x * x) % mod + mod) % mod;
- }
- }
-
- void remov(int pos)
- {
- int node = nodeList[pos];
- long long x = arr[node];
- counter[node]--;
- if (counter[node] == 1)
- {
- mul = (mul + (x * sum) % mod) % mod;
- sum = (sum + x) % mod;
- sumsq = (sumsq + (x * x) % mod) % mod;
- }
- else if (counter[node] == 0)
- {
- sum = (sum - x + mod) % mod;
- mul = (mul - (sum * x) % mod + mod) % mod;
- sumsq = (sumsq - (x * x) % mod + mod) % mod;
- }
- }
-
- void clean()
- {
- dfsTime = 0;
- pointer = 0;
- for (int i = 0; i < maxn; i++)
- {
- nodeList[i] = 0;
- in[i] = 0;
- out[i] = 0;
- depth[i] = 0;
- counter[i] = 0;
- for (int j = 0; j < logn; j++) father[i][j] = 0;
- }
- }
-
- int main()
- {
- ios_base::sync_with_stdio(false);
- cin.tie(nullptr);
- cout.tie(nullptr);
-
-
-
-
-
- int tc;
- cin >> tc;
- for (int tcase = 1; tcase <= tc; tcase++)
- {
- int n, q;
- cin >> n >> q;
- graph.clear();
- graph.resize(n + 1);
- for (int i = 1; i <= n; i++) arr[i] = i;
- int root = -1;
- for (int u = 1; u <= n; u++)
- {
- int v;
- cin >> v;
- if (v == 0)
- {
- root = u;
- continue;
- }
- graph[u].push_back(v);
- graph[v].push_back(u);
- }
- for (int i = 1; i <= n; i++) sort(graph[i].begin(), graph[i].end());
- depth[root] = 1;
- dfs(root);
- for (int i = 1; i <= q; i++)
- {
- int u, v;
- cin >> u >> v;
- if (in[u] > in[v]) swap(u, v);
- int lca = findLca(u, v);
- int d = depth[u] + depth[v] - 2 * depth[lca];
- if (lca == u)
- {
- int x = in[u];
- int y = in[v];
- queries[i].l = x;
- queries[i].r = y;
- queries[i].id = i;
- queries[i].lca = 0;
- queries[i].d = d;
- }
- else
- {
- int x = min(out[u], in[v]);
- int y = max(out[u], in[v]);
- queries[i].l = x;
- queries[i].r = y;
- queries[i].id = i;
- queries[i].lca = lca;
- queries[i].d = d;
- }
- }
- sort(queries + 1, queries + q + 1);
- l = 1;
- r = 0;
- sum = 0;
- mul = 0;
- sumsq = 0;
- for (int i = 1; i <= q; i++)
- {
- int L = queries[i].l;
- int R = queries[i].r;
- int id = queries[i].id;
- int lca = queries[i].lca;
- int d = queries[i].d;
- while (l > L) add(--l);
- while (r < R) add(++r);
- while (l < L) remov(l++);
- while (r > R) remov(r--);
- long long newmul = mul;
- long long newsum = sum;
- long long newsumsq = sumsq;
- if (lca > 0)
- {
- newmul = (newmul + (arr[lca] * newsum) % mod) % mod;
- newsumsq = (newsumsq + (arr[lca] * arr[lca]) % mod) % mod;
- }
- long long res = (1LL * 2 * newmul) % mod;
- res = (res + (1LL * d * newsumsq) % mod) % mod;
- ans[id] = res;
- }
- cout << "Case " << tcase << ":\n";
- for (int i = 1; i <= q; i++) cout << ans[i] << '\n';
- clean();
- }
- }
No comments:
Post a Comment
Note: Only a member of this blog may post a comment.