深入 Whisper 解码器:Beam Search、Logit 过滤器与 KV-Cache
前置知识
- ›第 1-3 篇文章
- ›理解 beam search 与温度采样
- ›PyTorch forward hooks
深入 Whisper 解码器:Beam Search、Logit 过滤器与 KV-Cache
whisper/decoding.py 是整个代码库中最大的模块,共 826 行,这并非偶然——它负责协调自回归生成的全过程,将编码器输出的特征转换为 token 序列。这个模块集中体现了 Whisper 的设计哲学:以清晰的抽象(解码器采用策略模式,logit 过滤器采用责任链模式)为基础,由统一的协调者 DecodingTask 将它们组合在一起。
本文将逐一拆解每个核心组件:配置用的 dataclass、基于 forward hook 的 KV-cache 实现、token 解码器的层级结构、logit 过滤器流水线,以及将所有环节串联起来的主循环。
DecodingOptions 与 DecodingResult
解码系统通过一个冻结的 dataclass DecodingOptions 进行配置,它控制着从采样温度到时间戳行为的方方面面:
@dataclass(frozen=True)
class DecodingOptions:
task: str = "transcribe"
language: Optional[str] = None
temperature: float = 0.0
sample_len: Optional[int] = None
best_of: Optional[int] = None
beam_size: Optional[int] = None
patience: Optional[float] = None
length_penalty: Optional[float] = None
prompt: Optional[Union[str, List[int]]] = None
prefix: Optional[Union[str, List[int]]] = None
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
suppress_blank: bool = True
without_timestamps: bool = False
max_initial_timestamp: Optional[float] = 1.0
fp16: bool = True
frozen=True 可以防止解码过程中的意外修改。与之对应的 DecodingResult 负责承载输出结果:token 序列、文本、音频特征、概率分布,以及 compression_ratio、no_speech_prob 等质量指标。
值得注意的是,suppress_tokens 的默认值是字符串 "-1"——这个哨兵值会触发对 non_speech_tokens 集合的抑制(我们在第 3 篇文章中有过介绍)。这个默认设置非常实用,能防止模型生成音符、说话人标签等标注符号。
Inference 抽象与 KV-Cache
Inference 协议定义了带缓存的解码器前向传播接口。其唯一实现 PyTorchInference 负责管理 KV-cache:
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
if not self.kv_cache:
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
KV-cache 是 Whisper 中最精妙的实现细节,具体逻辑位于 model.py 的 install_kv_cache_hooks()。这里没有直接修改 attention 的 forward() 方法,而是在 key 和 value 的投影层上安装 PyTorch forward hook:
sequenceDiagram
participant D as Decoder
participant K as Key Projection
participant Hook as Forward Hook
participant Cache as KV Cache Dict
Note over D: First forward pass (all tokens)
D->>K: key_proj(x) → full key tensor
K->>Hook: save_to_cache(output)
Hook->>Cache: cache[key_proj] = output
Note over D: Subsequent passes (last token only)
D->>K: key_proj(x[:, -1:]) → single key
K->>Hook: save_to_cache(output)
Hook->>Cache: cache[key_proj] = cat(cached, output)
Hook-->>D: return concatenated keys
save_to_cache hook 函数处理两种情况:
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
cache[module] = output # first token or cross-attn
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
对于自注意力的 key/value,每个新 token 的投影结果会被拼接到已缓存的序列上。对于交叉注意力,完整的 key/value 张量(由编码器输出计算得到)在第一次前向传播时被缓存,后续复用不变——这正是编码器在每个 30 秒窗口内只运行一次的原因所在。
hook 的返回值至关重要:PyTorch forward hook 可以通过返回新张量来修改模块的输出。在这里,hook 返回完整的拼接缓存,因此注意力层看到的始终是完整的 key/value 序列,即便实际上只处理了一个 token。
提示: 缓存拼接时调用的
detach()会切断梯度在缓存中的累积,从而将推理期间的内存占用控制在合理范围内。
Token 解码器:贪心解码与 Beam Search
TokenDecoder 基类通过策略模式定义了 token 选择的接口,两个具体实现分别对应核心策略:
classDiagram
class TokenDecoder {
<<abstract>>
+reset()
+update(tokens, logits, sum_logprobs) → tokens, completed
+finalize(tokens, sum_logprobs) → candidates
}
class GreedyDecoder {
+float temperature
+int eot
+update(): argmax or sample
}
class BeamSearchDecoder {
+int beam_size
+float patience
+Inference inference
+update(): expand and prune beams
}
TokenDecoder <|-- GreedyDecoder
TokenDecoder <|-- BeamSearchDecoder
GreedyDecoder 的逻辑很直接:温度为 0 时取 argmax,否则从 Categorical(logits=logits/temperature) 中采样。它会累积对数概率,并在序列结束后传播 EOT token。
BeamSearchDecoder 则要复杂得多。它的 update() 方法按音频样本逐一处理:
- 对每条 beam,计算概率最高的
beam_size + 1个 token 的对数概率 - 对所有候选续写(现有 beam × 候选 token)进行打分
- 保留得分最高的
beam_size条非 EOT 候选;将 EOT 候选收入已完成序列 - 重新排列 KV-cache,使其与新的 beam 顺序保持一致
patience 机制(来自 arxiv:2204.05424)允许最多收集 round(beam_size * patience) 条已完成序列后再宣告结束,这样即使第一条 beam 已经完成,搜索仍可继续,从而找到更优的结果。
第 365 行的 rearrange_kv_cache() 调用,正是 beam search 与缓存机制交互的关键所在。当 beam 顺序发生变化时,缓存的 key/value 张量也必须随之重排——否则 beam 2 的缓存就会被错误地用于 beam 0 的后续生成。
LogitFilter 责任链
logit 过滤器流水线采用责任链模式实现。每个过滤器在解码器选取 token 之前就地修改 logits:
flowchart LR
A["Raw Logits\nfrom decoder"] --> B["SuppressBlank\n(prevent blank/EOT\nat position 0)"]
B --> C["SuppressTokens\n(mask non-speech\n& special tokens)"]
C --> D["ApplyTimestampRules\n(enforce timestamp\ngrammar)"]
D --> E["Filtered Logits\nto TokenDecoder"]
SuppressBlank 仅在第一个采样位置触发,将空格 token 和 EOT 的 logit 设为 -inf。没有这个过滤器,模型可能一开始就输出空白或直接结束序列。
SuppressTokens 无条件屏蔽一组指定的 token ID。默认抑制列表包含所有非语音 token(♪、[、( 等),以及在正常解码中不应出现的控制 token,如 <|translate|> 和 <|startofprev|>。
ApplyTimestampRules 是最复杂的过滤器,用于强制执行我们在第 3 篇文章中介绍的时间戳语法规则:
- 配对约束:在一个紧随非时间戳 token 的时间戳(即结束时间标记)之后,抑制所有文本 token——强制下一个输出要么是时间戳(下一段的开始),要么是 EOT。
- 成对时间戳之后:抑制所有时间戳 token——在段边界之后强制输出文本。
- 单调性约束:禁止出现小于上一个时间戳的时间戳,防止时间倒退。
- 初始约束:在第一个采样位置,抑制所有非时间戳 token,并强制执行
max_initial_timestamp限制。 - 概率比较:若所有时间戳 token 的概率之和超过最大文本 token 的概率,则抑制所有文本 token——这让模型在确信存在段边界时能够"主动选择"输出时间戳。
这个过滤器将一个无约束的文本生成器改造成了结构化的分段生产者。
DecodingTask:组装与主循环
DecodingTask 是整个系统的协调者。其 __init__ 方法根据 DecodingOptions 将所有组件组装在一起:
# inference: forward pass with KV caching
self.inference = PyTorchInference(model, len(self.initial_tokens))
# sequence ranker: how to rank groups of samples
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
# decoder: how to select next tokens
if options.beam_size is not None:
self.decoder = BeamSearchDecoder(...)
else:
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
# logit filters: rules to suppress/penalize tokens
self.logit_filters = []
if self.options.suppress_blank:
self.logit_filters.append(SuppressBlank(...))
# ... etc
_main_loop() 实现了核心的自回归循环:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)
if i == 0 and self.tokenizer.no_speech is not None:
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
logits = logits[:, -1]
for logit_filter in self.logit_filters:
logit_filter.apply(logits, tokens)
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
if completed or tokens.shape[-1] > self.n_ctx:
break
注意 i == 0 时的无语音概率提取:它从 SOT 位置(而非最后一个位置)的 logits 中读取 <|nospeech|> token 的 softmax 概率。这是一次只发生在开头的单值提取,为转录循环提供语音活动检测信号。
run() 方法统筹整个流水线:编码音频 → 检测语言 → 为 beam/best-of 分组复制 token → 运行主循环 → 完成后处理 → 排序 → 构建结果。
语言检测
detect_language() 函数的实现堪称优雅——仅靠一个 token 就完成了语言检测。它以 <|startoftranscript|> 作为输入 token,让解码器运行一步,然后从输出 logits 中读取所有语言 token 的概率分布:
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)
logits = model.logits(x, mel)[:, 0]
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)
在 SOT token 之后,模型期望下一个输出是语言 token。通过将所有非语言 token 的 logit 设为 -inf 并取 argmax,我们便能得到模型对所说语言的最优预测。对结果取 softmax,则可获得覆盖全部支持语言的完整概率分布。
这一步有意放在主解码循环之外,目的是避免污染 KV-cache——它是一次独立、干净的前向传播。
提示: 通过传入批量的 mel 张量,可以对多段音频同时进行语言检测。该函数同时支持单条(
ndim == 2)和批量(ndim == 3)输入。
下一篇
至此,我们已经完整地走过了整个解码系统——从配置到缓存,再到 token 选择。在第 5 篇文章中,我们将把视角拉远,聚焦于包裹这个解码器的转录循环:它通过滑动 30 秒窗口处理完整的长音频,在遇到退化输出时借助温度回退机制进行恢复,并识别幻觉输出。