Read OSS

デコードループ:バッチ処理、KVキャッシュ、メモリ管理

上級

前提知識

  • 第1〜3回の記事
  • トランスフォーマー推論におけるKVキャッシュの理解

デコードループ:バッチ処理、KVキャッシュ、メモリ管理

第1〜3回では、libllama がモデル定義から GGML の計算グラフを構築し、ハードウェアバックエンド上で実行するアーキテクチャを解説しました。しかし、システムの中でも最も複雑な部分(バッチ管理、KVキャッシュのスロット確保、マイクロバッチ処理、そして障害からのリカバリを担うデコードパイプライン)については、まだ詳しく触れていませんでした。

この記事では、llama_decode() から logit の抽出に至るまでの処理の流れを、コンテキスト構築、バッチ分割パイプライン、process_ubatch() のコアループ、ポリモーフィックなメモリ抽象化、KVキャッシュの内部構造、そしてエラーリカバリを軸に追っていきます。

コンテキストの構築

llama_init_from_model() を呼び出すと、llama_context のコンストラクタはいくつかの処理を実行します。

  1. パラメータのマージ — ユーザーが指定した llama_context_params とモデルのハイパーパラメータを統合する。たとえば n_ctx が 0 の場合、モデルの学習時コンテキスト長がデフォルト値となる。

  2. メモリの作成model.create_memory() を呼び出し、適切なメモリ実装(KVキャッシュ、リカレント状態、またはハイブリッド)を確保する。どの実装が使われるかはモデルアーキテクチャで決まる。

  3. バッチアロケータの作成 — バッチのサニタイズと分割を担う llama_batch_allocr を初期化する。

  4. コンピュートバッファの確保memory->init_full() を使って最悪ケースのグラフを構築し、バックエンドスケジューラで実行して必要な最大コンピュートバッファサイズを算出する。

推論に最も大きな影響を与えるコンテキストパラメータは以下のとおりです。

パラメータ 効果
n_ctx コンテキストウィンドウの総サイズ(KVキャッシュが保持できるトークン数)
n_batch 最大論理バッチサイズ(ユーザーが送信するバッチ)
n_ubatch 最大物理バッチサイズ(グラフ実行ごとに GPU に渡されるバッチ)
n_seq_max 最大同時シーケンス数
type_k, type_v KVキャッシュのデータ型(F16、Q8_0、Q4_0 など)
flash_attn_type flash attention カーネルを使用するかどうか
flowchart TD
    PARAMS["llama_context_params"] --> CTX["llama_context constructor"]
    CTX --> MERGE["Merge with hparams"]
    CTX --> MEM["model.create_memory()"]
    CTX --> BALLOC["Create batch_allocr"]
    CTX --> SCHED["Reserve compute buffers\n(worst-case graph)"]
    MEM --> KV["llama_kv_cache"]
    MEM --> REC["llama_memory_recurrent"]
    MEM --> HYB["llama_memory_hybrid"]

バッチパイプライン

計算が始まる前に、ユーザーが送信した llama_batch はサニタイズ・バリデーションされ、マイクロバッチに分割されます。このパイプラインは3つのステージで構成されています。

ステージ1:サニタイズ。 llama_batch_allocr はユーザーから渡された生のバッチを受け取り、欠落フィールドを自動補完します。pos が NULL の場合、メモリモジュール内のシーケンスの連続性をもとに位置が自動生成されます。seq_id が NULL の場合、すべてのトークンはシーケンス 0 に割り当てられます。logits が NULL の場合、最後のトークンだけが出力されます。

ステージ2:メモリの初期化。 メモリモジュールの init_batch() が呼び出されます。ここでバッチは n_ubatch に収まる ubatch に分割され、KVキャッシュ(またはリカレント状態)に空きがあるか検証されます。失敗した場合、コンテキストはデフラグを試みてリトライすることがあります。

ステージ3:ubatch のイテレーション。 デコードループは llama_memory_context_i インターフェース経由で ubatch を順に処理します。next() で次へ進み、get_ubatch() で各 ubatch を取り出します。

llama_ubatch 構造体は、デコードパイプライン全体で使われる内部バッチ表現です。

struct llama_ubatch {
    uint32_t n_tokens;      // total tokens
    uint32_t n_seq_tokens;  // tokens per sequence set
    uint32_t n_seqs;        // sequence sets
    uint32_t n_seqs_unq;    // unique sequences
    
    llama_token  * token;   // [n_tokens]
    llama_pos    * pos;     // [n_tokens * n_pos]
    llama_seq_id ** seq_id; // [n_tokens]
    int8_t       * output;  // [n_tokens] which positions need logits
};

ヒント: 大きなバッチサイズを扱う場合、n_batchn_ubatch の違いは非常に重要です。ユーザーが一度に 512 トークンを送信した場合(n_batch)でも、GPU が一度に処理できるのは 128 トークンだけ(n_ubatch)ということがあります。バッチパイプラインはこの差を透過的に吸収し、処理を4つの ubatch に分割します。

process_ubatch:コアループ

デコードの中心となるループは llama_context::decode() に実装されています。バッチの準備とメモリの初期化が完了すると、do/while ループに入ります。

sequenceDiagram
    participant D as decode()
    participant M as memory_context
    participant P as process_ubatch()
    participant G as model.build_graph()
    participant S as Backend Scheduler

    D->>M: init_batch(balloc, n_ubatch)
    loop For each ubatch
        D->>M: get_ubatch()
        D->>P: process_ubatch(ubatch, mctx)
        P->>M: apply() — commit to memory
        P->>P: Check graph reuse
        alt Graph can be reused
            P->>P: set_inputs() only
        else Build new graph
            P->>G: build_graph(params)
            P->>S: alloc_graph(gf)
        end
        P->>P: set_inputs(ubatch)
        P->>S: graph_compute(gf)
        P-->>D: logits
        D->>D: Extract logits/embeddings
        D->>M: next() — advance to next ubatch
    end

process_ubatch() メソッドは、グラフ構築と実行が交差する場所です。

  1. メモリ状態の適用mctx->apply() が現在の ubatch に対する KVキャッシュのスロット割り当てをメモリモジュールにコミットします。メモリが変更されるのはこの時点だけです。

  2. グラフの再利用チェック — 前回のグラフと同じトポロジー(同じ ubatch の形状と設定)であれば、グラフ構築とスケジューラへの割り当てを完全にスキップします。

  3. グラフの構築または再利用model.build_graph() で新しいグラフを構築してスケジューラに割り当てるか、前回のグラフをそのまま再利用します。

  4. 入力のセット — ubatch のデータをもとに、すべての入力テンソル(トークン ID、位置、マスク、KV インデックス)を埋めます。

  5. 計算の実行ggml_backend_sched_graph_compute() を通じてグラフを実行します。

メモリ抽象化の階層

第1回で見たように、推論を制御する3つの重要な型があります。しかし、もう一つ重要なプレイヤーがいます——メモリインターフェースです。llama_memory_i はすべてのメモリ実装の抽象基底クラスです。

struct llama_memory_i {
    virtual llama_memory_context_ptr init_batch(...) = 0;
    virtual llama_memory_context_ptr init_full() = 0;
    virtual llama_memory_context_ptr init_update(...) = 0;
    virtual bool get_can_shift() const = 0;
    virtual void seq_rm(llama_seq_id, llama_pos, llama_pos) = 0;
    // ...
};

llama_modelcreate_memory() ファクトリが、適切な実装を選択します。

classDiagram
    class llama_memory_i {
        <<interface>>
        +init_batch()
        +init_full()
        +init_update()
        +seq_rm()
    }
    class llama_kv_cache {
        Standard transformer KV cache
        Cell-based slot allocation
    }
    class llama_kv_cache_iswa {
        Interleaved sliding window
        Two sub-caches (base + SWA)
    }
    class llama_memory_recurrent {
        SSM/RWKV state buffers
        Fixed slots per sequence
    }
    class llama_memory_hybrid {
        KV cache + recurrent state
        Per-layer type selection
    }
    class llama_memory_hybrid_iswa {
        iSWA cache + recurrent state
    }
    llama_memory_i <|-- llama_kv_cache
    llama_memory_i <|-- llama_kv_cache_iswa
    llama_memory_i <|-- llama_memory_recurrent
    llama_memory_i <|-- llama_memory_hybrid
    llama_memory_i <|-- llama_memory_hybrid_iswa

選択ロジックはシンプルです。

  • キャッシュなしのモデル(BERT、DREAM など) → nullptr
  • 純粋なリカレントモデル(Mamba、RWKV) → llama_memory_recurrent
  • ハイブリッドモデル(Jamba、Falcon-H1) → llama_memory_hybrid または llama_memory_hybrid_iswa
  • スライディングウィンドウアテンションを使うモデル → llama_kv_cache_iswa
  • 標準的なトランスフォーマー → llama_kv_cache

KVキャッシュの詳細

llama_kv_cache は最もよく使われるメモリ実装です。KVペアを「セル」単位で管理する固定サイズのバッファを持ち、1つのセルが1つの位置に対応します。

llama_kv_cells データ構造は、各セルのメタデータを追跡します。

  • pos[] — このセルに格納されているシーケンス位置(空の場合は -1)
  • seq[] — このセルを使用しているシーケンスのビットセット(セルは共有可能)
  • shift[] — コンテキストシフト用の保留中の位置シフト
  • used — 使用中のセルインデックスの集合
  • seq_pos[] — シーケンスごとの「位置 → セルインデックス」マップ

init_batch() が呼び出されると、KVキャッシュは新しいトークンのために連続した(あるいは少なくとも利用可能な)スロットを探す必要があります。各トークンの KVペアを格納するセルのインデックスを含む slot_info 構造体が作成され、そのインデックスは llm_graph_input_attn_kvself_k_idxs テンソルと self_v_idxs テンソルとして計算グラフに渡されます。これらのテンソルが ggml_set_rows 演算に対して、新しい K/V の値をどこに書き込むかを指示します。

デフラグは memory_update(optimize=true) が呼び出されたときに実行されます。セルを移動してギャップを埋めることで KVキャッシュをコンパクト化し、専用のデフラググラフをバックエンドスケジューラ経由で構築・実行します。init_batch()LLAMA_MEMORY_STATUS_FAILED_PREPARE で失敗した場合、デコードループは自動的にデフラグを試みます。

flowchart LR
    subgraph "KV Cache Cells"
        C0["Cell 0\nSeq 0, Pos 0"]
        C1["Cell 1\nSeq 0, Pos 1"]
        C2["Cell 2\n(empty)"]
        C3["Cell 3\nSeq 1, Pos 0"]
        C4["Cell 4\nSeq 0, Pos 2"]
        C5["Cell 5\n(empty)"]
        C6["Cell 6\nSeq 1, Pos 1"]
    end
    
    subgraph "After Defrag"
        D0["Cell 0\nSeq 0, Pos 0"]
        D1["Cell 1\nSeq 0, Pos 1"]
        D2["Cell 2\nSeq 0, Pos 2"]
        D3["Cell 3\nSeq 1, Pos 0"]
        D4["Cell 4\nSeq 1, Pos 1"]
        D5["Cell 5\n(empty)"]
        D6["Cell 6\n(empty)"]
    end

エラーリカバリとステートロールバック

デコードパイプラインは障害を適切に処理できるよう設計されています。llama_memory_context_i インターフェースは apply/commit プロトコルを強制します。

struct llama_memory_context_i {
    virtual bool next() = 0;      // advance to next ubatch
    virtual bool apply() = 0;     // commit current ubatch to memory
    virtual const llama_ubatch & get_ubatch() const = 0;
    virtual llama_memory_status get_status() const = 0;
};

apply() メソッドはメモリ状態が変更される唯一のポイントです。しかもそれは、グラフ計算のprocess_ubatch() の中で実行されます。グラフ計算が後で失敗した場合(アロケーションエラー、アボートコールバックなど)、デコードループはコミット済みの位置を削除することでロールバックを行います。

// from decode() error handling
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
    pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
}
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
    memory->seq_rm(s, pos_min[s], -1);
}

このアプローチは llama-context.cpp の 1690〜1711行目にあります。影響を受けた各シーケンスについて pos_min 以降の全位置を削除することで、部分的に処理されたバッチのデータが KVキャッシュに残らないようにします。

decode 関数は失敗モードに応じて異なるエラーコードを返します。

  • 1 — キャッシュが満杯(シーケンスを解放するか、バッチサイズを縮小してください)
  • -2 — アロケーションエラーまたは内部エラー
  • -3 — 計算の失敗
  • 2 — コールバックによる中断

ヒント: llama_decode() がリターンコード 1 を返す場合、KVキャッシュが満杯になっています。llama_memory_seq_rm() を呼び出して古いシーケンスを解放するか、n_predict を小さくするか、コンテキスト作成時に n_ctx を増やしてください。

次回のテーマ

これで、llama_decode() の呼び出しから KVキャッシュのセル割り当てまでの完全な流れを追いました。次回はライブラリ層からアプリケーション層へと視点を移します。HTTP サーバーと CLI ツールがどのように libllama をラップして実際のユーザーへのサービスを提供しているか、そして CLI をサーバーインフラの薄いラッパーにしているアーキテクチャ上の意外な決断について解説します。