[from_pretrained] Small refactor from_pretrained: move around unrelated stuff (#41445)

* drafts

* up

* simplify modeling utils

* more simplifications

* type kwargs

* up

* move more accelerate related stuff

* safeguarding?

* nits

* remove func when func is NOPE

* more

* nits

* styling

* yups

* up

* ups

* revert

* protect trainer utils iport

* fix doc

* Update src/transformers/integrations/peft.py

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>

* review

* update

* ?

* fixx

* update

* super small update

* ups

* style

* this is stupid

* 🤦 well this was the issue

* small nit

* fix

* nit

* damn the missing return

* one last stupid fix

---------

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
Arthur
2025-10-13 16:33:32 +02:00
committed by GitHub
parent cad74496ca
commit 1ee3b288a6
21 changed files with 619 additions and 540 deletions

View File

@ -42,7 +42,3 @@ set this to `False`.
## Pushing to the Hub
[[autodoc]] utils.PushToHubMixin
## Sharded checkpoints
[[autodoc]] modeling_utils.load_sharded_checkpoint

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import inspect
import os
import warnings
@ -365,6 +366,58 @@ class GenerationMixin(ContinuousMixin):
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
"""
def adjust_generation_fn(
self,
generation_config,
from_auto_class,
from_pipeline,
pretrained_model_name_or_path,
cache_dir,
force_download,
proxies,
local_files_only,
token,
revision,
subfolder,
trust_remote_code,
**kwargs,
):
if self.can_generate() and generation_config is not None:
self.generation_config = self.generation_config.from_dict(generation_config.to_dict())
elif self.can_generate() and pretrained_model_name_or_path is not None:
repo_loading_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"revision": revision,
"subfolder": subfolder,
**kwargs,
}
# Load generation config
try:
self.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**repo_loading_kwargs,
)
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
pass
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
if hasattr(self, "load_custom_generate"):
try:
custom_generate = self.load_custom_generate(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
)
self.generate = functools.partial(custom_generate, model=self)
except OSError: # there is no custom generate function
pass
def load_custom_generate(
self,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,

View File

@ -20,15 +20,24 @@ The `init_empty_weights` and `init_on_device` functions were copied from `accele
`find_tied_parameters` was copied from `accelerate.utils.modeling.py`
"""
import collections
import inspect
import os
from contextlib import contextmanager
from ..utils import is_torch_available, logging
from ..utils import is_accelerate_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationMethod
from .deepspeed import is_deepspeed_zero3_enabled
from .fsdp import is_fsdp_enabled
if is_torch_available():
import torch
import torch.nn as nn
if is_accelerate_available():
from accelerate import dispatch_model
logger = logging.get_logger(__name__)
@ -194,3 +203,150 @@ def find_tied_parameters(model: "nn.Module", **kwargs):
tied_param_groups[param_name].append(tied_param_name)
return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]
def check_and_set_device_map(device_map):
from ..modeling_utils import get_torch_context_manager_or_global_device
# Potentially detect context manager or global device, and use it (only if no device_map was provided)
if device_map is None and not is_deepspeed_zero3_enabled():
device_in_context = get_torch_context_manager_or_global_device()
if device_in_context == torch.device("meta"):
raise RuntimeError(
"You are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`.\n"
"This is an anti-pattern as `from_pretrained` wants to load existing weights.\nIf you want to initialize an "
"empty model on the meta device, use the context manager or global device with `from_config`, or `ModelClass(config)`"
)
device_map = device_in_context
# change device_map into a map if we passed an int, a str or a torch.device
if isinstance(device_map, torch.device):
device_map = {"": device_map}
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
try:
device_map = {"": torch.device(device_map)}
except RuntimeError:
raise ValueError(
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
)
elif isinstance(device_map, int):
if device_map < 0:
raise ValueError(
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
)
else:
device_map = {"": device_map}
if device_map is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
if not is_accelerate_available():
raise ValueError(
"Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` "
"requires `accelerate`. You can install it with `pip install accelerate`"
)
return device_map
def accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers):
device_map_kwargs = {
"device_map": device_map,
"offload_dir": offload_folder,
"offload_index": offload_index,
"offload_buffers": offload_buffers,
}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
# For HQQ method we force-set the hooks for single GPU envs
if (
"force_hooks" in inspect.signature(dispatch_model).parameters
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
):
device_map_kwargs["force_hooks"] = True
if (
hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
device_map_kwargs["offload_buffers"] = True
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
dispatch_model(model, **device_map_kwargs)
def get_disk_only_shard_files(device_map, weight_map):
"""
Returns the list of shard files containing only weights offloaded to disk.
"""
files_content = collections.defaultdict(list)
for weight_name, filename in weight_map.items():
while len(weight_name) > 0 and weight_name not in device_map:
weight_name = ".".join(weight_name.split(".")[:-1])
files_content[filename].append(device_map[weight_name])
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
def expand_device_map(device_map, param_names):
"""
Expand a device map to return the correspondence parameter name to device.
"""
new_device_map = {}
for module, device in device_map.items():
new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
)
return new_device_map
def accelerate_disk_offload(
disk_offload_folder,
checkpoint_files,
device_map,
checkpoint_keys,
key_renaming_mapping,
sharded_metadata,
dtype,
reverse_key_renaming_mapping,
):
disk_only_shard_files = []
if disk_offload_folder is not None:
os.makedirs(disk_offload_folder, exist_ok=True)
is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")
if disk_offload_folder is None and not is_offloaded_safetensors:
raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
if is_offloaded_safetensors:
param_device_map = expand_device_map(device_map, checkpoint_keys)
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
else:
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
# Fix the weight map keys according to the key mapping
weight_map = {
key_renaming_mapping[k]: v
for k, v in sharded_metadata["weight_map"].items()
if k in key_renaming_mapping
}
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
# Find potential checkpoints containing only offloaded weights
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
disk_offload_index = {
name: {
"safetensors_file": file,
"weight_name": reverse_key_renaming_mapping[name],
"dtype": str_dtype,
}
for name, file in weight_map.items()
if param_device_map[name] == "disk"
}
else:
disk_offload_index = {}
return disk_offload_index, disk_only_shard_files, is_offloaded_safetensors

View File

@ -12,21 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import importlib.metadata
import inspect
import json
import os
import re
from typing import Any, Optional, Union
from packaging import version
from ..utils import (
CONFIG_NAME,
cached_file,
check_peft_version,
extract_commit_hash,
find_adapter_config_file,
is_accelerate_available,
is_peft_available,
is_torch_available,
logging,
)
from ..utils.hub import DownloadKwargs
if is_torch_available():
@ -249,7 +255,7 @@ class PeftAdapterMixin:
else:
new_key = key
if key_mapping:
if key_mapping: # TODO dynamic weight loader for adapters
for pattern, replacement in key_mapping.items():
new_key, n_replace = re.subn(pattern, replacement, new_key)
# Early exit of the loop
@ -614,3 +620,54 @@ class PeftAdapterMixin:
if len(self.peft_config) == 0:
del self.peft_config
self._hf_peft_config_loaded = False
def maybe_load_adapters(
pretrained_model_name_or_path,
download_kwargs: DownloadKwargs,
**adapter_kwargs,
):
if pretrained_model_name_or_path is None or not is_peft_available():
return None, pretrained_model_name_or_path
token = download_kwargs.get("token")
if download_kwargs.get("commit_hash") is None:
resolved_config_file = cached_file(
pretrained_model_name_or_path,
CONFIG_NAME,
cache_dir=download_kwargs.get("cache_dir"),
force_download=bool(download_kwargs.get("force_download", False)),
proxies=download_kwargs.get("proxies"),
local_files_only=bool(download_kwargs.get("local_files_only", False)),
token=token,
revision=download_kwargs.get("revision"),
subfolder=download_kwargs.get("subfolder"),
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
download_kwargs["commit_hash"] = extract_commit_hash(resolved_config_file, None)
_adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)
if _adapter_model_path is None:
_adapter_model_path = find_adapter_config_file(
pretrained_model_name_or_path,
cache_dir=download_kwargs.get("cache_dir"),
force_download=bool(download_kwargs.get("force_download", False)),
proxies=download_kwargs.get("proxies"),
token=token,
revision=download_kwargs.get("revision"),
local_files_only=bool(download_kwargs.get("local_files_only", False)),
subfolder=download_kwargs.get("subfolder", ""),
_commit_hash=download_kwargs.get("commit_hash"),
**adapter_kwargs,
)
if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
with open(_adapter_model_path, "r", encoding="utf-8") as f:
_adapter_model_path = pretrained_model_name_or_path
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
return _adapter_model_path, pretrained_model_name_or_path

View File

@ -38,59 +38,80 @@ if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
def initialize_tensor_parallelism(tp_plan, tp_size=None):
def initialize_tensor_parallelism(tp_plan, tp_size=None, device_mesh=None, device_map=None):
r"""
Sets up the device mesh and initialized the backend for tensor parallelism.
This function is called when the model is loaded and the TP plan is set to 'auto'.
"""
if tp_plan is None:
return None, None, None
if tp_size is not None and tp_plan is None:
raise ValueError("tp_plan has to be set when tp_size is passed.")
if tp_plan is not None and tp_plan != "auto":
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
if tp_plan is not None and device_map is not None:
raise ValueError("`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization.")
if device_mesh is None:
if not is_torch_greater_or_equal("2.5"):
raise OSError("Tensor parallel is only supported for `torch>=2.5`.")
if not is_torch_greater_or_equal("2.5"):
raise OSError("Tensor parallel is only supported for `torch>=2.5`.")
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
if device_type == "mps":
device_type = "cpu" # fallback
current_device = getattr(torch, device_type)
if not torch.distributed.is_initialized():
try:
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
current_device = getattr(torch, device_type)
if not torch.distributed.is_initialized():
try:
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"}
backend = backend_map.get(device_type)
if device_type == "cpu" and int(os.environ.get("CCL_WORKER_COUNT", "0")):
backend = "ccl"
if device_type == "xpu" and not is_torch_greater_or_equal("2.8", accept_dev=True):
backend = "ccl"
backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"}
backend = backend_map.get(device_type)
if device_type == "cpu" and int(os.environ.get("CCL_WORKER_COUNT", "0")):
backend = "ccl"
if device_type == "xpu" and not is_torch_greater_or_equal("2.8", accept_dev=True):
backend = "ccl"
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
current_device = getattr(torch, device_type)
if device_type != "cpu":
current_device.set_device(local_rank)
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
current_device = getattr(torch, device_type)
if device_type != "cpu":
current_device.set_device(local_rank)
except Exception as e:
raise OSError(
"We tried to initialize torch.distributed for you, but it failed. Make "
"sure you init torch distributed in your script to use `tp_plan='auto'`."
) from e
except Exception as e:
raise OSError(
"We tried to initialize torch.distributed for you, but it failed. Make "
"sure you init torch distributed in your script to use `tp_plan='auto'`."
) from e
if device_type != "cpu":
current_device.set_device(int(os.environ["LOCAL_RANK"]))
index = current_device.current_device()
tp_device = torch.device(device_type, index)
device_map = tp_device
# Silence output for non-primary ranks
if index > 0:
import sys
if device_type != "cpu":
current_device.set_device(int(os.environ["LOCAL_RANK"]))
index = current_device.current_device() if device_type != "cpu" else None
tp_device = torch.device(device_type, index)
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
# Silence output for non-primary ranks
if index is not None and index > 0:
import sys
else:
tp_device = torch.device(device_type)
device_map = device_type or {}
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
else:
if device_mesh.ndim > 1:
if "tp" not in device_mesh.mesh_dim_names:
raise ValueError(
"When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. "
"Please provide a valid `device_mesh`."
)
device_mesh = device_mesh["tp"]
tp_size = device_mesh.size()
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
device_map = tp_device
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
return tp_device, device_map, device_mesh, tp_size

View File

@ -39,7 +39,6 @@ import torch
from huggingface_hub import split_torch_state_dict_into_shards
from packaging import version
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
from torch import Tensor, nn
from torch.distributions import constraints
@ -50,13 +49,20 @@ from .distributed import DistributedConfig
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled, is_fsdp_enabled
from .integrations.accelerate import find_tied_parameters, init_empty_weights
from .integrations.accelerate import (
accelerate_disk_offload,
accelerate_dispatch,
check_and_set_device_map,
find_tied_parameters,
init_empty_weights,
)
from .integrations.deepspeed import _load_state_dict_into_zero3_model
from .integrations.eager_paged import eager_paged_attention_forward
from .integrations.flash_attention import flash_attention_forward
from .integrations.flash_paged import paged_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.hub_kernels import is_kernel, load_and_register_attn_kernel
from .integrations.peft import maybe_load_adapters
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.sdpa_paged import sdpa_attention_paged_forward
from .integrations.tensor_parallel import (
@ -78,7 +84,6 @@ from .safetensors_conversion import auto_conversion
from .utils import (
ADAPTER_SAFE_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
CONFIG_NAME,
DUMMY_INPUTS,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
@ -91,7 +96,6 @@ from .utils import (
check_torch_load_is_safe,
copy_func,
download_url,
extract_commit_hash,
has_file,
is_accelerate_available,
is_bitsandbytes_available,
@ -99,7 +103,6 @@ from .utils import (
is_flash_attn_3_available,
is_kernels_available,
is_offline_mode,
is_peft_available,
is_remote_url,
is_torch_flex_attn_available,
is_torch_greater_or_equal,
@ -110,7 +113,7 @@ from .utils import (
logging,
)
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
from .utils.hub import DownloadKwargs, create_and_tag_model_card, get_checkpoint_shard_files
from .utils.import_utils import (
ENV_VARS_TRUE_VALUES,
is_huggingface_hub_greater_or_equal,
@ -122,7 +125,7 @@ from .utils.quantization_config import QuantizationMethod
if is_accelerate_available():
from accelerate import dispatch_model, infer_auto_device_map
from accelerate import infer_auto_device_map
from accelerate.hooks import add_hook_to_module
from accelerate.utils import (
check_tied_parameters_on_same_device,
@ -134,8 +137,6 @@ if is_accelerate_available():
)
from accelerate.utils.modeling import get_state_dict_from_offload
if is_peft_available():
from .utils import find_adapter_config_file
_torch_distributed_available = torch.distributed.is_available()
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
@ -276,6 +277,27 @@ def restore_default_dtype(func):
return _wrapper
def get_keep_in_fp32_regex(model, hf_quantizer, dtype):
# Find fp32 modules if needed
keep_in_fp32_modules = []
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
if model._keep_in_fp32_modules is not None and (
dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
):
keep_in_fp32_modules.extend(model._keep_in_fp32_modules)
if model._keep_in_fp32_modules_strict is not None and (dtype == torch.float16 or dtype == torch.bfloat16):
keep_in_fp32_modules.extend(model._keep_in_fp32_modules_strict)
keep_in_fp32_regex = None
if keep_in_fp32_modules:
# We need to match exact layers, so we add either `.` on each side, or start/end of string
keep_in_fp32_regex = re.compile("|".join([rf"((^|\.){module}($|\.))" for module in keep_in_fp32_modules]))
return keep_in_fp32_regex
def get_torch_context_manager_or_global_device():
"""
Test if a device context manager is currently in use, or if it is not the case, check if the default device
@ -368,81 +390,6 @@ def get_state_dict_dtype(state_dict):
return next(state_dict.values()).dtype
def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
"""
This is the same as
[`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
but for a sharded checkpoint.
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
model (`torch.nn.Module`): The model in which to load the checkpoint.
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
strict (`bool`, *optional*, defaults to `True`):
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
prefer_safe (`bool`, *optional*, defaults to `False`):
If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
Returns:
`NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
- `missing_keys` is a list of str containing the missing keys
- `unexpected_keys` is a list of str containing the unexpected keys
"""
# Load the index
index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
index_present = os.path.isfile(index_file)
safe_index_present = os.path.isfile(safe_index_file)
if not index_present and not safe_index_present:
filenames = (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME)
raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")
load_safe = safe_index_present and (prefer_safe or not index_present)
load_index = safe_index_file if load_safe else index_file
with open(load_index, "r", encoding="utf-8") as f:
index = json.load(f)
shard_files = list(set(index["weight_map"].values()))
# If strict=True, error before loading any of the state dicts.
loaded_keys = index["weight_map"].keys()
model_keys = model.state_dict().keys()
missing_keys = [key for key in model_keys if key not in loaded_keys]
unexpected_keys = [key for key in loaded_keys if key not in model_keys]
if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
if len(missing_keys) > 0:
str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
error_message += f"\nMissing key(s): {str_missing_keys}."
if len(unexpected_keys) > 0:
str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)
if load_safe:
loader = safe_load_file
else:
check_torch_load_is_safe()
loader = partial(torch.load, map_location="cpu", weights_only=True)
for shard_file in shard_files:
state_dict = loader(os.path.join(folder, shard_file))
model.load_state_dict(state_dict, strict=False)
# Make sure memory is freed before we load the next state dict.
del state_dict
gc.collect()
# Return the same thing as PyTorch load_state_dict function.
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
str_to_torch_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
@ -906,18 +853,11 @@ def update_key_name(keys):
def _get_resolved_checkpoint_files(
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
subfolder: str,
variant: Optional[str],
gguf_file: Optional[str],
use_safetensors: Optional[bool],
cache_dir: str,
force_download: bool,
proxies: Optional[dict[str, str]],
local_files_only: bool,
token: Optional[Union[str, bool]],
download_kwargs: DownloadKwargs,
user_agent: dict,
revision: str,
commit_hash: Optional[str],
is_remote_code: bool, # Because we can't determine this inside this function, we need it to be passed in
transformers_explicit_filename: Optional[str] = None,
) -> tuple[Optional[list[str]], Optional[dict]]:
@ -925,6 +865,24 @@ def _get_resolved_checkpoint_files(
checkpoints are sharded.
This function will download the data if necessary.
"""
cache_dir = download_kwargs.get("cache_dir")
force_download = download_kwargs.get("force_download", False)
proxies = download_kwargs.get("proxies")
local_files_only = download_kwargs.get("local_files_only", False)
token = download_kwargs.get("token")
revision = download_kwargs.get("revision") or "main"
subfolder = download_kwargs.get("subfolder", "")
commit_hash = download_kwargs.get("commit_hash")
if transformers_explicit_filename is not None:
if not transformers_explicit_filename.endswith(".safetensors") and not transformers_explicit_filename.endswith(
".safetensors.index.json"
):
raise ValueError(
"The transformers file in the config seems to be incorrect: it is neither a safetensors file "
"(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
f"{transformers_explicit_filename}"
)
is_sharded = False
if pretrained_model_name_or_path is not None and gguf_file is None:
@ -4212,6 +4170,31 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
return init_contexts
def set_use_kernels(self, use_kernels, kernel_config):
if use_kernels:
if not is_kernels_available():
raise ValueError(
"Kernels are not available. To use kernels, please install kernels using `pip install kernels`"
)
from kernels import use_kernel_mapping
if kernel_config is not None and isinstance(kernel_config, KernelConfig):
# This will make sure the mapping is valid, and the layers are registered in the model
kernel_config.sanitize_kernel_mapping(self)
# This will create a compatible mapping for the model with the kernels library
kernel_config.create_compatible_mapping(self)
# This is a context manager to override the default kernel mapping
# We are calling kernelize inside this context manager using the use_kernels setter
with use_kernel_mapping(kernel_config.kernel_mapping):
self.use_kernels = True
# We use the default kernel mapping in .integrations.hub_kernels
else:
self.use_kernels = True
else:
self.use_kernels = False
@classmethod
@restore_default_dtype
def from_pretrained(
@ -4431,7 +4414,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
state_dict = kwargs.pop("state_dict", None)
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
use_auth_token = kwargs.pop("use_auth_token", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
dtype = kwargs.pop("dtype", None)
@ -4457,7 +4439,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
kernel_config = kwargs.pop("kernel_config", None)
key_mapping = kwargs.pop("key_mapping", None)
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
# Load models with key mapping
if key_mapping is None and any(
allowed_name in class_name.__name__.lower() for class_name in cls.__mro__[:-1] for allowed_name in VLMS
):
@ -4467,32 +4449,31 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
tp_plan = "auto"
# Not used anymore -- remove them from the kwargs
_ = kwargs.pop("mirror", None)
_ = kwargs.pop("_fast_init", None)
_ = kwargs.pop("low_cpu_mem_usage", None)
_ = kwargs.pop("from_tf", None)
_ = kwargs.pop("from_flax", None)
_ = kwargs.pop("offload_state_dict", None)
for name in ["mirror", "_fast_init", "low_cpu_mem_usage", "from_tf", "from_flax", "offload_state_dict"]:
_ = kwargs.pop(name, None)
# For BC on torch_dtype argument
if torch_dtype is not None:
logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
# If both kwargs are provided, use `dtype`
dtype = dtype if dtype is not None else torch_dtype
if is_offline_mode() and not local_files_only:
local_files_only = True
download_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"revision": revision,
"subfolder": subfolder,
}
download_kwargs_with_commit = {**download_kwargs, "commit_hash": commit_hash}
if state_dict is not None and (pretrained_model_name_or_path is not None or gguf_file is not None):
raise ValueError(
"`state_dict` cannot be passed together with a model name or a `gguf_file`. Use one of the two loading strategies."
)
if tp_size is not None and tp_plan is None:
raise ValueError("tp_plan has to be set when tp_size is passed.")
if tp_plan is not None and tp_plan != "auto":
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
if tp_plan is not None and device_map is not None:
raise ValueError(
"`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization."
)
if device_map == "auto" and int(os.environ.get("WORLD_SIZE", "0")):
logger.info(
@ -4501,221 +4482,84 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
": PartialState().process_index} where PartialState comes from accelerate library"
)
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
# `device_map` pointing to the correct device
if tp_plan is not None:
if device_mesh is None:
tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(tp_plan, tp_size=tp_size)
else:
if device_mesh.ndim > 1:
if "tp" not in device_mesh.mesh_dim_names:
raise ValueError(
"When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. "
"Please provide a valid `device_mesh`."
)
device_mesh = device_mesh["tp"]
tp_size = device_mesh.size()
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
if tp_size is None:
tp_size = torch.distributed.get_world_size()
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
if tp_plan is not None or tp_size is not None: # TP warnings, and setup
tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(
tp_plan, tp_size=tp_size, device_mesh=device_mesh, device_map=device_map
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token
if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs:
adapter_kwargs["token"] = token
if gguf_file is not None and not is_accelerate_available():
raise ValueError("accelerate is required when loading a GGUF file `pip install accelerate`.")
if commit_hash is None:
if not isinstance(config, PreTrainedConfig):
# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
resolved_config_file = cached_file(
pretrained_model_name_or_path,
CONFIG_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
else:
commit_hash = getattr(config, "_commit_hash", None)
if adapter_kwargs is None:
adapter_kwargs = {}
if is_peft_available():
_adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)
if _adapter_model_path is None:
_adapter_model_path = find_adapter_config_file(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
_commit_hash=commit_hash,
**adapter_kwargs,
)
if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
with open(_adapter_model_path, "r", encoding="utf-8") as f:
_adapter_model_path = pretrained_model_name_or_path
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
else:
_adapter_model_path = None
# Potentially detect context manager or global device, and use it (only if no device_map was provided)
if device_map is None and not is_deepspeed_zero3_enabled():
device_in_context = get_torch_context_manager_or_global_device()
if device_in_context == torch.device("meta"):
raise RuntimeError(
"You are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`.\n"
"This is an anti-pattern as `from_pretrained` wants to load existing weights.\nIf you want to initialize an "
"empty model on the meta device, use the context manager or global device with `from_config`, or `ModelClass(config)`"
)
device_map = device_in_context
# change device_map into a map if we passed an int, a str or a torch.device
if isinstance(device_map, torch.device):
device_map = {"": device_map}
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
try:
device_map = {"": torch.device(device_map)}
except RuntimeError:
raise ValueError(
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
)
elif isinstance(device_map, int):
if device_map < 0:
raise ValueError(
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
)
else:
device_map = {"": device_map}
if device_map is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
if not is_accelerate_available():
raise ValueError(
"Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` "
"requires `accelerate`. You can install it with `pip install accelerate`"
)
_adapter_model_path, pretrained_model_name_or_path = maybe_load_adapters(
pretrained_model_name_or_path,
download_kwargs_with_commit,
**adapter_kwargs,
)
device_map = check_and_set_device_map(device_map) # warn, error and fix the device map
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
# Load config if we don't provide a configuration
if not isinstance(config, PreTrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
config, model_kwargs = cls.config_class.from_pretrained(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
gguf_file=gguf_file,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**download_kwargs,
**kwargs,
)
if "gguf_file" in model_kwargs:
model_kwargs.pop("gguf_file")
commit_hash = model_kwargs.pop("_commit_hash", commit_hash)
else:
config = copy.deepcopy(config)
model_kwargs = kwargs
commit_hash = getattr(config, "_commit_hash", commit_hash)
download_kwargs_with_commit["commit_hash"] = commit_hash
# Because some composite configs call super().__init__ before instantiating the sub-configs, we need this call
# to correctly redispatch recursively if the kwarg is provided
if "attn_implementation" in kwargs:
config._attn_implementation = kwargs.pop("attn_implementation")
transformers_explicit_filename = getattr(config, "transformers_weights", None)
if transformers_explicit_filename is not None:
if not transformers_explicit_filename.endswith(
".safetensors"
) and not transformers_explicit_filename.endswith(".safetensors.index.json"):
raise ValueError(
"The transformers file in the config seems to be incorrect: it is neither a safetensors file "
"(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
f"{transformers_explicit_filename}"
)
hf_quantizer, config, dtype, device_map = get_hf_quantizer(
config, quantization_config, dtype, device_map, weights_only, user_agent
)
if gguf_file is not None and hf_quantizer is not None:
raise ValueError(
"You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub."
)
if (
gguf_file
and device_map is not None
and ((isinstance(device_map, dict) and "disk" in device_map.values()) or "disk" in device_map)
):
raise RuntimeError(
"One or more modules is configured to be mapped to disk. Disk offload is not supported for models "
"loaded from GGUF files."
)
if gguf_file:
if hf_quantizer is not None:
raise ValueError(
"You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub."
)
if device_map is not None and (
(isinstance(device_map, dict) and "disk" in device_map.values()) or "disk" in device_map
):
raise RuntimeError(
"One or more modules is configured to be mapped to disk. Disk offload is not supported for models "
"loaded from GGUF files."
)
checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files(
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder,
variant=variant,
gguf_file=gguf_file,
use_safetensors=use_safetensors,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
download_kwargs=download_kwargs_with_commit,
user_agent=user_agent,
revision=revision,
commit_hash=commit_hash,
is_remote_code=cls._auto_class is not None,
transformers_explicit_filename=transformers_explicit_filename,
transformers_explicit_filename=getattr(config, "transformers_weights", None),
)
is_quantized = hf_quantizer is not None
is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None
# Just a helpful message in case we try to load safetensors files coming from old Transformers tf/flax classes
if is_from_file and checkpoint_files[0].endswith(".safetensors"):
with safe_open(checkpoint_files[0], framework="pt") as f:
metadata = f.metadata()
if metadata is not None and metadata.get("format") in ["tf", "flax"]:
logger.warning(
"The safetensors checkpoint found has format `tf` or `flax`. This mean that the keys will very"
"likely not match to the model you are trying to load, and will be newly initialized. If it's the case "
"another warning will be raised later. Consider converting your checkpoint to the correct format."
)
if gguf_file:
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
@ -4746,52 +4590,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# make sure we use the model's config since the __init__ call might have copied it
config = model.config
# Find fp32 modules if needed
keep_in_fp32_modules = []
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
if model._keep_in_fp32_modules is not None and (
dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
):
keep_in_fp32_modules.extend(model._keep_in_fp32_modules)
if model._keep_in_fp32_modules_strict is not None and (dtype == torch.float16 or dtype == torch.bfloat16):
keep_in_fp32_modules.extend(model._keep_in_fp32_modules_strict)
keep_in_fp32_regex = None
if keep_in_fp32_modules:
# We need to match exact layers, so we add either `.` on each side, or start/end of string
keep_in_fp32_regex = re.compile("|".join([rf"((^|\.){module}($|\.))" for module in keep_in_fp32_modules]))
if hf_quantizer is not None:
# Regex to keep a fixed dtype
keep_in_fp32_regex = get_keep_in_fp32_regex(model, hf_quantizer, dtype)
if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
hf_quantizer.preprocess_model(
model=model,
device_map=device_map,
keep_in_fp32_modules=model._keep_in_fp32_modules,
config=config,
checkpoint_files=checkpoint_files,
use_kernels=use_kernels,
)
# We store the original dtype for quantized models as we cannot easily retrieve it
# once the weights have been quantized
# Note that once you have loaded a quantized model, you can't change its dtype so this will
# remain a single source of truth
original_dtype = dtype if dtype is not None else torch.get_default_dtype()
def _assign_original_dtype(module):
for child in module.children():
if isinstance(child, PreTrainedModel):
child.config._pre_quantization_dtype = original_dtype
_assign_original_dtype(child)
config._pre_quantization_dtype = original_dtype
_assign_original_dtype(model)
# Torchao needs access to all metadata later
if hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO:
hf_quantizer.set_metadata(checkpoint_files)
if _torch_distributed_available and device_mesh is not None:
if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
model = distribute_model(model, distributed_config, device_mesh, tp_size)
# Prepare the full device map
@ -4820,109 +4631,34 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
weights_only=weights_only,
)
# make sure token embedding weights are still tied if needed
model.tie_weights()
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
# check if using kernels
if use_kernels:
if not is_kernels_available():
raise ValueError(
"Kernels are not available. To use kernels, please install kernels using `pip install kernels`"
)
from kernels import use_kernel_mapping
if kernel_config is not None and isinstance(kernel_config, KernelConfig):
# This will make sure the mapping is valid, and the layers are registered in the model
kernel_config.sanitize_kernel_mapping(model)
# This will create a compatible mapping for the model with the kernels library
kernel_config.create_compatible_mapping(model)
# This is a context manager to override the default kernel mapping
# We are calling kernelize inside this context manager using the use_kernels setter
with use_kernel_mapping(kernel_config.kernel_mapping):
model.use_kernels = True
# We use the default kernel mapping in .integrations.hub_kernels
else:
model.use_kernels = True
model.tie_weights() # make sure token embedding weights are still tied if needed
model.eval() # Set model in evaluation mode to deactivate DropOut modules by default
model.set_use_kernels(use_kernels, kernel_config)
# If it is a model with generation capabilities, attempt to load generation files (generation config,
# custom generate function)
if model.can_generate() and generation_config is not None:
logger.info("The user-defined `generation_config` will be used to override the default generation config.")
model.generation_config = model.generation_config.from_dict(generation_config.to_dict())
elif model.can_generate() and pretrained_model_name_or_path is not None:
repo_loading_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"revision": revision,
"subfolder": subfolder,
if model.can_generate() and hasattr(model, "adjust_generation_fn"):
model.adjust_generation_fn(
generation_config,
from_auto_class,
from_pipeline,
pretrained_model_name_or_path,
**download_kwargs,
trust_remote_code=trust_remote_code,
**kwargs,
}
# Load generation config
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**repo_loading_kwargs,
)
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
pass
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
if hasattr(model, "load_custom_generate"):
try:
custom_generate = model.load_custom_generate(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
)
model.generate = functools.partial(custom_generate, model=model)
except OSError: # there is no custom generate function
pass
)
# Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
# harm performances)
# for device_map="auto" : dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
# harm performances).
if device_map is not None and device_mesh is None:
device_map_kwargs = {
"device_map": device_map,
"offload_dir": offload_folder,
"offload_index": offload_index,
"offload_buffers": offload_buffers,
}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
# For HQQ method we force-set the hooks for single GPU envs
if (
"force_hooks" in inspect.signature(dispatch_model).parameters
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
):
device_map_kwargs["force_hooks"] = True
if (
hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
device_map_kwargs["offload_buffers"] = True
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
dispatch_model(model, **device_map_kwargs)
accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers)
if hf_quantizer is not None:
model.hf_quantizer = hf_quantizer
hf_quantizer.postprocess_model(model, config=config)
hf_quantizer.postprocess_model(model, config=config) # usually a no-op
if _adapter_model_path is not None:
adapter_kwargs["key_mapping"] = key_mapping
adapter_kwargs["key_mapping"] = key_mapping # TODO: Dynamic weight loader for adapters
model.load_adapter(
_adapter_model_path,
adapter_name=adapter_name,
@ -5081,7 +4817,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
QuantizationMethod.QUARK,
}
# Get all the keys of the state dicts that we have to initialize the model
# Get all the keys of the state dicts that we have to initialize the model with
if sharded_metadata is not None:
original_checkpoint_keys = sharded_metadata["all_checkpoint_keys"]
elif state_dict is not None:
@ -5152,43 +4888,16 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
disk_only_shard_files = []
# Prepare parameters offloading if needed
if device_map is not None and "disk" in device_map.values():
if disk_offload_folder is not None:
os.makedirs(disk_offload_folder, exist_ok=True)
is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")
if disk_offload_folder is None and not is_offloaded_safetensors:
raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
if is_offloaded_safetensors:
param_device_map = expand_device_map(device_map, checkpoint_keys)
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
else:
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
# Fix the weight map keys according to the key mapping
weight_map = {
key_renaming_mapping[k]: v
for k, v in sharded_metadata["weight_map"].items()
if k in key_renaming_mapping
}
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
# Find potential checkpoints containing only offloaded weights
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
disk_offload_index = {
name: {
"safetensors_file": file,
"weight_name": reverse_key_renaming_mapping[name],
"dtype": str_dtype,
}
for name, file in weight_map.items()
if param_device_map[name] == "disk"
}
else:
disk_offload_index = {}
disk_offload_index, disk_only_shard_files, is_offloaded_safetensors = accelerate_disk_offload(
disk_offload_folder,
checkpoint_files,
device_map,
checkpoint_keys,
key_renaming_mapping,
sharded_metadata,
dtype,
reverse_key_renaming_mapping,
)
# To be able to iterate, even if we don't use it if the state_dict is already provided
elif state_dict is not None:
checkpoint_files = [""]
@ -5286,6 +4995,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
missing_keys, unexpected_keys, loading_task_model_from_base_state_dict
)
# TODO: separate this in another function: it's not core....
# All potential warnings/infos
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
@ -5787,19 +5497,6 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
def get_disk_only_shard_files(device_map, weight_map):
"""
Returns the list of shard files containing only weights offloaded to disk.
"""
files_content = collections.defaultdict(list)
for weight_name, filename in weight_map.items():
while len(weight_name) > 0 and weight_name not in device_map:
weight_name = ".".join(weight_name.split(".")[:-1])
files_content[filename].append(device_map[weight_name])
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
class AttentionInterface(GeneralInterface):
"""
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function

View File

@ -31,6 +31,13 @@ else:
logger = logging.get_logger(__file__)
def _assign_original_dtype(module, original_dtype):
for child in module.children():
if isinstance(child, PreTrainedModel):
child.config._pre_quantization_dtype = original_dtype
_assign_original_dtype(child, original_dtype)
class HfQuantizer(ABC):
"""
Abstract class of the HuggingFace quantizer. Supports for now quantizing HF transformers models for inference and/or quantization.
@ -206,7 +213,7 @@ class HfQuantizer(ABC):
"updates the tp plan for the scales"
return config
def preprocess_model(self, model: "PreTrainedModel", **kwargs):
def preprocess_model(self, model: "PreTrainedModel", config, dtype=None, checkpoint_files=None, **kwargs):
"""
Setting model attributes and/or converting model before weights loading. At this point
the model should be initialized on the meta device so you can freely manipulate the skeleton
@ -222,7 +229,18 @@ class HfQuantizer(ABC):
model.quantization_method = self.quantization_config.quant_method
if self.pre_quantized:
self._convert_model_for_quantization(model)
return self._process_model_before_weight_loading(model, **kwargs)
self._process_model_before_weight_loading(model, **kwargs)
# We store the original dtype for quantized models as we cannot easily retrieve it
# once the weights have been quantized
# Note that once you have loaded a quantized model, you can't change its dtype so this will
# remain a single source of truth
original_dtype = dtype if dtype is not None else torch.get_default_dtype()
config._pre_quantization_dtype = original_dtype
_assign_original_dtype(model, original_dtype)
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
def postprocess_model(self, model: "PreTrainedModel", **kwargs):
"""

View File

@ -79,9 +79,6 @@ class AqlmHfQuantizer(HfQuantizer):
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
@property
def is_trainable(self) -> bool:
aqlm_supports_training = version.parse(importlib.metadata.version("aqlm")) >= version.parse("1.0.2")

View File

@ -69,9 +69,6 @@ class BitNetHfQuantizer(HfQuantizer):
"This is not supported. Please remove the CPU or disk device from the device_map."
)
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",

View File

@ -137,9 +137,6 @@ class EetqHfQuantizer(HfQuantizer):
module._buffers[tensor_name] = new_value.to(target_device)
module.register("weight_scales", weight_scale.to(target_device))
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",

View File

@ -192,9 +192,6 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
del param_name
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",

View File

@ -167,9 +167,6 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
from ..integrations import FP8Linear

View File

@ -140,9 +140,6 @@ class FPQuantHfQuantizer(HfQuantizer):
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
from fp_quant import FPQuantLinear

View File

@ -157,9 +157,6 @@ class QuantoHfQuantizer(HfQuantizer):
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model, **kwargs):
return model
@property
def is_trainable(self) -> bool:
return True

View File

@ -88,9 +88,6 @@ class QuarkHfQuantizer(HfQuantizer):
_load_parameter_into_model(model, param_name, param.to(param_device))
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
def is_serializable(self, safe_serialization=None):
return False

View File

@ -79,9 +79,6 @@ class SpQRHfQuantizer(HfQuantizer):
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
@property
def is_trainable(self):
return False

View File

@ -350,6 +350,22 @@ class TorchAoHfQuantizer(HfQuantizer):
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
def preprocess_model(self, model: "PreTrainedModel", config, dtype=None, checkpoint_files=None, **kwargs):
"""
Setting model attributes and/or converting model before weights loading. At this point
the model should be initialized on the meta device so you can freely manipulate the skeleton
of the model in order to replace modules in-place. Make sure to override the abstract method `_process_model_before_weight_loading`.
Args:
model (`~transformers.PreTrainedModel`):
The model to quantize
kwargs (`dict`, *optional*):
The keyword arguments that are passed along `_process_model_before_weight_loading`.
"""
super().preprocess_model(model, config, dtype, checkpoint_files, **kwargs)
# Torchao needs access to all metadata later
model.set_metadata(checkpoint_files)
def _process_model_after_weight_loading(self, model, **kwargs):
"""No process required for torchao quantized model"""
if self.quantization_config.quant_type == "autoquant":

View File

@ -88,9 +88,6 @@ class VptqHfQuantizer(HfQuantizer):
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
@property
def is_trainable(self) -> bool:
return False

View File

@ -65,7 +65,7 @@ from .image_processing_utils import BaseImageProcessor
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .integrations.tpu import tpu_spmd_dataloader
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .modeling_utils import PreTrainedModel, unwrap_model
from .models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
@ -123,6 +123,7 @@ from .trainer_utils import (
find_executable_batch_size,
get_last_checkpoint,
has_length,
load_sharded_checkpoint,
neftune_post_forward_hook,
number_of_arguments,
seed_worker,

View File

@ -19,18 +19,23 @@ import copy
import functools
import gc
import inspect
import json
import os
import random
import re
import threading
import time
from collections.abc import Callable
from functools import partial
from typing import Any, NamedTuple, Optional, Union
import numpy as np
from .utils import (
SAFE_WEIGHTS_INDEX_NAME,
WEIGHTS_INDEX_NAME,
ExplicitEnum,
check_torch_load_is_safe,
is_psutil_available,
is_torch_available,
is_torch_cuda_available,
@ -47,6 +52,7 @@ from .utils import (
if is_torch_available():
import torch
from safetensors.torch import load_file as safe_load_file
def seed_worker(worker_id: int, num_workers: int, rank: int):
@ -873,3 +879,79 @@ def check_target_module_exists(optim_target_modules, key: str, return_is_regex:
return target_module_found, is_regex
return target_module_found
def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
"""
This is the same as
[`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
but for a sharded checkpoint.
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
model (`torch.nn.Module`): The model in which to load the checkpoint.
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
strict (`bool`, *optional*, defaults to `True`):
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
prefer_safe (`bool`, *optional*, defaults to `False`):
If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
Returns:
`NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
- `missing_keys` is a list of str containing the missing keys
- `unexpected_keys` is a list of str containing the unexpected keys
"""
# Load the index
index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
index_present = os.path.isfile(index_file)
safe_index_present = os.path.isfile(safe_index_file)
if not index_present and not safe_index_present:
filenames = (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME)
raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")
load_safe = safe_index_present and (prefer_safe or not index_present)
load_index = safe_index_file if load_safe else index_file
with open(load_index, "r", encoding="utf-8") as f:
index = json.load(f)
shard_files = list(set(index["weight_map"].values()))
# If strict=True, error before loading any of the state dicts.
# TODO: Here, update the weigth map with the config.dynamic_weight_conversion
loaded_keys = index["weight_map"].keys()
model_keys = model.state_dict().keys()
missing_keys = [key for key in model_keys if key not in loaded_keys]
unexpected_keys = [key for key in loaded_keys if key not in model_keys]
if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
if len(missing_keys) > 0:
str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
error_message += f"\nMissing key(s): {str_missing_keys}."
if len(unexpected_keys) > 0:
str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)
if load_safe:
loader = safe_load_file
else:
check_torch_load_is_safe()
loader = partial(torch.load, map_location="cpu", weights_only=True)
for shard_file in shard_files:
state_dict = loader(os.path.join(folder, shard_file))
model.load_state_dict(state_dict, strict=False)
# Make sure memory is freed before we load the next state dict.
del state_dict
gc.collect()
# Return the same thing as PyTorch load_state_dict function.
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)

View File

@ -23,7 +23,7 @@ import tempfile
import warnings
from concurrent import futures
from pathlib import Path
from typing import Optional, Union
from typing import Optional, TypedDict, Union
from urllib.parse import urlparse
from uuid import uuid4
@ -75,6 +75,18 @@ CHAT_TEMPLATE_DIR = "additional_chat_templates"
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class DownloadKwargs(TypedDict, total=False):
cache_dir: Optional[Union[str, os.PathLike]]
force_download: bool
proxies: Optional[dict[str, str]]
local_files_only: bool
token: Optional[Union[str, bool]]
revision: Optional[str]
subfolder: str
commit_hash: Optional[str]
_is_offline_mode = huggingface_hub.constants.HF_HUB_OFFLINE