This commit is contained in:
Arthur
2025-10-17 10:17:26 +02:00
parent 8ca058d64c
commit a08b927826

View File

@ -125,6 +125,7 @@ from .utils.import_utils import (
is_torchdynamo_compiling,
)
from .utils.quantization_config import QuantizationMethod
from .conversion_mapping import _checkpoint_conversion_mapping as DEFAULT_WEIGHT_CONVERSION_MAPPING
if is_accelerate_available():
@ -4390,9 +4391,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
commit_hash = getattr(config, "_commit_hash", commit_hash)
download_kwargs_with_commit["commit_hash"] = commit_hash
weight_conversion_profile = bool(model_kwargs.pop("weight_conversion_profile", False))
profile_kwarg = model_kwargs.pop("profile", None)
profile_weight_conversion = kwargs.pop("profile_weight_conversion")
# Because some composite configs call super().__init__ before instantiating the sub-configs, we need this call
# to correctly redispatch recursively if the kwarg is provided
@ -4406,12 +4405,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
weight_conversions: Optional[list[WeightConversion]] = None
model_type = getattr(config, "model_type", None)
if model_type is not None:
from .conversion_mapping import _checkpoint_conversion_mapping as DEFAULT_WEIGHT_CONVERSION_MAPPING
conversions = DEFAULT_WEIGHT_CONVERSION_MAPPING.get(model_type)
if conversions:
weight_conversions = _clone_weight_conversions(conversions)
weight_conversions = DEFAULT_WEIGHT_CONVERSION_MAPPING.get(model_type)
profile_weight_conversion = kwargs.pop("profile_weight_conversion")
if gguf_file:
if hf_quantizer is not None:
@ -4704,16 +4699,14 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
}
if weight_mapping:
if state_dict is None:
merged_state_dict = {}
for file in checkpoint_files:
merged_state_dict.update(
load_state_dict(file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only)
)
state_dict = merged_state_dict
merged_state_dict = {}
for file in checkpoint_files: # TODO this is sequential but supposed to be fast
merged_state_dict.update(
load_state_dict(file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only)
)
tp_plan = getattr(model, "_tp_plan", None)
new_state_dict, conversion_ops = convert_state_dict(
model, state_dict, weight_mapping, tp_plan, hf_quantizer, profile=profile_weight_conversion
model, merged_state_dict, weight_mapping, tp_plan, hf_quantizer, profile=profile_weight_conversion
)
# Get all the keys of the state dicts that we have to initialize the model with