Source Code
#include <bits/stdc++.h>
#define int long long
#define all(a) a.begin(),a.end()
#define pii pair<int,int>
#define F first
#define S second
#define mp make_pair
#define mod 998244353
using namespace std;
namespace combinatorics{
	//Dont forget to define mod and int as long long
	void add(int& x,int v){
		x+=v;
		x%=mod;
	}
	void sub(int&x ,int v){
		x-=v,x+=mod;
		x%=mod;
	}
	void mul(int& x,int v){
		x*=v;
		x%=mod;
	}
	int be(int b,int e){
		int res=1;
		while(e){
			if(e&1) res*=b;
			e>>=1;
			b*=b,b%=mod;
			res%=mod;
		}
		return res;
	}
	int inv(int n){
		n%=mod ;
		return be(n,mod-2);
	}
	vector<int> fac;
	void factorial(int n){
		fac.resize(n+1);
		fac[0]=1;
		for(int i=1;i<=n;i++) (fac[i]=fac[i-1]*i)%=mod;
		return ;
	}
	int C(int n , int r){
		if(n>=fac.size()){
			cout << "Brick! you didnt calculate n!" ;
			exit(1);
		}
		if(r>n) return 0;
		int res = fac[n];
		mul(res,inv(fac[r]));
		mul(res,inv(fac[n-r]));
		return res;
	}
	int P(int n,int r){
		if(n>=fac.size()){
			cout << "Brick! you didnt calculate n!";
			exit(1);
		}
		if(r>n) return 0;
		int res = fac[n];
		mul(res,inv(fac[n-r]));
		return res ;
	}
}
using namespace combinatorics ;
const int mxn = 5001;
vector<vector<int>> mem(mxn,vector<int>(mxn,-1)) ;
int calc(int x , int y){
	if(x==0&&y==0) return 1 ;
	return (fac[x+y]*inv(fac[x]*fac[y])%mod);
}
int dp(int i, int j ){
	if(i==0) return 1 ;
	if(j==0) return 0 ;
	if(mem[i][j]!=-1) return mem[i][j] ;
	int ans = dp(i-1,j)+dp(i,j-1) ;
	ans %=mod ;
	mem[i][j]=ans ;
	return ans ;
}
signed main(){
	factorial(mxn) ;
	int shit , shitt ;
	cin>>shit>>shitt ;
	string a , b ;
	cin>>a>>b ;
	int z = 0 , o = 0 ;
	for(char x :a) if(x=='0') ++z;
	o=a.size()-z;
	int ospots= 0 , zspots=0 ;
	for(char x :b){
		if(x=='0') --z,ospots++ ;
		else --o,zspots++ ;
	}
	if(o<0||z<0){
		cout << 0  ;
		return 0 ;
	}
	int p =0 ;
	int ans = 0 ;
//	cout << zspots << " " << ospots << endl ;
	for(int zero=0;zero<=z;zero++){
		for(int one=0;one<=o;++one){
			int x = z-zero , y = o-one ;
			p=calc(x,y) ;
		//	cout << zero << "  "<< one  << "  " << dp(one,ospots)  << " " ;
			int pp = dp(zero,zspots)*dp(one,ospots) ;
		//	cout << pp << "  ";
			pp%=mod ;
			pp*=p ;
			pp%=mod ;
		//	cout << endl ;
			ans+=pp ;
			ans%=mod ;
		}
	}
	z=0,o=0 ;
	for(int x :a){
		if(x=='0') ++z ;
		else o++ ;
	}
	int totalshuffles = calc(z,o) ;
	ans*=inv(totalshuffles) ;
	ans%=mod;
	cout << ans ;
	return 0 ;
}
Copy
Binary Stack Hard tamahom
GNU G++17
823 ms
197.1 MB
Runtime Error