从 Hub 到 GPU:权重加载流水线
前置知识
- ›第 3 篇:模型结构与类层次体系
- ›PyTorch state_dict、meta device 及内存管理
- ›safetensors 文件格式基础
- ›量化概念入门(INT8、INT4、GPTQ、AWQ)
从 Hub 到 GPU:权重加载流水线
前几篇文章介绍了 Transformers 如何将 "meta-llama/Llama-2-7b-hf" 解析为 LlamaForCausalLM,以及这个类的内部结构。但 14 GB 的权重究竟是怎样以正确的格式落到 GPU 上的?PreTrainedModel 的 from_pretrained() 方法拥有超过 60 个参数,负责统筹协调 safetensors 下载、meta device 初始化、checkpoint 格式转换、量化处理以及多 GPU 分发等全部环节。本文将完整追踪这一流水线。
from_pretrained() 的整体调度
from_pretrained() 的方法签名揭示了它所承担的职责范围:
def from_pretrained(
cls,
pretrained_model_name_or_path,
*model_args,
config=None,
cache_dir=None,
ignore_mismatched_sizes=False,
force_download=False,
local_files_only=False,
token=None,
revision="main",
use_safetensors=None,
weights_only=True,
**kwargs, # ~50 more via kwargs
):
**kwargs 中隐藏着大量参数,包括 device_map、torch_dtype、quantization_config、attn_implementation、low_cpu_mem_usage 等。整体流程如下:
sequenceDiagram
participant User
participant FP as from_pretrained()
participant Hub as HuggingFace Hub
participant QA as Quantizer Auto-Select
participant DM as Device Map
participant Init as Meta Device Init
participant WL as Weight Loading
participant Post as Post-Load Hooks
User->>FP: from_pretrained("meta-llama/Llama-2-7b-hf")
FP->>Hub: Download config.json
Hub-->>FP: Config
FP->>FP: Resolve config class
FP->>QA: Select quantizer (if any)
QA-->>FP: HfQuantizer or None
FP->>DM: Compute device_map
DM-->>FP: {"model.layers.0": 0, ...}
FP->>Init: Initialize model on meta device
Init-->>FP: Empty model skeleton
FP->>WL: Load weights shard by shard
WL->>WL: Apply WeightConverter/WeightRenaming
WL->>WL: Quantize if needed
WL->>WL: Place on target device
FP->>Post: Post-weight-loading hooks
Post-->>User: Ready model
该方法首先解析配置(必要时通过 AutoConfig),然后进入权重加载流程。LoadStateDictConfig dataclass 将所有加载参数打包成一个不可变对象统一传递:
@dataclass(frozen=True)
class LoadStateDictConfig:
pretrained_model_name_or_path: str | None = None
use_safetensors: bool | None = None
ignore_mismatched_sizes: bool = False
device_map: dict | None = None
dtype: torch.dtype | None = None
hf_quantizer: HfQuantizer | None = None
weight_mapping: list[WeightConverter | WeightRenaming] | None = None
# ...
这个配置对象贯穿整个权重加载流水线,避免了将单个参数逐层透传到十几个函数的反模式。
core_model_loading.py:权重转换引擎
core_model_loading.py 模块负责处理 checkpoint 格式差异。同一模型的不同版本可能以不同的键名或 tensor 布局存储权重。WeightConverter 和 WeightRenaming 这两个 dataclass 以声明式的方式来描述这些转换规则。
WeightRenaming 处理简单的键名重命名:
@dataclass(slots=True)
class WeightRenaming(WeightTransform):
# Maps source_patterns → target_patterns
# e.g., "model.layers.*.attn.qkv.weight" → "model.layers.*.self_attn.q_proj.weight"
WeightConverter 则通过一组 ConversionOps 流水线完成实际的 tensor 变换:
flowchart LR
A["Checkpoint tensor<br/>qkv_proj.weight"] --> B["Chunk(dim=0)"]
B --> C["q_proj.weight"]
B --> D["k_proj.weight"]
B --> E["v_proj.weight"]
ConversionOps 层次体系包含 Chunk(拆分 tensor)、Concatenate(合并 tensor)以及针对量化格式的专用操作。每个操作都带有 reverse_op 属性,支持双向转换——这对加载和保存都至关重要。
转换系统使用 build_glob_alternation() 将所有权重模式编译为单个正则表达式,即便面对数百条转换规则,模式匹配依然高效。
提示: 权重转换器在各模型的模块中(或在其配置中)单独定义。如果你加载的 checkpoint 格式与模型原生格式不一致,可以查看模型类中的
_checkpoint_conversion_mapping或weight_mapping属性。
量化器自动选择
在传入 quantization_config 或加载预量化模型时,Transformers 会从 20 余个后端中自动选择合适的量化器。quantizers/auto.py 中的 AUTO_QUANTIZER_MAPPING 负责将方法名映射到对应的量化器类:
AUTO_QUANTIZER_MAPPING = {
"awq": AwqQuantizer,
"bitsandbytes_4bit": Bnb4BitHfQuantizer,
"bitsandbytes_8bit": Bnb8BitHfQuantizer,
"gptq": GptqHfQuantizer,
"quanto": QuantoHfQuantizer,
# ... 20+ more
}
flowchart TD
A["quantization_config provided?"] -->|Yes| B["Extract quant_method"]
A -->|No| C["Check config.json<br/>for quantization_config"]
C -->|Found| B
C -->|Not found| D["No quantization"]
B --> E["AUTO_QUANTIZER_MAPPING[method]"]
E --> F["HfQuantizer subclass"]
F --> G["validate_environment()"]
G --> H["preprocess_model()"]
H --> I["Weight loading with<br/>quantization hooks"]
I --> J["postprocess_model()"]
HfQuantizer 基类定义了包裹权重加载流程的生命周期钩子:
validate_environment()— 检查所需依赖包是否已安装update_device_map()— 部分量化器会强制设置device_map="auto"_process_model_before_weight_loading()— 在 meta device 上将模块替换为量化变体_process_model_after_weight_loading()— 完成量化权重的最终处理param_needs_quantization()— 逐参数决策,判断哪些参数需要量化
get_keys_to_not_convert() 函数会自动将 LM head 和权重绑定层排除在量化范围之外,让输出层保持全精度,以确保数值稳定性。
设备映射与多 GPU 加载
对于单张 GPU 装不下的大模型,Transformers 集成了 Accelerate 的设备映射系统。integrations/accelerate.py 模块提供了这一桥接能力。
当传入 device_map="auto" 时,框架会执行以下步骤:
- 根据配置估算模型大小(无需加载权重)
- 查询各 GPU 的可用显存
- 调用
infer_auto_device_map()为每个模块分配目标设备 - 遵守
_no_split_modules约束——解码器层会完整保留在同一设备上
flowchart TD
A["device_map='auto'"] --> B["Estimate model size<br/>from config"]
B --> C["Query GPU memory"]
C --> D["infer_auto_device_map()"]
D --> E{"Model fits on<br/>single GPU?"}
E -->|Yes| F["All on GPU 0"]
E -->|No| G["Split across GPUs"]
G --> H{"Still too large?"}
H -->|Yes| I["Offload to CPU"]
I --> J{"CPU OOM?"}
J -->|Yes| K["Offload to disk"]
J -->|No| L["device_map ready"]
H -->|No| L
F --> L
设备映射是一个将参数名映射到目标设备的字典,例如:{"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.16": 1, "lm_head": "cpu"}。在权重加载过程中,每个 tensor 直接被放置到目标设备,整个过程无需将完整模型加载到 CPU 内存。
safetensors 与分片逐步加载
现代 checkpoint 采用 safetensors 格式,该格式支持内存映射访问,同时规避了基于 pickle 的安全风险。权重通常分散在多个文件中(例如 model-00001-of-00003.safetensors)。
加载过程首先读取分片索引文件(model.safetensors.index.json),确定每个参数位于哪个分片,然后逐片加载。这对大模型至关重要——不再需要先将全部 14 GB 数据加载到 CPU 内存再进行分发,而是每个分片中的 tensor 直接落到目标设备:
sequenceDiagram
participant FP as from_pretrained
participant Idx as Shard Index
participant S1 as Shard 1
participant S2 as Shard 2
participant GPU0 as GPU 0
participant GPU1 as GPU 1
FP->>Idx: Read index.json
Idx-->>FP: param → shard mapping
FP->>S1: Load shard 1 (memory-mapped)
S1-->>FP: layers 0-15 weights
FP->>GPU0: Place layers 0-15
Note over FP: Shard 1 released from memory
FP->>S2: Load shard 2 (memory-mapped)
S2-->>FP: layers 16-31 weights
FP->>GPU1: Place layers 16-31
use_safetensors 参数默认为 None,即自动检测:优先使用 safetensors,若不可用则回退到 PyTorch .bin 文件。safetensors_conversion.py 模块负责在保存时自动将 .bin 格式转换为 safetensors 格式。
提示: 将环境变量
SAFETENSORS_FAST_GPU=1设置后,可以开启 safetensors 文件的 GPU 直接加载,完全绕过 CPU。此功能需要safetensors库版本 ≥ 0.4.0,在配备高速 NVMe 存储的系统上可以显著提升加载速度。
加载报告
所有权重加载完毕后,from_pretrained() 会输出一份详细报告,列明哪些键已成功加载、哪些键缺失(权重绑定情况下属于预期行为)、哪些键是多余的,以及哪些键存在尺寸不匹配的问题。LoadStateDictInfo dataclass 在整个加载过程中持续收集这些信息,最终报告有助于在微调 checkpoint 与基础模型结构出现偏差时快速定位问题。
该方法最后会调用 model.eval()——这是一个有意为之的默认行为,防止在推理时触发训练模式下的行为(如 dropout)。如果你需要进行微调,则必须显式调用 model.train()。
文件目录
| 文件 | 用途 |
|---|---|
src/transformers/modeling_utils.py |
from_pretrained()、LoadStateDictConfig |
src/transformers/core_model_loading.py |
WeightConverter、WeightRenaming、转换操作 |
src/transformers/quantizers/auto.py |
AUTO_QUANTIZER_MAPPING、量化器选择逻辑 |
src/transformers/quantizers/base.py |
HfQuantizer 基类及生命周期钩子 |
src/transformers/integrations/accelerate.py |
设备映射计算、Accelerate 桥接层 |
src/transformers/safetensors_conversion.py |
格式转换工具 |
模型已完整加载,接下来该生成文本了。下一篇文章将深入解析 generate() 方法——一个长达 1700 行的调度器,统一管理解码策略、KV 缓存、logits 处理、投机解码以及 token 流式输出。