Source Code
#include <bits/stdc++.h>
#define int long long  
#define all(a) a.begin(),a.end() 
#define pii pair<int,int> 
#define F first 
#define S second 
#define mp make_pair 
#define mod 998244353
using namespace std; 
namespace combinatorics{
	//Dont forget to define mod and int as long long 
	void add(int& x,int v){
		x+=v; 
		x%=mod;
	}
	void sub(int&x ,int v){
		x-=v,x+=mod;
		x%=mod;
	}
	void mul(int& x,int v){
		x*=v; 
		x%=mod; 
	}
	int be(int b,int e){
		int res=1;
		while(e){
			if(e&1) res*=b; 
			e>>=1; 
			b*=b,b%=mod;
			res%=mod;
		}
		return res; 
	}
	int inv(int n){
		n%=mod ; 
		return be(n,mod-2);
	}
	vector<int> fac;
	void factorial(int n){
		fac.resize(n+1); 
		fac[0]=1; 
		for(int i=1;i<=n;i++) (fac[i]=fac[i-1]*i)%=mod; 
		return ;
	}
	int C(int n , int r){
		if(n>=fac.size()){
			cout << "Brick! you didnt calculate n!" ; 
			exit(1);
		}
		if(r>n) return 0;		
		int res = fac[n];
		mul(res,inv(fac[r])); 
		mul(res,inv(fac[n-r])); 
		return res; 
	}
	int P(int n,int r){
		if(n>=fac.size()){
			cout << "Brick! you didnt calculate n!"; 
			exit(1);
		}
		if(r>n) return 0; 
		int res = fac[n]; 
		mul(res,inv(fac[n-r]));
		return res ; 
	}
}
using namespace combinatorics ; 
const int mxn = 5001; 
vector<vector<int>> mem(mxn,vector<int>(mxn,-1)) ; 
int calc(int x , int y){
	if(x==0&&y==0) return 1 ; 
	return (fac[x+y]*inv(fac[x]*fac[y])%mod);  
}
int dp(int i, int j ){
	if(i==0) return 1 ;
	if(j==0) return 0 ; 
	if(mem[i][j]!=-1) return mem[i][j] ; 
	int ans = dp(i-1,j)+dp(i,j-1) ; 
	ans %=mod ; 
	mem[i][j]=ans ; 
	return ans ; 
}
signed main(){
	factorial(mxn) ; 
	int shit , shitt ; 
	cin>>shit>>shitt ; 
	string a , b ; 
	cin>>a>>b ; 
	int z = 0 , o = 0 ; 
	for(char x :a) if(x=='0') ++z; 
	o=a.size()-z; 
	int ospots= 0 , zspots=0 ; 
	for(char x :b){
		if(x=='0') --z,ospots++ ;  
		else --o,zspots++ ; 
	} 
	if(o<0||z<0){
		cout << 0  ; 
		return 0 ;
	}
	int p =0 ; 
	int ans = 0 ; 
//	cout << zspots << " " << ospots << endl ; 
	for(int zero=0;zero<=z;zero++){
		for(int one=0;one<=o;++one){
			int x = z-zero , y = o-one ; 
			p=calc(x,y) ;
		//	cout << zero << "  "<< one  << "  " << dp(one,ospots)  << " " ;  
			int pp = dp(zero,zspots)*dp(one,ospots) ;
		//	cout << pp << "  "; 
			pp%=mod ;
			pp*=p ; 
			pp%=mod ;
		//	cout << endl ; 
			ans+=pp ; 
			ans%=mod ;  
		}
	}	
	z=0,o=0 ; 
	for(int x :a){
		if(x=='0') ++z ; 
		else o++ ; 
	}
	int totalshuffles = calc(z,o) ; 
	ans*=inv(totalshuffles) ; 
	ans%=mod; 
	cout << ans ; 
	return 0 ; 
}
Copy
Binary Stack Easy Boredom
GNU G++17
837 ms
196.9 MB
Accepted