[POJ1741]树上的点对 树分治

[POJ1741]树上的点对 树分治

Description

给一棵有n个节点的树,每条边都有一个长度(小于1001的正整数)。 
定义dist(u,v)=节点u到节点v的最短路距离。 
给出一个整数k,我们称顶点对(u,v)是合法的当且仅当dist(u,v)不大于k。 
写一个程序,对于给定的树,计算有多少对顶点对是合法的。

Input

输入包含多组数据。 
每组数据的第一行有两个整数N,K(N<=10000)。接下来N-1行每行有三个整数u,v,l,代表节点u和v之间有一条长度l的无向边。 
输入结束标志为N=K=0.

Output

对每组数据输出一行一个正整数,即合法顶点对的数量。

Sample Input

5 4 
1 2 3 
1 3 1 
1 4 2 
3 5 1 
0 0

Sample Output

8

树分治裸题:

每一次找到树的重心,使得二分复杂度降为logn,然后每一个小区间做相同处理:

设 dis[i]表示到当前根节点X的距离 找到以当前节点X为根的满足dis[i]+dis[j]<=k的所有方案。

在统计的时候可能存在把同一子树满足dis[i]+dis[j]<=k的也统计进答案中,产生重复方案。

于是我们要减去每一棵子树中满足dis[i]+dis[j]<=k的方案。

  1 #include<iostream>
  2 #include<cstdio>
  3 #include<cstring>
  4 #include<algorithm>
  5 #include<cmath>
  6 using namespace std;
  7  
  8 int n,k;
  9  
 10 int gi()
 11 {
 12     int str=0;char ch=getchar();
 13     while(ch>'9' || ch<'0')ch=getchar();
 14     while(ch>='0' && ch<='9')str=str*10+ch-'0',ch=getchar();
 15     return str;
 16 }
 17  
 18 const int N=10005;int sum=0;int ans=0;
 19 int num=0,head[N];int root=0;int son[N];int f[N]={9999999};bool vis[N];int dis[N];
 20 int b[N];
 21 struct Lin
 22 {
 23     int next,to,dis;
 24 }a[N*2];
 25  
 26 void init(int x,int y,int z)
 27 {
 28     a[++num].next=head[x];
 29     a[num].to=y;
 30     a[num].dis=z;
 31     head[x]=num;
 32 }
 33  
 34 void getroot(int x,int last)
 35 {
 36     son[x]=1;f[x]=0;
 37     int u;
 38     for(int i=head[x]; i ;i=a[i].next)
 39     {
 40         u=a[i].to;
 41         if(u==last || vis[u])continue;
 42         getroot(u,x);
 43         son[x]+=son[u];
 44         f[x]=max(f[x],son[u]);
 45     }
 46     f[x]=max(f[x],sum-son[x]);
 47     if(f[x]<f[root])root=x;
 48     return ;
 49 }
 50  
 51 void getdis(int x,int last)
 52 {
 53     int u;
 54     b[++b[0]]=dis[x];
 55     for(int i=head[x];i;i=a[i].next)
 56     {
 57         u=a[i].to;
 58         if(u==last || vis[u])continue;
 59         dis[u]=dis[x]+a[i].dis;
 60         getdis(u,x);
 61     }
 62     return ;
 63 }
 64  
 65 int cal(int x,int dd)
 66 {
 67     int tot=0,u;dis[x]=dd;
 68     b[0]=0;
 69     getdis(x,0);
 70     sort(b+1,b+b[0]+1);
 71     int l=1,r=b[0];
 72     while(l<r)
 73     {
 74         if(b[l]+b[r]<=k)tot+=r-l,l++;
 75         else r--;
 76     }
 77     return tot;
 78 }
 79  
 80 void work(int x)
 81 {
 82     ans+=cal(x,0);
 83     vis[x]=1;
 84     int u;
 85     for(int i=head[x];i;i=a[i].next)
 86     {
 87         u=a[i].to;
 88         if(vis[u])continue;
 89         ans-=cal(u,a[i].dis);
 90         root=0;sum=son[u];
 91         getroot(u,x);
 92         work(root);
 93     }
 94     return ;
 95 }
 96  
 97 void Clear()
 98 {
 99     memset(vis,0,sizeof(vis));
100     f[0]=99999999;
101     memset(a,0,sizeof(a));
102     memset(head,0,sizeof(head));
103     memset(dis,0,sizeof(dis));
104     num=0;
105 }
106 int main()
107 {
108     while(scanf("%d%d",&n,&k))
109     {
110         if(!n && !k)return 0;
111         int x,y,z;
112         Clear();
113         for(int i=1;i<=n-1;i++)
114         {
115             x=gi();y=gi();z=gi();
116             init(x,y,z);init(y,x,z);
117         }
118         sum=n;
119         ans=0;
120         root=0;
121         getroot(1,0);
122         work(root);
123         printf("%d
",ans);
124     }
125     return 0;
126 }