Wednesday, September 25, 2019

[Codemarshal] G. Master Cat Viper and Programming

Author            : Dipu Kumar Mohanto 
                    CSE, Batch - 6
                    BRUR.
Problem Statement : G. Master Cat Viper and Programming
Source            : Codemarshal
                    City University IUPC 2017
Category          : Data Structure, Tree
Algorithm         : MO's Algorithm on Tree
Verdict           : Accepted

  1. #include "bits/stdc++.h"  
  2.   
  3. using namespace std;  
  4.   
  5. static const int maxn = 50000 * 2 + 5;  
  6. static const int block = 320;  
  7. static const int logn = 18;  
  8. static const long long mod = 1000000007;  
  9.   
  10.   
  11. vector < vector <int> > graph;  
  12.   
  13. int nodeList[maxn];  
  14. int in[maxn];  
  15. int out[maxn];  
  16. int dfsTime;  
  17. int pointer;  
  18.   
  19. int father[maxn][logn];  
  20. int depth[maxn];  
  21.   
  22. void dfs(int u, int p = -1)  
  23. {  
  24.       in[u] = ++dfsTime;  
  25.       ++pointer;  
  26.       nodeList[pointer] = u;  
  27.       for (int i = 1; i < logn; i++) father[u][i] = father[ father[u][i - 1] ][i - 1];  
  28.       for (int v : graph[u])  
  29.       {  
  30.             if (v == p) continue;  
  31.             depth[v] = depth[u] + 1;  
  32.             father[v][0] = u;  
  33.             dfs(v, u);  
  34.       }  
  35.       ++dfsTime;  
  36.       ++pointer;  
  37.       out[u] = dfsTime;  
  38.       nodeList[pointer] = u;  
  39. }  
  40.   
  41. int findLca(int u, int v)  
  42. {  
  43.       if (depth[u] < depth[v]) swap(u, v);  
  44.       for (int i = logn - 1; i >= 0; i--)  
  45.       {  
  46.             if (depth[ father[u][i] ] >= depth[v]) u = father[u][i];  
  47.       }  
  48.       if (u == v) return u;  
  49.       for (int i = logn - 1; i >= 0; i--)  
  50.       {  
  51.             if (father[u][i] != father[v][i])  
  52.             {  
  53.                   u = father[u][i];  
  54.                   v = father[v][i];  
  55.             }  
  56.       }  
  57.       return father[u][0];  
  58. }  
  59.   
  60. struct Mo  
  61. {  
  62.       int l;  
  63.       int r;  
  64.       int id;  
  65.       int lca;  
  66.       int d;  
  67.       Mo(int l = 0, int r = 0, int id = 0, int lca = 0, int d = 0)  
  68.             : l(l)  
  69.             , r(r)  
  70.             , id(id)  
  71.             , lca(lca)  
  72.             , d(d) {}  
  73.   
  74.       friend bool operator < (Mo p, Mo q)  
  75.       {  
  76.             int pb = p.l / block;  
  77.             int qb = q.l / block;  
  78.             if (pb != qb) return p.l < q.l;  
  79.             return (p.r < q.r) ^ (p.l / block % 2);  
  80.       }  
  81. } queries[maxn];  
  82.   
  83. long long arr[maxn];  
  84. int counter[maxn];  
  85. long long sum;  
  86. long long mul;  
  87. long long sumsq;  
  88. long long ans[maxn];  
  89.   
  90. int l;  
  91. int r;  
  92.   
  93. void add(int pos)  
  94. {  
  95.       int node = nodeList[pos];  
  96.       long long x = arr[node];  
  97.       counter[node]++;  
  98.       if (counter[node] == 1)  
  99.       {  
  100.             mul = (mul + (x * sum) % mod) % mod;  
  101.             sum = (sum + x) % mod;  
  102.             sumsq = (sumsq + (x * x) % mod) % mod;  
  103.       }  
  104.       else if (counter[node] == 2)  
  105.       {  
  106.             sum = (sum - x + mod) % mod;  
  107.             mul = (mul - (sum * x) % mod + mod) % mod;  
  108.             sumsq = (sumsq - (x * x) % mod + mod) % mod;  
  109.       }  
  110. }  
  111.   
  112. void remov(int pos)  
  113. {  
  114.       int node = nodeList[pos];  
  115.       long long x = arr[node];  
  116.       counter[node]--;  
  117.       if (counter[node] == 1)  
  118.       {  
  119.             mul = (mul + (x * sum) % mod) % mod;  
  120.             sum = (sum + x) % mod;  
  121.             sumsq = (sumsq + (x * x) % mod) % mod;  
  122.       }  
  123.       else if (counter[node] == 0)  
  124.       {  
  125.             sum = (sum - x + mod) % mod;  
  126.             mul = (mul - (sum * x) % mod + mod) % mod;  
  127.             sumsq = (sumsq - (x * x) % mod + mod) % mod;  
  128.       }  
  129. }  
  130.   
  131. void clean()  
  132. {  
  133.       dfsTime = 0;  
  134.       pointer = 0;  
  135.       for (int i = 0; i < maxn; i++)  
  136.       {  
  137.             nodeList[i] = 0;  
  138.             in[i] = 0;  
  139.             out[i] = 0;  
  140.             depth[i] = 0;  
  141.             counter[i] = 0;  
  142.             for (int j = 0; j < logn; j++) father[i][j] = 0;  
  143.       }  
  144. }  
  145.   
  146. int main()  
  147. {  
  148.       ios_base::sync_with_stdio(false);  
  149.       cin.tie(nullptr);  
  150.       cout.tie(nullptr);  
  151.   
  152. //      #ifndef ONLINE_JUDGE  
  153. //            freopen("in.txt", "r", stdin);  
  154. //      #endif // ONLINE_JUDGE  
  155.   
  156.       int tc;  
  157.       cin >> tc;  
  158.       for (int tcase = 1; tcase <= tc; tcase++)  
  159.       {  
  160.             int n, q;  
  161.             cin >> n >> q;  
  162.             graph.clear();  
  163.             graph.resize(n + 1);  
  164.             for (int i = 1; i <= n; i++) arr[i] = i;  
  165.             int root = -1;  
  166.             for (int u = 1; u <= n; u++)  
  167.             {  
  168.                   int v;  
  169.                   cin >> v;  
  170.                   if (v == 0)  
  171.                   {  
  172.                         root = u;  
  173.                         continue;  
  174.                   }  
  175.                   graph[u].push_back(v);  
  176.                   graph[v].push_back(u);  
  177.             }  
  178.             for (int i = 1; i <= n; i++) sort(graph[i].begin(), graph[i].end());  
  179.             depth[root] = 1;  
  180.             dfs(root);  
  181.             for (int i = 1; i <= q; i++)  
  182.             {  
  183.                   int u, v;  
  184.                   cin >> u >> v;  
  185.                   if (in[u] > in[v]) swap(u, v);  
  186.                   int lca = findLca(u, v);  
  187.                   int d = depth[u] + depth[v] - 2 * depth[lca];  
  188.                   if (lca == u)  
  189.                   {  
  190.                         int x = in[u];  
  191.                         int y = in[v];  
  192.                         queries[i].l = x;  
  193.                         queries[i].r = y;  
  194.                         queries[i].id = i;  
  195.                         queries[i].lca = 0;  
  196.                         queries[i].d = d;  
  197.                   }  
  198.                   else  
  199.                   {  
  200.                         int x = min(out[u], in[v]);  
  201.                         int y = max(out[u], in[v]);  
  202.                         queries[i].l = x;  
  203.                         queries[i].r = y;  
  204.                         queries[i].id = i;  
  205.                         queries[i].lca = lca;  
  206.                         queries[i].d = d;  
  207.                   }  
  208.             }  
  209.             sort(queries + 1, queries + q + 1);  
  210.             l = 1;  
  211.             r = 0;  
  212.             sum = 0;  
  213.             mul = 0;  
  214.             sumsq = 0;  
  215.             for (int i = 1; i <= q; i++)  
  216.             {  
  217.                   int L = queries[i].l;  
  218.                   int R = queries[i].r;  
  219.                   int id = queries[i].id;  
  220.                   int lca = queries[i].lca;  
  221.                   int d = queries[i].d;  
  222.                   while (l > L) add(--l);  
  223.                   while (r < R) add(++r);  
  224.                   while (l < L) remov(l++);  
  225.                   while (r > R) remov(r--);  
  226.                   long long newmul = mul;  
  227.                   long long newsum = sum;  
  228.                   long long newsumsq = sumsq;  
  229.                   if (lca > 0)  
  230.                   {  
  231.                         newmul = (newmul + (arr[lca] * newsum) % mod) % mod;  
  232.                         newsumsq = (newsumsq + (arr[lca] * arr[lca]) % mod) % mod;  
  233.                   }  
  234.                   long long res = (1LL * 2 * newmul) % mod;  
  235.                   res = (res + (1LL * d * newsumsq) % mod) % mod;  
  236.                   ans[id] = res;  
  237.             }  
  238.             cout << "Case " << tcase << ":\n";  
  239.             for (int i = 1; i <= q; i++) cout << ans[i] << '\n';  
  240.             clean();  
  241.       }  
  242. }  

No comments:

Post a Comment

Note: Only a member of this blog may post a comment.