Make quantizers good citizens loading-wise (#41138)

* fix param_needs_quantization

* rewrite most hqq

* clean

* fix

* comment

* remove it from exception of safetensors

* start on bnb 4bits

* post-rebase fix

* make bnb4 bit a good citizen

* remove forgotten print

* make bnb 8bits a good citizen

* better hqq

* fix

* clean

* remove state dict from signature

* switch method

* make torchao a good citizen

* fixes

* fix torchao

* add check

* typo
This commit is contained in:
Cyril Vallez
2025-09-29 17:04:45 +02:00
committed by GitHub
parent 399c589dfa
commit 5426edecab
15 changed files with 350 additions and 532 deletions

View File

@ -159,24 +159,13 @@ class Int8SymmetricQuantizer(HfQuantizer):
pre_quantized=self.pre_quantized,
)
def param_needs_quantization(
self,
model,
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
):
def param_needs_quantization(self, model, param_name: str, **kwargs) -> bool:
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, Int8SymmetricLinear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.int8:
raise ValueError("Expect quantized weights but got an unquantized weight")
return False
else:
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
return False
@ -186,11 +175,18 @@ class Int8SymmetricQuantizer(HfQuantizer):
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: dict[str, Any],
**kwargs,
):
"""
Quantizes weights to INT8 symmetric format.
"""
# Sanity check
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, Int8SymmetricLinear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.int8:
raise ValueError("Expect quantized weights but got an unquantized weight")
else:
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
abs_max_per_row = torch.max(torch.abs(param_value), dim=1, keepdim=True)[0].clamp(min=1e-5)
weight_scale = abs_max_per_row / 127.0

View File

@ -104,7 +104,6 @@ from .utils import (
is_torch_npu_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchao_available,
logging,
)
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
@ -119,9 +118,6 @@ from .utils.import_utils import (
from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
if is_torchao_available():
from torchao.quantization import Int4WeightOnlyConfig
if is_accelerate_available():
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import add_hook_to_module
@ -644,6 +640,7 @@ def _infer_parameter_dtype(
QuantizationMethod.HQQ,
QuantizationMethod.QUARK,
QuantizationMethod.MXFP4,
QuantizationMethod.BITS_AND_BYTES,
}:
return True, None
else:
@ -698,13 +695,8 @@ def _load_state_dict_into_meta_model(
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
is_quantized = hf_quantizer is not None
is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
QuantizationMethod.TORCHAO,
}
is_safetensors = shard_file.endswith(".safetensors")
is_meta_state_dict = is_safetensors and not is_hqq_or_bnb_or_ao
is_meta_state_dict = is_safetensors
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) if is_meta_state_dict else None
params_to_load = list(state_dict.keys())
@ -726,9 +718,7 @@ def _load_state_dict_into_meta_model(
)
if device_mesh is not None:
if not is_quantized or not hf_quantizer.param_needs_quantization(
model, param, param_name, state_dict, device_map=device_map
):
if not is_quantized or not hf_quantizer.param_needs_quantization(model, param_name):
# In this case, the param is already on the correct device!
shard_and_distribute_module(
model,
@ -740,7 +730,8 @@ def _load_state_dict_into_meta_model(
device_mesh.get_local_rank(),
device_mesh,
)
else: # we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param:
else:
# we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param
sharding_kwargs = {
"empty_param": empty_param,
"casting_dtype": casting_dtype,
@ -753,7 +744,6 @@ def _load_state_dict_into_meta_model(
param,
param_name,
device_mesh.get_local_rank(),
state_dict,
**sharding_kwargs,
)
else:
@ -775,9 +765,7 @@ def _load_state_dict_into_meta_model(
if param_device == "disk":
if not is_safetensors:
disk_offload_index = offload_weight(param, param_name, disk_offload_folder, disk_offload_index)
elif not is_quantized or not hf_quantizer.param_needs_quantization(
model, param, param_name, state_dict, param_device=param_device, device_map=device_map
):
elif not is_quantized or not hf_quantizer.param_needs_quantization(model, param_name):
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"
@ -785,7 +773,7 @@ def _load_state_dict_into_meta_model(
else:
# TODO naming is stupid it loads it as well
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict)
hf_quantizer.create_quantized_param(model, param, param_name, param_device)
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
# and then cast it to CPU to avoid excessive memory usage on each GPU
@ -823,7 +811,6 @@ def load_shard_file(args):
shard_file,
state_dict,
disk_only_shard_files,
is_hqq_or_bnb_or_ao,
is_quantized,
device_map,
hf_quantizer,
@ -842,22 +829,8 @@ def load_shard_file(args):
return [], disk_offload_index
map_location = "cpu"
if (
shard_file.endswith(".safetensors")
and not is_hqq_or_bnb_or_ao
and not (is_deepspeed_zero3_enabled() and not is_quantized)
):
if shard_file.endswith(".safetensors") and not (is_deepspeed_zero3_enabled() and not is_quantized):
map_location = "meta"
elif (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and (
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
)
):
map_location = torch.device([d for d in device_map.values() if d not in ["disk"]][0])
# If shard_file is "", we use the existing state_dict instead of loading it
if shard_file != "":
@ -868,14 +841,7 @@ def load_shard_file(args):
# Fix the key names
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO:
if shard_file.endswith(".safetensors") and is_safetensors_available():
with safe_open(shard_file, framework="pt") as f:
metadata = f.metadata()
state_dict = hf_quantizer.update_state_dict_with_metadata(state_dict, metadata)
error_msgs = []
if is_deepspeed_zero3_enabled() and not is_quantized:
error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
# Skip it with fsdp on ranks other than 0
@ -1384,6 +1350,7 @@ def _find_missing_and_unexpected_keys(
if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys)
return missing_keys, unexpected_keys
@ -4398,9 +4365,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download:
Deprecated and ignored. All downloads are now resumed by default when possible.
Will be removed in v5 of Transformers.
proxies (`dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@ -4931,6 +4895,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
config._pre_quantization_dtype = original_dtype
_assign_original_dtype(model)
# Torchao needs access to all metadata later
if hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO:
hf_quantizer.set_metadata(checkpoint_files)
if _torch_distributed_available and device_mesh is not None:
model = distribute_model(model, distributed_config, device_mesh, tp_size)
@ -5201,11 +5169,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
QuantizationMethod.HQQ,
QuantizationMethod.QUARK,
}
is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
QuantizationMethod.TORCHAO,
}
# Get all the keys of the state dicts that we have to initialize the model
if sharded_metadata is not None:
@ -5338,7 +5301,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
shard_file,
state_dict,
disk_only_shard_files,
is_hqq_or_bnb_or_ao,
is_quantized,
device_map,
hf_quantizer,
@ -5709,12 +5671,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them
if param.device == torch.device("meta"):
value = torch.empty_like(param, dtype=dtype, device="cpu")
if not is_quantized or not hf_quantizer.param_needs_quantization(
self, param_value=value, param_name=key, state_dict={}
):
if not is_quantized or not hf_quantizer.param_needs_quantization(self, key):
_load_parameter_into_model(self, key, value)
else:
hf_quantizer.create_quantized_param(self, value, key, "cpu", model_state_dict)
hf_quantizer.create_quantized_param(self, value, key, "cpu")
def _initialize_missing_keys(self, missing_keys: list[str], is_quantized: bool) -> None:
"""Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to

View File

@ -140,6 +140,9 @@ class HfQuantizer(ABC):
"""
return expected_keys
def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
return unexpected_keys
def get_special_dtypes_update(self, model, dtype: "torch.dtype") -> dict[str, "torch.dtype"]:
"""
returns dtypes for modules that are not quantized - used for the computation of the device_map in case
@ -175,10 +178,12 @@ class HfQuantizer(ABC):
"""
return False
def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter":
def create_quantized_param(self, *args, **kwargs):
"""
takes needed components from state_dict and creates quantized param; only applicable if
requires_parameters_quantization == True
Take needed components from state_dict (those from which `param_needs_quantization` is True) and create
quantized param.
It usually also load the new param directly in the `model`.
Note: only applicable if requires_parameters_quantization == True.
"""
if not self.requires_parameters_quantization:
raise AttributeError(

View File

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from collections import defaultdict
from functools import cached_property
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Optional, Union
from packaging import version
@ -67,6 +68,15 @@ class Bnb4BitHfQuantizer(HfQuantizer):
if self.quantization_config.llm_int8_skip_modules is not None:
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
# This describes the additional items that are saved on the state dict (on the params themselves)
self.bnb_keys = [
f"quant_state.bitsandbytes__{self.quantization_config.bnb_4bit_quant_type}",
"absmax",
"quant_map",
]
if self.quantization_config.bnb_4bit_use_double_quant:
self.bnb_keys.extend(["nested_absmax", "nested_quant_map"])
def validate_environment(self, *args, **kwargs):
if not is_accelerate_available():
raise ImportError(
@ -132,26 +142,17 @@ class Bnb4BitHfQuantizer(HfQuantizer):
"calculation. You may encounter unexpected behavior, or pass your own device map"
)
def param_needs_quantization(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
) -> bool:
def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.bnb_keys)]
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
import bitsandbytes as bnb
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
# Add here check for loaded components' dtypes once serialization is implemented
# They are on the params themselves, so we cannot easily extract the module from the name
if any(param_name.endswith(x) for x in self.bnb_keys):
return True
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
# bias could be loaded by regular set_module_tensor_to_device() from accelerate,
# but it would wrongly use uninitialized weight there.
return True
else:
return False
module, name = get_module_from_name(model, param_name)
return isinstance(module, bnb.nn.Linear4bit) and name != "bias"
def create_quantized_param(
self,
@ -159,78 +160,51 @@ class Bnb4BitHfQuantizer(HfQuantizer):
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: dict[str, Any],
**kwargs,
):
"""
combines logic from _load_state_dict_into_meta_model and .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device()
"""
import bitsandbytes as bnb
is_quant_stat = any(param_name.endswith(x) for x in self.bnb_keys)
full_name = param_name
if is_quant_stat:
param_name = (
param_name.rsplit(".", 1)[0] if "quant_state." not in param_name else param_name.rsplit(".", 2)[0]
)
module, tensor_name = get_module_from_name(model, param_name)
if tensor_name not in module._parameters:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
old_value = getattr(module, tensor_name)
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if isinstance(target_device, int) and is_torch_npu_available():
target_device = f"npu:{target_device}"
if tensor_name == "bias":
if param_value is None:
new_value = old_value.to(target_device)
else:
new_value = param_value.to(target_device)
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
module._parameters[tensor_name] = new_value
return
if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
raise ValueError("this function only loads `Linear4bit components`")
if (
old_value.device == torch.device("meta")
and target_device not in ["meta", torch.device("meta")]
and param_value is None
):
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
# construct `new_value` for the module._parameters[tensor_name]:
# construct `new_value` for the module._parameters[tensor_name]
if self.pre_quantized:
# 4bit loading. Collecting components for restoring quantized weight
# This can be expanded to make a universal call for any quantized weight loading
if not self.is_serializable:
raise ValueError(
"Detected int4 weights but the version of bitsandbytes is not compatible with int4 serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (
param_name + ".quant_state.bitsandbytes__nf4" not in state_dict
):
raise ValueError(
f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components."
)
quantized_stats = {}
for k, v in state_dict.items():
if param_name + "." in k:
quantized_stats[k] = v
module_name = param_name.rsplit(".", 1)[0]
# Save the states for later quantization when they are all gathered
if not hasattr(self, "param_quant_stats"):
self.param_quant_stats = defaultdict(dict)
self.param_quant_stats[module_name].update({full_name: param_value})
# We are ready for quantization in this case (note, the +1 is for the weight itself)
if len(self.param_quant_stats[module_name]) == len(self.bnb_keys) + 1:
param_kwargs = {}
if self.is_bnb_supports_quant_storage_module:
param_kwargs["module"] = module
weight = self.param_quant_stats[module_name].pop(f"{module_name}.weight")
new_value = bnb.nn.Params4bit.from_prequantized(
data=param_value,
quantized_stats=quantized_stats,
data=weight,
quantized_stats=self.param_quant_stats[module_name],
requires_grad=False,
device=target_device,
**param_kwargs,
)
# Set it
module._parameters[tensor_name] = new_value
# Delete the states
del self.param_quant_stats[module_name]
else:
new_value = param_value.to("cpu")
old_value = getattr(module, tensor_name)
# Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
# Since weights are saved in the correct "orientation", we skip transposing when loading.
@ -313,7 +287,6 @@ class Bnb4BitHfQuantizer(HfQuantizer):
model = replace_with_bnb_linear(
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
)
# TODO: consider bringing replace_with_bnb_linear() code from ..integrations/bitsandbyter.py to here
model.config.quantization_config = self.quantization_config

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Optional, Union
from packaging import version
@ -158,27 +158,15 @@ class Bnb8BitHfQuantizer(HfQuantizer):
logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization")
return torch.int8
def param_needs_quantization(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
):
def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
bnb_keys = ["SCB", "weight_format"]
return [k for k in unexpected_keys if not any(k.endswith(x) for x in bnb_keys)]
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
import bitsandbytes as bnb
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Int8Params):
if self.pre_quantized:
if param_name.replace("weight", "SCB") not in state_dict:
raise ValueError("Missing quantization component `SCB`")
if param_value.dtype != torch.int8:
raise ValueError(
f"Incompatible dtype `{param_value.dtype}` when loading 8-bit prequantized weight. Expected `torch.int8`."
)
return True
return False
module, name = get_module_from_name(model, param_name)
return isinstance(module, bnb.nn.Linear8bitLt) and name != "bias"
def create_quantized_param(
self,
@ -186,52 +174,38 @@ class Bnb8BitHfQuantizer(HfQuantizer):
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: dict[str, Any],
**kwargs,
):
"""
combines logic from _load_state_dict_into_meta_model and .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device()
needs aux items from state dicts, if found
"""
import bitsandbytes as bnb
fp16_statistics_key = param_name.replace("weight", "SCB")
fp16_statistics = state_dict.get(fp16_statistics_key)
module, tensor_name = get_module_from_name(model, param_name)
if tensor_name not in module._parameters:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
old_value = getattr(module, tensor_name)
if not isinstance(module._parameters[tensor_name], bnb.nn.Int8Params):
raise TypeError(f"Parameter `{tensor_name}` should only be a `bnb.nn.Int8Params` instance.")
if (
old_value.device == torch.device("meta")
and target_device not in ["meta", torch.device("meta")]
and param_value is None
):
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
new_value = param_value.to("cpu")
if self.pre_quantized and not self.is_serializable():
raise ValueError(
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
# Those 2 can only happen when self.pre_quantized == True
if tensor_name == "SCB":
setattr(module.weight, "SCB", param_value.to(target_device))
return
# It's not used, but it's getting serialized for BC reason...
elif tensor_name == "weight_format":
return
# Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
# Since weights are saved in the correct "orientation", we skip transposing when loading.
if issubclass(module.source_cls, Conv1D):
if fp16_statistics is None:
new_value = new_value.T
if issubclass(module.source_cls, Conv1D) and not self.pre_quantized:
param_value = param_value.T
old_value = getattr(module, tensor_name)
kwargs = old_value.__dict__
kwargs.pop("_is_hf_initialized", None)
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(target_device)
new_value = bnb.nn.Int8Params(param_value.to("cpu"), requires_grad=False, **kwargs).to(target_device)
# Set it to the module
module._parameters[tensor_name] = new_value
if fp16_statistics is not None:
setattr(module.weight, "SCB", fp16_statistics.to(target_device))
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model.is_loaded_in_8bit = True
@ -268,7 +242,6 @@ class Bnb8BitHfQuantizer(HfQuantizer):
model = replace_with_bnb_linear(
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
)
# TODO: consider bringing replace_with_bnb_linear() code from ..integrations/bitsandbyter.py to here
model.config.quantization_config = self.quantization_config

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional
from .base import HfQuantizer
@ -100,26 +100,15 @@ class EetqHfQuantizer(HfQuantizer):
logger.info("We suggest you to set `dtype=torch.float16` for better efficiency with EETQ.")
return dtype
def param_needs_quantization(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
):
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
from eetq import EetqLinear
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, EetqLinear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.int8:
raise ValueError("Expect quantized weights but got an unquantized weight")
return False
else:
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
return False
@ -129,16 +118,22 @@ class EetqHfQuantizer(HfQuantizer):
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: dict[str, Any],
**kwargs,
):
"""
quantizes weights into qweight and weight_scales
"""
from eetq import quantize_and_preprocess_weights
from eetq import EetqLinear, quantize_and_preprocess_weights
module, tensor_name = get_module_from_name(model, param_name)
new_value, weight_scale = quantize_and_preprocess_weights(param_value)
# Samity check
if isinstance(module, EetqLinear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.int8:
raise ValueError("Expect quantized weights but got an unquantized weight")
else:
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
module._buffers[tensor_name] = new_value.to(target_device)
module.register("weight_scales", weight_scale.to(target_device))

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional
from .base import HfQuantizer
@ -105,33 +105,20 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
)
return dtype
def param_needs_quantization(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
):
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, FbgemmFp8Linear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
raise ValueError("Expect quantized weights but got an unquantized weight")
return False
else:
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
if isinstance(module, FbgemmFp8Llama4TextExperts):
if self.pre_quantized or tensor_name == "bias":
return False
else:
if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
return False
@ -141,15 +128,25 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: dict[str, Any],
**kwargs,
):
"""
Quantizes weights into weight and weight_scale
"""
from ..integrations import FbgemmFp8Llama4TextExperts
from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
module, tensor_name = get_module_from_name(model, param_name)
# Sanity checks
if isinstance(module, FbgemmFp8Linear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
raise ValueError("Expect quantized weights but got an unquantized weight")
else:
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
if isinstance(module, FbgemmFp8Llama4TextExperts):
if not (self.pre_quantized or tensor_name == "bias"):
if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
if isinstance(module, FbgemmFp8Llama4TextExperts):
if tensor_name == "gate_up_proj":
# Process each expert separately

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional
from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging
from .base import HfQuantizer
@ -81,13 +81,21 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: dict[str, Any],
**kwargs,
):
"""
Quantizes weights to FP8 format using Block-wise quantization
"""
from ..integrations.finegrained_fp8 import FP8Linear
from ..modeling_utils import _load_parameter_into_model
# Sanity checks
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, FP8Linear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
raise ValueError("Expect quantized weights but got an unquantized weight")
else:
if tensor_name == "weight_scale_inv":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
param_value = param_value.to(target_device)
# Get FP8 min/max values
@ -128,26 +136,14 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
_load_parameter_into_model(model, param_name, quantized_param)
_load_parameter_into_model(model, param_name.rsplit(".", 1)[0] + ".weight_scale_inv", scale)
def param_needs_quantization(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
):
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
from ..integrations.finegrained_fp8 import FP8Linear
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, FP8Linear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
raise ValueError("Expect quantized weights but got an unquantized weight")
return False
else:
if tensor_name == "weight_scale_inv":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
return False

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional
from .base import HfQuantizer
from .quantizers_utils import get_module_from_name
@ -89,7 +89,7 @@ class FPQuantHfQuantizer(HfQuantizer):
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: dict[str, Any],
**kwargs,
):
module, _ = get_module_from_name(model, param_name)
@ -159,14 +159,7 @@ class FPQuantHfQuantizer(HfQuantizer):
def is_serializable(self, safe_serialization=None):
return True
def param_needs_quantization(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
) -> bool:
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
from fp_quant import FPQuantLinear
module, tensor_name = get_module_from_name(model, param_name)

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional
from ..utils.logging import tqdm
from .base import HfQuantizer
@ -87,13 +87,10 @@ class HiggsHfQuantizer(HfQuantizer):
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: dict[str, Any],
**kwargs,
):
from ..integrations import quantize_with_higgs
"""
Quantizes weights into weight and weight_scale
"""
flute_dict = quantize_with_higgs(
param_value.to(target_device),
self.quantization_config.bits,
@ -180,18 +177,11 @@ class HiggsHfQuantizer(HfQuantizer):
def is_serializable(self, safe_serialization=None):
return True
def param_needs_quantization(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
) -> bool:
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
from ..integrations import HiggsLinear
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, HiggsLinear) and tensor_name == "weight" and param_value.dtype != torch.int16:
if isinstance(module, HiggsLinear) and tensor_name == "weight":
# Only quantize weights of HiggsLinear modules that are not already quantized
return True
else:

View File

@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any
from collections import defaultdict
from typing import TYPE_CHECKING
from ..integrations import prepare_for_hqq_linear
from ..utils import is_accelerate_available, is_hqq_available, is_torch_available, logging
from ..utils import is_hqq_available, is_torch_available, logging
from .base import HfQuantizer
from .quantizers_utils import get_module_from_name
@ -24,24 +25,24 @@ if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
if is_accelerate_available():
from accelerate.hooks import remove_hook_from_module
if is_torch_available():
import torch
if is_hqq_available():
from hqq.core.quantize import HQQLinear
# This is a compatibility hack. HQQ-quantized linear layers do not have a `weight` attribute,
# but some models attempt to access `weight.dtype` during the forward pass. To prevent runtime errors,
# we patch HQQLinear with a dummy `weight` property that returns an empty tensor with the correct dtype and device.
@property
def weight(self):
return torch.empty(0, dtype=self.compute_dtype, device=self.device)
HQQLinear.weight = weight
logger = logging.get_logger(__name__)
# Finds the parent of a node module named "name"
def find_parent(model, name):
module_tree = name.split(".")[:-1]
parent = model
for m in module_tree:
parent = parent._modules[m]
return parent
class HqqHfQuantizer(HfQuantizer):
"""
HQQ quantizer base HF class.
@ -54,16 +55,17 @@ class HqqHfQuantizer(HfQuantizer):
required_packages = ["hqq"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.dtype = None
self.using_multi_gpu = False
def validate_environment(self, *args, **kwargs):
if not (is_hqq_available()):
if not is_hqq_available():
raise ImportError(
"A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`."
)
super().__init__(quantization_config, **kwargs)
self.dtype = None
self.using_multi_gpu = False
# Keys that are serialized specifically by hqq
self.hqq_keys = HQQLinear(None, None).state_dict_keys() - {"bias"}
def validate_environment(self, *args, **kwargs):
if self.dtype is None:
if "dtype" in kwargs:
self.dtype = kwargs["dtype"]
@ -104,8 +106,6 @@ class HqqHfQuantizer(HfQuantizer):
_find_hqq_quantizable_layers(module, layers)
new_keys = set(expected_keys)
if is_hqq_available():
from hqq.core.quantize import HQQLinear
# Name modules
for name, module in model.named_modules():
@ -151,28 +151,11 @@ class HqqHfQuantizer(HfQuantizer):
return list(new_keys)
def param_needs_quantization(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
) -> bool:
if is_hqq_available():
from hqq.core.quantize import HQQLinear
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized:
return (isinstance(module, (torch.nn.Linear, HQQLinear))) and tensor_name != "weight"
else:
return (
isinstance(module, torch.nn.Linear)
and tensor_name == "weight"
# bias doesn't need to be quantized, we use this as a workaround to avoid loading bias into HQQLinear assuming it was loaded
# in the state_dict directly with the weight because hqq overwrote load_state_dict for this layer
or (isinstance(module, HQQLinear) and tensor_name == "bias")
)
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
module, _ = get_module_from_name(model, param_name)
# Since we do not prepare the modules in advance, we need every param of the Linear layer to go through
# `create_quantized_param`, even when `self.is_quantized == True`
return isinstance(module, torch.nn.Linear)
def create_quantized_param(
self,
@ -180,45 +163,33 @@ class HqqHfQuantizer(HfQuantizer):
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: dict[str, Any],
**kwargs,
):
"""
Each nn.Linear layer is processed here.
We first check if the corresponding module state_dict contains already HQQ quantized parameters.
If not, we create a temp linear layer with the module state_dict params and use it for quantization
"""
if is_hqq_available():
from hqq.core.quantize import HQQLinear
# TODO: This is a compatibility hack. HQQ-quantized linear layers do not have a `weight` attribute,
# but some models attempt to access `weight.dtype` during the forward pass. To prevent runtime errors,
# we patch HQQLinear with a dummy `weight` property that returns an empty tensor with the correct dtype and device.
@property
def weight(_self: HQQLinear):
return torch.empty(0, dtype=_self.compute_dtype, device=_self.device)
HQQLinear.weight = weight
module, tensor_name = get_module_from_name(model, param_name)
layer_name = ".".join(param_name.split(".")[:-1])
parent_module = find_parent(model, layer_name)
node = layer_name.split(".")[-1]
module_name = param_name.rsplit(".", 1)[0]
parent_module, node = get_module_from_name(model, module_name)
if tensor_name == "bias":
# this should already be set
quant_config = model.config.quantization_config["quant_config"]
skip_modules = model.config.quantization_config["skip_modules"]
# In this case we do not quantize this layer (it's explicitly skipped) -> simply load param
if any(skip_module in module.name for skip_module in skip_modules):
module.load_state_dict(
{tensor_name: param_value.to(device=target_device, dtype=self.dtype)}, strict=False, assign=True
)
return
# set module state_dict
module_state_dict = {}
for k, v in state_dict.items():
if layer_name + "." in k:
module_state_dict[k.split(".")[-1]] = v
# We need this hack as the model is not pre-prepared as an empty skeleton on meta device
if self.pre_quantized:
if isinstance(module, HQQLinear):
return
else:
# Save them for later
if not hasattr(self, "hqq_params"):
self.hqq_params = defaultdict(dict)
self.hqq_params[module_name].update({tensor_name: param_value})
hqq_params = self.hqq_params[module_name]
# If they are all present and saved, make it a HQQLinear layer! (we cannot do it param after param because
# hqq does not support it...)
if all(k in hqq_params for k in self.hqq_keys) and ("bias" in hqq_params or module.bias is None):
hqq_layer = HQQLinear(
linear_layer=None,
quant_config=None,
@ -226,43 +197,32 @@ class HqqHfQuantizer(HfQuantizer):
device=target_device,
del_orig=False,
)
hqq_layer.load_state_dict(module_state_dict)
hqq_layer.load_state_dict(hqq_params)
if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
if self.using_multi_gpu:
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
setattr(parent_module, node, hqq_layer)
# cleanup
del module.__dict__, module
torch.cuda.empty_cache()
del self.hqq_params[module_name], module
return
# Step 1: populate module with weight/bias from module state dict
for key, tensor in module_state_dict.items():
setattr(module, key, torch.nn.Parameter(tensor))
# Load param in the module (without caring about device or dtype, it will be changed later)
module.load_state_dict({tensor_name: param_value}, strict=False, assign=True)
# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
# directly doesn't work.
quant_config = model.config.quantization_config["quant_config"]
skip_modules = model.config.quantization_config["skip_modules"]
# If both the weight and bias have already been loaded, time to quantize!
module_is_ready = module.weight.device.type != "meta" and (
module.bias is None or module.bias.device.type != "meta"
)
if module_is_ready:
module_tag = ".".join(module.name.split(".")[-2:])
module_quant_config = None
if "weight_quant_params" in quant_config:
module_quant_config = quant_config
elif module_tag in quant_config:
module_quant_config = quant_config[module_tag]
for skip_module in skip_modules:
if skip_module in module.name:
module_quant_config = None
break
if module_quant_config is not None:
hqq_layer = HQQLinear(
module,
quant_config=module_quant_config,
@ -279,16 +239,7 @@ class HqqHfQuantizer(HfQuantizer):
setattr(parent_module, node, hqq_layer)
else:
module = module.to(dtype=self.dtype, device=target_device)
setattr(parent_module, node, module)
torch.cuda.empty_cache()
# Remove accelerate hook and uses a simpler forward pass. Otherwise, this breaks with multi-gpu
def _patch_layer_for_multigpu(self, hqq_layer):
hqq_layer = remove_hook_from_module(hqq_layer)
def forward_with_device(self, x):
out = torch.matmul(x.to(self.device), self.dequantize().t())
if self.bias is not None:

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional
from .base import HfQuantizer
@ -153,14 +153,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
)
return dtype
def param_needs_quantization(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
):
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
from ..integrations import Mxfp4GptOssExperts
from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
@ -183,7 +176,6 @@ class Mxfp4HfQuantizer(HfQuantizer):
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: dict[str, Any],
**kwargs,
):
from ..integrations import (

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Optional, Union
from packaging import version
@ -103,26 +103,10 @@ class QuantoHfQuantizer(HfQuantizer):
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]
def param_needs_quantization(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
) -> bool:
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
if is_optimum_quanto_available():
from optimum.quanto import QModuleMixin
device_map = kwargs.get("device_map")
param_device = kwargs.get("param_device")
# we don't quantize the model if the module is going to be offloaded to the cpu
if device_map is not None and param_device is not None:
device_map_values = set(device_map.values())
if param_device == "cpu" and len(device_map_values) > 1:
if not (device_map_values == {"cpu"} or device_map_values == {"cpu", "disk"}):
return False
module, tensor_name = get_module_from_name(model, param_name)
# We only quantize the weights and the bias is not quantized.
if isinstance(module, QModuleMixin) and "weight" in tensor_name:
@ -141,15 +125,11 @@ class QuantoHfQuantizer(HfQuantizer):
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
*args,
**kwargs,
):
"""
Create the quantized parameter by calling .freeze() after setting it to the module.
"""
from accelerate.utils import set_module_tensor_to_device
from ..modeling_utils import _load_parameter_into_model
set_module_tensor_to_device(model, param_name, target_device, param_value)
_load_parameter_into_model(model, param_name, param_value.to(target_device))
module, _ = get_module_from_name(model, param_name)
module.freeze()
module.weight.requires_grad = False

View File

@ -13,23 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING
from ..file_utils import is_torch_available
from .base import HfQuantizer
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
if is_torch_available():
import torch
from ..utils import is_quark_available, logging
from ..utils import is_accelerate_available, is_quark_available, logging
if is_accelerate_available():
from accelerate.utils import set_module_tensor_to_device
logger = logging.get_logger(__name__)
@ -82,23 +75,18 @@ class QuarkHfQuantizer(HfQuantizer):
return model
def param_needs_quantization(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
) -> bool:
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
return True
def create_quantized_param(self, model, param, param_name, param_device, state_dict) -> "torch.nn.Parameter":
def create_quantized_param(self, model, param, param_name, param_device, **kwargs):
from ..modeling_utils import _load_parameter_into_model
postfix = param_name.split(".")[-1]
if postfix in CHECKPOINT_KEYS:
param_name = param_name.replace(postfix, CHECKPOINT_KEYS[postfix])
set_module_tensor_to_device(model, param_name, param_device, value=param)
_load_parameter_into_model(model, param_name, param.to(param_device))
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model

View File

@ -14,6 +14,7 @@
import importlib
import re
import types
from collections import defaultdict
from typing import TYPE_CHECKING, Optional, Union
from packaging import version
@ -25,10 +26,12 @@ from .quantizers_utils import get_module_from_name
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from typing import Any
from ..utils import is_torch_available, is_torchao_available, logging
from ..utils.quantization_config import TorchAoConfig
from ..utils import is_safetensors_available, is_torch_available, is_torchao_available, logging
if is_safetensors_available():
from safetensors import safe_open
if is_torch_available():
@ -64,15 +67,6 @@ def fuzzy_match_size(config_name: str) -> Optional[str]:
return None
# Finds the parent of a node module named "name"
def find_parent(model, name):
module_tree = name.split(".")[:-1]
parent = model
for m in module_tree:
parent = parent._modules[m]
return parent
def _quantization_type(weight):
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
@ -113,6 +107,20 @@ class TorchAoHfQuantizer(HfQuantizer):
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
if isinstance(self.quantization_config.quant_type, str):
is_int_4 = "int4" in self.quantization_config.quant_type
else:
config_name = self.quantization_config.quant_type.__class__.__name__
is_int_4 = fuzzy_match_size(config_name) == "4"
# TODO: better way to get the serialized key names? Hard to read from torchao codebase
if is_int_4:
self.weight_ao_keys = ["qdata", "scale", "zero_point"]
else:
self.weight_ao_keys = ["qdata", "scale"]
# Instead of serializing the simple torch.Tensor like usual, torchao adds a `:_data` suffix so we need this
self.full_ao_keys = self.weight_ao_keys + ["_data"]
def validate_environment(self, *args, **kwargs):
if not is_torchao_available():
raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)")
@ -229,31 +237,25 @@ class TorchAoHfQuantizer(HfQuantizer):
]
return
def param_needs_quantization(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
) -> bool:
def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.full_ao_keys)]
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
if self.quantization_config.quant_type == "autoquant":
return False
param_device = kwargs.pop("param_device", None)
# check if the param_name is not in self.modules_to_not_convert
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
return False
elif param_device == "cpu" and self.offload:
# We don't quantize weights that we offload
if any(key + "." in param_name or key == param_name for key in self.modules_to_not_convert):
return False
elif any(param_name.endswith(f":{x}") for x in self.full_ao_keys):
return True
else:
# we only quantize the weight of nn.Linear and nn.Embedding
module, tensor_name = get_module_from_name(model, param_name)
_QUANTIZABLE = [torch.nn.Linear]
if self.quantization_config.include_input_output_embeddings:
_QUANTIZABLE.append(torch.nn.Embedding)
return isinstance(module, tuple(_QUANTIZABLE)) and (tensor_name == "weight")
return isinstance(module, tuple(_QUANTIZABLE)) and tensor_name == "weight"
def create_quantized_param(
self,
@ -261,29 +263,56 @@ class TorchAoHfQuantizer(HfQuantizer):
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: dict[str, Any],
**kwargs,
):
"""
Each nn.Linear layer that needs to be quantized is processed here.
First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
"""
if self.quantization_config.quant_type == "autoquant":
return
from torchao.quantization import quantize_
full_name = param_name
# Those are the pre quantized weights
if ":" in param_name:
param_name = param_name.rsplit(":", 1)[0]
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized:
# If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
# already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
is_unsafe_serialization = ":" not in full_name
if tensor_name == "bias" or is_unsafe_serialization:
module._parameters[tensor_name] = torch.nn.Parameter(
param_value.to(device=target_device), requires_grad=param_value.requires_grad
param_value.to(target_device), requires_grad=param_value.requires_grad
)
return
# Sanity check for the new serialization format
elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata)):
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
# Save the states for later quantization when they are all gathered
if not hasattr(self, "ao_params"):
self.ao_params = defaultdict(dict)
self.ao_params[param_name].update({full_name: param_value})
# We are ready for quantization in this case (we retrieved all the needed keys)
if len(self.ao_params[param_name]) == len(self.weight_ao_keys):
new_param = unflatten_tensor_state_dict(self.ao_params[param_name], self.metadata)[param_name]
# Set it
module._parameters[tensor_name] = torch.nn.Parameter(
new_param.to(target_device), requires_grad=new_param.requires_grad
)
# Free memory
del self.ao_params[param_name]
# Add repr to the module
if isinstance(module, nn.Linear):
module.extra_repr = types.MethodType(_linear_extra_repr, module)
else:
assert isinstance(self.quantization_config, TorchAoConfig)
module._parameters[tensor_name] = torch.nn.Parameter(
param_value, requires_grad=param_value.requires_grad
).to(device=target_device)
).to(target_device)
# if we are quantizing tied parameters, to avoid tying the quantized weights
# the correct order to do it is
# 1. load the weight to model
@ -313,16 +342,6 @@ class TorchAoHfQuantizer(HfQuantizer):
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
def update_state_dict_with_metadata(self, state_dict, metadata):
"""
If the metadata contains torchao tensor subclass information, we reconstruct the tensor subclass state dict
from the provided state_dict and metadata.
"""
if TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(metadata):
return unflatten_tensor_state_dict(state_dict, metadata)
else:
return state_dict
def _process_model_after_weight_loading(self, model, **kwargs):
"""No process required for torchao quantized model"""
if self.quantization_config.quant_type == "autoquant":
@ -415,3 +434,13 @@ class TorchAoHfQuantizer(HfQuantizer):
@property
def is_compileable(self) -> bool:
return True
def set_metadata(self, checkpoint_files: list[str]):
if checkpoint_files[0].endswith(".safetensors") and is_safetensors_available():
metadata = {}
for checkpoint in checkpoint_files:
with safe_open(checkpoint, framework="pt") as f:
metadata_ = f.metadata() or {}
metadata.update(metadata_)
# Save it
self.metadata = metadata