线段树典型题解-poj2777
线段树典型例题--poj2777
这是某些人线段树中询问的算法,仔细想想,这个算法是会退化到线性复杂度的。只要相邻的点颜色都不一样即可。
这到题我认为网上有些人的算法是不对的。
void solve(int l,int r,int root) //询问 { if(tree[root].col>=0) //如果父节点有单一的颜色,就直接更新,不需要找到子节点更新 { flag[tree[root].col]=1;//统计哪些颜色出现过 return; } if(tree[root].left==tree[root].right) return; int mid=(tree[root].left+tree[root].right)>>1; if(l>mid) solve(l,r,(root<<1)+1); else if(r<=mid) solve(l,r,root<<1); else { solve(l,mid,root<<1); solve(mid+1,r,(root<<1)+1); } }
这是某些人线段树中询问的算法,仔细想想,这个算法是会退化到线性复杂度的。只要相邻的点颜色都不一样即可。
我出了一组数据:
#include <cstdio> #include <iostream> using namespace std; int main() { freopen("in","w",stdout); int i; printf("%d %d %d\n",100000,30,100000); for (i=1;i<=50000;i++) printf("C %d %d %d\n",i,i,i%30+1); for (i=1;i<=50000;i++) printf("P %d %d\n",1,100000); }上述程序产生的数据会让错误程序超时(大约两分钟才能跑出结果)
那么真正的算法究竟如何??
应当使用延迟标记。同时可以使用二进制记录节点中的颜色情况。
具体维护,我想显而易见。
我说那么多的目的,希望大家在刷题的时候不要被poj弱数据所迷惑,要严格要求自己,才会进步。同时也提醒自己。
【代码】
#include <iostream> #include <cstring> #include <string> #include <cstdio> #include <algorithm> using namespace std; const int N=300000; int col[N]; int n,m,t,sum; bool pp[N],ans[33]; void down(int i) { if (!pp[i]) return; col[i*2]=col[i*2+1]=col[i]; pp[i*2]=pp[i*2+1]=true; pp[i]=false; } void update(int i) { col[i]=col[i*2]|col[i*2+1]; } void ins(int i,int l,int r,int x,int y,int k) { if (x<=l && y>=r) { col[i]=1<<(k-1); pp[i]=true; return; } down(i); int mid=(l+r)/2; if (x<=mid) ins(i*2,l,mid,x,y,k); if (y>mid) ins(i*2+1,mid+1,r,x,y,k); update(i); } void find(int i,int l,int r,int x,int y) { if (x<=l && y>=r) { int tmp=col[i]; for (int p=1;p<=t;p++) { ans[p]|=tmp&1; tmp>>=1; if (tmp==0) break; } return; } down(i); int mid=(l+r)/2; if (x<=mid) find(i*2,l,mid,x,y); if (y>mid) find(i*2+1,mid+1,r,x,y); update(i); } void build(int i,int l,int r) { col[i]=1; if (l==r) return; int mid=(l+r)/2; build(i*2,l,mid); build(i*2+1,mid+1,r); } int main() { int i,u,v,c; char ch; freopen("in","r",stdin); scanf("%d%d%d\n",&n,&t,&m); build(1,1,n); while (m--) { scanf("%c",&ch); if (ch=='P') { scanf("%d%d\n",&u,&v); if (u>v) swap(u,v); memset(ans,0,sizeof(ans)); find(1,1,n,u,v); sum=0; for (i=1;i<=t;i++) sum+=ans[i]; printf("%d\n",sum); } else { scanf("%d%d%d\n",&u,&v,&c); if (u>v) swap(u,v); ins(1,1,n,u,v,c); } } }