摘要:矩阵表示状压dp的状态转移

luogu1357花园

题面

题目分析

其实有个很显然的思路,设$f_{i,s}$为考虑前$i$位,前$m$位放置的状态为$s$的方案数.

然后转移的时候,可以发现$s$的第一位是没用的,因为已经超出了$m$位的范围,也就是:

于是可以发现,对于一个确定的状态$s$,有且最多仅有$2$个状态能转移到它,那么可以将每一个状态的转移关系处理好放在矩阵中,类似于图论中邻接矩阵的用法,我们可以通过矩阵乘法来得到每一个状态能够转移到的状态.

考虑上环,通过观察题解不难发现,对于一个确定的状态$s$,将dp的初始状态设成$f_{0,s}=1$,终止状态取为$f_{n,s}$,因为$0$号位的答案与$n$号位的状态相同所以是一个合法的环.

把这个做法移到矩阵上,只须要把初始的转移矩阵的主对角线视为$f_{0,s}=1$,在之后的乘法中自然会发生转移,最后统计主对角线上的值即可.

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include "iostream"
#include "cstdlib"
#include "cstdio"
#include "cstring"
#include "cmath"
#include "cctype"
#include "ctime"
#include "iomanip"
#include "algorithm"
#include "set"
#include "queue"
#include "map"
#include "stack"
#include "deque"
#include "vector"
#define R register
#define INF 0x3f3f3f3f
#define debug(x) printf("debug:%lld\n",x)
#define debugi(x) printf("debug:%d\n",x)
#define debugf(x) printf("debug:%llf\n",x)
#define endl putchar('\n')
typedef long long lxl;
const lxl big=32,p=1000000007;
lxl n,m,k,all,ans;
struct _Matrix
{
lxl a[big][big];
_Matrix(){memset(a,0,sizeof a);}
inline lxl* operator [](const lxl i){return a[i];}
inline _Matrix operator *(const _Matrix &another)const
{
_Matrix b;
R int i,j,k;
for(i=0;i<big;++i)
for(k=0;k<big;++k)
for(j=0;j<big;++j)
b.a[i][j]=(b.a[i][j]+a[i][k]*another.a[k][j]%p)%p;
return b;
}
}I,A,T;
inline lxl BitCnt(lxl s)
{
#define lowbit(x) (x&-x)
lxl cnt(0);
while(s)++cnt,s-=lowbit(s);
return cnt;
}
inline _Matrix FastPow(_Matrix A,lxl b)
{
_Matrix C=I;
for(;b;b>>=1,A=A*A)if(b&1)C=C*A;
return C;
}
inline lxl read()
{
char c(getchar());
lxl f(1),x(0);
for(;!isdigit(c);(c=='-')&&(f=-1),c=getchar());
for(;isdigit(c);x=(x<<1)+(x<<3)+(c^48),c=getchar());
return f*x;
}
inline void prework()
{
for(R int i(0);i<all;++i)
{
if(BitCnt(i)>k)continue;
lxl j=i>>1;
A[j][i]=1;
j|=1<<(m-1);
if(BitCnt(j)<=k)A[j][i]=1;
}
for(R int i(0);i<all;++i)I[i][i]=1;
}
int main(void)
{
n=read(),m=read(),k=read();
all=1<<m;
prework();
T=FastPow(A,n);
for(R int i(0);i<all;++i)ans=(ans+T[i][i])%p;
printf("%lld\n",ans);
return 0;
}