Read OSS

Whisper 架构概览:读懂 OpenAI 语音识别代码库

中级

前置知识

  • 具备基本的 Python 编程能力
  • 对神经网络概念有基本了解
  • 熟悉 PyTorch nn.Module 的基础用法

Whisper 架构概览:读懂 OpenAI 语音识别代码库

OpenAI 的 Whisper 是少数几个以极小代码量实现强大功能的开源项目之一。多语言语音识别、翻译、语言检测、词级时间戳——整个系统只用了约 2500 行 Python 代码,分布在九个模块中。没有臃肿的抽象层,没有插件架构,没有配置 DSL,有的只是一个设计精良的 encoder-decoder Transformer,加上围绕它构建的一套简洁 pipeline。

本文是深度解析 Whisper 代码库系列的第一篇,全系列共六篇,将逐行剖析代码中每一个有意义的细节。读完本文,你将建立起对整个项目的清晰认知,知道遇到任何问题该去哪个文件里找答案。

项目结构与模块概览

Whisper 的包结构是扁平的——没有嵌套的子包层级,没有 core/utils/services/ 之类的目录划分。所有模块都放在 whisper/ 目录的顶层,每个模块职责单一、边界清晰。

模块 代码行数(约) 职责
__init__.py 162 模型注册表、下载逻辑、load_model() API
model.py 345 ModelDimensionsWhisper、encoder/decoder 架构
audio.py 157 音频加载、梅尔频谱计算
tokenizer.py 395 Tiktoken 封装、特殊 token、多语言支持
decoding.py 826 自回归解码、束搜索、logit 过滤
transcribe.py 623 滑动窗口转录循环、CLI
timing.py 389 基于交叉注意力 + DTW 的词级时间戳
utils.py 318 输出写入器、格式化、压缩率
triton_ops.py 118 GPU 加速的 DTW 和中值滤波内核

此外还有一个 assets/ 目录,存放预计算的梅尔滤波器组(mel_filters.npz)和 BPE 词表(gpt2.tiktokenmultilingual.tiktoken);normalizers/ 子包负责文本规范化;version.py 用于版本追踪。

flowchart TD
    subgraph "whisper/ package"
        INIT["__init__.py\n(model registry + load)"]
        MODEL["model.py\n(Whisper nn.Module)"]
        AUDIO["audio.py\n(mel spectrograms)"]
        TOKEN["tokenizer.py\n(tiktoken wrapper)"]
        DECODE["decoding.py\n(beam search + filters)"]
        TRANS["transcribe.py\n(sliding window + CLI)"]
        TIMING["timing.py\n(word timestamps)"]
        UTILS["utils.py\n(output writers)"]
        TRITON["triton_ops.py\n(GPU kernels)"]
    end

    INIT --> MODEL
    INIT --> AUDIO
    INIT --> DECODE
    INIT --> TRANS
    TRANS --> AUDIO
    TRANS --> DECODE
    TRANS --> TIMING
    TRANS --> TOKEN
    DECODE --> TOKEN
    TIMING --> TRITON
    TIMING --> MODEL
    MODEL --> DECODE
    MODEL --> TRANS

提示: model.pydecoding.pytranscribe.py 之间存在循环依赖,这是有意为之的设计。Whisper 通过一种巧妙的方法绑定模式来解决这个问题,我们稍后会详细分析——这是整个代码库中最值得关注的结构性决策之一。

入口点:CLI 与 Python API

Whisper 提供两个入口点,但两条路最终都会汇聚到同一个 transcribe() 函数。

CLI 入口点pyproject.toml#L35 中声明:

scripts.whisper = "whisper.transcribe:cli"

也就是说,在命令行运行 whisper 会调用 transcribe.py 中的 cli()。你也可以运行 python -m whisper,这会触发 whisper/__main__.py——一个仅有两行的文件,直接导入并调用 cli()

Python API 入口点whisper.load_model(),从 whisper/__init__.py 导出。加载模型后,调用 model.transcribe(audio) 同样会执行底层的 transcribe() 函数。

flowchart LR
    A["$ whisper audio.mp3"] --> B["transcribe.py:cli()"]
    C["$ python -m whisper"] --> D["__main__.py"] --> B
    E["model.transcribe(audio)"] --> F["transcribe.py:transcribe()"]
    B --> G["load_model()"] --> F

两条路径最终都指向 transcribe(),这是整个系统的核心。CLI 在其外层封装了 argparse 参数解析、模型加载和输出写入逻辑;Python API 则将其作为 Whisper 类的方法直接暴露出来。

模型注册表与下载机制

Whisper 提供 12 个模型变体,参数量从最小的 tiny(3900 万)到最大的 large-v3(15.5 亿)不等。模型注册表是一个普通的字典,将模型名称映射到 Azure CDN 的 URL,定义在 whisper/__init__.py#L17-L32

_MODELS = {
    "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
    # ...
    "turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
}

注意 URL 路径中直接嵌入了 SHA256 哈希值——倒数第二段路径就是预期的校验和。第 54 行_download() 函数通过 url.split("/")[-2] 提取该哈希值,并对下载的文件内容进行验证。这个设计相当精妙:不需要单独的 checksums 文件,也不需要 manifest,URL 本身就承载了完整性校验信息。

此外,第 36–51 行还定义了 _ALIGNMENT_HEADS——这是经过 base85 编码和 gzip 压缩的布尔数组,标识了哪些交叉注意力头与词级时间戳的相关性最强。这些数据在 model.pyset_alignment_heads() 方法中解码,并在提取词级时间戳时使用(详见第 6 篇文章)。

第 103–161 行load_model() 函数处理三种情况:从注册表加载命名模型、从本地文件路径加载,或在两者都不满足时抛出友好的错误信息。它负责反序列化 checkpoint、从保存的 dims 字典中构造 ModelDimensions、实例化 Whisper 类、加载权重并配置对齐头。模型默认缓存到 ~/.cache/whisper(优先读取 XDG_CACHE_HOME 环境变量)。

Whisper 类与 ModelDimensions

整个模型架构由一个只有 10 个字段的 dataclass 完整参数化:ModelDimensions

@dataclass
class ModelDimensions:
    n_mels: int          # mel frequency bins (80 or 128)
    n_audio_ctx: int     # audio context length (1500)
    n_audio_state: int   # encoder hidden size
    n_audio_head: int    # encoder attention heads
    n_audio_layer: int   # encoder layers
    n_vocab: int         # vocabulary size
    n_text_ctx: int      # text context length (448)
    n_text_state: int    # decoder hidden size
    n_text_head: int     # decoder attention heads
    n_text_layer: int    # decoder layers

这 10 个数字完全决定了模型中每一个维度。第 252–345 行Whisper 类根据这些维度组合出 AudioEncoderTextDecoder

classDiagram
    class ModelDimensions {
        +int n_mels
        +int n_audio_ctx
        +int n_audio_state
        +int n_audio_head
        +int n_audio_layer
        +int n_vocab
        +int n_text_ctx
        +int n_text_state
        +int n_text_head
        +int n_text_layer
    }

    class Whisper {
        +ModelDimensions dims
        +AudioEncoder encoder
        +TextDecoder decoder
        +embed_audio(mel)
        +logits(tokens, audio_features)
        +detect_language()
        +transcribe()
        +decode()
    }

    class AudioEncoder {
        +Conv1d conv1
        +Conv1d conv2
        +Tensor positional_embedding
        +ModuleList blocks
    }

    class TextDecoder {
        +Embedding token_embedding
        +Parameter positional_embedding
        +ModuleList blocks
    }

    Whisper --> ModelDimensions
    Whisper --> AudioEncoder
    Whisper --> TextDecoder

model.py 最值得关注的细节在文件末尾——第 343–345 行

detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function

这几行代码将 decoding.pytranscribe.py 中的模块级函数作为方法绑定到 Whisper 类上。这正是 Whisper 规避循环导入的方式:model.py 以函数的形式导入 decodetranscribe(在文件顶部完成导入),再将它们赋值为类属性。而 decoding.pytranscribe.py 只在 TYPE_CHECKING 块中引用 Whisper 类型,运行时不存在循环依赖。

当你调用 model.transcribe(audio) 时,Python 会将模型实例作为 self 传入,对应 transcribe() 函数中的 model 参数。这个方案简单可靠,同时避免了引入插件系统带来的复杂性。

端到端数据流概览

在后续文章深入细节之前,让我们先完整地走一遍从音频文件到文本输出的全过程:

flowchart TD
    A["Audio File\n(any format)"] -->|"ffmpeg subprocess"| B["Raw Waveform\n16kHz mono float32"]
    B -->|"torch.stft + mel filterbank"| C["Log-Mel Spectrogram\n[n_mels × frames]"]
    C -->|"pad_or_trim to 3000 frames"| D["30-Second Chunk\n[n_mels × 3000]"]
    D -->|"AudioEncoder\n(conv stem + transformer)"| E["Audio Features\n[1500 × d_model]"]
    E -->|"TextDecoder\n(autoregressive)"| F["Token Sequence\nwith timestamps"]
    F -->|"Tokenizer.decode()"| G["Text Segments\nwith timing"]
    G -->|"Output Writers"| H["TXT / SRT / VTT / JSON / TSV"]

    style A fill:#f9f,stroke:#333
    style H fill:#9f9,stroke:#333

各关键模块在 pipeline 中的对应阶段如下:

  1. 音频加载audio.py):FFmpeg 将任意音频格式解码为原始 PCM,并一次性将整段音频转换为梅尔频谱。

  2. 滑动窗口transcribe.py):频谱按 30 秒一块进行切分,每块补零或裁剪至恰好 3000 帧后送入 encoder。

  3. 编码model.pyAudioEncoder):两层一维卷积(第二层步长为 2)将时间维度压缩至 1500,再经过 Transformer 堆栈处理特征。

  4. 解码decoding.py):带 logit 过滤、温度回退和可选束搜索的自回归 token 生成,时间戳 token 标记片段边界。

  5. 词级时间戳timing.py):提取交叉注意力权重,经中值滤波处理后,通过动态时间规整(DTW)与 token 对齐。

  6. 输出utils.py):写入器层级结构将结果格式化为纯文本、字幕文件(SRT/VTT)、TSV 或 JSON。

提示: 30 秒的分块大小并非随意选定——这是模型训练时所用的最大音频时长。audio.py 中的那些常量(SAMPLE_RATE=16000、CHUNK_LENGTH=30、N_SAMPLES=480000、N_FRAMES=3000)都是由这一根本性的设计决策推导而来的。

下一篇

有了这张全局地图,我们就可以开始深入细节了。第 2 篇将详细追踪音频预处理 pipeline——从 FFmpeg 子进程的调用,到 STFT 计算,再到卷积 encoder stem,在每个阶段追踪 tensor 的维度变化。你将看到一段声音是如何一步步变成 decoder 所关注的那 1500 帧特征序列的。