Read OSS

The Decode Loop: Batching, KV Cache, and Memory Management

Advanced

Prerequisites

  • Articles 1-3
  • Understanding of KV caching in transformer inference

The Decode Loop: Batching, KV Cache, and Memory Management

Articles 1–3 established the architecture: libllama builds GGML computation graphs from model definitions and executes them on hardware backends. But we glossed over the most operationally complex part of the system—the decode pipeline that manages batches, allocates KV cache slots, handles micro-batching, and recovers from failures.

This article traces the complete path from llama_decode() to logit extraction, covering context construction, the batch splitting pipeline, the process_ubatch() core loop, the polymorphic memory abstraction, KV cache internals, and error recovery.

Context Construction

When you call llama_init_from_model(), the llama_context constructor does several things:

  1. Merges parameters — The user's llama_context_params are merged with model hyperparameters. For example, if n_ctx is 0, it defaults to the model's training context length.

  2. Creates memory — Calls model.create_memory() to allocate the appropriate memory implementation (KV cache, recurrent state, or hybrid). The memory type depends entirely on the model architecture.

  3. Creates the batch allocator — Initializes llama_batch_allocr, which handles batch sanitization and splitting.

  4. Reserves compute buffers — Builds a worst-case graph (using memory->init_full()) and runs it through the backend scheduler to determine the maximum compute buffer sizes needed.

The context params that shape inference most significantly are:

Parameter Effect
n_ctx Total context window size (tokens the KV cache can hold)
n_batch Maximum logical batch size (what the user submits)
n_ubatch Maximum physical batch size (what hits the GPU per graph execution)
n_seq_max Maximum concurrent sequences
type_k, type_v KV cache data types (F16, Q8_0, Q4_0, etc.)
flash_attn_type Whether to use flash attention kernels
flowchart TD
    PARAMS["llama_context_params"] --> CTX["llama_context constructor"]
    CTX --> MERGE["Merge with hparams"]
    CTX --> MEM["model.create_memory()"]
    CTX --> BALLOC["Create batch_allocr"]
    CTX --> SCHED["Reserve compute buffers\n(worst-case graph)"]
    MEM --> KV["llama_kv_cache"]
    MEM --> REC["llama_memory_recurrent"]
    MEM --> HYB["llama_memory_hybrid"]

The Batch Pipeline

Before any computation happens, the user's llama_batch must be sanitized, validated, and split into micro-batches. This pipeline has three stages:

Stage 1: Sanitization. The llama_batch_allocr takes the raw user batch and auto-fills missing fields. If pos is NULL, positions are auto-generated based on sequence continuity in the memory module. If seq_id is NULL, all tokens default to sequence 0. If logits is NULL, only the last token gets output.

Stage 2: Memory initialization. The memory module's init_batch() is called, which splits the batch into ubatches that fit within n_ubatch and validates that the KV cache (or recurrent state) has room. If it fails, the context may attempt defragmentation and retry.

Stage 3: Ubatch iteration. The decode loop iterates through ubatches via the llama_memory_context_i interface, calling next() to advance and get_ubatch() to retrieve each one.

The llama_ubatch structure is the internal batch representation used throughout the decode pipeline:

struct llama_ubatch {
    uint32_t n_tokens;      // total tokens
    uint32_t n_seq_tokens;  // tokens per sequence set
    uint32_t n_seqs;        // sequence sets
    uint32_t n_seqs_unq;    // unique sequences
    
    llama_token  * token;   // [n_tokens]
    llama_pos    * pos;     // [n_tokens * n_pos]
    llama_seq_id ** seq_id; // [n_tokens]
    int8_t       * output;  // [n_tokens] which positions need logits
};

Tip: The distinction between n_batch and n_ubatch is critical for large batch sizes. A user might submit 512 tokens at once (n_batch), but the GPU may only handle 128 at a time (n_ubatch). The batch pipeline transparently splits the work into 4 ubatches.

process_ubatch: The Core Loop

The central decode loop lives in llama_context::decode(). After batch preparation and memory initialization, it enters a do/while loop:

sequenceDiagram
    participant D as decode()
    participant M as memory_context
    participant P as process_ubatch()
    participant G as model.build_graph()
    participant S as Backend Scheduler

    D->>M: init_batch(balloc, n_ubatch)
    loop For each ubatch
        D->>M: get_ubatch()
        D->>P: process_ubatch(ubatch, mctx)
        P->>M: apply() — commit to memory
        P->>P: Check graph reuse
        alt Graph can be reused
            P->>P: set_inputs() only
        else Build new graph
            P->>G: build_graph(params)
            P->>S: alloc_graph(gf)
        end
        P->>P: set_inputs(ubatch)
        P->>S: graph_compute(gf)
        P-->>D: logits
        D->>D: Extract logits/embeddings
        D->>M: next() — advance to next ubatch
    end

The process_ubatch() method is where graph building meets execution:

  1. Apply memory statemctx->apply() commits the current ubatch's KV cache slot assignments to the memory module. This is the only point where memory is mutated.

  2. Check graph reuse — If the previous graph's topology matches (same ubatch shape, same configuration), skip graph building and scheduler allocation entirely.

  3. Build or reuse graph — Either construct a new graph via model.build_graph() and allocate through the scheduler, or reuse the previous graph.

  4. Set inputs — Populate all input tensors (token IDs, positions, masks, KV indices) from the ubatch data.

  5. Compute — Execute the graph through ggml_backend_sched_graph_compute().

Memory Abstraction Hierarchy

As we saw in Article 1, the three critical types govern inference. But there's a fourth player: the memory interface. llama_memory_i is the abstract base for all memory implementations:

struct llama_memory_i {
    virtual llama_memory_context_ptr init_batch(...) = 0;
    virtual llama_memory_context_ptr init_full() = 0;
    virtual llama_memory_context_ptr init_update(...) = 0;
    virtual bool get_can_shift() const = 0;
    virtual void seq_rm(llama_seq_id, llama_pos, llama_pos) = 0;
    // ...
};

The create_memory() factory in llama_model selects the right implementation:

classDiagram
    class llama_memory_i {
        <<interface>>
        +init_batch()
        +init_full()
        +init_update()
        +seq_rm()
    }
    class llama_kv_cache {
        Standard transformer KV cache
        Cell-based slot allocation
    }
    class llama_kv_cache_iswa {
        Interleaved sliding window
        Two sub-caches (base + SWA)
    }
    class llama_memory_recurrent {
        SSM/RWKV state buffers
        Fixed slots per sequence
    }
    class llama_memory_hybrid {
        KV cache + recurrent state
        Per-layer type selection
    }
    class llama_memory_hybrid_iswa {
        iSWA cache + recurrent state
    }
    llama_memory_i <|-- llama_kv_cache
    llama_memory_i <|-- llama_kv_cache_iswa
    llama_memory_i <|-- llama_memory_recurrent
    llama_memory_i <|-- llama_memory_hybrid
    llama_memory_i <|-- llama_memory_hybrid_iswa

The selection logic is straightforward:

  • Models without caching (BERT, DREAM, etc.) → nullptr
  • Pure recurrent models (Mamba, RWKV) → llama_memory_recurrent
  • Hybrid models (Jamba, Falcon-H1) → llama_memory_hybrid or llama_memory_hybrid_iswa
  • Models with sliding window attention → llama_kv_cache_iswa
  • Standard transformers → llama_kv_cache

KV Cache Deep Dive

The llama_kv_cache is the most commonly used memory implementation. It manages a fixed-size buffer of KV pairs organized as "cells"—one cell per position.

The llama_kv_cells data structure tracks metadata for each cell:

  • pos[] — the sequence position stored in this cell (-1 if empty)
  • seq[] — bitset of which sequences use this cell (cells can be shared)
  • shift[] — pending position shift for context shifting
  • used — set of cell indices that are occupied
  • seq_pos[] — per-sequence maps of position → cell index

When init_batch() is called, the KV cache must find contiguous (or at least available) slots for the new tokens. It creates a slot_info structure containing the cell indices where each token's KV pair will be stored. The indices are passed to the computation graph as llm_graph_input_attn_kv's self_k_idxs and self_v_idxs tensors—these tell the ggml_set_rows operations where to write new K/V values.

Defragmentation is triggered when memory_update(optimize=true) is called. This compacts the KV cache by moving cells to eliminate gaps, building and executing a dedicated defragmentation graph through the backend scheduler. The decode loop attempts defragmentation automatically when init_batch() fails with LLAMA_MEMORY_STATUS_FAILED_PREPARE.

flowchart LR
    subgraph "KV Cache Cells"
        C0["Cell 0\nSeq 0, Pos 0"]
        C1["Cell 1\nSeq 0, Pos 1"]
        C2["Cell 2\n(empty)"]
        C3["Cell 3\nSeq 1, Pos 0"]
        C4["Cell 4\nSeq 0, Pos 2"]
        C5["Cell 5\n(empty)"]
        C6["Cell 6\nSeq 1, Pos 1"]
    end
    
    subgraph "After Defrag"
        D0["Cell 0\nSeq 0, Pos 0"]
        D1["Cell 1\nSeq 0, Pos 1"]
        D2["Cell 2\nSeq 0, Pos 2"]
        D3["Cell 3\nSeq 1, Pos 0"]
        D4["Cell 4\nSeq 1, Pos 1"]
        D5["Cell 5\n(empty)"]
        D6["Cell 6\n(empty)"]
    end

Error Recovery and State Rollback

The decode pipeline is designed to handle failures gracefully. The llama_memory_context_i interface enforces an apply/commit protocol:

struct llama_memory_context_i {
    virtual bool next() = 0;      // advance to next ubatch
    virtual bool apply() = 0;     // commit current ubatch to memory
    virtual const llama_ubatch & get_ubatch() const = 0;
    virtual llama_memory_status get_status() const = 0;
};

The apply() method is the only point where memory state is mutated—and it happens inside process_ubatch() before graph computation. If graph computation fails afterward (allocation error, abort callback, etc.), the decode loop handles rollback by removing the positions that were committed:

// from decode() error handling
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
    pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
}
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
    memory->seq_rm(s, pos_min[s], -1);
}

This approach—found in llama-context.cpp lines 1690–1711—removes all positions from pos_min onward for each affected sequence, ensuring the KV cache doesn't contain stale data from a partially processed batch.

The decode function returns different error codes for different failure modes:

  • 1 — cache full (user should free sequences or reduce batch)
  • -2 — allocation or internal error
  • -3 — compute failure
  • 2 — aborted by callback

Tip: If you're getting return code 1 from llama_decode(), the KV cache is full. Either call llama_memory_seq_rm() to free old sequences, reduce n_predict, or increase n_ctx when creating the context.

What's Next

We've now traced the complete path from a llama_decode() call down to KV cache cell allocation and back. The next article shifts from the library layer to the application layer: how the HTTP server and command-line tools wrap libllama to serve real users, and the surprising architectural decision that makes the CLI a thin wrapper around the server's infrastructure.