mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCM] enable fft tests (#60313)
Summary: This PR enables fft tests on ROCM. It contains a function that generates a valid input for fft tests that call hipfftExecC2R or hipfftExecZ2D. With this helper function we are able to fix a number of fft tests. This brings a close to the series of fft PRs enabling fft tests on ROCM. Pull Request resolved: https://github.com/pytorch/pytorch/pull/60313 Reviewed By: mruberry Differential Revision: D29463487 Pulled By: malfet fbshipit-source-id: d0903fbf12d24ba95a42c8b7589714fdb63353ed
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e2b42c6f52
commit
ce232e7847
@ -11,9 +11,9 @@ from torch.testing._internal.common_utils import \
|
||||
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, dtypes, onlyOnCPUAndCUDA,
|
||||
skipCPUIfNoMkl, skipCUDAIfRocm, deviceCountAtLeast, onlyCUDA, OpDTypes,
|
||||
skipCPUIfNoMkl, deviceCountAtLeast, onlyCUDA, OpDTypes,
|
||||
skipIf)
|
||||
from torch.testing._internal.common_methods_invocations import spectral_funcs
|
||||
from torch.testing._internal.common_methods_invocations import spectral_funcs, SpectralFuncInfo
|
||||
|
||||
from setuptools import distutils
|
||||
from typing import Optional, List
|
||||
@ -94,11 +94,37 @@ def _stft_reference(x, hop_length, window):
|
||||
X[:, m] = torch.fft.fft(slc * window)
|
||||
return X
|
||||
|
||||
|
||||
# Tests of functions related to Fourier analysis in the torch.fft namespace
|
||||
class TestFFT(TestCase):
|
||||
exact_dtype = True
|
||||
|
||||
# rocFFT requires/assumes that the input to hipfftExecC2R or hipfftExecZ2D
|
||||
# is of the form that is a valid output from a real to complex transform
|
||||
# (i.e. it cannot be a set of random numbers)
|
||||
# So for ROCm, call np.fft.rfftn and use its output as the input
|
||||
# for testing ops that call hipfftExecC2R
|
||||
def _generate_valid_rocfft_input(self, input, op):
|
||||
# check if op can invoke hipfftExecC2R or hipfftExecZ2D
|
||||
if type(op) == SpectralFuncInfo:
|
||||
supported_ops = op.supported_dtypes("")
|
||||
if not all(ctype in supported_ops for ctype in [torch.cfloat, torch.double]):
|
||||
return input
|
||||
else:
|
||||
if op.__name__ in ["fft_rfft2"]:
|
||||
return input
|
||||
|
||||
|
||||
# if input is complex use the real part
|
||||
if torch.is_complex(input):
|
||||
np_input_real = input.real.cpu().numpy()
|
||||
else:
|
||||
np_input_real = input.cpu().numpy()
|
||||
|
||||
# generate Hermitian symmetric input using rfftn
|
||||
rfft_output = np.fft.rfftn(np_input_real)
|
||||
|
||||
return torch.from_numpy(rfft_output)
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@ops([op for op in spectral_funcs if not op.ndimensional])
|
||||
def test_reference_1d(self, device, dtype, op):
|
||||
@ -133,12 +159,14 @@ class TestFFT(TestCase):
|
||||
input = args[0]
|
||||
args = args[1:]
|
||||
|
||||
if torch.version.hip is not None:
|
||||
input = self._generate_valid_rocfft_input(input, op)
|
||||
|
||||
expected = op.ref(input.cpu().numpy(), *args)
|
||||
exact_dtype = dtype in (torch.double, torch.complex128)
|
||||
actual = op(input, *args)
|
||||
self.assertEqual(actual, expected, exact_dtype=exact_dtype)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@skipCPUIfNoMkl
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
|
||||
@ -248,7 +276,6 @@ class TestFFT(TestCase):
|
||||
op(x)
|
||||
|
||||
# nd-fft tests
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
|
||||
@ops([op for op in spectral_funcs if op.ndimensional])
|
||||
@ -273,13 +300,16 @@ class TestFFT(TestCase):
|
||||
for input_ndim, s, dim in transform_desc:
|
||||
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
|
||||
input = torch.randn(*shape, device=device, dtype=dtype)
|
||||
|
||||
if torch.version.hip is not None:
|
||||
input = self._generate_valid_rocfft_input(input, op)
|
||||
|
||||
for norm in norm_modes:
|
||||
expected = op.ref(input.cpu().numpy(), s, dim, norm)
|
||||
exact_dtype = dtype in (torch.double, torch.complex128)
|
||||
actual = op(input, s, dim, norm)
|
||||
self.assertEqual(actual, expected, exact_dtype=exact_dtype)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@skipCPUIfNoMkl
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
|
||||
@ -346,7 +376,6 @@ class TestFFT(TestCase):
|
||||
# so don't require exhaustive testing.
|
||||
|
||||
@skipCPUIfNoMkl
|
||||
@skipCUDAIfRocm
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(torch.double, torch.complex128)
|
||||
def test_fft2_numpy(self, device, dtype):
|
||||
@ -375,21 +404,25 @@ class TestFFT(TestCase):
|
||||
|
||||
torch_fns = (torch_fn, torch.jit.script(fn))
|
||||
|
||||
if torch.version.hip is not None:
|
||||
valid_input = self._generate_valid_rocfft_input(input, torch_fn)
|
||||
else:
|
||||
valid_input = input
|
||||
|
||||
# Once with dim defaulted
|
||||
input_np = input.cpu().numpy()
|
||||
input_np = valid_input.cpu().numpy()
|
||||
expected = numpy_fn(input_np, s, norm=norm)
|
||||
for fn in torch_fns:
|
||||
actual = fn(input, s, norm=norm)
|
||||
actual = fn(valid_input, s, norm=norm)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
# Once with explicit dims
|
||||
dim = (1, 0)
|
||||
expected = numpy_fn(input.cpu(), s, dim, norm)
|
||||
expected = numpy_fn(valid_input.cpu(), s, dim, norm)
|
||||
for fn in torch_fns:
|
||||
actual = fn(input, s, dim, norm)
|
||||
actual = fn(valid_input, s, dim, norm)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@skipCPUIfNoMkl
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(torch.float, torch.complex64)
|
||||
@ -596,7 +629,6 @@ class TestFFT(TestCase):
|
||||
_test_complex((40, 60, 3, 80), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:])
|
||||
_test_complex((30, 55, 50, 22), 3, lambda x: x[:, 3:53, 15:40, 1:21])
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@skipCPUIfNoMkl
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(torch.double)
|
||||
@ -742,8 +774,6 @@ class TestFFT(TestCase):
|
||||
_test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError)
|
||||
_test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError)
|
||||
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@ -786,7 +816,6 @@ class TestFFT(TestCase):
|
||||
length=x.size(-1), **common_kwargs)
|
||||
self.assertEqual(x_roundtrip, x)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@ -828,7 +857,6 @@ class TestFFT(TestCase):
|
||||
self.assertEqual(x_roundtrip, x)
|
||||
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@skipCPUIfNoMkl
|
||||
@dtypes(torch.cdouble)
|
||||
def test_complex_stft_definition(self, device, dtype):
|
||||
@ -848,7 +876,6 @@ class TestFFT(TestCase):
|
||||
actual = torch.stft(*args, window=window, center=False)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
@dtypes(torch.cdouble)
|
||||
@ -883,7 +910,6 @@ class TestFFT(TestCase):
|
||||
center=center, normalized=normalized)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@skipCPUIfNoMkl
|
||||
@dtypes(torch.cdouble)
|
||||
def test_complex_istft_real_equiv(self, device, dtype):
|
||||
@ -959,7 +985,6 @@ class TestFFT(TestCase):
|
||||
_ = torch.fft.irfftn(half_spectrum_copy, s=(2, 2), dim=(-2, -1))
|
||||
self.assertEqual(half_spectrum, half_spectrum_copy)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
@dtypes(torch.double)
|
||||
@ -973,7 +998,6 @@ class TestFFT(TestCase):
|
||||
_test(torch.ones(4, dtype=dtype, device=device), 4, 4)
|
||||
_test(torch.zeros(4, dtype=dtype, device=device), 4, 4)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
@dtypes(torch.double)
|
||||
@ -1076,7 +1100,6 @@ class TestFFT(TestCase):
|
||||
self.assertRaises(RuntimeError, torch.istft, torch.zeros((3, 0, 2)), 2)
|
||||
self.assertRaises(RuntimeError, torch.istft, torch.zeros((0, 3, 2)), 2)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
@dtypes(torch.double)
|
||||
@ -1111,7 +1134,6 @@ class TestFFT(TestCase):
|
||||
_test(amplitude=80, L=9, n=6)
|
||||
_test(amplitude=99, L=10, n=7)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
@dtypes(torch.double)
|
||||
@ -1178,7 +1200,6 @@ class TestFFT(TestCase):
|
||||
for data_size, kwargs in patterns:
|
||||
_test(data_size, kwargs)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
def test_batch_istft(self, device):
|
||||
|
||||
Reference in New Issue
Block a user