mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix]fix mixed bits and visual language model quantization in AutoRound (#21802)
Signed-off-by: Wenhua Cheng <wenhua.cheng@intel.com>
This commit is contained in:
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from fractions import Fraction
|
||||
from typing import Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -16,6 +16,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -28,7 +31,13 @@ class AutoRoundConfig(QuantizationConfig):
|
||||
SUPPORTED_DTYPES = {"int"}
|
||||
SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"}
|
||||
SUPPORTED_BACKENDS = {
|
||||
"auto", "gptq", "gptq:marlin", "awq", "awq:marlin", "marlin", "ipex"
|
||||
"auto",
|
||||
"gptq",
|
||||
"gptq:marlin",
|
||||
"awq",
|
||||
"awq:marlin",
|
||||
"marlin",
|
||||
"ipex",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@ -109,26 +118,70 @@ class AutoRoundConfig(QuantizationConfig):
|
||||
)
|
||||
|
||||
def get_layer_config(self, layer, layer_name: str):
|
||||
# Priority: extra_config > block_name_to_quantize > type fallback
|
||||
if self.extra_config and layer_name in self.extra_config:
|
||||
cfg = self.extra_config[layer_name]
|
||||
return cfg.get("bits", self.weight_bits), cfg.get(
|
||||
"group_size", self.group_size), cfg.get("sym", self.sym)
|
||||
|
||||
quantized = True
|
||||
def get_config(name: str, quantized: bool = True):
|
||||
cfg = self.extra_config.get(name, {}) if self.extra_config else {}
|
||||
return (
|
||||
cfg.get("bits", self.weight_bits if quantized else 16),
|
||||
cfg.get("group_size", self.group_size if quantized else -1),
|
||||
cfg.get("sym", self.sym if quantized else True),
|
||||
)
|
||||
|
||||
# 1. Exact match from config
|
||||
if self.extra_config and layer_name in self.extra_config:
|
||||
return get_config(layer_name)
|
||||
|
||||
# 2. Determine whether layer should be quantized
|
||||
quantized = not isinstance(layer, ParallelLMHead)
|
||||
if self.block_name_to_quantize:
|
||||
quantized = any(
|
||||
layer_name.startswith(name)
|
||||
for name in self.block_name_to_quantize)
|
||||
elif isinstance(layer, ParallelLMHead):
|
||||
quantized = False
|
||||
|
||||
return (self.weight_bits, self.group_size,
|
||||
self.sym) if quantized else (16, -1, True)
|
||||
# 3. Handle fused MoE
|
||||
if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower(
|
||||
):
|
||||
moe_configs = [
|
||||
get_config(name, quantized) for name in self.extra_config
|
||||
if name.startswith(layer_name)
|
||||
]
|
||||
if moe_configs:
|
||||
if len(set(moe_configs)) == 1:
|
||||
return moe_configs[0]
|
||||
raise ValueError(f"Fused MoE layer '{layer_name}' requires "
|
||||
f"consistent quant config for all sub-layers")
|
||||
|
||||
# 4. Handle fused QKV or other patterns
|
||||
if self.extra_config:
|
||||
for fusion_key, sub_keys in self.packed_modules_mapping.items():
|
||||
if fusion_key in layer_name and layer_name.count(
|
||||
fusion_key) == 1:
|
||||
sub_names = [
|
||||
layer_name.replace(fusion_key, sub_key)
|
||||
for sub_key in sub_keys
|
||||
]
|
||||
sub_configs = [
|
||||
get_config(name, quantized) for name in sub_names
|
||||
]
|
||||
if len(set(sub_configs)) == 1:
|
||||
return sub_configs[0]
|
||||
raise ValueError(
|
||||
f"Fused module '{layer_name}' requires "
|
||||
f"consistent quant config for {sub_names}")
|
||||
|
||||
# 5. Fallback
|
||||
return get_config(layer_name, quantized)
|
||||
|
||||
def check_quantized(self, weight_bits: int) -> bool:
|
||||
return weight_bits < 16
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.block_name_to_quantize is not None:
|
||||
self.block_name_to_quantize = hf_to_vllm_mapper.apply_list(
|
||||
self.block_name_to_quantize)
|
||||
if self.extra_config is not None:
|
||||
self.extra_config = hf_to_vllm_mapper.apply_dict(self.extra_config)
|
||||
|
||||
def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
@ -141,9 +194,14 @@ class AutoRoundConfig(QuantizationConfig):
|
||||
else:
|
||||
return None
|
||||
|
||||
logger.debug("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
|
||||
prefix, layer.__class__.__name__, weight_bits, group_size,
|
||||
sym)
|
||||
logger.debug(
|
||||
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
|
||||
prefix,
|
||||
layer.__class__.__name__,
|
||||
weight_bits,
|
||||
group_size,
|
||||
sym,
|
||||
)
|
||||
if backend == "auto" or "marlin" in backend:
|
||||
AWQ_TYPE_MAP = {
|
||||
4: scalar_types.uint4,
|
||||
@ -162,15 +220,19 @@ class AutoRoundConfig(QuantizationConfig):
|
||||
if use_marlin:
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinConfig, AWQMarlinLinearMethod, AWQMoEMethod)
|
||||
quant_args_marlin = AWQMarlinConfig(weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
zero_point=not sym,
|
||||
lm_head_quantized=False,
|
||||
full_config={},
|
||||
modules_to_not_convert=[])
|
||||
|
||||
quant_args_marlin = AWQMarlinConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
zero_point=not sym,
|
||||
lm_head_quantized=False,
|
||||
full_config={},
|
||||
modules_to_not_convert=[],
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.quantization.awq import (
|
||||
AWQConfig, AWQLinearMethod)
|
||||
|
||||
quant_args = AWQConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
@ -182,6 +244,7 @@ class AutoRoundConfig(QuantizationConfig):
|
||||
return AWQMoEMethod(quant_args_marlin)
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
|
||||
config = {
|
||||
"quant_method": "awq",
|
||||
"bits": weight_bits,
|
||||
@ -206,6 +269,7 @@ class AutoRoundConfig(QuantizationConfig):
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported, check_moe_marlin_supports_layer)
|
||||
|
||||
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
|
||||
if not self.check_quantized(weight_bits):
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
@ -213,19 +277,24 @@ class AutoRoundConfig(QuantizationConfig):
|
||||
else:
|
||||
return None
|
||||
|
||||
logger.debug("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
|
||||
prefix, layer.__class__.__name__, weight_bits, group_size,
|
||||
sym)
|
||||
logger.debug(
|
||||
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
|
||||
prefix,
|
||||
layer.__class__.__name__,
|
||||
weight_bits,
|
||||
group_size,
|
||||
sym,
|
||||
)
|
||||
if backend == "auto" or "marlin" in backend:
|
||||
GPTQ_TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
use_marlin = ((weight_bits, sym) in GPTQ_TYPE_MAP
|
||||
and check_marlin_supported(
|
||||
use_marlin = (weight_bits,
|
||||
sym) in GPTQ_TYPE_MAP and check_marlin_supported(
|
||||
GPTQ_TYPE_MAP[(weight_bits, sym)],
|
||||
group_size,
|
||||
has_zp=not sym))
|
||||
has_zp=not sym)
|
||||
if isinstance(layer, FusedMoE):
|
||||
use_marlin = use_marlin and check_moe_marlin_supports_layer(
|
||||
layer, group_size)
|
||||
@ -234,26 +303,33 @@ class AutoRoundConfig(QuantizationConfig):
|
||||
if use_marlin:
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig, GPTQMarlinLinearMethod, GPTQMarlinMoEMethod)
|
||||
quant_args_marlin = GPTQMarlinConfig(weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
is_sym=sym,
|
||||
lm_head_quantized=False,
|
||||
desc_act=False,
|
||||
dynamic={},
|
||||
full_config={})
|
||||
|
||||
quant_args_marlin = GPTQMarlinConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
is_sym=sym,
|
||||
lm_head_quantized=False,
|
||||
desc_act=False,
|
||||
dynamic={},
|
||||
full_config={},
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.quantization.gptq import (
|
||||
GPTQConfig, GPTQLinearMethod)
|
||||
quant_args = GPTQConfig(weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
lm_head_quantized=False,
|
||||
desc_act=False,
|
||||
dynamic={})
|
||||
|
||||
quant_args = GPTQConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
lm_head_quantized=False,
|
||||
desc_act=False,
|
||||
dynamic={},
|
||||
)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
if use_marlin:
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
|
||||
config = {
|
||||
"quant_method": "gptq",
|
||||
"bits": weight_bits,
|
||||
@ -282,6 +358,7 @@ class AutoRoundConfig(QuantizationConfig):
|
||||
return None
|
||||
from vllm.model_executor.layers.quantization.ipex_quant import (
|
||||
IPEXAWQLinearMethod, IPEXConfig, IPEXGPTQLinearMethod)
|
||||
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
if "awq" in self.packing_format:
|
||||
config = IPEXConfig(method="awq",
|
||||
|
Reference in New Issue
Block a user