mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[complex32] fft support (cuda only) (#74857)
`half` and `complex32` support for `torch.fft.{fft, fft2, fftn, hfft, hfft2, hfftn, ifft, ifft2, ifftn, ihfft, ihfft2, ihfftn, irfft, irfft2, irfftn, rfft, rfft2, rfftn}`
* We only add support for `CUDA` as `cuFFT` supports these precision.
* We still error out on `CPU` and `ROCm` as their respective backends don't support this precision
For `cuFFT` following are the constraints for these precisions
* Minimum GPU architecture is SM_53
* Sizes are restricted to powers of two only
* Strides on the real part of real-to-complex and complex-to-real transforms are not supported
* More than one GPU is not supported
* Transforms spanning more than 4 billion elements are not supported
Ref: https://docs.nvidia.com/cuda/cufft/#half-precision-transforms
TODO:
* [x] Update docs about the restrictions
* [x] Check the correct way to check for `hip` device. (seems like `device.is_cuda()` is true for hip as well) (Thanks @peterbell10 )
Ref for second point in TODO:e424e7d214/aten/src/ATen/native/SpectralOps.cpp (L31)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74857
Approved by: https://github.com/anjali411, https://github.com/peterbell10
This commit is contained in:
committed by
PyTorch MergeBot
parent
b825e1d472
commit
ada65fdd67
@ -139,6 +139,14 @@ bool CUDAHooks::hasCuSOLVER() const {
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CUDAHooks::hasROCM() const {
|
||||
// Currently, this is same as `compiledWithMIOpen`.
|
||||
// But in future if there are ROCm builds without MIOpen,
|
||||
// then `hasROCM` should return true while `compiledWithMIOpen`
|
||||
// should return false
|
||||
return AT_ROCM_ENABLED();
|
||||
}
|
||||
|
||||
#if defined(USE_DIRECT_NVRTC)
|
||||
static std::pair<std::unique_ptr<at::DynamicLibrary>, at::cuda::NVRTC*> load_nvrtc() {
|
||||
return std::make_pair(nullptr, at::cuda::load_nvrtc());
|
||||
|
@ -27,6 +27,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
bool hasMAGMA() const override;
|
||||
bool hasCuDNN() const override;
|
||||
bool hasCuSOLVER() const override;
|
||||
bool hasROCM() const override;
|
||||
const at::cuda::NVRTC& nvrtc() const override;
|
||||
int64_t current_device() const override;
|
||||
bool hasPrimaryContext(int64_t device_index) const override;
|
||||
|
@ -107,6 +107,10 @@ struct TORCH_API CUDAHooksInterface {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool hasROCM() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual const at::cuda::NVRTC& nvrtc() const {
|
||||
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ namespace {
|
||||
// * Integers are promoted to the default floating type
|
||||
// * If require_complex=True, all types are promoted to complex
|
||||
// * Raises an error for half-precision dtypes to allow future support
|
||||
ScalarType promote_type_fft(ScalarType type, bool require_complex) {
|
||||
ScalarType promote_type_fft(ScalarType type, bool require_complex, Device device) {
|
||||
if (at::isComplexType(type)) {
|
||||
return type;
|
||||
}
|
||||
@ -28,7 +28,11 @@ ScalarType promote_type_fft(ScalarType type, bool require_complex) {
|
||||
type = c10::typeMetaToScalarType(c10::get_default_dtype());
|
||||
}
|
||||
|
||||
TORCH_CHECK(type == kFloat || type == kDouble, "Unsupported dtype ", type);
|
||||
if (device.is_cuda() && !at::detail::getCUDAHooks().hasROCM()) {
|
||||
TORCH_CHECK(type == kHalf || type == kFloat || type == kDouble, "Unsupported dtype ", type);
|
||||
} else {
|
||||
TORCH_CHECK(type == kFloat || type == kDouble, "Unsupported dtype ", type);
|
||||
}
|
||||
|
||||
if (!require_complex) {
|
||||
return type;
|
||||
@ -36,6 +40,7 @@ ScalarType promote_type_fft(ScalarType type, bool require_complex) {
|
||||
|
||||
// Promote to complex
|
||||
switch (type) {
|
||||
case kHalf: return kComplexHalf;
|
||||
case kFloat: return kComplexFloat;
|
||||
case kDouble: return kComplexDouble;
|
||||
default: TORCH_INTERNAL_ASSERT(false, "Unhandled dtype");
|
||||
@ -45,7 +50,7 @@ ScalarType promote_type_fft(ScalarType type, bool require_complex) {
|
||||
// Promote a tensor's dtype according to promote_type_fft
|
||||
Tensor promote_tensor_fft(const Tensor& t, bool require_complex=false) {
|
||||
auto cur_type = t.scalar_type();
|
||||
auto new_type = promote_type_fft(cur_type, require_complex);
|
||||
auto new_type = promote_type_fft(cur_type, require_complex, t.device());
|
||||
return (cur_type == new_type) ? t : t.to(new_type);
|
||||
}
|
||||
|
||||
|
@ -106,17 +106,17 @@ void _fft_fill_with_conjugate_symmetry_cuda_(
|
||||
signal_half_sizes, out_strides, mirror_dims, element_size);
|
||||
|
||||
const auto numel = c10::multiply_integers(signal_half_sizes);
|
||||
AT_DISPATCH_COMPLEX_TYPES(dtype, "_fft_fill_with_conjugate_symmetry", [&] {
|
||||
using namespace cuda::detail;
|
||||
_fft_conjugate_copy_kernel<<<
|
||||
GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
numel,
|
||||
static_cast<scalar_t*>(out_data),
|
||||
static_cast<const scalar_t*>(in_data),
|
||||
input_offset_calculator,
|
||||
output_offset_calculator);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "_fft_fill_with_conjugate_symmetry", [&] {
|
||||
using namespace cuda::detail;
|
||||
_fft_conjugate_copy_kernel<<<
|
||||
GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
numel,
|
||||
static_cast<scalar_t*>(out_data),
|
||||
static_cast<const scalar_t*>(in_data),
|
||||
input_offset_calculator,
|
||||
output_offset_calculator);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cuda_);
|
||||
|
@ -10,12 +10,14 @@ import doctest
|
||||
import inspect
|
||||
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL)
|
||||
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL, first_sample, TEST_WITH_ROCM,
|
||||
make_tensor)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes,
|
||||
skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf)
|
||||
skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf, toleranceOverride, tol)
|
||||
from torch.testing._internal.common_methods_invocations import (
|
||||
spectral_funcs, SpectralFuncType)
|
||||
from torch.testing._internal.common_cuda import SM53OrLater
|
||||
|
||||
from setuptools import distutils
|
||||
from typing import Optional, List
|
||||
@ -110,6 +112,20 @@ def _stft_reference(x, hop_length, window):
|
||||
X[:, m] = torch.fft.fft(slc * window)
|
||||
return X
|
||||
|
||||
|
||||
def skip_helper_for_fft(device, dtype):
|
||||
device_type = torch.device(device).type
|
||||
if dtype not in (torch.half, torch.complex32):
|
||||
return
|
||||
|
||||
if device_type == 'cpu':
|
||||
raise unittest.SkipTest("half and complex32 are not supported on CPU")
|
||||
if TEST_WITH_ROCM:
|
||||
raise unittest.SkipTest("half and complex32 are not supported on ROCM")
|
||||
if not SM53OrLater:
|
||||
raise unittest.SkipTest("half and complex32 are only supported on CUDA device with SM>53")
|
||||
|
||||
|
||||
# Tests of functions related to Fourier analysis in the torch.fft namespace
|
||||
class TestFFT(TestCase):
|
||||
exact_dtype = True
|
||||
@ -157,20 +173,39 @@ class TestFFT(TestCase):
|
||||
|
||||
@skipCPUIfNoFFT
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
|
||||
@toleranceOverride({
|
||||
torch.half : tol(1e-2, 1e-2),
|
||||
torch.chalf : tol(1e-2, 1e-2),
|
||||
})
|
||||
@dtypes(torch.half, torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128)
|
||||
def test_fft_round_trip(self, device, dtype):
|
||||
skip_helper_for_fft(device, dtype)
|
||||
# Test that round trip through ifft(fft(x)) is the identity
|
||||
test_args = list(product(
|
||||
# input
|
||||
(torch.randn(67, device=device, dtype=dtype),
|
||||
torch.randn(80, device=device, dtype=dtype),
|
||||
torch.randn(12, 14, device=device, dtype=dtype),
|
||||
torch.randn(9, 6, 3, device=device, dtype=dtype)),
|
||||
# dim
|
||||
(-1, 0),
|
||||
# norm
|
||||
(None, "forward", "backward", "ortho")
|
||||
))
|
||||
if dtype not in (torch.half, torch.complex32):
|
||||
test_args = list(product(
|
||||
# input
|
||||
(torch.randn(67, device=device, dtype=dtype),
|
||||
torch.randn(80, device=device, dtype=dtype),
|
||||
torch.randn(12, 14, device=device, dtype=dtype),
|
||||
torch.randn(9, 6, 3, device=device, dtype=dtype)),
|
||||
# dim
|
||||
(-1, 0),
|
||||
# norm
|
||||
(None, "forward", "backward", "ortho")
|
||||
))
|
||||
else:
|
||||
# cuFFT supports powers of 2 for half and complex half precision
|
||||
test_args = list(product(
|
||||
# input
|
||||
(torch.randn(64, device=device, dtype=dtype),
|
||||
torch.randn(128, device=device, dtype=dtype),
|
||||
torch.randn(4, 16, device=device, dtype=dtype),
|
||||
torch.randn(8, 6, 2, device=device, dtype=dtype)),
|
||||
# dim
|
||||
(-1, 0),
|
||||
# norm
|
||||
(None, "forward", "backward", "ortho")
|
||||
))
|
||||
|
||||
fft_functions = [(torch.fft.fft, torch.fft.ifft)]
|
||||
# Real-only functions
|
||||
@ -189,13 +224,17 @@ class TestFFT(TestCase):
|
||||
}
|
||||
|
||||
y = backward(forward(x, **kwargs), **kwargs)
|
||||
if x.dtype is torch.half and y.dtype is torch.complex32:
|
||||
# Since type promotion currently doesn't work with complex32
|
||||
# manually promote `x` to complex32
|
||||
x = x.to(torch.complex32)
|
||||
# For real input, ifft(fft(x)) will convert to complex
|
||||
self.assertEqual(x, y, exact_dtype=(
|
||||
forward != torch.fft.fft or x.is_complex()))
|
||||
|
||||
# Note: NumPy will throw a ValueError for an empty input
|
||||
@onlyNativeDeviceTypes
|
||||
@ops(spectral_funcs, allowed_dtypes=(torch.float, torch.cfloat))
|
||||
@ops(spectral_funcs, allowed_dtypes=(torch.half, torch.float, torch.complex32, torch.cfloat))
|
||||
def test_empty_fft(self, device, dtype, op):
|
||||
t = torch.empty(1, 0, device=device, dtype=dtype)
|
||||
match = r"Invalid number of data points \([-\d]*\) specified"
|
||||
@ -228,8 +267,11 @@ class TestFFT(TestCase):
|
||||
|
||||
@skipCPUIfNoFFT
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.int8, torch.float, torch.double, torch.complex64, torch.complex128)
|
||||
@dtypes(torch.int8, torch.half, torch.float, torch.double,
|
||||
torch.complex32, torch.complex64, torch.complex128)
|
||||
def test_fft_type_promotion(self, device, dtype):
|
||||
skip_helper_for_fft(device, dtype)
|
||||
|
||||
if dtype.is_complex or dtype.is_floating_point:
|
||||
t = torch.randn(64, device=device, dtype=dtype)
|
||||
else:
|
||||
@ -237,8 +279,10 @@ class TestFFT(TestCase):
|
||||
|
||||
PROMOTION_MAP = {
|
||||
torch.int8: torch.complex64,
|
||||
torch.half: torch.complex32,
|
||||
torch.float: torch.complex64,
|
||||
torch.double: torch.complex128,
|
||||
torch.complex32: torch.complex32,
|
||||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
}
|
||||
@ -247,17 +291,27 @@ class TestFFT(TestCase):
|
||||
|
||||
PROMOTION_MAP_C2R = {
|
||||
torch.int8: torch.float,
|
||||
torch.half: torch.half,
|
||||
torch.float: torch.float,
|
||||
torch.double: torch.double,
|
||||
torch.complex32: torch.half,
|
||||
torch.complex64: torch.float,
|
||||
torch.complex128: torch.double,
|
||||
}
|
||||
R = torch.fft.hfft(t)
|
||||
if dtype in (torch.half, torch.complex32):
|
||||
# cuFFT supports powers of 2 for half and complex half precision
|
||||
# NOTE: With hfft and default args where output_size n=2*(input_size - 1),
|
||||
# we make sure that logical fft size is a power of two.
|
||||
x = torch.randn(65, device=device, dtype=dtype)
|
||||
R = torch.fft.hfft(x)
|
||||
else:
|
||||
R = torch.fft.hfft(t)
|
||||
self.assertEqual(R.dtype, PROMOTION_MAP_C2R[dtype])
|
||||
|
||||
if not dtype.is_complex:
|
||||
PROMOTION_MAP_R2C = {
|
||||
torch.int8: torch.complex64,
|
||||
torch.half: torch.complex32,
|
||||
torch.float: torch.complex64,
|
||||
torch.double: torch.complex128,
|
||||
}
|
||||
@ -269,9 +323,32 @@ class TestFFT(TestCase):
|
||||
allowed_dtypes=[torch.half, torch.bfloat16])
|
||||
def test_fft_half_and_bfloat16_errors(self, device, dtype, op):
|
||||
# TODO: Remove torch.half error when complex32 is fully implemented
|
||||
x = torch.randn(8, 8, device=device).to(dtype)
|
||||
with self.assertRaisesRegex(RuntimeError, "Unsupported dtype "):
|
||||
op(x)
|
||||
sample = first_sample(self, op.sample_inputs(device, dtype))
|
||||
device_type = torch.device(device).type
|
||||
if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM:
|
||||
err_msg = "Unsupported dtype "
|
||||
elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater:
|
||||
err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53"
|
||||
else:
|
||||
err_msg = "Unsupported dtype "
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
op(sample.input, *sample.args, **sample.kwargs)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@ops(spectral_funcs, allowed_dtypes=(torch.half, torch.chalf))
|
||||
def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op):
|
||||
t = make_tensor(13, 13, device=device, dtype=dtype)
|
||||
err_msg = "cuFFT only supports dimensions whose sizes are powers of two"
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
op(t)
|
||||
|
||||
if op.ndimensional in (SpectralFuncType.ND, SpectralFuncType.TwoD):
|
||||
kwargs = {'s': (12, 12)}
|
||||
else:
|
||||
kwargs = {'n': 12}
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
op(t, **kwargs)
|
||||
|
||||
# nd-fft tests
|
||||
@onlyNativeDeviceTypes
|
||||
@ -308,8 +385,15 @@ class TestFFT(TestCase):
|
||||
|
||||
@skipCPUIfNoFFT
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
|
||||
@toleranceOverride({
|
||||
torch.half : tol(1e-2, 1e-2),
|
||||
torch.chalf : tol(1e-2, 1e-2),
|
||||
})
|
||||
@dtypes(torch.half, torch.float, torch.double,
|
||||
torch.complex32, torch.complex64, torch.complex128)
|
||||
def test_fftn_round_trip(self, device, dtype):
|
||||
skip_helper_for_fft(device, dtype)
|
||||
|
||||
norm_modes = (None, "forward", "backward", "ortho")
|
||||
|
||||
# input_ndim, dim
|
||||
@ -331,7 +415,11 @@ class TestFFT(TestCase):
|
||||
(torch.fft.ihfftn, torch.fft.hfftn)]
|
||||
|
||||
for input_ndim, dim in transform_desc:
|
||||
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
|
||||
if dtype in (torch.half, torch.complex32):
|
||||
# cuFFT supports powers of 2 for half and complex half precision
|
||||
shape = itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)
|
||||
else:
|
||||
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
|
||||
x = torch.randn(*shape, device=device, dtype=dtype)
|
||||
|
||||
for (forward, backward), norm in product(fft_functions, norm_modes):
|
||||
@ -343,8 +431,13 @@ class TestFFT(TestCase):
|
||||
kwargs = {'s': s, 'dim': dim, 'norm': norm}
|
||||
y = backward(forward(x, **kwargs), **kwargs)
|
||||
# For real input, ifftn(fftn(x)) will convert to complex
|
||||
self.assertEqual(x, y, exact_dtype=(
|
||||
forward != torch.fft.fftn or x.is_complex()))
|
||||
if x.dtype is torch.half and y.dtype is torch.chalf:
|
||||
# Since type promotion currently doesn't work with complex32
|
||||
# manually promote `x` to complex32
|
||||
self.assertEqual(x.to(torch.chalf), y)
|
||||
else:
|
||||
self.assertEqual(x, y, exact_dtype=(
|
||||
forward != torch.fft.fftn or x.is_complex()))
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
|
||||
@ -369,8 +462,13 @@ class TestFFT(TestCase):
|
||||
|
||||
@skipCPUIfNoFFT
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.float, torch.double)
|
||||
@toleranceOverride({
|
||||
torch.half : tol(1e-2, 1e-2),
|
||||
})
|
||||
@dtypes(torch.half, torch.float, torch.double)
|
||||
def test_hfftn(self, device, dtype):
|
||||
skip_helper_for_fft(device, dtype)
|
||||
|
||||
# input_ndim, dim
|
||||
transform_desc = [
|
||||
*product(range(2, 5), (None, (0,), (0, -1))),
|
||||
@ -383,8 +481,10 @@ class TestFFT(TestCase):
|
||||
|
||||
for input_ndim, dim in transform_desc:
|
||||
actual_dims = list(range(input_ndim)) if dim is None else dim
|
||||
|
||||
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
|
||||
if dtype is torch.half:
|
||||
shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
|
||||
else:
|
||||
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
|
||||
expect = torch.randn(*shape, device=device, dtype=dtype)
|
||||
input = torch.fft.ifftn(expect, dim=dim, norm="ortho")
|
||||
|
||||
@ -401,8 +501,13 @@ class TestFFT(TestCase):
|
||||
|
||||
@skipCPUIfNoFFT
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.float, torch.double)
|
||||
@toleranceOverride({
|
||||
torch.half : tol(1e-2, 1e-2),
|
||||
})
|
||||
@dtypes(torch.half, torch.float, torch.double)
|
||||
def test_ihfftn(self, device, dtype):
|
||||
skip_helper_for_fft(device, dtype)
|
||||
|
||||
# input_ndim, dim
|
||||
transform_desc = [
|
||||
*product(range(2, 5), (None, (0,), (0, -1))),
|
||||
@ -414,7 +519,11 @@ class TestFFT(TestCase):
|
||||
]
|
||||
|
||||
for input_ndim, dim in transform_desc:
|
||||
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
|
||||
if dtype is torch.half:
|
||||
shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
|
||||
else:
|
||||
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
|
||||
|
||||
input = torch.randn(*shape, device=device, dtype=dtype)
|
||||
expect = torch.fft.ifftn(input, dim=dim, norm="ortho")
|
||||
|
||||
|
@ -20,7 +20,6 @@ fft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
|
||||
Computes the one dimensional discrete Fourier transform of :attr:`input`.
|
||||
|
||||
Note:
|
||||
|
||||
The Fourier domain representation of any real signal satisfies the
|
||||
Hermitian property: `X[i] = conj(X[-i])`. This function always returns both
|
||||
the positive and negative frequency terms even though, for real inputs, the
|
||||
@ -28,6 +27,10 @@ Note:
|
||||
more compact one-sided representation where only the positive frequencies
|
||||
are returned.
|
||||
|
||||
Note:
|
||||
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimension.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
n (int, optional): Signal length. If given, the input will either be zero-padded
|
||||
@ -68,6 +71,10 @@ ifft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
|
||||
|
||||
Computes the one dimensional inverse discrete Fourier transform of :attr:`input`.
|
||||
|
||||
Note:
|
||||
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimension.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
n (int, optional): Signal length. If given, the input will either be zero-padded
|
||||
@ -111,6 +118,10 @@ Note:
|
||||
:func:`~torch.fft.rfft2` returns the more compact one-sided representation
|
||||
where only the positive frequencies of the last dimension are returned.
|
||||
|
||||
Note:
|
||||
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||
@ -157,6 +168,10 @@ ifft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor
|
||||
Computes the 2 dimensional inverse discrete Fourier transform of :attr:`input`.
|
||||
Equivalent to :func:`~torch.fft.ifftn` but IFFTs only the last two dimensions by default.
|
||||
|
||||
Note:
|
||||
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||
@ -203,7 +218,6 @@ fftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
|
||||
Computes the N dimensional discrete Fourier transform of :attr:`input`.
|
||||
|
||||
Note:
|
||||
|
||||
The Fourier domain representation of any real signal satisfies the
|
||||
Hermitian property: ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])``. This
|
||||
function always returns all positive and negative frequency terms even
|
||||
@ -211,6 +225,10 @@ Note:
|
||||
:func:`~torch.fft.rfftn` returns the more compact one-sided representation
|
||||
where only the positive frequencies of the last dimension are returned.
|
||||
|
||||
Note:
|
||||
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||
@ -256,6 +274,10 @@ ifftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
|
||||
|
||||
Computes the N dimensional inverse discrete Fourier transform of :attr:`input`.
|
||||
|
||||
Note:
|
||||
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||
@ -305,6 +327,10 @@ The FFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])`` so
|
||||
the output contains only the positive frequencies below the Nyquist frequency.
|
||||
To compute the full output, use :func:`~torch.fft.fft`
|
||||
|
||||
Note:
|
||||
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimension.
|
||||
|
||||
Args:
|
||||
input (Tensor): the real input tensor
|
||||
n (int, optional): Signal length. If given, the input will either be zero-padded
|
||||
@ -367,6 +393,12 @@ Note:
|
||||
signal is assumed to be even length and odd signals will not round-trip
|
||||
properly. So, it is recommended to always pass the signal length :attr:`n`.
|
||||
|
||||
Note:
|
||||
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimension.
|
||||
With default arguments, size of the transformed dimension should be (2^n + 1) as argument
|
||||
`n` defaults to even output size = 2 * (transformed_dim_size - 1)
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor representing a half-Hermitian signal
|
||||
n (int, optional): Output signal length. This determines the length of the
|
||||
@ -424,6 +456,10 @@ so the full :func:`~torch.fft.fft2` output contains redundant information.
|
||||
:func:`~torch.fft.rfft2` instead omits the negative frequencies in the last
|
||||
dimension.
|
||||
|
||||
Note:
|
||||
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||
@ -496,6 +532,12 @@ Note:
|
||||
signal is assumed to be even length and odd signals will not round-trip
|
||||
properly. So, it is recommended to always pass the signal shape :attr:`s`.
|
||||
|
||||
Note:
|
||||
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||
With default arguments, the size of last dimension should be (2^n + 1) as argument
|
||||
`s` defaults to even output size = 2 * (last_dim_size - 1)
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||
@ -557,6 +599,10 @@ The FFT of a real signal is Hermitian-symmetric,
|
||||
:func:`~torch.fft.rfftn` instead omits the negative frequencies in the
|
||||
last dimension.
|
||||
|
||||
Note:
|
||||
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||
@ -628,6 +674,12 @@ Note:
|
||||
signal is assumed to be even length and odd signals will not round-trip
|
||||
properly. So, it is recommended to always pass the signal shape :attr:`s`.
|
||||
|
||||
Note:
|
||||
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||
With default arguments, the size of last dimension should be (2^n + 1) as argument
|
||||
`s` defaults to even output size = 2 * (last_dim_size - 1)
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||
@ -709,6 +761,12 @@ Note:
|
||||
signal is assumed to be even length and odd signals will not round-trip
|
||||
properly. So, it is recommended to always pass the signal length :attr:`n`.
|
||||
|
||||
Note:
|
||||
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimension.
|
||||
With default arguments, size of the transformed dimension should be (2^n + 1) as argument
|
||||
`n` defaults to even output size = 2 * (transformed_dim_size - 1)
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor representing a half-Hermitian signal
|
||||
n (int, optional): Output signal length. This determines the length of the
|
||||
@ -771,6 +829,10 @@ The IFFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])``.
|
||||
positive frequencies below the Nyquist frequency are included. To compute the
|
||||
full output, use :func:`~torch.fft.ifft`.
|
||||
|
||||
Note:
|
||||
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimension.
|
||||
|
||||
Args:
|
||||
input (Tensor): the real input tensor
|
||||
n (int, optional): Signal length. If given, the input will either be zero-padded
|
||||
@ -818,6 +880,12 @@ transforms the last two dimensions by default.
|
||||
:attr:`input` is interpreted as a one-sided Hermitian signal in the time
|
||||
domain. By the Hermitian property, the Fourier transform will be real-valued.
|
||||
|
||||
Note:
|
||||
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||
With default arguments, the size of last dimension should be (2^n + 1) as argument
|
||||
`s` defaults to even output size = 2 * (last_dim_size - 1)
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||
@ -878,6 +946,10 @@ Computes the 2-dimensional inverse discrete Fourier transform of real
|
||||
:attr:`input`. Equivalent to :func:`~torch.fft.ihfftn` but transforms only the
|
||||
two last dimensions by default.
|
||||
|
||||
Note:
|
||||
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||
@ -960,6 +1032,12 @@ Note:
|
||||
signal is assumed to be even length and odd signals will not round-trip
|
||||
properly. It is recommended to always pass the signal shape :attr:`s`.
|
||||
|
||||
Note:
|
||||
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||
With default arguments, the size of last dimension should be (2^n + 1) as argument
|
||||
`s` defaults to even output size = 2 * (last_dim_size - 1)
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||
@ -1025,6 +1103,10 @@ this in the one-sided form where only the positive frequencies below the
|
||||
Nyquist frequency are included in the last signal dimension. To compute the
|
||||
full output, use :func:`~torch.fft.ifftn`.
|
||||
|
||||
Note:
|
||||
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
|
||||
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||
|
@ -5810,15 +5810,33 @@ def np_unary_ufunc_integer_promotion_wrapper(fn):
|
||||
return wrapped_fn
|
||||
|
||||
def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs):
|
||||
nd_tensor = partial(make_tensor, (S, S + 1, S + 2), device=device,
|
||||
dtype=dtype, requires_grad=requires_grad)
|
||||
oned_tensor = partial(make_tensor, (31,), device=device,
|
||||
dtype=dtype, requires_grad=requires_grad)
|
||||
is_fp16_or_chalf = dtype == torch.complex32 or dtype == torch.half
|
||||
if not is_fp16_or_chalf:
|
||||
nd_tensor = partial(make_tensor, (S, S + 1, S + 2), device=device,
|
||||
dtype=dtype, requires_grad=requires_grad)
|
||||
oned_tensor = partial(make_tensor, (31,), device=device,
|
||||
dtype=dtype, requires_grad=requires_grad)
|
||||
else:
|
||||
# cuFFT supports powers of 2 for half and complex half precision
|
||||
# NOTE: For hfft, hfft2, hfftn, irfft, irfft2, irfftn with default args
|
||||
# where output_size n=2*(input_size - 1), we make sure that logical fft size is a power of two
|
||||
if self.name in ['fft.hfft', 'fft.irfft']:
|
||||
shapes = ((2, 9, 9), (33,))
|
||||
elif self.name in ['fft.hfft2', 'fft.irfft2']:
|
||||
shapes = ((2, 8, 9), (33,))
|
||||
elif self.name in ['fft.hfftn', 'fft.irfftn']:
|
||||
shapes = ((2, 2, 33), (33,))
|
||||
else:
|
||||
shapes = ((2, 8, 16), (32,))
|
||||
nd_tensor = partial(make_tensor, shapes[0], device=device,
|
||||
dtype=dtype, requires_grad=requires_grad)
|
||||
oned_tensor = partial(make_tensor, shapes[1], device=device,
|
||||
dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
if self.ndimensional == SpectralFuncType.ND:
|
||||
return [
|
||||
SampleInput(nd_tensor(),
|
||||
kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')),
|
||||
kwargs=dict(s=(3, 10) if not is_fp16_or_chalf else (4, 8), dim=(1, 2), norm='ortho')),
|
||||
SampleInput(nd_tensor(),
|
||||
kwargs=dict(norm='ortho')),
|
||||
SampleInput(nd_tensor(),
|
||||
@ -5832,11 +5850,11 @@ def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwarg
|
||||
elif self.ndimensional == SpectralFuncType.TwoD:
|
||||
return [
|
||||
SampleInput(nd_tensor(),
|
||||
kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')),
|
||||
kwargs=dict(s=(3, 10) if not is_fp16_or_chalf else (4, 8), dim=(1, 2), norm='ortho')),
|
||||
SampleInput(nd_tensor(),
|
||||
kwargs=dict(norm='ortho')),
|
||||
SampleInput(nd_tensor(),
|
||||
kwargs=dict(s=(6, 8))),
|
||||
kwargs=dict(s=(6, 8) if not is_fp16_or_chalf else (4, 8))),
|
||||
SampleInput(nd_tensor(),
|
||||
kwargs=dict(dim=0)),
|
||||
SampleInput(nd_tensor(),
|
||||
@ -5847,11 +5865,12 @@ def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwarg
|
||||
else:
|
||||
return [
|
||||
SampleInput(nd_tensor(),
|
||||
kwargs=dict(n=10, dim=1, norm='ortho')),
|
||||
kwargs=dict(n=10 if not is_fp16_or_chalf else 8, dim=1, norm='ortho')),
|
||||
SampleInput(nd_tensor(),
|
||||
kwargs=dict(norm='ortho')),
|
||||
SampleInput(nd_tensor(),
|
||||
kwargs=dict(n=7)),
|
||||
kwargs=dict(n=7 if not is_fp16_or_chalf else 8)
|
||||
),
|
||||
SampleInput(oned_tensor()),
|
||||
|
||||
*(SampleInput(nd_tensor(),
|
||||
@ -5887,6 +5906,8 @@ class SpectralFuncInfo(OpInfo):
|
||||
decorators = list(decorators) if decorators is not None else []
|
||||
decorators += [
|
||||
skipCPUIfNoFFT,
|
||||
DecorateInfo(toleranceOverride({torch.chalf: tol(4e-2, 4e-2)}),
|
||||
"TestCommon", "test_complex_half_reference_testing")
|
||||
]
|
||||
|
||||
super().__init__(name=name,
|
||||
@ -10628,6 +10649,10 @@ op_db: List[OpInfo] = [
|
||||
ref=np.fft.fft,
|
||||
ndimensional=SpectralFuncType.OneD,
|
||||
dtypes=all_types_and_complex_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and_complex_and(
|
||||
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
),
|
||||
@ -10636,6 +10661,10 @@ op_db: List[OpInfo] = [
|
||||
ref=np.fft.fft2,
|
||||
ndimensional=SpectralFuncType.TwoD,
|
||||
dtypes=all_types_and_complex_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and_complex_and(
|
||||
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
decorators=[precisionOverride(
|
||||
@ -10646,6 +10675,10 @@ op_db: List[OpInfo] = [
|
||||
ref=np.fft.fftn,
|
||||
ndimensional=SpectralFuncType.ND,
|
||||
dtypes=all_types_and_complex_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and_complex_and(
|
||||
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
decorators=[precisionOverride(
|
||||
@ -10656,6 +10689,10 @@ op_db: List[OpInfo] = [
|
||||
ref=np.fft.hfft,
|
||||
ndimensional=SpectralFuncType.OneD,
|
||||
dtypes=all_types_and_complex_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and_complex_and(
|
||||
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
check_batched_gradgrad=False),
|
||||
@ -10664,6 +10701,10 @@ op_db: List[OpInfo] = [
|
||||
ref=scipy.fft.hfft2 if has_scipy_fft else None,
|
||||
ndimensional=SpectralFuncType.TwoD,
|
||||
dtypes=all_types_and_complex_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and_complex_and(
|
||||
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
check_batched_gradgrad=False,
|
||||
@ -10677,6 +10718,10 @@ op_db: List[OpInfo] = [
|
||||
ref=scipy.fft.hfftn if has_scipy_fft else None,
|
||||
ndimensional=SpectralFuncType.ND,
|
||||
dtypes=all_types_and_complex_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and_complex_and(
|
||||
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
check_batched_gradgrad=False,
|
||||
@ -10690,6 +10735,9 @@ op_db: List[OpInfo] = [
|
||||
ref=np.fft.rfft,
|
||||
ndimensional=SpectralFuncType.OneD,
|
||||
dtypes=all_types_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
check_batched_grad=False,
|
||||
@ -10701,6 +10749,9 @@ op_db: List[OpInfo] = [
|
||||
ref=np.fft.rfft2,
|
||||
ndimensional=SpectralFuncType.TwoD,
|
||||
dtypes=all_types_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
check_batched_grad=False,
|
||||
@ -10713,6 +10764,9 @@ op_db: List[OpInfo] = [
|
||||
ref=np.fft.rfftn,
|
||||
ndimensional=SpectralFuncType.ND,
|
||||
dtypes=all_types_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
check_batched_grad=False,
|
||||
@ -10726,7 +10780,11 @@ op_db: List[OpInfo] = [
|
||||
ndimensional=SpectralFuncType.OneD,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
dtypes=all_types_and_complex_and(torch.bool)),
|
||||
dtypes=all_types_and_complex_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and_complex_and(
|
||||
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),),
|
||||
SpectralFuncInfo('fft.ifft2',
|
||||
aten_name='fft_ifft2',
|
||||
ref=np.fft.ifft2,
|
||||
@ -10734,6 +10792,10 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
dtypes=all_types_and_complex_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and_complex_and(
|
||||
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||
decorators=[
|
||||
DecorateInfo(
|
||||
precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
|
||||
@ -10746,6 +10808,10 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
dtypes=all_types_and_complex_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and_complex_and(
|
||||
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||
decorators=[
|
||||
DecorateInfo(
|
||||
precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
|
||||
@ -10758,6 +10824,9 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
dtypes=all_types_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
|
||||
skips=(
|
||||
),
|
||||
check_batched_grad=False),
|
||||
@ -10768,6 +10837,9 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
dtypes=all_types_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
|
||||
check_batched_grad=False,
|
||||
check_batched_gradgrad=False,
|
||||
decorators=(
|
||||
@ -10784,6 +10856,9 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
dtypes=all_types_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archss
|
||||
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
|
||||
check_batched_grad=False,
|
||||
check_batched_gradgrad=False,
|
||||
decorators=[
|
||||
@ -10802,6 +10877,10 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
dtypes=all_types_and_complex_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and_complex_and(
|
||||
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||
check_batched_gradgrad=False),
|
||||
SpectralFuncInfo('fft.irfft2',
|
||||
aten_name='fft_irfft2',
|
||||
@ -10810,6 +10889,10 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
dtypes=all_types_and_complex_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and_complex_and(
|
||||
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||
check_batched_gradgrad=False,
|
||||
decorators=[
|
||||
DecorateInfo(
|
||||
@ -10823,6 +10906,10 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
dtypes=all_types_and_complex_and(torch.bool),
|
||||
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||
dtypesIfCUDA=all_types_and_complex_and(
|
||||
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||
check_batched_gradgrad=False,
|
||||
decorators=[
|
||||
DecorateInfo(
|
||||
|
Reference in New Issue
Block a user