はまやんはまやんはまやん

hamayanhamayan's blog

Fafa and Ancient Mathematics [Codeforces Round #465 E]

http://codeforces.com/contest/935/problem/E

以下のルールの文字列がある。

  • dは1桁の自然数であり、dのみでアーメス式
  • E1,E2がアーメス式なら(E1 op E2)もアーメス式。opは+か-

 
アーメス式のopの部分が全て?となっている式Eがある。
この?に+をP個、-をM個入れて、アーメス式を評価したときの数を最大化せよ。

前提知識

解法

http://codeforces.com/contest/935/submission/35495534

まずは、与えられたアーメス式を構文解析して、木構造にしておく。
parse関数でパースしている。
この辺が参考になりそう?
一番簡単な再帰で書いてくやり方で書いている。
木構造にする時にノードに一意なIDを割り当てておく。
これは後でメモ化再帰をする時に必要になる。
 
あとは、この木構造演算子の所を+か-に確定させて最大値を求めればいい。
この問題の類題がyukicoderにある。
この問題の解法の方針は、「途中経過の最大と最小だけを使って計算していく」という方針である。
これを今回でも採用することにしよう。

「f(n, p) := 頂点n以降で+をp個使った時の数の{最小値, 最大値}」
これを計算していくが、?を+にした場合と-にした場合のそれぞれについて、左の頂点でpをpls個使うというのを全探索する。
全ての可能性について計算し、そのうちの最小と最大を答える。
そのままやるとTLEだが、メモ化をすることで回避しよう。
関数fの状態数を考えると、O(NP)となる
Nは|E|=10^4を考えると、10^4/3位になると想像できる。
PもNと同じ位になる可能性がある。
遷移はO(P)だけある。
このままだと、10^4*10^4*10^4となって計算量的には心もとない。
ここで、いかにもな制約である「min(P,M)≦100」を使おう。
memo[n][p]だが、M<Pの場合はmemo[n][m]でやるようにする。
このように場合分けすることで、10^4*10^2*10^2となるので十分間に合う。
M<Pの場合は関数gとしてある(適当に書いてしまったので、関数g内のコメントは間違っている)。

string E; int P, M;
//---------------------------------------------------------------------------------------------------
struct Node {
    Node *left, *right;
    int value, id, cnt;
};
//---------------------------------------------------------------------------------------------------
int i = 0, idx = 0;
Node* parse() {
    Node *n = new Node;
    n->id = idx;
    idx++;

    if (E[i] == '(') {
        i++;
        n->left = parse();
        i++;
        n->right = parse();
        i++;
        n->value = -1;
        n->cnt = n->left->cnt + n->right->cnt + 1;
    } else {
        n->value = E[i] - '0';
        i++;
        n->cnt = 0;
    }

    return n;
}
//---------------------------------------------------------------------------------------------------
int vis[4101][105];
pair<int,int> memo[4101][105];
pair<int,int> f(Node* n, int p) { // res {min, max}
    int m = n->cnt - p;
    if (m < 0) return { inf, inf };

    int id = n->id;
    if (vis[id][p]) return memo[id][p];

    if (n->cnt == 0) return memo[id][p] = { n->value, n->value };

    int mi = inf, ma = -inf;
    
    // 左+右
    rep(pls, 0, p) {
        auto le = f(n->left, pls);
        if (le.first == inf) continue;
        auto ri = f(n->right, p - pls - 1);
        if (ri.first == inf) continue;

        vector<int> a = { le.first, le.second };
        vector<int> b = { ri.first, ri.second };
        rep(i, 0, 2) rep(j, 0, 2) {
            int res = a[i] + b[j];
            chmin(mi, res);
            chmax(ma, res);
        }
    }
    
    // 左-右
    rep(pls, 0, p + 1) {
        auto le = f(n->left, pls);
        if (le.first == inf) continue;
        auto ri = f(n->right, p - pls);
        if (ri.first == inf) continue;

        vector<int> a = { le.first, le.second };
        vector<int> b = { ri.first, ri.second };
        rep(i, 0, 2) rep(j, 0, 2) {
            int res = a[i] - b[j];
            chmin(mi, res);
            chmax(ma, res);
        }
    }

    assert(mi != inf);

    //printf("%d -> %d,%d\n", id, mi, ma);
    vis[id][p] = 1;
    return memo[id][p] = { mi, ma };
}
pair<int, int> g(Node* n, int p) { // res {min, max}
    int m = n->cnt - p;
    if (m < 0) return { inf, inf };

    int id = n->id;
    if (vis[id][p]) return memo[id][p];

    if (n->cnt == 0) return memo[id][p] = { n->value, n->value };

    int mi = inf, ma = -inf;

    // 左+右
    rep(pls, 0, p) {
        auto le = g(n->left, pls);
        if (le.first == inf) continue;
        auto ri = g(n->right, p - pls - 1);
        if (ri.first == inf) continue;

        vector<int> a = { le.first, le.second };
        vector<int> b = { ri.first, ri.second };
        rep(i, 0, 2) rep(j, 0, 2) {
            int res = a[i] - b[j];
            chmin(mi, res);
            chmax(ma, res);
        }
    }

    // 左-右
    rep(pls, 0, p + 1) {
        auto le = g(n->left, pls);
        if (le.first == inf) continue;
        auto ri = g(n->right, p - pls);
        if (ri.first == inf) continue;

        vector<int> a = { le.first, le.second };
        vector<int> b = { ri.first, ri.second };
        rep(i, 0, 2) rep(j, 0, 2) {
            int res = a[i] + b[j];
            chmin(mi, res);
            chmax(ma, res);
        }
    }

    assert(mi != inf);

    //printf("%d -> %d,%d\n", id, mi, ma);
    vis[id][p] = 1;
    return memo[id][p] = { mi, ma };
}
//---------------------------------------------------------------------------------------------------
void _main() {
    cin >> E >> P >> M;
    E += "=";

    Node *root = parse();

    if (P <= M) {
        auto res = f(root, P);
        int ans = res.second;
        cout << ans << endl;
    }
    else {
        auto res = g(root, M);
        int ans = res.second;
        cout << ans << endl;
    }
}