随机爬树题解

随机爬树题解

题目传送门

更好的阅读体验

(n^2) 暴力:

思路:

  • 求期望,即求所有点的权值乘上概率后的和,即:
[ans=sum_{u in V}{P_u a_u} ]

  • 求每个点的概率 (P_u)

    • 由题,令走到父亲的概率为 (P_f),走到儿子 (s) 的概率则为 (P_f times frac{w_s}{sum_f})(其中 (sum_f)(f) 所有儿子的 (w) 之和)。
  • 统计答案:

    • (ans_u) 表示 (u) 子树(不含 (u) 本身)的答案之和,最终答案为 (ans_1+a_1)
    • 暴力修改,跑 DFS 暴力求和即可。

代码:

//n^2暴力 60pts #include<iostream> #include<cstdio> #include<cstring> #include<vector> using namespace std; typedef long long ll; const int N=1e5+5,Mod=998244353; int n,q,fa[N]; ll sum[N],inv[N],w[N],a[N]; ll p[N],ans[N]; vector <int> e[N]; ll qpow(ll a,ll b){     ll res=1;     while(b){         if(b&1) (res*=a)%=Mod;         (a*=a)%=Mod;         b>>=1;     }     return res%Mod; } void dfs(int u){     ans[u]=0;     inv[u]=qpow(sum[u],Mod-2);     for(int i=0;i<e[u].size();i++){         int v=e[u][i];         p[v]=p[u]*w[v]%Mod*inv[u]%Mod;         dfs(v);         (ans[u]+=(ans[v]+p[v]*a[v]%Mod)%Mod)%=Mod;     } } int main(){     scanf("%d",&n);     for(int i=2;i<=n;i++){         scanf("%d",&fa[i]);         e[fa[i]].push_back(i);     }     for(int i=1;i<=n;i++){         scanf("%lld",&w[i]);         (sum[fa[i]]+=w[i])%=Mod;     }     for(int i=1;i<=n;i++){         scanf("%lld",&a[i]);     }     p[1]=1;     dfs(1);     printf("%lldn",(ans[1]+a[1])%Mod);     scanf("%d",&q);     int u;     ll ww,aa;     for(int i=1;i<=q;i++){         scanf("%d",&u);         sum[fa[u]]=(sum[fa[u]]-w[u]+Mod)%Mod;         scanf("%lld%lld",&w[u],&a[u]);         sum[fa[u]]=(sum[fa[u]]+w[u])%Mod;         dfs(1);         printf("%lldn",(ans[1]+a[1])%Mod);     }      return 0; } 

优化后正解:

思路:

  • 每次修改 (u) 只对 (f) 的整棵子树产生影响,故用线段树维护子树和。

  • 考虑有哪些影响:

    • (w_u) 修改为 (ww),使得 (sum_f) 发生改变,故整棵子树的概率都会变化:

      1. 对于 (f) 子树中每个节点 (t)(不含 (f)),原概率为 (P_t=P_{fa_t}times frac{w_t}{sum_{fa_t}}),修改后变为 (P_t'=P_{fa_t}times frac{w_t}{sum_{fa_t}-w_u+ww})

        [P_t'=P_ttimesfrac{sum_{fa_t}}{sum_{fa_t}-w_u+ww} ]

        则有 (Delta P=frac{sum_{fa_t}}{sum_{fa_t}-w_u+ww})

      2. 对于点 (u),原概率为 (P_u=P_ftimes frac{w_u}{sum_f}),修改后变为 (P_u'=P_f times frac{ww}{sum_f-w_u+ww})

        [P_u'=P_u times frac{ww}{wu} times Delta P ]

        则有 (Delta w=frac{ww}{wu})

      3. 对于 (u) 子树中的每个点 (t)(不含 (u)),原概率为 (P_t=P_{fa_t}times frac{w_t}{sum_{fa_t}}),修改后变为 (P_t'=P_{fa_t}times Delta w times frac{w_t}{sum_{fa_t}})

    • (a_u) 修改为 (aa),只对 (u) 的贡献产生影响,原贡献为 (ans=P_u times a_u),修改后变为 (ans'=P_u times aa),即:

      [ans'=ans times frac{aa}{au} ]

  • 综上,变化有:

    • (P_t'=P_ttimesfrac{sum_{fa_t}}{sum_{fa_t}-w_u+ww})
    • (P_u'=P_u times frac{ww}{wu} times frac{sum_{fa_t}}{sum_{fa_t}-w_u+ww})
    • (P_t'=P_{fa_t}times frac{ww}{wu} times frac{w_t}{sum_{fa_t}})
    • (ans'=ans times frac{aa}{au})

    其中操作 (1)(2)(2)(3) 可以合并。

代码:

#include<iostream> #include<cstdio> #include<vector> using namespace std; typedef long long ll; const int N=1e5+5,Mod=998244353; int n,q,fa[N]; ll sum[N],w[N],a[N],p[N],val[N]; int tid,dfn[N],siz[N]; vector <int> e[N]; struct Tree{     ll sum,tag; }tr[N<<2]; ll qpow(ll a,ll b){     ll res=1;     while(b){         if(b&1) (res*=a)%=Mod;         (a*=a)%=Mod;         b>>=1;     }     return res%Mod; } ll inv(ll x){     return qpow(x,Mod-2); } void dfs(int u){     dfn[u]=++tid;     val[tid]=p[u]*a[u]%Mod;     siz[u]=1;     for(int i=0;i<e[u].size();i++){         int v=e[u][i];         p[v]=p[u]*w[v]%Mod*inv(sum[u])%Mod;         dfs(v);         siz[u]+=siz[v];     } } void update(int k){     tr[k].sum=(tr[k<<1].sum+tr[k<<1|1].sum)%Mod; } void pushdown(int k,int l,int r){     tr[k<<1].sum=tr[k<<1].sum*tr[k].tag%Mod;     tr[k<<1].tag=tr[k<<1].tag*tr[k].tag%Mod;     tr[k<<1|1].sum=tr[k<<1|1].sum*tr[k].tag%Mod;     tr[k<<1|1].tag=tr[k<<1|1].tag*tr[k].tag%Mod;     tr[k].tag=1; } void build(int k,int l,int r){     tr[k].sum=0;     tr[k].tag=1;     if(l==r){         tr[k].sum=val[l]%Mod;         return ;     }     int mid=(l+r)>>1;     build(k<<1,l,mid);     build(k<<1|1,mid+1,r);     update(k); } void modify(int k,int l,int r,int x,int y,int v){     if(x<=l&&r<=y){         tr[k].sum=tr[k].sum*v%Mod;         tr[k].tag=tr[k].tag*v%Mod;         return ;     }     pushdown(k,l,r);     int mid=(l+r)>>1;     if(x<=mid) modify(k<<1,l,mid,x,y,v);     if(mid<y) modify(k<<1|1,mid+1,r,x,y,v);     update(k); } int main(){     //freopen("climb2.in","r",stdin);     //freopen("climb.out","w",stdout);     scanf("%d",&n);     for(int i=2;i<=n;i++){         scanf("%d",&fa[i]);         e[fa[i]].push_back(i);     }     for(int i=1;i<=n;i++){         scanf("%lld",&w[i]);         sum[fa[i]]=(sum[fa[i]]+w[i])%Mod;     }     for(int i=1;i<=n;i++){         scanf("%lld",&a[i]);     }     p[1]=1;     dfs(1);     build(1,1,n);     printf("%lldn",tr[1].sum%Mod);     scanf("%d",&q);     int u;     ll ww,aa;     for(int i=1;i<=q;i++){         scanf("%d%lld%lld",&u,&ww,&aa);         if(fa[u]){             /*             1.f子树(除f本身):Pt=Pf*(wt/sumf)* (1/(sumf-wu+ww))*sumf             2.u:pu=pf*(wu/sumf)*(ww/wu)* (1/(sumf-wu+ww))*sumf             3.u子树(除u本身):Pt=Pft* (ww/wu) *(wt/sumft)             */             //1. 2. 修改f子树(除f本身)             modify(1,1,n,dfn[fa[u]]+1,dfn[fa[u]]+siz[fa[u]]-1,inv(((sum[fa[u]]-w[u]+ww)%Mod+Mod)%Mod)%Mod*sum[fa[u]]%Mod);             //2. 3.             modify(1,1,n,dfn[u],dfn[u]+siz[u]-1,ww*inv(w[u])%Mod);             sum[fa[u]]=((sum[fa[u]]-w[u]+ww)%Mod+Mod)%Mod;         }         w[u]=ww;         //修改au:Pu*au* (aa/au)         modify(1,1,n,dfn[u],dfn[u],aa*inv(a[u])%Mod);         a[u]=aa;         printf("%lldn",tr[1].sum%Mod);     }      return 0; } 

发表评论

评论已关闭。

相关文章