mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement deterministic scan (#140887)
Fixes #89492 Uses block-wise cub primitives On large inputs, this implementation is approximately 25% slower than device cub implementation, so it's turned on only in cases where cub would have been (floating point inputs, cumsum that is effectively 1d) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140887 Approved by: https://github.com/ezyang, https://github.com/kurtamohler
This commit is contained in:
committed by
PyTorch MergeBot
parent
6ccd35ccb8
commit
0443398f5b
@ -7,6 +7,7 @@
|
||||
#include <limits>
|
||||
|
||||
#include <ATen/cuda/cub_definitions.cuh>
|
||||
#include <ATen/cuda/CUDAContextLight.h>
|
||||
|
||||
#if USE_GLOBAL_CUB_WRAPPED_NAMESPACE()
|
||||
|
||||
@ -291,6 +292,176 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
|
||||
#endif
|
||||
}
|
||||
|
||||
# if (defined(CUDA_VERSION) && CUDA_VERSION > 11040) || defined(USE_ROCM)
|
||||
|
||||
template<typename T>
|
||||
struct BlockPrefixCallbackOp
|
||||
{
|
||||
public:
|
||||
T running_total;
|
||||
|
||||
__host__ __device__ BlockPrefixCallbackOp(T running_total) : running_total(running_total) {}
|
||||
|
||||
// Callback operator to be entered by the first warp of threads in the block.
|
||||
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
|
||||
__host__ __device__ T operator()(T block_aggregate)
|
||||
{
|
||||
T old_prefix = running_total;
|
||||
running_total += block_aggregate;
|
||||
return old_prefix;
|
||||
}
|
||||
};
|
||||
|
||||
template<int BLOCK_THREADS, int ITEMS_PER_THREAD, typename T>
|
||||
__global__ void final_scan_kernel(const T* d_in, T* d_out, T* agg, int64_t nelem, int iters_per_cta) {
|
||||
if (BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x >= nelem) return;
|
||||
d_in += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
|
||||
d_out += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
|
||||
|
||||
using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
|
||||
// Specialize BlockStore type for our thread block (uses warp-striped loads for coalescing, then transposes in shared
|
||||
// memory to a blocked arrangement)
|
||||
using BlockStoreT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockStore<T, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
|
||||
// Specialize BlockScan type for our thread block
|
||||
using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan<T, BLOCK_THREADS, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_SCAN_WARP_SCANS>;
|
||||
using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce<T, BLOCK_THREADS>;
|
||||
|
||||
|
||||
// Shared memory
|
||||
__shared__ union TempStorage
|
||||
{
|
||||
typename BlockLoadT::TempStorage load;
|
||||
typename BlockStoreT::TempStorage store;
|
||||
typename BlockScanT::TempStorage scan;
|
||||
typename BlockReduceT::TempStorage reduce;
|
||||
} temp_storage;
|
||||
|
||||
// load agg and reduce my starting value
|
||||
T agg_data;
|
||||
agg_data = threadIdx.x >= blockIdx.x ? T(0) : agg[threadIdx.x];
|
||||
T aggregate = BlockReduceT(temp_storage.reduce).Sum(agg_data);
|
||||
__syncthreads();
|
||||
BlockPrefixCallbackOp prefix_op(aggregate);
|
||||
|
||||
|
||||
// Per-thread tile data
|
||||
T data[ITEMS_PER_THREAD];
|
||||
|
||||
int remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
|
||||
for (int i=0; i<iters_per_cta; i++){
|
||||
// Load items into a blocked arrangement
|
||||
if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
|
||||
BlockLoadT(temp_storage.load).Load(d_in, data);
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j=0; j<ITEMS_PER_THREAD; j++) {
|
||||
data[j] = 0;
|
||||
}
|
||||
BlockLoadT(temp_storage.load).Load(d_in, data, remaining);
|
||||
}
|
||||
|
||||
// Barrier for smem reuse
|
||||
__syncthreads();
|
||||
|
||||
// Compute inclusive prefix sum
|
||||
BlockScanT(temp_storage.scan).InclusiveSum(data, data, prefix_op);
|
||||
|
||||
// Barrier for smem reuse
|
||||
__syncthreads();
|
||||
|
||||
// Store items from a blocked arrangement
|
||||
if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
|
||||
BlockStoreT(temp_storage.store).Store(d_out, data);
|
||||
} else {
|
||||
BlockStoreT(temp_storage.store).Store(d_out, data, remaining);
|
||||
}
|
||||
d_in += BLOCK_THREADS * ITEMS_PER_THREAD;
|
||||
d_out += BLOCK_THREADS * ITEMS_PER_THREAD;
|
||||
remaining -= BLOCK_THREADS * ITEMS_PER_THREAD;
|
||||
if (remaining <= 0) return;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<int BLOCK_THREADS, int ITEMS_PER_THREAD, typename T>
|
||||
__global__ void calc_block_sums(const T * d_in, T * agg, int64_t nelem, int iters_per_cta){
|
||||
if (BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x >= nelem) return;
|
||||
d_in += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
|
||||
|
||||
using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_STRIPED>;
|
||||
using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce<T, BLOCK_THREADS>;
|
||||
// Shared memory
|
||||
__shared__ union TempStorage
|
||||
{
|
||||
typename BlockLoadT::TempStorage load;
|
||||
typename BlockReduceT::TempStorage reduce;
|
||||
} temp_storage;
|
||||
T data[ITEMS_PER_THREAD];
|
||||
T agg_val = 0;
|
||||
int64_t remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
|
||||
for (int i=0; i<iters_per_cta; i++){
|
||||
if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
|
||||
BlockLoadT(temp_storage.load).Load(d_in, data);
|
||||
__syncthreads();
|
||||
agg_val += BlockReduceT(temp_storage.reduce).Sum(data);
|
||||
|
||||
} else {
|
||||
BlockLoadT(temp_storage.load).Load(d_in, data, remaining);
|
||||
__syncthreads();
|
||||
agg_val += BlockReduceT(temp_storage.reduce).Sum(data);
|
||||
}
|
||||
d_in += BLOCK_THREADS * ITEMS_PER_THREAD;
|
||||
remaining -= BLOCK_THREADS * ITEMS_PER_THREAD;
|
||||
if (remaining <= 0) return;
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
agg[blockIdx.x] = agg_val;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<int size>
|
||||
constexpr int block_threads(){
|
||||
if constexpr (size >=16) {
|
||||
return 128;
|
||||
} else if constexpr (size >=8) {
|
||||
return 256;
|
||||
} else {
|
||||
return 512;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t, typename ScanOpT>
|
||||
inline void inclusive_deterministic_scan(const scalar_t * input, scalar_t * output, ScanOpT scan_op, int64_t num_items) {
|
||||
static_assert(std::is_same<ScanOpT, std::plus<scalar_t>>::value, "");
|
||||
constexpr int BLOCK_THREADS = block_threads<sizeof(scalar_t)>();
|
||||
constexpr int ITEMS_PER_THREAD = 16;
|
||||
auto grid_size = (num_items + BLOCK_THREADS * ITEMS_PER_THREAD - 1) / (BLOCK_THREADS * ITEMS_PER_THREAD);
|
||||
const int64_t num_sms = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||
|
||||
const int iters_per_cta = (grid_size + num_sms - 1)/num_sms;
|
||||
grid_size = std::min(num_sms, grid_size);
|
||||
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
|
||||
auto agg = allocator.allocate(grid_size * sizeof(scalar_t));
|
||||
calc_block_sums<BLOCK_THREADS, ITEMS_PER_THREAD>
|
||||
<<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
input, (scalar_t*)agg.get(), num_items, iters_per_cta);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
final_scan_kernel<BLOCK_THREADS, ITEMS_PER_THREAD>
|
||||
<<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
input, output, (scalar_t*)agg.get(), num_items, iters_per_cta);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename InitValueT, int max_cub_size=impl::max_cub_size>
|
||||
inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, InitValueT init_value, int64_t num_items) {
|
||||
#if defined(USE_ROCM)
|
||||
|
@ -89,11 +89,6 @@ Tensor _logcumsumexp_cuda(const Tensor& self, int64_t dim) {
|
||||
}
|
||||
|
||||
void cumsum_cuda_kernel(const Tensor& result, const Tensor& self, int64_t dim) {
|
||||
if (self.is_floating_point() || self.is_complex()) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
// Issue reporting nondeterministic behavior: https://github.com/pytorch/pytorch/issues/75240
|
||||
globalContext().alertNotDeterministic("cumsum_cuda_kernel");
|
||||
}
|
||||
auto result_ = contiguous_out_arg(result);
|
||||
launch_cumsum_cuda_kernel(*result_, self, dim);
|
||||
if (!result.is_same(*result_)) {
|
||||
|
@ -447,7 +447,20 @@ void scan_dim(const TensorBase& self, const TensorBase& result,
|
||||
TORCH_INTERNAL_ASSERT(result.is_contiguous());
|
||||
|
||||
if (self.numel() == self.size(dim)) {
|
||||
cuda::cub::inclusive_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());
|
||||
if constexpr (std::is_same<BinaryFunction, std::plus<scalar_t>>::value) {
|
||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms()) && (self.is_floating_point() || self.is_complex())) {
|
||||
# if (defined(CUDA_VERSION) && CUDA_VERSION > 11040) || defined(USE_ROCM)
|
||||
cuda::cub::inclusive_deterministic_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());
|
||||
#else
|
||||
globalContext().alertNotDeterministic("cumsum_cuda_kernel");
|
||||
cuda::cub::inclusive_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());
|
||||
#endif
|
||||
} else {
|
||||
cuda::cub::inclusive_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());
|
||||
}
|
||||
} else {
|
||||
cuda::cub::inclusive_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());
|
||||
}
|
||||
} else if (dim == ndim - 1) {
|
||||
scan_innermost_dim<scalar_t>(*self_, result, init, binary_op);
|
||||
} else {
|
||||
|
@ -1740,17 +1740,29 @@ else:
|
||||
'embedding_bag_backward_cuda_max',
|
||||
torch.device(device).type == 'cuda')
|
||||
|
||||
@dtypes(*all_types_and_complex_and(torch.bool))
|
||||
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
|
||||
def test_nondeterministic_alert_cumsum(self, device, dtype):
|
||||
input = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9)
|
||||
should_alert = torch.device(device).type == 'cuda' and (dtype.is_floating_point or dtype.is_complex)
|
||||
@onlyCUDA
|
||||
def test_deterministic_cumsum(self, device):
|
||||
test_cases = [
|
||||
# size, dim
|
||||
[(1025,), 0],
|
||||
[(8193,), 0],
|
||||
[(8191,), 0],
|
||||
[(128256,), 0],
|
||||
[(1282560,), 0],
|
||||
[(12825600,), 0],
|
||||
]
|
||||
for size, dim in test_cases:
|
||||
input = 100 * torch.rand(*size, device=device)
|
||||
with DeterministicGuard(True):
|
||||
res0 = input.cumsum(dim)
|
||||
for _ in range(3):
|
||||
res1 = input.cumsum(dim)
|
||||
self.assertEqual(res0, res1, atol=0, rtol=0)
|
||||
|
||||
res_cpu = input.cpu().cumsum(dim)
|
||||
self.assertEqual(res0, res_cpu, atol=1e-3, rtol=1e-2)
|
||||
|
||||
for op_call in [torch.Tensor.cumsum, torch.cumsum]:
|
||||
self.check_nondeterministic_alert(
|
||||
lambda: op_call(input, 0),
|
||||
'cumsum_cuda_kernel',
|
||||
should_alert)
|
||||
|
||||
@expectedFailureMeta # expected a non-determinitic error, but it was not raised
|
||||
@onlyNativeDeviceTypes
|
||||
|
Reference in New Issue
Block a user