mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix deterministic scatter_add path for multi-d tensors (#162866)
PReviously for more than 2d tensor `select` didn't work correctly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162866 Approved by: https://github.com/valentinandrei
This commit is contained in:
committed by
PyTorch MergeBot
parent
814ba34fa6
commit
bf6b40da3e
@ -2174,7 +2174,7 @@ static void _scatter_via_index_put(
|
||||
if (self.dim() == 1 || broadcast_index) {
|
||||
Tensor squeezed = index;
|
||||
if (broadcast_index && index.dim() > 1) {
|
||||
for (const auto d : c10::irange(index.dim())) {
|
||||
for (int64_t d = index.dim() - 1; d >= 0; --d) {
|
||||
if (d == dim) {
|
||||
continue;
|
||||
}
|
||||
|
@ -383,13 +383,14 @@ class TestScatterGather(TestCase):
|
||||
@dtypes(torch.float32)
|
||||
def test_scatter_add_broadcasted_index_deterministic(self, device, dtype):
|
||||
for d in (0, 1):
|
||||
inp = torch.randn(3, 4, device=device, dtype=dtype)
|
||||
inp = torch.randn(3, 4, 5, device=device, dtype=dtype)
|
||||
idx_1d = torch.randint(3, (10,), device=device)
|
||||
src_shape = list(inp.shape)
|
||||
src_shape[d] = 10
|
||||
src = torch.randn(src_shape, device=device, dtype=dtype)
|
||||
idx = idx_1d.unsqueeze(1 - d).expand(src_shape)
|
||||
print(idx.stride())
|
||||
idx_view_shape = [1] * inp.ndim
|
||||
idx_view_shape[d] = 10
|
||||
idx = idx_1d.view(idx_view_shape).expand(src_shape)
|
||||
ref = inp.clone().scatter_add_(d, idx, src)
|
||||
with DeterministicGuard(True):
|
||||
res = inp.clone().scatter_add_(d, idx, src)
|
||||
|
Reference in New Issue
Block a user