From 35e51893bd2ee2966503ed5f426e2323328a9a0b Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Sat, 18 Oct 2025 20:05:50 +0000 Subject: [PATCH] Remove CUDA 11 workarounds for CUB_SUPPORTS_SCAN_BY_KEY and CUB_SUPPORTS_UNIQUE_BY_KEY (#164637) `CUB_SUPPORTS_SCAN_BY_KEY` and `CUB_SUPPORTS_UNIQUE_BY_KEY` are true since CUDA 12. This PR removes the old branches and source files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164637 Approved by: https://github.com/ezyang --- aten/src/ATen/cuda/cub.cuh | 4 - aten/src/ATen/cuda/cub_definitions.cuh | 16 ---- aten/src/ATen/native/cuda/Embedding.cu | 12 --- .../native/cuda/EmbeddingBackwardKernel.cu | 19 ---- aten/src/ATen/native/cuda/EmbeddingBag.cu | 12 --- .../ATen/native/cuda/LegacyThrustHelpers.cu | 90 ------------------- aten/src/ATen/native/cuda/TensorTopK.cpp | 12 +-- aten/src/ATen/native/cuda/TensorTopK.cu | 45 ---------- 8 files changed, 1 insertion(+), 209 deletions(-) delete mode 100644 aten/src/ATen/native/cuda/LegacyThrustHelpers.cu diff --git a/aten/src/ATen/cuda/cub.cuh b/aten/src/ATen/cuda/cub.cuh index 23a3ff8c8958..7828c3917fc4 100644 --- a/aten/src/ATen/cuda/cub.cuh +++ b/aten/src/ATen/cuda/cub.cuh @@ -177,7 +177,6 @@ inline void segmented_sort_pairs( } } -#if CUB_SUPPORTS_UNIQUE_BY_KEY() template inline void unique_by_key( KeysInputIteratorT keys_in, ValuesInputIteratorT values_in, @@ -193,7 +192,6 @@ inline void unique_by_key( CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey, keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream()); } -#endif namespace impl { @@ -579,7 +577,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT #endif } -#if CUB_SUPPORTS_SCAN_BY_KEY() template inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) { @@ -607,7 +604,6 @@ inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT #endif } -#endif template void unique(InputIteratorT input, OutputIteratorT output, diff --git a/aten/src/ATen/cuda/cub_definitions.cuh b/aten/src/ATen/cuda/cub_definitions.cuh index b80951269209..0d76ae6e8dcf 100644 --- a/aten/src/ATen/cuda/cub_definitions.cuh +++ b/aten/src/ATen/cuda/cub_definitions.cuh @@ -28,22 +28,6 @@ #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false #endif -// cub support for UniqueByKey is added to cub 1.16 in: -// https://github.com/NVIDIA/cub/pull/405 -#if CUB_VERSION >= 101600 -#define CUB_SUPPORTS_UNIQUE_BY_KEY() true -#else -#define CUB_SUPPORTS_UNIQUE_BY_KEY() false -#endif - -// cub support for scan by key is added to cub 1.15 -// in https://github.com/NVIDIA/cub/pull/376 -#if CUB_VERSION >= 101500 -#define CUB_SUPPORTS_SCAN_BY_KEY() 1 -#else -#define CUB_SUPPORTS_SCAN_BY_KEY() 0 -#endif - // cub support for cub::FutureValue is added to cub 1.15 in: // https://github.com/NVIDIA/cub/pull/305 #if CUB_VERSION >= 101500 diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index adc300a5a9ef..65b0e1441de7 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -15,9 +15,7 @@ #include #include -#if CUB_SUPPORTS_SCAN_BY_KEY() #include -#endif #ifndef AT_PER_OPERATOR_HEADERS #include @@ -240,10 +238,6 @@ __global__ void renorm_kernel( } // anonymous namespace -#if !CUB_SUPPORTS_SCAN_BY_KEY() -template -void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); -#endif Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_, int64_t num_weights, int64_t padding_idx, @@ -306,7 +300,6 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice if (scale_grad_by_freq) { count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); -#if CUB_SUPPORTS_SCAN_BY_KEY() AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -333,11 +326,6 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice num_indices ); }); -#else - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { - embedding_dense_backward_cuda_scan(sorted_indices, count); - }); -#endif } return embedding_backward_cuda_kernel(grad, orig_indices, diff --git a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu index 4f67696bd022..6ce419137345 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu @@ -10,9 +10,7 @@ #include -#if CUB_SUPPORTS_UNIQUE_BY_KEY() #include -#endif #ifndef AT_PER_OPERATOR_HEADERS #include @@ -196,18 +194,9 @@ __global__ void compute_num_of_partial_segments(const index_t *partials_per_segm partials_per_segment_offset[num_of_segments-1]; } -#if !CUB_SUPPORTS_UNIQUE_BY_KEY() -__global__ void write_num_of_segments_for_legacy_thrust_path(int64_t *num_of_segments_ptr, int64_t num_of_segments) { - *num_of_segments_ptr = num_of_segments; -} -#endif } // anon namespace -#if !CUB_SUPPORTS_UNIQUE_BY_KEY() -template -int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets); -#endif Tensor embedding_backward_cuda_kernel( const Tensor &grad, @@ -234,20 +223,12 @@ Tensor embedding_backward_cuda_kernel( auto segment_offsets = at::empty({numel}, orig_indices.options()); auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong)); int64_t *num_of_segments_ptr = num_of_segments_tensor.mutable_data_ptr(); -#if !CUB_SUPPORTS_UNIQUE_BY_KEY() - AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () { - int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key(sorted_indices, segment_offsets); - write_num_of_segments_for_legacy_thrust_path<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(num_of_segments_ptr, num_of_segments); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -#else AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () { cuda::cub::unique_by_key( sorted_indices.const_data_ptr(), thrust::make_counting_iterator(0), segment_offsets.mutable_data_ptr(), num_of_segments_ptr, sorted_indices.numel()); }); -#endif int64_t max_segments = std::min(numel, num_weights); diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index fb92c7488a15..ab3747df031e 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -31,16 +31,10 @@ #include -#if CUB_SUPPORTS_SCAN_BY_KEY() #include -#endif namespace at::native { -#if !CUB_SUPPORTS_SCAN_BY_KEY() -template -void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); -#endif namespace { @@ -199,7 +193,6 @@ Tensor embedding_bag_backward_cuda_sum_avg( if (scale_grad_by_freq) { count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); -#if CUB_SUPPORTS_SCAN_BY_KEY() AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -226,11 +219,6 @@ Tensor embedding_bag_backward_cuda_sum_avg( num_indices ); }); -#else - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { - embedding_dense_backward_cuda_scan(sorted_indices, count); - }); -#endif } return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices, count, num_weights, padding_idx, mode == EmbeddingBagMode::MEAN, offset2bag, diff --git a/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu b/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu deleted file mode 100644 index 6a549ac3d62c..000000000000 --- a/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu +++ /dev/null @@ -1,90 +0,0 @@ -#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include -#include -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#else -#include -#endif - -#include -#include -#include -#include -#include -#include -#include - -namespace at::native { - -#if !CUB_SUPPORTS_SCAN_BY_KEY() - -template -void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - at::cuda::ThrustAllocator allocator; - auto policy = thrust::cuda::par(allocator).on(stream); - - auto num_indices = count.numel(); - - // Compute an increasing sequence per unique item in sortedIndices: - // sorted: 2 5 5 5 7 7 8 9 9 - // count: 1 1 2 3 1 2 1 1 2 - auto sorted_data = thrust::device_ptr(sorted_indices.const_data_ptr()); - auto count_data = thrust::device_ptr(count.mutable_data_ptr()); - thrust::inclusive_scan_by_key( - policy, - sorted_data, - sorted_data + num_indices, - thrust::make_constant_iterator(1), - count_data - ); - - // Take the maximum of each count per unique key in reverse: - // sorted: 2 5 5 5 7 7 8 9 9 - // count: 1 3 3 3 2 2 1 2 2 - thrust::inclusive_scan_by_key( - policy, - thrust::make_reverse_iterator(sorted_data + num_indices), - thrust::make_reverse_iterator(sorted_data), - thrust::make_reverse_iterator(count_data + num_indices), - thrust::make_reverse_iterator(count_data + num_indices), - thrust::equal_to(), - thrust::maximum() - ); -} - -template -void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); -template -void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); - -#endif - -template -int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets) { - auto stream = at::cuda::getCurrentCUDAStream(); - at::cuda::ThrustAllocator allocator; - auto policy = thrust::cuda::par(allocator).on(stream); - const ptrdiff_t numel = sorted_indices.numel(); - auto sorted_indices_dev = thrust::device_ptr(sorted_indices.const_data_ptr()); - auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto dummy_dev = thrust::device_ptr(dummy.mutable_data_ptr()); - auto ends = thrust::unique_by_key_copy( - policy, - sorted_indices_dev, - sorted_indices_dev + numel, - thrust::make_counting_iterator(0), - dummy_dev, - thrust::device_ptr(segment_offsets.mutable_data_ptr())); - return thrust::get<0>(ends) - dummy_dev; -} - -template -int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets); -template -int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets); - -} // namespace at::native diff --git a/aten/src/ATen/native/cuda/TensorTopK.cpp b/aten/src/ATen/native/cuda/TensorTopK.cpp index f47e7a887ebe..bc609f829a26 100644 --- a/aten/src/ATen/native/cuda/TensorTopK.cpp +++ b/aten/src/ATen/native/cuda/TensorTopK.cpp @@ -19,7 +19,6 @@ namespace at::native { -// TODO: remove this when CUDA <11.6 is no longer supported void topk_out_with_sort( const Tensor& self, int64_t k, int64_t dim, bool largest, @@ -31,21 +30,12 @@ void topk_out_with_sort( indices.copy_(sorted_indices.narrow(dim, 0, k)); } -// TODO: remove this when CUDA <11.6 is no longer supported -bool disable_sort_for_topk(); bool should_use_sort(const Tensor& self, int64_t dim) { #if defined(USE_ROCM) if (self.dtype() == kBool) return false; // Bool sort not supported in ROCm: https://github.com/pytorch/pytorch/issues/139972 return (self.numel() >= 10000 && self.numel() == self.size(dim)); // based on the experiments in https://github.com/pytorch/pytorch/pull/146387 #else - if (disable_sort_for_topk()) return false; - // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/68632 - if (self.dim() == 0) return false; - if (self.dtype() == kBool) return false; // Bool is not support by topk - int64_t slice_size = self.size(dim); - if (slice_size == 0) return false; - int64_t num_slices = self.numel() / slice_size; - return num_slices <= 10 && slice_size >= 100000; + return false; #endif } diff --git a/aten/src/ATen/native/cuda/TensorTopK.cu b/aten/src/ATen/native/cuda/TensorTopK.cu index 3f57281ebf56..d95d85bf0237 100644 --- a/aten/src/ATen/native/cuda/TensorTopK.cu +++ b/aten/src/ATen/native/cuda/TensorTopK.cu @@ -21,11 +21,6 @@ using namespace at::native; namespace at::native { -// TODO: remove this when CUDA <11.6 is no longer supported -bool disable_sort_for_topk() { - return CUB_SUPPORTS_SCAN_BY_KEY(); -} - namespace sbtopk { // single_block_topk template @@ -418,10 +413,6 @@ __global__ void computeBlockwiseWithinKCounts( } __syncthreads(); -#if !CUB_SUPPORTS_SCAN_BY_KEY() - return; -#endif - Bitwise desired_digit = at::cuda::Bitfield::getBitfield(desired, current_bit, RADIX_BITS); // if largest, then only threads that has tidx > desired_digit are active @@ -477,7 +468,6 @@ __global__ void computeBlockwiseWithinKCounts( } } -#if CUB_SUPPORTS_SCAN_BY_KEY() // Assumption: slice_size can not be larger than UINT32_MAX template __global__ void computeBlockwiseKthCounts( @@ -609,7 +599,6 @@ __global__ void gatherTopK(at::cuda::detail::TensorInfo inpu } } } -#endif int get_items_per_thread(uint64_t num_slices, uint64_t slice_size) { // occupancy of this kernel is limited by registers per threads @@ -687,16 +676,12 @@ void launch( uint32_t* digit_cum_sum = reinterpret_cast(digit_cum_sum_buffer.get()); AT_CUDA_CHECK(cudaMemsetAsync(digit_cum_sum, 0, numInputSlices * RADIX_DIGITS * sizeof(uint32_t), stream)); -#if CUB_SUPPORTS_SCAN_BY_KEY() auto withinKCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t)); uint32_t* withinKCounts = reinterpret_cast(withinKCounts_buffer.get()); AT_CUDA_CHECK(cudaMemsetAsync(withinKCounts, 0, num_blocks * sizeof(uint32_t), stream)); auto kthCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t)); uint32_t* kthCounts = reinterpret_cast(kthCounts_buffer.get()); -#else - uint32_t* withinKCounts = nullptr; -#endif Bitwise desiredMask = 0; dim3 grid; @@ -743,7 +728,6 @@ void launch( } desired = desired_in; -#if CUB_SUPPORTS_SCAN_BY_KEY() computeBlockwiseKthCounts<<>>( desired, counts, num_blocks, blocks_per_slice, kthCounts); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -759,28 +743,6 @@ void launch( topK, topKWithinSliceStride, indices, indicesWithinSliceStride, items_per_thread, blocks_per_slice, kthValues, withinKCounts, kthCounts, num_blocks); C10_CUDA_KERNEL_LAUNCH_CHECK(); -#else - // Find topk values based on kth values - { - dim3 grid; - TORCH_INTERNAL_ASSERT(getGridFromTiles(numInputSlices, grid), "Too many slices for topk"); - int warp_size = at::cuda::warp_size(); - dim3 block(std::min(at::ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * (int64_t)warp_size, (int64_t)1024)); - sbtopk::gatherTopK<<>>( - input, - inputSliceSize, - outputSliceSize, - largest, - numInputSlices, - inputWithinSliceStride, - topK, - topKWithinSliceStride, - indices, - indicesWithinSliceStride, - kthValues); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } -#endif } } // namespace mbtopk @@ -788,7 +750,6 @@ void launch( bool should_use_multiblock(int64_t num_slices, int64_t slice_size) { if (num_slices > std::numeric_limits::max() || slice_size > std::numeric_limits::max()) return false; -#if CUB_SUPPORTS_SCAN_BY_KEY() // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/74267 return (num_slices <= 20 && slice_size >= 20000) || (num_slices > 20 && num_slices <= 40 && slice_size >= 10000) || @@ -797,12 +758,6 @@ bool should_use_multiblock(int64_t num_slices, int64_t slice_size) { (num_slices >= 200 && num_slices < 800 && slice_size >= 3000) || (num_slices >= 800 && num_slices <= 4000 && slice_size >= 800) || (num_slices > 4000 && slice_size >= 400); -#else - // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/71081 - return (num_slices <= 400 && slice_size >= 5000) || - (num_slices > 400 && num_slices < 4000 && slice_size >= 1000) || - (num_slices >= 4000 && slice_size >= 300); -#endif } void launch_gather_topk_kernel(