随机爬树题解
(n^2) 暴力:
思路:
- 求期望,即求所有点的权值乘上概率后的和,即:
[ans=sum_{u in V}{P_u a_u} ]
-
求每个点的概率 (P_u) :
- 由题,令走到父亲的概率为 (P_f),走到儿子 (s) 的概率则为 (P_f times frac{w_s}{sum_f})(其中 (sum_f) 为 (f) 所有儿子的 (w) 之和)。
-
统计答案:
- 记 (ans_u) 表示 (u) 子树(不含 (u) 本身)的答案之和,最终答案为 (ans_1+a_1)。
- 暴力修改,跑 DFS 暴力求和即可。
代码:
//n^2暴力 60pts #include<iostream> #include<cstdio> #include<cstring> #include<vector> using namespace std; typedef long long ll; const int N=1e5+5,Mod=998244353; int n,q,fa[N]; ll sum[N],inv[N],w[N],a[N]; ll p[N],ans[N]; vector <int> e[N]; ll qpow(ll a,ll b){ ll res=1; while(b){ if(b&1) (res*=a)%=Mod; (a*=a)%=Mod; b>>=1; } return res%Mod; } void dfs(int u){ ans[u]=0; inv[u]=qpow(sum[u],Mod-2); for(int i=0;i<e[u].size();i++){ int v=e[u][i]; p[v]=p[u]*w[v]%Mod*inv[u]%Mod; dfs(v); (ans[u]+=(ans[v]+p[v]*a[v]%Mod)%Mod)%=Mod; } } int main(){ scanf("%d",&n); for(int i=2;i<=n;i++){ scanf("%d",&fa[i]); e[fa[i]].push_back(i); } for(int i=1;i<=n;i++){ scanf("%lld",&w[i]); (sum[fa[i]]+=w[i])%=Mod; } for(int i=1;i<=n;i++){ scanf("%lld",&a[i]); } p[1]=1; dfs(1); printf("%lldn",(ans[1]+a[1])%Mod); scanf("%d",&q); int u; ll ww,aa; for(int i=1;i<=q;i++){ scanf("%d",&u); sum[fa[u]]=(sum[fa[u]]-w[u]+Mod)%Mod; scanf("%lld%lld",&w[u],&a[u]); sum[fa[u]]=(sum[fa[u]]+w[u])%Mod; dfs(1); printf("%lldn",(ans[1]+a[1])%Mod); } return 0; }
优化后正解:
思路:
-
每次修改 (u) 只对 (f) 的整棵子树产生影响,故用线段树维护子树和。
-
考虑有哪些影响:
-
(w_u) 修改为 (ww),使得 (sum_f) 发生改变,故整棵子树的概率都会变化:
-
对于 (f) 子树中每个节点 (t)(不含 (f)),原概率为 (P_t=P_{fa_t}times frac{w_t}{sum_{fa_t}}),修改后变为 (P_t'=P_{fa_t}times frac{w_t}{sum_{fa_t}-w_u+ww})。
[P_t'=P_ttimesfrac{sum_{fa_t}}{sum_{fa_t}-w_u+ww} ]则有 (Delta P=frac{sum_{fa_t}}{sum_{fa_t}-w_u+ww})。
-
对于点 (u),原概率为 (P_u=P_ftimes frac{w_u}{sum_f}),修改后变为 (P_u'=P_f times frac{ww}{sum_f-w_u+ww})。
[P_u'=P_u times frac{ww}{wu} times Delta P ]则有 (Delta w=frac{ww}{wu})。
-
对于 (u) 子树中的每个点 (t)(不含 (u)),原概率为 (P_t=P_{fa_t}times frac{w_t}{sum_{fa_t}}),修改后变为 (P_t'=P_{fa_t}times Delta w times frac{w_t}{sum_{fa_t}})。
-
-
(a_u) 修改为 (aa),只对 (u) 的贡献产生影响,原贡献为 (ans=P_u times a_u),修改后变为 (ans'=P_u times aa),即:
[ans'=ans times frac{aa}{au} ]
-
-
综上,变化有:
- (P_t'=P_ttimesfrac{sum_{fa_t}}{sum_{fa_t}-w_u+ww});
- (P_u'=P_u times frac{ww}{wu} times frac{sum_{fa_t}}{sum_{fa_t}-w_u+ww});
- (P_t'=P_{fa_t}times frac{ww}{wu} times frac{w_t}{sum_{fa_t}});
- (ans'=ans times frac{aa}{au})。
其中操作 (1)、(2) 与 (2)、(3) 可以合并。
代码:
#include<iostream> #include<cstdio> #include<vector> using namespace std; typedef long long ll; const int N=1e5+5,Mod=998244353; int n,q,fa[N]; ll sum[N],w[N],a[N],p[N],val[N]; int tid,dfn[N],siz[N]; vector <int> e[N]; struct Tree{ ll sum,tag; }tr[N<<2]; ll qpow(ll a,ll b){ ll res=1; while(b){ if(b&1) (res*=a)%=Mod; (a*=a)%=Mod; b>>=1; } return res%Mod; } ll inv(ll x){ return qpow(x,Mod-2); } void dfs(int u){ dfn[u]=++tid; val[tid]=p[u]*a[u]%Mod; siz[u]=1; for(int i=0;i<e[u].size();i++){ int v=e[u][i]; p[v]=p[u]*w[v]%Mod*inv(sum[u])%Mod; dfs(v); siz[u]+=siz[v]; } } void update(int k){ tr[k].sum=(tr[k<<1].sum+tr[k<<1|1].sum)%Mod; } void pushdown(int k,int l,int r){ tr[k<<1].sum=tr[k<<1].sum*tr[k].tag%Mod; tr[k<<1].tag=tr[k<<1].tag*tr[k].tag%Mod; tr[k<<1|1].sum=tr[k<<1|1].sum*tr[k].tag%Mod; tr[k<<1|1].tag=tr[k<<1|1].tag*tr[k].tag%Mod; tr[k].tag=1; } void build(int k,int l,int r){ tr[k].sum=0; tr[k].tag=1; if(l==r){ tr[k].sum=val[l]%Mod; return ; } int mid=(l+r)>>1; build(k<<1,l,mid); build(k<<1|1,mid+1,r); update(k); } void modify(int k,int l,int r,int x,int y,int v){ if(x<=l&&r<=y){ tr[k].sum=tr[k].sum*v%Mod; tr[k].tag=tr[k].tag*v%Mod; return ; } pushdown(k,l,r); int mid=(l+r)>>1; if(x<=mid) modify(k<<1,l,mid,x,y,v); if(mid<y) modify(k<<1|1,mid+1,r,x,y,v); update(k); } int main(){ //freopen("climb2.in","r",stdin); //freopen("climb.out","w",stdout); scanf("%d",&n); for(int i=2;i<=n;i++){ scanf("%d",&fa[i]); e[fa[i]].push_back(i); } for(int i=1;i<=n;i++){ scanf("%lld",&w[i]); sum[fa[i]]=(sum[fa[i]]+w[i])%Mod; } for(int i=1;i<=n;i++){ scanf("%lld",&a[i]); } p[1]=1; dfs(1); build(1,1,n); printf("%lldn",tr[1].sum%Mod); scanf("%d",&q); int u; ll ww,aa; for(int i=1;i<=q;i++){ scanf("%d%lld%lld",&u,&ww,&aa); if(fa[u]){ /* 1.f子树(除f本身):Pt=Pf*(wt/sumf)* (1/(sumf-wu+ww))*sumf 2.u:pu=pf*(wu/sumf)*(ww/wu)* (1/(sumf-wu+ww))*sumf 3.u子树(除u本身):Pt=Pft* (ww/wu) *(wt/sumft) */ //1. 2. 修改f子树(除f本身) modify(1,1,n,dfn[fa[u]]+1,dfn[fa[u]]+siz[fa[u]]-1,inv(((sum[fa[u]]-w[u]+ww)%Mod+Mod)%Mod)%Mod*sum[fa[u]]%Mod); //2. 3. modify(1,1,n,dfn[u],dfn[u]+siz[u]-1,ww*inv(w[u])%Mod); sum[fa[u]]=((sum[fa[u]]-w[u]+ww)%Mod+Mod)%Mod; } w[u]=ww; //修改au:Pu*au* (aa/au) modify(1,1,n,dfn[u],dfn[u],aa*inv(a[u])%Mod); a[u]=aa; printf("%lldn",tr[1].sum%Mod); } return 0; }