Compare commits

...

6 Commits

4 changed files with 105 additions and 18 deletions

View File

@ -97,7 +97,7 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
# Convert quantization_config to layer-wise config
skip_modules = quantization_config.skip_modules
quant_config = quantization_config.to_dict()
quant_config = quantization_config.quant_config
linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))
if any(key in linear_tags for key in quant_config.keys()):
@ -113,7 +113,11 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
)
# We store quantization config as linear_tag -> hqq quant config
model.config.quantization_config = patch_params
model.config.quantization_config = {
"quant_config": quant_config,
"quant_method": quantization_config.quant_method,
"skip_modules": skip_modules,
}
if not has_been_replaced:
logger.warning("No linear modules were found in your model for quantization.")

View File

@ -56,6 +56,7 @@ from .pytorch_utils import ( # noqa: F401
prune_linear_layer,
)
from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizer_hqq import HqqHfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
from .safetensors_conversion import auto_conversion
from .utils import (
@ -858,8 +859,9 @@ def _load_state_dict_into_meta_model(
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
for param_name, param in state_dict.items():
# print('param_name', param_name, param_name in loaded_state_dict_keys, param_name in expected_keys)
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
if param_name not in loaded_state_dict_keys: # or param_name not in expected_keys: #TODO @mobicham
continue
if param_name.startswith(start_prefix):
@ -891,12 +893,21 @@ def _load_state_dict_into_meta_model(
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
if old_param is None:
break
# TODO @mobicham: We need this for Hqq Quantizer otherwise it would break because state_dict fields (W_q, etc.) are not in nn.Linear
check_old_param = True
if is_quantized:
if isinstance(hf_quantizer, HqqHfQuantizer):
check_old_param, old_param = False, None
if check_old_param:
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
if old_param is None:
break
if old_param is not None:
if dtype is None:
param = param.to(old_param.dtype)
@ -3725,6 +3736,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
from_pt = not (from_tf | from_flax)
# load pt weights early so that we know which dtype to init the model under
if from_pt:
if not is_sharded and state_dict is None:
# Time to load the checkpoint
@ -4181,7 +4193,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
value = torch.empty(*param.size(), dtype=target_dtype)
if (
not is_quantized
or getattr(hf_quantizer, "requires_parameters_quantization", False)
or (getattr(hf_quantizer, "requires_parameters_quantization", False))
or not hf_quantizer.check_quantized_param(
model, param_value=value, param_name=key, state_dict={}
)

View File

@ -91,6 +91,14 @@ class HqqHfQuantizer(HfQuantizer):
else:
self.using_multi_gpu = len(set(device_map.values())) > 1
def update_missing_keys(
self, model: "PreTrainedModel", missing_keys: List[str], prefix: str, **kwargs
) -> List[str]:
if self.pre_quantized:
return [key for key in missing_keys if ("weight" not in key)]
else:
return missing_keys
def check_quantized_param(
self,
model: "PreTrainedModel",
@ -99,9 +107,18 @@ class HqqHfQuantizer(HfQuantizer):
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
if is_hqq_available():
from hqq.core.quantize import HQQLinear
module, tensor_name = get_module_from_name(model, param_name)
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
if self.pre_quantized:
return (
(isinstance(module, torch.nn.Linear) or isinstance(module, HQQLinear))
and tensor_name != "weight"
and tensor_name != "bias"
)
else:
return isinstance(module, torch.nn.Linear) and tensor_name == "weight"
def create_quantized_param(
self,
@ -122,13 +139,54 @@ class HqqHfQuantizer(HfQuantizer):
from hqq.core.quantize import HQQLinear
module, tensor_name = get_module_from_name(model, param_name)
layer_name = param_name.replace(".weight", "").replace(".bias", "")
layer_name = ".".join(param_name.split(".")[:-1])
parent_module = find_parent(model, layer_name)
node = layer_name.split(".")[-1]
# Step 0: set module state_dict
module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key}
# print("create_quantized_param | ", 'layer_name', layer_name, type(module), hasattr(module, "quant_config")) #model.layers.0.mlp.down_proj
# set module state_dict
module_state_dict = {}
for k, v in state_dict.items():
if layer_name + "." in k:
module_state_dict[k.split(".")[-1]] = v
if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k)
if self.pre_quantized:
if isinstance(module, HQQLinear):
return
else:
hqq_layer = HQQLinear(
linear_layer=None,
quant_config=None, # module.quant_config
compute_dtype=self.torch_dtype,
device=target_device,
)
try:
hqq_layer.load_state_dict(module_state_dict)
except Exception:
# TODO @mobicham: Llama3 break with model.layers.28.mlp.down_proj because its parameters are split across 2 safetensors. How to fix this?
# Currently setting a fake layer so that loading doesn't break
print("Error loading, setting a fake layer for", layer_name, module_state_dict.keys())
hqq_layer = HQQLinear(
torch.nn.Linear(in_features=module.in_features, out_features=module.out_features, bias=False),
module.quant_config,
compute_dtype=self.torch_dtype,
device=target_device,
del_orig=True,
)
if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
if self.using_multi_gpu:
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
setattr(parent_module, node, hqq_layer)
torch.cuda.empty_cache()
return
# Step 1: populate module with weight/bias from module state dict
for key in module_state_dict:
@ -136,7 +194,6 @@ class HqqHfQuantizer(HfQuantizer):
# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
# directly doesn't work.
if hasattr(module, "quant_config"):
hqq_layer = HQQLinear(
module,
@ -193,7 +250,7 @@ class HqqHfQuantizer(HfQuantizer):
@property
def is_serializable(self):
return False
return True
@property
def is_trainable(self) -> bool:

View File

@ -258,12 +258,26 @@ class HqqConfig(QuantizationConfigMixin):
"""
pass
@classmethod
def from_dict(cls, config: Dict[str, Any]):
"""
Override from_dict, used in AutoQuantizationConfig.from_dict in quantizers/auto.py
"""
instance = cls()
instance.quant_config = config["quant_config"]
instance.skip_modules = config["skip_modules"]
return instance
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
return self.quant_config
return {
"quant_config": self.quant_config,
"quant_method": self.quant_method,
"skip_modules": self.skip_modules,
}
def __repr__(self):
config_dict = self.to_dict()