Read OSS

Word-Level Timestamps: Cross-Attention Alignment, DTW, and Output Formatting

Advanced

Prerequisites

  • Articles 1-5
  • Basic understanding of Dynamic Time Warping
  • Familiarity with attention mechanisms

Word-Level Timestamps: Cross-Attention Alignment, DTW, and Output Formatting

Word-level timestamps are Whisper's most technically impressive post-processing feature. Given a decoded text segment, the system determines exactly when each word was spoken — not from the decoder's timestamp tokens (which give segment-level boundaries), but from the cross-attention weights that reveal which audio frames the decoder was "looking at" when generating each text token.

This article covers the full pipeline: disabling SDPA to access raw attention weights, extracting QK matrices from specific attention heads, processing them through a median filter, aligning tokens to audio frames via Dynamic Time Warping (with both CPU and GPU implementations), merging punctuation, and formatting the results for subtitle output.

Cross-Attention Weight Extraction

The timing pipeline starts with a requirement that shapes the entire implementation: PyTorch's scaled_dot_product_attention (SDPA) is a fused kernel that does not return attention weights. When SDPA is active, qk is None (as we can see in model.py lines 123-128). To extract QK matrices, SDPA must be disabled.

The disable_sdpa() context manager in model.py handles this:

@contextmanager
def disable_sdpa():
    prev_state = MultiHeadAttention.use_sdpa
    try:
        MultiHeadAttention.use_sdpa = False
        yield
    finally:
        MultiHeadAttention.use_sdpa = prev_state

It toggles a class variable on MultiHeadAttention, switching all attention layers from the fused path to the explicit QK computation path that returns attention weights.

Inside find_alignment(), forward hooks are installed on every cross-attention layer to capture the QK matrices:

QKs = [None] * model.dims.n_text_layer
hooks = [
    block.cross_attn.register_forward_hook(
        lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
    )
    for i, block in enumerate(model.decoder.blocks)
]
sequenceDiagram
    participant FA as find_alignment()
    participant SDPA as disable_sdpa()
    participant Model as Whisper Model
    participant Hooks as Cross-Attn Hooks
    participant QKs as QK Matrix List

    FA->>SDPA: Enter context (disable SDPA)
    FA->>Hooks: Install forward hooks on cross_attn layers
    FA->>Model: Forward pass (mel, tokens)
    Model->>Hooks: Each cross-attn fires hook
    Hooks->>QKs: Store QK matrix [heads × tokens × frames]
    FA->>Hooks: Remove all hooks
    FA->>SDPA: Exit context (restore SDPA)

Not all attention heads are equally informative for alignment. The alignment_heads mask (set from the base85-encoded _ALIGNMENT_HEADS data we saw in Article 1) identifies specific heads that correlate strongly with word timing. Only these heads are used:

weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])

Attention Weight Processing

The raw QK weights go through a processing pipeline before DTW alignment:

flowchart LR
    A["Raw QK weights\n[heads × tokens × frames]"] -->|"× qk_scale"| B["Scaled"]
    B -->|"softmax(dim=-1)"| C["Normalized"]
    C -->|"standardize\n(mean/std across tokens)"| D["Standardized"]
    D -->|"median_filter\n(width=7)"| E["Smoothed"]
    E -->|"mean across heads"| F["Final matrix\n[tokens × frames]"]

The standardization step (subtract mean, divide by std across the token dimension) normalizes each frame's attention distribution, making the alignment more robust. The median filter removes spiky noise from the attention weights.

The median_filter() function tries the Triton GPU implementation first, falling back to a CPU version using torch.sort():

if x.is_cuda:
    try:
        from .triton_ops import median_filter_cuda
        result = median_filter_cuda(x, filter_width)
    except (RuntimeError, subprocess.CalledProcessError):
        warnings.warn("Failed to launch Triton kernels...")

if result is None:
    result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]

The CPU fallback exploits a PyTorch quirk: sort() is faster than torch.median() for this use case (pytorch/pytorch#51450).

Dynamic Time Warping: CPU and GPU

DTW finds the optimal monotonic alignment between the token sequence and the audio frame sequence. Conceptually, it builds a cost matrix where cost[i,j] is the minimum cost of aligning the first i tokens with the first j frames, then backtraces to find the optimal path.

The dtw_cpu() implementation uses Numba JIT compilation for performance:

@numba.jit(nopython=True, parallel=True)
def dtw_cpu(x: np.ndarray):
    N, M = x.shape
    cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
    trace = -np.ones((N + 1, M + 1), dtype=np.float32)
    cost[0, 0] = 0
    for j in range(1, M + 1):
        for i in range(1, N + 1):
            c0 = cost[i - 1, j - 1]  # diagonal
            c1 = cost[i - 1, j]      # vertical
            c2 = cost[i, j - 1]      # horizontal
            # ... pick minimum, record trace

The GPU implementation dtw_cuda() uses a Triton kernel with anti-diagonal wavefront parallelization. The key insight: all cells on the same anti-diagonal (i + j = k) are independent of each other, so they can be computed in parallel. The dtw_kernel iterates over anti-diagonals k = 1 to N + M:

for k in range(1, N + M + 1):
    tl.debug_barrier()
    # Load costs from three predecessors
    c0 = tl.load(p0 + offsets, mask=mask)  # diagonal
    c1 = tl.load(p1 + offsets, mask=mask)  # from above
    c2 = tl.load(p2 + offsets, mask=mask)  # from left
    # Compute minimum and store
    cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)

The matrix is "skewed" before the kernel (using F.pad and reshape) so that anti-diagonals become rows, enabling contiguous memory access.

The dtw() dispatcher tries GPU first, falling back to CPU:

flowchart TD
    A["dtw(cost_matrix)"] --> B{"CUDA\navailable?"}
    B -->|Yes| C["dtw_cuda()\nTriton kernel"]
    C -->|"RuntimeError"| D["dtw_cpu()\nNumba JIT"]
    B -->|No| D
    C --> E["backtrace()\nExtract optimal path"]
    D --> E

The Triton Median Filter's Code Generation Trick

The most unusual pattern in the entire Whisper codebase lives in triton_ops.py's median_kernel(). Rather than writing a generic median filter kernel, it generates specialized kernel source code at runtime for each filter width.

The template kernel contains placeholder identifiers:

LOAD_ALL_ROWS_HERE    # replaced with N load statements
BUBBLESORT_HERE       # replaced with unrolled bubblesort comparisons
MIDDLE_ROW_HERE       # replaced with the median variable name

The median_kernel() function constructs the actual kernel by string-replacing these placeholders. For filter_width=7, this generates 7 load statements, ~28 compare-and-swap operations (an unrolled bubblesort), and stores row3 (the median):

new_kernel = kernel.src.replace(
    "    LOAD_ALL_ROWS_HERE",
    "\n".join([
        f"    row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
        for i in range(filter_width)
    ]),
)

The generated bubblesort is just enough passes to guarantee the middle element is correct — filter_width // 2 + 1 passes of nested compare-and-swap:

f"    smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
f"    larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
f"    row{j} = smaller",
f"    row{j + 1} = larger",

The function then patches the JIT function's source via kernel._unsafe_update_src(new_kernel) (or direct kernel.src assignment for older Triton versions), and clears the hash to force recompilation. The @lru_cache ensures this code generation happens only once per filter width.

This is an example of staged metaprogramming motivated by Triton's constraints: Triton kernels must be statically shaped in certain dimensions, so dynamic filter widths require specialized kernels.

Tip: The _unsafe_update_src method name tells you everything about how supported this approach is. It works, but it's reaching deep into Triton's internals. If Triton's API changes, this is the first thing that will break.

Word Boundary Assignment and Heuristics

After DTW produces the token-to-frame alignment, find_alignment() converts it to word boundaries:

  1. Split tokens into words using tokenizer.split_to_word_tokens() (language-aware, as covered in Article 3)
  2. Compute word boundaries from cumulative token counts
  3. Extract start/end times from the DTW path at word boundary positions
  4. Compute word probabilities as the mean token probability within each word

The merge_punctuations() function then adjusts word boundaries to attach punctuation to adjacent words. Opening punctuation (", ', (, [, {, -) is prepended to the following word; closing punctuation (., ,, !, ?, ), ], etc.) is appended to the preceding word.

The add_word_timestamps() function applies additional boundary correction heuristics:

  • Max duration capping: Words at sentence boundaries (., !, ?) are capped at twice the median word duration. This prevents a period from "absorbing" a long silence after a sentence.
  • Segment boundary alignment: If the first word's start time is much earlier than the segment's start time, prefer the segment-level timestamp. Similarly for the last word's end time.
  • Post-pause correction: After a long pause (> 4× median duration), the first word's start is clamped to prevent it from spanning the entire pause.

These heuristics are self-described as "hacks" in the comments, with notes that a VAD-based segmentation algorithm should eventually replace them. But they're effective — they handle the most common DTW alignment artifacts.

flowchart TD
    A["find_alignment()\nDTW → raw word timings"] --> B["merge_punctuations()\nAttach . , ! ? to words"]
    B --> C["Compute median/max\nword duration"]
    C --> D["Truncate long words\nat sentence boundaries"]
    D --> E["Align word starts/ends\nwith segment boundaries"]
    E --> F["Post-pause correction\n(first word after silence)"]
    F --> G["Final word timestamps\nper segment"]

Output Writers and Subtitle Formatting

The output system in whisper/utils.py uses a class hierarchy rooted in ResultWriter:

classDiagram
    class ResultWriter {
        +str extension
        +str output_dir
        +__call__(result, audio_path)
        +write_result(result, file)*
    }

    class WriteTXT {
        +extension = "txt"
    }

    class WriteJSON {
        +extension = "json"
    }

    class WriteTSV {
        +extension = "tsv"
    }

    class SubtitlesWriter {
        +bool always_include_hours
        +str decimal_marker
        +iterate_result() → yields (start, end, text)
    }

    class WriteVTT {
        +extension = "vtt"
    }

    class WriteSRT {
        +extension = "srt"
    }

    ResultWriter <|-- WriteTXT
    ResultWriter <|-- WriteJSON
    ResultWriter <|-- WriteTSV
    ResultWriter <|-- SubtitlesWriter
    SubtitlesWriter <|-- WriteVTT
    SubtitlesWriter <|-- WriteSRT

WriteTXT and WriteJSON are trivial. WriteTSV uses integer milliseconds to avoid locale-dependent decimal separators — a nice touch documented in its docstring.

The interesting logic lives in SubtitlesWriter.iterate_result(), which handles word-level highlighting. When highlight_words=True, each word gets its own subtitle cue with <u> tags around the active word:

00:00.640 --> 00:00.920
Hello <u>world</u> today

The method also handles line breaking with max_line_width, max_line_count, and max_words_per_line constraints, yielding subtitle blocks that respect these limits while keeping natural word groupings.

The get_writer() factory handles format selection. The "all" option creates all five writers and returns a composite function that calls each one — a simple implementation of the Composite pattern:

if output_format == "all":
    all_writers = [writer(output_dir) for writer in writers.values()]
    def write_all(result, file, options=None, **kwargs):
        for writer in all_writers:
            writer(result, file, options, **kwargs)
    return write_all

Series Conclusion

Over these six articles, we've traced every major code path in Whisper's ~2,500 lines:

  1. Architecture overview: The 9-module package, entry points, and method-binding pattern
  2. Audio frontend: FFmpeg loading, STFT, mel spectrograms, and the convolutional encoder stem
  3. Tokenization: Dual vocabularies, 1501 timestamp tokens, and the SOT sequence protocol
  4. Decoding: KV-cache via hooks, beam search, logit filter chain, and the main loop
  5. Transcription: Sliding window, temperature fallback, no-speech detection, hallucination recovery
  6. Word timestamps: Cross-attention extraction, DTW alignment, runtime code generation, and output formatting

What makes Whisper remarkable isn't any single technique — it's the engineering taste that keeps the codebase small without sacrificing capability. You won't find unnecessary abstractions, configuration frameworks, or plugin systems. Just a clear pipeline from audio to text, implemented with a few well-chosen design patterns and a lot of attention to the details that actually matter for robust speech recognition.