mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
12 Commits
v0.11.0rc3
...
deepep_twe
Author | SHA1 | Date | |
---|---|---|---|
fcec8c8827 | |||
850dafea92 | |||
b4f17e12a4 | |||
21ffc7353a | |||
39d5d33f8f | |||
7a821f0e7f | |||
26fd8ca33c | |||
d5f206767c | |||
2b5ad9f233 | |||
299f829180 | |||
104a984e6a | |||
8de2fd39fc |
83
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
Normal file
83
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
Normal file
@ -0,0 +1,83 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
silu_mul_fp8_quant_deep_gemm)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# (E, T, H, group_size, seed)
|
||||
CASES = [
|
||||
(1, 1, 128, 64, 0),
|
||||
(1, 4, 128, 128, 0),
|
||||
(2, 4, 256, 128, 0),
|
||||
(32, 64, 256, 128, 0),
|
||||
(17, 31, 768, 128, 0),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("E,T,H,group_size,seed", CASES)
|
||||
@torch.inference_mode()
|
||||
def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
# Input tensor of shape (E, T, 2*H)
|
||||
y = torch.randn((E, T, 2 * H), dtype=torch.float32, device="cuda")
|
||||
tokens_per_expert = torch.randint(
|
||||
low=0,
|
||||
high=T,
|
||||
size=(E, ),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Run the Triton kernel
|
||||
y_q, y_s = silu_mul_fp8_quant_deep_gemm(y,
|
||||
tokens_per_expert,
|
||||
group_size=group_size,
|
||||
eps=1e-10)
|
||||
|
||||
# Reference implementation
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = fp8_info.max
|
||||
fp8_min = fp8_info.min
|
||||
eps = 1e-10
|
||||
|
||||
# Compute silu activation and elementwise multiplication
|
||||
y1 = y[..., :H]
|
||||
y2 = y[..., H:]
|
||||
silu_x = y1 * torch.sigmoid(y1)
|
||||
merged = silu_x * y2
|
||||
|
||||
# Compute reference scales and quantized output, skipping padded tokens
|
||||
for e in range(E):
|
||||
nt = tokens_per_expert[e].item()
|
||||
ref_s = torch.empty((T, H // group_size),
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda")
|
||||
for t in range(nt):
|
||||
data = merged[e, t]
|
||||
data_grp = data.view(H // group_size, group_size)
|
||||
amax = data_grp.abs().amax(dim=1).clamp(min=eps)
|
||||
scale = amax / fp8_max
|
||||
|
||||
scaled = data / scale.repeat_interleave(group_size)
|
||||
clamped = scaled.clamp(fp8_min, fp8_max)
|
||||
q = clamped.to(torch.float8_e4m3fn)
|
||||
|
||||
ref_s[t] = scale
|
||||
ref_q[t] = q
|
||||
|
||||
y_se = y_s[e]
|
||||
y_qe = y_q[e]
|
||||
|
||||
torch.testing.assert_close(y_se[:nt], ref_s[:nt])
|
||||
torch.testing.assert_close(
|
||||
y_qe[:nt].to(torch.float32),
|
||||
ref_q[:nt].to(torch.float32),
|
||||
atol=2,
|
||||
rtol=2e-1,
|
||||
)
|
@ -138,9 +138,29 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
super().__init__(cpu_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
# This is the DeepEP default. Stick to it till we can establish
|
||||
# reasonable defaults based on profiling.
|
||||
self.num_sms = 20
|
||||
# Use all SMs for all2all communication
|
||||
# This will need to be adjusted for dual-batch overlap
|
||||
device = self.dp_group.device
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
self.num_sms = props.multi_processor_count
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
raise NotImplementedError
|
||||
|
@ -6,14 +6,179 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _silu_mul_fp8_quant_deep_gemm(
|
||||
# Pointers ------------------------------------------------------------
|
||||
input_ptr, # 16-bit activations (E, T, 2*H)
|
||||
y_q_ptr, # fp8 quantized activations (E, T, H)
|
||||
y_s_ptr, # 16-bit scales (E, T, G)
|
||||
counts_ptr, # int32 num tokens per expert (E)
|
||||
|
||||
# Sizes ---------------------------------------------------------------
|
||||
H: tl.constexpr, # hidden dimension (per output)
|
||||
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
|
||||
|
||||
# Strides for input (elements) ---------------------------------------
|
||||
stride_i_e,
|
||||
stride_i_t,
|
||||
stride_i_h,
|
||||
|
||||
# Strides for y_q (elements) -----------------------------------------
|
||||
stride_yq_e,
|
||||
stride_yq_t,
|
||||
stride_yq_h,
|
||||
|
||||
# Strides for y_s (elements) -----------------------------------------
|
||||
stride_ys_e,
|
||||
stride_ys_t,
|
||||
stride_ys_g,
|
||||
|
||||
# Stride for counts (elements)
|
||||
stride_counts_e,
|
||||
|
||||
# Numeric params ------------------------------------------------------
|
||||
eps: tl.constexpr,
|
||||
fp8_min: tl.constexpr,
|
||||
fp8_max: tl.constexpr,
|
||||
|
||||
# Meta ---------------------------------------------------------------
|
||||
BLOCK: tl.constexpr,
|
||||
):
|
||||
G = H // GROUP_SIZE
|
||||
|
||||
# map program id -> (e, g)
|
||||
pid = tl.program_id(0)
|
||||
e = pid // G
|
||||
g = pid % G
|
||||
|
||||
e = e.to(tl.int64)
|
||||
g = g.to(tl.int64)
|
||||
|
||||
# number of valid tokens for this expert
|
||||
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
|
||||
|
||||
cols = tl.arange(0, BLOCK)
|
||||
cols = cols.to(tl.int64)
|
||||
mask_h = cols < BLOCK
|
||||
|
||||
t = tl.zeros([], tl.int64)
|
||||
while t < n_tokens:
|
||||
base_i_offset = (e * stride_i_e + t * stride_i_t +
|
||||
g * GROUP_SIZE * stride_i_h)
|
||||
base_yq_offset = (e * stride_yq_e + t * stride_yq_t +
|
||||
g * GROUP_SIZE * stride_yq_h)
|
||||
base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g
|
||||
|
||||
mask = mask_h
|
||||
x = tl.load(input_ptr + base_i_offset + cols * stride_i_h,
|
||||
mask=mask,
|
||||
other=0.0).to(tl.float32)
|
||||
y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h +
|
||||
cols * stride_i_h,
|
||||
mask=mask,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
x = x * (1.0 / (1.0 + tl.exp(-x)))
|
||||
y = x * y2
|
||||
|
||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||
y_s = _absmax / fp8_max
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
|
||||
tl.store(y_s_ptr + base_ys_offset, y_s)
|
||||
|
||||
t += 1
|
||||
|
||||
|
||||
def silu_mul_fp8_quant_deep_gemm(
|
||||
y: torch.Tensor, # (E, T, 2*H) float32
|
||||
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
||||
group_size: int = 128,
|
||||
eps: float = 1e-10,
|
||||
):
|
||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||
|
||||
y has shape (E, T, 2*H). The first half of the last dimension is
|
||||
silu-activated, multiplied by the second half, then quantized into FP8.
|
||||
|
||||
Returns `(y_q, y_s)` where
|
||||
* `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`.
|
||||
* `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)`
|
||||
"""
|
||||
assert y.ndim == 3, "y must be (E, T, 2*H)"
|
||||
E, T, H2 = y.shape
|
||||
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
|
||||
H = H2 // 2
|
||||
G = H // group_size
|
||||
assert H % group_size == 0, "H must be divisible by group_size"
|
||||
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \
|
||||
"tokens_per_expert must be shape (E,)"
|
||||
tokens_per_expert = tokens_per_expert.to(device=y.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
# allocate outputs
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
|
||||
|
||||
# strides (elements)
|
||||
stride_i_e, stride_i_t, stride_i_h = y.stride()
|
||||
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
|
||||
|
||||
# desired scale strides (elements): (T*G, 1, T)
|
||||
stride_ys_e = T * G
|
||||
stride_ys_t = 1
|
||||
stride_ys_g = T
|
||||
y_s = torch.empty_strided((E, T, G),
|
||||
(stride_ys_e, stride_ys_t, stride_ys_g),
|
||||
dtype=torch.float32,
|
||||
device=y.device)
|
||||
|
||||
stride_cnt_e = tokens_per_expert.stride()[0]
|
||||
|
||||
# static grid over experts and H-groups.
|
||||
# A loop inside the kernel handles the token dim
|
||||
grid = (E * G, )
|
||||
|
||||
f_info = torch.finfo(fp8_dtype)
|
||||
fp8_max = f_info.max
|
||||
fp8_min = f_info.min
|
||||
|
||||
_silu_mul_fp8_quant_deep_gemm[grid](
|
||||
y,
|
||||
y_q,
|
||||
y_s,
|
||||
tokens_per_expert,
|
||||
H,
|
||||
group_size,
|
||||
stride_i_e,
|
||||
stride_i_t,
|
||||
stride_i_h,
|
||||
stride_yq_e,
|
||||
stride_yq_t,
|
||||
stride_yq_h,
|
||||
stride_ys_e,
|
||||
stride_ys_t,
|
||||
stride_ys_g,
|
||||
stride_cnt_e,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
BLOCK=group_size,
|
||||
num_warps=4,
|
||||
)
|
||||
|
||||
return y_q, y_s
|
||||
|
||||
|
||||
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
# The Deep Gemm kernels only support block size of 128
|
||||
@ -96,7 +261,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
hidden_states, w1, w2, topk_ids)
|
||||
|
||||
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
|
||||
workspace2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2))
|
||||
|
||||
# (from deepgemm docs) : A value hint (which is a value on CPU)
|
||||
# for the M expectation of each batch, correctly setting this value
|
||||
@ -109,19 +273,9 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
masked_m=expert_num_tokens,
|
||||
expected_m=expected_m)
|
||||
|
||||
# TODO (varun) [Optimization]: Use a batched version of activation.
|
||||
# Similarly for the quant below.
|
||||
self.activation(activation, workspace2, workspace1.view(-1, N))
|
||||
|
||||
w2_hidden_size = workspace2.size(-1)
|
||||
workspace2 = workspace2.view(-1, w2_hidden_size)
|
||||
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
a2q, a2q_scale = per_token_group_quant_fp8(workspace2,
|
||||
self.block_shape[1],
|
||||
column_major_scales=False)
|
||||
a2q = a2q.view(E, max_num_tokens, -1)
|
||||
a2q_scale = a2q_scale.view(E, max_num_tokens, -1)
|
||||
assert expert_num_tokens is not None
|
||||
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
|
||||
expert_num_tokens)
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
|
||||
(w2, w2_scale),
|
||||
|
@ -45,7 +45,8 @@ if current_platform.is_cuda_alike():
|
||||
from .pplx_prepare_finalize import PplxPrepareAndFinalize
|
||||
if has_deepep:
|
||||
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
||||
from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize
|
||||
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
FusedMoEPermuteExpertsUnpermute = None # type: ignore
|
||||
@ -377,6 +378,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
all2all_manager.world_size)
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
# Note : We may want to use FP8 dispatch even otherwise just to
|
||||
# reduce datamovement
|
||||
assert act_quant_block_size is not None
|
||||
use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype()
|
||||
and act_quant_block_size[1]
|
||||
== DEEPEP_QUANT_BLOCK_SIZE)
|
||||
|
||||
# Note (varun): Whether to use FP8 dispatch or not needs some
|
||||
# profiling. Turning it off for now.
|
||||
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
||||
@ -386,7 +394,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
max_tokens_per_rank=moe.max_num_tokens,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=act_quant_block_size,
|
||||
use_fp8_dispatch=False,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
)
|
||||
|
||||
self.topk_indices_dtype = None
|
||||
|
Reference in New Issue
Block a user