解码循环:批处理、KV 缓存与内存管理
前置知识
- ›第 1–3 篇文章
- ›了解 transformer 推理中的 KV 缓存机制
解码循环:批处理、KV 缓存与内存管理
前三篇文章奠定了架构基础:libllama 根据模型定义构建 GGML 计算图,并在硬件后端上执行。但我们略过了系统中操作复杂度最高的部分——负责管理批次、分配 KV 缓存槽位、处理微批次(micro-batch)以及从故障中恢复的解码流水线。
本文将完整追踪从 llama_decode() 到 logit 提取的全过程,内容涵盖:上下文构建、批次拆分流水线、process_ubatch() 核心循环、多态内存抽象、KV 缓存内部机制,以及错误恢复策略。
上下文构建
调用 llama_init_from_model() 时,llama_context 的构造函数 会依次执行以下操作:
-
合并参数 — 将用户传入的
llama_context_params与模型超参数合并。例如,若n_ctx为 0,则默认使用模型的训练上下文长度。 -
创建内存 — 调用
model.create_memory(),根据模型架构分配对应的内存实现(KV 缓存、循环状态或混合模式)。 -
创建批次分配器 — 初始化
llama_batch_allocr,负责批次的规范化与拆分。 -
预留计算缓冲区 — 构建最坏情况下的计算图(通过
memory->init_full()),并通过后端调度器运行,以确定所需的最大计算缓冲区大小。
对推理影响最显著的上下文参数如下:
| 参数 | 作用 |
|---|---|
n_ctx |
总上下文窗口大小(KV 缓存可容纳的 token 数) |
n_batch |
最大逻辑批次大小(用户提交的批次上限) |
n_ubatch |
最大物理批次大小(每次图执行实际送入 GPU 的 token 数) |
n_seq_max |
最大并发序列数 |
type_k, type_v |
KV 缓存数据类型(F16、Q8_0、Q4_0 等) |
flash_attn_type |
是否使用 flash attention kernel |
flowchart TD
PARAMS["llama_context_params"] --> CTX["llama_context constructor"]
CTX --> MERGE["Merge with hparams"]
CTX --> MEM["model.create_memory()"]
CTX --> BALLOC["Create batch_allocr"]
CTX --> SCHED["Reserve compute buffers\n(worst-case graph)"]
MEM --> KV["llama_kv_cache"]
MEM --> REC["llama_memory_recurrent"]
MEM --> HYB["llama_memory_hybrid"]
批次流水线
在任何计算开始之前,用户提交的 llama_batch 必须经过规范化、校验,并拆分为若干微批次。整个流水线分为三个阶段:
阶段一:规范化。 llama_batch_allocr 接收用户的原始批次,并自动填充缺失字段。若 pos 为 NULL,则根据内存模块中的序列连续性自动生成位置信息;若 seq_id 为 NULL,所有 token 默认归属序列 0;若 logits 为 NULL,则仅对最后一个 token 输出结果。
阶段二:内存初始化。 调用内存模块的 init_batch(),将批次拆分为不超过 n_ubatch 大小的微批次,并验证 KV 缓存(或循环状态)是否有足够空间。若失败,上下文可能会尝试碎片整理后重试。
阶段三:微批次迭代。 解码循环通过 llama_memory_context_i 接口遍历各微批次,调用 next() 推进迭代,调用 get_ubatch() 获取当前微批次。
llama_ubatch 结构体 是解码流水线内部使用的批次表示:
struct llama_ubatch {
uint32_t n_tokens; // total tokens
uint32_t n_seq_tokens; // tokens per sequence set
uint32_t n_seqs; // sequence sets
uint32_t n_seqs_unq; // unique sequences
llama_token * token; // [n_tokens]
llama_pos * pos; // [n_tokens * n_pos]
llama_seq_id ** seq_id; // [n_tokens]
int8_t * output; // [n_tokens] which positions need logits
};
提示:
n_batch与n_ubatch的区别在处理大批次时至关重要。用户可能一次提交 512 个 token(n_batch),但 GPU 每次只能处理 128 个(n_ubatch)。批次流水线会自动将任务拆分为 4 个微批次,整个过程对调用方完全透明。
process_ubatch:核心循环
解码的核心逻辑位于 llama_context::decode()。完成批次准备和内存初始化后,函数进入一个 do/while 循环:
sequenceDiagram
participant D as decode()
participant M as memory_context
participant P as process_ubatch()
participant G as model.build_graph()
participant S as Backend Scheduler
D->>M: init_batch(balloc, n_ubatch)
loop For each ubatch
D->>M: get_ubatch()
D->>P: process_ubatch(ubatch, mctx)
P->>M: apply() — commit to memory
P->>P: Check graph reuse
alt Graph can be reused
P->>P: set_inputs() only
else Build new graph
P->>G: build_graph(params)
P->>S: alloc_graph(gf)
end
P->>P: set_inputs(ubatch)
P->>S: graph_compute(gf)
P-->>D: logits
D->>D: Extract logits/embeddings
D->>M: next() — advance to next ubatch
end
process_ubatch() 是图构建与执行相交汇的地方,依次完成以下步骤:
-
应用内存状态 —
mctx->apply()将当前微批次的 KV 缓存槽位分配提交到内存模块。这是内存发生实际变更的唯一时机。 -
检查图复用 — 若上一次图的拓扑结构与当前微批次一致(形状相同、配置相同),则跳过图构建和调度器分配,直接复用。
-
构建或复用图 — 通过
model.build_graph()构建新图并经调度器分配,或直接复用已有图。 -
设置输入 — 将微批次数据(token ID、位置、掩码、KV 索引)填充到所有输入张量中。
-
执行计算 — 通过
ggml_backend_sched_graph_compute()执行计算图。
内存抽象层次结构
正如第一篇文章所介绍的,三种核心类型主导着推理过程。但还有第四个角色:内存接口。llama_memory_i 是所有内存实现的抽象基类:
struct llama_memory_i {
virtual llama_memory_context_ptr init_batch(...) = 0;
virtual llama_memory_context_ptr init_full() = 0;
virtual llama_memory_context_ptr init_update(...) = 0;
virtual bool get_can_shift() const = 0;
virtual void seq_rm(llama_seq_id, llama_pos, llama_pos) = 0;
// ...
};
llama_model 中的 create_memory() 工厂函数 负责选择合适的实现:
classDiagram
class llama_memory_i {
<<interface>>
+init_batch()
+init_full()
+init_update()
+seq_rm()
}
class llama_kv_cache {
Standard transformer KV cache
Cell-based slot allocation
}
class llama_kv_cache_iswa {
Interleaved sliding window
Two sub-caches (base + SWA)
}
class llama_memory_recurrent {
SSM/RWKV state buffers
Fixed slots per sequence
}
class llama_memory_hybrid {
KV cache + recurrent state
Per-layer type selection
}
class llama_memory_hybrid_iswa {
iSWA cache + recurrent state
}
llama_memory_i <|-- llama_kv_cache
llama_memory_i <|-- llama_kv_cache_iswa
llama_memory_i <|-- llama_memory_recurrent
llama_memory_i <|-- llama_memory_hybrid
llama_memory_i <|-- llama_memory_hybrid_iswa
选择逻辑如下:
- 无需缓存的模型(BERT、DREAM 等)→
nullptr - 纯循环模型(Mamba、RWKV)→
llama_memory_recurrent - 混合模型(Jamba、Falcon-H1)→
llama_memory_hybrid或llama_memory_hybrid_iswa - 带滑动窗口注意力机制的模型 →
llama_kv_cache_iswa - 标准 transformer →
llama_kv_cache
KV 缓存深度解析
llama_kv_cache 是使用最广泛的内存实现。它管理一块固定大小的 KV 对缓冲区,以"单元格(cell)"为单位组织——每个位置对应一个单元格。
llama_kv_cells 数据结构记录每个单元格的元数据:
pos[]— 该单元格存储的序列位置(-1 表示空闲)seq[]— 标记哪些序列使用该单元格的位集(单元格可被共享)shift[]— 上下文滑动时待处理的位置偏移量used— 已占用单元格的索引集合seq_pos[]— 每个序列的"位置 → 单元格索引"映射
调用 init_batch() 时,KV 缓存需要为新 token 找到可用槽位(连续或非连续均可)。它会创建一个 slot_info 结构,记录每个 token 的 KV 对将写入哪些单元格。这些索引作为 llm_graph_input_attn_kv 中的 self_k_idxs 和 self_v_idxs 张量传递给计算图,指示 ggml_set_rows 操作将新的 K/V 值写入何处。
当 memory_update(optimize=true) 被调用时,会触发碎片整理。该过程通过移动单元格来消除空隙,并通过后端调度器构建并执行一个专用的碎片整理计算图。当 init_batch() 返回 LLAMA_MEMORY_STATUS_FAILED_PREPARE 时,解码循环会自动尝试碎片整理。
flowchart LR
subgraph "KV Cache Cells"
C0["Cell 0\nSeq 0, Pos 0"]
C1["Cell 1\nSeq 0, Pos 1"]
C2["Cell 2\n(empty)"]
C3["Cell 3\nSeq 1, Pos 0"]
C4["Cell 4\nSeq 0, Pos 2"]
C5["Cell 5\n(empty)"]
C6["Cell 6\nSeq 1, Pos 1"]
end
subgraph "After Defrag"
D0["Cell 0\nSeq 0, Pos 0"]
D1["Cell 1\nSeq 0, Pos 1"]
D2["Cell 2\nSeq 0, Pos 2"]
D3["Cell 3\nSeq 1, Pos 0"]
D4["Cell 4\nSeq 1, Pos 1"]
D5["Cell 5\n(empty)"]
D6["Cell 6\n(empty)"]
end
错误恢复与状态回滚
解码流水线在设计上具备良好的故障处理能力。llama_memory_context_i 接口定义了 apply/commit 协议:
struct llama_memory_context_i {
virtual bool next() = 0; // advance to next ubatch
virtual bool apply() = 0; // commit current ubatch to memory
virtual const llama_ubatch & get_ubatch() const = 0;
virtual llama_memory_status get_status() const = 0;
};
apply() 是内存状态发生变更的唯一入口——它在 process_ubatch() 内部、图计算之前执行。如果后续图计算失败(分配错误、abort 回调等),解码循环会通过删除已提交的位置来执行回滚:
// from decode() error handling
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
}
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
memory->seq_rm(s, pos_min[s], -1);
}
这段逻辑位于 llama-context.cpp 第 1690–1711 行,它会移除每个受影响序列从 pos_min 起的所有位置,确保 KV 缓存中不会残留来自部分处理批次的脏数据。
解码函数针对不同的失败场景返回不同的错误码:
1— 缓存已满(应释放部分序列或减小批次大小)-2— 分配失败或内部错误-3— 计算执行失败2— 被回调函数中止
提示: 如果
llama_decode()返回1,说明 KV 缓存已满。可以调用llama_memory_seq_rm()释放旧序列、减少n_predict,或在创建上下文时增大n_ctx。
下一步
至此,我们完整追踪了从 llama_decode() 调用到 KV 缓存单元格分配的全链路。下一篇文章将视角从库层切换到应用层:HTTP server 和 CLI 工具如何封装 libllama 以服务真实用户,以及一个令人意外的架构决策——CLI 不过是对 server 基础设施的轻量封装。