题目
Imakf 送给了 Clamee 一个 n 行 n 列 (n<=1e6) 的阶梯状棋盘,
为了测试 Clamee 的智商,他问了 Clamee 一个问题:
在这个棋盘中填 k(1≤k≤n)个相同的棋子,
这 k 个棋子中没有两个棋子在同一行,也没有两个棋子在同一列的方案数是多少?
答案对998244353取模。
例如,n=3,m=2时,合法方案有7种,如下图。
思路来源
江老师
题解
dp[i][j]表示只考虑前i行还有j个空列的方案数,
有dp[i][j]=dp[i-1][j-1]+dp[i-1][j]*j
决策是枚举第i-1行的决策,
1. 第i-1行没取的话,第i行会新增一列,dp[i-1][j-1]转移到dp[i][j]
2. 第i-1行取了的话,从j列里挑一列取,j种方案,
取完之后,第i行还是会新增一列,dp[i-1][j]*j转移到dp[i][j]
等到第n行决策完取不取后,第n+1行都会新增一列,所以dp[n+1][n-k+1]即为所求
暴力dp是的,而注意到这个转移式和第二类斯特林数一模一样,
所以,求S(n,k)时,用O(k)的容斥版本即可
代码
// Problem: G - Many Good Tuple Problems
// Contest: AtCoder - HHKB Programming Contest 2023(AtCoder Beginner Contest 327)
// URL: https://atcoder.jp/contests/abc327/tasks/abc327_g
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include<bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define per(i,a,b) for(int i=(a);i>=(b);--i)
typedef long long ll;
typedef double db;
typedef pair<ll,int> P;
#define fi first
#define se second
#define pb push_back
#define dbg(x) cerr<<(#x)<<":"<<x<<" ";
#define dbg2(x) cerr<<(#x)<<":"<<x<<endl;
#define SZ(a) (int)(a.size())
#define sci(a) scanf("%d",&(a))
#define scll(a) scanf("%lld",&(a))
#define pt(a) printf("%d",a);
#define pte(a) printf("%d\n",a)
#define ptlle(a) printf("%lld\n",a)
#define debug(...) fprintf(stderr, __VA_ARGS__)
const int N=1e6+10,mod=998244353;
int n,k,ans,Finv[N],fac[N],inv[N];
void ADD(int &x,int y){x=(x+y)%mod;}
int modpow(int x,int n,int mod){
int res=1;
for(;n;x=1ll*x*x%mod,n>>=1)
if(n&1)res=1ll*res*x%mod;
return res;
}
void init(int n){ //n<N
inv[1]=1;
for(int i=2;i<=n;++i)inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
fac[0]=Finv[0]=1;
for(int i=1;i<=n;++i)fac[i]=1ll*fac[i-1]*i%mod,Finv[i]=1ll*Finv[i-1]*inv[i]%mod;
//Finv[n]=modpow(fac[n],mod-2,mod);
//for(int i=n-1;i>=1;--i)Finv[i]=1ll*Finv[i+1]*(i+1)%mod;
}
int C(int n,int m){
if(m<0||m>n)return 0;
return 1ll*fac[n]*Finv[n-m]%mod*Finv[m]%mod;
}
int S(int n,int m){
int res=0;
rep(j,1,m){
int sg=((m-j)&1)?-1:1;
ADD(res,(1ll*sg*C(m,j)%mod*modpow(j,n,mod)%mod+mod)%mod);
}
res=1ll*res*Finv[m]%mod;
return res;
}
int main(){
init(N-5);
sci(n),sci(k);
pte(S(n+1,n-k+1));
return 0;
}