Read OSS

生成引擎:model.generate() 如何生成文本

高级

前置知识

  • 第 3 篇:模型结构解析(前向传播、注意力机制)
  • 第 4 篇:权重加载(模型已就绪,可进行推理)
  • 自回归 Transformer 中的 KV-cache 基本概念
  • 采样基础知识:temperature、top-k、top-p

生成引擎:model.generate() 如何生成文本

在第 4 篇完成模型加载之后,我们来看最关键的用户接口:model.generate()。这个方法有将近 1700 行代码,是一个功能完备的调度器,涵盖贪婪解码、多项式采样、束搜索、推测(辅助)解码、token 流式输出等多种策略。它还管理着一套可组合的 logits 处理管线、多种 cache 后端以及停止条件——同时保持与 torch.compile 的兼容性,以满足生产环境的吞吐量需求。

本文将完整梳理生成流程,从 GenerationConfig 验证到主自回归循环,并重点介绍 KV-cache 系统以及能带来 2–3 倍加速的辅助解码机制。

GenerationConfig 与模式选择

每次调用 generate() 都从一个 GenerationConfig 实例开始。这个类保存了所有解码参数:max_new_tokenstemperaturetop_ktop_pnum_beamsdo_samplerepetition_penalty 等等。

模式选择遵循以下优先级流程:

flowchart TD
    A["generate() called"] --> B["Resolve GenerationConfig"]
    B --> C{"assistant_model<br/>provided?"}
    C -->|Yes| D["Assisted/Speculative<br/>Decoding"]
    C -->|No| E{"num_beams > 1?"}
    E -->|Yes| F{"do_sample?"}
    F -->|Yes| G["Beam Search<br/>Multinomial Sampling"]
    F -->|No| H["Beam Search"]
    E -->|No| I{"do_sample?"}
    I -->|Yes| J["Multinomial<br/>Sampling"]
    I -->|No| K["Greedy<br/>Decoding"]

GenerationMixin 上的 generate() 方法首先将传入的 kwargs 合并到 generation config 中,校验参数组合的合法性(例如 temperature=0 意味着使用贪婪解码),然后分发到对应的内部方法。

GenerationConfig 还支持 custom_generate——可以传入字符串(Hub 仓库名)或可调用对象,完全替换默认的生成循环。这一扩展点让研究人员无需 fork 代码库,就能尝试新颖的解码算法。

提示: 相比直接向 generate() 传 kwargs,建议显式传入 generation_config=GenerationConfig(...)。这样可以避免参数合并的额外开销,同时让解码配置可以在多次调用之间复用。

KV-Cache 系统

自回归生成在每一步都需要对所有历史 token 重新计算注意力——除非将 key-value 投影缓存下来。Transformers 实现了一套完整的 cache 层级体系:

classDiagram
    class CacheLayerMixin {
        <<abstract>>
        +keys: Tensor
        +values: Tensor
        +is_initialized: bool
        +update(key_states, value_states)
        +get_seq_length() int
        +offload()
        +prefetch()
    }
    class Cache {
        +layers: list[CacheLayerMixin]
        +layer_class_to_replicate: type
        +offloading: bool
        +update(key, value, layer_idx)
        +get_seq_length()
    }
    class DynamicCache {
        «grows as needed»
    }
    class StaticCache {
        «fixed size, compile-friendly»
    }
    class QuantizedCache {
        «INT8 keys/values»
    }
    CacheLayerMixin --o Cache : layers
    Cache <|-- DynamicCache
    Cache <|-- StaticCache
    Cache <|-- QuantizedCache

CacheLayerMixin 是单层的抽象接口,支持 offloading(将 KV pair 移至 CPU 以释放显存)和 prefetching(提前将数据移回 GPU,实现流水线执行)。

Cache 基类是 CacheLayerMixin 对象的容器,每个模型层对应一个实例。它可以从预分配的层构建,也可以通过 layer_class_to_replicate 按需懒惰扩展。

三种主要实现:

  • DynamicCache — 动态追加到不断增长的 tensor 中,灵活但不支持 torch.compile(动态 shape 问题)。
  • StaticCache — 预分配固定大小的 buffer,与 torch.compile 兼容,也是 CUDA graph capture 的必要条件。
  • QuantizedCache — 以 INT8 存储 key 和 value,内存占用减少约 50%,精度略有损失。

正如第 3 篇所介绍的,LlamaModel.forward()use_cache=True 时默认创建 DynamicCache。cache 在每个解码层的注意力计算中流转,逐步积累 key-value pair。

Logits 处理管线

从模型输出的原始 logits 到最终的 token 选择,中间有一套可组合的 LogitsProcessor 管线对分数进行变换:

flowchart LR
    A["Raw logits<br/>[batch, vocab]"] --> B["TemperatureLogitsWarper"]
    B --> C["TopKLogitsWarper"]
    C --> D["TopPLogitsWarper"]
    D --> E["RepetitionPenaltyProcessor"]
    E --> F["NoBadWordsLogitsProcessor"]
    F --> G["Final scores"]
    G --> H["torch.multinomial<br/>or argmax"]

LogitsProcessorListlist 的子类,通过 __call__ 串联所有 processor。每个 processor 接收 (input_ids, scores) 并返回修改后的 scores。管线根据 GenerationConfig 参数自动组装:

  • temperatureTemperatureLogitsWarper
  • top_kTopKLogitsWarper
  • top_pTopPLogitsWarper
  • repetition_penaltyRepetitionPenaltyLogitsProcessor
  • no_repeat_ngram_sizeNoRepeatNGramLogitsProcessor

你可以通过 generate()logits_processor 参数注入自定义 processor,它们会被追加到内置 processor 之后。

提示: processor 的顺序至关重要。temperature 缩放应排在最前面(它会影响整体分布),其次是过滤操作(top-k/top-p),最后才是惩罚项。Transformers 默认按照正确顺序组装,但自定义 processor 会被追加到末尾,使用时需注意这一点。

推测解码 / 辅助解码

CandidateGenerator 系统实现了推测解码——用一个更小、更快的模型草拟候选 token,再由主模型并行验证。

sequenceDiagram
    participant Main as Main Model (70B)
    participant Draft as Draft Model (1B)
    participant Verify as Verification

    Note over Draft: Generate K candidate tokens
    Draft->>Draft: token_1, token_2, ..., token_K
    Draft-->>Main: Candidate sequence
    Main->>Main: Run forward pass on<br/>all K+1 positions simultaneously
    Main-->>Verify: Logits for each position
    Verify->>Verify: Compare draft vs main<br/>Accept matching tokens
    Verify-->>Main: Accepted: token_1..token_j<br/>Rejected from token_j+1
    Note over Main: Only 1 forward pass<br/>for up to K tokens!

AssistedCandidateGenerator 负责驱动草稿模型生成候选序列,随后主模型通过单次前向传播完成验证(由于注意力是因果的,所有位置可以同时检查)。验证通过的 token 被保留,第一个被拒绝的 token 则从主模型的分布中重新采样。

这一机制能带来 2–3 倍的吞吐量提升:草稿模型的前向传播开销远低于主模型,而主模型的验证过程又是批量完成的。关键在于,其输出分布在数学上与标准采样完全等价——拒绝步骤保证了这一点。

Transformers 还在第 341 行提供了 AssistedCandidateGeneratorDifferentTokenizers,专门处理草稿模型与主模型使用不同 tokenizer 的情况,需要进行 token 级别的对齐。

流式输出与停止条件

对于交互式应用来说,等待整个序列生成完毕是不可接受的。BaseStreamer 接口支持实时 token 输出:

class BaseStreamer:
    def put(self, value): ...    # Called with new token IDs
    def end(self): ...           # Called when generation finishes

TextStreamer 实时解码 token 并打印到标准输出。TextIteratorStreamer 将 token 放入队列供异步消费,非常适合 Web 服务场景。AsyncTextIteratorStreamer 则进一步支持 async 迭代。

生成终止由 StoppingCriteria 控制:

flowchart TD
    A["After each generation step"] --> B["StoppingCriteriaList"]
    B --> C["MaxLengthCriteria"]
    B --> D["MaxTimeCriteria"]
    B --> E["EosTokenCriteria"]
    B --> F["StopStringCriteria"]
    B --> G["Custom criteria"]
    C --> H{"Any criterion<br/>returns True?"}
    D --> H
    E --> H
    F --> H
    G --> H
    H -->|Yes| I["Stop generation"]
    H -->|No| J["Continue"]

StoppingCriteriaList 在每个 token 生成后检查所有条件。内置条件包括最大长度、最大时间、EOS token 检测和停止字符串匹配。自定义条件可以检查完整的已生成序列及分数。

主循环

将所有部分组合在一起,核心自回归循环(简化版)如下所示:

while not stopping_criteria(input_ids, scores):
    # 1. Run model forward pass
    outputs = model(input_ids, past_key_values=cache, ...)
    
    # 2. Extract next-token logits
    next_token_logits = outputs.logits[:, -1, :]
    
    # 3. Process logits
    next_token_scores = logits_processor(input_ids, next_token_logits)
    
    # 4. Select next token
    if do_sample:
        next_tokens = torch.multinomial(probs, num_samples=1)
    else:
        next_tokens = torch.argmax(next_token_scores, dim=-1)
    
    # 5. Update input_ids and stream
    input_ids = torch.cat([input_ids, next_tokens], dim=-1)
    if streamer is not None:
        streamer.put(next_tokens)

实际实现中还需处理 batch 维度追踪、束管理、分布式生成的 GPU 同步以及辅助解码验证等细节——但核心模式始终如一:前向传播 → 处理 → 选择 → 追加。

目录结构

文件 用途
src/transformers/generation/utils.py GenerationMixin.generate() — 约 1700 行的调度器
src/transformers/generation/configuration_utils.py GenerationConfig — 所有解码参数
src/transformers/generation/logits_process.py 可组合的 logits 处理管线
src/transformers/generation/candidate_generator.py 推测解码 / 辅助解码
src/transformers/generation/streamers.py 实时 token 流式输出
src/transformers/generation/stopping_criteria.py 生成终止条件
src/transformers/cache_utils.py KV-cache 层级体系

至此,我们已经完整覆盖了推理侧的内容。但 Transformers 同样是一个训练库。下一篇文章,我们将深入 Trainer 类——一个约 4400 行的训练循环调度器,集成了分布式后端、回调机制和损失函数注册表。