Fix mask construction when dispatching index_put to masked_fill (#158472)

Fixes #158413
Previously trailing Nones in the index were incorrectly handled as implicit broadcasting dims in the mask, whereas they should just be ignored.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158472
Approved by: https://github.com/ezyang
This commit is contained in:
Natalia Gimelshein
2025-07-17 04:21:43 +00:00
committed by PyTorch MergeBot
parent ebf83b8b77
commit 8eaa9f2701
2 changed files with 6 additions and 1 deletions

View File

@ -964,6 +964,9 @@ class TestIndexing(TestCase):
mask = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool, device=device)
v[:, mask, :] = 0
self.assertEqual(v, torch.tensor([[[[0], [2]], [[3], [0]]]], device=device))
v = torch.tensor([[[[1], [2]], [[3], [4]]]], device=device)
torch.ops.aten.index_put_(v, [None, mask, None], torch.tensor(0))
self.assertEqual(v, torch.tensor([[[[0], [2]], [[3], [0]]]], device=device))
def test_byte_mask(self, device):
v = torch.randn(5, 7, 3, device=device)