mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix][ROCm] Fix for warp_size uses on host (#21205)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
dde295a934
commit
90eeea8f85
@ -24,7 +24,7 @@
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
#include "cuda_compat.h"
|
||||
#include "../cuda_compat.h"
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
|
@ -16,9 +16,8 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "attention_kernels.cuh"
|
||||
#include "cuda_compat.h"
|
||||
#include "../cuda_compat.h"
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
@ -75,7 +74,7 @@ void paged_attention_v1_launcher(
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int padded_max_seq_len =
|
||||
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int logits_size = padded_max_seq_len * sizeof(float);
|
||||
|
@ -16,9 +16,8 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "attention_kernels.cuh"
|
||||
#include "cuda_compat.h"
|
||||
#include "../cuda_compat.h"
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
@ -79,7 +78,7 @@ void paged_attention_v2_launcher(
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
|
||||
int logits_size = PARTITION_SIZE * sizeof(float);
|
||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||
|
@ -4,8 +4,35 @@
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM) && defined(__GFX9__)
|
||||
#define WARP_SIZE 64
|
||||
#ifdef USE_ROCM
|
||||
struct Utils {
|
||||
static __host__ int get_warp_size() {
|
||||
static bool is_cached = false;
|
||||
static int result;
|
||||
|
||||
if (!is_cached) {
|
||||
int device_id;
|
||||
cudaDeviceProp deviceProp;
|
||||
cudaGetDevice(&device_id);
|
||||
cudaGetDeviceProperties(&deviceProp, device_id);
|
||||
|
||||
result = deviceProp.warpSize;
|
||||
is_cached = true;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static __device__ constexpr int get_warp_size() {
|
||||
#ifdef __GFX9__
|
||||
return 64;
|
||||
#else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
#define WARP_SIZE Utils::get_warp_size()
|
||||
#else
|
||||
#define WARP_SIZE 32
|
||||
#endif
|
||||
|
@ -190,8 +190,8 @@ __launch_bounds__(TPB) __global__ void moeTopK(
|
||||
2) This implementation assumes k is small, but will work for any k.
|
||||
*/
|
||||
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
|
||||
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
|
||||
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
@ -209,12 +209,12 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
|
||||
// Restrictions based on previous section.
|
||||
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
|
||||
static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
|
||||
static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
|
||||
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
|
||||
static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
|
||||
static_assert(THREADS_PER_ROW <= WARP_SIZE_PARAM, "THREADS_PER_ROW can be at most warp size");
|
||||
|
||||
// We have NUM_EXPERTS elements per row. We specialize for small #experts
|
||||
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
|
||||
static constexpr int ELTS_PER_WARP = WARP_SIZE_PARAM * VPT;
|
||||
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
|
||||
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
|
||||
|
||||
@ -393,41 +393,51 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
namespace detail
|
||||
{
|
||||
// Constructs some constants needed to partition the work across threads at compile time.
|
||||
template <int EXPERTS, int BYTES_PER_LDG>
|
||||
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM>
|
||||
struct TopkConstants
|
||||
{
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
||||
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, "");
|
||||
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM));
|
||||
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
||||
static const int ROWS_PER_WARP = WARP_SIZE_PARAM / THREADS_PER_ROW;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <int EXPERTS, int WARPS_PER_TB, typename IndType>
|
||||
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, typename IndType>
|
||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
|
||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||
{
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
|
||||
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||
|
||||
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
|
||||
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
|
||||
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
|
||||
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
||||
}
|
||||
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, \
|
||||
stream);
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
|
||||
switch (warpSize) { \
|
||||
case 32: \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
||||
break; \
|
||||
case 64: \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported warp size: ", warpSize); \
|
||||
}
|
||||
|
||||
template <typename IndType>
|
||||
void topkGatingSoftmaxKernelLauncher(
|
||||
@ -441,6 +451,7 @@ void topkGatingSoftmaxKernelLauncher(
|
||||
const int topk,
|
||||
cudaStream_t stream) {
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
auto warpSize = WARP_SIZE;
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
#include <cmath>
|
||||
#include "core/math.hpp"
|
||||
#include "cuda_compat.h"
|
||||
#include "../cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "quantization/fp8/common.cuh"
|
||||
|
@ -4,7 +4,7 @@
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "../../cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "ggml-common.h"
|
||||
|
@ -19,7 +19,7 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <hip/hip_fp8.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
#include "cuda_compat.h"
|
||||
#include "../cuda_compat.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include "../attention/dtype_fp8.cuh"
|
||||
|
@ -9,7 +9,7 @@
|
||||
#include <stdexcept>
|
||||
#include <algorithm>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "../cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#include "quantization/fp8/common.cuh"
|
||||
|
||||
|
Reference in New Issue
Block a user