【POJ1741】树中点对统计 点分治

题目描述

  给定一棵N(1<=N<=100000)个结点的带权树,每条边都有一个权值(为正整数,小于等于1001)。定义dis(u,v)为u,v两点间的最短路径长度,路径的长度定义为路径上所有边的权和。再给定一个K(1<=K<=10^9),如果对于不同的两个结点u,v,如果满足dist(u,v)<=K,则称(u,v)为合法点对。求合法点对个数。

题目大意

  求树中距离小于k的点对个数

数据范围

对于50%的数据,n<=1000,k<=1000; 对于100%的数据,n<=100000,k<=10^9;

样例输入

5 4 1 2 3 1 3 1 1 4 2 3 5 1

样例输出

8

解题思路

不写了233

代码

#include <algorithm> #include <iostream> #include <cstring> #include <cstdlib> #include <cstdio> #include <cmath> #include <ctime> #define Maxn 100005 using namespace std; inline int Getint(){int x=0,f=1;char ch=getchar();while('0'>ch||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while('0'<=ch&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;} int fa[Maxn],h[Maxn],size[Maxn],dep[Maxn]; bool vis[Maxn]; int n,k,cnt=0,L,r,Min; struct node{int to,next,v;}e[Maxn*2]; void AddEdge(int x,int y,int v){e[++cnt]=(node){y,h[x],v};h[x]=cnt;} void Init(){ n=Getint(),k=Getint(); for(int i=1;i<n;i++){ int x,y,v; x=Getint(),y=Getint(),v=Getint(); fa[y]=x; AddEdge(x,y,v); AddEdge(y,x,v); } memset(vis,0,sizeof(vis)); } int dfssize(int u,int PRe){ size[u]=1; for(int p=h[u];p;p=e[p].next){ int y=e[p].to; if(vis[y]||y==pre)continue; size[u]+=dfssize(y,u); } return size[u]; } void Getroot(int u,int pre,int tot,int &root){ int Max=tot-size[u]; for(int p=h[u];p;p=e[p].next){ int y=e[p].to; if(vis[y]||y==pre)continue; Getroot(y,u,tot,root); Max=max(Max,size[y]); } if(Max<Min){ Min=Max; root=u; } } void Getlen(int u,int pre,int d){ dep[r++]=d; for(int p=h[u];p;p=e[p].next){ int y=e[p].to; if(vis[y]||y==pre)continue; Getlen(y,u,d+e[p].v); } } int Calc(int L,int r){ sort(dep+L,dep+r); int ret=0,Pos=r-1; for(int i=L;i<r;i++){ if(dep[i]>k)break; while(Pos>=L&&dep[i]+dep[Pos]>k)Pos--; ret+=Pos-L+1; if(Pos>i)ret--; } return ret/2; } int Solve(int u){ int tot=dfssize(u,0),ret=0,root; Min=0x7fffffff; Getroot(u,0,tot,root); vis[root]=true; for(int p=h[root];p;p=e[p].next){ int y=e[p].to; if(vis[y])continue; ret+=Solve(y); } L=r=0; for(int p=h[root];p;p=e[p].next){ int y=e[p].to; if(vis[y])continue; Getlen(y,root,e[p].v); ret-=Calc(L,r); L=r; } ret+=Calc(0,r); for(int i=0;i<r;i++) if(dep[i]<=k)ret++; else break; vis[root]=false; return ret; } int main(){ Init(); cout<<Solve(1)<<"\n"; }