Source Code
#include <bits/stdc++.h>
using namespace std;
#define ll long long
ll F[1000000];
ll iF[1000000];
ll mod = 998244353;
ll P(ll a, ll b) {
    if (b==0) return 1;
    ll x = P(a, b/2);
    x = (x*x)%mod;
    if (b%2) x = (x*a)%mod;
    return x;
}
ll inv(ll x) {
    return P(x, mod-2);
}
ll C(ll n, ll m){
    if (n==m) return 1;
    return (((F[n]*iF[n-m])%mod)*iF[m])%mod;
}
int main(){
    ios_base::sync_with_stdio(0);
    F[0] = 1;
    iF[0] = 1;
    for (int i=1;i<=250000;i++) {
        F[i] = (i*F[i-1])%mod;
        iF[i] = inv(F[i]);
    }
    int n, m;
    cin>>n>>m;
    string a,b;
    cin>>a>>b;
    int cnt[2] = {0, 0};
    int buckets[2] = {0, 0};
    for (auto x:a) cnt[x-'0']++;
    ll Q = (iF[cnt[0]] * iF[cnt[1]])%mod;
    Q = (Q*F[n])%mod;
    for (auto x:b) cnt[x-'0']--, buckets[1-(x-'0')]++;
    if (cnt[0] < 0 || cnt[1] < 0){
        cout<<0<<endl;
        return 0;
    }
    ll p = 0;
    buckets[0]++;
    int numOnes = cnt[1];
    int numZeros = cnt[0];
    while(numOnes>=0){
        ll cnt = (C(numZeros+buckets[0]-1, buckets[0]-1) * C(numOnes+buckets[1]-1, buckets[1]-1))%mod;
        p = (p+cnt)%mod;
        buckets[0]++;
        numOnes--;
    }
    ll ret = (p*inv(Q))%mod;
    cout<<ret<<endl;
}
Copy
Binary Stack Hard RedNextCentury
GNU G++17
183 ms
5.1 MB
Accepted