mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
6 Commits
v4.52.2
...
hqq_serial
Author | SHA1 | Date | |
---|---|---|---|
a8704d266e | |||
bc9cb55d8d | |||
f2ea032e40 | |||
75dfe0a9c6 | |||
fa8a9f55c0 | |||
ff40f1a9e1 |
@ -97,7 +97,7 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
|
||||
|
||||
# Convert quantization_config to layer-wise config
|
||||
skip_modules = quantization_config.skip_modules
|
||||
quant_config = quantization_config.to_dict()
|
||||
quant_config = quantization_config.quant_config
|
||||
linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))
|
||||
|
||||
if any(key in linear_tags for key in quant_config.keys()):
|
||||
@ -113,7 +113,11 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
|
||||
)
|
||||
|
||||
# We store quantization config as linear_tag -> hqq quant config
|
||||
model.config.quantization_config = patch_params
|
||||
model.config.quantization_config = {
|
||||
"quant_config": quant_config,
|
||||
"quant_method": quantization_config.quant_method,
|
||||
"skip_modules": skip_modules,
|
||||
}
|
||||
|
||||
if not has_been_replaced:
|
||||
logger.warning("No linear modules were found in your model for quantization.")
|
||||
|
@ -56,6 +56,7 @@ from .pytorch_utils import ( # noqa: F401
|
||||
prune_linear_layer,
|
||||
)
|
||||
from .quantizers import AutoHfQuantizer, HfQuantizer
|
||||
from .quantizers.quantizer_hqq import HqqHfQuantizer
|
||||
from .quantizers.quantizers_utils import get_module_from_name
|
||||
from .safetensors_conversion import auto_conversion
|
||||
from .utils import (
|
||||
@ -858,8 +859,9 @@ def _load_state_dict_into_meta_model(
|
||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||
|
||||
for param_name, param in state_dict.items():
|
||||
# print('param_name', param_name, param_name in loaded_state_dict_keys, param_name in expected_keys)
|
||||
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
|
||||
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
|
||||
if param_name not in loaded_state_dict_keys: # or param_name not in expected_keys: #TODO @mobicham
|
||||
continue
|
||||
|
||||
if param_name.startswith(start_prefix):
|
||||
@ -891,12 +893,21 @@ def _load_state_dict_into_meta_model(
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
||||
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
||||
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
||||
old_param = model
|
||||
splits = param_name.split(".")
|
||||
for split in splits:
|
||||
old_param = getattr(old_param, split)
|
||||
if old_param is None:
|
||||
break
|
||||
|
||||
# TODO @mobicham: We need this for Hqq Quantizer otherwise it would break because state_dict fields (W_q, etc.) are not in nn.Linear
|
||||
check_old_param = True
|
||||
if is_quantized:
|
||||
if isinstance(hf_quantizer, HqqHfQuantizer):
|
||||
check_old_param, old_param = False, None
|
||||
|
||||
if check_old_param:
|
||||
old_param = model
|
||||
splits = param_name.split(".")
|
||||
for split in splits:
|
||||
old_param = getattr(old_param, split)
|
||||
if old_param is None:
|
||||
break
|
||||
|
||||
if old_param is not None:
|
||||
if dtype is None:
|
||||
param = param.to(old_param.dtype)
|
||||
@ -3725,6 +3736,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
from_pt = not (from_tf | from_flax)
|
||||
|
||||
# load pt weights early so that we know which dtype to init the model under
|
||||
|
||||
if from_pt:
|
||||
if not is_sharded and state_dict is None:
|
||||
# Time to load the checkpoint
|
||||
@ -4181,7 +4193,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
value = torch.empty(*param.size(), dtype=target_dtype)
|
||||
if (
|
||||
not is_quantized
|
||||
or getattr(hf_quantizer, "requires_parameters_quantization", False)
|
||||
or (getattr(hf_quantizer, "requires_parameters_quantization", False))
|
||||
or not hf_quantizer.check_quantized_param(
|
||||
model, param_value=value, param_name=key, state_dict={}
|
||||
)
|
||||
|
@ -91,6 +91,14 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
else:
|
||||
self.using_multi_gpu = len(set(device_map.values())) > 1
|
||||
|
||||
def update_missing_keys(
|
||||
self, model: "PreTrainedModel", missing_keys: List[str], prefix: str, **kwargs
|
||||
) -> List[str]:
|
||||
if self.pre_quantized:
|
||||
return [key for key in missing_keys if ("weight" not in key)]
|
||||
else:
|
||||
return missing_keys
|
||||
|
||||
def check_quantized_param(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
@ -99,9 +107,18 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
if is_hqq_available():
|
||||
from hqq.core.quantize import HQQLinear
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
|
||||
if self.pre_quantized:
|
||||
return (
|
||||
(isinstance(module, torch.nn.Linear) or isinstance(module, HQQLinear))
|
||||
and tensor_name != "weight"
|
||||
and tensor_name != "bias"
|
||||
)
|
||||
else:
|
||||
return isinstance(module, torch.nn.Linear) and tensor_name == "weight"
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
@ -122,13 +139,54 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
from hqq.core.quantize import HQQLinear
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
layer_name = param_name.replace(".weight", "").replace(".bias", "")
|
||||
layer_name = ".".join(param_name.split(".")[:-1])
|
||||
parent_module = find_parent(model, layer_name)
|
||||
node = layer_name.split(".")[-1]
|
||||
|
||||
# Step 0: set module state_dict
|
||||
module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key}
|
||||
# print("create_quantized_param | ", 'layer_name', layer_name, type(module), hasattr(module, "quant_config")) #model.layers.0.mlp.down_proj
|
||||
|
||||
# set module state_dict
|
||||
module_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if layer_name + "." in k:
|
||||
module_state_dict[k.split(".")[-1]] = v
|
||||
if unexpected_keys is not None and k in unexpected_keys:
|
||||
unexpected_keys.remove(k)
|
||||
|
||||
if self.pre_quantized:
|
||||
if isinstance(module, HQQLinear):
|
||||
return
|
||||
else:
|
||||
hqq_layer = HQQLinear(
|
||||
linear_layer=None,
|
||||
quant_config=None, # module.quant_config
|
||||
compute_dtype=self.torch_dtype,
|
||||
device=target_device,
|
||||
)
|
||||
|
||||
try:
|
||||
hqq_layer.load_state_dict(module_state_dict)
|
||||
except Exception:
|
||||
# TODO @mobicham: Llama3 break with model.layers.28.mlp.down_proj because its parameters are split across 2 safetensors. How to fix this?
|
||||
# Currently setting a fake layer so that loading doesn't break
|
||||
print("Error loading, setting a fake layer for", layer_name, module_state_dict.keys())
|
||||
hqq_layer = HQQLinear(
|
||||
torch.nn.Linear(in_features=module.in_features, out_features=module.out_features, bias=False),
|
||||
module.quant_config,
|
||||
compute_dtype=self.torch_dtype,
|
||||
device=target_device,
|
||||
del_orig=True,
|
||||
)
|
||||
|
||||
if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
|
||||
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
|
||||
|
||||
if self.using_multi_gpu:
|
||||
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
|
||||
|
||||
setattr(parent_module, node, hqq_layer)
|
||||
torch.cuda.empty_cache()
|
||||
return
|
||||
|
||||
# Step 1: populate module with weight/bias from module state dict
|
||||
for key in module_state_dict:
|
||||
@ -136,7 +194,6 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
|
||||
# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
|
||||
# directly doesn't work.
|
||||
|
||||
if hasattr(module, "quant_config"):
|
||||
hqq_layer = HQQLinear(
|
||||
module,
|
||||
@ -193,7 +250,7 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
|
@ -258,12 +258,26 @@ class HqqConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config: Dict[str, Any]):
|
||||
"""
|
||||
Override from_dict, used in AutoQuantizationConfig.from_dict in quantizers/auto.py
|
||||
"""
|
||||
instance = cls()
|
||||
instance.quant_config = config["quant_config"]
|
||||
instance.skip_modules = config["skip_modules"]
|
||||
return instance
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Returns:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||
"""
|
||||
return self.quant_config
|
||||
return {
|
||||
"quant_config": self.quant_config,
|
||||
"quant_method": self.quant_method,
|
||||
"skip_modules": self.skip_modules,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
config_dict = self.to_dict()
|
||||
|
Reference in New Issue
Block a user