mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
cleanup
This commit is contained in:
@ -125,6 +125,7 @@ from .utils.import_utils import (
|
||||
is_torchdynamo_compiling,
|
||||
)
|
||||
from .utils.quantization_config import QuantizationMethod
|
||||
from .conversion_mapping import _checkpoint_conversion_mapping as DEFAULT_WEIGHT_CONVERSION_MAPPING
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@ -4390,9 +4391,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
commit_hash = getattr(config, "_commit_hash", commit_hash)
|
||||
|
||||
download_kwargs_with_commit["commit_hash"] = commit_hash
|
||||
|
||||
weight_conversion_profile = bool(model_kwargs.pop("weight_conversion_profile", False))
|
||||
profile_kwarg = model_kwargs.pop("profile", None)
|
||||
profile_weight_conversion = kwargs.pop("profile_weight_conversion")
|
||||
|
||||
# Because some composite configs call super().__init__ before instantiating the sub-configs, we need this call
|
||||
# to correctly redispatch recursively if the kwarg is provided
|
||||
@ -4406,12 +4405,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
weight_conversions: Optional[list[WeightConversion]] = None
|
||||
model_type = getattr(config, "model_type", None)
|
||||
if model_type is not None:
|
||||
from .conversion_mapping import _checkpoint_conversion_mapping as DEFAULT_WEIGHT_CONVERSION_MAPPING
|
||||
conversions = DEFAULT_WEIGHT_CONVERSION_MAPPING.get(model_type)
|
||||
if conversions:
|
||||
weight_conversions = _clone_weight_conversions(conversions)
|
||||
weight_conversions = DEFAULT_WEIGHT_CONVERSION_MAPPING.get(model_type)
|
||||
|
||||
profile_weight_conversion = kwargs.pop("profile_weight_conversion")
|
||||
|
||||
if gguf_file:
|
||||
if hf_quantizer is not None:
|
||||
@ -4704,16 +4699,14 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
}
|
||||
|
||||
if weight_mapping:
|
||||
if state_dict is None:
|
||||
merged_state_dict = {}
|
||||
for file in checkpoint_files:
|
||||
merged_state_dict.update(
|
||||
load_state_dict(file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only)
|
||||
)
|
||||
state_dict = merged_state_dict
|
||||
merged_state_dict = {}
|
||||
for file in checkpoint_files: # TODO this is sequential but supposed to be fast
|
||||
merged_state_dict.update(
|
||||
load_state_dict(file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only)
|
||||
)
|
||||
tp_plan = getattr(model, "_tp_plan", None)
|
||||
new_state_dict, conversion_ops = convert_state_dict(
|
||||
model, state_dict, weight_mapping, tp_plan, hf_quantizer, profile=profile_weight_conversion
|
||||
model, merged_state_dict, weight_mapping, tp_plan, hf_quantizer, profile=profile_weight_conversion
|
||||
)
|
||||
|
||||
# Get all the keys of the state dicts that we have to initialize the model with
|
||||
|
Reference in New Issue
Block a user