mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement nonzero for large inputs (#141592)
Fixes #51871 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141592 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
aa827e319e
commit
4ae1c4cbb5
@ -1,13 +1,13 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/EmptyTensor.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <ATen/cuda/EmptyTensor.h>
|
||||
#include <ATen/cuda/detail/KernelUtils.h>
|
||||
#include <ATen/cuda/detail/OffsetCalculator.cuh> //for MAX_DIMS
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <ATen/cuda/cub.cuh>
|
||||
#include <ATen/cuda/detail/OffsetCalculator.cuh> //for MAX_DIMS
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/NativeFunctions.h>
|
||||
@ -16,20 +16,18 @@
|
||||
#include <ATen/ops/nonzero_native.h>
|
||||
#endif
|
||||
|
||||
|
||||
namespace at::native {
|
||||
|
||||
namespace{
|
||||
template<typename T>
|
||||
struct NonZeroOp
|
||||
{
|
||||
__host__ __device__ __forceinline__ bool operator()(const T& a) const {
|
||||
return (a!=T(0));
|
||||
}
|
||||
namespace {
|
||||
template <typename T>
|
||||
struct NonZeroOp {
|
||||
__host__ __device__ __forceinline__ bool operator()(const T& a) const {
|
||||
return (a != T(0));
|
||||
}
|
||||
};
|
||||
|
||||
//TODO: actually support int64_t index_t
|
||||
template<typename index_t>
|
||||
// TODO: actually support int64_t index_t
|
||||
template <typename index_t>
|
||||
struct TensorDims {
|
||||
index_t sizes[MAX_DIMS];
|
||||
};
|
||||
@ -55,86 +53,167 @@ __global__ void write_indices(
|
||||
}
|
||||
}
|
||||
|
||||
} //anonymous namespace
|
||||
} // anonymous namespace
|
||||
|
||||
template<typename scalar_t>
|
||||
void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){
|
||||
template <typename scalar_t>
|
||||
void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) {
|
||||
Tensor self_ = self.contiguous();
|
||||
int N = self_.numel();
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
// compute number of nonzero elements
|
||||
size_t temp_storage_bytes=0;
|
||||
int64_t chunk_size, num_chunks;
|
||||
if (self.numel() < std::numeric_limits<int>::max()) {
|
||||
chunk_size = self.numel();
|
||||
num_chunks = 1;
|
||||
} else {
|
||||
chunk_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
|
||||
num_chunks = (self.numel() + chunk_size - 1) / chunk_size;
|
||||
}
|
||||
// compute number of nonzero elements
|
||||
size_t temp_storage_bytes = 0;
|
||||
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
|
||||
auto num_nonzeros = allocator.allocate(sizeof(int));
|
||||
cub::TransformInputIterator<bool, NonZeroOp<scalar_t>, const scalar_t*> itr(self_.const_data_ptr<scalar_t>(), NonZeroOp<scalar_t>());
|
||||
cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream);
|
||||
auto temp_storage = allocator.allocate(temp_storage_bytes);
|
||||
cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream);
|
||||
int num_nonzeros_h;
|
||||
auto num_nonzeros = allocator.allocate(sizeof(int) * num_chunks);
|
||||
for (int64_t idx = 0; idx < num_chunks; idx++) {
|
||||
int64_t remaining = std::min(chunk_size, self.numel() - idx * chunk_size);
|
||||
cub::TransformInputIterator<bool, NonZeroOp<scalar_t>, const scalar_t*> itr(
|
||||
self_.const_data_ptr<scalar_t>() + idx * chunk_size,
|
||||
NonZeroOp<scalar_t>());
|
||||
cub::DeviceReduce::Sum(
|
||||
nullptr,
|
||||
temp_storage_bytes,
|
||||
itr,
|
||||
((int*)num_nonzeros.get()) + idx,
|
||||
remaining,
|
||||
stream);
|
||||
auto temp_storage = allocator.allocate(temp_storage_bytes);
|
||||
cub::DeviceReduce::Sum(
|
||||
temp_storage.get(),
|
||||
temp_storage_bytes,
|
||||
itr,
|
||||
((int*)num_nonzeros.get()) + idx,
|
||||
remaining,
|
||||
stream);
|
||||
}
|
||||
auto pinned_num_nonzeros_h = at::detail::empty_cpu(
|
||||
{1}, /* size */
|
||||
c10::CppTypeToScalarType<int>(), /* dtype */
|
||||
std::nullopt, /* layout */
|
||||
std::nullopt, /* device */
|
||||
true, /* pin_memory */
|
||||
std::nullopt /* memory format */
|
||||
);
|
||||
at::cuda::memcpy_and_sync((void *)pinned_num_nonzeros_h.const_data_ptr<int>(), num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream);
|
||||
num_nonzeros_h = (int)*(pinned_num_nonzeros_h.const_data_ptr<int>());
|
||||
//expected output size is num_nonzeros x ndim
|
||||
//we are producing output with size {num_nonzeros, ndim} and strides {1, num_nonzeros} (that is, transposed ndim x num_nonzeros output)
|
||||
//we are able to directly use passed output with this size and strides, and we can also (per contract)
|
||||
//resize passed output with incorrect sizes anyway we want.
|
||||
//However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced.
|
||||
bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous();
|
||||
at::Tensor out_temp = need_to_copy ?
|
||||
Tensor(at::detail::empty_cuda({self.dim(), num_nonzeros_h}, out.options())) :
|
||||
out.resize_({self.dim(), num_nonzeros_h});
|
||||
//Scalars are expected to produce output of size (1,0), so we can't write to it
|
||||
{num_chunks}, /* size */
|
||||
c10::CppTypeToScalarType<int>(), /* dtype */
|
||||
std::nullopt, /* layout */
|
||||
std::nullopt, /* device */
|
||||
true, /* pin_memory */
|
||||
std::nullopt /* memory format */
|
||||
);
|
||||
at::cuda::memcpy_and_sync(
|
||||
(void*)pinned_num_nonzeros_h.const_data_ptr<int>(),
|
||||
num_nonzeros.get(),
|
||||
sizeof(int) * num_chunks,
|
||||
cudaMemcpyDeviceToHost,
|
||||
stream);
|
||||
int64_t num_nonzeros_h = 0;
|
||||
|
||||
for (int64_t idx = 0; idx < num_chunks; idx++) {
|
||||
num_nonzeros_h += (int)*(pinned_num_nonzeros_h.const_data_ptr<int>() + idx);
|
||||
}
|
||||
// num_nonzeros_h = (int)*(pinned_num_nonzeros_h.const_data_ptr<int>());
|
||||
// expected output size is num_nonzeros x ndim
|
||||
// we are producing output with size {num_nonzeros, ndim} and strides {1,
|
||||
// num_nonzeros} (that is, transposed ndim x num_nonzeros output) we are able
|
||||
// to directly use passed output with this size and strides, and we can also
|
||||
// (per contract) resize passed output with incorrect sizes anyway we want.
|
||||
// However, out with correct sizes and incorrect strides will have to be
|
||||
// copied to from the intermediate we've produced.
|
||||
bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h &&
|
||||
out.sizes()[1] == self.dim() && !out.t().is_contiguous();
|
||||
at::Tensor out_temp = need_to_copy
|
||||
? Tensor(
|
||||
at::detail::empty_cuda({self.dim(), num_nonzeros_h}, out.options()))
|
||||
: out.resize_({self.dim(), num_nonzeros_h});
|
||||
// Scalars are expected to produce output of size (1,0), so we can't write to
|
||||
// it
|
||||
int64_t curr_nonzeros = 0;
|
||||
if (self.dim() > 0) {
|
||||
cub::CountingInputIterator<int64_t> counting_itr(0);
|
||||
temp_storage_bytes = 0;
|
||||
cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr,
|
||||
out_temp.mutable_data_ptr<int64_t>(), (int*)num_nonzeros.get(), N, stream);
|
||||
temp_storage = allocator.allocate(temp_storage_bytes);
|
||||
cub::DeviceSelect::Flagged(temp_storage.get(), temp_storage_bytes, counting_itr, itr,
|
||||
out_temp.mutable_data_ptr<int64_t>(), (int*)num_nonzeros.get(), N, stream);
|
||||
if (num_nonzeros_h > 0 && self.dim() > 1){
|
||||
TensorDims<int> dims;
|
||||
for (int i=0; i<self.dim(); i++){
|
||||
dims.sizes[i] = self.sizes()[i];
|
||||
}
|
||||
const int nthreads = 256;
|
||||
const int nblocks = (num_nonzeros_h + nthreads -1)/nthreads;
|
||||
write_indices<<<nblocks, nthreads, 0, stream>>>(out_temp.mutable_data_ptr<int64_t>(),
|
||||
dims, self.dim(), num_nonzeros_h);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
for (int64_t idx = 0; idx < num_chunks; idx++) {
|
||||
int remaining = std::min(chunk_size, self.numel() - idx * chunk_size);
|
||||
|
||||
cub::CountingInputIterator<int64_t> counting_itr(idx * chunk_size);
|
||||
cub::TransformInputIterator<bool, NonZeroOp<scalar_t>, const scalar_t*>
|
||||
itr(self_.const_data_ptr<scalar_t>() + idx * chunk_size,
|
||||
NonZeroOp<scalar_t>());
|
||||
temp_storage_bytes = 0;
|
||||
cub::DeviceSelect::Flagged(
|
||||
nullptr,
|
||||
temp_storage_bytes,
|
||||
counting_itr,
|
||||
itr,
|
||||
out_temp.mutable_data_ptr<int64_t>(),
|
||||
((int*)num_nonzeros.get()) + idx,
|
||||
remaining,
|
||||
stream);
|
||||
auto temp_storage = allocator.allocate(temp_storage_bytes);
|
||||
cub::DeviceSelect::Flagged(
|
||||
temp_storage.get(),
|
||||
temp_storage_bytes,
|
||||
counting_itr,
|
||||
itr,
|
||||
out_temp.mutable_data_ptr<int64_t>() + curr_nonzeros,
|
||||
((int*)num_nonzeros.get()) + idx,
|
||||
remaining,
|
||||
stream);
|
||||
curr_nonzeros +=
|
||||
(int)*(pinned_num_nonzeros_h.const_data_ptr<int>() + idx);
|
||||
}
|
||||
if (num_nonzeros_h > 0 && self.dim() > 1) {
|
||||
TensorDims<int64_t> dims;
|
||||
for (int i = 0; i < self.dim(); i++) {
|
||||
dims.sizes[i] = self.sizes()[i];
|
||||
}
|
||||
const int nthreads = 256;
|
||||
const int nblocks = (num_nonzeros_h + nthreads - 1) / nthreads;
|
||||
write_indices<<<nblocks, nthreads, 0, stream>>>(
|
||||
out_temp.mutable_data_ptr<int64_t>(),
|
||||
dims,
|
||||
self.dim(),
|
||||
num_nonzeros_h);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
}
|
||||
if (need_to_copy) {
|
||||
out.copy_(out_temp.t());
|
||||
} else {
|
||||
//transpose out so it is correct size
|
||||
// transpose out so it is correct size
|
||||
Tensor out_ = out_temp.t();
|
||||
out.set_(out_);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor& nonzero_out_cuda(const Tensor& self, Tensor& out){
|
||||
TORCH_CHECK(self.numel() < std::numeric_limits<int>::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \
|
||||
See https://github.com/pytorch/pytorch/issues/51871");
|
||||
TORCH_CHECK(out.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out.dtype());
|
||||
TORCH_CHECK(self.device() == out.device(), "expected self and out to be on the same device, but got out on ",
|
||||
out.device(), " and self on ", self.device());
|
||||
TORCH_CHECK(self.dim() <= MAX_DIMS, "nonzero is not supported for tensor with more than ", MAX_DIMS, " dimensions");
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::ScalarType::ComplexHalf, at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
|
||||
self.scalar_type(), "nonzero_cuda",
|
||||
[&] {nonzero_cuda_out_impl<scalar_t>(self, out);});
|
||||
Tensor& nonzero_out_cuda(const Tensor& self, Tensor& out) {
|
||||
TORCH_CHECK(
|
||||
out.dtype() == at::kLong,
|
||||
"Expected object of scalar type ",
|
||||
at::kLong,
|
||||
" as out, but got ",
|
||||
out.dtype());
|
||||
TORCH_CHECK(
|
||||
self.device() == out.device(),
|
||||
"expected self and out to be on the same device, but got out on ",
|
||||
out.device(),
|
||||
" and self on ",
|
||||
self.device());
|
||||
TORCH_CHECK(
|
||||
self.dim() <= MAX_DIMS,
|
||||
"nonzero is not supported for tensor with more than ",
|
||||
MAX_DIMS,
|
||||
" dimensions");
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
at::ScalarType::ComplexHalf,
|
||||
at::ScalarType::Bool,
|
||||
at::ScalarType::BFloat16,
|
||||
at::ScalarType::Half,
|
||||
self.scalar_type(),
|
||||
"nonzero_cuda",
|
||||
[&] { nonzero_cuda_out_impl<scalar_t>(self, out); });
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor nonzero_cuda(const Tensor& self){
|
||||
Tensor nonzero_cuda(const Tensor& self) {
|
||||
Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong));
|
||||
return at::native::nonzero_out_cuda(self, out);
|
||||
}
|
||||
} //namespace at::native
|
||||
} // namespace at::native
|
||||
|
@ -16,6 +16,7 @@ from torch.testing._internal.common_device_type import (
|
||||
dtypesIfCPU,
|
||||
dtypesIfCUDA,
|
||||
instantiate_device_type_tests,
|
||||
largeTensorTest,
|
||||
onlyCPU,
|
||||
onlyCUDA,
|
||||
onlyNativeDeviceTypes,
|
||||
@ -44,6 +45,7 @@ from torch.testing._internal.common_utils import (
|
||||
numpy_to_torch_dtype_dict,
|
||||
run_tests,
|
||||
skipIfNoSciPy,
|
||||
skipIfRocm,
|
||||
slowTest,
|
||||
suppress_warnings,
|
||||
TEST_SCIPY,
|
||||
@ -1555,7 +1557,28 @@ class TestUnaryUfuncs(TestCase):
|
||||
self.assertEqual(1, len(z))
|
||||
self.assertEqual(torch.empty(0, dtype=torch.long), z[0])
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.int8)
|
||||
@largeTensorTest("8GB")
|
||||
@skipIfRocm(msg="ROCM tries to allocate 60GB")
|
||||
def test_nonzero_large(self, device, dtype):
|
||||
indices = (
|
||||
torch.tensor((0, 2, 3, 4, 6, 100, 103, 2**30, 2**31 - 3, 2**31 - 2)),
|
||||
torch.tensor((0, 1, 1, 1, 0, 1, 0, 1, 0, 0)),
|
||||
)
|
||||
|
||||
x = torch.zeros(2**31 - 1, 2, device=device, dtype=dtype)
|
||||
x[indices[0], indices[1]] = 1
|
||||
y = torch.nonzero(x, as_tuple=True)
|
||||
self.assertEqual(y, indices)
|
||||
x = x.view(-1).fill_(0)
|
||||
indices = indices[0] * 2
|
||||
x[indices] = 1
|
||||
y = torch.nonzero(x)
|
||||
self.assertEqual(y.view(-1), indices)
|
||||
|
||||
# TODO: rationalize with exp OpInfo
|
||||
|
||||
@dtypes(*floating_and_complex_types_and(torch.bfloat16))
|
||||
@dtypesIfCUDA(*floating_and_complex_types_and(torch.half, torch.bfloat16))
|
||||
def test_exp(self, device, dtype):
|
||||
|
Reference in New Issue
Block a user