mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[DP/EP][GPTOSS] Use triton matmul-ogs kernels for GPTOSS DP/EP (#24588)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
fafbe11af4
commit
e8db44f883
@ -288,7 +288,11 @@ class FusedMoEQuantConfig:
|
||||
|
||||
@property
|
||||
def use_mxfp4_w4a4(self) -> bool:
|
||||
return self.quant_dtype == "mxfp4"
|
||||
return (self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4")
|
||||
|
||||
@property
|
||||
def use_mxfp4_w4a16(self) -> bool:
|
||||
return (self._a1.dtype is None and self._w1.dtype == "mxfp4")
|
||||
|
||||
@property
|
||||
def use_nvfp4_w4a4(self) -> bool:
|
||||
@ -453,6 +457,22 @@ def int8_w8a8_moe_quant_config(
|
||||
)
|
||||
|
||||
|
||||
def mxfp4_w4a16_moe_quant_config(
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for unquantized activations and mxfp4 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(),
|
||||
_a2=FusedMoEQuantDesc(),
|
||||
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
|
||||
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
|
||||
)
|
||||
|
||||
|
||||
def mxfp4_w4a4_moe_quant_config(
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
|
@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
from vllm.utils import round_up
|
||||
|
||||
|
||||
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
@ -18,6 +19,23 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
Prepare/Finalize using DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def maybe_roundup_layer_hidden_size(hidden_size: int,
|
||||
dtype: torch.dtype) -> int:
|
||||
# Round up hidden size so it is compatible with DeepEP High Throughput
|
||||
# kernels.
|
||||
# DeepEP intranode kernels make copies in units of,
|
||||
# 32(warp-size) int4 elements. Round up hidden size to respect this.
|
||||
# For example, an input hidden size of 2880 with dtype torch.bfloat16
|
||||
# will be rounded up to 3072.
|
||||
hidden_size_bytes = hidden_size * dtype.itemsize
|
||||
xfer_atom_size = 512 # 32 * 16 (size(int4))
|
||||
if hidden_size_bytes % xfer_atom_size == 0:
|
||||
return hidden_size
|
||||
|
||||
hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size)
|
||||
return hidden_size_bytes // dtype.itemsize
|
||||
|
||||
def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int,
|
||||
dp_size: int, rank_expert_offset: int):
|
||||
super().__init__()
|
||||
|
@ -9,7 +9,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
TopKWeightAndReduceNoOP)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import has_triton_kernels
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -19,13 +20,55 @@ if has_triton_kernels():
|
||||
import triton_kernels.swiglu
|
||||
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
|
||||
matmul_ogs)
|
||||
from triton_kernels.routing import routing
|
||||
from triton_kernels.routing import (RoutingData, routing,
|
||||
routing_from_bitmatrix)
|
||||
from triton_kernels.tensor import Bitmatrix
|
||||
except (ModuleNotFoundError, AttributeError) as e:
|
||||
logger.error(
|
||||
"Failed to import Triton kernels. Please make sure your triton "
|
||||
"version is compatible. Error: %s", e)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def pack_bitmatrix(
|
||||
bitmatrix,
|
||||
topk_ids,
|
||||
n_rows, # n_rows in bitmatrix / topk_ids
|
||||
bm_cols: tl.constexpr, # n int32_t bitpacks in bitmatrix
|
||||
n_expts_act, # num_topk
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Packs topk_ids into a bitmatrix.
|
||||
code reference:
|
||||
https://github.com/triton-lang/triton/blob/dd1bbc52b34d202dfe5ffea1e04fb16166c5c04e/python/triton_kernels/bench/distributed.py#L264
|
||||
"""
|
||||
pid_m = tl.program_id(0)
|
||||
offsets_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offsets_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
offsets = offsets_m[:, None] * n_expts_act + offsets_k[None, :]
|
||||
mask = (offsets_m < n_rows)[:, None] & (offsets_k < n_expts_act)[None, :]
|
||||
indices = tl.load(topk_ids + offsets, mask=mask, other=-1)
|
||||
div = indices // 32
|
||||
rem = indices % 32
|
||||
one = tl.cast(1, tl.uint32)
|
||||
|
||||
# Iterate through all the relevant bitmatrix columns.
|
||||
for i in range(bm_cols):
|
||||
# When BLOCK_SIZE_K=32, offs is just the column index.
|
||||
offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32)
|
||||
# All topks that need to go into this column has the correct bit set.
|
||||
# Other bits are 0. x is a 2D tensor.
|
||||
x = tl.where(div[:, :, None] == offs[None, None, :],
|
||||
(one << rem)[:, :, None], 0)
|
||||
# Reduce x to get a single int32_t bitpack.
|
||||
y = tl.reduce_or(x, axis=1)
|
||||
bitmatrix_ptrs = bitmatrix + offsets_m[:,
|
||||
None] * bm_cols + offs[None, :]
|
||||
tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows)
|
||||
|
||||
|
||||
def triton_kernel_moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
w1, # Tensor or triton_kernels.Tensor
|
||||
@ -124,34 +167,88 @@ def triton_kernel_fused_experts(
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def make_routing_data(
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
num_local_experts: int,
|
||||
) -> tuple["RoutingData", torch.Tensor, torch.Tensor]:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
topk_ids = topk_ids.to(torch.int16)
|
||||
topk_weights = topk_weights.to(torch.bfloat16)
|
||||
|
||||
n_rows, num_topk = topk_ids.size()
|
||||
|
||||
BLOCK_SIZE_M = 512
|
||||
BLOCK_SIZE_K = 32
|
||||
|
||||
bm_cols = triton.cdiv(num_local_experts, BLOCK_SIZE_K) # n_bitpacks
|
||||
bitmatrix = torch.zeros((n_rows, bm_cols),
|
||||
dtype=torch.uint32,
|
||||
device=topk_ids.device)
|
||||
|
||||
grid = (triton.cdiv(n_rows, BLOCK_SIZE_M), )
|
||||
pack_bitmatrix[grid](
|
||||
bitmatrix,
|
||||
topk_ids,
|
||||
n_rows,
|
||||
bm_cols,
|
||||
num_topk,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
)
|
||||
|
||||
bitmatrix_shape = [n_rows, bm_cols * 32]
|
||||
bitmatrix_shape_max = [n_rows, None]
|
||||
bitmatrix = Bitmatrix(bitmatrix,
|
||||
shape=bitmatrix_shape,
|
||||
shape_max=bitmatrix_shape_max,
|
||||
scratchpad=None)
|
||||
|
||||
# matmul_ogs expects invalid topk_weights to be -1s
|
||||
topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights)
|
||||
routing_data, gather_indx, scatter_indx = routing_from_bitmatrix(
|
||||
bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk)
|
||||
|
||||
return routing_data, gather_indx, scatter_indx
|
||||
|
||||
|
||||
class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||
super().__init__(quant_config)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Weight application and reduction happens in the fused_experts kernel.
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def _make_routing_data(
|
||||
self,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
num_local_experts: int,
|
||||
) -> tuple["RoutingData", torch.Tensor, torch.Tensor]:
|
||||
return make_routing_data(topk_ids, topk_weights, num_local_experts)
|
||||
|
||||
|
||||
class OAITritonExperts(BaseOAITritonExperts):
|
||||
|
||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||
# TODO (varun) : Enable activation quantization
|
||||
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
|
||||
super().__init__(quant_config)
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||
return (mk.FusedMoEActivationFormat.Standard,
|
||||
mk.FusedMoEActivationFormat.Standard)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
return True
|
||||
|
||||
def workspace_shapes(
|
||||
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
|
||||
@ -159,13 +256,10 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
# workspace are allocated inside the kernel
|
||||
assert a.dim() == 2
|
||||
num_dp = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
max_num_tokens = self.max_num_tokens
|
||||
workspace2 = (0, 0, 0)
|
||||
output = (num_experts, max_num_tokens * num_dp, N)
|
||||
return (output, workspace2, output, a.dtype)
|
||||
workspace1 = (M, K)
|
||||
workspace2 = (0, 0)
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -185,17 +279,29 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
return triton_kernel_fused_experts(
|
||||
output,
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
routing_data, gather_indx, scatter_indx = self._make_routing_data(
|
||||
topk_ids, topk_weights, local_num_experts)
|
||||
|
||||
experts_output = triton_kernel_fused_experts(
|
||||
None,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
routing_data=None,
|
||||
gather_indx=None,
|
||||
scatter_indx=None,
|
||||
routing_data,
|
||||
gather_indx,
|
||||
scatter_indx,
|
||||
activation=activation,
|
||||
quant_config=self.quant_config,
|
||||
apply_router_weight_on_input=False,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
global_num_experts=local_num_experts,
|
||||
expert_map=None, # applied already
|
||||
a1q_scale=a1q_scale)
|
||||
|
||||
output.copy_(experts_output, non_blocking=True)
|
||||
|
@ -800,6 +800,49 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
|
||||
for local_index, global_index in zip(local_indices, global_indices))
|
||||
|
||||
|
||||
def maybe_roundup_hidden_size(
|
||||
hidden_size: int, act_dtype: torch.dtype,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
moe_parallel_config: FusedMoEParallelConfig) -> int:
|
||||
"""
|
||||
Given layer hidden size and MoE configurations, round up hidden_size
|
||||
if necessary.
|
||||
|
||||
Args:
|
||||
hidden_size(int): Layer hidden-size
|
||||
act_dtype: Data type of the layer activations.
|
||||
quant_config(FusedMoEQuantConfig): Fused MoE quantization configuration.
|
||||
moe_parallel_config(FusedMoEParallelConfig): Fused MoE parallelization
|
||||
strategy configuration.
|
||||
|
||||
Return:
|
||||
Rounded up hidden_size if rounding up is required based on the configs.
|
||||
Original hidden size otherwise.
|
||||
"""
|
||||
|
||||
if (moe_parallel_config.use_deepep_ht_kernels):
|
||||
hidden_size = (
|
||||
DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size(
|
||||
hidden_size, act_dtype))
|
||||
|
||||
# we are padding globally so EP buffer allocation works
|
||||
if quant_config and quant_config.get_name() == "mxfp4":
|
||||
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
Mxfp4Backend, get_mxfp4_backend)
|
||||
current_mxfp4_backend = get_mxfp4_backend()
|
||||
if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
|
||||
or current_mxfp4_backend
|
||||
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS):
|
||||
hidden_size = round_up(hidden_size, 128)
|
||||
elif (current_platform.is_rocm() or current_mxfp4_backend
|
||||
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
|
||||
return hidden_size
|
||||
|
||||
|
||||
@CustomOp.register("fused_moe")
|
||||
class FusedMoE(CustomOp):
|
||||
"""FusedMoE layer for MoE models.
|
||||
@ -856,6 +899,18 @@ class FusedMoE(CustomOp):
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
# FIXME (varun): We should have a better way of inferring the activation
|
||||
# datatype. This works for now as the tensor datatype entering the MoE
|
||||
# operation is typically unquantized (i.e. float16/bfloat16).
|
||||
if vllm_config.model_config is not None:
|
||||
moe_in_dtype = vllm_config.model_config.dtype
|
||||
else:
|
||||
# TODO (bnell): This is a hack to get test_mixtral_moe to work
|
||||
# since model_config is not set in the pytest test.
|
||||
moe_in_dtype = params_dtype
|
||||
|
||||
tp_size_ = (tp_size if tp_size is not None else
|
||||
get_tensor_model_parallel_world_size())
|
||||
dp_size_ = (dp_size
|
||||
@ -865,7 +920,6 @@ class FusedMoE(CustomOp):
|
||||
if self.is_sequence_parallel:
|
||||
self.sp_size = tp_size_
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||
FusedMoEParallelConfig.make(
|
||||
tp_size_=tp_size_,
|
||||
@ -874,19 +928,10 @@ class FusedMoE(CustomOp):
|
||||
|
||||
self.global_num_experts = num_experts + num_redundant_experts
|
||||
|
||||
# we are padding globally so EP buffer allocation works
|
||||
if quant_config and quant_config.get_name() == "mxfp4":
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
Mxfp4Backend, get_mxfp4_backend)
|
||||
current_mxfp4_backend = get_mxfp4_backend()
|
||||
if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
|
||||
or current_mxfp4_backend
|
||||
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS):
|
||||
hidden_size = round_up(hidden_size, 128)
|
||||
elif (current_platform.is_rocm() or current_mxfp4_backend
|
||||
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or
|
||||
current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
# Round up hidden size if needed.
|
||||
hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype,
|
||||
quant_config,
|
||||
self.moe_parallel_config)
|
||||
|
||||
# For smuggling this layer into the fused moe custom op
|
||||
compilation_config = vllm_config.compilation_config
|
||||
@ -967,20 +1012,13 @@ class FusedMoE(CustomOp):
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
"non-grouped topk.")
|
||||
|
||||
if vllm_config.model_config is not None:
|
||||
model_dtype = vllm_config.model_config.dtype
|
||||
else:
|
||||
# TODO (bnell): This is a hack to get test_mixtral_moe to work
|
||||
# since model_config is not set in the pytest test.
|
||||
model_dtype = params_dtype
|
||||
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=model_dtype,
|
||||
in_dtype=moe_in_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
has_bias=has_bias,
|
||||
)
|
||||
|
@ -76,7 +76,7 @@ def _moe_problem_size(
|
||||
"""
|
||||
assert w1.dim() == 3 and w2.dim() == 3
|
||||
E, N, _ = w1.size()
|
||||
K = w2.size(1)
|
||||
K = a1.size(-1)
|
||||
|
||||
if a1.dim() == 2:
|
||||
# Make sure we are using the correct a1 (pre-permute).
|
||||
|
@ -13,7 +13,10 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config)
|
||||
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config,
|
||||
mxfp4_w4a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
OAITritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
@ -578,9 +581,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
|
||||
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
|
||||
|
||||
# FIXME warp need to be adjusted based on batch size
|
||||
# only apply to batched mode
|
||||
if self.moe.use_ep:
|
||||
# Ideally we'd use FusedMoEModularKernel.prepare_finalize object
|
||||
# (stored in self.fused_experts) to determine if the MoE has a
|
||||
# batched activation format. As self.fused_experts is not
|
||||
# initialized at this point, we resort to checking the MoE config
|
||||
# directly.
|
||||
is_batched_moe = (self.moe.use_pplx_kernels
|
||||
or self.moe.use_deepep_ll_kernels)
|
||||
if is_batched_moe:
|
||||
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
|
||||
else:
|
||||
num_warps = 8
|
||||
@ -640,16 +648,21 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
if self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
w1_scale = self.w13_precision_config
|
||||
w2_scale = self.w2_precision_config
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
else:
|
||||
w1_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
|
||||
return mxfp4_w4a4_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
return mxfp4_w4a4_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
@ -661,6 +674,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 does not support batched experts format for EP")
|
||||
else:
|
||||
assert self.moe_quant_config is not None
|
||||
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
# B200 code-path
|
||||
@ -671,13 +685,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
# TODO(bnell): part of quant_config
|
||||
"max_capture_size": self.max_capture_size,
|
||||
}
|
||||
assert self.moe_quant_config is not None
|
||||
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
|
||||
**kwargs)
|
||||
else:
|
||||
# Use matmul_ogs from triton_kernels here!
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 does not support non-batched experts format for EP")
|
||||
return OAITritonExperts(self.moe_quant_config)
|
||||
|
||||
def _route_and_experts(
|
||||
self,
|
||||
@ -722,10 +733,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count)
|
||||
|
||||
w13_weight = (self.w13_weight_triton_tensor
|
||||
if layer.w13_weight is None else layer.w13_weight)
|
||||
w2_weight = (self.w2_weight_triton_tensor
|
||||
if layer.w2_weight is None else layer.w2_weight)
|
||||
assert all([w is not None for w in [w13_weight, w2_weight]])
|
||||
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1=w13_weight,
|
||||
w2=w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
|
Reference in New Issue
Block a user