The Trainer: From Data to Gradients at Scale
Prerequisites
- ›Article 3: Model anatomy
- ›Article 4: Weight loading
- ›PyTorch training loop basics (forward, backward, optimizer.step)
- ›Distributed training concepts: DDP, gradient accumulation, mixed precision
The Trainer: From Data to Gradients at Scale
The Trainer class is Transformers' answer to the question: "How do you write one training loop that works on a single GPU, 8 GPUs with DDP, 64 GPUs with DeepSpeed ZeRO-3, or a TPU pod — all from the same user code?" At ~4400 lines, it's the largest single file in the codebase, handling data loading, mixed precision, gradient accumulation, checkpointing, evaluation, 20+ callback hooks, and integration with a dozen distributed backends.
This article traces the core training flow, examines the callback lifecycle, and shows how the Trainer integrates with distributed backends without the user changing a line of model code.
Trainer.init and TrainingArguments
The Trainer's constructor accepts a model, a TrainingArguments dataclass with 100+ parameters, datasets, a processing class, callbacks, and optimizer overrides.
TrainingArguments is the control surface for the entire training process. Key parameter groups include:
| Group | Examples |
|---|---|
| Basics | output_dir, num_train_epochs, per_device_train_batch_size |
| Optimization | learning_rate, weight_decay, adam_beta1, adam_beta2, lr_scheduler_type |
| Mixed precision | fp16, bf16, fp16_opt_level |
| Distributed | local_rank, deepspeed, fsdp, fsdp_config |
| Checkpointing | save_strategy, save_steps, save_total_limit |
| Logging | logging_steps, report_to, run_name |
| Gradient | gradient_accumulation_steps, gradient_checkpointing, max_grad_norm |
The constructor sets up the Accelerate Accelerator object, which abstracts away the distributed backend. Whether you're running on 1 GPU or 64, the Trainer's internal code is the same — Accelerate handles the data parallelism, gradient synchronization, and device placement.
classDiagram
class Trainer {
+model: PreTrainedModel
+args: TrainingArguments
+train_dataset: Dataset
+eval_dataset: Dataset
+processing_class: Any
+optimizer: Optimizer
+lr_scheduler: LRScheduler
+callback_handler: CallbackHandler
+train()
+evaluate()
+predict()
+training_step()
+compute_loss()
}
class TrainingArguments {
+output_dir: str
+num_train_epochs: float
+learning_rate: float
+per_device_train_batch_size: int
+gradient_accumulation_steps: int
+fp16: bool
+deepspeed: str
+100+ more parameters
}
Trainer --> TrainingArguments
The Training Loop: train() → training_step()
The train() method is the main entry point. Its high-level flow:
flowchart TD
A["train()"] --> B["Resume from checkpoint?"]
B --> C["Create DataLoader"]
C --> D["Setup optimizer + scheduler"]
D --> E["Call on_train_begin callbacks"]
E --> F["For each epoch"]
F --> G["For each batch"]
G --> H["training_step(model, inputs)"]
H --> I{"Gradient accumulation<br/>boundary?"}
I -->|Yes| J["Clip gradients"]
J --> K["optimizer.step()"]
K --> L["scheduler.step()"]
L --> M["optimizer.zero_grad()"]
I -->|No| G
M --> N{"eval_steps reached?"}
N -->|Yes| O["evaluate()"]
N -->|No| P{"save_steps reached?"}
P -->|Yes| Q["save_checkpoint()"]
P -->|No| G
O --> G
Q --> G
F --> R["Call on_train_end callbacks"]
The training_step() method handles a single batch:
def training_step(self, model, inputs, num_items_in_batch=None):
model.train()
inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
# Handle multi-GPU loss averaging
if self.args.n_gpu > 1:
loss = loss.mean()
# Backward pass (Accelerate handles mixed precision)
self.accelerator.backward(loss, **kwargs)
return loss.detach()
The self.accelerator.backward(loss) call is where mixed precision, gradient scaling, and distributed gradient synchronization happen — all transparent to the training step logic. The Accelerate library wraps the model, optimizer, and data loaders during Trainer initialization, injecting the appropriate hooks.
Tip: Override
training_step()when you need custom training logic (e.g., multi-task losses, adversarial training). Overridecompute_loss()for simpler cases where you just need a different loss function. The Trainer is designed for subclassing — many method signatures include "Subclass and override to inject custom behavior."
The Callback System
The Trainer provides a rich callback system via TrainerCallback, with 20+ lifecycle hooks:
sequenceDiagram
participant T as Trainer
participant CH as CallbackHandler
participant CB1 as WandbCallback
participant CB2 as TensorBoardCallback
participant CB3 as PrinterCallback
T->>CH: on_train_begin(args, state, control)
CH->>CB1: on_train_begin(...)
CH->>CB2: on_train_begin(...)
CH->>CB3: on_train_begin(...)
loop Each training step
T->>CH: on_step_begin(...)
T->>T: training_step()
T->>CH: on_step_end(...)
T->>CH: on_log(logs={"loss": 0.5, ...})
CH->>CB1: on_log(logs) → upload to W&B
CH->>CB2: on_log(logs) → write to TB
end
T->>CH: on_evaluate(metrics)
T->>CH: on_train_end(...)
The lifecycle hooks span the entire training process:
| Hook | When |
|---|---|
on_init_end |
After Trainer.init |
on_train_begin / on_train_end |
Start/end of training |
on_epoch_begin / on_epoch_end |
Start/end of each epoch |
on_step_begin / on_step_end |
Before/after each training step |
on_substep_end |
After each gradient accumulation substep |
on_log |
When metrics are logged |
on_evaluate |
After evaluation |
on_save |
When a checkpoint is saved |
on_prediction_step |
During prediction |
Each callback receives (args, state, control, **kwargs). The control object is the only mutable return — callbacks can set control.should_training_stop = True or control.should_save = True to influence the training loop.
Built-in callbacks include integrations with Weights & Biases, TensorBoard, MLflow, CometML, and Neptune — all auto-detected from args.report_to.
Distributed Training Integration
The Trainer supports three main distributed paradigms:
DDP (Distributed Data Parallel) — the default when launching with torchrun. Each GPU has a full model copy, gradients are synchronized via all-reduce.
DeepSpeed ZeRO — progressively shards optimizer states (Stage 1), gradients (Stage 2), and parameters (Stage 3) across GPUs. Configured via args.deepspeed pointing to a JSON config. The integrations/deepspeed.py module handles initialization.
FSDP (Fully Sharded Data Parallel) — PyTorch-native parameter sharding. Configured via args.fsdp and args.fsdp_config. The integrations/fsdp.py module provides the bridge.
flowchart TD
A["TrainingArguments"] --> B{"deepspeed config?"}
B -->|Yes| C["Initialize DeepSpeed<br/>ZeRO Stage 1/2/3"]
B -->|No| D{"fsdp enabled?"}
D -->|Yes| E["Initialize FSDP<br/>via Accelerate"]
D -->|No| F{"n_gpu > 1?"}
F -->|Yes| G["Distributed Data<br/>Parallel (DDP)"]
F -->|No| H["Single GPU<br/>training"]
C --> I["Accelerate wraps<br/>model + optimizer"]
E --> I
G --> I
H --> I
The beauty of this architecture is that training_step() looks identical regardless of backend. The Accelerate library's wrapping handles:
- Gradient accumulation (calls
accelerator.backward()which may or may not sync gradients depending on whether it's an accumulation step) - Mixed precision (FP16 or BF16 via
autocastcontext) - Distributed gradient reduction (handled by the wrapped model)
Loss Functions and Data Collators
The LOSS_MAPPING registry maps model class name suffixes to shared loss functions:
LOSS_MAPPING = {
"ForCausalLM": ForCausalLMLoss,
"ForMaskedLM": ForMaskedLMLoss,
"ForSequenceClassification": ForSequenceClassificationLoss,
"ForTokenClassification": ForTokenClassification,
"ForQuestionAnswering": ForQuestionAnsweringLoss,
# ... plus object detection variants
}
Each model's loss_function property looks up its class name in this mapping. The ForCausalLMLoss function handles the label shifting (labels[n] predicts token[n+1]) and uses fixed_cross_entropy which supports both mean reduction and sum-with-normalization (for gradient accumulation correctness).
Data collators from data/data_collator.py handle batching different data types:
DataCollatorForLanguageModeling— dynamic masking for MLMDataCollatorForSeq2Seq— padding to longest in batch for encoder-decoderDataCollatorWithPadding— simple padding with attention mask generationDataCollatorForTokenClassification— aligns labels with subword tokens
The Trainer auto-selects a collator based on the processing class if none is provided.
Optimizers and Schedulers
The optimization.py module provides optimizer and scheduler implementations beyond what PyTorch offers out of the box. The default is AdamW, but args.optim supports AdaFactor, Lion, SGD, schedule-free optimizers, and LOMO (for memory-efficient fine-tuning).
Learning rate schedulers include linear warmup + decay, cosine, cosine with restarts, polynomial decay, and constant with warmup. The scheduler is created from args.lr_scheduler_type during train() setup.
Tip: For large-scale training,
args.optim = "adamw_torch_fused"enables PyTorch's fused AdamW kernel, which can be 15-20% faster than the default by avoiding multiple kernel launches per parameter update.
Directory Map
| File | Purpose |
|---|---|
src/transformers/trainer.py |
4400-line Trainer class |
src/transformers/training_args.py |
100+ TrainingArguments parameters |
src/transformers/trainer_callback.py |
Callback system with 20+ hooks |
src/transformers/optimization.py |
Optimizers and LR schedulers |
src/transformers/data/data_collator.py |
Data collators for batching |
src/transformers/loss/loss_utils.py |
LOSS_MAPPING and shared loss functions |
src/transformers/integrations/deepspeed.py |
DeepSpeed ZeRO integration |
src/transformers/integrations/fsdp.py |
FSDP integration |
We've now covered the complete model lifecycle: import → resolve → load → generate → train. In the final article, we'll look at the high-level Pipeline API, the tokenizer hierarchy, multimodal processing, and the extension points for contributing new models to the library.