Read OSS

The Trainer: From Data to Gradients at Scale

Advanced

Prerequisites

  • Article 3: Model anatomy
  • Article 4: Weight loading
  • PyTorch training loop basics (forward, backward, optimizer.step)
  • Distributed training concepts: DDP, gradient accumulation, mixed precision

The Trainer: From Data to Gradients at Scale

The Trainer class is Transformers' answer to the question: "How do you write one training loop that works on a single GPU, 8 GPUs with DDP, 64 GPUs with DeepSpeed ZeRO-3, or a TPU pod — all from the same user code?" At ~4400 lines, it's the largest single file in the codebase, handling data loading, mixed precision, gradient accumulation, checkpointing, evaluation, 20+ callback hooks, and integration with a dozen distributed backends.

This article traces the core training flow, examines the callback lifecycle, and shows how the Trainer integrates with distributed backends without the user changing a line of model code.

Trainer.init and TrainingArguments

The Trainer's constructor accepts a model, a TrainingArguments dataclass with 100+ parameters, datasets, a processing class, callbacks, and optimizer overrides.

TrainingArguments is the control surface for the entire training process. Key parameter groups include:

Group Examples
Basics output_dir, num_train_epochs, per_device_train_batch_size
Optimization learning_rate, weight_decay, adam_beta1, adam_beta2, lr_scheduler_type
Mixed precision fp16, bf16, fp16_opt_level
Distributed local_rank, deepspeed, fsdp, fsdp_config
Checkpointing save_strategy, save_steps, save_total_limit
Logging logging_steps, report_to, run_name
Gradient gradient_accumulation_steps, gradient_checkpointing, max_grad_norm

The constructor sets up the Accelerate Accelerator object, which abstracts away the distributed backend. Whether you're running on 1 GPU or 64, the Trainer's internal code is the same — Accelerate handles the data parallelism, gradient synchronization, and device placement.

classDiagram
    class Trainer {
        +model: PreTrainedModel
        +args: TrainingArguments
        +train_dataset: Dataset
        +eval_dataset: Dataset
        +processing_class: Any
        +optimizer: Optimizer
        +lr_scheduler: LRScheduler
        +callback_handler: CallbackHandler
        +train()
        +evaluate()
        +predict()
        +training_step()
        +compute_loss()
    }
    class TrainingArguments {
        +output_dir: str
        +num_train_epochs: float
        +learning_rate: float
        +per_device_train_batch_size: int
        +gradient_accumulation_steps: int
        +fp16: bool
        +deepspeed: str
        +100+ more parameters
    }
    Trainer --> TrainingArguments

The Training Loop: train() → training_step()

The train() method is the main entry point. Its high-level flow:

flowchart TD
    A["train()"] --> B["Resume from checkpoint?"]
    B --> C["Create DataLoader"]
    C --> D["Setup optimizer + scheduler"]
    D --> E["Call on_train_begin callbacks"]
    E --> F["For each epoch"]
    F --> G["For each batch"]
    G --> H["training_step(model, inputs)"]
    H --> I{"Gradient accumulation<br/>boundary?"}
    I -->|Yes| J["Clip gradients"]
    J --> K["optimizer.step()"]
    K --> L["scheduler.step()"]
    L --> M["optimizer.zero_grad()"]
    I -->|No| G
    M --> N{"eval_steps reached?"}
    N -->|Yes| O["evaluate()"]
    N -->|No| P{"save_steps reached?"}
    P -->|Yes| Q["save_checkpoint()"]
    P -->|No| G
    O --> G
    Q --> G
    F --> R["Call on_train_end callbacks"]

The training_step() method handles a single batch:

def training_step(self, model, inputs, num_items_in_batch=None):
    model.train()
    inputs = self._prepare_inputs(inputs)
    
    with self.compute_loss_context_manager():
        loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
    
    # Handle multi-GPU loss averaging
    if self.args.n_gpu > 1:
        loss = loss.mean()
    
    # Backward pass (Accelerate handles mixed precision)
    self.accelerator.backward(loss, **kwargs)
    
    return loss.detach()

The self.accelerator.backward(loss) call is where mixed precision, gradient scaling, and distributed gradient synchronization happen — all transparent to the training step logic. The Accelerate library wraps the model, optimizer, and data loaders during Trainer initialization, injecting the appropriate hooks.

Tip: Override training_step() when you need custom training logic (e.g., multi-task losses, adversarial training). Override compute_loss() for simpler cases where you just need a different loss function. The Trainer is designed for subclassing — many method signatures include "Subclass and override to inject custom behavior."

The Callback System

The Trainer provides a rich callback system via TrainerCallback, with 20+ lifecycle hooks:

sequenceDiagram
    participant T as Trainer
    participant CH as CallbackHandler
    participant CB1 as WandbCallback
    participant CB2 as TensorBoardCallback
    participant CB3 as PrinterCallback

    T->>CH: on_train_begin(args, state, control)
    CH->>CB1: on_train_begin(...)
    CH->>CB2: on_train_begin(...)
    CH->>CB3: on_train_begin(...)
    
    loop Each training step
        T->>CH: on_step_begin(...)
        T->>T: training_step()
        T->>CH: on_step_end(...)
        T->>CH: on_log(logs={"loss": 0.5, ...})
        CH->>CB1: on_log(logs) → upload to W&B
        CH->>CB2: on_log(logs) → write to TB
    end
    
    T->>CH: on_evaluate(metrics)
    T->>CH: on_train_end(...)

The lifecycle hooks span the entire training process:

Hook When
on_init_end After Trainer.init
on_train_begin / on_train_end Start/end of training
on_epoch_begin / on_epoch_end Start/end of each epoch
on_step_begin / on_step_end Before/after each training step
on_substep_end After each gradient accumulation substep
on_log When metrics are logged
on_evaluate After evaluation
on_save When a checkpoint is saved
on_prediction_step During prediction

Each callback receives (args, state, control, **kwargs). The control object is the only mutable return — callbacks can set control.should_training_stop = True or control.should_save = True to influence the training loop.

Built-in callbacks include integrations with Weights & Biases, TensorBoard, MLflow, CometML, and Neptune — all auto-detected from args.report_to.

Distributed Training Integration

The Trainer supports three main distributed paradigms:

DDP (Distributed Data Parallel) — the default when launching with torchrun. Each GPU has a full model copy, gradients are synchronized via all-reduce.

DeepSpeed ZeRO — progressively shards optimizer states (Stage 1), gradients (Stage 2), and parameters (Stage 3) across GPUs. Configured via args.deepspeed pointing to a JSON config. The integrations/deepspeed.py module handles initialization.

FSDP (Fully Sharded Data Parallel) — PyTorch-native parameter sharding. Configured via args.fsdp and args.fsdp_config. The integrations/fsdp.py module provides the bridge.

flowchart TD
    A["TrainingArguments"] --> B{"deepspeed config?"}
    B -->|Yes| C["Initialize DeepSpeed<br/>ZeRO Stage 1/2/3"]
    B -->|No| D{"fsdp enabled?"}
    D -->|Yes| E["Initialize FSDP<br/>via Accelerate"]
    D -->|No| F{"n_gpu > 1?"}
    F -->|Yes| G["Distributed Data<br/>Parallel (DDP)"]
    F -->|No| H["Single GPU<br/>training"]
    C --> I["Accelerate wraps<br/>model + optimizer"]
    E --> I
    G --> I
    H --> I

The beauty of this architecture is that training_step() looks identical regardless of backend. The Accelerate library's wrapping handles:

  • Gradient accumulation (calls accelerator.backward() which may or may not sync gradients depending on whether it's an accumulation step)
  • Mixed precision (FP16 or BF16 via autocast context)
  • Distributed gradient reduction (handled by the wrapped model)

Loss Functions and Data Collators

The LOSS_MAPPING registry maps model class name suffixes to shared loss functions:

LOSS_MAPPING = {
    "ForCausalLM": ForCausalLMLoss,
    "ForMaskedLM": ForMaskedLMLoss,
    "ForSequenceClassification": ForSequenceClassificationLoss,
    "ForTokenClassification": ForTokenClassification,
    "ForQuestionAnswering": ForQuestionAnsweringLoss,
    # ... plus object detection variants
}

Each model's loss_function property looks up its class name in this mapping. The ForCausalLMLoss function handles the label shifting (labels[n] predicts token[n+1]) and uses fixed_cross_entropy which supports both mean reduction and sum-with-normalization (for gradient accumulation correctness).

Data collators from data/data_collator.py handle batching different data types:

  • DataCollatorForLanguageModeling — dynamic masking for MLM
  • DataCollatorForSeq2Seq — padding to longest in batch for encoder-decoder
  • DataCollatorWithPadding — simple padding with attention mask generation
  • DataCollatorForTokenClassification — aligns labels with subword tokens

The Trainer auto-selects a collator based on the processing class if none is provided.

Optimizers and Schedulers

The optimization.py module provides optimizer and scheduler implementations beyond what PyTorch offers out of the box. The default is AdamW, but args.optim supports AdaFactor, Lion, SGD, schedule-free optimizers, and LOMO (for memory-efficient fine-tuning).

Learning rate schedulers include linear warmup + decay, cosine, cosine with restarts, polynomial decay, and constant with warmup. The scheduler is created from args.lr_scheduler_type during train() setup.

Tip: For large-scale training, args.optim = "adamw_torch_fused" enables PyTorch's fused AdamW kernel, which can be 15-20% faster than the default by avoiding multiple kernel launches per parameter update.

Directory Map

File Purpose
src/transformers/trainer.py 4400-line Trainer class
src/transformers/training_args.py 100+ TrainingArguments parameters
src/transformers/trainer_callback.py Callback system with 20+ hooks
src/transformers/optimization.py Optimizers and LR schedulers
src/transformers/data/data_collator.py Data collators for batching
src/transformers/loss/loss_utils.py LOSS_MAPPING and shared loss functions
src/transformers/integrations/deepspeed.py DeepSpeed ZeRO integration
src/transformers/integrations/fsdp.py FSDP integration

We've now covered the complete model lifecycle: import → resolve → load → generate → train. In the final article, we'll look at the high-level Pipeline API, the tokenizer hierarchy, multimodal processing, and the extension points for contributing new models to the library.