Read OSS

単語レベルのタイムスタンプ:クロスアテンション整合、DTW、出力フォーマット

上級

前提知識

  • 第1〜5回の記事
  • Dynamic Time Warpingの基礎知識
  • アテンション機構への理解

単語レベルのタイムスタンプ:クロスアテンション整合、DTW、出力フォーマット

単語レベルのタイムスタンプは、Whisperが持つ後処理機能の中でも技術的に最も洗練されたものです。デコードされたテキストセグメントに対し、各単語がいつ発話されたかを正確に特定します。ただし、その根拠はデコーダーのタイムスタンプトークン(セグメント単位の境界を示すもの)ではありません。各テキストトークンを生成する際にデコーダーがどの音声フレームに「注目していたか」を表すクロスアテンション重みを使用します。

本記事では、このパイプライン全体を解説します。生のアテンション重みにアクセスするためのSDPA無効化、特定のアテンションヘッドからのQK行列の抽出、メディアンフィルターによる処理について説明します。さらに、Dynamic Time Warping(CPUとGPUの両実装)を用いたトークンと音声フレームの整合、句読点のマージ、字幕出力へのフォーマット変換を順を追って見ていきます。

クロスアテンション重みの抽出

タイミング処理のパイプラインは、実装全体の設計を左右するひとつの制約から始まります。PyTorchのscaled_dot_product_attention(SDPA)は融合カーネルであり、アテンション重みを返しません。SDPAが有効な状態ではqkNoneになります(model.py 123〜128行目を参照)。QK行列を取り出すには、SDPAを無効にする必要があります。

model.pydisable_sdpa()コンテキストマネージャーがこれを担います。

@contextmanager
def disable_sdpa():
    prev_state = MultiHeadAttention.use_sdpa
    try:
        MultiHeadAttention.use_sdpa = False
        yield
    finally:
        MultiHeadAttention.use_sdpa = prev_state

MultiHeadAttentionのクラス変数を切り替えることで、すべてのアテンション層を融合パスから、アテンション重みを返す明示的なQK計算パスに切り替えます。

find_alignment()の内部では、すべてのクロスアテンション層にフォワードフックを設置してQK行列をキャプチャします。

QKs = [None] * model.dims.n_text_layer
hooks = [
    block.cross_attn.register_forward_hook(
        lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
    )
    for i, block in enumerate(model.decoder.blocks)
]
sequenceDiagram
    participant FA as find_alignment()
    participant SDPA as disable_sdpa()
    participant Model as Whisper Model
    participant Hooks as Cross-Attn Hooks
    participant QKs as QK Matrix List

    FA->>SDPA: Enter context (disable SDPA)
    FA->>Hooks: Install forward hooks on cross_attn layers
    FA->>Model: Forward pass (mel, tokens)
    Model->>Hooks: Each cross-attn fires hook
    Hooks->>QKs: Store QK matrix [heads × tokens × frames]
    FA->>Hooks: Remove all hooks
    FA->>SDPA: Exit context (restore SDPA)

すべてのアテンションヘッドが整合に等しく有効なわけではありません。alignment_headsマスク(第1回で触れたbase85エンコードの_ALIGNMENT_HEADSデータから設定される)は、単語タイミングとの相関が高い特定のヘッドを識別します。DTW整合にはこれらのヘッドのみが使用されます。

weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])

アテンション重みの処理

生のQK重みは、DTW整合の前にいくつかの処理ステップを経ます。

flowchart LR
    A["Raw QK weights\n[heads × tokens × frames]"] -->|"× qk_scale"| B["Scaled"]
    B -->|"softmax(dim=-1)"| C["Normalized"]
    C -->|"standardize\n(mean/std across tokens)"| D["Standardized"]
    D -->|"median_filter\n(width=7)"| E["Smoothed"]
    E -->|"mean across heads"| F["Final matrix\n[tokens × frames]"]

標準化ステップ(トークン次元にわたる平均を引き、標準偏差で割る)は、各フレームのアテンション分布を正規化し、整合の安定性を高めます。メディアンフィルターは、アテンション重みに含まれるスパイク状のノイズを除去します。

median_filter()はまずTriton GPU実装を試み、失敗した場合はtorch.sort()を用いたCPU版にフォールバックします。

if x.is_cuda:
    try:
        from .triton_ops import median_filter_cuda
        result = median_filter_cuda(x, filter_width)
    except (RuntimeError, subprocess.CalledProcessError):
        warnings.warn("Failed to launch Triton kernels...")

if result is None:
    result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]

CPUフォールバックはPyTorchのある特性を活用しています。このユースケースではsort()torch.median()より高速なのです(pytorch/pytorch#51450)。

Dynamic Time Warping:CPUとGPU

DTWは、トークン列と音声フレーム列の間の最適な単調整合を求めるアルゴリズムです。概念的には、cost[i,j]が最初のiトークンと最初のjフレームを整合させるための最小コストを表すコスト行列を構築し、そこから最適パスをバックトレースします。

dtw_cpu()はNumba JITコンパイルを使って高速化されています。

@numba.jit(nopython=True, parallel=True)
def dtw_cpu(x: np.ndarray):
    N, M = x.shape
    cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
    trace = -np.ones((N + 1, M + 1), dtype=np.float32)
    cost[0, 0] = 0
    for j in range(1, M + 1):
        for i in range(1, N + 1):
            c0 = cost[i - 1, j - 1]  # diagonal
            c1 = cost[i - 1, j]      # vertical
            c2 = cost[i, j - 1]      # horizontal
            # ... pick minimum, record trace

GPU実装のdtw_cuda()は、Tritonカーネルによる反対角ウェーブフロント並列化を採用しています。鍵となるアイデアは、同じ反対角線上(i + j = k)にあるセルはすべて互いに独立しているため、並列計算が可能だという点です。dtw_kernelは反対角線k = 1からN + Mまで順に処理します。

for k in range(1, N + M + 1):
    tl.debug_barrier()
    # Load costs from three predecessors
    c0 = tl.load(p0 + offsets, mask=mask)  # diagonal
    c1 = tl.load(p1 + offsets, mask=mask)  # from above
    c2 = tl.load(p2 + offsets, mask=mask)  # from left
    # Compute minimum and store
    cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)

カーネル実行前に行列を「スキュー」させます(F.padとreshapeを使用)。これにより反対角線が行として並ぶため、連続したメモリアクセスが可能になります。

dtw()ディスパッチャーはGPUを優先し、失敗時はCPUにフォールバックします。

flowchart TD
    A["dtw(cost_matrix)"] --> B{"CUDA\navailable?"}
    B -->|Yes| C["dtw_cuda()\nTriton kernel"]
    C -->|"RuntimeError"| D["dtw_cpu()\nNumba JIT"]
    B -->|No| D
    C --> E["backtrace()\nExtract optimal path"]
    D --> E

Tritonメディアンフィルターのコード生成トリック

Whisperのコードベース全体の中で最も変わったパターンが、triton_ops.pymedian_kernel()に存在します。汎用的なメディアンフィルターカーネルを書くのではなく、フィルター幅ごとに特化したカーネルのソースコードを実行時に動的生成するアプローチをとっています。

テンプレートカーネルにはプレースホルダーの識別子が含まれています。

LOAD_ALL_ROWS_HERE    # replaced with N load statements
BUBBLESORT_HERE       # replaced with unrolled bubblesort comparisons
MIDDLE_ROW_HERE       # replaced with the median variable name

median_kernel()はこれらのプレースホルダーを文字列置換することで実際のカーネルを構築します。filter_width=7の場合、7つのload文、約28回のcompare-and-swap操作(展開されたバブルソート)が生成され、中央値としてrow3が格納されます。

new_kernel = kernel.src.replace(
    "    LOAD_ALL_ROWS_HERE",
    "\n".join([
        f"    row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
        for i in range(filter_width)
    ]),
)

生成されるバブルソートのパス数は、中央要素の正確さを保証するのに必要な最小限に抑えられています。具体的にはfilter_width // 2 + 1回のネストされたcompare-and-swapです。

f"    smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
f"    larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
f"    row{j} = smaller",
f"    row{j + 1} = larger",

その後、kernel._unsafe_update_src(new_kernel)(古いTritonバージョンでは直接kernel.srcへ代入)でJIT関数のソースを書き換え、ハッシュをクリアして強制的に再コンパイルを起こします。@lru_cacheにより、このコード生成はフィルター幅ごとに一度しか実行されません。

これはTritonの制約に動機づけられた段階的メタプログラミングの一例です。Tritonカーネルはある次元において静的な形状を要求するため、動的なフィルター幅に対応するには特化カーネルを生成するしかないのです。

注意: _unsafe_update_srcというメソッド名が、このアプローチのサポート状況を雄弁に物語っています。動作はしますが、Tritonの内部深くに依存しています。TritonのAPIが変わった際に最初に壊れるのは、ここでしょう。

単語境界の割り当てとヒューリスティック

DTWがトークンとフレームの整合を出力したあと、find_alignment()はそれを単語境界へと変換します。

  1. tokenizer.split_to_word_tokens()でトークンを単語に分割する(第3回で解説した言語対応の処理)
  2. 累積トークン数から単語境界を計算する
  3. 単語境界位置のDTWパスから開始・終了時刻を取り出す
  4. 各単語内のトークン確率の平均から単語確率を計算する

続いてmerge_punctuations()が、句読点を隣接する単語に付加するよう単語境界を調整します。開き記号("'([{-)は直後の単語の前に付き、閉じ記号(.,!?)]など)は直前の単語の後に付きます。

さらにadd_word_timestamps()が、追加の境界補正ヒューリスティックを適用します。

  • 最大継続時間の上限設定: 文末記号(.!?)に続く単語の継続時間を、中央値の2倍に制限します。これにより、ピリオドが文の後の長い無音区間を「吸収」してしまうのを防ぎます。
  • セグメント境界への整合: 最初の単語の開始時刻がセグメントの開始時刻よりも大幅に早い場合、セグメントレベルのタイムスタンプを優先します。最後の単語の終了時刻についても同様です。
  • ポーズ後の補正: 長いポーズ(中央値の4倍超)の後に続く最初の単語の開始時刻を、ポーズ全体にまたがらないようクランプします。

これらのヒューリスティックはコードのコメントでも「ハック」と自認されており、VADベースのセグメンテーションアルゴリズムへの置き換えが将来的な課題として記されています。それでも実際には効果的で、DTW整合で起こりやすい代表的なアーティファクトをうまく処理します。

flowchart TD
    A["find_alignment()\nDTW → raw word timings"] --> B["merge_punctuations()\nAttach . , ! ? to words"]
    B --> C["Compute median/max\nword duration"]
    C --> D["Truncate long words\nat sentence boundaries"]
    D --> E["Align word starts/ends\nwith segment boundaries"]
    E --> F["Post-pause correction\n(first word after silence)"]
    F --> G["Final word timestamps\nper segment"]

出力ライターと字幕フォーマット

whisper/utils.pyの出力システムは、ResultWriterを頂点とするクラス階層で構成されています。

classDiagram
    class ResultWriter {
        +str extension
        +str output_dir
        +__call__(result, audio_path)
        +write_result(result, file)*
    }

    class WriteTXT {
        +extension = "txt"
    }

    class WriteJSON {
        +extension = "json"
    }

    class WriteTSV {
        +extension = "tsv"
    }

    class SubtitlesWriter {
        +bool always_include_hours
        +str decimal_marker
        +iterate_result() → yields (start, end, text)
    }

    class WriteVTT {
        +extension = "vtt"
    }

    class WriteSRT {
        +extension = "srt"
    }

    ResultWriter <|-- WriteTXT
    ResultWriter <|-- WriteJSON
    ResultWriter <|-- WriteTSV
    ResultWriter <|-- SubtitlesWriter
    SubtitlesWriter <|-- WriteVTT
    SubtitlesWriter <|-- WriteSRT

WriteTXTWriteJSONはシンプルな実装です。WriteTSVはロケール依存の小数点区切り文字を避けるため、時刻をミリ秒の整数で扱います。ドキュメントコメントにも記されている細かい配慮です。

興味深いのはSubtitlesWriter.iterate_result()の実装です。highlight_words=Trueの場合、各単語が独自の字幕キューを持ち、アクティブな単語が<u>タグで囲まれます。

00:00.640 --> 00:00.920
Hello <u>world</u> today

このメソッドはmax_line_widthmax_line_countmax_words_per_lineの制約に従った改行処理も担い、自然な単語のまとまりを保ちながら字幕ブロックを生成します。

get_writer()ファクトリー関数はフォーマットの選択を担います。"all"オプションを指定すると、5つのライターすべてをインスタンス化し、それぞれを順に呼び出すコンポジット関数を返します。Compositeパターンのシンプルな実装です。

if output_format == "all":
    all_writers = [writer(output_dir) for writer in writers.values()]
    def write_all(result, file, options=None, **kwargs):
        for writer in all_writers:
            writer(result, file, options, **kwargs)
    return write_all

シリーズのまとめ

6回にわたる本シリーズでは、Whisperの約2,500行にわたるコードの主要なパスをすべて追ってきました。

  1. アーキテクチャ概観: 9モジュール構成のパッケージ、エントリーポイント、メソッドバインディングパターン
  2. 音声フロントエンド: FFmpegによる読み込み、STFT、メルスペクトログラム、畳み込みエンコーダーのステム
  3. トークナイザー: デュアル語彙、1501個のタイムスタンプトークン、SOTシーケンスプロトコル
  4. デコーディング: フックによるKVキャッシュ、ビームサーチ、ロジットフィルターチェーン、メインループ
  5. 文字起こし: スライディングウィンドウ、温度フォールバック、無音検出、幻覚からの回復
  6. 単語タイムスタンプ: クロスアテンション抽出、DTW整合、実行時コード生成、出力フォーマット

Whisperが優れているのは、ある1つの技術が突出しているからではありません。コードベースを小さく保ちながらも高い能力を実現する、エンジニアリング上のセンスにあります。不要な抽象化もなく、設定フレームワークもなく、プラグインシステムもない。音声からテキストへ至る明快なパイプラインが、いくつかの的確なデザインパターンと、堅牢な音声認識に本当に重要な細部への深い注意によって実装されているだけです。