mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
994 lines
33 KiB
Python
994 lines
33 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import importlib.metadata
|
|
from dataclasses import dataclass
|
|
from importlib.util import find_spec
|
|
|
|
import pytest
|
|
import torch
|
|
from packaging import version
|
|
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils.flashinfer import has_flashinfer
|
|
|
|
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
|
importlib.metadata.version("amd-quark")
|
|
) >= version.parse("0.8.99")
|
|
|
|
TRTLLM_GEN_MXFP4_AVAILABLE = (
|
|
current_platform.is_cuda() and current_platform.is_device_capability(100)
|
|
)
|
|
|
|
HOPPER_MXFP4_BF16_AVAILABLE = (
|
|
current_platform.is_cuda()
|
|
and current_platform.is_device_capability(90)
|
|
and has_flashinfer()
|
|
)
|
|
|
|
if TRTLLM_GEN_MXFP4_AVAILABLE:
|
|
from flashinfer import (
|
|
fp4_quantize,
|
|
mxfp8_quantize,
|
|
next_positive_power_of_2,
|
|
reorder_rows_for_gated_act_gemm,
|
|
shuffle_matrix_a,
|
|
shuffle_matrix_sf_a,
|
|
trtllm_fp4_block_scale_moe,
|
|
)
|
|
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
|
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
|
|
|
|
|
|
@dataclass
|
|
class ModelCase:
|
|
model_id: str
|
|
tp: int
|
|
|
|
|
|
@pytest.fixture(scope="function", autouse=True)
|
|
def enable_pickle(monkeypatch):
|
|
"""`LLM.apply_model` requires pickling a function."""
|
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"model_case",
|
|
[
|
|
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=2),
|
|
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
|
|
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1),
|
|
ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=1),
|
|
ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=4),
|
|
],
|
|
)
|
|
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
|
|
def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
|
if torch.cuda.device_count() < model_case.tp:
|
|
pytest.skip(
|
|
f"This test requires >={model_case.tp} gpus, got only "
|
|
f"{torch.cuda.device_count()}"
|
|
)
|
|
|
|
# `cuda_graph_sizes=[16]` to reduce load time.
|
|
with vllm_runner(
|
|
model_case.model_id,
|
|
tensor_parallel_size=model_case.tp,
|
|
load_format="dummy",
|
|
cuda_graph_sizes=[16],
|
|
) as llm:
|
|
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
|
|
# def check_model(model):
|
|
# from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
|
|
# QuarkLinearMethod)
|
|
# from vllm.model_executor.layers.quantization.quark.schemes.quark_ocp_mx import QuarkOCP_MX # noqa: E501
|
|
# from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
|
# QuarkOCP_MX_MoEMethod)
|
|
|
|
# layer = model.model.layers[0]
|
|
|
|
# qkv_proj = layer.self_attn.qkv_proj
|
|
|
|
# assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
|
|
# assert isinstance(qkv_proj.scheme, QuarkOCP_MX)
|
|
|
|
# assert isinstance(layer.mlp.experts.quant_method,
|
|
# QuarkOCP_MX_MoEMethod)
|
|
|
|
# if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
|
|
# llm.apply_model(check_model)
|
|
|
|
output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20)
|
|
assert output
|
|
|
|
|
|
def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: float | None = None):
|
|
# Note we add an extra bias of 1 to the linear layer
|
|
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
|
|
if limit is not None:
|
|
x_glu = x_glu.clamp(max=limit)
|
|
x_linear = x_linear.clamp(min=-limit, max=limit)
|
|
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
|
return out_glu * (x_linear + beta)
|
|
|
|
|
|
fp4_lookup_table = [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6]
|
|
|
|
|
|
def mxfp4_dequantize(x, scale):
|
|
assert x.dtype == torch.uint8
|
|
x = x.view(torch.uint8).to(torch.int32)
|
|
x_unpacked = torch.zeros(
|
|
*x.shape[:-1], x.shape[-1] * 2, dtype=torch.int32, device=x.device
|
|
)
|
|
x_unpacked[..., 0::2].copy_(x & 0xF)
|
|
x_unpacked[..., 1::2].copy_((x >> 4) & 0xF)
|
|
|
|
x_float = torch.zeros(x_unpacked.shape, dtype=torch.float32, device=x.device)
|
|
for i, val in enumerate(fp4_lookup_table):
|
|
x_float[x_unpacked == i] = val
|
|
|
|
scale = scale.view(torch.uint8).to(torch.int32)
|
|
scale = (scale << 23).view(torch.float32)
|
|
scale = scale.reshape(*x.shape[:-1], -1)
|
|
scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape)
|
|
|
|
return x_float * scale
|
|
|
|
|
|
def mxfp8_dequantize(x, scale):
|
|
assert x.dtype == torch.float8_e4m3fn
|
|
x_float = x.to(torch.float32)
|
|
|
|
scale = scale.view(torch.uint8).to(torch.int32)
|
|
scale = (scale << 23).view(torch.float32)
|
|
scale = scale.reshape(*x.shape[:-1], -1)
|
|
scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape)
|
|
|
|
return x_float * scale
|
|
|
|
|
|
def reference_moe(
|
|
roouting_logits,
|
|
topk,
|
|
num_experts,
|
|
hidden_states,
|
|
w13,
|
|
bias13,
|
|
w2,
|
|
bias2,
|
|
alpha,
|
|
beta,
|
|
limit,
|
|
act_type,
|
|
):
|
|
# renormalize routing
|
|
experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True)
|
|
expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
|
|
expert_indices = experts.indices
|
|
t = hidden_states.clone()
|
|
# MLP #1
|
|
mlp1_weight = w13[expert_indices, ...]
|
|
mlp1_bias = bias13[expert_indices, ...]
|
|
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
|
|
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
|
|
|
|
if act_type == "mxfp8":
|
|
t_quantized, t_scale = mxfp8_quantize(
|
|
t.to(torch.bfloat16), is_sf_swizzled_layout=False
|
|
)
|
|
t = mxfp8_dequantize(t_quantized, t_scale)
|
|
# MLP #2
|
|
mlp2_weight = w2[expert_indices, ...]
|
|
mlp2_bias = bias2[expert_indices, ...]
|
|
t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias
|
|
# Weighted sum of experts
|
|
t = torch.einsum("bec,be->bc", t, expert_weights)
|
|
assert t.shape == hidden_states.shape
|
|
return t.to(torch.bfloat16)
|
|
|
|
|
|
def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int):
|
|
# Number of tokens in the input tensor.
|
|
num_tokens = x.shape[0]
|
|
# Factor to account for the imbalance of the experts.
|
|
# factor equals to the
|
|
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
|
# - 1.0 means perfect expert distribution.
|
|
# - > 1.0 means some experts have more
|
|
# tokens than the perfect distribution.
|
|
# - < 1.0 does not make sense.
|
|
imbalance_factor = 1.3
|
|
# Calculate the number of tokens per expert
|
|
# assuming perfect distribution.
|
|
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
|
# Apply the imbalance factor.
|
|
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
|
# And pad the number to the next power of 2.
|
|
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
|
# Cap to 8-64 tokens per CTA tile
|
|
# as it's the range supported by the kernel.
|
|
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
|
return tile_tokens_dim
|
|
|
|
|
|
def tg_mxfp4_moe(
|
|
router_logits,
|
|
topk,
|
|
num_experts,
|
|
intermediate_size,
|
|
hidden_size,
|
|
hidden_states,
|
|
hidden_states_scale,
|
|
w13_weight,
|
|
w13_weight_scale,
|
|
w13_bias,
|
|
w2_weight,
|
|
w2_weight_scale,
|
|
w2_bias,
|
|
act_type,
|
|
alpha,
|
|
beta,
|
|
limit,
|
|
transpose_optimized: bool = False,
|
|
) -> torch.Tensor:
|
|
sf_block_size = 32
|
|
assert (
|
|
w13_weight.dim() == 3
|
|
and w13_weight.shape[0] == num_experts
|
|
and w13_weight.shape[1] == intermediate_size * 2
|
|
and w13_weight.shape[2] == hidden_size // 2
|
|
)
|
|
assert (
|
|
w13_weight_scale.dim() == 3
|
|
and w13_weight_scale.shape[0] == num_experts
|
|
and w13_weight_scale.shape[1] == intermediate_size * 2
|
|
and w13_weight_scale.shape[2] == hidden_size // sf_block_size
|
|
)
|
|
assert (
|
|
w2_weight.dim() == 3
|
|
and w2_weight.shape[0] == num_experts
|
|
and w2_weight.shape[1] == hidden_size
|
|
and w2_weight.shape[2] == intermediate_size // 2
|
|
)
|
|
assert (
|
|
w2_weight_scale.dim() == 3
|
|
and w2_weight_scale.shape[1] == hidden_size
|
|
and w2_weight_scale.shape[2] == intermediate_size // sf_block_size
|
|
)
|
|
assert (
|
|
w13_bias.dim() == 2
|
|
and w13_bias.shape[0] == num_experts
|
|
and w13_bias.shape[1] == intermediate_size * 2
|
|
)
|
|
assert (
|
|
w2_bias.dim() == 2
|
|
and w2_bias.shape[0] == num_experts
|
|
and w2_bias.shape[1] == hidden_size
|
|
)
|
|
|
|
# Swap w1 and w3 as the definition of
|
|
# swiglu is different in the trtllm-gen
|
|
w13_weight_scale_ = w13_weight_scale.clone()
|
|
w13_weight_ = w13_weight.clone()
|
|
w13_bias_ = w13_bias.clone()
|
|
w13_weight[:, :intermediate_size, :].copy_(w13_weight_[:, intermediate_size:, :])
|
|
w13_weight[:, intermediate_size:, :].copy_(w13_weight_[:, :intermediate_size, :])
|
|
w13_weight_scale[:, :intermediate_size, :].copy_(
|
|
w13_weight_scale_[:, intermediate_size:, :]
|
|
)
|
|
w13_weight_scale[:, intermediate_size:, :].copy_(
|
|
w13_weight_scale_[:, :intermediate_size, :]
|
|
)
|
|
w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:])
|
|
w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size])
|
|
|
|
# Interleave the weights and scaling factors for activation
|
|
w13_weight_interleaved = []
|
|
w13_weight_scale_interleaved = []
|
|
w13_bias_interleaved = []
|
|
for i in range(num_experts):
|
|
w13_weight_interleaved.append(
|
|
reorder_rows_for_gated_act_gemm(w13_weight[i].clone())
|
|
)
|
|
w13_weight_scale_interleaved.append(
|
|
reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone())
|
|
)
|
|
w13_bias_interleaved.append(
|
|
reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, 1))
|
|
)
|
|
w13_weight = torch.stack(w13_weight_interleaved).reshape(
|
|
num_experts, 2 * intermediate_size, hidden_size // 2
|
|
)
|
|
w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape(
|
|
num_experts, 2 * intermediate_size, hidden_size // 32
|
|
)
|
|
w13_bias = torch.stack(w13_bias_interleaved).reshape(
|
|
num_experts, 2 * intermediate_size
|
|
)
|
|
|
|
# Shuffle weights and scaling factors for transposed mma output
|
|
gemm1_weights_shuffled = []
|
|
gemm1_scales_shuffled = []
|
|
gemm2_weights_shuffled = []
|
|
gemm2_scales_shuffled = []
|
|
gemm1_bias_shuffled = []
|
|
gemm2_bias_shuffled = []
|
|
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
|
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
|
if transpose_optimized:
|
|
for i in range(num_experts):
|
|
# w13 weight shuffling
|
|
permute_indices = get_w2_permute_indices_with_cache(
|
|
_cache_permute_indices,
|
|
w13_weight[i].view(torch.uint8),
|
|
epilogue_tile_m,
|
|
)
|
|
gemm1_weights_shuffled.append(
|
|
w13_weight[i]
|
|
.view(torch.uint8)[permute_indices.to(w13_weight.device)]
|
|
.contiguous()
|
|
)
|
|
# w13 scale shuffling
|
|
permute_sf_indices = get_w2_permute_indices_with_cache(
|
|
_cache_permute_indices,
|
|
w13_weight_scale[i].view(torch.uint8),
|
|
epilogue_tile_m,
|
|
num_elts_per_sf=16,
|
|
)
|
|
gemm1_scales_shuffled.append(
|
|
nvfp4_block_scale_interleave(
|
|
w13_weight_scale[i]
|
|
.view(torch.uint8)[permute_sf_indices.to(w13_weight_scale.device)]
|
|
.contiguous()
|
|
)
|
|
)
|
|
# w13 bias shuffling
|
|
permute_bias_indices = get_w2_permute_indices_with_cache(
|
|
_cache_permute_indices,
|
|
w13_bias[i].clone().reshape(-1, 1),
|
|
epilogue_tile_m,
|
|
)
|
|
gemm1_bias_shuffled.append(
|
|
w13_bias[i]
|
|
.clone()
|
|
.reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
|
|
.contiguous()
|
|
)
|
|
# w2 weight shuffling
|
|
permute_indices = get_w2_permute_indices_with_cache(
|
|
_cache_permute_indices,
|
|
w2_weight[i].view(torch.uint8),
|
|
epilogue_tile_m,
|
|
)
|
|
gemm2_weights_shuffled.append(
|
|
w2_weight[i]
|
|
.view(torch.uint8)[permute_indices.to(w2_weight.device)]
|
|
.contiguous()
|
|
)
|
|
# w2 scale shuffling
|
|
permute_sf_indices = get_w2_permute_indices_with_cache(
|
|
_cache_permute_indices,
|
|
w2_weight_scale[i].view(torch.uint8),
|
|
epilogue_tile_m,
|
|
num_elts_per_sf=16,
|
|
)
|
|
gemm2_scales_shuffled.append(
|
|
nvfp4_block_scale_interleave(
|
|
w2_weight_scale[i]
|
|
.view(torch.uint8)[permute_sf_indices.to(w2_weight_scale.device)]
|
|
.contiguous()
|
|
)
|
|
)
|
|
# w2 bias shuffling
|
|
permute_indices = get_w2_permute_indices_with_cache(
|
|
_cache_permute_indices,
|
|
w2_bias[i].clone().reshape(-1, 1),
|
|
epilogue_tile_m,
|
|
)
|
|
gemm2_bias_shuffled.append(
|
|
w2_bias[i]
|
|
.clone()
|
|
.reshape(-1, 1)[permute_indices.to(w2_bias.device)]
|
|
.contiguous()
|
|
)
|
|
|
|
else:
|
|
for i in range(num_experts):
|
|
gemm1_weights_shuffled.append(
|
|
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)
|
|
)
|
|
gemm1_scales_shuffled.append(
|
|
shuffle_matrix_sf_a(
|
|
w13_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
|
)
|
|
)
|
|
|
|
gemm2_weights_shuffled.append(
|
|
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)
|
|
)
|
|
gemm2_scales_shuffled.append(
|
|
shuffle_matrix_sf_a(
|
|
w2_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
|
)
|
|
)
|
|
gemm1_bias_shuffled.append(
|
|
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)
|
|
)
|
|
gemm2_bias_shuffled.append(
|
|
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)
|
|
)
|
|
|
|
w13_weight = torch.stack(gemm1_weights_shuffled)
|
|
w13_weight_scale = (
|
|
torch.stack(gemm1_scales_shuffled)
|
|
.reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
|
|
.view(torch.float8_e4m3fn)
|
|
)
|
|
w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1)
|
|
|
|
w2_weight = torch.stack(gemm2_weights_shuffled)
|
|
w2_weight_scale = (
|
|
torch.stack(gemm2_scales_shuffled)
|
|
.reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
|
|
.view(torch.float8_e4m3fn)
|
|
)
|
|
w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1)
|
|
|
|
tg_result = trtllm_fp4_block_scale_moe(
|
|
routing_logits=router_logits.to(torch.bfloat16),
|
|
routing_bias=None,
|
|
hidden_states=hidden_states,
|
|
hidden_states_scale=hidden_states_scale,
|
|
gemm1_weights=w13_weight,
|
|
gemm1_weights_scale=w13_weight_scale,
|
|
gemm1_bias=w13_bias,
|
|
gemm1_alpha=alpha,
|
|
gemm1_beta=beta,
|
|
gemm1_clamp_limit=limit,
|
|
gemm2_weights=w2_weight,
|
|
gemm2_weights_scale=w2_weight_scale,
|
|
gemm2_bias=w2_bias,
|
|
output1_scale_scalar=None,
|
|
output1_scale_gate_scalar=None,
|
|
output2_scale_scalar=None,
|
|
num_experts=num_experts,
|
|
top_k=topk,
|
|
n_group=None,
|
|
topk_group=None,
|
|
intermediate_size=intermediate_size,
|
|
local_expert_offset=0,
|
|
local_num_experts=num_experts,
|
|
routed_scaling_factor=None,
|
|
tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts),
|
|
routing_method_type=1, # renormalize
|
|
do_finalize=True,
|
|
)[0]
|
|
return tg_result
|
|
|
|
|
|
def check_accuracy(a, b, atol, rtol, percent):
|
|
"""Allow a mismatch percentage of 1 - percent."""
|
|
if torch.any(torch.isnan(a)):
|
|
raise Exception("NaN in reference output")
|
|
if torch.any(torch.isnan(b)):
|
|
raise Exception("NaN in actual output")
|
|
if torch.any(torch.isinf(a)):
|
|
raise Exception("Inf in reference output")
|
|
if torch.any(torch.isinf(b)):
|
|
raise Exception("Inf in actual output")
|
|
assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}"
|
|
|
|
left = torch.abs(a - b)
|
|
right = atol + rtol * torch.abs(b)
|
|
count = torch.sum(left > right)
|
|
mismatch_percent = count / a.numel()
|
|
if mismatch_percent > 1 - percent:
|
|
raise Exception(
|
|
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
|
|
f"(threshold: {1 - percent:.4f})"
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("topk", [1, 4])
|
|
@pytest.mark.parametrize("num_experts", [32, 128])
|
|
@pytest.mark.parametrize("num_tokens", [1, 128, 1024])
|
|
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
|
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
|
|
@pytest.mark.parametrize("act_type", ["mxfp8", "bf16"])
|
|
@pytest.mark.parametrize("transpose_optimized", [False, True])
|
|
@pytest.mark.skipif(
|
|
not TRTLLM_GEN_MXFP4_AVAILABLE,
|
|
reason="nvidia gpu and compute capability sm100 is required for this test",
|
|
)
|
|
def test_trtllm_gen_mxfp4_fused_moe(
|
|
topk: int,
|
|
num_experts: int,
|
|
num_tokens: int,
|
|
intermediate_size: int,
|
|
hidden_size: int,
|
|
alpha: float,
|
|
beta: float,
|
|
limit: float | None,
|
|
act_type: str,
|
|
transpose_optimized: bool,
|
|
):
|
|
seed = 42
|
|
torch.manual_seed(seed)
|
|
hidden_states = torch.randn(
|
|
num_tokens, hidden_size, device="cuda:0", dtype=torch.bfloat16
|
|
)
|
|
w13 = torch.randn(
|
|
num_experts,
|
|
intermediate_size * 2,
|
|
hidden_size,
|
|
device="cuda:0",
|
|
dtype=torch.bfloat16,
|
|
)
|
|
w2 = torch.randn(
|
|
num_experts,
|
|
hidden_size,
|
|
intermediate_size,
|
|
device="cuda:0",
|
|
dtype=torch.bfloat16,
|
|
)
|
|
bias13 = torch.randn(num_experts, intermediate_size * 2, device="cuda:0") * 10
|
|
bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10
|
|
router_logits = torch.rand(num_tokens, num_experts, dtype=torch.float32).cuda()
|
|
|
|
w13, w13_scale = fp4_quantize(
|
|
w13,
|
|
torch.tensor(1.0, device="cuda:0"),
|
|
32,
|
|
sf_use_ue8m0=True,
|
|
is_sf_swizzled_layout=False,
|
|
)
|
|
w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape(
|
|
num_experts, intermediate_size * 2, hidden_size // 32
|
|
)
|
|
w2, w2_scale = fp4_quantize(
|
|
w2,
|
|
torch.tensor(1.0, device="cuda:0"),
|
|
32,
|
|
sf_use_ue8m0=True,
|
|
is_sf_swizzled_layout=False,
|
|
)
|
|
w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape(
|
|
num_experts, hidden_size, intermediate_size // 32
|
|
)
|
|
if act_type == "mxfp8":
|
|
hidden_states, hidden_states_scale = mxfp8_quantize(
|
|
hidden_states, is_sf_swizzled_layout=False
|
|
)
|
|
hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(-1)
|
|
else:
|
|
hidden_states_scale = None
|
|
|
|
# reference result
|
|
ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
|
w13_ref = mxfp4_dequantize(w13.clone(), w13_scale.clone())
|
|
w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone())
|
|
bias13_ref = bias13
|
|
bias2_ref = bias2
|
|
if act_type == "mxfp8":
|
|
hidden_states_ref = mxfp8_dequantize(hidden_states, hidden_states_scale).to(
|
|
torch.float32
|
|
)
|
|
else:
|
|
hidden_states_ref = hidden_states.to(torch.float32)
|
|
# Process tokens in chunks of 32 to reduce memory usage
|
|
chunk_size = 32
|
|
num_chunks = (num_tokens + chunk_size - 1) // chunk_size
|
|
for i in range(num_chunks):
|
|
start_idx = i * chunk_size
|
|
end_idx = min(start_idx + chunk_size, num_tokens)
|
|
chunk_result = reference_moe(
|
|
router_logits[start_idx:end_idx].to(torch.float32),
|
|
topk,
|
|
num_experts,
|
|
hidden_states_ref[start_idx:end_idx],
|
|
w13_ref,
|
|
bias13_ref,
|
|
w2_ref,
|
|
bias2_ref,
|
|
alpha,
|
|
beta,
|
|
limit,
|
|
act_type,
|
|
)
|
|
ref_result[start_idx:end_idx].copy_(chunk_result)
|
|
|
|
# trtllm-gen result
|
|
if alpha is not None:
|
|
alpha = torch.full((num_experts,), alpha, device=hidden_states.device)
|
|
if limit is not None:
|
|
limit = torch.full((num_experts,), limit, device=hidden_states.device)
|
|
if beta is not None:
|
|
beta = torch.full((num_experts,), beta, device=hidden_states.device)
|
|
tg_result = tg_mxfp4_moe(
|
|
router_logits,
|
|
topk,
|
|
num_experts,
|
|
intermediate_size,
|
|
hidden_size,
|
|
hidden_states,
|
|
hidden_states_scale,
|
|
w13,
|
|
w13_scale,
|
|
bias13,
|
|
w2,
|
|
w2_scale,
|
|
bias2,
|
|
act_type,
|
|
alpha=alpha,
|
|
beta=beta,
|
|
limit=limit,
|
|
transpose_optimized=transpose_optimized,
|
|
)
|
|
# relatively loose check since the mxfp4 quantization is less accurate
|
|
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
|
|
|
|
|
|
def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor:
|
|
"""Interleave scales on the last dimension by groups of 4, matching
|
|
the transformation in mxfp4.py's BF16 (Hopper) path."""
|
|
s = scales.to(torch.uint8)
|
|
s_shape = s.shape
|
|
assert s_shape[-1] % 4 == 0
|
|
s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4)
|
|
# Move the 4-group dimension before the row dimension
|
|
permuted = s.permute(0, 2, 1, 3)
|
|
# Merge the row dim with the 4-group dim
|
|
return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4)
|
|
|
|
|
|
@pytest.mark.parametrize("topk", [1, 4])
|
|
@pytest.mark.parametrize("num_experts", [32])
|
|
@pytest.mark.parametrize("num_tokens", [1, 128])
|
|
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
|
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
|
|
@pytest.mark.skipif(
|
|
not HOPPER_MXFP4_BF16_AVAILABLE,
|
|
reason="nvidia gpu sm90 and flashinfer are required for this test",
|
|
)
|
|
def test_flashinfer_cutlass_mxfp4_fused_moe(
|
|
topk: int,
|
|
num_experts: int,
|
|
num_tokens: int,
|
|
intermediate_size: int,
|
|
hidden_size: int,
|
|
alpha: float,
|
|
beta: float,
|
|
limit: float | None,
|
|
):
|
|
torch.manual_seed(42)
|
|
device = "cuda:0"
|
|
|
|
# Inputs
|
|
hidden_states = torch.randn(
|
|
num_tokens, hidden_size, device=device, dtype=torch.bfloat16
|
|
)
|
|
# Random MXFP4 weights and scales (uint8), contiguous [w1; w3]
|
|
w13_q = torch.randint(
|
|
0,
|
|
256,
|
|
(num_experts, 2 * intermediate_size, hidden_size // 2),
|
|
device=device,
|
|
dtype=torch.uint8,
|
|
)
|
|
w13_scale = torch.randint(
|
|
118,
|
|
123,
|
|
(num_experts, 2 * intermediate_size, hidden_size // 32),
|
|
device=device,
|
|
dtype=torch.uint8,
|
|
)
|
|
|
|
w2_q = torch.randint(
|
|
0,
|
|
256,
|
|
(num_experts, hidden_size, intermediate_size // 2),
|
|
device=device,
|
|
dtype=torch.uint8,
|
|
)
|
|
w2_scale = torch.randint(
|
|
118,
|
|
123,
|
|
(num_experts, hidden_size, intermediate_size // 32),
|
|
device=device,
|
|
dtype=torch.uint8,
|
|
)
|
|
# Bias contiguous [b1; b3]
|
|
bias13 = (
|
|
torch.randn(
|
|
num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16
|
|
)
|
|
* 10
|
|
)
|
|
bias2 = (
|
|
torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10
|
|
)
|
|
router_logits = torch.rand(
|
|
num_tokens, num_experts, dtype=torch.float32, device=device
|
|
)
|
|
|
|
w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape(
|
|
num_experts, 2 * intermediate_size, hidden_size
|
|
)
|
|
w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape(
|
|
num_experts, hidden_size, intermediate_size
|
|
)
|
|
ref = reference_moe(
|
|
router_logits.to(torch.float32),
|
|
topk,
|
|
num_experts,
|
|
hidden_states.to(torch.float32),
|
|
w13_ref,
|
|
bias13.to(torch.float32),
|
|
w2_ref,
|
|
bias2.to(torch.float32),
|
|
alpha,
|
|
beta,
|
|
limit,
|
|
"bf16",
|
|
)
|
|
|
|
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
|
|
|
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
|
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
|
|
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)
|
|
|
|
b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
|
|
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
|
|
|
|
w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1)
|
|
w13_s = torch.cat([w3_s, w1_s], dim=1)
|
|
w13_s_inter = _interleave_scales_lastdim_by4(w13_s)
|
|
w2_s_inter = _interleave_scales_lastdim_by4(w2_scale)
|
|
|
|
routing_weights = torch.nn.functional.softmax(
|
|
router_logits, dim=1, dtype=torch.float32
|
|
)
|
|
token_final_scales, token_selected_experts = torch.topk(
|
|
routing_weights, topk, dim=-1
|
|
)
|
|
token_final_scales = token_final_scales / token_final_scales.sum(
|
|
dim=-1, keepdim=True
|
|
)
|
|
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
|
|
|
|
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
|
if alpha is not None:
|
|
alpha = torch.full((num_experts,), alpha, device=hidden_states.device)
|
|
if beta is not None:
|
|
beta = torch.full((num_experts,), beta, device=hidden_states.device)
|
|
if limit is not None:
|
|
limit = torch.full((num_experts,), limit, device=hidden_states.device)
|
|
|
|
_ = flashinfer_cutlass_fused_moe(
|
|
input=hidden_states,
|
|
token_selected_experts=token_selected_experts,
|
|
token_final_scales=token_final_scales,
|
|
fc1_expert_weights=w13_q_swapped,
|
|
fc2_expert_weights=w2_q,
|
|
output_dtype=torch.bfloat16,
|
|
output=out,
|
|
quant_scales=[w13_s_inter.to(torch.uint8), w2_s_inter.to(torch.uint8)],
|
|
fc1_expert_biases=w13_b,
|
|
fc2_expert_biases=bias2.to(torch.bfloat16),
|
|
swiglu_alpha=alpha,
|
|
swiglu_beta=beta,
|
|
swiglu_limit=limit,
|
|
tp_size=1,
|
|
tp_rank=0,
|
|
ep_size=1,
|
|
ep_rank=0,
|
|
use_w4_group_scaling=True,
|
|
)
|
|
|
|
# Allow some mismatch due to MXFP4 quantization
|
|
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
|
|
|
|
|
|
@pytest.mark.parametrize("topk", [1, 4])
|
|
@pytest.mark.parametrize("num_experts", [32])
|
|
@pytest.mark.parametrize("num_tokens", [1, 128])
|
|
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
|
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
|
|
@pytest.mark.skipif(
|
|
not (
|
|
current_platform.is_cuda()
|
|
and current_platform.is_device_capability(100)
|
|
and has_flashinfer()
|
|
),
|
|
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
|
|
)
|
|
def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
|
topk: int,
|
|
num_experts: int,
|
|
num_tokens: int,
|
|
intermediate_size: int,
|
|
hidden_size: int,
|
|
alpha: float | None,
|
|
beta: float | None,
|
|
limit: float | None,
|
|
):
|
|
torch.manual_seed(42)
|
|
device = "cuda:0"
|
|
|
|
# Inputs
|
|
hidden_states = torch.randn(
|
|
num_tokens, hidden_size, device=device, dtype=torch.bfloat16
|
|
)
|
|
# Float weights in w13 format [w1; w3]
|
|
w13 = (
|
|
torch.randn(
|
|
num_experts,
|
|
2 * intermediate_size,
|
|
hidden_size,
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
/ 10
|
|
)
|
|
w2 = (
|
|
torch.randn(
|
|
num_experts,
|
|
hidden_size,
|
|
intermediate_size,
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
/ 10
|
|
)
|
|
# Bias contiguous [b1; b3]
|
|
bias13 = (
|
|
torch.randn(
|
|
num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16
|
|
)
|
|
* 10
|
|
)
|
|
bias2 = (
|
|
torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10
|
|
)
|
|
router_logits = torch.rand(
|
|
num_tokens, num_experts, dtype=torch.float32, device=device
|
|
)
|
|
|
|
# Quantize weights to MXFP4 per expert (SM100 path)
|
|
from flashinfer import mxfp4_quantize
|
|
|
|
def quant_mxfp4_batches(a: torch.Tensor, e: int):
|
|
qs, sfs = [], []
|
|
for i in range(e):
|
|
q, sf = mxfp4_quantize(a[i].cuda())
|
|
qs.append(q)
|
|
sfs.append(sf)
|
|
return torch.stack(qs), torch.stack(sfs)
|
|
|
|
def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor):
|
|
num_batches = mat_fp4.size(0)
|
|
scale_tensor = scale_tensor.view(num_batches, -1)
|
|
from flashinfer import mxfp4_dequantize
|
|
|
|
return torch.stack(
|
|
[
|
|
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
|
|
for b in range(num_batches)
|
|
]
|
|
)
|
|
|
|
w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts)
|
|
w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts)
|
|
|
|
# Reference result using dequantized tensors and reference_moe
|
|
w13_ref = (
|
|
dequant_mxfp4_batches(
|
|
w13_q.view(torch.uint8), w13_scale.view(torch.uint8).reshape(-1)
|
|
)
|
|
.to(torch.float32)
|
|
.reshape(num_experts, 2 * intermediate_size, hidden_size)
|
|
.to(device)
|
|
)
|
|
w2_ref = (
|
|
dequant_mxfp4_batches(
|
|
w2_q.view(torch.uint8), w2_scale.view(torch.uint8).reshape(-1)
|
|
)
|
|
.to(torch.float32)
|
|
.reshape(num_experts, hidden_size, intermediate_size)
|
|
.to(device)
|
|
)
|
|
|
|
# Quantize activations for SM100 path and dequantize for reference
|
|
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)
|
|
# Reference uses BF16 input but quantizes intermediate activation to MXFP8
|
|
ref = reference_moe(
|
|
router_logits.to(torch.float32),
|
|
topk,
|
|
num_experts,
|
|
hidden_states.to(torch.float32),
|
|
w13_ref,
|
|
bias13.to(torch.float32),
|
|
w2_ref,
|
|
bias2.to(torch.float32),
|
|
alpha,
|
|
beta,
|
|
limit,
|
|
"mxfp8",
|
|
)
|
|
|
|
# Prepare inputs for FlashInfer CUTLASS fused MoE
|
|
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
|
|
|
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
|
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
|
|
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)
|
|
|
|
# Swap scales halves to match swapped weights
|
|
s1, s3 = torch.chunk(w13_scale, 2, dim=1)
|
|
w13_scale_swapped = torch.cat([s3, s1], dim=1)
|
|
|
|
b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
|
|
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
|
|
|
|
# Build routing for kernel
|
|
routing_weights = torch.nn.functional.softmax(
|
|
router_logits, dim=1, dtype=torch.float32
|
|
)
|
|
token_final_scales, token_selected_experts = torch.topk(
|
|
routing_weights, topk, dim=-1
|
|
)
|
|
token_final_scales = token_final_scales / token_final_scales.sum(
|
|
dim=-1, keepdim=True
|
|
)
|
|
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
|
|
|
|
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
|
if alpha is not None:
|
|
alpha_t = torch.full((num_experts,), alpha, device=hidden_states.device)
|
|
else:
|
|
alpha_t = None
|
|
if beta is not None:
|
|
beta_t = torch.full((num_experts,), beta, device=hidden_states.device)
|
|
else:
|
|
beta_t = None
|
|
if limit is not None:
|
|
limit_t = torch.full((num_experts,), limit, device=hidden_states.device)
|
|
else:
|
|
limit_t = None
|
|
|
|
# Quant scales for SM100 MXFP8+MXFP4 path
|
|
fake_input_scale = torch.ones(num_experts, device=device)
|
|
quant_scales = [
|
|
w13_scale_swapped.view(torch.int32),
|
|
fake_input_scale,
|
|
w2_scale.view(torch.int32),
|
|
fake_input_scale,
|
|
]
|
|
|
|
_ = flashinfer_cutlass_fused_moe(
|
|
input=hidden_states_q,
|
|
token_selected_experts=token_selected_experts,
|
|
token_final_scales=token_final_scales,
|
|
fc1_expert_weights=w13_q_swapped.contiguous().view(torch.long),
|
|
fc2_expert_weights=w2_q.contiguous().view(torch.long),
|
|
output_dtype=torch.bfloat16,
|
|
output=out,
|
|
quant_scales=quant_scales,
|
|
fc1_expert_biases=w13_b,
|
|
fc2_expert_biases=bias2.to(torch.bfloat16),
|
|
swiglu_alpha=alpha_t,
|
|
swiglu_beta=beta_t,
|
|
swiglu_limit=limit_t,
|
|
tp_size=1,
|
|
tp_rank=0,
|
|
ep_size=1,
|
|
ep_rank=0,
|
|
use_mxfp8_act_scaling=True,
|
|
input_sf=hidden_states_sf,
|
|
)
|
|
|
|
# Allow some mismatch due to MXFP4 quantization
|
|
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
|