Read OSS

词级时间戳:交叉注意力对齐、DTW 与输出格式化

高级

前置知识

  • 第 1–5 篇文章
  • 动态时间规整(DTW)基础知识
  • 对注意力机制有一定了解

词级时间戳:交叉注意力对齐、DTW 与输出格式化

词级时间戳是 Whisper 后处理流程中技术含量最高的功能。给定一段已解码的文本片段,系统会精确判断每个词的发音时刻——这并非依赖解码器的时间戳 token(那只能提供片段级边界),而是通过交叉注意力权重来实现的:这些权重揭示了解码器在生成每个文本 token 时"关注"的是哪些音频帧。

本文将完整梳理整个处理流程:禁用 SDPA 以获取原始注意力权重、从特定注意力头提取 QK 矩阵、通过中值滤波器处理、使用动态时间规整(DTW,含 CPU 和 GPU 两种实现)将 token 对齐到音频帧、合并标点符号,最终格式化输出字幕。

交叉注意力权重的提取

时间对齐流程有一个前提条件,它直接影响了整个实现方案:PyTorch 的 scaled_dot_product_attention(SDPA)是一个融合内核,不会返回注意力权重。SDPA 处于激活状态时,qkNone(详见 model.py 第 123–128 行)。要提取 QK 矩阵,必须先禁用 SDPA。

model.py 中的 disable_sdpa() 上下文管理器负责处理这一切:

@contextmanager
def disable_sdpa():
    prev_state = MultiHeadAttention.use_sdpa
    try:
        MultiHeadAttention.use_sdpa = False
        yield
    finally:
        MultiHeadAttention.use_sdpa = prev_state

它通过切换 MultiHeadAttention 的类变量,将所有注意力层从融合路径切换到显式 QK 计算路径,从而返回注意力权重。

find_alignment() 内部,前向钩子被安装到每一个交叉注意力层,用于捕获 QK 矩阵:

QKs = [None] * model.dims.n_text_layer
hooks = [
    block.cross_attn.register_forward_hook(
        lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
    )
    for i, block in enumerate(model.decoder.blocks)
]
sequenceDiagram
    participant FA as find_alignment()
    participant SDPA as disable_sdpa()
    participant Model as Whisper Model
    participant Hooks as Cross-Attn Hooks
    participant QKs as QK Matrix List

    FA->>SDPA: Enter context (disable SDPA)
    FA->>Hooks: Install forward hooks on cross_attn layers
    FA->>Model: Forward pass (mel, tokens)
    Model->>Hooks: Each cross-attn fires hook
    Hooks->>QKs: Store QK matrix [heads × tokens × frames]
    FA->>Hooks: Remove all hooks
    FA->>SDPA: Exit context (restore SDPA)

并非所有注意力头对对齐任务都同样有用。alignment_heads 掩码(由第 1 篇中介绍的 base85 编码 _ALIGNMENT_HEADS 数据设置)标识了与词语时间高度相关的特定头,对齐时只使用这些头:

weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])

注意力权重的处理

原始 QK 权重在进行 DTW 对齐之前,需要经过一条处理流水线:

flowchart LR
    A["Raw QK weights\n[heads × tokens × frames]"] -->|"× qk_scale"| B["Scaled"]
    B -->|"softmax(dim=-1)"| C["Normalized"]
    C -->|"standardize\n(mean/std across tokens)"| D["Standardized"]
    D -->|"median_filter\n(width=7)"| E["Smoothed"]
    E -->|"mean across heads"| F["Final matrix\n[tokens × frames]"]

标准化步骤(沿 token 维度减去均值并除以标准差)对每帧的注意力分布进行归一化,使对齐结果更加鲁棒。中值滤波则用于去除注意力权重中的尖刺噪声。

median_filter() 函数优先尝试 Triton GPU 实现,如果失败则回退到使用 torch.sort() 的 CPU 版本:

if x.is_cuda:
    try:
        from .triton_ops import median_filter_cuda
        result = median_filter_cuda(x, filter_width)
    except (RuntimeError, subprocess.CalledProcessError):
        warnings.warn("Failed to launch Triton kernels...")

if result is None:
    result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]

CPU 回退方案利用了 PyTorch 的一个特性:对于此场景,sort()torch.median() 更快(pytorch/pytorch#51450)。

动态时间规整:CPU 与 GPU 实现

DTW 用于在 token 序列与音频帧序列之间寻找最优的单调对齐路径。从概念上讲,它构建一个代价矩阵,其中 cost[i,j] 表示将前 i 个 token 与前 j 帧对齐的最小代价,然后回溯找出最优路径。

dtw_cpu() 使用 Numba JIT 编译来提升性能:

@numba.jit(nopython=True, parallel=True)
def dtw_cpu(x: np.ndarray):
    N, M = x.shape
    cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
    trace = -np.ones((N + 1, M + 1), dtype=np.float32)
    cost[0, 0] = 0
    for j in range(1, M + 1):
        for i in range(1, N + 1):
            c0 = cost[i - 1, j - 1]  # diagonal
            c1 = cost[i - 1, j]      # vertical
            c2 = cost[i, j - 1]      # horizontal
            # ... pick minimum, record trace

GPU 实现 dtw_cuda() 使用了带有反对角线波前并行化的 Triton kernel。核心思路在于:同一反对角线上的所有单元格(满足 i + j = k)彼此独立,因此可以并行计算。dtw_kernel 依次遍历从 k = 1N + M 的所有反对角线:

for k in range(1, N + M + 1):
    tl.debug_barrier()
    # Load costs from three predecessors
    c0 = tl.load(p0 + offsets, mask=mask)  # diagonal
    c1 = tl.load(p1 + offsets, mask=mask)  # from above
    c2 = tl.load(p2 + offsets, mask=mask)  # from left
    # Compute minimum and store
    cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)

在调用 kernel 之前,矩阵会通过 F.pad 和 reshape 进行"倾斜"处理,使反对角线变为行,从而实现连续的内存访问。

dtw() 调度函数优先尝试 GPU,失败时回退到 CPU:

flowchart TD
    A["dtw(cost_matrix)"] --> B{"CUDA\navailable?"}
    B -->|Yes| C["dtw_cuda()\nTriton kernel"]
    C -->|"RuntimeError"| D["dtw_cpu()\nNumba JIT"]
    B -->|No| D
    C --> E["backtrace()\nExtract optimal path"]
    D --> E

Triton 中值滤波器的代码生成技巧

整个 Whisper 代码库中最不寻常的模式,藏在 triton_ops.pymedian_kernel() 中。它并没有编写一个通用的中值滤波 kernel,而是在运行时动态生成针对特定滤波器宽度的专用 kernel 源码。

模板 kernel 中包含若干占位符标识符:

LOAD_ALL_ROWS_HERE    # replaced with N load statements
BUBBLESORT_HERE       # replaced with unrolled bubblesort comparisons
MIDDLE_ROW_HERE       # replaced with the median variable name

median_kernel() 函数通过字符串替换这些占位符来构建实际的 kernel。以 filter_width=7 为例,这会生成 7 条 load 语句、约 28 次比较交换操作(展开的冒泡排序),并存储 row3(即中位数):

new_kernel = kernel.src.replace(
    "    LOAD_ALL_ROWS_HERE",
    "\n".join([
        f"    row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
        for i in range(filter_width)
    ]),
)

生成的冒泡排序只进行恰好足够的轮次,以保证中间元素正确——即 filter_width // 2 + 1 轮嵌套比较交换:

f"    smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
f"    larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
f"    row{j} = smaller",
f"    row{j + 1} = larger",

随后,函数通过 kernel._unsafe_update_src(new_kernel)(或在旧版 Triton 中直接赋值给 kernel.src)将生成的源码注入 JIT 函数,并清除哈希值以强制重新编译。@lru_cache 确保每种滤波器宽度只进行一次代码生成。

这是一种由 Triton 约束驱动的分阶段元编程方案:Triton kernel 在某些维度上必须是静态形状的,因此动态滤波器宽度只能通过生成专用 kernel 来实现。

注意: _unsafe_update_src 这个方法名本身就说明了一切——这种用法有多"不被官方支持"。它目前可以正常工作,但深入使用了 Triton 的内部机制。一旦 Triton API 发生变化,这里很可能是第一个出问题的地方。

词边界的确定与启发式修正

DTW 生成 token 到帧的对齐结果后,find_alignment() 将其转换为词边界:

  1. 使用 tokenizer.split_to_word_tokens() 将 token 分割为词(支持多语言,详见第 3 篇)
  2. 根据累积 token 数量计算词边界
  3. 从 DTW 路径中提取词边界位置对应的起止时间
  4. 将每个词内所有 token 概率的均值作为该词的概率

merge_punctuations() 随后调整词边界,将标点符号附加到相邻词上。开放性标点("'([{-)前置到下一个词;闭合性标点(.,!?)] 等)追加到上一个词。

add_word_timestamps() 还会应用额外的边界修正启发式规则:

  • 最大时长限制:句子边界处的词(.!?)被限制为中位词语时长的两倍,防止句号"吸收"句子后的长段静音。
  • 片段边界对齐:若第一个词的起始时间远早于片段起始时间,则优先采用片段级时间戳;末尾词的结束时间同理。
  • 静音后修正:在长段停顿(超过中位时长的 4 倍)之后,第一个词的起始时间会被截断,以防止它横跨整段停顿。

这些规则在注释中被坦率地称为"hack",并注明最终应由基于 VAD 的分割算法取而代之。但它们确实有效——能够处理最常见的 DTW 对齐误差。

flowchart TD
    A["find_alignment()\nDTW → raw word timings"] --> B["merge_punctuations()\nAttach . , ! ? to words"]
    B --> C["Compute median/max\nword duration"]
    C --> D["Truncate long words\nat sentence boundaries"]
    D --> E["Align word starts/ends\nwith segment boundaries"]
    E --> F["Post-pause correction\n(first word after silence)"]
    F --> G["Final word timestamps\nper segment"]

输出写入器与字幕格式化

whisper/utils.py 中的输出系统以 ResultWriter 为根节点构建了一套类层次结构:

classDiagram
    class ResultWriter {
        +str extension
        +str output_dir
        +__call__(result, audio_path)
        +write_result(result, file)*
    }

    class WriteTXT {
        +extension = "txt"
    }

    class WriteJSON {
        +extension = "json"
    }

    class WriteTSV {
        +extension = "tsv"
    }

    class SubtitlesWriter {
        +bool always_include_hours
        +str decimal_marker
        +iterate_result() → yields (start, end, text)
    }

    class WriteVTT {
        +extension = "vtt"
    }

    class WriteSRT {
        +extension = "srt"
    }

    ResultWriter <|-- WriteTXT
    ResultWriter <|-- WriteJSON
    ResultWriter <|-- WriteTSV
    ResultWriter <|-- SubtitlesWriter
    SubtitlesWriter <|-- WriteVTT
    SubtitlesWriter <|-- WriteSRT

WriteTXTWriteJSON 的实现相对简单。WriteTSV 使用整数毫秒表示时间,以避免因地区设置不同而导致的小数分隔符差异——这个细节在其文档字符串中有专门说明,颇为用心。

真正有趣的逻辑在 SubtitlesWriter.iterate_result() 中,它负责处理词级高亮。当 highlight_words=True 时,每个词都会生成独立的字幕条目,并用 <u> 标签标注当前激活的词:

00:00.640 --> 00:00.920
Hello <u>world</u> today

该方法还支持通过 max_line_widthmax_line_countmax_words_per_line 参数控制换行行为,在保持自然词语分组的同时,输出符合这些限制的字幕块。

get_writer() 工厂函数负责格式选择。"all" 选项会创建全部五种写入器,并返回一个依次调用它们的组合函数——这是组合模式(Composite Pattern)的一种简洁实现:

if output_format == "all":
    all_writers = [writer(output_dir) for writer in writers.values()]
    def write_all(result, file, options=None, **kwargs):
        for writer in all_writers:
            writer(result, file, options, **kwargs)
    return write_all

系列总结

经过六篇文章,我们已经梳理了 Whisper 约 2500 行代码中的每一条主要执行路径:

  1. 架构概览:9 个模块的包结构、入口点与方法绑定模式
  2. 音频前端:FFmpeg 加载、STFT、梅尔频谱图与卷积编码器主干
  3. 分词:双词表、1501 个时间戳 token 与 SOT 序列协议
  4. 解码:基于钩子的 KV-cache、束搜索、logit 过滤链与主循环
  5. 转录:滑动窗口、温度回退、无语音检测与幻觉恢复
  6. 词级时间戳:交叉注意力提取、DTW 对齐、运行时代码生成与输出格式化

Whisper 真正令人印象深刻的,并非某一项单独的技术——而是那种让代码库保持精简而不失能力的工程品味。没有多余的抽象,没有配置框架,没有插件系统。有的只是一条从音频到文本的清晰流水线,配合几个精心选择的设计模式,以及对那些真正影响语音识别鲁棒性的细节的高度专注。