Read OSS

Auto クラス:Transformers がモデル名をコードにマッピングする仕組み

中級

前提知識

  • 第1回:遅延ロードとインポートシステム
  • Python のデータクラスとクラス継承
  • HuggingFace Hub の基本的な知識(モデルリポジトリ、config.json)

Auto クラス:Transformers がモデル名をコードにマッピングする仕組み

AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") を呼び出すと、Transformers は 450 以上あるモデル実装の中からどれをインスタンス化すべきかを瞬時に判断しなければなりません。モデル名は Hub リポジトリを指す単なる文字列にすぎませんが、数秒のうちに config.json をダウンロードしてモデルタイプを特定します。そして LlamaForCausalLM だけをインポートし、重みを読み込みます。このディスパッチ機構が Auto クラスシステムです。第1回で見た遅延インポートの仕組みと、ライブラリ全体のモデルを駆動する設定階層の交差点に位置しています。

本記事では、手動メンテナンスされるマッピングレジストリから始めて、_LazyAutoMapping による遅延クラス解決を経て、すべてを統合する検証済みデータクラス PreTrainedConfig に至る解決チェーン全体を追っていきます。

3 つのマッピングレジストリ

Auto システムの基盤となるのは、model_type 文字列をキーとする 3 種類の OrderedDict マッピングです。これらはライブラリ全体における唯一の手動登録ポイントです。

CONFIG_MAPPING_NAMES はモデルタイプをコンフィグクラス名にマッピングします。

CONFIG_MAPPING_NAMES = OrderedDict([
    ("llama", "LlamaConfig"),
    ("bert", "BertConfig"),
    ("gpt2", "GPT2Config"),
    # ... 450+ entries
])

MODEL_MAPPING_NAMES はモデルタイプをベースモデルクラス名にマッピングします。さらに MODEL_FOR_CAUSAL_LM_MAPPING_NAMESMODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES のようなタスク別マッピングが 20 種類以上あります。

classDiagram
    class CONFIG_MAPPING_NAMES {
        "llama" → "LlamaConfig"
        "bert" → "BertConfig"
        "gpt2" → "GPT2Config"
    }
    class MODEL_MAPPING_NAMES {
        "llama" → "LlamaModel"
        "bert" → "BertModel"
        "gpt2" → "GPT2Model"
    }
    class MODEL_FOR_CAUSAL_LM_MAPPING_NAMES {
        "llama" → "LlamaForCausalLM"
        "gpt2" → "GPT2LMHeadModel"
    }
    CONFIG_MAPPING_NAMES --> MODEL_MAPPING_NAMES : model_type key
    CONFIG_MAPPING_NAMES --> MODEL_FOR_CAUSAL_LM_MAPPING_NAMES : model_type key

これらの辞書がクラスオブジェクトではなくクラス名の文字列を格納している点に注目してください。これは意図的な設計です。モデルコードをインポートせずにマッピングを定義できるため、実際のクラス解決は _LazyAutoMapping を通じて遅延して行われます。

ヒント: 新しいモデルをコントリビュートする際に手動で行う作業は、CONFIG_MAPPING_NAMES と該当する MODEL_FOR_*_MAPPING_NAMES にエントリを追加するだけです。それ以外はすべて第1回で説明した遅延インポートシステムが処理します。

_LazyAutoMapping:コンフィグからモデルクラスへの解決

_LazyAutoMapping クラスは、コンフィグクラスオブジェクトとモデルクラスオブジェクトの橋渡し役です。OrderedDict のサブクラスで、コンフィグとモデルの 2 つの名前ベースのマッピングを受け取り、遅延的に解決します。

sequenceDiagram
    participant User
    participant LAM as _LazyAutoMapping
    participant CM as CONFIG_MAPPING_NAMES
    participant MM as MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
    participant IMP as importlib

    User->>LAM: mapping[LlamaConfig]
    LAM->>LAM: _reverse_config_mapping["LlamaConfig"] → "llama"
    LAM->>MM: _model_mapping["llama"] → "LlamaForCausalLM"
    LAM->>LAM: model_type_to_module_name("llama") → "llama"
    LAM->>IMP: import_module(".llama", "transformers.models")
    IMP-->>LAM: llama module
    LAM->>LAM: getattr(module, "LlamaForCausalLM")
    LAM-->>User: LlamaForCausalLM class

__getitem__ メソッドは LlamaConfig のようなコンフィグクラスをキーとして受け取り、model_type 文字列に逆引きし、モデルクラス名を取得した上でモジュールを遅延インポートします。インポート済みのモジュールは self._modules にキャッシュされるため、2 回目以降の参照は即座に返ります。

_extra_content 辞書はランタイム登録をサポートしています。AutoModelForCausalLM.register(MyConfig, MyModel) を呼び出すと、静的な名前マッピングではなくこの辞書にマッピングが保存されます。

AutoModelForCausalLM.from_pretrained() の解決フロー

AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") を呼び出したときの処理を、最初から最後まで追ってみましょう。

sequenceDiagram
    participant User
    participant Auto as AutoModelForCausalLM
    participant Hub as HuggingFace Hub
    participant AC as AutoConfig
    participant LAM as _LazyAutoMapping
    participant Model as LlamaForCausalLM

    User->>Auto: from_pretrained("meta-llama/Llama-2-7b-hf")
    Auto->>Hub: Download config.json
    Hub-->>Auto: {"model_type": "llama", ...}
    Auto->>AC: AutoConfig.from_pretrained(...)
    AC->>AC: CONFIG_MAPPING["llama"] → LlamaConfig
    AC-->>Auto: LlamaConfig instance
    Auto->>Auto: Check trust_remote_code
    Auto->>LAM: _model_mapping[LlamaConfig]
    LAM-->>Auto: LlamaForCausalLM class
    Auto->>Model: LlamaForCausalLM.from_pretrained(...)
    Model-->>User: Loaded model

_BaseAutoModelClassfrom_pretrained メソッドはまずコンフィグを解決し(未指定の場合)、その後、解決されたモデルクラス自身の from_pretrained に処理を委譲します。肝心のクラス解決は _get_model_class で行われます。

def _get_model_class(config, model_mapping):
    supported_models = model_mapping[type(config)]
    if not isinstance(supported_models, (list, tuple)):
        return supported_models
    # If multiple models match, use config.architectures to disambiguate
    name_to_model = {model.__name__: model for model in supported_models}
    architectures = getattr(config, "architectures", [])
    for arch in architectures:
        if arch in name_to_model:
            return name_to_model[arch]
    return supported_models[0]

config.jsonarchitectures フィールドが重要になるのは、1 つの model_type が複数のモデルクラスに対応するケースです。たとえば LLaMA のコンフィグは LlamaModel にも LlamaForCausalLM にも対応しますが、architectures リスト(["LlamaForCausalLM"])によってどちらを使うかが明確になります。

PreTrainedConfig:検証済みデータクラス

すべてのモデルコンフィグは PreTrainedConfig を継承します。このクラスには @stricthuggingface_hub 由来)と @dataclass の両方のデコレータが付いています。

@strict(accept_kwargs=True)
@dataclass(repr=False)
class PreTrainedConfig(PushToHubMixin, RotaryEmbeddingConfigMixin):
    ...

@strict デコレータは、宣言済みのフィールド以外をセットできないよう強制します。hiden_size のようなタイポがあればコンストラクト時に即座に検出できます。accept_kwargs=True フラグは後方互換性のための逃げ道で、未知のキーワード引数はエラーにせず __post_init__ に渡され、サブクラスが独自に処理できるようになっています。

PreTrainedConfig のクラスメソッド from_pretrained は Hub から config.json をダウンロードして解析し、model_type フィールドを使って正しいサブクラスにディスパッチします。

具体的な例として LlamaConfig を見てみましょう。

@auto_docstring(checkpoint="meta-llama/Llama-2-7b-hf")
@strict
class LlamaConfig(PreTrainedConfig):
    model_type = "llama"
    
    base_model_tp_plan = {
        "layers.*.self_attn.q_proj": "colwise",
        "layers.*.self_attn.k_proj": "colwise",
        # ...
    }
    base_model_pp_plan = {
        "embed_tokens": (["input_ids"], ["inputs_embeds"]),
        "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
        "norm": (["hidden_states"], ["hidden_states"]),
    }
    
    vocab_size: int = 32000
    hidden_size: int = 4096
    num_hidden_layers: int = 32
    # ...
classDiagram
    class PreTrainedConfig {
        +model_type: str
        +architectures: list
        +name_or_path: str
        +from_pretrained()
        +save_pretrained()
        +to_dict()
    }
    class LlamaConfig {
        +model_type = "llama"
        +vocab_size: int = 32000
        +hidden_size: int = 4096
        +num_hidden_layers: int = 32
        +base_model_tp_plan: dict
        +base_model_pp_plan: dict
    }
    PreTrainedConfig <|-- LlamaConfig

クラスレベルの辞書が 2 つあります:base_model_tp_planbase_model_pp_plan です。これらはテンソル並列化とパイプライン並列化の戦略をコンフィグレベルで宣言するものです。モデルの並列実行計画がコンフィグだけで完全に定義され、コードを変更する必要がないという点が重要です。これらの計画が実際にどう使われるかは第3回で詳しく見ていきます。

auto_class_update とドキュメント生成

AutoModelFor* の派生クラスは CausalLM、SequenceClassification、TokenClassification、QuestionAnswering など 20 種類以上あり、そのままでは大量の定型コードが生まれかねません。auto_class_update() 関数はそれを解消します。

class AutoModelForCausalLM(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING

AutoModelForCausalLM = auto_class_update(
    AutoModelForCausalLM, head_doc="causal language modeling"
)

auto_class_update_BaseAutoModelClass から from_configfrom_pretrained をコピーし、ドキュメント文字列内のプレースホルダーをクラス名やタスク説明に置き換え、サポートするモデルのリストを動的に生成します。結果として各 Auto クラスには、対応するすべてのモデルを列挙した完全なドキュメントが自動生成されます。これもすべてマッピングレジストリから導出されます。

flowchart LR
    A["_BaseAutoModelClass<br/>(from_config, from_pretrained)"] --> B["auto_class_update()"]
    B --> C["Copy methods"]
    B --> D["Replace docstring<br/>placeholders"]
    B --> E["Inject model list<br/>from mapping"]
    C --> F["AutoModelForCausalLM"]
    D --> F
    E --> F

1971 行目 のクラス定義では、from_pretrained の戻り値の型アノテーションを _BaseModelWithGenerate にオーバーライドしています。これは PreTrainedModelGenerationMixin を組み合わせた合成型で、IDE のサポートを向上させるためのものです。

登録チェーンの全体像

model_type 文字列から実際に使える Python クラスに至るまでのチェーン全体をまとめましょう。

レイヤー 格納するもの キーの型 → 値の型
CONFIG_MAPPING_NAMES クラス名の文字列 strstr
MODEL_FOR_*_MAPPING_NAMES クラス名の文字列 strstr
_LazyAutoMapping 遅延クラス解決 type[Config]type[Model]
AutoModelForCausalLM タスク別マッピング ユーザー向け API

ここでの重要な設計判断は、名前ベースの登録OrderedDict)とクラスベースの解決_LazyAutoMapping)を分離している点です。名前ベースで登録することで、モデルのコードをインポートせずに追加できます。クラスベースで解決することで、実際のインポートはそのモデルが使われる瞬間まで遅延されます。これは第1回で取り上げた遅延インポートの思想を、モデルディスパッチ層に適用したものです。

ヒント: モデル解決の問題をデバッグするときは AutoModelForCausalLM._model_mapping を確認しましょう。イテレートすることで登録済みの (config_class, model_class) ペアをすべて確認できます。マッピングのルックアップ前に print(type(config)) を挿入すると、コンフィグクラスの不一致を発見しやすくなります。

次回予告

Transformers がモデル名から正しいクラスを見つける仕組みは理解できました。では、そのクラスの内部はどうなっているのでしょうか?次回は LlamaForCausalLM の中身を解剖します。PreTrainedModel のミックスインチェーンから、eager・SDPA・FlashAttention・FlexAttention のバックエンドを切り替えるアテンションカーネルのディスパッチシステムまで追っていきます。