mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Use macro guard CUDA functions for back compatibility in grouped_topk_kernel.cu (#25346)
Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
@ -418,6 +418,15 @@ __device__ inline T neg_inf() {
|
||||
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline bool is_finite(const T val) {
|
||||
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
|
||||
return cuda::std::isfinite(val);
|
||||
#else
|
||||
return isfinite(cuda_cast<float, T>(val));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void topk_with_k2(T* output, T const* input,
|
||||
cg::thread_block_tile<32> const& tile,
|
||||
@ -533,7 +542,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
// calculate group_idx
|
||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||
// The check is necessary to avoid abnormal input
|
||||
if (lane_id < n_group && cuda::std::isfinite(group_scores[lane_id])) {
|
||||
if (lane_id < n_group && is_finite(group_scores[lane_id])) {
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
@ -568,11 +577,10 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
int32_t offset = i_group * num_experts_per_group;
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates =
|
||||
(i < num_experts_per_group) &&
|
||||
cuda::std::isfinite(scores_with_bias[offset + i])
|
||||
? scores_with_bias[offset + i]
|
||||
: neg_inf<T>();
|
||||
T candidates = (i < num_experts_per_group) &&
|
||||
is_finite(scores_with_bias[offset + i])
|
||||
? scores_with_bias[offset + i]
|
||||
: neg_inf<T>();
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
|
Reference in New Issue
Block a user