[NVIDIA] Add SM100 Flashinfer Cutlass MoE fp8 backend (#22357)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
This commit is contained in:
@ -630,6 +630,7 @@ steps:
|
||||
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
|
||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
|
||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
|
||||
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/compilation/fusion.py
|
||||
- vllm/compilation/fusion_attn.py
|
||||
@ -650,6 +651,7 @@ steps:
|
||||
# Fusion
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
|
||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
248
tests/kernels/moe/test_flashinfer.py
Normal file
248
tests/kernels/moe/test_flashinfer.py
Normal file
@ -0,0 +1,248 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8,
|
||||
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
|
||||
swap_w13_to_w31)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
input_to_float8)
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
if not has_flashinfer_cutlass_fused_moe(
|
||||
) or not current_platform.has_device_capability(100):
|
||||
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
|
||||
allow_module_level=True)
|
||||
|
||||
NUM_EXPERTS = [16]
|
||||
TOP_KS = [1]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(256, 8192, 5120),
|
||||
(256, 4096, 5120),
|
||||
(127, 8192, 5120),
|
||||
(127, 4096, 5120),
|
||||
(10, 8192, 5120),
|
||||
(10, 4096, 5120),
|
||||
(1, 8192, 5120),
|
||||
(1, 4096, 5120),
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
def quant_fp8_per_tensor_batches(a):
|
||||
num_batches = a.size(0)
|
||||
a_quant = []
|
||||
a_scales = []
|
||||
|
||||
for i in range(num_batches):
|
||||
a_fp8, a_global_sf = input_to_float8(a[i])
|
||||
a_global_sf = 1.0 / a_global_sf
|
||||
a_quant.append(a_fp8)
|
||||
a_scales.append(a_global_sf)
|
||||
|
||||
result_a_quant = torch.stack(a_quant)
|
||||
result_a_scales = torch.stack(a_scales)
|
||||
|
||||
return result_a_quant, result_a_scales
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestData:
|
||||
hidden_states: torch.Tensor
|
||||
w13_quantized: torch.Tensor
|
||||
w2_quantized: torch.Tensor
|
||||
a1_scale: torch.Tensor
|
||||
a2_scale: torch.Tensor
|
||||
w13_weight_scale: torch.Tensor
|
||||
w2_weight_scale: torch.Tensor
|
||||
layer: torch.nn.Module
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
|
||||
reorder: bool) -> "TestData":
|
||||
hidden_states = torch.randn(
|
||||
(m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Scale to fp8
|
||||
_, a1_scale = input_to_float8(hidden_states)
|
||||
a1_scale = 1.0 / a1_scale
|
||||
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(
|
||||
dtype=torch.float32)
|
||||
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
|
||||
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
|
||||
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = w13_quantized.clone()
|
||||
layer.w2_weight = w2_quantized.clone()
|
||||
layer.w13_input_scale = a1_scale
|
||||
layer.w2_input_scale = a2_scale
|
||||
layer.w13_weight_scale = w13_weight_scale
|
||||
layer.w2_weight_scale = w2_weight_scale
|
||||
|
||||
register_moe_scaling_factors(layer)
|
||||
|
||||
# flashinfer expects swapped rows for w13
|
||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||
if reorder:
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
||||
layer.w2_weight)
|
||||
layer.custom_routing_function = Llama4MoE.custom_routing_function
|
||||
layer.intermediate_size_per_partition = n
|
||||
layer.ep_rank = 0
|
||||
layer.local_num_experts = e
|
||||
|
||||
return TestData(
|
||||
hidden_states=hidden_states,
|
||||
w13_quantized=w13_quantized,
|
||||
w2_quantized=w2_quantized,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
w13_weight_scale=w13_weight_scale,
|
||||
w2_weight_scale=w2_weight_scale,
|
||||
layer=layer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
monkeypatch,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=td.hidden_states,
|
||||
router_logits=score,
|
||||
use_grouped_topk=False,
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
scoring_func="softmax")
|
||||
|
||||
output = fused_experts(
|
||||
td.hidden_states,
|
||||
td.w13_quantized,
|
||||
td.w2_quantized,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=False,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
w1_scale=td.w13_weight_scale,
|
||||
w2_scale=td.w2_weight_scale,
|
||||
a1_scale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer=td.layer,
|
||||
hidden_states=td.hidden_states,
|
||||
router_logits=score,
|
||||
routing_bias=None,
|
||||
global_num_experts=e,
|
||||
top_k=topk,
|
||||
num_expert_group=None,
|
||||
topk_group=None,
|
||||
apply_router_weight_on_input=True)
|
||||
|
||||
torch.testing.assert_close(output,
|
||||
flashinfer_output,
|
||||
atol=5.5e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472"
|
||||
)
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
monkeypatch,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=td.hidden_states,
|
||||
router_logits=score,
|
||||
use_grouped_topk=False,
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
scoring_func="softmax")
|
||||
|
||||
output = fused_experts(
|
||||
td.hidden_states,
|
||||
td.w13_quantized,
|
||||
td.w2_quantized,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=False,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
w1_scale=td.w13_weight_scale,
|
||||
w2_scale=td.w2_weight_scale,
|
||||
a1_scale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
td.layer.dp_size = 1
|
||||
|
||||
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
|
||||
td.hidden_states,
|
||||
td.layer,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation="silu",
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output,
|
||||
flashinfer_cutlass_output,
|
||||
atol=5.5e-2,
|
||||
rtol=1e-2)
|
@ -61,8 +61,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
per_act_token_quant=False,
|
||||
block_shape=None,
|
||||
))
|
||||
assert quant_dtype == "nvfp4", ("Only nvfp4 quantization is "
|
||||
"currently supported.")
|
||||
assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
|
||||
"Only nvfp4,fp8 quantization are currently supported.")
|
||||
self.ep_rank = ep_rank
|
||||
self.ep_size = ep_size
|
||||
self.tp_rank = tp_rank
|
||||
@ -122,7 +122,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
"""
|
||||
aq_m, aq_n = aq.shape
|
||||
workspace2 = ()
|
||||
output_shape = (aq_m, aq_n * 2)
|
||||
output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \
|
||||
torch.float8_e4m3fn else (aq_m, aq_n)
|
||||
workspace_dtype = a.dtype
|
||||
workspace1 = output_shape
|
||||
# The workspace is determined by `aq`, since it comes after any
|
||||
@ -151,29 +152,39 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: Optional[bool],
|
||||
):
|
||||
# Flashinfer CUTLASS kernel takes scalar global scales,
|
||||
# min because inv_scale.
|
||||
if self.quant_dtype == torch.float8_e4m3fn:
|
||||
quant_scales = [
|
||||
self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale
|
||||
]
|
||||
|
||||
# Ensure w1_scale and w2_scale are not None before calling view
|
||||
assert w1_scale is not None and w2_scale is not None, (
|
||||
"w1_scale and w2_scale must not "
|
||||
"be None for FlashInferExperts")
|
||||
a1q_scale = None # not passing input_sf in fp8
|
||||
fc1_expert_weights = w1
|
||||
fc2_expert_weights = w2
|
||||
else:
|
||||
# Ensure w1_scale and w2_scale are not None before calling view
|
||||
assert w1_scale is not None and w2_scale is not None, (
|
||||
"w1_scale and w2_scale must not "
|
||||
"be None for FlashInferExperts")
|
||||
# Flashinfer CUTLASS kernel takes scalar global scales,
|
||||
# min because inv_scale.
|
||||
quant_scales = [
|
||||
self.a1_gscale,
|
||||
w1_scale.view(torch.int32),
|
||||
self.g1_alphas,
|
||||
self.a2_gscale,
|
||||
w2_scale.view(torch.int32),
|
||||
self.g2_alphas,
|
||||
]
|
||||
# FlashInfer API requires weight to be long for nvfp4
|
||||
fc1_expert_weights = w1.view(torch.long)
|
||||
fc2_expert_weights = w2.view(torch.long)
|
||||
|
||||
quant_scales = [
|
||||
self.a1_gscale,
|
||||
w1_scale.view(torch.int32),
|
||||
self.g1_alphas,
|
||||
self.a2_gscale,
|
||||
w2_scale.view(torch.int32),
|
||||
self.g2_alphas,
|
||||
]
|
||||
_ = flashinfer_cutlass_fused_moe(
|
||||
input=hidden_states,
|
||||
token_selected_experts=topk_ids.to(torch.int),
|
||||
token_final_scales=topk_weights,
|
||||
# FlashInfer API requires weight to be long for nvfp4
|
||||
fc1_expert_weights=w1.view(torch.long),
|
||||
fc2_expert_weights=w2.view(torch.long),
|
||||
fc1_expert_weights=fc1_expert_weights,
|
||||
fc2_expert_weights=fc2_expert_weights,
|
||||
output_dtype=self.out_dtype,
|
||||
quant_scales=quant_scales,
|
||||
input_sf=a1q_scale,
|
||||
|
@ -9,6 +9,7 @@ from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
@ -23,8 +24,11 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
|
||||
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
|
||||
FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8,
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||
flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend,
|
||||
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
|
||||
select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
@ -145,7 +149,7 @@ class Fp8Config(QuantizationConfig):
|
||||
return UnquantizedLinearMethod()
|
||||
return Fp8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Fp8MoEMethod(self, layer.moe_config)
|
||||
return Fp8MoEMethod(self, layer)
|
||||
elif isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
@ -482,16 +486,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
||||
super().__init__(layer.moe_config)
|
||||
self.layer = layer
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
self.flashinfer_moe_enabled = False
|
||||
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
||||
self.fused_experts: Optional[
|
||||
mk.FusedMoEModularKernel] = None # type: ignore
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
logger.info_once(
|
||||
"Using FlashInfer MoE FP8 kernels for Fp8MoEMethod.")
|
||||
self.flashinfer_moe_enabled = True
|
||||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||
)
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
@ -531,6 +539,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"CutlassBlockScaledGroupedGemm not supported on the current "
|
||||
"platform.")
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
moe: FusedMoEConfig,
|
||||
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
|
||||
return super().maybe_make_prepare_finalize(moe)
|
||||
|
||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe,
|
||||
layer=self.layer,
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
@ -678,7 +700,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_weight_scale_inv,
|
||||
layer.w2_input_scale)
|
||||
elif self.flashinfer_moe_enabled:
|
||||
elif self.flashinfer_moe_backend is not None:
|
||||
# NOTE: weights have to be swapped since the activation is
|
||||
# applied on different half for flashinfer vs vllm
|
||||
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
|
||||
@ -686,9 +708,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight_scale_inv.data)
|
||||
w2_weight = layer.w2_weight.data
|
||||
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
|
||||
if not self.block_quant:
|
||||
register_moe_scaling_factors(layer)
|
||||
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
||||
else:
|
||||
w13_weight = layer.w13_weight.data
|
||||
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
|
||||
@ -834,6 +853,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
|
||||
if self.flashinfer_moe_backend is not None:
|
||||
# NOTE: weights have to be swapped since the activation is
|
||||
# applied on different half for flashinfer vs vllm
|
||||
assert not self.block_quant
|
||||
register_moe_scaling_factors(layer)
|
||||
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
|
||||
if self.flashinfer_moe_backend == \
|
||||
FlashinferMoeBackend.TENSORRT_LLM:
|
||||
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
||||
layer.w13_weight.data = w13_weight.data
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||
# Activations not quantized for marlin.
|
||||
@ -892,6 +922,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
per_act_token_quant=False,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
experts = select_cutlass_fp8_gemm_impl(
|
||||
moe,
|
||||
self.layer,
|
||||
)
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
else:
|
||||
logger.debug(
|
||||
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
|
||||
@ -930,25 +967,66 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
assert isinstance(layer, FusedMoE)
|
||||
if not self.flashinfer_moe_enabled:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
assert activation == 'silu', (
|
||||
f"Expected 'silu' activation but got {activation}")
|
||||
assert scoring_func == 'sigmoid', (
|
||||
f"Expected 'sigmoid' scoring func but got {scoring_func}")
|
||||
if self.block_quant:
|
||||
assert (renormalize and use_grouped_topk
|
||||
and custom_routing_function is None)
|
||||
|
||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits=router_logits.to(torch.float32),
|
||||
routing_bias=e_score_correction_bias,
|
||||
x=x,
|
||||
w13_weight=layer.w13_weight,
|
||||
w13_weight_scale_inv=layer.w13_weight_scale_inv,
|
||||
w2_weight=layer.w2_weight,
|
||||
w2_weight_scale_inv=layer.w2_weight_scale_inv,
|
||||
global_num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
routed_scaling=1.0,
|
||||
)
|
||||
else:
|
||||
assert (not renormalize
|
||||
and custom_routing_function is not None)
|
||||
return apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
@ -988,63 +1066,38 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
elif self.flashinfer_moe_enabled:
|
||||
assert activation == 'silu'
|
||||
assert scoring_func == 'sigmoid'
|
||||
if self.block_quant:
|
||||
assert (renormalize and use_grouped_topk
|
||||
and custom_routing_function is None)
|
||||
|
||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits=router_logits.to(torch.float32),
|
||||
routing_bias=e_score_correction_bias,
|
||||
x=x,
|
||||
w13_weight=layer.w13_weight,
|
||||
w13_weight_scale_inv=layer.w13_weight_scale_inv,
|
||||
w2_weight=layer.w2_weight,
|
||||
w2_weight_scale_inv=layer.w2_weight_scale_inv,
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
assert self.block_quant is None
|
||||
assert (not renormalize and custom_routing_function is not None)
|
||||
assert activation == 'silu', (
|
||||
f"Expected 'silu' activation but got {activation}")
|
||||
assert scoring_func == 'sigmoid', (
|
||||
f"Expected 'sigmoid' scoring func but got {scoring_func}")
|
||||
if self.fused_experts is not None:
|
||||
return self.fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
routed_scaling=1.0,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
assert (not renormalize
|
||||
and custom_routing_function is not None)
|
||||
return apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
return flashinfer_cutlass_moe_fp8(
|
||||
x,
|
||||
layer,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
elif self.fused_experts is not None:
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
return fused_experts(
|
||||
|
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -27,8 +26,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
|
||||
select_nvfp4_gemm_impl)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
|
||||
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
|
||||
FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8,
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||
flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend,
|
||||
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
|
||||
select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear, is_fp4_marlin_supported,
|
||||
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
|
||||
@ -49,11 +51,6 @@ QUANT_ALGOS = ["FP8", "NVFP4"]
|
||||
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
||||
|
||||
|
||||
class FlashinferMoeBackend(Enum):
|
||||
TENSORRT_LLM = "TensorRT-LLM"
|
||||
CUTLASS = "CUTLASS"
|
||||
|
||||
|
||||
class ModelOptFp8Config(QuantizationConfig):
|
||||
"""Config class for ModelOpt FP8."""
|
||||
|
||||
@ -179,7 +176,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
elif isinstance(layer, Attention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return ModelOptFp8MoEMethod(self, layer.moe_config)
|
||||
return ModelOptFp8MoEMethod(self, layer)
|
||||
return None
|
||||
|
||||
|
||||
@ -278,18 +275,49 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: ModelOptFp8Config,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> None:
|
||||
super().__init__(moe)
|
||||
super().__init__(layer.moe_config)
|
||||
self.layer = layer
|
||||
self.quant_config = quant_config
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_fp8_supported)
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
self.flashinfer_moe_enabled = False
|
||||
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
||||
self.fused_experts: Optional[
|
||||
mk.FusedMoEModularKernel] = None # type: ignore
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
logger.info_once(
|
||||
"Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.")
|
||||
self.flashinfer_moe_enabled = True
|
||||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
moe: FusedMoEConfig,
|
||||
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if self.fused_experts is not None or \
|
||||
self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
|
||||
return super().maybe_make_prepare_finalize(moe)
|
||||
|
||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe,
|
||||
layer=self.layer,
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
experts = select_cutlass_fp8_gemm_impl(
|
||||
moe,
|
||||
self.layer,
|
||||
)
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -433,11 +461,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
|
||||
requires_grad=False)
|
||||
|
||||
if self.flashinfer_moe_enabled:
|
||||
if self.flashinfer_moe_backend is not None:
|
||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
||||
layer.w2_weight)
|
||||
register_moe_scaling_factors(layer)
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
||||
layer.w2_weight)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -461,14 +490,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
|
||||
|
||||
if self.flashinfer_moe_enabled:
|
||||
assert activation == 'silu'
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
assert activation == 'silu', (
|
||||
f"Expected 'silu' activation but got {activation}")
|
||||
assert not renormalize
|
||||
return apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer=layer,
|
||||
@ -495,6 +523,36 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
assert not renormalize
|
||||
assert activation == 'silu', (
|
||||
f"Expected 'silu' activation but got {activation}")
|
||||
if self.fused_experts is not None:
|
||||
return self.fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
return flashinfer_cutlass_moe_fp8(
|
||||
x,
|
||||
layer,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts)
|
||||
return fused_experts(
|
||||
@ -951,20 +1009,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
self.flashinfer_moe_backend = None
|
||||
|
||||
if self.allow_flashinfer:
|
||||
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||
if flashinfer_moe_backend == "throughput":
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||
logger.info_once("Using FlashInfer CUTLASS kernels for "
|
||||
"ModelOptNvFp4FusedMoE.")
|
||||
elif flashinfer_moe_backend == "latency":
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||
logger.info_once("Using FlashInfer TensorRT-LLM kernels for "
|
||||
"ModelOptNvFp4FusedMoE.")
|
||||
else:
|
||||
allowed_backends = ["throughput", "latency"]
|
||||
raise ValueError(
|
||||
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
|
||||
f" expected one of {allowed_backends}")
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
logger.info_once(
|
||||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||
" for ModelOptNvFp4FusedMoE.")
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
|
@ -1,9 +1,26 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashinferMoeBackend(Enum):
|
||||
TENSORRT_LLM = "TensorRT-LLM"
|
||||
CUTLASS = "CUTLASS"
|
||||
|
||||
|
||||
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
|
||||
@ -144,3 +161,98 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
|
||||
layer.register_parameter(
|
||||
'output2_scales_scalar',
|
||||
torch.nn.Parameter(output2_scales, requires_grad=False))
|
||||
layer.register_parameter(
|
||||
'w2_input_scale_inv',
|
||||
torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False))
|
||||
|
||||
|
||||
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe: Optional[FusedMoEConfig],
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
|
||||
return FlashInferCutlassMoEPrepareAndFinalize(
|
||||
use_dp, a1_gscale=layer.w13_input_scale)
|
||||
|
||||
|
||||
def select_cutlass_fp8_gemm_impl(
|
||||
moe: Optional[FusedMoEConfig],
|
||||
layer: torch.nn.Module,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
"""Return a GEMM *experts* implementation for fused-MoE layers"""
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
|
||||
"FusedMoE flashinfer kernels are only supported for Llama4"
|
||||
|
||||
if moe is not None:
|
||||
return FlashInferExperts(
|
||||
g1_alphas=layer.output1_scales_gate_scalar,
|
||||
g2_alphas=layer.output2_scales_scalar,
|
||||
a1_gscale=layer.w13_input_scale,
|
||||
a2_gscale=layer.w2_input_scale_inv,
|
||||
out_dtype=moe.in_dtype,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
tp_size=moe.moe_parallel_config.tp_size,
|
||||
)
|
||||
|
||||
assert out_dtype is not None, (
|
||||
"If moe config is None, out_dtype must be passed")
|
||||
return FlashInferExperts(
|
||||
g1_alphas=layer.output1_scales_gate_scalar,
|
||||
g2_alphas=layer.output2_scales_scalar,
|
||||
a1_gscale=layer.w13_input_scale,
|
||||
a2_gscale=layer.w2_input_scale_inv,
|
||||
out_dtype=out_dtype,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_cutlass_moe_fp8(
|
||||
hidden_states: torch.Tensor,
|
||||
layer: torch.nn.Module,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None,
|
||||
layer=layer),
|
||||
select_cutlass_fp8_gemm_impl(moe=None,
|
||||
layer=layer,
|
||||
out_dtype=hidden_states.dtype))
|
||||
|
||||
return fused_experts(
|
||||
hidden_states,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||
if flashinfer_moe_backend == "throughput":
|
||||
return FlashinferMoeBackend.CUTLASS
|
||||
elif flashinfer_moe_backend == "latency":
|
||||
return FlashinferMoeBackend.TENSORRT_LLM
|
||||
|
||||
allowed_backends = ["throughput", "latency"]
|
||||
raise ValueError(
|
||||
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
|
||||
f" expected one of {allowed_backends}")
|
||||
|
Reference in New Issue
Block a user