词级时间戳:交叉注意力对齐、DTW 与输出格式化
前置知识
- ›第 1–5 篇文章
- ›动态时间规整(DTW)基础知识
- ›对注意力机制有一定了解
词级时间戳:交叉注意力对齐、DTW 与输出格式化
词级时间戳是 Whisper 后处理流程中技术含量最高的功能。给定一段已解码的文本片段,系统会精确判断每个词的发音时刻——这并非依赖解码器的时间戳 token(那只能提供片段级边界),而是通过交叉注意力权重来实现的:这些权重揭示了解码器在生成每个文本 token 时"关注"的是哪些音频帧。
本文将完整梳理整个处理流程:禁用 SDPA 以获取原始注意力权重、从特定注意力头提取 QK 矩阵、通过中值滤波器处理、使用动态时间规整(DTW,含 CPU 和 GPU 两种实现)将 token 对齐到音频帧、合并标点符号,最终格式化输出字幕。
交叉注意力权重的提取
时间对齐流程有一个前提条件,它直接影响了整个实现方案:PyTorch 的 scaled_dot_product_attention(SDPA)是一个融合内核,不会返回注意力权重。SDPA 处于激活状态时,qk 为 None(详见 model.py 第 123–128 行)。要提取 QK 矩阵,必须先禁用 SDPA。
model.py 中的 disable_sdpa() 上下文管理器负责处理这一切:
@contextmanager
def disable_sdpa():
prev_state = MultiHeadAttention.use_sdpa
try:
MultiHeadAttention.use_sdpa = False
yield
finally:
MultiHeadAttention.use_sdpa = prev_state
它通过切换 MultiHeadAttention 的类变量,将所有注意力层从融合路径切换到显式 QK 计算路径,从而返回注意力权重。
在 find_alignment() 内部,前向钩子被安装到每一个交叉注意力层,用于捕获 QK 矩阵:
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
sequenceDiagram
participant FA as find_alignment()
participant SDPA as disable_sdpa()
participant Model as Whisper Model
participant Hooks as Cross-Attn Hooks
participant QKs as QK Matrix List
FA->>SDPA: Enter context (disable SDPA)
FA->>Hooks: Install forward hooks on cross_attn layers
FA->>Model: Forward pass (mel, tokens)
Model->>Hooks: Each cross-attn fires hook
Hooks->>QKs: Store QK matrix [heads × tokens × frames]
FA->>Hooks: Remove all hooks
FA->>SDPA: Exit context (restore SDPA)
并非所有注意力头对对齐任务都同样有用。alignment_heads 掩码(由第 1 篇中介绍的 base85 编码 _ALIGNMENT_HEADS 数据设置)标识了与词语时间高度相关的特定头,对齐时只使用这些头:
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
注意力权重的处理
原始 QK 权重在进行 DTW 对齐之前,需要经过一条处理流水线:
flowchart LR
A["Raw QK weights\n[heads × tokens × frames]"] -->|"× qk_scale"| B["Scaled"]
B -->|"softmax(dim=-1)"| C["Normalized"]
C -->|"standardize\n(mean/std across tokens)"| D["Standardized"]
D -->|"median_filter\n(width=7)"| E["Smoothed"]
E -->|"mean across heads"| F["Final matrix\n[tokens × frames]"]
标准化步骤(沿 token 维度减去均值并除以标准差)对每帧的注意力分布进行归一化,使对齐结果更加鲁棒。中值滤波则用于去除注意力权重中的尖刺噪声。
median_filter() 函数优先尝试 Triton GPU 实现,如果失败则回退到使用 torch.sort() 的 CPU 版本:
if x.is_cuda:
try:
from .triton_ops import median_filter_cuda
result = median_filter_cuda(x, filter_width)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn("Failed to launch Triton kernels...")
if result is None:
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
CPU 回退方案利用了 PyTorch 的一个特性:对于此场景,sort() 比 torch.median() 更快(pytorch/pytorch#51450)。
动态时间规整:CPU 与 GPU 实现
DTW 用于在 token 序列与音频帧序列之间寻找最优的单调对齐路径。从概念上讲,它构建一个代价矩阵,其中 cost[i,j] 表示将前 i 个 token 与前 j 帧对齐的最小代价,然后回溯找出最优路径。
dtw_cpu() 使用 Numba JIT 编译来提升性能:
@numba.jit(nopython=True, parallel=True)
def dtw_cpu(x: np.ndarray):
N, M = x.shape
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
cost[0, 0] = 0
for j in range(1, M + 1):
for i in range(1, N + 1):
c0 = cost[i - 1, j - 1] # diagonal
c1 = cost[i - 1, j] # vertical
c2 = cost[i, j - 1] # horizontal
# ... pick minimum, record trace
GPU 实现 dtw_cuda() 使用了带有反对角线波前并行化的 Triton kernel。核心思路在于:同一反对角线上的所有单元格(满足 i + j = k)彼此独立,因此可以并行计算。dtw_kernel 依次遍历从 k = 1 到 N + M 的所有反对角线:
for k in range(1, N + M + 1):
tl.debug_barrier()
# Load costs from three predecessors
c0 = tl.load(p0 + offsets, mask=mask) # diagonal
c1 = tl.load(p1 + offsets, mask=mask) # from above
c2 = tl.load(p2 + offsets, mask=mask) # from left
# Compute minimum and store
cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
在调用 kernel 之前,矩阵会通过 F.pad 和 reshape 进行"倾斜"处理,使反对角线变为行,从而实现连续的内存访问。
dtw() 调度函数优先尝试 GPU,失败时回退到 CPU:
flowchart TD
A["dtw(cost_matrix)"] --> B{"CUDA\navailable?"}
B -->|Yes| C["dtw_cuda()\nTriton kernel"]
C -->|"RuntimeError"| D["dtw_cpu()\nNumba JIT"]
B -->|No| D
C --> E["backtrace()\nExtract optimal path"]
D --> E
Triton 中值滤波器的代码生成技巧
整个 Whisper 代码库中最不寻常的模式,藏在 triton_ops.py 的 median_kernel() 中。它并没有编写一个通用的中值滤波 kernel,而是在运行时动态生成针对特定滤波器宽度的专用 kernel 源码。
模板 kernel 中包含若干占位符标识符:
LOAD_ALL_ROWS_HERE # replaced with N load statements
BUBBLESORT_HERE # replaced with unrolled bubblesort comparisons
MIDDLE_ROW_HERE # replaced with the median variable name
median_kernel() 函数通过字符串替换这些占位符来构建实际的 kernel。以 filter_width=7 为例,这会生成 7 条 load 语句、约 28 次比较交换操作(展开的冒泡排序),并存储 row3(即中位数):
new_kernel = kernel.src.replace(
" LOAD_ALL_ROWS_HERE",
"\n".join([
f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
for i in range(filter_width)
]),
)
生成的冒泡排序只进行恰好足够的轮次,以保证中间元素正确——即 filter_width // 2 + 1 轮嵌套比较交换:
f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
f" row{j} = smaller",
f" row{j + 1} = larger",
随后,函数通过 kernel._unsafe_update_src(new_kernel)(或在旧版 Triton 中直接赋值给 kernel.src)将生成的源码注入 JIT 函数,并清除哈希值以强制重新编译。@lru_cache 确保每种滤波器宽度只进行一次代码生成。
这是一种由 Triton 约束驱动的分阶段元编程方案:Triton kernel 在某些维度上必须是静态形状的,因此动态滤波器宽度只能通过生成专用 kernel 来实现。
注意:
_unsafe_update_src这个方法名本身就说明了一切——这种用法有多"不被官方支持"。它目前可以正常工作,但深入使用了 Triton 的内部机制。一旦 Triton API 发生变化,这里很可能是第一个出问题的地方。
词边界的确定与启发式修正
DTW 生成 token 到帧的对齐结果后,find_alignment() 将其转换为词边界:
- 使用
tokenizer.split_to_word_tokens()将 token 分割为词(支持多语言,详见第 3 篇) - 根据累积 token 数量计算词边界
- 从 DTW 路径中提取词边界位置对应的起止时间
- 将每个词内所有 token 概率的均值作为该词的概率
merge_punctuations() 随后调整词边界,将标点符号附加到相邻词上。开放性标点("、'、(、[、{、-)前置到下一个词;闭合性标点(.、,、!、?、)、] 等)追加到上一个词。
add_word_timestamps() 还会应用额外的边界修正启发式规则:
- 最大时长限制:句子边界处的词(
.、!、?)被限制为中位词语时长的两倍,防止句号"吸收"句子后的长段静音。 - 片段边界对齐:若第一个词的起始时间远早于片段起始时间,则优先采用片段级时间戳;末尾词的结束时间同理。
- 静音后修正:在长段停顿(超过中位时长的 4 倍)之后,第一个词的起始时间会被截断,以防止它横跨整段停顿。
这些规则在注释中被坦率地称为"hack",并注明最终应由基于 VAD 的分割算法取而代之。但它们确实有效——能够处理最常见的 DTW 对齐误差。
flowchart TD
A["find_alignment()\nDTW → raw word timings"] --> B["merge_punctuations()\nAttach . , ! ? to words"]
B --> C["Compute median/max\nword duration"]
C --> D["Truncate long words\nat sentence boundaries"]
D --> E["Align word starts/ends\nwith segment boundaries"]
E --> F["Post-pause correction\n(first word after silence)"]
F --> G["Final word timestamps\nper segment"]
输出写入器与字幕格式化
whisper/utils.py 中的输出系统以 ResultWriter 为根节点构建了一套类层次结构:
classDiagram
class ResultWriter {
+str extension
+str output_dir
+__call__(result, audio_path)
+write_result(result, file)*
}
class WriteTXT {
+extension = "txt"
}
class WriteJSON {
+extension = "json"
}
class WriteTSV {
+extension = "tsv"
}
class SubtitlesWriter {
+bool always_include_hours
+str decimal_marker
+iterate_result() → yields (start, end, text)
}
class WriteVTT {
+extension = "vtt"
}
class WriteSRT {
+extension = "srt"
}
ResultWriter <|-- WriteTXT
ResultWriter <|-- WriteJSON
ResultWriter <|-- WriteTSV
ResultWriter <|-- SubtitlesWriter
SubtitlesWriter <|-- WriteVTT
SubtitlesWriter <|-- WriteSRT
WriteTXT 和 WriteJSON 的实现相对简单。WriteTSV 使用整数毫秒表示时间,以避免因地区设置不同而导致的小数分隔符差异——这个细节在其文档字符串中有专门说明,颇为用心。
真正有趣的逻辑在 SubtitlesWriter.iterate_result() 中,它负责处理词级高亮。当 highlight_words=True 时,每个词都会生成独立的字幕条目,并用 <u> 标签标注当前激活的词:
00:00.640 --> 00:00.920
Hello <u>world</u> today
该方法还支持通过 max_line_width、max_line_count 和 max_words_per_line 参数控制换行行为,在保持自然词语分组的同时,输出符合这些限制的字幕块。
get_writer() 工厂函数负责格式选择。"all" 选项会创建全部五种写入器,并返回一个依次调用它们的组合函数——这是组合模式(Composite Pattern)的一种简洁实现:
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result, file, options=None, **kwargs):
for writer in all_writers:
writer(result, file, options, **kwargs)
return write_all
系列总结
经过六篇文章,我们已经梳理了 Whisper 约 2500 行代码中的每一条主要执行路径:
- 架构概览:9 个模块的包结构、入口点与方法绑定模式
- 音频前端:FFmpeg 加载、STFT、梅尔频谱图与卷积编码器主干
- 分词:双词表、1501 个时间戳 token 与 SOT 序列协议
- 解码:基于钩子的 KV-cache、束搜索、logit 过滤链与主循环
- 转录:滑动窗口、温度回退、无语音检测与幻觉恢复
- 词级时间戳:交叉注意力提取、DTW 对齐、运行时代码生成与输出格式化
Whisper 真正令人印象深刻的,并非某一项单独的技术——而是那种让代码库保持精简而不失能力的工程品味。没有多余的抽象,没有配置框架,没有插件系统。有的只是一条从音频到文本的清晰流水线,配合几个精心选择的设计模式,以及对那些真正影响语音识别鲁棒性的细节的高度专注。