整体二分学习笔记

Upd:增加折叠代码框。

整体二分学习笔记

整体二分,就是对所有的操作进行一个整体的二分答案,需要数据结构题满足以下性质:

  • 询问的答案具有可二分性。
  • 修改对判定答案的贡献相对独立,修改之间互不影响效果。
  • 修改如果对判定答案有贡献,则贡献为一确定的与判定标准无关的值。
  • 贡献满足交换律、结合律,具有可加性。
  • 题目允许离线。

例题引入:P3332 [ZJOI2013]K大数查询

Solution1:忘了是啥线段树 套 忘了是啥线段树。

太麻烦了。。。。。。。

点击查看代码
#include<bits/stdc++.h> #include<bits/extc++.h> using namespace __gnu_pbds; #define int long long  // 权值线段树套区间和线段树  using namespace std;  const int Size=(1<<20)+1; char buf[Size],*p1=buf,*p2=buf; char buffer[Size]; int op1=-1; const int op2=Size-1; #define getchar()                                                               (tt == ss && (tt=(ss=In)+fread(In, 1, 1 << 20, stdin), ss == tt)      	? EOF                                                                  	: *ss++) char In[1<<20],*ss=In,*tt=In; inline int read() { 	int x=0,c=getchar(),f=0; 	for(;c>'9'||c<'0';f=c=='-',c=getchar()); 	for(;c>='0'&&c<='9';c=getchar()) 		x=(x<<1)+(x<<3)+(c^48); 	return f?-x:x; } inline void write(int x) { 	if(x<0) x=-x,putchar('-'); 	if(x>9)  write(x/10); 	putchar(x%10+'0'); }  int n,m; int root[1<<20],tot; struct Tree{// 	int lazy,cnt,lp,rp; }tree2[(int)2e7];  const int MINN=0,MAXN=50001;  void pushdown(int l,int mid,int r,Tree &p) { 	if(!p.lazy) return; 	if(l==r) return; 	if(!p.lp) p.lp=++tot; 	if(!p.rp) p.rp=++tot; 	tree2[p.lp].cnt+=(mid-l+1)*p.lazy; 	tree2[p.rp].cnt+=(r-mid)*p.lazy; 	tree2[p.lp].lazy+=p.lazy; 	tree2[p.rp].lazy+=p.lazy; 	p.lazy=0; }  void add2(int l,int r,int sl,int sr,int &p) { 	if(!p) p=++tot; 	if(sl<=l&&r<=sr) 	// if(l==r) 	{ 		tree2[p].cnt+=r-l+1; 		tree2[p].lazy++; 		return; 	} 	int mid=(l+r)>>1; 	pushdown(l,mid,r,tree2[p]); 	if(sl<=mid) add2(l,mid,sl,sr,tree2[p].lp); 	if(sr>mid) add2(mid+1,r,sl,sr,tree2[p].rp); 	tree2[p].cnt=tree2[tree2[p].lp].cnt+tree2[tree2[p].rp].cnt; }  void add(int l,int r,int sl,int sr,int x,int p) { 	add2(1,n,sl,sr,root[p]); 	if(l==r) return; 	int mid=(l+r)>>1,lp=(p<<1),rp=(p<<1)|1; 	if(x<=mid) add(l,mid,sl,sr,x,lp); 	else add(mid+1,r,sl,sr,x,rp); }  int query_sum(int l,int r,int sl,int sr,int &p) { 	if(!p) return 0; 	if(sl<=l&&r<=sr) return tree2[p].cnt; 	int mid=(l+r)>>1,ans=0; 	pushdown(l,mid,r,tree2[p]); 	if(sl<=mid) ans=query_sum(l,mid,sl,sr,tree2[p].lp); 	if(sr>mid) ans+=query_sum(mid+1,r,sl,sr,tree2[p].rp); 	return ans; }  int maxkth(int l,int r,int sl,int sr,long long k,int p) { 	if(l==r) return l; 	int mid=(l+r)>>1; 	int lp=(p<<1),rp=(p<<1)|1; 	long long nw=query_sum(1,n,sl,sr,root[rp]); 	if(nw>=k) return maxkth(mid+1,r,sl,sr,k,rp); 	else return maxkth(l,mid,sl,sr,k-nw,lp); }  struct Ask{ 	int op,l,r; 	long long c; }ask[50005]; gp_hash_table<int,int>mp; int a[50005],cnta; int id[50005]; signed main() { 	//mt19937_64 myrand(time(0)); 	n=read(); 	m=read(); 	for(int i=1;i<=m;i++) 	{ 		int op=read(),l=read(),r=read(),c=read(); 		ask[i]={op,l,r,c}; 		if(op==1) a[++cnta]=c; 	} 	sort(a+1,a+1+cnta); 	a[0]=-50001; 	int cntb=0; 	for(int i=1;i<=cnta;i++) 		if(a[i]!=a[i-1]) 		{ 			++cntb; 			mp[a[i]]=cntb; 			id[cntb]=a[i]; 		}   	for(int i=1;i<=m;i++) 	{ 		int op=ask[i].op,l=ask[i].l,r=ask[i].r; 		if(op==1) add(1,cnta,l,r,mp[ask[i].c],1); 		if(op==2) cout<<id[maxkth(1,cnta,ask[i].l,ask[i].r,ask[i].c,1)]<<"n"; 	} 	return 0; } 

Solution2:整体二分。

我们考虑对于一个询问怎么朴素二分求解。

假设询问区间 ([l,r]) 中的第 (k) 大。

有一个想法是,我们二分可能的答案 (mid),再暴力将 ([l,r])(> mid) 的数设为 (1)(le mid) 的数设为 (0),统计 ([l,r])(1) 的个数,设为 (cnt),如果 (cntge k),表明答案一定 (>mid),反之答案一定 (le mid)。然后缩小二分的区间,继续进行上述操作。

至多 (O(log V)) 次二分后,答案一定确定,此时单次二分复杂度最多 (O(n log V))

然后考虑 (q) 组询问。我们仍尝试根据上述过程求解答案。

我们尝试设计一个二分框架 (operatorname{solve}(ql,qr,L,R)),表示:

  • 操作序列 ([ql,qr]) 的答案在值域 ([L,R]) 中。假设当前值域中点为 (mid)

具体地:

对于添数操作,表示添加的值 (c) 在值域 ([L,R]) 中。
对于查询操作,表示该查询的答案在值域 ([L,R]) 中。

判定完当前贡献后,将操作序列分成答案在值域 ([l,mid])([mid+1,r]) 两部分,分别向下递归求解。

注意边界问题。

考虑如何判定当前贡献。

我们仍延续单个询问二分的过程,对操作区间 ([ql,qr])(c > mid) 的区间添数操作,将其视为区间添数 (1),反之则视为 (0)。我们需要一个数据结构来支持区间添数 (1)(可以视为区间加 (1)),区间查询 (1) 的个数,自然想到要用线段树来维护这个过程。

然后我们先从左到右扫一遍操作序列。再开两个临时操作序列 (a1,a2),分别存答案在值域 ([L,mid])([mid+1,R]) 时的操作序列。

对于添数操作,若添加的值 $c in [mid+1,R] $,那么在线段树上区间加 (1),并将该操作存进 (a_2)。反之存进 (a_1)


Sol1:

对于查询操作,直接在线段树上查询该询问 (id) 的询问区间 ([l_{id},r_{id}])(1) 的个数,设为 (nw),若 (nw+tmp_{id}>k)(tmp) 一回再说),那么该询问答案一定在值域 ([mid+1,R]),将该询问存进 (a_2) 中,反之令 (tmp_{id}+nw to tmp_{id}),并将该询问存进 (a_1)

你发现,(tmp_{id}) 作用是记录第 (id) 个询问区间中所有 (>) 当前二分值域上界 (R) 的数的个数。通过这种类似主席树的方法,我们可以严格保证一个操作只会被分进 (a_1)(a_2) 其中之一。

扫完操作序列 ([ql,qr]) 后,需要清空线段树,简单实现是扫一遍 (a_2),将其中所有的添数操作所对应的区间减 (1) 即可。

假设现在 (a_1) 中有 (cnt_1) 个操作。我们将 (a_1) 序列拷贝到 (a) 序列下标 ([ql,ql+cnt_1)) 上,将 (a_2) 拷贝到 (a) 序列下标 ([ql+cnt,qr]) 上。


Sol2:

假了。


递归求解 (operatorname{solve}(ql,ql+cnt_1-1,L,mid))(operatorname{solve}(ql+cnt_1,qr,mid+1,R))

二分的边界条件:当 (L=R) 时,操作区间 ([ql,qr]) 中的所有询问的答案为 (L)

时间复杂度分析:

每个操作最多被遍历 (O(log V)) 次,每次遍历一个操作的复杂度为线段树的 (O(log n))

总时间复杂度 (O(m log V log n)),本题 (n,m,V) 同阶,可以看成 (O(n log^2 n)),可以通过。

Code:

点击查看代码
#include<bits/stdc++.h> #define int long long  using namespace std;  const int Size=(1<<20)+1; char buf[Size],*p1=buf,*p2=buf; char buffer[Size]; int op1=-1; const int op2=Size-1; #define getchar()                                                               (tt == ss && (tt=(ss=In)+fread(In, 1, 1 << 20, stdin), ss == tt)      	? EOF                                                                  	: *ss++) char In[1<<20],*ss=In,*tt=In; inline int read() { 	int x=0,c=getchar(),f=0; 	for(;c>'9'||c<'0';f=c=='-',c=getchar()); 	for(;c>='0'&&c<='9';c=getchar()) 		x=(x<<1)+(x<<3)+(c^48); 	return f?-x:x; } inline void write(int x) { 	if(x<0) x=-x,putchar('-'); 	if(x>9)  write(x/10); 	putchar(x%10+'0'); }  const int N=5e4+5; int n,m; int laz[N<<2],sum[N<<2]; #define lp (p<<1) #define rp ((p<<1)|1)  void pushdown(int p,int l,int mid,int r) { 	if(!laz[p]) return; 	laz[lp]+=laz[p]; 	laz[rp]+=laz[p]; 	sum[lp]+=(mid-l+1)*laz[p]; 	sum[rp]+=(r-mid)*laz[p]; 	laz[p]=0; }  void pushup(int p) { 	sum[p]=sum[lp]+sum[rp]; }  void add(int l,int r,int sl,int sr,int k,int p) { 	if(l>r||sl>sr) return; 	if(sl<=l&&r<=sr) { laz[p]+=k; sum[p]+=k*(r-l+1); return; } 	int mid=(l+r)>>1; 	pushdown(p,l,mid,r); 	if(sl<=mid) add(l,mid,sl,sr,k,lp); 	if(sr>mid) add(mid+1,r,sl,sr,k,rp); 	pushup(p); }  int query(int l,int r,int sl,int sr,int p) { 	if(l>r||sl>sr) return 0; 	if(sl<=l&&r<=sr) return sum[p]; 	int mid=(l+r)>>1,ans=0; 	pushdown(p,l,mid,r); 	if(sl<=mid) ans+=query(l,mid,sl,sr,lp); 	if(sr>mid) ans+=query(mid+1,r,sl,sr,rp); 	pushup(p); 	return ans; }  struct Node{ 	int op,x,y,z,id; }a[N<<1],a1[N<<1],a2[N<<1]; int ans[N],tmp[N];  void solve(int ql,int qr,int l,int r) { 	if(l>r||ql>qr) return; 	if(l==r) 	{ 		for(int i=ql;i<=qr;i++) if(a[i].op==2) ans[a[i].id]=l; 		return; 	} 	int mid=(l+r)>>1; 	int cnt1=0,cnt2=0; 	for(int i=ql;i<=qr;i++) 	{ 		if(a[i].op==1) 		{ 			if(a[i].z>mid) add(1,n,a[i].x,a[i].y,1,1),a2[++cnt2]=a[i]; 			else a1[++cnt1]=a[i]; 		} 		else 		{ 			int nw=query(1,n,a[i].x,a[i].y,1); 			if(nw+tmp[a[i].id]>=a[i].z) a2[++cnt2]=a[i]; 			else a1[++cnt1]=a[i],tmp[a[i].id]+=nw; 		} 	} 	 	for(int i=1;i<=cnt2;i++) 	if(a2[i].op==1) add(1,n,a2[i].x,a2[i].y,-1,1);  	for(int i=1;i<=cnt1;i++) a[i+ql-1]=a1[i]; 	for(int i=1;i<=cnt2;i++) a[qr-cnt2+i]=a2[i];  	// cout<<"["<<l<<","<<r<<"] mid="<<mid<<" cntl="<<cnt1<<" cntr="<<cnt2<<"n"; 	solve(ql,ql+cnt1-1,l,mid); 	solve(ql+cnt1,qr,mid+1,r); }  signed main() { 	// freopen("P3332.in","r",stdin); 	// freopen("P3332.out","w",stdout); 	n=read(); 	m=read(); 	for(int i=1;i<=m;i++) 	{ 		int op=read(),x=read(),y=read(),z=read(); 		a[i]={op,x,y,z,i}; 	} 	solve(1,m,-n-1,n+1); 	for(int i=1;i<=m;i++) 	if(ans[i]) write(ans[i]),putchar('n'); 	//mt19937_64 myrand(time(0)); 	return 0; }  

练习题 1:P4175 [CTSC2008] 网络管理

这题存在单点修改操作,而且不像例题那么好做。

上树了,所以我们先需要剖树。支持单点修改区间查询,可以用线段树或树状数组来维护这个操作。建议用线段树,因为可以将单点修改转为单点覆盖,实现时比较简单。

先把初始点权视为修改以简化代码。

遇到单点修改时,将其拆成两个操作,分别表示删除原数和加入新数两个部分。

然后还是一样的思路,注意这题修改的判定有两个:修改的目标值 (>mid)当前操作是删数操作。也需要 (tmp) 数组记录询问路径上 (>R) 的点的个数。

时间复杂度 (O(n log^3 n))

点击查看代码
#include<bits/stdc++.h> // #define int long long  using namespace std;  const int Size=(1<<20)+1; char buf[Size],*p1=buf,*p2=buf; char buffer[Size]; int k1=-1; const int k2=Size-1; #define getchar()                                                               (tt == ss && (tt=(ss=In)+fread(In, 1, 1 << 20, stdin), ss == tt)      	? EOF                                                                  	: *ss++) char In[1<<20],*ss=In,*tt=In; inline int read() { 	int x=0,c=getchar(),f=0; 	for(;c>'9'||c<'0';f=c=='-',c=getchar()); 	for(;c>='0'&&c<='9';c=getchar()) 		x=(x<<1)+(x<<3)+(c^48); 	return f?-x:x; } inline void write(int x) { 	if(x<0) x=-x,putchar('-'); 	if(x>9)  write(x/10); 	putchar(x%10+'0'); }  int n,q; const int N=2.5e5+5;  struct Node{ 	int k,x,y,id,op; }a[N],a1[N],a2[N]; int m; int c[N]; int tmp[N]; int ans[N]; const int inf=1e8+1;  #define lp (p<<1) #define rp ((p<<1)|1) int sum[N<<2]; vector<int> E[N]; int tot,dfn[N],id[N],top[N],siz[N],son[N],dep[N],fa[N]; void dfs1(int p,int f) { 	fa[p]=f; 	siz[p]=1; 	dep[p]=dep[f]+1; 	for(int to:E[p]) 	{ 		if(to==f) continue; 		dfs1(to,p); 		siz[p]+=siz[to]; 		if(siz[to]>siz[son[p]]) son[p]=to; 	} }  void dfs2(int p,int tp) { 	dfn[p]=++tot; 	top[p]=tp; 	id[tot]=p; 	if(son[p]) dfs2(son[p],tp); 	for(int to:E[p]) 	if(!dfn[to]) dfs2(to,to); } void pushup(int p) { 	sum[p]=sum[lp]+sum[rp]; }  void add(int l,int r,int x,int k,int p) { 	if(l==r) { sum[p]=k; return; } 	int mid=(l+r)>>1; 	if(x<=mid) add(l,mid,x,k,lp); 	else add(mid+1,r,x,k,rp); 	pushup(p); }  int query(int l,int r,int sl,int sr,int p) { 	if(sl<=l&&r<=sr) return sum[p]; 	int mid=(l+r)>>1,ans=0; 	if(sl<=mid) ans+=query(l,mid,sl,sr,lp); 	if(sr>mid) ans+=query(mid+1,r,sl,sr,rp); 	return ans; }  int query(int u,int v) { 	int ans=0; 	while(top[u]!=top[v]) 	{ 		if(dep[top[v]]>dep[top[u]]) swap(u,v); 		ans+=query(1,n,dfn[top[u]],dfn[u],1); 		u=fa[top[u]]; 	} 	if(dep[u]>dep[v]) swap(u,v); 	return ans+query(1,n,dfn[u],dfn[v],1); }  void solve(int ql,int qr,int L,int R) {  	if(ql>qr||L>R) return; 	if(L==R) 	{ 		for(int i=ql;i<=qr;i++) 		if(a[i].k!=0) ans[a[i].id]=L; 		return; 	} 	int mid=(L+R)>>1,cnt1=0,cnt2=0; 	for(int i=ql;i<=qr;i++) 	{ 		if(a[i].k==0) 		{ 			if(a[i].y>mid) a2[++cnt2]=a[i]; 			else a1[++cnt1]=a[i]; 			if(a[i].op==0||a[i].y>mid) add(1,n,dfn[a[i].x],a[i].op,1); 		} 		else 		{ 			int nw=query(a[i].x,a[i].y); 			// cout<<"!!! nw="<<nw<<"n"; 			if(nw+tmp[a[i].id]>=a[i].k) a2[++cnt2]=a[i]; 			else tmp[a[i].id]+=nw,a1[++cnt1]=a[i]; 		} 	} 	for(int i=1;i<=cnt2;i++) if(a2[i].k==0) add(1,n,dfn[a2[i].x],0,1); 	for(int i=1;i<=cnt1;i++) a[ql+i-1]=a1[i]; 	for(int i=1;i<=cnt2;i++) a[ql+cnt1+i-1]=a2[i]; 	solve(ql,ql+cnt1-1,L,mid); 	solve(ql+cnt1,qr,mid+1,R); }  signed main() { 	#ifndef ONLINE_JUDGE 	freopen("P4175.in","r",stdin); 	freopen("P4175.out","w",stdout); 	#endif 	for(int i=0;i<N;i++) ans[i]=inf; 	n=read(); 	q=read(); 	for(int i=1;i<=n;i++) c[i]=read(),a[++m]={0,i,c[i],0,1}; 	for(int i=1;i<n;i++) 	{ 		int u=read(),v=read(); 		E[u].push_back(v); 		E[v].push_back(u); 	} 	dfs1(1,0); 	dfs2(1,1); 	for(int i=1;i<=q;i++) 	{ 		int k=read(),x=read(),y=read(); 		if(k==0) a[++m]={k,x,c[x],i,0},c[x]=y; // !!!!!!! 		a[++m]={k,x,y,i,1}; 	} 	solve(1,m,-inf,inf); 	for(int i=1;i<=m;i++) 	if(ans[i]!=inf) 	{ 		if(ans[i]<-1e8||ans[i]>1e8) puts("invalid request!"); 		else write(ans[i]),putchar('n'); 	} 	return 0; } 

练习题 2:P4602 [CTSC2018] 混合果汁

很好的一道题。

发现题目给出的限制很多。

我们仍考虑使用整体二分求解。

整体二分答案的美味度。假设当前答案值域 ([L,R])

美味度 (in [mid+1,R]) 的那些果汁加进一个 ds 里。

然后查可行性可以转化为这个 ds 里前 (k) 小的值的和(最小花费),设为 (c),与这个人最多可以花的钱,设为 (w),做一个比较,如果 (wge c),那么这个询问的答案 (in [mid+1,R])。反之 (in [L,mid])。这样我们就将操作答案分为两部分,可以递归求解。

这个 ds 显然可以是权值线段树,每个节点记录所管辖区间的果汁总个数和总价钱,查询时直接二分即可。

发现在当前 ds 中的果汁的美味度其实是 ([mid+1,R] cup [R+1,+infty]),因为后者也可以带来贡献。

那么我们修正一下整体二分过程,遍历完当前 ([ql,qr]) 的所有操作后,不执行线段树删除操作,直接递归求解答案值域在 ([L,mid]) 的那些操作,结束后再撤销当前 ([ql,qr]) 对线段树的所有操作,然后递归求解答案值域在 ([mid+1,R]) 的那些操作。

注意无解时线段树边界情况。

时间复杂度 (O(n log n log V)),可以极限通过本题。可以离散化果汁的美味度,时间复杂度可以降为 (O(n log^2 n))

点击查看代码
#include<bits/stdc++.h> #define int long long  using namespace std;  const int Size=(1<<20)+1; char buf[Size],*p1=buf,*p2=buf; char buffer[Size]; int op1=-1; const int op2=Size-1; #define getchar()                                                               (tt == ss && (tt=(ss=In)+fread(In, 1, 1 << 20, stdin), ss == tt)      	? EOF                                                                  	: *ss++) char In[1<<20],*ss=In,*tt=In; inline int read() { 	int x=0,c=getchar(),f=0; 	for(;c>'9'||c<'0';f=c=='-',c=getchar()); 	for(;c>='0'&&c<='9';c=getchar()) 		x=(x<<1)+(x<<3)+(c^48); 	return f?-x:x; } inline void write(int x) { 	if(x<0) x=-x,putchar('-'); 	if(x>9)  write(x/10); 	putchar(x%10+'0'); }  const int N=1e5+5; int n,m;  struct Node{ 	// op, a, b, c, id 	// 0, d, p, l, 0 	// 1, g, L, 0, id 	int op,a,b,c,id; }a[N<<1],a1[N<<1],a2[N<<1]; int ans[N]; int tot;  const int MINN=-1,MAXN=1e18+1; struct Tr{ 	int cnt,sum,lp,rp; }t[N*64]; int root;  void pushup(int p) { 	t[p].cnt=t[t[p].lp].cnt+t[t[p].rp].cnt; 	t[p].sum=t[t[p].lp].sum+t[t[p].rp].sum; }  void add(int l,int r,int x,int k,int &p) { 	if(!p) p=++tot; 	if(l==r) 	{ 		t[p].cnt+=k; 		t[p].sum+=k*x; 		return; 	} 	int mid=(l+r)>>1; 	if(x<=mid) add(l,mid,x,k,t[p].lp); 	else add(mid+1,r,x,k,t[p].rp); 	pushup(p); }  int query_kth(int l,int r,int k,int p) { 	if(l==r) return l==MAXN?MAXN+1:k*l; 	int ans=0,mid=(l+r)>>1; 	if(t[t[p].lp].cnt>=k) return query_kth(l,mid,k,t[p].lp); 	return t[t[p].lp].sum+query_kth(mid+1,r,k-t[t[p].lp].cnt,t[p].rp); }  void solve(int ql,int qr,int L,int R) { 	if(ql>qr||L>R) return; 	if(L==R) 	{ 		for(int i=ql;i<=qr;i++) if(a[i].op) ans[a[i].id]=L; 		return; 	} 	int mid=(L+R)>>1,cnt1=0,cnt2=0;  	for(int i=ql;i<=qr;i++) 	{ 		if(!a[i].op) 		{ 			if(a[i].a>mid) a2[++cnt2]=a[i],add(MINN,MAXN,a[i].b,a[i].c,root); 			else a1[++cnt1]=a[i]; 		}	 		else 		{ 			int nw=query_kth(MINN,MAXN,a[i].b,root); 			if(nw>a[i].a) a1[++cnt1]=a[i]; 			else a2[++cnt2]=a[i]; 		}	 	} 	for(int i=1;i<=cnt1;i++) a[ql+i-1]=a1[i]; 	for(int i=1;i<=cnt2;i++) a[ql+cnt1+i-1]=a2[i]; 	solve(ql,ql+cnt1-1,L,mid); 	for(int i=ql;i<=qr;i++) if(!a[i].op&&a[i].a>mid) add(MINN,MAXN,a[i].b,-a[i].c,root); 	solve(ql+cnt1,qr,mid+1,R); }  signed main() { 	// #ifndef ONLINE_JUDGE 	// freopen("P4602.in","r",stdin); 	// freopen("P4602.out","w",stdout); 	// #endif 	//mt19937_64 myrand(time(0)); 	 	n=read(); 	m=read(); 	for(int i=1;i<=n;i++) 	{ 		int d=read(),p=read(),l=read(); 		a[++tot]={0,d,p,l,0}; 	} 	for(int i=1;i<=m;i++) 	{ 		int g=read(),l=read(); 		a[++tot]={1,g,l,0,i}; 	} 	int nww=tot; 	tot=0; 	add(MINN,MAXN,MAXN,1,root); 	solve(1,nww,-1,1e18+1); 	 	for(int i=1;i<=m;i++) write(ans[i]),putchar('n'); 	return 0; } 

练习题 3:CF868F Yet Another Minimization Problem

类似整体二分的分治方法优化决策单调性 dp。

思路就看这道题的题解吧,我也是看的题解。

这里说一下我的实现方式:

外层遍历 (k)

对于每一层:我们还是设一个 ({l,r,ql,qr}) 表示 dp 数组下标 (in [ql,qr]),决策点一定 (in [l,r])

(mid=frac{ql+qr}{2})

然后我们暴力遍历决策点区间尝试转移 (dp_{mid}),记录一个 (p) 表示决策点 (p) 可以使 (dp_{mid}) 最小。

由于决策单调性,dp 下标 (in [ql,mid)) 的决策点一定 (in [l,p]),dp 下标 (in (mid,qr]) 的决策点一定 (in [p,r])。继续分治求解即可。

边界是下标集合为空或决策点集合为空。

对于快速算一个区间的费用,可以用一个类似莫队的方法(可看成双指针)维护。

为了使复杂度正确,并且易于分析复杂度,可以使用 bfs 分治树的方法进行求解。

具体的,维护一个元素为 ({l,r,ql,qr}) 的队列,每次取出队首,还是按照上面说的方法做,但是将最后分治求解改为队列中依次加入元素 ({l,p,ql,mid-1})({p,r,mid+1,qr})

发现分治树深度为 (log n),对于每一层,对答案造成贡献的区间指针一定是向右移动,所以每一层复杂度为线性。

总时间复杂度 (O(kn log n)),常数可能较大。

点击查看代码
#include<bits/stdc++.h> #define int long long  using namespace std;  const int Size=(1<<20)+1; char buf[Size],*p1=buf,*p2=buf; char buffer[Size]; int op1=-1; const int op2=Size-1; #define getchar()                                                               (tt == ss && (tt=(ss=In)+fread(In, 1, 1 << 20, stdin), ss == tt)      	? EOF                                                                  	: *ss++) char In[1<<20],*ss=In,*tt=In; inline int read() { 	int x=0,c=getchar(),f=0; 	for(;c>'9'||c<'0';f=c=='-',c=getchar()); 	for(;c>='0'&&c<='9';c=getchar()) 		x=(x<<1)+(x<<3)+(c^48); 	return f?-x:x; } inline void write(int x) { 	if(x<0) x=-x,putchar('-'); 	if(x>9)  write(x/10); 	putchar(x%10+'0'); }  const int N=1e5+5; int n,k; int a[N]; int dp[N],dp2[N];  struct Node{ 	int l,r,ql,qr; 	// 决策 in [l,r] ,dp数组下标 in [ql,qr]; }; int cnt[N]; queue<Node> q; int nwl=1,nwr,res;  int calc(int l,int r) { 	while(nwr<r) res+=cnt[a[++nwr]]++; 	while(nwl>l) res+=cnt[a[--nwl]]++; 	while(nwr>r) res-=--cnt[a[nwr--]]; 	while(nwl<l) res-=--cnt[a[nwl++]]; 	return res; }  signed main() { 	memset(dp,0x3f,sizeof(dp)); 	memset(dp2,0x3f,sizeof(dp2));  	dp[0]=0; 	n=read(); 	k=read(); 	for(int i=1;i<=n;i++) a[i]=read(); 	for(int iiii=1;iiii<=k;iiii++) 	{ 		q.push(Node{1,n,1,n}); 		while(q.size()) 		{ 			Node nw=q.front(); 			q.pop(); 			if(nw.l>nw.r||nw.ql>nw.qr) continue; 			int mid=(nw.ql+nw.qr)>>1,p=0;  			for(int i=nw.l;i<=min(nw.r,mid);i++) 			{ 				int to=calc(i,mid)+dp[i-1]; 				if(to<dp2[mid]) { dp2[mid]=to,p=i; } 			} 			q.push({nw.l,p,nw.ql,mid-1}); 			q.push({p,nw.r,mid+1,nw.qr}); 		} 		for(int j=1;j<=n;j++) dp[j]=dp2[j]; 		memset(dp2,0x3f,sizeof(dp2)); 	} 	cout<<dp[n]<<"n"; 	return 0; }  
发表评论

评论已关闭。

相关文章