Trainer:从数据到大规模梯度计算
前置知识
- ›第 3 篇:模型结构解析
- ›第 4 篇:权重加载
- ›PyTorch 训练循环基础(forward、backward、optimizer.step)
- ›分布式训练基本概念:DDP、梯度累积、混合精度
Trainer:从数据到大规模梯度计算
Trainer 类是 Transformers 对这个问题的解答:"如何用同一套用户代码,写出一个既能在单 GPU 上运行,又能在 8 GPU(DDP)、64 GPU(DeepSpeed ZeRO-3)乃至 TPU Pod 上运行的训练循环?"整个文件约 4400 行,是代码库中最大的单文件,涵盖了数据加载、混合精度、梯度累积、检查点保存、评估、20 多个回调钩子,以及与十余种分布式后端的集成。
本文将梳理核心训练流程,剖析回调系统的生命周期,并展示 Trainer 如何在用户不修改任何模型代码的情况下与分布式后端无缝集成。
Trainer.init 与 TrainingArguments
Trainer 的构造函数接收模型、包含 100 多个参数的 TrainingArguments 数据类、数据集、处理类、回调以及优化器覆盖配置。
TrainingArguments 是整个训练流程的控制面板。主要参数分组如下:
| 分组 | 示例 |
|---|---|
| 基础设置 | output_dir、num_train_epochs、per_device_train_batch_size |
| 优化 | learning_rate、weight_decay、adam_beta1、adam_beta2、lr_scheduler_type |
| 混合精度 | fp16、bf16、fp16_opt_level |
| 分布式 | local_rank、deepspeed、fsdp、fsdp_config |
| 检查点 | save_strategy、save_steps、save_total_limit |
| 日志 | logging_steps、report_to、run_name |
| 梯度 | gradient_accumulation_steps、gradient_checkpointing、max_grad_norm |
构造函数会初始化 Accelerate 的 Accelerator 对象,用于抽象底层的分布式后端。无论运行在 1 块 GPU 还是 64 块 GPU 上,Trainer 的内部代码保持一致——数据并行、梯度同步和设备分配均由 Accelerate 统一处理。
classDiagram
class Trainer {
+model: PreTrainedModel
+args: TrainingArguments
+train_dataset: Dataset
+eval_dataset: Dataset
+processing_class: Any
+optimizer: Optimizer
+lr_scheduler: LRScheduler
+callback_handler: CallbackHandler
+train()
+evaluate()
+predict()
+training_step()
+compute_loss()
}
class TrainingArguments {
+output_dir: str
+num_train_epochs: float
+learning_rate: float
+per_device_train_batch_size: int
+gradient_accumulation_steps: int
+fp16: bool
+deepspeed: str
+100+ more parameters
}
Trainer --> TrainingArguments
训练循环:train() → training_step()
train() 方法是训练的主入口,其整体流程如下:
flowchart TD
A["train()"] --> B["Resume from checkpoint?"]
B --> C["Create DataLoader"]
C --> D["Setup optimizer + scheduler"]
D --> E["Call on_train_begin callbacks"]
E --> F["For each epoch"]
F --> G["For each batch"]
G --> H["training_step(model, inputs)"]
H --> I{"Gradient accumulation<br/>boundary?"}
I -->|Yes| J["Clip gradients"]
J --> K["optimizer.step()"]
K --> L["scheduler.step()"]
L --> M["optimizer.zero_grad()"]
I -->|No| G
M --> N{"eval_steps reached?"}
N -->|Yes| O["evaluate()"]
N -->|No| P{"save_steps reached?"}
P -->|Yes| Q["save_checkpoint()"]
P -->|No| G
O --> G
Q --> G
F --> R["Call on_train_end callbacks"]
training_step() 方法负责处理单个 batch:
def training_step(self, model, inputs, num_items_in_batch=None):
model.train()
inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
# Handle multi-GPU loss averaging
if self.args.n_gpu > 1:
loss = loss.mean()
# Backward pass (Accelerate handles mixed precision)
self.accelerator.backward(loss, **kwargs)
return loss.detach()
self.accelerator.backward(loss) 这一行涵盖了混合精度、梯度缩放以及分布式梯度同步的全部逻辑,对训练步骤本身完全透明。Accelerate 在 Trainer 初始化时对模型、优化器和 DataLoader 进行包装,并注入相应的钩子。
提示: 如果需要自定义训练逻辑(例如多任务损失、对抗训练),可以覆写
training_step();如果只是需要替换损失函数,覆写compute_loss()更为简洁。Trainer 在设计上就鼓励子类化——许多方法的签名中都明确注明了"子类可覆写以注入自定义行为"。
回调系统
Trainer 通过 TrainerCallback 提供了丰富的回调机制,共有 20 多个生命周期钩子:
sequenceDiagram
participant T as Trainer
participant CH as CallbackHandler
participant CB1 as WandbCallback
participant CB2 as TensorBoardCallback
participant CB3 as PrinterCallback
T->>CH: on_train_begin(args, state, control)
CH->>CB1: on_train_begin(...)
CH->>CB2: on_train_begin(...)
CH->>CB3: on_train_begin(...)
loop Each training step
T->>CH: on_step_begin(...)
T->>T: training_step()
T->>CH: on_step_end(...)
T->>CH: on_log(logs={"loss": 0.5, ...})
CH->>CB1: on_log(logs) → upload to W&B
CH->>CB2: on_log(logs) → write to TB
end
T->>CH: on_evaluate(metrics)
T->>CH: on_train_end(...)
这些生命周期钩子覆盖了整个训练过程:
| 钩子 | 触发时机 |
|---|---|
on_init_end |
Trainer.init 完成后 |
on_train_begin / on_train_end |
训练开始 / 结束 |
on_epoch_begin / on_epoch_end |
每个 epoch 开始 / 结束 |
on_step_begin / on_step_end |
每个训练步骤前 / 后 |
on_substep_end |
每个梯度累积子步骤后 |
on_log |
指标记录时 |
on_evaluate |
评估完成后 |
on_save |
保存检查点时 |
on_prediction_step |
预测过程中 |
每个回调接收 (args, state, control, **kwargs) 作为参数。control 对象是唯一可变的返回值——回调可以通过设置 control.should_training_stop = True 或 control.should_save = True 来干预训练循环的行为。
内置回调已包含与 Weights & Biases、TensorBoard、MLflow、CometML 和 Neptune 的集成,均通过 args.report_to 自动检测并启用。
分布式训练集成
Trainer 支持三种主要的分布式训练方案:
DDP(Distributed Data Parallel) — 使用 torchrun 启动时的默认方案。每块 GPU 持有完整的模型副本,梯度通过 all-reduce 同步。
DeepSpeed ZeRO — 按阶段逐步将优化器状态(Stage 1)、梯度(Stage 2)和模型参数(Stage 3)分片到各 GPU 上。通过 args.deepspeed 指定 JSON 配置文件进行配置,integrations/deepspeed.py 模块负责初始化。
FSDP(Fully Sharded Data Parallel) — PyTorch 原生的参数分片方案,通过 args.fsdp 和 args.fsdp_config 进行配置,integrations/fsdp.py 模块提供桥接层。
flowchart TD
A["TrainingArguments"] --> B{"deepspeed config?"}
B -->|Yes| C["Initialize DeepSpeed<br/>ZeRO Stage 1/2/3"]
B -->|No| D{"fsdp enabled?"}
D -->|Yes| E["Initialize FSDP<br/>via Accelerate"]
D -->|No| F{"n_gpu > 1?"}
F -->|Yes| G["Distributed Data<br/>Parallel (DDP)"]
F -->|No| H["Single GPU<br/>training"]
C --> I["Accelerate wraps<br/>model + optimizer"]
E --> I
G --> I
H --> I
这套架构的精妙之处在于,无论使用哪种后端,training_step() 的代码看起来都完全一致。Accelerate 的包装层负责处理以下所有细节:
- 梯度累积(调用
accelerator.backward(),根据是否处于累积步骤决定是否同步梯度) - 混合精度(通过
autocast上下文支持 FP16 或 BF16) - 分布式梯度规约(由包装后的模型自动处理)
损失函数与 Data Collator
LOSS_MAPPING 注册表将模型类名的后缀映射到对应的共享损失函数:
LOSS_MAPPING = {
"ForCausalLM": ForCausalLMLoss,
"ForMaskedLM": ForMaskedLMLoss,
"ForSequenceClassification": ForSequenceClassificationLoss,
"ForTokenClassification": ForTokenClassification,
"ForQuestionAnswering": ForQuestionAnsweringLoss,
# ... plus object detection variants
}
每个模型通过其 loss_function 属性在该映射表中查找自身的类名。ForCausalLMLoss 处理标签偏移(用 labels[n] 预测 token[n+1]),并使用 fixed_cross_entropy,该函数同时支持均值规约和求和归一化(以确保梯度累积时的计算正确性)。
data/data_collator.py 中的 Data Collator 负责处理不同类型数据的批处理:
DataCollatorForLanguageModeling— 用于 MLM 的动态掩码DataCollatorForSeq2Seq— 对 encoder-decoder 模型按批次最长序列进行 paddingDataCollatorWithPadding— 简单 padding 并生成 attention maskDataCollatorForTokenClassification— 将标签与子词 token 对齐
如果未手动指定,Trainer 会根据处理类自动选择合适的 Collator。
优化器与学习率调度器
optimization.py 模块提供了超出 PyTorch 内置范围的优化器和调度器实现。默认使用 AdamW,但 args.optim 还支持 AdaFactor、Lion、SGD、schedule-free 优化器以及用于内存高效微调的 LOMO。
学习率调度器支持线性 warmup + 衰减、余弦、带重启的余弦、多项式衰减以及带 warmup 的常数调度。调度器在 train() 初始化阶段根据 args.lr_scheduler_type 创建。
提示: 大规模训练时,设置
args.optim = "adamw_torch_fused"可启用 PyTorch 的融合 AdamW 内核,通过减少每次参数更新的 kernel 启动次数,速度比默认实现快 15%–20%。
目录结构
| 文件 | 用途 |
|---|---|
src/transformers/trainer.py |
约 4400 行的 Trainer 类 |
src/transformers/training_args.py |
100 多个 TrainingArguments 参数 |
src/transformers/trainer_callback.py |
包含 20 多个钩子的回调系统 |
src/transformers/optimization.py |
优化器与学习率调度器 |
src/transformers/data/data_collator.py |
用于批处理的 Data Collator |
src/transformers/loss/loss_utils.py |
LOSS_MAPPING 及共享损失函数 |
src/transformers/integrations/deepspeed.py |
DeepSpeed ZeRO 集成 |
src/transformers/integrations/fsdp.py |
FSDP 集成 |
至此,我们已经梳理了完整的模型生命周期:导入 → 解析 → 加载 → 生成 → 训练。在最后一篇文章中,我们将聚焦于高层 Pipeline API、tokenizer 层次结构、多模态处理,以及向代码库贡献新模型的扩展接口。