mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add deterministic impl of scatter_add
CUDA for all input sizes (#79466)
Fixes #50469 Pull Request resolved: https://github.com/pytorch/pytorch/pull/79466 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
039b0146f9
commit
5b58140d1a
@ -1504,16 +1504,101 @@ TORCH_IMPL_FUNC(scatter_add)
|
||||
|
||||
if (index.numel() == 0) return;
|
||||
|
||||
if (globalContext().deterministicAlgorithms() && self.device().type() == DeviceType::CUDA && self.dim() == 1) {
|
||||
TORCH_CHECK(index.dim() == 1 && src.dim() == 1, "index and src should be 1D tensors when self is a 1D tensor, "
|
||||
"but their dims are ", index.dim(), " and ", src.dim(), ", respectively");
|
||||
TORCH_CHECK(index.numel() == src.numel(), "index and src should have same number of elements for 1D tensors, "
|
||||
"but got ", index.numel(), " versus ", src.numel());
|
||||
TORCH_CHECK(dim == 0, "dim should be zero for 1D self tensor, but got ", dim);
|
||||
torch::List<c10::optional<Tensor>> indices;
|
||||
indices.reserve(1);
|
||||
indices.push_back(index);
|
||||
mut_out.index_put_(indices, src, true);
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
// Avoid gpuAtomicAdd for CUDA if deterministic mode is turned on
|
||||
if (globalContext().deterministicAlgorithms() && self.device().type() == DeviceType::CUDA) {
|
||||
if (self.dim() == 1) {
|
||||
// TODO: Pretty sure these checks can be removed, since they're done in
|
||||
// `scatter_meta_impl`, which I think is always called before this
|
||||
TORCH_CHECK(index.dim() == 1 && src.dim() == 1, "index and src should be 1D tensors when self is a 1D tensor, "
|
||||
"but their dims are ", index.dim(), " and ", src.dim(), ", respectively");
|
||||
TORCH_CHECK(index.numel() == src.numel(), "index and src should have same number of elements for 1D tensors, "
|
||||
"but got ", index.numel(), " versus ", src.numel());
|
||||
TORCH_CHECK(dim == 0, "dim should be zero for 1D self tensor, but got ", dim);
|
||||
torch::List<c10::optional<Tensor>> indices;
|
||||
indices.reserve(1);
|
||||
indices.push_back(index);
|
||||
mut_out.index_put_(indices, src, true);
|
||||
} else {
|
||||
Tensor mut_out_contig = mut_out.contiguous();
|
||||
|
||||
auto index_coords_sizes = index.sizes().vec();
|
||||
index_coords_sizes.push_back(self.dim());
|
||||
auto index_coords = at::empty(
|
||||
index_coords_sizes,
|
||||
at::TensorOptions().dtype(at::ScalarType::Long).device(self.device()));
|
||||
|
||||
for (int64_t dim_other = 0; dim_other < self.dim(); dim_other++) {
|
||||
if (dim_other == dim) {
|
||||
continue;
|
||||
}
|
||||
auto dim_coord_vals = at::arange(
|
||||
index.size(dim_other),
|
||||
at::TensorOptions().device(self.device()));
|
||||
|
||||
for (int64_t dim_unsqueeze = 0; dim_unsqueeze < self.dim() - 1; dim_unsqueeze++) {
|
||||
dim_coord_vals = dim_coord_vals.unsqueeze((dim_unsqueeze >= dim_other) ? -1 : 0);
|
||||
}
|
||||
|
||||
auto view_sizes = index.sizes().vec();
|
||||
view_sizes.push_back(1);
|
||||
auto view_strides = index_coords.strides().vec();
|
||||
view_strides[self.dim()] = self.dim();
|
||||
|
||||
at::as_strided(
|
||||
index_coords,
|
||||
view_sizes,
|
||||
view_strides,
|
||||
dim_other
|
||||
).copy_(dim_coord_vals.unsqueeze(-1));
|
||||
}
|
||||
|
||||
auto view_sizes = index.sizes().vec();
|
||||
view_sizes.push_back(1);
|
||||
auto view_strides = index_coords.strides().vec();
|
||||
view_strides[self.dim()] = self.dim();
|
||||
|
||||
at::as_strided(
|
||||
index_coords,
|
||||
view_sizes,
|
||||
view_strides,
|
||||
dim
|
||||
).copy_(index.unsqueeze(-1));
|
||||
|
||||
Tensor index_coords_flat = index_coords.flatten(0, -2);
|
||||
|
||||
// Copy mut_out_contig's strides into a tensor
|
||||
// TODO: Is there a utility function that already does this?
|
||||
IntArrayRef mut_out_contig_strides = mut_out_contig.strides();
|
||||
Tensor coord_strides = at::empty(
|
||||
{mut_out_contig.dim()},
|
||||
TensorOptions().dtype(at::ScalarType::Long).device(at::kCPU));
|
||||
std::memcpy(
|
||||
coord_strides.data_ptr(),
|
||||
mut_out_contig_strides.data(),
|
||||
coord_strides.nbytes());
|
||||
coord_strides = coord_strides.to(mut_out_contig.device());
|
||||
|
||||
// `index_flat` contains the 1-D indices corresponding with the
|
||||
// flattened `mut_out`
|
||||
Tensor index_flat = (index_coords_flat * coord_strides).sum({-1});
|
||||
Tensor mut_out_flat = mut_out_contig.flatten();
|
||||
Tensor src_flat = at::as_strided(
|
||||
src,
|
||||
index.sizes(),
|
||||
src.strides()
|
||||
).flatten();
|
||||
|
||||
torch::List<c10::optional<Tensor>> indices;
|
||||
indices.reserve(1);
|
||||
indices.push_back(index_flat);
|
||||
|
||||
mut_out_flat.index_put_(indices, src_flat, true);
|
||||
|
||||
if (!mut_out.is_contiguous()) {
|
||||
mut_out.copy_(mut_out_flat.reshape(mut_out.sizes()));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scatter_add_stub(self.device().type(), mut_out, dim, index, src);
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import \
|
||||
(parametrize, run_tests, TestCase,)
|
||||
(parametrize, run_tests, TestCase, DeterministicGuard)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, dtypes, dtypesIfCUDA,
|
||||
toleranceOverride, tol,)
|
||||
@ -185,19 +185,23 @@ class TestScatterGather(TestCase):
|
||||
|
||||
@dtypes(torch.float16, torch.float32, torch.complex64)
|
||||
def test_scatter_add_(self, device, dtype):
|
||||
self._test_scatter_base(torch.Tensor.scatter_add_, device=device, dtype=dtype,
|
||||
is_scalar=False, reduction=None)
|
||||
for deterministic in [False, True]:
|
||||
with DeterministicGuard(deterministic):
|
||||
self._test_scatter_base(torch.Tensor.scatter_add_, device=device, dtype=dtype,
|
||||
is_scalar=False, reduction=None)
|
||||
|
||||
@dtypes(torch.float32)
|
||||
def test_scatter_add_mult_index_base(self, device, dtype):
|
||||
m, n = 30, 40
|
||||
idx = torch.zeros(m, n, device=device, dtype=torch.long)
|
||||
src = torch.ones(m, n, device=device, dtype=dtype)
|
||||
res0 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(0, idx, src)
|
||||
res1 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(1, idx, src)
|
||||
for deterministic in [False, True]:
|
||||
with DeterministicGuard(deterministic):
|
||||
m, n = 30, 40
|
||||
idx = torch.zeros(m, n, device=device, dtype=torch.long)
|
||||
src = torch.ones(m, n, device=device, dtype=dtype)
|
||||
res0 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(0, idx, src)
|
||||
res1 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(1, idx, src)
|
||||
|
||||
self.assertEqual(res0[0, :], m * torch.ones(n, device=device, dtype=dtype), atol=0, rtol=0)
|
||||
self.assertEqual(res1[:, 0], n * torch.ones(m, device=device, dtype=dtype), atol=0, rtol=0)
|
||||
self.assertEqual(res0[0, :], m * torch.ones(n, device=device, dtype=dtype), atol=0, rtol=0)
|
||||
self.assertEqual(res1[:, 0], n * torch.ones(m, device=device, dtype=dtype), atol=0, rtol=0)
|
||||
|
||||
# FIXME: discrepancy between bool ReduceAdd on CUDA and CPU (a + b on CPU and buggy a && b on CUDA)
|
||||
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False))
|
||||
|
@ -1440,23 +1440,6 @@ else:
|
||||
test_func(torch.Tensor.cumsum)
|
||||
test_func(torch.cumsum)
|
||||
|
||||
def test_nondeterministic_alert_scatter_add(self, device):
|
||||
def test_func(op_call):
|
||||
input = torch.randn(5, 4, device=device)
|
||||
dim = 0
|
||||
index = torch.tensor([[3]], device=device)
|
||||
src = torch.tensor([[1.0]], device=device)
|
||||
|
||||
@expectedAlertNondeterministic('scatter_add_cuda_kernel', ['cuda'])
|
||||
def forward_func(slf, device):
|
||||
op_call(input, dim, index, src)
|
||||
|
||||
forward_func(self, device)
|
||||
|
||||
test_func(torch.Tensor.scatter_add_)
|
||||
test_func(torch.Tensor.scatter_add)
|
||||
test_func(torch.scatter_add)
|
||||
|
||||
@expectedFailureMeta # expected a non-determinitic error, but it was not raised
|
||||
@onlyNativeDeviceTypes
|
||||
def test_nondeterministic_alert_put(self, device):
|
||||
@ -1540,24 +1523,6 @@ else:
|
||||
test_func(self, device, 'method')
|
||||
test_func(self, device, 'out')
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
def test_nondeterministic_alert_gather(self, device):
|
||||
def test_func(op_call):
|
||||
a = torch.randn(3, 3, device=device, requires_grad=True)
|
||||
dim = 0
|
||||
index = torch.tensor([[0]], device=device)
|
||||
res = op_call(a, dim, index)
|
||||
grad = torch.ones_like(res)
|
||||
|
||||
@expectedAlertNondeterministic('scatter_add_cuda_kernel', ['cuda'])
|
||||
def backward_func(slf, device):
|
||||
res.backward(grad)
|
||||
|
||||
backward_func(self, device)
|
||||
|
||||
test_func(torch.gather)
|
||||
test_func(torch.Tensor.gather)
|
||||
|
||||
@skipIfMps
|
||||
def test_nondeterministic_alert_grid_sample_2d(self, device):
|
||||
input = torch.empty(1, 1, 2, 2, device=device, requires_grad=True)
|
||||
|
@ -399,10 +399,8 @@ def use_deterministic_algorithms(mode, *, warn_only=False):
|
||||
tensor
|
||||
* :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU
|
||||
tensor
|
||||
* :func:`torch.Tensor.scatter_add_` when ``input`` dimension is one and called
|
||||
on a CUDA tensor
|
||||
* :func:`torch.gather` when ``input`` dimension is one and called
|
||||
on a CUDA tensor that requires grad
|
||||
* :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor
|
||||
* :func:`torch.gather` when called on a CUDA tensor that requires grad
|
||||
* :func:`torch.index_add` when called on CUDA tensor
|
||||
* :func:`torch.index_select` when attempting to differentiate a CUDA tensor
|
||||
* :func:`torch.repeat_interleave` when attempting to differentiate a CUDA tensor
|
||||
@ -436,10 +434,6 @@ def use_deterministic_algorithms(mode, *, warn_only=False):
|
||||
* :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor
|
||||
* :class:`torch.nn.EmbeddingBag` when attempting to differentiate a CUDA tensor when
|
||||
``mode='max'``
|
||||
* :func:`torch.Tensor.scatter_add_` when ``input`` dimension is larger than one
|
||||
and called on a CUDA tensor
|
||||
* :func:`torch.gather` when ``input`` dimension is larger than one
|
||||
and called on a CUDA tensor that requires grad
|
||||
* :func:`torch.Tensor.put_` when ``accumulate=False``
|
||||
* :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor
|
||||
* :func:`torch.histc` when called on a CUDA tensor
|
||||
|
Reference in New Issue
Block a user