例题:bzoj3572/洛谷P3233
如果,只是考虑每个点被谁控制的话,用两个dfs就能够搞定了。
可是这题n和Q都很大,只有 ∑ m \sum m ∑m相对较小,因此我们应该考虑一种基本上只跟询问点有关的算法。
那就是虚树。
现在我们有若干询问点,一棵以1为根的树。为了能够将询问点按照原本的树上路径连成一棵新树,我们还需要将这些询问点按照dfs序排序后,相邻两个询问点之间的lca,这些点组成的就是虚树(姑且称虚树上的点为“神选之点标记点”)。最后我们把这些标记点按照原树里的辈分关系连成一棵新树,这就是美丽的原谅色虚树。
此题中,为了维护整棵树的信息,1号根节点也是需要的。
考虑一种较快的建立虚树的方法——利用栈维护一条链。
首先我们求一遍dfs序,按照dfs序处理询问点,开一个栈。设栈顶的点的节点叫做a,栈顶第二个点叫做b,当前我们想要将一个点x插入虚树。
我们在栈里面维护了一条链,栈顶是链上深度最大的节点。如果x在这个链上,操作很简单,直接将x推入栈中即可。否则,我们要让这个链“拐弯”朝向x。
所谓拐弯就是图1变成图2(图源xzy大佬)
- 首先我们求出a和x的lca,o。
- 然后我们比较o和b的深度关系。如果b的深度更深,由于我们是按照dfs序处理的,所以b就是a的父亲,连接b和a,然后将a从栈中弹出,继续向上去寻找x与这条链的“拐弯”之地。
否则,比较a和o的深度关系,如果a比较深,说明o是a的父亲,a出栈,退出(因为此时x和a在o的两个不同子树中,说明o就是那个“拐弯点”)。否则说明a=x,直接退出。 - 栈中点两两连边。
写成代码就是酱紫的:
if(bel[1]!=1) s[++top]=1;
for(int i=1;i<=m;++i) {
int x=a[i],o=0;
while(top) {
o=lca(x,s[top]);
if(top>1&&dep[s[top-1]]>dep[o]) add(s[top-1],s[top]),--top;
else if(dep[s[top]]>dep[o]) {add(o,s[top]),--top;break;}
else break;
}
if(s[top]!=o) s[++top]=o;
s[++top]=x;
}
while(top>1) add(s[top-1],s[top]),--top;
建完虚树后,我们还要处理这道题中的询问。
首先利用两遍dfs把虚树上的点被哪个询问点控制都搞一遍。
然后枚举虚树上的每一条边。
如果两边的标记点被同一个询问点控制,说明在原树上找到这两个标记点,他们中间的所有点都被这个询问点控制。
否则,我们可以用类似倍增的方法找到两个询问点控制点的“国界点”来分这条边以及这条边上连的没有标记点的子树。
最后可能还有一些并不处于两个标记点之间的点没有归宿。我们可以让离他们最近的标记点来代表他们,他们的控制点就是那个代表点的控制点。
写成代码就是这样(rem是这个点可以代表多少点,初始值为这个点的子树大小sz)
void cal(int x,int y) {
int tmp=y,mid=y;
for(int i=19;i>=0;--i) if(dep[f[tmp][i]]>dep[x]) tmp=f[tmp][i];
//tmp:x到y的原树路径上第二个点
rem[x]-=sz[tmp];//两个标记点之间的点不能被代表,y的另一边的点也不能被x代表
if(bel[x]==bel[y]) {ans[bel[x]]+=sz[tmp]-sz[y];return;}
for(int i=19;i>=0;--i) {//mid:国界点
int nxt=f[mid][i];
if(dep[nxt]<=dep[x]) continue;
int t1=dis(nxt,bel[x]),t2=dis(nxt,bel[y]);
if(t2<t1||(t1==t2&&bel[y]<bel[x])) mid=nxt;
}
ans[bel[x]]+=sz[tmp]-sz[mid];
ans[bel[y]]+=sz[mid]-sz[y];
}
至此,这道题被解决了,下面要做的就是抄hzwer代码写代码啦!
#include<bits/stdc++.h>
using namespace std;
int read() {
int q=0;char ch=' ';
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9') q=q*10+ch-'0',ch=getchar();
return q;
}
const int N=300005;
int n,Q,m,tot,now,top;
int a[N],b[N],bel[N],s[N],p[N],rem[N],ans[N];//bel:归谁管 rem:每个点可以代表多少个点
int h[N],ne[N<<1],to[N<<1],pos[N],f[N][21],dep[N],sz[N];
void add(int x,int y) {to[++tot]=y,ne[tot]=h[x],h[x]=tot;}
void dfs(int x,int las) {
pos[x]=++now,dep[x]=dep[las]+1,sz[x]=1,f[x][0]=las;
for(int i=1;i<=19;++i) f[x][i]=f[f[x][i-1]][i-1];
for(int i=h[x];i;i=ne[i])
if(to[i]!=las) dfs(to[i],x),sz[x]+=sz[to[i]];
}
int lca(int x,int y) {
if(dep[x]<dep[y]) swap(x,y);
for(int i=19;i>=0;--i) if(dep[f[x][i]]>=dep[y]) x=f[x][i];
if(x==y) return x;
for(int i=19;i>=0;--i) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
int cmp(int x,int y) {return pos[x]<pos[y];}
int dis(int x,int y) {return dep[x]+dep[y]-dep[lca(x,y)]*2;}
void dfs1(int x) {
p[++now]=x,rem[x]=sz[x];
for(int i=h[x];i;i=ne[i]) {
dfs1(to[i]);
if(!bel[to[i]]) continue;
int t1=dis(x,bel[x]),t2=dis(x,bel[to[i]]);
if(!bel[x]||t1>t2||(t1==t2&&bel[to[i]]<bel[x])) bel[x]=bel[to[i]];
}
}
void dfs2(int x) {
for(int i=h[x];i;i=ne[i]) {
int t1=dis(bel[to[i]],to[i]),t2=dis(bel[x],to[i]);
if(!bel[to[i]]||t1>t2||(t1==t2&&bel[to[i]]>bel[x])) bel[to[i]]=bel[x];
dfs2(to[i]);
}
}
void cal(int x,int y) {
int tmp=y,mid=y;
for(int i=19;i>=0;--i) if(dep[f[tmp][i]]>dep[x]) tmp=f[tmp][i];
rem[x]-=sz[tmp];
if(bel[x]==bel[y]) {ans[bel[x]]+=sz[tmp]-sz[y];return;}
for(int i=19;i>=0;--i) {
int nxt=f[mid][i];
if(dep[nxt]<=dep[x]) continue;
int t1=dis(nxt,bel[x]),t2=dis(nxt,bel[y]);
if(t2<t1||(t1==t2&&bel[y]<bel[x])) mid=nxt;
}
ans[bel[x]]+=sz[tmp]-sz[mid];
ans[bel[y]]+=sz[mid]-sz[y];
}
void work() {
tot=top=now=0;
m=read();
for(int i=1;i<=m;++i) a[i]=b[i]=read(),bel[a[i]]=a[i];
sort(a+1,a+1+m,cmp);
if(bel[1]!=1) s[++top]=1;
for(int i=1;i<=m;++i) {
int x=a[i],o=0;
while(top) {
o=lca(x,s[top]);
if(top>1&&dep[s[top-1]]>dep[o]) add(s[top-1],s[top]),--top;
else if(dep[s[top]]>dep[o]) {add(o,s[top]),--top;break;}
else break;
}
if(s[top]!=o) s[++top]=o;
s[++top]=x;
}
while(top>1) add(s[top-1],s[top]),--top;
dfs1(1),dfs2(1); //求出每个点被谁管
for(int j=1;j<=now;++j)
for(int i=h[p[j]];i;i=ne[i]) cal(p[j],to[i]);
for(int i=1;i<=now;++i) ans[bel[p[i]]]+=rem[p[i]];
for(int i=1;i<=m;++i) printf("%d ",ans[b[i]]);
puts("");
for(int i=1;i<=now;++i) h[p[i]]=bel[p[i]]=ans[p[i]]=0;//防止时间爆炸的清空方式
}
int main()
{
n=read();int x,y;
for(int i=1;i<n;++i) x=read(),y=read(),add(x,y),add(y,x);
dfs(1,0),Q=read();
for(int i=1;i<=n;++i) h[i]=0;
while(Q--) work();
return 0;
}