关于BM算法求解线性递推的细节

关于BM算法求解线性递推的细节

网上已有相当多的讲解BM算法的文章,这里只讨论一些细节部分。

1.第一次遇到非零元素:那么下一个转移数组一定是由i个0组成的,其中i是已经访问过零元素个数加1。应为我们**当前**需要保存的只是对下一位可能正确并且前面必然正确的转移数组,若只有i-1个零,这一位上显然是不对的。

2.新添多少个零元素:在以后遇到delta不为0的位置时(比如说是i),我们在前面找到一个能够让这一个位置上能够减少delta的位置j,那么就要空出i-j-1个位置放0(对于位置j处的转移数组而言)。

3.乘正的还是负的(视你的实现而言):观察计算delta的式子,自然而然就能知道是取正的。

4.小心resize!!!

 1 #include<bits/stdc++.h>
 2 #define mod 998244353
 3 using namespace std;
 4 typedef long long int ll;
 5 inline ll qpow(ll x,ll y)
 6 {
 7     ll ans=1,base=x;
 8     while(y)
 9     {
10         if(y&1)
11             ans=ans*base%mod;
12         base=base*base%mod;
13         y>>=1;
14     }
15     return ans;
16 }
17 inline void add(ll&x,ll y)
18 {
19     x=(x+y)%mod;
20 }
21 namespace BM
22 {
23     ll delta[100005];
24     int fail[100005];
25     vector<ll>r[100005];
26     inline vector<ll>BM(vector<ll>a)
27     {
28         memset(delta,0,sizeof(delta));
29         memset(fail,0,sizeof(fail));
30         for(int i=0;i<100000;++i)
31             r[i].clear();
32         int cnt=0;
33         for(int i=0;i<a.size();++i)
34         {
35             delta[i]=a[i];
36             for(int j=0;j<r[cnt].size();++j)
37                 delta[i]=(delta[i]-a[i-j-1]*r[cnt][j])%mod;
38             delta[i]=(delta[i]%mod+mod)%mod;
39             if(delta[i]==0)
40                 continue;
41             fail[cnt]=i;
42             if(cnt==0)
43             {
44                 r[++cnt].resize(i+1);// !!!!!!!
45                 continue;
46             }
47             int p=0;
48             for(int j=1;j<cnt;++j)// !!! cant be self !!!
49                 if(i-fail[j]+1+r[j].size()<i-fail[p]+1+r[p].size())
50                     p=j;
51             ++cnt;
52             r[cnt]=r[cnt-1];
53             r[cnt].resize(max(i-fail[p]+r[p].size(),r[cnt].size()));
54             ll d=delta[i]*qpow(delta[fail[p]],mod-2)%mod;// !!! fail[p] !!!
55             add(r[cnt][i-fail[p]-1],d);
56             for(int j=0;j<r[p].size();++j)
57                 add(r[cnt][i-fail[p]+j],-r[p][j]*d);
58         }
59         for(int i=0;i<r[cnt].size();++i)
60             r[cnt][i]=(r[cnt][i]%mod+mod)%mod; 
61         return r[cnt];
62     }
63 }
64 int main()
65 {
66     ios::sync_with_stdio(false);
67     vector<ll>f;
68     vector<ll>r=BM::BM(f);
69     for(auto p:r)
70         cout<<p<<" ";cout<<endl; 
71     return 0;
72 }
View Code

5.优化:只要记录-fail[p]+r[p].size()最小的就行了。不过要当心,r[p].size()是一个unsigned int,小心溢出!