Whisper 架构概览:读懂 OpenAI 语音识别代码库
前置知识
- ›具备基本的 Python 编程能力
- ›对神经网络概念有基本了解
- ›熟悉 PyTorch nn.Module 的基础用法
Whisper 架构概览:读懂 OpenAI 语音识别代码库
OpenAI 的 Whisper 是少数几个以极小代码量实现强大功能的开源项目之一。多语言语音识别、翻译、语言检测、词级时间戳——整个系统只用了约 2500 行 Python 代码,分布在九个模块中。没有臃肿的抽象层,没有插件架构,没有配置 DSL,有的只是一个设计精良的 encoder-decoder Transformer,加上围绕它构建的一套简洁 pipeline。
本文是深度解析 Whisper 代码库系列的第一篇,全系列共六篇,将逐行剖析代码中每一个有意义的细节。读完本文,你将建立起对整个项目的清晰认知,知道遇到任何问题该去哪个文件里找答案。
项目结构与模块概览
Whisper 的包结构是扁平的——没有嵌套的子包层级,没有 core/、utils/、services/ 之类的目录划分。所有模块都放在 whisper/ 目录的顶层,每个模块职责单一、边界清晰。
| 模块 | 代码行数(约) | 职责 |
|---|---|---|
__init__.py |
162 | 模型注册表、下载逻辑、load_model() API |
model.py |
345 | ModelDimensions、Whisper、encoder/decoder 架构 |
audio.py |
157 | 音频加载、梅尔频谱计算 |
tokenizer.py |
395 | Tiktoken 封装、特殊 token、多语言支持 |
decoding.py |
826 | 自回归解码、束搜索、logit 过滤 |
transcribe.py |
623 | 滑动窗口转录循环、CLI |
timing.py |
389 | 基于交叉注意力 + DTW 的词级时间戳 |
utils.py |
318 | 输出写入器、格式化、压缩率 |
triton_ops.py |
118 | GPU 加速的 DTW 和中值滤波内核 |
此外还有一个 assets/ 目录,存放预计算的梅尔滤波器组(mel_filters.npz)和 BPE 词表(gpt2.tiktoken、multilingual.tiktoken);normalizers/ 子包负责文本规范化;version.py 用于版本追踪。
flowchart TD
subgraph "whisper/ package"
INIT["__init__.py\n(model registry + load)"]
MODEL["model.py\n(Whisper nn.Module)"]
AUDIO["audio.py\n(mel spectrograms)"]
TOKEN["tokenizer.py\n(tiktoken wrapper)"]
DECODE["decoding.py\n(beam search + filters)"]
TRANS["transcribe.py\n(sliding window + CLI)"]
TIMING["timing.py\n(word timestamps)"]
UTILS["utils.py\n(output writers)"]
TRITON["triton_ops.py\n(GPU kernels)"]
end
INIT --> MODEL
INIT --> AUDIO
INIT --> DECODE
INIT --> TRANS
TRANS --> AUDIO
TRANS --> DECODE
TRANS --> TIMING
TRANS --> TOKEN
DECODE --> TOKEN
TIMING --> TRITON
TIMING --> MODEL
MODEL --> DECODE
MODEL --> TRANS
提示:
model.py、decoding.py和transcribe.py之间存在循环依赖,这是有意为之的设计。Whisper 通过一种巧妙的方法绑定模式来解决这个问题,我们稍后会详细分析——这是整个代码库中最值得关注的结构性决策之一。
入口点:CLI 与 Python API
Whisper 提供两个入口点,但两条路最终都会汇聚到同一个 transcribe() 函数。
CLI 入口点在 pyproject.toml#L35 中声明:
scripts.whisper = "whisper.transcribe:cli"
也就是说,在命令行运行 whisper 会调用 transcribe.py 中的 cli()。你也可以运行 python -m whisper,这会触发 whisper/__main__.py——一个仅有两行的文件,直接导入并调用 cli()。
Python API 入口点是 whisper.load_model(),从 whisper/__init__.py 导出。加载模型后,调用 model.transcribe(audio) 同样会执行底层的 transcribe() 函数。
flowchart LR
A["$ whisper audio.mp3"] --> B["transcribe.py:cli()"]
C["$ python -m whisper"] --> D["__main__.py"] --> B
E["model.transcribe(audio)"] --> F["transcribe.py:transcribe()"]
B --> G["load_model()"] --> F
两条路径最终都指向 transcribe(),这是整个系统的核心。CLI 在其外层封装了 argparse 参数解析、模型加载和输出写入逻辑;Python API 则将其作为 Whisper 类的方法直接暴露出来。
模型注册表与下载机制
Whisper 提供 12 个模型变体,参数量从最小的 tiny(3900 万)到最大的 large-v3(15.5 亿)不等。模型注册表是一个普通的字典,将模型名称映射到 Azure CDN 的 URL,定义在 whisper/__init__.py#L17-L32:
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
# ...
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
}
注意 URL 路径中直接嵌入了 SHA256 哈希值——倒数第二段路径就是预期的校验和。第 54 行的 _download() 函数通过 url.split("/")[-2] 提取该哈希值,并对下载的文件内容进行验证。这个设计相当精妙:不需要单独的 checksums 文件,也不需要 manifest,URL 本身就承载了完整性校验信息。
此外,第 36–51 行还定义了 _ALIGNMENT_HEADS——这是经过 base85 编码和 gzip 压缩的布尔数组,标识了哪些交叉注意力头与词级时间戳的相关性最强。这些数据在 model.py 的 set_alignment_heads() 方法中解码,并在提取词级时间戳时使用(详见第 6 篇文章)。
第 103–161 行的 load_model() 函数处理三种情况:从注册表加载命名模型、从本地文件路径加载,或在两者都不满足时抛出友好的错误信息。它负责反序列化 checkpoint、从保存的 dims 字典中构造 ModelDimensions、实例化 Whisper 类、加载权重并配置对齐头。模型默认缓存到 ~/.cache/whisper(优先读取 XDG_CACHE_HOME 环境变量)。
Whisper 类与 ModelDimensions
整个模型架构由一个只有 10 个字段的 dataclass 完整参数化:ModelDimensions。
@dataclass
class ModelDimensions:
n_mels: int # mel frequency bins (80 or 128)
n_audio_ctx: int # audio context length (1500)
n_audio_state: int # encoder hidden size
n_audio_head: int # encoder attention heads
n_audio_layer: int # encoder layers
n_vocab: int # vocabulary size
n_text_ctx: int # text context length (448)
n_text_state: int # decoder hidden size
n_text_head: int # decoder attention heads
n_text_layer: int # decoder layers
这 10 个数字完全决定了模型中每一个维度。第 252–345 行的 Whisper 类根据这些维度组合出 AudioEncoder 和 TextDecoder:
classDiagram
class ModelDimensions {
+int n_mels
+int n_audio_ctx
+int n_audio_state
+int n_audio_head
+int n_audio_layer
+int n_vocab
+int n_text_ctx
+int n_text_state
+int n_text_head
+int n_text_layer
}
class Whisper {
+ModelDimensions dims
+AudioEncoder encoder
+TextDecoder decoder
+embed_audio(mel)
+logits(tokens, audio_features)
+detect_language()
+transcribe()
+decode()
}
class AudioEncoder {
+Conv1d conv1
+Conv1d conv2
+Tensor positional_embedding
+ModuleList blocks
}
class TextDecoder {
+Embedding token_embedding
+Parameter positional_embedding
+ModuleList blocks
}
Whisper --> ModelDimensions
Whisper --> AudioEncoder
Whisper --> TextDecoder
model.py 最值得关注的细节在文件末尾——第 343–345 行:
detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function
这几行代码将 decoding.py 和 transcribe.py 中的模块级函数作为方法绑定到 Whisper 类上。这正是 Whisper 规避循环导入的方式:model.py 以函数的形式导入 decode 和 transcribe(在文件顶部完成导入),再将它们赋值为类属性。而 decoding.py 和 transcribe.py 只在 TYPE_CHECKING 块中引用 Whisper 类型,运行时不存在循环依赖。
当你调用 model.transcribe(audio) 时,Python 会将模型实例作为 self 传入,对应 transcribe() 函数中的 model 参数。这个方案简单可靠,同时避免了引入插件系统带来的复杂性。
端到端数据流概览
在后续文章深入细节之前,让我们先完整地走一遍从音频文件到文本输出的全过程:
flowchart TD
A["Audio File\n(any format)"] -->|"ffmpeg subprocess"| B["Raw Waveform\n16kHz mono float32"]
B -->|"torch.stft + mel filterbank"| C["Log-Mel Spectrogram\n[n_mels × frames]"]
C -->|"pad_or_trim to 3000 frames"| D["30-Second Chunk\n[n_mels × 3000]"]
D -->|"AudioEncoder\n(conv stem + transformer)"| E["Audio Features\n[1500 × d_model]"]
E -->|"TextDecoder\n(autoregressive)"| F["Token Sequence\nwith timestamps"]
F -->|"Tokenizer.decode()"| G["Text Segments\nwith timing"]
G -->|"Output Writers"| H["TXT / SRT / VTT / JSON / TSV"]
style A fill:#f9f,stroke:#333
style H fill:#9f9,stroke:#333
各关键模块在 pipeline 中的对应阶段如下:
-
音频加载(
audio.py):FFmpeg 将任意音频格式解码为原始 PCM,并一次性将整段音频转换为梅尔频谱。 -
滑动窗口(
transcribe.py):频谱按 30 秒一块进行切分,每块补零或裁剪至恰好 3000 帧后送入 encoder。 -
编码(
model.py→AudioEncoder):两层一维卷积(第二层步长为 2)将时间维度压缩至 1500,再经过 Transformer 堆栈处理特征。 -
解码(
decoding.py):带 logit 过滤、温度回退和可选束搜索的自回归 token 生成,时间戳 token 标记片段边界。 -
词级时间戳(
timing.py):提取交叉注意力权重,经中值滤波处理后,通过动态时间规整(DTW)与 token 对齐。 -
输出(
utils.py):写入器层级结构将结果格式化为纯文本、字幕文件(SRT/VTT)、TSV 或 JSON。
提示: 30 秒的分块大小并非随意选定——这是模型训练时所用的最大音频时长。
audio.py中的那些常量(SAMPLE_RATE=16000、CHUNK_LENGTH=30、N_SAMPLES=480000、N_FRAMES=3000)都是由这一根本性的设计决策推导而来的。
下一篇
有了这张全局地图,我们就可以开始深入细节了。第 2 篇将详细追踪音频预处理 pipeline——从 FFmpeg 子进程的调用,到 STFT 计算,再到卷积 encoder stem,在每个阶段追踪 tensor 的维度变化。你将看到一段声音是如何一步步变成 decoder 所关注的那 1500 帧特征序列的。