From bf6b40da3e3be7718b8ddc94eed2da8cabaa5e86 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Mon, 15 Sep 2025 06:50:00 +0000 Subject: [PATCH] fix deterministic scatter_add path for multi-d tensors (#162866) PReviously for more than 2d tensor `select` didn't work correctly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162866 Approved by: https://github.com/valentinandrei --- aten/src/ATen/native/TensorAdvancedIndexing.cpp | 2 +- test/test_scatter_gather_ops.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 408faea1b764..7d613fc02312 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -2174,7 +2174,7 @@ static void _scatter_via_index_put( if (self.dim() == 1 || broadcast_index) { Tensor squeezed = index; if (broadcast_index && index.dim() > 1) { - for (const auto d : c10::irange(index.dim())) { + for (int64_t d = index.dim() - 1; d >= 0; --d) { if (d == dim) { continue; } diff --git a/test/test_scatter_gather_ops.py b/test/test_scatter_gather_ops.py index d2a0e8bd1ccc..ba967c142f1e 100644 --- a/test/test_scatter_gather_ops.py +++ b/test/test_scatter_gather_ops.py @@ -383,13 +383,14 @@ class TestScatterGather(TestCase): @dtypes(torch.float32) def test_scatter_add_broadcasted_index_deterministic(self, device, dtype): for d in (0, 1): - inp = torch.randn(3, 4, device=device, dtype=dtype) + inp = torch.randn(3, 4, 5, device=device, dtype=dtype) idx_1d = torch.randint(3, (10,), device=device) src_shape = list(inp.shape) src_shape[d] = 10 src = torch.randn(src_shape, device=device, dtype=dtype) - idx = idx_1d.unsqueeze(1 - d).expand(src_shape) - print(idx.stride()) + idx_view_shape = [1] * inp.ndim + idx_view_shape[d] = 10 + idx = idx_1d.view(idx_view_shape).expand(src_shape) ref = inp.clone().scatter_add_(d, idx, src) with DeterministicGuard(True): res = inp.clone().scatter_add_(d, idx, src)