Source Code
#include<bits/stdc++.h>
#define ll long long 

using namespace std;
ll MOD = 998244353 ;
ll bigMod(ll x,ll y){
    if (y == 0)return 1;
    if (y == 1)return x;
    ll res = bigMod(x , y / 2LL);
    res *= res;
    res %= MOD;
    if (y % 2){
        res *= x;
        res %= MOD;
    }
    return res;
}
ll fac[1000005];
ll choose(ll x,ll y){
    if (x == y)return 1;
    if (y > x)return 0;
    if (y == 0)return 1;
    ll up = fac[x];
    ll down = fac[y] * fac[x-y];
    down %= MOD;
    ll ret = up * bigMod(down , MOD - 2);
    ret %= MOD;
    return ret;
}

int t;
ll n , k , m;
ll a[1000005] , b[1000005] , c[1000005];
ll dp[1000005];
ll idx[1000005];
int pre[1000005];
int ch[1000005];
int main()
{
    ios::sync_with_stdio(0);
    cin >> n >> m;
    ll sum = 0;
    for (int i=0;i<n;i++)cin >> a[i];
    for (int i=0;i<n;i++)cin >> c[i] , sum += c[i];
    for (int i=0;i<m;i++)cin >> b[i];
    for (int i=0;i<=1000000;i++)idx[i] = -1;
    for (int i=0;i<m;i++){
        idx[b[i]] = i;
    }
    memset(ch,-1,sizeof ch);
    ll add = 0;
    dp[0] = 0;
    for (int i=1;i<=1000000;i++)dp[i] = -1e17;
    for (int i=0;i<n;i++){
        if (idx[a[i]] == -1){
            continue;
        }
        int x = idx[a[i]];
        if (x == 0){
            // dp[a[i]] = max(dp[a[i]] , max(dp[a[i]] + c[i] , c[i]));
            if (c[i] > dp[a[i]]){
                ch[a[i]] = i;
                dp[a[i]] = max(dp[a[i]] , c[i]);
            }
            else {
                dp[a[i]] = dp[a[i]];
            }
        }
        else {
            int last = x - 1;
            int num = b[last];
            // dp[a[i]] = max(dp[a[i]] ,max(dp[a[i]] + c[i] , dp[num] + c[i]));
            if (dp[num] + c[i] > dp[a[i]]){
                ch[a[i]] = i;
                dp[a[i]] = max(dp[a[i]] , dp[num] + c[i]);
            }
            else {
                dp[a[i]] = dp[a[i]];
            }
        }
    }
    set<int> se;
    for (int i=0;i<m;i++){
        int x = ch[b[i]];
        se.insert(x);
    }
    for (int i=0;i<n;i++){
        if (se.find(i) == se.end() && c[i] > 0)add+=c[i];
    }
    // cout << dp[b[m-1]] << " " << add << endl;
    cout << min(0LL , sum - (add + dp[b[m-1]])) << endl;
    return 0;
}
Copy
Always with Me, Always with You Khaled97ha
GNU G++17
287 ms
48.9 MB
Wrong Answer