ぱーぽーの競プロ記

競技プログラミングに関することを書きます。

Valentine's Day Round (BestCoder #030) C : The Experience of Love

概要

N個の町とN-1本の道がある。

町には1~Nまで番号がつけられており、道は2つの町a,bを距離cでつないでいる。

ある町から別の町に行くときに通る道のうち、最大の距離を持つ道と最小の距離を持つ道の距離の差を計算する。

例えば町1,2,3があり、町1,2間の距離は5、町2,3間の距離は2として町1,3の移動を考える。
そのとき最大の距離を持つ道は5、最小の距離を持つ道は2なので、その差は3となる。

町iと町jの選び方を全通り試したとき、上記の計算した値の和を求めよ。
(1 <= i < j <= N)

N <= 150000
c <= 10^9

http://bestcoder.hdu.edu.cn/contests/contest_showproblem.php?cid=568&pid=1003

解法

(※自分では理解できているつもりなのにうまく文章にまとめられていない。)

クラスカル法っぽいやり方で解くことができる。

Union-Findで集合を管理するときに、あるノードが属する集合の要素数を持たせておく。
そうすることで町a,bを距離cで繋ぐ道を考えたときに、(aが属する集合の要素数) * (bが属する集合の要素数)が全パターン試すときにその道の距離が使われる回数を示し、それに距離cをかけることでその距離の合計値となる。
(たぶん紙に書いてみると分かりやすいと思う。)

道を距離でソートし、
・距離が小さい順に道を選び計算させると全パターン試したときの最大距離の合計
・距離が大きい順に道を選び計算させると全パターン試したときの最小距離の合計
を求めることができ、それらの差を取ると解を求めることができる。

ちなみに計算のさせ方によってはlong long intだとオーバーフローするので注意。

ソースコード

#include <iostream>
#include <vector>
#include <algorithm>
#include <map>

#define REP(i, x, n) for(int i = x; i < (int)(n); i++)
#define rep(i, n) REP(i, 0, n)
#define all(x) (x).begin(), (x).end()
#define rall(x) (x).rbegin(), (x).rend()
#define F first
#define S second
#define mp make_pair
#define pb push_back

using namespace std;

typedef long long int lli;
typedef unsigned long long int ulli;
typedef pair<int, int> Pii;

class UnionFindTree {
private:
  int nodeSize;
  vector<int> parent;
  vector<int> rank;
  vector<int> treeSize;

public:
  UnionFindTree(int ns) {
    nodeSize = ns;
    parent = vector<int>(nodeSize);
    rank = vector<int>(nodeSize, 0);
    treeSize = vector<int>(nodeSize, 1);

    for(int i = 0; i < nodeSize; i++) {
      parent[i] = i;
    }
  }

  ~UnionFindTree() {}
  
  int find(int x) {
    if(parent[x] == x) {
      return x;
    }
    else {
      return parent[x] = find(parent[x]);
    }
  }

  bool unite(int x, int y) {
    x = find(x);
    y = find(y);
    
    if(x == y) {
      return false;
    }
    
    if(rank[x] < rank[y]) {
      parent[x] = y;
      treeSize[y] += treeSize[x];
    }
    else {
      parent[y] = x;
      treeSize[x] += treeSize[y];
      if(rank[x] == rank[y]) {
        rank[x]++;
      }
    }

    return true;
  }

  bool same(int x, int y) {
    return find(x) == find(y);
  }

  int getSize(int x) {
    return treeSize[find(x)];
  }
};

ulli solve(int N, const vector<pair<int, pair<int, int> > >& edge) {
  UnionFindTree uf(N);
  ulli res = 0;

  rep(i, N - 1) {
    ulli c = edge[i].F;
    int a = edge[i].S.F;
    int b = edge[i].S.S;

    res += c * uf.getSize(a) * uf.getSize(b);
    uf.unite(a, b);
  }

  return res;
}

int main() {
  int T = 1;
  int N;

  while(cin >> N) {
    vector<pair<int, pair<int, int> > > edge;
    
    rep(i, N - 1) {
      int a, b, c;
      cin >> a >> b >> c;
      edge.push_back(mp(c, mp(a - 1, b - 1)));
    }

    ulli ans = 0;
    sort(all(edge));
    ans += solve(N, edge);
    sort(rall(edge));
    ans -= solve(N, edge);
    
    cout << "Case #" << T++ << ": " << ans << endl;
  }

  return 0;
}