Add deterministic impl of scatter_add CUDA for all input sizes (#79466)

Fixes #50469

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79466
Approved by: https://github.com/ngimel
This commit is contained in:
Kurt Mohler
2022-09-07 03:12:49 +00:00
committed by PyTorch MergeBot
parent 039b0146f9
commit 5b58140d1a
4 changed files with 111 additions and 63 deletions

View File

@ -1504,16 +1504,101 @@ TORCH_IMPL_FUNC(scatter_add)
if (index.numel() == 0) return;
if (globalContext().deterministicAlgorithms() && self.device().type() == DeviceType::CUDA && self.dim() == 1) {
TORCH_CHECK(index.dim() == 1 && src.dim() == 1, "index and src should be 1D tensors when self is a 1D tensor, "
"but their dims are ", index.dim(), " and ", src.dim(), ", respectively");
TORCH_CHECK(index.numel() == src.numel(), "index and src should have same number of elements for 1D tensors, "
"but got ", index.numel(), " versus ", src.numel());
TORCH_CHECK(dim == 0, "dim should be zero for 1D self tensor, but got ", dim);
torch::List<c10::optional<Tensor>> indices;
indices.reserve(1);
indices.push_back(index);
mut_out.index_put_(indices, src, true);
// See Note [Enabling Deterministic Operations]
// Avoid gpuAtomicAdd for CUDA if deterministic mode is turned on
if (globalContext().deterministicAlgorithms() && self.device().type() == DeviceType::CUDA) {
if (self.dim() == 1) {
// TODO: Pretty sure these checks can be removed, since they're done in
// `scatter_meta_impl`, which I think is always called before this
TORCH_CHECK(index.dim() == 1 && src.dim() == 1, "index and src should be 1D tensors when self is a 1D tensor, "
"but their dims are ", index.dim(), " and ", src.dim(), ", respectively");
TORCH_CHECK(index.numel() == src.numel(), "index and src should have same number of elements for 1D tensors, "
"but got ", index.numel(), " versus ", src.numel());
TORCH_CHECK(dim == 0, "dim should be zero for 1D self tensor, but got ", dim);
torch::List<c10::optional<Tensor>> indices;
indices.reserve(1);
indices.push_back(index);
mut_out.index_put_(indices, src, true);
} else {
Tensor mut_out_contig = mut_out.contiguous();
auto index_coords_sizes = index.sizes().vec();
index_coords_sizes.push_back(self.dim());
auto index_coords = at::empty(
index_coords_sizes,
at::TensorOptions().dtype(at::ScalarType::Long).device(self.device()));
for (int64_t dim_other = 0; dim_other < self.dim(); dim_other++) {
if (dim_other == dim) {
continue;
}
auto dim_coord_vals = at::arange(
index.size(dim_other),
at::TensorOptions().device(self.device()));
for (int64_t dim_unsqueeze = 0; dim_unsqueeze < self.dim() - 1; dim_unsqueeze++) {
dim_coord_vals = dim_coord_vals.unsqueeze((dim_unsqueeze >= dim_other) ? -1 : 0);
}
auto view_sizes = index.sizes().vec();
view_sizes.push_back(1);
auto view_strides = index_coords.strides().vec();
view_strides[self.dim()] = self.dim();
at::as_strided(
index_coords,
view_sizes,
view_strides,
dim_other
).copy_(dim_coord_vals.unsqueeze(-1));
}
auto view_sizes = index.sizes().vec();
view_sizes.push_back(1);
auto view_strides = index_coords.strides().vec();
view_strides[self.dim()] = self.dim();
at::as_strided(
index_coords,
view_sizes,
view_strides,
dim
).copy_(index.unsqueeze(-1));
Tensor index_coords_flat = index_coords.flatten(0, -2);
// Copy mut_out_contig's strides into a tensor
// TODO: Is there a utility function that already does this?
IntArrayRef mut_out_contig_strides = mut_out_contig.strides();
Tensor coord_strides = at::empty(
{mut_out_contig.dim()},
TensorOptions().dtype(at::ScalarType::Long).device(at::kCPU));
std::memcpy(
coord_strides.data_ptr(),
mut_out_contig_strides.data(),
coord_strides.nbytes());
coord_strides = coord_strides.to(mut_out_contig.device());
// `index_flat` contains the 1-D indices corresponding with the
// flattened `mut_out`
Tensor index_flat = (index_coords_flat * coord_strides).sum({-1});
Tensor mut_out_flat = mut_out_contig.flatten();
Tensor src_flat = at::as_strided(
src,
index.sizes(),
src.strides()
).flatten();
torch::List<c10::optional<Tensor>> indices;
indices.reserve(1);
indices.push_back(index_flat);
mut_out_flat.index_put_(indices, src_flat, true);
if (!mut_out.is_contiguous()) {
mut_out.copy_(mut_out_flat.reshape(mut_out.sizes()));
}
}
} else {
scatter_add_stub(self.device().type(), mut_out, dim, index, src);
}

View File

@ -7,7 +7,7 @@ import torch
from torch.testing import make_tensor
from torch.testing._internal.common_utils import \
(parametrize, run_tests, TestCase,)
(parametrize, run_tests, TestCase, DeterministicGuard)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes, dtypesIfCUDA,
toleranceOverride, tol,)
@ -185,19 +185,23 @@ class TestScatterGather(TestCase):
@dtypes(torch.float16, torch.float32, torch.complex64)
def test_scatter_add_(self, device, dtype):
self._test_scatter_base(torch.Tensor.scatter_add_, device=device, dtype=dtype,
is_scalar=False, reduction=None)
for deterministic in [False, True]:
with DeterministicGuard(deterministic):
self._test_scatter_base(torch.Tensor.scatter_add_, device=device, dtype=dtype,
is_scalar=False, reduction=None)
@dtypes(torch.float32)
def test_scatter_add_mult_index_base(self, device, dtype):
m, n = 30, 40
idx = torch.zeros(m, n, device=device, dtype=torch.long)
src = torch.ones(m, n, device=device, dtype=dtype)
res0 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(0, idx, src)
res1 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(1, idx, src)
for deterministic in [False, True]:
with DeterministicGuard(deterministic):
m, n = 30, 40
idx = torch.zeros(m, n, device=device, dtype=torch.long)
src = torch.ones(m, n, device=device, dtype=dtype)
res0 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(0, idx, src)
res1 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(1, idx, src)
self.assertEqual(res0[0, :], m * torch.ones(n, device=device, dtype=dtype), atol=0, rtol=0)
self.assertEqual(res1[:, 0], n * torch.ones(m, device=device, dtype=dtype), atol=0, rtol=0)
self.assertEqual(res0[0, :], m * torch.ones(n, device=device, dtype=dtype), atol=0, rtol=0)
self.assertEqual(res1[:, 0], n * torch.ones(m, device=device, dtype=dtype), atol=0, rtol=0)
# FIXME: discrepancy between bool ReduceAdd on CUDA and CPU (a + b on CPU and buggy a && b on CUDA)
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False))

View File

@ -1440,23 +1440,6 @@ else:
test_func(torch.Tensor.cumsum)
test_func(torch.cumsum)
def test_nondeterministic_alert_scatter_add(self, device):
def test_func(op_call):
input = torch.randn(5, 4, device=device)
dim = 0
index = torch.tensor([[3]], device=device)
src = torch.tensor([[1.0]], device=device)
@expectedAlertNondeterministic('scatter_add_cuda_kernel', ['cuda'])
def forward_func(slf, device):
op_call(input, dim, index, src)
forward_func(self, device)
test_func(torch.Tensor.scatter_add_)
test_func(torch.Tensor.scatter_add)
test_func(torch.scatter_add)
@expectedFailureMeta # expected a non-determinitic error, but it was not raised
@onlyNativeDeviceTypes
def test_nondeterministic_alert_put(self, device):
@ -1540,24 +1523,6 @@ else:
test_func(self, device, 'method')
test_func(self, device, 'out')
@onlyNativeDeviceTypes
def test_nondeterministic_alert_gather(self, device):
def test_func(op_call):
a = torch.randn(3, 3, device=device, requires_grad=True)
dim = 0
index = torch.tensor([[0]], device=device)
res = op_call(a, dim, index)
grad = torch.ones_like(res)
@expectedAlertNondeterministic('scatter_add_cuda_kernel', ['cuda'])
def backward_func(slf, device):
res.backward(grad)
backward_func(self, device)
test_func(torch.gather)
test_func(torch.Tensor.gather)
@skipIfMps
def test_nondeterministic_alert_grid_sample_2d(self, device):
input = torch.empty(1, 1, 2, 2, device=device, requires_grad=True)

View File

@ -399,10 +399,8 @@ def use_deterministic_algorithms(mode, *, warn_only=False):
tensor
* :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU
tensor
* :func:`torch.Tensor.scatter_add_` when ``input`` dimension is one and called
on a CUDA tensor
* :func:`torch.gather` when ``input`` dimension is one and called
on a CUDA tensor that requires grad
* :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor
* :func:`torch.gather` when called on a CUDA tensor that requires grad
* :func:`torch.index_add` when called on CUDA tensor
* :func:`torch.index_select` when attempting to differentiate a CUDA tensor
* :func:`torch.repeat_interleave` when attempting to differentiate a CUDA tensor
@ -436,10 +434,6 @@ def use_deterministic_algorithms(mode, *, warn_only=False):
* :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor
* :class:`torch.nn.EmbeddingBag` when attempting to differentiate a CUDA tensor when
``mode='max'``
* :func:`torch.Tensor.scatter_add_` when ``input`` dimension is larger than one
and called on a CUDA tensor
* :func:`torch.gather` when ``input`` dimension is larger than one
and called on a CUDA tensor that requires grad
* :func:`torch.Tensor.put_` when ``accumulate=False``
* :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor
* :func:`torch.histc` when called on a CUDA tensor