mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix kaiser_window for lower precision data types on CPU (#117345)
Fixes #117230. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117345 Approved by: https://github.com/jgong5, https://github.com/soumith
This commit is contained in:
@ -2737,10 +2737,22 @@ class TestTensorCreation(TestCase):
|
||||
return
|
||||
for size in [0, 1, 2, 5, 10, 50, 100, 1024, 2048]:
|
||||
for periodic in [True, False]:
|
||||
res = torch_method(size, periodic=periodic, **kwargs, device=device, dtype=dtype)
|
||||
res = torch_method(
|
||||
size,
|
||||
periodic=periodic,
|
||||
layout=torch.strided,
|
||||
requires_grad=False,
|
||||
**kwargs,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
# NB: scipy always returns a float64 result
|
||||
ref = torch.from_numpy(signal.get_window((name, *(kwargs.values())), size, fftbins=periodic))
|
||||
self.assertEqual(res, ref, exact_dtype=False)
|
||||
ref = torch.from_numpy(
|
||||
signal.get_window(
|
||||
(name, *(kwargs.values())), size, fftbins=periodic
|
||||
)
|
||||
)
|
||||
self.assertEqual(res, ref.to(dtype))
|
||||
with self.assertRaisesRegex(RuntimeError, r'not implemented for sparse types'):
|
||||
torch_method(3, layout=torch.sparse_coo)
|
||||
self.assertTrue(torch_method(3, requires_grad=True).requires_grad)
|
||||
@ -2761,7 +2773,7 @@ class TestTensorCreation(TestCase):
|
||||
@unittest.skipIf(not TEST_SCIPY, "Scipy not found")
|
||||
@skipIfTorchDynamo("TorchDynamo fails with unknown reason")
|
||||
@dtypesIfCUDA(torch.float, torch.double, torch.bfloat16, torch.half, torch.long)
|
||||
@dtypes(torch.float, torch.double, torch.long)
|
||||
@dtypes(torch.float, torch.double, torch.long, torch.bfloat16, torch.float16)
|
||||
def test_kaiser_window(self, device, dtype):
|
||||
for num_test in range(50):
|
||||
self._test_signal_window_functions('kaiser', dtype, device, beta=random.random() * 30)
|
||||
|
Reference in New Issue
Block a user