Read OSS

从 Hub 到 GPU:权重加载流水线

高级

前置知识

  • 第 3 篇:模型结构与类层次体系
  • PyTorch state_dict、meta device 及内存管理
  • safetensors 文件格式基础
  • 量化概念入门(INT8、INT4、GPTQ、AWQ)

从 Hub 到 GPU:权重加载流水线

前几篇文章介绍了 Transformers 如何将 "meta-llama/Llama-2-7b-hf" 解析为 LlamaForCausalLM,以及这个类的内部结构。但 14 GB 的权重究竟是怎样以正确的格式落到 GPU 上的?PreTrainedModelfrom_pretrained() 方法拥有超过 60 个参数,负责统筹协调 safetensors 下载、meta device 初始化、checkpoint 格式转换、量化处理以及多 GPU 分发等全部环节。本文将完整追踪这一流水线。

from_pretrained() 的整体调度

from_pretrained() 的方法签名揭示了它所承担的职责范围:

def from_pretrained(
    cls,
    pretrained_model_name_or_path,
    *model_args,
    config=None,
    cache_dir=None,
    ignore_mismatched_sizes=False,
    force_download=False,
    local_files_only=False,
    token=None,
    revision="main",
    use_safetensors=None,
    weights_only=True,
    **kwargs,  # ~50 more via kwargs
):

**kwargs 中隐藏着大量参数,包括 device_maptorch_dtypequantization_configattn_implementationlow_cpu_mem_usage 等。整体流程如下:

sequenceDiagram
    participant User
    participant FP as from_pretrained()
    participant Hub as HuggingFace Hub
    participant QA as Quantizer Auto-Select
    participant DM as Device Map
    participant Init as Meta Device Init
    participant WL as Weight Loading
    participant Post as Post-Load Hooks

    User->>FP: from_pretrained("meta-llama/Llama-2-7b-hf")
    FP->>Hub: Download config.json
    Hub-->>FP: Config
    FP->>FP: Resolve config class
    FP->>QA: Select quantizer (if any)
    QA-->>FP: HfQuantizer or None
    FP->>DM: Compute device_map
    DM-->>FP: {"model.layers.0": 0, ...}
    FP->>Init: Initialize model on meta device
    Init-->>FP: Empty model skeleton
    FP->>WL: Load weights shard by shard
    WL->>WL: Apply WeightConverter/WeightRenaming
    WL->>WL: Quantize if needed
    WL->>WL: Place on target device
    FP->>Post: Post-weight-loading hooks
    Post-->>User: Ready model

该方法首先解析配置(必要时通过 AutoConfig),然后进入权重加载流程。LoadStateDictConfig dataclass 将所有加载参数打包成一个不可变对象统一传递:

@dataclass(frozen=True)
class LoadStateDictConfig:
    pretrained_model_name_or_path: str | None = None
    use_safetensors: bool | None = None
    ignore_mismatched_sizes: bool = False
    device_map: dict | None = None
    dtype: torch.dtype | None = None
    hf_quantizer: HfQuantizer | None = None
    weight_mapping: list[WeightConverter | WeightRenaming] | None = None
    # ...

这个配置对象贯穿整个权重加载流水线,避免了将单个参数逐层透传到十几个函数的反模式。

core_model_loading.py:权重转换引擎

core_model_loading.py 模块负责处理 checkpoint 格式差异。同一模型的不同版本可能以不同的键名或 tensor 布局存储权重。WeightConverterWeightRenaming 这两个 dataclass 以声明式的方式来描述这些转换规则。

WeightRenaming 处理简单的键名重命名:

@dataclass(slots=True)
class WeightRenaming(WeightTransform):
    # Maps source_patterns → target_patterns
    # e.g., "model.layers.*.attn.qkv.weight" → "model.layers.*.self_attn.q_proj.weight"

WeightConverter 则通过一组 ConversionOps 流水线完成实际的 tensor 变换:

flowchart LR
    A["Checkpoint tensor<br/>qkv_proj.weight"] --> B["Chunk(dim=0)"]
    B --> C["q_proj.weight"]
    B --> D["k_proj.weight"]
    B --> E["v_proj.weight"]

ConversionOps 层次体系包含 Chunk(拆分 tensor)、Concatenate(合并 tensor)以及针对量化格式的专用操作。每个操作都带有 reverse_op 属性,支持双向转换——这对加载和保存都至关重要。

转换系统使用 build_glob_alternation() 将所有权重模式编译为单个正则表达式,即便面对数百条转换规则,模式匹配依然高效。

提示: 权重转换器在各模型的模块中(或在其配置中)单独定义。如果你加载的 checkpoint 格式与模型原生格式不一致,可以查看模型类中的 _checkpoint_conversion_mappingweight_mapping 属性。

量化器自动选择

在传入 quantization_config 或加载预量化模型时,Transformers 会从 20 余个后端中自动选择合适的量化器。quantizers/auto.py 中的 AUTO_QUANTIZER_MAPPING 负责将方法名映射到对应的量化器类:

AUTO_QUANTIZER_MAPPING = {
    "awq": AwqQuantizer,
    "bitsandbytes_4bit": Bnb4BitHfQuantizer,
    "bitsandbytes_8bit": Bnb8BitHfQuantizer,
    "gptq": GptqHfQuantizer,
    "quanto": QuantoHfQuantizer,
    # ... 20+ more
}
flowchart TD
    A["quantization_config provided?"] -->|Yes| B["Extract quant_method"]
    A -->|No| C["Check config.json<br/>for quantization_config"]
    C -->|Found| B
    C -->|Not found| D["No quantization"]
    B --> E["AUTO_QUANTIZER_MAPPING[method]"]
    E --> F["HfQuantizer subclass"]
    F --> G["validate_environment()"]
    G --> H["preprocess_model()"]
    H --> I["Weight loading with<br/>quantization hooks"]
    I --> J["postprocess_model()"]

HfQuantizer 基类定义了包裹权重加载流程的生命周期钩子:

  • validate_environment() — 检查所需依赖包是否已安装
  • update_device_map() — 部分量化器会强制设置 device_map="auto"
  • _process_model_before_weight_loading() — 在 meta device 上将模块替换为量化变体
  • _process_model_after_weight_loading() — 完成量化权重的最终处理
  • param_needs_quantization() — 逐参数决策,判断哪些参数需要量化

get_keys_to_not_convert() 函数会自动将 LM head 和权重绑定层排除在量化范围之外,让输出层保持全精度,以确保数值稳定性。

设备映射与多 GPU 加载

对于单张 GPU 装不下的大模型,Transformers 集成了 Accelerate 的设备映射系统。integrations/accelerate.py 模块提供了这一桥接能力。

当传入 device_map="auto" 时,框架会执行以下步骤:

  1. 根据配置估算模型大小(无需加载权重)
  2. 查询各 GPU 的可用显存
  3. 调用 infer_auto_device_map() 为每个模块分配目标设备
  4. 遵守 _no_split_modules 约束——解码器层会完整保留在同一设备上
flowchart TD
    A["device_map='auto'"] --> B["Estimate model size<br/>from config"]
    B --> C["Query GPU memory"]
    C --> D["infer_auto_device_map()"]
    D --> E{"Model fits on<br/>single GPU?"}
    E -->|Yes| F["All on GPU 0"]
    E -->|No| G["Split across GPUs"]
    G --> H{"Still too large?"}
    H -->|Yes| I["Offload to CPU"]
    I --> J{"CPU OOM?"}
    J -->|Yes| K["Offload to disk"]
    J -->|No| L["device_map ready"]
    H -->|No| L
    F --> L

设备映射是一个将参数名映射到目标设备的字典,例如:{"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.16": 1, "lm_head": "cpu"}。在权重加载过程中,每个 tensor 直接被放置到目标设备,整个过程无需将完整模型加载到 CPU 内存。

safetensors 与分片逐步加载

现代 checkpoint 采用 safetensors 格式,该格式支持内存映射访问,同时规避了基于 pickle 的安全风险。权重通常分散在多个文件中(例如 model-00001-of-00003.safetensors)。

加载过程首先读取分片索引文件(model.safetensors.index.json),确定每个参数位于哪个分片,然后逐片加载。这对大模型至关重要——不再需要先将全部 14 GB 数据加载到 CPU 内存再进行分发,而是每个分片中的 tensor 直接落到目标设备:

sequenceDiagram
    participant FP as from_pretrained
    participant Idx as Shard Index
    participant S1 as Shard 1
    participant S2 as Shard 2
    participant GPU0 as GPU 0
    participant GPU1 as GPU 1

    FP->>Idx: Read index.json
    Idx-->>FP: param → shard mapping
    FP->>S1: Load shard 1 (memory-mapped)
    S1-->>FP: layers 0-15 weights
    FP->>GPU0: Place layers 0-15
    Note over FP: Shard 1 released from memory
    FP->>S2: Load shard 2 (memory-mapped)
    S2-->>FP: layers 16-31 weights
    FP->>GPU1: Place layers 16-31

use_safetensors 参数默认为 None,即自动检测:优先使用 safetensors,若不可用则回退到 PyTorch .bin 文件。safetensors_conversion.py 模块负责在保存时自动将 .bin 格式转换为 safetensors 格式。

提示: 将环境变量 SAFETENSORS_FAST_GPU=1 设置后,可以开启 safetensors 文件的 GPU 直接加载,完全绕过 CPU。此功能需要 safetensors 库版本 ≥ 0.4.0,在配备高速 NVMe 存储的系统上可以显著提升加载速度。

加载报告

所有权重加载完毕后,from_pretrained() 会输出一份详细报告,列明哪些键已成功加载、哪些键缺失(权重绑定情况下属于预期行为)、哪些键是多余的,以及哪些键存在尺寸不匹配的问题。LoadStateDictInfo dataclass 在整个加载过程中持续收集这些信息,最终报告有助于在微调 checkpoint 与基础模型结构出现偏差时快速定位问题。

该方法最后会调用 model.eval()——这是一个有意为之的默认行为,防止在推理时触发训练模式下的行为(如 dropout)。如果你需要进行微调,则必须显式调用 model.train()

文件目录

文件 用途
src/transformers/modeling_utils.py from_pretrained()LoadStateDictConfig
src/transformers/core_model_loading.py WeightConverterWeightRenaming、转换操作
src/transformers/quantizers/auto.py AUTO_QUANTIZER_MAPPING、量化器选择逻辑
src/transformers/quantizers/base.py HfQuantizer 基类及生命周期钩子
src/transformers/integrations/accelerate.py 设备映射计算、Accelerate 桥接层
src/transformers/safetensors_conversion.py 格式转换工具

模型已完整加载,接下来该生成文本了。下一篇文章将深入解析 generate() 方法——一个长达 1700 行的调度器,统一管理解码策略、KV 缓存、logits 处理、投机解码以及 token 流式输出。