生成引擎:model.generate() 如何生成文本
前置知识
- ›第 3 篇:模型结构解析(前向传播、注意力机制)
- ›第 4 篇:权重加载(模型已就绪,可进行推理)
- ›自回归 Transformer 中的 KV-cache 基本概念
- ›采样基础知识:temperature、top-k、top-p
生成引擎:model.generate() 如何生成文本
在第 4 篇完成模型加载之后,我们来看最关键的用户接口:model.generate()。这个方法有将近 1700 行代码,是一个功能完备的调度器,涵盖贪婪解码、多项式采样、束搜索、推测(辅助)解码、token 流式输出等多种策略。它还管理着一套可组合的 logits 处理管线、多种 cache 后端以及停止条件——同时保持与 torch.compile 的兼容性,以满足生产环境的吞吐量需求。
本文将完整梳理生成流程,从 GenerationConfig 验证到主自回归循环,并重点介绍 KV-cache 系统以及能带来 2–3 倍加速的辅助解码机制。
GenerationConfig 与模式选择
每次调用 generate() 都从一个 GenerationConfig 实例开始。这个类保存了所有解码参数:max_new_tokens、temperature、top_k、top_p、num_beams、do_sample、repetition_penalty 等等。
模式选择遵循以下优先级流程:
flowchart TD
A["generate() called"] --> B["Resolve GenerationConfig"]
B --> C{"assistant_model<br/>provided?"}
C -->|Yes| D["Assisted/Speculative<br/>Decoding"]
C -->|No| E{"num_beams > 1?"}
E -->|Yes| F{"do_sample?"}
F -->|Yes| G["Beam Search<br/>Multinomial Sampling"]
F -->|No| H["Beam Search"]
E -->|No| I{"do_sample?"}
I -->|Yes| J["Multinomial<br/>Sampling"]
I -->|No| K["Greedy<br/>Decoding"]
GenerationMixin 上的 generate() 方法首先将传入的 kwargs 合并到 generation config 中,校验参数组合的合法性(例如 temperature=0 意味着使用贪婪解码),然后分发到对应的内部方法。
GenerationConfig 还支持 custom_generate——可以传入字符串(Hub 仓库名)或可调用对象,完全替换默认的生成循环。这一扩展点让研究人员无需 fork 代码库,就能尝试新颖的解码算法。
提示: 相比直接向
generate()传 kwargs,建议显式传入generation_config=GenerationConfig(...)。这样可以避免参数合并的额外开销,同时让解码配置可以在多次调用之间复用。
KV-Cache 系统
自回归生成在每一步都需要对所有历史 token 重新计算注意力——除非将 key-value 投影缓存下来。Transformers 实现了一套完整的 cache 层级体系:
classDiagram
class CacheLayerMixin {
<<abstract>>
+keys: Tensor
+values: Tensor
+is_initialized: bool
+update(key_states, value_states)
+get_seq_length() int
+offload()
+prefetch()
}
class Cache {
+layers: list[CacheLayerMixin]
+layer_class_to_replicate: type
+offloading: bool
+update(key, value, layer_idx)
+get_seq_length()
}
class DynamicCache {
«grows as needed»
}
class StaticCache {
«fixed size, compile-friendly»
}
class QuantizedCache {
«INT8 keys/values»
}
CacheLayerMixin --o Cache : layers
Cache <|-- DynamicCache
Cache <|-- StaticCache
Cache <|-- QuantizedCache
CacheLayerMixin 是单层的抽象接口,支持 offloading(将 KV pair 移至 CPU 以释放显存)和 prefetching(提前将数据移回 GPU,实现流水线执行)。
Cache 基类是 CacheLayerMixin 对象的容器,每个模型层对应一个实例。它可以从预分配的层构建,也可以通过 layer_class_to_replicate 按需懒惰扩展。
三种主要实现:
DynamicCache— 动态追加到不断增长的 tensor 中,灵活但不支持torch.compile(动态 shape 问题)。StaticCache— 预分配固定大小的 buffer,与torch.compile兼容,也是 CUDA graph capture 的必要条件。QuantizedCache— 以 INT8 存储 key 和 value,内存占用减少约 50%,精度略有损失。
正如第 3 篇所介绍的,LlamaModel.forward() 在 use_cache=True 时默认创建 DynamicCache。cache 在每个解码层的注意力计算中流转,逐步积累 key-value pair。
Logits 处理管线
从模型输出的原始 logits 到最终的 token 选择,中间有一套可组合的 LogitsProcessor 管线对分数进行变换:
flowchart LR
A["Raw logits<br/>[batch, vocab]"] --> B["TemperatureLogitsWarper"]
B --> C["TopKLogitsWarper"]
C --> D["TopPLogitsWarper"]
D --> E["RepetitionPenaltyProcessor"]
E --> F["NoBadWordsLogitsProcessor"]
F --> G["Final scores"]
G --> H["torch.multinomial<br/>or argmax"]
LogitsProcessorList 是 list 的子类,通过 __call__ 串联所有 processor。每个 processor 接收 (input_ids, scores) 并返回修改后的 scores。管线根据 GenerationConfig 参数自动组装:
temperature→TemperatureLogitsWarpertop_k→TopKLogitsWarpertop_p→TopPLogitsWarperrepetition_penalty→RepetitionPenaltyLogitsProcessorno_repeat_ngram_size→NoRepeatNGramLogitsProcessor
你可以通过 generate() 的 logits_processor 参数注入自定义 processor,它们会被追加到内置 processor 之后。
提示: processor 的顺序至关重要。temperature 缩放应排在最前面(它会影响整体分布),其次是过滤操作(top-k/top-p),最后才是惩罚项。Transformers 默认按照正确顺序组装,但自定义 processor 会被追加到末尾,使用时需注意这一点。
推测解码 / 辅助解码
CandidateGenerator 系统实现了推测解码——用一个更小、更快的模型草拟候选 token,再由主模型并行验证。
sequenceDiagram
participant Main as Main Model (70B)
participant Draft as Draft Model (1B)
participant Verify as Verification
Note over Draft: Generate K candidate tokens
Draft->>Draft: token_1, token_2, ..., token_K
Draft-->>Main: Candidate sequence
Main->>Main: Run forward pass on<br/>all K+1 positions simultaneously
Main-->>Verify: Logits for each position
Verify->>Verify: Compare draft vs main<br/>Accept matching tokens
Verify-->>Main: Accepted: token_1..token_j<br/>Rejected from token_j+1
Note over Main: Only 1 forward pass<br/>for up to K tokens!
AssistedCandidateGenerator 负责驱动草稿模型生成候选序列,随后主模型通过单次前向传播完成验证(由于注意力是因果的,所有位置可以同时检查)。验证通过的 token 被保留,第一个被拒绝的 token 则从主模型的分布中重新采样。
这一机制能带来 2–3 倍的吞吐量提升:草稿模型的前向传播开销远低于主模型,而主模型的验证过程又是批量完成的。关键在于,其输出分布在数学上与标准采样完全等价——拒绝步骤保证了这一点。
Transformers 还在第 341 行提供了 AssistedCandidateGeneratorDifferentTokenizers,专门处理草稿模型与主模型使用不同 tokenizer 的情况,需要进行 token 级别的对齐。
流式输出与停止条件
对于交互式应用来说,等待整个序列生成完毕是不可接受的。BaseStreamer 接口支持实时 token 输出:
class BaseStreamer:
def put(self, value): ... # Called with new token IDs
def end(self): ... # Called when generation finishes
TextStreamer 实时解码 token 并打印到标准输出。TextIteratorStreamer 将 token 放入队列供异步消费,非常适合 Web 服务场景。AsyncTextIteratorStreamer 则进一步支持 async 迭代。
生成终止由 StoppingCriteria 控制:
flowchart TD
A["After each generation step"] --> B["StoppingCriteriaList"]
B --> C["MaxLengthCriteria"]
B --> D["MaxTimeCriteria"]
B --> E["EosTokenCriteria"]
B --> F["StopStringCriteria"]
B --> G["Custom criteria"]
C --> H{"Any criterion<br/>returns True?"}
D --> H
E --> H
F --> H
G --> H
H -->|Yes| I["Stop generation"]
H -->|No| J["Continue"]
StoppingCriteriaList 在每个 token 生成后检查所有条件。内置条件包括最大长度、最大时间、EOS token 检测和停止字符串匹配。自定义条件可以检查完整的已生成序列及分数。
主循环
将所有部分组合在一起,核心自回归循环(简化版)如下所示:
while not stopping_criteria(input_ids, scores):
# 1. Run model forward pass
outputs = model(input_ids, past_key_values=cache, ...)
# 2. Extract next-token logits
next_token_logits = outputs.logits[:, -1, :]
# 3. Process logits
next_token_scores = logits_processor(input_ids, next_token_logits)
# 4. Select next token
if do_sample:
next_tokens = torch.multinomial(probs, num_samples=1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
# 5. Update input_ids and stream
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
if streamer is not None:
streamer.put(next_tokens)
实际实现中还需处理 batch 维度追踪、束管理、分布式生成的 GPU 同步以及辅助解码验证等细节——但核心模式始终如一:前向传播 → 处理 → 选择 → 追加。
目录结构
| 文件 | 用途 |
|---|---|
src/transformers/generation/utils.py |
GenerationMixin.generate() — 约 1700 行的调度器 |
src/transformers/generation/configuration_utils.py |
GenerationConfig — 所有解码参数 |
src/transformers/generation/logits_process.py |
可组合的 logits 处理管线 |
src/transformers/generation/candidate_generator.py |
推测解码 / 辅助解码 |
src/transformers/generation/streamers.py |
实时 token 流式输出 |
src/transformers/generation/stopping_criteria.py |
生成终止条件 |
src/transformers/cache_utils.py |
KV-cache 层级体系 |
至此,我们已经完整覆盖了推理侧的内容。但 Transformers 同样是一个训练库。下一篇文章,我们将深入 Trainer 类——一个约 4400 行的训练循环调度器,集成了分布式后端、回调机制和损失函数注册表。