Fix kaiser_window for lower precision data types on CPU (#117345)

Fixes #117230.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117345
Approved by: https://github.com/jgong5, https://github.com/soumith
This commit is contained in:
CaoE
2024-01-26 03:26:12 +00:00
committed by PyTorch MergeBot
parent ef29fe745f
commit 8467de4e97
2 changed files with 18 additions and 6 deletions

View File

@ -2737,10 +2737,22 @@ class TestTensorCreation(TestCase):
return
for size in [0, 1, 2, 5, 10, 50, 100, 1024, 2048]:
for periodic in [True, False]:
res = torch_method(size, periodic=periodic, **kwargs, device=device, dtype=dtype)
res = torch_method(
size,
periodic=periodic,
layout=torch.strided,
requires_grad=False,
**kwargs,
device=device,
dtype=dtype,
)
# NB: scipy always returns a float64 result
ref = torch.from_numpy(signal.get_window((name, *(kwargs.values())), size, fftbins=periodic))
self.assertEqual(res, ref, exact_dtype=False)
ref = torch.from_numpy(
signal.get_window(
(name, *(kwargs.values())), size, fftbins=periodic
)
)
self.assertEqual(res, ref.to(dtype))
with self.assertRaisesRegex(RuntimeError, r'not implemented for sparse types'):
torch_method(3, layout=torch.sparse_coo)
self.assertTrue(torch_method(3, requires_grad=True).requires_grad)
@ -2761,7 +2773,7 @@ class TestTensorCreation(TestCase):
@unittest.skipIf(not TEST_SCIPY, "Scipy not found")
@skipIfTorchDynamo("TorchDynamo fails with unknown reason")
@dtypesIfCUDA(torch.float, torch.double, torch.bfloat16, torch.half, torch.long)
@dtypes(torch.float, torch.double, torch.long)
@dtypes(torch.float, torch.double, torch.long, torch.bfloat16, torch.float16)
def test_kaiser_window(self, device, dtype):
for num_test in range(50):
self._test_signal_window_functions('kaiser', dtype, device, beta=random.random() * 30)