Source Code
#include <bits/stdc++.h>

using namespace std;
long long mod = int(1e9+7);
long long fast(long long a,long long b){
    if(b==0)return 1;
    if(b==1)return a%mod;
    long long ha = fast(a,b/2)%mod;
    if(b%2==0)return (ha*ha)%mod;
    else return (((ha*ha)%mod)*a)%mod;
}
vector<pair<long long,long long>> adj1[100001],adj2[100001];
long long bi1[30][2]={0},bi2[30][2]={0};
long long ans1[100001][30][2],ans2[100001][30][2];
void preCompute1(int i,int pr,long long cur){
    for(int j = 0;j<30;j++){
        if(cur&fast(2,j))bi1[j][1]++;
        else bi1[j][0]++;
    }
    for(auto j:adj1[i]){
        if(j.first!=pr){
            preCompute1(j.first,i,cur^j.second);
        }
    }
}
void preCompute2(int i,int pr,long long cur){
    for(int j = 0;j<30;j++){
        if(cur&fast(2,j))bi2[j][1]++;
        else bi2[j][0]++;
    }
    for(auto j:adj2[i]){
        if(j.first!=pr){
            preCompute2(j.first,i,cur^j.second);
        }
    }
}
void iterate1(int i,int pr,long long c){
    for(int j = 0;j<30;j++){
        if(i==1){
            ans1[i][j][1]=bi1[j][(fast(2,j)&c?0:1)];
            ans1[i][j][0]=bi1[j][(fast(2,j)&c?1:0)];
        }else{
            ans1[i][j][1]=ans1[pr][j][(fast(2,j)&c?0:1)];
            ans1[i][j][0]=ans1[pr][j][(fast(2,j)&c?1:0)];
        }
    }
    for(auto j:adj1[i]){
        if(j.first!=pr){
            iterate1(j.first,i,j.second);
        }
    }
}void iterate2(int i,int pr,long long c){
    for(int j = 0;j<30;j++){
        if(i==1){
            ans2[i][j][1]=bi2[j][(fast(2,j)&c?0:1)];
            ans2[i][j][0]=bi2[j][(fast(2,j)&c?1:0)];
        }else{
            ans2[i][j][1]=ans2[pr][j][(fast(2,j)&c?0:1)];
            ans2[i][j][0]=ans2[pr][j][(fast(2,j)&c?1:0)];
        }
        //cout<<i<<" "<<j<<" "<<ans2[i][j][1]<<" "<<ans2[pr][j][!((1<<j)&c)]<<" "<<c<<"\n";
    }
    for(auto j:adj2[i]){
        if(j.first!=pr){
            iterate2(j.first,i,j.second);
        }
    }
}
int main(){
    memset(ans1,0,sizeof ans1);
    memset(ans2,0,sizeof ans2);
    memset(bi1,0,sizeof bi1);
    memset(bi2,0,sizeof bi2);
    int n;
    cin>>n;
    for(int i = 0;i<n-1;i++){
        long long a,b,c;
        cin>>a>>b>>c;
        adj1[a].push_back({b,c});
        adj1[b].push_back({a,c});
    }for(int i = 0;i<n-1;i++){
        long long a,b,c;
        cin>>a>>b>>c;
        adj2[a].push_back({b,c});
        adj2[b].push_back({a,c});
    }
    preCompute1(1,0,0);
    preCompute2(1,0,0);
    iterate1(1,0,0);
    iterate2(1,0,0);
    cout<<bi1[0][0]<<"\n";
    for(int i = 1;i<=n;i++){
        for(int j=0;j<32;j++){
            cout<<ans1[i][j][1]<<" ";
        }
        cout<<"\n";
    }for(int i = 1;i<=n;i++){
        for(int j=0;j<32;j++){
            cout<<ans2[i][j][1]<<" ";
        }
        cout<<"\n";
    }for(int i = 1;i<=n;i++){
        for(int j=0;j<32;j++){
            cout<<ans1[i][j][0]<<" ";
        }
        cout<<"\n";
    }for(int i = 1;i<=n;i++){
        for(int j=0;j<32;j++){
            cout<<ans2[i][j][0]<<" ";
        }
        cout<<"\n";
    }
    long long fina = 0;
    for(int i = 0;i<30;i++){
        for(int j=1;j<=n;j++){
            fina+=(((ans1[j][i][1]*ans2[j][i][0])%mod)*fast(2,i))%mod;
            fina%=mod;
            fina+=(((ans2[j][i][1]*ans1[j][i][0])%mod)*fast(2,i))%mod;
            fina%=mod;
        }
    }
    cout<<fina<<"\n";
}
Copy
Trees xOr Ahmed57
GNU G++17
494 ms
105.5 MB
Runtime Error