Problem Link : EvenOddity
Category : Data Structure, Counting, Bits, Centroid Decomposition
Contest : December Circuits'23
Tutorial : Counting Subarrays
#include "bits/stdc++.h"
using namespace std;
#define int long long int
#define endl '\n'
const int maxn = 1e6 + 5;
struct Node {
int v, w;
Node(int v = 0, int w = 0)
: v(v), w(w) {}
};
int n;
vector<vector<Node>> graph;
vector<int> subtree;
vector<bool> isCentroid;
int curTreeNodes;
void calcSubtreeSize(int u, int p) {
subtree[u] = 1;
for (Node it : graph[u]) {
int v = it.v;
if (v == p or isCentroid[v]) {
continue;
}
calcSubtreeSize(v, u);
subtree[u] += subtree[v];
}
}
int getCentroid(int u, int p) {
for (Node it : graph[u]) {
int v = it.v;
if (v == p or isCentroid[v]) {
continue;
}
if (subtree[v] > curTreeNodes / 2) {
return getCentroid(v, u);
}
}
return u;
}
int cache[2][2]; // cache[ distance % 2 ][ parity(xorSum) % 2 ]
int curCache[2][2];
int ans;
void add(int u, int p, int lvl, int xorSum) {
int bit = lvl & 1;
int parity = __builtin_popcount(xorSum) & 1;
int curAns = cache[bit ^ 1][parity];
ans += curAns;
curCache[bit][parity]++;
for (Node it : graph[u]) {
int v = it.v;
int w = it.w;
if (v == p or isCentroid[v]) {
continue;
}
add(v, u, lvl + 1, xorSum ^ w);
}
}
void decompose(int u) {
calcSubtreeSize(u, 0);
curTreeNodes = subtree[u];
int centroid = getCentroid(u, 0);
isCentroid[centroid] = true;
memset(cache, 0, sizeof cache);
cache[0][0] = 1;
for (Node it : graph[centroid]) {
int child = it.v;
if (isCentroid[child]) {
continue;
}
memset(curCache, 0, sizeof curCache);
add(child, centroid, 1, it.w);
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
cache[i][j] += curCache[i][j];
}
}
}
for (Node it : graph[centroid]) {
int child = it.v;
if (isCentroid[child]) {
continue;
}
decompose(child);
}
}
signed main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);
cout.precision(12);
bool FILEIO = 1;
if (FILEIO and fopen("f2.txt", "r")) {
freopen("f2.txt", "r", stdin);
freopen("f2out.txt", "w", stdout);
}
int tc;
cin >> tc;
for (int tcase = 1; tcase <= tc; tcase++) {
cin >> n;
graph.clear(); graph.resize(n + 1);
subtree.clear(); subtree.resize(n + 1);
isCentroid.clear(); isCentroid.resize(n + 1);
for (int e = 1; e < n; e++) {
int u, v, w;
cin >> u >> v >> w;
graph[u].emplace_back(v, w);
graph[v].emplace_back(u, w);
}
ans = 0;
decompose(1);
cout << ans * 2 << endl;
}
}