Trainer:データから大規模な勾配計算まで
前提知識
- ›第3回:モデルの構造
- ›第4回:重みの読み込み
- ›PyTorch の学習ループの基礎(forward、backward、optimizer.step)
- ›分散学習の基本概念:DDP、勾配累積、混合精度
Trainer:データから大規模な勾配計算まで
Trainer クラスは、「1 GPU、8 GPU(DDP)、64 GPU(DeepSpeed ZeRO-3)、TPU ポッドのいずれでも、まったく同じコードで動く学習ループをどう書くか?」という問いに対する Transformers の答えです。約 4,400 行にも及ぶこのファイルはコードベース最大の単一ファイルであり、データ読み込み、混合精度、勾配累積、チェックポイント保存、評価、20 以上のコールバックフック、そして十数種類の分散バックエンドとの統合を一手に担っています。
この記事では、コアとなる学習フローを追いながら、コールバックのライフサイクルを詳しく見ていきます。また、モデルコードを一切変更することなく、Trainer が分散バックエンドとどう統合されるかも解説します。
Trainer.init と TrainingArguments
Trainer のコンストラクタが受け取るのは、モデル、100 以上のパラメータを持つ TrainingArguments データクラス、データセット、処理クラス、コールバック、そしてオプティマイザの上書き設定です。
TrainingArguments は学習プロセス全体の制御パネルとも言える存在です。主なパラメータグループは以下の通りです。
| グループ | パラメータ例 |
|---|---|
| 基本設定 | output_dir, num_train_epochs, per_device_train_batch_size |
| 最適化 | learning_rate, weight_decay, adam_beta1, adam_beta2, lr_scheduler_type |
| 混合精度 | fp16, bf16, fp16_opt_level |
| 分散学習 | local_rank, deepspeed, fsdp, fsdp_config |
| チェックポイント | save_strategy, save_steps, save_total_limit |
| ロギング | logging_steps, report_to, run_name |
| 勾配 | gradient_accumulation_steps, gradient_checkpointing, max_grad_norm |
コンストラクタでは、分散バックエンドを抽象化する Accelerate の Accelerator オブジェクトをセットアップします。1 GPU でも 64 GPU でも、Trainer の内部コードは変わりません。データ並列化、勾配の同期、デバイスへの配置はすべて Accelerate が担います。
classDiagram
class Trainer {
+model: PreTrainedModel
+args: TrainingArguments
+train_dataset: Dataset
+eval_dataset: Dataset
+processing_class: Any
+optimizer: Optimizer
+lr_scheduler: LRScheduler
+callback_handler: CallbackHandler
+train()
+evaluate()
+predict()
+training_step()
+compute_loss()
}
class TrainingArguments {
+output_dir: str
+num_train_epochs: float
+learning_rate: float
+per_device_train_batch_size: int
+gradient_accumulation_steps: int
+fp16: bool
+deepspeed: str
+100+ more parameters
}
Trainer --> TrainingArguments
学習ループ:train() → training_step()
train() メソッドが学習の起点となります。大まかな処理フローは以下の通りです。
flowchart TD
A["train()"] --> B["Resume from checkpoint?"]
B --> C["Create DataLoader"]
C --> D["Setup optimizer + scheduler"]
D --> E["Call on_train_begin callbacks"]
E --> F["For each epoch"]
F --> G["For each batch"]
G --> H["training_step(model, inputs)"]
H --> I{"Gradient accumulation<br/>boundary?"}
I -->|Yes| J["Clip gradients"]
J --> K["optimizer.step()"]
K --> L["scheduler.step()"]
L --> M["optimizer.zero_grad()"]
I -->|No| G
M --> N{"eval_steps reached?"}
N -->|Yes| O["evaluate()"]
N -->|No| P{"save_steps reached?"}
P -->|Yes| Q["save_checkpoint()"]
P -->|No| G
O --> G
Q --> G
F --> R["Call on_train_end callbacks"]
training_step() メソッドは、1 バッチ分の処理を担います。
def training_step(self, model, inputs, num_items_in_batch=None):
model.train()
inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
# Handle multi-GPU loss averaging
if self.args.n_gpu > 1:
loss = loss.mean()
# Backward pass (Accelerate handles mixed precision)
self.accelerator.backward(loss, **kwargs)
return loss.detach()
self.accelerator.backward(loss) の呼び出し箇所で、混合精度、勾配スケーリング、分散勾配の同期がすべて処理されます。これらは学習ステップのロジックからは完全に隠蔽されています。Accelerate ライブラリは Trainer の初期化時にモデル、オプティマイザ、データローダーをラップし、適切なフックを注入します。
ヒント: マルチタスク損失や敵対的学習など、カスタムの学習ロジックが必要な場合は
training_step()をオーバーライドしましょう。損失関数だけを差し替えたいシンプルなケースでは、compute_loss()のオーバーライドで十分です。Trainer はサブクラス化を前提に設計されており、多くのメソッドに「サブクラス化してオーバーライドすることでカスタム動作を注入できる」という説明が記載されています。
コールバックシステム
Trainer は TrainerCallback を通じて、20 以上のライフサイクルフックを持つ豊富なコールバックシステムを提供しています。
sequenceDiagram
participant T as Trainer
participant CH as CallbackHandler
participant CB1 as WandbCallback
participant CB2 as TensorBoardCallback
participant CB3 as PrinterCallback
T->>CH: on_train_begin(args, state, control)
CH->>CB1: on_train_begin(...)
CH->>CB2: on_train_begin(...)
CH->>CB3: on_train_begin(...)
loop Each training step
T->>CH: on_step_begin(...)
T->>T: training_step()
T->>CH: on_step_end(...)
T->>CH: on_log(logs={"loss": 0.5, ...})
CH->>CB1: on_log(logs) → upload to W&B
CH->>CB2: on_log(logs) → write to TB
end
T->>CH: on_evaluate(metrics)
T->>CH: on_train_end(...)
ライフサイクルフックは学習プロセス全体をカバーしています。
| フック | タイミング |
|---|---|
on_init_end |
Trainer.init の完了後 |
on_train_begin / on_train_end |
学習の開始/終了 |
on_epoch_begin / on_epoch_end |
各エポックの開始/終了 |
on_step_begin / on_step_end |
各学習ステップの前後 |
on_substep_end |
各勾配累積サブステップの終了後 |
on_log |
メトリクスのログ記録時 |
on_evaluate |
評価の完了後 |
on_save |
チェックポイントの保存時 |
on_prediction_step |
予測実行中 |
各コールバックは (args, state, control, **kwargs) を受け取ります。control オブジェクトだけが変更可能な戻り値であり、コールバックから control.should_training_stop = True や control.should_save = True を設定することで、学習ループの動作に影響を与えられます。
組み込みコールバックには、Weights & Biases、TensorBoard、MLflow、CometML、Neptune との統合が含まれており、args.report_to の設定から自動的に検出されます。
分散学習との統合
Trainer は主に 3 つの分散学習方式をサポートしています。
DDP(Distributed Data Parallel) — torchrun で起動した際のデフォルト方式です。各 GPU がモデルの完全なコピーを保持し、勾配は all-reduce によって同期されます。
DeepSpeed ZeRO — オプティマイザの状態(Stage 1)、勾配(Stage 2)、パラメータ(Stage 3)を段階的に GPU 間でシャーディングします。args.deepspeed に JSON 設定ファイルのパスを指定することで構成できます。初期化は integrations/deepspeed.py モジュールが担当します。
FSDP(Fully Sharded Data Parallel) — PyTorch ネイティブのパラメータシャーディング方式です。args.fsdp と args.fsdp_config で設定し、integrations/fsdp.py モジュールが橋渡し役を担います。
flowchart TD
A["TrainingArguments"] --> B{"deepspeed config?"}
B -->|Yes| C["Initialize DeepSpeed<br/>ZeRO Stage 1/2/3"]
B -->|No| D{"fsdp enabled?"}
D -->|Yes| E["Initialize FSDP<br/>via Accelerate"]
D -->|No| F{"n_gpu > 1?"}
F -->|Yes| G["Distributed Data<br/>Parallel (DDP)"]
F -->|No| H["Single GPU<br/>training"]
C --> I["Accelerate wraps<br/>model + optimizer"]
E --> I
G --> I
H --> I
このアーキテクチャの優れた点は、バックエンドに関わらず training_step() の見た目がまったく同じであることです。Accelerate ライブラリのラッピングが以下を処理します。
- 勾配累積(
accelerator.backward()を呼び出し、累積ステップかどうかに応じて勾配同期の有無を制御) - 混合精度(
autocastコンテキストによる FP16 または BF16) - 分散勾配の reduction(ラップされたモデルが処理)
損失関数とデータコレーター
LOSS_MAPPING レジストリは、モデルクラス名のサフィックスを共有損失関数にマッピングします。
LOSS_MAPPING = {
"ForCausalLM": ForCausalLMLoss,
"ForMaskedLM": ForMaskedLMLoss,
"ForSequenceClassification": ForSequenceClassificationLoss,
"ForTokenClassification": ForTokenClassification,
"ForQuestionAnswering": ForQuestionAnsweringLoss,
# ... plus object detection variants
}
各モデルの loss_function プロパティは、このマッピングからクラス名を検索します。ForCausalLMLoss 関数はラベルのシフト処理(labels[n] が token[n+1] を予測)を担い、平均 reduction と正規化付き合計 reduction の両方をサポートする fixed_cross_entropy を使用しています(勾配累積の正確性を担保するため)。
data/data_collator.py のデータコレーターは、さまざまなデータ型のバッチ処理を担います。
DataCollatorForLanguageModeling— MLM 向けの動的マスキングDataCollatorForSeq2Seq— エンコーダー・デコーダー向けの最長系列へのパディングDataCollatorWithPadding— アテンションマスク生成を伴うシンプルなパディングDataCollatorForTokenClassification— サブワードトークンとラベルの整合処理
コレーターが指定されていない場合、Trainer は処理クラスに基づいて自動的に適切なコレーターを選択します。
オプティマイザとスケジューラ
optimization.py モジュールは、PyTorch 標準の実装を超えたオプティマイザとスケジューラを提供します。デフォルトは AdamW ですが、args.optim では AdaFactor、Lion、SGD、スケジュールフリーオプティマイザ、LOMO(メモリ効率の良いファインチューニング向け)もサポートしています。
学習率スケジューラとしては、線形ウォームアップ + 減衰、コサイン、コサインリスタート、多項式減衰、ウォームアップ付き定数などが用意されており、train() のセットアップ時に args.lr_scheduler_type から生成されます。
ヒント: 大規模な学習では、
args.optim = "adamw_torch_fused"を指定すると PyTorch の fused AdamW カーネルが有効になります。パラメータ更新ごとに複数のカーネル起動を回避できるため、デフォルトと比べて 15〜20% の高速化が期待できます。
ディレクトリマップ
| ファイル | 役割 |
|---|---|
src/transformers/trainer.py |
4,400 行の Trainer クラス |
src/transformers/training_args.py |
100 以上の TrainingArguments パラメータ |
src/transformers/trainer_callback.py |
20 以上のフックを持つコールバックシステム |
src/transformers/optimization.py |
オプティマイザと学習率スケジューラ |
src/transformers/data/data_collator.py |
バッチ処理用データコレーター |
src/transformers/loss/loss_utils.py |
LOSS_MAPPING と共有損失関数 |
src/transformers/integrations/deepspeed.py |
DeepSpeed ZeRO 統合 |
src/transformers/integrations/fsdp.py |
FSDP 統合 |
これでモデルのライフサイクル全体(import → resolve → load → generate → train)をひと通りカバーしました。最終回では、高レベルの Pipeline API、トークナイザーの階層構造、マルチモーダル処理を解説します。さらに、ライブラリへの新モデル追加に向けた拡張ポイントを取り上げます。