[Bugfix] Respect modules_to_not_convert within awq_marlin (#9895)

Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Michael Goin
2024-11-04 18:57:44 -05:00
committed by GitHub
parent 2094062b4e
commit 8f0a9ca890

View File

@ -9,7 +9,9 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils import replace_parameter
@ -36,13 +38,18 @@ class AWQMarlinConfig(QuantizationConfig):
8: scalar_types.uint8,
}
def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
lm_head_quantized: bool) -> None:
def __init__(self,
weight_bits: int,
group_size: int,
zero_point: bool,
lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]] = None) -> None:
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.has_zp = has_zp
self.zero_point = zero_point
self.lm_head_quantized = lm_head_quantized
self.weight_bits = weight_bits
self.modules_to_not_convert = modules_to_not_convert or []
if self.weight_bits not in self.TYPE_MAP:
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
@ -52,13 +59,14 @@ class AWQMarlinConfig(QuantizationConfig):
verify_marlin_supported(self.quant_type,
group_size=self.group_size,
has_zp=self.has_zp)
has_zp=self.zero_point)
def __repr__(self) -> str:
return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"has_zp={self.has_zp}, "
f"lm_head_quantized={self.lm_head_quantized})")
f"zero_point={self.zero_point}, "
f"lm_head_quantized={self.lm_head_quantized}, "
f"modules_to_not_convert={self.modules_to_not_convert})")
@classmethod
def get_name(cls) -> str:
@ -80,10 +88,13 @@ class AWQMarlinConfig(QuantizationConfig):
def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
has_zp = cls.get_from_keys(config, ["zero_point"])
zero_point = cls.get_from_keys(config, ["zero_point"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, has_zp, lm_head_quantized)
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(weight_bits, group_size, zero_point, lm_head_quantized,
modules_to_not_convert)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
@ -109,6 +120,8 @@ class AWQMarlinConfig(QuantizationConfig):
prefix: str) -> Optional["QuantizeMethodBase"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
@ -123,7 +136,7 @@ class AWQMarlinConfig(QuantizationConfig):
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
has_zp = quant_config.get("zero_point")
zero_point = quant_config.get("zero_point")
if not current_platform.is_cuda():
return False
@ -132,7 +145,7 @@ class AWQMarlinConfig(QuantizationConfig):
return False
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or has_zp is None):
if (num_bits is None or group_size is None or zero_point is None):
return False
if num_bits not in cls.TYPE_MAP:
@ -140,7 +153,7 @@ class AWQMarlinConfig(QuantizationConfig):
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
group_size=group_size,
has_zp=has_zp)
has_zp=zero_point)
class AWQMarlinLinearMethod(LinearMethodBase):