Read OSS

Whisperのアーキテクチャを一望する:OpenAIの音声認識コードベースを読み解く

中級

前提知識

  • Pythonの基礎的な知識
  • ニューラルネットワークの概念に対する一般的な理解
  • PyTorchのnn.Moduleの基本的な知識

Whisperのアーキテクチャを一望する:OpenAIの音声認識コードベースを読み解く

OpenAIのWhisperは、コード量と実現できる機能のバランスという観点で、オープンソースプロジェクトの中でも際立った存在です。多言語の音声認識、翻訳、言語検出、単語レベルのタイムスタンプといった機能の全体が、わずか9つのモジュールに分散した約2,500行のPythonコードに収まっています。冗長な抽象化も、プラグインアーキテクチャも、独自の設定DSLも存在しません。あるのは、よく設計されたパイプラインを備えた、シンプルなエンコーダー・デコーダー型Transformerだけです。

この記事は、Whisperコードベースのあらゆる重要なコードを解剖する全6回シリーズの第1回です。この記事を読み終えることで、プロジェクト全体のメンタルモデルが構築でき、システムに関するどんな疑問が生じても迷わず該当ファイルを開けるようになります。

プロジェクト構造とモジュールマップ

Whisperパッケージはフラットな構成を採用しており、ネストされたサブパッケージ階層も、core/utils/services/のような分散した構造も存在しません。すべてのモジュールはwhisper/ディレクトリの直下に置かれ、それぞれが単一の明確な責務を持っています。

モジュール 行数(概算) 責務
__init__.py 162 モデルレジストリ、ダウンロード、load_model() API
model.py 345 ModelDimensionsWhisper、エンコーダー/デコーダーアーキテクチャ
audio.py 157 音声読み込み、メルスペクトログラムの計算
tokenizer.py 395 Tiktokenラッパー、特殊トークン、言語サポート
decoding.py 826 自己回帰デコード、ビームサーチ、ロジットフィルター
transcribe.py 623 スライディングウィンドウ方式の文字起こしループ、CLI
timing.py 389 クロスアテンション + DTWによる単語タイムスタンプ
utils.py 318 出力ライター、フォーマット処理、圧縮率
triton_ops.py 118 GPU加速されたDTWおよびメディアンフィルターカーネル

また、事前計算済みのメルフィルターバンク(mel_filters.npz)とBPEボキャブラリー(gpt2.tiktokenmultilingual.tiktoken)を格納するassets/ディレクトリも存在します。テキスト正規化を担うnormalizers/サブパッケージやバージョン管理用のversion.pyも含まれています。

flowchart TD
    subgraph "whisper/ package"
        INIT["__init__.py\n(model registry + load)"]
        MODEL["model.py\n(Whisper nn.Module)"]
        AUDIO["audio.py\n(mel spectrograms)"]
        TOKEN["tokenizer.py\n(tiktoken wrapper)"]
        DECODE["decoding.py\n(beam search + filters)"]
        TRANS["transcribe.py\n(sliding window + CLI)"]
        TIMING["timing.py\n(word timestamps)"]
        UTILS["utils.py\n(output writers)"]
        TRITON["triton_ops.py\n(GPU kernels)"]
    end

    INIT --> MODEL
    INIT --> AUDIO
    INIT --> DECODE
    INIT --> TRANS
    TRANS --> AUDIO
    TRANS --> DECODE
    TRANS --> TIMING
    TRANS --> TOKEN
    DECODE --> TOKEN
    TIMING --> TRITON
    TIMING --> MODEL
    MODEL --> DECODE
    MODEL --> TRANS

ヒント: model.pydecoding.pytranscribe.pyの間に描かれた依存の矢印は、意図的な循環設計です。Whisperがこれをどう解決しているかは、後ほど詳しく見ていく巧妙なメソッドバインディングのパターンに秘密があります。コードベースの中でも特に興味深い設計上の判断の一つです。

エントリーポイント:CLIとPython API

Whisperには2つのエントリーポイントがあり、どちらも最終的に同じtranscribe()関数を呼び出します。

CLIのエントリーポイントpyproject.toml#L35に定義されています。

scripts.whisper = "whisper.transcribe:cli"

コマンドラインからwhisperを実行すると、transcribe.pycli()が呼び出される仕組みです。python -m whisperでも同様に起動でき、こちらはwhisper/__main__.pyにたどり着きます。このファイルはわずか2行で、cli()をインポートして呼び出すだけの内容です。

Python APIのエントリーポイントwhisper.load_model()で、whisper/__init__.pyからエクスポートされています。モデルを読み込んだ後、model.transcribe(audio)を呼び出すと、内部で同じ関数が実行されます。

flowchart LR
    A["$ whisper audio.mp3"] --> B["transcribe.py:cli()"]
    C["$ python -m whisper"] --> D["__main__.py"] --> B
    E["model.transcribe(audio)"] --> F["transcribe.py:transcribe()"]
    B --> G["load_model()"] --> F

どちらのパスも最終的にtranscribe()へと合流します。これがシステムの中核です。CLIはargparseの処理、モデルの読み込み、出力の書き込みをその周りに加えており、Python APIはこれをWhisperクラスのメソッドとして直接公開しています。

モデルレジストリとダウンロードの仕組み

Whisperには12種類のモデルバリアントが用意されており、パラメーター数3,900万のtinyから15億5,000万のlarge-v3まで揃っています。モデルレジストリは、モデル名とAzure CDNのURLを対応させたシンプルな辞書で、whisper/__init__.py#L17-L32に定義されています。

_MODELS = {
    "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
    # ...
    "turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
}

URLのパスに注目してください。SHA256ハッシュがURLパスそのものに埋め込まれています。末尾から2番目のパスセグメントが期待されるチェックサムです。54行目_download()関数はurl.split("/")[-2]でこれを取り出し、ダウンロードしたバイト列と照合します。別途チェックサムファイルを用意する必要も、マニフェストも不要で、URLそのものが整合性の記録になっているという、エレガントな設計です。

モデルURLと並んで、36〜51行目には_ALIGNMENT_HEADSも定義されています。これはbase85エンコードされたgzip圧縮のブール配列で、単語レベルのタイミングと最も相関の高いクロスアテンションヘッドを特定するものです。model.pyset_alignment_heads()メソッドでデコードされ、単語タイムスタンプの抽出時に使用されます(詳細は第6回で解説します)。

103〜161行目load_model()は、レジストリから名前でモデルを読み込む、ローカルのファイルパスから読み込む、どちらでもない場合は分かりやすいエラーを出す、という3つのケースに対応しています。チェックポイントをデシリアライズし、保存されたdims辞書からModelDimensionsを構築してWhisperクラスをインスタンス化し、重みを読み込んでアラインメントヘッドを設定します。モデルはデフォルトで~/.cache/whisperにキャッシュされます(XDG_CACHE_HOMEの設定も尊重されます)。

WhisperクラスとModelDimensions

アーキテクチャ全体は、たった1つの10フィールドのdataclassで定義されています。それがModelDimensionsです。

@dataclass
class ModelDimensions:
    n_mels: int          # mel frequency bins (80 or 128)
    n_audio_ctx: int     # audio context length (1500)
    n_audio_state: int   # encoder hidden size
    n_audio_head: int    # encoder attention heads
    n_audio_layer: int   # encoder layers
    n_vocab: int         # vocabulary size
    n_text_ctx: int      # text context length (448)
    n_text_state: int    # decoder hidden size
    n_text_head: int     # decoder attention heads
    n_text_layer: int    # decoder layers

この10個の数値がモデル内のすべての次元を完全に規定します。252〜345行目Whisperクラスは、これらの次元をもとにAudioEncoderTextDecoderを組み合わせて構成されています。

classDiagram
    class ModelDimensions {
        +int n_mels
        +int n_audio_ctx
        +int n_audio_state
        +int n_audio_head
        +int n_audio_layer
        +int n_vocab
        +int n_text_ctx
        +int n_text_state
        +int n_text_head
        +int n_text_layer
    }

    class Whisper {
        +ModelDimensions dims
        +AudioEncoder encoder
        +TextDecoder decoder
        +embed_audio(mel)
        +logits(tokens, audio_features)
        +detect_language()
        +transcribe()
        +decode()
    }

    class AudioEncoder {
        +Conv1d conv1
        +Conv1d conv2
        +Tensor positional_embedding
        +ModuleList blocks
    }

    class TextDecoder {
        +Embedding token_embedding
        +Parameter positional_embedding
        +ModuleList blocks
    }

    Whisper --> ModelDimensions
    Whisper --> AudioEncoder
    Whisper --> TextDecoder

アーキテクチャ上で最も興味深いのは、model.pyの末尾——343〜345行目です。

detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function

これらの行は、decoding.pytranscribe.pyのモジュールレベル関数をWhisperクラスのメソッドとしてバインドしています。循環インポートを回避するためのWhisper流の工夫です。model.pydecodetranscribeを関数としてインポートし(ファイル冒頭でインポート済み)、それをクラス属性として代入しています。一方、decoding.pytranscribe.pyTYPE_CHECKING配下でのみWhisper型を参照するため、実行時に循環依存は発生しません。

model.transcribe(audio)を呼び出すと、Pythonはモデルインスタンスをselfとして渡し、それがtranscribe()関数のmodelパラメーターになります。シンプルで機能し、プラグインシステムのような複雑さを一切持ち込まない設計です。

エンドツーエンドのデータフロー概観

続く記事で詳しく掘り下げる前に、音声ファイルからテキスト出力に至るまでの全体の流れをざっと追っておきましょう。

flowchart TD
    A["Audio File\n(any format)"] -->|"ffmpeg subprocess"| B["Raw Waveform\n16kHz mono float32"]
    B -->|"torch.stft + mel filterbank"| C["Log-Mel Spectrogram\n[n_mels × frames]"]
    C -->|"pad_or_trim to 3000 frames"| D["30-Second Chunk\n[n_mels × 3000]"]
    D -->|"AudioEncoder\n(conv stem + transformer)"| E["Audio Features\n[1500 × d_model]"]
    E -->|"TextDecoder\n(autoregressive)"| F["Token Sequence\nwith timestamps"]
    F -->|"Tokenizer.decode()"| G["Text Segments\nwith timing"]
    G -->|"Output Writers"| H["TXT / SRT / VTT / JSON / TSV"]

    style A fill:#f9f,stroke:#333
    style H fill:#9f9,stroke:#333

各モジュールがどのステージに対応しているかを整理しましょう。

  1. 音声読み込みaudio.py):FFmpegがあらゆる音声フォーマットをPCMに変換します。音声全体は最初にメルスペクトログラムへと変換されます。

  2. スライディングウィンドウtranscribe.py):スペクトログラムを30秒のチャンクに分割して処理します。各チャンクはちょうど3000フレームにパディングまたはトリミングされ、エンコーダーに渡されます。

  3. エンコードmodel.pyAudioEncoder):2つの1D畳み込み(2つ目はストライド2)が時間次元を半分の1500に圧縮し、その後Transformerが特徴量を処理します。

  4. デコードdecoding.py):ロジットフィルタリング、温度フォールバック、オプションのビームサーチを組み合わせた自己回帰的なトークン生成を行います。タイムスタンプトークンがセグメントの区切りを示します。

  5. 単語タイムスタンプtiming.py):クロスアテンションの重みを抽出し、メディアンフィルターで処理した後、Dynamic Time Warpingによってトークンにアライメントします。

  6. 出力utils.py):ライター階層が結果をプレーンテキスト、字幕(SRT/VTT)、TSV、JSONのいずれかの形式に整形します。

ヒント: 30秒というチャンクサイズは恣意的に決められたものではありません。これはモデルが学習した最大の音声長です。audio.pyの定数(SAMPLE_RATE=16000、CHUNK_LENGTH=30、N_SAMPLES=480000、N_FRAMES=3000)はすべて、この根本的な設計上の選択から導かれています。

次回予告

全体の地図が手に入ったところで、いよいよ深く掘り下げていきましょう。第2回では、音声前処理パイプラインを詳細に追っていきます。FFmpegのサブプロセス呼び出しから始まり、STFT計算、畳み込みエンコーダーのステムに至るまで、各ステージでのテンソルの次元変化を丁寧に追跡します。音声波形がデコーダーの注意対象となる1500フレームの特徴列に変換されるまでの過程を、具体的なコードとともに確認していきましょう。