摘要:卷积计算匹配

luogu3763[TJOI2017]DNA

题面

题目分析

要计算字符串带通配符在每个位置上的匹配,一般还是想到用多项式卷积.

视两个字符串为多项式$f$,$g$,匹配相当于是:

其中一个倒序卷一下就可以了,不卷不是中国人.

这道题相当于是带通配符,但是不能用超过三个,转换下思路,枚举字符集作为当前的通配符,把第一个串里不等于通配符的赋为$1$,表示这个位置匹配需要一个通配符,另一个串等于通配符的赋为$1$,表示第一个串的这个地方需要被换成通配符,然后进行卷积,算出每个位置匹配上要多少个通配符,最后全部加起来即可.

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include "iostream"
#include "cstdlib"
#include "cstdio"
#include "cstring"
#include "cmath"
#include "cctype"
#include "ctime"
#include "iomanip"
#include "algorithm"
#include "set"
#include "queue"
#include "map"
#include "stack"
#include "deque"
#include "vector"
#define R register
#define INF 0x3f3f3f3f
#define debug(x) printf("debug:%lld\n",x)
#define debugi(x) printf("debug:%d\n",x)
#define debugf(x) printf("debug:%llf\n",x)
#define endl putchar('\n')
typedef long long lxl;
const lxl big=100010,mod=998244353,g=3;
const double pi=acos(-1);
char ch[4]={'A','G','C','T'};
char s1[big],s2[big];
lxl T,n,m,len,out,InvG,InvN;
lxl tr[big<<2],ans[big];
lxl a[big<<2],b[big<<2];
inline lxl read()
{
char c(getchar());
lxl f(1),x(0);
for(;!isdigit(c);(c=='-')&&(f=-1),c=getchar());
for(;isdigit(c);x=(x<<1)+(x<<3)+(c^48),c=getchar());
return f*x;
}
inline lxl FastPow(lxl a,lxl b)
{
lxl sum(1);
for(;b;b>>=1,a=a*a%mod)(b&1)&&(sum=sum*a%mod);
return sum;
}
void NTT(lxl *f,lxl IDFT)
{
for(R int i=0;i<len;++i)if(i<tr[i])std::swap(f[i],f[tr[i]]);
for(R int p=2,leng;p<=len;p<<=1)
{
leng=p>>1;
lxl tG(FastPow((~IDFT)?g:InvG,(mod-1)/p));
for(R int k=0;k<len;k+=p)
{
lxl buf(1);
for(R int l=k;l<k+leng;++l)
{
lxl tt=(f[l+leng]*buf)%mod;
f[leng+l]=f[l]-tt;
if(f[leng+l]<0)f[leng+l]+=mod;
f[l]=f[l]+tt;
if(f[l]>mod)f[l]-=mod;
buf=(buf*tG)%mod;
}
}
}
}
int main(void)
{
T=read();
InvG=FastPow(g,mod-2);
while(T--)
{
out=0,memset(tr,0,sizeof tr);
memset(ans,0,sizeof ans);
scanf("%s",s1),scanf("%s",s2);
n=strlen(s1),m=strlen(s2);
for(R int i(0);i<m/2;++i)std::swap(s2[i],s2[m-i-1]);
// for(R int i(0);i<m;++i)printf("%c",s2[i]);endl;
for(len=1;len<n+m;len<<=1);
InvN=FastPow(len,mod-2);
for(R int i=0;i<len;++i)tr[i]=(tr[i>>1]>>1)|((i&1)?(len>>1):0);
for(R int i(0);i<4;++i)
{
memset(a,0,sizeof a),memset(b,0,sizeof b);
for(R int j(0);j<n;++j)a[j]=(s1[j]!=ch[i]);
for(R int j(0);j<m;++j)b[j]=(s2[j]==ch[i]);
NTT(a,1),NTT(b,1);
for(R int j(0);j<len;++j)a[j]=a[j]*b[j]%mod;
NTT(a,-1);
for(R int j(0);j<len;++j)a[j]=a[j]*InvN%mod;
for(R int j(m-1);j<n;++j)ans[j]+=a[j];
}
for(R int i(m-1);i<n;++i)out+=(ans[i]<=3);
printf("%lld\n",out);
}
return 0;
}