[DP/EP][GPTOSS] Use triton matmul-ogs kernels for GPTOSS DP/EP (#24588)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-09-23 00:01:09 -04:00
committed by GitHub
parent fafbe11af4
commit e8db44f883
6 changed files with 275 additions and 76 deletions

View File

@ -288,7 +288,11 @@ class FusedMoEQuantConfig:
@property
def use_mxfp4_w4a4(self) -> bool:
return self.quant_dtype == "mxfp4"
return (self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4")
@property
def use_mxfp4_w4a16(self) -> bool:
return (self._a1.dtype is None and self._w1.dtype == "mxfp4")
@property
def use_nvfp4_w4a4(self) -> bool:
@ -453,6 +457,22 @@ def int8_w8a8_moe_quant_config(
)
def mxfp4_w4a16_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> FusedMoEQuantConfig:
"""
Construct a quant config for unquantized activations and mxfp4 weights.
"""
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(),
_a2=FusedMoEQuantDesc(),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
)
def mxfp4_w4a4_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],

View File

@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.utils import round_up
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
@ -18,6 +19,23 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
Prepare/Finalize using DeepEP High-Throughput kernels.
"""
@staticmethod
def maybe_roundup_layer_hidden_size(hidden_size: int,
dtype: torch.dtype) -> int:
# Round up hidden size so it is compatible with DeepEP High Throughput
# kernels.
# DeepEP intranode kernels make copies in units of,
# 32(warp-size) int4 elements. Round up hidden size to respect this.
# For example, an input hidden size of 2880 with dtype torch.bfloat16
# will be rounded up to 3072.
hidden_size_bytes = hidden_size * dtype.itemsize
xfer_atom_size = 512 # 32 * 16 (size(int4))
if hidden_size_bytes % xfer_atom_size == 0:
return hidden_size
hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size)
return hidden_size_bytes // dtype.itemsize
def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int,
dp_size: int, rank_expert_offset: int):
super().__init__()

View File

@ -9,7 +9,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
TopKWeightAndReduceNoOP)
from vllm.triton_utils import tl, triton
from vllm.utils import has_triton_kernels
logger = init_logger(__name__)
@ -19,13 +20,55 @@ if has_triton_kernels():
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
matmul_ogs)
from triton_kernels.routing import routing
from triton_kernels.routing import (RoutingData, routing,
routing_from_bitmatrix)
from triton_kernels.tensor import Bitmatrix
except (ModuleNotFoundError, AttributeError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
"version is compatible. Error: %s", e)
@triton.jit
def pack_bitmatrix(
bitmatrix,
topk_ids,
n_rows, # n_rows in bitmatrix / topk_ids
bm_cols: tl.constexpr, # n int32_t bitpacks in bitmatrix
n_expts_act, # num_topk
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
"""
Packs topk_ids into a bitmatrix.
code reference:
https://github.com/triton-lang/triton/blob/dd1bbc52b34d202dfe5ffea1e04fb16166c5c04e/python/triton_kernels/bench/distributed.py#L264
"""
pid_m = tl.program_id(0)
offsets_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offsets_k = tl.arange(0, BLOCK_SIZE_K)
offsets = offsets_m[:, None] * n_expts_act + offsets_k[None, :]
mask = (offsets_m < n_rows)[:, None] & (offsets_k < n_expts_act)[None, :]
indices = tl.load(topk_ids + offsets, mask=mask, other=-1)
div = indices // 32
rem = indices % 32
one = tl.cast(1, tl.uint32)
# Iterate through all the relevant bitmatrix columns.
for i in range(bm_cols):
# When BLOCK_SIZE_K=32, offs is just the column index.
offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32)
# All topks that need to go into this column has the correct bit set.
# Other bits are 0. x is a 2D tensor.
x = tl.where(div[:, :, None] == offs[None, None, :],
(one << rem)[:, :, None], 0)
# Reduce x to get a single int32_t bitpack.
y = tl.reduce_or(x, axis=1)
bitmatrix_ptrs = bitmatrix + offsets_m[:,
None] * bm_cols + offs[None, :]
tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows)
def triton_kernel_moe_forward(
hidden_states: torch.Tensor,
w1, # Tensor or triton_kernels.Tensor
@ -124,34 +167,88 @@ def triton_kernel_fused_experts(
return intermediate_cache3
class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def make_routing_data(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
num_local_experts: int,
) -> tuple["RoutingData", torch.Tensor, torch.Tensor]:
def __init__(
self,
max_num_tokens: int,
num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
):
topk_ids = topk_ids.to(torch.int16)
topk_weights = topk_weights.to(torch.bfloat16)
n_rows, num_topk = topk_ids.size()
BLOCK_SIZE_M = 512
BLOCK_SIZE_K = 32
bm_cols = triton.cdiv(num_local_experts, BLOCK_SIZE_K) # n_bitpacks
bitmatrix = torch.zeros((n_rows, bm_cols),
dtype=torch.uint32,
device=topk_ids.device)
grid = (triton.cdiv(n_rows, BLOCK_SIZE_M), )
pack_bitmatrix[grid](
bitmatrix,
topk_ids,
n_rows,
bm_cols,
num_topk,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_K=BLOCK_SIZE_K,
)
bitmatrix_shape = [n_rows, bm_cols * 32]
bitmatrix_shape_max = [n_rows, None]
bitmatrix = Bitmatrix(bitmatrix,
shape=bitmatrix_shape,
shape_max=bitmatrix_shape_max,
scratchpad=None)
# matmul_ogs expects invalid topk_weights to be -1s
topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights)
routing_data, gather_indx, scatter_indx = routing_from_bitmatrix(
bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk)
return routing_data, gather_indx, scatter_indx
class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig):
super().__init__(quant_config)
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Weight application and reduction happens in the fused_experts kernel.
return TopKWeightAndReduceNoOP()
def _make_routing_data(
self,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
num_local_experts: int,
) -> tuple["RoutingData", torch.Tensor, torch.Tensor]:
return make_routing_data(topk_ids, topk_weights, num_local_experts)
class OAITritonExperts(BaseOAITritonExperts):
def __init__(self, quant_config: FusedMoEQuantConfig):
# TODO (varun) : Enable activation quantization
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
super().__init__(quant_config)
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
return True
def workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
@ -159,13 +256,10 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# workspace are allocated inside the kernel
assert a.dim() == 2
num_dp = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = self.max_num_tokens
workspace2 = (0, 0, 0)
output = (num_experts, max_num_tokens * num_dp, N)
return (output, workspace2, output, a.dtype)
workspace1 = (M, K)
workspace2 = (0, 0)
output = (M, K)
return (workspace1, workspace2, output, a.dtype)
def apply(
self,
@ -185,17 +279,29 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
return triton_kernel_fused_experts(
output,
if expert_map is not None:
topk_ids = expert_map[topk_ids]
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts
routing_data, gather_indx, scatter_indx = self._make_routing_data(
topk_ids, topk_weights, local_num_experts)
experts_output = triton_kernel_fused_experts(
None,
hidden_states,
w1,
w2,
routing_data=None,
gather_indx=None,
scatter_indx=None,
routing_data,
gather_indx,
scatter_indx,
activation=activation,
quant_config=self.quant_config,
apply_router_weight_on_input=False,
global_num_experts=global_num_experts,
expert_map=expert_map,
global_num_experts=local_num_experts,
expert_map=None, # applied already
a1q_scale=a1q_scale)
output.copy_(experts_output, non_blocking=True)

View File

@ -800,6 +800,49 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
for local_index, global_index in zip(local_indices, global_indices))
def maybe_roundup_hidden_size(
hidden_size: int, act_dtype: torch.dtype,
quant_config: Optional[QuantizationConfig],
moe_parallel_config: FusedMoEParallelConfig) -> int:
"""
Given layer hidden size and MoE configurations, round up hidden_size
if necessary.
Args:
hidden_size(int): Layer hidden-size
act_dtype: Data type of the layer activations.
quant_config(FusedMoEQuantConfig): Fused MoE quantization configuration.
moe_parallel_config(FusedMoEParallelConfig): Fused MoE parallelization
strategy configuration.
Return:
Rounded up hidden_size if rounding up is required based on the configs.
Original hidden size otherwise.
"""
if (moe_parallel_config.use_deepep_ht_kernels):
hidden_size = (
DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size(
hidden_size, act_dtype))
# we are padding globally so EP buffer allocation works
if quant_config and quant_config.get_name() == "mxfp4":
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend, get_mxfp4_backend)
current_mxfp4_backend = get_mxfp4_backend()
if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS):
hidden_size = round_up(hidden_size, 128)
elif (current_platform.is_rocm() or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
hidden_size = round_up(hidden_size, 256)
return hidden_size
@CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
"""FusedMoE layer for MoE models.
@ -856,6 +899,18 @@ class FusedMoE(CustomOp):
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
vllm_config = get_current_vllm_config()
# FIXME (varun): We should have a better way of inferring the activation
# datatype. This works for now as the tensor datatype entering the MoE
# operation is typically unquantized (i.e. float16/bfloat16).
if vllm_config.model_config is not None:
moe_in_dtype = vllm_config.model_config.dtype
else:
# TODO (bnell): This is a hack to get test_mixtral_moe to work
# since model_config is not set in the pytest test.
moe_in_dtype = params_dtype
tp_size_ = (tp_size if tp_size is not None else
get_tensor_model_parallel_world_size())
dp_size_ = (dp_size
@ -865,7 +920,6 @@ class FusedMoE(CustomOp):
if self.is_sequence_parallel:
self.sp_size = tp_size_
vllm_config = get_current_vllm_config()
self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make(
tp_size_=tp_size_,
@ -874,19 +928,10 @@ class FusedMoE(CustomOp):
self.global_num_experts = num_experts + num_redundant_experts
# we are padding globally so EP buffer allocation works
if quant_config and quant_config.get_name() == "mxfp4":
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend, get_mxfp4_backend)
current_mxfp4_backend = get_mxfp4_backend()
if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS):
hidden_size = round_up(hidden_size, 128)
elif (current_platform.is_rocm() or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or
current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
hidden_size = round_up(hidden_size, 256)
# Round up hidden size if needed.
hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype,
quant_config,
self.moe_parallel_config)
# For smuggling this layer into the fused moe custom op
compilation_config = vllm_config.compilation_config
@ -967,20 +1012,13 @@ class FusedMoE(CustomOp):
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
if vllm_config.model_config is not None:
model_dtype = vllm_config.model_config.dtype
else:
# TODO (bnell): This is a hack to get test_mixtral_moe to work
# since model_config is not set in the pytest test.
model_dtype = params_dtype
moe = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype,
in_dtype=moe_in_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias,
)

View File

@ -76,7 +76,7 @@ def _moe_problem_size(
"""
assert w1.dim() == 3 and w2.dim() == 3
E, N, _ = w1.size()
K = w2.size(1)
K = a1.size(-1)
if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).

View File

@ -13,7 +13,10 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config)
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config,
mxfp4_w4a16_moe_quant_config)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts)
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
@ -578,9 +581,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
# FIXME warp need to be adjusted based on batch size
# only apply to batched mode
if self.moe.use_ep:
# Ideally we'd use FusedMoEModularKernel.prepare_finalize object
# (stored in self.fused_experts) to determine if the MoE has a
# batched activation format. As self.fused_experts is not
# initialized at this point, we resort to checking the MoE config
# directly.
is_batched_moe = (self.moe.use_pplx_kernels
or self.moe.use_deepep_ll_kernels)
if is_batched_moe:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
@ -640,16 +648,21 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
if self.mxfp4_backend == Mxfp4Backend.TRITON:
w1_scale = self.w13_precision_config
w2_scale = self.w2_precision_config
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
else:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
return mxfp4_w4a4_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
return mxfp4_w4a4_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
def select_gemm_impl(
self,
@ -661,6 +674,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
raise NotImplementedError(
"Mxfp4 does not support batched experts format for EP")
else:
assert self.moe_quant_config is not None
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
# B200 code-path
@ -671,13 +685,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# TODO(bnell): part of quant_config
"max_capture_size": self.max_capture_size,
}
assert self.moe_quant_config is not None
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
**kwargs)
else:
# Use matmul_ogs from triton_kernels here!
raise NotImplementedError(
"Mxfp4 does not support non-batched experts format for EP")
return OAITritonExperts(self.moe_quant_config)
def _route_and_experts(
self,
@ -722,10 +733,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count)
w13_weight = (self.w13_weight_triton_tensor
if layer.w13_weight is None else layer.w13_weight)
w2_weight = (self.w2_weight_triton_tensor
if layer.w2_weight is None else layer.w2_weight)
assert all([w is not None for w in [w13_weight, w2_weight]])
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1=w13_weight,
w2=w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,