mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[AMD][ROCm]Quantization methods on ROCm; Fix _scaled_mm call (#8380)
Co-authored-by: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
e18749ff09
commit
b3195bc9e4
@ -255,7 +255,10 @@ class ModelConfig:
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = [*QUANTIZATION_METHODS]
|
||||
rocm_supported_quantization = ["awq", "gptq", "fp8"]
|
||||
rocm_supported_quantization = [
|
||||
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
|
||||
"fbgemm_fp8"
|
||||
]
|
||||
optimized_quantization_methods = [
|
||||
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
|
||||
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
|
||||
|
@ -8,10 +8,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
QuantizationStrategy)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
|
||||
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
|
||||
requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.utils import is_hip
|
||||
|
||||
__all__ = ["CompressedTensorsW8A8Fp8"]
|
||||
|
||||
@ -39,16 +41,37 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
if is_hip():
|
||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=max_w_scale,
|
||||
input_scale=layer.input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# If channelwise, scales are already lined up, so just transpose.
|
||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight = layer.weight
|
||||
|
||||
if is_hip():
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
else:
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
||||
|
@ -15,10 +15,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear)
|
||||
apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -125,8 +126,18 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
weight = layer.weight
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
|
||||
if is_hip():
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=None)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
if self.quant_config.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
# Activations not quantized for marlin.
|
||||
|
@ -6,11 +6,9 @@ from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_hip
|
||||
|
||||
# scaled_mm in pytorch on rocm has a bug that requires always
|
||||
# providing scaling factor for result. This value is created
|
||||
# as global value to avoid multiple tensor allocations, and
|
||||
# can be removed once pytorch fixes the bug.
|
||||
TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None
|
||||
|
||||
|
||||
def cutlass_fp8_supported() -> bool:
|
||||
@ -131,19 +129,17 @@ def apply_fp8_linear(
|
||||
|
||||
if per_tensor_weights and per_tensor_activations:
|
||||
# Fused GEMM_DQ
|
||||
output = torch._scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
scale_result=TORCH_SCALED_MM_SCALE_RESULT,
|
||||
bias=bias)
|
||||
# Since in torch 2.5, scaled_mm only returns single value
|
||||
# This should be removed when vllm-nvidia also moves to 2.5
|
||||
if is_hip():
|
||||
return torch.narrow(output, 0, 0, input.shape[0])
|
||||
return torch.narrow(output[0], 0, 0, input.shape[0])
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
return torch.narrow(output[0], 0, 0, input.shape[0])
|
||||
return torch.narrow(output, 0, 0, input.shape[0])
|
||||
|
||||
else:
|
||||
# Fallback for channelwise case, where we use unfused DQ
|
||||
@ -161,12 +157,23 @@ def apply_fp8_linear(
|
||||
# For the scaled_mm fallback case, we break this down, since it
|
||||
# does not support s_w being a vector.
|
||||
|
||||
# Making sure the dummy tensor is on the same device as the weight
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
||||
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
||||
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
output, _ = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=torch.float32)
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
scale_a=TORCH_DEVICE_IDENTITY,
|
||||
scale_b=TORCH_DEVICE_IDENTITY,
|
||||
out_dtype=torch.float32)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, input.shape[0])
|
||||
x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])
|
||||
|
Reference in New Issue
Block a user