Read OSS

Whisper のデコーダー内部:Beam Search、Logit フィルター、KV キャッシュ

上級

前提知識

  • 記事 1〜3
  • Beam Search と温度サンプリングの理解
  • PyTorch の forward フック

Whisper のデコーダー内部:Beam Search、Logit フィルター、KV キャッシュ

826 行に及ぶ whisper/decoding.py は、コードベース最大のモジュールです。それも当然で、このファイルはエンコーダーの特徴量をトークン列へと変換する自己回帰生成を一手に担っています。Whisper の設計思想がもっとも色濃く表れているのもこのモジュールです。デコーダーには Strategy パターン、logit フィルターには Chain of Responsibility パターンと、整理された抽象化が DecodingTask という単一のオーケストレーターによって組み合わされています。

この記事では、主要なコンポーネントをひとつずつ解説します。設定用の dataclass、forward フックによる KV キャッシュの実装、トークンデコーダーの継承階層、logit フィルターのパイプライン、そしてすべてを束ねるメインループです。

DecodingOptions と DecodingResult

デコードシステムの設定は、frozen dataclass である DecodingOptions を通じて行います。サンプリング温度からタイムスタンプの挙動まで、あらゆる動作をここで制御できます。

@dataclass(frozen=True)
class DecodingOptions:
    task: str = "transcribe"
    language: Optional[str] = None
    temperature: float = 0.0
    sample_len: Optional[int] = None
    best_of: Optional[int] = None
    beam_size: Optional[int] = None
    patience: Optional[float] = None
    length_penalty: Optional[float] = None
    prompt: Optional[Union[str, List[int]]] = None
    prefix: Optional[Union[str, List[int]]] = None
    suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
    suppress_blank: bool = True
    without_timestamps: bool = False
    max_initial_timestamp: Optional[float] = 1.0
    fp16: bool = True

frozen=True を指定することで、デコード処理中に誤って値が書き換えられるのを防いでいます。出力側は DecodingResult に格納されます。トークン、テキスト、音声特徴量、確率のほか、compression_rationo_speech_prob といった品質指標も含まれます。

suppress_tokens のデフォルト値が文字列 "-1" であることに注目してください。このセンチネル値は、記事 3 で紹介した non_speech_tokens セットの抑制をトリガーします。音符記号やスピーカータグといったアノテーション系のトークンをモデルが生成しないようにする、便利なデフォルト設定です。

Inference 抽象と KV キャッシュ

Inference プロトコルは、キャッシュを使ったデコーダーの forward パスを実行するためのインターフェースを定義します。その唯一の実装である PyTorchInference が KV キャッシュを管理します。

def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
    if not self.kv_cache:
        self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()

    if tokens.shape[-1] > self.initial_token_length:
        # only need to use the last token except in the first forward pass
        tokens = tokens[:, -1:]

    return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)

KV キャッシュは Whisper の実装の中でもとりわけ巧妙な部分で、model.pyinstall_kv_cache_hooks() に実装されています。attention の forward() メソッド自体を改変するのではなく、key と value の projection レイヤーに PyTorch の forward フックを取り付けるアプローチを採っています。

sequenceDiagram
    participant D as Decoder
    participant K as Key Projection
    participant Hook as Forward Hook
    participant Cache as KV Cache Dict

    Note over D: First forward pass (all tokens)
    D->>K: key_proj(x) → full key tensor
    K->>Hook: save_to_cache(output)
    Hook->>Cache: cache[key_proj] = output

    Note over D: Subsequent passes (last token only)
    D->>K: key_proj(x[:, -1:]) → single key
    K->>Hook: save_to_cache(output)
    Hook->>Cache: cache[key_proj] = cat(cached, output)
    Hook-->>D: return concatenated keys

save_to_cache フック関数は 2 つのケースを処理します。

def save_to_cache(module, _, output):
    if module not in cache or output.shape[1] > self.dims.n_text_ctx:
        cache[module] = output                                    # first token or cross-attn
    else:
        cache[module] = torch.cat([cache[module], output], dim=1).detach()
    return cache[module]

self-attention の key/value については、新しいトークンの projection 結果がキャッシュ済みシーケンスに連結されていきます。一方、cross-attention の key/value はエンコーダー出力から計算され、最初の forward パスでキャッシュされた後は再利用されます。これにより、30 秒の音声ウィンドウにつきエンコーダーの実行が 1 回で済むわけです。

フックの戻り値がここでは重要です。PyTorch の forward フックは新しいテンソルを返すことでモジュールの出力を差し替えられます。このフックはキャッシュ全体を連結したテンソルを返すため、処理するトークンが 1 つだけであっても、attention レイヤーには完全な key/value シーケンスが渡されます。

ヒント: キャッシュの連結時に detach() を呼んでいるのは、キャッシュを通じた勾配の蓄積を防ぐためです。推論中のメモリ使用量を一定に抑える効果があります。

TokenDecoder 基底クラスは、トークン選択のための Strategy パターンを定義します。主要な 2 つの戦略がそれぞれ実装されています。

classDiagram
    class TokenDecoder {
        <<abstract>>
        +reset()
        +update(tokens, logits, sum_logprobs) → tokens, completed
        +finalize(tokens, sum_logprobs) → candidates
    }

    class GreedyDecoder {
        +float temperature
        +int eot
        +update(): argmax or sample
    }

    class BeamSearchDecoder {
        +int beam_size
        +float patience
        +Inference inference
        +update(): expand and prune beams
    }

    TokenDecoder <|-- GreedyDecoder
    TokenDecoder <|-- BeamSearchDecoder

GreedyDecoder はシンプルです。温度 0 では argmax を取り、それ以外では Categorical(logits=logits/temperature) からサンプリングします。対数確率を累積し、シーケンスが終了したら EOT トークンを伝播させます。

BeamSearchDecoder はより複雑です。update() メソッドは音声サンプルごとに次の手順で動作します。

  1. 各ビームについて、上位 beam_size + 1 トークンの対数確率を計算する
  2. 候補となる継続トークン(既存ビーム × 上位トークン)をすべてスコアリングする
  3. 上位 beam_size の非 EOT 候補を保持し、EOT 候補は完了済みとして収集する
  4. 新しいビームの並び順に合わせて KV キャッシュを並び替える

patience 機構(arxiv:2204.05424 より)を使うと、最初のビームが完了した後も探索を継続し、最大 round(beam_size * patience) 個の完了シーケンスを収集できます。これにより、より良い結果が得られる可能性が高まります。

365 行目rearrange_kv_cache() 呼び出しは、ビームサーチとキャッシュが交差する箇所です。ビームが並び替えられると、key/value テンソルも同じ順序に揃える必要があります。揃えなければ、ビーム 2 のキャッシュがビーム 0 の続きに使われてしまいます。

LogitFilter チェーン

logit フィルターパイプラインは Chain of Responsibility パターンを実装しています。デコーダーがトークンを選択する前に、各フィルターが logit をインプレースで変更していきます。

flowchart LR
    A["Raw Logits\nfrom decoder"] --> B["SuppressBlank\n(prevent blank/EOT\nat position 0)"]
    B --> C["SuppressTokens\n(mask non-speech\n& special tokens)"]
    C --> D["ApplyTimestampRules\n(enforce timestamp\ngrammar)"]
    D --> E["Filtered Logits\nto TokenDecoder"]

SuppressBlank は最初のサンプリング位置でのみ動作し、スペーストークンと EOT の logit を -inf に設定します。これがないと、モデルが冒頭で空白を出力したりシーケンスを即座に終了させたりする可能性があります。

SuppressTokens は指定されたトークン ID を無条件でマスクします。デフォルトの抑制リストには、非音声トークン(♪、[、( など)のすべてと、通常のデコード中に生成されるべきでない <|translate|><|startofprev|> といった制御トークンが含まれています。

ApplyTimestampRules はもっとも複雑なフィルターで、記事 3 で解説したタイムスタンプの文法を強制します。ルールは次のとおりです。

  1. ペアリング: 非タイムスタンプの後に来るタイムスタンプ(つまり終了時刻マーカー)の直後は、テキストトークンをすべて抑制し、次の出力を別のタイムスタンプ(次セグメントの開始)か EOT に強制する。
  2. ペアのタイムスタンプの後: タイムスタンプトークンをすべて抑制し、セグメント境界後はテキスト出力を強制する。
  3. 単調性: 直前のタイムスタンプより小さい値のタイムスタンプを禁止し、時間の逆行を防ぐ。
  4. 初期制約: 最初のサンプリング位置では非タイムスタンプトークンをすべて抑制し、max_initial_timestamp を適用する。
  5. 確率比較: タイムスタンプトークンの確率の合計が最大テキストトークンの確率を上回る場合、テキストトークンをすべて抑制する。これにより、セグメント境界の存在に確信があるときにタイムスタンプを出力するかどうかをモデル自身が「選択」できる。

このフィルターのおかげで、制約のない文章生成器が構造化されたセグメント生成器へと変わります。

DecodingTask:組み立てとメインループ

DecodingTask はオーケストレーターです。__init__ では DecodingOptions に基づいてすべてのコンポーネントを組み立てます。

# inference: forward pass with KV caching
self.inference = PyTorchInference(model, len(self.initial_tokens))

# sequence ranker: how to rank groups of samples
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)

# decoder: how to select next tokens
if options.beam_size is not None:
    self.decoder = BeamSearchDecoder(...)
else:
    self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)

# logit filters: rules to suppress/penalize tokens
self.logit_filters = []
if self.options.suppress_blank:
    self.logit_filters.append(SuppressBlank(...))
# ... etc

_main_loop() は自己回帰の中核サイクルを実装しています。

for i in range(self.sample_len):
    logits = self.inference.logits(tokens, audio_features)

    if i == 0 and self.tokenizer.no_speech is not None:
        probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
        no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()

    logits = logits[:, -1]

    for logit_filter in self.logit_filters:
        logit_filter.apply(logits, tokens)

    tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)

    if completed or tokens.shape[-1] > self.n_ctx:
        break

i == 0 のタイミングで無音確率を取り出していることに注目してください。最後の位置ではなく、SOT 位置の logit から <|nospeech|> トークンの softmax 確率を読み取っています。この処理はループ内で 1 度だけ行われ、トランスクリプションループに対して音声区間検出 (VAD) のシグナルを提供します。

run() メソッドはパイプライン全体をつなぎます。音声のエンコード → 言語検出 → beam/best-of のグループ化のためのトークン複製 → メインループの実行 → 終端処理 → ランキング → 結果の構築という流れで進みます。

言語検出

detect_language() 関数は、1 トークンだけを使う巧妙なトリックです。入力トークンとして <|startoftranscript|> のみを渡し、デコーダーを 1 ステップだけ動かして、出力 logit から全言語トークンの確率分布を読み取ります。

x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)
logits = model.logits(x, mel)[:, 0]

mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)

SOT トークンの直後、モデルは言語トークンを期待します。言語トークン以外をすべて -inf でマスクしてから argmax を取ることで、モデルが推定する話されている言語が得られます。softmax を取れば、サポートされているすべての言語にわたる完全な確率分布が得られます。

この処理はメインのデコードループの外で意図的に実行されます。KV キャッシュを汚染しないようにするためで、独立したクリーンな forward パスとして完結します。

ヒント: バッチ化した mel テンソルを渡せば、複数の音声セグメントの言語検出を一括で実行できます。この関数は単一入力(ndim == 2)とバッチ入力(ndim == 3)の両方に対応しています。

次回予告

デコードシステム全体を、設定からキャッシュ、トークン選択まで一通り見てきました。記事 5 では、このデコーダーをラップするトランスクリプションループへとズームアウトします。30 秒のウィンドウをスライドさせながら長尺の音声ファイルを処理し、温度フォールバックで異常な出力から回復し、幻覚を検出する仕組みを解説します。