use more efficient implementation for broadcasted indexing in determi… (#156744)

…nistic scatter_add

per title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156744
Approved by: https://github.com/suo
This commit is contained in:
Natalia Gimelshein
2025-06-25 02:59:45 +00:00
committed by PyTorch MergeBot
parent 9b498d3bb2
commit beb52f5c0a
2 changed files with 60 additions and 72 deletions

View File

@ -380,6 +380,22 @@ class TestScatterGather(TestCase):
helper([50, 8, 7], 100)
helper([50, 3, 4, 5], 100)
@dtypes(torch.float32)
def test_scatter_add_broadcasted_index_deterministic(self, device, dtype):
for d in (0, 1):
inp = torch.randn(3, 4, device=device, dtype=dtype)
idx_1d = torch.randint(3, (10,), device=device)
src_shape = list(inp.shape)
src_shape[d] = 10
src = torch.randn(src_shape, device=device, dtype=dtype)
idx = idx_1d.unsqueeze(1 - d).expand(src_shape)
print(idx.stride())
ref = inp.clone().scatter_add_(d, idx, src)
with DeterministicGuard(True):
res = inp.clone().scatter_add_(d, idx, src)
self.assertEqual(res, ref)
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
def test_gather_expanded_index(self, device, dtype):