Source Code
#include<bits/stdc++.h>
using namespace std;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

vector<vector<vector<int>>> adj[2];
const int N= 1e5 + 10;
long long cnt[2][N][2];
int B;
void dfs(int node, int tree, int parity, int parent){
        cnt[tree][0][parity]++;
        for(auto &v : adj[tree][node]){
                int to = v[0];
                int w  = v[1];
                if(to == parent) continue;
                
                int new_parity= parity;
                if(w & (1 << B)) new_parity ^= 1;
                dfs(to, tree, new_parity, node);
        }
}
void reroot(int node, int tree, int parent, bool edge){
        for(int parity= 0; parity < 2; parity++){
                cnt[tree][node][parity]= cnt[tree][parent][parity ^ edge];
        }
        for(auto &v : adj[tree][node]){
                int to = v[0];
                int w  = v[1];
                if(to == parent) continue;
                
                reroot(to, tree, node, w & (1 << B));
        }
}
const int MOD = 1e9 + 7;
int main(){
        cin.tie(0);
        cin.sync_with_stdio(0);

        int n; cin>>n;

        for(int p = 0; p < 2; p++){
                adj[p].resize(n);
                for(int i = 0; i + 1 < n; i++){
                        int u, v, w;
                        cin>>u>>v>>w;
                        adj[p][u - 1].push_back({v - 1, w});
                        adj[p][v - 1].push_back({u - 1, w});
                }
        }
        
        long long ans= 0;
        for(int bt= 0; bt < 30; bt++){
                B= bt;
                memset(cnt, 0, sizeof cnt);
                for(int p = 0; p < 2; p++){
                        dfs(0, p, 0, -1);
                        for(auto &v : adj[p][0]){
                                int to = v[0];
                                int w  = v[1];
                        
                                int edge= w & (1 << bt);
                                reroot(to, p, 0, edge);
                        }
                }
                for(int middle_node= 0; middle_node < n; middle_node++){
                        ans += (cnt[0][middle_node][0] * cnt[1][middle_node][1] % MOD +
                                cnt[0][middle_node][1] * cnt[1][middle_node][0] % MOD) * (1 << bt) % MOD;
                
                        if(ans >= MOD) ans -= MOD;
                }
        }
        cout<<ans<<"\n";
}

Copy
Trees xOr maghrabyJr_
GNU G++17
4307 ms
33.7 MB
Accepted