Remove CUDA 11 workarounds for CUB_SUPPORTS_SCAN_BY_KEY and CUB_SUPPORTS_UNIQUE_BY_KEY (#164637)

`CUB_SUPPORTS_SCAN_BY_KEY` and `CUB_SUPPORTS_UNIQUE_BY_KEY` are true since CUDA 12. This PR removes the old branches and source files.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164637
Approved by: https://github.com/ezyang
This commit is contained in:
Yuanyuan Chen
2025-10-18 20:05:50 +00:00
committed by PyTorch MergeBot
parent 1f43d17ce6
commit 35e51893bd
8 changed files with 1 additions and 209 deletions

View File

@ -177,7 +177,6 @@ inline void segmented_sort_pairs(
} }
} }
#if CUB_SUPPORTS_UNIQUE_BY_KEY()
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT> template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT>
inline void unique_by_key( inline void unique_by_key(
KeysInputIteratorT keys_in, ValuesInputIteratorT values_in, KeysInputIteratorT keys_in, ValuesInputIteratorT values_in,
@ -193,7 +192,6 @@ inline void unique_by_key(
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey, CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey,
keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream()); keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream());
} }
#endif
namespace impl { namespace impl {
@ -579,7 +577,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
#endif #endif
} }
#if CUB_SUPPORTS_SCAN_BY_KEY()
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT> template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) { inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) {
@ -607,7 +604,6 @@ inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT
#endif #endif
} }
#endif
template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT> template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT>
void unique(InputIteratorT input, OutputIteratorT output, void unique(InputIteratorT input, OutputIteratorT output,

View File

@ -28,22 +28,6 @@
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
#endif #endif
// cub support for UniqueByKey is added to cub 1.16 in:
// https://github.com/NVIDIA/cub/pull/405
#if CUB_VERSION >= 101600
#define CUB_SUPPORTS_UNIQUE_BY_KEY() true
#else
#define CUB_SUPPORTS_UNIQUE_BY_KEY() false
#endif
// cub support for scan by key is added to cub 1.15
// in https://github.com/NVIDIA/cub/pull/376
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
#else
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
#endif
// cub support for cub::FutureValue is added to cub 1.15 in: // cub support for cub::FutureValue is added to cub 1.15 in:
// https://github.com/NVIDIA/cub/pull/305 // https://github.com/NVIDIA/cub/pull/305
#if CUB_VERSION >= 101500 #if CUB_VERSION >= 101500

View File

@ -15,9 +15,7 @@
#include <ATen/native/cuda/block_reduce.cuh> #include <ATen/native/cuda/block_reduce.cuh>
#include <ATen/native/cuda/thread_constants.h> #include <ATen/native/cuda/thread_constants.h>
#if CUB_SUPPORTS_SCAN_BY_KEY()
#include <thrust/iterator/reverse_iterator.h> #include <thrust/iterator/reverse_iterator.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS #ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h> #include <ATen/Functions.h>
@ -240,10 +238,6 @@ __global__ void renorm_kernel(
} // anonymous namespace } // anonymous namespace
#if !CUB_SUPPORTS_SCAN_BY_KEY()
template<typename index_t>
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
#endif
Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_, Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_,
int64_t num_weights, int64_t padding_idx, int64_t num_weights, int64_t padding_idx,
@ -306,7 +300,6 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
if (scale_grad_by_freq) { if (scale_grad_by_freq) {
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
#if CUB_SUPPORTS_SCAN_BY_KEY()
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -333,11 +326,6 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
num_indices num_indices
); );
}); });
#else
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
});
#endif
} }
return embedding_backward_cuda_kernel(grad, orig_indices, return embedding_backward_cuda_kernel(grad, orig_indices,

View File

@ -10,9 +10,7 @@
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#if CUB_SUPPORTS_UNIQUE_BY_KEY()
#include <thrust/iterator/counting_iterator.h> #include <thrust/iterator/counting_iterator.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS #ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h> #include <ATen/Functions.h>
@ -196,18 +194,9 @@ __global__ void compute_num_of_partial_segments(const index_t *partials_per_segm
partials_per_segment_offset[num_of_segments-1]; partials_per_segment_offset[num_of_segments-1];
} }
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
__global__ void write_num_of_segments_for_legacy_thrust_path(int64_t *num_of_segments_ptr, int64_t num_of_segments) {
*num_of_segments_ptr = num_of_segments;
}
#endif
} // anon namespace } // anon namespace
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
template<typename index_t>
int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets);
#endif
Tensor embedding_backward_cuda_kernel( Tensor embedding_backward_cuda_kernel(
const Tensor &grad, const Tensor &grad,
@ -234,20 +223,12 @@ Tensor embedding_backward_cuda_kernel(
auto segment_offsets = at::empty({numel}, orig_indices.options()); auto segment_offsets = at::empty({numel}, orig_indices.options());
auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong)); auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong));
int64_t *num_of_segments_ptr = num_of_segments_tensor.mutable_data_ptr<int64_t>(); int64_t *num_of_segments_ptr = num_of_segments_tensor.mutable_data_ptr<int64_t>();
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets);
write_num_of_segments_for_legacy_thrust_path<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(num_of_segments_ptr, num_of_segments);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
#else
AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () { AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
cuda::cub::unique_by_key( cuda::cub::unique_by_key(
sorted_indices.const_data_ptr<index_t>(), thrust::make_counting_iterator(0), sorted_indices.const_data_ptr<index_t>(), thrust::make_counting_iterator(0),
segment_offsets.mutable_data_ptr<index_t>(), segment_offsets.mutable_data_ptr<index_t>(),
num_of_segments_ptr, sorted_indices.numel()); num_of_segments_ptr, sorted_indices.numel());
}); });
#endif
int64_t max_segments = std::min<int64_t>(numel, num_weights); int64_t max_segments = std::min<int64_t>(numel, num_weights);

View File

@ -31,16 +31,10 @@
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#if CUB_SUPPORTS_SCAN_BY_KEY()
#include <thrust/iterator/reverse_iterator.h> #include <thrust/iterator/reverse_iterator.h>
#endif
namespace at::native { namespace at::native {
#if !CUB_SUPPORTS_SCAN_BY_KEY()
template<typename index_t>
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
#endif
namespace { namespace {
@ -199,7 +193,6 @@ Tensor embedding_bag_backward_cuda_sum_avg(
if (scale_grad_by_freq) { if (scale_grad_by_freq) {
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
#if CUB_SUPPORTS_SCAN_BY_KEY()
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -226,11 +219,6 @@ Tensor embedding_bag_backward_cuda_sum_avg(
num_indices num_indices
); );
}); });
#else
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
});
#endif
} }
return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices, return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,
count, num_weights, padding_idx, mode == EmbeddingBagMode::MEAN, offset2bag, count, num_weights, padding_idx, mode == EmbeddingBagMode::MEAN, offset2bag,

View File

@ -1,90 +0,0 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/native/cuda/SortingCommon.cuh>
#include <ATen/cuda/cub_definitions.cuh>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty_like.h>
#endif
#include <ATen/cuda/ThrustAllocator.h>
#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <thrust/sort.h>
#include <thrust/unique.h>
#include <thrust/device_ptr.h>
#include <thrust/iterator/constant_iterator.h>
namespace at::native {
#if !CUB_SUPPORTS_SCAN_BY_KEY()
template<typename index_t>
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
at::cuda::ThrustAllocator allocator;
auto policy = thrust::cuda::par(allocator).on(stream);
auto num_indices = count.numel();
// Compute an increasing sequence per unique item in sortedIndices:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 1 2 3 1 2 1 1 2
auto sorted_data = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>());
auto count_data = thrust::device_ptr<index_t>(count.mutable_data_ptr<index_t>());
thrust::inclusive_scan_by_key(
policy,
sorted_data,
sorted_data + num_indices,
thrust::make_constant_iterator(1),
count_data
);
// Take the maximum of each count per unique key in reverse:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 3 3 3 2 2 1 2 2
thrust::inclusive_scan_by_key(
policy,
thrust::make_reverse_iterator(sorted_data + num_indices),
thrust::make_reverse_iterator(sorted_data),
thrust::make_reverse_iterator(count_data + num_indices),
thrust::make_reverse_iterator(count_data + num_indices),
thrust::equal_to<index_t>(),
thrust::maximum<index_t>()
);
}
template
void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count);
template
void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count);
#endif
template<typename index_t>
int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets) {
auto stream = at::cuda::getCurrentCUDAStream();
at::cuda::ThrustAllocator allocator;
auto policy = thrust::cuda::par(allocator).on(stream);
const ptrdiff_t numel = sorted_indices.numel();
auto sorted_indices_dev = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>());
auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto dummy_dev = thrust::device_ptr<index_t>(dummy.mutable_data_ptr<index_t>());
auto ends = thrust::unique_by_key_copy(
policy,
sorted_indices_dev,
sorted_indices_dev + numel,
thrust::make_counting_iterator(0),
dummy_dev,
thrust::device_ptr<index_t>(segment_offsets.mutable_data_ptr<index_t>()));
return thrust::get<0>(ends) - dummy_dev;
}
template
int64_t embedding_backward_cuda_kernel_unique_by_key<int>(const Tensor &sorted_indices, Tensor &segment_offsets);
template
int64_t embedding_backward_cuda_kernel_unique_by_key<int64_t>(const Tensor &sorted_indices, Tensor &segment_offsets);
} // namespace at::native

View File

@ -19,7 +19,6 @@
namespace at::native { namespace at::native {
// TODO: remove this when CUDA <11.6 is no longer supported
void topk_out_with_sort( void topk_out_with_sort(
const Tensor& self, const Tensor& self,
int64_t k, int64_t dim, bool largest, int64_t k, int64_t dim, bool largest,
@ -31,21 +30,12 @@ void topk_out_with_sort(
indices.copy_(sorted_indices.narrow(dim, 0, k)); indices.copy_(sorted_indices.narrow(dim, 0, k));
} }
// TODO: remove this when CUDA <11.6 is no longer supported
bool disable_sort_for_topk();
bool should_use_sort(const Tensor& self, int64_t dim) { bool should_use_sort(const Tensor& self, int64_t dim) {
#if defined(USE_ROCM) #if defined(USE_ROCM)
if (self.dtype() == kBool) return false; // Bool sort not supported in ROCm: https://github.com/pytorch/pytorch/issues/139972 if (self.dtype() == kBool) return false; // Bool sort not supported in ROCm: https://github.com/pytorch/pytorch/issues/139972
return (self.numel() >= 10000 && self.numel() == self.size(dim)); // based on the experiments in https://github.com/pytorch/pytorch/pull/146387 return (self.numel() >= 10000 && self.numel() == self.size(dim)); // based on the experiments in https://github.com/pytorch/pytorch/pull/146387
#else #else
if (disable_sort_for_topk()) return false; return false;
// This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/68632
if (self.dim() == 0) return false;
if (self.dtype() == kBool) return false; // Bool is not support by topk
int64_t slice_size = self.size(dim);
if (slice_size == 0) return false;
int64_t num_slices = self.numel() / slice_size;
return num_slices <= 10 && slice_size >= 100000;
#endif #endif
} }

View File

@ -21,11 +21,6 @@ using namespace at::native;
namespace at::native { namespace at::native {
// TODO: remove this when CUDA <11.6 is no longer supported
bool disable_sort_for_topk() {
return CUB_SUPPORTS_SCAN_BY_KEY();
}
namespace sbtopk { // single_block_topk namespace sbtopk { // single_block_topk
template <typename T> template <typename T>
@ -418,10 +413,6 @@ __global__ void computeBlockwiseWithinKCounts(
} }
__syncthreads(); __syncthreads();
#if !CUB_SUPPORTS_SCAN_BY_KEY()
return;
#endif
Bitwise desired_digit = at::cuda::Bitfield<Bitwise>::getBitfield(desired, current_bit, RADIX_BITS); Bitwise desired_digit = at::cuda::Bitfield<Bitwise>::getBitfield(desired, current_bit, RADIX_BITS);
// if largest, then only threads that has tidx > desired_digit are active // if largest, then only threads that has tidx > desired_digit are active
@ -477,7 +468,6 @@ __global__ void computeBlockwiseWithinKCounts(
} }
} }
#if CUB_SUPPORTS_SCAN_BY_KEY()
// Assumption: slice_size can not be larger than UINT32_MAX // Assumption: slice_size can not be larger than UINT32_MAX
template <typename Bitwise> template <typename Bitwise>
__global__ void computeBlockwiseKthCounts( __global__ void computeBlockwiseKthCounts(
@ -609,7 +599,6 @@ __global__ void gatherTopK(at::cuda::detail::TensorInfo<const T, IndexType> inpu
} }
} }
} }
#endif
int get_items_per_thread(uint64_t num_slices, uint64_t slice_size) { int get_items_per_thread(uint64_t num_slices, uint64_t slice_size) {
// occupancy of this kernel is limited by registers per threads // occupancy of this kernel is limited by registers per threads
@ -687,16 +676,12 @@ void launch(
uint32_t* digit_cum_sum = reinterpret_cast<uint32_t*>(digit_cum_sum_buffer.get()); uint32_t* digit_cum_sum = reinterpret_cast<uint32_t*>(digit_cum_sum_buffer.get());
AT_CUDA_CHECK(cudaMemsetAsync(digit_cum_sum, 0, numInputSlices * RADIX_DIGITS * sizeof(uint32_t), stream)); AT_CUDA_CHECK(cudaMemsetAsync(digit_cum_sum, 0, numInputSlices * RADIX_DIGITS * sizeof(uint32_t), stream));
#if CUB_SUPPORTS_SCAN_BY_KEY()
auto withinKCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t)); auto withinKCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t));
uint32_t* withinKCounts = reinterpret_cast<uint32_t*>(withinKCounts_buffer.get()); uint32_t* withinKCounts = reinterpret_cast<uint32_t*>(withinKCounts_buffer.get());
AT_CUDA_CHECK(cudaMemsetAsync(withinKCounts, 0, num_blocks * sizeof(uint32_t), stream)); AT_CUDA_CHECK(cudaMemsetAsync(withinKCounts, 0, num_blocks * sizeof(uint32_t), stream));
auto kthCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t)); auto kthCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t));
uint32_t* kthCounts = reinterpret_cast<uint32_t*>(kthCounts_buffer.get()); uint32_t* kthCounts = reinterpret_cast<uint32_t*>(kthCounts_buffer.get());
#else
uint32_t* withinKCounts = nullptr;
#endif
Bitwise desiredMask = 0; Bitwise desiredMask = 0;
dim3 grid; dim3 grid;
@ -743,7 +728,6 @@ void launch(
} }
desired = desired_in; desired = desired_in;
#if CUB_SUPPORTS_SCAN_BY_KEY()
computeBlockwiseKthCounts<Bitwise><<<std::min(((int64_t)numInputSlices + 255) / 256, (int64_t)1073741824), 256, 0, stream>>>( computeBlockwiseKthCounts<Bitwise><<<std::min(((int64_t)numInputSlices + 255) / 256, (int64_t)1073741824), 256, 0, stream>>>(
desired, counts, num_blocks, blocks_per_slice, kthCounts); desired, counts, num_blocks, blocks_per_slice, kthCounts);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
@ -759,28 +743,6 @@ void launch(
topK, topKWithinSliceStride, indices, indicesWithinSliceStride, items_per_thread, topK, topKWithinSliceStride, indices, indicesWithinSliceStride, items_per_thread,
blocks_per_slice, kthValues, withinKCounts, kthCounts, num_blocks); blocks_per_slice, kthValues, withinKCounts, kthCounts, num_blocks);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
#else
// Find topk values based on kth values
{
dim3 grid;
TORCH_INTERNAL_ASSERT(getGridFromTiles(numInputSlices, grid), "Too many slices for topk");
int warp_size = at::cuda::warp_size();
dim3 block(std::min(at::ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * (int64_t)warp_size, (int64_t)1024));
sbtopk::gatherTopK<T, IndexType, Dim, /* WithKthValues= */true><<<grid, block, 0, stream>>>(
input,
inputSliceSize,
outputSliceSize,
largest,
numInputSlices,
inputWithinSliceStride,
topK,
topKWithinSliceStride,
indices,
indicesWithinSliceStride,
kthValues);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#endif
} }
} // namespace mbtopk } // namespace mbtopk
@ -788,7 +750,6 @@ void launch(
bool should_use_multiblock(int64_t num_slices, int64_t slice_size) { bool should_use_multiblock(int64_t num_slices, int64_t slice_size) {
if (num_slices > std::numeric_limits<uint32_t>::max() || if (num_slices > std::numeric_limits<uint32_t>::max() ||
slice_size > std::numeric_limits<uint32_t>::max()) return false; slice_size > std::numeric_limits<uint32_t>::max()) return false;
#if CUB_SUPPORTS_SCAN_BY_KEY()
// This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/74267 // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/74267
return (num_slices <= 20 && slice_size >= 20000) || return (num_slices <= 20 && slice_size >= 20000) ||
(num_slices > 20 && num_slices <= 40 && slice_size >= 10000) || (num_slices > 20 && num_slices <= 40 && slice_size >= 10000) ||
@ -797,12 +758,6 @@ bool should_use_multiblock(int64_t num_slices, int64_t slice_size) {
(num_slices >= 200 && num_slices < 800 && slice_size >= 3000) || (num_slices >= 200 && num_slices < 800 && slice_size >= 3000) ||
(num_slices >= 800 && num_slices <= 4000 && slice_size >= 800) || (num_slices >= 800 && num_slices <= 4000 && slice_size >= 800) ||
(num_slices > 4000 && slice_size >= 400); (num_slices > 4000 && slice_size >= 400);
#else
// This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/71081
return (num_slices <= 400 && slice_size >= 5000) ||
(num_slices > 400 && num_slices < 4000 && slice_size >= 1000) ||
(num_slices >= 4000 && slice_size >= 300);
#endif
} }
void launch_gather_topk_kernel( void launch_gather_topk_kernel(