はまやんはまやんはまやん

hamayanhamayan's blog

Variable Spanning Trees [第二回 アルゴリズム実技検定 O]

https://atcoder.jp/contests/past202004-open/tasks/past202004_o

前提知識

解説

https://atcoder.jp/contests/past202004-open/submissions/12901968

どこから手を付けようか。
最小全域木を作る問題なので、既存の最小全域木アルゴリズムから考えてみる。
まず、自明な所として、普通に最小全域木を作成したとして、そこにある辺が含まれていれば、
その最小全域木を構成する重みの総和が答え。
そうでない辺があった場合は、作った最小全域木をベースにして改変することで最小全域木を作成する。
a-bをつなぐ辺を採用すると、木のある頂点をつなぐことになるので、1つのサイクルができる。
このサイクルから1辺を削除すると、木に戻ることができる。
なので、このサイクルからa-bを除く最も重みの大きい辺を選択できれば、
(最小全域木コスト)+(a-b)-(サイクルの最大辺)で答えを導ける。
このサイクルは、lc=lca(a,b)とおくと、a -> lc -> b -> aのサイクルになるので、
(a-lc)と(b-lc)の最大コスト辺を取ってくる問題となる。

つまり、木のあるパスについて最大辺がとってこれればいい。
これにはいくつかやり方がある。

公式解説には、「並列二分探索解」と「ダブリング」が紹介されている。
ダブリング解法が一番優しいと思う。

自分はHL分解+セグメントツリーで解いた。
データ構造でごり押した感じになる。
HL分解を使うと、木へのパスをO(logN)個の区間に分割することができる。
この各区間についてセグメントツリーで最大値を取ってきて、更に区間での最大値を取れば、パスの最大値になる。
類題を解いたことがあれば、この流れはそんなに難しくない。
(習得までが大変ですけど)

int N, M, A[101010], B[101010], C[101010];
bool used[101010];
SegTree<int, 1 << 17> st;
//---------------------------------------------------------------------------------------------------
int tmp;
void _main() {
    cin >> N >> M;
    rep(i, 0, M) cin >> A[i] >> B[i] >> C[i], A[i]--, B[i]--;

    UnionFind uf(N);
    ll tot = 0;
    vector<int> ord;
    rep(i, 0, M) ord.push_back(i);
    sort(all(ord), [&](int a, int b) { return C[a] < C[b]; });
    fore(i, ord) {
        int a = A[i];
        int b = B[i];
        int c = C[i];

        if (uf[a] != uf[b]) {
            uf(a, b);
            used[i] = true;
            tot += c;
        }
    }

    HLDecomposition hld;
    hld.init(N);
    rep(i, 0, M) if (used[i]) hld.add(A[i], B[i]);
    hld.build(0);

    rep(i, 0, M) if (used[i]) {
        int a = A[i];
        int b = B[i];
        int c = C[i];

        if (hld.depth[a] > hld.depth[b]) st.update(hld.vid[a], c);
        else st.update(hld.vid[b], c);
    }

    rep(i, 0, M) {
        if (used[i]) {
            printf("%lld\n", tot);
            continue;
        }

        int a = A[i];
        int b = B[i];
        int c = C[i];

        int lc = hld.lca(a, b);
        int tm = st.get(hld.vid[lc], hld.vid[lc] + 1);
        st.update(hld.vid[lc], 0);

        ll ans = tot + c;

        tmp = 0;
        hld.foreach(a, b, [&](int x, int y) {
            chmax(tmp, st.get(x, y + 1));
        });

        st.update(hld.vid[lc], tm);

        ans -= tmp;
        printf("%lld\n", ans);
    }
}