Source Code
#include <bits/stdc++.h>

using namespace std;
long long mod = int(1e9+7);
vector<pair<int,int>> adj1[100001],adj2[100001];
int bi1[30][2]={0},bi2[30][2]={0};
int ans1[100001][30][2],ans2[100001][30][2];
void preCompute1(int i,int pr,int cur){
    for(int j = 0;j<30;j++){
        if(cur&(1<<j))bi1[j][1]++;
        else bi1[j][0]++;
    }
    for(auto j:adj1[i]){
        if(j.first!=pr){
            preCompute1(j.first,i,cur^j.second);
        }
    }
}
void preCompute2(int i,int pr,int cur){
    for(int j = 0;j<30;j++){
        if(cur&(1<<j))bi2[j][1]++;
        else bi2[j][0]++;
    }
    for(auto j:adj2[i]){
        if(j.first!=pr){
            preCompute2(j.first,i,cur^j.second);
        }
    }
}
void iterate1(int i,int pr,int c){
    for(int j = 0;j<30;j++){
        if(i==1){
            ans1[i][j][1]=bi1[j][!((1<<j)&c)];
            ans1[i][j][0]=bi1[j][((1<<j)&c)];
        }else{
            ans1[i][j][1]=ans1[pr][j][!((1<<j)&c)];
            ans1[i][j][0]=ans1[pr][j][((1<<j)&c)];
        }
    }
    for(auto j:adj1[i]){
        if(j.first!=pr){
            iterate1(j.first,i,j.second);
        }
    }
}void iterate2(int i,int pr,int c){
    for(int j = 0;j<30;j++){
        if(i==1){
            ans2[i][j][1]=bi2[j][!((1<<j)&c)];
            ans2[i][j][0]=bi2[j][((1<<j)&c)];
        }else{
            ans2[i][j][1]=ans2[pr][j][!((1<<j)&c)];
            ans2[i][j][0]=ans2[pr][j][((1<<j)&c)];
        }
        //cout<<i<<" "<<j<<" "<<ans2[i][j][1]<<" "<<ans2[pr][j][!((1<<j)&c)]<<" "<<c<<"\n";
    }
    for(auto j:adj2[i]){
        if(j.first!=pr){
            iterate2(j.first,i,j.second);
        }
    }
}
int main(){
    int n;
    cin>>n;
    for(int i = 0;i<n-1;i++){
        long long a,b,c;
        cin>>a>>b>>c;
        adj1[a].push_back({b,c});
        adj1[b].push_back({a,c});
    }for(int i = 0;i<n-1;i++){
        long long a,b,c;
        cin>>a>>b>>c;
        adj2[a].push_back({b,c});
        adj2[b].push_back({a,c});
    }
    preCompute1(1,0,0);
    preCompute2(1,0,0);
    iterate1(1,0,0);
    iterate2(1,0,0);
    /*
    cout<<ans1[2][1][1]<<"\n";
    for(int i = 1;i<=n;i++){
        for(int j=0;j<32;j++){
            cout<<ans1[i][j][1]<<" ";
        }
        cout<<"\n";
    }for(int i = 1;i<=n;i++){
        for(int j=0;j<32;j++){
            cout<<ans2[i][j][1]<<" ";
        }
        cout<<"\n";
    }for(int i = 1;i<=n;i++){
        for(int j=0;j<32;j++){
            cout<<ans1[i][j][0]<<" ";
        }
        cout<<"\n";
    }for(int i = 1;i<=n;i++){
        for(int j=0;j<32;j++){
            cout<<ans2[i][j][0]<<" ";
        }
        cout<<"\n";
    }*/
    long long fina = 0;
    for(int i = 0;i<30;i++){
        for(int j=1;j<=n;j++){
            fina+=(((((n)-ans1[j][i][1])*ans2[j][i][1])%mod)*((1<<i)%mod))%mod;
            fina+=(((ans1[j][i][1]*((n)-ans2[j][i][1]))%mod)*((1<<i)%mod))%mod;
        }
    }
    cout<<fina<<"\n";
}
Copy
Trees xOr Ahmed57
GNU G++17
5 ms
4.9 MB
Wrong Answer