Add Half support for cummax, cummin, cumprod, logcumsumexp, and prod on CPU (#112132)

Add Half support for cummax, cummin, cumprod, logcumsumexp, and prod on CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112132
Approved by: https://github.com/cpuhrsch
This commit is contained in:
CaoE
2023-11-05 12:31:38 +00:00
committed by PyTorch MergeBot
parent 64f326097b
commit 26b5e27ace
7 changed files with 44 additions and 23 deletions

View File

@ -1435,6 +1435,18 @@ class TestReductions(TestCase):
torch.prod(x, 1, out=res2)
self.assertEqual(res1, res2)
@onlyCPU
@dtypes(torch.float16, torch.bfloat16)
def test_prod_lowp(self, device, dtype):
x = torch.rand(100, 100, dtype=dtype, device=device)
x_ref = x.float()
res1 = torch.prod(x, 1)
res2 = torch.prod(x_ref, 1)
self.assertEqual(res1, res2.to(dtype=dtype))
res1 = torch.prod(x, 0)
res2 = torch.prod(x_ref, 0)
self.assertEqual(res1, res2.to(dtype=dtype))
def test_prod_bool(self, device):
vals = [[True, True], [True, False], [False, False], []]
for val in vals: