Description
给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
Input
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
Output
M行,表示每个询问的答案。
Solution
二分套树或树套树别想了,没戏。
这里有一个常用trick:
两点路径上点权和=sum[x]+sum[y]-sum[lca(x,y)]-sum[fa[lca(x,y)]]
可以自己想想为什么。
那么按树结构建主席树,查询直接这样算即可。
Code
#include<bits/stdc++.h>
using namespace std;
map<long long,int>mp;
map<int,long long>mp1;
struct Edge{
int v,nxt;
}e[400010];
int n,m,tot,sum,s,cnt;
int head[200010];
long long a[200010];
long long c[200010];
int fa[200010];
int id[200010];
int siz[200010];
int son[200010];
int dep[200010];
int top[200010];
int node[200010];
void addEdge(int u,int v){
tot++;
e[tot].v = v;
e[tot].nxt = head[u];
head[u] = tot;
}
inline void ad(int u,int v){
addEdge(u,v);
addEdge(v,u);
}
int tree[200010*40];
int ls[200010*40];
int rs[200010*40];
int rt[200010];
int build(int l,int r){
int id=++cnt;
int mid=(l+r)>>1;
if(l==r) return id;
ls[id]=build(l,mid);
rs[id]=build(mid+1,r);
return id;
}
int update(int pre,int l,int r,int x){
int id=++cnt;
ls[id]=ls[pre];
rs[id]=rs[pre];
tree[id]=tree[pre]+1;
if(l==r) return id;
int mid=(l+r)>>1;
if(x<=mid) ls[id]=update(ls[pre],l,mid,x);
else rs[id]=update(rs[pre],mid+1,r,x);
return id;
}
int query(int LT,int RT,int lca,int lcaf,int l,int r,int k){
if(l>=r) return l;
int mid=(l+r)>>1;
int sum=tree[ls[LT]]+tree[ls[RT]]-tree[ls[lca]]-tree[ls[lcaf]];
if(sum>=k) return query(ls[LT],ls[RT],ls[lca],ls[lcaf],l,mid,k);
else return query(rs[LT],rs[RT],rs[lca],rs[lcaf],mid+1,r,k-sum);
}
void dfs1(int u,int father,int depth){
rt[u]=update(rt[fa[u]],1,s,mp[a[u]]);
//cout<<u<<" "<<fa[u]<<endl;
siz[u]=1;
fa[u]=father;
dep[u]=depth;
int maxn=-1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==father) continue;
fa[v]=u;
dfs1(v,u,depth+1);
siz[u]+=siz[v];
if(siz[v]>maxn){
maxn=siz[v];
son[u]=v;
}
}
}
void dfs2(int u,int ltop){
top[u]=ltop;
if(!son[u]) return;
dfs2(son[u],ltop);
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==fa[u]||v==son[u])
continue;
dfs2(v,v);
}
}
int lca(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
return x;
}
int main(){
int op,x,y,z;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
scanf("%lld",&a[i]);
c[i]=a[i];
}
sort(c+1,c+n+1);
for(int i=1;i<=n;i++){
if(c[i]!=c[i-1]) s++;
mp[c[i]]=s;
mp1[s]=c[i];
}
for(int i=1;i<n;i++){
scanf("%d%d",&x,&y);
ad(x,y);
}
rt[0]=build(1,s);
dfs1(1,-1,1);
dfs2(1,1);
//cout<<"pre:ok\n";
int las=0;
for(int i=1;i<=m;i++){
scanf("%d%d%d",&x,&y,&z); x^=las;
printf("%lld\n",las=mp1[query(rt[x],rt[y],rt[lca(x,y)],rt[fa[lca(x,y)]],1,s,z)]);
}
}