mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Feature] Integrate SM100 DeepGEMM support (#20087)
This commit is contained in:
@ -86,6 +86,9 @@ def benchmark_config(
|
||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_deep_gemm:
|
||||
# we use the default block shape for deepgemm
|
||||
block_quant_shape = [128, 128]
|
||||
if use_fp8_w8a8:
|
||||
if block_quant_shape:
|
||||
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||
|
@ -15,13 +15,13 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
|
||||
dg_available = False
|
||||
try:
|
||||
import deep_gemm
|
||||
dg_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
dg_available = has_deep_gemm()
|
||||
|
||||
if dg_available:
|
||||
from deep_gemm import get_m_alignment_for_contiguous_layout
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||
@ -224,6 +224,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||
@pytest.mark.skipif(is_blackwell_deep_gemm_used(), reason="Not E8M0 scale MOE")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
monkeypatch):
|
||||
@ -238,8 +239,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
|
||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
block_m = get_m_alignment_for_contiguous_layout()
|
||||
block_size = [block_m, block_m]
|
||||
dtype = torch.bfloat16
|
||||
|
||||
|
@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
from .utils import make_test_weights
|
||||
@ -368,6 +369,8 @@ NUM_EXPERTS = [32]
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
|
||||
reason="Skipping test for Blackwell DeepGEMM")
|
||||
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
||||
topk: int, world_dp_size: tuple[int, int]):
|
||||
"""
|
||||
@ -423,6 +426,8 @@ USE_FP8_DISPATCH = [False]
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
|
||||
reason="Skipping test for Blackwell DeepGEMM")
|
||||
def test_ll_deepep_deepgemm_moe(
|
||||
mnk: tuple[int, int, int],
|
||||
num_experts: int,
|
||||
|
@ -13,48 +13,18 @@ import torch
|
||||
|
||||
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (calc_diff, per_block_cast_to_fp8,
|
||||
per_token_group_cast_to_fp8)
|
||||
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
if has_deep_gemm:
|
||||
import deep_gemm
|
||||
BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
BLOCK_SIZE = [BLOCK_M, BLOCK_M]
|
||||
BLOCK_SIZE = [128, 128]
|
||||
|
||||
requires_deep_gemm = pytest.mark.skipif(
|
||||
not has_deep_gemm,
|
||||
not has_deep_gemm(),
|
||||
reason="Requires deep_gemm kernels",
|
||||
)
|
||||
|
||||
|
||||
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
def make_block_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
@ -111,7 +81,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
"""
|
||||
tokens_bf16 = torch.randn(
|
||||
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
|
||||
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
|
||||
_, a1_scale = per_token_group_cast_to_fp8(tokens_bf16, block_size[1])
|
||||
|
||||
# expert weight tensors
|
||||
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
|
||||
@ -155,17 +125,8 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
block_shape=block_size,
|
||||
allow_deep_gemm=True,
|
||||
)
|
||||
|
||||
base = out_triton.abs().mean()
|
||||
atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3
|
||||
rtol = 0.05
|
||||
# ----- Compare -----
|
||||
torch.testing.assert_close(
|
||||
out_deepgemm.to(torch.float32),
|
||||
out_triton.to(torch.float32),
|
||||
rtol=rtol,
|
||||
atol=float(atol),
|
||||
)
|
||||
diff = calc_diff(out_deepgemm, out_triton)
|
||||
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
|
||||
|
||||
|
||||
# Note: W1 has shape (E, 2N, K), so N = 512
|
||||
|
@ -8,19 +8,15 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul,
|
||||
per_block_cast_to_fp8)
|
||||
native_w8a8_block_matmul)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
get_col_major_tma_aligned_tensor, per_token_group_quant_fp8,
|
||||
w8a8_block_fp8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
dg_available = False
|
||||
try:
|
||||
import deep_gemm
|
||||
dg_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (fp8_gemm_nt, per_block_cast_to_fp8,
|
||||
per_token_group_cast_to_fp8)
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||
@ -106,7 +102,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||
@pytest.mark.skipif(not has_deep_gemm(),
|
||||
reason="DeepGemm kernels not available.")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
# only aligned sizes
|
||||
@ -120,9 +117,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
|
||||
_, block_k = block_size[0], block_size[1]
|
||||
|
||||
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k)
|
||||
A_fp8, As_fp8 = per_token_group_cast_to_fp8(A_fp32, block_size[1])
|
||||
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
|
||||
|
||||
As = As_fp8.to(torch.float32)
|
||||
@ -132,14 +127,14 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
out_dtype)
|
||||
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8)
|
||||
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
|
||||
|
||||
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
|
||||
|
||||
assert As_fp8.shape == (M, (K + 127) //
|
||||
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
|
||||
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
|
||||
fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
|
@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -271,7 +272,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
assert expert_tokens_meta is not None
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
|
||||
import deep_gemm as dg
|
||||
assert hidden_states.ndim == 3
|
||||
assert self.block_shape is not None
|
||||
|
||||
@ -289,18 +289,15 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# for the M expectation of each batch, correctly setting this value
|
||||
# may lead to better performance.
|
||||
expected_m = max_num_tokens
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a1q, a1q_scale),
|
||||
(w1, w1_scale),
|
||||
out=workspace1,
|
||||
masked_m=expert_num_tokens,
|
||||
expected_m=expected_m)
|
||||
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
|
||||
out=workspace1,
|
||||
masked_m=expert_num_tokens,
|
||||
expected_m=expected_m)
|
||||
|
||||
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
|
||||
expert_num_tokens)
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
|
||||
(w2, w2_scale),
|
||||
out=output,
|
||||
masked_m=expert_num_tokens,
|
||||
expected_m=expected_m)
|
||||
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale),
|
||||
out=output,
|
||||
masked_m=expert_num_tokens,
|
||||
expected_m=expected_m)
|
||||
|
@ -14,9 +14,10 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.utils import has_deep_gemm, round_up
|
||||
from vllm.utils.deep_gemm import (m_grouped_fp8_gemm_nt_contiguous,
|
||||
per_token_group_cast_to_fp8)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -127,7 +128,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
):
|
||||
import deep_gemm as dg
|
||||
assert self.block_shape is not None
|
||||
|
||||
a1q = hidden_states
|
||||
@ -164,19 +164,19 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
(M_sum, N // 2))
|
||||
mm2_out = _resize_cache(workspace2, (M_sum, K))
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
|
||||
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
|
||||
mm1_out, expert_ids)
|
||||
|
||||
self.activation(activation, act_out, mm1_out.view(-1, N))
|
||||
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
a2q, a2q_scale = per_token_group_quant_fp8(act_out,
|
||||
self.block_shape[1],
|
||||
column_major_scales=True,
|
||||
out_q=quant_out)
|
||||
a2q, a2q_scale = per_token_group_cast_to_fp8(act_out,
|
||||
self.block_shape[1],
|
||||
column_major_scales=True,
|
||||
out_q=quant_out)
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
|
||||
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
|
||||
mm2_out, expert_ids)
|
||||
|
||||
torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K)))
|
||||
|
||||
|
@ -34,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
|
||||
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
||||
|
||||
@ -1171,9 +1172,15 @@ def fused_experts(
|
||||
allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor:
|
||||
# For now, disable DeepGemm for small N (<= 512) until better
|
||||
# permute/unpermute ops are available.
|
||||
# However, on B200, we use DeepGemm for all cases becuase they only support
|
||||
# E8M0 scale, which means we requantize the weight and input to the specific
|
||||
# scale. Fallen back to cutlass or triton for some cases would cause
|
||||
# accuracy issue.
|
||||
N = w1.size(1)
|
||||
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
|
||||
and _valid_deep_gemm(hidden_states, w1, w2)):
|
||||
should_use_deep_gemm = ((N > 512
|
||||
and _valid_deep_gemm(hidden_states, w1, w2))
|
||||
or is_blackwell_deep_gemm_used())
|
||||
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
|
||||
assert apply_router_weight_on_input is False
|
||||
return deep_gemm_moe_fp8(
|
||||
hidden_states=hidden_states,
|
||||
@ -1363,7 +1370,6 @@ def fused_experts_impl(
|
||||
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
||||
A=curr_hidden_states,
|
||||
A_scale=a1_scale,
|
||||
|
@ -48,7 +48,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
assert topk == 1, \
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
a1.mul_(topk_weights.to(a1.dtype))
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1, a1_scale, quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||
|
@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
|
||||
|
||||
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
@ -102,7 +103,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
|
||||
if self.allow_deep_gemm and (_valid_deep_gemm_shape(M, N, K)
|
||||
or is_blackwell_deep_gemm_used()):
|
||||
assert self.deep_gemm_expert is not None
|
||||
return self.deep_gemm_expert.workspace_shapes(
|
||||
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
|
||||
@ -132,7 +134,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
):
|
||||
use_deep_gemm = (self.allow_deep_gemm
|
||||
and _valid_deep_gemm(hidden_states, w1, w2))
|
||||
and (_valid_deep_gemm(hidden_states, w1, w2)
|
||||
or is_blackwell_deep_gemm_used()))
|
||||
|
||||
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
|
||||
assert experts is not None
|
||||
|
@ -15,6 +15,8 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_used,
|
||||
per_token_group_cast_to_fp8)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -115,7 +117,10 @@ def _fp8_quantize(
|
||||
assert not per_act_token
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
if is_blackwell_deep_gemm_used():
|
||||
A, A_scale = per_token_group_cast_to_fp8(A, block_k)
|
||||
else:
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
|
||||
|
||||
return A, A_scale
|
||||
|
@ -8,10 +8,9 @@ from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit()
|
||||
|
@ -6,10 +6,8 @@ import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils import direct_register_custom_op, has_deep_gemm
|
||||
|
||||
if has_deep_gemm():
|
||||
import deep_gemm
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.deep_gemm import fp8_gemm_nt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -57,7 +55,7 @@ def w8a8_block_fp8_matmul_deepgemm(
|
||||
output_dtype)
|
||||
# Deepgemm only supports output tensor type as bfloat16
|
||||
assert C.dtype == torch.bfloat16
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||
fp8_gemm_nt((A, As), (B, Bs), C)
|
||||
return C
|
||||
|
||||
|
||||
|
@ -23,6 +23,8 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
|
||||
prepare_moe_fp8_layer_for_marlin)
|
||||
@ -40,6 +42,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
@ -393,6 +396,19 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale
|
||||
|
||||
# On B200, DeepGemm only support E8M0 scale, which means we need to
|
||||
# requantize the weight and input to the specific scale
|
||||
# at the same time.
|
||||
if is_blackwell_deep_gemm_used():
|
||||
assert layer.weight_block_size is not None
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.weight.data,
|
||||
layer.weight_scale_inv.data if hasattr(
|
||||
layer, "weight_scale_inv") else layer.weight_scale.data,
|
||||
block_sz,
|
||||
)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
@ -670,15 +686,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||
# it ahead of time for performance reasons.
|
||||
if self.allow_deep_gemm:
|
||||
if self.allow_deep_gemm and not is_blackwell_deep_gemm_used():
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
if _is_col_major(layer.w13_weight_scale_inv):
|
||||
layer.w13_weight_scale_inv = \
|
||||
dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
|
||||
get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
|
||||
if _is_col_major(layer.w2_weight_scale_inv):
|
||||
layer.w2_weight_scale_inv = \
|
||||
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
|
||||
get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
@ -797,6 +812,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
if is_blackwell_deep_gemm_used():
|
||||
assert layer.weight_block_size is not None
|
||||
# Re-quantise the expert weights so their scales are UE8M0.
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.w13_weight.data,
|
||||
layer.w13_weight_scale_inv.data,
|
||||
block_sz,
|
||||
)
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.w2_weight.data,
|
||||
layer.w2_weight_scale_inv.data,
|
||||
block_sz,
|
||||
)
|
||||
|
||||
# Ensure column-major TMA alignment expected by DeepGEMM.
|
||||
if _is_col_major(layer.w13_weight_scale_inv):
|
||||
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||
layer.w13_weight_scale_inv).contiguous()
|
||||
if _is_col_major(layer.w2_weight_scale_inv):
|
||||
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||
layer.w2_weight_scale_inv).contiguous()
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
|
@ -5,6 +5,7 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -13,7 +14,7 @@ import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
scaled_dequantize)
|
||||
group_broadcast)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.platforms import current_platform
|
||||
@ -235,7 +236,7 @@ def block_quant_to_tensor_quant(
|
||||
The outputs are tensor-wise quantization tensor and tensor-wise
|
||||
quantization scale. Note only float8 is supported for now.
|
||||
"""
|
||||
x_dq_block = scaled_dequantize(x_q_block, x_s)
|
||||
x_dq_block = group_broadcast(x_q_block, x_s)
|
||||
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
||||
return x_q_tensor, scale
|
||||
|
||||
@ -651,3 +652,124 @@ def w8a8_block_fp8_matmul(
|
||||
)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
|
||||
# TODO(wentao): remove this function when DeepGEMM exposes this function
|
||||
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
"""
|
||||
Global memory address of TMA must be 16-byte aligned.
|
||||
Since we use column-major layout for the LHS scaling tensor,
|
||||
the M-axis of the LHS scaling tensor needs to be padded to a multiple of
|
||||
16 bytes.
|
||||
|
||||
Arguments:
|
||||
x: original M-axis shape of the LHS scaling tensor.
|
||||
element_size: element size of the LHS scaling tensor.
|
||||
|
||||
Returns:
|
||||
M-axis shape of the LHS scaling tensor after padding.
|
||||
"""
|
||||
tma_alignment_bytes = 16
|
||||
assert tma_alignment_bytes % element_size == 0
|
||||
alignment = tma_alignment_bytes // element_size
|
||||
return cdiv(x, alignment) * alignment
|
||||
|
||||
|
||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
|
||||
# TODO(wentao): remove this function when DeepGEMM exposes this function
|
||||
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Returns TMA-aligned transposed format of the input tensor. `torch.transpose`
|
||||
will be called if necessary.
|
||||
If the input tensor is already column-major layout and 16-byte aligned along
|
||||
the M axis (thus meets the requirement of LHS scaling tensor in
|
||||
DeepGEMM), this function will do nothing.
|
||||
|
||||
Arguments:
|
||||
x: usually the LHS scaling tensor in GEMM.
|
||||
|
||||
Returns:
|
||||
The LHS scaling tensor of TMA-aligned transposed format.
|
||||
"""
|
||||
# NOTES: for the extreme performance, you may rewrite/fuse this function in
|
||||
# CUDA
|
||||
assert x.dim() in (2, 3)
|
||||
remove_dim = False
|
||||
m, n = x.shape[-2], x.shape[-1]
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
if x.dim() == 2:
|
||||
if x.stride(0) == 1 and x.stride(1) == aligned_m:
|
||||
return x
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
b = x.shape[0]
|
||||
|
||||
# The last kernel gives a column-major TMA aligned layout
|
||||
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(
|
||||
2) == aligned_m:
|
||||
return x.squeeze(0) if remove_dim else x
|
||||
|
||||
# Normal layout requires transposing
|
||||
aligned_x = torch.transpose(
|
||||
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||
aligned_x[:, :m, :] = x
|
||||
aligned_x = aligned_x[:, :m, :]
|
||||
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
||||
|
||||
|
||||
def requant_weight_ue8m0_inplace(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: Sequence[int] = (128, 128),
|
||||
) -> None:
|
||||
"""Re-quantise *weight* so that its per-block scaling factors are in the
|
||||
UE8M0 (power-of-two) format expected by the new DeepGEMM kernels inplace.
|
||||
|
||||
Args:
|
||||
weight: Block-quantised weight tensor stored in ``torch.float8_e4m3fn``.
|
||||
Expected shape ``(..., M, K)``.
|
||||
weight_scale: Corresponding per-block scale tensor (``torch.float32``)
|
||||
with shape ``(..., M // block_size[0], K // block_size[1])``.
|
||||
block_size: 2-element iterable ``[block_m, block_k]`` describing the
|
||||
block quantisation granularity.
|
||||
"""
|
||||
if weight.numel() == 0:
|
||||
return
|
||||
|
||||
if weight.dtype != torch.float8_e4m3fn:
|
||||
raise ValueError("Expected *weight* to be torch.float8_e4m3fn, got "
|
||||
f"{weight.dtype} instead.")
|
||||
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
|
||||
block_m, block_k = int(block_size[0]), int(block_size[1])
|
||||
|
||||
# Flatten leading dimensions so we can iterate over the last two dims.
|
||||
leading_shape = weight.shape[:-2]
|
||||
if len(leading_shape) == 0:
|
||||
w_view = weight.unsqueeze(0)
|
||||
s_view = weight_scale.unsqueeze(0)
|
||||
else:
|
||||
w_view = weight.reshape(-1, weight.shape[-2], weight.shape[-1])
|
||||
s_view = weight_scale.reshape(-1, *weight_scale.shape[-2:])
|
||||
|
||||
num_mats = w_view.size(0)
|
||||
for idx in range(num_mats):
|
||||
w_q = w_view[idx]
|
||||
s_old = s_view[idx]
|
||||
|
||||
# De-quantise with the *old* scaling factors (float32).
|
||||
m_cur, k_cur = w_q.shape
|
||||
s_float = s_old.to(torch.float32)
|
||||
# Expand scales along rows and cols by block size, then crop.
|
||||
s_exp_r = torch.repeat_interleave(s_float, block_m, dim=0)
|
||||
s_exp = torch.repeat_interleave(s_exp_r, block_k, dim=1)
|
||||
s_exp = s_exp[:m_cur, :k_cur]
|
||||
w_dq = w_q.to(torch.float32) * s_exp
|
||||
# Re-quantise using power-of-two scaling (UE8M0).
|
||||
w_requant, s_requant = per_block_cast_to_fp8(w_dq, [block_m, block_k])
|
||||
|
||||
# Write back the results in-place.
|
||||
w_q.copy_(w_requant)
|
||||
s_old.copy_(s_requant)
|
||||
|
152
vllm/utils/deep_gemm.py
Normal file
152
vllm/utils/deep_gemm.py
Normal file
@ -0,0 +1,152 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Compatibility wrapper for DeepGEMM API changes.
|
||||
|
||||
Users of vLLM should always import **only** these wrappers.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
from typing import Any, Callable, NoReturn
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.utils import cuda_get_device_properties, has_deep_gemm
|
||||
|
||||
|
||||
@functools.cache
|
||||
def is_blackwell_deep_gemm_used() -> bool:
|
||||
"""Return ``True`` if vLLM is configured to use DeepGEMM on a
|
||||
Blackwell-class GPU.
|
||||
"""
|
||||
|
||||
if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()
|
||||
and _per_block_cast_impl is not None):
|
||||
return False
|
||||
|
||||
return cuda_get_device_properties(0, ("major", ))[0] == 10
|
||||
|
||||
|
||||
def _missing(*_: Any, **__: Any) -> NoReturn:
|
||||
"""Placeholder for unavailable DeepGEMM backend."""
|
||||
raise RuntimeError(
|
||||
"DeepGEMM backend is not available. Please install the `deep_gemm` "
|
||||
"package to enable FP8 kernels.")
|
||||
|
||||
|
||||
def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
|
||||
"""Return the *new* symbol if it exists, otherwise the *old* one."""
|
||||
if hasattr(module, new):
|
||||
return getattr(module, new)
|
||||
if hasattr(module, old):
|
||||
return getattr(module, old)
|
||||
return None
|
||||
|
||||
|
||||
if not has_deep_gemm():
|
||||
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
||||
_grouped_impl: Callable[..., Any] | None = None
|
||||
_grouped_masked_impl: Callable[..., Any] | None = None
|
||||
_per_token_cast_impl: Callable[..., Any] | None = None
|
||||
_per_block_cast_impl: Callable[..., Any] | None = None
|
||||
else:
|
||||
_dg = importlib.import_module("deep_gemm") # type: ignore
|
||||
|
||||
_fp8_gemm_nt_impl = _resolve_symbol(
|
||||
_dg,
|
||||
"fp8_gemm_nt",
|
||||
"gemm_fp8_fp8_bf16_nt",
|
||||
)
|
||||
_grouped_impl = _resolve_symbol(
|
||||
_dg,
|
||||
"m_grouped_fp8_gemm_nt_contiguous",
|
||||
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
|
||||
)
|
||||
_grouped_masked_impl = _resolve_symbol(
|
||||
_dg,
|
||||
"fp8_m_grouped_gemm_nt_masked",
|
||||
"m_grouped_gemm_fp8_fp8_bf16_nt_masked",
|
||||
)
|
||||
|
||||
# Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
|
||||
try:
|
||||
_math_mod = importlib.import_module(
|
||||
"deep_gemm.utils.math") # type: ignore
|
||||
_per_token_cast_impl = getattr(_math_mod, "per_token_cast_to_fp8",
|
||||
None)
|
||||
_per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8",
|
||||
None)
|
||||
except ModuleNotFoundError:
|
||||
_per_token_cast_impl = None
|
||||
_per_block_cast_impl = None
|
||||
|
||||
|
||||
def fp8_gemm_nt(*args, **kwargs):
|
||||
if _fp8_gemm_nt_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _fp8_gemm_nt_impl(*args, **kwargs)
|
||||
|
||||
|
||||
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
|
||||
if _grouped_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _grouped_impl(*args, **kwargs)
|
||||
|
||||
|
||||
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||
if _grouped_masked_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _grouped_masked_impl(*args, **kwargs)
|
||||
|
||||
|
||||
def per_token_group_cast_to_fp8(x, group_size, *args, **kwargs):
|
||||
"""Wrapper for token-wise FP8 quantisation.
|
||||
|
||||
• If DeepGEMM provides ``per_token_cast_to_fp8`` (new API), use it.
|
||||
• Otherwise, fall back to vLLM's ``per_token_group_quant_fp8``
|
||||
"""
|
||||
|
||||
if _per_token_cast_impl is not None and is_blackwell_deep_gemm_used():
|
||||
assert group_size == 128, "group_size must be 128 for deepgemm"
|
||||
return _per_token_cast_impl(x)
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8 as _ptg)
|
||||
return _ptg(x, group_size, *args, **kwargs)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x, *args, **kwargs):
|
||||
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
|
||||
return _per_block_cast_impl(x)
|
||||
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
|
||||
from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf
|
||||
return _pbcf(x, *args, **kwargs)
|
||||
|
||||
|
||||
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
"""Return a global difference metric for unit tests.
|
||||
|
||||
DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
|
||||
error, causing ``torch.testing.assert_close`` to fail. Instead of checking
|
||||
every element, we compute a cosine-style similarity over the whole tensor
|
||||
and report ``1 - sim``. Once kernel accuracy improves this helper can be
|
||||
removed.
|
||||
"""
|
||||
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
__all__ = [
|
||||
"calc_diff",
|
||||
"fp8_gemm_nt",
|
||||
"m_grouped_fp8_gemm_nt_contiguous",
|
||||
"fp8_m_grouped_gemm_nt_masked",
|
||||
"per_token_group_cast_to_fp8",
|
||||
"per_block_cast_to_fp8",
|
||||
"is_blackwell_deep_gemm_used",
|
||||
]
|
Reference in New Issue
Block a user