Source Code
#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define F first
#define S second
#define ii pair < int , int >
#define ever (;;)

struct segtree
{
    int sz;

    vector <ll> maxs;

    void init(int n)
    {
        sz = n;
        maxs.resize(n<<2);
    }

    void update(int nd,int l,int r,int pos,ll val)
    {
        if( l == pos && r == pos )
        {
            maxs[nd] = val;
            return;
        }

        int mid = (l+r)/2;

        if( pos <= mid )
            update(nd<<1,l,mid,pos,val);
        else
            update(nd<<1|1,mid+1,r,pos,val);

        maxs[nd] = max( maxs[nd<<1] , maxs[nd<<1|1] );
    }

    void update(int pos,ll val) { update(1,1,sz,pos,val); }

    ll querymax(int nd,int l,int r,int from,int to)
    {
        if( from <= l && r <= to )
            return maxs[nd];
        if( from > r || l > to )
            return -1e18;

        int mid = (l+r)/2;

        return max( querymax(nd<<1,l,mid,from,to) , querymax(nd<<1|1,mid+1,r,from,to) );
    }

    ll querymax(int l,int r) { return querymax(1,1,sz,l,r); }
};

const int N = 1001000;

int n,m,a[N],b[N],c[N];
ll dp[N],inf = 1e18,ans=-inf,sum;
vector <int> v[N];
set <int> s;
segtree st;

int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&a[i]);
        v[a[i]].push_back(i);
    }
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&c[i]);
        c[i] = min( c[i] , 0 );
        sum += c[i];
    }
    for(int i=1;i<=m;i++)
    {
        scanf("%d",&b[i]);
        s.insert(b[i]);
    }

    st.init(n);

    for(int i=1;i<=n;i++)
        st.update(i,-inf);

    for(auto &x:v[b[m]])
    {
        dp[x] = c[x];
        st.update(x,dp[x]);
    }

    for(int i=m-1;i>=1;i--)
    {
        for(auto &x:v[b[i]])
        {
            ll nxt = st.querymax(x,n);

            dp[x] = nxt + c[x];

            st.update(x,dp[x]);
        }

        for(auto &x:v[b[i+1]])
            st.update(x,-inf);
    }

    for(int i=1;i<=n;i++)
        if( a[i] == b[1] )
            ans = max( ans , dp[i] );

    printf("%lld\n",sum-ans);
}
Copy
Always with Me, Always with You Naseem17
GNU G++17
794 ms
94.9 MB
Accepted