mirror of
				https://github.com/huggingface/transformers.git
				synced 2025-11-04 20:14:36 +08:00 
			
		
		
		
	Compare commits
	
		
			6 Commits
		
	
	
		
			v4.56.2
			...
			hqq_serial
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| a8704d266e | |||
| bc9cb55d8d | |||
| f2ea032e40 | |||
| 75dfe0a9c6 | |||
| fa8a9f55c0 | |||
| ff40f1a9e1 | 
@ -97,7 +97,7 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
 | 
			
		||||
 | 
			
		||||
    # Convert quantization_config to layer-wise config
 | 
			
		||||
    skip_modules = quantization_config.skip_modules
 | 
			
		||||
    quant_config = quantization_config.to_dict()
 | 
			
		||||
    quant_config = quantization_config.quant_config
 | 
			
		||||
    linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))
 | 
			
		||||
 | 
			
		||||
    if any(key in linear_tags for key in quant_config.keys()):
 | 
			
		||||
@ -113,7 +113,11 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # We store quantization config as linear_tag -> hqq quant config
 | 
			
		||||
    model.config.quantization_config = patch_params
 | 
			
		||||
    model.config.quantization_config = {
 | 
			
		||||
        "quant_config": quant_config,
 | 
			
		||||
        "quant_method": quantization_config.quant_method,
 | 
			
		||||
        "skip_modules": skip_modules,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if not has_been_replaced:
 | 
			
		||||
        logger.warning("No linear modules were found in your model for quantization.")
 | 
			
		||||
 | 
			
		||||
@ -56,6 +56,7 @@ from .pytorch_utils import (  # noqa: F401
 | 
			
		||||
    prune_linear_layer,
 | 
			
		||||
)
 | 
			
		||||
from .quantizers import AutoHfQuantizer, HfQuantizer
 | 
			
		||||
from .quantizers.quantizer_hqq import HqqHfQuantizer
 | 
			
		||||
from .quantizers.quantizers_utils import get_module_from_name
 | 
			
		||||
from .safetensors_conversion import auto_conversion
 | 
			
		||||
from .utils import (
 | 
			
		||||
@ -858,8 +859,9 @@ def _load_state_dict_into_meta_model(
 | 
			
		||||
    is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
 | 
			
		||||
 | 
			
		||||
    for param_name, param in state_dict.items():
 | 
			
		||||
        # print('param_name', param_name, param_name in loaded_state_dict_keys, param_name in expected_keys)
 | 
			
		||||
        # First part of the test is always true as load_state_dict_keys always contains state_dict keys.
 | 
			
		||||
        if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
 | 
			
		||||
        if param_name not in loaded_state_dict_keys:  # or param_name not in expected_keys: #TODO @mobicham
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if param_name.startswith(start_prefix):
 | 
			
		||||
@ -891,12 +893,21 @@ def _load_state_dict_into_meta_model(
 | 
			
		||||
        # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
 | 
			
		||||
        # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
 | 
			
		||||
        # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
 | 
			
		||||
        old_param = model
 | 
			
		||||
        splits = param_name.split(".")
 | 
			
		||||
        for split in splits:
 | 
			
		||||
            old_param = getattr(old_param, split)
 | 
			
		||||
            if old_param is None:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
        # TODO @mobicham: We need this for Hqq Quantizer otherwise it would break because state_dict fields (W_q, etc.) are not in nn.Linear
 | 
			
		||||
        check_old_param = True
 | 
			
		||||
        if is_quantized:
 | 
			
		||||
            if isinstance(hf_quantizer, HqqHfQuantizer):
 | 
			
		||||
                check_old_param, old_param = False, None
 | 
			
		||||
 | 
			
		||||
        if check_old_param:
 | 
			
		||||
            old_param = model
 | 
			
		||||
            splits = param_name.split(".")
 | 
			
		||||
            for split in splits:
 | 
			
		||||
                old_param = getattr(old_param, split)
 | 
			
		||||
                if old_param is None:
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
        if old_param is not None:
 | 
			
		||||
            if dtype is None:
 | 
			
		||||
                param = param.to(old_param.dtype)
 | 
			
		||||
@ -3725,6 +3736,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
 | 
			
		||||
        from_pt = not (from_tf | from_flax)
 | 
			
		||||
 | 
			
		||||
        # load pt weights early so that we know which dtype to init the model under
 | 
			
		||||
 | 
			
		||||
        if from_pt:
 | 
			
		||||
            if not is_sharded and state_dict is None:
 | 
			
		||||
                # Time to load the checkpoint
 | 
			
		||||
@ -4181,7 +4193,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
 | 
			
		||||
                    value = torch.empty(*param.size(), dtype=target_dtype)
 | 
			
		||||
                    if (
 | 
			
		||||
                        not is_quantized
 | 
			
		||||
                        or getattr(hf_quantizer, "requires_parameters_quantization", False)
 | 
			
		||||
                        or (getattr(hf_quantizer, "requires_parameters_quantization", False))
 | 
			
		||||
                        or not hf_quantizer.check_quantized_param(
 | 
			
		||||
                            model, param_value=value, param_name=key, state_dict={}
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
@ -91,6 +91,14 @@ class HqqHfQuantizer(HfQuantizer):
 | 
			
		||||
            else:
 | 
			
		||||
                self.using_multi_gpu = len(set(device_map.values())) > 1
 | 
			
		||||
 | 
			
		||||
    def update_missing_keys(
 | 
			
		||||
        self, model: "PreTrainedModel", missing_keys: List[str], prefix: str, **kwargs
 | 
			
		||||
    ) -> List[str]:
 | 
			
		||||
        if self.pre_quantized:
 | 
			
		||||
            return [key for key in missing_keys if ("weight" not in key)]
 | 
			
		||||
        else:
 | 
			
		||||
            return missing_keys
 | 
			
		||||
 | 
			
		||||
    def check_quantized_param(
 | 
			
		||||
        self,
 | 
			
		||||
        model: "PreTrainedModel",
 | 
			
		||||
@ -99,9 +107,18 @@ class HqqHfQuantizer(HfQuantizer):
 | 
			
		||||
        state_dict: Dict[str, Any],
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> bool:
 | 
			
		||||
        if is_hqq_available():
 | 
			
		||||
            from hqq.core.quantize import HQQLinear
 | 
			
		||||
        module, tensor_name = get_module_from_name(model, param_name)
 | 
			
		||||
 | 
			
		||||
        return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
 | 
			
		||||
        if self.pre_quantized:
 | 
			
		||||
            return (
 | 
			
		||||
                (isinstance(module, torch.nn.Linear) or isinstance(module, HQQLinear))
 | 
			
		||||
                and tensor_name != "weight"
 | 
			
		||||
                and tensor_name != "bias"
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            return isinstance(module, torch.nn.Linear) and tensor_name == "weight"
 | 
			
		||||
 | 
			
		||||
    def create_quantized_param(
 | 
			
		||||
        self,
 | 
			
		||||
@ -122,13 +139,54 @@ class HqqHfQuantizer(HfQuantizer):
 | 
			
		||||
            from hqq.core.quantize import HQQLinear
 | 
			
		||||
 | 
			
		||||
        module, tensor_name = get_module_from_name(model, param_name)
 | 
			
		||||
 | 
			
		||||
        layer_name = param_name.replace(".weight", "").replace(".bias", "")
 | 
			
		||||
        layer_name = ".".join(param_name.split(".")[:-1])
 | 
			
		||||
        parent_module = find_parent(model, layer_name)
 | 
			
		||||
        node = layer_name.split(".")[-1]
 | 
			
		||||
 | 
			
		||||
        # Step 0: set module state_dict
 | 
			
		||||
        module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key}
 | 
			
		||||
        # print("create_quantized_param | ", 'layer_name', layer_name, type(module), hasattr(module, "quant_config")) #model.layers.0.mlp.down_proj
 | 
			
		||||
 | 
			
		||||
        # set module state_dict
 | 
			
		||||
        module_state_dict = {}
 | 
			
		||||
        for k, v in state_dict.items():
 | 
			
		||||
            if layer_name + "." in k:
 | 
			
		||||
                module_state_dict[k.split(".")[-1]] = v
 | 
			
		||||
                if unexpected_keys is not None and k in unexpected_keys:
 | 
			
		||||
                    unexpected_keys.remove(k)
 | 
			
		||||
 | 
			
		||||
        if self.pre_quantized:
 | 
			
		||||
            if isinstance(module, HQQLinear):
 | 
			
		||||
                return
 | 
			
		||||
            else:
 | 
			
		||||
                hqq_layer = HQQLinear(
 | 
			
		||||
                    linear_layer=None,
 | 
			
		||||
                    quant_config=None,  # module.quant_config
 | 
			
		||||
                    compute_dtype=self.torch_dtype,
 | 
			
		||||
                    device=target_device,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                try:
 | 
			
		||||
                    hqq_layer.load_state_dict(module_state_dict)
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    # TODO @mobicham: Llama3 break with model.layers.28.mlp.down_proj because its parameters are split across 2 safetensors. How to fix this?
 | 
			
		||||
                    # Currently setting a fake layer so that loading doesn't break
 | 
			
		||||
                    print("Error loading, setting a fake layer for", layer_name, module_state_dict.keys())
 | 
			
		||||
                    hqq_layer = HQQLinear(
 | 
			
		||||
                        torch.nn.Linear(in_features=module.in_features, out_features=module.out_features, bias=False),
 | 
			
		||||
                        module.quant_config,
 | 
			
		||||
                        compute_dtype=self.torch_dtype,
 | 
			
		||||
                        device=target_device,
 | 
			
		||||
                        del_orig=True,
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
 | 
			
		||||
                    hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
 | 
			
		||||
 | 
			
		||||
                if self.using_multi_gpu:
 | 
			
		||||
                    hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
 | 
			
		||||
 | 
			
		||||
                setattr(parent_module, node, hqq_layer)
 | 
			
		||||
                torch.cuda.empty_cache()
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
        # Step 1: populate module with weight/bias from module state dict
 | 
			
		||||
        for key in module_state_dict:
 | 
			
		||||
@ -136,7 +194,6 @@ class HqqHfQuantizer(HfQuantizer):
 | 
			
		||||
 | 
			
		||||
        # Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
 | 
			
		||||
        # directly doesn't work.
 | 
			
		||||
 | 
			
		||||
        if hasattr(module, "quant_config"):
 | 
			
		||||
            hqq_layer = HQQLinear(
 | 
			
		||||
                module,
 | 
			
		||||
@ -193,7 +250,7 @@ class HqqHfQuantizer(HfQuantizer):
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def is_serializable(self):
 | 
			
		||||
        return False
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def is_trainable(self) -> bool:
 | 
			
		||||
 | 
			
		||||
@ -258,12 +258,26 @@ class HqqConfig(QuantizationConfigMixin):
 | 
			
		||||
        """
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_dict(cls, config: Dict[str, Any]):
 | 
			
		||||
        """
 | 
			
		||||
        Override from_dict, used in AutoQuantizationConfig.from_dict in quantizers/auto.py
 | 
			
		||||
        """
 | 
			
		||||
        instance = cls()
 | 
			
		||||
        instance.quant_config = config["quant_config"]
 | 
			
		||||
        instance.skip_modules = config["skip_modules"]
 | 
			
		||||
        return instance
 | 
			
		||||
 | 
			
		||||
    def to_dict(self) -> Dict[str, Any]:
 | 
			
		||||
        """
 | 
			
		||||
        Serializes this instance to a Python dictionary. Returns:
 | 
			
		||||
            `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
 | 
			
		||||
        """
 | 
			
		||||
        return self.quant_config
 | 
			
		||||
        return {
 | 
			
		||||
            "quant_config": self.quant_config,
 | 
			
		||||
            "quant_method": self.quant_method,
 | 
			
		||||
            "skip_modules": self.skip_modules,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        config_dict = self.to_dict()
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user