这题最关键的一点是要想清楚这个路径长度是咋算的。
要注意到他是走到所有关键点去取宝藏然后回来。我们假想这些关键点构成一棵虚树,那么这就相当于是虚树上的一个
d
f
s
dfs
dfs。
当关键点确定的时候,其实最短路径长就确定了,也就是在虚树上从一个点开始dfs直到到回到这个节点所经过的路径长之和。起点其实无所谓,不选虚树外面的点就行。
如果我们把虚树上的点按照dfs序从小到大排为
p
1
,
p
2
,
⋯
,
p
n
p_1,p_2,\cdots,p_n
p1,p2,⋯,pn,那么这时候询问的答案即为:
a
n
s
=
∑
i
=
1
n
−
1
d
i
s
(
p
i
,
p
i
+
1
)
+
d
i
s
(
p
n
,
p
1
)
ans=\sum_{i=1}^{n-1} {dis(p_i,p_{i+1})}+dis(p_n,p_1)
ans=i=1∑n−1dis(pi,pi+1)+dis(pn,p1)
算是计算树上路径和的一种套路吧。
把
d
f
s
dfs
dfs序从小到大排列,可以想象成是一个环,要求支持在这个环上加点,删点,求权值和。
发现这个东西用
s
e
t
set
set维护就行了。每次找到前驱后继把权值改一改即可。要注意一下插入的是
d
f
s
dfs
dfs序。
#include<bits/stdc++.h>
#define cs const
#define re register
#define ll long long
cs int N=1e5+10,Log=17;
int Head[N],Next[N<<1],V[N<<1],cnt=0;
int dfn[N],t[N],key[N],dep[N],f[N][Log],tot=0;
int n,m,x,y,pos;ll z,W[N<<1],dis[N];
std::set<int> S;
typedef std::set<int>::iterator It;
namespace IO{
cs int Rlen=1<<22|1;
char buf[Rlen],*p1,*p2;
inline char gc(){return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,Rlen,stdin),p1==p2)?EOF:*p1++;}
template<typename T>
inline T get(){
char ch;T x;
while(!isdigit(ch=gc()));x=ch^48;
while(isdigit(ch=gc())) x=((x+(x<<2))<<1)+(ch^48);
return x;
}
inline int gi(){return get<int>();}
inline ll gl(){return get<ll>();}
}
using namespace IO;
inline void add(int u,int v,ll w){Next[++cnt]=Head[u],V[cnt]=v,W[cnt]=w,Head[u]=cnt;}
inline void dfs(int u,int fa){
t[dfn[u]=++tot]=u,f[u][0]=fa,dep[u]=dep[fa]+1;
for(int re i=1;i<Log;++i) f[u][i]=f[f[u][i-1]][i-1];
for(int re i=Head[u],v=V[i];i;v=V[i=Next[i]])
if(v!=fa) dis[v]=dis[u]+W[i],dfs(v,u);
}
inline int lca(int u,int v){
if(dep[u]<dep[v]) std::swap(u,v);
for(int re i=Log-1;~i;--i) if(dep[f[u][i]]>=dep[v]) u=f[u][i];
if(u==v) return u;
for(int re i=Log-1;~i;--i) if(f[u][i]!=f[v][i]) u=f[u][i],v=f[v][i];
return f[u][0];
}
inline ll dist(int u,int v){return dis[u]+dis[v]-2ll*dis[lca(u,v)];}
inline int pre(int u){
It p=S.lower_bound(u);
if(p==S.begin()) return t[*--S.end()];
return t[*--p];
}
inline int nxt(int u){
It p=S.lower_bound(u);
if(++p==S.end()) return t[*S.begin()];
return t[*p];
}
ll ans=0;
int main(){
// freopen("2704.in","r",stdin);
n=gi(),m=gi();
for(int re i=1;i<n;++i)
x=gi(),y=gi(),z=gl(),add(x,y,z),add(y,x,z);
dfs(1,0);
while(m--){
pos=gi();
if(!key[pos]){
key[pos]^=1,S.insert(dfn[pos]),x=pre(dfn[pos]),y=nxt(dfn[pos]);
ans+=dist(x,pos)+dist(y,pos)-dist(x,y);
}
else{
key[pos]^=1,x=pre(dfn[pos]),y=nxt(dfn[pos]);
ans-=dist(x,pos)+dist(y,pos)-dist(x,y);
S.erase(dfn[pos]);
}printf("%lld\n",ans);
}
}