mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix index_add for int64 input + zerodim index (#161511)
Fixes #161446 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161511 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
07a4e9fea8
commit
d51486616c
@ -2029,6 +2029,18 @@ class TestIndexing(TestCase):
|
||||
|
||||
self.assertEqual(output, input_list)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
def test_index_add_zerodim_index_floating_alpha(self, device) -> None:
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/161446
|
||||
x = torch.ones([2, 3], dtype=torch.int64, device=device)
|
||||
index = torch.tensor(0, dtype=torch.int64, device=device)
|
||||
src = torch.full([1, 3], 2, dtype=torch.int64, device=device)
|
||||
alpha = 1.5
|
||||
x.index_add_(0, index, src, alpha=alpha)
|
||||
self.assertEqual(
|
||||
x, torch.tensor([[3, 3, 3], [1, 1, 1]], dtype=torch.int64, device=device)
|
||||
)
|
||||
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
@expectedFailureMPS
|
||||
def test_index_fill(self, device, dtype):
|
||||
|
Reference in New Issue
Block a user