mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 09:03:53 +08:00
Make quantizers good citizens loading-wise (#41138)
* fix param_needs_quantization * rewrite most hqq * clean * fix * comment * remove it from exception of safetensors * start on bnb 4bits * post-rebase fix * make bnb4 bit a good citizen * remove forgotten print * make bnb 8bits a good citizen * better hqq * fix * clean * remove state dict from signature * switch method * make torchao a good citizen * fixes * fix torchao * add check * typo
This commit is contained in:
@ -159,24 +159,13 @@ class Int8SymmetricQuantizer(HfQuantizer):
|
||||
pre_quantized=self.pre_quantized,
|
||||
)
|
||||
|
||||
def param_needs_quantization(
|
||||
self,
|
||||
model,
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
def param_needs_quantization(self, model, param_name: str, **kwargs) -> bool:
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
if isinstance(module, Int8SymmetricLinear):
|
||||
if self.pre_quantized or tensor_name == "bias":
|
||||
if tensor_name == "weight" and param_value.dtype != torch.int8:
|
||||
raise ValueError("Expect quantized weights but got an unquantized weight")
|
||||
return False
|
||||
else:
|
||||
if tensor_name == "weight_scale":
|
||||
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -186,11 +175,18 @@ class Int8SymmetricQuantizer(HfQuantizer):
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Quantizes weights to INT8 symmetric format.
|
||||
"""
|
||||
# Sanity check
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if isinstance(module, Int8SymmetricLinear):
|
||||
if self.pre_quantized or tensor_name == "bias":
|
||||
if tensor_name == "weight" and param_value.dtype != torch.int8:
|
||||
raise ValueError("Expect quantized weights but got an unquantized weight")
|
||||
else:
|
||||
if tensor_name == "weight_scale":
|
||||
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
|
||||
|
||||
abs_max_per_row = torch.max(torch.abs(param_value), dim=1, keepdim=True)[0].clamp(min=1e-5)
|
||||
|
||||
weight_scale = abs_max_per_row / 127.0
|
||||
|
@ -104,7 +104,6 @@ from .utils import (
|
||||
is_torch_npu_available,
|
||||
is_torch_xla_available,
|
||||
is_torch_xpu_available,
|
||||
is_torchao_available,
|
||||
logging,
|
||||
)
|
||||
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
|
||||
@ -119,9 +118,6 @@ from .utils.import_utils import (
|
||||
from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
|
||||
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import dispatch_model, infer_auto_device_map
|
||||
from accelerate.hooks import add_hook_to_module
|
||||
@ -644,6 +640,7 @@ def _infer_parameter_dtype(
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.QUARK,
|
||||
QuantizationMethod.MXFP4,
|
||||
QuantizationMethod.BITS_AND_BYTES,
|
||||
}:
|
||||
return True, None
|
||||
else:
|
||||
@ -698,13 +695,8 @@ def _load_state_dict_into_meta_model(
|
||||
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
|
||||
|
||||
is_quantized = hf_quantizer is not None
|
||||
is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.BITS_AND_BYTES,
|
||||
QuantizationMethod.TORCHAO,
|
||||
}
|
||||
is_safetensors = shard_file.endswith(".safetensors")
|
||||
is_meta_state_dict = is_safetensors and not is_hqq_or_bnb_or_ao
|
||||
is_meta_state_dict = is_safetensors
|
||||
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) if is_meta_state_dict else None
|
||||
params_to_load = list(state_dict.keys())
|
||||
|
||||
@ -726,9 +718,7 @@ def _load_state_dict_into_meta_model(
|
||||
)
|
||||
|
||||
if device_mesh is not None:
|
||||
if not is_quantized or not hf_quantizer.param_needs_quantization(
|
||||
model, param, param_name, state_dict, device_map=device_map
|
||||
):
|
||||
if not is_quantized or not hf_quantizer.param_needs_quantization(model, param_name):
|
||||
# In this case, the param is already on the correct device!
|
||||
shard_and_distribute_module(
|
||||
model,
|
||||
@ -740,7 +730,8 @@ def _load_state_dict_into_meta_model(
|
||||
device_mesh.get_local_rank(),
|
||||
device_mesh,
|
||||
)
|
||||
else: # we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param:
|
||||
else:
|
||||
# we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param
|
||||
sharding_kwargs = {
|
||||
"empty_param": empty_param,
|
||||
"casting_dtype": casting_dtype,
|
||||
@ -753,7 +744,6 @@ def _load_state_dict_into_meta_model(
|
||||
param,
|
||||
param_name,
|
||||
device_mesh.get_local_rank(),
|
||||
state_dict,
|
||||
**sharding_kwargs,
|
||||
)
|
||||
else:
|
||||
@ -775,9 +765,7 @@ def _load_state_dict_into_meta_model(
|
||||
if param_device == "disk":
|
||||
if not is_safetensors:
|
||||
disk_offload_index = offload_weight(param, param_name, disk_offload_folder, disk_offload_index)
|
||||
elif not is_quantized or not hf_quantizer.param_needs_quantization(
|
||||
model, param, param_name, state_dict, param_device=param_device, device_map=device_map
|
||||
):
|
||||
elif not is_quantized or not hf_quantizer.param_needs_quantization(model, param_name):
|
||||
if is_fsdp_enabled():
|
||||
param_device = "cpu" if is_local_dist_rank_0() else "meta"
|
||||
|
||||
@ -785,7 +773,7 @@ def _load_state_dict_into_meta_model(
|
||||
|
||||
else:
|
||||
# TODO naming is stupid it loads it as well
|
||||
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict)
|
||||
hf_quantizer.create_quantized_param(model, param, param_name, param_device)
|
||||
|
||||
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
|
||||
# and then cast it to CPU to avoid excessive memory usage on each GPU
|
||||
@ -823,7 +811,6 @@ def load_shard_file(args):
|
||||
shard_file,
|
||||
state_dict,
|
||||
disk_only_shard_files,
|
||||
is_hqq_or_bnb_or_ao,
|
||||
is_quantized,
|
||||
device_map,
|
||||
hf_quantizer,
|
||||
@ -842,22 +829,8 @@ def load_shard_file(args):
|
||||
return [], disk_offload_index
|
||||
|
||||
map_location = "cpu"
|
||||
if (
|
||||
shard_file.endswith(".safetensors")
|
||||
and not is_hqq_or_bnb_or_ao
|
||||
and not (is_deepspeed_zero3_enabled() and not is_quantized)
|
||||
):
|
||||
if shard_file.endswith(".safetensors") and not (is_deepspeed_zero3_enabled() and not is_quantized):
|
||||
map_location = "meta"
|
||||
elif (
|
||||
device_map is not None
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
|
||||
and (
|
||||
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
|
||||
)
|
||||
):
|
||||
map_location = torch.device([d for d in device_map.values() if d not in ["disk"]][0])
|
||||
|
||||
# If shard_file is "", we use the existing state_dict instead of loading it
|
||||
if shard_file != "":
|
||||
@ -868,14 +841,7 @@ def load_shard_file(args):
|
||||
# Fix the key names
|
||||
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
|
||||
|
||||
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO:
|
||||
if shard_file.endswith(".safetensors") and is_safetensors_available():
|
||||
with safe_open(shard_file, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
state_dict = hf_quantizer.update_state_dict_with_metadata(state_dict, metadata)
|
||||
|
||||
error_msgs = []
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
|
||||
# Skip it with fsdp on ranks other than 0
|
||||
@ -1384,6 +1350,7 @@ def _find_missing_and_unexpected_keys(
|
||||
|
||||
if hf_quantizer is not None:
|
||||
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
|
||||
unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys)
|
||||
|
||||
return missing_keys, unexpected_keys
|
||||
|
||||
@ -4398,9 +4365,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible.
|
||||
Will be removed in v5 of Transformers.
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@ -4931,6 +4895,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
config._pre_quantization_dtype = original_dtype
|
||||
_assign_original_dtype(model)
|
||||
|
||||
# Torchao needs access to all metadata later
|
||||
if hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO:
|
||||
hf_quantizer.set_metadata(checkpoint_files)
|
||||
|
||||
if _torch_distributed_available and device_mesh is not None:
|
||||
model = distribute_model(model, distributed_config, device_mesh, tp_size)
|
||||
|
||||
@ -5201,11 +5169,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.QUARK,
|
||||
}
|
||||
is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.BITS_AND_BYTES,
|
||||
QuantizationMethod.TORCHAO,
|
||||
}
|
||||
|
||||
# Get all the keys of the state dicts that we have to initialize the model
|
||||
if sharded_metadata is not None:
|
||||
@ -5338,7 +5301,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
shard_file,
|
||||
state_dict,
|
||||
disk_only_shard_files,
|
||||
is_hqq_or_bnb_or_ao,
|
||||
is_quantized,
|
||||
device_map,
|
||||
hf_quantizer,
|
||||
@ -5709,12 +5671,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
# Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them
|
||||
if param.device == torch.device("meta"):
|
||||
value = torch.empty_like(param, dtype=dtype, device="cpu")
|
||||
if not is_quantized or not hf_quantizer.param_needs_quantization(
|
||||
self, param_value=value, param_name=key, state_dict={}
|
||||
):
|
||||
if not is_quantized or not hf_quantizer.param_needs_quantization(self, key):
|
||||
_load_parameter_into_model(self, key, value)
|
||||
else:
|
||||
hf_quantizer.create_quantized_param(self, value, key, "cpu", model_state_dict)
|
||||
hf_quantizer.create_quantized_param(self, value, key, "cpu")
|
||||
|
||||
def _initialize_missing_keys(self, missing_keys: list[str], is_quantized: bool) -> None:
|
||||
"""Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to
|
||||
|
@ -140,6 +140,9 @@ class HfQuantizer(ABC):
|
||||
"""
|
||||
return expected_keys
|
||||
|
||||
def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
|
||||
return unexpected_keys
|
||||
|
||||
def get_special_dtypes_update(self, model, dtype: "torch.dtype") -> dict[str, "torch.dtype"]:
|
||||
"""
|
||||
returns dtypes for modules that are not quantized - used for the computation of the device_map in case
|
||||
@ -175,10 +178,12 @@ class HfQuantizer(ABC):
|
||||
"""
|
||||
return False
|
||||
|
||||
def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter":
|
||||
def create_quantized_param(self, *args, **kwargs):
|
||||
"""
|
||||
takes needed components from state_dict and creates quantized param; only applicable if
|
||||
requires_parameters_quantization == True
|
||||
Take needed components from state_dict (those from which `param_needs_quantization` is True) and create
|
||||
quantized param.
|
||||
It usually also load the new param directly in the `model`.
|
||||
Note: only applicable if requires_parameters_quantization == True.
|
||||
"""
|
||||
if not self.requires_parameters_quantization:
|
||||
raise AttributeError(
|
||||
|
@ -12,8 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from collections import defaultdict
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
@ -67,6 +68,15 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
||||
if self.quantization_config.llm_int8_skip_modules is not None:
|
||||
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
|
||||
|
||||
# This describes the additional items that are saved on the state dict (on the params themselves)
|
||||
self.bnb_keys = [
|
||||
f"quant_state.bitsandbytes__{self.quantization_config.bnb_4bit_quant_type}",
|
||||
"absmax",
|
||||
"quant_map",
|
||||
]
|
||||
if self.quantization_config.bnb_4bit_use_double_quant:
|
||||
self.bnb_keys.extend(["nested_absmax", "nested_quant_map"])
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not is_accelerate_available():
|
||||
raise ImportError(
|
||||
@ -132,26 +142,17 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
||||
"calculation. You may encounter unexpected behavior, or pass your own device map"
|
||||
)
|
||||
|
||||
def param_needs_quantization(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
|
||||
return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.bnb_keys)]
|
||||
|
||||
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
|
||||
# Add here check for loaded components' dtypes once serialization is implemented
|
||||
# They are on the params themselves, so we cannot easily extract the module from the name
|
||||
if any(param_name.endswith(x) for x in self.bnb_keys):
|
||||
return True
|
||||
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
|
||||
# bias could be loaded by regular set_module_tensor_to_device() from accelerate,
|
||||
# but it would wrongly use uninitialized weight there.
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
module, name = get_module_from_name(model, param_name)
|
||||
return isinstance(module, bnb.nn.Linear4bit) and name != "bias"
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
@ -159,78 +160,51 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
combines logic from _load_state_dict_into_meta_model and .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device()
|
||||
"""
|
||||
import bitsandbytes as bnb
|
||||
|
||||
is_quant_stat = any(param_name.endswith(x) for x in self.bnb_keys)
|
||||
full_name = param_name
|
||||
if is_quant_stat:
|
||||
param_name = (
|
||||
param_name.rsplit(".", 1)[0] if "quant_state." not in param_name else param_name.rsplit(".", 2)[0]
|
||||
)
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
if tensor_name not in module._parameters:
|
||||
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
|
||||
|
||||
old_value = getattr(module, tensor_name)
|
||||
|
||||
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
|
||||
if isinstance(target_device, int) and is_torch_npu_available():
|
||||
target_device = f"npu:{target_device}"
|
||||
if tensor_name == "bias":
|
||||
if param_value is None:
|
||||
new_value = old_value.to(target_device)
|
||||
else:
|
||||
new_value = param_value.to(target_device)
|
||||
|
||||
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
|
||||
module._parameters[tensor_name] = new_value
|
||||
return
|
||||
|
||||
if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
|
||||
raise ValueError("this function only loads `Linear4bit components`")
|
||||
if (
|
||||
old_value.device == torch.device("meta")
|
||||
and target_device not in ["meta", torch.device("meta")]
|
||||
and param_value is None
|
||||
):
|
||||
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
|
||||
|
||||
# construct `new_value` for the module._parameters[tensor_name]:
|
||||
# construct `new_value` for the module._parameters[tensor_name]
|
||||
if self.pre_quantized:
|
||||
# 4bit loading. Collecting components for restoring quantized weight
|
||||
# This can be expanded to make a universal call for any quantized weight loading
|
||||
module_name = param_name.rsplit(".", 1)[0]
|
||||
# Save the states for later quantization when they are all gathered
|
||||
if not hasattr(self, "param_quant_stats"):
|
||||
self.param_quant_stats = defaultdict(dict)
|
||||
self.param_quant_stats[module_name].update({full_name: param_value})
|
||||
|
||||
if not self.is_serializable:
|
||||
raise ValueError(
|
||||
"Detected int4 weights but the version of bitsandbytes is not compatible with int4 serialization. "
|
||||
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
|
||||
# We are ready for quantization in this case (note, the +1 is for the weight itself)
|
||||
if len(self.param_quant_stats[module_name]) == len(self.bnb_keys) + 1:
|
||||
param_kwargs = {}
|
||||
if self.is_bnb_supports_quant_storage_module:
|
||||
param_kwargs["module"] = module
|
||||
|
||||
weight = self.param_quant_stats[module_name].pop(f"{module_name}.weight")
|
||||
new_value = bnb.nn.Params4bit.from_prequantized(
|
||||
data=weight,
|
||||
quantized_stats=self.param_quant_stats[module_name],
|
||||
requires_grad=False,
|
||||
device=target_device,
|
||||
**param_kwargs,
|
||||
)
|
||||
|
||||
if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (
|
||||
param_name + ".quant_state.bitsandbytes__nf4" not in state_dict
|
||||
):
|
||||
raise ValueError(
|
||||
f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components."
|
||||
)
|
||||
|
||||
quantized_stats = {}
|
||||
for k, v in state_dict.items():
|
||||
if param_name + "." in k:
|
||||
quantized_stats[k] = v
|
||||
|
||||
param_kwargs = {}
|
||||
if self.is_bnb_supports_quant_storage_module:
|
||||
param_kwargs["module"] = module
|
||||
|
||||
new_value = bnb.nn.Params4bit.from_prequantized(
|
||||
data=param_value,
|
||||
quantized_stats=quantized_stats,
|
||||
requires_grad=False,
|
||||
device=target_device,
|
||||
**param_kwargs,
|
||||
)
|
||||
# Set it
|
||||
module._parameters[tensor_name] = new_value
|
||||
# Delete the states
|
||||
del self.param_quant_stats[module_name]
|
||||
else:
|
||||
new_value = param_value.to("cpu")
|
||||
old_value = getattr(module, tensor_name)
|
||||
|
||||
# Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
|
||||
# Since weights are saved in the correct "orientation", we skip transposing when loading.
|
||||
@ -241,7 +215,7 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
||||
kwargs.pop("_is_hf_initialized", None)
|
||||
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
|
||||
|
||||
module._parameters[tensor_name] = new_value
|
||||
module._parameters[tensor_name] = new_value
|
||||
|
||||
# Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.adjust_max_memory
|
||||
def adjust_max_memory(self, max_memory: dict[str, Union[int, str]]) -> dict[str, Union[int, str]]:
|
||||
@ -313,7 +287,6 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
||||
model = replace_with_bnb_linear(
|
||||
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
|
||||
)
|
||||
# TODO: consider bringing replace_with_bnb_linear() code from ..integrations/bitsandbyter.py to here
|
||||
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
@ -158,27 +158,15 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
||||
logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization")
|
||||
return torch.int8
|
||||
|
||||
def param_needs_quantization(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
|
||||
bnb_keys = ["SCB", "weight_format"]
|
||||
return [k for k in unexpected_keys if not any(k.endswith(x) for x in bnb_keys)]
|
||||
|
||||
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Int8Params):
|
||||
if self.pre_quantized:
|
||||
if param_name.replace("weight", "SCB") not in state_dict:
|
||||
raise ValueError("Missing quantization component `SCB`")
|
||||
if param_value.dtype != torch.int8:
|
||||
raise ValueError(
|
||||
f"Incompatible dtype `{param_value.dtype}` when loading 8-bit prequantized weight. Expected `torch.int8`."
|
||||
)
|
||||
return True
|
||||
return False
|
||||
module, name = get_module_from_name(model, param_name)
|
||||
return isinstance(module, bnb.nn.Linear8bitLt) and name != "bias"
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
@ -186,52 +174,38 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
combines logic from _load_state_dict_into_meta_model and .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device()
|
||||
needs aux items from state dicts, if found
|
||||
"""
|
||||
import bitsandbytes as bnb
|
||||
|
||||
fp16_statistics_key = param_name.replace("weight", "SCB")
|
||||
fp16_statistics = state_dict.get(fp16_statistics_key)
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if tensor_name not in module._parameters:
|
||||
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
|
||||
|
||||
old_value = getattr(module, tensor_name)
|
||||
|
||||
if not isinstance(module._parameters[tensor_name], bnb.nn.Int8Params):
|
||||
raise TypeError(f"Parameter `{tensor_name}` should only be a `bnb.nn.Int8Params` instance.")
|
||||
if (
|
||||
old_value.device == torch.device("meta")
|
||||
and target_device not in ["meta", torch.device("meta")]
|
||||
and param_value is None
|
||||
):
|
||||
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
|
||||
|
||||
new_value = param_value.to("cpu")
|
||||
if self.pre_quantized and not self.is_serializable():
|
||||
raise ValueError(
|
||||
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
|
||||
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
|
||||
)
|
||||
|
||||
# Those 2 can only happen when self.pre_quantized == True
|
||||
if tensor_name == "SCB":
|
||||
setattr(module.weight, "SCB", param_value.to(target_device))
|
||||
return
|
||||
# It's not used, but it's getting serialized for BC reason...
|
||||
elif tensor_name == "weight_format":
|
||||
return
|
||||
|
||||
# Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
|
||||
# Since weights are saved in the correct "orientation", we skip transposing when loading.
|
||||
if issubclass(module.source_cls, Conv1D):
|
||||
if fp16_statistics is None:
|
||||
new_value = new_value.T
|
||||
if issubclass(module.source_cls, Conv1D) and not self.pre_quantized:
|
||||
param_value = param_value.T
|
||||
|
||||
old_value = getattr(module, tensor_name)
|
||||
kwargs = old_value.__dict__
|
||||
kwargs.pop("_is_hf_initialized", None)
|
||||
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(target_device)
|
||||
new_value = bnb.nn.Int8Params(param_value.to("cpu"), requires_grad=False, **kwargs).to(target_device)
|
||||
|
||||
# Set it to the module
|
||||
module._parameters[tensor_name] = new_value
|
||||
if fp16_statistics is not None:
|
||||
setattr(module.weight, "SCB", fp16_statistics.to(target_device))
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
model.is_loaded_in_8bit = True
|
||||
@ -268,7 +242,6 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
||||
model = replace_with_bnb_linear(
|
||||
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
|
||||
)
|
||||
# TODO: consider bringing replace_with_bnb_linear() code from ..integrations/bitsandbyter.py to here
|
||||
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .base import HfQuantizer
|
||||
|
||||
@ -100,26 +100,15 @@ class EetqHfQuantizer(HfQuantizer):
|
||||
logger.info("We suggest you to set `dtype=torch.float16` for better efficiency with EETQ.")
|
||||
return dtype
|
||||
|
||||
def param_needs_quantization(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
||||
from eetq import EetqLinear
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
if isinstance(module, EetqLinear):
|
||||
if self.pre_quantized or tensor_name == "bias":
|
||||
if tensor_name == "weight" and param_value.dtype != torch.int8:
|
||||
raise ValueError("Expect quantized weights but got an unquantized weight")
|
||||
return False
|
||||
else:
|
||||
if tensor_name == "weight_scale":
|
||||
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -129,16 +118,22 @@ class EetqHfQuantizer(HfQuantizer):
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
quantizes weights into qweight and weight_scales
|
||||
"""
|
||||
from eetq import quantize_and_preprocess_weights
|
||||
from eetq import EetqLinear, quantize_and_preprocess_weights
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
new_value, weight_scale = quantize_and_preprocess_weights(param_value)
|
||||
|
||||
# Samity check
|
||||
if isinstance(module, EetqLinear):
|
||||
if self.pre_quantized or tensor_name == "bias":
|
||||
if tensor_name == "weight" and param_value.dtype != torch.int8:
|
||||
raise ValueError("Expect quantized weights but got an unquantized weight")
|
||||
else:
|
||||
if tensor_name == "weight_scale":
|
||||
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
|
||||
|
||||
module._buffers[tensor_name] = new_value.to(target_device)
|
||||
module.register("weight_scales", weight_scale.to(target_device))
|
||||
|
||||
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .base import HfQuantizer
|
||||
|
||||
@ -105,33 +105,20 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
|
||||
)
|
||||
return dtype
|
||||
|
||||
def param_needs_quantization(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
||||
from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
if isinstance(module, FbgemmFp8Linear):
|
||||
if self.pre_quantized or tensor_name == "bias":
|
||||
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
|
||||
raise ValueError("Expect quantized weights but got an unquantized weight")
|
||||
return False
|
||||
else:
|
||||
if tensor_name == "weight_scale":
|
||||
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
|
||||
return True
|
||||
if isinstance(module, FbgemmFp8Llama4TextExperts):
|
||||
if self.pre_quantized or tensor_name == "bias":
|
||||
return False
|
||||
else:
|
||||
if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale":
|
||||
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -141,15 +128,25 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Quantizes weights into weight and weight_scale
|
||||
"""
|
||||
|
||||
from ..integrations import FbgemmFp8Llama4TextExperts
|
||||
from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
# Sanity checks
|
||||
if isinstance(module, FbgemmFp8Linear):
|
||||
if self.pre_quantized or tensor_name == "bias":
|
||||
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
|
||||
raise ValueError("Expect quantized weights but got an unquantized weight")
|
||||
else:
|
||||
if tensor_name == "weight_scale":
|
||||
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
|
||||
if isinstance(module, FbgemmFp8Llama4TextExperts):
|
||||
if not (self.pre_quantized or tensor_name == "bias"):
|
||||
if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale":
|
||||
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
|
||||
|
||||
if isinstance(module, FbgemmFp8Llama4TextExperts):
|
||||
if tensor_name == "gate_up_proj":
|
||||
# Process each expert separately
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging
|
||||
from .base import HfQuantizer
|
||||
@ -81,13 +81,21 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Quantizes weights to FP8 format using Block-wise quantization
|
||||
"""
|
||||
from ..integrations.finegrained_fp8 import FP8Linear
|
||||
from ..modeling_utils import _load_parameter_into_model
|
||||
|
||||
# Sanity checks
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if isinstance(module, FP8Linear):
|
||||
if self.pre_quantized or tensor_name == "bias":
|
||||
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
|
||||
raise ValueError("Expect quantized weights but got an unquantized weight")
|
||||
else:
|
||||
if tensor_name == "weight_scale_inv":
|
||||
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
|
||||
|
||||
param_value = param_value.to(target_device)
|
||||
|
||||
# Get FP8 min/max values
|
||||
@ -128,26 +136,14 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
|
||||
_load_parameter_into_model(model, param_name, quantized_param)
|
||||
_load_parameter_into_model(model, param_name.rsplit(".", 1)[0] + ".weight_scale_inv", scale)
|
||||
|
||||
def param_needs_quantization(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
||||
from ..integrations.finegrained_fp8 import FP8Linear
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
if isinstance(module, FP8Linear):
|
||||
if self.pre_quantized or tensor_name == "bias":
|
||||
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
|
||||
raise ValueError("Expect quantized weights but got an unquantized weight")
|
||||
return False
|
||||
else:
|
||||
if tensor_name == "weight_scale_inv":
|
||||
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .base import HfQuantizer
|
||||
from .quantizers_utils import get_module_from_name
|
||||
@ -89,7 +89,7 @@ class FPQuantHfQuantizer(HfQuantizer):
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
module, _ = get_module_from_name(model, param_name)
|
||||
|
||||
@ -159,14 +159,7 @@ class FPQuantHfQuantizer(HfQuantizer):
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return True
|
||||
|
||||
def param_needs_quantization(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
||||
from fp_quant import FPQuantLinear
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ..utils.logging import tqdm
|
||||
from .base import HfQuantizer
|
||||
@ -87,13 +87,10 @@ class HiggsHfQuantizer(HfQuantizer):
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations import quantize_with_higgs
|
||||
|
||||
"""
|
||||
Quantizes weights into weight and weight_scale
|
||||
"""
|
||||
flute_dict = quantize_with_higgs(
|
||||
param_value.to(target_device),
|
||||
self.quantization_config.bits,
|
||||
@ -180,18 +177,11 @@ class HiggsHfQuantizer(HfQuantizer):
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return True
|
||||
|
||||
def param_needs_quantization(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
||||
from ..integrations import HiggsLinear
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if isinstance(module, HiggsLinear) and tensor_name == "weight" and param_value.dtype != torch.int16:
|
||||
if isinstance(module, HiggsLinear) and tensor_name == "weight":
|
||||
# Only quantize weights of HiggsLinear modules that are not already quantized
|
||||
return True
|
||||
else:
|
||||
|
@ -12,10 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..integrations import prepare_for_hqq_linear
|
||||
from ..utils import is_accelerate_available, is_hqq_available, is_torch_available, logging
|
||||
from ..utils import is_hqq_available, is_torch_available, logging
|
||||
from .base import HfQuantizer
|
||||
from .quantizers_utils import get_module_from_name
|
||||
|
||||
@ -24,24 +25,24 @@ if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_hqq_available():
|
||||
from hqq.core.quantize import HQQLinear
|
||||
|
||||
# This is a compatibility hack. HQQ-quantized linear layers do not have a `weight` attribute,
|
||||
# but some models attempt to access `weight.dtype` during the forward pass. To prevent runtime errors,
|
||||
# we patch HQQLinear with a dummy `weight` property that returns an empty tensor with the correct dtype and device.
|
||||
@property
|
||||
def weight(self):
|
||||
return torch.empty(0, dtype=self.compute_dtype, device=self.device)
|
||||
|
||||
HQQLinear.weight = weight
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Finds the parent of a node module named "name"
|
||||
def find_parent(model, name):
|
||||
module_tree = name.split(".")[:-1]
|
||||
parent = model
|
||||
for m in module_tree:
|
||||
parent = parent._modules[m]
|
||||
return parent
|
||||
|
||||
|
||||
class HqqHfQuantizer(HfQuantizer):
|
||||
"""
|
||||
HQQ quantizer base HF class.
|
||||
@ -54,16 +55,17 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
required_packages = ["hqq"]
|
||||
|
||||
def __init__(self, quantization_config, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
self.dtype = None
|
||||
self.using_multi_gpu = False
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not (is_hqq_available()):
|
||||
if not is_hqq_available():
|
||||
raise ImportError(
|
||||
"A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`."
|
||||
)
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
self.dtype = None
|
||||
self.using_multi_gpu = False
|
||||
# Keys that are serialized specifically by hqq
|
||||
self.hqq_keys = HQQLinear(None, None).state_dict_keys() - {"bias"}
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if self.dtype is None:
|
||||
if "dtype" in kwargs:
|
||||
self.dtype = kwargs["dtype"]
|
||||
@ -104,75 +106,56 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
_find_hqq_quantizable_layers(module, layers)
|
||||
|
||||
new_keys = set(expected_keys)
|
||||
if is_hqq_available():
|
||||
from hqq.core.quantize import HQQLinear
|
||||
|
||||
# Name modules
|
||||
for name, module in model.named_modules():
|
||||
module.name = name
|
||||
# Name modules
|
||||
for name, module in model.named_modules():
|
||||
module.name = name
|
||||
|
||||
# valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
|
||||
_valid_modules = set()
|
||||
_find_hqq_quantizable_layers(model, _valid_modules)
|
||||
# valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
|
||||
_valid_modules = set()
|
||||
_find_hqq_quantizable_layers(model, _valid_modules)
|
||||
|
||||
# Remove skipped modules
|
||||
_skipped_modules = set()
|
||||
for _module in _valid_modules:
|
||||
for _skip_module in model.config.quantization_config["skip_modules"]:
|
||||
if _skip_module in _module:
|
||||
_skipped_modules.add(_module)
|
||||
_valid_modules -= _skipped_modules
|
||||
# Remove skipped modules
|
||||
_skipped_modules = set()
|
||||
for _module in _valid_modules:
|
||||
for _skip_module in model.config.quantization_config["skip_modules"]:
|
||||
if _skip_module in _module:
|
||||
_skipped_modules.add(_module)
|
||||
_valid_modules -= _skipped_modules
|
||||
|
||||
# Append new expected layers based on _ref_keys
|
||||
_ref_keys = HQQLinear(
|
||||
linear_layer=None,
|
||||
quant_config=None,
|
||||
compute_dtype=torch.float16,
|
||||
device="cpu",
|
||||
del_orig=False,
|
||||
).state_dict_keys() - {"bias"}
|
||||
# Append new expected layers based on _ref_keys
|
||||
_ref_keys = HQQLinear(
|
||||
linear_layer=None,
|
||||
quant_config=None,
|
||||
compute_dtype=torch.float16,
|
||||
device="cpu",
|
||||
del_orig=False,
|
||||
).state_dict_keys() - {"bias"}
|
||||
|
||||
# Clean-up
|
||||
_rm_keys = set()
|
||||
for key in new_keys:
|
||||
if any(_module in key for _module in _valid_modules):
|
||||
_rm_keys.add(key)
|
||||
new_keys -= _rm_keys
|
||||
# At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear
|
||||
# Clean-up
|
||||
_rm_keys = set()
|
||||
for key in new_keys:
|
||||
if any(_module in key for _module in _valid_modules):
|
||||
_rm_keys.add(key)
|
||||
new_keys -= _rm_keys
|
||||
# At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear
|
||||
|
||||
# Re-populate Linear/HQQLinear
|
||||
for _module in _valid_modules:
|
||||
if _module + ".weight" in loaded_keys:
|
||||
new_keys.add(_module + ".weight")
|
||||
else:
|
||||
new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys})
|
||||
if _module + ".bias" in loaded_keys:
|
||||
new_keys.add(_module + ".bias")
|
||||
# Re-populate Linear/HQQLinear
|
||||
for _module in _valid_modules:
|
||||
if _module + ".weight" in loaded_keys:
|
||||
new_keys.add(_module + ".weight")
|
||||
else:
|
||||
new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys})
|
||||
if _module + ".bias" in loaded_keys:
|
||||
new_keys.add(_module + ".bias")
|
||||
|
||||
return list(new_keys)
|
||||
|
||||
def param_needs_quantization(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
if is_hqq_available():
|
||||
from hqq.core.quantize import HQQLinear
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
if self.pre_quantized:
|
||||
return (isinstance(module, (torch.nn.Linear, HQQLinear))) and tensor_name != "weight"
|
||||
else:
|
||||
return (
|
||||
isinstance(module, torch.nn.Linear)
|
||||
and tensor_name == "weight"
|
||||
# bias doesn't need to be quantized, we use this as a workaround to avoid loading bias into HQQLinear assuming it was loaded
|
||||
# in the state_dict directly with the weight because hqq overwrote load_state_dict for this layer
|
||||
or (isinstance(module, HQQLinear) and tensor_name == "bias")
|
||||
)
|
||||
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
||||
module, _ = get_module_from_name(model, param_name)
|
||||
# Since we do not prepare the modules in advance, we need every param of the Linear layer to go through
|
||||
# `create_quantized_param`, even when `self.is_quantized == True`
|
||||
return isinstance(module, torch.nn.Linear)
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
@ -180,45 +163,33 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Each nn.Linear layer is processed here.
|
||||
We first check if the corresponding module state_dict contains already HQQ quantized parameters.
|
||||
If not, we create a temp linear layer with the module state_dict params and use it for quantization
|
||||
"""
|
||||
|
||||
if is_hqq_available():
|
||||
from hqq.core.quantize import HQQLinear
|
||||
|
||||
# TODO: This is a compatibility hack. HQQ-quantized linear layers do not have a `weight` attribute,
|
||||
# but some models attempt to access `weight.dtype` during the forward pass. To prevent runtime errors,
|
||||
# we patch HQQLinear with a dummy `weight` property that returns an empty tensor with the correct dtype and device.
|
||||
@property
|
||||
def weight(_self: HQQLinear):
|
||||
return torch.empty(0, dtype=_self.compute_dtype, device=_self.device)
|
||||
|
||||
HQQLinear.weight = weight
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
layer_name = ".".join(param_name.split(".")[:-1])
|
||||
parent_module = find_parent(model, layer_name)
|
||||
node = layer_name.split(".")[-1]
|
||||
module_name = param_name.rsplit(".", 1)[0]
|
||||
parent_module, node = get_module_from_name(model, module_name)
|
||||
|
||||
if tensor_name == "bias":
|
||||
# this should already be set
|
||||
quant_config = model.config.quantization_config["quant_config"]
|
||||
skip_modules = model.config.quantization_config["skip_modules"]
|
||||
|
||||
# In this case we do not quantize this layer (it's explicitly skipped) -> simply load param
|
||||
if any(skip_module in module.name for skip_module in skip_modules):
|
||||
module.load_state_dict(
|
||||
{tensor_name: param_value.to(device=target_device, dtype=self.dtype)}, strict=False, assign=True
|
||||
)
|
||||
return
|
||||
|
||||
# set module state_dict
|
||||
module_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if layer_name + "." in k:
|
||||
module_state_dict[k.split(".")[-1]] = v
|
||||
|
||||
# We need this hack as the model is not pre-prepared as an empty skeleton on meta device
|
||||
if self.pre_quantized:
|
||||
if isinstance(module, HQQLinear):
|
||||
return
|
||||
else:
|
||||
# Save them for later
|
||||
if not hasattr(self, "hqq_params"):
|
||||
self.hqq_params = defaultdict(dict)
|
||||
self.hqq_params[module_name].update({tensor_name: param_value})
|
||||
hqq_params = self.hqq_params[module_name]
|
||||
|
||||
# If they are all present and saved, make it a HQQLinear layer! (we cannot do it param after param because
|
||||
# hqq does not support it...)
|
||||
if all(k in hqq_params for k in self.hqq_keys) and ("bias" in hqq_params or module.bias is None):
|
||||
hqq_layer = HQQLinear(
|
||||
linear_layer=None,
|
||||
quant_config=None,
|
||||
@ -226,43 +197,32 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
device=target_device,
|
||||
del_orig=False,
|
||||
)
|
||||
hqq_layer.load_state_dict(hqq_params)
|
||||
|
||||
hqq_layer.load_state_dict(module_state_dict)
|
||||
if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
|
||||
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
|
||||
if self.using_multi_gpu:
|
||||
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
|
||||
|
||||
if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
|
||||
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
|
||||
|
||||
if self.using_multi_gpu:
|
||||
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
|
||||
|
||||
setattr(parent_module, node, hqq_layer)
|
||||
|
||||
# cleanup
|
||||
del module.__dict__, module
|
||||
torch.cuda.empty_cache()
|
||||
setattr(parent_module, node, hqq_layer)
|
||||
del self.hqq_params[module_name], module
|
||||
return
|
||||
|
||||
# Step 1: populate module with weight/bias from module state dict
|
||||
for key, tensor in module_state_dict.items():
|
||||
setattr(module, key, torch.nn.Parameter(tensor))
|
||||
# Load param in the module (without caring about device or dtype, it will be changed later)
|
||||
module.load_state_dict({tensor_name: param_value}, strict=False, assign=True)
|
||||
|
||||
# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
|
||||
# directly doesn't work.
|
||||
quant_config = model.config.quantization_config["quant_config"]
|
||||
skip_modules = model.config.quantization_config["skip_modules"]
|
||||
module_tag = ".".join(module.name.split(".")[-2:])
|
||||
module_quant_config = None
|
||||
if "weight_quant_params" in quant_config:
|
||||
module_quant_config = quant_config
|
||||
elif module_tag in quant_config:
|
||||
module_quant_config = quant_config[module_tag]
|
||||
# If both the weight and bias have already been loaded, time to quantize!
|
||||
module_is_ready = module.weight.device.type != "meta" and (
|
||||
module.bias is None or module.bias.device.type != "meta"
|
||||
)
|
||||
|
||||
for skip_module in skip_modules:
|
||||
if skip_module in module.name:
|
||||
module_quant_config = None
|
||||
break
|
||||
if module_is_ready:
|
||||
module_tag = ".".join(module.name.split(".")[-2:])
|
||||
if "weight_quant_params" in quant_config:
|
||||
module_quant_config = quant_config
|
||||
elif module_tag in quant_config:
|
||||
module_quant_config = quant_config[module_tag]
|
||||
|
||||
if module_quant_config is not None:
|
||||
hqq_layer = HQQLinear(
|
||||
module,
|
||||
quant_config=module_quant_config,
|
||||
@ -279,16 +239,7 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
|
||||
setattr(parent_module, node, hqq_layer)
|
||||
|
||||
else:
|
||||
module = module.to(dtype=self.dtype, device=target_device)
|
||||
setattr(parent_module, node, module)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Remove accelerate hook and uses a simpler forward pass. Otherwise, this breaks with multi-gpu
|
||||
def _patch_layer_for_multigpu(self, hqq_layer):
|
||||
hqq_layer = remove_hook_from_module(hqq_layer)
|
||||
|
||||
def forward_with_device(self, x):
|
||||
out = torch.matmul(x.to(self.device), self.dequantize().t())
|
||||
if self.bias is not None:
|
||||
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .base import HfQuantizer
|
||||
|
||||
@ -153,14 +153,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
||||
)
|
||||
return dtype
|
||||
|
||||
def param_needs_quantization(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
||||
from ..integrations import Mxfp4GptOssExperts
|
||||
from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
|
||||
|
||||
@ -183,7 +176,6 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations import (
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
@ -103,26 +103,10 @@ class QuantoHfQuantizer(HfQuantizer):
|
||||
not_missing_keys.append(missing)
|
||||
return [k for k in missing_keys if k not in not_missing_keys]
|
||||
|
||||
def param_needs_quantization(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
||||
if is_optimum_quanto_available():
|
||||
from optimum.quanto import QModuleMixin
|
||||
|
||||
device_map = kwargs.get("device_map")
|
||||
param_device = kwargs.get("param_device")
|
||||
# we don't quantize the model if the module is going to be offloaded to the cpu
|
||||
if device_map is not None and param_device is not None:
|
||||
device_map_values = set(device_map.values())
|
||||
if param_device == "cpu" and len(device_map_values) > 1:
|
||||
if not (device_map_values == {"cpu"} or device_map_values == {"cpu", "disk"}):
|
||||
return False
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
# We only quantize the weights and the bias is not quantized.
|
||||
if isinstance(module, QModuleMixin) and "weight" in tensor_name:
|
||||
@ -141,15 +125,11 @@ class QuantoHfQuantizer(HfQuantizer):
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create the quantized parameter by calling .freeze() after setting it to the module.
|
||||
"""
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from ..modeling_utils import _load_parameter_into_model
|
||||
|
||||
set_module_tensor_to_device(model, param_name, target_device, param_value)
|
||||
_load_parameter_into_model(model, param_name, param_value.to(target_device))
|
||||
module, _ = get_module_from_name(model, param_name)
|
||||
module.freeze()
|
||||
module.weight.requires_grad = False
|
||||
|
@ -13,23 +13,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..file_utils import is_torch_available
|
||||
from .base import HfQuantizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from ..utils import is_quark_available, logging
|
||||
|
||||
from ..utils import is_accelerate_available, is_quark_available, logging
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -82,23 +75,18 @@ class QuarkHfQuantizer(HfQuantizer):
|
||||
|
||||
return model
|
||||
|
||||
def param_needs_quantization(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
||||
return True
|
||||
|
||||
def create_quantized_param(self, model, param, param_name, param_device, state_dict) -> "torch.nn.Parameter":
|
||||
def create_quantized_param(self, model, param, param_name, param_device, **kwargs):
|
||||
from ..modeling_utils import _load_parameter_into_model
|
||||
|
||||
postfix = param_name.split(".")[-1]
|
||||
|
||||
if postfix in CHECKPOINT_KEYS:
|
||||
param_name = param_name.replace(postfix, CHECKPOINT_KEYS[postfix])
|
||||
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||
_load_parameter_into_model(model, param_name, param.to(param_device))
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
@ -14,6 +14,7 @@
|
||||
import importlib
|
||||
import re
|
||||
import types
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
@ -25,10 +26,12 @@ from .quantizers_utils import get_module_from_name
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..utils import is_torch_available, is_torchao_available, logging
|
||||
from ..utils.quantization_config import TorchAoConfig
|
||||
from ..utils import is_safetensors_available, is_torch_available, is_torchao_available, logging
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -64,15 +67,6 @@ def fuzzy_match_size(config_name: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
# Finds the parent of a node module named "name"
|
||||
def find_parent(model, name):
|
||||
module_tree = name.split(".")[:-1]
|
||||
parent = model
|
||||
for m in module_tree:
|
||||
parent = parent._modules[m]
|
||||
return parent
|
||||
|
||||
|
||||
def _quantization_type(weight):
|
||||
from torchao.dtypes import AffineQuantizedTensor
|
||||
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
|
||||
@ -113,6 +107,20 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
def __init__(self, quantization_config, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
|
||||
if isinstance(self.quantization_config.quant_type, str):
|
||||
is_int_4 = "int4" in self.quantization_config.quant_type
|
||||
else:
|
||||
config_name = self.quantization_config.quant_type.__class__.__name__
|
||||
is_int_4 = fuzzy_match_size(config_name) == "4"
|
||||
|
||||
# TODO: better way to get the serialized key names? Hard to read from torchao codebase
|
||||
if is_int_4:
|
||||
self.weight_ao_keys = ["qdata", "scale", "zero_point"]
|
||||
else:
|
||||
self.weight_ao_keys = ["qdata", "scale"]
|
||||
# Instead of serializing the simple torch.Tensor like usual, torchao adds a `:_data` suffix so we need this
|
||||
self.full_ao_keys = self.weight_ao_keys + ["_data"]
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not is_torchao_available():
|
||||
raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)")
|
||||
@ -229,31 +237,25 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
]
|
||||
return
|
||||
|
||||
def param_needs_quantization(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
|
||||
return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.full_ao_keys)]
|
||||
|
||||
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
||||
if self.quantization_config.quant_type == "autoquant":
|
||||
return False
|
||||
|
||||
param_device = kwargs.pop("param_device", None)
|
||||
# check if the param_name is not in self.modules_to_not_convert
|
||||
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
|
||||
return False
|
||||
elif param_device == "cpu" and self.offload:
|
||||
# We don't quantize weights that we offload
|
||||
if any(key + "." in param_name or key == param_name for key in self.modules_to_not_convert):
|
||||
return False
|
||||
elif any(param_name.endswith(f":{x}") for x in self.full_ao_keys):
|
||||
return True
|
||||
else:
|
||||
# we only quantize the weight of nn.Linear and nn.Embedding
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
_QUANTIZABLE = [torch.nn.Linear]
|
||||
if self.quantization_config.include_input_output_embeddings:
|
||||
_QUANTIZABLE.append(torch.nn.Embedding)
|
||||
return isinstance(module, tuple(_QUANTIZABLE)) and (tensor_name == "weight")
|
||||
return isinstance(module, tuple(_QUANTIZABLE)) and tensor_name == "weight"
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
@ -261,29 +263,56 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Each nn.Linear layer that needs to be quantized is processed here.
|
||||
First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
|
||||
"""
|
||||
if self.quantization_config.quant_type == "autoquant":
|
||||
return
|
||||
|
||||
from torchao.quantization import quantize_
|
||||
|
||||
full_name = param_name
|
||||
# Those are the pre quantized weights
|
||||
if ":" in param_name:
|
||||
param_name = param_name.rsplit(":", 1)[0]
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
if self.pre_quantized:
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(
|
||||
param_value.to(device=target_device), requires_grad=param_value.requires_grad
|
||||
)
|
||||
# If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
|
||||
# already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
|
||||
is_unsafe_serialization = ":" not in full_name
|
||||
if tensor_name == "bias" or is_unsafe_serialization:
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(
|
||||
param_value.to(target_device), requires_grad=param_value.requires_grad
|
||||
)
|
||||
return
|
||||
# Sanity check for the new serialization format
|
||||
elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata)):
|
||||
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
|
||||
|
||||
# Save the states for later quantization when they are all gathered
|
||||
if not hasattr(self, "ao_params"):
|
||||
self.ao_params = defaultdict(dict)
|
||||
self.ao_params[param_name].update({full_name: param_value})
|
||||
|
||||
# We are ready for quantization in this case (we retrieved all the needed keys)
|
||||
if len(self.ao_params[param_name]) == len(self.weight_ao_keys):
|
||||
new_param = unflatten_tensor_state_dict(self.ao_params[param_name], self.metadata)[param_name]
|
||||
# Set it
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(
|
||||
new_param.to(target_device), requires_grad=new_param.requires_grad
|
||||
)
|
||||
|
||||
# Free memory
|
||||
del self.ao_params[param_name]
|
||||
|
||||
# Add repr to the module
|
||||
if isinstance(module, nn.Linear):
|
||||
module.extra_repr = types.MethodType(_linear_extra_repr, module)
|
||||
else:
|
||||
assert isinstance(self.quantization_config, TorchAoConfig)
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(
|
||||
param_value, requires_grad=param_value.requires_grad
|
||||
).to(device=target_device)
|
||||
).to(target_device)
|
||||
# if we are quantizing tied parameters, to avoid tying the quantized weights
|
||||
# the correct order to do it is
|
||||
# 1. load the weight to model
|
||||
@ -313,16 +342,6 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
|
||||
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
|
||||
|
||||
def update_state_dict_with_metadata(self, state_dict, metadata):
|
||||
"""
|
||||
If the metadata contains torchao tensor subclass information, we reconstruct the tensor subclass state dict
|
||||
from the provided state_dict and metadata.
|
||||
"""
|
||||
if TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(metadata):
|
||||
return unflatten_tensor_state_dict(state_dict, metadata)
|
||||
else:
|
||||
return state_dict
|
||||
|
||||
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||
"""No process required for torchao quantized model"""
|
||||
if self.quantization_config.quant_type == "autoquant":
|
||||
@ -415,3 +434,13 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
@property
|
||||
def is_compileable(self) -> bool:
|
||||
return True
|
||||
|
||||
def set_metadata(self, checkpoint_files: list[str]):
|
||||
if checkpoint_files[0].endswith(".safetensors") and is_safetensors_available():
|
||||
metadata = {}
|
||||
for checkpoint in checkpoint_files:
|
||||
with safe_open(checkpoint, framework="pt") as f:
|
||||
metadata_ = f.metadata() or {}
|
||||
metadata.update(metadata_)
|
||||
# Save it
|
||||
self.metadata = metadata
|
||||
|
Reference in New Issue
Block a user