桁DPを覚えた

ABC154のE問題 で桁DPが出たのでこの機会に桁DPを覚えたいということでリサーチをかけました.

目についた記事がこちら qiita.com

問題としては$1 \leq i \leq N$の範囲の整数の中で10新数表示にしたときに3が現れるものの数を求めろという話.

おそらく$N$は$10^{100}$という感じの値になっていらっしゃるのでなかなか桁ごとに見る必要がある. ここで課題になってくるのは桁ごとに見ていく時にどうやって$N$未満になるようなものの数を数えられるかということ.

最初コードを読んでもナニコレという感じだったが, 理解した. 記事のコード中にsmallerというbool値があるが, これはsmaller/largerということではなく, smaller/sameということ. 今までに見てきたxの情報を抽象化してDPの状態に落とし込んでいる. "上から$i$桁がNの上から$i$桁と一致しているかソレより小さいかという情報"さえもてば, $N$未満かどうか判定ができる.

桁DPの高レベル記述(?)は以下のような感じ.

基本的には一番上の桁から下りながら, 1-9をそれぞれ調べていく.
上からi桁目までの結果smallerだった時は上から(i+1)はどうやってもsmaller.
上からi桁目までsmallerではないつまりsameだった場合,
     Nのi桁目より小さいものを選べば
         i+1はsmaller
     Nのi桁目と同じものを選べば
         i+1はsame

上の記事の内容を基にして, ABC154のE問題 を通すために書いたコードがこちら.

#include <bits/stdc++.h>
using namespace std;
#define int long long

// 初期条件は dp[0桁目][same][K = 0] = 1
// dp[i][smaller][k] = i桁目まで見た時にsmallerで0でないものがk個だったものの数.

// 桁 / smaller(else: same) / k, k = 4のときは0出ないものが4個以上ある状態
int dp[101][2][5];
signed main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    string N;
    int K;
    cin >> N >> K;
    int l = N.length();
    vector<int> n;
    for (auto c: N) {
        n.push_back(c - '0');
    }
    // 桁 smaller
    dp[0][0][0] = 1;
    for (int i = 0; i < l; i++) {
        for (int smaller = 0; smaller < 2; smaller++) {
            for (int k = 0; k <= K; k++) {
                // x: 各桁の数
                for (int x = 0; x <= (smaller ? 9 : n[i]); x++) {
                   // 前回の状態数から次の状態数を計算する.
                    int next_smaller = (int)(smaller || x < n[i]);
                    int next_k = k + (int)(x != 0);
                    dp[i + 1][next_smaller][next_k] += dp[i][smaller][k];
                }
            }
        }
    }

 // すべての桁をみたときにちょうど0出ない数がK個あるものの数.
   // smallerであろうがsameであろうがOkってこっちゃ.
    cout << dp[l][0][K] + dp[l][1][K] << endl;
    return 0;
}

桁DPという概念は知っていたがコンテスト中に頭の中を整理しきれなかったのが悔しいところ. ほどほどに競プロ知識スタック.pushしていきたい.