mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
[from_pretrained
] Small refactor from_pretrained
: move around unrelated stuff (#41445)
* drafts * up * simplify modeling utils * more simplifications * type kwargs * up * move more accelerate related stuff * safeguarding? * nits * remove func when func is NOPE * more * nits * styling * yups * up * ups * revert * protect trainer utils iport * fix doc * Update src/transformers/integrations/peft.py Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co> * review * update * ? * fixx * update * super small update * ups * style * this is stupid * 🤦 well this was the issue * small nit * fix * nit * damn the missing return * one last stupid fix --------- Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
@ -42,7 +42,3 @@ set this to `False`.
|
||||
## Pushing to the Hub
|
||||
|
||||
[[autodoc]] utils.PushToHubMixin
|
||||
|
||||
## Sharded checkpoints
|
||||
|
||||
[[autodoc]] modeling_utils.load_sharded_checkpoint
|
||||
|
@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
@ -365,6 +366,58 @@ class GenerationMixin(ContinuousMixin):
|
||||
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
||||
"""
|
||||
|
||||
def adjust_generation_fn(
|
||||
self,
|
||||
generation_config,
|
||||
from_auto_class,
|
||||
from_pipeline,
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir,
|
||||
force_download,
|
||||
proxies,
|
||||
local_files_only,
|
||||
token,
|
||||
revision,
|
||||
subfolder,
|
||||
trust_remote_code,
|
||||
**kwargs,
|
||||
):
|
||||
if self.can_generate() and generation_config is not None:
|
||||
self.generation_config = self.generation_config.from_dict(generation_config.to_dict())
|
||||
elif self.can_generate() and pretrained_model_name_or_path is not None:
|
||||
repo_loading_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"local_files_only": local_files_only,
|
||||
"token": token,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
**kwargs,
|
||||
}
|
||||
# Load generation config
|
||||
try:
|
||||
self.generation_config = GenerationConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
**repo_loading_kwargs,
|
||||
)
|
||||
except OSError:
|
||||
logger.info(
|
||||
"Generation config file not found, using a generation config created from the model config."
|
||||
)
|
||||
pass
|
||||
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
|
||||
if hasattr(self, "load_custom_generate"):
|
||||
try:
|
||||
custom_generate = self.load_custom_generate(
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
|
||||
)
|
||||
self.generate = functools.partial(custom_generate, model=self)
|
||||
except OSError: # there is no custom generate function
|
||||
pass
|
||||
|
||||
def load_custom_generate(
|
||||
self,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
||||
|
@ -20,15 +20,24 @@ The `init_empty_weights` and `init_on_device` functions were copied from `accele
|
||||
`find_tied_parameters` was copied from `accelerate.utils.modeling.py`
|
||||
"""
|
||||
|
||||
import collections
|
||||
import inspect
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ..utils import is_torch_available, logging
|
||||
from ..utils import is_accelerate_available, is_torch_available, logging
|
||||
from ..utils.quantization_config import QuantizationMethod
|
||||
from .deepspeed import is_deepspeed_zero3_enabled
|
||||
from .fsdp import is_fsdp_enabled
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import dispatch_model
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -194,3 +203,150 @@ def find_tied_parameters(model: "nn.Module", **kwargs):
|
||||
tied_param_groups[param_name].append(tied_param_name)
|
||||
|
||||
return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]
|
||||
|
||||
|
||||
def check_and_set_device_map(device_map):
|
||||
from ..modeling_utils import get_torch_context_manager_or_global_device
|
||||
|
||||
# Potentially detect context manager or global device, and use it (only if no device_map was provided)
|
||||
if device_map is None and not is_deepspeed_zero3_enabled():
|
||||
device_in_context = get_torch_context_manager_or_global_device()
|
||||
if device_in_context == torch.device("meta"):
|
||||
raise RuntimeError(
|
||||
"You are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`.\n"
|
||||
"This is an anti-pattern as `from_pretrained` wants to load existing weights.\nIf you want to initialize an "
|
||||
"empty model on the meta device, use the context manager or global device with `from_config`, or `ModelClass(config)`"
|
||||
)
|
||||
device_map = device_in_context
|
||||
|
||||
# change device_map into a map if we passed an int, a str or a torch.device
|
||||
if isinstance(device_map, torch.device):
|
||||
device_map = {"": device_map}
|
||||
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
||||
try:
|
||||
device_map = {"": torch.device(device_map)}
|
||||
except RuntimeError:
|
||||
raise ValueError(
|
||||
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
|
||||
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
|
||||
)
|
||||
elif isinstance(device_map, int):
|
||||
if device_map < 0:
|
||||
raise ValueError(
|
||||
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
|
||||
)
|
||||
else:
|
||||
device_map = {"": device_map}
|
||||
|
||||
if device_map is not None:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
|
||||
if not is_accelerate_available():
|
||||
raise ValueError(
|
||||
"Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` "
|
||||
"requires `accelerate`. You can install it with `pip install accelerate`"
|
||||
)
|
||||
return device_map
|
||||
|
||||
|
||||
def accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers):
|
||||
device_map_kwargs = {
|
||||
"device_map": device_map,
|
||||
"offload_dir": offload_folder,
|
||||
"offload_index": offload_index,
|
||||
"offload_buffers": offload_buffers,
|
||||
}
|
||||
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
||||
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
||||
# For HQQ method we force-set the hooks for single GPU envs
|
||||
if (
|
||||
"force_hooks" in inspect.signature(dispatch_model).parameters
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
|
||||
):
|
||||
device_map_kwargs["force_hooks"] = True
|
||||
if (
|
||||
hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8
|
||||
and isinstance(device_map, dict)
|
||||
and ("cpu" in device_map.values() or "disk" in device_map.values())
|
||||
):
|
||||
device_map_kwargs["offload_buffers"] = True
|
||||
|
||||
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
||||
dispatch_model(model, **device_map_kwargs)
|
||||
|
||||
|
||||
def get_disk_only_shard_files(device_map, weight_map):
|
||||
"""
|
||||
Returns the list of shard files containing only weights offloaded to disk.
|
||||
"""
|
||||
files_content = collections.defaultdict(list)
|
||||
for weight_name, filename in weight_map.items():
|
||||
while len(weight_name) > 0 and weight_name not in device_map:
|
||||
weight_name = ".".join(weight_name.split(".")[:-1])
|
||||
files_content[filename].append(device_map[weight_name])
|
||||
|
||||
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
|
||||
|
||||
|
||||
def expand_device_map(device_map, param_names):
|
||||
"""
|
||||
Expand a device map to return the correspondence parameter name to device.
|
||||
"""
|
||||
new_device_map = {}
|
||||
for module, device in device_map.items():
|
||||
new_device_map.update(
|
||||
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
|
||||
)
|
||||
return new_device_map
|
||||
|
||||
|
||||
def accelerate_disk_offload(
|
||||
disk_offload_folder,
|
||||
checkpoint_files,
|
||||
device_map,
|
||||
checkpoint_keys,
|
||||
key_renaming_mapping,
|
||||
sharded_metadata,
|
||||
dtype,
|
||||
reverse_key_renaming_mapping,
|
||||
):
|
||||
disk_only_shard_files = []
|
||||
if disk_offload_folder is not None:
|
||||
os.makedirs(disk_offload_folder, exist_ok=True)
|
||||
is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")
|
||||
if disk_offload_folder is None and not is_offloaded_safetensors:
|
||||
raise ValueError(
|
||||
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
|
||||
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
|
||||
" offers the weights in this format."
|
||||
)
|
||||
if is_offloaded_safetensors:
|
||||
param_device_map = expand_device_map(device_map, checkpoint_keys)
|
||||
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
||||
if sharded_metadata is None:
|
||||
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
|
||||
else:
|
||||
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
|
||||
# Fix the weight map keys according to the key mapping
|
||||
weight_map = {
|
||||
key_renaming_mapping[k]: v
|
||||
for k, v in sharded_metadata["weight_map"].items()
|
||||
if k in key_renaming_mapping
|
||||
}
|
||||
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
|
||||
# Find potential checkpoints containing only offloaded weights
|
||||
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
|
||||
disk_offload_index = {
|
||||
name: {
|
||||
"safetensors_file": file,
|
||||
"weight_name": reverse_key_renaming_mapping[name],
|
||||
"dtype": str_dtype,
|
||||
}
|
||||
for name, file in weight_map.items()
|
||||
if param_device_map[name] == "disk"
|
||||
}
|
||||
else:
|
||||
disk_offload_index = {}
|
||||
return disk_offload_index, disk_only_shard_files, is_offloaded_safetensors
|
||||
|
@ -12,21 +12,27 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
cached_file,
|
||||
check_peft_version,
|
||||
extract_commit_hash,
|
||||
find_adapter_config_file,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from ..utils.hub import DownloadKwargs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -249,7 +255,7 @@ class PeftAdapterMixin:
|
||||
else:
|
||||
new_key = key
|
||||
|
||||
if key_mapping:
|
||||
if key_mapping: # TODO dynamic weight loader for adapters
|
||||
for pattern, replacement in key_mapping.items():
|
||||
new_key, n_replace = re.subn(pattern, replacement, new_key)
|
||||
# Early exit of the loop
|
||||
@ -614,3 +620,54 @@ class PeftAdapterMixin:
|
||||
if len(self.peft_config) == 0:
|
||||
del self.peft_config
|
||||
self._hf_peft_config_loaded = False
|
||||
|
||||
|
||||
def maybe_load_adapters(
|
||||
pretrained_model_name_or_path,
|
||||
download_kwargs: DownloadKwargs,
|
||||
**adapter_kwargs,
|
||||
):
|
||||
if pretrained_model_name_or_path is None or not is_peft_available():
|
||||
return None, pretrained_model_name_or_path
|
||||
|
||||
token = download_kwargs.get("token")
|
||||
|
||||
if download_kwargs.get("commit_hash") is None:
|
||||
resolved_config_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
CONFIG_NAME,
|
||||
cache_dir=download_kwargs.get("cache_dir"),
|
||||
force_download=bool(download_kwargs.get("force_download", False)),
|
||||
proxies=download_kwargs.get("proxies"),
|
||||
local_files_only=bool(download_kwargs.get("local_files_only", False)),
|
||||
token=token,
|
||||
revision=download_kwargs.get("revision"),
|
||||
subfolder=download_kwargs.get("subfolder"),
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
download_kwargs["commit_hash"] = extract_commit_hash(resolved_config_file, None)
|
||||
|
||||
_adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)
|
||||
|
||||
if _adapter_model_path is None:
|
||||
_adapter_model_path = find_adapter_config_file(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=download_kwargs.get("cache_dir"),
|
||||
force_download=bool(download_kwargs.get("force_download", False)),
|
||||
proxies=download_kwargs.get("proxies"),
|
||||
token=token,
|
||||
revision=download_kwargs.get("revision"),
|
||||
local_files_only=bool(download_kwargs.get("local_files_only", False)),
|
||||
subfolder=download_kwargs.get("subfolder", ""),
|
||||
_commit_hash=download_kwargs.get("commit_hash"),
|
||||
**adapter_kwargs,
|
||||
)
|
||||
|
||||
if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
|
||||
with open(_adapter_model_path, "r", encoding="utf-8") as f:
|
||||
_adapter_model_path = pretrained_model_name_or_path
|
||||
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
|
||||
|
||||
return _adapter_model_path, pretrained_model_name_or_path
|
||||
|
@ -38,59 +38,80 @@ if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
||||
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
|
||||
|
||||
|
||||
def initialize_tensor_parallelism(tp_plan, tp_size=None):
|
||||
def initialize_tensor_parallelism(tp_plan, tp_size=None, device_mesh=None, device_map=None):
|
||||
r"""
|
||||
Sets up the device mesh and initialized the backend for tensor parallelism.
|
||||
This function is called when the model is loaded and the TP plan is set to 'auto'.
|
||||
"""
|
||||
if tp_plan is None:
|
||||
return None, None, None
|
||||
if tp_size is not None and tp_plan is None:
|
||||
raise ValueError("tp_plan has to be set when tp_size is passed.")
|
||||
if tp_plan is not None and tp_plan != "auto":
|
||||
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
|
||||
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
|
||||
if tp_plan is not None and device_map is not None:
|
||||
raise ValueError("`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization.")
|
||||
if device_mesh is None:
|
||||
if not is_torch_greater_or_equal("2.5"):
|
||||
raise OSError("Tensor parallel is only supported for `torch>=2.5`.")
|
||||
|
||||
if not is_torch_greater_or_equal("2.5"):
|
||||
raise OSError("Tensor parallel is only supported for `torch>=2.5`.")
|
||||
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
||||
device_type = torch._C._get_accelerator().type
|
||||
if device_type == "mps":
|
||||
device_type = "cpu" # fallback
|
||||
current_device = getattr(torch, device_type)
|
||||
if not torch.distributed.is_initialized():
|
||||
try:
|
||||
rank = int(os.environ["RANK"])
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
||||
device_type = torch._C._get_accelerator().type
|
||||
current_device = getattr(torch, device_type)
|
||||
if not torch.distributed.is_initialized():
|
||||
try:
|
||||
rank = int(os.environ["RANK"])
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"}
|
||||
backend = backend_map.get(device_type)
|
||||
if device_type == "cpu" and int(os.environ.get("CCL_WORKER_COUNT", "0")):
|
||||
backend = "ccl"
|
||||
if device_type == "xpu" and not is_torch_greater_or_equal("2.8", accept_dev=True):
|
||||
backend = "ccl"
|
||||
|
||||
backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"}
|
||||
backend = backend_map.get(device_type)
|
||||
if device_type == "cpu" and int(os.environ.get("CCL_WORKER_COUNT", "0")):
|
||||
backend = "ccl"
|
||||
if device_type == "xpu" and not is_torch_greater_or_equal("2.8", accept_dev=True):
|
||||
backend = "ccl"
|
||||
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
current_device = getattr(torch, device_type)
|
||||
if device_type != "cpu":
|
||||
current_device.set_device(local_rank)
|
||||
|
||||
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
current_device = getattr(torch, device_type)
|
||||
if device_type != "cpu":
|
||||
current_device.set_device(local_rank)
|
||||
except Exception as e:
|
||||
raise OSError(
|
||||
"We tried to initialize torch.distributed for you, but it failed. Make "
|
||||
"sure you init torch distributed in your script to use `tp_plan='auto'`."
|
||||
) from e
|
||||
|
||||
except Exception as e:
|
||||
raise OSError(
|
||||
"We tried to initialize torch.distributed for you, but it failed. Make "
|
||||
"sure you init torch distributed in your script to use `tp_plan='auto'`."
|
||||
) from e
|
||||
if device_type != "cpu":
|
||||
current_device.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
index = current_device.current_device()
|
||||
tp_device = torch.device(device_type, index)
|
||||
device_map = tp_device
|
||||
# Silence output for non-primary ranks
|
||||
if index > 0:
|
||||
import sys
|
||||
|
||||
if device_type != "cpu":
|
||||
current_device.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
index = current_device.current_device() if device_type != "cpu" else None
|
||||
tp_device = torch.device(device_type, index)
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
|
||||
# Silence output for non-primary ranks
|
||||
if index is not None and index > 0:
|
||||
import sys
|
||||
else:
|
||||
tp_device = torch.device(device_type)
|
||||
device_map = device_type or {}
|
||||
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
|
||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
|
||||
else:
|
||||
if device_mesh.ndim > 1:
|
||||
if "tp" not in device_mesh.mesh_dim_names:
|
||||
raise ValueError(
|
||||
"When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. "
|
||||
"Please provide a valid `device_mesh`."
|
||||
)
|
||||
device_mesh = device_mesh["tp"]
|
||||
tp_size = device_mesh.size()
|
||||
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
|
||||
|
||||
device_map = tp_device
|
||||
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
|
||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
|
||||
return tp_device, device_map, device_mesh, tp_size
|
||||
|
||||
|
||||
|
@ -39,7 +39,6 @@ import torch
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from packaging import version
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
from torch import Tensor, nn
|
||||
from torch.distributions import constraints
|
||||
@ -50,13 +49,20 @@ from .distributed import DistributedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import CompileConfig, GenerationConfig
|
||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled, is_fsdp_enabled
|
||||
from .integrations.accelerate import find_tied_parameters, init_empty_weights
|
||||
from .integrations.accelerate import (
|
||||
accelerate_disk_offload,
|
||||
accelerate_dispatch,
|
||||
check_and_set_device_map,
|
||||
find_tied_parameters,
|
||||
init_empty_weights,
|
||||
)
|
||||
from .integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||
from .integrations.eager_paged import eager_paged_attention_forward
|
||||
from .integrations.flash_attention import flash_attention_forward
|
||||
from .integrations.flash_paged import paged_attention_forward
|
||||
from .integrations.flex_attention import flex_attention_forward
|
||||
from .integrations.hub_kernels import is_kernel, load_and_register_attn_kernel
|
||||
from .integrations.peft import maybe_load_adapters
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
from .integrations.sdpa_paged import sdpa_attention_paged_forward
|
||||
from .integrations.tensor_parallel import (
|
||||
@ -78,7 +84,6 @@ from .safetensors_conversion import auto_conversion
|
||||
from .utils import (
|
||||
ADAPTER_SAFE_WEIGHTS_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
CONFIG_NAME,
|
||||
DUMMY_INPUTS,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
@ -91,7 +96,6 @@ from .utils import (
|
||||
check_torch_load_is_safe,
|
||||
copy_func,
|
||||
download_url,
|
||||
extract_commit_hash,
|
||||
has_file,
|
||||
is_accelerate_available,
|
||||
is_bitsandbytes_available,
|
||||
@ -99,7 +103,6 @@ from .utils import (
|
||||
is_flash_attn_3_available,
|
||||
is_kernels_available,
|
||||
is_offline_mode,
|
||||
is_peft_available,
|
||||
is_remote_url,
|
||||
is_torch_flex_attn_available,
|
||||
is_torch_greater_or_equal,
|
||||
@ -110,7 +113,7 @@ from .utils import (
|
||||
logging,
|
||||
)
|
||||
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
|
||||
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
|
||||
from .utils.hub import DownloadKwargs, create_and_tag_model_card, get_checkpoint_shard_files
|
||||
from .utils.import_utils import (
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
is_huggingface_hub_greater_or_equal,
|
||||
@ -122,7 +125,7 @@ from .utils.quantization_config import QuantizationMethod
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import dispatch_model, infer_auto_device_map
|
||||
from accelerate import infer_auto_device_map
|
||||
from accelerate.hooks import add_hook_to_module
|
||||
from accelerate.utils import (
|
||||
check_tied_parameters_on_same_device,
|
||||
@ -134,8 +137,6 @@ if is_accelerate_available():
|
||||
)
|
||||
from accelerate.utils.modeling import get_state_dict_from_offload
|
||||
|
||||
if is_peft_available():
|
||||
from .utils import find_adapter_config_file
|
||||
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
|
||||
@ -276,6 +277,27 @@ def restore_default_dtype(func):
|
||||
return _wrapper
|
||||
|
||||
|
||||
def get_keep_in_fp32_regex(model, hf_quantizer, dtype):
|
||||
# Find fp32 modules if needed
|
||||
keep_in_fp32_modules = []
|
||||
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
|
||||
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
|
||||
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
|
||||
if model._keep_in_fp32_modules is not None and (
|
||||
dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
|
||||
):
|
||||
keep_in_fp32_modules.extend(model._keep_in_fp32_modules)
|
||||
|
||||
if model._keep_in_fp32_modules_strict is not None and (dtype == torch.float16 or dtype == torch.bfloat16):
|
||||
keep_in_fp32_modules.extend(model._keep_in_fp32_modules_strict)
|
||||
|
||||
keep_in_fp32_regex = None
|
||||
if keep_in_fp32_modules:
|
||||
# We need to match exact layers, so we add either `.` on each side, or start/end of string
|
||||
keep_in_fp32_regex = re.compile("|".join([rf"((^|\.){module}($|\.))" for module in keep_in_fp32_modules]))
|
||||
return keep_in_fp32_regex
|
||||
|
||||
|
||||
def get_torch_context_manager_or_global_device():
|
||||
"""
|
||||
Test if a device context manager is currently in use, or if it is not the case, check if the default device
|
||||
@ -368,81 +390,6 @@ def get_state_dict_dtype(state_dict):
|
||||
return next(state_dict.values()).dtype
|
||||
|
||||
|
||||
def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
||||
"""
|
||||
This is the same as
|
||||
[`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
|
||||
but for a sharded checkpoint.
|
||||
|
||||
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
|
||||
loaded in the model.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model in which to load the checkpoint.
|
||||
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
|
||||
strict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
|
||||
prefer_safe (`bool`, *optional*, defaults to `False`):
|
||||
If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
|
||||
safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
|
||||
|
||||
Returns:
|
||||
`NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
|
||||
- `missing_keys` is a list of str containing the missing keys
|
||||
- `unexpected_keys` is a list of str containing the unexpected keys
|
||||
"""
|
||||
# Load the index
|
||||
index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
|
||||
safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
|
||||
|
||||
index_present = os.path.isfile(index_file)
|
||||
safe_index_present = os.path.isfile(safe_index_file)
|
||||
|
||||
if not index_present and not safe_index_present:
|
||||
filenames = (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME)
|
||||
raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")
|
||||
|
||||
load_safe = safe_index_present and (prefer_safe or not index_present)
|
||||
load_index = safe_index_file if load_safe else index_file
|
||||
|
||||
with open(load_index, "r", encoding="utf-8") as f:
|
||||
index = json.load(f)
|
||||
|
||||
shard_files = list(set(index["weight_map"].values()))
|
||||
|
||||
# If strict=True, error before loading any of the state dicts.
|
||||
loaded_keys = index["weight_map"].keys()
|
||||
model_keys = model.state_dict().keys()
|
||||
missing_keys = [key for key in model_keys if key not in loaded_keys]
|
||||
unexpected_keys = [key for key in loaded_keys if key not in model_keys]
|
||||
if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
|
||||
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
|
||||
if len(missing_keys) > 0:
|
||||
str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
|
||||
error_message += f"\nMissing key(s): {str_missing_keys}."
|
||||
if len(unexpected_keys) > 0:
|
||||
str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
|
||||
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
if load_safe:
|
||||
loader = safe_load_file
|
||||
else:
|
||||
check_torch_load_is_safe()
|
||||
loader = partial(torch.load, map_location="cpu", weights_only=True)
|
||||
|
||||
for shard_file in shard_files:
|
||||
state_dict = loader(os.path.join(folder, shard_file))
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Make sure memory is freed before we load the next state dict.
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
# Return the same thing as PyTorch load_state_dict function.
|
||||
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
|
||||
|
||||
str_to_torch_dtype = {
|
||||
"BOOL": torch.bool,
|
||||
"U8": torch.uint8,
|
||||
@ -906,18 +853,11 @@ def update_key_name(keys):
|
||||
|
||||
def _get_resolved_checkpoint_files(
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||
subfolder: str,
|
||||
variant: Optional[str],
|
||||
gguf_file: Optional[str],
|
||||
use_safetensors: Optional[bool],
|
||||
cache_dir: str,
|
||||
force_download: bool,
|
||||
proxies: Optional[dict[str, str]],
|
||||
local_files_only: bool,
|
||||
token: Optional[Union[str, bool]],
|
||||
download_kwargs: DownloadKwargs,
|
||||
user_agent: dict,
|
||||
revision: str,
|
||||
commit_hash: Optional[str],
|
||||
is_remote_code: bool, # Because we can't determine this inside this function, we need it to be passed in
|
||||
transformers_explicit_filename: Optional[str] = None,
|
||||
) -> tuple[Optional[list[str]], Optional[dict]]:
|
||||
@ -925,6 +865,24 @@ def _get_resolved_checkpoint_files(
|
||||
checkpoints are sharded.
|
||||
This function will download the data if necessary.
|
||||
"""
|
||||
cache_dir = download_kwargs.get("cache_dir")
|
||||
force_download = download_kwargs.get("force_download", False)
|
||||
proxies = download_kwargs.get("proxies")
|
||||
local_files_only = download_kwargs.get("local_files_only", False)
|
||||
token = download_kwargs.get("token")
|
||||
revision = download_kwargs.get("revision") or "main"
|
||||
subfolder = download_kwargs.get("subfolder", "")
|
||||
commit_hash = download_kwargs.get("commit_hash")
|
||||
if transformers_explicit_filename is not None:
|
||||
if not transformers_explicit_filename.endswith(".safetensors") and not transformers_explicit_filename.endswith(
|
||||
".safetensors.index.json"
|
||||
):
|
||||
raise ValueError(
|
||||
"The transformers file in the config seems to be incorrect: it is neither a safetensors file "
|
||||
"(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
|
||||
f"{transformers_explicit_filename}"
|
||||
)
|
||||
|
||||
is_sharded = False
|
||||
|
||||
if pretrained_model_name_or_path is not None and gguf_file is None:
|
||||
@ -4212,6 +4170,31 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
|
||||
return init_contexts
|
||||
|
||||
def set_use_kernels(self, use_kernels, kernel_config):
|
||||
if use_kernels:
|
||||
if not is_kernels_available():
|
||||
raise ValueError(
|
||||
"Kernels are not available. To use kernels, please install kernels using `pip install kernels`"
|
||||
)
|
||||
from kernels import use_kernel_mapping
|
||||
|
||||
if kernel_config is not None and isinstance(kernel_config, KernelConfig):
|
||||
# This will make sure the mapping is valid, and the layers are registered in the model
|
||||
kernel_config.sanitize_kernel_mapping(self)
|
||||
|
||||
# This will create a compatible mapping for the model with the kernels library
|
||||
kernel_config.create_compatible_mapping(self)
|
||||
|
||||
# This is a context manager to override the default kernel mapping
|
||||
# We are calling kernelize inside this context manager using the use_kernels setter
|
||||
with use_kernel_mapping(kernel_config.kernel_mapping):
|
||||
self.use_kernels = True
|
||||
# We use the default kernel mapping in .integrations.hub_kernels
|
||||
else:
|
||||
self.use_kernels = True
|
||||
else:
|
||||
self.use_kernels = False
|
||||
|
||||
@classmethod
|
||||
@restore_default_dtype
|
||||
def from_pretrained(
|
||||
@ -4431,7 +4414,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
state_dict = kwargs.pop("state_dict", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
@ -4457,7 +4439,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
kernel_config = kwargs.pop("kernel_config", None)
|
||||
|
||||
key_mapping = kwargs.pop("key_mapping", None)
|
||||
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
|
||||
# Load models with key mapping
|
||||
if key_mapping is None and any(
|
||||
allowed_name in class_name.__name__.lower() for class_name in cls.__mro__[:-1] for allowed_name in VLMS
|
||||
):
|
||||
@ -4467,32 +4449,31 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
tp_plan = "auto"
|
||||
|
||||
# Not used anymore -- remove them from the kwargs
|
||||
_ = kwargs.pop("mirror", None)
|
||||
_ = kwargs.pop("_fast_init", None)
|
||||
_ = kwargs.pop("low_cpu_mem_usage", None)
|
||||
_ = kwargs.pop("from_tf", None)
|
||||
_ = kwargs.pop("from_flax", None)
|
||||
_ = kwargs.pop("offload_state_dict", None)
|
||||
for name in ["mirror", "_fast_init", "low_cpu_mem_usage", "from_tf", "from_flax", "offload_state_dict"]:
|
||||
_ = kwargs.pop(name, None)
|
||||
|
||||
# For BC on torch_dtype argument
|
||||
if torch_dtype is not None:
|
||||
logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
|
||||
# If both kwargs are provided, use `dtype`
|
||||
dtype = dtype if dtype is not None else torch_dtype
|
||||
|
||||
if is_offline_mode() and not local_files_only:
|
||||
local_files_only = True
|
||||
|
||||
download_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"local_files_only": local_files_only,
|
||||
"token": token,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
}
|
||||
download_kwargs_with_commit = {**download_kwargs, "commit_hash": commit_hash}
|
||||
|
||||
if state_dict is not None and (pretrained_model_name_or_path is not None or gguf_file is not None):
|
||||
raise ValueError(
|
||||
"`state_dict` cannot be passed together with a model name or a `gguf_file`. Use one of the two loading strategies."
|
||||
)
|
||||
if tp_size is not None and tp_plan is None:
|
||||
raise ValueError("tp_plan has to be set when tp_size is passed.")
|
||||
if tp_plan is not None and tp_plan != "auto":
|
||||
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
|
||||
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
|
||||
if tp_plan is not None and device_map is not None:
|
||||
raise ValueError(
|
||||
"`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization."
|
||||
)
|
||||
|
||||
if device_map == "auto" and int(os.environ.get("WORLD_SIZE", "0")):
|
||||
logger.info(
|
||||
@ -4501,221 +4482,84 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
": PartialState().process_index} where PartialState comes from accelerate library"
|
||||
)
|
||||
|
||||
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
|
||||
# `device_map` pointing to the correct device
|
||||
if tp_plan is not None:
|
||||
if device_mesh is None:
|
||||
tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(tp_plan, tp_size=tp_size)
|
||||
else:
|
||||
if device_mesh.ndim > 1:
|
||||
if "tp" not in device_mesh.mesh_dim_names:
|
||||
raise ValueError(
|
||||
"When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. "
|
||||
"Please provide a valid `device_mesh`."
|
||||
)
|
||||
device_mesh = device_mesh["tp"]
|
||||
tp_size = device_mesh.size()
|
||||
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
|
||||
|
||||
if tp_size is None:
|
||||
tp_size = torch.distributed.get_world_size()
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||
FutureWarning,
|
||||
if tp_plan is not None or tp_size is not None: # TP warnings, and setup
|
||||
tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(
|
||||
tp_plan, tp_size=tp_size, device_mesh=device_mesh, device_map=device_map
|
||||
)
|
||||
if token is not None:
|
||||
raise ValueError(
|
||||
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
||||
)
|
||||
token = use_auth_token
|
||||
|
||||
if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs:
|
||||
adapter_kwargs["token"] = token
|
||||
|
||||
if gguf_file is not None and not is_accelerate_available():
|
||||
raise ValueError("accelerate is required when loading a GGUF file `pip install accelerate`.")
|
||||
|
||||
if commit_hash is None:
|
||||
if not isinstance(config, PreTrainedConfig):
|
||||
# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
|
||||
resolved_config_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
CONFIG_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
||||
else:
|
||||
commit_hash = getattr(config, "_commit_hash", None)
|
||||
if adapter_kwargs is None:
|
||||
adapter_kwargs = {}
|
||||
|
||||
if is_peft_available():
|
||||
_adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)
|
||||
|
||||
if _adapter_model_path is None:
|
||||
_adapter_model_path = find_adapter_config_file(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
_commit_hash=commit_hash,
|
||||
**adapter_kwargs,
|
||||
)
|
||||
if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
|
||||
with open(_adapter_model_path, "r", encoding="utf-8") as f:
|
||||
_adapter_model_path = pretrained_model_name_or_path
|
||||
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
|
||||
else:
|
||||
_adapter_model_path = None
|
||||
|
||||
# Potentially detect context manager or global device, and use it (only if no device_map was provided)
|
||||
if device_map is None and not is_deepspeed_zero3_enabled():
|
||||
device_in_context = get_torch_context_manager_or_global_device()
|
||||
if device_in_context == torch.device("meta"):
|
||||
raise RuntimeError(
|
||||
"You are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`.\n"
|
||||
"This is an anti-pattern as `from_pretrained` wants to load existing weights.\nIf you want to initialize an "
|
||||
"empty model on the meta device, use the context manager or global device with `from_config`, or `ModelClass(config)`"
|
||||
)
|
||||
device_map = device_in_context
|
||||
|
||||
# change device_map into a map if we passed an int, a str or a torch.device
|
||||
if isinstance(device_map, torch.device):
|
||||
device_map = {"": device_map}
|
||||
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
||||
try:
|
||||
device_map = {"": torch.device(device_map)}
|
||||
except RuntimeError:
|
||||
raise ValueError(
|
||||
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
|
||||
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
|
||||
)
|
||||
elif isinstance(device_map, int):
|
||||
if device_map < 0:
|
||||
raise ValueError(
|
||||
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
|
||||
)
|
||||
else:
|
||||
device_map = {"": device_map}
|
||||
|
||||
if device_map is not None:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
|
||||
if not is_accelerate_available():
|
||||
raise ValueError(
|
||||
"Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` "
|
||||
"requires `accelerate`. You can install it with `pip install accelerate`"
|
||||
)
|
||||
_adapter_model_path, pretrained_model_name_or_path = maybe_load_adapters(
|
||||
pretrained_model_name_or_path,
|
||||
download_kwargs_with_commit,
|
||||
**adapter_kwargs,
|
||||
)
|
||||
device_map = check_and_set_device_map(device_map) # warn, error and fix the device map
|
||||
|
||||
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
||||
if from_pipeline is not None:
|
||||
user_agent["using_pipeline"] = from_pipeline
|
||||
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
if not isinstance(config, PreTrainedConfig):
|
||||
config_path = config if config is not None else pretrained_model_name_or_path
|
||||
config, model_kwargs = cls.config_class.from_pretrained(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
gguf_file=gguf_file,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
**download_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if "gguf_file" in model_kwargs:
|
||||
model_kwargs.pop("gguf_file")
|
||||
commit_hash = model_kwargs.pop("_commit_hash", commit_hash)
|
||||
else:
|
||||
config = copy.deepcopy(config)
|
||||
model_kwargs = kwargs
|
||||
commit_hash = getattr(config, "_commit_hash", commit_hash)
|
||||
|
||||
download_kwargs_with_commit["commit_hash"] = commit_hash
|
||||
|
||||
# Because some composite configs call super().__init__ before instantiating the sub-configs, we need this call
|
||||
# to correctly redispatch recursively if the kwarg is provided
|
||||
if "attn_implementation" in kwargs:
|
||||
config._attn_implementation = kwargs.pop("attn_implementation")
|
||||
|
||||
transformers_explicit_filename = getattr(config, "transformers_weights", None)
|
||||
|
||||
if transformers_explicit_filename is not None:
|
||||
if not transformers_explicit_filename.endswith(
|
||||
".safetensors"
|
||||
) and not transformers_explicit_filename.endswith(".safetensors.index.json"):
|
||||
raise ValueError(
|
||||
"The transformers file in the config seems to be incorrect: it is neither a safetensors file "
|
||||
"(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
|
||||
f"{transformers_explicit_filename}"
|
||||
)
|
||||
|
||||
hf_quantizer, config, dtype, device_map = get_hf_quantizer(
|
||||
config, quantization_config, dtype, device_map, weights_only, user_agent
|
||||
)
|
||||
|
||||
if gguf_file is not None and hf_quantizer is not None:
|
||||
raise ValueError(
|
||||
"You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub."
|
||||
)
|
||||
|
||||
if (
|
||||
gguf_file
|
||||
and device_map is not None
|
||||
and ((isinstance(device_map, dict) and "disk" in device_map.values()) or "disk" in device_map)
|
||||
):
|
||||
raise RuntimeError(
|
||||
"One or more modules is configured to be mapped to disk. Disk offload is not supported for models "
|
||||
"loaded from GGUF files."
|
||||
)
|
||||
if gguf_file:
|
||||
if hf_quantizer is not None:
|
||||
raise ValueError(
|
||||
"You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub."
|
||||
)
|
||||
if device_map is not None and (
|
||||
(isinstance(device_map, dict) and "disk" in device_map.values()) or "disk" in device_map
|
||||
):
|
||||
raise RuntimeError(
|
||||
"One or more modules is configured to be mapped to disk. Disk offload is not supported for models "
|
||||
"loaded from GGUF files."
|
||||
)
|
||||
|
||||
checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
subfolder=subfolder,
|
||||
variant=variant,
|
||||
gguf_file=gguf_file,
|
||||
use_safetensors=use_safetensors,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
download_kwargs=download_kwargs_with_commit,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
commit_hash=commit_hash,
|
||||
is_remote_code=cls._auto_class is not None,
|
||||
transformers_explicit_filename=transformers_explicit_filename,
|
||||
transformers_explicit_filename=getattr(config, "transformers_weights", None),
|
||||
)
|
||||
|
||||
is_quantized = hf_quantizer is not None
|
||||
is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None
|
||||
|
||||
# Just a helpful message in case we try to load safetensors files coming from old Transformers tf/flax classes
|
||||
if is_from_file and checkpoint_files[0].endswith(".safetensors"):
|
||||
with safe_open(checkpoint_files[0], framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None and metadata.get("format") in ["tf", "flax"]:
|
||||
logger.warning(
|
||||
"The safetensors checkpoint found has format `tf` or `flax`. This mean that the keys will very"
|
||||
"likely not match to the model you are trying to load, and will be newly initialized. If it's the case "
|
||||
"another warning will be raised later. Consider converting your checkpoint to the correct format."
|
||||
)
|
||||
|
||||
if gguf_file:
|
||||
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
|
||||
@ -4746,52 +4590,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
# make sure we use the model's config since the __init__ call might have copied it
|
||||
config = model.config
|
||||
|
||||
# Find fp32 modules if needed
|
||||
keep_in_fp32_modules = []
|
||||
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
|
||||
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
|
||||
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
|
||||
if model._keep_in_fp32_modules is not None and (
|
||||
dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
|
||||
):
|
||||
keep_in_fp32_modules.extend(model._keep_in_fp32_modules)
|
||||
|
||||
if model._keep_in_fp32_modules_strict is not None and (dtype == torch.float16 or dtype == torch.bfloat16):
|
||||
keep_in_fp32_modules.extend(model._keep_in_fp32_modules_strict)
|
||||
|
||||
keep_in_fp32_regex = None
|
||||
if keep_in_fp32_modules:
|
||||
# We need to match exact layers, so we add either `.` on each side, or start/end of string
|
||||
keep_in_fp32_regex = re.compile("|".join([rf"((^|\.){module}($|\.))" for module in keep_in_fp32_modules]))
|
||||
|
||||
if hf_quantizer is not None:
|
||||
# Regex to keep a fixed dtype
|
||||
keep_in_fp32_regex = get_keep_in_fp32_regex(model, hf_quantizer, dtype)
|
||||
if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
|
||||
hf_quantizer.preprocess_model(
|
||||
model=model,
|
||||
device_map=device_map,
|
||||
keep_in_fp32_modules=model._keep_in_fp32_modules,
|
||||
config=config,
|
||||
checkpoint_files=checkpoint_files,
|
||||
use_kernels=use_kernels,
|
||||
)
|
||||
# We store the original dtype for quantized models as we cannot easily retrieve it
|
||||
# once the weights have been quantized
|
||||
# Note that once you have loaded a quantized model, you can't change its dtype so this will
|
||||
# remain a single source of truth
|
||||
original_dtype = dtype if dtype is not None else torch.get_default_dtype()
|
||||
|
||||
def _assign_original_dtype(module):
|
||||
for child in module.children():
|
||||
if isinstance(child, PreTrainedModel):
|
||||
child.config._pre_quantization_dtype = original_dtype
|
||||
_assign_original_dtype(child)
|
||||
|
||||
config._pre_quantization_dtype = original_dtype
|
||||
_assign_original_dtype(model)
|
||||
|
||||
# Torchao needs access to all metadata later
|
||||
if hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO:
|
||||
hf_quantizer.set_metadata(checkpoint_files)
|
||||
|
||||
if _torch_distributed_available and device_mesh is not None:
|
||||
if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
|
||||
model = distribute_model(model, distributed_config, device_mesh, tp_size)
|
||||
|
||||
# Prepare the full device map
|
||||
@ -4820,109 +4631,34 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
weights_only=weights_only,
|
||||
)
|
||||
|
||||
# make sure token embedding weights are still tied if needed
|
||||
model.tie_weights()
|
||||
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
# check if using kernels
|
||||
if use_kernels:
|
||||
if not is_kernels_available():
|
||||
raise ValueError(
|
||||
"Kernels are not available. To use kernels, please install kernels using `pip install kernels`"
|
||||
)
|
||||
from kernels import use_kernel_mapping
|
||||
|
||||
if kernel_config is not None and isinstance(kernel_config, KernelConfig):
|
||||
# This will make sure the mapping is valid, and the layers are registered in the model
|
||||
kernel_config.sanitize_kernel_mapping(model)
|
||||
|
||||
# This will create a compatible mapping for the model with the kernels library
|
||||
kernel_config.create_compatible_mapping(model)
|
||||
|
||||
# This is a context manager to override the default kernel mapping
|
||||
# We are calling kernelize inside this context manager using the use_kernels setter
|
||||
with use_kernel_mapping(kernel_config.kernel_mapping):
|
||||
model.use_kernels = True
|
||||
# We use the default kernel mapping in .integrations.hub_kernels
|
||||
else:
|
||||
model.use_kernels = True
|
||||
model.tie_weights() # make sure token embedding weights are still tied if needed
|
||||
model.eval() # Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.set_use_kernels(use_kernels, kernel_config)
|
||||
|
||||
# If it is a model with generation capabilities, attempt to load generation files (generation config,
|
||||
# custom generate function)
|
||||
if model.can_generate() and generation_config is not None:
|
||||
logger.info("The user-defined `generation_config` will be used to override the default generation config.")
|
||||
model.generation_config = model.generation_config.from_dict(generation_config.to_dict())
|
||||
elif model.can_generate() and pretrained_model_name_or_path is not None:
|
||||
repo_loading_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"local_files_only": local_files_only,
|
||||
"token": token,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
if model.can_generate() and hasattr(model, "adjust_generation_fn"):
|
||||
model.adjust_generation_fn(
|
||||
generation_config,
|
||||
from_auto_class,
|
||||
from_pipeline,
|
||||
pretrained_model_name_or_path,
|
||||
**download_kwargs,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
}
|
||||
# Load generation config
|
||||
try:
|
||||
model.generation_config = GenerationConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
**repo_loading_kwargs,
|
||||
)
|
||||
except OSError:
|
||||
logger.info(
|
||||
"Generation config file not found, using a generation config created from the model config."
|
||||
)
|
||||
pass
|
||||
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
|
||||
if hasattr(model, "load_custom_generate"):
|
||||
try:
|
||||
custom_generate = model.load_custom_generate(
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
|
||||
)
|
||||
model.generate = functools.partial(custom_generate, model=model)
|
||||
except OSError: # there is no custom generate function
|
||||
pass
|
||||
)
|
||||
|
||||
# Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
|
||||
# harm performances)
|
||||
# for device_map="auto" : dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
|
||||
# harm performances).
|
||||
if device_map is not None and device_mesh is None:
|
||||
device_map_kwargs = {
|
||||
"device_map": device_map,
|
||||
"offload_dir": offload_folder,
|
||||
"offload_index": offload_index,
|
||||
"offload_buffers": offload_buffers,
|
||||
}
|
||||
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
||||
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
||||
# For HQQ method we force-set the hooks for single GPU envs
|
||||
if (
|
||||
"force_hooks" in inspect.signature(dispatch_model).parameters
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
|
||||
):
|
||||
device_map_kwargs["force_hooks"] = True
|
||||
if (
|
||||
hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8
|
||||
and isinstance(device_map, dict)
|
||||
and ("cpu" in device_map.values() or "disk" in device_map.values())
|
||||
):
|
||||
device_map_kwargs["offload_buffers"] = True
|
||||
|
||||
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
||||
dispatch_model(model, **device_map_kwargs)
|
||||
accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers)
|
||||
|
||||
if hf_quantizer is not None:
|
||||
model.hf_quantizer = hf_quantizer
|
||||
hf_quantizer.postprocess_model(model, config=config)
|
||||
hf_quantizer.postprocess_model(model, config=config) # usually a no-op
|
||||
|
||||
if _adapter_model_path is not None:
|
||||
adapter_kwargs["key_mapping"] = key_mapping
|
||||
adapter_kwargs["key_mapping"] = key_mapping # TODO: Dynamic weight loader for adapters
|
||||
model.load_adapter(
|
||||
_adapter_model_path,
|
||||
adapter_name=adapter_name,
|
||||
@ -5081,7 +4817,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
QuantizationMethod.QUARK,
|
||||
}
|
||||
|
||||
# Get all the keys of the state dicts that we have to initialize the model
|
||||
# Get all the keys of the state dicts that we have to initialize the model with
|
||||
if sharded_metadata is not None:
|
||||
original_checkpoint_keys = sharded_metadata["all_checkpoint_keys"]
|
||||
elif state_dict is not None:
|
||||
@ -5152,43 +4888,16 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
disk_only_shard_files = []
|
||||
# Prepare parameters offloading if needed
|
||||
if device_map is not None and "disk" in device_map.values():
|
||||
if disk_offload_folder is not None:
|
||||
os.makedirs(disk_offload_folder, exist_ok=True)
|
||||
is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")
|
||||
if disk_offload_folder is None and not is_offloaded_safetensors:
|
||||
raise ValueError(
|
||||
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
|
||||
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
|
||||
" offers the weights in this format."
|
||||
)
|
||||
if is_offloaded_safetensors:
|
||||
param_device_map = expand_device_map(device_map, checkpoint_keys)
|
||||
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
||||
if sharded_metadata is None:
|
||||
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
|
||||
else:
|
||||
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
|
||||
# Fix the weight map keys according to the key mapping
|
||||
weight_map = {
|
||||
key_renaming_mapping[k]: v
|
||||
for k, v in sharded_metadata["weight_map"].items()
|
||||
if k in key_renaming_mapping
|
||||
}
|
||||
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
|
||||
# Find potential checkpoints containing only offloaded weights
|
||||
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
|
||||
disk_offload_index = {
|
||||
name: {
|
||||
"safetensors_file": file,
|
||||
"weight_name": reverse_key_renaming_mapping[name],
|
||||
"dtype": str_dtype,
|
||||
}
|
||||
for name, file in weight_map.items()
|
||||
if param_device_map[name] == "disk"
|
||||
}
|
||||
else:
|
||||
disk_offload_index = {}
|
||||
|
||||
disk_offload_index, disk_only_shard_files, is_offloaded_safetensors = accelerate_disk_offload(
|
||||
disk_offload_folder,
|
||||
checkpoint_files,
|
||||
device_map,
|
||||
checkpoint_keys,
|
||||
key_renaming_mapping,
|
||||
sharded_metadata,
|
||||
dtype,
|
||||
reverse_key_renaming_mapping,
|
||||
)
|
||||
# To be able to iterate, even if we don't use it if the state_dict is already provided
|
||||
elif state_dict is not None:
|
||||
checkpoint_files = [""]
|
||||
@ -5286,6 +4995,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
missing_keys, unexpected_keys, loading_task_model_from_base_state_dict
|
||||
)
|
||||
|
||||
# TODO: separate this in another function: it's not core....
|
||||
# All potential warnings/infos
|
||||
if len(error_msgs) > 0:
|
||||
error_msg = "\n\t".join(error_msgs)
|
||||
@ -5787,19 +5497,6 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
|
||||
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
|
||||
|
||||
|
||||
def get_disk_only_shard_files(device_map, weight_map):
|
||||
"""
|
||||
Returns the list of shard files containing only weights offloaded to disk.
|
||||
"""
|
||||
files_content = collections.defaultdict(list)
|
||||
for weight_name, filename in weight_map.items():
|
||||
while len(weight_name) > 0 and weight_name not in device_map:
|
||||
weight_name = ".".join(weight_name.split(".")[:-1])
|
||||
files_content[filename].append(device_map[weight_name])
|
||||
|
||||
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
|
||||
|
||||
|
||||
class AttentionInterface(GeneralInterface):
|
||||
"""
|
||||
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
|
||||
|
@ -31,6 +31,13 @@ else:
|
||||
logger = logging.get_logger(__file__)
|
||||
|
||||
|
||||
def _assign_original_dtype(module, original_dtype):
|
||||
for child in module.children():
|
||||
if isinstance(child, PreTrainedModel):
|
||||
child.config._pre_quantization_dtype = original_dtype
|
||||
_assign_original_dtype(child, original_dtype)
|
||||
|
||||
|
||||
class HfQuantizer(ABC):
|
||||
"""
|
||||
Abstract class of the HuggingFace quantizer. Supports for now quantizing HF transformers models for inference and/or quantization.
|
||||
@ -206,7 +213,7 @@ class HfQuantizer(ABC):
|
||||
"updates the tp plan for the scales"
|
||||
return config
|
||||
|
||||
def preprocess_model(self, model: "PreTrainedModel", **kwargs):
|
||||
def preprocess_model(self, model: "PreTrainedModel", config, dtype=None, checkpoint_files=None, **kwargs):
|
||||
"""
|
||||
Setting model attributes and/or converting model before weights loading. At this point
|
||||
the model should be initialized on the meta device so you can freely manipulate the skeleton
|
||||
@ -222,7 +229,18 @@ class HfQuantizer(ABC):
|
||||
model.quantization_method = self.quantization_config.quant_method
|
||||
if self.pre_quantized:
|
||||
self._convert_model_for_quantization(model)
|
||||
return self._process_model_before_weight_loading(model, **kwargs)
|
||||
self._process_model_before_weight_loading(model, **kwargs)
|
||||
|
||||
# We store the original dtype for quantized models as we cannot easily retrieve it
|
||||
# once the weights have been quantized
|
||||
# Note that once you have loaded a quantized model, you can't change its dtype so this will
|
||||
# remain a single source of truth
|
||||
original_dtype = dtype if dtype is not None else torch.get_default_dtype()
|
||||
config._pre_quantization_dtype = original_dtype
|
||||
_assign_original_dtype(model, original_dtype)
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
def postprocess_model(self, model: "PreTrainedModel", **kwargs):
|
||||
"""
|
||||
|
@ -79,9 +79,6 @@ class AqlmHfQuantizer(HfQuantizer):
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
aqlm_supports_training = version.parse(importlib.metadata.version("aqlm")) >= version.parse("1.0.2")
|
||||
|
@ -69,9 +69,6 @@ class BitNetHfQuantizer(HfQuantizer):
|
||||
"This is not supported. Please remove the CPU or disk device from the device_map."
|
||||
)
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
|
@ -137,9 +137,6 @@ class EetqHfQuantizer(HfQuantizer):
|
||||
module._buffers[tensor_name] = new_value.to(target_device)
|
||||
module.register("weight_scales", weight_scale.to(target_device))
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
|
@ -192,9 +192,6 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
|
||||
|
||||
del param_name
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
|
@ -167,9 +167,6 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
|
||||
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
|
||||
from ..integrations import FP8Linear
|
||||
|
||||
|
@ -140,9 +140,6 @@ class FPQuantHfQuantizer(HfQuantizer):
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
|
||||
from fp_quant import FPQuantLinear
|
||||
|
||||
|
@ -157,9 +157,6 @@ class QuantoHfQuantizer(HfQuantizer):
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||
return model
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
return True
|
||||
|
@ -88,9 +88,6 @@ class QuarkHfQuantizer(HfQuantizer):
|
||||
|
||||
_load_parameter_into_model(model, param_name, param.to(param_device))
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return False
|
||||
|
||||
|
@ -79,9 +79,6 @@ class SpQRHfQuantizer(HfQuantizer):
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
@property
|
||||
def is_trainable(self):
|
||||
return False
|
||||
|
@ -350,6 +350,22 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
|
||||
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
|
||||
|
||||
def preprocess_model(self, model: "PreTrainedModel", config, dtype=None, checkpoint_files=None, **kwargs):
|
||||
"""
|
||||
Setting model attributes and/or converting model before weights loading. At this point
|
||||
the model should be initialized on the meta device so you can freely manipulate the skeleton
|
||||
of the model in order to replace modules in-place. Make sure to override the abstract method `_process_model_before_weight_loading`.
|
||||
|
||||
Args:
|
||||
model (`~transformers.PreTrainedModel`):
|
||||
The model to quantize
|
||||
kwargs (`dict`, *optional*):
|
||||
The keyword arguments that are passed along `_process_model_before_weight_loading`.
|
||||
"""
|
||||
super().preprocess_model(model, config, dtype, checkpoint_files, **kwargs)
|
||||
# Torchao needs access to all metadata later
|
||||
model.set_metadata(checkpoint_files)
|
||||
|
||||
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||
"""No process required for torchao quantized model"""
|
||||
if self.quantization_config.quant_type == "autoquant":
|
||||
|
@ -88,9 +88,6 @@ class VptqHfQuantizer(HfQuantizer):
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
return False
|
||||
|
@ -65,7 +65,7 @@ from .image_processing_utils import BaseImageProcessor
|
||||
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
|
||||
from .integrations.tpu import tpu_spmd_dataloader
|
||||
from .modelcard import TrainingSummary
|
||||
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
|
||||
from .modeling_utils import PreTrainedModel, unwrap_model
|
||||
from .models.auto.modeling_auto import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
MODEL_MAPPING_NAMES,
|
||||
@ -123,6 +123,7 @@ from .trainer_utils import (
|
||||
find_executable_batch_size,
|
||||
get_last_checkpoint,
|
||||
has_length,
|
||||
load_sharded_checkpoint,
|
||||
neftune_post_forward_hook,
|
||||
number_of_arguments,
|
||||
seed_worker,
|
||||
|
@ -19,18 +19,23 @@ import copy
|
||||
import functools
|
||||
import gc
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
ExplicitEnum,
|
||||
check_torch_load_is_safe,
|
||||
is_psutil_available,
|
||||
is_torch_available,
|
||||
is_torch_cuda_available,
|
||||
@ -47,6 +52,7 @@ from .utils import (
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
|
||||
|
||||
def seed_worker(worker_id: int, num_workers: int, rank: int):
|
||||
@ -873,3 +879,79 @@ def check_target_module_exists(optim_target_modules, key: str, return_is_regex:
|
||||
return target_module_found, is_regex
|
||||
|
||||
return target_module_found
|
||||
|
||||
|
||||
def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
||||
"""
|
||||
This is the same as
|
||||
[`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
|
||||
but for a sharded checkpoint.
|
||||
|
||||
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
|
||||
loaded in the model.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model in which to load the checkpoint.
|
||||
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
|
||||
strict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
|
||||
prefer_safe (`bool`, *optional*, defaults to `False`):
|
||||
If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
|
||||
safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
|
||||
|
||||
Returns:
|
||||
`NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
|
||||
- `missing_keys` is a list of str containing the missing keys
|
||||
- `unexpected_keys` is a list of str containing the unexpected keys
|
||||
"""
|
||||
# Load the index
|
||||
index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
|
||||
safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
|
||||
|
||||
index_present = os.path.isfile(index_file)
|
||||
safe_index_present = os.path.isfile(safe_index_file)
|
||||
|
||||
if not index_present and not safe_index_present:
|
||||
filenames = (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME)
|
||||
raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")
|
||||
|
||||
load_safe = safe_index_present and (prefer_safe or not index_present)
|
||||
load_index = safe_index_file if load_safe else index_file
|
||||
|
||||
with open(load_index, "r", encoding="utf-8") as f:
|
||||
index = json.load(f)
|
||||
|
||||
shard_files = list(set(index["weight_map"].values()))
|
||||
|
||||
# If strict=True, error before loading any of the state dicts.
|
||||
# TODO: Here, update the weigth map with the config.dynamic_weight_conversion
|
||||
loaded_keys = index["weight_map"].keys()
|
||||
model_keys = model.state_dict().keys()
|
||||
missing_keys = [key for key in model_keys if key not in loaded_keys]
|
||||
unexpected_keys = [key for key in loaded_keys if key not in model_keys]
|
||||
if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
|
||||
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
|
||||
if len(missing_keys) > 0:
|
||||
str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
|
||||
error_message += f"\nMissing key(s): {str_missing_keys}."
|
||||
if len(unexpected_keys) > 0:
|
||||
str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
|
||||
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
if load_safe:
|
||||
loader = safe_load_file
|
||||
else:
|
||||
check_torch_load_is_safe()
|
||||
loader = partial(torch.load, map_location="cpu", weights_only=True)
|
||||
|
||||
for shard_file in shard_files:
|
||||
state_dict = loader(os.path.join(folder, shard_file))
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Make sure memory is freed before we load the next state dict.
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
# Return the same thing as PyTorch load_state_dict function.
|
||||
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
|
@ -23,7 +23,7 @@ import tempfile
|
||||
import warnings
|
||||
from concurrent import futures
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, TypedDict, Union
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
@ -75,6 +75,18 @@ CHAT_TEMPLATE_DIR = "additional_chat_templates"
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DownloadKwargs(TypedDict, total=False):
|
||||
cache_dir: Optional[Union[str, os.PathLike]]
|
||||
force_download: bool
|
||||
proxies: Optional[dict[str, str]]
|
||||
local_files_only: bool
|
||||
token: Optional[Union[str, bool]]
|
||||
revision: Optional[str]
|
||||
subfolder: str
|
||||
commit_hash: Optional[str]
|
||||
|
||||
|
||||
_is_offline_mode = huggingface_hub.constants.HF_HUB_OFFLINE
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user