[Hardware][XPU] AWQ/GPTQ support for xpu backend (#10107)

Signed-off-by: yan ma <yan.ma@intel.com>
This commit is contained in:
Yan Ma
2024-11-19 02:18:05 +08:00
committed by GitHub
parent 281cc4b3cd
commit 6b2d25efc7
7 changed files with 146 additions and 52 deletions

View File

@ -27,7 +27,7 @@ The table below shows the compatibility of various quantization implementations
- ✅︎
- ✅︎
- ✗
-
- ✅︎
- ✅︎
- ✗
- ✗
@ -38,8 +38,8 @@ The table below shows the compatibility of various quantization implementations
- ✅︎
- ✅︎
- ✗
-
-
- ✅︎
- ✅︎
- ✗
- ✗
* - Marlin (GPTQ/AWQ/FP8)
@ -129,4 +129,4 @@ Notes:
Please note that this compatibility chart may be subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods.
For the most up-to-date information on hardware support and quantization methods, please check the `quantization directory <https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization>`_ or consult with the vLLM development team.
For the most up-to-date information on hardware support and quantization methods, please check the `quantization directory <https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization>`_ or consult with the vLLM development team.

View File

@ -1,5 +1,5 @@
"""Test model set-up and inference for quantized HF models supported
on the CPU backend using IPEX (including AWQ).
on the CPU/GPU backend using IPEX (including AWQ/GPTQ).
Validating the configuration and printing results for manual checking.
@ -11,13 +11,15 @@ import pytest
from vllm.platforms import current_platform
MODELS = [
"casperhansen/llama-3-8b-instruct-awq",
"AMead10/Llama-3.2-1B-Instruct-AWQ",
"shuyuej/Llama-3.2-1B-Instruct-GPTQ", # with g_idx
]
DTYPE = ["bfloat16"]
@pytest.mark.skipif(not current_platform.is_cpu(),
reason="only supports the CPU backend.")
@pytest.mark.skipif(not current_platform.is_cpu()
and not current_platform.is_xpu(),
reason="only supports Intel CPU/XPU backend.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", DTYPE)
def test_ipex_quant(vllm_runner, model, dtype):

View File

@ -27,7 +27,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod"
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod"
]

View File

@ -210,7 +210,6 @@ class GPTQLinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# for torch.compile
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)

View File

@ -23,6 +23,7 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@ -134,6 +135,9 @@ class GPTQMarlinConfig(QuantizationConfig):
sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act")
if not current_platform.is_cuda():
return False
if quant_method != "gptq":
return False

View File

@ -2,21 +2,26 @@ from typing import Any, Dict, List, Optional
import torch
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.platforms import current_platform
MIN_IPEX_VERSION = "2.5.0"
class IPEXConfig(QuantizationConfig):
"""INT8 quantization config class using IPEX for the CPU backend,
including AWQ.
"""INT8 quantization config class using IPEX for the CPU/XPU backend,
including AWQ, GPTQ.
"""
IPEX_QUANT_METHOD_MAP = {
"awq": 1,
"gptq": 2,
"gptq": 0,
}
def __init__(
@ -24,29 +29,30 @@ class IPEXConfig(QuantizationConfig):
method: str,
weight_bits: int,
group_size: int,
modules_to_not_convert: Optional[List[str]] = None,
desc_act: Optional[bool] = None,
lm_head_quantized: Optional[bool] = None,
) -> None:
self.method = method
self.weight_bits = weight_bits
self.group_size = group_size
self.modules_to_not_convert = modules_to_not_convert or []
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.pack_factor = 32 // self.weight_bits
if self.weight_bits not in [4]:
raise ValueError(f"IPEX quantization supports weight bits [4], "
f"but got {self.weight_bits}.")
if self.method == "awq":
self.quant_method = IPEXAWQLinearMethod
else:
raise ValueError(f"IPEX quantization supports [awq], "
if self.method not in ["awq", "gptq"]:
raise ValueError(f"IPEX quantization supports [awq, gptq], "
f"but got {self.method}.")
def __repr__(self) -> str:
return (f"IPEXConfig(method={self.method}"
return (f"IPEXConfig(method={self.method},"
f"weight_bits={self.weight_bits}, "
f"group_size={self.group_size}")
def get_ipex_quant_method_id(self) -> int:
return IPEXConfig.IPEX_QUANT_METHOD_MAP[self.method]
f"group_size={self.group_size})")
@classmethod
def get_name(cls) -> str:
@ -70,19 +76,32 @@ class IPEXConfig(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig":
method = cls.get_from_keys(config, ["quant_method"]).lower()
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
return cls(method, weight_bits, group_size)
if method == "awq":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config,
["q_group_size", "group_size"])
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(method, weight_bits, group_size, modules_to_not_convert,
False, False)
# otherwise for gptq
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False)
return cls(method, weight_bits, group_size, [], desc_act,
lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
if not current_platform.is_cpu():
if not current_platform.is_cpu() and not current_platform.is_xpu():
return None
quant_method = hf_quant_cfg.get("quant_method", "").lower()
if quant_method in ["awq"]:
if quant_method in ["awq", "gptq"]:
return cls.get_name()
return None
@ -90,12 +109,81 @@ class IPEXConfig(QuantizationConfig):
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase):
return self.quant_method(self)
if self.method == "awq":
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return IPEXAWQLinearMethod(self)
if self.method == "gptq":
return IPEXGPTQLinearMethod(self)
return None
class IPEXGPTQLinearMethod(GPTQLinearMethod):
"""GPTQ linear method using IPEX for the CPU/XPU backend.
"""
def __init__(self, quant_config: IPEXConfig):
self.quant_config = quant_config # type: ignore
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
bias = layer.bias if not layer.skip_bias_add else None
try:
import intel_extension_for_pytorch as ipex
if ipex.__version__ < MIN_IPEX_VERSION:
raise ImportError(
"intel_extension_for_pytorch version is "
"wrong. Please install "
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.")
except ImportError as err:
raise ImportError(
"Please install "
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
" to use IPEX-AWQ linear method.") from err
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
# with better performance.
lowp_mode = ipex.quantization.WoqLowpMode.INT8
# The weight will be de-packed from INT4 to INT8.
weight_dtype = ipex.quantization.WoqWeightDtype.INT4
# The float activation will be quantized (dynamic, per-token) to INT8.
act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=weight_dtype,
lowp_mode=lowp_mode,
act_quant_mode=act_quant_mode,
group_size=self.quant_config.group_size,
)
layer.ipex_output_size = layer.qweight.shape[-1]
g_idx = layer.g_idx if self.quant_config.desc_act else None
layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \
IPEXWeightOnlyQuantizedLinear.from_weight(
layer.qweight,
layer.scales,
layer.qzeros,
layer.qweight.size(0),
layer.ipex_output_size,
qconfig=qconfig,
g_idx=g_idx,
bias=bias,
group_size=self.quant_config.group_size,
quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"]
)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out = layer.ipex_qlinear(reshaped_x)
if bias is not None:
out.add_(bias)
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))
class IPEXAWQLinearMethod(AWQLinearMethod):
"""AWQ linear method using IPEX for the CPU backend.
"""AWQ linear method using IPEX for the CPU/XPU backend.
"""
def __init__(self, quant_config: IPEXConfig):
@ -108,15 +196,16 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
try:
import intel_extension_for_pytorch as ipex
if ipex.__version__ < "2.4.0":
raise ImportError("intel_extension_for_pytorch version is "
"wrong. Please install "
"intel_extension_for_pytorch>=2.4.0.")
if ipex.__version__ < MIN_IPEX_VERSION:
raise ImportError(
"intel_extension_for_pytorch version is "
"wrong. Please install "
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.")
except ImportError as err:
raise ImportError(
"Please install "
"intel_extension_for_pytorch>=2.4.0 via "
"`pip install intel_extension_for_pytorch>=2.4.0`"
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
" to use IPEX-AWQ linear method.") from err
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
@ -136,19 +225,18 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
layer.ipex_output_size = layer.qweight.size(
1) * self.quant_config.pack_factor
layer.ipex_qlinear = ipex.nn.modules.weight_only_quantization.\
WeightOnlyQuantizedLinear.from_weight(
layer.qweight,
layer.scales,
layer.qzeros,
layer.qweight.size(0),
layer.ipex_output_size,
qconfig=qconfig,
bias=bias,
group_size=self.quant_config.group_size,
quant_method=
self.quant_config.get_ipex_quant_method_id() # type: ignore
)
layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \
IPEXWeightOnlyQuantizedLinear.from_weight(
layer.qweight,
layer.scales,
layer.qzeros,
layer.qweight.size(0),
layer.ipex_output_size,
qconfig=qconfig,
bias=bias,
group_size=self.quant_config.group_size,
quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"] # type: ignore
)
def apply(self,
layer: torch.nn.Module,
@ -156,5 +244,4 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out = layer.ipex_qlinear(reshaped_x)
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))

View File

@ -29,6 +29,8 @@ from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase)
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator)
@ -348,7 +350,7 @@ class DefaultModelLoader(BaseModelLoader):
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the