Explicitly explain quant method override ordering and ensure all overrides are ordered (#17256)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-04-28 17:55:31 +01:00
committed by GitHub
parent b6dd32aa07
commit c7941cca18
2 changed files with 39 additions and 9 deletions

View File

@ -28,6 +28,7 @@ import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
QuantizationMethods,
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import CpuArchEnum, current_platform
@ -767,12 +768,43 @@ class ModelConfig:
"compressed-tensors")
quant_cfg["quant_method"] = quant_method
# Quantization methods which are overrides (i.e. they have a
# `override_quantization_method` method) must be checked in order
# of preference (this is particularly important for GPTQ).
overrides = [
"marlin",
"bitblas",
"gptq_marlin_24",
"gptq_marlin",
"gptq_bitblas",
"awq_marlin",
"ipex",
"moe_wna16",
]
quantization_methods = [
q for q in supported_quantization if q not in overrides
]
# Any custom overrides will be in quantization_methods so we place
# them at the start of the list so custom overrides have preference
# over the built in ones.
quantization_methods = quantization_methods + overrides
# Detect which checkpoint is it
for name in QUANTIZATION_METHODS:
for name in quantization_methods:
method = get_quantization_config(name)
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization)
if quantization_override:
if quantization_override is not None:
# Raise error if the override is not custom (custom would
# be in QUANTIZATION_METHODS but not QuantizationMethods)
# and hasn't been added to the overrides list.
if (name in get_args(QuantizationMethods)
and name not in overrides):
raise ValueError(
f"Quantization method {name} is an override but "
"is has not been added to the `overrides` list "
"above. This is necessary to ensure that the "
"overrides are checked in order of preference.")
quant_method = quantization_override
self.quantization = quantization_override
break

View File

@ -1,11 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Type
from typing import Literal, Type, get_args
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QUANTIZATION_METHODS: List[str] = [
QuantizationMethods = Literal[
"aqlm",
"awq",
"deepspeedfp",
@ -15,8 +15,6 @@ QUANTIZATION_METHODS: List[str] = [
"fbgemm_fp8",
"modelopt",
"nvfp4",
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin",
"bitblas",
"gguf",
@ -36,6 +34,7 @@ QUANTIZATION_METHODS: List[str] = [
"moe_wna16",
"torchao",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
# The customized quantization methods which will be added to this dict.
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}
@ -111,7 +110,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig
method_to_config: Dict[str, Type[QuantizationConfig]] = {
method_to_config: dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
@ -120,8 +119,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config,
"nvfp4": ModelOptNvFp4Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
"bitblas": BitBLASConfig,
"gguf": GGUFConfig,
@ -150,6 +147,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
__all__ = [
"QuantizationConfig",
"QuantizationMethods",
"get_quantization_config",
"QUANTIZATION_METHODS",
]