mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
use more efficient implementation for broadcasted indexing in determi… (#156744)
…nistic scatter_add per title Pull Request resolved: https://github.com/pytorch/pytorch/pull/156744 Approved by: https://github.com/suo
This commit is contained in:
committed by
PyTorch MergeBot
parent
9b498d3bb2
commit
beb52f5c0a
@ -380,6 +380,22 @@ class TestScatterGather(TestCase):
|
||||
helper([50, 8, 7], 100)
|
||||
helper([50, 3, 4, 5], 100)
|
||||
|
||||
@dtypes(torch.float32)
|
||||
def test_scatter_add_broadcasted_index_deterministic(self, device, dtype):
|
||||
for d in (0, 1):
|
||||
inp = torch.randn(3, 4, device=device, dtype=dtype)
|
||||
idx_1d = torch.randint(3, (10,), device=device)
|
||||
src_shape = list(inp.shape)
|
||||
src_shape[d] = 10
|
||||
src = torch.randn(src_shape, device=device, dtype=dtype)
|
||||
idx = idx_1d.unsqueeze(1 - d).expand(src_shape)
|
||||
print(idx.stride())
|
||||
ref = inp.clone().scatter_add_(d, idx, src)
|
||||
with DeterministicGuard(True):
|
||||
res = inp.clone().scatter_add_(d, idx, src)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||
def test_gather_expanded_index(self, device, dtype):
|
||||
|
Reference in New Issue
Block a user