Files
vllm-dev/vllm/model_executor/layers/quantization/ptpc_fp8.py
nvjullin f66673a39d [Kernel] Added flashinfer fp8 per-tensor gemms (#22895)
Signed-off-by: Julien Lin <jullin@nvidia.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
2025-08-26 06:54:04 -07:00

130 lines
5.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase)
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
Fp8KVCacheMethod,
Fp8LinearMethod)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
from vllm.platforms import current_platform
ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__)
class PTPCFp8Config(Fp8Config):
"""Config class for Per-Token-Per-Channel Dynamic Quantization Fp8."""
def __init__(
self,
activation_scheme: str = "dynamic",
ignored_layers: Optional[list[str]] = None,
) -> None:
if not current_platform.is_rocm():
raise ValueError(
"ptpc_fp8 quantization is supported only on ROCm.")
if not current_platform.has_device_capability(94):
raise ValueError(
"ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501
)
if activation_scheme == "static":
raise ValueError(
"ptpc_fp8 as of now only support dynamic quantization.")
super().__init__(is_checkpoint_fp8_serialized=False,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers)
@classmethod
def get_name(cls) -> QuantizationMethods:
return "ptpc_fp8"
@classmethod
def from_config(cls, config: dict[str, Any]) -> "PTPCFp8Config":
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
return cls(activation_scheme=activation_scheme,
ignored_layers=ignored_layers)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
return PTPCFp8LinearMethod(self)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
class PTPCFp8LinearMethod(Fp8LinearMethod):
"""Linear method for Per-Token and Per-Channel FP8 Quantization.
Only supports loading quantized BF16 model checkpoints with dynamic
activation scaling. To load FP16 model checkpoints, user must specify
to convert the FP16 model weight loading into BF16.
The weight scaling factor will be initialized after
the model weights are loaded.
Limitations:
1. Only support float8_e4m3fnuz data type due to the limitation of
torch._scaled_mm (https://github.com/ROCm/pytorch/blob/8c0504d7f3fb0ee4c278c096a5c3caedb01129fa/aten/src/ATen/native/cuda/Blas.cpp#L1041)
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: PTPCFp8Config):
super().__init__(quant_config=quant_config)
# Force weight quantization
self.quant_config.is_checkpoint_fp8_serialized = False
self.fp8_linear = Fp8LinearOp(
act_quant_static=False,
act_quant_group_shape=GroupShape.PER_TOKEN,
force_fp8_e4m3fnuz=True)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
assert layer.weight.data.dtype == torch.bfloat16, \
f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501
# Quantize the weights.
qweight, weight_scale = ops.scaled_fp8_quant(
layer.weight, scale=None, use_per_token_if_dynamic=True)
# Update the layer with the new values.
layer.weight = Parameter(
qweight.t(), requires_grad=False) # Pretranspose the weight
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.input_scale = None
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
input_scale_ub=None,
bias=bias)