线段树典型题解-poj2777

线段树典型例题--poj2777

这到题我认为网上有些人的算法是不对的。

void solve(int l,int r,int root) //询问
{

    if(tree[root].col>=0) //如果父节点有单一的颜色,就直接更新,不需要找到子节点更新
    {
        flag[tree[root].col]=1;//统计哪些颜色出现过
        return;
    }
    if(tree[root].left==tree[root].right) return;
    int mid=(tree[root].left+tree[root].right)>>1;
    if(l>mid) solve(l,r,(root<<1)+1);
    else if(r<=mid) solve(l,r,root<<1);
    else
    {
        solve(l,mid,root<<1);
        solve(mid+1,r,(root<<1)+1);
    }
}

这是某些人线段树中询问的算法,仔细想想,这个算法是会退化到线性复杂度的。只要相邻的点颜色都不一样即可。

我出了一组数据:

#include <cstdio>
#include <iostream>
using namespace std;

int main()
{
    freopen("in","w",stdout);
    int i;
    printf("%d %d %d\n",100000,30,100000);
    for (i=1;i<=50000;i++)
        printf("C %d %d %d\n",i,i,i%30+1);
    for (i=1;i<=50000;i++)
        printf("P %d %d\n",1,100000);
}
上述程序产生的数据会让错误程序超时(大约两分钟才能跑出结果)

那么真正的算法究竟如何??

应当使用延迟标记。同时可以使用二进制记录节点中的颜色情况。

具体维护,我想显而易见。

我说那么多的目的,希望大家在刷题的时候不要被poj弱数据所迷惑,要严格要求自己,才会进步。同时也提醒自己。

【代码】

#include <iostream>
#include <cstring>
#include <string>
#include <cstdio>
#include <algorithm>
using namespace std;

const int N=300000;

int col[N];
int n,m,t,sum;
bool pp[N],ans[33];

void down(int i)
{
    if (!pp[i]) return;
    col[i*2]=col[i*2+1]=col[i];
    pp[i*2]=pp[i*2+1]=true;
    pp[i]=false;
}

void update(int i)
{
    col[i]=col[i*2]|col[i*2+1];
}

void ins(int i,int l,int r,int x,int y,int k)
{
    if (x<=l && y>=r)
    {
        col[i]=1<<(k-1);
        pp[i]=true;
        return;
    }
    down(i);
    int mid=(l+r)/2;
    if (x<=mid) ins(i*2,l,mid,x,y,k);
    if (y>mid) ins(i*2+1,mid+1,r,x,y,k);
    update(i);
}

void find(int i,int l,int r,int x,int y)
{
    if (x<=l && y>=r)
    {
        int tmp=col[i];
        for (int p=1;p<=t;p++)
        {
            ans[p]|=tmp&1;
            tmp>>=1;
            if (tmp==0) break;
        }
        return;
    }
    down(i);
    int mid=(l+r)/2;
    if (x<=mid) find(i*2,l,mid,x,y);
    if (y>mid) find(i*2+1,mid+1,r,x,y);
    update(i);
}

void build(int i,int l,int r)
{
    col[i]=1;
    if (l==r) return;
    int mid=(l+r)/2;
    build(i*2,l,mid);
    build(i*2+1,mid+1,r);
}

int main()
{
    int i,u,v,c;
    char ch;

    freopen("in","r",stdin);
    scanf("%d%d%d\n",&n,&t,&m);
    build(1,1,n);
    while (m--)
    {
        scanf("%c",&ch);
        if (ch=='P')
        {
            scanf("%d%d\n",&u,&v);
            if (u>v) swap(u,v);
            memset(ans,0,sizeof(ans));
            find(1,1,n,u,v);
            sum=0;
            for (i=1;i<=t;i++)
                sum+=ans[i];
            printf("%d\n",sum);
        }
        else
        {
            scanf("%d%d%d\n",&u,&v,&c);
            if (u>v) swap(u,v);
            ins(1,1,n,u,v,c);
        }
    }
}