EDPC T-Permutation Skew tableauxを利用した解法

atcoder.jp

問題概要:(1, 2, ..., N)を並び替えた順列(p _ 1, p _ 2, ..., p _ N)であって、次の条件を満たすものは何通りか?10^ 9+7で割った余りを求めよ。
・文字列Sが与えられ、Si文字目が\ltのときはp _ i \lt p _ i+1であり、Si文字目が\gtのときはp _ i \gt p _ i+1である。

この問題はskew tableauxの標準盤の個数の数え上げに言い換えることができます。 hotmanさんの記事を読んでください。

qiita.com

本記事では↑のアイデアのもと実装するパートのみ紹介します(人によってはやるだけ、かもしれません)。

与えられる文字列Sに対してskew tableauxが一意に定まります。 例えば文字列S\lt\gt\lt\gt\ltのとき、skew tableauxは以下のような図形になります。

f:id:unosss:20210404141402p:plain

求めたいのはこのskew tableauxへの数字の書き込み方の個数(標準盤の個数)です。以下のような手順で求めることができます。

skew tableauxをもとにyoung tableauxを構成します。

f:id:unosss:20210404141626p:plain

このyoung tableauxについて、各マスのフック長を書き込みます。

f:id:unosss:20210404143359p:plain

このyoung tableauxにおいて、最左下のマスから最右上のマスへの最短路を考えます。

f:id:unosss:20210404143834p:plain

例えば上図のような経路において、経路の長さをNとしたとき

val = \frac{N!}{(各マスのフック長hの積)}= \frac{6!}{1 \times 2 \times 3^ 2 \times 4 \times 5}

がこの経路が寄与する標準盤の個数です。

よって、young tableauxにおける最短路の集合をUとしたとき、求める標準盤の個数は

ans = \sum_{ u \in U} val(u)

となります。

最短路の個数が膨大になるため、最短路ごとにvalを求めていると間に合いません。

そこで、最左下のマスからスタートする前にnow = N!を持っておき、

マス(i, j)に移動したときにnow /= h(i, j)(マス(i, j)のフック長)と更新します。

 dp(i, j)= \frac{dp(i-1, j) + dp(i, j-1)}{h(i, j)}という要領です。)

こういう工夫の仕方は以前Atcoderで出題されていましたね。

atcoder.jp

最後に実装例を挙げておきます。

#include <bits/stdc++.h>
using namespace std;

using ll = long long;

#define REP(i,n) for(ll i = 0;i < (n);i++)

const ll mod = 1e9+7;

struct mint {
  ll x; // typedef long long ll;
  mint(ll x=0):x((x%mod+mod)%mod){}
  mint operator-() const { return mint(-x);}
  mint& operator+=(const mint a) {
    if ((x += a.x) >= mod) x -= mod;
    return *this;
  }
  mint& operator-=(const mint a) {
    if ((x += mod-a.x) >= mod) x -= mod;
    return *this;
  }
  mint& operator*=(const mint a) { (x *= a.x) %= mod; return *this;}
  mint operator+(const mint a) const { return mint(*this) += a;}
  mint operator-(const mint a) const { return mint(*this) -= a;}
  mint operator*(const mint a) const { return mint(*this) *= a;}
  mint val(){
    return x;
  }
  mint pow(ll t) const {
    if (!t) return 1;
    mint a = pow(t>>1);
    a *= a;
    if (t&1) a *= *this;
    return a;
  }
 
  // for prime mod
  mint inv() const { return pow(mod-2);}
  mint& operator/=(const mint a) { return *this *= a.inv();}
  mint operator/(const mint a) const { return mint(*this) /= a;}
};
istream& operator>>(istream& is, mint& a) { return is >> a.x;}
ostream& operator<<(ostream& os, const mint& a) { return os << a.x;}
// combination mod prime
// https://www.youtube.com/watch?v=8uowVvQ_-Mo&feature=youtu.be&t=1619
struct combination {
  vector<mint> fact, ifact;
  combination(ll n):fact(n+1),ifact(n+1) {
    assert(n < mod);
    fact[0] = 1;
    for (ll i = 1; i <= n; ++i) fact[i] = fact[i-1]*i;
    ifact[n] = fact[n].inv();
    for (ll i = n; i >= 1; --i) ifact[i-1] = ifact[i]*i;
  }
  mint operator()(ll n, ll k) {
    if (k < 0 || k > n) return 0;
    return fact[n]*ifact[k]*ifact[n-k];
  }
  mint p(ll n, ll k) {
    return fact[n]*ifact[n-k];
  }
} c(3005);

void main_() {
    ll N;
    cin >> N;
    string s;
    cin >> s;
    ll a = 0,b = 0;
    REP(i,N-1){
        if(s[i] == '<')a++;
        else b++;
    }
    
    vector<vector<mint>> dp(b+1,vector<mint>(a+1));
    
    //sumには各マスの下にあるマスの数、lenには各行の列の個数、sumには各マスのフック長
    vector<vector<ll>> num(b+1,vector<ll>(a+1)),sum(b+1,vector<ll>(a+1));
    vector<ll> len(b+1);
    sum[0][0] = 1;
    ll h = 0,w = 1;
    
    REP(i,N-1){
        if(s[i] == '<'){
            sum[h][w] = 1;
            w++;
        }
        else{
            len[h] = w;
            h++;
        }
    }
    
    len[h] = w;
    REP(i,a+1){
        for(ll j = 1;j <= b;j++){
            if(sum[j-1][i]) sum[j][i] = sum[j-1][i] + 1;
        }
    }
    mint val = c.fact[N];
    REP(i,b+1){
        REP(j,a+1){
            num[i][j] = sum[i][j] + len[i] - j - 1;
        }
    }
    dp[0][0] = val;
    dp[0][0] /= num[0][0];
    
    REP(i,b+1){
        REP(j,len[i]){
            if(i == 0 && j == 0)continue;
            if(i){
                mint v = dp[i-1][j];
                v /= num[i][j];
                dp[i][j] += v;
            }
            if(j){
                mint v = dp[i][j-1];
                v /= num[i][j];
                dp[i][j] += v;
            }
        }
    }
    cout << dp[b][a] << endl;
}
 
int main() {
    int t = 1;
    //cin >> t;
    while(t--) main_();
    return 0;
}