BZOJ1468 【树的分治】

Tree

[Description]

统计带权树上有多少点对的距离小于等于k.

[Hint]

$n≤10^5,k≤10^9$

[Solution]

将树转化为有根树,满足题意的任意两点间的路径要么经过根结点,要么经过根的子结点,这引导我们使用树的分治,对于每个结点,我们只考虑经过路径根结点的点对,然后递归处理。
设dis[i]为结点i到根结点的路径长度,belong[i]为i位于根结点的第几个子结点的子树内,我需要统计的就是
$$\sum_{i,j}^{belong[i]≠belong[j]}(dis[i]+dis[j]≤k)$$
等价于所有满足dis[i]+dis[j]≤k的i,j数量-所有满足dis[i]+dis[j]≤k且being[i]=being[j]的i,j数量。
对于计算前一部分“所有满足dis[i]+dis[j]≤k的i,j数量”,当我们通过DFS得到dis数组后,有一个经典的排序算法可以$O(n)$求解(详情见代码),后面一部分可以通过同样的方法分别计算每个子节点的贡献量。
这样,递归分治下去,$O(n×log^2n)$的复杂度(算上了排序的时间)是完全可以接受的。
然后就可以愉快地AC辣~

[Code]

嘴上说AC挺容易。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include<cstdio>
#include<cstring>
#include<cctype>
#include<algorithm>
using namespace std;

inline int readint(){
int x=0; bool minu=false; char c=getchar();
while(!isdigit(c) && c-'-') c=getchar();
if(c=='-') minu=true; else x=c-'0',c=getchar();
while(isdigit(c)) x=x*10+c-'0',c=getchar();
return minu?-x:x;
}

const int inf=~0u>>1;
const int maxn=20000;
int n,k,ans;
int to[maxn],w[maxn],nxt[maxn],cur[maxn],cnt;
void insert(int x,int y,int z){
to[cnt]=y,w[cnt]=z,nxt[cnt]=cur[x],cur[x]=cnt++;
to[cnt]=x,w[cnt]=z,nxt[cnt]=cur[y],cur[y]=cnt++;
}

bool visited[maxn];
bool being[maxn];

int now,bary;
int num[maxn];
void find(int u,int sum){ //找树的重心
being[u]=true;
int mx=0; num[u]=1;
for(int i=cur[u];i>=0;i=nxt[i]){
int v=to[i];
if(!being[v] && !visited[v]){
find(v,sum);
num[u]+=num[v]; mx=max(mx,num[v]);
}
}
mx=max(mx,sum-num[u]);
if(mx<now) now=mx,bary=u;
}

int dis[maxn],cnt2;
void dfs(int u,int d){ //计算离重心的距离
being[u]=true; dis[cnt2++]=d;
for(int i=cur[u];i>=0;i=nxt[i]){
int v=to[i];
if(!being[v] && !visited[v]) dfs(v,d+w[i]);
}
}
int calc(int u,int x){
cnt2=0;
dfs(u,x); memset(being,0,sizeof(being));
sort(dis,dis+cnt2);
int l=0,r=cnt2-1,res=0;
while(l<r){ //经典算法
while(dis[l]+dis[r]>k && l<r) r--;
res+=r-l++;
}
return res;
}

void work(int u,int sum){ //注意sum的用于正确找到重心,被这里坑了好久
now=inf; bary=u;
find(u,sum); memset(being,0,sizeof(being));
visited[bary]=true;
ans+=calc(bary,0);
for(int i=cur[bary];i>=0;i=nxt[i]){
int v=to[i];
if(!visited[v]){
ans-=calc(v,w[i]);
work(v,num[v]);
}
}
}

void init(){
ans=cnt=0;
memset(cur,-1,sizeof(cur));
memset(visited,0,sizeof(visited));
}
int main(){
while(~scanf("%d%d",&n,&k) && n){
init();
for(int i=1,x,y,z;i<n;i++){
x=readint()-1,y=readint()-1,z=readint();
insert(x,y,z);
}
work(0,n);
printf("%d\n",ans);
}
return 0;
}