fix torch.prod vectorized path for bool (#128009)

Fix https://github.com/pytorch/pytorch/issues/127866.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128009
Approved by: https://github.com/jgong5, https://github.com/albanD
This commit is contained in:
haozhe.zhu
2024-08-27 09:29:50 +08:00
committed by PyTorch MergeBot
parent 89929d9abc
commit 2ba60a1618
4 changed files with 41 additions and 1 deletions

View File

@ -1498,7 +1498,13 @@ class TestReductions(TestCase):
self.assertEqual(res1, res2.to(dtype=dtype))
def test_prod_bool(self, device):
vals = [[True, True], [True, False], [False, False], []]
vals = [
[True, True],
[True, False],
[False, False],
[],
[False] * 256, # https://github.com/pytorch/pytorch/issues/127866
]
for val in vals:
result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool).item()
expect = np.prod(np.array(val), dtype=bool)