Read OSS

Inside Whisper's Decoder: Beam Search, Logit Filters, and the KV-Cache

Advanced

Prerequisites

  • Articles 1-3
  • Understanding of beam search and temperature sampling
  • PyTorch forward hooks

Inside Whisper's Decoder: Beam Search, Logit Filters, and the KV-Cache

At 826 lines, whisper/decoding.py is the largest module in the codebase, and for good reason — it orchestrates the autoregressive generation that turns encoder features into token sequences. This module is where Whisper's design philosophy shines: clean abstractions (Strategy pattern for decoders, chain-of-responsibility for logit filters) composed by a single orchestrator (DecodingTask).

This article breaks down every major component: the configuration dataclasses, the KV-cache implementation via forward hooks, the token decoder hierarchy, the logit filter pipeline, and the main loop that ties it all together.

DecodingOptions and DecodingResult

The decoding system is configured through a frozen dataclass, DecodingOptions, which controls everything from sampling temperature to timestamp behavior:

@dataclass(frozen=True)
class DecodingOptions:
    task: str = "transcribe"
    language: Optional[str] = None
    temperature: float = 0.0
    sample_len: Optional[int] = None
    best_of: Optional[int] = None
    beam_size: Optional[int] = None
    patience: Optional[float] = None
    length_penalty: Optional[float] = None
    prompt: Optional[Union[str, List[int]]] = None
    prefix: Optional[Union[str, List[int]]] = None
    suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
    suppress_blank: bool = True
    without_timestamps: bool = False
    max_initial_timestamp: Optional[float] = 1.0
    fp16: bool = True

The frozen=True prevents accidental mutation during decoding. The corresponding DecodingResult captures the output: tokens, text, audio features, probabilities, and quality metrics like compression_ratio and no_speech_prob.

Note the suppress_tokens default of "-1" — this string sentinel triggers suppression of the non_speech_tokens set we covered in Article 3. It's a convenient default that prevents the model from generating music notes, speaker tags, and similar annotations.

The Inference Abstraction and KV-Cache

The Inference protocol defines the interface for running the decoder forward pass with caching. Its sole implementation, PyTorchInference, manages the KV-cache:

def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
    if not self.kv_cache:
        self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()

    if tokens.shape[-1] > self.initial_token_length:
        # only need to use the last token except in the first forward pass
        tokens = tokens[:, -1:]

    return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)

The KV-cache is Whisper's most clever implementation detail, found in model.py's install_kv_cache_hooks(). Instead of modifying the attention forward() method, it installs PyTorch forward hooks on the key and value projection layers:

sequenceDiagram
    participant D as Decoder
    participant K as Key Projection
    participant Hook as Forward Hook
    participant Cache as KV Cache Dict

    Note over D: First forward pass (all tokens)
    D->>K: key_proj(x) → full key tensor
    K->>Hook: save_to_cache(output)
    Hook->>Cache: cache[key_proj] = output

    Note over D: Subsequent passes (last token only)
    D->>K: key_proj(x[:, -1:]) → single key
    K->>Hook: save_to_cache(output)
    Hook->>Cache: cache[key_proj] = cat(cached, output)
    Hook-->>D: return concatenated keys

The save_to_cache hook function handles two cases:

def save_to_cache(module, _, output):
    if module not in cache or output.shape[1] > self.dims.n_text_ctx:
        cache[module] = output                                    # first token or cross-attn
    else:
        cache[module] = torch.cat([cache[module], output], dim=1).detach()
    return cache[module]

For self-attention keys/values, each new token's projection is concatenated to the cached sequence. For cross-attention, the full key/value tensors (computed from encoder output) are cached on the first pass and reused unchanged — this is why the encoder only runs once per 30-second window.

The hook's return value is critical: PyTorch forward hooks can modify a module's output by returning a new tensor. Here, the hook returns the full concatenated cache, so the attention layer sees the complete key/value sequence even though only one token was processed.

Tip: The detach() call in the cache concatenation prevents gradient accumulation through the cache, keeping memory usage bounded during inference.

The TokenDecoder base class defines the Strategy pattern for token selection. Two implementations handle the core strategies:

classDiagram
    class TokenDecoder {
        <<abstract>>
        +reset()
        +update(tokens, logits, sum_logprobs) → tokens, completed
        +finalize(tokens, sum_logprobs) → candidates
    }

    class GreedyDecoder {
        +float temperature
        +int eot
        +update(): argmax or sample
    }

    class BeamSearchDecoder {
        +int beam_size
        +float patience
        +Inference inference
        +update(): expand and prune beams
    }

    TokenDecoder <|-- GreedyDecoder
    TokenDecoder <|-- BeamSearchDecoder

GreedyDecoder is straightforward: at temperature 0 it takes argmax, otherwise it samples from Categorical(logits=logits/temperature). It accumulates log probabilities and propagates the EOT token once a sequence has ended.

BeamSearchDecoder is more involved. Its update() method works per-audio-sample:

  1. For each beam, compute log probabilities of the top beam_size + 1 tokens
  2. Score all candidate continuations (existing beams × top tokens)
  3. Keep the top beam_size non-EOT candidates; collect EOT candidates as finished
  4. Rearrange the KV-cache to match the new beam ordering

The patience mechanism (from arxiv:2204.05424) allows collecting up to round(beam_size * patience) finished sequences before declaring completion, which can find better results by allowing the search to continue after the first beam finishes.

The rearrange_kv_cache() call at line 365 is where beam search interacts with caching. When beams are reordered, the cached key/value tensors must be shuffled to match — otherwise beam 2's cache would be used for beam 0's continuation.

LogitFilter Chain

The logit filter pipeline implements a chain-of-responsibility pattern. Each filter modifies logits in-place before the decoder selects tokens:

flowchart LR
    A["Raw Logits\nfrom decoder"] --> B["SuppressBlank\n(prevent blank/EOT\nat position 0)"]
    B --> C["SuppressTokens\n(mask non-speech\n& special tokens)"]
    C --> D["ApplyTimestampRules\n(enforce timestamp\ngrammar)"]
    D --> E["Filtered Logits\nto TokenDecoder"]

SuppressBlank fires only at the first sampling position, setting logits for space tokens and EOT to -inf. Without this, the model might immediately produce a blank or end the sequence.

SuppressTokens unconditionally masks a set of token IDs. The default suppress list includes all non-speech tokens (♪, [, (, etc.) plus control tokens like <|translate|> and <|startofprev|> that should never be generated during normal decoding.

ApplyTimestampRules is the most complex filter, enforcing the timestamp grammar we described in Article 3. Its rules:

  1. Pairing: After a timestamp that follows a non-timestamp (i.e., an end-time marker), suppress all text tokens — force the next output to be either a timestamp (start of next segment) or EOT.
  2. After paired timestamps: Suppress all timestamp tokens — force text output after a segment boundary.
  3. Monotonicity: Forbid timestamps smaller than the last one, preventing backward jumps.
  4. Initial constraint: At the first sampling position, suppress all non-timestamp tokens and enforce max_initial_timestamp.
  5. Probability comparison: If the sum of all timestamp token probabilities exceeds the max text token probability, suppress all text tokens — this lets the model "choose" to emit a timestamp when it's confident a segment boundary exists.

This filter turns what would be an unconstrained text generator into a structured segment producer.

DecodingTask: Assembly and Main Loop

DecodingTask is the orchestrator. Its __init__ assembles all components based on DecodingOptions:

# inference: forward pass with KV caching
self.inference = PyTorchInference(model, len(self.initial_tokens))

# sequence ranker: how to rank groups of samples
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)

# decoder: how to select next tokens
if options.beam_size is not None:
    self.decoder = BeamSearchDecoder(...)
else:
    self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)

# logit filters: rules to suppress/penalize tokens
self.logit_filters = []
if self.options.suppress_blank:
    self.logit_filters.append(SuppressBlank(...))
# ... etc

The _main_loop() implements the core autoregressive cycle:

for i in range(self.sample_len):
    logits = self.inference.logits(tokens, audio_features)

    if i == 0 and self.tokenizer.no_speech is not None:
        probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
        no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()

    logits = logits[:, -1]

    for logit_filter in self.logit_filters:
        logit_filter.apply(logits, tokens)

    tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)

    if completed or tokens.shape[-1] > self.n_ctx:
        break

Notice the no-speech probability extraction at i == 0: it reads the softmax probability of the <|nospeech|> token from the logits at the SOT position (not the last position). This is a single-value extraction that happens only once, providing a voice activity detection signal for the transcription loop.

The run() method orchestrates the full pipeline: encode audio → detect language → repeat tokens for beam/best-of grouping → run main loop → finalize → rank → build results.

Language Detection

The detect_language() function is an elegant single-token trick. It feeds just <|startoftranscript|> as the input token, runs the decoder for a single step, and reads the probability distribution over all language tokens from the output logits:

x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)
logits = model.logits(x, mel)[:, 0]

mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)

After the SOT token, the model expects a language token. By masking all non-language tokens to -inf and taking the argmax, we get the model's best guess at the spoken language. The softmax gives a full probability distribution over all supported languages.

This is deliberately done outside the main decode loop to avoid polluting the KV-cache — it's a clean, separate forward pass.

Tip: You can batch language detection across multiple audio segments by passing a batched mel tensor. The function handles both single (ndim == 2) and batched (ndim == 3) inputs.

What's Next

We've covered the complete decoding system — from configuration through caching to token selection. In Article 5, we'll zoom out to the transcription loop that wraps this decoder, processing full-length audio files by sliding a 30-second window, recovering from degenerate outputs with temperature fallback, and detecting hallucinations.