mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[GPTOSS][DP/EP][Marlin] Enable GPTOSS Batched DP/EP using Marlin kernels (#25997)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
2ed8b6b3d0
commit
fb0571b077
@ -8,12 +8,77 @@
|
||||
|
||||
#include "../cuda_compat.h"
|
||||
#include "../dispatch_utils.h"
|
||||
#include "core/math.hpp"
|
||||
|
||||
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
||||
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
namespace batched_moe_align_block_size {
|
||||
|
||||
// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel.
|
||||
static constexpr int32_t num_threads = 1024;
|
||||
static constexpr int32_t num_blocks = 1;
|
||||
__global__ void batched_moe_align_block_size_kernel(
|
||||
int32_t const num_batches, int32_t const max_tokens_per_batch,
|
||||
int32_t const block_size, int32_t const* __restrict__ batch_num_tokens,
|
||||
int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids,
|
||||
int32_t* __restrict__ num_tokens_post_pad) {
|
||||
// TODO(varun): This is a naive implementation. Could be optimized.
|
||||
|
||||
size_t const batch_id = threadIdx.x;
|
||||
size_t const stride = blockDim.x * gridDim.x;
|
||||
int32_t const num_blocks_per_batch =
|
||||
CEILDIV(max_tokens_per_batch, block_size);
|
||||
int32_t const sorted_ids_size =
|
||||
num_blocks_per_batch * num_batches * block_size;
|
||||
int32_t const block_ids_size = sorted_ids_size / block_size;
|
||||
int32_t const SENTINEL =
|
||||
num_batches * max_tokens_per_batch; // To denote invalid entries.
|
||||
// Intialize sorted_ids
|
||||
for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) {
|
||||
sorted_ids[i] = SENTINEL;
|
||||
}
|
||||
// Intialize expert_ids with -1
|
||||
for (size_t i = threadIdx.x; i < block_ids_size; i += stride) {
|
||||
block_ids[i] = -1;
|
||||
}
|
||||
|
||||
int32_t b_num_tokens = 0;
|
||||
if (batch_id < num_batches) {
|
||||
b_num_tokens = batch_num_tokens[batch_id];
|
||||
}
|
||||
int32_t const ceil_b_num_tokens =
|
||||
CEILDIV(b_num_tokens, block_size) * block_size;
|
||||
|
||||
// Compute prefix sum over token counts per expert
|
||||
using BlockScan = cub::BlockScan<int32_t, 1024>;
|
||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||
int cumsum_val;
|
||||
BlockScan(temp_storage).ExclusiveSum(ceil_b_num_tokens, cumsum_val);
|
||||
__syncthreads();
|
||||
|
||||
bool const is_last_batch = batch_id == (num_batches - 1);
|
||||
if (is_last_batch) {
|
||||
*num_tokens_post_pad = cumsum_val + ceil_b_num_tokens;
|
||||
}
|
||||
|
||||
if (batch_id < num_batches) {
|
||||
int32_t const batch_offset = batch_id * max_tokens_per_batch;
|
||||
for (size_t i = 0; i < b_num_tokens; ++i) {
|
||||
sorted_ids[cumsum_val + i] = batch_offset + i;
|
||||
}
|
||||
|
||||
int32_t const block_start = cumsum_val / block_size;
|
||||
int32_t const num_blocks = ceil_b_num_tokens / block_size;
|
||||
for (size_t i = 0; i < num_blocks; ++i) {
|
||||
block_ids[block_start + i] = batch_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace batched_moe_align_block_size
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
@ -280,6 +345,33 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
});
|
||||
}
|
||||
|
||||
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
|
||||
int64_t block_size,
|
||||
torch::Tensor const& batch_num_tokens,
|
||||
torch::Tensor sorted_ids,
|
||||
torch::Tensor batch_ids,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
namespace batched_kernel = vllm::moe::batched_moe_align_block_size;
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
int32_t const B = batch_num_tokens.size(0);
|
||||
int32_t const num_blocks_per_batch =
|
||||
round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size;
|
||||
int32_t const num_blocks = num_blocks_per_batch * B;
|
||||
int64_t const sorted_ids_size = num_blocks * block_size;
|
||||
|
||||
TORCH_CHECK(sorted_ids.size(0) == sorted_ids_size);
|
||||
TORCH_CHECK(batch_ids.size(0) == sorted_ids_size / block_size);
|
||||
TORCH_CHECK(num_tokens_post_pad.size(0) == 1);
|
||||
TORCH_CHECK(B <= batched_kernel::num_threads);
|
||||
|
||||
batched_kernel::batched_moe_align_block_size_kernel<<<
|
||||
batched_kernel::num_blocks, batched_kernel::num_threads, 0, stream>>>(
|
||||
B, max_tokens_per_batch, block_size, batch_num_tokens.data_ptr<int32_t>(),
|
||||
sorted_ids.data_ptr<int32_t>(), batch_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>());
|
||||
}
|
||||
|
||||
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
||||
torch::Tensor& output) // [num_tokens, hidden_size]
|
||||
{
|
||||
|
@ -12,6 +12,14 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
|
||||
int64_t block_size,
|
||||
torch::Tensor const& expert_num_tokens,
|
||||
torch::Tensor sorted_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
torch::Tensor b_qweight, torch::Tensor b_scales,
|
||||
|
@ -22,6 +22,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
" Tensor! num_tokens_post_pad) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
|
||||
// Aligning the number of tokens to be processed by each expert such
|
||||
// that it is divisible by the block size, but for the batched case.
|
||||
m.def(
|
||||
"batched_moe_align_block_size(int max_tokens_per_batch,"
|
||||
" int block_size, Tensor expert_num_tokens,"
|
||||
" Tensor! sorted_token_ids,"
|
||||
" Tensor! experts_ids,"
|
||||
" Tensor! num_tokens_post_pad) -> ()");
|
||||
m.impl("batched_moe_align_block_size", torch::kCUDA,
|
||||
&batched_moe_align_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
m.def(
|
||||
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
|
||||
|
Reference in New Issue
Block a user