|
|
|
@ -16,12 +16,23 @@
|
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
#include <type_traits>
|
|
|
|
|
#include <torch/all.h>
|
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
|
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
|
#include "../cuda_compat.h"
|
|
|
|
|
#include "../cub_helpers.h"
|
|
|
|
|
|
|
|
|
|
#ifndef USE_ROCM
|
|
|
|
|
#include <cuda_bf16.h>
|
|
|
|
|
#include <cuda_fp16.h>
|
|
|
|
|
#else
|
|
|
|
|
#include <hip/hip_bf16.h>
|
|
|
|
|
#include <hip/hip_fp16.h>
|
|
|
|
|
typedef __hip_bfloat16 __nv_bfloat16;
|
|
|
|
|
typedef __hip_bfloat162 __nv_bfloat162;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
|
|
|
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
|
|
|
|
|
|
|
|
@ -36,16 +47,27 @@ template <
|
|
|
|
|
/// Alignment requirement in bytes
|
|
|
|
|
int Alignment = sizeof(T) * N
|
|
|
|
|
>
|
|
|
|
|
class alignas(Alignment) AlignedArray {
|
|
|
|
|
float data[N];
|
|
|
|
|
struct alignas(Alignment) AlignedArray {
|
|
|
|
|
T data[N];
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__device__ __forceinline__ float toFloat(T value) {
|
|
|
|
|
if constexpr (std::is_same_v<T, float>) {
|
|
|
|
|
return value;
|
|
|
|
|
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
|
|
|
|
|
return __bfloat162float(value);
|
|
|
|
|
} else if constexpr (std::is_same_v<T, __half>) {
|
|
|
|
|
return __half2float(value);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ====================== Softmax things ===============================
|
|
|
|
|
// We have our own implementation of softmax here so we can support transposing the output
|
|
|
|
|
// in the softmax kernel when we extend this module to support expert-choice routing.
|
|
|
|
|
template <int TPB>
|
|
|
|
|
template <int TPB, typename InputType>
|
|
|
|
|
__launch_bounds__(TPB) __global__
|
|
|
|
|
void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
|
|
|
|
|
void moeSoftmax(const InputType* input, const bool* finished, float* output, const int num_cols)
|
|
|
|
|
{
|
|
|
|
|
using BlockReduce = cub::BlockReduce<float, TPB>;
|
|
|
|
|
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
|
|
|
@ -66,7 +88,8 @@ __launch_bounds__(TPB) __global__
|
|
|
|
|
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
|
|
|
|
{
|
|
|
|
|
const int idx = thread_row_offset + ii;
|
|
|
|
|
threadData = max(static_cast<float>(input[idx]), threadData);
|
|
|
|
|
const float val = toFloat(input[idx]);
|
|
|
|
|
threadData = max(val, threadData);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp());
|
|
|
|
@ -81,7 +104,8 @@ __launch_bounds__(TPB) __global__
|
|
|
|
|
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
|
|
|
|
{
|
|
|
|
|
const int idx = thread_row_offset + ii;
|
|
|
|
|
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
|
|
|
|
const float val = toFloat(input[idx]);
|
|
|
|
|
threadData += expf(val - float_max);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp());
|
|
|
|
@ -95,8 +119,9 @@ __launch_bounds__(TPB) __global__
|
|
|
|
|
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
|
|
|
|
{
|
|
|
|
|
const int idx = thread_row_offset + ii;
|
|
|
|
|
const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
|
|
|
|
output[idx] = val;
|
|
|
|
|
const float val = toFloat(input[idx]);
|
|
|
|
|
const float softmax_val = expf(val - float_max) * normalizing_factor;
|
|
|
|
|
output[idx] = softmax_val;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -110,7 +135,8 @@ __launch_bounds__(TPB) __global__ void moeTopK(
|
|
|
|
|
const int num_experts,
|
|
|
|
|
const int k,
|
|
|
|
|
const int start_expert,
|
|
|
|
|
const int end_expert)
|
|
|
|
|
const int end_expert,
|
|
|
|
|
const bool renormalize)
|
|
|
|
|
{
|
|
|
|
|
|
|
|
|
|
using cub_kvp = cub::KeyValuePair<int, float>;
|
|
|
|
@ -125,6 +151,7 @@ __launch_bounds__(TPB) __global__ void moeTopK(
|
|
|
|
|
|
|
|
|
|
const bool row_is_active = finished ? !finished[block_row] : true;
|
|
|
|
|
const int thread_read_offset = blockIdx.x * num_experts;
|
|
|
|
|
float selected_sum = 0.f;
|
|
|
|
|
for (int k_idx = 0; k_idx < k; ++k_idx)
|
|
|
|
|
{
|
|
|
|
|
thread_kvp.key = 0;
|
|
|
|
@ -163,9 +190,23 @@ __launch_bounds__(TPB) __global__ void moeTopK(
|
|
|
|
|
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
|
|
|
|
|
assert(indices[idx] >= 0);
|
|
|
|
|
source_rows[idx] = k_idx * num_rows + block_row;
|
|
|
|
|
if (renormalize) {
|
|
|
|
|
selected_sum += result_kvp.value;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Renormalize the k weights for this row to sum to 1, if requested.
|
|
|
|
|
if (renormalize) {
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
|
|
|
|
|
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
|
|
|
|
const int idx = k * block_row + k_idx;
|
|
|
|
|
output[idx] = output[idx] / denom;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ====================== TopK softmax things ===============================
|
|
|
|
@ -184,21 +225,30 @@ __launch_bounds__(TPB) __global__ void moeTopK(
|
|
|
|
|
2) This implementation assumes k is small, but will work for any k.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType>
|
|
|
|
|
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType, typename InputType = float>
|
|
|
|
|
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
|
|
|
|
|
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
|
|
|
|
|
int* source_rows, const int k, const int start_expert, const int end_expert)
|
|
|
|
|
void topkGatingSoftmax(const InputType* input, const bool* finished, float* output, const int num_rows, IndType* indices,
|
|
|
|
|
int* source_rows, const int k, const int start_expert, const int end_expert, const bool renormalize)
|
|
|
|
|
{
|
|
|
|
|
static_assert(std::is_same_v<InputType, float> || std::is_same_v<InputType, __nv_bfloat16> ||
|
|
|
|
|
std::is_same_v<InputType, __half>,
|
|
|
|
|
"InputType must be float, __nv_bfloat16, or __half");
|
|
|
|
|
|
|
|
|
|
// We begin by enforcing compile time assertions and setting up compile time constants.
|
|
|
|
|
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
|
|
|
|
|
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
|
|
|
|
|
|
|
|
|
// Number of bytes each thread pulls in per load
|
|
|
|
|
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
|
|
|
|
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
|
|
|
|
|
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
|
|
|
|
|
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
|
|
|
|
|
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
|
|
|
|
|
|
|
|
|
|
if constexpr (std::is_same_v<InputType, __nv_bfloat16> || std::is_same_v<InputType, __half>) {
|
|
|
|
|
static_assert(ELTS_PER_LDG == 1 || ELTS_PER_LDG % 2 == 0,
|
|
|
|
|
"ELTS_PER_LDG must be 1 or even for 16-bit conversion");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Restrictions based on previous section.
|
|
|
|
|
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
|
|
|
|
|
static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
|
|
|
|
@ -236,27 +286,71 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
|
|
|
|
|
|
|
|
|
|
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
|
|
|
|
|
// row it will read.
|
|
|
|
|
const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
|
|
|
|
const InputType* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
|
|
|
|
|
|
|
|
|
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
|
|
|
|
|
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
|
|
|
|
|
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
|
|
|
|
|
const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
|
|
|
|
|
|
|
|
|
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
|
|
|
|
|
// this can support all powers of 2 up to 16.
|
|
|
|
|
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
|
|
|
|
|
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
|
|
|
|
|
using AccessType = AlignedArray<float, ELTS_PER_LDG>;
|
|
|
|
|
const InputType* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
|
|
|
|
|
|
|
|
|
// Finally, we pull in the data from global mem
|
|
|
|
|
float row_chunk[VPT];
|
|
|
|
|
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
|
|
|
|
|
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
|
|
|
|
|
|
|
|
|
|
// NOTE(zhuhaoran): dispatch different input types loading, BF16/FP16 convert to float
|
|
|
|
|
if constexpr (std::is_same_v<InputType, float>) {
|
|
|
|
|
using VecType = AlignedArray<float, ELTS_PER_LDG>;
|
|
|
|
|
VecType* row_chunk_vec_ptr = reinterpret_cast<VecType*>(&row_chunk);
|
|
|
|
|
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
|
|
|
|
|
{
|
|
|
|
|
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
|
|
|
|
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
|
|
|
|
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
|
|
|
|
}
|
|
|
|
|
} else if constexpr (std::is_same_v<InputType, __nv_bfloat16>) {
|
|
|
|
|
if constexpr (ELTS_PER_LDG >= 2) {
|
|
|
|
|
using VecType = AlignedArray<__nv_bfloat16, ELTS_PER_LDG>;
|
|
|
|
|
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
|
|
|
|
|
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
|
|
|
|
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
|
|
|
|
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
|
|
|
|
|
row_chunk_f2[base_idx_f2 + jj] = __bfloat1622float2(
|
|
|
|
|
*reinterpret_cast<const __nv_bfloat162*>(vec.data + jj * 2)
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else { // ELTS_PER_LDG == 1
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
|
|
|
|
const __nv_bfloat16* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
|
|
|
|
|
row_chunk[ii] = __bfloat162float(*scalar_ptr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if constexpr (std::is_same_v<InputType, __half>) {
|
|
|
|
|
if constexpr (ELTS_PER_LDG >= 2) {
|
|
|
|
|
using VecType = AlignedArray<__half, ELTS_PER_LDG>;
|
|
|
|
|
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
|
|
|
|
|
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
|
|
|
|
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
|
|
|
|
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
|
|
|
|
|
row_chunk_f2[base_idx_f2 + jj] = __half22float2(
|
|
|
|
|
*reinterpret_cast<const __half2*>(vec.data + jj * 2)
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else { // ELTS_PER_LDG == 1
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
|
|
|
|
const __half* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
|
|
|
|
|
row_chunk[ii] = __half2float(*scalar_ptr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
|
|
|
|
@ -310,6 +404,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
|
|
|
|
|
int start_col = first_elt_read_by_thread;
|
|
|
|
|
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
|
|
|
|
|
|
|
|
|
|
float selected_sum = 0.f;
|
|
|
|
|
for (int k_idx = 0; k_idx < k; ++k_idx)
|
|
|
|
|
{
|
|
|
|
|
// First, each thread does the local argmax
|
|
|
|
@ -363,6 +458,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
|
|
|
|
|
output[idx] = max_val;
|
|
|
|
|
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
|
|
|
|
|
source_rows[idx] = k_idx * num_rows + thread_row;
|
|
|
|
|
if (renormalize) {
|
|
|
|
|
selected_sum += max_val;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
|
|
|
|
@ -380,15 +478,28 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Renormalize the k weights for this row to sum to 1, if requested.
|
|
|
|
|
if (renormalize) {
|
|
|
|
|
if (thread_group_idx == 0)
|
|
|
|
|
{
|
|
|
|
|
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
|
|
|
|
|
for (int k_idx = 0; k_idx < k; ++k_idx)
|
|
|
|
|
{
|
|
|
|
|
const int idx = k * thread_row + k_idx;
|
|
|
|
|
output[idx] = output[idx] / denom;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
namespace detail
|
|
|
|
|
{
|
|
|
|
|
// Constructs some constants needed to partition the work across threads at compile time.
|
|
|
|
|
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM>
|
|
|
|
|
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename InputType>
|
|
|
|
|
struct TopkConstants
|
|
|
|
|
{
|
|
|
|
|
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
|
|
|
|
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
|
|
|
|
|
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, "");
|
|
|
|
|
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM));
|
|
|
|
|
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
|
|
|
@ -397,20 +508,21 @@ struct TopkConstants
|
|
|
|
|
};
|
|
|
|
|
} // namespace detail
|
|
|
|
|
|
|
|
|
|
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType>
|
|
|
|
|
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
|
|
|
|
|
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
|
|
|
|
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType, typename InputType>
|
|
|
|
|
void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finished, float* output, IndType* indices,
|
|
|
|
|
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, const bool renormalize,
|
|
|
|
|
cudaStream_t stream)
|
|
|
|
|
{
|
|
|
|
|
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
|
|
|
|
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
|
|
|
|
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS);
|
|
|
|
|
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM, InputType>;
|
|
|
|
|
static constexpr int VPT = Constants::VPT;
|
|
|
|
|
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
|
|
|
|
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
|
|
|
|
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
|
|
|
|
|
|
|
|
|
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
|
|
|
|
|
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM><<<num_blocks, block_dim, 0, stream>>>(
|
|
|
|
|
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
|
|
|
|
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM, IndType, InputType><<<num_blocks, block_dim, 0, stream>>>(
|
|
|
|
|
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert, renormalize);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifndef USE_ROCM
|
|
|
|
@ -418,26 +530,26 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
|
|
|
|
static_assert(WARP_SIZE == 32, \
|
|
|
|
|
"Unsupported warp size. Only 32 is supported for CUDA"); \
|
|
|
|
|
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
|
|
|
|
|
gating_output, nullptr, topk_weights, topk_indices, \
|
|
|
|
|
token_expert_indices, num_tokens, topk, 0, num_experts, stream);
|
|
|
|
|
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
|
|
|
|
|
num_tokens, topk, 0, num_experts, renormalize, stream);
|
|
|
|
|
#else
|
|
|
|
|
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
|
|
|
|
|
if (WARP_SIZE == 64) { \
|
|
|
|
|
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES>( \
|
|
|
|
|
gating_output, nullptr, topk_weights, topk_indices, \
|
|
|
|
|
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
|
|
|
|
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
|
|
|
|
|
num_tokens, topk, 0, num_experts, renormalize, stream); \
|
|
|
|
|
} else if (WARP_SIZE == 32) { \
|
|
|
|
|
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES>( \
|
|
|
|
|
gating_output, nullptr, topk_weights, topk_indices, \
|
|
|
|
|
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
|
|
|
|
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
|
|
|
|
|
num_tokens, topk, 0, num_experts, renormalize, stream); \
|
|
|
|
|
} else { \
|
|
|
|
|
assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
template <typename IndType>
|
|
|
|
|
template <typename IndType, typename InputType>
|
|
|
|
|
void topkGatingSoftmaxKernelLauncher(
|
|
|
|
|
const float* gating_output,
|
|
|
|
|
const InputType* gating_output,
|
|
|
|
|
float* topk_weights,
|
|
|
|
|
IndType* topk_indices,
|
|
|
|
|
int* token_expert_indices,
|
|
|
|
@ -445,11 +557,15 @@ void topkGatingSoftmaxKernelLauncher(
|
|
|
|
|
const int num_tokens,
|
|
|
|
|
const int num_experts,
|
|
|
|
|
const int topk,
|
|
|
|
|
const bool renormalize,
|
|
|
|
|
cudaStream_t stream) {
|
|
|
|
|
static constexpr int WARPS_PER_TB = 4;
|
|
|
|
|
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
|
|
|
|
|
#ifndef USE_ROCM
|
|
|
|
|
static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8;
|
|
|
|
|
// for bfloat16 dtype, we need 4 bytes loading to make sure num_experts
|
|
|
|
|
// elements can be loaded by a warp
|
|
|
|
|
static constexpr int BYTES_PER_LDG_MULTIPLE_64 =
|
|
|
|
|
(std::is_same_v<InputType, __nv_bfloat16> || std::is_same_v<InputType, __half>) ? 4 : 8;
|
|
|
|
|
#endif
|
|
|
|
|
switch (num_experts) {
|
|
|
|
|
case 1:
|
|
|
|
@ -506,11 +622,11 @@ void topkGatingSoftmaxKernelLauncher(
|
|
|
|
|
TORCH_CHECK(softmax_workspace != nullptr,
|
|
|
|
|
"softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64.");
|
|
|
|
|
static constexpr int TPB = 256;
|
|
|
|
|
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
|
|
|
|
|
moeSoftmax<TPB, InputType><<<num_tokens, TPB, 0, stream>>>(
|
|
|
|
|
gating_output, nullptr, softmax_workspace, num_experts);
|
|
|
|
|
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
|
|
|
|
|
softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices,
|
|
|
|
|
num_experts, topk, 0, num_experts);
|
|
|
|
|
num_experts, topk, 0, num_experts, renormalize);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -518,11 +634,50 @@ void topkGatingSoftmaxKernelLauncher(
|
|
|
|
|
} // namespace moe
|
|
|
|
|
} // namespace vllm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<typename ComputeType>
|
|
|
|
|
void dispatch_topk_softmax_launch(
|
|
|
|
|
torch::Tensor& gating_output,
|
|
|
|
|
torch::Tensor& topk_weights,
|
|
|
|
|
torch::Tensor& topk_indices,
|
|
|
|
|
torch::Tensor& token_expert_indices,
|
|
|
|
|
torch::Tensor& softmax_workspace,
|
|
|
|
|
int num_tokens, int num_experts, int topk, bool renormalize, cudaStream_t stream)
|
|
|
|
|
{
|
|
|
|
|
if (topk_indices.scalar_type() == at::ScalarType::Int) {
|
|
|
|
|
vllm::moe::topkGatingSoftmaxKernelLauncher<int, ComputeType>(
|
|
|
|
|
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
|
|
|
|
|
topk_weights.data_ptr<float>(),
|
|
|
|
|
topk_indices.data_ptr<int>(),
|
|
|
|
|
token_expert_indices.data_ptr<int>(),
|
|
|
|
|
softmax_workspace.data_ptr<float>(),
|
|
|
|
|
num_tokens, num_experts, topk, renormalize, stream);
|
|
|
|
|
} else if (topk_indices.scalar_type() == at::ScalarType::UInt32) {
|
|
|
|
|
vllm::moe::topkGatingSoftmaxKernelLauncher<uint32_t, ComputeType>(
|
|
|
|
|
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
|
|
|
|
|
topk_weights.data_ptr<float>(),
|
|
|
|
|
topk_indices.data_ptr<uint32_t>(),
|
|
|
|
|
token_expert_indices.data_ptr<int>(),
|
|
|
|
|
softmax_workspace.data_ptr<float>(),
|
|
|
|
|
num_tokens, num_experts, topk, renormalize, stream);
|
|
|
|
|
} else {
|
|
|
|
|
TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long);
|
|
|
|
|
vllm::moe::topkGatingSoftmaxKernelLauncher<int64_t, ComputeType>(
|
|
|
|
|
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
|
|
|
|
|
topk_weights.data_ptr<float>(),
|
|
|
|
|
topk_indices.data_ptr<int64_t>(),
|
|
|
|
|
token_expert_indices.data_ptr<int>(),
|
|
|
|
|
softmax_workspace.data_ptr<float>(),
|
|
|
|
|
num_tokens, num_experts, topk, renormalize, stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void topk_softmax(
|
|
|
|
|
torch::Tensor& topk_weights, // [num_tokens, topk]
|
|
|
|
|
torch::Tensor& topk_indices, // [num_tokens, topk]
|
|
|
|
|
torch::Tensor& token_expert_indices, // [num_tokens, topk]
|
|
|
|
|
torch::Tensor& gating_output) // [num_tokens, num_experts]
|
|
|
|
|
torch::Tensor& gating_output, // [num_tokens, num_experts]
|
|
|
|
|
bool renormalize)
|
|
|
|
|
{
|
|
|
|
|
const int num_experts = gating_output.size(-1);
|
|
|
|
|
const auto num_tokens = gating_output.numel() / num_experts;
|
|
|
|
@ -534,45 +689,19 @@ void topk_softmax(
|
|
|
|
|
|
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
|
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
|
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
|
|
|
|
|
const auto workspace_options = gating_output.options().dtype(at::ScalarType::Float);
|
|
|
|
|
torch::Tensor softmax_workspace = torch::empty({workspace_size}, workspace_options);
|
|
|
|
|
|
|
|
|
|
if(topk_indices.scalar_type() == at::ScalarType::Int)
|
|
|
|
|
{
|
|
|
|
|
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
|
|
|
|
gating_output.data_ptr<float>(),
|
|
|
|
|
topk_weights.data_ptr<float>(),
|
|
|
|
|
topk_indices.data_ptr<int>(),
|
|
|
|
|
token_expert_indices.data_ptr<int>(),
|
|
|
|
|
softmax_workspace.data_ptr<float>(),
|
|
|
|
|
num_tokens,
|
|
|
|
|
num_experts,
|
|
|
|
|
topk,
|
|
|
|
|
stream);
|
|
|
|
|
}
|
|
|
|
|
else if (topk_indices.scalar_type() == at::ScalarType::UInt32)
|
|
|
|
|
{
|
|
|
|
|
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
|
|
|
|
gating_output.data_ptr<float>(),
|
|
|
|
|
topk_weights.data_ptr<float>(),
|
|
|
|
|
topk_indices.data_ptr<uint32_t>(),
|
|
|
|
|
token_expert_indices.data_ptr<int>(),
|
|
|
|
|
softmax_workspace.data_ptr<float>(),
|
|
|
|
|
num_tokens,
|
|
|
|
|
num_experts,
|
|
|
|
|
topk,
|
|
|
|
|
stream);
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long);
|
|
|
|
|
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
|
|
|
|
gating_output.data_ptr<float>(),
|
|
|
|
|
topk_weights.data_ptr<float>(),
|
|
|
|
|
topk_indices.data_ptr<int64_t>(),
|
|
|
|
|
token_expert_indices.data_ptr<int>(),
|
|
|
|
|
softmax_workspace.data_ptr<float>(),
|
|
|
|
|
num_tokens,
|
|
|
|
|
num_experts,
|
|
|
|
|
topk,
|
|
|
|
|
stream);
|
|
|
|
|
if (gating_output.scalar_type() == at::ScalarType::Float) {
|
|
|
|
|
dispatch_topk_softmax_launch<float>(gating_output, topk_weights, topk_indices,
|
|
|
|
|
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
|
|
|
|
|
} else if (gating_output.scalar_type() == at::ScalarType::Half) {
|
|
|
|
|
dispatch_topk_softmax_launch<__half>(gating_output, topk_weights, topk_indices,
|
|
|
|
|
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
|
|
|
|
|
} else if (gating_output.scalar_type() == at::ScalarType::BFloat16) {
|
|
|
|
|
dispatch_topk_softmax_launch<__nv_bfloat16>(gating_output, topk_weights, topk_indices,
|
|
|
|
|
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
|
|
|
|
|
} else {
|
|
|
|
|
TORCH_CHECK(false, "Unsupported gating_output data type: ", gating_output.scalar_type());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|