mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Revert "[Performance] Move apply_w8a8_block_fp8_linear to an op class… (#25607)
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
committed by
GitHub
parent
af4ee63e0e
commit
1260180c67
@ -17,7 +17,7 @@ from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_triton_block_scaled_mm,
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
from vllm.utils import FlexibleArgumentParser, cdiv
|
||||
|
||||
@ -158,7 +158,7 @@ def bench_fp8(
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
|
||||
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
|
||||
),
|
||||
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm(
|
||||
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
|
||||
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
|
||||
),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
|
||||
|
@ -9,7 +9,7 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
w8a8_triton_block_scaled_mm,
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.deep_gemm import (
|
||||
@ -63,7 +63,7 @@ def benchmark_shape(m: int,
|
||||
|
||||
# === vLLM Triton Implementation ===
|
||||
def vllm_triton_gemm():
|
||||
return w8a8_triton_block_scaled_mm(A_vllm,
|
||||
return w8a8_block_fp8_matmul(A_vllm,
|
||||
B_vllm,
|
||||
A_scale_vllm,
|
||||
B_scale_vllm,
|
||||
|
@ -11,7 +11,7 @@ from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
|
||||
cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (fp8_gemm_nt,
|
||||
@ -91,8 +91,7 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
|
@ -20,11 +20,9 @@ from vllm.platforms import current_platform
|
||||
(8, 513, 64), # Non-divisible (native only)
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@pytest.mark.parametrize("use_ue8m0", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
group_size: int, seed: int,
|
||||
use_ue8m0: bool) -> None:
|
||||
group_size: int, seed: int) -> None:
|
||||
"""Test QuantFP8 group quantization with various configurations.
|
||||
|
||||
Tests both CUDA and native implementations, column-major scales,
|
||||
@ -40,8 +38,7 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0)
|
||||
column_major_scales=False)
|
||||
|
||||
# 1. Test native implementation (always available)
|
||||
x_quant_native, scales_native = quant_op.forward_native(x.clone())
|
||||
@ -51,15 +48,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
# 2. Test column-major scales configuration
|
||||
quant_op_col = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0)
|
||||
column_major_scales=True)
|
||||
_, scales_col = quant_op_col.forward_native(x.clone())
|
||||
assert scales_col.shape == (batch_size, expected_num_groups)
|
||||
assert scales_col.stride(0) == 1
|
||||
assert scales_col.stride(1) == batch_size
|
||||
|
||||
# Test column-major scales consistency
|
||||
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)
|
||||
assert scales_col.shape == (expected_num_groups, batch_size)
|
||||
|
||||
# 3. Test CUDA implementation (only for divisible dimensions)
|
||||
if is_divisible:
|
||||
@ -77,9 +68,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@pytest.mark.parametrize("use_ue8m0", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
|
||||
def test_quantfp8_group_multidimensional(seed: int) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
group_size = 64
|
||||
@ -92,8 +82,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0)
|
||||
column_major_scales=False)
|
||||
|
||||
x_quant, scales = quant_op.forward_native(x_3d.clone())
|
||||
assert x_quant.shape == x_3d.shape
|
||||
@ -102,8 +91,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
|
||||
# Test column_major_scales with multi-dim
|
||||
quant_op_col = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0)
|
||||
column_major_scales=True)
|
||||
_, scales_col = quant_op_col.forward_native(x_3d.clone())
|
||||
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)
|
||||
|
||||
|
@ -17,6 +17,8 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
from vllm.model_executor.layers.layernorm import (RMSNorm,
|
||||
dispatch_rocm_rmsnorm_func,
|
||||
fused_add_rms_norm, rms_norm)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
|
||||
@ -109,6 +111,34 @@ def test_enabled_ops_invalid(env: str):
|
||||
RMSNorm(1024).enabled()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(),
|
||||
reason="AITER is a feature exclusive for ROCm and FP8_FNUZ")
|
||||
@pytest.mark.parametrize("use_cutlass", [True, False])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"])
|
||||
def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str,
|
||||
use_rocm_aiter_gemm_w8a8_blockscale: str,
|
||||
monkeypatch):
|
||||
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR",
|
||||
use_rocm_aiter_gemm_w8a8_blockscale)
|
||||
|
||||
use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool(
|
||||
int(use_rocm_aiter_gemm_w8a8_blockscale)))
|
||||
block_scale_func = dispatch_w8a8_blockscale_func(
|
||||
use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported)
|
||||
if use_cutlass:
|
||||
assert block_scale_func == cutlass_scaled_mm
|
||||
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
|
||||
use_rocm_aiter_gemm_w8a8_blockscale):
|
||||
assert block_scale_func == (
|
||||
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale)
|
||||
else:
|
||||
assert block_scale_func == w8a8_block_fp8_matmul
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
|
@ -18,9 +18,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
||||
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
@ -745,35 +742,3 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt,
|
||||
perplexity = llm.generate_prompt_perplexity([prompt])[0]
|
||||
print(perplexity)
|
||||
assert perplexity <= exp_perplexity
|
||||
|
||||
|
||||
def test_compressed_tensors_fp8_block_enabled(vllm_runner):
|
||||
model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK"
|
||||
with vllm_runner(model_path) as llm:
|
||||
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert isinstance(qkv_proj.quant_method,
|
||||
CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
|
||||
assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear,
|
||||
W8A8BlockFp8LinearOp)
|
||||
|
||||
assert qkv_proj.weight.dtype is fp8_dtype
|
||||
assert qkv_proj.weight_scale.dtype is torch.float32
|
||||
assert len(qkv_proj.weight.shape) == 2
|
||||
assert len(qkv_proj.weight_scale.shape) == 2
|
||||
|
||||
input_quant_op = \
|
||||
qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
|
||||
assert isinstance(input_quant_op, QuantFP8)
|
||||
assert input_quant_op._forward_method == input_quant_op.forward_cuda
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
|
@ -545,23 +545,6 @@ class VllmConfig:
|
||||
# local attention.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = True
|
||||
|
||||
def has_blocked_weights():
|
||||
if self.quant_config is not None:
|
||||
if hasattr(self.quant_config, "weight_block_size"):
|
||||
return self.quant_config.weight_block_size is not None
|
||||
elif hasattr(self.quant_config, "has_blocked_weights"):
|
||||
return self.quant_config.has_blocked_weights()
|
||||
return False
|
||||
|
||||
# Enable quant_fp8 CUDA ops (TODO disable in follow up)
|
||||
# On H100 the CUDA kernel is faster than
|
||||
# native implementation
|
||||
# https://github.com/vllm-project/vllm/issues/25094
|
||||
if has_blocked_weights():
|
||||
custom_ops = self.compilation_config.custom_ops
|
||||
if "none" not in custom_ops and "-quant_fp8" not in custom_ops:
|
||||
custom_ops.append("+quant_fp8")
|
||||
|
||||
def update_sizes_for_sequence_parallelism(self,
|
||||
possible_sizes: list) -> list:
|
||||
# remove the sizes that not multiple of tp_size when
|
||||
|
@ -644,14 +644,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
def has_blocked_weights(self) -> bool:
|
||||
for scheme in self.target_scheme_map.values():
|
||||
weight_quant = scheme.get("weights")
|
||||
if (weight_quant is not None
|
||||
and weight_quant.strategy == QuantizationStrategy.BLOCK):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def supports_cutlass_24(
|
||||
weight_quant: Optional[QuantizationArgs],
|
||||
|
@ -11,7 +11,7 @@ from torch.nn import Parameter
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
|
||||
apply_fp8_block_linear, check_aiter_fp8_linear_support,
|
||||
create_fp8_input_scale, create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter, maybe_post_process_fp8_weight_block,
|
||||
process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy,
|
||||
@ -41,30 +41,16 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
self.strategy = weight_quant.strategy
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR \
|
||||
if is_static_input_scheme else GroupShape.PER_TOKEN
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.is_static_input_scheme,
|
||||
act_quant_group_shape=self.act_q_group_shape)
|
||||
|
||||
self.weight_block_size = self.weight_quant.block_structure
|
||||
if self.weight_block_size is not None:
|
||||
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
||||
else:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR \
|
||||
if is_static_input_scheme else GroupShape.PER_TOKEN
|
||||
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||
|
||||
if self.weight_block_size is not None:
|
||||
assert not self.is_static_input_scheme
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
else:
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.is_static_input_scheme,
|
||||
act_quant_group_shape=self.act_q_group_shape)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
@ -155,14 +141,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.weight_block_size is not None:
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
if layer.weight_block_size is not None:
|
||||
return apply_fp8_block_linear(
|
||||
layer,
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
|
@ -42,7 +42,7 @@ def prepare_block_fp8_matmul_inputs(
|
||||
return M, N, K, C
|
||||
|
||||
|
||||
def w8a8_deepgemm_block_scaled_mm(
|
||||
def w8a8_block_fp8_matmul_deepgemm(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
@ -58,7 +58,7 @@ def w8a8_deepgemm_block_scaled_mm(
|
||||
return C
|
||||
|
||||
|
||||
def w8a8_deepgemm_block_scaled_mm_fake(
|
||||
def w8a8_block_fp8_matmul_deepgemm_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
@ -72,7 +72,7 @@ def w8a8_deepgemm_block_scaled_mm_fake(
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="w8a8_deepgemm_block_scaled_mm",
|
||||
op_func=w8a8_deepgemm_block_scaled_mm,
|
||||
fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
|
||||
op_name="w8a8_block_fp8_matmul_deepgemm",
|
||||
op_func=w8a8_block_fp8_matmul_deepgemm,
|
||||
fake_impl=w8a8_block_fp8_matmul_deepgemm_fake,
|
||||
)
|
||||
|
@ -33,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
|
||||
select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
|
||||
apply_fp8_block_linear, check_aiter_fp8_linear_support,
|
||||
create_fp8_input_scale, create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter, expert_weight_is_col_major,
|
||||
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
|
||||
@ -242,28 +242,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant = self.weight_block_size is not None
|
||||
self.act_q_static = self.quant_config.activation_scheme == "static"
|
||||
if self.weight_block_size:
|
||||
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if not self.act_q_static and cutlass_fp8_supported():
|
||||
self.act_q_group_shape = GroupShape.PER_TOKEN
|
||||
else:
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if not self.act_q_static and cutlass_fp8_supported():
|
||||
self.act_q_group_shape = GroupShape.PER_TOKEN
|
||||
else:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||
|
||||
if self.block_quant:
|
||||
assert not self.act_q_static
|
||||
assert self.weight_block_size is not None
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
else:
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.act_q_static,
|
||||
act_quant_group_shape=self.act_q_group_shape)
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.act_q_static,
|
||||
act_quant_group_shape=self.act_q_group_shape)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -412,15 +399,12 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
bias=bias)
|
||||
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
return apply_fp8_block_linear(
|
||||
layer,
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
|
@ -27,14 +27,11 @@ class QuantFP8(CustomOp):
|
||||
This CustomOp supports both static and dynamic quantization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
static: bool,
|
||||
group_shape: GroupShape,
|
||||
num_token_padding: Optional[int] = None,
|
||||
column_major_scales: bool = False,
|
||||
use_ue8m0: Optional[bool] = None, # for Torch compile
|
||||
):
|
||||
def __init__(self,
|
||||
static: bool,
|
||||
group_shape: GroupShape,
|
||||
num_token_padding: Optional[int] = None,
|
||||
column_major_scales: bool = False):
|
||||
"""
|
||||
:param static: static or dynamic quantization
|
||||
:param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
|
||||
@ -49,7 +46,6 @@ class QuantFP8(CustomOp):
|
||||
self.group_shape = group_shape
|
||||
self.num_token_padding = num_token_padding
|
||||
self.column_major_scales = column_major_scales
|
||||
self.use_ue8m0 = use_ue8m0
|
||||
|
||||
self.is_group_quant = group_shape.is_per_group()
|
||||
if self.is_group_quant:
|
||||
@ -74,8 +70,7 @@ class QuantFP8(CustomOp):
|
||||
x,
|
||||
group_size=self.group_size,
|
||||
column_major_scales=self.column_major_scales,
|
||||
dtype=_FP8_DTYPE,
|
||||
use_ue8m0=self.use_ue8m0)
|
||||
dtype=_FP8_DTYPE)
|
||||
|
||||
assert (scale is not None) == self.static
|
||||
assert scale_ub is None or (not self.static and self.group_shape
|
||||
@ -142,10 +137,7 @@ class QuantFP8(CustomOp):
|
||||
|
||||
x_grouped = x.view(-1, num_groups, self.group_size)
|
||||
absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float()
|
||||
scales_raw = absmax / _FP8_MAX
|
||||
if self.use_ue8m0:
|
||||
scales_raw = torch.exp2(torch.ceil(torch.log2(scales_raw)))
|
||||
scales = (scales_raw).clamp(min=_FP8_MIN_SCALING_FACTOR)
|
||||
scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
|
||||
|
||||
x_scaled = x_grouped / scales
|
||||
x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
|
||||
@ -159,6 +151,6 @@ class QuantFP8(CustomOp):
|
||||
scales = scales.reshape(orig_shape[:-1] + (num_groups, ))
|
||||
|
||||
if self.column_major_scales:
|
||||
scales = scales.transpose(-2, -1).contiguous().transpose(-1, -2)
|
||||
scales = scales.transpose(-2, -1).contiguous()
|
||||
|
||||
return x_quant, scales
|
||||
|
@ -13,9 +13,8 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, group_broadcast)
|
||||
group_broadcast)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
@ -25,7 +24,6 @@ from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
should_use_deepgemm_for_fp8_linear)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -37,8 +35,6 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
||||
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
||||
|
||||
|
||||
# We need to pass in the is_hopper flag as argument because the function
|
||||
# current_platform.is_device_capability() is not supported by Torch compiler.
|
||||
def cutlass_scaled_mm(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
@ -46,17 +42,15 @@ def cutlass_scaled_mm(
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
is_hopper: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
if is_hopper is None:
|
||||
is_hopper = current_platform.is_device_capability(90)
|
||||
return ops.cutlass_scaled_mm(
|
||||
A,
|
||||
B.T,
|
||||
out_dtype=output_dtype,
|
||||
scale_a=As,
|
||||
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
|
||||
scale_b=Bs if block_size is not None and is_hopper else Bs.T)
|
||||
scale_b=Bs if block_size is not None
|
||||
and current_platform.is_device_capability(90) else Bs.T)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_blockscale_impl(
|
||||
@ -102,190 +96,122 @@ if current_platform.is_rocm():
|
||||
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
|
||||
|
||||
|
||||
# TODO we should be able to change the type of block_size to GroupShape
|
||||
# after we resolve GroupShape compilation issue
|
||||
# https://github.com/vllm-project/vllm/issues/25270
|
||||
def _w8a8_triton_block_scaled_mm_func(
|
||||
qx: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return w8a8_triton_block_scaled_mm(qx, weight, x_scale, weight_scale,
|
||||
block_size, output_dtype)
|
||||
|
||||
|
||||
def _w8a8_triton_block_scaled_mm_fake(
|
||||
qx: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((qx.size(0), weight.size(0)),
|
||||
dtype=output_dtype,
|
||||
device=qx.device)
|
||||
|
||||
|
||||
# Note: the check can be removed when CPU torch > 2.7
|
||||
if not current_platform.is_cpu():
|
||||
direct_register_custom_op(
|
||||
"w8a8_triton_block_scaled_mm_func",
|
||||
_w8a8_triton_block_scaled_mm_func,
|
||||
fake_impl=_w8a8_triton_block_scaled_mm_fake,
|
||||
dispatch_key="CUDA",
|
||||
)
|
||||
def dispatch_w8a8_blockscale_func(
|
||||
use_cutlass: bool, use_aiter_and_is_supported: bool
|
||||
) -> Callable[[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
list[int],
|
||||
torch.dtype,
|
||||
], torch.Tensor]:
|
||||
if use_cutlass:
|
||||
return cutlass_scaled_mm
|
||||
if (use_aiter_and_is_supported):
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
|
||||
return w8a8_block_fp8_matmul
|
||||
|
||||
|
||||
# TODO fix ROCm->Triton custom path:
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
class W8A8BlockFp8LinearOp:
|
||||
"""
|
||||
This class executes a Blocked FP8 linear layer using cutlass if supported
|
||||
and torch.scaled_mm otherwise.
|
||||
"""
|
||||
def apply_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
output_dtype = input.dtype
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_group_shape: GroupShape,
|
||||
act_quant_group_shape: GroupShape,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
):
|
||||
self.weight_group_shape = weight_group_shape
|
||||
self.act_quant_group_shape = act_quant_group_shape
|
||||
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
||||
self.is_hopper = current_platform.is_device_capability(90)
|
||||
if should_use_deepgemm_for_fp8_linear(output_dtype, weight):
|
||||
|
||||
# Get the correct blockscale mul and input quant operations.
|
||||
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
|
||||
# to use deepgemm because we don't know the shape of weights (and
|
||||
# whether deepgemm supports it) at the init time.
|
||||
self.w8a8_blockscale_op, self.input_quant_op = \
|
||||
self._dispatch_w8a8_blockscale_op(
|
||||
cutlass_block_fp8_supported, use_aiter_and_is_supported)
|
||||
self.deepgemm_input_quant_op = (QuantFP8(
|
||||
False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=is_deep_gemm_e8m0_used()) if self.is_deep_gemm_supported
|
||||
else None)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
output_dtype = input.dtype
|
||||
|
||||
if should_use_deepgemm_for_fp8_linear(output_dtype, weight,
|
||||
self.is_deep_gemm_supported):
|
||||
output = self._run_deepgemm(input, weight, weight_scale)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True,
|
||||
)
|
||||
|
||||
output = self.w8a8_blockscale_op(input_2d, weight, weight_scale)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
def _run_deepgemm(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# ensure DeepGEMM-backed custom op is registered before use
|
||||
import vllm.model_executor.layers.quantization.deepgemm # noqa: F401
|
||||
|
||||
assert self.deepgemm_input_quant_op is not None
|
||||
q_input, x_scale = self.deepgemm_input_quant_op(input_2d)
|
||||
return torch.ops.vllm.w8a8_deepgemm_block_scaled_mm(
|
||||
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
|
||||
q_input,
|
||||
weight,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
self.weight_group_shape,
|
||||
output_dtype=input_2d.dtype)
|
||||
block_size,
|
||||
output_dtype=output_dtype)
|
||||
if bias is not None:
|
||||
output += bias
|
||||
return output.to(dtype=output_dtype).view(*output_shape)
|
||||
|
||||
def _run_cutlass(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert self.input_quant_op is not None
|
||||
if self.is_hopper:
|
||||
# We pad unconditionally (even if shape is already divisible by 4)
|
||||
# to support dynamic shape for input_2d.shape[0] in torch.compile
|
||||
x = torch.nn.functional.pad(input_2d,
|
||||
(0, 0, 0, -input_2d.shape[0] % 4))
|
||||
else:
|
||||
x = input_2d
|
||||
|
||||
q_input, x_scale = self.input_quant_op(x)
|
||||
output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale,
|
||||
list(self.weight_group_shape),
|
||||
input_2d.dtype, self.is_hopper)
|
||||
output = output[0:input_2d.shape[0], ...]
|
||||
return output
|
||||
|
||||
def _run_aiter(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert self.act_quant_group_shape == GroupShape(1, 128)
|
||||
q_input, x_scale = aiter_per1x128_quant(
|
||||
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
|
||||
q_input, weight, x_scale, weight_scale, self.weight_group_shape,
|
||||
input_2d.dtype)
|
||||
|
||||
def _run_triton(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert self.input_quant_op is not None
|
||||
q_input, x_scale = self.input_quant_op(input_2d)
|
||||
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
|
||||
q_input, weight, x_scale, weight_scale, self.weight_group_shape,
|
||||
input_2d.dtype)
|
||||
|
||||
def _dispatch_w8a8_blockscale_op(
|
||||
self,
|
||||
use_cutlass: bool,
|
||||
use_aiter_and_is_supported: bool,
|
||||
) -> tuple[Callable[[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
], torch.Tensor], Optional[QuantFP8]]:
|
||||
if use_cutlass:
|
||||
return self._run_cutlass, (QuantFP8(False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=False))
|
||||
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
|
||||
cutlass_block_fp8_supported, use_aiter_and_is_supported)
|
||||
if cutlass_block_fp8_supported:
|
||||
num_pad = 0
|
||||
if current_platform.is_device_capability(90):
|
||||
# pad first dimension to be divisible by 4 due to
|
||||
# cutlass blockwise gemm limitation for hopper
|
||||
num_pad = 4 - (input_2d.shape[0] % 4)
|
||||
if num_pad > 0:
|
||||
input_2d = torch.nn.functional.pad(input_2d,
|
||||
(0, 0, 0, num_pad),
|
||||
"constant", 0)
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True)
|
||||
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
||||
block_size, input.dtype)
|
||||
if num_pad > 0:
|
||||
output = output[:-num_pad]
|
||||
else:
|
||||
if use_aiter_and_is_supported:
|
||||
return self._run_aiter, None
|
||||
return self._run_triton, (QuantFP8(False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False))
|
||||
q_input, x_scale = aiter_per1x128_quant(
|
||||
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
|
||||
else:
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False)
|
||||
|
||||
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
||||
block_size, input.dtype)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
|
||||
def apply_w8a8_block_fp8_linear_fake(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
) -> torch.Tensor:
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
|
||||
|
||||
|
||||
if not current_platform.is_cpu():
|
||||
direct_register_custom_op(
|
||||
op_name="apply_w8a8_block_fp8_linear",
|
||||
op_func=apply_w8a8_block_fp8_linear,
|
||||
mutates_args=[],
|
||||
fake_impl=apply_w8a8_block_fp8_linear_fake,
|
||||
)
|
||||
|
||||
|
||||
def input_to_float8(
|
||||
@ -537,7 +463,7 @@ def per_token_group_quant_fp8(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _w8a8_triton_block_scaled_mm(
|
||||
def _w8a8_block_fp8_matmul(
|
||||
# Pointers to inputs and output
|
||||
A,
|
||||
B,
|
||||
@ -662,7 +588,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
|
||||
return None
|
||||
|
||||
|
||||
def w8a8_triton_block_scaled_mm(
|
||||
def w8a8_block_fp8_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
@ -722,7 +648,7 @@ def w8a8_triton_block_scaled_mm(
|
||||
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
|
||||
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
|
||||
|
||||
_w8a8_triton_block_scaled_mm[grid](
|
||||
_w8a8_block_fp8_matmul[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
@ -1005,6 +931,25 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
|
||||
layer.weight_scale.data.T.contiguous(), requires_grad=False)
|
||||
|
||||
|
||||
def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
cutlass_block_fp8_supported: bool,
|
||||
use_aiter_and_is_supported: bool) -> torch.Tensor:
|
||||
"""Apply block-wise FP8 linear operation."""
|
||||
assert layer.weight_block_size is not None
|
||||
|
||||
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
|
||||
input=input,
|
||||
weight=layer.weight,
|
||||
block_size=layer.weight_block_size,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=use_aiter_and_is_supported,
|
||||
)
|
||||
|
||||
|
||||
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
|
||||
assert x.dim() == 3
|
||||
b, m, n = x.shape
|
||||
|
@ -9,7 +9,7 @@ from __future__ import annotations
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
from typing import Any, Callable, NoReturn, Optional
|
||||
from typing import Any, Callable, NoReturn
|
||||
|
||||
import torch
|
||||
|
||||
@ -184,13 +184,9 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def should_use_deepgemm_for_fp8_linear(
|
||||
output_dtype: torch.dtype,
|
||||
weight: torch.Tensor,
|
||||
supports_deep_gemm: Optional[bool] = None):
|
||||
if supports_deep_gemm is None:
|
||||
supports_deep_gemm = is_deep_gemm_supported()
|
||||
return (supports_deep_gemm and output_dtype == torch.bfloat16
|
||||
def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype,
|
||||
weight: torch.Tensor):
|
||||
return (is_deep_gemm_supported() and output_dtype == torch.bfloat16
|
||||
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user