Pの競プロ記

競技プログラミングに関することを書きます。

ARC #023 C タコヤ木

概要


問題文はこちら
http://arc023.contest.atcoder.jp/tasks/arc023_3

解法


(他の記事を参考にしながら解いた。説明に間違いがあるかもしれない)

例えば入力で 1 -1 -1 3 が与えられた場合、単調増加な数列になるのは
・1 1 1 3
・1 1 2 3
・1 1 3 3
・1 2 2 3
・1 2 3 3
・1 3 3 3
の6通りが考えられる

よってこの問題では重複組合わせを計算することができれば解くことができる。

n種類の中から重複可でk個取る場合の組合わせは、
nHk = n+k-1Ck
で求めることができる。

そしてn種類の中から重複なしでk個取る場合の組合わせは、
nCk = nPk / k!
で求めることができる。

しかし今回の問題ではAiの値が非常に大きいのでnPkやk!の計算の過程でMODを取る必要がある。

ここで注意しなければいけないのは、
足し算・引き算・かけ算の途中でMODを取ることはできるが、割り算では分子、分母それぞれで単純にMODを取ってしまうと結果が変わってしまうということ。(実際にやってみると分かると思う)

そこで逆元の計算をさせる。

= = = = = = = = = = = = = = =
mod mでの逆元について
aとmが互いに素であるときに、ax≡1(mod m)となるxが存在する。
= = = = = = = = = = = = = = =

ということで逆元を高速に求めることができればよいことになる。(説明不十分ですいません)

逆元の計算は蟻本に載っているので参考にするとよい。
またそれ以外にも1~Nまでの逆元をO(N)で求めることもできる。

これで解くことができる。

ソースコード

#include <iostream>
#include <vector>

using namespace std;

typedef long long int lli;

lli extgcd(lli a, lli b, lli& x, lli& y) {
  lli d = a;
  
  if(b != 0) {
    d = extgcd(b, a % b, y, x);
    y -= (a / b) * x;
  }
  else {
    x = 1;
    y = 0;
  }
  
  return d;
}

// aの逆元を求める
lli calculateModInverse(lli a, lli MOD) {
  lli x, y;
  extgcd(a, MOD, x, y);
  return (MOD + x % MOD) % MOD;
}

vector<lli> listModInverse(int n, lli MOD) {
  vector<lli> inv(n + 1);
  inv[1] = 1;
  
  for(int i = 2; i <= n; i++) {
    inv[i] = inv[MOD % i] * (MOD - MOD / i) % MOD;
  }

  return inv;
}

lli solve(int N, vector<int> A) {
  const int MOD = 1e9 + 7;
  
  // vector<lli> inv = listModInverse(N, MOD);

  lli res = 1;
  for(int i = 0; i < A.size(); i++) {
    if(A[i] == -1) {
      int start = i;
      while(A[i] == -1) i++;
      int end = i;
      
      lli comb = 1;
      int n = A[end] - A[start - 1] + 1;
      int r = end - start;
      n += r - 1;

      // nHk = n+k-1Ck
      for(int j = 1; j <= r; j++) {
        comb = comb * (n - j + 1) % MOD;
        // comb = comb * inv[j] % MOD;
        comb = comb * calculateModInverse(j, MOD) % MOD;
      }
            
      res = res * comb % MOD;
    }
  }

  return res;
}

int main() {
  int N;
  vector<int> A;
  
  cin >> N;
  for(int i = 0; i < N; i++) {
    int tmp;
    cin >> tmp;
    A.push_back(tmp);
  }

  cout << solve(N, A) << endl;

  return 0;
}