はまやんはまやんはまやん

hamayanhamayan's blog

PopCount [yukicoder No.737]

https://yukicoder.me/problems/no/737

考察過程

1. まず目につくのがpopcountの組み合わせ数の少なさである
2. 60通りくらいしかないので、問題を言い換えることができそう
3. 2進数表記にしたときに1の数がx個であり、N以下の数の総和をsum[x]とする
4. すると答えはsum[x]*popcount(x)の総和になる
5. このsumは桁DPで求められそう

解法

https://yukicoder.me/submissions/288502

今回の問題は「2進数表記にしたときに1の数がx個であり、N以下の数の総和をsum[x]」とすると、
sum[x]*popcount(x)の総和を求める問題になる。
配列sumは桁DPを使って計算していく。
 
配列を2つ用意する。
comb[dgt][tot][isless] := dgt桁目まで確定していて、1の数がtot個、現時点でN以下かどうかがislessである数の組み合わせ数
sum[dgt][tot][isless] := dgt桁目まで確定していて、1の数がtot個、現時点でN以下かどうかがislessである数の総和
最初はcomb[0][0][0] = 1, sum[0][0][0] = 0
 
遷移は次の桁に0を置くか、1を置くかの2通り。
組み合わせ数は桁DPだと頻出なので、遷移は書きやすいかもしれない。
総和の方は、まず×2をして、すべての数を1つ左にビットシフトさせる。
次に、次に1を置く場合は、全ての数の下1ビットを1にするので、+1することになるが、

  1. 1回数は組み合わせ数と等しいので、nxtが1の場合は組み合わせ数を足し合わせる。

 
最後にpopcountとの積の総和を取れば答え。

ll N;
string B; int NB;
mint comb[61][62][2];
mint sum[61][62][2];
//---------------------------------------------------------------------------------------------------
string to_binary(ll N, int digit = 0) {
    string res = "";
    while (N) {
        if (N & 1) res = "1" + res;
        else res = "0" + res;
        N >>= 1;
    }
    if (0 < digit) {
        int n = digit - res.length();
        rep(i, 0, n) res = "0" + res;
    }
    return res;
}
//---------------------------------------------------------------------------------------------------
void _main() {
    cin >> N;

    B = to_binary(N);
    NB = B.length();

    comb[0][0][0] = 1;
    sum[0][0][0] = 0;

    rep(dgt, 0, NB) rep(tot, 0, NB + 1) rep(isless, 0, 2) {
        int c = B[dgt] - '0';
        rep(nxt, 0, 2) {
            if (isless == 0 and c < nxt) continue;
            int isless2 = isless;
            if (nxt < c) isless2 = 1;
            comb[dgt + 1][tot + nxt][isless2] += comb[dgt][tot][isless];
            sum[dgt + 1][tot + nxt][isless2] += sum[dgt][tot][isless] * 2 + comb[dgt][tot][isless] * nxt;
        }
    }

    mint ans = 0;
    rep(tot, 0, NB + 2) rep(isless, 0, 2) ans += sum[NB][tot][isless] * tot;
    cout << ans << endl;
}