Fix: AWQ Marlin get_quant_method does not recognize "modules_to_not_convert" (#21888)

Signed-off-by: JunHowie <JunHowie@aliyun.com>
Co-authored-by: JunHowie <JunHowie@aliyun.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Jun-Howie
2025-08-12 17:03:53 +08:00
committed by GitHub
parent bc8372efc3
commit 1ece7f30ba

View File

@ -10,7 +10,8 @@ import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
@ -141,6 +142,9 @@ class AWQMarlinConfig(QuantizationConfig):
elif isinstance(layer, FusedMoE):
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
if is_layer_skipped_awq(
prefix, getattr(self, "modules_to_not_convert", [])):
return UnquantizedFusedMoEMethod(layer.moe_config)
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
@ -520,4 +524,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
expert_map=expert_map,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
workspace=layer.workspace)
workspace=layer.workspace)