生成エンジンの内側:model.generate() がテキストを生成する仕組み
前提知識
- ›第3回:モデルの構造(フォワードパス、アテンション)
- ›第4回:重みの読み込み(推論可能な状態のモデル)
- ›自己回帰 Transformer における KV キャッシュの概念
- ›サンプリングの基礎:temperature、top-k、top-p
生成エンジンの内側:model.generate() がテキストを生成する仕組み
第4回でモデルの読み込みを終えたところで、今回はユーザー向け API の中で最もパフォーマンスに直結する model.generate() を深く掘り下げます。この単一メソッドは約1700行からなるオーケストレーターで、greedy decoding、multinomial サンプリング、beam search、投機的(アシスト)デコーディング、トークンストリーミングなど、多岐にわたる処理を担っています。composable な logits 処理パイプライン、複数の cache バックエンド、そして停止条件の管理も行いつつ、本番環境のスループット向上のために torch.compile との互換性も維持しています。
この記事では、GenerationConfig の検証から自己回帰ループの本体まで、生成フロー全体を順を追って解説します。特に KV キャッシュシステムと、2〜3倍の高速化をもたらすアシストデコーディングの仕組みに焦点を当てます。
GenerationConfig とモード選択
generate() の呼び出しはすべて GenerationConfig から始まります。このクラスにはデコーディングに関するすべてのパラメータが集約されています。max_new_tokens、temperature、top_k、top_p、num_beams、do_sample、repetition_penalty など、数多くの設定値を保持しています。
モードの選択は、以下の優先順位に従って決定されます。
flowchart TD
A["generate() called"] --> B["Resolve GenerationConfig"]
B --> C{"assistant_model<br/>provided?"}
C -->|Yes| D["Assisted/Speculative<br/>Decoding"]
C -->|No| E{"num_beams > 1?"}
E -->|Yes| F{"do_sample?"}
F -->|Yes| G["Beam Search<br/>Multinomial Sampling"]
F -->|No| H["Beam Search"]
E -->|No| I{"do_sample?"}
I -->|Yes| J["Multinomial<br/>Sampling"]
I -->|No| K["Greedy<br/>Decoding"]
GenerationMixin の generate() メソッドはまず、受け取った kwargs を generation config にマージし、パラメータの組み合わせを検証します(例:temperature=0 は greedy を意味します)。その後、適切な内部メソッドへ処理を委譲します。
GenerationConfig は custom_generate もサポートしています。これは Hub のリポジトリ名を表す文字列または callable を受け取り、生成ループ全体を丸ごと置き換えることができます。この拡張ポイントにより、研究者はライブラリをフォークすることなく、独自のデコーディングアルゴリズムを試せるようになっています。
ヒント:
generate()に kwargs を直接渡すのではなく、generation_config=GenerationConfig(...)として明示的に渡しましょう。config のマージ処理が省かれるだけでなく、デコーディングパラメータを呼び出し間で再利用しやすくなります。
KV キャッシュシステム
自己回帰生成では、KV プロジェクションをキャッシュしない限り、ステップごとにすべての過去トークンのアテンションを再計算することになります。Transformers はこれを解決するため、階層的な cache システムを実装しています。
classDiagram
class CacheLayerMixin {
<<abstract>>
+keys: Tensor
+values: Tensor
+is_initialized: bool
+update(key_states, value_states)
+get_seq_length() int
+offload()
+prefetch()
}
class Cache {
+layers: list[CacheLayerMixin]
+layer_class_to_replicate: type
+offloading: bool
+update(key, value, layer_idx)
+get_seq_length()
}
class DynamicCache {
«grows as needed»
}
class StaticCache {
«fixed size, compile-friendly»
}
class QuantizedCache {
«INT8 keys/values»
}
CacheLayerMixin --o Cache : layers
Cache <|-- DynamicCache
Cache <|-- StaticCache
Cache <|-- QuantizedCache
CacheLayerMixin はレイヤーごとの抽象化を担うクラスです。不要なタイミングで KV ペアを CPU へオフロードしたり、パイプライン実行に備えて事前にプリフェッチしたりする機能をサポートしています。
Cache 基底クラスは、モデルの各レイヤーに対応する CacheLayerMixin オブジェクトのコンテナです。事前に確保されたレイヤーから構築することも、layer_class_to_replicate を通じて遅延的に拡張することもできます。
主な実装は3種類あります。
DynamicCache— テンソルを動的に拡張していく実装です。柔軟ですが動的な形状を持つため、torch.compileとの相性はよくありません。StaticCache— 固定サイズのバッファを事前に確保します。torch.compileに対応しており、CUDA graph のキャプチャにも必要です。QuantizedCache— キーとバリューを INT8 で保存し、メモリ使用量を約50%削減します。精度にはわずかなトレードオフがあります。
第3回で見たように、LlamaModel.forward() は use_cache=True のとき、デフォルトで DynamicCache を生成します。この cache はすべての decoder レイヤーのアテンション処理を流れながら、KV ペアを蓄積していきます。
Logits 処理パイプライン
モデルが出力した生の logits から最終的なトークン選択に至るまでの間に、composable な LogitsProcessor オブジェクトのパイプラインがスコアを変換していきます。
flowchart LR
A["Raw logits<br/>[batch, vocab]"] --> B["TemperatureLogitsWarper"]
B --> C["TopKLogitsWarper"]
C --> D["TopPLogitsWarper"]
D --> E["RepetitionPenaltyProcessor"]
E --> F["NoBadWordsLogitsProcessor"]
F --> G["Final scores"]
G --> H["torch.multinomial<br/>or argmax"]
LogitsProcessorList は list のサブクラスで、__call__ によってプロセッサを順番に呼び出します。各プロセッサは (input_ids, scores) を受け取り、変換後のスコアを返します。パイプラインは GenerationConfig のパラメータから次のように組み立てられます。
temperature→TemperatureLogitsWarpertop_k→TopKLogitsWarpertop_p→TopPLogitsWarperrepetition_penalty→RepetitionPenaltyLogitsProcessorno_repeat_ngram_size→NoRepeatNGramLogitsProcessor
generate() の logits_processor 引数にカスタムプロセッサを渡すと、組み込みのプロセッサの後ろに追加されます。
ヒント: プロセッサの順序は重要です。最初に temperature スケーリング(分布全体に影響します)を行い、続いてフィルタリング(top-k/top-p)、そしてペナルティという順番が正しい並びです。Transformers はデフォルトで正しい順序に組み立てますが、カスタムプロセッサは末尾に追加されることに注意してください。
投機的デコーディング / アシストデコーディング
CandidateGenerator システムは投機的デコーディングを実装しています。小型で高速なモデルが候補トークンを先に生成し、メインモデルがそれを並列で検証するという仕組みです。
sequenceDiagram
participant Main as Main Model (70B)
participant Draft as Draft Model (1B)
participant Verify as Verification
Note over Draft: Generate K candidate tokens
Draft->>Draft: token_1, token_2, ..., token_K
Draft-->>Main: Candidate sequence
Main->>Main: Run forward pass on<br/>all K+1 positions simultaneously
Main-->>Verify: Logits for each position
Verify->>Verify: Compare draft vs main<br/>Accept matching tokens
Verify-->>Main: Accepted: token_1..token_j<br/>Rejected from token_j+1
Note over Main: Only 1 forward pass<br/>for up to K tokens!
AssistedCandidateGenerator が draft モデルを動かして候補を生成します。メインモデルはそれらをまとめて1回のフォワードパスで検証します(アテンションは causal であるため、すべての位置を同時にチェックできます)。受理されたトークンはそのまま採用され、最初に棄却されたトークンの位置からはメインモデルの分布を使って再サンプリングします。
これにより 2〜3倍のスループット向上が実現します。draft モデルのフォワードパスはコストが低く、メインモデルの検証はバッチ処理されるためです。重要な特性として、棄却サンプリングの仕組みにより出力の分布は通常のサンプリングと数学的に等価です。
Transformers は AssistedCandidateGeneratorDifferentTokenizers(341行目)もサポートしており、draft モデルとメインモデルが異なるトークナイザーを使っている場合のトークンレベルのアライメントにも対応しています。
ストリーミングと停止条件
インタラクティブなアプリケーションでは、生成が完了するまで待ち続けることは現実的ではありません。BaseStreamer インターフェースによって、生成されたトークンをリアルタイムに出力できます。
class BaseStreamer:
def put(self, value): ... # Called with new token IDs
def end(self): ... # Called when generation finishes
TextStreamer はトークンをデコードしてリアルタイムに stdout へ出力します。TextIteratorStreamer はトークンをキューに入れて非同期に消費できるようにします(Web サーバーとの組み合わせに最適です)。AsyncTextIteratorStreamer はさらに async イテレーション対応を加えたものです。
生成の終了は StoppingCriteria によって制御されます。
flowchart TD
A["After each generation step"] --> B["StoppingCriteriaList"]
B --> C["MaxLengthCriteria"]
B --> D["MaxTimeCriteria"]
B --> E["EosTokenCriteria"]
B --> F["StopStringCriteria"]
B --> G["Custom criteria"]
C --> H{"Any criterion<br/>returns True?"}
D --> H
E --> H
F --> H
G --> H
H -->|Yes| I["Stop generation"]
H -->|No| J["Continue"]
StoppingCriteriaList はトークンが生成されるたびにすべての条件をチェックします。組み込みの条件としては、最大長、最大実行時間、EOS トークンの検出、停止文字列のマッチングがあります。カスタム条件では、生成済みのシーケンス全体とスコアを参照することもできます。
メインループ
ここまでの要素をまとめると、自己回帰ループの核心部分(簡略化)は次のようになります。
while not stopping_criteria(input_ids, scores):
# 1. Run model forward pass
outputs = model(input_ids, past_key_values=cache, ...)
# 2. Extract next-token logits
next_token_logits = outputs.logits[:, -1, :]
# 3. Process logits
next_token_scores = logits_processor(input_ids, next_token_logits)
# 4. Select next token
if do_sample:
next_tokens = torch.multinomial(probs, num_samples=1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
# 5. Update input_ids and stream
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
if streamer is not None:
streamer.put(next_tokens)
実際のループではバッチ次元の追跡、beam の管理、分散生成のための GPU 同期、アシストデコーディングの検証処理なども行われています。しかし基本的なパターンは常に同じです。フォワード → 処理 → 選択 → 追加、この繰り返しです。
ディレクトリ構成
| ファイル | 役割 |
|---|---|
src/transformers/generation/utils.py |
GenerationMixin.generate() — 約1700行のオーケストレーター |
src/transformers/generation/configuration_utils.py |
GenerationConfig — すべてのデコーディングパラメータ |
src/transformers/generation/logits_process.py |
Composable な logits 処理パイプライン |
src/transformers/generation/candidate_generator.py |
投機的/アシストデコーディング |
src/transformers/generation/streamers.py |
リアルタイムトークンストリーミング |
src/transformers/generation/stopping_criteria.py |
生成の終了条件 |
src/transformers/cache_utils.py |
KV キャッシュの階層構造 |
生成(推論)の仕組みはここまでです。しかし Transformers は訓練ライブラリとしての側面も同様に重要です。次回は Trainer クラスを掘り下げます。分散バックエンド、コールバック、loss 関数レジストリを備えた約4400行の訓練ループオーケストレーターです。