[complex32] fft support (cuda only) (#74857)

`half` and `complex32` support for `torch.fft.{fft, fft2, fftn, hfft, hfft2, hfftn, ifft, ifft2, ifftn, ihfft, ihfft2, ihfftn, irfft, irfft2, irfftn, rfft, rfft2, rfftn}`

* We only add support for `CUDA` as `cuFFT` supports these precision.
* We still error out on `CPU` and `ROCm` as their respective backends don't support this precision

For `cuFFT` following are the constraints for these precisions
* Minimum GPU architecture is SM_53
* Sizes are restricted to powers of two only
* Strides on the real part of real-to-complex and complex-to-real transforms are not supported
* More than one GPU is not supported
* Transforms spanning more than 4 billion elements are not supported

Ref: https://docs.nvidia.com/cuda/cufft/#half-precision-transforms

TODO:
* [x] Update docs about the restrictions
* [x] Check the correct way to check for `hip` device. (seems like `device.is_cuda()` is true for hip as well) (Thanks @peterbell10 )

Ref  for second point in TODO:e424e7d214/aten/src/ATen/native/SpectralOps.cpp (L31)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74857
Approved by: https://github.com/anjali411, https://github.com/peterbell10
This commit is contained in:
kshitij12345
2022-05-12 04:28:55 +00:00
committed by PyTorch MergeBot
parent b825e1d472
commit ada65fdd67
8 changed files with 351 additions and 55 deletions

View File

@ -139,6 +139,14 @@ bool CUDAHooks::hasCuSOLVER() const {
#endif
}
bool CUDAHooks::hasROCM() const {
// Currently, this is same as `compiledWithMIOpen`.
// But in future if there are ROCm builds without MIOpen,
// then `hasROCM` should return true while `compiledWithMIOpen`
// should return false
return AT_ROCM_ENABLED();
}
#if defined(USE_DIRECT_NVRTC)
static std::pair<std::unique_ptr<at::DynamicLibrary>, at::cuda::NVRTC*> load_nvrtc() {
return std::make_pair(nullptr, at::cuda::load_nvrtc());

View File

@ -27,6 +27,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasMAGMA() const override;
bool hasCuDNN() const override;
bool hasCuSOLVER() const override;
bool hasROCM() const override;
const at::cuda::NVRTC& nvrtc() const override;
int64_t current_device() const override;
bool hasPrimaryContext(int64_t device_index) const override;

View File

@ -107,6 +107,10 @@ struct TORCH_API CUDAHooksInterface {
return false;
}
virtual bool hasROCM() const {
return false;
}
virtual const at::cuda::NVRTC& nvrtc() const {
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
}

View File

@ -19,7 +19,7 @@ namespace {
// * Integers are promoted to the default floating type
// * If require_complex=True, all types are promoted to complex
// * Raises an error for half-precision dtypes to allow future support
ScalarType promote_type_fft(ScalarType type, bool require_complex) {
ScalarType promote_type_fft(ScalarType type, bool require_complex, Device device) {
if (at::isComplexType(type)) {
return type;
}
@ -28,7 +28,11 @@ ScalarType promote_type_fft(ScalarType type, bool require_complex) {
type = c10::typeMetaToScalarType(c10::get_default_dtype());
}
TORCH_CHECK(type == kFloat || type == kDouble, "Unsupported dtype ", type);
if (device.is_cuda() && !at::detail::getCUDAHooks().hasROCM()) {
TORCH_CHECK(type == kHalf || type == kFloat || type == kDouble, "Unsupported dtype ", type);
} else {
TORCH_CHECK(type == kFloat || type == kDouble, "Unsupported dtype ", type);
}
if (!require_complex) {
return type;
@ -36,6 +40,7 @@ ScalarType promote_type_fft(ScalarType type, bool require_complex) {
// Promote to complex
switch (type) {
case kHalf: return kComplexHalf;
case kFloat: return kComplexFloat;
case kDouble: return kComplexDouble;
default: TORCH_INTERNAL_ASSERT(false, "Unhandled dtype");
@ -45,7 +50,7 @@ ScalarType promote_type_fft(ScalarType type, bool require_complex) {
// Promote a tensor's dtype according to promote_type_fft
Tensor promote_tensor_fft(const Tensor& t, bool require_complex=false) {
auto cur_type = t.scalar_type();
auto new_type = promote_type_fft(cur_type, require_complex);
auto new_type = promote_type_fft(cur_type, require_complex, t.device());
return (cur_type == new_type) ? t : t.to(new_type);
}

View File

@ -106,17 +106,17 @@ void _fft_fill_with_conjugate_symmetry_cuda_(
signal_half_sizes, out_strides, mirror_dims, element_size);
const auto numel = c10::multiply_integers(signal_half_sizes);
AT_DISPATCH_COMPLEX_TYPES(dtype, "_fft_fill_with_conjugate_symmetry", [&] {
using namespace cuda::detail;
_fft_conjugate_copy_kernel<<<
GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
numel,
static_cast<scalar_t*>(out_data),
static_cast<const scalar_t*>(in_data),
input_offset_calculator,
output_offset_calculator);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "_fft_fill_with_conjugate_symmetry", [&] {
using namespace cuda::detail;
_fft_conjugate_copy_kernel<<<
GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
numel,
static_cast<scalar_t*>(out_data),
static_cast<const scalar_t*>(in_data),
input_offset_calculator,
output_offset_calculator);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
REGISTER_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cuda_);

View File

@ -10,12 +10,14 @@ import doctest
import inspect
from torch.testing._internal.common_utils import \
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL)
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL, first_sample, TEST_WITH_ROCM,
make_tensor)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes,
skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf)
skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf, toleranceOverride, tol)
from torch.testing._internal.common_methods_invocations import (
spectral_funcs, SpectralFuncType)
from torch.testing._internal.common_cuda import SM53OrLater
from setuptools import distutils
from typing import Optional, List
@ -110,6 +112,20 @@ def _stft_reference(x, hop_length, window):
X[:, m] = torch.fft.fft(slc * window)
return X
def skip_helper_for_fft(device, dtype):
device_type = torch.device(device).type
if dtype not in (torch.half, torch.complex32):
return
if device_type == 'cpu':
raise unittest.SkipTest("half and complex32 are not supported on CPU")
if TEST_WITH_ROCM:
raise unittest.SkipTest("half and complex32 are not supported on ROCM")
if not SM53OrLater:
raise unittest.SkipTest("half and complex32 are only supported on CUDA device with SM>53")
# Tests of functions related to Fourier analysis in the torch.fft namespace
class TestFFT(TestCase):
exact_dtype = True
@ -157,20 +173,39 @@ class TestFFT(TestCase):
@skipCPUIfNoFFT
@onlyNativeDeviceTypes
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
@toleranceOverride({
torch.half : tol(1e-2, 1e-2),
torch.chalf : tol(1e-2, 1e-2),
})
@dtypes(torch.half, torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128)
def test_fft_round_trip(self, device, dtype):
skip_helper_for_fft(device, dtype)
# Test that round trip through ifft(fft(x)) is the identity
test_args = list(product(
# input
(torch.randn(67, device=device, dtype=dtype),
torch.randn(80, device=device, dtype=dtype),
torch.randn(12, 14, device=device, dtype=dtype),
torch.randn(9, 6, 3, device=device, dtype=dtype)),
# dim
(-1, 0),
# norm
(None, "forward", "backward", "ortho")
))
if dtype not in (torch.half, torch.complex32):
test_args = list(product(
# input
(torch.randn(67, device=device, dtype=dtype),
torch.randn(80, device=device, dtype=dtype),
torch.randn(12, 14, device=device, dtype=dtype),
torch.randn(9, 6, 3, device=device, dtype=dtype)),
# dim
(-1, 0),
# norm
(None, "forward", "backward", "ortho")
))
else:
# cuFFT supports powers of 2 for half and complex half precision
test_args = list(product(
# input
(torch.randn(64, device=device, dtype=dtype),
torch.randn(128, device=device, dtype=dtype),
torch.randn(4, 16, device=device, dtype=dtype),
torch.randn(8, 6, 2, device=device, dtype=dtype)),
# dim
(-1, 0),
# norm
(None, "forward", "backward", "ortho")
))
fft_functions = [(torch.fft.fft, torch.fft.ifft)]
# Real-only functions
@ -189,13 +224,17 @@ class TestFFT(TestCase):
}
y = backward(forward(x, **kwargs), **kwargs)
if x.dtype is torch.half and y.dtype is torch.complex32:
# Since type promotion currently doesn't work with complex32
# manually promote `x` to complex32
x = x.to(torch.complex32)
# For real input, ifft(fft(x)) will convert to complex
self.assertEqual(x, y, exact_dtype=(
forward != torch.fft.fft or x.is_complex()))
# Note: NumPy will throw a ValueError for an empty input
@onlyNativeDeviceTypes
@ops(spectral_funcs, allowed_dtypes=(torch.float, torch.cfloat))
@ops(spectral_funcs, allowed_dtypes=(torch.half, torch.float, torch.complex32, torch.cfloat))
def test_empty_fft(self, device, dtype, op):
t = torch.empty(1, 0, device=device, dtype=dtype)
match = r"Invalid number of data points \([-\d]*\) specified"
@ -228,8 +267,11 @@ class TestFFT(TestCase):
@skipCPUIfNoFFT
@onlyNativeDeviceTypes
@dtypes(torch.int8, torch.float, torch.double, torch.complex64, torch.complex128)
@dtypes(torch.int8, torch.half, torch.float, torch.double,
torch.complex32, torch.complex64, torch.complex128)
def test_fft_type_promotion(self, device, dtype):
skip_helper_for_fft(device, dtype)
if dtype.is_complex or dtype.is_floating_point:
t = torch.randn(64, device=device, dtype=dtype)
else:
@ -237,8 +279,10 @@ class TestFFT(TestCase):
PROMOTION_MAP = {
torch.int8: torch.complex64,
torch.half: torch.complex32,
torch.float: torch.complex64,
torch.double: torch.complex128,
torch.complex32: torch.complex32,
torch.complex64: torch.complex64,
torch.complex128: torch.complex128,
}
@ -247,17 +291,27 @@ class TestFFT(TestCase):
PROMOTION_MAP_C2R = {
torch.int8: torch.float,
torch.half: torch.half,
torch.float: torch.float,
torch.double: torch.double,
torch.complex32: torch.half,
torch.complex64: torch.float,
torch.complex128: torch.double,
}
R = torch.fft.hfft(t)
if dtype in (torch.half, torch.complex32):
# cuFFT supports powers of 2 for half and complex half precision
# NOTE: With hfft and default args where output_size n=2*(input_size - 1),
# we make sure that logical fft size is a power of two.
x = torch.randn(65, device=device, dtype=dtype)
R = torch.fft.hfft(x)
else:
R = torch.fft.hfft(t)
self.assertEqual(R.dtype, PROMOTION_MAP_C2R[dtype])
if not dtype.is_complex:
PROMOTION_MAP_R2C = {
torch.int8: torch.complex64,
torch.half: torch.complex32,
torch.float: torch.complex64,
torch.double: torch.complex128,
}
@ -269,9 +323,32 @@ class TestFFT(TestCase):
allowed_dtypes=[torch.half, torch.bfloat16])
def test_fft_half_and_bfloat16_errors(self, device, dtype, op):
# TODO: Remove torch.half error when complex32 is fully implemented
x = torch.randn(8, 8, device=device).to(dtype)
with self.assertRaisesRegex(RuntimeError, "Unsupported dtype "):
op(x)
sample = first_sample(self, op.sample_inputs(device, dtype))
device_type = torch.device(device).type
if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM:
err_msg = "Unsupported dtype "
elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater:
err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53"
else:
err_msg = "Unsupported dtype "
with self.assertRaisesRegex(RuntimeError, err_msg):
op(sample.input, *sample.args, **sample.kwargs)
@onlyNativeDeviceTypes
@ops(spectral_funcs, allowed_dtypes=(torch.half, torch.chalf))
def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op):
t = make_tensor(13, 13, device=device, dtype=dtype)
err_msg = "cuFFT only supports dimensions whose sizes are powers of two"
with self.assertRaisesRegex(RuntimeError, err_msg):
op(t)
if op.ndimensional in (SpectralFuncType.ND, SpectralFuncType.TwoD):
kwargs = {'s': (12, 12)}
else:
kwargs = {'n': 12}
with self.assertRaisesRegex(RuntimeError, err_msg):
op(t, **kwargs)
# nd-fft tests
@onlyNativeDeviceTypes
@ -308,8 +385,15 @@ class TestFFT(TestCase):
@skipCPUIfNoFFT
@onlyNativeDeviceTypes
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
@toleranceOverride({
torch.half : tol(1e-2, 1e-2),
torch.chalf : tol(1e-2, 1e-2),
})
@dtypes(torch.half, torch.float, torch.double,
torch.complex32, torch.complex64, torch.complex128)
def test_fftn_round_trip(self, device, dtype):
skip_helper_for_fft(device, dtype)
norm_modes = (None, "forward", "backward", "ortho")
# input_ndim, dim
@ -331,7 +415,11 @@ class TestFFT(TestCase):
(torch.fft.ihfftn, torch.fft.hfftn)]
for input_ndim, dim in transform_desc:
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
if dtype in (torch.half, torch.complex32):
# cuFFT supports powers of 2 for half and complex half precision
shape = itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)
else:
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
x = torch.randn(*shape, device=device, dtype=dtype)
for (forward, backward), norm in product(fft_functions, norm_modes):
@ -343,8 +431,13 @@ class TestFFT(TestCase):
kwargs = {'s': s, 'dim': dim, 'norm': norm}
y = backward(forward(x, **kwargs), **kwargs)
# For real input, ifftn(fftn(x)) will convert to complex
self.assertEqual(x, y, exact_dtype=(
forward != torch.fft.fftn or x.is_complex()))
if x.dtype is torch.half and y.dtype is torch.chalf:
# Since type promotion currently doesn't work with complex32
# manually promote `x` to complex32
self.assertEqual(x.to(torch.chalf), y)
else:
self.assertEqual(x, y, exact_dtype=(
forward != torch.fft.fftn or x.is_complex()))
@onlyNativeDeviceTypes
@ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
@ -369,8 +462,13 @@ class TestFFT(TestCase):
@skipCPUIfNoFFT
@onlyNativeDeviceTypes
@dtypes(torch.float, torch.double)
@toleranceOverride({
torch.half : tol(1e-2, 1e-2),
})
@dtypes(torch.half, torch.float, torch.double)
def test_hfftn(self, device, dtype):
skip_helper_for_fft(device, dtype)
# input_ndim, dim
transform_desc = [
*product(range(2, 5), (None, (0,), (0, -1))),
@ -383,8 +481,10 @@ class TestFFT(TestCase):
for input_ndim, dim in transform_desc:
actual_dims = list(range(input_ndim)) if dim is None else dim
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
if dtype is torch.half:
shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
else:
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
expect = torch.randn(*shape, device=device, dtype=dtype)
input = torch.fft.ifftn(expect, dim=dim, norm="ortho")
@ -401,8 +501,13 @@ class TestFFT(TestCase):
@skipCPUIfNoFFT
@onlyNativeDeviceTypes
@dtypes(torch.float, torch.double)
@toleranceOverride({
torch.half : tol(1e-2, 1e-2),
})
@dtypes(torch.half, torch.float, torch.double)
def test_ihfftn(self, device, dtype):
skip_helper_for_fft(device, dtype)
# input_ndim, dim
transform_desc = [
*product(range(2, 5), (None, (0,), (0, -1))),
@ -414,7 +519,11 @@ class TestFFT(TestCase):
]
for input_ndim, dim in transform_desc:
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
if dtype is torch.half:
shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
else:
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
input = torch.randn(*shape, device=device, dtype=dtype)
expect = torch.fft.ifftn(input, dim=dim, norm="ortho")

View File

@ -20,7 +20,6 @@ fft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
Computes the one dimensional discrete Fourier transform of :attr:`input`.
Note:
The Fourier domain representation of any real signal satisfies the
Hermitian property: `X[i] = conj(X[-i])`. This function always returns both
the positive and negative frequency terms even though, for real inputs, the
@ -28,6 +27,10 @@ Note:
more compact one-sided representation where only the positive frequencies
are returned.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimension.
Args:
input (Tensor): the input tensor
n (int, optional): Signal length. If given, the input will either be zero-padded
@ -68,6 +71,10 @@ ifft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
Computes the one dimensional inverse discrete Fourier transform of :attr:`input`.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimension.
Args:
input (Tensor): the input tensor
n (int, optional): Signal length. If given, the input will either be zero-padded
@ -111,6 +118,10 @@ Note:
:func:`~torch.fft.rfft2` returns the more compact one-sided representation
where only the positive frequencies of the last dimension are returned.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
Args:
input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -157,6 +168,10 @@ ifft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor
Computes the 2 dimensional inverse discrete Fourier transform of :attr:`input`.
Equivalent to :func:`~torch.fft.ifftn` but IFFTs only the last two dimensions by default.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
Args:
input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -203,7 +218,6 @@ fftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
Computes the N dimensional discrete Fourier transform of :attr:`input`.
Note:
The Fourier domain representation of any real signal satisfies the
Hermitian property: ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])``. This
function always returns all positive and negative frequency terms even
@ -211,6 +225,10 @@ Note:
:func:`~torch.fft.rfftn` returns the more compact one-sided representation
where only the positive frequencies of the last dimension are returned.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
Args:
input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -256,6 +274,10 @@ ifftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
Computes the N dimensional inverse discrete Fourier transform of :attr:`input`.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
Args:
input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -305,6 +327,10 @@ The FFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])`` so
the output contains only the positive frequencies below the Nyquist frequency.
To compute the full output, use :func:`~torch.fft.fft`
Note:
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimension.
Args:
input (Tensor): the real input tensor
n (int, optional): Signal length. If given, the input will either be zero-padded
@ -367,6 +393,12 @@ Note:
signal is assumed to be even length and odd signals will not round-trip
properly. So, it is recommended to always pass the signal length :attr:`n`.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimension.
With default arguments, size of the transformed dimension should be (2^n + 1) as argument
`n` defaults to even output size = 2 * (transformed_dim_size - 1)
Args:
input (Tensor): the input tensor representing a half-Hermitian signal
n (int, optional): Output signal length. This determines the length of the
@ -424,6 +456,10 @@ so the full :func:`~torch.fft.fft2` output contains redundant information.
:func:`~torch.fft.rfft2` instead omits the negative frequencies in the last
dimension.
Note:
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
Args:
input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -496,6 +532,12 @@ Note:
signal is assumed to be even length and odd signals will not round-trip
properly. So, it is recommended to always pass the signal shape :attr:`s`.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
With default arguments, the size of last dimension should be (2^n + 1) as argument
`s` defaults to even output size = 2 * (last_dim_size - 1)
Args:
input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -557,6 +599,10 @@ The FFT of a real signal is Hermitian-symmetric,
:func:`~torch.fft.rfftn` instead omits the negative frequencies in the
last dimension.
Note:
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
Args:
input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -628,6 +674,12 @@ Note:
signal is assumed to be even length and odd signals will not round-trip
properly. So, it is recommended to always pass the signal shape :attr:`s`.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
With default arguments, the size of last dimension should be (2^n + 1) as argument
`s` defaults to even output size = 2 * (last_dim_size - 1)
Args:
input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -709,6 +761,12 @@ Note:
signal is assumed to be even length and odd signals will not round-trip
properly. So, it is recommended to always pass the signal length :attr:`n`.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimension.
With default arguments, size of the transformed dimension should be (2^n + 1) as argument
`n` defaults to even output size = 2 * (transformed_dim_size - 1)
Args:
input (Tensor): the input tensor representing a half-Hermitian signal
n (int, optional): Output signal length. This determines the length of the
@ -771,6 +829,10 @@ The IFFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])``.
positive frequencies below the Nyquist frequency are included. To compute the
full output, use :func:`~torch.fft.ifft`.
Note:
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimension.
Args:
input (Tensor): the real input tensor
n (int, optional): Signal length. If given, the input will either be zero-padded
@ -818,6 +880,12 @@ transforms the last two dimensions by default.
:attr:`input` is interpreted as a one-sided Hermitian signal in the time
domain. By the Hermitian property, the Fourier transform will be real-valued.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
With default arguments, the size of last dimension should be (2^n + 1) as argument
`s` defaults to even output size = 2 * (last_dim_size - 1)
Args:
input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -878,6 +946,10 @@ Computes the 2-dimensional inverse discrete Fourier transform of real
:attr:`input`. Equivalent to :func:`~torch.fft.ihfftn` but transforms only the
two last dimensions by default.
Note:
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
Args:
input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -960,6 +1032,12 @@ Note:
signal is assumed to be even length and odd signals will not round-trip
properly. It is recommended to always pass the signal shape :attr:`s`.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
With default arguments, the size of last dimension should be (2^n + 1) as argument
`s` defaults to even output size = 2 * (last_dim_size - 1)
Args:
input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -1025,6 +1103,10 @@ this in the one-sided form where only the positive frequencies below the
Nyquist frequency are included in the last signal dimension. To compute the
full output, use :func:`~torch.fft.ifftn`.
Note:
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
Args:
input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions.

View File

@ -5810,15 +5810,33 @@ def np_unary_ufunc_integer_promotion_wrapper(fn):
return wrapped_fn
def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs):
nd_tensor = partial(make_tensor, (S, S + 1, S + 2), device=device,
dtype=dtype, requires_grad=requires_grad)
oned_tensor = partial(make_tensor, (31,), device=device,
dtype=dtype, requires_grad=requires_grad)
is_fp16_or_chalf = dtype == torch.complex32 or dtype == torch.half
if not is_fp16_or_chalf:
nd_tensor = partial(make_tensor, (S, S + 1, S + 2), device=device,
dtype=dtype, requires_grad=requires_grad)
oned_tensor = partial(make_tensor, (31,), device=device,
dtype=dtype, requires_grad=requires_grad)
else:
# cuFFT supports powers of 2 for half and complex half precision
# NOTE: For hfft, hfft2, hfftn, irfft, irfft2, irfftn with default args
# where output_size n=2*(input_size - 1), we make sure that logical fft size is a power of two
if self.name in ['fft.hfft', 'fft.irfft']:
shapes = ((2, 9, 9), (33,))
elif self.name in ['fft.hfft2', 'fft.irfft2']:
shapes = ((2, 8, 9), (33,))
elif self.name in ['fft.hfftn', 'fft.irfftn']:
shapes = ((2, 2, 33), (33,))
else:
shapes = ((2, 8, 16), (32,))
nd_tensor = partial(make_tensor, shapes[0], device=device,
dtype=dtype, requires_grad=requires_grad)
oned_tensor = partial(make_tensor, shapes[1], device=device,
dtype=dtype, requires_grad=requires_grad)
if self.ndimensional == SpectralFuncType.ND:
return [
SampleInput(nd_tensor(),
kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')),
kwargs=dict(s=(3, 10) if not is_fp16_or_chalf else (4, 8), dim=(1, 2), norm='ortho')),
SampleInput(nd_tensor(),
kwargs=dict(norm='ortho')),
SampleInput(nd_tensor(),
@ -5832,11 +5850,11 @@ def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwarg
elif self.ndimensional == SpectralFuncType.TwoD:
return [
SampleInput(nd_tensor(),
kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')),
kwargs=dict(s=(3, 10) if not is_fp16_or_chalf else (4, 8), dim=(1, 2), norm='ortho')),
SampleInput(nd_tensor(),
kwargs=dict(norm='ortho')),
SampleInput(nd_tensor(),
kwargs=dict(s=(6, 8))),
kwargs=dict(s=(6, 8) if not is_fp16_or_chalf else (4, 8))),
SampleInput(nd_tensor(),
kwargs=dict(dim=0)),
SampleInput(nd_tensor(),
@ -5847,11 +5865,12 @@ def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwarg
else:
return [
SampleInput(nd_tensor(),
kwargs=dict(n=10, dim=1, norm='ortho')),
kwargs=dict(n=10 if not is_fp16_or_chalf else 8, dim=1, norm='ortho')),
SampleInput(nd_tensor(),
kwargs=dict(norm='ortho')),
SampleInput(nd_tensor(),
kwargs=dict(n=7)),
kwargs=dict(n=7 if not is_fp16_or_chalf else 8)
),
SampleInput(oned_tensor()),
*(SampleInput(nd_tensor(),
@ -5887,6 +5906,8 @@ class SpectralFuncInfo(OpInfo):
decorators = list(decorators) if decorators is not None else []
decorators += [
skipCPUIfNoFFT,
DecorateInfo(toleranceOverride({torch.chalf: tol(4e-2, 4e-2)}),
"TestCommon", "test_complex_half_reference_testing")
]
super().__init__(name=name,
@ -10628,6 +10649,10 @@ op_db: List[OpInfo] = [
ref=np.fft.fft,
ndimensional=SpectralFuncType.OneD,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
),
@ -10636,6 +10661,10 @@ op_db: List[OpInfo] = [
ref=np.fft.fft2,
ndimensional=SpectralFuncType.TwoD,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=[precisionOverride(
@ -10646,6 +10675,10 @@ op_db: List[OpInfo] = [
ref=np.fft.fftn,
ndimensional=SpectralFuncType.ND,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=[precisionOverride(
@ -10656,6 +10689,10 @@ op_db: List[OpInfo] = [
ref=np.fft.hfft,
ndimensional=SpectralFuncType.OneD,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
check_batched_gradgrad=False),
@ -10664,6 +10701,10 @@ op_db: List[OpInfo] = [
ref=scipy.fft.hfft2 if has_scipy_fft else None,
ndimensional=SpectralFuncType.TwoD,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
check_batched_gradgrad=False,
@ -10677,6 +10718,10 @@ op_db: List[OpInfo] = [
ref=scipy.fft.hfftn if has_scipy_fft else None,
ndimensional=SpectralFuncType.ND,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
check_batched_gradgrad=False,
@ -10690,6 +10735,9 @@ op_db: List[OpInfo] = [
ref=np.fft.rfft,
ndimensional=SpectralFuncType.OneD,
dtypes=all_types_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
check_batched_grad=False,
@ -10701,6 +10749,9 @@ op_db: List[OpInfo] = [
ref=np.fft.rfft2,
ndimensional=SpectralFuncType.TwoD,
dtypes=all_types_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
check_batched_grad=False,
@ -10713,6 +10764,9 @@ op_db: List[OpInfo] = [
ref=np.fft.rfftn,
ndimensional=SpectralFuncType.ND,
dtypes=all_types_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
check_batched_grad=False,
@ -10726,7 +10780,11 @@ op_db: List[OpInfo] = [
ndimensional=SpectralFuncType.OneD,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
dtypes=all_types_and_complex_and(torch.bool)),
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),),
SpectralFuncInfo('fft.ifft2',
aten_name='fft_ifft2',
ref=np.fft.ifft2,
@ -10734,6 +10792,10 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
decorators=[
DecorateInfo(
precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
@ -10746,6 +10808,10 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
decorators=[
DecorateInfo(
precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
@ -10758,6 +10824,9 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
dtypes=all_types_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
skips=(
),
check_batched_grad=False),
@ -10768,6 +10837,9 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
dtypes=all_types_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
check_batched_grad=False,
check_batched_gradgrad=False,
decorators=(
@ -10784,6 +10856,9 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
dtypes=all_types_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archss
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
check_batched_grad=False,
check_batched_gradgrad=False,
decorators=[
@ -10802,6 +10877,10 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
check_batched_gradgrad=False),
SpectralFuncInfo('fft.irfft2',
aten_name='fft_irfft2',
@ -10810,6 +10889,10 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
check_batched_gradgrad=False,
decorators=[
DecorateInfo(
@ -10823,6 +10906,10 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
check_batched_gradgrad=False,
decorators=[
DecorateInfo(