Inside a Model: From LlamaConfig to LlamaForCausalLM
Prerequisites
- ›Article 2: Auto classes and config system
- ›PyTorch nn.Module, forward pass, state_dict
- ›Transformer architecture: multi-head attention, KV-cache, positional embeddings
- ›Basic understanding of gradient checkpointing
Inside a Model: From LlamaConfig to LlamaForCausalLM
In the previous articles, we saw how import transformers defers all imports and how the Auto class system maps a model name to the right Python class. Now it's time to look inside that class. Using LLaMA as our concrete case study, this article traces the full model hierarchy from PreTrainedModel down to individual attention layers, revealing the shared infrastructure that makes 450+ model implementations consistent yet flexible.
The modern Transformers model is more than a stack of nn.Modules. It's a carefully designed class hierarchy with pluggable attention backends, Hub-sourced kernel hot-swapping, declarative parallelism plans, and generic head classes that reduce task-specific models to one line of code.
PreTrainedModel: The Base Class
Every model in Transformers inherits from PreTrainedModel, which itself is a richly composed class:
classDiagram
class nn_Module {
+forward()
+parameters()
+state_dict()
}
class EmbeddingAccessMixin {
+get_input_embeddings()
+set_input_embeddings()
+get_output_embeddings()
}
class ModuleUtilsMixin {
+num_parameters()
+estimate_tokens()
+floating_point_ops()
}
class PushToHubMixin {
+push_to_hub()
}
class PeftAdapterMixin {
+load_adapter()
+set_adapter()
+disable_adapters()
}
class PreTrainedModel {
+config_class: type
+base_model_prefix: str
+_supports_sdpa: bool
+_supports_flash_attn: bool
+_supports_flex_attn: bool
+_tp_plan: dict
+from_pretrained()
+save_pretrained()
+post_init()
}
nn_Module <|-- PreTrainedModel
EmbeddingAccessMixin <|-- PreTrainedModel
ModuleUtilsMixin <|-- PreTrainedModel
PushToHubMixin <|-- PreTrainedModel
PeftAdapterMixin <|-- PreTrainedModel
The class attributes at lines 1104–1163 are a declaration of the model's capabilities. Key ones include:
_supports_sdpa,_supports_flash_attn,_supports_flex_attn— which attention backends this model can use_no_split_modules— modules that must stay on a single device during model parallelism_tied_weights_keys— weight tying declarations (e.g.,{"lm_head.weight": "model.embed_tokens.weight"})_tp_planand_pp_plan— tensor and pipeline parallelism strategies
These flags drive behavior throughout the library — the weight loader, device map computation, and attention dispatch all read them.
LlamaModel: The Forward Pass
The LlamaModel forward pass is the canonical decoder-only transformer pattern:
flowchart TD
A["input_ids"] --> B["embed_tokens"]
B --> C["inputs_embeds"]
C --> D["Create causal mask"]
C --> E["Compute RoPE embeddings"]
D --> F["Decoder Layer 0"]
E --> F
F --> G["Decoder Layer 1"]
G --> H["..."]
H --> I["Decoder Layer N"]
I --> J["RMSNorm"]
J --> K["last_hidden_state"]
The forward method at line 375 follows this flow:
- Embed tokens via
nn.Embedding - Initialize cache — if
use_cacheis True and no cache provided, create aDynamicCache - Compute position IDs from cache length
- Create causal mask via
create_causal_mask()frommasking_utils.py - Compute RoPE position embeddings once, passed to all layers
- Iterate decoder layers — each layer receives hidden states, mask, and position embeddings
- Final norm via
LlamaRMSNorm
The decorator stack is worth noting: @merge_with_config_defaults auto-fills parameters from config (so use_cache=None resolves to config.use_cache), and @capture_outputs enables the output recording system for hidden states and attentions.
Tip: The slice
self.layers[: self.config.num_hidden_layers]at line 410 is not redundant — it enables dynamic layer pruning by settingconfig.num_hidden_layersto less than the total, useful for early exit experiments.
The AttentionInterface Dispatch System
Transformers supports multiple attention backends: eager (manual matmul), SDPA (torch.nn.functional.scaled_dot_product_attention), FlashAttention, and FlexAttention. Rather than if/else chains, models use the AttentionInterface dispatch system.
AttentionInterface extends GeneralInterface, a MutableMapping with two-level lookup:
flowchart TD
A["LlamaAttention.forward()"] --> B["ALL_ATTENTION_FUNCTIONS.get_interface(<br/>config._attn_implementation,<br/>eager_attention_forward)"]
B --> C{"Check local_mapping"}
C -->|Found| D["Return local override"]
C -->|Not found| E{"Check _global_mapping"}
E -->|"'sdpa'"| F["sdpa_attention_forward"]
E -->|"'flash_attention_2'"| G["flash_attention_forward"]
E -->|"'flex_attention'"| H["flex_attention_forward"]
E -->|Not found| I["Return eager default"]
The global singleton ALL_ATTENTION_FUNCTIONS at line 4869 maps string keys to attention functions. The _global_mapping includes paged attention variants too ("paged|flash_attention_2", "paged|sdpa").
In LlamaAttention.forward(), the dispatch is a single line:
attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
If a model like Gemma needs a custom SDPA implementation, it creates a local AttentionInterface instance and overrides just the "sdpa" key. This local override only affects that model — the global registry is unchanged. The two-level lookup (local → global) is the GeneralInterface pattern.
Hub Kernel Hot-Swapping
One of the most innovative features in modern Transformers is runtime kernel replacement via the @use_kernel_forward_from_hub decorator:
@use_kernel_forward_from_hub("RMSNorm")
class LlamaRMSNorm(nn.Module):
def forward(self, hidden_states):
# Pure PyTorch fallback implementation
...
And for functions:
@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
# Pure PyTorch fallback
...
sequenceDiagram
participant Code as LlamaRMSNorm
participant Dec as @use_kernel_forward_from_hub
participant Kernels as kernels library
participant Hub as HuggingFace Hub
Code->>Dec: Decorated with "RMSNorm"
Dec->>Kernels: Check if kernels library available
Kernels->>Hub: Fetch optimized kernel<br/>for current hardware
Hub-->>Kernels: Triton/CUDA kernel
Kernels-->>Dec: Replace forward() method
Note over Code: Now uses GPU-optimized kernel<br/>Falls back to PyTorch if unavailable
The decorator at lines 52–67 is applied to LlamaRMSNorm, and line 145 shows apply_rotary_pos_emb decorated similarly. The kernels library (a separate package) fetches hardware-optimized implementations from the Hub and transparently replaces the forward method. If kernels isn't installed or USE_HUB_KERNELS=NO is set, the decorator is a no-op.
This design separates model logic from kernel optimization — model code stays readable pure PyTorch, while production deployments get Triton or CUDA kernels automatically.
RoPE and Masking Utilities
Rotary Position Embeddings are shared across virtually all modern decoder models. The LlamaRotaryEmbedding class demonstrates the pattern:
self.rope_type = self.config.rope_parameters["rope_type"]
rope_init_fn = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
ROPE_INIT_FUNCTIONS from modeling_rope_utils.py is a registry of RoPE variants — standard, dynamic NTK, YaRN, LLaMA 3's theta scaling, and more. The rope type is configured in config.rope_parameters, making it a pure configuration choice.
The masking_utils.py module provides create_causal_mask(), which creates the attention mask appropriate for the current backend. For SDPA, this might be None (SDPA handles causality internally). For FlexAttention, it creates a BlockMask. For eager attention, it creates the traditional additive float mask. This abstraction keeps model code clean — models just call create_causal_mask() and the right thing happens.
LlamaDecoderLayer and GradientCheckpointingLayer
Each decoder layer inherits from GradientCheckpointingLayer rather than plain nn.Module:
class LlamaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx):
super().__init__()
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
The GradientCheckpointingLayer overrides __call__ to intercept the forward pass when gradient checkpointing is enabled. It automatically disables caching (since you can't cache during recomputation) and routes through torch.utils.checkpoint.checkpoint(). The key insight is that gradient checkpointing is controlled at the layer level, not the model level — and it's transparent to the forward method implementation.
The forward() of LlamaDecoderLayer is the classic pre-norm residual pattern: layernorm → attention → residual add → layernorm → MLP → residual add. It returns only hidden_states — attention weights flow through the @capture_outputs decorator at the model level.
Generic Head Classes and One-Liner Models
Perhaps the most elegant design pattern in the codebase is the generic head system. Look at how task-specific LLaMA models are defined at lines 502–509:
class LlamaForSequenceClassification(GenericForSequenceClassification, LlamaPreTrainedModel): ...
class LlamaForQuestionAnswering(GenericForQuestionAnswering, LlamaPreTrainedModel):
base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
class LlamaForTokenClassification(GenericForTokenClassification, LlamaPreTrainedModel): ...
The SequenceClassification and TokenClassification variants are truly single-line empty bodies. LlamaForQuestionAnswering overrides base_model_prefix to "transformer" for backward compatibility — the old implementation used self.transformer rather than self.model. The magic is in multiple inheritance: GenericForSequenceClassification provides the __init__ (adds a classification head), forward (runs base model + head + loss), and all the scaffolding. LlamaPreTrainedModel provides config class binding, weight initialization, and the mixin chain.
classDiagram
class GenericForSequenceClassification {
+base_model_prefix = "model"
+__init__(config)
+forward(input_ids, labels, ...)
}
class LlamaPreTrainedModel {
+config: LlamaConfig
+base_model_prefix = "model"
+_supports_flash_attn = True
}
class LlamaForSequenceClassification {
«empty body»
}
GenericForSequenceClassification <|-- LlamaForSequenceClassification
LlamaPreTrainedModel <|-- LlamaForSequenceClassification
The GenericForSequenceClassification.__init__ creates the base model dynamically using AutoModel.from_config(config) and adds a nn.Linear classification head. The base_model_prefix attribute (defaulting to "model") tells it where to store the base model. Since LlamaPreTrainedModel also declares base_model_prefix = "model", everything aligns.
This pattern has replaced hundreds of lines of duplicated code across model implementations. Before it existed, every model had its own copy of the classification/QA/token-classification forward logic. Now, adding a new task variant for any model is a one-liner.
Tip: If your model uses a different attribute name for the base model (e.g.,
transformerinstead ofmodel), overridebase_model_prefixin yourPreTrainedModelsubclass. The generic heads will usesetattr(self, self.base_model_prefix, ...)to store it correctly.
Putting It All Together
Here's the complete LLaMA class hierarchy from top to bottom:
| Class | Defined in | Role |
|---|---|---|
PreTrainedModel |
modeling_utils.py |
Base: weights, saving, mixins |
LlamaPreTrainedModel |
modeling_llama.py |
Config binding, attention flags |
LlamaModel |
modeling_llama.py |
Core decoder stack |
LlamaForCausalLM |
modeling_llama.py |
LM head + GenerationMixin |
LlamaForSequenceClassification |
modeling_llama.py |
One-liner generic head |
LlamaDecoderLayer |
modeling_llama.py |
Single transformer block |
LlamaAttention |
modeling_llama.py |
Multi-head attention with dispatch |
LlamaRMSNorm |
modeling_llama.py |
Norm with Hub kernel hot-swap |
LlamaRotaryEmbedding |
modeling_llama.py |
Configurable RoPE |
In the next article, we'll trace what happens when you call from_pretrained() on one of these classes — the weight loading pipeline that downloads safetensors shards, initializes on meta device, applies quantization, and dispatches across multiple GPUs.