Read OSS

Trainer:从数据到大规模梯度计算

高级

前置知识

  • 第 3 篇:模型结构解析
  • 第 4 篇:权重加载
  • PyTorch 训练循环基础(forward、backward、optimizer.step)
  • 分布式训练基本概念:DDP、梯度累积、混合精度

Trainer:从数据到大规模梯度计算

Trainer 类是 Transformers 对这个问题的解答:"如何用同一套用户代码,写出一个既能在单 GPU 上运行,又能在 8 GPU(DDP)、64 GPU(DeepSpeed ZeRO-3)乃至 TPU Pod 上运行的训练循环?"整个文件约 4400 行,是代码库中最大的单文件,涵盖了数据加载、混合精度、梯度累积、检查点保存、评估、20 多个回调钩子,以及与十余种分布式后端的集成。

本文将梳理核心训练流程,剖析回调系统的生命周期,并展示 Trainer 如何在用户不修改任何模型代码的情况下与分布式后端无缝集成。

Trainer.init 与 TrainingArguments

Trainer 的构造函数接收模型、包含 100 多个参数的 TrainingArguments 数据类、数据集、处理类、回调以及优化器覆盖配置。

TrainingArguments 是整个训练流程的控制面板。主要参数分组如下:

分组 示例
基础设置 output_dirnum_train_epochsper_device_train_batch_size
优化 learning_rateweight_decayadam_beta1adam_beta2lr_scheduler_type
混合精度 fp16bf16fp16_opt_level
分布式 local_rankdeepspeedfsdpfsdp_config
检查点 save_strategysave_stepssave_total_limit
日志 logging_stepsreport_torun_name
梯度 gradient_accumulation_stepsgradient_checkpointingmax_grad_norm

构造函数会初始化 Accelerate 的 Accelerator 对象,用于抽象底层的分布式后端。无论运行在 1 块 GPU 还是 64 块 GPU 上,Trainer 的内部代码保持一致——数据并行、梯度同步和设备分配均由 Accelerate 统一处理。

classDiagram
    class Trainer {
        +model: PreTrainedModel
        +args: TrainingArguments
        +train_dataset: Dataset
        +eval_dataset: Dataset
        +processing_class: Any
        +optimizer: Optimizer
        +lr_scheduler: LRScheduler
        +callback_handler: CallbackHandler
        +train()
        +evaluate()
        +predict()
        +training_step()
        +compute_loss()
    }
    class TrainingArguments {
        +output_dir: str
        +num_train_epochs: float
        +learning_rate: float
        +per_device_train_batch_size: int
        +gradient_accumulation_steps: int
        +fp16: bool
        +deepspeed: str
        +100+ more parameters
    }
    Trainer --> TrainingArguments

训练循环:train() → training_step()

train() 方法是训练的主入口,其整体流程如下:

flowchart TD
    A["train()"] --> B["Resume from checkpoint?"]
    B --> C["Create DataLoader"]
    C --> D["Setup optimizer + scheduler"]
    D --> E["Call on_train_begin callbacks"]
    E --> F["For each epoch"]
    F --> G["For each batch"]
    G --> H["training_step(model, inputs)"]
    H --> I{"Gradient accumulation<br/>boundary?"}
    I -->|Yes| J["Clip gradients"]
    J --> K["optimizer.step()"]
    K --> L["scheduler.step()"]
    L --> M["optimizer.zero_grad()"]
    I -->|No| G
    M --> N{"eval_steps reached?"}
    N -->|Yes| O["evaluate()"]
    N -->|No| P{"save_steps reached?"}
    P -->|Yes| Q["save_checkpoint()"]
    P -->|No| G
    O --> G
    Q --> G
    F --> R["Call on_train_end callbacks"]

training_step() 方法负责处理单个 batch:

def training_step(self, model, inputs, num_items_in_batch=None):
    model.train()
    inputs = self._prepare_inputs(inputs)
    
    with self.compute_loss_context_manager():
        loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
    
    # Handle multi-GPU loss averaging
    if self.args.n_gpu > 1:
        loss = loss.mean()
    
    # Backward pass (Accelerate handles mixed precision)
    self.accelerator.backward(loss, **kwargs)
    
    return loss.detach()

self.accelerator.backward(loss) 这一行涵盖了混合精度、梯度缩放以及分布式梯度同步的全部逻辑,对训练步骤本身完全透明。Accelerate 在 Trainer 初始化时对模型、优化器和 DataLoader 进行包装,并注入相应的钩子。

提示: 如果需要自定义训练逻辑(例如多任务损失、对抗训练),可以覆写 training_step();如果只是需要替换损失函数,覆写 compute_loss() 更为简洁。Trainer 在设计上就鼓励子类化——许多方法的签名中都明确注明了"子类可覆写以注入自定义行为"。

回调系统

Trainer 通过 TrainerCallback 提供了丰富的回调机制,共有 20 多个生命周期钩子:

sequenceDiagram
    participant T as Trainer
    participant CH as CallbackHandler
    participant CB1 as WandbCallback
    participant CB2 as TensorBoardCallback
    participant CB3 as PrinterCallback

    T->>CH: on_train_begin(args, state, control)
    CH->>CB1: on_train_begin(...)
    CH->>CB2: on_train_begin(...)
    CH->>CB3: on_train_begin(...)
    
    loop Each training step
        T->>CH: on_step_begin(...)
        T->>T: training_step()
        T->>CH: on_step_end(...)
        T->>CH: on_log(logs={"loss": 0.5, ...})
        CH->>CB1: on_log(logs) → upload to W&B
        CH->>CB2: on_log(logs) → write to TB
    end
    
    T->>CH: on_evaluate(metrics)
    T->>CH: on_train_end(...)

这些生命周期钩子覆盖了整个训练过程:

钩子 触发时机
on_init_end Trainer.init 完成后
on_train_begin / on_train_end 训练开始 / 结束
on_epoch_begin / on_epoch_end 每个 epoch 开始 / 结束
on_step_begin / on_step_end 每个训练步骤前 / 后
on_substep_end 每个梯度累积子步骤后
on_log 指标记录时
on_evaluate 评估完成后
on_save 保存检查点时
on_prediction_step 预测过程中

每个回调接收 (args, state, control, **kwargs) 作为参数。control 对象是唯一可变的返回值——回调可以通过设置 control.should_training_stop = Truecontrol.should_save = True 来干预训练循环的行为。

内置回调已包含与 Weights & Biases、TensorBoard、MLflow、CometML 和 Neptune 的集成,均通过 args.report_to 自动检测并启用。

分布式训练集成

Trainer 支持三种主要的分布式训练方案:

DDP(Distributed Data Parallel) — 使用 torchrun 启动时的默认方案。每块 GPU 持有完整的模型副本,梯度通过 all-reduce 同步。

DeepSpeed ZeRO — 按阶段逐步将优化器状态(Stage 1)、梯度(Stage 2)和模型参数(Stage 3)分片到各 GPU 上。通过 args.deepspeed 指定 JSON 配置文件进行配置,integrations/deepspeed.py 模块负责初始化。

FSDP(Fully Sharded Data Parallel) — PyTorch 原生的参数分片方案,通过 args.fsdpargs.fsdp_config 进行配置,integrations/fsdp.py 模块提供桥接层。

flowchart TD
    A["TrainingArguments"] --> B{"deepspeed config?"}
    B -->|Yes| C["Initialize DeepSpeed<br/>ZeRO Stage 1/2/3"]
    B -->|No| D{"fsdp enabled?"}
    D -->|Yes| E["Initialize FSDP<br/>via Accelerate"]
    D -->|No| F{"n_gpu > 1?"}
    F -->|Yes| G["Distributed Data<br/>Parallel (DDP)"]
    F -->|No| H["Single GPU<br/>training"]
    C --> I["Accelerate wraps<br/>model + optimizer"]
    E --> I
    G --> I
    H --> I

这套架构的精妙之处在于,无论使用哪种后端,training_step() 的代码看起来都完全一致。Accelerate 的包装层负责处理以下所有细节:

  • 梯度累积(调用 accelerator.backward(),根据是否处于累积步骤决定是否同步梯度)
  • 混合精度(通过 autocast 上下文支持 FP16 或 BF16)
  • 分布式梯度规约(由包装后的模型自动处理)

损失函数与 Data Collator

LOSS_MAPPING 注册表将模型类名的后缀映射到对应的共享损失函数:

LOSS_MAPPING = {
    "ForCausalLM": ForCausalLMLoss,
    "ForMaskedLM": ForMaskedLMLoss,
    "ForSequenceClassification": ForSequenceClassificationLoss,
    "ForTokenClassification": ForTokenClassification,
    "ForQuestionAnswering": ForQuestionAnsweringLoss,
    # ... plus object detection variants
}

每个模型通过其 loss_function 属性在该映射表中查找自身的类名。ForCausalLMLoss 处理标签偏移(用 labels[n] 预测 token[n+1]),并使用 fixed_cross_entropy,该函数同时支持均值规约和求和归一化(以确保梯度累积时的计算正确性)。

data/data_collator.py 中的 Data Collator 负责处理不同类型数据的批处理:

  • DataCollatorForLanguageModeling — 用于 MLM 的动态掩码
  • DataCollatorForSeq2Seq — 对 encoder-decoder 模型按批次最长序列进行 padding
  • DataCollatorWithPadding — 简单 padding 并生成 attention mask
  • DataCollatorForTokenClassification — 将标签与子词 token 对齐

如果未手动指定,Trainer 会根据处理类自动选择合适的 Collator。

优化器与学习率调度器

optimization.py 模块提供了超出 PyTorch 内置范围的优化器和调度器实现。默认使用 AdamW,但 args.optim 还支持 AdaFactor、Lion、SGD、schedule-free 优化器以及用于内存高效微调的 LOMO。

学习率调度器支持线性 warmup + 衰减、余弦、带重启的余弦、多项式衰减以及带 warmup 的常数调度。调度器在 train() 初始化阶段根据 args.lr_scheduler_type 创建。

提示: 大规模训练时,设置 args.optim = "adamw_torch_fused" 可启用 PyTorch 的融合 AdamW 内核,通过减少每次参数更新的 kernel 启动次数,速度比默认实现快 15%–20%。

目录结构

文件 用途
src/transformers/trainer.py 约 4400 行的 Trainer 类
src/transformers/training_args.py 100 多个 TrainingArguments 参数
src/transformers/trainer_callback.py 包含 20 多个钩子的回调系统
src/transformers/optimization.py 优化器与学习率调度器
src/transformers/data/data_collator.py 用于批处理的 Data Collator
src/transformers/loss/loss_utils.py LOSS_MAPPING 及共享损失函数
src/transformers/integrations/deepspeed.py DeepSpeed ZeRO 集成
src/transformers/integrations/fsdp.py FSDP 集成

至此,我们已经梳理了完整的模型生命周期:导入 → 解析 → 加载 → 生成 → 训练。在最后一篇文章中,我们将聚焦于高层 Pipeline API、tokenizer 层次结构、多模态处理,以及向代码库贡献新模型的扩展接口。