mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Half support for cummax, cummin, cumprod, logcumsumexp, and prod on CPU (#112132)
Add Half support for cummax, cummin, cumprod, logcumsumexp, and prod on CPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112132 Approved by: https://github.com/cpuhrsch
This commit is contained in:
@ -1435,6 +1435,18 @@ class TestReductions(TestCase):
|
||||
torch.prod(x, 1, out=res2)
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_prod_lowp(self, device, dtype):
|
||||
x = torch.rand(100, 100, dtype=dtype, device=device)
|
||||
x_ref = x.float()
|
||||
res1 = torch.prod(x, 1)
|
||||
res2 = torch.prod(x_ref, 1)
|
||||
self.assertEqual(res1, res2.to(dtype=dtype))
|
||||
res1 = torch.prod(x, 0)
|
||||
res2 = torch.prod(x_ref, 0)
|
||||
self.assertEqual(res1, res2.to(dtype=dtype))
|
||||
|
||||
def test_prod_bool(self, device):
|
||||
vals = [[True, True], [True, False], [False, False], []]
|
||||
for val in vals:
|
||||
|
Reference in New Issue
Block a user