mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Respect modules_to_not_convert within awq_marlin (#9895)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
@ -9,7 +9,9 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
@ -36,13 +38,18 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
8: scalar_types.uint8,
|
||||
}
|
||||
|
||||
def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
|
||||
lm_head_quantized: bool) -> None:
|
||||
def __init__(self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
lm_head_quantized: bool,
|
||||
modules_to_not_convert: Optional[List[str]] = None) -> None:
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.has_zp = has_zp
|
||||
self.zero_point = zero_point
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.weight_bits = weight_bits
|
||||
self.modules_to_not_convert = modules_to_not_convert or []
|
||||
|
||||
if self.weight_bits not in self.TYPE_MAP:
|
||||
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
|
||||
@ -52,13 +59,14 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
|
||||
verify_marlin_supported(self.quant_type,
|
||||
group_size=self.group_size,
|
||||
has_zp=self.has_zp)
|
||||
has_zp=self.zero_point)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"has_zp={self.has_zp}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized})")
|
||||
f"zero_point={self.zero_point}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
@ -80,10 +88,13 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
has_zp = cls.get_from_keys(config, ["zero_point"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, has_zp, lm_head_quantized)
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None)
|
||||
return cls(weight_bits, group_size, zero_point, lm_head_quantized,
|
||||
modules_to_not_convert)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
@ -109,6 +120,8 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
return AWQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return AWQMoEMethod(self)
|
||||
@ -123,7 +136,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
has_zp = quant_config.get("zero_point")
|
||||
zero_point = quant_config.get("zero_point")
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
@ -132,7 +145,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or has_zp is None):
|
||||
if (num_bits is None or group_size is None or zero_point is None):
|
||||
return False
|
||||
|
||||
if num_bits not in cls.TYPE_MAP:
|
||||
@ -140,7 +153,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
|
||||
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
|
||||
group_size=group_size,
|
||||
has_zp=has_zp)
|
||||
has_zp=zero_point)
|
||||
|
||||
|
||||
class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
|
Reference in New Issue
Block a user