校内模拟赛题,倒在了正解前最后一步 qwq。
解题思路
首先,发现题目要求的东西很不好做。于是转化一下,考虑计算每条边对答案贡献了几次。
这样问题就转化成了求有多少个区间的点分布在一条边两端的两个子树中。
发现这个东西还是不好求。于是正难则反,考虑计算区间的点全部在同一个子树里的区间数量。
显然,这个数量可以通过 set 维护子树内外点的编号所形成的连续段实现。但实现起来细节较多,需要注意各种边界问题。
如果暴力地对每棵子树都把节点挨个 insert 一下去算这个东西,复杂度就是 (O(n^2 log n)) 的,会直接 T 飞。
我们使用树上启发式合并优化这个过程。时间复杂度 (O(n log^2 n))。
Code
#include <bits/stdc++.h> #define rep(i,a,b) for(int i(a);i<b;++i) #define rept(i,a,b) for(int i(a);i<=b;++i) #define fi first #define se second #define int long long using namespace std; constexpr int N=3e5+5,P=1e9+7; vector<pair<int,int>> g[N]; int n,tim; int ans; int l[N],r[N],siz[N],up[N],ch[N],rk[N]; signed maintenance_costs_sum(vector<signed> U,vector<signed> V,vector<signed> W); inline int calc(int x){ return x*(x+1)/2%P; } struct Set{ set<pair<int,int>> s; int tot; void clear(){ s.clear(); s.emplace(1,n); tot=calc(n); } void insert(int x){ auto it=prev(s.lower_bound({x,n+1})); int l=it->fi,r=it->se; if(l<x&&x<r){ (tot-=calc(r-l+1))%=P; (tot+=calc(1))%=P; (tot+=calc(x-l))%=P; (tot+=calc(r-x))%=P; s.erase(it); s.emplace(l,x-1); s.emplace(x+1,r); return; } int lm=it==s.begin()?0:prev(it)->se; int rm=next(it)==s.end()?n+1:next(it)->fi; if(l==r){ (tot-=calc(l-lm-1))%=P; (tot-=calc(rm-r-1))%=P; (tot-=calc(1))%=P; (tot+=calc(rm-lm-1))%=P; s.erase(it); return; } if(l==x){ (tot-=calc(r-l+1))%=P; (tot-=calc(l-lm-1))%=P; (tot+=calc(r-l))%=P; (tot+=calc(l-lm))%=P; s.erase(it); s.emplace(l+1,r); return; } if(r==x){ (tot-=calc(r-l+1))%=P; (tot-=calc(rm-r-1))%=P; (tot+=calc(r-l))%=P; (tot+=calc(rm-r))%=P; s.erase(it); s.emplace(l,r-1); } } }st; void add(int l,int r){ rept(i,l,r){ st.insert(rk[i]); } } void dfs(int u,int pre){ l[u]=++tim,rk[l[u]]=u,siz[u]=1; for(auto [v,w]:g[u]){ if(v==pre) continue; up[v]=w; dfs(v,u); siz[u]+=siz[v]; if(siz[ch[u]]<siz[v]) ch[u]=v; } r[u]=tim; } void dsu(int u,int pre){ st.clear(); for(auto [v,w]:g[u]){ if(v==pre||v==ch[u]) continue; dsu(v,u); st.clear(); } if(ch[u]) dsu(ch[u],u); for(auto [v,w]:g[u]){ if(v==pre||v==ch[u]) continue; add(l[v],r[v]); } add(l[u],l[u]); (ans+=up[u]*(calc(n)-st.tot))%=P; } signed maintenance_costs_sum(vector<signed> U,vector<signed> V,vector<signed> W){ n=U.size()+1; rep(i,0,n-1){ int u=U[i]+1,v=V[i]+1,w=W[i]; g[u].emplace_back(v,w); g[v].emplace_back(u,w); } dfs(1,0); dsu(1,0); return (ans+P)%P; }