FFT: Handle noop fftn calls gracefully (#117368)

Fixes #117252
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117368
Approved by: https://github.com/malfet
This commit is contained in:
Peter Bell
2024-01-12 17:18:04 +00:00
committed by PyTorch MergeBot
parent 5cf481d1ac
commit 18bd5c05bc
2 changed files with 33 additions and 0 deletions

View File

@ -462,6 +462,31 @@ class TestFFT(TestCase):
with self.assertRaisesRegex(RuntimeError, "tensor only has 3 dimensions"):
op(a, s=(10, 10, 10, 10))
@skipCPUIfNoFFT
@onlyNativeDeviceTypes
@dtypes(torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_fftn_noop_transform(self, device, dtype):
skip_helper_for_fft(device, dtype)
RESULT_TYPE = {
torch.half: torch.chalf,
torch.float: torch.cfloat,
torch.double: torch.cdouble,
}
for op in [
torch.fft.fftn,
torch.fft.ifftn,
torch.fft.fft2,
torch.fft.ifft2,
]:
inp = make_tensor((10, 10), device=device, dtype=dtype)
out = torch.fft.fftn(inp, dim=[])
expect_dtype = RESULT_TYPE.get(inp.dtype, inp.dtype)
expect = inp.to(expect_dtype)
self.assertEqual(expect, out)
@skipCPUIfNoFFT
@onlyNativeDeviceTypes
@toleranceOverride({