ぱーぽーの競プロ記

競技プログラミングに関することを書きます。

BestCoder #029 B : GTY's birthday gift

概要

集合Sがあり、その中にn個の要素が含まれている。

「集合Sに属する2値を選び、その和を集合Sに加える」という動作をk回行った後に集合Sに含まれる要素の和の最大値を求めよ。

2<=n<=100000
1<=k<=1000000000

http://bestcoder.hdu.edu.cn/contests/contest_showproblem.php?pid=1002&cid=567

解法

和を最大化させるには集合Sに属する要素の中で1番大きい値とその次に大きい値を持ってくれば良さそうというのはすぐに思いつく。

しかし上記動作回数のkはとても大きな値で単純にk回計算させていたら間に合わないので工夫する必要がある。

和をSk、集合Sに含まれる要素で1番大きい値をFk、その次に大きい値をFk-1とすると、以下のような更新式が成り立つ。
Sk+1 = Sk+ Fk + Fk-1

更にこれを高速に求めるために行列を用いると以下のような更新式を作ることができる。

{ 
 \begin{pmatrix} S[k+1] \\ F[k+1] \\ F[k] \end{pmatrix} = \begin{pmatrix} 1 & 1 & 1 \\ 0 & 1 & 1 \\ 0 & 1 & 0 \end{pmatrix} \begin{pmatrix} S[k] \\ F[k] \\ F[k-1] \end{pmatrix}
}

そして、

{ 
 A = \begin{pmatrix} 1 & 1 & 1 \\ 0 & 1 & 1 \\ 0 & 1 & 0 \end{pmatrix}
}

とおいて上記式を変形させると、

{ 
 \begin{pmatrix} S[k] \\ F[k] \\ F[k-1] \end{pmatrix} = A^{k} \begin{pmatrix} S[1] \\ F[1] \\ F[0] \end{pmatrix}
}

を得ることができる。

Akは繰り返し二乗法を用いればO(log(k))で求めることができる。

補足

AOJ 1327 : One-Dimensional Cellular Automaton - ぱーぽーのぷろぐらみんぐ記
同じようなやり方で解くことができるので練習するとよい。

ソースコード

#include <iostream>
#include <vector>
#include <algorithm>

#define REP(i, x, n) for(int i = x; i < (int)(n); i++)
#define rep(i, n) REP(i, 0, n)
#define all(x) (x).begin(), (x).end()
#define rall(x) (x).rbegin(), (x).rend()
#define F first
#define S second
#define mp make_pair
#define pb push_back

using namespace std;

typedef long long int lli;
typedef vector<vector<lli> > Matrix;

Matrix mul(const Matrix& A, const Matrix& B, const lli& mod) {
  Matrix C(A.size(), vector<lli>(B[0].size()));

  for(int i = 0; i < A.size(); i++) {
    for(int k = 0; k < B.size(); k++) {
      for(int j = 0; j < B[0].size(); j++) {
        C[i][j] = (C[i][j] + A[i][k] * B[k][j]) % mod;
      }
    }
  }

  return C;
}

// B = A^n
Matrix modPow(Matrix A, int n, lli mod) {
  Matrix B(A.size(), vector<lli>(A.size()));

  for(int i = 0; i < A.size(); i++) {
    B[i][i] = 1;
  }

  while(n) {
    if(n & 1) B = mul(B, A, mod);
    A = mul(A, A, mod);
    n >>= 1;
  }

  return B;
}

int main() {
  // ios_base::sync_with_stdio(false);
  int n, k;
  const lli mod = 10000007;
  
  while(cin >> n >> k) {
    vector<lli> a(n);
    rep(i, n) cin >> a[i];

    lli s1 = 0;
    rep(i, n) s1 += a[i];

    sort(rall(a));
    lli f1 = a[0];
    lli f0 = a[1];
    
    Matrix A = Matrix(3, vector<lli>(3, 0));
    A[0][0] = 1;
    A[0][1] = 1;
    A[0][2] = 1;
    A[1][1] = 1;
    A[1][2] = 1;
    A[2][1] = 1;

    Matrix B = Matrix(3, vector<lli>(1, 0));
    B[0][0] = s1;
    B[1][0] = f1;
    B[2][0] = f0;

    Matrix ans = mul(modPow(A, k, mod), B, mod);

    cout << ans[0][0] << endl;
  }
  return 0;
}