mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Kernel][B200] mxfp4
fused cutlass moe (#23696)
Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@ -11,6 +11,7 @@ import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
|
||||
"quark") is not None and version.parse(
|
||||
@ -19,6 +20,10 @@ QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
|
||||
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
|
||||
) and current_platform.is_device_capability(100)
|
||||
|
||||
HOPPER_MXFP4_BF16_AVAILABLE = (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90)
|
||||
and has_flashinfer())
|
||||
|
||||
if TRTLLM_GEN_MXFP4_AVAILABLE:
|
||||
from flashinfer import (fp4_quantize, mxfp8_quantize,
|
||||
next_positive_power_of_2,
|
||||
@ -542,3 +547,317 @@ def test_trtllm_gen_mxfp4_fused_moe(
|
||||
transpose_optimized=transpose_optimized)
|
||||
# relatively loose check since the mxfp4 quantization is less accurate
|
||||
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
|
||||
|
||||
|
||||
def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor:
|
||||
"""Interleave scales on the last dimension by groups of 4, matching
|
||||
the transformation in mxfp4.py's BF16 (Hopper) path."""
|
||||
s = scales.to(torch.uint8)
|
||||
s_shape = s.shape
|
||||
assert s_shape[-1] % 4 == 0
|
||||
s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4)
|
||||
# Move the 4-group dimension before the row dimension
|
||||
permuted = s.permute(0, 2, 1, 3)
|
||||
# Merge the row dim with the 4-group dim
|
||||
return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("topk", [1, 4])
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128])
|
||||
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
|
||||
(1.702, 1.0, 7.0)])
|
||||
@pytest.mark.skipif(
|
||||
not HOPPER_MXFP4_BF16_AVAILABLE,
|
||||
reason="nvidia gpu sm90 and flashinfer are required for this test",
|
||||
)
|
||||
def test_flashinfer_cutlass_mxfp4_fused_moe(
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
num_tokens: int,
|
||||
intermediate_size: int,
|
||||
hidden_size: int,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
limit: Optional[float],
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
device = "cuda:0"
|
||||
|
||||
# Inputs
|
||||
hidden_states = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16)
|
||||
# Random MXFP4 weights and scales (uint8), contiguous [w1; w3]
|
||||
w13_q = torch.randint(
|
||||
0,
|
||||
256, (num_experts, 2 * intermediate_size, hidden_size // 2),
|
||||
device=device,
|
||||
dtype=torch.uint8)
|
||||
w13_scale = torch.randint(
|
||||
118,
|
||||
123, (num_experts, 2 * intermediate_size, hidden_size // 32),
|
||||
device=device,
|
||||
dtype=torch.uint8)
|
||||
|
||||
w2_q = torch.randint(0,
|
||||
256,
|
||||
(num_experts, hidden_size, intermediate_size // 2),
|
||||
device=device,
|
||||
dtype=torch.uint8)
|
||||
w2_scale = torch.randint(
|
||||
118,
|
||||
123, (num_experts, hidden_size, intermediate_size // 32),
|
||||
device=device,
|
||||
dtype=torch.uint8)
|
||||
# Bias contiguous [b1; b3]
|
||||
bias13 = (torch.randn(num_experts,
|
||||
2 * intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16) * 10)
|
||||
bias2 = (torch.randn(
|
||||
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
|
||||
router_logits = torch.rand(num_tokens,
|
||||
num_experts,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
|
||||
w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size)
|
||||
w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape(
|
||||
num_experts, hidden_size, intermediate_size)
|
||||
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
|
||||
hidden_states.to(torch.float32), w13_ref,
|
||||
bias13.to(torch.float32), w2_ref,
|
||||
bias2.to(torch.float32), alpha, beta, limit, 'bf16')
|
||||
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
|
||||
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)
|
||||
|
||||
b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
|
||||
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
|
||||
|
||||
w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1)
|
||||
w13_s = torch.cat([w3_s, w1_s], dim=1)
|
||||
w13_s_inter = _interleave_scales_lastdim_by4(w13_s)
|
||||
w2_s_inter = _interleave_scales_lastdim_by4(w2_scale)
|
||||
|
||||
routing_weights = torch.nn.functional.softmax(router_logits,
|
||||
dim=1,
|
||||
dtype=torch.float32)
|
||||
token_final_scales, token_selected_experts = torch.topk(routing_weights,
|
||||
topk,
|
||||
dim=-1)
|
||||
token_final_scales = (token_final_scales /
|
||||
token_final_scales.sum(dim=-1, keepdim=True))
|
||||
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
|
||||
|
||||
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
||||
if alpha is not None:
|
||||
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
|
||||
if beta is not None:
|
||||
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
|
||||
if limit is not None:
|
||||
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
|
||||
|
||||
_ = flashinfer_cutlass_fused_moe(
|
||||
input=hidden_states,
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
fc1_expert_weights=w13_q_swapped,
|
||||
fc2_expert_weights=w2_q,
|
||||
output_dtype=torch.bfloat16,
|
||||
output=out,
|
||||
quant_scales=[w13_s_inter.to(torch.uint8),
|
||||
w2_s_inter.to(torch.uint8)],
|
||||
fc1_expert_biases=w13_b,
|
||||
fc2_expert_biases=bias2.to(torch.bfloat16),
|
||||
swiglu_alpha=alpha,
|
||||
swiglu_beta=beta,
|
||||
swiglu_limit=limit,
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
use_w4_group_scaling=True,
|
||||
)
|
||||
|
||||
# Allow some mismatch due to MXFP4 quantization
|
||||
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("topk", [1, 4])
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128])
|
||||
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
|
||||
(1.702, 1.0, 7.0)])
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100) and has_flashinfer()),
|
||||
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
|
||||
)
|
||||
def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
num_tokens: int,
|
||||
intermediate_size: int,
|
||||
hidden_size: int,
|
||||
alpha: Optional[float],
|
||||
beta: Optional[float],
|
||||
limit: Optional[float],
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
device = "cuda:0"
|
||||
|
||||
# Inputs
|
||||
hidden_states = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16)
|
||||
# Float weights in w13 format [w1; w3]
|
||||
w13 = (torch.randn(num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16) / 10)
|
||||
w2 = (torch.randn(num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16) / 10)
|
||||
# Bias contiguous [b1; b3]
|
||||
bias13 = (torch.randn(num_experts,
|
||||
2 * intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16) * 10)
|
||||
bias2 = (torch.randn(
|
||||
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
|
||||
router_logits = torch.rand(num_tokens,
|
||||
num_experts,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
|
||||
# Quantize weights to MXFP4 per expert (SM100 path)
|
||||
from flashinfer import mxfp4_quantize
|
||||
|
||||
def quant_mxfp4_batches(a: torch.Tensor, e: int):
|
||||
qs, sfs = [], []
|
||||
for i in range(e):
|
||||
q, sf = mxfp4_quantize(a[i].cuda())
|
||||
qs.append(q)
|
||||
sfs.append(sf)
|
||||
return torch.stack(qs), torch.stack(sfs)
|
||||
|
||||
def dequant_mxfp4_batches(mat_fp4: torch.Tensor,
|
||||
scale_tensor: torch.Tensor):
|
||||
num_batches = mat_fp4.size(0)
|
||||
scale_tensor = scale_tensor.view(num_batches, -1)
|
||||
from flashinfer import mxfp4_dequantize
|
||||
return torch.stack([
|
||||
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
|
||||
for b in range(num_batches)
|
||||
])
|
||||
|
||||
w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts)
|
||||
w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts)
|
||||
|
||||
# Reference result using dequantized tensors and reference_moe
|
||||
w13_ref = dequant_mxfp4_batches(
|
||||
w13_q.view(torch.uint8),
|
||||
w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size)
|
||||
w2_ref = dequant_mxfp4_batches(
|
||||
w2_q.view(torch.uint8),
|
||||
w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
|
||||
num_experts, hidden_size, intermediate_size)
|
||||
|
||||
# Quantize activations for SM100 path and dequantize for reference
|
||||
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)
|
||||
# Reference uses BF16 input but quantizes intermediate activation to MXFP8
|
||||
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
|
||||
hidden_states.to(torch.float32), w13_ref,
|
||||
bias13.to(torch.float32), w2_ref,
|
||||
bias2.to(torch.float32), alpha, beta, limit, 'mxfp8')
|
||||
|
||||
# Prepare inputs for FlashInfer CUTLASS fused MoE
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
|
||||
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)
|
||||
|
||||
# Swap scales halves to match swapped weights
|
||||
s1, s3 = torch.chunk(w13_scale, 2, dim=1)
|
||||
w13_scale_swapped = torch.cat([s3, s1], dim=1)
|
||||
|
||||
b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
|
||||
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
|
||||
|
||||
# Build routing for kernel
|
||||
routing_weights = torch.nn.functional.softmax(router_logits,
|
||||
dim=1,
|
||||
dtype=torch.float32)
|
||||
token_final_scales, token_selected_experts = torch.topk(routing_weights,
|
||||
topk,
|
||||
dim=-1)
|
||||
token_final_scales = (token_final_scales /
|
||||
token_final_scales.sum(dim=-1, keepdim=True))
|
||||
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
|
||||
|
||||
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
||||
if alpha is not None:
|
||||
alpha_t = torch.full((num_experts, ),
|
||||
alpha,
|
||||
device=hidden_states.device)
|
||||
else:
|
||||
alpha_t = None
|
||||
if beta is not None:
|
||||
beta_t = torch.full((num_experts, ), beta, device=hidden_states.device)
|
||||
else:
|
||||
beta_t = None
|
||||
if limit is not None:
|
||||
limit_t = torch.full((num_experts, ),
|
||||
limit,
|
||||
device=hidden_states.device)
|
||||
else:
|
||||
limit_t = None
|
||||
|
||||
# Quant scales for SM100 MXFP8+MXFP4 path
|
||||
fake_input_scale = torch.ones(num_experts, device=device)
|
||||
quant_scales = [
|
||||
w13_scale_swapped.view(torch.int32),
|
||||
fake_input_scale,
|
||||
w2_scale.view(torch.int32),
|
||||
fake_input_scale,
|
||||
]
|
||||
|
||||
_ = flashinfer_cutlass_fused_moe(
|
||||
input=hidden_states_q,
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
fc1_expert_weights=w13_q_swapped.contiguous().view(torch.long),
|
||||
fc2_expert_weights=w2_q.contiguous().view(torch.long),
|
||||
output_dtype=torch.bfloat16,
|
||||
output=out,
|
||||
quant_scales=quant_scales,
|
||||
fc1_expert_biases=w13_b,
|
||||
fc2_expert_biases=bias2.to(torch.bfloat16),
|
||||
swiglu_alpha=alpha_t,
|
||||
swiglu_beta=beta_t,
|
||||
swiglu_limit=limit_t,
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
use_mxfp8_act_scaling=True,
|
||||
input_sf=hidden_states_sf,
|
||||
)
|
||||
|
||||
# Allow some mismatch due to MXFP4 quantization
|
||||
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
|
||||
|
13
vllm/envs.py
13
vllm/envs.py
@ -166,7 +166,8 @@ if TYPE_CHECKING:
|
||||
VLLM_HAS_FLASHINFER_CUBIN: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
|
||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
|
||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
||||
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
||||
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
||||
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
|
||||
@ -1004,6 +1005,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))),
|
||||
|
||||
# If set to 1, use the FlashInfer CUTLASS backend for
|
||||
# MXFP8 (activation) x MXFP4 (weight) MoE.
|
||||
# This is separate from the TRTLLMGEN path controlled by
|
||||
# VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8.
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS":
|
||||
lambda: bool(int(
|
||||
os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "0")
|
||||
)),
|
||||
|
||||
# If set to 1, use the FlashInfer
|
||||
# BF16 (activation) x MXFP4 (weight) MoE backend.
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16":
|
||||
@ -1296,6 +1306,7 @@ def compute_hash() -> str:
|
||||
"VLLM_USE_FLASHINFER_MOE_FP8",
|
||||
"VLLM_USE_FLASHINFER_MOE_FP4",
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
|
||||
"VLLM_USE_CUDNN_PREFILL",
|
||||
"VLLM_USE_TRTLLM_ATTENTION",
|
||||
|
@ -813,9 +813,16 @@ class FusedMoE(CustomOp):
|
||||
|
||||
# we are padding globally so EP buffer allocation works
|
||||
if quant_config and quant_config.get_name() == "mxfp4":
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501
|
||||
should_use_flashinfer_mxfp4)
|
||||
if current_platform.is_rocm() or should_use_flashinfer_mxfp4():
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
Mxfp4Backend, get_mxfp4_backend)
|
||||
current_mxfp4_backend = get_mxfp4_backend()
|
||||
if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
|
||||
or current_mxfp4_backend
|
||||
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS):
|
||||
hidden_size = round_up(hidden_size, 128)
|
||||
elif (current_platform.is_rocm() or current_mxfp4_backend
|
||||
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or
|
||||
current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
|
||||
# For smuggling this layer into the fused moe custom op
|
||||
|
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -33,33 +34,72 @@ from vllm.utils.flashinfer import has_flashinfer
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _should_use_flashinfer_mxfp4_bf16():
|
||||
"""Determine if FlashInfer MXFP4 BF16 should be used."""
|
||||
# If explicitly set, respect the setting
|
||||
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
|
||||
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
|
||||
# enum for mxfp4 backend
|
||||
class Mxfp4Backend(Enum):
|
||||
NONE = 0
|
||||
|
||||
# Enable by default on SM100 if MXFP8 is not explicitly enabled
|
||||
if (current_platform.is_device_capability(100) and has_flashinfer()
|
||||
and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")):
|
||||
logger.info_once(
|
||||
"Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. "
|
||||
"For faster performance, consider setting "
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, "
|
||||
"though this may impact accuracy.")
|
||||
return True
|
||||
# FlashInfer Backend
|
||||
SM100_FI_MXFP4_MXFP8_TRTLLM = 1
|
||||
SM100_FI_MXFP4_MXFP8_CUTLASS = 2
|
||||
SM100_FI_MXFP4_BF16 = 3
|
||||
SM90_FI_MXFP4_BF16 = 4
|
||||
|
||||
return False
|
||||
# Marlin Backend
|
||||
MARLIN = 5
|
||||
|
||||
# Triton Backend
|
||||
TRITON = 6
|
||||
|
||||
|
||||
def _should_use_flashinfer_mxfp4_mxfp8():
|
||||
"""Determine if FlashInfer MXFP4 MXFP8 should be used."""
|
||||
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
def get_mxfp4_backend():
|
||||
# Backend Selection
|
||||
if current_platform.is_cuda():
|
||||
if (current_platform.is_device_capability(90) and has_flashinfer()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
|
||||
return Mxfp4Backend.SM90_FI_MXFP4_BF16
|
||||
elif (current_platform.is_device_capability(100) and has_flashinfer()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS):
|
||||
logger.info_once(
|
||||
"Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
|
||||
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
elif (current_platform.is_device_capability(100) and has_flashinfer()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
|
||||
logger.info_once(
|
||||
"Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, "
|
||||
"for high concurrency throughput workloads consider setting "
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better "
|
||||
"performance")
|
||||
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
elif current_platform.is_device_capability(100) and has_flashinfer():
|
||||
logger.info_once(
|
||||
"Using FlashInfer MXFP4 BF16 backend for SM100, "
|
||||
"For faster performance on SM100, consider setting "
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact "
|
||||
"accuracy.")
|
||||
return Mxfp4Backend.SM100_FI_MXFP4_BF16
|
||||
elif ((current_platform.is_device_capability(100)
|
||||
or current_platform.is_device_capability(90))
|
||||
and not has_flashinfer()):
|
||||
logger.warning_once(
|
||||
"MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer "
|
||||
"is not available. This may result in degraded performance. "
|
||||
"Please `pip install vllm[flashinfer]` for best results.")
|
||||
|
||||
# If FlashInfer is not available, try either Marlin or Triton
|
||||
if current_platform.get_device_capability(
|
||||
)[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer(
|
||||
"2.8.0"):
|
||||
logger.info_once("Using Marlin backend")
|
||||
return Mxfp4Backend.MARLIN
|
||||
else:
|
||||
logger.info_once("Using Triton backend")
|
||||
return Mxfp4Backend.TRITON
|
||||
elif current_platform.is_rocm() and has_triton_kernels():
|
||||
logger.info_once("Using Triton backend")
|
||||
return Mxfp4Backend.TRITON
|
||||
|
||||
def should_use_flashinfer_mxfp4():
|
||||
return (_should_use_flashinfer_mxfp4_mxfp8()
|
||||
or _should_use_flashinfer_mxfp4_bf16())
|
||||
return Mxfp4Backend.NONE
|
||||
|
||||
|
||||
class Mxfp4Config(QuantizationConfig):
|
||||
@ -113,31 +153,15 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
super().__init__(moe)
|
||||
self.topk_indices_dtype = None
|
||||
self.moe = moe
|
||||
self.use_marlin = self._should_use_marlin()
|
||||
self.mxfp4_backend = get_mxfp4_backend()
|
||||
self.max_capture_size = get_current_vllm_config(
|
||||
).compilation_config.max_capture_size
|
||||
|
||||
if current_platform.is_device_capability(100) and not has_flashinfer():
|
||||
logger.warning_once(
|
||||
"MXFP4 MoE is enabled on Blackwell but FlashInfer "
|
||||
"is not available. This may result in degraded performance. "
|
||||
"Please `pip install vllm[flashinfer]` for best results.")
|
||||
assert self.mxfp4_backend != Mxfp4Backend.NONE, (
|
||||
"No MXFP4 MoE backend (FlashInfer/Marlin/Triton) available."
|
||||
"Please check your environment and try again.")
|
||||
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
|
||||
def _should_use_marlin(self):
|
||||
if envs.VLLM_MXFP4_USE_MARLIN is not None:
|
||||
return envs.VLLM_MXFP4_USE_MARLIN
|
||||
if current_platform.is_cuda() and \
|
||||
not current_platform.is_device_capability(100):
|
||||
if not current_platform.has_device_capability(90):
|
||||
# marlin kernel has better performance on ampere
|
||||
return True
|
||||
if not has_triton_kernels():
|
||||
return True
|
||||
if not is_torch_equal_or_newer("2.8.0"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
@ -157,7 +181,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
intermediate_size_per_partition_after_pad = \
|
||||
intermediate_size_per_partition
|
||||
if self.use_marlin:
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
# The moe marlin kernel requires that for each linear
|
||||
# n % 256 == 0 and k % 128 == 0.
|
||||
# In gate_up_proj:
|
||||
@ -175,16 +199,20 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.hidden_size = hidden_size
|
||||
layer.intermediate_size_per_partition = \
|
||||
intermediate_size_per_partition_after_pad
|
||||
elif should_use_flashinfer_mxfp4():
|
||||
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||
# for to hold non-uniform sharded tensor as well as swizzling
|
||||
# other padding to increase performance
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 256)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
elif current_platform.is_rocm():
|
||||
elif current_platform.is_rocm() or (
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 128)
|
||||
hidden_size = round_up(hidden_size, 128)
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 64)
|
||||
@ -264,9 +292,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.use_marlin:
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
elif should_use_flashinfer_mxfp4():
|
||||
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
from flashinfer.fp4_quantization import (
|
||||
nvfp4_block_scale_interleave)
|
||||
from flashinfer.fused_moe.core import (
|
||||
@ -429,7 +458,116 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
|
||||
self.num_experts, -1),
|
||||
requires_grad=False)
|
||||
else:
|
||||
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
|
||||
layer.gemm1_alpha = Parameter(torch.tensor(
|
||||
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
layer.gemm1_beta = Parameter(torch.tensor(
|
||||
[1.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
layer.gemm1_clamp_limit = Parameter(torch.tensor(
|
||||
[7.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
|
||||
sf_block_size = 32 # mxfp4 block size
|
||||
|
||||
# Common shape assertions
|
||||
assert (layer.w13_weight.dim() == 3
|
||||
and layer.w13_weight.shape[0] == self.num_experts
|
||||
and layer.w13_weight.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight.shape[2] == self.hidden_size // 2)
|
||||
assert (layer.w13_weight_scale.dim() == 3
|
||||
and layer.w13_weight_scale.shape[0] == self.num_experts
|
||||
and layer.w13_weight_scale.shape[1]
|
||||
== self.intermediate_size * 2
|
||||
and layer.w13_weight_scale.shape[2]
|
||||
== self.hidden_size // sf_block_size)
|
||||
assert (layer.w2_weight.dim() == 3
|
||||
and layer.w2_weight.shape[0] == self.num_experts
|
||||
and layer.w2_weight.shape[1] == self.hidden_size and
|
||||
layer.w2_weight.shape[2] == self.intermediate_size // 2)
|
||||
assert (layer.w2_weight_scale.dim() == 3
|
||||
and layer.w2_weight_scale.shape[1] == self.hidden_size
|
||||
and layer.w2_weight_scale.shape[2]
|
||||
== self.intermediate_size // sf_block_size)
|
||||
assert (layer.w13_bias.dim() == 2
|
||||
and layer.w13_bias.shape[0] == self.num_experts
|
||||
and layer.w13_bias.shape[1] == self.intermediate_size * 2)
|
||||
assert (layer.w2_bias.dim() == 2
|
||||
and layer.w2_bias.shape[0] == self.num_experts
|
||||
and layer.w2_bias.shape[1] == self.hidden_size)
|
||||
|
||||
# De-interleave and swap for w13 weight, bias, and scales
|
||||
w13_w = layer.w13_weight.data
|
||||
gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :]
|
||||
deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1)
|
||||
w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1)
|
||||
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
|
||||
|
||||
w13_b = layer.w13_bias.data.to(torch.float32)
|
||||
gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2]
|
||||
deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1)
|
||||
b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1)
|
||||
w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
|
||||
|
||||
w13_s = layer.w13_weight_scale.data
|
||||
gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
|
||||
deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
|
||||
s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
|
||||
w13_scale_swapped = torch.cat([s3, s1], dim=1)
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
|
||||
from flashinfer import block_scale_interleave
|
||||
|
||||
orig_shape = w13_scale_swapped.shape
|
||||
w13_scale_interleaved = block_scale_interleave(
|
||||
w13_scale_swapped.view(torch.uint8)).reshape(orig_shape)
|
||||
|
||||
w2_s = layer.w2_weight_scale.data
|
||||
orig_shape = w2_s.shape
|
||||
w2_scale_interleaved = block_scale_interleave(
|
||||
w2_s.view(torch.uint8)).reshape(orig_shape)
|
||||
|
||||
layer.w13_weight = Parameter(w13_weight_swapped,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(w13_scale_interleaved,
|
||||
requires_grad=False)
|
||||
layer.w13_bias = Parameter(w13_bias_swapped,
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = Parameter(w2_scale_interleaved,
|
||||
requires_grad=False)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
|
||||
|
||||
def _interleave_mxfp4_cutlass_sm90(w):
|
||||
w_shape = w.shape
|
||||
w_interleaved = w.reshape(w_shape[0], w_shape[1],
|
||||
(w_shape[2] // 4), 4)
|
||||
w_interleaved = w_interleaved.permute(0, 2, 1, 3)
|
||||
w_interleaved = w_interleaved.reshape(
|
||||
w_shape[0], w_shape[2] // 4, w_shape[1] * 4)
|
||||
return w_interleaved
|
||||
|
||||
w31_scales = w13_scale_swapped.to(torch.uint8).view(
|
||||
torch.uint8)
|
||||
w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(
|
||||
w31_scales)
|
||||
|
||||
w2_weight_scale = layer.w2_weight_scale.data
|
||||
w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8)
|
||||
w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(
|
||||
w2_scales)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(torch.cat([w3_w, w1_w],
|
||||
dim=1),
|
||||
requires_grad=False)
|
||||
layer.w13_bias = torch.nn.Parameter(w13_bias_swapped,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
w31_scales_interleaved, requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
w2_scales_interleaved, requires_grad=False)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
w13_bias = layer.w13_bias.to(torch.float32)
|
||||
@ -464,6 +602,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
||||
# Number of tokens in the input tensor.
|
||||
@ -500,7 +640,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 does not support batched experts format for EP")
|
||||
else:
|
||||
if should_use_flashinfer_mxfp4():
|
||||
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
# B200 code-path
|
||||
kwargs = {
|
||||
"gemm1_alpha": layer.gemm1_alpha,
|
||||
@ -601,7 +742,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.use_marlin:
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@ -665,16 +806,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
logical_replica_count), (
|
||||
"MXFP4 are not supported with this configuration.")
|
||||
|
||||
if should_use_flashinfer_mxfp4():
|
||||
from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
|
||||
if _should_use_flashinfer_mxfp4_bf16():
|
||||
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
from flashinfer import trtllm_fp4_block_scale_moe
|
||||
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16:
|
||||
assert x.dtype == torch.bfloat16
|
||||
x_quant = x
|
||||
x_scale = None
|
||||
else:
|
||||
elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
|
||||
from flashinfer import mxfp8_quantize
|
||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
|
||||
*x.shape[:-1], -1)
|
||||
|
||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||
router_logits.to(torch.bfloat16),
|
||||
None, # routing_bias
|
||||
@ -706,7 +850,86 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
tune_max_num_tokens=self.max_capture_size,
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
else:
|
||||
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
# Backend-specific preparation
|
||||
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
|
||||
|
||||
from flashinfer import mxfp8_quantize
|
||||
|
||||
x_quant, x_scale = mxfp8_quantize(x, True, 32)
|
||||
|
||||
fake_input_scale = torch.ones(self.num_experts,
|
||||
device=x.device)
|
||||
quant_scales = [
|
||||
layer.w13_weight_scale.contiguous().view(torch.int32),
|
||||
fake_input_scale,
|
||||
layer.w2_weight_scale.contiguous().view(torch.int32),
|
||||
fake_input_scale,
|
||||
]
|
||||
|
||||
fi_input = x_quant
|
||||
extra_kwargs = dict(
|
||||
use_mxfp8_act_scaling=True,
|
||||
input_sf=x_scale,
|
||||
fc1_expert_weights=layer.w13_weight.contiguous().view(
|
||||
torch.long),
|
||||
fc2_expert_weights=layer.w2_weight.contiguous().view(
|
||||
torch.long),
|
||||
)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
|
||||
assert x.dtype == torch.bfloat16
|
||||
|
||||
quant_scales = [
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
]
|
||||
|
||||
fi_input = x
|
||||
extra_kwargs = dict(
|
||||
use_w4_group_scaling=True,
|
||||
fc1_expert_weights=layer.w13_weight,
|
||||
fc2_expert_weights=layer.w2_weight,
|
||||
)
|
||||
|
||||
output = torch.empty_like(x, dtype=torch.bfloat16)
|
||||
_ = flashinfer_cutlass_fused_moe(
|
||||
input=fi_input,
|
||||
token_selected_experts=topk_ids.to(torch.int).contiguous(),
|
||||
token_final_scales=topk_weights,
|
||||
output_dtype=torch.bfloat16,
|
||||
output=output,
|
||||
quant_scales=quant_scales,
|
||||
fc1_expert_biases=layer.w13_bias,
|
||||
fc2_expert_biases=layer.w2_bias,
|
||||
swiglu_alpha=layer.gemm1_alpha,
|
||||
swiglu_beta=layer.gemm1_beta,
|
||||
swiglu_limit=layer.gemm1_clamp_limit,
|
||||
tp_size=self.moe.tp_size,
|
||||
tp_rank=self.moe.tp_rank,
|
||||
ep_size=self.moe.ep_size,
|
||||
ep_rank=self.moe.ep_rank,
|
||||
tune_max_num_tokens=self.max_capture_size,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
|
||||
triton_kernel_moe_forward)
|
||||
return triton_kernel_moe_forward(
|
||||
@ -724,3 +947,5 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w2_precision=self.w2_precision_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
||||
|
@ -33,8 +33,8 @@ def kernel_warmup(worker: "Worker"):
|
||||
max_tokens = worker.scheduler_config.max_num_batched_tokens
|
||||
deep_gemm_warmup(model, max_tokens)
|
||||
|
||||
# FlashInfer kernel autotune for Blackwell (SM 10.0) GPUs
|
||||
if has_flashinfer() and current_platform.is_device_capability(100):
|
||||
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
|
||||
if has_flashinfer() and current_platform.has_device_capability(90):
|
||||
flashinfer_autotune(worker.model_runner)
|
||||
|
||||
# FlashInfer attention warmup
|
||||
|
Reference in New Issue
Block a user