Source Code
#include <bits/stdc++.h>
using namespace std;
 
using ll = long long;
using pii = pair<int, int>;
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()

const int INF = 2e9, N = 1005, M = 1e9 + 7, LOG = 16;
const ll LINF = 1e18;

int n, d[2][N];
ll ans, cnt[2][2][31];
vector<pii> t1[N], t2[N];

void dfs(int u, int p, int t) {
    for(int i = 0; i < 31; i++) 
        cnt[t][(d[t][u] >> i) & 1][i]++;

    for(auto [v, w] : (t ? t2[u] : t1[u])) if(v != p) {
        d[t][v] = d[t][u] ^ w;
        dfs(v, u, t);
    }
}

signed main () {
    ios::sync_with_stdio(false); cin.tie(nullptr);
    cin >> n;

    for(int i = 1, x, y, w; i <= n-1; i++) {
        cin >> x >> y >> w;
        t1[x].push_back({y, w});
        t1[y].push_back({x, w});
    }
    for(int i = 1, x, y, w; i <= n-1; i++) {
        cin >> x >> y >> w;
        t2[y].push_back({x, w});
        t2[x].push_back({y, w});
    }
    
    dfs(1, 0, 0);
    dfs(1, 0, 1);

    for(int i = 1; i <= n; i++) {
        int s = d[0][i] ^ d[1][i];

        for(int j = 0; j < 31; j++) {
            if(s & (1 << j)) 
                ans += (1LL << j) * (cnt[0][0][j] * cnt[1][0][j] % M + cnt[0][1][j] * cnt[1][1][j] % M) % M;
            else 
                ans += (1LL << j) * (cnt[0][0][j] * cnt[1][1][j] % M + cnt[0][1][j] * cnt[1][0][j] % M) % M;
            
            ans %= M;
        }
    }
    cout << ans;
}
Copy
Trees xOr ahmet23
GNU G++17
6 ms
504 KB
Runtime Error