Revert "[Performance] Move apply_w8a8_block_fp8_linear to an op class… (#25607)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
Tyler Michael Smith
2025-09-25 04:05:21 -04:00
committed by GitHub
parent af4ee63e0e
commit 1260180c67
14 changed files with 205 additions and 346 deletions

View File

@ -17,7 +17,7 @@ from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_triton_block_scaled_mm,
w8a8_block_fp8_matmul,
)
from vllm.utils import FlexibleArgumentParser, cdiv
@ -158,7 +158,7 @@ def bench_fp8(
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
),
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm(
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(

View File

@ -9,7 +9,7 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
w8a8_triton_block_scaled_mm,
w8a8_block_fp8_matmul,
)
from vllm.triton_utils import triton
from vllm.utils.deep_gemm import (
@ -63,7 +63,7 @@ def benchmark_shape(m: int,
# === vLLM Triton Implementation ===
def vllm_triton_gemm():
return w8a8_triton_block_scaled_mm(A_vllm,
return w8a8_block_fp8_matmul(A_vllm,
B_vllm,
A_scale_vllm,
B_scale_vllm,

View File

@ -11,7 +11,7 @@ from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul)
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import (fp8_gemm_nt,
@ -91,8 +91,7 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /

View File

@ -20,11 +20,9 @@ from vllm.platforms import current_platform
(8, 513, 64), # Non-divisible (native only)
])
@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode()
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_size: int, seed: int,
use_ue8m0: bool) -> None:
group_size: int, seed: int) -> None:
"""Test QuantFP8 group quantization with various configurations.
Tests both CUDA and native implementations, column-major scales,
@ -40,8 +38,7 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False,
use_ue8m0=use_ue8m0)
column_major_scales=False)
# 1. Test native implementation (always available)
x_quant_native, scales_native = quant_op.forward_native(x.clone())
@ -51,15 +48,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
# 2. Test column-major scales configuration
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True,
use_ue8m0=use_ue8m0)
column_major_scales=True)
_, scales_col = quant_op_col.forward_native(x.clone())
assert scales_col.shape == (batch_size, expected_num_groups)
assert scales_col.stride(0) == 1
assert scales_col.stride(1) == batch_size
# Test column-major scales consistency
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)
assert scales_col.shape == (expected_num_groups, batch_size)
# 3. Test CUDA implementation (only for divisible dimensions)
if is_divisible:
@ -77,9 +68,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode()
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
def test_quantfp8_group_multidimensional(seed: int) -> None:
current_platform.seed_everything(seed)
group_size = 64
@ -92,8 +82,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False,
use_ue8m0=use_ue8m0)
column_major_scales=False)
x_quant, scales = quant_op.forward_native(x_3d.clone())
assert x_quant.shape == x_3d.shape
@ -102,8 +91,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
# Test column_major_scales with multi-dim
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True,
use_ue8m0=use_ue8m0)
column_major_scales=True)
_, scales_col = quant_op_col.forward_native(x_3d.clone())
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)

View File

@ -17,6 +17,8 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
from vllm.model_executor.layers.layernorm import (RMSNorm,
dispatch_rocm_rmsnorm_func,
fused_add_rms_norm, rms_norm)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
@ -109,6 +111,34 @@ def test_enabled_ops_invalid(env: str):
RMSNorm(1024).enabled()
@pytest.mark.skipif(
not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(),
reason="AITER is a feature exclusive for ROCm and FP8_FNUZ")
@pytest.mark.parametrize("use_cutlass", [True, False])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"])
def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str,
use_rocm_aiter_gemm_w8a8_blockscale: str,
monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR",
use_rocm_aiter_gemm_w8a8_blockscale)
use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool(
int(use_rocm_aiter_gemm_w8a8_blockscale)))
block_scale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported)
if use_cutlass:
assert block_scale_func == cutlass_scaled_mm
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_gemm_w8a8_blockscale):
assert block_scale_func == (
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale)
else:
assert block_scale_func == w8a8_block_fp8_matmul
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)

View File

@ -18,9 +18,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@ -745,35 +742,3 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt,
perplexity = llm.generate_prompt_perplexity([prompt])[0]
print(perplexity)
assert perplexity <= exp_perplexity
def test_compressed_tensors_fp8_block_enabled(vllm_runner):
model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK"
with vllm_runner(model_path) as llm:
fp8_dtype = current_platform.fp8_dtype()
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear,
W8A8BlockFp8LinearOp)
assert qkv_proj.weight.dtype is fp8_dtype
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight.shape) == 2
assert len(qkv_proj.weight_scale.shape) == 2
input_quant_op = \
qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
assert isinstance(input_quant_op, QuantFP8)
assert input_quant_op._forward_method == input_quant_op.forward_cuda
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output

View File

@ -545,23 +545,6 @@ class VllmConfig:
# local attention.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
def has_blocked_weights():
if self.quant_config is not None:
if hasattr(self.quant_config, "weight_block_size"):
return self.quant_config.weight_block_size is not None
elif hasattr(self.quant_config, "has_blocked_weights"):
return self.quant_config.has_blocked_weights()
return False
# Enable quant_fp8 CUDA ops (TODO disable in follow up)
# On H100 the CUDA kernel is faster than
# native implementation
# https://github.com/vllm-project/vllm/issues/25094
if has_blocked_weights():
custom_ops = self.compilation_config.custom_ops
if "none" not in custom_ops and "-quant_fp8" not in custom_ops:
custom_ops.append("+quant_fp8")
def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when

View File

@ -644,14 +644,6 @@ class CompressedTensorsConfig(QuantizationConfig):
# If no matches, return None
return None
def has_blocked_weights(self) -> bool:
for scheme in self.target_scheme_map.values():
weight_quant = scheme.get("weights")
if (weight_quant is not None
and weight_quant.strategy == QuantizationStrategy.BLOCK):
return True
return False
@staticmethod
def supports_cutlass_24(
weight_quant: Optional[QuantizationArgs],

View File

@ -11,7 +11,7 @@ from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
apply_fp8_block_linear, check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_scale_parameter,
create_fp8_weight_parameter, maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy,
@ -41,30 +41,16 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme
self.act_q_group_shape = GroupShape.PER_TENSOR \
if is_static_input_scheme else GroupShape.PER_TOKEN
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape)
self.weight_block_size = self.weight_quant.block_structure
if self.weight_block_size is not None:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
self.act_q_group_shape = GroupShape.PER_TENSOR \
if is_static_input_scheme else GroupShape.PER_TOKEN
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
if self.weight_block_size is not None:
assert not self.is_static_input_scheme
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape)
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
@ -155,14 +141,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.weight_block_size is not None:
return self.w8a8_block_fp8_linear.apply(
if layer.weight_block_size is not None:
return apply_fp8_block_linear(
layer,
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
return self.fp8_linear.apply(input=x,
weight=layer.weight,

View File

@ -42,7 +42,7 @@ def prepare_block_fp8_matmul_inputs(
return M, N, K, C
def w8a8_deepgemm_block_scaled_mm(
def w8a8_block_fp8_matmul_deepgemm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
@ -58,7 +58,7 @@ def w8a8_deepgemm_block_scaled_mm(
return C
def w8a8_deepgemm_block_scaled_mm_fake(
def w8a8_block_fp8_matmul_deepgemm_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
@ -72,7 +72,7 @@ def w8a8_deepgemm_block_scaled_mm_fake(
direct_register_custom_op(
op_name="w8a8_deepgemm_block_scaled_mm",
op_func=w8a8_deepgemm_block_scaled_mm,
fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
op_name="w8a8_block_fp8_matmul_deepgemm",
op_func=w8a8_block_fp8_matmul_deepgemm,
fake_impl=w8a8_block_fp8_matmul_deepgemm_fake,
)

View File

@ -33,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
apply_fp8_block_linear, check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_scale_parameter,
create_fp8_weight_parameter, expert_weight_is_col_major,
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
@ -242,28 +242,15 @@ class Fp8LinearMethod(LinearMethodBase):
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None
self.act_q_static = self.quant_config.activation_scheme == "static"
if self.weight_block_size:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
self.act_q_group_shape = GroupShape.PER_TENSOR
if self.block_quant:
assert not self.act_q_static
assert self.weight_block_size is not None
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape)
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape)
def create_weights(
self,
@ -412,15 +399,12 @@ class Fp8LinearMethod(LinearMethodBase):
bias=bias)
if self.block_quant:
assert self.weight_block_size is not None
return self.w8a8_block_fp8_linear.apply(
return apply_fp8_block_linear(
layer,
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
return self.fp8_linear.apply(input=x,
weight=layer.weight,

View File

@ -27,14 +27,11 @@ class QuantFP8(CustomOp):
This CustomOp supports both static and dynamic quantization.
"""
def __init__(
self,
static: bool,
group_shape: GroupShape,
num_token_padding: Optional[int] = None,
column_major_scales: bool = False,
use_ue8m0: Optional[bool] = None, # for Torch compile
):
def __init__(self,
static: bool,
group_shape: GroupShape,
num_token_padding: Optional[int] = None,
column_major_scales: bool = False):
"""
:param static: static or dynamic quantization
:param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
@ -49,7 +46,6 @@ class QuantFP8(CustomOp):
self.group_shape = group_shape
self.num_token_padding = num_token_padding
self.column_major_scales = column_major_scales
self.use_ue8m0 = use_ue8m0
self.is_group_quant = group_shape.is_per_group()
if self.is_group_quant:
@ -74,8 +70,7 @@ class QuantFP8(CustomOp):
x,
group_size=self.group_size,
column_major_scales=self.column_major_scales,
dtype=_FP8_DTYPE,
use_ue8m0=self.use_ue8m0)
dtype=_FP8_DTYPE)
assert (scale is not None) == self.static
assert scale_ub is None or (not self.static and self.group_shape
@ -142,10 +137,7 @@ class QuantFP8(CustomOp):
x_grouped = x.view(-1, num_groups, self.group_size)
absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float()
scales_raw = absmax / _FP8_MAX
if self.use_ue8m0:
scales_raw = torch.exp2(torch.ceil(torch.log2(scales_raw)))
scales = (scales_raw).clamp(min=_FP8_MIN_SCALING_FACTOR)
scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
x_scaled = x_grouped / scales
x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
@ -159,6 +151,6 @@ class QuantFP8(CustomOp):
scales = scales.reshape(orig_shape[:-1] + (num_groups, ))
if self.column_major_scales:
scales = scales.transpose(-2, -1).contiguous().transpose(-1, -2)
scales = scales.transpose(-2, -1).contiguous()
return x_quant, scales

View File

@ -13,9 +13,8 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, group_broadcast)
group_broadcast)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
@ -25,7 +24,6 @@ from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear)
logger = init_logger(__name__)
@ -37,8 +35,6 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
# We need to pass in the is_hopper flag as argument because the function
# current_platform.is_device_capability() is not supported by Torch compiler.
def cutlass_scaled_mm(
A: torch.Tensor,
B: torch.Tensor,
@ -46,17 +42,15 @@ def cutlass_scaled_mm(
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
is_hopper: Optional[bool] = None,
) -> torch.Tensor:
if is_hopper is None:
is_hopper = current_platform.is_device_capability(90)
return ops.cutlass_scaled_mm(
A,
B.T,
out_dtype=output_dtype,
scale_a=As,
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
scale_b=Bs if block_size is not None and is_hopper else Bs.T)
scale_b=Bs if block_size is not None
and current_platform.is_device_capability(90) else Bs.T)
def rocm_aiter_gemm_w8a8_blockscale_impl(
@ -102,190 +96,122 @@ if current_platform.is_rocm():
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
# TODO we should be able to change the type of block_size to GroupShape
# after we resolve GroupShape compilation issue
# https://github.com/vllm-project/vllm/issues/25270
def _w8a8_triton_block_scaled_mm_func(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return w8a8_triton_block_scaled_mm(qx, weight, x_scale, weight_scale,
block_size, output_dtype)
def _w8a8_triton_block_scaled_mm_fake(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty((qx.size(0), weight.size(0)),
dtype=output_dtype,
device=qx.device)
# Note: the check can be removed when CPU torch > 2.7
if not current_platform.is_cpu():
direct_register_custom_op(
"w8a8_triton_block_scaled_mm_func",
_w8a8_triton_block_scaled_mm_func,
fake_impl=_w8a8_triton_block_scaled_mm_fake,
dispatch_key="CUDA",
)
def dispatch_w8a8_blockscale_func(
use_cutlass: bool, use_aiter_and_is_supported: bool
) -> Callable[[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
list[int],
torch.dtype,
], torch.Tensor]:
if use_cutlass:
return cutlass_scaled_mm
if (use_aiter_and_is_supported):
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
return w8a8_block_fp8_matmul
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
class W8A8BlockFp8LinearOp:
"""
This class executes a Blocked FP8 linear layer using cutlass if supported
and torch.scaled_mm otherwise.
"""
def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
def __init__(
self,
weight_group_shape: GroupShape,
act_quant_group_shape: GroupShape,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
):
self.weight_group_shape = weight_group_shape
self.act_quant_group_shape = act_quant_group_shape
self.is_deep_gemm_supported = is_deep_gemm_supported()
self.is_hopper = current_platform.is_device_capability(90)
if should_use_deepgemm_for_fp8_linear(output_dtype, weight):
# Get the correct blockscale mul and input quant operations.
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
# to use deepgemm because we don't know the shape of weights (and
# whether deepgemm supports it) at the init time.
self.w8a8_blockscale_op, self.input_quant_op = \
self._dispatch_w8a8_blockscale_op(
cutlass_block_fp8_supported, use_aiter_and_is_supported)
self.deepgemm_input_quant_op = (QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=is_deep_gemm_e8m0_used()) if self.is_deep_gemm_supported
else None)
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
if should_use_deepgemm_for_fp8_linear(output_dtype, weight,
self.is_deep_gemm_supported):
output = self._run_deepgemm(input, weight, weight_scale)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
q_input, x_scale = per_token_group_quant_fp8(
input_2d,
block_size[1],
column_major_scales=True,
)
output = self.w8a8_blockscale_op(input_2d, weight, weight_scale)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def _run_deepgemm(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
# ensure DeepGEMM-backed custom op is registered before use
import vllm.model_executor.layers.quantization.deepgemm # noqa: F401
assert self.deepgemm_input_quant_op is not None
q_input, x_scale = self.deepgemm_input_quant_op(input_2d)
return torch.ops.vllm.w8a8_deepgemm_block_scaled_mm(
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
q_input,
weight,
x_scale,
weight_scale,
self.weight_group_shape,
output_dtype=input_2d.dtype)
block_size,
output_dtype=output_dtype)
if bias is not None:
output += bias
return output.to(dtype=output_dtype).view(*output_shape)
def _run_cutlass(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.input_quant_op is not None
if self.is_hopper:
# We pad unconditionally (even if shape is already divisible by 4)
# to support dynamic shape for input_2d.shape[0] in torch.compile
x = torch.nn.functional.pad(input_2d,
(0, 0, 0, -input_2d.shape[0] % 4))
else:
x = input_2d
q_input, x_scale = self.input_quant_op(x)
output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale,
list(self.weight_group_shape),
input_2d.dtype, self.is_hopper)
output = output[0:input_2d.shape[0], ...]
return output
def _run_aiter(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128)
q_input, x_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
q_input, weight, x_scale, weight_scale, self.weight_group_shape,
input_2d.dtype)
def _run_triton(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.input_quant_op is not None
q_input, x_scale = self.input_quant_op(input_2d)
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
q_input, weight, x_scale, weight_scale, self.weight_group_shape,
input_2d.dtype)
def _dispatch_w8a8_blockscale_op(
self,
use_cutlass: bool,
use_aiter_and_is_supported: bool,
) -> tuple[Callable[[
torch.Tensor,
torch.Tensor,
torch.Tensor,
], torch.Tensor], Optional[QuantFP8]]:
if use_cutlass:
return self._run_cutlass, (QuantFP8(False,
self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=False))
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
cutlass_block_fp8_supported, use_aiter_and_is_supported)
if cutlass_block_fp8_supported:
num_pad = 0
if current_platform.is_device_capability(90):
# pad first dimension to be divisible by 4 due to
# cutlass blockwise gemm limitation for hopper
num_pad = 4 - (input_2d.shape[0] % 4)
if num_pad > 0:
input_2d = torch.nn.functional.pad(input_2d,
(0, 0, 0, num_pad),
"constant", 0)
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
column_major_scales=True)
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
block_size, input.dtype)
if num_pad > 0:
output = output[:-num_pad]
else:
if use_aiter_and_is_supported:
return self._run_aiter, None
return self._run_triton, (QuantFP8(False,
self.act_quant_group_shape,
column_major_scales=False,
use_ue8m0=False))
q_input, x_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
else:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False)
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
block_size, input.dtype)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def apply_w8a8_block_fp8_linear_fake(
input: torch.Tensor,
weight: torch.Tensor,
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
output_shape = [*input.shape[:-1], weight.shape[0]]
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
if not current_platform.is_cpu():
direct_register_custom_op(
op_name="apply_w8a8_block_fp8_linear",
op_func=apply_w8a8_block_fp8_linear,
mutates_args=[],
fake_impl=apply_w8a8_block_fp8_linear_fake,
)
def input_to_float8(
@ -537,7 +463,7 @@ def per_token_group_quant_fp8(
@triton.jit
def _w8a8_triton_block_scaled_mm(
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
A,
B,
@ -662,7 +588,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
return None
def w8a8_triton_block_scaled_mm(
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
@ -722,7 +648,7 @@ def w8a8_triton_block_scaled_mm(
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
_w8a8_triton_block_scaled_mm[grid](
_w8a8_block_fp8_matmul[grid](
A,
B,
C,
@ -1005,6 +931,25 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
layer.weight_scale.data.T.contiguous(), requires_grad=False)
def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor,
bias: Optional[torch.Tensor],
cutlass_block_fp8_supported: bool,
use_aiter_and_is_supported: bool) -> torch.Tensor:
"""Apply block-wise FP8 linear operation."""
assert layer.weight_block_size is not None
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
input=input,
weight=layer.weight,
block_size=layer.weight_block_size,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
use_aiter_and_is_supported=use_aiter_and_is_supported,
)
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
b, m, n = x.shape

View File

@ -9,7 +9,7 @@ from __future__ import annotations
import functools
import importlib
import os
from typing import Any, Callable, NoReturn, Optional
from typing import Any, Callable, NoReturn
import torch
@ -184,13 +184,9 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
return 1 - sim
def should_use_deepgemm_for_fp8_linear(
output_dtype: torch.dtype,
weight: torch.Tensor,
supports_deep_gemm: Optional[bool] = None):
if supports_deep_gemm is None:
supports_deep_gemm = is_deep_gemm_supported()
return (supports_deep_gemm and output_dtype == torch.bfloat16
def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype,
weight: torch.Tensor):
return (is_deep_gemm_supported() and output_dtype == torch.bfloat16
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)