[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
ElizaWszola
2025-06-07 03:26:11 +02:00
committed by GitHub
parent 6e0cd10f72
commit 84166fee97
26 changed files with 918 additions and 409 deletions

View File

@ -543,8 +543,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# CUTLASS MoE kernels
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
# to compile MoE kernels that use its output.
# on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled
# if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"

View File

@ -7,8 +7,8 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.fused_moe import (
cutlass_moe_fp8,
fused_experts,
fused_topk,
)
@ -70,18 +70,9 @@ def bench_run(
w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
for expert in range(num_experts):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
w1_q_notransp = w1_q.clone()
w2_q_notransp = w2_q.clone()
w1_q = w1_q.transpose(1, 2)
w2_q = w2_q.transpose(1, 2)
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
@ -122,10 +113,6 @@ def bench_run(
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
num_repeats: int,
):
for _ in range(num_repeats):
@ -133,14 +120,10 @@ def bench_run(
a,
w1,
w2,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
w1_scale,
w2_scale,
a1_scale=a_scale,
)
@ -153,10 +136,6 @@ def bench_run(
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
@ -165,14 +144,10 @@ def bench_run(
a,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
w1_scale,
w2_scale,
a1_scale=a_scale,
)
@ -218,10 +193,6 @@ def bench_run(
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
)
torch.cuda.synchronize()
@ -230,8 +201,8 @@ def bench_run(
with torch.cuda.graph(triton_graph, stream=triton_stream):
run_triton_from_graph(
a,
w1_q_notransp,
w2_q_notransp,
w1_q,
w2_q,
topk_weights,
topk_ids,
w1_scale,
@ -250,18 +221,12 @@ def bench_run(
"w2": w2,
"score": score,
"topk": topk,
"w1_q_notransp": w1_q_notransp,
"w2_q_notransp": w2_q_notransp,
# Cutlass params
"a_scale": a_scale,
"w1_q": w1_q,
"w2_q": w2_q,
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"ab_strides1": ab_strides1,
"c_strides1": c_strides1,
"ab_strides2": ab_strides2,
"c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
@ -279,8 +244,8 @@ def bench_run(
# Warmup
run_triton_moe(
a,
w1_q_notransp,
w2_q_notransp,
w1_q,
w2_q,
topk_weights,
topk_ids,
w1_scale,
@ -291,7 +256,7 @@ def bench_run(
results.append(
benchmark.Timer(
stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
stmt="run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
@ -322,16 +287,12 @@ def bench_run(
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
num_warmup,
)
results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,

View File

@ -236,7 +236,8 @@ void cutlass_moe_mm(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
@ -251,6 +252,14 @@ void get_cutlass_moe_mm_data(
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k);
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,

View File

@ -84,7 +84,8 @@ void run_cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
@ -113,19 +114,23 @@ void run_cutlass_moe_mm_sm90(
if (n >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (k >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (m <= 16) {
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}
@ -134,15 +139,18 @@ void dispatch_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.dtype() == torch::kBFloat16) {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}
@ -153,8 +161,9 @@ void cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides);
c_strides, per_act_token, per_out_ch);
}

View File

@ -76,7 +76,8 @@ void cutlass_group_gemm_caller(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
@ -84,9 +85,6 @@ void cutlass_group_gemm_caller(
int k_size = a_tensors.size(1);
int n_size = out_tensors.size(1);
bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto options_int =

View File

@ -7,7 +7,7 @@
constexpr uint64_t THREADS_PER_EXPERT = 512;
__global__ void compute_problem_sizes(const int* __restrict__ topk_ids,
__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
}
}
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation,
int32_t* output_permutation,
@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
@ -120,10 +120,44 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
}
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
topk_ids.size(1));
}
__global__ void compute_pplx_data(int32_t* expert_offsets,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
const int32_t* __restrict__ expert_num_tokens,
const int padded_m, const int n,
const int k) {
int expert_idx = threadIdx.x;
expert_offsets[expert_idx] = expert_idx * padded_m;
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
problem_sizes1[expert_idx * 3 + 2] = k;
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
problem_sizes2[expert_idx * 3 + 1] = k;
problem_sizes2[expert_idx * 3 + 2] = n;
}
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m,
const int64_t n, const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
compute_pplx_data<<<1, num_local_experts, 0, stream>>>(
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k);
}

View File

@ -36,7 +36,8 @@ void cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif
@ -56,6 +57,14 @@ void get_cutlass_moe_mm_data_caller(
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m,
const int64_t n, const int64_t k);
#endif
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
@ -207,12 +216,13 @@ void cutlass_moe_mm(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides);
c_strides, per_act_token, per_out_ch);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
@ -245,6 +255,29 @@ void get_cutlass_moe_mm_data(
version_num, ". Required capability: 90");
}
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: ",
version_num, ". Required capability: 90");
}
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,

View File

@ -435,7 +435,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor c_strides) -> ()",
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
" bool per_out_ch) -> ()",
{stride_tag});
ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm);
@ -454,6 +455,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
{stride_tag});
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
// as an input, and computes expert_offsets (token start indices of each
// expert). In addition to this, it computes problem sizes for each expert's
// multiplication used by the two mms called from fused MoE operation.
ops.def(
"get_cutlass_pplx_moe_mm_data(Tensor! expert_offsets, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" Tensor expert_num_tokens, "
" int num_local_experts, int padded_m, "
" int n, int k) -> ()",
{stride_tag});
ops.impl("get_cutlass_pplx_moe_mm_data", torch::kCUDA,
&get_cutlass_pplx_moe_mm_data);
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
ops.def(
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "

View File

@ -193,14 +193,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
kwargs = {
'a': moe_tensors.a,
'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr]
'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr]
'w1_q': moe_tensors.w1_q, # type: ignore[union-attr]
'w2_q': moe_tensors.w2_q, # type: ignore[union-attr]
'topk_weights': topk_weights,
'topk_ids': topk_ids,
'ab_strides1': moe_tensors.ab_strides1,
'c_strides1': moe_tensors.c_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides2': moe_tensors.c_strides2,
'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale,
'a1_scale': moe_tensors.a_scale

View File

@ -0,0 +1,287 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
try:
from pplx_kernels import AllToAll
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_finalize, nvshmem_get_unique_id,
nvshmem_init)
has_pplx = True
except ImportError:
has_pplx = False
requires_pplx = pytest.mark.skipif(
not has_pplx,
reason="Requires PPLX kernels",
)
NUM_EXPERTS = [40, 64]
TOP_KS = [6, 8]
def rank_chunk(num, r, w):
rem = num % w
return (num // w) + (1 if r < rem else 0)
def chunk_by_rank(t, r, w):
num = t.shape[0]
chunk = rank_chunk(num, r, w)
rem = num % w
if rem == 0 or r < rem:
return t[(r * chunk):(r + 1) * chunk].contiguous()
else:
long_chunks = (num // w + 1) * rem
short_chunks = (r - rem) * chunk
start = long_chunks + short_chunks
return t[start:start + chunk].contiguous()
def pplx_cutlass_moe(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
a1_scale: torch.Tensor,
out_dtype,
per_act_token: bool,
per_out_ch: bool,
):
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
assert torch.cuda.current_device() == pgi.local_rank
num_tokens, hidden_dim = a.shape
num_experts = w1.shape[0]
block_size = hidden_dim # TODO support more cases
device = pgi.device
rank = pgi.rank
world_size = pgi.world_size
rank_num_tokens = rank_chunk(num_tokens, rank, world_size)
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
topk = topk_ids.shape[1]
if block_size == hidden_dim:
scale_elems = 4 # hack to circumvent pplx data format requirements
else:
scale_elems = (hidden_dim + block_size - 1) // block_size
ata = AllToAll.internode(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
experts_per_token=topk,
rank=rank,
world_size=pgi.world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1
hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize,
)
w1 = w1.to(device)
w2 = w2.to(device)
w1_scale = w1_scale.to(device)
w2_scale = w2_scale.to(device)
a1_scale = a1_scale.to(device)
prepare_finalize = PplxPrepareAndFinalize(
ata,
max_num_tokens,
pgi.world_size,
rank,
dp_size,
quant_dtype=torch.float8_e4m3fn,
per_act_token=per_act_token,
)
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
out_dtype, per_act_token, per_out_ch)
fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weights, rank,
world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank,
world_size).to(torch.uint32).to(device)
out = fused_cutlass_experts(
a_chunk,
chunk_by_rank(w1, rank, world_size),
chunk_by_rank(w2, rank, world_size),
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts,
expert_map=None, #TODO
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank])
torch.cuda.synchronize()
ata.destroy()
return out[:rank_num_tokens]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def torch_moe2(a, w1, w2, topk_weight, topk_ids):
M, K = a.shape
topk = topk_ids.shape[1]
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
num_experts = w1.shape[0]
for i in range(num_experts):
mask = (topk_ids == i).view(-1)
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
def _pplx_moe(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
a1_scale: torch.Tensor,
out_dtype,
a_full: torch.Tensor,
w1_full: torch.Tensor,
w2_full: torch.Tensor,
per_act_token: bool,
per_out_ch: bool,
):
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights,
topk_ids)
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
w2_scale, topk_weights, topk_ids,
a1_scale, out_dtype, per_act_token,
per_out_ch)
torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)
# Uncomment if more debugging is needed
# print("PPLX OUT:", pplx_output)
# print("TORCH OUT:", torch_output)
torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
nvshmem_finalize()
@pytest.mark.parametrize("m", [2, 224])
@pytest.mark.parametrize("n", [3072])
@pytest.mark.parametrize("k", [1536])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
@requires_pplx
def test_cutlass_moe_pplx(
m: int,
n: int,
k: int,
e: int,
topk: int,
per_act_token: bool,
per_out_ch: bool,
world_dp_size: tuple[int, int],
):
current_platform.seed_everything(7)
with set_current_vllm_config(vllm_config):
dtype = torch.half
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10.0
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10.0
n_b_scales = 2 * n if per_out_ch else 1
k_b_scales = k if per_out_ch else 1
w1_q = torch.empty((e, 2 * n, k),
device="cuda",
dtype=torch.float8_e4m3fn)
w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
w1[expert], use_per_token_if_dynamic=per_out_ch)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
w2[expert], use_per_token_if_dynamic=per_out_ch)
w1_d = torch.empty_like(w1)
w2_d = torch.empty_like(w2)
for expert in range(e):
w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a,
score,
topk,
renormalize=False)
world_size, dp_size = world_dp_size
a_scale1 = torch.randn(
(m if per_act_token else 1, 1), device="cuda",
dtype=torch.float32) / 10.0
if not per_act_token:
a_scale1 = a_scale1.repeat(world_size, 1)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q,
w1_scale, w2_scale, topk_weights, topk_ids, a_scale1,
dtype, a, w1_d, w2_d, per_act_token, per_out_ch)

View File

@ -4,10 +4,7 @@
Run `pytest tests/kernels/test_pplx_moe.py`.
"""
import dataclasses
import os
import traceback
from typing import Callable, Optional
from typing import Optional
import pytest
import torch
@ -21,10 +18,7 @@ try:
except ImportError:
has_pplx = False
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import override_config
@ -36,6 +30,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
requires_pplx = pytest.mark.skipif(
not has_pplx,
reason="Requires PPLX kernels",
)
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
(222, 2048, 1024)]
@ -57,122 +56,6 @@ vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
P = ParamSpec("P")
requires_pplx = pytest.mark.skipif(
not has_pplx,
reason="Requires PPLX kernels",
)
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
"tcp://localhost:29500",
worker,
) + args,
nprocs=world_size,
join=True,
)
def parallel_launch_from_env(
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Launches a worker function in parallel across all processes in the current
environment. The environment must have the following variables set:
- WORLD_SIZE: The total number of processes.
- WORLD_LOCAL_SIZE: The number of processes on the current node.
- NODE_RANK: The rank of the current
- MASTER_ADDR: The address of the master process.
- MASTER_PORT: The port of the master process.
"""
assert not kwargs
world_size = int(os.environ["WORLD_SIZE"])
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
node_rank = int(os.environ["NODE_RANK"])
assert "MASTER_ADDR" in os.environ
assert "MASTER_PORT" in os.environ
spawn(
_worker_parallel_launch,
args=(
world_size,
world_local_size,
node_rank,
"env://",
worker,
) + args,
nprocs=world_local_size,
join=True,
)
def torch_prepare(
a: torch.Tensor,

View File

@ -632,7 +632,8 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked,
b_tensors_stacked, a_scales_tensors_stacked,
b_scales_tensors_stacked, expert_offsets[:-1],
problem_sizes, ab_strides, ab_strides, c_strides)
problem_sizes, ab_strides, ab_strides, c_strides,
per_act_token, per_out_ch)
# Validate each group's result against the baseline
for g in range(num_experts):

123
tests/pplx_utils.py Normal file
View File

@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import os
import traceback
from typing import Callable
import torch
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
P = ParamSpec("P")
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
"tcp://localhost:29500",
worker,
) + args,
nprocs=world_size,
join=True,
)
def parallel_launch_from_env(
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Launches a worker function in parallel across all processes in the current
environment. The environment must have the following variables set:
- WORLD_SIZE: The total number of processes.
- WORLD_LOCAL_SIZE: The number of processes on the current node.
- NODE_RANK: The rank of the current
- MASTER_ADDR: The address of the master process.
- MASTER_PORT: The port of the master process.
"""
assert not kwargs
world_size = int(os.environ["WORLD_SIZE"])
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
node_rank = int(os.environ["NODE_RANK"])
assert "MASTER_ADDR" in os.environ
assert "MASTER_PORT" in os.environ
spawn(
_worker_parallel_launch,
args=(
world_size,
world_local_size,
node_rank,
"env://",
worker,
) + args,
nprocs=world_local_size,
join=True,
)

View File

@ -899,11 +899,36 @@ def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
return output_tensor
def get_cutlass_pplx_moe_mm_data(expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
expert_num_tokens: torch.Tensor,
num_local_experts: int, padded_m: int, n: int,
k: int):
"""
Prepare data necessary to perform CUTLASS grouped matrix multiplications
used in CUTLASS-based fused MoE.
The function takes in expert_num_tokens (token count per expert) and
non_zero_expert_idxs (consecutive indices of experts with non-zero token
counts) and uses them to compute:
- expert_offsets: Indices that mark at which token index each expert begins
its computation.
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
multiplication in two grouped MMs used in
the fused MoE operation.
"""
return torch.ops._C.get_cutlass_pplx_moe_mm_data(
expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k)
def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
b_tensors: torch.Tensor, a_scales: torch.Tensor,
b_scales: torch.Tensor, expert_offsets: torch.Tensor,
problem_sizes: torch.Tensor, a_strides: torch.Tensor,
b_strides: torch.Tensor, c_strides: torch.Tensor):
b_strides: torch.Tensor, c_strides: torch.Tensor,
per_act_token: bool, per_out_ch: bool):
"""
A single grouped matrix multiplication used in CUTLASS-based fused MoE.
The function executes fp8-quantized OUT = AB matrix multiplication.
@ -918,7 +943,7 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors,
a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides,
c_strides)
c_strides, per_act_token, per_out_ch)
def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,

View File

@ -39,6 +39,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,

View File

@ -67,6 +67,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
@ -78,11 +79,11 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
return self.batched_deep_gemm_experts.workspace_shapes(
a, M, N, K, topk, num_experts)
a, aq, M, N, K, topk, num_experts)
else:
assert self.batched_triton_experts is not None
return self.batched_triton_experts.workspace_shapes(
a, M, N, K, topk, num_experts)
a, aq, M, N, K, topk, num_experts)
def apply(
self,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" CUTLASS based Fused MoE kernels."""
from typing import Optional
from typing import Callable, Optional
import torch
@ -13,110 +13,109 @@ from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
from vllm.scalar_type import scalar_types
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
def run_cutlass_moe_fp8(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation_callable: Callable,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
out_dtype: torch.dtype,
per_act_token: bool,
per_out_ch: bool,
) -> torch.Tensor:
a1q = hidden_states
def __init__(
self,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
out_dtype: torch.dtype,
):
super().__init__()
self.ab_strides1 = ab_strides1
self.c_strides1 = c_strides1
self.ab_strides2 = ab_strides2
self.c_strides2 = c_strides2
self.out_dtype = out_dtype
assert w1_scale is not None
assert w2_scale is not None
assert w1.dtype == torch.float8_e4m3fn
assert w2.dtype == torch.float8_e4m3fn
if expert_num_tokens is None:
assert a1q.shape[1] == w1.shape[2], "Hidden size mismatch w1"
else:
assert a1q.shape[2] == w1.shape[2], "Hidden size mismatch w1"
assert w1.shape[1] == w2.shape[2] * 2, "Hidden size mismatch w2"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1.shape[1], "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
1] == w2.shape[1], "W2 scale shape mismatch"
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
assert a1q_scale is None or a1q_scale.dim(
) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
0], "Input scale shape mismatch"
assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert a2_scale is None or a2_scale.dim(
) == 0 or a2_scale.shape[0] == 1 or a2_scale.shape[0] == a1q.shape[
0], "Intermediate scale shape mismatch"
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
if expert_map is not None:
assert expert_num_tokens is None
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
# Note that K, N are transposed
N, K = K, N
workspace1 = M * topk * max(2 * N, K)
workspace2 = M * topk * N
return (workspace1, workspace2, self.out_dtype)
# We have two modes: PPLX and non-PPLX. We differentiate them by checking
# if expert_num_tokens is None (expert_num_tokens is a tensor which PPLX
# uses to track the number of tokens per expert).
# In the non-PPLX mode, the input tokens are not padded: thus, the shape
# of the input is [total_num_tokens, hidden_size]. The input and output
# require shuffling by a_map and c_map such that the tokens assigned to
# each expert are contiguous.
# In the PPLX mode, the input tokens are padded per expert to ensure that
# the PPLX dispatch and combine functions work correctly: thus, the shape
# of the input is [num_experts, max_num_tokens_per_expert, hidden_size].
# The PPLX input and output require no shuffling by a_map and c_map since
# their tokens are already contiguous for each expert as a result of
# the dispatch function.
is_pplx = expert_num_tokens is not None
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
a1q = hidden_states
M = a1q.shape[0] # no pplx
padded_M = a1q.shape[1] # pplx
_, K, N = w2.shape
device = a1q.device
assert w1_scale is not None
assert w2_scale is not None
assert w1.dtype == torch.float8_e4m3fn
assert w2.dtype == torch.float8_e4m3fn
assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1"
assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2"
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
assert a1q_scale is None or a1q_scale.dim(
) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
0], "Input scale shape mismatch"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1.shape[2], "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
1] == w2.shape[2], "W2 scale shape mismatch"
assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
assert w1.shape[0] == w1_scale.shape[
0], "w1 scales expert number mismatch"
assert w1.shape[0] == w2_scale.shape[
0], "w2 scales expert number mismatch"
assert a2_scale is None or a1q_scale is None or a2_scale.shape == a1q_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
assert self.ab_strides1.shape[0] == w1.shape[
0], "AB Strides 1 expert number mismatch"
assert self.c_strides1.shape[0] == w1.shape[
0], "C Strides 1 expert number mismatch"
assert self.ab_strides2.shape[0] == w2.shape[
0], "AB Strides 2 expert number mismatch"
assert self.c_strides2.shape[0] == w2.shape[
0], "C Strides 2 expert number mismatch"
assert self.out_dtype in [torch.half,
torch.bfloat16], "Invalid output dtype"
assert w1.shape[2] == K
assert global_num_experts != -1
assert a1q_scale is not None
M = a1q.shape[0]
_, N, K = w2.shape # because w1 + w2 are transposed
device = a1q.device
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(expert_map[topk_ids] != -1,
expert_map[topk_ids], -1)
else:
local_topk_ids = topk_ids
assert w1.shape[1] == K
assert global_num_experts != -1
assert a1q_scale is not None
topk = local_topk_ids.shape[1]
local_E = w1.shape[0]
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(expert_map[topk_ids] != -1,
expert_map[topk_ids], -1)
else:
local_topk_ids = topk_ids
if is_pplx:
expert_offsets = torch.empty((local_E),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((local_E, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((local_E, 3),
dtype=torch.int32,
device=device)
topk = local_topk_ids.shape[1]
ops.get_cutlass_pplx_moe_mm_data(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
local_E, padded_M, N, K)
per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
w1_scale = w1_scale.reshape(w1_scale.shape[0], -1)
w2_scale = w2_scale.reshape(w2_scale.shape[0], -1)
a1q = a1q.reshape(-1, a1q.shape[2])
a1q_scale = a1q_scale.reshape(-1, a1q_scale.shape[2]).contiguous()
else:
expert_offsets = torch.empty((global_num_experts + 1),
dtype=torch.int32,
device=device)
@ -149,50 +148,130 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
expert_offsets = expert_offsets[:-1]
ab_strides1 = torch.full((w1.shape[0], ),
K,
device=device,
dtype=torch.int64)
c_strides1 = torch.full((w1.shape[0], ),
2 * N,
device=device,
dtype=torch.int64)
ab_strides2 = torch.full((w1.shape[0], ),
N,
device=device,
dtype=torch.int64)
c_strides2 = torch.full((w1.shape[0], ),
K,
device=device,
dtype=torch.int64)
if is_pplx:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
c3 = _resize_cache(workspace13, (local_E * padded_M, K))
else:
c1 = _resize_cache(workspace13, (M * topk, N * 2))
c2 = _resize_cache(workspace2, (M * topk, N))
c3 = _resize_cache(workspace13, (M * topk, K))
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale,
expert_offsets[:-1], problem_sizes1,
self.ab_strides1, self.ab_strides1, self.c_strides1)
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
per_act_token, per_out_ch)
self.activation(activation, c2, c1)
activation_callable(c2, c1)
a2q, a2q_scale = ops.scaled_fp8_quant(
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
a2q, a2q_scale = ops.scaled_fp8_quant(
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
if expert_map is not None:
c3.fill_(0)
if expert_map is not None:
c3.fill_(0)
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale,
expert_offsets[:-1], problem_sizes2,
self.ab_strides2, self.ab_strides2, self.c_strides2)
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets,
problem_sizes2, ab_strides2, ab_strides2, c_strides2,
per_act_token, per_out_ch)
c3 = c3[c_map]
return c3
if is_pplx:
return c3.reshape(local_E, padded_M, K)
else:
return c3[c_map].view(M, topk, K)
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_experts_per_worker: int,
out_dtype: torch.dtype,
per_act_token: bool,
per_out_ch: bool,
):
super().__init__()
self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype
self.per_act_token = per_act_token
self.per_out_ch = per_out_ch
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
padded_M = aq.shape[1]
workspace1 = self.max_experts_per_worker * padded_M * max(N, K)
workspace2 = self.max_experts_per_worker * padded_M * (N // 2)
return (workspace1, workspace2, self.out_dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
activation_callable = lambda i, o: self.activation(activation, i, o)
return run_cutlass_moe_fp8(hidden_states, w1, w2, topk_ids,
activation_callable, global_num_experts,
expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2,
expert_num_tokens, self.out_dtype,
self.per_act_token, self.per_out_ch)
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
def cutlass_moe_fp8(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.half,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
@ -207,25 +286,17 @@ def cutlass_moe_fp8(
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mappings.
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- ab_strides1 (torch.Tensor): The input and weights strides of the first
grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
- ab_strides2 (torch.Tensor): The input and weights strides of the second
grouped gemm.
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
- out_dtype (torch.dtype): The output tensor type.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
@ -233,24 +304,27 @@ def cutlass_moe_fp8(
expert-id i.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
- global_num_experts (int): The total number of experts.
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
per_out_ch = w1_scale.numel() != w1_q.shape[0]
out_dtype = a.dtype
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
per_channel_quant=per_act_token,
quant_dtype=torch.float8_e4m3fn,
per_channel_quant=per_act_token,
),
CutlassExpertsFp8(
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
out_dtype,
max_experts_per_worker=global_num_experts,
out_dtype=out_dtype,
per_act_token=per_act_token,
per_out_ch=per_out_ch,
),
)
@ -260,9 +334,12 @@ def cutlass_moe_fp8(
w2_q,
topk_weights,
topk_ids,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
False,
activation,
global_num_experts if global_num_experts != -1 else w1_q.size(0),
expert_map,
w1_scale,
w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,

View File

@ -73,6 +73,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,

View File

@ -521,6 +521,7 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
@ -632,6 +633,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,

View File

@ -1545,6 +1545,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,

View File

@ -9,6 +9,9 @@ from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
@ -210,6 +213,7 @@ class MoEConfig:
moe_parallel_config: FusedMoEParallelConfig
in_dtype: torch.dtype # The activation type.
quant_dtype: torch.dtype = None
# TODO: add more quantization params, blocked, per-token, etc.
block_size: int = 128
@ -264,8 +268,22 @@ class FusedMoeWeightScaleSupported(Enum):
BLOCK = "block"
def get_quant_config_input_activations(
quant_config: Optional[QuantizationConfig]
) -> Optional[QuantizationArgs]:
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
and "Linear" in quant_config.target_scheme_map and
"input_activations" in quant_config.target_scheme_map["Linear"]):
return quant_config.target_scheme_map["Linear"].get(
"input_activations")
else:
return None
class FusedMoEMethodBase(QuantizeMethodBase):
moe: MoEConfig
@abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@ -277,6 +295,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.moe = moe
quant_dtype = None
act_quant_block_size = None
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
@ -297,13 +316,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else (
(moe.hidden_dim + moe.block_size - 1) // moe.block_size *
torch.float32.itemsize)),
hidden_dim_scale_bytes=(
0 if moe.quant_dtype.itemsize != 1 else
((moe.hidden_dim + moe.block_size - 1) // moe.block_size *
torch.float32.itemsize)),
)
# Intranode pplx a2a takes a group name while internode does not.
@ -313,6 +333,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
handle = all2all_manager.get_handle(all_to_all_args)
input_activations = get_quant_config_input_activations(
quant_config)
prepare_finalize = PplxPrepareAndFinalize(
handle,
max_num_tokens=moe.max_num_tokens,
@ -320,7 +343,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
rank=all2all_manager.rank,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
quant_dtype=moe.in_dtype,
quant_dtype=moe.quant_dtype,
per_act_token=(input_activations.strategy
== QuantizationStrategy.TOKEN
if input_activations is not None else False),
)
elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size
@ -365,15 +391,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
self.topk_indices_dtype = None
if prepare_finalize is not None:
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize)
experts = self.select_gemm_impl(prepare_finalize, moe)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
def select_gemm_impl(
self, prepare_finalize: FusedMoEPrepareAndFinalize
) -> FusedMoEPermuteExpertsUnpermute:
self, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
# gemm implementation
raise NotImplementedError(
@ -419,7 +445,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else:
self.rocm_aiter_fused_experts = None # type: ignore
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize):
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: Optional[MoEConfig]):
assert self.fused_experts == fused_experts
@ -809,7 +836,6 @@ class FusedMoE(torch.nn.Module):
activation: str = "silu",
):
super().__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
@ -869,14 +895,24 @@ class FusedMoE(torch.nn.Module):
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
# Only support float8 for now.
quant_dtype = params_dtype
if quant_config is not None:
input_activations = get_quant_config_input_activations(
quant_config)
if (input_activations is not None
and input_activations.num_bits == 8
and input_activations.type == QuantizationType.FLOAT):
quant_dtype = torch.float8_e4m3fn
moe = MoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
quant_dtype=quant_dtype,
max_num_tokens=MOE_DP_CHUNK_SIZE,
)
self.moe_config = moe

View File

@ -175,6 +175,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
@ -309,7 +310,7 @@ class FusedMoEModularKernel(torch.nn.Module):
# Use a1 here to decipher the correct workspace datatype
workspace13_shape, workspace2_shape, workspace_dtype = (
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
global_num_experts))
# We can reuse the memory between cache1 and cache3 because by the time

View File

@ -21,7 +21,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
rank: int,
dp_size: int,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
block_shape: Optional[list[int]] = None,
per_act_token: bool = False):
super().__init__()
assert max_num_tokens > 0
self.a2a = a2a
@ -31,6 +32,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.rank = rank
self.dp_size = dp_size
self.quant_dtype = quant_dtype
self.per_act_token = per_act_token
def max_num_tokens_per_rank(self) -> Optional[int]:
return self.max_num_tokens
@ -66,13 +68,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * rank_topk_weights.to(a1.dtype)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
repeat_cols = 4
repeat_rows = 1 if self.per_act_token else a1.shape[0]
a1q, a1q_scale = moe_kernel_quantize_input(
a1, (None if self.per_act_token else a1_scale), self.quant_dtype,
self.per_act_token, self.block_shape)
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
self.quant_dtype,
per_act_token,
self.block_shape)
if a1q_scale is not None:
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
# rem_experts need to be 0 for pplx to work properly.
rem_experts = num_experts % self.world_size
@ -100,7 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else 1) * float32_size
expert_x_scale = torch.empty(
(
num_experts,
num_local_experts,
expert_x.size(1),
(expert_x.size(2) + block_size - 1) // block_size,
),
@ -121,6 +124,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
indices=rank_topk_ids,
bound_m=bound_m,
)
if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, 0:1]
return expert_x, expert_x_scale, expert_num_tokens, None, None

View File

@ -37,6 +37,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
@ -49,9 +50,9 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes(
a, M, N, K, topk, num_experts)
a, aq, M, N, K, topk, num_experts)
else:
return self.triton_expert.workspace_shapes(a, M, N, K, topk,
return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk,
num_experts)
def apply(

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import importlib
from enum import Enum
from typing import Callable, Optional
@ -11,7 +12,6 @@ from compressed_tensors.quantization import (ActivationOrdering,
QuantizationStrategy)
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
@ -30,6 +30,15 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
if current_platform.is_cuda_alike():
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize)
if has_pplx:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
logger = init_logger(__name__)
@ -77,8 +86,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
and layer.activation == "silu"):
elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
@ -421,6 +429,11 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization.")
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
self.fused_experts = cutlass_moe_fp8 # type: ignore
self.disable_expert_map = False
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
@ -499,25 +512,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = None
layer.w2_input_scale = None
device = w13_weight.device
# TODO strides can be shared across multiple layers
self.ab_strides1 = torch.full((num_experts, ),
hidden_size,
device=device,
dtype=torch.int64)
self.c_strides1 = torch.full((num_experts, ),
2 * intermediate_size_per_partition,
device=device,
dtype=torch.int64)
self.ab_strides2 = torch.full((num_experts, ),
intermediate_size_per_partition,
device=device,
dtype=torch.int64)
self.c_strides2 = torch.full((num_experts, ),
hidden_size,
device=device,
dtype=torch.int64)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
@ -558,6 +552,27 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
def select_gemm_impl(self, prepare_finalize, moe):
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp8)
assert moe is not None
max_experts_per_worker = (
(moe.num_experts + prepare_finalize.world_size - 1) //
prepare_finalize.world_size)
experts = CutlassExpertsFp8(
max_experts_per_worker, moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
if has_pplx and isinstance(
prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
# no expert_map support in this case
self.disable_expert_map = True
return experts
def apply(
self,
layer: torch.nn.Module,
@ -577,9 +592,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", (
f"{activation} not supported for Cutlass MoE.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -590,27 +602,22 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=torch.uint32)
from vllm.model_executor.layers.fused_moe import cutlass_moe_fp8
return cutlass_moe_fp8(
return self.fused_experts(
x,
layer.w13_weight.transpose(1, 2),
layer.w2_weight.transpose(1, 2),
layer.w13_weight_scale,
layer.w2_weight_scale,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
self.ab_strides1,
self.c_strides1,
self.ab_strides2,
self.c_strides2,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
out_dtype=x.dtype,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)

View File

@ -769,7 +769,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
def select_gemm_impl(self, prepare_finalize):
def select_gemm_impl(self, prepare_finalize, moe):
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)