Quantization: support FP4 quantized models on AMD CDNA2/CDNA3 GPUs (#22527)
Signed-off-by: feng <fengli1702@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@ -17,4 +17,4 @@ setuptools>=77.0.3,<80.0.0
|
||||
setuptools-scm>=8
|
||||
runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
conch-triton-kernels==1.2.1
|
||||
conch-triton-kernels==1.2.1
|
2
setup.py
2
setup.py
@ -695,6 +695,8 @@ setup(
|
||||
"video": [], # Kept for backwards compatibility
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
"flashinfer": ["flashinfer-python==0.2.12"],
|
||||
# Optional deps for AMD FP4 quantization support
|
||||
"petit-kernel": ["petit-kernel"],
|
||||
},
|
||||
cmdclass=cmdclass,
|
||||
package_data=package_data,
|
||||
|
@ -1119,9 +1119,20 @@ class ModelConfig:
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = me_quant.QUANTIZATION_METHODS
|
||||
optimized_quantization_methods = [
|
||||
"fp8", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
|
||||
"fbgemm_fp8", "compressed-tensors", "experts_int8", "quark",
|
||||
"modelopt_fp4", "bitblas", "gptq_bitblas", "inc"
|
||||
"fp8",
|
||||
"modelopt",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"awq_marlin",
|
||||
"fbgemm_fp8",
|
||||
"compressed-tensors",
|
||||
"experts_int8",
|
||||
"quark",
|
||||
"modelopt_fp4",
|
||||
"bitblas",
|
||||
"gptq_bitblas",
|
||||
"inc",
|
||||
"petit_nvfp4",
|
||||
]
|
||||
if self.quantization is not None:
|
||||
self.quantization = cast(me_quant.QuantizationMethods,
|
||||
@ -1153,6 +1164,7 @@ class ModelConfig:
|
||||
"moe_wna16",
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"petit_nvfp4",
|
||||
]
|
||||
quantization_methods = [
|
||||
q for q in supported_quantization if q not in overrides
|
||||
|
@ -52,6 +52,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"HQQMarlinMethod",
|
||||
"QuarkLinearMethod",
|
||||
"ModelOptNvFp4LinearMethod",
|
||||
"PetitNvFp4LinearMethod",
|
||||
]
|
||||
|
||||
|
||||
|
@ -35,6 +35,7 @@ QuantizationMethods = Literal[
|
||||
"rtn",
|
||||
"inc",
|
||||
"mxfp4",
|
||||
"petit_nvfp4",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
|
||||
@ -108,6 +109,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .mxfp4 import Mxfp4Config
|
||||
from .neuron_quant import NeuronQuantConfig
|
||||
from .petit import PetitNvFp4Config
|
||||
from .ptpc_fp8 import PTPCFp8Config
|
||||
from .rtn import RTNConfig
|
||||
from .torchao import TorchAOConfig
|
||||
@ -142,6 +144,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"rtn": RTNConfig,
|
||||
"inc": INCConfig,
|
||||
"mxfp4": Mxfp4Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
}
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
|
306
vllm/model_executor/layers/quantization/petit.py
Normal file
306
vllm/model_executor/layers/quantization/petit.py
Normal file
@ -0,0 +1,306 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.petit_utils import (
|
||||
apply_petit_nvfp4_linear, prepare_nvfp4_layer_for_petit,
|
||||
verify_petit_nvfp4_supported)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Initialize logger for the module
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Configuration class to support the NVFP4 quantized model
|
||||
# generated by the ModelOpt quantization tool
|
||||
class PetitNvFp4Config(QuantizationConfig):
|
||||
"""Config class for Petit FP4."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_checkpoint_nvfp4_serialized: bool = False,
|
||||
kv_cache_quant_algo: Optional[str] = None,
|
||||
group_size: Optional[int] = None,
|
||||
exclude_modules: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
self._check_hardware_support()
|
||||
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
||||
if is_checkpoint_nvfp4_serialized:
|
||||
logger.warning("Detected nvfp4 checkpoint. Please note that the "
|
||||
"format is experimental and subject to change.")
|
||||
self.group_size = group_size
|
||||
self.kv_cache_quant_algo = kv_cache_quant_algo
|
||||
self.exclude_modules = exclude_modules
|
||||
|
||||
def _check_hardware_support(self) -> None:
|
||||
"""
|
||||
Verifies that the current hardware is supported by the Petit backend.
|
||||
This backend is specifically designed for AMD GPUs and is not
|
||||
supported on the CUDA platform.
|
||||
"""
|
||||
# This check ensures the code is NOT running on an NVIDIA GPU.
|
||||
if current_platform.is_cuda():
|
||||
raise ValueError(
|
||||
"The 'petit' quantization backend is designed for AMD GPUs "
|
||||
"and is not supported on the CUDA platform. For NVIDIA GPUs, "
|
||||
"please use a different quantization method such as FP8, AWQ, "
|
||||
"or GPTQ.")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "petit_nvfp4"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# Petit supports the gfx90a and gfx942 GPUs
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config":
|
||||
qc = cls.get_from_keys(config, ["quantization"])
|
||||
|
||||
quant_method_raw = qc.get("quant_algo")
|
||||
if not isinstance(quant_method_raw, str) or not quant_method_raw:
|
||||
raise ValueError(
|
||||
"Missing or invalid 'quant_algo' in quantization config.")
|
||||
quant_method = quant_method_raw.upper()
|
||||
|
||||
group_size_raw = qc.get("group_size")
|
||||
if not isinstance(group_size_raw, int):
|
||||
raise ValueError(
|
||||
"Missing or invalid 'group_size' (int) in hf_quant_config.json."
|
||||
)
|
||||
group_size = group_size_raw
|
||||
|
||||
verify_petit_nvfp4_supported(quant_method, group_size)
|
||||
|
||||
kv_cache_quant_algo_raw = qc.get("kv_cache_quant_algo") or "auto"
|
||||
if not isinstance(kv_cache_quant_algo_raw, str):
|
||||
raise ValueError(
|
||||
"'kv_cache_quant_algo' must be a string if provided.")
|
||||
kv_cache_quant_algo = kv_cache_quant_algo_raw
|
||||
|
||||
exclude_raw = qc.get("exclude_modules", [])
|
||||
if exclude_raw is None:
|
||||
exclude_modules: list[str] = []
|
||||
elif isinstance(exclude_raw, list) and all(
|
||||
isinstance(x, str) for x in exclude_raw):
|
||||
exclude_modules = exclude_raw
|
||||
else:
|
||||
raise ValueError(
|
||||
"'exclude_modules' must be a list[str] (or omitted).")
|
||||
|
||||
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
||||
|
||||
return cls(
|
||||
is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized,
|
||||
kv_cache_quant_algo=kv_cache_quant_algo,
|
||||
group_size=group_size,
|
||||
exclude_modules=exclude_modules,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
if not current_platform.is_rocm():
|
||||
return None
|
||||
|
||||
qc = hf_quant_cfg.get("quantization", hf_quant_cfg)
|
||||
algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
|
||||
if algo in ("NVFP4", "MODELOPT_FP4", "MODELOPT"):
|
||||
return cls.get_name() # "petit_nvfp4"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_petit_nvfp4_compatible(cls, quant_config: dict[str, Any]) -> bool:
|
||||
qc = quant_config.get("quantization", quant_config)
|
||||
algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
|
||||
return algo == "NVFP4"
|
||||
|
||||
def is_layer_excluded(self, prefix: str,
|
||||
exclude_modules: list[str]) -> bool:
|
||||
for pattern in exclude_modules:
|
||||
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
exclude = self.require_exclude_modules()
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, exclude) or self.is_layer_excluded(
|
||||
prefix, exclude):
|
||||
return UnquantizedLinearMethod()
|
||||
return PetitNvFp4LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
return PetitFp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def require_group_size(self) -> int:
|
||||
if self.group_size is None:
|
||||
logger.warning("group_size not set; defaulting to 16 for NVFP4.")
|
||||
return 16
|
||||
return self.group_size
|
||||
|
||||
def require_kv_cache_quant_algo(self) -> str:
|
||||
return self.kv_cache_quant_algo or "auto"
|
||||
|
||||
def require_exclude_modules(self) -> list[str]:
|
||||
return list(self.exclude_modules or [])
|
||||
|
||||
|
||||
class PetitFp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: PetitNvFp4Config):
|
||||
super().__init__(quant_config)
|
||||
|
||||
|
||||
class PetitNvFp4LinearMethod(LinearMethodBase):
|
||||
"""Linear method for NVFP4.
|
||||
Supports loading NVFP4 checkpoints with the following structure:
|
||||
|
||||
|Tensor Name | datatype | shape |
|
||||
|----------------------------------------------------|
|
||||
|input_scale | torch.float32 | scalar |
|
||||
|weight | NVFP4(SE2M1) | [1, X, y/2] |
|
||||
|weight_scale | FP8-E4M3 | [X, Y] |
|
||||
|weight_scale_2 | torch.float32 | scalar |
|
||||
|
||||
The weights are quantized per block of 16 elements.
|
||||
Args: quant_config: The ModelOpt quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: PetitNvFp4Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del input_size, output_size
|
||||
if not self.quant_config.is_checkpoint_nvfp4_serialized:
|
||||
raise ValueError("NVFP4 quantization was selected, "
|
||||
" dynamic quantization is not supported.")
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
if input_size_per_partition % 16 != 0:
|
||||
raise ValueError("Unsupported model when in features size is "
|
||||
"not multiple of 16")
|
||||
|
||||
weight_dtype = (torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_nvfp4_serialized
|
||||
else params_dtype)
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
# 2 fp4 data is packed in one uint8 in the input dimension
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale_2", weight_scale_2)
|
||||
|
||||
group_size = self.quant_config.require_group_size()
|
||||
weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // group_size,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
input_scale_2 = layer.input_scale.max().to(torch.float32)
|
||||
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
|
||||
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
|
||||
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
|
||||
layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
|
||||
requires_grad=False)
|
||||
|
||||
prepare_nvfp4_layer_for_petit(layer)
|
||||
del layer.input_scale
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_petit_nvfp4_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale_2=layer.weight_scale_2,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
122
vllm/model_executor/layers/quantization/utils/petit_utils.py
Normal file
122
vllm/model_executor/layers/quantization/utils/petit_utils.py
Normal file
@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
# TYPE_CHECKING is used for static type analysis to prevent circular imports.
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
# 1. Create a global variable as a placeholder for the module
|
||||
_petit_kernel: Optional["ModuleType"] = None
|
||||
|
||||
_PETIT_INSTALL_MSG = ("Petit is not installed. Please install it with "
|
||||
"`pip install petit-kernel`.")
|
||||
|
||||
|
||||
def _import_petit_kernel() -> "ModuleType":
|
||||
"""
|
||||
A helper function to handle the lazy import.
|
||||
The first time this function is called, it will import the petit_kernel
|
||||
library and store it in the global _petit_kernel variable.
|
||||
Subsequent calls will return the already-loaded module directly.
|
||||
"""
|
||||
global _petit_kernel
|
||||
if _petit_kernel is not None:
|
||||
return _petit_kernel
|
||||
|
||||
try:
|
||||
import petit_kernel
|
||||
_petit_kernel = petit_kernel
|
||||
return _petit_kernel
|
||||
except ImportError:
|
||||
# The 'from None' syntax prevents chaining the original ImportError,
|
||||
# making the traceback cleaner.
|
||||
raise ImportError(_PETIT_INSTALL_MSG) from None
|
||||
|
||||
|
||||
# The _require_petit function can now be a simple alias for consistency.
|
||||
_require_petit = _import_petit_kernel
|
||||
|
||||
|
||||
def _check_petit_nvfp4_supported(
|
||||
quant_method: str,
|
||||
group_size: Optional[int]) -> tuple[bool, Optional[str]]:
|
||||
if quant_method != "NVFP4":
|
||||
return (
|
||||
False,
|
||||
("Petit currently only supports: NVFP4 quantizations in sglang. "
|
||||
"Please check the `hf_quant_config.json` file for your model's "
|
||||
"quant configuration."),
|
||||
)
|
||||
if group_size is not None and group_size != 16:
|
||||
return (
|
||||
False,
|
||||
"Petit currently only supports: group_size=16 quantizations.",
|
||||
)
|
||||
return (True, None)
|
||||
|
||||
|
||||
def verify_petit_nvfp4_supported(quant_method: str,
|
||||
group_size: Optional[int]) -> None:
|
||||
supported, error_msg = _check_petit_nvfp4_supported(
|
||||
quant_method, group_size)
|
||||
if not supported:
|
||||
assert error_msg is not None
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None:
|
||||
# 2. Call _import_petit_kernel() to trigger (or get) the import.
|
||||
petit_kernel = _import_petit_kernel()
|
||||
|
||||
# Repack weights to petit format
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
qweight = layer.weight.view(torch.int32).contiguous()
|
||||
|
||||
# 3. Call functions through the imported module variable.
|
||||
petit_qweight = petit_kernel.repack_nvfp4(qweight,
|
||||
size_n=part_size_n,
|
||||
size_k=part_size_k)
|
||||
layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False)
|
||||
|
||||
# Permute scales
|
||||
weight_scale = petit_kernel.process_nvfp4_scales(scales=layer.weight_scale,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
|
||||
def apply_petit_nvfp4_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
weight_scale_2: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Trigger (or get) the import here as well.
|
||||
petit_kernel = _import_petit_kernel()
|
||||
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (size_n, )
|
||||
|
||||
# TODO: Use auto-tuning to find the performant solution_id
|
||||
# Call the function via the module variable.
|
||||
output = petit_kernel.mul_nvfp4_a16(
|
||||
a=reshaped_x,
|
||||
b=weight,
|
||||
s=weight_scale,
|
||||
global_scale=weight_scale_2,
|
||||
size_m=reshaped_x.size(0),
|
||||
size_n=size_n,
|
||||
size_k=size_k,
|
||||
solution_id=-1,
|
||||
)
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
@ -171,7 +171,7 @@ class RocmPlatform(Platform):
|
||||
|
||||
supported_quantization: list[str] = [
|
||||
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
|
||||
"quark", "ptpc_fp8", "mxfp4"
|
||||
"quark", "ptpc_fp8", "mxfp4", "petit_nvfp4"
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
Reference in New Issue
Block a user