はまやんはまやんはまやん

hamayanhamayan's blog

いたずら好きなお姉ちゃん [yukicoder 1031]

https://yukicoder.me/problems/no/1031

前提知識

解説

https://yukicoder.me/submissions/466496

自分の解法はちょっとやりすぎかもしれない。

この問題を解くときに厄介なのは、異なる操作であっても同じ結果になることがあり、
その同じ結果がダブって計算しては行けないということである。
これを解決するために以下のように問題を読み替えて考えよう。
ある区間を選んだ時の(最大の要素,最小の要素)の組み合わせは?
このように言い換えると若干数えやすくなる。

解法の方針であるが、最大の要素を固定する。
そこから、ありえる最小の要素を数え上げることで問題を解く。
例えば1 7 3 2 5 8 6 4というサンプル3で考えてみる。
最大の要素を5の場合を考えたいとき、5を含むような区間で、かつ5より大きい部分は含んではいけないので、
3 2 5
この区間に対して、操作区間を考えることになる。
5は必ず含むべきなので、区間を5から伸ばしていくように考えていく。
すると、3 2 5, 2 5の2パターンしかないが、どちらも最小値が2なので、(5, 2)しか答えがない。
もう少し一般化して説明すると、ある最大値から区間を伸ばしていき、最小値が更新される回数が答えになる。

これはgetpre(idx, cnt)に任せよう。
A[idx]から最小値が更新されるようにcnt回移動したときの最終的な場所を求める。
この関数の内部ではダブリングを使っている。

この関数が用意できれば、二分探索によって、最小値が何回更新されるかが分かる。
移動後に区間内に入っているような最大回数を求めれば、何回更新されるかを導ける。

最後に、最大値を固定したときに、どこまでの区間で考えることができるかというのをどうやって扱うかであるが、
列の分割統治を利用する。
二分探索で範囲を特定する方が簡単かもしれない。
列の分割統治では、列の最大値をとってきて、それに対して操作を行い、
操作を行った後に、その最大値で列を分割して、再帰によって処理を行う。

int N, A[101010];
SparseTable<pair<int, int>> st;
//---------------------------------------------------------------------------------------------------
int pre[20][101010];
int post[20][101010];
SegTreeMin<int, 1 << 17> stmin;
SegTreeMax<int, 1 << 17> stmax;
void init() {
    rep(p, 0, 20) rep(idx, 0, 101010) pre[p][idx] = 0;
    rep(i, 0, N) {
        pre[0][i + 1] = stmax.get(0, A[i]);
        stmax.update(A[i], i + 1);
    }
    rep(p, 1, 20) rep(idx, 0, N + 1) pre[p][idx] = pre[p - 1][pre[p - 1][idx]];

    defmin = N + 1;
    rep(p, 0, 20) rep(idx, 0, 101010) post[p][idx] = N + 1;
    rrep(i, N - 1, 0) {
        post[0][i + 1] = stmin.get(0, A[i]);
        stmin.update(A[i], i + 1);
    }
    rep(p, 1, 20) rep(idx, 0, N + 1) post[p][idx] = post[p - 1][post[p - 1][idx]];
}
int getpre(int idx, int cnt) {
    int res = idx;
    rep(p, 0, 20) if (cnt & (1 << p)) res = pre[p][res];
    return res;
}
int getpost(int idx, int cnt) {
    int res = idx;
    rep(p, 0, 20) if (cnt & (1 << p)) res = post[p][res];
    return res;
}
//---------------------------------------------------------------------------------------------------
ll ans = 0;
void f(int L, int R) {
    if (L == R) return;
    if (L + 1 == R) return;
    auto ma = st.get(L, R);

    int C = ma.second;

    {
        int ok = 0, ng = N + 1;
        while (ok + 1 != ng) {
            int md = (ok + ng) / 2;
            if (L <= getpre(C, md)) ok = md;
            else ng = md;
        }
        ans += ok;
    }

    {
        int ok = 0, ng = N + 1;
        while (ok + 1 != ng) {
            int md = (ok + ng) / 2;
            if (getpost(C, md) < R) ok = md;
            else ng = md;
        }
        ans += ok;
    }

    f(L, C);
    f(C + 1, R);
}
//---------------------------------------------------------------------------------------------------
void _main() {
    cin >> N;
    rep(i, 0, N) cin >> A[i];

    init();

    vector<pair<int, int>> dic;
    dic.push_back({ -1, -1 });
    rep(i, 0, N) dic.push_back({ A[i], i + 1 });
    st.init(dic);

    f(1, N + 1);
    cout << ans << endl;
}