はまやんはまやんはまやん

hamayanhamayan's blog

Avoiding Collision [AtCoder Regular Contest 090 E]

https://beta.atcoder.jp/contests/arc090/tasks/arc090_c

前提知識

解法

https://beta.atcoder.jp/contests/arc090/submissions/2031587

まず点Sからダイクストラをして、最短経路を求めておこう。
最短経路を求める際に最短距離が同じ場合は組合せを足すようにして、同時に最短経路となるパス数も同時に求めておこう。
S->Tへの最短経路をlen, S->Tへの最短経路のパス数をpatとする。
すると、2人が取るパスの組み合わせ数はpat*patとなる。
ans = pat*pat

ここから2人が出会ってしまう組み合わせ数を引くことで計算しよう。
2人が出会う場所を全探索することにする。
このために点Tからもダイクストラをして、最短距離とパス数を求めよう。
点Sからの最短経路表をdisS, そこまでのパス数をpatS
点Tからの最短経路表をdisT, そこまでのパス数をpatT
として以降説明していく。
 
まず、頂点で出会う場合、つまり「disS[i] + disT[i] == len」かつ「disS[i] == len / 2」の場合は、
頂点で出会うので、その頂点を通るパスの組み合わせ数を引く。
ans -= (patS[i] * patT[i]) ^ 2
 
次に、辺(a,b,c)で出会う場合、つまり「disS[a] + c + disT[b] == len」かつ「disS[a] < len / 2」かつ「disT[b] < len/2」の場合は、
その辺で出会うので、その辺を通るパスの組み合わせ数を引く。
ans -= (patS[a] * patT[b]) ^ 2
 
これで答えが求まる。
比較の/2は「a = len/2」だと切り捨てで誤差が生じるので「a*2 == len」で比較する。

int N, M, S, T;
vector<pair<int, ll>> E[101010];
//---------------------------------------------------------------------------------------------------
template<typename T> using min_priority_queue = priority_queue<T, vector<T>, greater<T>>;
int vis[101010];
void dij(int s, vector<ll> &dis, vector<mint> &pat) {
    rep(i, 0, N) dis[i] = infl;
    rep(i, 0, N) pat[i] = 0;
    rep(i, 0, N) vis[i] = 0;
 
    min_priority_queue<pair<ll, int>> que;
 
    dis[s] = 0;
    pat[s] = 1;
    que.push({ 0, s });
 
    while (!que.empty()) {
        auto q = que.top(); que.pop();
 
        int cu = q.second;
        ll cst = q.first;
 
        if (vis[cu]) continue;
        vis[cu] = 1;
 
        fore(p, E[cu]) if(!vis[p.first]) {
            if (dis[cu] + p.second < dis[p.first]) {
                dis[p.first] = dis[cu] + p.second;
                pat[p.first] = pat[cu];
                que.push({ dis[p.first], p.first });
            }
            else if (dis[cu] + p.second == dis[p.first]) {
                pat[p.first] += pat[cu];
            }
        }
    }
}
//---------------------------------------------------------------------------------------------------
void _main() {
    cin >> N >> M >> S >> T;
    S--;
    T--;
    rep(i, 0, M) {
        int a, b, c; cin >> a >> b >> c;
        a--; b--;
        E[a].push_back({ b, c });
        E[b].push_back({ a, c });
    }
    vector<ll> disS(N), disT(N);
    vector<mint> patS(N), patT(N);
    dij(S, disS, patS);
    dij(T, disT, patT);
 
    ll len = disS[T];
 
    mint ans = patS[T] * patS[T];
 
    rep(i, 0, N) if (disS[i] + disT[i] == len and disS[i] * 2 == len) {
        ans -= patS[i] * patT[i] * patS[i] * patT[i];
    }
 
    
    rep(i, 0, N) fore(p, E[i]) {
        int j = p.first;
        ll c = p.second;
 
        if (disS[i] + c + disT[j] == len) {
            //printf("[%d %d]\n", i + 1, j + 1);
            if (disS[i] * 2 < len and disT[j] * 2 < len) {
                ans -= patS[i] * patT[j] * patS[i] * patT[j];
            }
        }
    }
 
    cout << ans << endl;
}