mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Quant][Perf] Use moe_wna16 kernel by default for MoEs with many experts (#13236)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@ -12,7 +12,7 @@ MODEL_NAME = os.environ.get("MODEL_NAME",
|
||||
"robertgshaw2/zephyr-7b-beta-channelwise-gptq")
|
||||
REVISION = os.environ.get("REVISION", "main")
|
||||
QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin")
|
||||
MIN_CAPABILITY = os.environ.get("MIN_CAPABILITY", "89")
|
||||
MIN_CAPABILITY = os.environ.get("MIN_CAPABILITY", "80")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
||||
is_layer_skipped_awq)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
||||
@ -134,7 +135,12 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
return AWQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return AWQMoEMethod(self)
|
||||
if layer.num_experts > 32:
|
||||
# For MoEs with many experts the moe_wna16 kernel is faster
|
||||
return MoeWNA16Config.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
else:
|
||||
return AWQMoEMethod(self)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
|
@ -10,20 +10,18 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported, marlin_moe_permute_scales,
|
||||
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
@ -44,15 +42,10 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||
) -> None:
|
||||
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
|
||||
is_sym: bool, lm_head_quantized: bool,
|
||||
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||
full_config: Dict[str, Any]) -> None:
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
@ -90,6 +83,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.full_config = full_config
|
||||
|
||||
if (weight_bits, is_sym) not in self.TYPE_MAP:
|
||||
raise ValueError("Unsupported quantization config: "
|
||||
@ -132,7 +126,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym,
|
||||
lm_head_quantized, dynamic)
|
||||
lm_head_quantized, dynamic, config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
@ -155,12 +149,15 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
" faster inference")
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod",
|
||||
UnquantizedLinearMethod, UnquantizedEmbeddingMethod]]:
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
if layer.num_experts > 32:
|
||||
# For MoEs with many experts the moe_wna16 kernel is faster
|
||||
return MoeWNA16Config.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
else:
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
return get_linear_quant_method(self, layer, prefix,
|
||||
GPTQMarlinLinearMethod)
|
||||
|
||||
|
@ -9,13 +9,8 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supports_layer)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -37,6 +32,12 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
self.linear_quant_method = linear_quant_method
|
||||
self.full_config = full_config
|
||||
self.use_marlin = False
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinConfig)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
if self.linear_quant_method == "gptq":
|
||||
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(
|
||||
full_config)
|
||||
@ -115,6 +116,8 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
awq_min_capability = AWQConfig.get_min_capability()
|
||||
|
||||
gptq_compatible = quant_method == "gptq" and \
|
||||
@ -129,6 +132,13 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, LinearBase):
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinConfig)
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
if self.linear_quant_method == "gptq":
|
||||
if self.use_marlin:
|
||||
return GPTQMarlinConfig.from_config(
|
||||
|
Reference in New Issue
Block a user