题目大意
有一棵
n
+
1
n+1
n+1个节点的树,根节点为0。给你一个
k
k
k,定义集合
S
i
=
{
j
∈
Z
∣
max
(
1
,
i
−
k
)
≤
j
<
i
}
∪
{
0
}
S_i=\{j\in Z|\max(1,i-k)\leq j<i\}\cup\{0\}
Si={j∈Z∣max(1,i−k)≤j<i}∪{0}。
A
i
=
∑
j
∈
S
i
d
i
s
(
i
,
j
)
2
A_i=\sum\limits_{j\in S_i}dis(i,j)^2
Ai=j∈Si∑dis(i,j)2,
d
i
s
(
i
,
j
)
dis(i,j)
dis(i,j)指
i
i
i到
j
j
j在树上的距离。求
A
1
,
A
2
,
…
,
A
n
A_1,A_2,\dots,A_n
A1,A2,…,An分别是多少。
题解
令
a
i
a_i
ai表示
i
i
i在树上的深度,注意根节点的深度为1。
那么
d
i
s
(
i
,
j
)
2
=
(
a
i
+
a
j
−
2
a
l
c
a
)
2
=
a
i
2
+
a
j
2
+
2
a
i
a
j
−
4
a
l
c
a
a
i
−
4
a
l
c
a
a
j
+
4
a
l
c
a
2
dis(i,j)^2=(a_i+a_j-2a_{lca})^2=a_i^2+a_j^2+2a_ia_j-4a_{lca}a_i-4a_{lca}a_j+4a_{lca}^2
dis(i,j)2=(ai+aj−2alca)2=ai2+aj2+2aiaj−4alcaai−4alcaaj+4alca2
我们可以枚举
i
i
i,用树链剖分来维护
j
j
j的值。
对于
a
i
2
a_i^2
ai2,可以直接得出。
对于
a
j
2
a_j^2
aj2,将所有在
S
i
S_i
Si中的
j
j
j求前缀和即可。
对于
2
a
i
a
j
2a_ia_j
2aiaj,对
a
j
a_j
aj求前缀和,再乘上
2
a
i
2a_i
2ai。
对于
4
a
l
c
a
a
i
4a_{lca}a_i
4alcaai,对每个
S
i
S_i
Si中的
j
j
j,都将
j
j
j到根节点的路径上的点加1,然后查询
i
i
i到根节点的路径上的对应值的和,再乘上
4
a
i
4a_i
4ai即可。
对于
4
a
l
c
a
a
j
4a_{lca}a_j
4alcaaj,对每个
S
i
S_i
Si中的
j
j
j,都将
j
j
j到根节点的路径上的点加
a
j
a_j
aj,然后查询
i
i
i到根节点的路径上的对应值的和,再乘上
4
4
4即可。
对于
4
a
l
c
a
2
4a_{lca}^2
4alca2,对每个
S
i
S_i
Si中的
j
j
j,都将
j
j
j到根节点的路径上的点加
a
k
∗
2
−
1
a_k*2-1
ak∗2−1(
k
k
k表示当前节点),然后查询
i
i
i到根节点的路径上的对应值的和,因为
x
2
=
1
+
3
+
5
+
⋯
+
(
x
∗
2
−
1
)
x^2=1+3+5+\cdots+(x*2-1)
x2=1+3+5+⋯+(x∗2−1),所以这样求出的就是
a
l
c
a
2
a_{lca}^2
alca2,然后乘
4
4
4即可。
对于
j
j
j进入集合或离开集合,在
j
j
j到根节点的路径上进行区间修改即可。
为什么根节点的深度为1而不为0呢?因为只有根节点的深度为1,那么每个点到根节点的路径的长度才能等于这个点的深度,这样才能更好地实现。
时间复杂度为
O
(
n
log
2
n
)
O(n\log^2 n)
O(nlog2n)。
code
#include<bits/stdc++.h>
#define lc k<<1
#define rc k<<1|1
using namespace std;
int n,k,x,y,tot=0,d[500005],l[500005],r[500005],dep[500005],fa[500005],siz[500005],son[500005];
int tp[200005],s[200005],re[200005];
long long ans,hv1[800005],hv2[800005],mx1[800005],mx2[800005],mx3[800005],ly1[800005],ly2[800005],ly3[800005];
void add(int xx,int yy){
l[++tot]=r[xx];d[tot]=yy;r[xx]=tot;
}
void dfs1(int u,int f){
fa[u]=f;dep[u]=dep[f]+1;siz[u]=1;
for(int i=r[u];i;i=l[i]){
if(d[i]==f) continue;
dfs1(d[i],u);
siz[u]+=siz[d[i]];
if(siz[d[i]]>siz[son[u]]) son[u]=d[i];
}
}
void dfs2(int u,int f){
if(son[u]){
tp[son[u]]=tp[u];
s[son[u]]=++s[0];re[s[0]]=son[u];
dfs2(son[u],u);
}
for(int i=r[u];i;i=l[i]){
if(d[i]==f||d[i]==son[u]) continue;
tp[d[i]]=d[i];
s[d[i]]=++s[0];re[s[0]]=d[i];
dfs2(d[i],u);
}
}
void build(int k,int l,int r){
if(l==r){
hv1[k]=2ll*dep[re[l]]-1;
hv2[k]=1ll;
return;
}
int mid=l+r>>1;
build(lc,l,mid);
build(rc,mid+1,r);
hv1[k]=hv1[lc]+hv1[rc];
hv2[k]=hv2[lc]+hv2[rc];
}
void down(int k){
mx1[lc]+=ly1[k]*hv1[lc];
ly1[lc]+=ly1[k];
mx2[lc]+=ly2[k]*hv2[lc];
ly2[lc]+=ly2[k];
mx3[lc]+=ly3[k]*hv2[lc];
ly3[lc]+=ly3[k];
mx1[rc]+=ly1[k]*hv1[rc];
ly1[rc]+=ly1[k];
mx2[rc]+=ly2[k]*hv2[rc];
ly2[rc]+=ly2[k];
mx3[rc]+=ly3[k]*hv2[rc];
ly3[rc]+=ly3[k];
ly1[k]=ly2[k]=ly3[k]=0;
}
void ch(int k,int l,int r,int x,int y,long long t,int u){
if(l>=x&&r<=y){
mx1[k]+=t*hv1[k];
ly1[k]+=t;
mx2[k]+=t*hv2[k];
ly2[k]+=t;
mx3[k]+=t*dep[u]*hv2[k];
ly3[k]+=t*dep[u];
return;
}
if(l>y||r<x) return;
if(l==r) return;
if(ly1[k]||ly2[k]||ly3[k]) down(k);
int mid=l+r>>1;
if(x<=mid) ch(lc,l,mid,x,y,t,u);
if(y>mid) ch(rc,mid+1,r,x,y,t,u);
mx1[k]=mx1[lc]+mx1[rc];
mx2[k]=mx2[lc]+mx2[rc];
mx3[k]=mx3[lc]+mx3[rc];
}
void find(int k,int l,int r,int x,int y,int u){
if(l>=x&&r<=y){
ans+=4ll*(mx1[k]-mx2[k]*dep[u]-mx3[k]);
return;
}
if(l>y||r<x) return;
if(l==r) return;
if(ly1[k]||ly2[k]||ly3[k]) down(k);
int mid=l+r>>1;
if(x<=mid) find(lc,l,mid,x,y,u);
if(y>mid) find(rc,mid+1,r,x,y,u);
}
void ask(int i){
int t=i;
while(i>=1){
find(1,1,s[0],s[tp[i]],s[i],t);
i=fa[tp[i]];
}
}
void ins(int i){
int t=i;
while(i>=1){
ch(1,1,s[0],s[tp[i]],s[i],1,t);
i=fa[tp[i]];
}
}
void del(int i){
int t=i;
while(i>=1){
ch(1,1,s[0],s[tp[i]],s[i],-1,t);
i=fa[tp[i]];
}
}
int main()
{
scanf("%d%d",&n,&k);++n;
for(int i=1;i<n;i++){
scanf("%d%d",&x,&y);++x;++y;
add(x,y);add(y,x);
}
dfs1(1,0);
s[1]=++s[0];re[s[0]]=1;tp[1]=1;
dfs2(1,0);
build(1,1,s[0]);
long long sum1=0,sum2=0;
for(int i=2,vt,vk=2;i<=n;i++){
vt=max(2,i-k);
while(vk<vt){
sum1-=1ll*dep[vk]*dep[vk];
sum2-=1ll*dep[vk];
del(vk);++vk;
}
ans=1ll*(dep[i]-1)*(dep[i]-1)+1ll*(i-vt)*dep[i]*dep[i]+sum1+2ll*dep[i]*sum2;
ask(i);
printf("%lld\n",ans);
sum1+=1ll*dep[i]*dep[i];
sum2+=1ll*dep[i];
ins(i);
}
return 0;
}