Read OSS

Trainer:データから大規模な勾配計算まで

上級

前提知識

  • 第3回:モデルの構造
  • 第4回:重みの読み込み
  • PyTorch の学習ループの基礎(forward、backward、optimizer.step)
  • 分散学習の基本概念:DDP、勾配累積、混合精度

Trainer:データから大規模な勾配計算まで

Trainer クラスは、「1 GPU、8 GPU(DDP)、64 GPU(DeepSpeed ZeRO-3)、TPU ポッドのいずれでも、まったく同じコードで動く学習ループをどう書くか?」という問いに対する Transformers の答えです。約 4,400 行にも及ぶこのファイルはコードベース最大の単一ファイルであり、データ読み込み、混合精度、勾配累積、チェックポイント保存、評価、20 以上のコールバックフック、そして十数種類の分散バックエンドとの統合を一手に担っています。

この記事では、コアとなる学習フローを追いながら、コールバックのライフサイクルを詳しく見ていきます。また、モデルコードを一切変更することなく、Trainer が分散バックエンドとどう統合されるかも解説します。

Trainer.init と TrainingArguments

Trainer のコンストラクタが受け取るのは、モデル、100 以上のパラメータを持つ TrainingArguments データクラス、データセット、処理クラス、コールバック、そしてオプティマイザの上書き設定です。

TrainingArguments は学習プロセス全体の制御パネルとも言える存在です。主なパラメータグループは以下の通りです。

グループ パラメータ例
基本設定 output_dir, num_train_epochs, per_device_train_batch_size
最適化 learning_rate, weight_decay, adam_beta1, adam_beta2, lr_scheduler_type
混合精度 fp16, bf16, fp16_opt_level
分散学習 local_rank, deepspeed, fsdp, fsdp_config
チェックポイント save_strategy, save_steps, save_total_limit
ロギング logging_steps, report_to, run_name
勾配 gradient_accumulation_steps, gradient_checkpointing, max_grad_norm

コンストラクタでは、分散バックエンドを抽象化する Accelerate の Accelerator オブジェクトをセットアップします。1 GPU でも 64 GPU でも、Trainer の内部コードは変わりません。データ並列化、勾配の同期、デバイスへの配置はすべて Accelerate が担います。

classDiagram
    class Trainer {
        +model: PreTrainedModel
        +args: TrainingArguments
        +train_dataset: Dataset
        +eval_dataset: Dataset
        +processing_class: Any
        +optimizer: Optimizer
        +lr_scheduler: LRScheduler
        +callback_handler: CallbackHandler
        +train()
        +evaluate()
        +predict()
        +training_step()
        +compute_loss()
    }
    class TrainingArguments {
        +output_dir: str
        +num_train_epochs: float
        +learning_rate: float
        +per_device_train_batch_size: int
        +gradient_accumulation_steps: int
        +fp16: bool
        +deepspeed: str
        +100+ more parameters
    }
    Trainer --> TrainingArguments

学習ループ:train() → training_step()

train() メソッドが学習の起点となります。大まかな処理フローは以下の通りです。

flowchart TD
    A["train()"] --> B["Resume from checkpoint?"]
    B --> C["Create DataLoader"]
    C --> D["Setup optimizer + scheduler"]
    D --> E["Call on_train_begin callbacks"]
    E --> F["For each epoch"]
    F --> G["For each batch"]
    G --> H["training_step(model, inputs)"]
    H --> I{"Gradient accumulation<br/>boundary?"}
    I -->|Yes| J["Clip gradients"]
    J --> K["optimizer.step()"]
    K --> L["scheduler.step()"]
    L --> M["optimizer.zero_grad()"]
    I -->|No| G
    M --> N{"eval_steps reached?"}
    N -->|Yes| O["evaluate()"]
    N -->|No| P{"save_steps reached?"}
    P -->|Yes| Q["save_checkpoint()"]
    P -->|No| G
    O --> G
    Q --> G
    F --> R["Call on_train_end callbacks"]

training_step() メソッドは、1 バッチ分の処理を担います。

def training_step(self, model, inputs, num_items_in_batch=None):
    model.train()
    inputs = self._prepare_inputs(inputs)
    
    with self.compute_loss_context_manager():
        loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
    
    # Handle multi-GPU loss averaging
    if self.args.n_gpu > 1:
        loss = loss.mean()
    
    # Backward pass (Accelerate handles mixed precision)
    self.accelerator.backward(loss, **kwargs)
    
    return loss.detach()

self.accelerator.backward(loss) の呼び出し箇所で、混合精度、勾配スケーリング、分散勾配の同期がすべて処理されます。これらは学習ステップのロジックからは完全に隠蔽されています。Accelerate ライブラリは Trainer の初期化時にモデル、オプティマイザ、データローダーをラップし、適切なフックを注入します。

ヒント: マルチタスク損失や敵対的学習など、カスタムの学習ロジックが必要な場合は training_step() をオーバーライドしましょう。損失関数だけを差し替えたいシンプルなケースでは、compute_loss() のオーバーライドで十分です。Trainer はサブクラス化を前提に設計されており、多くのメソッドに「サブクラス化してオーバーライドすることでカスタム動作を注入できる」という説明が記載されています。

コールバックシステム

Trainer は TrainerCallback を通じて、20 以上のライフサイクルフックを持つ豊富なコールバックシステムを提供しています。

sequenceDiagram
    participant T as Trainer
    participant CH as CallbackHandler
    participant CB1 as WandbCallback
    participant CB2 as TensorBoardCallback
    participant CB3 as PrinterCallback

    T->>CH: on_train_begin(args, state, control)
    CH->>CB1: on_train_begin(...)
    CH->>CB2: on_train_begin(...)
    CH->>CB3: on_train_begin(...)
    
    loop Each training step
        T->>CH: on_step_begin(...)
        T->>T: training_step()
        T->>CH: on_step_end(...)
        T->>CH: on_log(logs={"loss": 0.5, ...})
        CH->>CB1: on_log(logs) → upload to W&B
        CH->>CB2: on_log(logs) → write to TB
    end
    
    T->>CH: on_evaluate(metrics)
    T->>CH: on_train_end(...)

ライフサイクルフックは学習プロセス全体をカバーしています。

フック タイミング
on_init_end Trainer.init の完了後
on_train_begin / on_train_end 学習の開始/終了
on_epoch_begin / on_epoch_end 各エポックの開始/終了
on_step_begin / on_step_end 各学習ステップの前後
on_substep_end 各勾配累積サブステップの終了後
on_log メトリクスのログ記録時
on_evaluate 評価の完了後
on_save チェックポイントの保存時
on_prediction_step 予測実行中

各コールバックは (args, state, control, **kwargs) を受け取ります。control オブジェクトだけが変更可能な戻り値であり、コールバックから control.should_training_stop = Truecontrol.should_save = True を設定することで、学習ループの動作に影響を与えられます。

組み込みコールバックには、Weights & Biases、TensorBoard、MLflow、CometML、Neptune との統合が含まれており、args.report_to の設定から自動的に検出されます。

分散学習との統合

Trainer は主に 3 つの分散学習方式をサポートしています。

DDP(Distributed Data Parallel)torchrun で起動した際のデフォルト方式です。各 GPU がモデルの完全なコピーを保持し、勾配は all-reduce によって同期されます。

DeepSpeed ZeRO — オプティマイザの状態(Stage 1)、勾配(Stage 2)、パラメータ(Stage 3)を段階的に GPU 間でシャーディングします。args.deepspeed に JSON 設定ファイルのパスを指定することで構成できます。初期化は integrations/deepspeed.py モジュールが担当します。

FSDP(Fully Sharded Data Parallel) — PyTorch ネイティブのパラメータシャーディング方式です。args.fsdpargs.fsdp_config で設定し、integrations/fsdp.py モジュールが橋渡し役を担います。

flowchart TD
    A["TrainingArguments"] --> B{"deepspeed config?"}
    B -->|Yes| C["Initialize DeepSpeed<br/>ZeRO Stage 1/2/3"]
    B -->|No| D{"fsdp enabled?"}
    D -->|Yes| E["Initialize FSDP<br/>via Accelerate"]
    D -->|No| F{"n_gpu > 1?"}
    F -->|Yes| G["Distributed Data<br/>Parallel (DDP)"]
    F -->|No| H["Single GPU<br/>training"]
    C --> I["Accelerate wraps<br/>model + optimizer"]
    E --> I
    G --> I
    H --> I

このアーキテクチャの優れた点は、バックエンドに関わらず training_step() の見た目がまったく同じであることです。Accelerate ライブラリのラッピングが以下を処理します。

  • 勾配累積(accelerator.backward() を呼び出し、累積ステップかどうかに応じて勾配同期の有無を制御)
  • 混合精度(autocast コンテキストによる FP16 または BF16)
  • 分散勾配の reduction(ラップされたモデルが処理)

損失関数とデータコレーター

LOSS_MAPPING レジストリは、モデルクラス名のサフィックスを共有損失関数にマッピングします。

LOSS_MAPPING = {
    "ForCausalLM": ForCausalLMLoss,
    "ForMaskedLM": ForMaskedLMLoss,
    "ForSequenceClassification": ForSequenceClassificationLoss,
    "ForTokenClassification": ForTokenClassification,
    "ForQuestionAnswering": ForQuestionAnsweringLoss,
    # ... plus object detection variants
}

各モデルの loss_function プロパティは、このマッピングからクラス名を検索します。ForCausalLMLoss 関数はラベルのシフト処理(labels[n] が token[n+1] を予測)を担い、平均 reduction と正規化付き合計 reduction の両方をサポートする fixed_cross_entropy を使用しています(勾配累積の正確性を担保するため)。

data/data_collator.py のデータコレーターは、さまざまなデータ型のバッチ処理を担います。

  • DataCollatorForLanguageModeling — MLM 向けの動的マスキング
  • DataCollatorForSeq2Seq — エンコーダー・デコーダー向けの最長系列へのパディング
  • DataCollatorWithPadding — アテンションマスク生成を伴うシンプルなパディング
  • DataCollatorForTokenClassification — サブワードトークンとラベルの整合処理

コレーターが指定されていない場合、Trainer は処理クラスに基づいて自動的に適切なコレーターを選択します。

オプティマイザとスケジューラ

optimization.py モジュールは、PyTorch 標準の実装を超えたオプティマイザとスケジューラを提供します。デフォルトは AdamW ですが、args.optim では AdaFactor、Lion、SGD、スケジュールフリーオプティマイザ、LOMO(メモリ効率の良いファインチューニング向け)もサポートしています。

学習率スケジューラとしては、線形ウォームアップ + 減衰、コサイン、コサインリスタート、多項式減衰、ウォームアップ付き定数などが用意されており、train() のセットアップ時に args.lr_scheduler_type から生成されます。

ヒント: 大規模な学習では、args.optim = "adamw_torch_fused" を指定すると PyTorch の fused AdamW カーネルが有効になります。パラメータ更新ごとに複数のカーネル起動を回避できるため、デフォルトと比べて 15〜20% の高速化が期待できます。

ディレクトリマップ

ファイル 役割
src/transformers/trainer.py 4,400 行の Trainer クラス
src/transformers/training_args.py 100 以上の TrainingArguments パラメータ
src/transformers/trainer_callback.py 20 以上のフックを持つコールバックシステム
src/transformers/optimization.py オプティマイザと学習率スケジューラ
src/transformers/data/data_collator.py バッチ処理用データコレーター
src/transformers/loss/loss_utils.py LOSS_MAPPING と共有損失関数
src/transformers/integrations/deepspeed.py DeepSpeed ZeRO 統合
src/transformers/integrations/fsdp.py FSDP 統合

これでモデルのライフサイクル全体(import → resolve → load → generate → train)をひと通りカバーしました。最終回では、高レベルの Pipeline API、トークナイザーの階層構造、マルチモーダル処理を解説します。さらに、ライブラリへの新モデル追加に向けた拡張ポイントを取り上げます。