Source Code
#include <bits/stdc++.h>
using namespace std;
#define ll long long
ll F[1000000];
ll iF[1000000];
ll mod = 998244353;
ll P(ll a, ll b) {
    if (b==0) return 1;
    ll x = P(a, b/2);
    x = (x*x)%mod;
    if (b%2) x = (x*a)%mod;
    return x;
}
ll inv(ll x) {
    return P(x, mod-2);
}
ll C(ll n, ll m){
    return (((F[n]*iF[n-m])%mod)*iF[m])%mod;
}
int main(){
    ios_base::sync_with_stdio(0);
    F[0] = 1;
    iF[0] = 1;
    for (int i=1;i<=5000;i++) {
        F[i] = (i*F[i-1])%mod;
        iF[i] = inv(F[i]);
    }
    int n, m;
    cin>>n>>m;
    string a,b;
    cin>>a>>b;
    int cnt[2] = {0, 0};
    int buckets[2] = {0, 0};
    for (auto x:a) cnt[x-'0']++;
    ll Q = (iF[cnt[0]] * iF[cnt[1]])%mod;
    Q = (Q*F[n])%mod;
    for (auto x:b) cnt[x-'0']--, buckets[1-(x-'0')]++;
    if (cnt[0] < 0 || cnt[1] < 0){
        cout<<0<<endl;
        return 0;
    }
    ll p = 0;
    for (int i=0;i<=cnt[0];i++){
        for (int j=0;j<=cnt[1];j++) {
            if (buckets[0] == 0 && i>0) continue;
            if (buckets[1] == 0 && j>0) continue;
            ll afterOnes = cnt[1] - j;
            ll afterZeros = cnt[0] - i;
            ll afterCnt = (F[afterOnes+afterZeros]*iF[afterOnes])%mod;
            afterCnt = (afterCnt*iF[afterZeros])%mod;
            ll beforeCnt = (C(i+buckets[0]-1, buckets[0]-1) * C(j+buckets[1]-1, buckets[1]-1))%mod;
            p = (p+(beforeCnt*afterCnt)%mod)%mod;
        }
    }
    ll ret = (p*inv(Q))%mod;
    cout<<ret<<endl;
}

Copy
Binary Stack Easy RedNextCentury
GNU G++17
4 ms
712 KB
Wrong Answer