mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Fix FP-Quant quantization fallback CPU dispatch. (#41619)
* fp_quant fix * Update quantizer_fp_quant.py
This commit is contained in:
@ -97,6 +97,10 @@ class FPQuantHfQuantizer(HfQuantizer):
|
|||||||
):
|
):
|
||||||
module, _ = get_module_from_name(model, param_name)
|
module, _ = get_module_from_name(model, param_name)
|
||||||
|
|
||||||
|
if target_device == "cpu" and param_name.endswith("weight"):
|
||||||
|
# Works agains hard-coded missing key dispatch to CPU
|
||||||
|
return
|
||||||
|
|
||||||
# The module holds either:
|
# The module holds either:
|
||||||
# * `weight` when `store_master_weights=True`
|
# * `weight` when `store_master_weights=True`
|
||||||
# * `qweight` and `scales` when `store_master_weights=False` and `pseudoquantization=False`
|
# * `qweight` and `scales` when `store_master_weights=False` and `pseudoquantization=False`
|
||||||
|
@ -160,14 +160,14 @@ class FPQuantNVFP4PseudoquantTest(FPQuantBaseTest):
|
|||||||
class FPQuantMXFP4Test(FPQuantBaseTest):
|
class FPQuantMXFP4Test(FPQuantBaseTest):
|
||||||
@classmethod
|
@classmethod
|
||||||
def getQuantizationConfig(cls):
|
def getQuantizationConfig(cls):
|
||||||
return FPQuantConfig(forward_dtype="nvfp4", pseudoquantization=False)
|
return FPQuantConfig(forward_dtype="mxfp4", pseudoquantization=False)
|
||||||
|
|
||||||
|
|
||||||
@require_qutlass
|
@require_qutlass
|
||||||
class FPQuantMXFP4GS128Test(FPQuantBaseTest):
|
class FPQuantMXFP4GS128Test(FPQuantBaseTest):
|
||||||
@classmethod
|
@classmethod
|
||||||
def getQuantizationConfig(cls):
|
def getQuantizationConfig(cls):
|
||||||
return FPQuantConfig(forward_dtype="nvfp4", pseudoquantization=False, hadamard_group_size=128)
|
return FPQuantConfig(forward_dtype="mxfp4", pseudoquantization=False, hadamard_group_size=128)
|
||||||
|
|
||||||
|
|
||||||
@require_qutlass
|
@require_qutlass
|
||||||
|
Reference in New Issue
Block a user