mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Kernel] CUTLASS MoE FP8: Integrate cuda moe permute/unpermute (#23045)
Signed-off-by: Shixian Cui <shixian@amazon.com>
This commit is contained in:
@ -80,6 +80,11 @@ def bench_run(
|
||||
a, score, topk, renormalize=False
|
||||
)
|
||||
|
||||
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
def run_triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@ -111,6 +116,10 @@ def bench_run(
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
@ -125,6 +134,10 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
@ -136,6 +149,10 @@ def bench_run(
|
||||
w2_q: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
):
|
||||
@ -150,6 +167,10 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
@ -194,6 +215,10 @@ def bench_run(
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
)
|
||||
@ -231,6 +256,10 @@ def bench_run(
|
||||
"w1_scale": w1_scale,
|
||||
"w2_scale": w2_scale,
|
||||
"per_act_token": per_act_token,
|
||||
"ab_strides1": ab_strides1,
|
||||
"ab_strides2": ab_strides2,
|
||||
"c_strides1": c_strides1,
|
||||
"c_strides2": c_strides2,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
@ -289,6 +318,10 @@ def bench_run(
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token,
|
||||
@ -297,7 +330,7 @@ def bench_run(
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
|
@ -45,8 +45,6 @@ void moe_permute(
|
||||
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
|
||||
auto permuted_experts_id = torch::empty_like(topk_ids);
|
||||
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
|
||||
auto align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
|
||||
CubKeyValueSorter sorter{};
|
||||
int64_t* valid_num_ptr = nullptr;
|
||||
@ -85,12 +83,14 @@ void moe_permute(
|
||||
});
|
||||
|
||||
// get m_indices and update expert_first_token_offset with align block
|
||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||
stream);
|
||||
// this is only required for DeepGemm and not required for CUTLASS group gemm
|
||||
if (align_block_size.has_value()) {
|
||||
// update align_expert_first_token_offset
|
||||
auto align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||
stream);
|
||||
expert_first_token_offset.copy_(align_expert_first_token_offset);
|
||||
}
|
||||
}
|
||||
@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
|
||||
torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& src_row_id2dst_row_id_map,
|
||||
torch::Tensor& m_indices) {
|
||||
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
||||
TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
|
||||
}
|
||||
|
||||
void moe_unpermute(const torch::Tensor& input,
|
||||
const torch::Tensor& topk_weights, torch::Tensor& topk_ids,
|
||||
const torch::Tensor& token_expert_indices,
|
||||
const std::optional<torch::Tensor>& expert_map,
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
const std::optional<int64_t>& align_block_size,
|
||||
torch::Tensor& permuted_input,
|
||||
torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& src_row_id2dst_row_id_map,
|
||||
torch::Tensor& m_indices) {
|
||||
void moe_unpermute(
|
||||
const torch::Tensor& permuted_hidden_states,
|
||||
const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx,
|
||||
const std::optional<torch::Tensor>& expert_first_token_offset, int64_t topk,
|
||||
torch::Tensor& hidden_states) {
|
||||
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
||||
}
|
||||
|
||||
@ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() {
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_permute", &moe_permute);
|
||||
m.impl("moe_unpermute", &moe_unpermute);
|
||||
}
|
||||
}
|
@ -229,6 +229,11 @@ void get_cutlass_moe_mm_data(
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
|
@ -10,7 +10,7 @@
|
||||
|
||||
template <typename ElementAB, typename ElementC, typename ElementAccumulator>
|
||||
__global__ void get_group_gemm_starts(
|
||||
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
int64_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
|
||||
ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
||||
@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts(
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<int64_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
@ -61,6 +61,8 @@ void run_get_group_gemm_starts(
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
// expect int64_t to avoid overflow during offset calculations
|
||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
|
@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& atomic_buffer,
|
||||
int64_t num_experts, int64_t n,
|
||||
int64_t k, cudaStream_t stream,
|
||||
const bool swap_ab) {
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
const int32_t* topk_ptr = static_cast<const int32_t*>(topk_ids.data_ptr());
|
||||
int32_t* ps1_ptr = static_cast<int32_t*>(problem_sizes1.data_ptr());
|
||||
int32_t* ps2_ptr = static_cast<int32_t*>(problem_sizes2.data_ptr());
|
||||
int32_t* atomic_ptr = static_cast<int32_t*>(atomic_buffer.data_ptr());
|
||||
|
||||
if (swap_ab) {
|
||||
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
} else {
|
||||
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
|
||||
// Swap-AB should be disabled for FP4 path
|
||||
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
|
||||
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
|
||||
|
||||
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
|
||||
atomic_buffer, num_experts, n, k, stream,
|
||||
may_swap_ab);
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller(
|
||||
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
|
||||
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
|
||||
|
||||
if (may_swap_ab) {
|
||||
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
|
||||
k);
|
||||
} else {
|
||||
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
|
||||
k);
|
||||
}
|
||||
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
|
||||
atomic_buffer, num_experts, n, k, stream,
|
||||
may_swap_ab);
|
||||
|
||||
if (blockscale_offsets.has_value()) {
|
||||
// fp4 path
|
||||
|
@ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller(
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
@ -293,6 +298,25 @@ void get_cutlass_moe_mm_data(
|
||||
version_num, ". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
|
||||
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
|
||||
problem_sizes2, num_experts, n, k,
|
||||
blockscale_offsets);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
|
||||
"kernel for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
|
@ -440,6 +440,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
{stride_tag});
|
||||
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
|
||||
|
||||
// A function that computes problem sizes for each expert's multiplication
|
||||
// used by the two mms called from fused MoE operation. It takes topk_ids as
|
||||
// an input, and computes problem_sizes1 and problem_sizes2 only.
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_problem_sizes(Tensor topk_ids, "
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" int num_experts, int n, int k, "
|
||||
" Tensor? blockscale_offsets) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
|
||||
&get_cutlass_moe_mm_problem_sizes);
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
|
||||
// as an input, and computes expert_offsets (token start indices of each
|
||||
|
@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
'topk_ids': topk_ids,
|
||||
'w1_scale': moe_tensors.w1_scale,
|
||||
'w2_scale': moe_tensors.w2_scale,
|
||||
'ab_strides1': moe_tensors.ab_strides1,
|
||||
'ab_strides2': moe_tensors.ab_strides2,
|
||||
'c_strides1': moe_tensors.c_strides1,
|
||||
'c_strides2': moe_tensors.c_strides2,
|
||||
'per_act_token': per_act_token,
|
||||
'a1_scale': None #moe_tensors.a_scale
|
||||
}
|
||||
@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8(
|
||||
topk_ids[0][1] = 1
|
||||
|
||||
workspace13_shape = (m * topk, max(2 * n, k))
|
||||
workspace2_shape = (m * topk, n)
|
||||
output_shape = (m * topk, k)
|
||||
workspace2_shape = (m * topk, max(n, k))
|
||||
output_shape = (m, k)
|
||||
|
||||
workspace13 = torch.empty(prod(workspace13_shape),
|
||||
device="cuda",
|
||||
@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
|
||||
expert_map[start:end] = list(range(num_local_experts))
|
||||
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
|
||||
torch.float8_e4m3fn,
|
||||
@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
|
||||
func = lambda output: run_cutlass_moe_fp8(
|
||||
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
|
||||
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
|
||||
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
|
||||
per_act_token, per_out_channel, False)
|
||||
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
|
||||
workspace13, workspace2, None, mt.a.dtype, per_act_token,
|
||||
per_out_channel, False, topk_weights)
|
||||
|
||||
workspace13.random_()
|
||||
output_random_workspace = torch.empty(output_shape,
|
||||
|
@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
# check mindice
|
||||
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
|
||||
# current kernel usage assumes deepgemm requires align_block_size
|
||||
# when it's not provided then we don't compute m_indices (for cutlass)
|
||||
if align_block_size is not None:
|
||||
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
|
||||
|
||||
# check permuted_hidden_states, only valid token
|
||||
torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx],
|
||||
permuted_hidden_states[valid_row_idx],
|
||||
|
@ -76,6 +76,7 @@ def pplx_cutlass_moe(
|
||||
assert torch.cuda.current_device() == pgi.local_rank
|
||||
|
||||
num_tokens, hidden_dim = a.shape
|
||||
intermediate_dim = w2.shape[2]
|
||||
num_experts = w1.shape[0]
|
||||
block_size = hidden_dim # TODO support more cases
|
||||
device = pgi.device
|
||||
@ -124,8 +125,27 @@ def pplx_cutlass_moe(
|
||||
num_local_experts=num_local_experts,
|
||||
num_dispatchers=num_dispatchers)
|
||||
|
||||
ab_strides1 = torch.full((num_local_experts, ),
|
||||
hidden_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_local_experts, ),
|
||||
intermediate_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_local_experts, ),
|
||||
2 * intermediate_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_local_experts, ),
|
||||
hidden_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
|
||||
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
|
||||
out_dtype, per_act_token, per_out_ch)
|
||||
out_dtype, per_act_token, per_out_ch,
|
||||
ab_strides1, ab_strides2, c_strides1,
|
||||
c_strides2)
|
||||
|
||||
fused_cutlass_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
|
@ -535,7 +535,7 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
|
||||
|
||||
expert_offsets = torch.zeros((num_experts + 1),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
dtype=torch.int64)
|
||||
|
||||
problem_sizes = torch.zeros((num_experts, 3),
|
||||
device=device,
|
||||
|
@ -844,6 +844,28 @@ def get_cutlass_moe_mm_data(topk_ids: torch.Tensor,
|
||||
blockscale_offsets)
|
||||
|
||||
|
||||
def get_cutlass_moe_mm_problem_sizes(
|
||||
topk_ids: torch.Tensor,
|
||||
problem_sizes1: torch.Tensor,
|
||||
problem_sizes2: torch.Tensor,
|
||||
num_experts: int,
|
||||
n: int,
|
||||
k: int,
|
||||
blockscale_offsets: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Compute only the per-expert problem sizes needed by the two grouped matrix
|
||||
multiplications used in CUTLASS-based fused MoE.
|
||||
|
||||
The function takes in topk_ids (token→expert mapping) and computes:
|
||||
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
|
||||
multiplication for the two grouped MMs
|
||||
used in the fused MoE operation.
|
||||
"""
|
||||
return torch.ops._C.get_cutlass_moe_mm_problem_sizes(
|
||||
topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k,
|
||||
blockscale_offsets)
|
||||
|
||||
|
||||
def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
|
||||
"""
|
||||
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
|
||||
|
@ -9,12 +9,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
moe_permute, moe_unpermute)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
|
||||
_fp8_quantize,
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
||||
_resize_cache)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
@ -34,6 +35,10 @@ def run_cutlass_moe_fp8(
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
@ -41,6 +46,7 @@ def run_cutlass_moe_fp8(
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
use_batched_format: bool,
|
||||
topk_weights: Optional[torch.Tensor],
|
||||
):
|
||||
a1q = hidden_states
|
||||
|
||||
@ -99,6 +105,22 @@ def run_cutlass_moe_fp8(
|
||||
topk = local_topk_ids.size(1)
|
||||
local_E = w1.size(0)
|
||||
|
||||
if use_batched_format:
|
||||
mm1_out = _resize_cache(workspace13, (local_E * padded_M, N * 2))
|
||||
act_out = _resize_cache(workspace2, (local_E * padded_M, N))
|
||||
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
|
||||
(local_E * padded_M, N))
|
||||
mm2_out = _resize_cache(workspace2, (local_E * padded_M, K))
|
||||
else:
|
||||
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
|
||||
(M * topk, K))
|
||||
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
|
||||
act_out = _resize_cache(workspace2, (M * topk, N))
|
||||
# original workspace are based on input hidden_states dtype (bf16)
|
||||
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
|
||||
(M * topk, N))
|
||||
mm2_out = _resize_cache(workspace2, (M * topk, K))
|
||||
|
||||
if use_batched_format:
|
||||
assert expert_num_tokens is not None
|
||||
|
||||
@ -120,11 +142,10 @@ def run_cutlass_moe_fp8(
|
||||
w2_scale = w2_scale.reshape(w2_scale.size(0), -1)
|
||||
a1q = a1q.reshape(-1, a1q.size(2))
|
||||
a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous()
|
||||
|
||||
# c3x get_group_gemm_starts expects int64 to avoid overflow
|
||||
# during offset calculations
|
||||
expert_offsets = expert_offsets.to(torch.int64)
|
||||
else:
|
||||
expert_offsets = torch.empty((global_num_experts + 1),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
problem_sizes1 = torch.empty((global_num_experts, 3),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
@ -132,84 +153,57 @@ def run_cutlass_moe_fp8(
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
# With expert_map each Rank processes only a subset of experts. As
|
||||
# a result not all of a_map and c2 tensors are filled. We fill it
|
||||
# zeros for correctness.
|
||||
if expert_map is not None:
|
||||
a_map = torch.zeros((local_topk_ids.numel()),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
else:
|
||||
a_map = torch.empty((local_topk_ids.numel()),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
c_map = torch.empty((local_topk_ids.numel()),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets,
|
||||
problem_sizes1, problem_sizes2, a_map,
|
||||
c_map, global_num_experts, N, K)
|
||||
|
||||
a1q = _fp8_perm(a1q, a_map)
|
||||
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
|
||||
num_expert = global_num_experts if expert_map is None \
|
||||
else expert_map.size(0)
|
||||
# permuted a1q reuses workspace2
|
||||
a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute(
|
||||
a1q,
|
||||
a1q_scale,
|
||||
topk_ids,
|
||||
num_expert,
|
||||
local_E,
|
||||
expert_map,
|
||||
permuted_hidden_states=a1q_perm)
|
||||
expert_offsets = expert_offsets[:-1]
|
||||
|
||||
ab_strides1 = torch.full((w1.size(0), ),
|
||||
K,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
c_strides1 = torch.full((w1.size(0), ),
|
||||
2 * N,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
ab_strides2 = torch.full((w1.size(0), ),
|
||||
N,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
c_strides2 = torch.full((w1.size(0), ),
|
||||
K,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
|
||||
if use_batched_format:
|
||||
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
|
||||
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
|
||||
c3 = _resize_cache(workspace13, (local_E * padded_M, K))
|
||||
else:
|
||||
c1 = _resize_cache(workspace13, (M * topk, N * 2))
|
||||
c2 = _resize_cache(workspace2, (M * topk, N))
|
||||
c3 = _resize_cache(workspace13, (M * topk, K))
|
||||
ops.get_cutlass_moe_mm_problem_sizes(local_topk_ids, problem_sizes1,
|
||||
problem_sizes2,
|
||||
global_num_experts, N, K)
|
||||
|
||||
if not per_act_token and (expert_map is not None or use_batched_format):
|
||||
# this is necessary to avoid imprecise scale calculation caused by
|
||||
# random data in the unused workspace. The workspace is unused when
|
||||
# this rank handles only partial tokens, or when it is batched .
|
||||
c1.fill_(0)
|
||||
mm1_out.fill_(0)
|
||||
|
||||
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
|
||||
ops.cutlass_moe_mm(mm1_out, a1q, w1, a1q_scale, w1_scale, expert_offsets,
|
||||
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
|
||||
per_act_token, per_out_ch)
|
||||
|
||||
activation_callable(c2, c1)
|
||||
activation_callable(act_out, mm1_out)
|
||||
|
||||
a2q, a2q_scale = ops.scaled_fp8_quant(
|
||||
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
|
||||
act_out,
|
||||
a2_scale,
|
||||
use_per_token_if_dynamic=per_act_token,
|
||||
output=quant_out)
|
||||
|
||||
if expert_map is not None:
|
||||
c3.fill_(0)
|
||||
mm2_out.fill_(0)
|
||||
|
||||
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets,
|
||||
ops.cutlass_moe_mm(mm2_out, a2q, w2, a2q_scale, w2_scale, expert_offsets,
|
||||
problem_sizes2, ab_strides2, ab_strides2, c_strides2,
|
||||
per_act_token, per_out_ch)
|
||||
|
||||
if use_batched_format:
|
||||
output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True)
|
||||
output.copy_(mm2_out.reshape(local_E, padded_M, K), non_blocking=True)
|
||||
else:
|
||||
# We can't do this inplace because output may point to the same tensor
|
||||
# as c3.
|
||||
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
|
||||
# for non-chunking mode the output is resized from workspace13
|
||||
# so we need to make sure mm2_out uses workspace2.
|
||||
moe_unpermute(out=output,
|
||||
permuted_hidden_states=mm2_out,
|
||||
topk_weights=topk_weights,
|
||||
inv_permuted_idx=inv_perm)
|
||||
|
||||
|
||||
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
@ -219,6 +213,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
@ -229,6 +227,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
block_shape=block_shape,
|
||||
))
|
||||
self.out_dtype = out_dtype
|
||||
self.ab_strides1 = ab_strides1
|
||||
self.ab_strides2 = ab_strides2
|
||||
self.c_strides1 = c_strides1
|
||||
self.c_strides2 = c_strides2
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
@ -272,10 +274,11 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
run_cutlass_moe_fp8(
|
||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
|
||||
a2_scale, workspace13, workspace2, expert_num_tokens,
|
||||
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
|
||||
self.c_strides2, workspace13, workspace2, expert_num_tokens,
|
||||
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||
self.per_act_token_quant, self.per_out_ch_quant,
|
||||
use_batched_format)
|
||||
use_batched_format, topk_weights)
|
||||
|
||||
|
||||
class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
@ -285,12 +288,20 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
out_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
block_shape,
|
||||
)
|
||||
|
||||
@ -307,6 +318,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# topk weights and reduction are fused in moe_unpermute cuda kernel
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -320,8 +335,8 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
workspace1 = (M * topk, max(N, K))
|
||||
workspace2 = (M * topk, N // 2)
|
||||
output = (M * topk, K)
|
||||
workspace2 = (M * topk, max(N // 2, K))
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||
|
||||
@ -335,12 +350,20 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
out_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
block_shape,
|
||||
)
|
||||
assert max_experts_per_worker > 0
|
||||
@ -378,7 +401,8 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
assert num_dp is not None
|
||||
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
|
||||
max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2))
|
||||
workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
|
||||
max(N // 2, K))
|
||||
output = (self.max_experts_per_worker, padded_M, K)
|
||||
return (workspace1, workspace2, output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||
@ -392,6 +416,10 @@ def cutlass_moe_fp8(
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
per_act_token: Optional[bool] = None,
|
||||
activation: str = "silu",
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
@ -419,6 +447,17 @@ def cutlass_moe_fp8(
|
||||
Shape: [num_experts] or [num_experts, 2N]
|
||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||
Shape: [num_experts] or [num_experts, K]
|
||||
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
|
||||
Shape: [num_experts]
|
||||
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
|
||||
Shape: [num_experts]
|
||||
- c_strides1 (torch.Tensor): The output strides for the first gemm.
|
||||
Shape: [num_experts]
|
||||
- c_strides2 (torch.Tensor): The output strides for the second gemm.
|
||||
Shape: [num_experts]
|
||||
- per_act_token (Optional[bool]): Whether the scale is per-token or
|
||||
per-tensor.
|
||||
- activation (str): The activation function to use.
|
||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||
Shape: scalar or [M]
|
||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||
@ -450,6 +489,10 @@ def cutlass_moe_fp8(
|
||||
out_dtype=a.dtype,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
ab_strides1=ab_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides1=c_strides1,
|
||||
c_strides2=c_strides2,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -82,7 +82,8 @@ def moe_permute(
|
||||
n_local_expert: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
align_block_size: Optional[int] = None,
|
||||
fill_invalid_expert: int = -1
|
||||
fill_invalid_expert: int = -1,
|
||||
permuted_hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||
torch.Tensor]:
|
||||
"""
|
||||
@ -95,14 +96,17 @@ def moe_permute(
|
||||
- n_expert (int): The number of expert.
|
||||
- n_local_expert (int): The number of expert in current EP rank.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- align_block_size (Optional[int]): align group gemm block size for deepgemm
|
||||
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
|
||||
to workaround DeepGemm unsupported -1 in m_indices
|
||||
- permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor.
|
||||
If None, the output tensor will be created in this function.
|
||||
Returns:
|
||||
- permuted_hidden_states (torch.Tensor): permuted activation.
|
||||
- a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states
|
||||
- a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states
|
||||
if original scale not per-tensor scaling
|
||||
- expert_first_token_offset (torch.Tensor): offset of the first token
|
||||
of each expert for standard grouped gemm. if enable 'align_block_size'
|
||||
expert_first_token_offset will align up to 'align_block_size'.
|
||||
@ -122,11 +126,16 @@ def moe_permute(
|
||||
1) // align_block_size * align_block_size
|
||||
if n_local_expert == -1:
|
||||
n_local_expert = n_expert
|
||||
permuted_hidden_states = torch.empty(
|
||||
(permuted_row_size, n_hidden),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
if permuted_hidden_states is None:
|
||||
permuted_hidden_states = torch.empty(
|
||||
(permuted_row_size, n_hidden),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
assert permuted_hidden_states.size() == (permuted_row_size, n_hidden), (
|
||||
f"Expected permuted hidden states to be {(permuted_row_size, n_hidden)}"
|
||||
f" but got {permuted_hidden_states.size()}")
|
||||
|
||||
token_expert_indices = torch.arange(0,
|
||||
n_token * topk,
|
||||
dtype=torch.int32,
|
||||
@ -153,7 +162,8 @@ def moe_permute(
|
||||
align_block_size, permuted_hidden_states,
|
||||
expert_first_token_offset, inv_permuted_idx,
|
||||
permuted_idx, m_indices)
|
||||
if a1q_scale is not None:
|
||||
|
||||
if a1q_scale is not None and a1q_scale.dim() > 1:
|
||||
a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) //
|
||||
topk]
|
||||
return (permuted_hidden_states, a1q_scale, expert_first_token_offset,
|
||||
@ -185,6 +195,7 @@ def moe_unpermute(
|
||||
n_hidden = permuted_hidden_states.size(-1)
|
||||
assert (n_hidden * permuted_hidden_states.element_size()
|
||||
) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
|
||||
|
||||
torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights,
|
||||
inv_permuted_idx, expert_first_token_offset,
|
||||
topk, out)
|
||||
|
@ -669,6 +669,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
self.fused_experts_func = fused_experts
|
||||
|
||||
if self.use_cutlass:
|
||||
device = layer.w13_weight.device
|
||||
# ab_strides1 and c_strides2 are the same
|
||||
self.ab_strides1_c_strides2 = torch.full(
|
||||
(layer.local_num_experts, ),
|
||||
layer.hidden_size,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
self.ab_strides2 = torch.full(
|
||||
(layer.local_num_experts, ),
|
||||
layer.intermediate_size_per_partition,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
self.c_strides1 = torch.full(
|
||||
(layer.local_num_experts, ),
|
||||
2 * layer.intermediate_size_per_partition,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
@ -693,6 +712,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
moe.in_dtype,
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||
ab_strides1=self.ab_strides1_c_strides2,
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
c_strides2=self.ab_strides1_c_strides2,
|
||||
)
|
||||
else:
|
||||
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
|
||||
@ -700,6 +723,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
moe.in_dtype,
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||
ab_strides1=self.ab_strides1_c_strides2,
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
c_strides2=self.ab_strides1_c_strides2,
|
||||
)
|
||||
|
||||
self.disable_expert_map = (num_dispatchers > 1
|
||||
@ -822,6 +849,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
ab_strides1=self.ab_strides1_c_strides2,
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
c_strides2=self.ab_strides1_c_strides2,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
Reference in New Issue
Block a user