Implement deterministic scan (#140887)

Fixes #89492
Uses block-wise cub primitives
On large inputs, this implementation is approximately 25% slower than device cub implementation, so it's turned on only in cases where cub would have been (floating point inputs, cumsum that is effectively 1d)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140887
Approved by: https://github.com/ezyang, https://github.com/kurtamohler
This commit is contained in:
Natalia Gimelshein
2024-11-19 23:43:24 +00:00
committed by PyTorch MergeBot
parent 6ccd35ccb8
commit 0443398f5b
4 changed files with 206 additions and 15 deletions

View File

@ -7,6 +7,7 @@
#include <limits>
#include <ATen/cuda/cub_definitions.cuh>
#include <ATen/cuda/CUDAContextLight.h>
#if USE_GLOBAL_CUB_WRAPPED_NAMESPACE()
@ -291,6 +292,176 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
#endif
}
# if (defined(CUDA_VERSION) && CUDA_VERSION > 11040) || defined(USE_ROCM)
template<typename T>
struct BlockPrefixCallbackOp
{
public:
T running_total;
__host__ __device__ BlockPrefixCallbackOp(T running_total) : running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
__host__ __device__ T operator()(T block_aggregate)
{
T old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
template<int BLOCK_THREADS, int ITEMS_PER_THREAD, typename T>
__global__ void final_scan_kernel(const T* d_in, T* d_out, T* agg, int64_t nelem, int iters_per_cta) {
if (BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x >= nelem) return;
d_in += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
d_out += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_WARP_TRANSPOSE>;
// Specialize BlockStore type for our thread block (uses warp-striped loads for coalescing, then transposes in shared
// memory to a blocked arrangement)
using BlockStoreT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockStore<T, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_STORE_WARP_TRANSPOSE>;
// Specialize BlockScan type for our thread block
using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan<T, BLOCK_THREADS, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_SCAN_WARP_SCANS>;
using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce<T, BLOCK_THREADS>;
// Shared memory
__shared__ union TempStorage
{
typename BlockLoadT::TempStorage load;
typename BlockStoreT::TempStorage store;
typename BlockScanT::TempStorage scan;
typename BlockReduceT::TempStorage reduce;
} temp_storage;
// load agg and reduce my starting value
T agg_data;
agg_data = threadIdx.x >= blockIdx.x ? T(0) : agg[threadIdx.x];
T aggregate = BlockReduceT(temp_storage.reduce).Sum(agg_data);
__syncthreads();
BlockPrefixCallbackOp prefix_op(aggregate);
// Per-thread tile data
T data[ITEMS_PER_THREAD];
int remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
for (int i=0; i<iters_per_cta; i++){
// Load items into a blocked arrangement
if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
BlockLoadT(temp_storage.load).Load(d_in, data);
} else {
#pragma unroll
for (int j=0; j<ITEMS_PER_THREAD; j++) {
data[j] = 0;
}
BlockLoadT(temp_storage.load).Load(d_in, data, remaining);
}
// Barrier for smem reuse
__syncthreads();
// Compute inclusive prefix sum
BlockScanT(temp_storage.scan).InclusiveSum(data, data, prefix_op);
// Barrier for smem reuse
__syncthreads();
// Store items from a blocked arrangement
if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
BlockStoreT(temp_storage.store).Store(d_out, data);
} else {
BlockStoreT(temp_storage.store).Store(d_out, data, remaining);
}
d_in += BLOCK_THREADS * ITEMS_PER_THREAD;
d_out += BLOCK_THREADS * ITEMS_PER_THREAD;
remaining -= BLOCK_THREADS * ITEMS_PER_THREAD;
if (remaining <= 0) return;
__syncthreads();
}
}
template<int BLOCK_THREADS, int ITEMS_PER_THREAD, typename T>
__global__ void calc_block_sums(const T * d_in, T * agg, int64_t nelem, int iters_per_cta){
if (BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x >= nelem) return;
d_in += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_STRIPED>;
using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce<T, BLOCK_THREADS>;
// Shared memory
__shared__ union TempStorage
{
typename BlockLoadT::TempStorage load;
typename BlockReduceT::TempStorage reduce;
} temp_storage;
T data[ITEMS_PER_THREAD];
T agg_val = 0;
int64_t remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
for (int i=0; i<iters_per_cta; i++){
if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
BlockLoadT(temp_storage.load).Load(d_in, data);
__syncthreads();
agg_val += BlockReduceT(temp_storage.reduce).Sum(data);
} else {
BlockLoadT(temp_storage.load).Load(d_in, data, remaining);
__syncthreads();
agg_val += BlockReduceT(temp_storage.reduce).Sum(data);
}
d_in += BLOCK_THREADS * ITEMS_PER_THREAD;
remaining -= BLOCK_THREADS * ITEMS_PER_THREAD;
if (remaining <= 0) return;
__syncthreads();
}
if (threadIdx.x == 0) {
agg[blockIdx.x] = agg_val;
}
}
template<int size>
constexpr int block_threads(){
if constexpr (size >=16) {
return 128;
} else if constexpr (size >=8) {
return 256;
} else {
return 512;
}
}
template<typename scalar_t, typename ScanOpT>
inline void inclusive_deterministic_scan(const scalar_t * input, scalar_t * output, ScanOpT scan_op, int64_t num_items) {
static_assert(std::is_same<ScanOpT, std::plus<scalar_t>>::value, "");
constexpr int BLOCK_THREADS = block_threads<sizeof(scalar_t)>();
constexpr int ITEMS_PER_THREAD = 16;
auto grid_size = (num_items + BLOCK_THREADS * ITEMS_PER_THREAD - 1) / (BLOCK_THREADS * ITEMS_PER_THREAD);
const int64_t num_sms = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
const int iters_per_cta = (grid_size + num_sms - 1)/num_sms;
grid_size = std::min(num_sms, grid_size);
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto agg = allocator.allocate(grid_size * sizeof(scalar_t));
calc_block_sums<BLOCK_THREADS, ITEMS_PER_THREAD>
<<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
input, (scalar_t*)agg.get(), num_items, iters_per_cta);
C10_CUDA_KERNEL_LAUNCH_CHECK();
final_scan_kernel<BLOCK_THREADS, ITEMS_PER_THREAD>
<<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
input, output, (scalar_t*)agg.get(), num_items, iters_per_cta);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#endif
template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename InitValueT, int max_cub_size=impl::max_cub_size>
inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, InitValueT init_value, int64_t num_items) {
#if defined(USE_ROCM)

View File

@ -89,11 +89,6 @@ Tensor _logcumsumexp_cuda(const Tensor& self, int64_t dim) {
}
void cumsum_cuda_kernel(const Tensor& result, const Tensor& self, int64_t dim) {
if (self.is_floating_point() || self.is_complex()) {
// See Note [Writing Nondeterministic Operations]
// Issue reporting nondeterministic behavior: https://github.com/pytorch/pytorch/issues/75240
globalContext().alertNotDeterministic("cumsum_cuda_kernel");
}
auto result_ = contiguous_out_arg(result);
launch_cumsum_cuda_kernel(*result_, self, dim);
if (!result.is_same(*result_)) {

View File

@ -447,7 +447,20 @@ void scan_dim(const TensorBase& self, const TensorBase& result,
TORCH_INTERNAL_ASSERT(result.is_contiguous());
if (self.numel() == self.size(dim)) {
cuda::cub::inclusive_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());
if constexpr (std::is_same<BinaryFunction, std::plus<scalar_t>>::value) {
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms()) && (self.is_floating_point() || self.is_complex())) {
# if (defined(CUDA_VERSION) && CUDA_VERSION > 11040) || defined(USE_ROCM)
cuda::cub::inclusive_deterministic_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());
#else
globalContext().alertNotDeterministic("cumsum_cuda_kernel");
cuda::cub::inclusive_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());
#endif
} else {
cuda::cub::inclusive_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());
}
} else {
cuda::cub::inclusive_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());
}
} else if (dim == ndim - 1) {
scan_innermost_dim<scalar_t>(*self_, result, init, binary_op);
} else {

View File

@ -1740,17 +1740,29 @@ else:
'embedding_bag_backward_cuda_max',
torch.device(device).type == 'cuda')
@dtypes(*all_types_and_complex_and(torch.bool))
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
def test_nondeterministic_alert_cumsum(self, device, dtype):
input = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9)
should_alert = torch.device(device).type == 'cuda' and (dtype.is_floating_point or dtype.is_complex)
@onlyCUDA
def test_deterministic_cumsum(self, device):
test_cases = [
# size, dim
[(1025,), 0],
[(8193,), 0],
[(8191,), 0],
[(128256,), 0],
[(1282560,), 0],
[(12825600,), 0],
]
for size, dim in test_cases:
input = 100 * torch.rand(*size, device=device)
with DeterministicGuard(True):
res0 = input.cumsum(dim)
for _ in range(3):
res1 = input.cumsum(dim)
self.assertEqual(res0, res1, atol=0, rtol=0)
res_cpu = input.cpu().cumsum(dim)
self.assertEqual(res0, res_cpu, atol=1e-3, rtol=1e-2)
for op_call in [torch.Tensor.cumsum, torch.cumsum]:
self.check_nondeterministic_alert(
lambda: op_call(input, 0),
'cumsum_cuda_kernel',
should_alert)
@expectedFailureMeta # expected a non-determinitic error, but it was not raised
@onlyNativeDeviceTypes