题目传送门
题意:在一个有n个结点n-1条边的树状的网络中,键盘侠a和键盘侠b一起移动,两个人轮流选择一个当前结点的相邻结点移动,每个结点至多被经过一次,当没有可以移动的结点的时候旅途结束,到达结点i后两个人获得的金币为A[i]-B[i],键盘侠a想让最终结果金币尽可能多,b想让最终结果金币尽可能少,第一轮由a任选一个结点作为起点,求最终的结果(a肯定会选择一个能使最终结果最大的结点作为起点啦~)
题解:显然为树形dp+换根。注意一些实现细节就行了
#include<bits/stdc++.h>
using namespace std;
//#define debug(x) cout<<#x<<" is "<<x<<endl;
typedef long long ll;
const int maxn=1e5+5;
const ll inf=1e16;
struct edge{
int to;
int nex;
}e[maxn<<1];
int cnt,head[maxn];
ll ac,a[maxn],b[maxn],dp[2][maxn],siz[maxn];
void adde(int x,int y){
e[cnt].to=y;
e[cnt].nex=head[x];
head[x]=cnt++;
}
void dfs1(int x,int f){
dp[0][x]=-inf;
dp[1][x]=inf;
for(int i=head[x];i!=-1;i=e[i].nex){
int v=e[i].to;
if(v==f)continue;
dfs1(v,x);
dp[0][x]=max(dp[0][x],dp[1][v]);
dp[1][x]=min(dp[1][x],dp[0][v]);
}
if(dp[0][x]==-inf)dp[0][x]=0;
if(dp[1][x]==inf)dp[1][x]=0;
dp[0][x]+=a[x]-b[x];
dp[1][x]+=a[x]-b[x];
}
void dfs2(int x,int f){
ac=max(ac,dp[1][x]);
ll tle[2][2];
tle[0][0]=tle[0][1]=-inf;
tle[1][0]=tle[1][1]=inf;
for(int j=head[x];j!=-1;j=e[j].nex){
int v2=e[j].to;
if(dp[1][v2]>=tle[0][0]){
tle[0][1]=tle[0][0];
tle[0][0]=dp[1][v2];
}
else{
tle[0][1]=max(tle[0][1],dp[1][v2]);
}
if(dp[0][v2]<=tle[1][0]){
tle[1][1]=tle[1][0];
tle[1][0]=dp[0][v2];
}
else{
tle[1][1]=min(tle[1][1],dp[0][v2]);
}
}
for(int i=head[x];i!=-1;i=e[i].nex){
int v=e[i].to;
if(v==f)continue;
ll xx=dp[0][x];
ll yy=dp[1][x];
ll xx2=dp[0][v];
ll yy2=dp[1][v];
if(dp[1][v]!=tle[0][0]){
dp[0][x]=tle[0][0];
}
else{
dp[0][x]=tle[0][1];
}
if(dp[0][v]!=tle[1][0]){
dp[1][x]=tle[1][0];
}
else{
dp[1][x]=tle[1][1];
}
if(dp[0][x]==-inf)dp[0][x]=0;
if(dp[1][x]==inf)dp[1][x]=0;
dp[0][x]+=a[x]-b[x];
dp[1][x]+=a[x]-b[x];
if(siz[v]>1)dp[0][v]=max(dp[0][v],dp[1][x]+a[v]-b[v]);
else dp[0][v]=dp[1][x]+a[v]-b[v];
if(siz[v]>1)dp[1][v]=min(dp[1][v],dp[0][x]+a[v]-b[v]);
else dp[1][v]=dp[0][x]+a[v]-b[v];
dfs2(v,x);
dp[0][x]=xx;
dp[1][x]=yy;
dp[0][v]=xx2;
dp[1][v]=yy2;
}
}
int main() {
int t;
scanf("%d",&t);
while(t--){
cnt=0;
ac=-inf;
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++)siz[i]=0;
memset(head,-1,sizeof(head));
for(int i=1;i<=n;i++)scanf("%lld",&a[i]);
for(int i=1;i<=n;i++)scanf("%lld",&b[i]);
int N=n-1;
while(N--){
int fr,to;
scanf("%d%d",&fr,&to);
adde(fr,to);
adde(to,fr);
siz[fr]++;
siz[to]++;
}
dfs1(1,-1);
dfs2(1,-1);
printf("%lld\n",ac);
}
return 0;
}