[Kernel] CUTLASS MoE FP8: Integrate cuda moe permute/unpermute (#23045)

Signed-off-by: Shixian Cui <shixian@amazon.com>
This commit is contained in:
shixianc
2025-08-20 07:35:26 -07:00
committed by GitHub
parent 4449235843
commit b17109beea
15 changed files with 369 additions and 121 deletions

View File

@ -80,6 +80,11 @@ def bench_run(
a, score, topk, renormalize=False
)
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
@ -111,6 +116,10 @@ def bench_run(
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
@ -125,6 +134,10 @@ def bench_run(
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
)
@ -136,6 +149,10 @@ def bench_run(
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
):
@ -150,6 +167,10 @@ def bench_run(
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
)
@ -194,6 +215,10 @@ def bench_run(
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
)
@ -231,6 +256,10 @@ def bench_run(
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
"ab_strides1": ab_strides1,
"ab_strides2": ab_strides2,
"c_strides1": c_strides1,
"c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
@ -289,6 +318,10 @@ def bench_run(
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
per_act_token,
@ -297,7 +330,7 @@ def bench_run(
results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,

View File

@ -45,8 +45,6 @@ void moe_permute(
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
auto permuted_experts_id = torch::empty_like(topk_ids);
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
auto align_expert_first_token_offset =
torch::zeros_like(expert_first_token_offset);
CubKeyValueSorter sorter{};
int64_t* valid_num_ptr = nullptr;
@ -85,12 +83,14 @@ void moe_permute(
});
// get m_indices and update expert_first_token_offset with align block
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
get_ptr<int64_t>(align_expert_first_token_offset),
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
stream);
// this is only required for DeepGemm and not required for CUTLASS group gemm
if (align_block_size.has_value()) {
// update align_expert_first_token_offset
auto align_expert_first_token_offset =
torch::zeros_like(expert_first_token_offset);
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
get_ptr<int64_t>(align_expert_first_token_offset),
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
stream);
expert_first_token_offset.copy_(align_expert_first_token_offset);
}
}
@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
torch::Tensor& expert_first_token_offset,
torch::Tensor& src_row_id2dst_row_id_map,
torch::Tensor& m_indices) {
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
}
void moe_unpermute(const torch::Tensor& input,
const torch::Tensor& topk_weights, torch::Tensor& topk_ids,
const torch::Tensor& token_expert_indices,
const std::optional<torch::Tensor>& expert_map,
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor& permuted_input,
torch::Tensor& expert_first_token_offset,
torch::Tensor& src_row_id2dst_row_id_map,
torch::Tensor& m_indices) {
void moe_unpermute(
const torch::Tensor& permuted_hidden_states,
const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx,
const std::optional<torch::Tensor>& expert_first_token_offset, int64_t topk,
torch::Tensor& hidden_states) {
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
}
@ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() {
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_permute", &moe_permute);
m.impl("moe_unpermute", &moe_unpermute);
}
}

View File

@ -229,6 +229,11 @@ void get_cutlass_moe_mm_data(
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);
void get_cutlass_moe_mm_problem_sizes(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,

View File

@ -10,7 +10,7 @@
template <typename ElementAB, typename ElementC, typename ElementAccumulator>
__global__ void get_group_gemm_starts(
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
int64_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int,
ElementAB* b_base_as_int, ElementC* out_base_as_int,
@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts(
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
<<<1, num_experts, 0, stream>>>( \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<int64_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
@ -61,6 +61,8 @@ void run_get_group_gemm_starts(
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
// expect int64_t to avoid overflow during offset calculations
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
int num_experts = static_cast<int>(expert_offsets.size(0));
bool per_act_token = a_scales.numel() != 1;

View File

@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
}
}
namespace {
inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& atomic_buffer,
int64_t num_experts, int64_t n,
int64_t k, cudaStream_t stream,
const bool swap_ab) {
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
const int32_t* topk_ptr = static_cast<const int32_t*>(topk_ids.data_ptr());
int32_t* ps1_ptr = static_cast<int32_t*>(problem_sizes1.data_ptr());
int32_t* ps2_ptr = static_cast<int32_t*>(problem_sizes2.data_ptr());
int32_t* atomic_ptr = static_cast<int32_t*>(atomic_buffer.data_ptr());
if (swap_ab) {
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
static_cast<int>(k));
} else {
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
static_cast<int>(k));
}
}
} // namespace
void get_cutlass_moe_mm_problem_sizes_caller(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
// Swap-AB should be disabled for FP4 path
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
atomic_buffer, num_experts, n, k, stream,
may_swap_ab);
}
void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller(
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
if (may_swap_ab) {
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
k);
} else {
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
k);
}
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
atomic_buffer, num_experts, n, k, stream,
may_swap_ab);
if (blockscale_offsets.has_value()) {
// fp4 path

View File

@ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller(
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);
void get_cutlass_moe_mm_problem_sizes_caller(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
@ -293,6 +298,25 @@ void get_cutlass_moe_mm_data(
version_num, ". Required capability: 90 or 100");
}
void get_cutlass_moe_mm_problem_sizes(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
problem_sizes2, num_experts, n, k,
blockscale_offsets);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
"kernel for CUDA device capability: ",
version_num, ". Required capability: 90 or 100");
}
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,

View File

@ -440,6 +440,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
{stride_tag});
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
// A function that computes problem sizes for each expert's multiplication
// used by the two mms called from fused MoE operation. It takes topk_ids as
// an input, and computes problem_sizes1 and problem_sizes2 only.
ops.def(
"get_cutlass_moe_mm_problem_sizes(Tensor topk_ids, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" int num_experts, int n, int k, "
" Tensor? blockscale_offsets) -> ()",
{stride_tag});
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
&get_cutlass_moe_mm_problem_sizes);
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
// as an input, and computes expert_offsets (token start indices of each

View File

@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale,
'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2,
'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale
}
@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8(
topk_ids[0][1] = 1
workspace13_shape = (m * topk, max(2 * n, k))
workspace2_shape = (m * topk, n)
output_shape = (m * topk, k)
workspace2_shape = (m * topk, max(n, k))
output_shape = (m, k)
workspace13 = torch.empty(prod(workspace13_shape),
device="cuda",
@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
expert_map[start:end] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
torch.float8_e4m3fn,
@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
func = lambda output: run_cutlass_moe_fp8(
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
per_act_token, per_out_channel, False)
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
workspace13, workspace2, None, mt.a.dtype, per_act_token,
per_out_channel, False, topk_weights)
workspace13.random_()
output_random_workspace = torch.empty(output_shape,

View File

@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
atol=0,
rtol=0)
# check mindice
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
# current kernel usage assumes deepgemm requires align_block_size
# when it's not provided then we don't compute m_indices (for cutlass)
if align_block_size is not None:
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
# check permuted_hidden_states, only valid token
torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx],
permuted_hidden_states[valid_row_idx],

View File

@ -76,6 +76,7 @@ def pplx_cutlass_moe(
assert torch.cuda.current_device() == pgi.local_rank
num_tokens, hidden_dim = a.shape
intermediate_dim = w2.shape[2]
num_experts = w1.shape[0]
block_size = hidden_dim # TODO support more cases
device = pgi.device
@ -124,8 +125,27 @@ def pplx_cutlass_moe(
num_local_experts=num_local_experts,
num_dispatchers=num_dispatchers)
ab_strides1 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
ab_strides2 = torch.full((num_local_experts, ),
intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides1 = torch.full((num_local_experts, ),
2 * intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides2 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
out_dtype, per_act_token, per_out_ch)
out_dtype, per_act_token, per_out_ch,
ab_strides1, ab_strides2, c_strides1,
c_strides2)
fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,

View File

@ -535,7 +535,7 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
expert_offsets = torch.zeros((num_experts + 1),
device=device,
dtype=torch.int32)
dtype=torch.int64)
problem_sizes = torch.zeros((num_experts, 3),
device=device,

View File

@ -844,6 +844,28 @@ def get_cutlass_moe_mm_data(topk_ids: torch.Tensor,
blockscale_offsets)
def get_cutlass_moe_mm_problem_sizes(
topk_ids: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
num_experts: int,
n: int,
k: int,
blockscale_offsets: Optional[torch.Tensor] = None):
"""
Compute only the per-expert problem sizes needed by the two grouped matrix
multiplications used in CUTLASS-based fused MoE.
The function takes in topk_ids (token→expert mapping) and computes:
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
multiplication for the two grouped MMs
used in the fused MoE operation.
"""
return torch.ops._C.get_cutlass_moe_mm_problem_sizes(
topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k,
blockscale_offsets)
def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
"""
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.

View File

@ -9,12 +9,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute, moe_unpermute)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
_resize_cache)
from vllm.scalar_type import scalar_types
@ -34,6 +35,10 @@ def run_cutlass_moe_fp8(
w2_scale: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
@ -41,6 +46,7 @@ def run_cutlass_moe_fp8(
per_act_token: bool,
per_out_ch: bool,
use_batched_format: bool,
topk_weights: Optional[torch.Tensor],
):
a1q = hidden_states
@ -99,6 +105,22 @@ def run_cutlass_moe_fp8(
topk = local_topk_ids.size(1)
local_E = w1.size(0)
if use_batched_format:
mm1_out = _resize_cache(workspace13, (local_E * padded_M, N * 2))
act_out = _resize_cache(workspace2, (local_E * padded_M, N))
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(local_E * padded_M, N))
mm2_out = _resize_cache(workspace2, (local_E * padded_M, K))
else:
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
(M * topk, K))
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
act_out = _resize_cache(workspace2, (M * topk, N))
# original workspace are based on input hidden_states dtype (bf16)
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(M * topk, N))
mm2_out = _resize_cache(workspace2, (M * topk, K))
if use_batched_format:
assert expert_num_tokens is not None
@ -120,11 +142,10 @@ def run_cutlass_moe_fp8(
w2_scale = w2_scale.reshape(w2_scale.size(0), -1)
a1q = a1q.reshape(-1, a1q.size(2))
a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous()
# c3x get_group_gemm_starts expects int64 to avoid overflow
# during offset calculations
expert_offsets = expert_offsets.to(torch.int64)
else:
expert_offsets = torch.empty((global_num_experts + 1),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((global_num_experts, 3),
dtype=torch.int32,
device=device)
@ -132,84 +153,57 @@ def run_cutlass_moe_fp8(
dtype=torch.int32,
device=device)
# With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
# zeros for correctness.
if expert_map is not None:
a_map = torch.zeros((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
else:
a_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
c_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets,
problem_sizes1, problem_sizes2, a_map,
c_map, global_num_experts, N, K)
a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
num_expert = global_num_experts if expert_map is None \
else expert_map.size(0)
# permuted a1q reuses workspace2
a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute(
a1q,
a1q_scale,
topk_ids,
num_expert,
local_E,
expert_map,
permuted_hidden_states=a1q_perm)
expert_offsets = expert_offsets[:-1]
ab_strides1 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
c_strides1 = torch.full((w1.size(0), ),
2 * N,
device=device,
dtype=torch.int64)
ab_strides2 = torch.full((w1.size(0), ),
N,
device=device,
dtype=torch.int64)
c_strides2 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
if use_batched_format:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
c3 = _resize_cache(workspace13, (local_E * padded_M, K))
else:
c1 = _resize_cache(workspace13, (M * topk, N * 2))
c2 = _resize_cache(workspace2, (M * topk, N))
c3 = _resize_cache(workspace13, (M * topk, K))
ops.get_cutlass_moe_mm_problem_sizes(local_topk_ids, problem_sizes1,
problem_sizes2,
global_num_experts, N, K)
if not per_act_token and (expert_map is not None or use_batched_format):
# this is necessary to avoid imprecise scale calculation caused by
# random data in the unused workspace. The workspace is unused when
# this rank handles only partial tokens, or when it is batched .
c1.fill_(0)
mm1_out.fill_(0)
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
ops.cutlass_moe_mm(mm1_out, a1q, w1, a1q_scale, w1_scale, expert_offsets,
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
per_act_token, per_out_ch)
activation_callable(c2, c1)
activation_callable(act_out, mm1_out)
a2q, a2q_scale = ops.scaled_fp8_quant(
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
act_out,
a2_scale,
use_per_token_if_dynamic=per_act_token,
output=quant_out)
if expert_map is not None:
c3.fill_(0)
mm2_out.fill_(0)
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets,
ops.cutlass_moe_mm(mm2_out, a2q, w2, a2q_scale, w2_scale, expert_offsets,
problem_sizes2, ab_strides2, ab_strides2, c_strides2,
per_act_token, per_out_ch)
if use_batched_format:
output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True)
output.copy_(mm2_out.reshape(local_E, padded_M, K), non_blocking=True)
else:
# We can't do this inplace because output may point to the same tensor
# as c3.
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
# for non-chunking mode the output is resized from workspace13
# so we need to make sure mm2_out uses workspace2.
moe_unpermute(out=output,
permuted_hidden_states=mm2_out,
topk_weights=topk_weights,
inv_permuted_idx=inv_perm)
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
@ -219,6 +213,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
):
super().__init__(
@ -229,6 +227,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
block_shape=block_shape,
))
self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = c_strides2
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
@ -272,10 +274,11 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens,
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
self.c_strides2, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
use_batched_format)
use_batched_format, topk_weights)
class CutlassExpertsFp8(CutlassExpertsFp8Base):
@ -285,12 +288,20 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape,
)
@ -307,6 +318,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# topk weights and reduction are fused in moe_unpermute cuda kernel
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
a: torch.Tensor,
@ -320,8 +335,8 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, N // 2)
output = (M * topk, K)
workspace2 = (M * topk, max(N // 2, K))
output = (M, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
@ -335,12 +350,20 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape,
)
assert max_experts_per_worker > 0
@ -378,7 +401,8 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
max(N // 2, K))
output = (self.max_experts_per_worker, padded_M, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
@ -392,6 +416,10 @@ def cutlass_moe_fp8(
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
per_act_token: Optional[bool] = None,
activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
@ -419,6 +447,17 @@ def cutlass_moe_fp8(
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
@ -450,6 +489,10 @@ def cutlass_moe_fp8(
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
),
)

View File

@ -82,7 +82,8 @@ def moe_permute(
n_local_expert: int = -1,
expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1
fill_invalid_expert: int = -1,
permuted_hidden_states: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor]:
"""
@ -95,14 +96,17 @@ def moe_permute(
- n_expert (int): The number of expert.
- n_local_expert (int): The number of expert in current EP rank.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
from the global expert space to the local expert space of the expert
parallel shard.
- align_block_size (Optional[int]): align group gemm block size for deepgemm
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
to workaround DeepGemm unsupported -1 in m_indices
- permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor.
If None, the output tensor will be created in this function.
Returns:
- permuted_hidden_states (torch.Tensor): permuted activation.
- a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states
- a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states
if original scale not per-tensor scaling
- expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for standard grouped gemm. if enable 'align_block_size'
expert_first_token_offset will align up to 'align_block_size'.
@ -122,11 +126,16 @@ def moe_permute(
1) // align_block_size * align_block_size
if n_local_expert == -1:
n_local_expert = n_expert
permuted_hidden_states = torch.empty(
(permuted_row_size, n_hidden),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
if permuted_hidden_states is None:
permuted_hidden_states = torch.empty(
(permuted_row_size, n_hidden),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
assert permuted_hidden_states.size() == (permuted_row_size, n_hidden), (
f"Expected permuted hidden states to be {(permuted_row_size, n_hidden)}"
f" but got {permuted_hidden_states.size()}")
token_expert_indices = torch.arange(0,
n_token * topk,
dtype=torch.int32,
@ -153,7 +162,8 @@ def moe_permute(
align_block_size, permuted_hidden_states,
expert_first_token_offset, inv_permuted_idx,
permuted_idx, m_indices)
if a1q_scale is not None:
if a1q_scale is not None and a1q_scale.dim() > 1:
a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) //
topk]
return (permuted_hidden_states, a1q_scale, expert_first_token_offset,
@ -185,6 +195,7 @@ def moe_unpermute(
n_hidden = permuted_hidden_states.size(-1)
assert (n_hidden * permuted_hidden_states.element_size()
) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights,
inv_permuted_idx, expert_first_token_offset,
topk, out)

View File

@ -669,6 +669,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts
if self.use_cutlass:
device = layer.w13_weight.device
# ab_strides1 and c_strides2 are the same
self.ab_strides1_c_strides2 = torch.full(
(layer.local_num_experts, ),
layer.hidden_size,
device=device,
dtype=torch.int64)
self.ab_strides2 = torch.full(
(layer.local_num_experts, ),
layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64)
self.c_strides1 = torch.full(
(layer.local_num_experts, ),
2 * layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
@ -693,6 +712,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
)
else:
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
@ -700,6 +723,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
)
self.disable_expert_map = (num_dispatchers > 1
@ -822,6 +849,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)