Read OSS

Hub から GPU へ:重みロードのパイプライン

上級

前提知識

  • 第3回:モデルの構造とクラス階層
  • PyTorch の state_dict・meta デバイス・メモリ管理
  • safetensors ファイルフォーマットの基礎
  • 量子化の概念(INT8・INT4・GPTQ・AWQ)についての基本的な理解

Hub から GPU へ:重みロードのパイプライン

前回の記事では、Transformers が "meta-llama/Llama-2-7b-hf"LlamaForCausalLM に解決する仕組みと、そのクラスの内部構造を見てきました。では、14 GB もの重みが正しいフォーマットで GPU に乗るまでに、何が起きているのでしょうか。PreTrainedModelfrom_pretrained() は 60 以上のパラメータを受け取るオーケストレーターで、safetensors のダウンロード・meta デバイスの初期化・チェックポイントフォーマット変換・量子化・マルチ GPU へのディスパッチをまとめて処理します。本記事では、このパイプライン全体を順を追って解説します。

from_pretrained() のオーケストレーション

from_pretrained() のシグネチャを見るだけで、何を担っているかが伝わります。

def from_pretrained(
    cls,
    pretrained_model_name_or_path,
    *model_args,
    config=None,
    cache_dir=None,
    ignore_mismatched_sizes=False,
    force_download=False,
    local_files_only=False,
    token=None,
    revision="main",
    use_safetensors=None,
    weights_only=True,
    **kwargs,  # ~50 more via kwargs
):

**kwargs の裏には多くのパラメータが隠れています。device_maptorch_dtypequantization_configattn_implementationlow_cpu_mem_usage など、重要なオプションのほとんどはここに収まっています。処理の大まかな流れは次のとおりです。

sequenceDiagram
    participant User
    participant FP as from_pretrained()
    participant Hub as HuggingFace Hub
    participant QA as Quantizer Auto-Select
    participant DM as Device Map
    participant Init as Meta Device Init
    participant WL as Weight Loading
    participant Post as Post-Load Hooks

    User->>FP: from_pretrained("meta-llama/Llama-2-7b-hf")
    FP->>Hub: Download config.json
    Hub-->>FP: Config
    FP->>FP: Resolve config class
    FP->>QA: Select quantizer (if any)
    QA-->>FP: HfQuantizer or None
    FP->>DM: Compute device_map
    DM-->>FP: {"model.layers.0": 0, ...}
    FP->>Init: Initialize model on meta device
    Init-->>FP: Empty model skeleton
    FP->>WL: Load weights shard by shard
    WL->>WL: Apply WeightConverter/WeightRenaming
    WL->>WL: Quantize if needed
    WL->>WL: Place on target device
    FP->>Post: Post-weight-loading hooks
    Post-->>User: Ready model

まず config の解決(必要に応じて AutoConfig を利用)が行われ、その後に重みのロード処理へと進みます。LoadStateDictConfig データクラスは、ロードに関するすべてのパラメータをひとつのイミュータブルなオブジェクトにまとめる役割を担っています。

@dataclass(frozen=True)
class LoadStateDictConfig:
    pretrained_model_name_or_path: str | None = None
    use_safetensors: bool | None = None
    ignore_mismatched_sizes: bool = False
    device_map: dict | None = None
    dtype: torch.dtype | None = None
    hf_quantizer: HfQuantizer | None = None
    weight_mapping: list[WeightConverter | WeightRenaming] | None = None
    # ...

この config を重みロードパイプライン全体に渡すことで、個々のパラメータを十数個の関数に引き回すというアンチパターンを回避しています。

core_model_loading.py:重み変換エンジン

core_model_loading.py は、チェックポイントのフォーマット差異を吸収するモジュールです。同じモデルでもバージョンによって重みのキー名やテンソルレイアウトが異なる場合があります。WeightConverterWeightRenaming データクラスは、こうした変換を宣言的に表現します。

WeightRenaming は単純なキー名の変換を担います。

@dataclass(slots=True)
class WeightRenaming(WeightTransform):
    # Maps source_patterns → target_patterns
    # e.g., "model.layers.*.attn.qkv.weight" → "model.layers.*.self_attn.q_proj.weight"

WeightConverter は、ConversionOps のパイプラインを通じて実際のテンソル変換を行います。

flowchart LR
    A["Checkpoint tensor<br/>qkv_proj.weight"] --> B["Chunk(dim=0)"]
    B --> C["q_proj.weight"]
    B --> D["k_proj.weight"]
    B --> E["v_proj.weight"]

ConversionOps の階層には、テンソルを分割する Chunk、結合する Concatenate、量子化フォーマット向けの特殊な操作などが含まれます。各操作は reverse_op プロパティを持っており、ロードとセーブの両方向での変換が可能です。

この変換システムでは、build_glob_alternation() を使ってすべての重みパターンをひとつの正規表現にコンパイルします。変換ルールが数百個あっても、パターンマッチングを高速に処理できる仕組みです。

ヒント: 重みコンバーターはモデルのモジュール(またはコンフィグ)ごとに定義されています。ネイティブフォーマット以外のチェックポイントをロードする場合は、モデルクラスの _checkpoint_conversion_mappingweight_mapping 属性を確認してみてください。

量子化器の自動選択

quantization_config を指定してロードする場合(または事前量子化済みモデルの場合)、Transformers は 20 以上のバックエンドの中から適切な量子化器を自動で選択します。quantizers/auto.pyAUTO_QUANTIZER_MAPPING が、メソッド名と量子化クラスの対応を定義しています。

AUTO_QUANTIZER_MAPPING = {
    "awq": AwqQuantizer,
    "bitsandbytes_4bit": Bnb4BitHfQuantizer,
    "bitsandbytes_8bit": Bnb8BitHfQuantizer,
    "gptq": GptqHfQuantizer,
    "quanto": QuantoHfQuantizer,
    # ... 20+ more
}
flowchart TD
    A["quantization_config provided?"] -->|Yes| B["Extract quant_method"]
    A -->|No| C["Check config.json<br/>for quantization_config"]
    C -->|Found| B
    C -->|Not found| D["No quantization"]
    B --> E["AUTO_QUANTIZER_MAPPING[method]"]
    E --> F["HfQuantizer subclass"]
    F --> G["validate_environment()"]
    G --> H["preprocess_model()"]
    H --> I["Weight loading with<br/>quantization hooks"]
    I --> J["postprocess_model()"]

HfQuantizer 基底クラスは、重みロード処理を包むライフサイクルフックを定義しています。

  • validate_environment() — 必要なパッケージがインストールされているか確認する
  • update_device_map() — 一部の量子化器は device_map="auto" を強制する
  • _process_model_before_weight_loading() — meta デバイス上のモジュールを量子化済みバリアントに差し替える
  • _process_model_after_weight_loading() — 量子化後の重みをファイナライズする
  • param_needs_quantization() — パラメータごとに量子化するかどうかを判断する

get_keys_to_not_convert() 関数は、LM head と tied weights を量子化の対象から自動的に除外します。出力層をフル精度のまま保つことで、数値的な安定性を確保しています。

デバイスマップとマルチ GPU ロード

単一の GPU に収まらない大規模モデルを扱う場合、Transformers は Accelerate のデバイスマップシステムと連携します。integrations/accelerate.py モジュールがその橋渡しを担っています。

device_map="auto" を指定すると、以下の流れで処理が進みます。

  1. config からモデルサイズを推定する(重みはロードしない)
  2. 利用可能な GPU メモリを照会する
  3. infer_auto_device_map() を実行して各モジュールのデバイスを決定する
  4. _no_split_modules を尊重して、デコーダー層が複数デバイスにまたがらないようにする
flowchart TD
    A["device_map='auto'"] --> B["Estimate model size<br/>from config"]
    B --> C["Query GPU memory"]
    C --> D["infer_auto_device_map()"]
    D --> E{"Model fits on<br/>single GPU?"}
    E -->|Yes| F["All on GPU 0"]
    E -->|No| G["Split across GPUs"]
    G --> H{"Still too large?"}
    H -->|Yes| I["Offload to CPU"]
    I --> J{"CPU OOM?"}
    J -->|Yes| K["Offload to disk"]
    J -->|No| L["device_map ready"]
    H -->|No| L
    F --> L

デバイスマップはパラメータ名からデバイスへの辞書として表現されます(例:{"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.16": 1, "lm_head": "cpu"})。重みのロード時には、各テンソルが直接ターゲットデバイスに配置されるため、モデル全体を一度に CPU メモリに乗せる必要がありません。

safetensors とシャード単位のロード

最近のチェックポイントは safetensors フォーマットを採用しており、メモリマップトアクセスをサポートし、pickle ベースの攻撃も受けません。重みは通常、複数のファイルにシャーディングされています(例:model-00001-of-00003.safetensors)。

ロード処理では、まずシャードインデックス(model.safetensors.index.json)を読み込んでどのパラメータがどのシャードにあるかを把握し、その後シャードを 1 つずつ読み込みます。大規模モデルでは特に重要な設計で、14 GB すべてを CPU メモリに乗せてからディスパッチするのではなく、各シャードのテンソルが直接ターゲットデバイスに送られます。

sequenceDiagram
    participant FP as from_pretrained
    participant Idx as Shard Index
    participant S1 as Shard 1
    participant S2 as Shard 2
    participant GPU0 as GPU 0
    participant GPU1 as GPU 1

    FP->>Idx: Read index.json
    Idx-->>FP: param → shard mapping
    FP->>S1: Load shard 1 (memory-mapped)
    S1-->>FP: layers 0-15 weights
    FP->>GPU0: Place layers 0-15
    Note over FP: Shard 1 released from memory
    FP->>S2: Load shard 2 (memory-mapped)
    S2-->>FP: layers 16-31 weights
    FP->>GPU1: Place layers 16-31

use_safetensors パラメータのデフォルト値は None です。これは自動検出を意味し、safetensors が利用可能であれば優先して使用し、なければ PyTorch の .bin ファイルにフォールバックします。safetensors_conversion.py モジュールは、保存時に .bin から safetensors フォーマットへの自動変換を担っています。

ヒント: 環境変数 SAFETENSORS_FAST_GPU=1 を設定すると、safetensors ファイルを CPU を経由せず直接 GPU にロードできます。safetensors ライブラリのバージョン 0.4.0 以上が必要ですが、高速な NVMe ストレージを搭載したシステムではロード時間を大幅に短縮できます。

ロードレポート

すべての重みがロードされた後、from_pretrained() は詳細なレポートをログに出力します。ロードに成功したキー・見つからなかったキー(tied weights では想定内)・想定外のキー・サイズ不一致のキーが一覧表示されます。この情報は LoadStateDictInfo データクラスがロード処理全体を通じて収集したものです。ファインチューニング済みチェックポイントがベースモデルの構造から外れている場合に、問題の診断に役立ちます。

最後に model.eval() が呼び出されます。これは意図的なデフォルトで、推論時にドロップアウトなどの学習時の挙動が誤って有効になるのを防ぎます。ファインチューニングを行う場合は、明示的に model.train() を呼び出す必要があります。

ディレクトリマップ

ファイル 役割
src/transformers/modeling_utils.py from_pretrained()LoadStateDictConfig
src/transformers/core_model_loading.py WeightConverterWeightRenaming、変換操作
src/transformers/quantizers/auto.py AUTO_QUANTIZER_MAPPING、量子化器の選択
src/transformers/quantizers/base.py ライフサイクルフックを持つ HfQuantizer 基底クラス
src/transformers/integrations/accelerate.py デバイスマップの計算、Accelerate との連携
src/transformers/safetensors_conversion.py フォーマット変換ユーティリティ

モデルのロードが完了したら、次はテキスト生成です。次回は generate() メソッドに深く踏み込みます。デコード戦略・KV キャッシュ・logits 処理・スペキュラティブデコーディング・トークンストリーミングを統括する 1700 行のオーケストレーターを解説していきます。