[Kernel][Performance] Fuse float cast and renormalize to topk softmax kernel (#26717)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
This commit is contained in:
zhrrr
2025-10-17 15:30:35 +08:00
committed by GitHub
parent 5550ff9c25
commit 75c7ad9918
5 changed files with 221 additions and 94 deletions

View File

@ -4,7 +4,7 @@
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);
torch::Tensor& gating_output, bool renormalize);
void moe_sum(torch::Tensor& input, torch::Tensor& output);

View File

@ -16,12 +16,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <type_traits>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h"
#include "../cub_helpers.h"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
typedef __hip_bfloat16 __nv_bfloat16;
typedef __hip_bfloat162 __nv_bfloat162;
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
@ -36,16 +47,27 @@ template <
/// Alignment requirement in bytes
int Alignment = sizeof(T) * N
>
class alignas(Alignment) AlignedArray {
float data[N];
struct alignas(Alignment) AlignedArray {
T data[N];
};
template <typename T>
__device__ __forceinline__ float toFloat(T value) {
if constexpr (std::is_same_v<T, float>) {
return value;
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
return __bfloat162float(value);
} else if constexpr (std::is_same_v<T, __half>) {
return __half2float(value);
}
}
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
template <int TPB>
template <int TPB, typename InputType>
__launch_bounds__(TPB) __global__
void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
void moeSoftmax(const InputType* input, const bool* finished, float* output, const int num_cols)
{
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
@ -66,7 +88,8 @@ __launch_bounds__(TPB) __global__
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData);
const float val = toFloat(input[idx]);
threadData = max(val, threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp());
@ -81,7 +104,8 @@ __launch_bounds__(TPB) __global__
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
threadData += exp((static_cast<float>(input[idx]) - float_max));
const float val = toFloat(input[idx]);
threadData += expf(val - float_max);
}
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp());
@ -95,8 +119,9 @@ __launch_bounds__(TPB) __global__
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
output[idx] = val;
const float val = toFloat(input[idx]);
const float softmax_val = expf(val - float_max) * normalizing_factor;
output[idx] = softmax_val;
}
}
@ -110,7 +135,8 @@ __launch_bounds__(TPB) __global__ void moeTopK(
const int num_experts,
const int k,
const int start_expert,
const int end_expert)
const int end_expert,
const bool renormalize)
{
using cub_kvp = cub::KeyValuePair<int, float>;
@ -125,6 +151,7 @@ __launch_bounds__(TPB) __global__ void moeTopK(
const bool row_is_active = finished ? !finished[block_row] : true;
const int thread_read_offset = blockIdx.x * num_experts;
float selected_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx)
{
thread_kvp.key = 0;
@ -163,9 +190,23 @@ __launch_bounds__(TPB) __global__ void moeTopK(
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
assert(indices[idx] >= 0);
source_rows[idx] = k_idx * num_rows + block_row;
if (renormalize) {
selected_sum += result_kvp.value;
}
}
__syncthreads();
}
// Renormalize the k weights for this row to sum to 1, if requested.
if (renormalize) {
if (threadIdx.x == 0) {
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int idx = k * block_row + k_idx;
output[idx] = output[idx] / denom;
}
}
}
}
// ====================== TopK softmax things ===============================
@ -184,21 +225,30 @@ __launch_bounds__(TPB) __global__ void moeTopK(
2) This implementation assumes k is small, but will work for any k.
*/
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType>
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType, typename InputType = float>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert)
void topkGatingSoftmax(const InputType* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert, const bool renormalize)
{
static_assert(std::is_same_v<InputType, float> || std::is_same_v<InputType, __nv_bfloat16> ||
std::is_same_v<InputType, __half>,
"InputType must be float, __nv_bfloat16, or __half");
// We begin by enforcing compile time assertions and setting up compile time constants.
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
// Number of bytes each thread pulls in per load
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
if constexpr (std::is_same_v<InputType, __nv_bfloat16> || std::is_same_v<InputType, __half>) {
static_assert(ELTS_PER_LDG == 1 || ELTS_PER_LDG % 2 == 0,
"ELTS_PER_LDG must be 1 or even for 16-bit conversion");
}
// Restrictions based on previous section.
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
@ -236,28 +286,72 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
// row it will read.
const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
const InputType* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
// this can support all powers of 2 up to 16.
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
using AccessType = AlignedArray<float, ELTS_PER_LDG>;
const InputType* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
// Finally, we pull in the data from global mem
float row_chunk[VPT];
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
// NOTE(zhuhaoran): dispatch different input types loading, BF16/FP16 convert to float
if constexpr (std::is_same_v<InputType, float>) {
using VecType = AlignedArray<float, ELTS_PER_LDG>;
VecType* row_chunk_vec_ptr = reinterpret_cast<VecType*>(&row_chunk);
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
{
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
}
} else if constexpr (std::is_same_v<InputType, __nv_bfloat16>) {
if constexpr (ELTS_PER_LDG >= 2) {
using VecType = AlignedArray<__nv_bfloat16, ELTS_PER_LDG>;
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
#pragma unroll
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
row_chunk_f2[base_idx_f2 + jj] = __bfloat1622float2(
*reinterpret_cast<const __nv_bfloat162*>(vec.data + jj * 2)
);
}
}
} else { // ELTS_PER_LDG == 1
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
const __nv_bfloat16* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
row_chunk[ii] = __bfloat162float(*scalar_ptr);
}
}
} else if constexpr (std::is_same_v<InputType, __half>) {
if constexpr (ELTS_PER_LDG >= 2) {
using VecType = AlignedArray<__half, ELTS_PER_LDG>;
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
#pragma unroll
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
row_chunk_f2[base_idx_f2 + jj] = __half22float2(
*reinterpret_cast<const __half2*>(vec.data + jj * 2)
);
}
}
} else { // ELTS_PER_LDG == 1
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
const __half* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
row_chunk[ii] = __half2float(*scalar_ptr);
}
}
}
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
// convert to float afterwards for the exp + sum reduction.
@ -310,6 +404,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
int start_col = first_elt_read_by_thread;
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
float selected_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx)
{
// First, each thread does the local argmax
@ -363,6 +458,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
output[idx] = max_val;
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
if (renormalize) {
selected_sum += max_val;
}
}
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
@ -380,15 +478,28 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
}
}
}
// Renormalize the k weights for this row to sum to 1, if requested.
if (renormalize) {
if (thread_group_idx == 0)
{
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
for (int k_idx = 0; k_idx < k; ++k_idx)
{
const int idx = k * thread_row + k_idx;
output[idx] = output[idx] / denom;
}
}
}
}
namespace detail
{
// Constructs some constants needed to partition the work across threads at compile time.
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM>
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename InputType>
struct TopkConstants
{
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, "");
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
@ -397,20 +508,21 @@ struct TopkConstants
};
} // namespace detail
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType, typename InputType>
void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, const bool renormalize,
cudaStream_t stream)
{
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM, InputType>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM><<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM, IndType, InputType><<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert, renormalize);
}
#ifndef USE_ROCM
@ -418,26 +530,26 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
static_assert(WARP_SIZE == 32, \
"Unsupported warp size. Only 32 is supported for CUDA"); \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, stream);
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
num_tokens, topk, 0, num_experts, renormalize, stream);
#else
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
if (WARP_SIZE == 64) { \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
num_tokens, topk, 0, num_experts, renormalize, stream); \
} else if (WARP_SIZE == 32) { \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
num_tokens, topk, 0, num_experts, renormalize, stream); \
} else { \
assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
}
#endif
template <typename IndType>
template <typename IndType, typename InputType>
void topkGatingSoftmaxKernelLauncher(
const float* gating_output,
const InputType* gating_output,
float* topk_weights,
IndType* topk_indices,
int* token_expert_indices,
@ -445,11 +557,15 @@ void topkGatingSoftmaxKernelLauncher(
const int num_tokens,
const int num_experts,
const int topk,
const bool renormalize,
cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
#ifndef USE_ROCM
static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8;
// for bfloat16 dtype, we need 4 bytes loading to make sure num_experts
// elements can be loaded by a warp
static constexpr int BYTES_PER_LDG_MULTIPLE_64 =
(std::is_same_v<InputType, __nv_bfloat16> || std::is_same_v<InputType, __half>) ? 4 : 8;
#endif
switch (num_experts) {
case 1:
@ -506,11 +622,11 @@ void topkGatingSoftmaxKernelLauncher(
TORCH_CHECK(softmax_workspace != nullptr,
"softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64.");
static constexpr int TPB = 256;
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
moeSoftmax<TPB, InputType><<<num_tokens, TPB, 0, stream>>>(
gating_output, nullptr, softmax_workspace, num_experts);
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices,
num_experts, topk, 0, num_experts);
num_experts, topk, 0, num_experts, renormalize);
}
}
}
@ -518,11 +634,50 @@ void topkGatingSoftmaxKernelLauncher(
} // namespace moe
} // namespace vllm
template<typename ComputeType>
void dispatch_topk_softmax_launch(
torch::Tensor& gating_output,
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& softmax_workspace,
int num_tokens, int num_experts, int topk, bool renormalize, cudaStream_t stream)
{
if (topk_indices.scalar_type() == at::ScalarType::Int) {
vllm::moe::topkGatingSoftmaxKernelLauncher<int, ComputeType>(
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens, num_experts, topk, renormalize, stream);
} else if (topk_indices.scalar_type() == at::ScalarType::UInt32) {
vllm::moe::topkGatingSoftmaxKernelLauncher<uint32_t, ComputeType>(
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens, num_experts, topk, renormalize, stream);
} else {
TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long);
vllm::moe::topkGatingSoftmaxKernelLauncher<int64_t, ComputeType>(
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int64_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens, num_experts, topk, renormalize, stream);
}
}
void topk_softmax(
torch::Tensor& topk_weights, // [num_tokens, topk]
torch::Tensor& topk_indices, // [num_tokens, topk]
torch::Tensor& token_expert_indices, // [num_tokens, topk]
torch::Tensor& gating_output) // [num_tokens, num_experts]
torch::Tensor& gating_output, // [num_tokens, num_experts]
bool renormalize)
{
const int num_experts = gating_output.size(-1);
const auto num_tokens = gating_output.numel() / num_experts;
@ -534,45 +689,19 @@ void topk_softmax(
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
const auto workspace_options = gating_output.options().dtype(at::ScalarType::Float);
torch::Tensor softmax_workspace = torch::empty({workspace_size}, workspace_options);
if(topk_indices.scalar_type() == at::ScalarType::Int)
{
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
else if (topk_indices.scalar_type() == at::ScalarType::UInt32)
{
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
else {
TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long);
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int64_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
if (gating_output.scalar_type() == at::ScalarType::Float) {
dispatch_topk_softmax_launch<float>(gating_output, topk_weights, topk_indices,
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
} else if (gating_output.scalar_type() == at::ScalarType::Half) {
dispatch_topk_softmax_launch<__half>(gating_output, topk_weights, topk_indices,
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
} else if (gating_output.scalar_type() == at::ScalarType::BFloat16) {
dispatch_topk_softmax_launch<__nv_bfloat16>(gating_output, topk_weights, topk_indices,
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
} else {
TORCH_CHECK(false, "Unsupported gating_output data type: ", gating_output.scalar_type());
}
}

View File

@ -5,7 +5,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
"token_expert_indices, Tensor gating_output, bool renormalize) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
// Calculate the result of moe by summing up the partial results

View File

@ -1850,9 +1850,10 @@ def topk_softmax(
topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
) -> None:
torch.ops._moe_C.topk_softmax(
topk_weights, topk_ids, token_expert_indices, gating_output
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)

View File

@ -1074,9 +1074,8 @@ def vllm_topk_softmax(
topk_indices,
token_expert_indices,
gating_output,
renormalize,
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_indices
@ -1113,11 +1112,9 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device
)
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
topk_func = dispatch_topk_func()
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_ids, token_expert_indices