Source Code
#include "bits/stdc++.h"

using namespace std;

#define endl "\n"
#define ll long long
#define sz(s) (int) (s.size())
#define INF 0x3f3f3f3f3f3f3f3fLL
#define all(v) v.begin(), v.end()
#define watch(x) cout << (#x) << " = " << x << endl
const int dr[]{-1, -1, 0, 1, 1, 1, 0, -1};
const int dc[]{0, 1, 1, 1, 0, -1, -1, -1};

void run() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
#ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
#else
#endif
}

void dfs(int node, int par, vector<vector<pair<int, int>>> &adj, int *dist) {
    for (auto it: adj[node]) {
        int ch = it.first, w = it.second;
        if (ch == par)continue;
        dist[ch] = dist[node] ^ w;
        dfs(ch, node, adj, dist);
    }
}

void readTree(int n, int *dist) {
    vector<vector<pair<int, int>>> adj(n + 1);
    for (int i = 1; i < n; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        u--;
        v--;
        adj[u].push_back({v, w});
        adj[v].push_back({u, w});
    }
    dfs(0, -1, adj, dist);
}

bool getBit(int num, int idx) {
    return (num >> idx) & 1;
}

const int N = 1e5 + 9, LG = 30, MOD = 1e9 + 7;
int dist[N], dist2[N], dp[N][(1 << 3)][2], n;

void add(int &a, int b) {
    if ((a += b) >= MOD)a -= MOD;
}

int main() {
    run();
    cin >> n;
    readTree(n, dist);
    readTree(n, dist2);
    int res = 0;
    for (int bit = LG - 1; bit >= 0; bit--) {
        memset(dp, 0, sizeof dp);
        dp[n][0][1] = 1;
        for (int node = n - 1; node >= 0; node--) {
            vector<int> values = {
                    getBit(dist[node], bit),
                    getBit(dist[node] ^ dist2[node], bit),
                    getBit(dist2[node], bit)
            };
            for (int mask = 0; mask < (1 << 3); mask++) {
                for (int val = 0; val < 2; val++) {
                    int &rt = dp[node][mask][val];
                    rt = dp[node + 1][mask][val];

                    for (int sub = mask; sub > 0; sub = (sub - 1) & mask) {
                        int nval = val;
                        for (int k = 0; k < 3; k++)
                            if (sub & (1 << k))nval ^= values[k];
                        add(rt, dp[node + 1][mask ^ sub][nval]);
                    }
                }
            }
        }
        int cnt = dp[0][(1 << 3) - 1][0];
        cnt = (cnt * (1LL << bit)) % MOD;
        add(res, cnt);
    }
    cout << res << endl;
}
Copy
Trees xOr Ma7moud.7amdy
GNU G++17
1354 ms
13.5 MB
Accepted