Source Code
#include "bits/stdc++.h"
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
#define all(v) v.begin(), v.end()
#define pb push_back    
#define sz(x) (int)(x).size()
const int N = 3e5 + 5;
const ll mod = 998244353;
const int nax = 2e5 + 5;
int fac[nax];
int inv_fac[nax];
inline ll add(ll a, ll b) {return a + b >= mod ? a + b - mod : a + b;}
inline ll mul(ll a, ll b) {return (ll)a * b % mod;}
int my_pow(int a, int b) {
	int r = 1;
	while(b) {
		if(b % 2) {
			r = mul(r, a);
		}
		a = mul(a, a);
		b /= 2;
	}
	return r;
}
int my_inv(int a) {
	return my_pow(a, mod - 2);
}
inline int d(int a, int b){return mul(a, my_pow(b, mod - 2));}
int C(int a, int b) {
	return mul(fac[a], mul(inv_fac[b], inv_fac[a-b]));
}
int inv_C(int a, int b) {
	return mul(inv_fac[a], mul(fac[b], fac[a-b]));
}
void solve() {
    
    fac[0] = inv_fac[0] = 1;
	for(int i = 1; i < nax; ++i) {
		fac[i] = mul(fac[i-1], i);
		inv_fac[i] = my_inv(fac[i]);
	}
    
    int n, m;
    string a, s;
    cin >> n >> m >> a >> s;
    auto perm = [&](int a, int b) { 
        return d(fac[a + b], mul(fac[a], fac[b]));
    };

    vi freqs(2), freq(2);
    
    for(int i = 0; i < n; ++i)
        ++freq[a[i] - '0'];
    for(int i = 0; i < m; ++i)
        --freq[s[i] - '0'], ++freqs[s[i] - '0'];

    ll ans = 0;
    for(int z = 0; z <= freq[0]; ++z) {
        for(int o = 0; o <= freq[1]; ++o) {
            int tz = freq[0] - z, to = freq[1] - o;
            ll pr = mul(perm(to, tz), mul(C(freqs[0], z), C(freqs[1], o)));
            ans = add(ans, pr);
        }
    }
    freq = vi(2, 0);
    for(int i = 0; i < n; ++i)
        ++freq[a[i] - '0'];
    printf("%lld", mul(ans, my_inv(perm(freq[0], freq[1]))));
    // ll ans = mul(fac[C(freq[0], freqs[0])], mul(fac[freq[s]]C(freq[1], freqs[1])]);
    // ll zero = mul(C(freq[0], freqs[0]), fac[freqs[0]]);
    // ll one = mul(C(freq[1], freqs[1]), fac[freqs[1]]);
    // ll ans = mul(zero, one);
    // // ll ans = mul(
    // ll de = d(fac[n],  mul(fac[freq[0]], fac[freq[1]]));
    // printf("%lld/%lld\n", ans, de);
    // ll t = mul(ans, my_inv(de));
    // printf("ans = %lld", t);
}
int main() {
    cin.tie(0)->sync_with_stdio(0);
    int t = 1;
    while(t--)
        solve();
}
Copy
Binary Stack Easy noomaK
GNU G++17
25 ms
2.1 MB
Wrong Answer