fix deterministic scatter_add path for multi-d tensors (#162866)

PReviously for more than 2d tensor `select` didn't work correctly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162866
Approved by: https://github.com/valentinandrei
This commit is contained in:
Natalia Gimelshein
2025-09-15 06:50:00 +00:00
committed by PyTorch MergeBot
parent 814ba34fa6
commit bf6b40da3e
2 changed files with 5 additions and 4 deletions

View File

@ -383,13 +383,14 @@ class TestScatterGather(TestCase):
@dtypes(torch.float32)
def test_scatter_add_broadcasted_index_deterministic(self, device, dtype):
for d in (0, 1):
inp = torch.randn(3, 4, device=device, dtype=dtype)
inp = torch.randn(3, 4, 5, device=device, dtype=dtype)
idx_1d = torch.randint(3, (10,), device=device)
src_shape = list(inp.shape)
src_shape[d] = 10
src = torch.randn(src_shape, device=device, dtype=dtype)
idx = idx_1d.unsqueeze(1 - d).expand(src_shape)
print(idx.stride())
idx_view_shape = [1] * inp.ndim
idx_view_shape[d] = 10
idx = idx_1d.view(idx_view_shape).expand(src_shape)
ref = inp.clone().scatter_add_(d, idx, src)
with DeterministicGuard(True):
res = inp.clone().scatter_add_(d, idx, src)