mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix mask construction when dispatching index_put to masked_fill (#158472)
Fixes #158413 Previously trailing Nones in the index were incorrectly handled as implicit broadcasting dims in the mask, whereas they should just be ignored. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158472 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
ebf83b8b77
commit
8eaa9f2701
@ -964,6 +964,9 @@ class TestIndexing(TestCase):
|
||||
mask = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool, device=device)
|
||||
v[:, mask, :] = 0
|
||||
self.assertEqual(v, torch.tensor([[[[0], [2]], [[3], [0]]]], device=device))
|
||||
v = torch.tensor([[[[1], [2]], [[3], [4]]]], device=device)
|
||||
torch.ops.aten.index_put_(v, [None, mask, None], torch.tensor(0))
|
||||
self.assertEqual(v, torch.tensor([[[[0], [2]], [[3], [0]]]], device=device))
|
||||
|
||||
def test_byte_mask(self, device):
|
||||
v = torch.randn(5, 7, 3, device=device)
|
||||
|
Reference in New Issue
Block a user