はまやんはまやんはまやん

hamayanhamayan's blog

貪欲な領主 [yukicoder 386]

問題

http://yukicoder.me/problems/no/386

N頂点の木が与えられ、各頂点に重みUiが設定されている。
この時、M個のクエリ「(頂点Aiから頂点Biへのパスを構成する頂点の重みの和) * Ci」の和を求めよ。

2 <= N <= 10^5
0 < Ui <= 100
1 <= Ci <= 10

考察

1. 木で2頂点間のパスに対する操作と言えば「LCA
2. LCAを使うだろうなと思って考えれば解法が思いつきます

LCAを使うと共通の親を素早く見つけることができる

3. 頂点Aと頂点Bの共有の親を頂点Pとする。すると、A->Bの重みの和は以下の式になる
(A->B) = (根->A) + (根->B) - 2 * (根->P) + U[P]
4. なので、BFSをつかって、予め根からその頂点までの重みの和を全頂点に対して計算しておく -> bfs()

5. あとは、3.の式と4.で計算したやつを使って、各クエリを順番に計算して足せば良い

LCAを知ってるかどうかが問題

実装

http://yukicoder.me/submissions/101172

typedef long long ll;
int N;
vector<int> E[101010];
int U[101010];
//-----------------------------------------------------------------
int sum[101010];
void bfs() {
    queue<int> que;

    sum[0] = U[0];

    que.push(0);
    while (!que.empty()) {
        int i = que.front(); que.pop();

        for (int j : E[i]) if (sum[j] == 0) {
            sum[j] = sum[i] + U[j];
            que.push(j);
        }
    }
}
//-----------------------------------------------------------------
class LCA {
public:
    int NV, logNV;
    vector<int> D;
    vector<vector<int> > P;

    LCA(int N) {
        NV = N;
        logNV = 0;
        while (NV > (1LL << logNV)) logNV++;
        D = vector<int>(NV);
        P = vector<vector<int> >(logNV, vector<int>(NV));
        dfs(0, -1, 0);
        build();
    }

    void dfs(int v, int par, int d) {
        D[v] = d;
        P[0][v] = par;
        for(int i : E[v]) if (i != par) dfs(i, v, d + 1);
    }

    void build() {
        rep(k, 0, logNV - 1) rep(v, 0, NV) {
            if (P[k][v] < 0) 
                P[k + 1][v] = -1;
            else
                P[k + 1][v] = P[k][P[k][v]];
        }
    }

    int query(int u, int v) {
        if (D[u] > D[v]) swap(u, v);
        rep(k, 0, logNV) if ((D[v] - D[u]) >> k & 1) v = P[k][v];
        if (u == v) return u;

        for (int k = logNV - 1; k >= 0; k--) {
            if (P[k][u] != P[k][v]) {
                u = P[k][u];
                v = P[k][v];
            }
        }
        return P[0][u];
    }
};
//-----------------------------------------------------------------
int main() {
    scanf("%d", &N);
    rep(i, 0, N - 1) {
        int a, b; scanf("%d %d", &a, &b);
        E[a].push_back(b);
        E[b].push_back(a);
    }
    rep(i, 0, N) scanf("%d", &U[i]);

    bfs();

    LCA lca(N);

    int M; scanf("%d", &M);
    ll ans = 0;
    rep(i, 0, M) {
        int a, b, c; scanf("%d %d %d", &a, &b, &c);

        int p = lca.query(a, b);
        ans += (sum[a] + sum[b] - sum[p] * 2 + U[p]) * c;
    }
    printf("%lld\n", ans);
}