[BJOI2019]奥术神杖(AC自动机,DP,分数规划)

题目大意:

给出一个长度 $n$ 的字符串 $T$,只由数字和点组成。你可以把每个点替换成一个任意的数字。再给出 $m$ 个数字串 $S_i$,第 $i$ 个权值为 $t_i$。

对于一个替换方案,这样定义它的价值:

如果数字串 $S_i$ 在 $T$ 中出现了,那么将 $t_i$ 加入多重集。如果出现多次也要加多次。

它的价值就是这个多重集元素的几何平均数(所有 $c$ 个数的乘积开 $c$ 次方根)。

请构造出一个替换方案,使得这个值最大。不用输出这个价值。

$1le nle 1500,1le sum|S_i|le 1500,1le t_ile 10^9$。


首先这个奇怪的式子是乘积的形式。如果将它取对数:(其实不一定要用 $ln$,任意底数都可以)

$$dfrac{1}{c}sumlimits_{i=1}^cln x_i$$

$$dfrac{sumlimits_{i=1}^cln x_i}{sumlimits_{i=1}^c 1}$$

这就是分数规划经典模式了。

二分 $x$,看看这个式子的值能不能 $>v$。可以就缩小左边界,否则缩小右边界。

$$dfrac{1}{c}sumlimits_{i=1}^cln x_i>v$$

$$sumlimits_{i=1}^cln x_i>vc$$

$$sumlimits_{i=1}^c(ln x_i-v)>0$$

设第 $i$ 个串的新权值 $w_i=ln x_i-v$,那么就是要找出一个方案使得 $w_i$ 之和最大,看看是否 $>0$ 即可。

如何求最大值?实际上是个套路 DP 了。

先建出所有 $S_i$ 串的 AC 自动机。然后设 $wsum_u$ 为从 $u$ 开始跳 fail 指针,跳过的所有点的 $w_i$ 之和。要记录这个是因为在 AC 自动机上走到点 $u$ 时,沿着 $u$ 跳 fail 能跳到的所有点都是可以匹配的。

令 $f[i][j]$ 表示按照 $T$ 串走了 $i$ 步,走到点 $j$ 的最大值。

可以从 $f[i][j]+wsum_c$ 转移到 $f[i+1][c]$。其中 $T_{i+1}$ 已经确定时 $c$ 就是对应的儿子,否则就要枚举替换成什么字符,再转移到相应的儿子。

由于要输出方案,要记录从哪个状态转移过来。

时间复杂度 $O(nsum|S_i|log)$。

#include<bits/stdc++.h>
using namespace std;
const int maxn=1555;
#define FOR(i,a,b) for(int i=(a);i<=(b);i++)
#define ROF(i,a,b) for(int i=(a);i>=(b);i--)
#define MEM(x,v) memset(x,v,sizeof(x))
inline int read(){
    char ch=getchar();int x=0,f=0;
    while(ch<'0' || ch>'9') f|=ch=='-',ch=getchar();
    while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();
    return f?-x:x;
}
int n,m,nd,ch[maxn][10],fail[maxn],cnt[maxn],q[maxn],h,r,fr[maxn][maxn],id[maxn][maxn];
double val[maxn],hhh[maxn],dp[maxn][maxn];
char s[maxn],t[maxn],tmp[maxn];
void insert(const char *s,int x){
    int now=0,l=strlen(s+1);
    FOR(i,1,l){
        int p=s[i]-'0';
        if(!ch[now][p]) ch[now][p]=++nd;
        now=ch[now][p];
    }
    val[now]+=log(x);
    cnt[now]++;
}
void build(){
    h=1;r=0;
    FOR(i,0,9) if(ch[0][i]) q[++r]=ch[0][i];
    while(h<=r){
        int u=q[h++];
        val[u]+=val[fail[u]];
        cnt[u]+=cnt[fail[u]];
        FOR(i,0,9) if(ch[u][i]) fail[q[++r]=ch[u][i]]=ch[fail[u]][i];
        else ch[u][i]=ch[fail[u]][i];
    }
}
void trans(int i,int j,int k){
    int c=ch[j][k];
    if(dp[i][j]+hhh[c]>dp[i+1][c]){
        dp[i+1][c]=dp[i][j]+hhh[c];
        fr[i+1][c]=j;
        id[i+1][c]=k;
    }
}
bool check(double x){
    FOR(i,0,nd) hhh[i]=val[i]-cnt[i]*x;
    FOR(i,0,n) FOR(j,0,nd) dp[i][j]=-1e9;
    dp[0][0]=0;
    FOR(i,0,n-1) FOR(j,0,nd){
        if(s[i+1]=='.') FOR(k,0,9) trans(i,j,k);
        else trans(i,j,s[i+1]-'0');
    }
    int mxid=0;
    FOR(i,1,nd) if(dp[n][i]>dp[n][mxid]) mxid=i;
    if(dp[n][mxid]<=0) return false;
    int at=mxid;
    ROF(i,n,1){
        t[i]=id[i][at]+'0';
        at=fr[i][at];
    }
    return true;
}
int main(){
    n=read();m=read();
    scanf("%s",s+1);
    FOR(i,1,m){
        scanf("%s",tmp+1);
        insert(tmp,read());
    }
    build();
    double l=0,r=log(1e9);
    while(r-l>1e-8){
        double mid=(l+r)/2;
        if(check(mid)) l=mid;
        else r=mid;
    }
    check(l);
    printf("%s
",t+1);
}
View Code