diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 5a444376cc8f..93a23ec6a730 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -139,6 +139,14 @@ bool CUDAHooks::hasCuSOLVER() const { #endif } +bool CUDAHooks::hasROCM() const { + // Currently, this is same as `compiledWithMIOpen`. + // But in future if there are ROCm builds without MIOpen, + // then `hasROCM` should return true while `compiledWithMIOpen` + // should return false + return AT_ROCM_ENABLED(); +} + #if defined(USE_DIRECT_NVRTC) static std::pair, at::cuda::NVRTC*> load_nvrtc() { return std::make_pair(nullptr, at::cuda::load_nvrtc()); diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index a0d175df27c0..1c61aa709b97 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -27,6 +27,7 @@ struct CUDAHooks : public at::CUDAHooksInterface { bool hasMAGMA() const override; bool hasCuDNN() const override; bool hasCuSOLVER() const override; + bool hasROCM() const override; const at::cuda::NVRTC& nvrtc() const override; int64_t current_device() const override; bool hasPrimaryContext(int64_t device_index) const override; diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index 7a55740b7914..1303b9f8c8bf 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -107,6 +107,10 @@ struct TORCH_API CUDAHooksInterface { return false; } + virtual bool hasROCM() const { + return false; + } + virtual const at::cuda::NVRTC& nvrtc() const { TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP); } diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index af000cc70d9f..9c0ebed7551a 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -19,7 +19,7 @@ namespace { // * Integers are promoted to the default floating type // * If require_complex=True, all types are promoted to complex // * Raises an error for half-precision dtypes to allow future support -ScalarType promote_type_fft(ScalarType type, bool require_complex) { +ScalarType promote_type_fft(ScalarType type, bool require_complex, Device device) { if (at::isComplexType(type)) { return type; } @@ -28,7 +28,11 @@ ScalarType promote_type_fft(ScalarType type, bool require_complex) { type = c10::typeMetaToScalarType(c10::get_default_dtype()); } - TORCH_CHECK(type == kFloat || type == kDouble, "Unsupported dtype ", type); + if (device.is_cuda() && !at::detail::getCUDAHooks().hasROCM()) { + TORCH_CHECK(type == kHalf || type == kFloat || type == kDouble, "Unsupported dtype ", type); + } else { + TORCH_CHECK(type == kFloat || type == kDouble, "Unsupported dtype ", type); + } if (!require_complex) { return type; @@ -36,6 +40,7 @@ ScalarType promote_type_fft(ScalarType type, bool require_complex) { // Promote to complex switch (type) { + case kHalf: return kComplexHalf; case kFloat: return kComplexFloat; case kDouble: return kComplexDouble; default: TORCH_INTERNAL_ASSERT(false, "Unhandled dtype"); @@ -45,7 +50,7 @@ ScalarType promote_type_fft(ScalarType type, bool require_complex) { // Promote a tensor's dtype according to promote_type_fft Tensor promote_tensor_fft(const Tensor& t, bool require_complex=false) { auto cur_type = t.scalar_type(); - auto new_type = promote_type_fft(cur_type, require_complex); + auto new_type = promote_type_fft(cur_type, require_complex, t.device()); return (cur_type == new_type) ? t : t.to(new_type); } diff --git a/aten/src/ATen/native/cuda/SpectralOps.cu b/aten/src/ATen/native/cuda/SpectralOps.cu index df51fe46afea..2f5c13006578 100644 --- a/aten/src/ATen/native/cuda/SpectralOps.cu +++ b/aten/src/ATen/native/cuda/SpectralOps.cu @@ -106,17 +106,17 @@ void _fft_fill_with_conjugate_symmetry_cuda_( signal_half_sizes, out_strides, mirror_dims, element_size); const auto numel = c10::multiply_integers(signal_half_sizes); - AT_DISPATCH_COMPLEX_TYPES(dtype, "_fft_fill_with_conjugate_symmetry", [&] { - using namespace cuda::detail; - _fft_conjugate_copy_kernel<<< - GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( - numel, - static_cast(out_data), - static_cast(in_data), - input_offset_calculator, - output_offset_calculator); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "_fft_fill_with_conjugate_symmetry", [&] { + using namespace cuda::detail; + _fft_conjugate_copy_kernel<<< + GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( + numel, + static_cast(out_data), + static_cast(in_data), + input_offset_calculator, + output_offset_calculator); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); } REGISTER_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cuda_); diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index 03324d1c7ee2..b4f37cc1558e 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -10,12 +10,14 @@ import doctest import inspect from torch.testing._internal.common_utils import \ - (TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL) + (TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL, first_sample, TEST_WITH_ROCM, + make_tensor) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes, - skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf) + skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf, toleranceOverride, tol) from torch.testing._internal.common_methods_invocations import ( spectral_funcs, SpectralFuncType) +from torch.testing._internal.common_cuda import SM53OrLater from setuptools import distutils from typing import Optional, List @@ -110,6 +112,20 @@ def _stft_reference(x, hop_length, window): X[:, m] = torch.fft.fft(slc * window) return X + +def skip_helper_for_fft(device, dtype): + device_type = torch.device(device).type + if dtype not in (torch.half, torch.complex32): + return + + if device_type == 'cpu': + raise unittest.SkipTest("half and complex32 are not supported on CPU") + if TEST_WITH_ROCM: + raise unittest.SkipTest("half and complex32 are not supported on ROCM") + if not SM53OrLater: + raise unittest.SkipTest("half and complex32 are only supported on CUDA device with SM>53") + + # Tests of functions related to Fourier analysis in the torch.fft namespace class TestFFT(TestCase): exact_dtype = True @@ -157,20 +173,39 @@ class TestFFT(TestCase): @skipCPUIfNoFFT @onlyNativeDeviceTypes - @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) + @toleranceOverride({ + torch.half : tol(1e-2, 1e-2), + torch.chalf : tol(1e-2, 1e-2), + }) + @dtypes(torch.half, torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128) def test_fft_round_trip(self, device, dtype): + skip_helper_for_fft(device, dtype) # Test that round trip through ifft(fft(x)) is the identity - test_args = list(product( - # input - (torch.randn(67, device=device, dtype=dtype), - torch.randn(80, device=device, dtype=dtype), - torch.randn(12, 14, device=device, dtype=dtype), - torch.randn(9, 6, 3, device=device, dtype=dtype)), - # dim - (-1, 0), - # norm - (None, "forward", "backward", "ortho") - )) + if dtype not in (torch.half, torch.complex32): + test_args = list(product( + # input + (torch.randn(67, device=device, dtype=dtype), + torch.randn(80, device=device, dtype=dtype), + torch.randn(12, 14, device=device, dtype=dtype), + torch.randn(9, 6, 3, device=device, dtype=dtype)), + # dim + (-1, 0), + # norm + (None, "forward", "backward", "ortho") + )) + else: + # cuFFT supports powers of 2 for half and complex half precision + test_args = list(product( + # input + (torch.randn(64, device=device, dtype=dtype), + torch.randn(128, device=device, dtype=dtype), + torch.randn(4, 16, device=device, dtype=dtype), + torch.randn(8, 6, 2, device=device, dtype=dtype)), + # dim + (-1, 0), + # norm + (None, "forward", "backward", "ortho") + )) fft_functions = [(torch.fft.fft, torch.fft.ifft)] # Real-only functions @@ -189,13 +224,17 @@ class TestFFT(TestCase): } y = backward(forward(x, **kwargs), **kwargs) + if x.dtype is torch.half and y.dtype is torch.complex32: + # Since type promotion currently doesn't work with complex32 + # manually promote `x` to complex32 + x = x.to(torch.complex32) # For real input, ifft(fft(x)) will convert to complex self.assertEqual(x, y, exact_dtype=( forward != torch.fft.fft or x.is_complex())) # Note: NumPy will throw a ValueError for an empty input @onlyNativeDeviceTypes - @ops(spectral_funcs, allowed_dtypes=(torch.float, torch.cfloat)) + @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.float, torch.complex32, torch.cfloat)) def test_empty_fft(self, device, dtype, op): t = torch.empty(1, 0, device=device, dtype=dtype) match = r"Invalid number of data points \([-\d]*\) specified" @@ -228,8 +267,11 @@ class TestFFT(TestCase): @skipCPUIfNoFFT @onlyNativeDeviceTypes - @dtypes(torch.int8, torch.float, torch.double, torch.complex64, torch.complex128) + @dtypes(torch.int8, torch.half, torch.float, torch.double, + torch.complex32, torch.complex64, torch.complex128) def test_fft_type_promotion(self, device, dtype): + skip_helper_for_fft(device, dtype) + if dtype.is_complex or dtype.is_floating_point: t = torch.randn(64, device=device, dtype=dtype) else: @@ -237,8 +279,10 @@ class TestFFT(TestCase): PROMOTION_MAP = { torch.int8: torch.complex64, + torch.half: torch.complex32, torch.float: torch.complex64, torch.double: torch.complex128, + torch.complex32: torch.complex32, torch.complex64: torch.complex64, torch.complex128: torch.complex128, } @@ -247,17 +291,27 @@ class TestFFT(TestCase): PROMOTION_MAP_C2R = { torch.int8: torch.float, + torch.half: torch.half, torch.float: torch.float, torch.double: torch.double, + torch.complex32: torch.half, torch.complex64: torch.float, torch.complex128: torch.double, } - R = torch.fft.hfft(t) + if dtype in (torch.half, torch.complex32): + # cuFFT supports powers of 2 for half and complex half precision + # NOTE: With hfft and default args where output_size n=2*(input_size - 1), + # we make sure that logical fft size is a power of two. + x = torch.randn(65, device=device, dtype=dtype) + R = torch.fft.hfft(x) + else: + R = torch.fft.hfft(t) self.assertEqual(R.dtype, PROMOTION_MAP_C2R[dtype]) if not dtype.is_complex: PROMOTION_MAP_R2C = { torch.int8: torch.complex64, + torch.half: torch.complex32, torch.float: torch.complex64, torch.double: torch.complex128, } @@ -269,9 +323,32 @@ class TestFFT(TestCase): allowed_dtypes=[torch.half, torch.bfloat16]) def test_fft_half_and_bfloat16_errors(self, device, dtype, op): # TODO: Remove torch.half error when complex32 is fully implemented - x = torch.randn(8, 8, device=device).to(dtype) - with self.assertRaisesRegex(RuntimeError, "Unsupported dtype "): - op(x) + sample = first_sample(self, op.sample_inputs(device, dtype)) + device_type = torch.device(device).type + if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM: + err_msg = "Unsupported dtype " + elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater: + err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53" + else: + err_msg = "Unsupported dtype " + with self.assertRaisesRegex(RuntimeError, err_msg): + op(sample.input, *sample.args, **sample.kwargs) + + @onlyNativeDeviceTypes + @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.chalf)) + def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op): + t = make_tensor(13, 13, device=device, dtype=dtype) + err_msg = "cuFFT only supports dimensions whose sizes are powers of two" + with self.assertRaisesRegex(RuntimeError, err_msg): + op(t) + + if op.ndimensional in (SpectralFuncType.ND, SpectralFuncType.TwoD): + kwargs = {'s': (12, 12)} + else: + kwargs = {'n': 12} + + with self.assertRaisesRegex(RuntimeError, err_msg): + op(t, **kwargs) # nd-fft tests @onlyNativeDeviceTypes @@ -308,8 +385,15 @@ class TestFFT(TestCase): @skipCPUIfNoFFT @onlyNativeDeviceTypes - @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) + @toleranceOverride({ + torch.half : tol(1e-2, 1e-2), + torch.chalf : tol(1e-2, 1e-2), + }) + @dtypes(torch.half, torch.float, torch.double, + torch.complex32, torch.complex64, torch.complex128) def test_fftn_round_trip(self, device, dtype): + skip_helper_for_fft(device, dtype) + norm_modes = (None, "forward", "backward", "ortho") # input_ndim, dim @@ -331,7 +415,11 @@ class TestFFT(TestCase): (torch.fft.ihfftn, torch.fft.hfftn)] for input_ndim, dim in transform_desc: - shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) + if dtype in (torch.half, torch.complex32): + # cuFFT supports powers of 2 for half and complex half precision + shape = itertools.islice(itertools.cycle((2, 4, 8)), input_ndim) + else: + shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) x = torch.randn(*shape, device=device, dtype=dtype) for (forward, backward), norm in product(fft_functions, norm_modes): @@ -343,8 +431,13 @@ class TestFFT(TestCase): kwargs = {'s': s, 'dim': dim, 'norm': norm} y = backward(forward(x, **kwargs), **kwargs) # For real input, ifftn(fftn(x)) will convert to complex - self.assertEqual(x, y, exact_dtype=( - forward != torch.fft.fftn or x.is_complex())) + if x.dtype is torch.half and y.dtype is torch.chalf: + # Since type promotion currently doesn't work with complex32 + # manually promote `x` to complex32 + self.assertEqual(x.to(torch.chalf), y) + else: + self.assertEqual(x, y, exact_dtype=( + forward != torch.fft.fftn or x.is_complex())) @onlyNativeDeviceTypes @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND], @@ -369,8 +462,13 @@ class TestFFT(TestCase): @skipCPUIfNoFFT @onlyNativeDeviceTypes - @dtypes(torch.float, torch.double) + @toleranceOverride({ + torch.half : tol(1e-2, 1e-2), + }) + @dtypes(torch.half, torch.float, torch.double) def test_hfftn(self, device, dtype): + skip_helper_for_fft(device, dtype) + # input_ndim, dim transform_desc = [ *product(range(2, 5), (None, (0,), (0, -1))), @@ -383,8 +481,10 @@ class TestFFT(TestCase): for input_ndim, dim in transform_desc: actual_dims = list(range(input_ndim)) if dim is None else dim - - shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim)) + if dtype is torch.half: + shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)) + else: + shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim)) expect = torch.randn(*shape, device=device, dtype=dtype) input = torch.fft.ifftn(expect, dim=dim, norm="ortho") @@ -401,8 +501,13 @@ class TestFFT(TestCase): @skipCPUIfNoFFT @onlyNativeDeviceTypes - @dtypes(torch.float, torch.double) + @toleranceOverride({ + torch.half : tol(1e-2, 1e-2), + }) + @dtypes(torch.half, torch.float, torch.double) def test_ihfftn(self, device, dtype): + skip_helper_for_fft(device, dtype) + # input_ndim, dim transform_desc = [ *product(range(2, 5), (None, (0,), (0, -1))), @@ -414,7 +519,11 @@ class TestFFT(TestCase): ] for input_ndim, dim in transform_desc: - shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim)) + if dtype is torch.half: + shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)) + else: + shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim)) + input = torch.randn(*shape, device=device, dtype=dtype) expect = torch.fft.ifftn(input, dim=dim, norm="ortho") diff --git a/torch/fft/__init__.py b/torch/fft/__init__.py index 6ad15de6dfec..a9a6e3e84650 100644 --- a/torch/fft/__init__.py +++ b/torch/fft/__init__.py @@ -20,7 +20,6 @@ fft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor Computes the one dimensional discrete Fourier transform of :attr:`input`. Note: - The Fourier domain representation of any real signal satisfies the Hermitian property: `X[i] = conj(X[-i])`. This function always returns both the positive and negative frequency terms even though, for real inputs, the @@ -28,6 +27,10 @@ Note: more compact one-sided representation where only the positive frequencies are returned. +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + Args: input (Tensor): the input tensor n (int, optional): Signal length. If given, the input will either be zero-padded @@ -68,6 +71,10 @@ ifft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor Computes the one dimensional inverse discrete Fourier transform of :attr:`input`. +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + Args: input (Tensor): the input tensor n (int, optional): Signal length. If given, the input will either be zero-padded @@ -111,6 +118,10 @@ Note: :func:`~torch.fft.rfft2` returns the more compact one-sided representation where only the positive frequencies of the last dimension are returned. +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. @@ -157,6 +168,10 @@ ifft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor Computes the 2 dimensional inverse discrete Fourier transform of :attr:`input`. Equivalent to :func:`~torch.fft.ifftn` but IFFTs only the last two dimensions by default. +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. @@ -203,7 +218,6 @@ fftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor Computes the N dimensional discrete Fourier transform of :attr:`input`. Note: - The Fourier domain representation of any real signal satisfies the Hermitian property: ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])``. This function always returns all positive and negative frequency terms even @@ -211,6 +225,10 @@ Note: :func:`~torch.fft.rfftn` returns the more compact one-sided representation where only the positive frequencies of the last dimension are returned. +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. @@ -256,6 +274,10 @@ ifftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor Computes the N dimensional inverse discrete Fourier transform of :attr:`input`. +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. @@ -305,6 +327,10 @@ The FFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])`` so the output contains only the positive frequencies below the Nyquist frequency. To compute the full output, use :func:`~torch.fft.fft` +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + Args: input (Tensor): the real input tensor n (int, optional): Signal length. If given, the input will either be zero-padded @@ -367,6 +393,12 @@ Note: signal is assumed to be even length and odd signals will not round-trip properly. So, it is recommended to always pass the signal length :attr:`n`. +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + With default arguments, size of the transformed dimension should be (2^n + 1) as argument + `n` defaults to even output size = 2 * (transformed_dim_size - 1) + Args: input (Tensor): the input tensor representing a half-Hermitian signal n (int, optional): Output signal length. This determines the length of the @@ -424,6 +456,10 @@ so the full :func:`~torch.fft.fft2` output contains redundant information. :func:`~torch.fft.rfft2` instead omits the negative frequencies in the last dimension. +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. @@ -496,6 +532,12 @@ Note: signal is assumed to be even length and odd signals will not round-trip properly. So, it is recommended to always pass the signal shape :attr:`s`. +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + With default arguments, the size of last dimension should be (2^n + 1) as argument + `s` defaults to even output size = 2 * (last_dim_size - 1) + Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. @@ -557,6 +599,10 @@ The FFT of a real signal is Hermitian-symmetric, :func:`~torch.fft.rfftn` instead omits the negative frequencies in the last dimension. +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. @@ -628,6 +674,12 @@ Note: signal is assumed to be even length and odd signals will not round-trip properly. So, it is recommended to always pass the signal shape :attr:`s`. +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + With default arguments, the size of last dimension should be (2^n + 1) as argument + `s` defaults to even output size = 2 * (last_dim_size - 1) + Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. @@ -709,6 +761,12 @@ Note: signal is assumed to be even length and odd signals will not round-trip properly. So, it is recommended to always pass the signal length :attr:`n`. +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + With default arguments, size of the transformed dimension should be (2^n + 1) as argument + `n` defaults to even output size = 2 * (transformed_dim_size - 1) + Args: input (Tensor): the input tensor representing a half-Hermitian signal n (int, optional): Output signal length. This determines the length of the @@ -771,6 +829,10 @@ The IFFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])``. positive frequencies below the Nyquist frequency are included. To compute the full output, use :func:`~torch.fft.ifft`. +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + Args: input (Tensor): the real input tensor n (int, optional): Signal length. If given, the input will either be zero-padded @@ -818,6 +880,12 @@ transforms the last two dimensions by default. :attr:`input` is interpreted as a one-sided Hermitian signal in the time domain. By the Hermitian property, the Fourier transform will be real-valued. +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + With default arguments, the size of last dimension should be (2^n + 1) as argument + `s` defaults to even output size = 2 * (last_dim_size - 1) + Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. @@ -878,6 +946,10 @@ Computes the 2-dimensional inverse discrete Fourier transform of real :attr:`input`. Equivalent to :func:`~torch.fft.ihfftn` but transforms only the two last dimensions by default. +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. @@ -960,6 +1032,12 @@ Note: signal is assumed to be even length and odd signals will not round-trip properly. It is recommended to always pass the signal shape :attr:`s`. +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + With default arguments, the size of last dimension should be (2^n + 1) as argument + `s` defaults to even output size = 2 * (last_dim_size - 1) + Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. @@ -1025,6 +1103,10 @@ this in the one-sided form where only the positive frequencies below the Nyquist frequency are included in the last signal dimension. To compute the full output, use :func:`~torch.fft.ifftn`. +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + Args: input (Tensor): the input tensor s (Tuple[int], optional): Signal size in the transformed dimensions. diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 8ef8aad6d078..3c76c8fc3f36 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -5810,15 +5810,33 @@ def np_unary_ufunc_integer_promotion_wrapper(fn): return wrapped_fn def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs): - nd_tensor = partial(make_tensor, (S, S + 1, S + 2), device=device, - dtype=dtype, requires_grad=requires_grad) - oned_tensor = partial(make_tensor, (31,), device=device, - dtype=dtype, requires_grad=requires_grad) + is_fp16_or_chalf = dtype == torch.complex32 or dtype == torch.half + if not is_fp16_or_chalf: + nd_tensor = partial(make_tensor, (S, S + 1, S + 2), device=device, + dtype=dtype, requires_grad=requires_grad) + oned_tensor = partial(make_tensor, (31,), device=device, + dtype=dtype, requires_grad=requires_grad) + else: + # cuFFT supports powers of 2 for half and complex half precision + # NOTE: For hfft, hfft2, hfftn, irfft, irfft2, irfftn with default args + # where output_size n=2*(input_size - 1), we make sure that logical fft size is a power of two + if self.name in ['fft.hfft', 'fft.irfft']: + shapes = ((2, 9, 9), (33,)) + elif self.name in ['fft.hfft2', 'fft.irfft2']: + shapes = ((2, 8, 9), (33,)) + elif self.name in ['fft.hfftn', 'fft.irfftn']: + shapes = ((2, 2, 33), (33,)) + else: + shapes = ((2, 8, 16), (32,)) + nd_tensor = partial(make_tensor, shapes[0], device=device, + dtype=dtype, requires_grad=requires_grad) + oned_tensor = partial(make_tensor, shapes[1], device=device, + dtype=dtype, requires_grad=requires_grad) if self.ndimensional == SpectralFuncType.ND: return [ SampleInput(nd_tensor(), - kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')), + kwargs=dict(s=(3, 10) if not is_fp16_or_chalf else (4, 8), dim=(1, 2), norm='ortho')), SampleInput(nd_tensor(), kwargs=dict(norm='ortho')), SampleInput(nd_tensor(), @@ -5832,11 +5850,11 @@ def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwarg elif self.ndimensional == SpectralFuncType.TwoD: return [ SampleInput(nd_tensor(), - kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')), + kwargs=dict(s=(3, 10) if not is_fp16_or_chalf else (4, 8), dim=(1, 2), norm='ortho')), SampleInput(nd_tensor(), kwargs=dict(norm='ortho')), SampleInput(nd_tensor(), - kwargs=dict(s=(6, 8))), + kwargs=dict(s=(6, 8) if not is_fp16_or_chalf else (4, 8))), SampleInput(nd_tensor(), kwargs=dict(dim=0)), SampleInput(nd_tensor(), @@ -5847,11 +5865,12 @@ def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwarg else: return [ SampleInput(nd_tensor(), - kwargs=dict(n=10, dim=1, norm='ortho')), + kwargs=dict(n=10 if not is_fp16_or_chalf else 8, dim=1, norm='ortho')), SampleInput(nd_tensor(), kwargs=dict(norm='ortho')), SampleInput(nd_tensor(), - kwargs=dict(n=7)), + kwargs=dict(n=7 if not is_fp16_or_chalf else 8) + ), SampleInput(oned_tensor()), *(SampleInput(nd_tensor(), @@ -5887,6 +5906,8 @@ class SpectralFuncInfo(OpInfo): decorators = list(decorators) if decorators is not None else [] decorators += [ skipCPUIfNoFFT, + DecorateInfo(toleranceOverride({torch.chalf: tol(4e-2, 4e-2)}), + "TestCommon", "test_complex_half_reference_testing") ] super().__init__(name=name, @@ -10628,6 +10649,10 @@ op_db: List[OpInfo] = [ ref=np.fft.fft, ndimensional=SpectralFuncType.OneD, dtypes=all_types_and_complex_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), supports_forward_ad=True, supports_fwgrad_bwgrad=True, ), @@ -10636,6 +10661,10 @@ op_db: List[OpInfo] = [ ref=np.fft.fft2, ndimensional=SpectralFuncType.TwoD, dtypes=all_types_and_complex_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), supports_forward_ad=True, supports_fwgrad_bwgrad=True, decorators=[precisionOverride( @@ -10646,6 +10675,10 @@ op_db: List[OpInfo] = [ ref=np.fft.fftn, ndimensional=SpectralFuncType.ND, dtypes=all_types_and_complex_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), supports_forward_ad=True, supports_fwgrad_bwgrad=True, decorators=[precisionOverride( @@ -10656,6 +10689,10 @@ op_db: List[OpInfo] = [ ref=np.fft.hfft, ndimensional=SpectralFuncType.OneD, dtypes=all_types_and_complex_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_gradgrad=False), @@ -10664,6 +10701,10 @@ op_db: List[OpInfo] = [ ref=scipy.fft.hfft2 if has_scipy_fft else None, ndimensional=SpectralFuncType.TwoD, dtypes=all_types_and_complex_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_gradgrad=False, @@ -10677,6 +10718,10 @@ op_db: List[OpInfo] = [ ref=scipy.fft.hfftn if has_scipy_fft else None, ndimensional=SpectralFuncType.ND, dtypes=all_types_and_complex_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_gradgrad=False, @@ -10690,6 +10735,9 @@ op_db: List[OpInfo] = [ ref=np.fft.rfft, ndimensional=SpectralFuncType.OneD, dtypes=all_types_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)), supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_grad=False, @@ -10701,6 +10749,9 @@ op_db: List[OpInfo] = [ ref=np.fft.rfft2, ndimensional=SpectralFuncType.TwoD, dtypes=all_types_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)), supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_grad=False, @@ -10713,6 +10764,9 @@ op_db: List[OpInfo] = [ ref=np.fft.rfftn, ndimensional=SpectralFuncType.ND, dtypes=all_types_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)), supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_grad=False, @@ -10726,7 +10780,11 @@ op_db: List[OpInfo] = [ ndimensional=SpectralFuncType.OneD, supports_forward_ad=True, supports_fwgrad_bwgrad=True, - dtypes=all_types_and_complex_and(torch.bool)), + dtypes=all_types_and_complex_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),), SpectralFuncInfo('fft.ifft2', aten_name='fft_ifft2', ref=np.fft.ifft2, @@ -10734,6 +10792,10 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, supports_fwgrad_bwgrad=True, dtypes=all_types_and_complex_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), decorators=[ DecorateInfo( precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), @@ -10746,6 +10808,10 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, supports_fwgrad_bwgrad=True, dtypes=all_types_and_complex_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), decorators=[ DecorateInfo( precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), @@ -10758,6 +10824,9 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, supports_fwgrad_bwgrad=True, dtypes=all_types_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)), skips=( ), check_batched_grad=False), @@ -10768,6 +10837,9 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, supports_fwgrad_bwgrad=True, dtypes=all_types_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)), check_batched_grad=False, check_batched_gradgrad=False, decorators=( @@ -10784,6 +10856,9 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, supports_fwgrad_bwgrad=True, dtypes=all_types_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archss + dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)), check_batched_grad=False, check_batched_gradgrad=False, decorators=[ @@ -10802,6 +10877,10 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, supports_fwgrad_bwgrad=True, dtypes=all_types_and_complex_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), check_batched_gradgrad=False), SpectralFuncInfo('fft.irfft2', aten_name='fft_irfft2', @@ -10810,6 +10889,10 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, supports_fwgrad_bwgrad=True, dtypes=all_types_and_complex_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), check_batched_gradgrad=False, decorators=[ DecorateInfo( @@ -10823,6 +10906,10 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, supports_fwgrad_bwgrad=True, dtypes=all_types_and_complex_and(torch.bool), + # rocFFT doesn't support Half/Complex Half Precision FFT + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), check_batched_gradgrad=False, decorators=[ DecorateInfo(