题目
n(n<=1e5)个点的一棵树,每个点有点权ai(1<=ai<=1e5),
对于树上两个点s、t,定义C(s,t)=k×gcd(Ap1,Ap2,…,Apk),
其中,k为s到t这条链的顶点个数,p1为点s,pk为点t,中间的点是链上经过的点
求,答案对998244353取模
思路来源
dls/官方题解
题解1(反演)
根据反演,考虑将gcd的贡献拆开,然后统计每一部分贡献的倍数有哪些,
根据式子,将gcd(a1,...,an)的贡献展开,
即对于每个因子d,在d的倍数的路径上,都有phi(d)的贡献,
则问题转化为,按因子建树,dfs树,对每个因子求贡献
TIPS
根据式子,移项得到
,
即可以得到O(nlogn)求phi的方法,即一开始初始化phi(n)=n,
在枚举到某一对(因子,倍数)关系的时候,令倍数减去因子的phi值
代码1
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10,M=1e5,mod=998244353;
int n,u,v,phi[N],a[N],sz[N],all,ans,cnt;
bool vis[N];
vector<int>e[N],fac[N],h[N];
vector<int>now;
struct edge{
int u,v;
}f[N];
void dfs(int u,int fa){
all++;
vis[u]=1;
sz[u]=1;
for(auto &v:e[u]){
if(v==fa)continue;
dfs(v,u);
sz[u]+=sz[v];
}
}
void dfs2(int u,int fa){
for(auto &v:e[u]){
if(v==fa)continue;
cnt=(cnt+1ll*sz[v]*(all-sz[v])%mod)%mod;
dfs2(v,u);
}
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;++i){
scanf("%d",&a[i]);
}
for(int i=1;i<=M;++i){
phi[i]=i;
}
for(int i=1;i<=M;++i){
fac[i].push_back(i);
for(int j=2*i;j<=M;j+=i){
fac[j].push_back(i);
phi[j]-=phi[i];
}
}
for(int i=1;i<n;++i){
scanf("%d%d",&u,&v);
f[i]=edge{u,v};
int g=__gcd(a[u],a[v]);
//printf("g:%d\n",g);
for(auto &w:fac[g]){
//printf("w:%d i:%d\n",w,i);
h[w].push_back(i);
}
}
for(int i=1;i<=M;++i){
if(!h[i].size())continue;
now.clear();
for(auto &w:h[i]){
u=f[w].u,v=f[w].v;
e[u].push_back(v);
e[v].push_back(u);
//printf("i:%d u:%d <- -> v:%d\n",i,u,v);
if(!vis[u]){
now.push_back(u);
vis[u]=1;
}
if(!vis[v]){
now.push_back(v);
vis[v]=1;
}
}
for(auto &v:now)vis[v]=0;
for(auto &v:now){
if(!vis[v]){
all=cnt=0;
dfs(v,-1);
dfs2(v,-1);
cnt=(cnt+1ll*all*(all-1)/2%mod)%mod;
//printf("i:%d phi:%d all:%d cnt:%d\n",i,phi[i],all,cnt);
ans=(ans+1ll*phi[i]*cnt%mod)%mod;
}
}
for(auto &v:now){
e[v].clear();
vis[v]=0;
}
}
printf("%d\n",ans);
return 0;
}
题解2(容斥)
先对于每个因子d,求出d的倍数的路径有多少条,然后因为会有算重,
记ans[i]为恰为因子d的贡献,则ans[i]需要减去ans[2*i],ans[3*i],...,减去每一个倍数的贡献
代码2
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10,M=1e5,mod=998244353;
int n,u,v,phi[N],a[N],sz[N],all,ans[N],cnt,res;
bool vis[N];
vector<int>e[N],fac[N],h[N];
vector<int>now;
struct edge{
int u,v;
}f[N];
void dfs(int u,int fa){
all++;
vis[u]=1;
sz[u]=1;
for(auto &v:e[u]){
if(v==fa)continue;
dfs(v,u);
sz[u]+=sz[v];
}
}
void dfs2(int u,int fa){
for(auto &v:e[u]){
if(v==fa)continue;
cnt=(cnt+1ll*sz[v]*(all-sz[v])%mod)%mod;
dfs2(v,u);
}
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;++i){
scanf("%d",&a[i]);
}
for(int i=1;i<=M;++i){
phi[i]=i;
}
for(int i=1;i<=M;++i){
fac[i].push_back(i);
for(int j=2*i;j<=M;j+=i){
fac[j].push_back(i);
phi[j]-=phi[i];
}
}
for(int i=1;i<n;++i){
scanf("%d%d",&u,&v);
f[i]=edge{u,v};
int g=__gcd(a[u],a[v]);
//printf("g:%d\n",g);
for(auto &w:fac[g]){
//printf("w:%d i:%d\n",w,i);
h[w].push_back(i);
}
}
for(int i=1;i<=M;++i){
if(!h[i].size())continue;
now.clear();
for(auto &w:h[i]){
u=f[w].u,v=f[w].v;
e[u].push_back(v);
e[v].push_back(u);
//printf("i:%d u:%d <- -> v:%d\n",i,u,v);
if(!vis[u]){
now.push_back(u);
vis[u]=1;
}
if(!vis[v]){
now.push_back(v);
vis[v]=1;
}
}
for(auto &v:now)vis[v]=0;
for(auto &v:now){
if(!vis[v]){
all=cnt=0;
dfs(v,-1);
dfs2(v,-1);
cnt=(cnt+1ll*all*(all-1)/2%mod)%mod;
//printf("i:%d phi:%d all:%d cnt:%d\n",i,phi[i],all,cnt);
ans[i]=(ans[i]+cnt)%mod;
}
}
for(auto &v:now){
e[v].clear();
vis[v]=0;
}
}
for(int i=M;i;--i){
for(int j=2*i;j<=M;j+=i){
ans[i]=(ans[i]-ans[j]+mod)%mod;
}
res=(res+1ll*i*ans[i]%mod)%mod;
}
printf("%d\n",res);
return 0;
}