Read OSS

解码循环:批处理、KV 缓存与内存管理

高级

前置知识

  • 第 1–3 篇文章
  • 了解 transformer 推理中的 KV 缓存机制

解码循环:批处理、KV 缓存与内存管理

前三篇文章奠定了架构基础:libllama 根据模型定义构建 GGML 计算图,并在硬件后端上执行。但我们略过了系统中操作复杂度最高的部分——负责管理批次、分配 KV 缓存槽位、处理微批次(micro-batch)以及从故障中恢复的解码流水线。

本文将完整追踪从 llama_decode() 到 logit 提取的全过程,内容涵盖:上下文构建、批次拆分流水线、process_ubatch() 核心循环、多态内存抽象、KV 缓存内部机制,以及错误恢复策略。

上下文构建

调用 llama_init_from_model() 时,llama_context 的构造函数 会依次执行以下操作:

  1. 合并参数 — 将用户传入的 llama_context_params 与模型超参数合并。例如,若 n_ctx 为 0,则默认使用模型的训练上下文长度。

  2. 创建内存 — 调用 model.create_memory(),根据模型架构分配对应的内存实现(KV 缓存、循环状态或混合模式)。

  3. 创建批次分配器 — 初始化 llama_batch_allocr,负责批次的规范化与拆分。

  4. 预留计算缓冲区 — 构建最坏情况下的计算图(通过 memory->init_full()),并通过后端调度器运行,以确定所需的最大计算缓冲区大小。

对推理影响最显著的上下文参数如下:

参数 作用
n_ctx 总上下文窗口大小(KV 缓存可容纳的 token 数)
n_batch 最大逻辑批次大小(用户提交的批次上限)
n_ubatch 最大物理批次大小(每次图执行实际送入 GPU 的 token 数)
n_seq_max 最大并发序列数
type_k, type_v KV 缓存数据类型(F16、Q8_0、Q4_0 等)
flash_attn_type 是否使用 flash attention kernel
flowchart TD
    PARAMS["llama_context_params"] --> CTX["llama_context constructor"]
    CTX --> MERGE["Merge with hparams"]
    CTX --> MEM["model.create_memory()"]
    CTX --> BALLOC["Create batch_allocr"]
    CTX --> SCHED["Reserve compute buffers\n(worst-case graph)"]
    MEM --> KV["llama_kv_cache"]
    MEM --> REC["llama_memory_recurrent"]
    MEM --> HYB["llama_memory_hybrid"]

批次流水线

在任何计算开始之前,用户提交的 llama_batch 必须经过规范化、校验,并拆分为若干微批次。整个流水线分为三个阶段:

阶段一:规范化。 llama_batch_allocr 接收用户的原始批次,并自动填充缺失字段。若 pos 为 NULL,则根据内存模块中的序列连续性自动生成位置信息;若 seq_id 为 NULL,所有 token 默认归属序列 0;若 logits 为 NULL,则仅对最后一个 token 输出结果。

阶段二:内存初始化。 调用内存模块的 init_batch(),将批次拆分为不超过 n_ubatch 大小的微批次,并验证 KV 缓存(或循环状态)是否有足够空间。若失败,上下文可能会尝试碎片整理后重试。

阶段三:微批次迭代。 解码循环通过 llama_memory_context_i 接口遍历各微批次,调用 next() 推进迭代,调用 get_ubatch() 获取当前微批次。

llama_ubatch 结构体 是解码流水线内部使用的批次表示:

struct llama_ubatch {
    uint32_t n_tokens;      // total tokens
    uint32_t n_seq_tokens;  // tokens per sequence set
    uint32_t n_seqs;        // sequence sets
    uint32_t n_seqs_unq;    // unique sequences
    
    llama_token  * token;   // [n_tokens]
    llama_pos    * pos;     // [n_tokens * n_pos]
    llama_seq_id ** seq_id; // [n_tokens]
    int8_t       * output;  // [n_tokens] which positions need logits
};

提示: n_batchn_ubatch 的区别在处理大批次时至关重要。用户可能一次提交 512 个 token(n_batch),但 GPU 每次只能处理 128 个(n_ubatch)。批次流水线会自动将任务拆分为 4 个微批次,整个过程对调用方完全透明。

process_ubatch:核心循环

解码的核心逻辑位于 llama_context::decode()。完成批次准备和内存初始化后,函数进入一个 do/while 循环:

sequenceDiagram
    participant D as decode()
    participant M as memory_context
    participant P as process_ubatch()
    participant G as model.build_graph()
    participant S as Backend Scheduler

    D->>M: init_batch(balloc, n_ubatch)
    loop For each ubatch
        D->>M: get_ubatch()
        D->>P: process_ubatch(ubatch, mctx)
        P->>M: apply() — commit to memory
        P->>P: Check graph reuse
        alt Graph can be reused
            P->>P: set_inputs() only
        else Build new graph
            P->>G: build_graph(params)
            P->>S: alloc_graph(gf)
        end
        P->>P: set_inputs(ubatch)
        P->>S: graph_compute(gf)
        P-->>D: logits
        D->>D: Extract logits/embeddings
        D->>M: next() — advance to next ubatch
    end

process_ubatch() 是图构建与执行相交汇的地方,依次完成以下步骤:

  1. 应用内存状态mctx->apply() 将当前微批次的 KV 缓存槽位分配提交到内存模块。这是内存发生实际变更的唯一时机。

  2. 检查图复用 — 若上一次图的拓扑结构与当前微批次一致(形状相同、配置相同),则跳过图构建和调度器分配,直接复用。

  3. 构建或复用图 — 通过 model.build_graph() 构建新图并经调度器分配,或直接复用已有图。

  4. 设置输入 — 将微批次数据(token ID、位置、掩码、KV 索引)填充到所有输入张量中。

  5. 执行计算 — 通过 ggml_backend_sched_graph_compute() 执行计算图。

内存抽象层次结构

正如第一篇文章所介绍的,三种核心类型主导着推理过程。但还有第四个角色:内存接口。llama_memory_i 是所有内存实现的抽象基类:

struct llama_memory_i {
    virtual llama_memory_context_ptr init_batch(...) = 0;
    virtual llama_memory_context_ptr init_full() = 0;
    virtual llama_memory_context_ptr init_update(...) = 0;
    virtual bool get_can_shift() const = 0;
    virtual void seq_rm(llama_seq_id, llama_pos, llama_pos) = 0;
    // ...
};

llama_model 中的 create_memory() 工厂函数 负责选择合适的实现:

classDiagram
    class llama_memory_i {
        <<interface>>
        +init_batch()
        +init_full()
        +init_update()
        +seq_rm()
    }
    class llama_kv_cache {
        Standard transformer KV cache
        Cell-based slot allocation
    }
    class llama_kv_cache_iswa {
        Interleaved sliding window
        Two sub-caches (base + SWA)
    }
    class llama_memory_recurrent {
        SSM/RWKV state buffers
        Fixed slots per sequence
    }
    class llama_memory_hybrid {
        KV cache + recurrent state
        Per-layer type selection
    }
    class llama_memory_hybrid_iswa {
        iSWA cache + recurrent state
    }
    llama_memory_i <|-- llama_kv_cache
    llama_memory_i <|-- llama_kv_cache_iswa
    llama_memory_i <|-- llama_memory_recurrent
    llama_memory_i <|-- llama_memory_hybrid
    llama_memory_i <|-- llama_memory_hybrid_iswa

选择逻辑如下:

  • 无需缓存的模型(BERT、DREAM 等)→ nullptr
  • 纯循环模型(Mamba、RWKV)→ llama_memory_recurrent
  • 混合模型(Jamba、Falcon-H1)→ llama_memory_hybridllama_memory_hybrid_iswa
  • 带滑动窗口注意力机制的模型 → llama_kv_cache_iswa
  • 标准 transformer → llama_kv_cache

KV 缓存深度解析

llama_kv_cache 是使用最广泛的内存实现。它管理一块固定大小的 KV 对缓冲区,以"单元格(cell)"为单位组织——每个位置对应一个单元格。

llama_kv_cells 数据结构记录每个单元格的元数据:

  • pos[] — 该单元格存储的序列位置(-1 表示空闲)
  • seq[] — 标记哪些序列使用该单元格的位集(单元格可被共享)
  • shift[] — 上下文滑动时待处理的位置偏移量
  • used — 已占用单元格的索引集合
  • seq_pos[] — 每个序列的"位置 → 单元格索引"映射

调用 init_batch() 时,KV 缓存需要为新 token 找到可用槽位(连续或非连续均可)。它会创建一个 slot_info 结构,记录每个 token 的 KV 对将写入哪些单元格。这些索引作为 llm_graph_input_attn_kv 中的 self_k_idxsself_v_idxs 张量传递给计算图,指示 ggml_set_rows 操作将新的 K/V 值写入何处。

memory_update(optimize=true) 被调用时,会触发碎片整理。该过程通过移动单元格来消除空隙,并通过后端调度器构建并执行一个专用的碎片整理计算图。当 init_batch() 返回 LLAMA_MEMORY_STATUS_FAILED_PREPARE 时,解码循环会自动尝试碎片整理。

flowchart LR
    subgraph "KV Cache Cells"
        C0["Cell 0\nSeq 0, Pos 0"]
        C1["Cell 1\nSeq 0, Pos 1"]
        C2["Cell 2\n(empty)"]
        C3["Cell 3\nSeq 1, Pos 0"]
        C4["Cell 4\nSeq 0, Pos 2"]
        C5["Cell 5\n(empty)"]
        C6["Cell 6\nSeq 1, Pos 1"]
    end
    
    subgraph "After Defrag"
        D0["Cell 0\nSeq 0, Pos 0"]
        D1["Cell 1\nSeq 0, Pos 1"]
        D2["Cell 2\nSeq 0, Pos 2"]
        D3["Cell 3\nSeq 1, Pos 0"]
        D4["Cell 4\nSeq 1, Pos 1"]
        D5["Cell 5\n(empty)"]
        D6["Cell 6\n(empty)"]
    end

错误恢复与状态回滚

解码流水线在设计上具备良好的故障处理能力。llama_memory_context_i 接口定义了 apply/commit 协议:

struct llama_memory_context_i {
    virtual bool next() = 0;      // advance to next ubatch
    virtual bool apply() = 0;     // commit current ubatch to memory
    virtual const llama_ubatch & get_ubatch() const = 0;
    virtual llama_memory_status get_status() const = 0;
};

apply() 是内存状态发生变更的唯一入口——它在 process_ubatch() 内部、图计算之前执行。如果后续图计算失败(分配错误、abort 回调等),解码循环会通过删除已提交的位置来执行回滚:

// from decode() error handling
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
    pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
}
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
    memory->seq_rm(s, pos_min[s], -1);
}

这段逻辑位于 llama-context.cpp 第 1690–1711 行,它会移除每个受影响序列从 pos_min 起的所有位置,确保 KV 缓存中不会残留来自部分处理批次的脏数据。

解码函数针对不同的失败场景返回不同的错误码:

  • 1 — 缓存已满(应释放部分序列或减小批次大小)
  • -2 — 分配失败或内部错误
  • -3 — 计算执行失败
  • 2 — 被回调函数中止

提示: 如果 llama_decode() 返回 1,说明 KV 缓存已满。可以调用 llama_memory_seq_rm() 释放旧序列、减少 n_predict,或在创建上下文时增大 n_ctx

下一步

至此,我们完整追踪了从 llama_decode() 调用到 KV 缓存单元格分配的全链路。下一篇文章将视角从库层切换到应用层:HTTP server 和 CLI 工具如何封装 libllama 以服务真实用户,以及一个令人意外的架构决策——CLI 不过是对 server 基础设施的轻量封装。