[ROCM] enable fft tests (#60313)

Summary:
This PR enables fft tests on ROCM. It contains a function that generates a valid input for fft tests that call hipfftExecC2R or hipfftExecZ2D. With this helper function we are able to fix a number of fft tests. This brings a close to the series of fft PRs enabling fft tests on ROCM.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60313

Reviewed By: mruberry

Differential Revision: D29463487

Pulled By: malfet

fbshipit-source-id: d0903fbf12d24ba95a42c8b7589714fdb63353ed
This commit is contained in:
Michael Melesse
2021-06-29 22:42:18 -07:00
committed by Facebook GitHub Bot
parent e2b42c6f52
commit ce232e7847

View File

@ -11,9 +11,9 @@ from torch.testing._internal.common_utils import \
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, dtypes, onlyOnCPUAndCUDA,
skipCPUIfNoMkl, skipCUDAIfRocm, deviceCountAtLeast, onlyCUDA, OpDTypes,
skipCPUIfNoMkl, deviceCountAtLeast, onlyCUDA, OpDTypes,
skipIf)
from torch.testing._internal.common_methods_invocations import spectral_funcs
from torch.testing._internal.common_methods_invocations import spectral_funcs, SpectralFuncInfo
from setuptools import distutils
from typing import Optional, List
@ -94,11 +94,37 @@ def _stft_reference(x, hop_length, window):
X[:, m] = torch.fft.fft(slc * window)
return X
# Tests of functions related to Fourier analysis in the torch.fft namespace
class TestFFT(TestCase):
exact_dtype = True
# rocFFT requires/assumes that the input to hipfftExecC2R or hipfftExecZ2D
# is of the form that is a valid output from a real to complex transform
# (i.e. it cannot be a set of random numbers)
# So for ROCm, call np.fft.rfftn and use its output as the input
# for testing ops that call hipfftExecC2R
def _generate_valid_rocfft_input(self, input, op):
# check if op can invoke hipfftExecC2R or hipfftExecZ2D
if type(op) == SpectralFuncInfo:
supported_ops = op.supported_dtypes("")
if not all(ctype in supported_ops for ctype in [torch.cfloat, torch.double]):
return input
else:
if op.__name__ in ["fft_rfft2"]:
return input
# if input is complex use the real part
if torch.is_complex(input):
np_input_real = input.real.cpu().numpy()
else:
np_input_real = input.cpu().numpy()
# generate Hermitian symmetric input using rfftn
rfft_output = np.fft.rfftn(np_input_real)
return torch.from_numpy(rfft_output)
@onlyOnCPUAndCUDA
@ops([op for op in spectral_funcs if not op.ndimensional])
def test_reference_1d(self, device, dtype, op):
@ -133,12 +159,14 @@ class TestFFT(TestCase):
input = args[0]
args = args[1:]
if torch.version.hip is not None:
input = self._generate_valid_rocfft_input(input, op)
expected = op.ref(input.cpu().numpy(), *args)
exact_dtype = dtype in (torch.double, torch.complex128)
actual = op(input, *args)
self.assertEqual(actual, expected, exact_dtype=exact_dtype)
@skipCUDAIfRocm
@skipCPUIfNoMkl
@onlyOnCPUAndCUDA
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
@ -248,7 +276,6 @@ class TestFFT(TestCase):
op(x)
# nd-fft tests
@onlyOnCPUAndCUDA
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
@ops([op for op in spectral_funcs if op.ndimensional])
@ -273,13 +300,16 @@ class TestFFT(TestCase):
for input_ndim, s, dim in transform_desc:
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
input = torch.randn(*shape, device=device, dtype=dtype)
if torch.version.hip is not None:
input = self._generate_valid_rocfft_input(input, op)
for norm in norm_modes:
expected = op.ref(input.cpu().numpy(), s, dim, norm)
exact_dtype = dtype in (torch.double, torch.complex128)
actual = op(input, s, dim, norm)
self.assertEqual(actual, expected, exact_dtype=exact_dtype)
@skipCUDAIfRocm
@skipCPUIfNoMkl
@onlyOnCPUAndCUDA
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
@ -346,7 +376,6 @@ class TestFFT(TestCase):
# so don't require exhaustive testing.
@skipCPUIfNoMkl
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@dtypes(torch.double, torch.complex128)
def test_fft2_numpy(self, device, dtype):
@ -375,21 +404,25 @@ class TestFFT(TestCase):
torch_fns = (torch_fn, torch.jit.script(fn))
if torch.version.hip is not None:
valid_input = self._generate_valid_rocfft_input(input, torch_fn)
else:
valid_input = input
# Once with dim defaulted
input_np = input.cpu().numpy()
input_np = valid_input.cpu().numpy()
expected = numpy_fn(input_np, s, norm=norm)
for fn in torch_fns:
actual = fn(input, s, norm=norm)
actual = fn(valid_input, s, norm=norm)
self.assertEqual(actual, expected)
# Once with explicit dims
dim = (1, 0)
expected = numpy_fn(input.cpu(), s, dim, norm)
expected = numpy_fn(valid_input.cpu(), s, dim, norm)
for fn in torch_fns:
actual = fn(input, s, dim, norm)
actual = fn(valid_input, s, dim, norm)
self.assertEqual(actual, expected)
@skipCUDAIfRocm
@skipCPUIfNoMkl
@onlyOnCPUAndCUDA
@dtypes(torch.float, torch.complex64)
@ -596,7 +629,6 @@ class TestFFT(TestCase):
_test_complex((40, 60, 3, 80), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:])
_test_complex((30, 55, 50, 22), 3, lambda x: x[:, 3:53, 15:40, 1:21])
@skipCUDAIfRocm
@skipCPUIfNoMkl
@onlyOnCPUAndCUDA
@dtypes(torch.double)
@ -742,8 +774,6 @@ class TestFFT(TestCase):
_test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError)
_test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError)
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
@dtypes(torch.double, torch.cdouble)
@ -786,7 +816,6 @@ class TestFFT(TestCase):
length=x.size(-1), **common_kwargs)
self.assertEqual(x_roundtrip, x)
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
@dtypes(torch.double, torch.cdouble)
@ -828,7 +857,6 @@ class TestFFT(TestCase):
self.assertEqual(x_roundtrip, x)
@skipCUDAIfRocm
@skipCPUIfNoMkl
@dtypes(torch.cdouble)
def test_complex_stft_definition(self, device, dtype):
@ -848,7 +876,6 @@ class TestFFT(TestCase):
actual = torch.stft(*args, window=window, center=False)
self.assertEqual(actual, expected)
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
@dtypes(torch.cdouble)
@ -883,7 +910,6 @@ class TestFFT(TestCase):
center=center, normalized=normalized)
self.assertEqual(expected, actual)
@skipCUDAIfRocm
@skipCPUIfNoMkl
@dtypes(torch.cdouble)
def test_complex_istft_real_equiv(self, device, dtype):
@ -959,7 +985,6 @@ class TestFFT(TestCase):
_ = torch.fft.irfftn(half_spectrum_copy, s=(2, 2), dim=(-2, -1))
self.assertEqual(half_spectrum, half_spectrum_copy)
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
@dtypes(torch.double)
@ -973,7 +998,6 @@ class TestFFT(TestCase):
_test(torch.ones(4, dtype=dtype, device=device), 4, 4)
_test(torch.zeros(4, dtype=dtype, device=device), 4, 4)
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
@dtypes(torch.double)
@ -1076,7 +1100,6 @@ class TestFFT(TestCase):
self.assertRaises(RuntimeError, torch.istft, torch.zeros((3, 0, 2)), 2)
self.assertRaises(RuntimeError, torch.istft, torch.zeros((0, 3, 2)), 2)
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
@dtypes(torch.double)
@ -1111,7 +1134,6 @@ class TestFFT(TestCase):
_test(amplitude=80, L=9, n=6)
_test(amplitude=99, L=10, n=7)
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
@dtypes(torch.double)
@ -1178,7 +1200,6 @@ class TestFFT(TestCase):
for data_size, kwargs in patterns:
_test(data_size, kwargs)
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
def test_batch_istft(self, device):