diff --git a/aten/src/ATen/cuda/cub.cuh b/aten/src/ATen/cuda/cub.cuh index daa7f311ff23..21e21cacfaa2 100644 --- a/aten/src/ATen/cuda/cub.cuh +++ b/aten/src/ATen/cuda/cub.cuh @@ -7,6 +7,7 @@ #include #include +#include #if USE_GLOBAL_CUB_WRAPPED_NAMESPACE() @@ -291,6 +292,176 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT #endif } +# if (defined(CUDA_VERSION) && CUDA_VERSION > 11040) || defined(USE_ROCM) + +template +struct BlockPrefixCallbackOp +{ + public: + T running_total; + + __host__ __device__ BlockPrefixCallbackOp(T running_total) : running_total(running_total) {} + + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __host__ __device__ T operator()(T block_aggregate) + { + T old_prefix = running_total; + running_total += block_aggregate; + return old_prefix; + } +}; + +template +__global__ void final_scan_kernel(const T* d_in, T* d_out, T* agg, int64_t nelem, int iters_per_cta) { + if (BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x >= nelem) return; + d_in += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x; + d_out += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x; + + using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad; + + // Specialize BlockStore type for our thread block (uses warp-striped loads for coalescing, then transposes in shared + // memory to a blocked arrangement) + using BlockStoreT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockStore; + + // Specialize BlockScan type for our thread block + using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan; + using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce; + + + // Shared memory + __shared__ union TempStorage + { + typename BlockLoadT::TempStorage load; + typename BlockStoreT::TempStorage store; + typename BlockScanT::TempStorage scan; + typename BlockReduceT::TempStorage reduce; + } temp_storage; + + // load agg and reduce my starting value + T agg_data; + agg_data = threadIdx.x >= blockIdx.x ? T(0) : agg[threadIdx.x]; + T aggregate = BlockReduceT(temp_storage.reduce).Sum(agg_data); + __syncthreads(); + BlockPrefixCallbackOp prefix_op(aggregate); + + + // Per-thread tile data + T data[ITEMS_PER_THREAD]; + + int remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x; + for (int i=0; i= BLOCK_THREADS * ITEMS_PER_THREAD) { + BlockLoadT(temp_storage.load).Load(d_in, data); + } else { + #pragma unroll + for (int j=0; j= BLOCK_THREADS * ITEMS_PER_THREAD) { + BlockStoreT(temp_storage.store).Store(d_out, data); + } else { + BlockStoreT(temp_storage.store).Store(d_out, data, remaining); + } + d_in += BLOCK_THREADS * ITEMS_PER_THREAD; + d_out += BLOCK_THREADS * ITEMS_PER_THREAD; + remaining -= BLOCK_THREADS * ITEMS_PER_THREAD; + if (remaining <= 0) return; + __syncthreads(); + } + +} + + + +template +__global__ void calc_block_sums(const T * d_in, T * agg, int64_t nelem, int iters_per_cta){ + if (BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x >= nelem) return; + d_in += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x; + + using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad; + using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce; + // Shared memory + __shared__ union TempStorage + { + typename BlockLoadT::TempStorage load; + typename BlockReduceT::TempStorage reduce; + } temp_storage; + T data[ITEMS_PER_THREAD]; + T agg_val = 0; + int64_t remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x; + for (int i=0; i= BLOCK_THREADS * ITEMS_PER_THREAD) { + BlockLoadT(temp_storage.load).Load(d_in, data); + __syncthreads(); + agg_val += BlockReduceT(temp_storage.reduce).Sum(data); + + } else { + BlockLoadT(temp_storage.load).Load(d_in, data, remaining); + __syncthreads(); + agg_val += BlockReduceT(temp_storage.reduce).Sum(data); + } + d_in += BLOCK_THREADS * ITEMS_PER_THREAD; + remaining -= BLOCK_THREADS * ITEMS_PER_THREAD; + if (remaining <= 0) return; + __syncthreads(); + + } + if (threadIdx.x == 0) { + agg[blockIdx.x] = agg_val; + } + +} + +template +constexpr int block_threads(){ + if constexpr (size >=16) { + return 128; + } else if constexpr (size >=8) { + return 256; + } else { + return 512; + } +} + +template +inline void inclusive_deterministic_scan(const scalar_t * input, scalar_t * output, ScanOpT scan_op, int64_t num_items) { + static_assert(std::is_same>::value, ""); + constexpr int BLOCK_THREADS = block_threads(); + constexpr int ITEMS_PER_THREAD = 16; + auto grid_size = (num_items + BLOCK_THREADS * ITEMS_PER_THREAD - 1) / (BLOCK_THREADS * ITEMS_PER_THREAD); + const int64_t num_sms = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + + const int iters_per_cta = (grid_size + num_sms - 1)/num_sms; + grid_size = std::min(num_sms, grid_size); + auto& allocator = *c10::cuda::CUDACachingAllocator::get(); + auto agg = allocator.allocate(grid_size * sizeof(scalar_t)); + calc_block_sums + <<>>( + input, (scalar_t*)agg.get(), num_items, iters_per_cta); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + final_scan_kernel + <<>>( + input, output, (scalar_t*)agg.get(), num_items, iters_per_cta); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +#endif + template inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, InitValueT init_value, int64_t num_items) { #if defined(USE_ROCM) diff --git a/aten/src/ATen/native/cuda/ScanKernels.cpp b/aten/src/ATen/native/cuda/ScanKernels.cpp index 3f89c022e3c1..7db58d474c00 100644 --- a/aten/src/ATen/native/cuda/ScanKernels.cpp +++ b/aten/src/ATen/native/cuda/ScanKernels.cpp @@ -89,11 +89,6 @@ Tensor _logcumsumexp_cuda(const Tensor& self, int64_t dim) { } void cumsum_cuda_kernel(const Tensor& result, const Tensor& self, int64_t dim) { - if (self.is_floating_point() || self.is_complex()) { - // See Note [Writing Nondeterministic Operations] - // Issue reporting nondeterministic behavior: https://github.com/pytorch/pytorch/issues/75240 - globalContext().alertNotDeterministic("cumsum_cuda_kernel"); - } auto result_ = contiguous_out_arg(result); launch_cumsum_cuda_kernel(*result_, self, dim); if (!result.is_same(*result_)) { diff --git a/aten/src/ATen/native/cuda/ScanUtils.cuh b/aten/src/ATen/native/cuda/ScanUtils.cuh index 88cfa15abf60..1a9d12a753a8 100644 --- a/aten/src/ATen/native/cuda/ScanUtils.cuh +++ b/aten/src/ATen/native/cuda/ScanUtils.cuh @@ -447,7 +447,20 @@ void scan_dim(const TensorBase& self, const TensorBase& result, TORCH_INTERNAL_ASSERT(result.is_contiguous()); if (self.numel() == self.size(dim)) { - cuda::cub::inclusive_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel()); + if constexpr (std::is_same>::value) { + if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms()) && (self.is_floating_point() || self.is_complex())) { +# if (defined(CUDA_VERSION) && CUDA_VERSION > 11040) || defined(USE_ROCM) + cuda::cub::inclusive_deterministic_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel()); +#else + globalContext().alertNotDeterministic("cumsum_cuda_kernel"); + cuda::cub::inclusive_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel()); +#endif + } else { + cuda::cub::inclusive_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel()); + } + } else { + cuda::cub::inclusive_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel()); + } } else if (dim == ndim - 1) { scan_innermost_dim(*self_, result, init, binary_op); } else { diff --git a/test/test_torch.py b/test/test_torch.py index c2ec2aea19f6..2ab0e85df7f5 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1740,17 +1740,29 @@ else: 'embedding_bag_backward_cuda_max', torch.device(device).type == 'cuda') - @dtypes(*all_types_and_complex_and(torch.bool)) @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") - def test_nondeterministic_alert_cumsum(self, device, dtype): - input = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9) - should_alert = torch.device(device).type == 'cuda' and (dtype.is_floating_point or dtype.is_complex) + @onlyCUDA + def test_deterministic_cumsum(self, device): + test_cases = [ + # size, dim + [(1025,), 0], + [(8193,), 0], + [(8191,), 0], + [(128256,), 0], + [(1282560,), 0], + [(12825600,), 0], + ] + for size, dim in test_cases: + input = 100 * torch.rand(*size, device=device) + with DeterministicGuard(True): + res0 = input.cumsum(dim) + for _ in range(3): + res1 = input.cumsum(dim) + self.assertEqual(res0, res1, atol=0, rtol=0) + + res_cpu = input.cpu().cumsum(dim) + self.assertEqual(res0, res_cpu, atol=1e-3, rtol=1e-2) - for op_call in [torch.Tensor.cumsum, torch.cumsum]: - self.check_nondeterministic_alert( - lambda: op_call(input, 0), - 'cumsum_cuda_kernel', - should_alert) @expectedFailureMeta # expected a non-determinitic error, but it was not raised @onlyNativeDeviceTypes