Read OSS

モデルの内側:LlamaConfigからLlamaForCausalLMまで

上級

前提知識

  • 第2回:Autoクラスとconfigシステム
  • PyTorchのnn.Module、forward pass、state_dict
  • Transformerアーキテクチャ:マルチヘッドアテンション、KVキャッシュ、位置埋め込み
  • 勾配チェックポインティングの基本的な理解

モデルの内側:LlamaConfigからLlamaForCausalLMまで

前回の記事では、import transformers がすべてのインポートを遅延させる仕組みと、Autoクラスがモデル名から適切なPythonクラスを引き当てる仕組みを見てきました。今回はそのクラスの中身に踏み込みます。LLaMAを具体的なケーススタディとして、PreTrainedModel から個々のアテンション層まで、モデル階層の全体像をたどります。450以上のモデル実装を一貫性を保ちながら柔軟にしている共通インフラの正体が見えてくるはずです。

現代のTransformersモデルは、単なる nn.Module の積み重ねではありません。差し替え可能なアテンションバックエンド、Hubからカーネルをホットスワップする仕組み、宣言的な並列化プラン、タスク固有モデルを1行に凝縮するジェネリックヘッドクラスを備えた、精巧に設計されたクラス階層です。

PreTrainedModel:基底クラス

Transformersのすべてのモデルは PreTrainedModel を継承しており、このクラス自体が複数のmixinから構成されています。

classDiagram
    class nn_Module {
        +forward()
        +parameters()
        +state_dict()
    }
    class EmbeddingAccessMixin {
        +get_input_embeddings()
        +set_input_embeddings()
        +get_output_embeddings()
    }
    class ModuleUtilsMixin {
        +num_parameters()
        +estimate_tokens()
        +floating_point_ops()
    }
    class PushToHubMixin {
        +push_to_hub()
    }
    class PeftAdapterMixin {
        +load_adapter()
        +set_adapter()
        +disable_adapters()
    }
    class PreTrainedModel {
        +config_class: type
        +base_model_prefix: str
        +_supports_sdpa: bool
        +_supports_flash_attn: bool
        +_supports_flex_attn: bool
        +_tp_plan: dict
        +from_pretrained()
        +save_pretrained()
        +post_init()
    }
    nn_Module <|-- PreTrainedModel
    EmbeddingAccessMixin <|-- PreTrainedModel
    ModuleUtilsMixin <|-- PreTrainedModel
    PushToHubMixin <|-- PreTrainedModel
    PeftAdapterMixin <|-- PreTrainedModel

1104〜1163行目のクラス属性は、モデルの能力を宣言するものです。特に重要なものを挙げます。

  • _supports_sdpa_supports_flash_attn_supports_flex_attn — このモデルが対応しているアテンションバックエンド
  • _no_split_modules — モデル並列時に単一デバイス上に留める必要があるモジュール
  • _tied_weights_keys — 重み共有の宣言(例:{"lm_head.weight": "model.embed_tokens.weight"}
  • _tp_plan_pp_plan — テンソル並列・パイプライン並列の戦略

これらのフラグはライブラリ全体の挙動を制御します。重みローダー、デバイスマップの計算、アテンションのディスパッチ——いずれもこれらの値を参照して動作します。

LlamaModel:forward passの全体像

LlamaModel のforward passは、decoder-only Transformerの標準的なパターンに従っています。

flowchart TD
    A["input_ids"] --> B["embed_tokens"]
    B --> C["inputs_embeds"]
    C --> D["Create causal mask"]
    C --> E["Compute RoPE embeddings"]
    D --> F["Decoder Layer 0"]
    E --> F
    F --> G["Decoder Layer 1"]
    G --> H["..."]
    H --> I["Decoder Layer N"]
    I --> J["RMSNorm"]
    J --> K["last_hidden_state"]

375行目のforwardメソッドは次の流れで処理を進めます。

  1. 埋め込みnn.Embedding でトークンをベクトルに変換
  2. キャッシュの初期化use_cache がTrueでキャッシュが未指定の場合、DynamicCache を新規作成
  3. position IDsの計算 — キャッシュ長から算出
  4. causal maskの生成masking_utils.pycreate_causal_mask() を使用
  5. RoPE位置埋め込みの計算 — 1回だけ計算し、全レイヤーに渡す
  6. デコーダー層の反復処理 — 各レイヤーがhidden states、mask、位置埋め込みを受け取る
  7. 最終ノームLlamaRMSNorm を適用

デコレーターのスタックも注目に値します。@merge_with_config_defaults はconfigからパラメーターを自動補完します。たとえば use_cache=Noneconfig.use_cache に解決されます。@capture_outputs はhidden statesやアテンションの出力記録システムを有効化します。

ヒント: 410行目self.layers[: self.config.num_hidden_layers] というスライスは冗長ではありません。config.num_hidden_layers を実際の層数より小さい値に設定することで動的な層の枝刈りが可能になり、early exitの実験などで活用できます。

AttentionInterfaceディスパッチシステム

Transformersはeager(手動のmatmul)、SDPA(torch.nn.functional.scaled_dot_product_attention)、FlashAttention、FlexAttentionという複数のアテンションバックエンドに対応しています。if/elseの分岐を重ねる代わりに、モデルは AttentionInterface ディスパッチシステムを使っています。

AttentionInterfaceGeneralInterface を継承しており、これは2段階のルックアップを持つ MutableMapping です。

flowchart TD
    A["LlamaAttention.forward()"] --> B["ALL_ATTENTION_FUNCTIONS.get_interface(<br/>config._attn_implementation,<br/>eager_attention_forward)"]
    B --> C{"Check local_mapping"}
    C -->|Found| D["Return local override"]
    C -->|Not found| E{"Check _global_mapping"}
    E -->|"'sdpa'"| F["sdpa_attention_forward"]
    E -->|"'flash_attention_2'"| G["flash_attention_forward"]
    E -->|"'flex_attention'"| H["flex_attention_forward"]
    E -->|Not found| I["Return eager default"]

4869行目 のグローバルシングルトン ALL_ATTENTION_FUNCTIONS が文字列キーをアテンション関数にマッピングしています。_global_mapping にはページドアテンションのバリアント("paged|flash_attention_2""paged|sdpa")も含まれます。

LlamaAttention.forward() では、ディスパッチはたった1行です。

attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface(
    self.config._attn_implementation, eager_attention_forward
)

GemmaのようにカスタムのSDPA実装が必要なモデルは、AttentionInterface のローカルインスタンスを作成して "sdpa" キーだけを上書きします。このローカルオーバーライドはそのモデルにのみ適用され、グローバルレジストリは変更されません。ローカル→グローバルの2段階ルックアップが GeneralInterface パターンの核心です。

Hubカーネルのホットスワップ

現代のTransformersが持つ最も革新的な機能の一つが、@use_kernel_forward_from_hub デコレーターによる実行時のカーネル置換です。

@use_kernel_forward_from_hub("RMSNorm")
class LlamaRMSNorm(nn.Module):
    def forward(self, hidden_states):
        # Pure PyTorch fallback implementation
        ...

関数に対しても同様に使えます。

@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
    # Pure PyTorch fallback
    ...
sequenceDiagram
    participant Code as LlamaRMSNorm
    participant Dec as @use_kernel_forward_from_hub
    participant Kernels as kernels library
    participant Hub as HuggingFace Hub

    Code->>Dec: Decorated with "RMSNorm"
    Dec->>Kernels: Check if kernels library available
    Kernels->>Hub: Fetch optimized kernel<br/>for current hardware
    Hub-->>Kernels: Triton/CUDA kernel
    Kernels-->>Dec: Replace forward() method
    Note over Code: Now uses GPU-optimized kernel<br/>Falls back to PyTorch if unavailable

52〜67行目 のデコレーターが LlamaRMSNorm に適用されており、145行目 では同様に apply_rotary_pos_emb にも適用されています。kernels ライブラリ(別パッケージ)がHubからハードウェア最適化済みの実装を取得し、forwardメソッドを透過的に置き換えます。kernels がインストールされていない場合、あるいは USE_HUB_KERNELS=NO が設定されている場合、このデコレーターは何もしません。

この設計によって、モデルのロジックとカーネルの最適化を明確に分離できます。モデルコードは読みやすい純粋なPyTorchのままで、本番環境ではTritonやCUDAカーネルが自動的に適用されます。

RoPEとマスキングユーティリティ

Rotary Position Embedding(RoPE)は、現代のほぼすべてのdecoderモデルで共有されています。LlamaRotaryEmbedding クラスがそのパターンを示しています。

self.rope_type = self.config.rope_parameters["rope_type"]
rope_init_fn = self.compute_default_rope_parameters
if self.rope_type != "default":
    rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)

modeling_rope_utils.pyROPE_INIT_FUNCTIONS は、standard、dynamic NTK、YaRN、LLaMA 3のθスケーリングなど、RoPEのバリアントを登録したレジストリです。rope typeは config.rope_parameters で設定するため、完全にconfigの選択として扱えます。

masking_utils.py モジュールが提供する create_causal_mask() は、現在のバックエンドに適したアテンションマスクを生成します。SDPAでは None を返すことがあります(SDPAが因果性を内部で処理するため)。FlexAttentionでは BlockMask を、eagarアテンションでは従来の加算型floatマスクを生成します。この抽象化によってモデルコードはシンプルに保たれます。モデルは create_causal_mask() を呼ぶだけで、あとは適切な処理が行われます。

LlamaDecoderLayerとGradientCheckpointingLayer

各デコーダー層は、nn.Module ではなく GradientCheckpointingLayer を継承しています。

class LlamaDecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

GradientCheckpointingLayer__call__ をオーバーライドし、勾配チェックポインティングが有効なときにforward passを横取りします。再計算中はキャッシュが使えないため自動的に無効化され、処理は torch.utils.checkpoint.checkpoint() を通じてルーティングされます。重要なのは、勾配チェックポインティングはモデルレベルではなく層レベルで制御されており、forwardメソッドの実装からは完全に隠蔽されているという点です。

LlamaDecoderLayerforward() は、pre-norm残差接続の古典的なパターンです。layernorm → アテンション → 残差加算 → layernorm → MLP → 残差加算という流れで処理します。返り値は hidden_states のみで、アテンション重みはモデルレベルの @capture_outputs デコレーターを通じて流れます。

ジェネリックヘッドクラスと1行モデル

このコードベースで最も洗練された設計パターンの一つが、ジェネリックヘッドシステムです。502〜509行目 でタスク固有のLLaMAモデルがどのように定義されているか見てみましょう。

class LlamaForSequenceClassification(GenericForSequenceClassification, LlamaPreTrainedModel): ...

class LlamaForQuestionAnswering(GenericForQuestionAnswering, LlamaPreTrainedModel):
    base_model_prefix = "transformer"  # 後方互換性のため、以前の `transformer` 属性名を維持

class LlamaForTokenClassification(GenericForTokenClassification, LlamaPreTrainedModel): ...

SequenceClassificationとTokenClassificationは本当に1行だけの空クラスです。LlamaForQuestionAnswering は後方互換性のために base_model_prefix"transformer" にオーバーライドしています。旧実装では self.model ではなく self.transformer が使用されていたためです。カラクリは多重継承にあります。GenericForSequenceClassification__init__(分類ヘッドの追加)、forward(ベースモデル+ヘッド+lossの実行)、およびすべての土台となる処理を提供します。LlamaPreTrainedModel はconfigクラスのバインディング、重みの初期化、そしてmixinチェーンを担います。

classDiagram
    class GenericForSequenceClassification {
        +base_model_prefix = "model"
        +__init__(config)
        +forward(input_ids, labels, ...)
    }
    class LlamaPreTrainedModel {
        +config: LlamaConfig
        +base_model_prefix = "model"
        +_supports_flash_attn = True
    }
    class LlamaForSequenceClassification {
        «empty body»
    }
    GenericForSequenceClassification <|-- LlamaForSequenceClassification
    LlamaPreTrainedModel <|-- LlamaForSequenceClassification

GenericForSequenceClassification.__init__AutoModel.from_config(config) を使ってベースモデルを動的に生成し、nn.Linear の分類ヘッドを追加します。base_model_prefix 属性(デフォルト値は "model")がベースモデルの格納先を指定します。LlamaPreTrainedModelbase_model_prefix = "model" を宣言しているため、すべてが正しく噛み合います。

このパターンは、各モデル実装に散在していた何百行もの重複コードを置き換えました。以前は分類・QA・トークン分類それぞれのforward処理がモデルごとにコピーされていましたが、今や任意のモデルに新しいタスクバリアントを追加するのは1行で済みます。

ヒント: ベースモデルの属性名が model ではなく transformer のように異なる場合は、PreTrainedModel サブクラスで base_model_prefix をオーバーライドしてください。ジェネリックヘッドは setattr(self, self.base_model_prefix, ...) を使って格納するため、正しく動作します。

まとめ

LLaMAのクラス階層を上から下まで整理すると次のようになります。

クラス 定義ファイル 役割
PreTrainedModel modeling_utils.py 基底クラス:重み、保存、mixin群
LlamaPreTrainedModel modeling_llama.py configバインディング、アテンションフラグ
LlamaModel modeling_llama.py コアのdecoderスタック
LlamaForCausalLM modeling_llama.py LMヘッド+GenerationMixin
LlamaForSequenceClassification modeling_llama.py 1行のジェネリックヘッド
LlamaDecoderLayer modeling_llama.py 単一のTransformerブロック
LlamaAttention modeling_llama.py ディスパッチ付きマルチヘッドアテンション
LlamaRMSNorm modeling_llama.py Hubカーネルホットスワップ付きノーム
LlamaRotaryEmbedding modeling_llama.py 設定可能なRoPE

次回の記事では、これらのクラスに対して from_pretrained() を呼び出したときに何が起きるかを追います。safetensorsシャードのダウンロード、metaデバイス上での初期化、量子化の適用、複数GPUへのディスパッチなど、重みローディングパイプラインの全体像を解説します。