mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
FFT: Handle noop fftn calls gracefully (#117368)
Fixes #117252 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117368 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
5cf481d1ac
commit
18bd5c05bc
@ -462,6 +462,31 @@ class TestFFT(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "tensor only has 3 dimensions"):
|
||||
op(a, s=(10, 10, 10, 10))
|
||||
|
||||
@skipCPUIfNoFFT
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble)
|
||||
def test_fftn_noop_transform(self, device, dtype):
|
||||
skip_helper_for_fft(device, dtype)
|
||||
RESULT_TYPE = {
|
||||
torch.half: torch.chalf,
|
||||
torch.float: torch.cfloat,
|
||||
torch.double: torch.cdouble,
|
||||
}
|
||||
|
||||
for op in [
|
||||
torch.fft.fftn,
|
||||
torch.fft.ifftn,
|
||||
torch.fft.fft2,
|
||||
torch.fft.ifft2,
|
||||
]:
|
||||
inp = make_tensor((10, 10), device=device, dtype=dtype)
|
||||
out = torch.fft.fftn(inp, dim=[])
|
||||
|
||||
expect_dtype = RESULT_TYPE.get(inp.dtype, inp.dtype)
|
||||
expect = inp.to(expect_dtype)
|
||||
self.assertEqual(expect, out)
|
||||
|
||||
|
||||
@skipCPUIfNoFFT
|
||||
@onlyNativeDeviceTypes
|
||||
@toleranceOverride({
|
||||
|
Reference in New Issue
Block a user