Computer Science/Problem Solving

[BOJ 31419] 배열 제작의 달인, 생성 함수와 FFT

리유나 2024. 10. 10. 00:41

안녕하세요. 리유나입니다.

 

지난 UCPC 포스팅 이후로 생성함수 문제를 여럿 풀어보면서 그에 관련된 공부를 어느 정도 했는데, 마침 포스팅하기 좋은 문제가 있어 보여서 간략히 포스팅합니다. 먼저 문제는 다음과 같습니다.

https://www.acmicpc.net/problem/31419

 

1. 접근

간략히 생각해 보면, 1부터 n까지의 숫자가 특정 개수 주어져 있고, 그들 중 0의 개수만큼을 뽑아서 배열하는 방법의 경우의 수입니다. 이런 류의 문제에는 여러 가지 풀이가 있고 DP를 이용하는 풀이 또한 제법 알려져 있지만, 뽑는 종류가 워낙 여러가지고 상황에 따라 바뀔 수 있어서 당장 예쁜 일반항으로 나오기는 조금 어렵습니다.

 

이럴 때 생성 함수를 이용하면 제법 깔끔한 풀이를 사용할 수 있습니다. 우선 조금 더 간단한 예시를 하나 보도록 합시다.

 

https://www.acmicpc.net/problem/13542

1원짜리 우표가 n개, 2원짜리 우표가 m개 있을 때 k원어치 사는 방법을 구하면 됩니다...마는, 정말 큰 문제가 있는데, 범위가 10^12까지입니다. 그래서 일반적인 DP로는 잘 되지 않습니다. 나머지를 구하는 수 p가 소수임이 보장되어 있으므로 일반항을 구하고 뤼카 정리 등을 끼얹어서 나머지를 잘 구해서 하는 방법도 있지만, 생성함수를 사용해 풀이하자면 이 때 1원 우표를 사는 개수를 나타내는 항을 (1+x+...+x^n)으로 나타내고, 2원 우표를 사는 개수는 (1+x^2+...+x^2m)으로 나타내고, 두 식을 곱해서 나오는 x^k 계수가 곧 우표를 구매하는 방법의 가짓수가 됩니다. 이 접근법을 약간 확장시켜보면 이번 문제에도 좋은 풀이가 나올 수 있습니다.

 

2. 숫자를 고르는 방법의 경우의 수

우선 배열시키는 방법은 나중에 생각하고, 숫자를 고르는 방법부터 생각해봅시다. 1부터 n까지의 숫자들이 등장한 횟수를 세어주면, 반대로 1부터 n까지 숫자가 앞으로 몇개 더 나올 수 있는지를 구할 수 있습니다. 그 값을 ct1, ct2, ... ctn이라고 해봅시다.

 

그렇다면 어떤 자연수 i에 대해서, i를 고르는 방법에 대한 항은 (1+x+x^2+...+x^cti)의 꼴로 나타낼 수 있습니다. 이 항이 많아야 3천개가 나오고, 3천개의 항을 곱해주고 x^k의 계수를 구해주면 될...것처럼 보이지만, 약간의 문제가 있습니다. 바로 '배열을 시키는 방법도 생각해야 한다'는 것입니다.

 

3. 배열시키는 방법을 고려해보자.

사실 배열시키는 방법 또한 그렇게까지 어렵지는 않습니다. 같은 것이 있는 순열을 생각하면, 기본적으로 1이 a_1개, 2가 a_2개, ..., n이 a_n개 있고 총 갯수가 k개라고 한다면 이들을 일렬로 배열시키는 방법은 바로 k!/(a_1)!(a_2)!...(a_n)!이 됨을 알 수 있습니다.

 

여기서 중요한 사실은, 어떤 자연수가 1이건 2이건 n이건 상관 없이, 그리고 k의 값에도 상관없이, 자연수를 a_i개 골랐다면 자연스럽게 저 배열시키는 방법 가짓수 식 분모에도 (a_i)!이 포함되게 된다는 것입니다. 그렇다면 위 식을 조금씩 변형해주면 어떨까요?

 

이전에 나온 항을 (1+x/1! + x^2/2!+ ... + x^cti/(cti)!)의 꼴로 변형시켜 줍니다. 그렇다면 i를 a_i개 고르는 것을 나타내는 항이 곱해졌을 때, 자동으로 그 앞에 붙은 계수인 1/(a_i)!도 같이 곱해지게 됩니다. 이 모든 과정을 거쳐 구한 x^k의 계수에 마지막으로 k!을 곱해주면 됩니다.

 

4. 구현 상 주의점

물론 당연히 위의 과정을 막무가내로 float를 쓰든가 해서 하라는 뜻은 절대로 아닙니다. 모듈로 역원을 구하는 연산을 취해서 mod 998244353에서의 1!의 역원, 2!의 역원... n!의 역원을 구해두면 되는데, 제가 선호하는 구현 방식은 n!을 먼저 구하고, n!의 역원을 구한 뒤 거기서 n을 곱해서 (n-1)!의 역원을 구하고...를 반복하는 것입니다.

 

또한, 3000차 다항식을 최대 3000번 곱해서 마지막에는 900만차 다항식까지 나올 가능성이 있지만, 어차피 k 또한 당연히 n 이하의 자연수이므로 FFT 곱셈 과정에서 일정 차수 이상은 그냥 날려버리는 식으로 구현하면 시간을 상당히 절약할 수 있습니다.

 

이제 나온 n개의 다항식들을 잘 FFT로 다항식 곱셈을 하고, x^k의 계수를 구하고 거기에다가 k!을 구하면 어렵지 않게 답을 얻을 수 있습니다.

 

5. 풀이

코드 복사-붙여넣기 하는 당신이 아름답습니다.

 

#include<bits/stdc++.h>
#include <queue>
using namespace std;

typedef long long ll;
typedef vector<ll> vll;

const ll w=3;
const ll mod =998244353;



ll pw(ll a, ll b, ll p=mod){
    ll res=1;
    while(b){
        ll r=b%2;
        if(r)res=res*a%p;
        b>>=1;
        a=(a*a)%p;
    }
    return res;
}

void ntt(vll &L, bool inv=false){
    ll sz=L.size();
    ll j=0;
    for(ll i=1;i<sz;i++){
        ll bit=sz>>1;
        while(j>=bit){
            j-=bit;
            bit>>=1;
        }
        j+=bit;
        if(i<j){
            swap(L[i], L[j]);
        }
    }
    ll m=2;
    while(m<=sz){
        ll u=pw(3,mod/m, mod);
        if(inv) u=pw(u,mod-2,mod);
        for(ll i=0;i<sz;i+=m){
            ll w=1;
            for(ll k=i;k<(i+(m/2ll));k++){
                ll tmp=(L[k+(m/2ll)]*w)%mod;
                L[k+(m/2ll)] = (L[k]-tmp+mod)%mod;
                L[k]=(L[k]+tmp)%mod;
                w=(w*u)%mod;
            }
        }
        m*=2;
    }
    if(inv){
        ll inv_n = mod-((mod-1)/sz);
        for(ll i=0;i<sz;i++){
            L[i]=L[i]*inv_n%mod;
        }
    }
}

vll mul(vll _L1, vll _L2){
    vll L1(_L1.begin(), _L1.end()), L2(_L2.begin(), _L2.end());
    ll n=2;
    while(n<L1.size()+L2.size())n<<=1;
    L1.resize(n);
    L2.resize(n);
    ntt(L1);
    ntt(L2);
    for(ll i=0;i<n;i++){
        L1[i]=(L1[i]*L2[i])%mod;
    }
    ntt(L1,true);
    vll L;
    for(int i=0;i<=3000;i++){
        if(L1.size()<=i)L.push_back(0);
        else L.push_back(L1[i]);
    }
    return L;
}


int main(){
   ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
   int n,m,k;
   k=0;
   cin>>n;
   vll res;
   queue<vll> q;
   vll ct;
   for(int i=0;i<=n;i++){
        ct.push_back(i);
   }
   for(int i=0;i<n;i++){
        cin >> m;
        if(m)ct[m]--;
        else k++;
   }
   ll rev=1;
   for(ll i=1;i<=3000;i++){
        rev*=i;rev%=mod;
   }
   ll r=pw(rev, mod-2);
   vll rL;
   for(int i=3000;i>=0;i--){
        rL.push_back(r);
        r*=i;
        r%=mod;
   }
   reverse(rL.begin(), rL.end());
   sort(ct.begin(), ct.end());
   for(int i=0;i<=n;i++){
        vll nowL;
        for(int j=0;j<=ct[i];j++)nowL.push_back(rL[j]);
        q.push(nowL);
   }
    while(q.size()>1){
        vll L1, L2;
        L1=q.front();
        q.pop();
        L2=q.front();
        q.pop();
        vll L = mul(L1, L2);
        q.push(L);
   }
   vll L = q.front();
   q.pop();
   ll ans=L[k];
   for(int i=k;i>0;i--){
        ans*=i;
        ans%=mod;
   }
   cout<<ans;
   return 0;
}

제 주 언어는 Python이지만, 이번에는 시간 복잡도가 상당히 두려웠기 때문에 C++로 구현하였습니다.