[Feature] Integrate SM100 DeepGEMM support (#20087)

This commit is contained in:
Wentao Ye
2025-07-10 23:18:05 -04:00
committed by GitHub
parent 5b032352cc
commit e2de455c34
16 changed files with 397 additions and 114 deletions

View File

@ -86,6 +86,9 @@ def benchmark_config(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_deep_gemm:
# we use the default block shape for deepgemm
block_quant_shape = [128, 128]
if use_fp8_w8a8:
if block_quant_shape:
block_n, block_k = block_quant_shape[0], block_quant_shape[1]

View File

@ -15,13 +15,13 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
dg_available = False
try:
import deep_gemm
dg_available = True
except ImportError:
pass
dg_available = has_deep_gemm()
if dg_available:
from deep_gemm import get_m_alignment_for_contiguous_layout
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
@ -224,6 +224,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@pytest.mark.skipif(is_blackwell_deep_gemm_used(), reason="Not E8M0 scale MOE")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
monkeypatch):
@ -238,8 +239,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
torch.manual_seed(seed)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_m = get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]
dtype = torch.bfloat16

View File

@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
from .parallel_utils import ProcessGroupInfo, parallel_launch
from .utils import make_test_weights
@ -368,6 +369,8 @@ NUM_EXPERTS = [32]
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep
@requires_deep_gemm
@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
reason="Skipping test for Blackwell DeepGEMM")
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
topk: int, world_dp_size: tuple[int, int]):
"""
@ -423,6 +426,8 @@ USE_FP8_DISPATCH = [False]
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep
@requires_deep_gemm
@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
reason="Skipping test for Blackwell DeepGEMM")
def test_ll_deepep_deepgemm_moe(
mnk: tuple[int, int, int],
num_experts: int,

View File

@ -13,48 +13,18 @@ import torch
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import cdiv
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import (calc_diff, per_block_cast_to_fp8,
per_token_group_cast_to_fp8)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
if has_deep_gemm:
import deep_gemm
BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout()
BLOCK_SIZE = [BLOCK_M, BLOCK_M]
BLOCK_SIZE = [128, 128]
requires_deep_gemm = pytest.mark.skipif(
not has_deep_gemm,
not has_deep_gemm(),
reason="Requires deep_gemm kernels",
)
def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def make_block_quant_fp8_weights(
e: int,
n: int,
@ -111,7 +81,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
"""
tokens_bf16 = torch.randn(
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
_, a1_scale = per_token_group_cast_to_fp8(tokens_bf16, block_size[1])
# expert weight tensors
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
@ -155,17 +125,8 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
block_shape=block_size,
allow_deep_gemm=True,
)
base = out_triton.abs().mean()
atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3
rtol = 0.05
# ----- Compare -----
torch.testing.assert_close(
out_deepgemm.to(torch.float32),
out_triton.to(torch.float32),
rtol=rtol,
atol=float(atol),
)
diff = calc_diff(out_deepgemm, out_triton)
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
# Note: W1 has shape (E, 2N, K), so N = 512

View File

@ -8,19 +8,15 @@ import pytest
import torch
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul,
per_block_cast_to_fp8)
native_w8a8_block_matmul)
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
get_col_major_tma_aligned_tensor, per_token_group_quant_fp8,
w8a8_block_fp8_matmul)
from vllm.platforms import current_platform
dg_available = False
try:
import deep_gemm
dg_available = True
except ImportError:
pass
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import (fp8_gemm_nt, per_block_cast_to_fp8,
per_token_group_cast_to_fp8)
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
@ -106,7 +102,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@pytest.mark.skipif(not has_deep_gemm(),
reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
# only aligned sizes
@ -120,9 +117,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
_, block_k = block_size[0], block_size[1]
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k)
A_fp8, As_fp8 = per_token_group_cast_to_fp8(A_fp32, block_size[1])
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
As = As_fp8.to(torch.float32)
@ -132,14 +127,14 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
out_dtype)
# Transpose earlier so that the testing will not trigger transposing kernels
As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8)
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
assert As_fp8.shape == (M, (K + 127) //
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /

View File

@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked
logger = init_logger(__name__)
@ -271,7 +272,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
import deep_gemm as dg
assert hidden_states.ndim == 3
assert self.block_shape is not None
@ -289,18 +289,15 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m = max_num_tokens
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a1q, a1q_scale),
(w1, w1_scale),
out=workspace1,
masked_m=expert_num_tokens,
expected_m=expected_m)
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
out=workspace1,
masked_m=expert_num_tokens,
expected_m=expected_m)
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
expert_num_tokens)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
out=output,
masked_m=expert_num_tokens,
expected_m=expected_m)
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale),
out=output,
masked_m=expert_num_tokens,
expected_m=expected_m)

View File

@ -14,9 +14,10 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, per_token_group_quant_fp8)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.utils import has_deep_gemm, round_up
from vllm.utils.deep_gemm import (m_grouped_fp8_gemm_nt_contiguous,
per_token_group_cast_to_fp8)
logger = init_logger(__name__)
@ -127,7 +128,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
):
import deep_gemm as dg
assert self.block_shape is not None
a1q = hidden_states
@ -164,19 +164,19 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
(M_sum, N // 2))
mm2_out = _resize_cache(workspace2, (M_sum, K))
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
mm1_out, expert_ids)
self.activation(activation, act_out, mm1_out.view(-1, N))
a2q_scale: Optional[torch.Tensor] = None
a2q, a2q_scale = per_token_group_quant_fp8(act_out,
self.block_shape[1],
column_major_scales=True,
out_q=quant_out)
a2q, a2q_scale = per_token_group_cast_to_fp8(act_out,
self.block_shape[1],
column_major_scales=True,
out_q=quant_out)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
mm2_out, expert_ids)
torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K)))

View File

@ -34,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
@ -1171,9 +1172,15 @@ def fused_experts(
allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
# However, on B200, we use DeepGemm for all cases becuase they only support
# E8M0 scale, which means we requantize the weight and input to the specific
# scale. Fallen back to cutlass or triton for some cases would cause
# accuracy issue.
N = w1.size(1)
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2)):
should_use_deep_gemm = ((N > 512
and _valid_deep_gemm(hidden_states, w1, w2))
or is_blackwell_deep_gemm_used())
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
assert apply_router_weight_on_input is False
return deep_gemm_moe_fp8(
hidden_states=hidden_states,
@ -1363,7 +1370,6 @@ def fused_experts_impl(
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
A=curr_hidden_states,
A_scale=a1_scale,

View File

@ -48,7 +48,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype))
a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)

View File

@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@ -102,7 +103,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
if self.allow_deep_gemm and (_valid_deep_gemm_shape(M, N, K)
or is_blackwell_deep_gemm_used()):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes(
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
@ -132,7 +134,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
):
use_deep_gemm = (self.allow_deep_gemm
and _valid_deep_gemm(hidden_states, w1, w2))
and (_valid_deep_gemm(hidden_states, w1, w2)
or is_blackwell_deep_gemm_used()))
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
assert experts is not None

View File

@ -15,6 +15,8 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_used,
per_token_group_cast_to_fp8)
@triton.jit
@ -115,7 +117,10 @@ def _fp8_quantize(
assert not per_act_token
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
if is_blackwell_deep_gemm_used():
A, A_scale = per_token_group_cast_to_fp8(A, block_k)
else:
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
return A, A_scale

View File

@ -8,10 +8,9 @@ from typing import Optional, Union
import numpy as np
import torch
import triton
import triton.language as tl
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import tl, triton
@triton.jit()

View File

@ -6,10 +6,8 @@ import torch
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import direct_register_custom_op, has_deep_gemm
if has_deep_gemm():
import deep_gemm
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import fp8_gemm_nt
logger = logging.getLogger(__name__)
@ -57,7 +55,7 @@ def w8a8_block_fp8_matmul_deepgemm(
output_dtype)
# Deepgemm only supports output tensor type as bfloat16
assert C.dtype == torch.bfloat16
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
fp8_gemm_nt((A, As), (B, Bs), C)
return C

View File

@ -23,6 +23,8 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
prepare_moe_fp8_layer_for_marlin)
@ -40,6 +42,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
@ -393,6 +396,19 @@ class Fp8LinearMethod(LinearMethodBase):
# Activations not quantized for marlin.
del layer.input_scale
# On B200, DeepGemm only support E8M0 scale, which means we need to
# requantize the weight and input to the specific scale
# at the same time.
if is_blackwell_deep_gemm_used():
assert layer.weight_block_size is not None
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.weight.data,
layer.weight_scale_inv.data if hasattr(
layer, "weight_scale_inv") else layer.weight_scale.data,
block_sz,
)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
@ -670,15 +686,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
if self.allow_deep_gemm:
if self.allow_deep_gemm and not is_blackwell_deep_gemm_used():
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
if _is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = \
dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
if _is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = \
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
# If checkpoint is fp16, quantize in place.
elif not self.quant_config.is_checkpoint_fp8_serialized:
@ -797,6 +812,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
if is_blackwell_deep_gemm_used():
assert layer.weight_block_size is not None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.w13_weight.data,
layer.w13_weight_scale_inv.data,
block_sz,
)
requant_weight_ue8m0_inplace(
layer.w2_weight.data,
layer.w2_weight_scale_inv.data,
block_sz,
)
# Ensure column-major TMA alignment expected by DeepGEMM.
if _is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale_inv).contiguous()
if _is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv).contiguous()
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,

View File

@ -5,6 +5,7 @@
import functools
import json
import os
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
import torch
@ -13,7 +14,7 @@ import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_dequantize)
group_broadcast)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.platforms import current_platform
@ -235,7 +236,7 @@ def block_quant_to_tensor_quant(
The outputs are tensor-wise quantization tensor and tensor-wise
quantization scale. Note only float8 is supported for now.
"""
x_dq_block = scaled_dequantize(x_q_block, x_s)
x_dq_block = group_broadcast(x_q_block, x_s)
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
return x_q_tensor, scale
@ -651,3 +652,124 @@ def w8a8_block_fp8_matmul(
)
return C
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
# TODO(wentao): remove this function when DeepGEMM exposes this function
def get_tma_aligned_size(x: int, element_size: int) -> int:
"""
Global memory address of TMA must be 16-byte aligned.
Since we use column-major layout for the LHS scaling tensor,
the M-axis of the LHS scaling tensor needs to be padded to a multiple of
16 bytes.
Arguments:
x: original M-axis shape of the LHS scaling tensor.
element_size: element size of the LHS scaling tensor.
Returns:
M-axis shape of the LHS scaling tensor after padding.
"""
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return cdiv(x, alignment) * alignment
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
# TODO(wentao): remove this function when DeepGEMM exposes this function
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
"""
Returns TMA-aligned transposed format of the input tensor. `torch.transpose`
will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along
the M axis (thus meets the requirement of LHS scaling tensor in
DeepGEMM), this function will do nothing.
Arguments:
x: usually the LHS scaling tensor in GEMM.
Returns:
The LHS scaling tensor of TMA-aligned transposed format.
"""
# NOTES: for the extreme performance, you may rewrite/fuse this function in
# CUDA
assert x.dim() in (2, 3)
remove_dim = False
m, n = x.shape[-2], x.shape[-1]
aligned_m = get_tma_aligned_size(m, x.element_size())
if x.dim() == 2:
if x.stride(0) == 1 and x.stride(1) == aligned_m:
return x
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
# The last kernel gives a column-major TMA aligned layout
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(
2) == aligned_m:
return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing
aligned_x = torch.transpose(
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
aligned_x[:, :m, :] = x
aligned_x = aligned_x[:, :m, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x
def requant_weight_ue8m0_inplace(
weight: torch.Tensor,
weight_scale: torch.Tensor,
block_size: Sequence[int] = (128, 128),
) -> None:
"""Re-quantise *weight* so that its per-block scaling factors are in the
UE8M0 (power-of-two) format expected by the new DeepGEMM kernels inplace.
Args:
weight: Block-quantised weight tensor stored in ``torch.float8_e4m3fn``.
Expected shape ``(..., M, K)``.
weight_scale: Corresponding per-block scale tensor (``torch.float32``)
with shape ``(..., M // block_size[0], K // block_size[1])``.
block_size: 2-element iterable ``[block_m, block_k]`` describing the
block quantisation granularity.
"""
if weight.numel() == 0:
return
if weight.dtype != torch.float8_e4m3fn:
raise ValueError("Expected *weight* to be torch.float8_e4m3fn, got "
f"{weight.dtype} instead.")
from vllm.utils.deep_gemm import per_block_cast_to_fp8
block_m, block_k = int(block_size[0]), int(block_size[1])
# Flatten leading dimensions so we can iterate over the last two dims.
leading_shape = weight.shape[:-2]
if len(leading_shape) == 0:
w_view = weight.unsqueeze(0)
s_view = weight_scale.unsqueeze(0)
else:
w_view = weight.reshape(-1, weight.shape[-2], weight.shape[-1])
s_view = weight_scale.reshape(-1, *weight_scale.shape[-2:])
num_mats = w_view.size(0)
for idx in range(num_mats):
w_q = w_view[idx]
s_old = s_view[idx]
# De-quantise with the *old* scaling factors (float32).
m_cur, k_cur = w_q.shape
s_float = s_old.to(torch.float32)
# Expand scales along rows and cols by block size, then crop.
s_exp_r = torch.repeat_interleave(s_float, block_m, dim=0)
s_exp = torch.repeat_interleave(s_exp_r, block_k, dim=1)
s_exp = s_exp[:m_cur, :k_cur]
w_dq = w_q.to(torch.float32) * s_exp
# Re-quantise using power-of-two scaling (UE8M0).
w_requant, s_requant = per_block_cast_to_fp8(w_dq, [block_m, block_k])
# Write back the results in-place.
w_q.copy_(w_requant)
s_old.copy_(s_requant)

152
vllm/utils/deep_gemm.py Normal file
View File

@ -0,0 +1,152 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for DeepGEMM API changes.
Users of vLLM should always import **only** these wrappers.
"""
from __future__ import annotations
import functools
import importlib
from typing import Any, Callable, NoReturn
import torch
import vllm.envs as envs
from vllm.utils import cuda_get_device_properties, has_deep_gemm
@functools.cache
def is_blackwell_deep_gemm_used() -> bool:
"""Return ``True`` if vLLM is configured to use DeepGEMM on a
Blackwell-class GPU.
"""
if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()
and _per_block_cast_impl is not None):
return False
return cuda_get_device_properties(0, ("major", ))[0] == 10
def _missing(*_: Any, **__: Any) -> NoReturn:
"""Placeholder for unavailable DeepGEMM backend."""
raise RuntimeError(
"DeepGEMM backend is not available. Please install the `deep_gemm` "
"package to enable FP8 kernels.")
def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
"""Return the *new* symbol if it exists, otherwise the *old* one."""
if hasattr(module, new):
return getattr(module, new)
if hasattr(module, old):
return getattr(module, old)
return None
if not has_deep_gemm():
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
_per_token_cast_impl: Callable[..., Any] | None = None
_per_block_cast_impl: Callable[..., Any] | None = None
else:
_dg = importlib.import_module("deep_gemm") # type: ignore
_fp8_gemm_nt_impl = _resolve_symbol(
_dg,
"fp8_gemm_nt",
"gemm_fp8_fp8_bf16_nt",
)
_grouped_impl = _resolve_symbol(
_dg,
"m_grouped_fp8_gemm_nt_contiguous",
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
)
_grouped_masked_impl = _resolve_symbol(
_dg,
"fp8_m_grouped_gemm_nt_masked",
"m_grouped_gemm_fp8_fp8_bf16_nt_masked",
)
# Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
try:
_math_mod = importlib.import_module(
"deep_gemm.utils.math") # type: ignore
_per_token_cast_impl = getattr(_math_mod, "per_token_cast_to_fp8",
None)
_per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8",
None)
except ModuleNotFoundError:
_per_token_cast_impl = None
_per_block_cast_impl = None
def fp8_gemm_nt(*args, **kwargs):
if _fp8_gemm_nt_impl is None:
return _missing(*args, **kwargs)
return _fp8_gemm_nt_impl(*args, **kwargs)
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
if _grouped_impl is None:
return _missing(*args, **kwargs)
return _grouped_impl(*args, **kwargs)
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
if _grouped_masked_impl is None:
return _missing(*args, **kwargs)
return _grouped_masked_impl(*args, **kwargs)
def per_token_group_cast_to_fp8(x, group_size, *args, **kwargs):
"""Wrapper for token-wise FP8 quantisation.
• If DeepGEMM provides ``per_token_cast_to_fp8`` (new API), use it.
• Otherwise, fall back to vLLM's ``per_token_group_quant_fp8``
"""
if _per_token_cast_impl is not None and is_blackwell_deep_gemm_used():
assert group_size == 128, "group_size must be 128 for deepgemm"
return _per_token_cast_impl(x)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8 as _ptg)
return _ptg(x, group_size, *args, **kwargs)
def per_block_cast_to_fp8(x, *args, **kwargs):
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
return _per_block_cast_impl(x)
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf
return _pbcf(x, *args, **kwargs)
def calc_diff(x: torch.Tensor, y: torch.Tensor):
"""Return a global difference metric for unit tests.
DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
error, causing ``torch.testing.assert_close`` to fail. Instead of checking
every element, we compute a cosine-style similarity over the whole tensor
and report ``1 - sim``. Once kernel accuracy improves this helper can be
removed.
"""
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
__all__ = [
"calc_diff",
"fp8_gemm_nt",
"m_grouped_fp8_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked",
"per_token_group_cast_to_fp8",
"per_block_cast_to_fp8",
"is_blackwell_deep_gemm_used",
]