Add PocketFFT support (#60976)

Summary:
Needed on platforms, that do not have MKL, such as aarch64 and M1
- Add `AT_POCKETFFT_ENABLED()` to Config.h.in
- Introduce torch._C.has_spectral that is true if PyTorch was compiled with either MKL or PocketFFT
- Modify spectral test to use skipCPUIfNoFFT instead of skipCPUIfNoMKL

Share implementation of `_out` functions as well as fft_fill_with_conjugate_symmetry_stub between MKL and PocketFFT implementations

Fixes https://github.com/pytorch/pytorch/issues/41592

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60976

Reviewed By: walterddr, driazati, janeyx99, samestep

Differential Revision: D29466530

Pulled By: malfet

fbshipit-source-id: ac5edb3d40e7c413267825f92a5e8bc4bb249caf
This commit is contained in:
Nikita Shulga
2021-06-30 16:27:07 -07:00
committed by Facebook GitHub Bot
parent 2d0c6e60a7
commit 4036820506
10 changed files with 273 additions and 116 deletions

View File

@ -516,6 +516,7 @@ header_template_rule(
"@AT_MKLDNN_ENABLED@": "1",
"@AT_MKL_ENABLED@": "0",
"@AT_FFTW_ENABLED@": "0",
"@AT_POCKETFFT_ENABLED@": "0",
"@AT_NNPACK_ENABLED@": "0",
"@CAFFE2_STATIC_LINK_CUDA_INT@": "0",
"@AT_BUILD_WITH_BLAS@": "1",

View File

@ -9,6 +9,7 @@
#define AT_MKLDNN_ENABLED() @AT_MKLDNN_ENABLED@
#define AT_MKL_ENABLED() @AT_MKL_ENABLED@
#define AT_FFTW_ENABLED() @AT_FFTW_ENABLED@
#define AT_POCKETFFT_ENABLED() @AT_POCKETFFT_ENABLED@
#define AT_NNPACK_ENABLED() @AT_NNPACK_ENABLED@
#define CAFFE2_STATIC_LINK_CUDA() @CAFFE2_STATIC_LINK_CUDA_INT@
#define AT_BUILD_WITH_BLAS() @AT_BUILD_WITH_BLAS@

View File

@ -6,66 +6,10 @@
#include <c10/util/accumulate.h>
#include <c10/util/irange.h>
#if !AT_MKL_ENABLED()
namespace at { namespace native {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_CPU_DISPATCH(fft_fill_with_conjugate_symmetry_stub, fft_fill_with_conjugate_symmetry_fn);
Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
AT_ERROR("fft: ATen not compiled with MKL support");
}
Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
AT_ERROR("fft: ATen not compiled with MKL support");
}
Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
AT_ERROR("fft: ATen not compiled with MKL support");
}
Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
bool onesided, Tensor& out) {
AT_ERROR("fft: ATen not compiled with MKL support");
}
Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
int64_t last_dim_size, Tensor& out) {
AT_ERROR("fft: ATen not compiled with MKL support");
}
Tensor& _fft_c2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
bool forward, Tensor& out) {
AT_ERROR("fft: ATen not compiled with MKL support");
}
}}
#else // AT_MKL_ENABLED
#include <ATen/ATen.h>
#include <ATen/Config.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED()
#include <ATen/Parallel.h>
#include <ATen/Utils.h>
#include <ATen/native/TensorIterator.h>
#include <algorithm>
#include <vector>
#include <numeric>
#include <cmath>
#include <mkl_dfti.h>
#include <ATen/mkl/Exceptions.h>
#include <ATen/mkl/Descriptors.h>
#include <ATen/mkl/Limits.h>
namespace at { namespace native {
// In real-to-complex transform, MKL FFT only fills half of the values due to
// conjugate symmetry. See native/SpectralUtils.h for more details.
// The following structs are used to fill in the other half with symmetry in
@ -208,6 +152,186 @@ REGISTER_ARCH_DISPATCH(fft_fill_with_conjugate_symmetry_stub, DEFAULT, &_fft_fil
REGISTER_AVX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
// _out variants can be shared between PocketFFT and MKL
Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
bool onesided, Tensor& out) {
auto result = _fft_r2c_mkl(self, dim, normalization, /*onesided=*/true);
if (onesided) {
resize_output(out, result.sizes());
return out.copy_(result);
}
resize_output(out, self.sizes());
auto last_dim = dim.back();
auto last_dim_halfsize = result.sizes()[last_dim];
auto out_slice = out.slice(last_dim, 0, last_dim_halfsize);
out_slice.copy_(result);
at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
return out;
}
Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
int64_t last_dim_size, Tensor& out) {
auto result = _fft_c2r_mkl(self, dim, normalization, last_dim_size);
resize_output(out, result.sizes());
return out.copy_(result);
}
Tensor& _fft_c2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
bool forward, Tensor& out) {
auto result = _fft_c2c_mkl(self, dim, normalization, forward);
resize_output(out, result.sizes());
return out.copy_(result);
}
}} // namespace at::native
#endif /* AT_MKL_ENALED() || AT_POCKETFFT_ENABLED() */
#if AT_POCKETFFT_ENABLED()
#include <pocketfft_hdronly.h>
namespace at { namespace native {
namespace {
using namespace pocketfft;
stride_t stride_from_tensor(const Tensor& t) {
stride_t stride(t.strides().begin(), t.strides().end());
for(auto& s: stride) {
s *= t.element_size();
}
return stride;
}
inline shape_t shape_from_tensor(const Tensor& t) {
return shape_t(t.sizes().begin(), t.sizes().end());
}
template<typename T>
inline std::complex<T> *tensor_cdata(Tensor& t) {
return reinterpret_cast<std::complex<T>*>(t.data<c10::complex<T>>());
}
template<typename T>
inline const std::complex<T> *tensor_cdata(const Tensor& t) {
return reinterpret_cast<const std::complex<T>*>(t.data<c10::complex<T>>());
}
template<typename T>
T compute_fct(int64_t size, int64_t normalization) {
constexpr auto one = static_cast<T>(1);
switch (static_cast<fft_norm_mode>(normalization)) {
case fft_norm_mode::none: return one;
case fft_norm_mode::by_n: return one / static_cast<T>(size);
case fft_norm_mode::by_root_n: return one / std::sqrt(static_cast<T>(size));
}
AT_ERROR("Unsupported normalization type", normalization);
}
template<typename T>
T compute_fct(const Tensor& t, IntArrayRef dim, int64_t normalization) {
if (static_cast<fft_norm_mode>(normalization) == fft_norm_mode::none) {
return static_cast<T>(1);
}
const auto& sizes = t.sizes();
int64_t n = 1;
for(auto idx: dim) {
n *= sizes[idx];
}
return compute_fct<T>(n, normalization);
}
} // anonymous namespace
Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
auto in_sizes = self.sizes();
DimVector out_sizes(in_sizes.begin(), in_sizes.end());
out_sizes[dim.back()] = last_dim_size;
auto out = at::empty(out_sizes, self.options().dtype(c10::toValueType(self.scalar_type())));
pocketfft::shape_t axes(dim.begin(), dim.end());
if (self.scalar_type() == kComplexFloat) {
pocketfft::c2r(shape_from_tensor(out), stride_from_tensor(self), stride_from_tensor(out), axes, false,
tensor_cdata<float>(self),
out.data<float>(), compute_fct<float>(out, dim, normalization));
} else {
pocketfft::c2r(shape_from_tensor(out), stride_from_tensor(self), stride_from_tensor(out), axes, false,
tensor_cdata<double>(self),
out.data<double>(), compute_fct<double>(out, dim, normalization));
}
return out;
}
Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
TORCH_CHECK(self.is_floating_point());
auto input_sizes = self.sizes();
DimVector out_sizes(input_sizes.begin(), input_sizes.end());
auto last_dim = dim.back();
auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
if (onesided) {
out_sizes[last_dim] = last_dim_halfsize;
}
auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type())));
pocketfft::shape_t axes(dim.begin(), dim.end());
if (self.scalar_type() == kFloat) {
pocketfft::r2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, true,
self.data<float>(),
tensor_cdata<float>(out), compute_fct<float>(self, dim, normalization));
} else {
pocketfft::r2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, true,
self.data<double>(),
tensor_cdata<double>(out), compute_fct<double>(self, dim, normalization));
}
if (!onesided) {
at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
}
return out;
}
Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
TORCH_CHECK(self.is_complex());
auto out = at::empty(self.sizes(), self.options());
pocketfft::shape_t axes(dim.begin(), dim.end());
if (self.scalar_type() == kComplexFloat) {
pocketfft::c2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, forward,
tensor_cdata<float>(self),
tensor_cdata<float>(out), compute_fct<float>(self, dim, normalization));
} else {
pocketfft::c2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, forward,
tensor_cdata<double>(self),
tensor_cdata<double>(out), compute_fct<double>(self, dim, normalization));
}
return out;
}
}}
#elif AT_MKL_ENABLED()
#include <ATen/ATen.h>
#include <ATen/Config.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Utils.h>
#include <ATen/native/TensorIterator.h>
#include <algorithm>
#include <vector>
#include <numeric>
#include <cmath>
#include <mkl_dfti.h>
#include <ATen/mkl/Exceptions.h>
#include <ATen/mkl/Descriptors.h>
#include <ATen/mkl/Limits.h>
namespace at { namespace native {
// Constructs an mkl-fft plan descriptor representing the desired transform
// For complex types, strides are in units of 2 * element_size(dtype)
// sizes are for the full signal, including batch size and always two-sided
@ -401,13 +525,6 @@ Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization,
return _exec_fft(out, input, out_sizes, dim, normalization, /*forward=*/false);
}
Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
int64_t last_dim_size, Tensor& out) {
auto result = _fft_c2r_mkl(self, dim, normalization, last_dim_size);
resize_output(out, result.sizes());
return out.copy_(result);
}
// n-dimensional real to complex FFT
Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
TORCH_CHECK(self.is_floating_point());
@ -429,24 +546,6 @@ Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization,
return out;
}
Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
bool onesided, Tensor& out) {
auto result = _fft_r2c_mkl(self, dim, normalization, /*onesided=*/true);
if (onesided) {
resize_output(out, result.sizes());
return out.copy_(result);
}
resize_output(out, self.sizes());
auto last_dim = dim.back();
auto last_dim_halfsize = result.sizes()[last_dim];
auto out_slice = out.slice(last_dim, 0, last_dim_halfsize);
out_slice.copy_(result);
at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
return out;
}
// n-dimensional complex to complex FFT/IFFT
Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
TORCH_CHECK(self.is_complex());
@ -455,13 +554,40 @@ Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization,
return _exec_fft(out, self, self.sizes(), sorted_dims, normalization, forward);
}
}} // namespace at::native
#else
namespace at { namespace native {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_CPU_DISPATCH(fft_fill_with_conjugate_symmetry_stub, fft_fill_with_conjugate_symmetry_fn);
Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
AT_ERROR("fft: ATen not compiled with FFT support");
}
Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
AT_ERROR("fft: ATen not compiled with FFT support");
}
Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
AT_ERROR("fft: ATen not compiled with FFT support");
}
Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
bool onesided, Tensor& out) {
AT_ERROR("fft: ATen not compiled with FFT support");
}
Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
int64_t last_dim_size, Tensor& out) {
AT_ERROR("fft: ATen not compiled with FFT support");
}
Tensor& _fft_c2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
bool forward, Tensor& out) {
auto result = _fft_c2c_mkl(self, dim, normalization, forward);
resize_output(out, result.sizes());
return out.copy_(result);
AT_ERROR("fft: ATen not compiled with FFT support");
}
}} // namespace at::native
#endif

View File

@ -518,6 +518,11 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/tensorexpr/llvm_codegen.cpp PROPERTIES COMPILE_FLAGS -Wno-init-list-lifetime)
endif()
# Pass path to PocketFFT
if(AT_POCKETFFT_ENABLED)
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/mkl/SpectralOps.cpp PROPERTIES INCLUDE_DIRECTORIES "${POCKETFFT_INCLUDE_DIR}")
endif()
if(NOT INTERN_DISABLE_MOBILE_INTERP)
set(MOBILE_SRCS
${TORCH_SRC_DIR}/csrc/jit/mobile/function.cpp

View File

@ -216,6 +216,18 @@ if(USE_FFTW OR NOT MKL_FOUND)
endif()
endif()
# --- [ PocketFFT
set(AT_POCKETFFT_ENABLED 0)
if(NOT MKL_FOUND)
find_path(POCKETFFT_INCLUDE_DIR NAMES pocketfft_hdronly.h
PATHS /usr/local/include
PATHS $ENV{POCKETFFT_HOME}
)
if(POCKETFFT_INCLUDE_DIR)
set(AT_POCKETFFT_ENABLED 1)
endif()
endif()
# ---[ Dependencies
# NNPACK and family (QNNPACK, PYTORCH_QNNPACK, and XNNPACK) can download and
# compile their dependencies in isolation as part of their build. These dependencies

View File

@ -11,8 +11,7 @@ from torch.testing._internal.common_utils import \
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, dtypes, onlyOnCPUAndCUDA,
skipCPUIfNoMkl, deviceCountAtLeast, onlyCUDA, OpDTypes,
skipIf)
skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf)
from torch.testing._internal.common_methods_invocations import spectral_funcs, SpectralFuncInfo
from setuptools import distutils
@ -167,7 +166,7 @@ class TestFFT(TestCase):
actual = op(input, *args)
self.assertEqual(actual, expected, exact_dtype=exact_dtype)
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@onlyOnCPUAndCUDA
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
def test_fft_round_trip(self, device, dtype):
@ -228,7 +227,7 @@ class TestFFT(TestCase):
with self.assertRaisesRegex(RuntimeError, "ihfft expects a real input tensor"):
torch.fft.ihfft(t)
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@onlyOnCPUAndCUDA
@dtypes(torch.int8, torch.float, torch.double, torch.complex64, torch.complex128)
def test_fft_type_promotion(self, device, dtype):
@ -310,7 +309,7 @@ class TestFFT(TestCase):
actual = op(input, s, dim, norm)
self.assertEqual(actual, expected, exact_dtype=exact_dtype)
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@onlyOnCPUAndCUDA
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
def test_fftn_round_trip(self, device, dtype):
@ -375,7 +374,7 @@ class TestFFT(TestCase):
# NOTE: 2d transforms are only thin wrappers over n-dim transforms,
# so don't require exhaustive testing.
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@onlyOnCPUAndCUDA
@dtypes(torch.double, torch.complex128)
def test_fft2_numpy(self, device, dtype):
@ -423,7 +422,7 @@ class TestFFT(TestCase):
actual = fn(valid_input, s, dim, norm)
self.assertEqual(actual, expected)
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@onlyOnCPUAndCUDA
@dtypes(torch.float, torch.complex64)
def test_fft2_fftn_equivalence(self, device, dtype):
@ -460,7 +459,7 @@ class TestFFT(TestCase):
self.assertEqual(actual, expect)
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@onlyOnCPUAndCUDA
def test_fft2_invalid(self, device):
a = torch.rand(10, 10, 10, device=device)
@ -486,7 +485,7 @@ class TestFFT(TestCase):
# Helper functions
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@onlyOnCPUAndCUDA
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
@dtypes(torch.float, torch.double)
@ -512,7 +511,7 @@ class TestFFT(TestCase):
actual = torch_fn(*args, device=device, dtype=dtype)
self.assertEqual(actual, expected, exact_dtype=False)
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@onlyOnCPUAndCUDA
@dtypes(torch.float, torch.double)
def test_fftfreq_out(self, device, dtype):
@ -524,7 +523,7 @@ class TestFFT(TestCase):
self.assertEqual(actual, expect)
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@onlyOnCPUAndCUDA
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
@ -550,7 +549,7 @@ class TestFFT(TestCase):
actual = torch_fn(input, dim=dim)
self.assertEqual(actual, expected)
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@onlyOnCPUAndCUDA
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
@dtypes(torch.float, torch.double)
@ -629,7 +628,7 @@ class TestFFT(TestCase):
_test_complex((40, 60, 3, 80), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:])
_test_complex((30, 55, 50, 22), 3, lambda x: x[:, 3:53, 15:40, 1:21])
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@onlyOnCPUAndCUDA
@dtypes(torch.double)
def test_fft_ifft_rfft_irfft(self, device, dtype):
@ -708,7 +707,7 @@ class TestFFT(TestCase):
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1
# passes on ROCm w/ python 2.7, fails w/ python 3.6
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@onlyOnCPUAndCUDA
@dtypes(torch.double)
def test_stft(self, device, dtype):
@ -775,7 +774,7 @@ class TestFFT(TestCase):
_test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError)
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@dtypes(torch.double, torch.cdouble)
def test_complex_stft_roundtrip(self, device, dtype):
test_args = list(product(
@ -817,7 +816,7 @@ class TestFFT(TestCase):
self.assertEqual(x_roundtrip, x)
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@dtypes(torch.double, torch.cdouble)
def test_stft_roundtrip_complex_window(self, device, dtype):
test_args = list(product(
@ -857,7 +856,7 @@ class TestFFT(TestCase):
self.assertEqual(x_roundtrip, x)
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@dtypes(torch.cdouble)
def test_complex_stft_definition(self, device, dtype):
test_args = list(product(
@ -877,7 +876,7 @@ class TestFFT(TestCase):
self.assertEqual(actual, expected)
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@dtypes(torch.cdouble)
def test_complex_stft_real_equiv(self, device, dtype):
test_args = list(product(
@ -910,7 +909,7 @@ class TestFFT(TestCase):
center=center, normalized=normalized)
self.assertEqual(expected, actual)
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@dtypes(torch.cdouble)
def test_complex_istft_real_equiv(self, device, dtype):
test_args = list(product(
@ -936,7 +935,7 @@ class TestFFT(TestCase):
return_complex=True)
self.assertEqual(expected, actual)
@skipCPUIfNoMkl
@skipCPUIfNoFFT
def test_complex_stft_onesided(self, device):
# stft of complex input cannot be onesided
for x_dtype, window_dtype in product((torch.double, torch.cdouble), repeat=2):
@ -958,14 +957,14 @@ class TestFFT(TestCase):
# stft is currently warning that it requires return-complex while an upgrader is written
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
@skipCPUIfNoFFT
def test_stft_requires_complex(self, device):
x = torch.rand(100)
y = x.stft(10, pad_mode='constant')
# with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
# y = x.stft(10, pad_mode='constant')
@skipCPUIfNoMkl
@skipCPUIfNoFFT
def test_fft_input_modification(self, device):
# FFT functions should not modify their input (gh-34551)
@ -986,7 +985,7 @@ class TestFFT(TestCase):
self.assertEqual(half_spectrum, half_spectrum_copy)
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@dtypes(torch.double)
def test_istft_round_trip_simple_cases(self, device, dtype):
"""stft -> istft should recover the original signale"""
@ -999,7 +998,7 @@ class TestFFT(TestCase):
_test(torch.zeros(4, dtype=dtype, device=device), 4, 4)
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@dtypes(torch.double)
def test_istft_round_trip_various_params(self, device, dtype):
"""stft -> istft should recover the original signale"""
@ -1101,7 +1100,7 @@ class TestFFT(TestCase):
self.assertRaises(RuntimeError, torch.istft, torch.zeros((0, 3, 2)), 2)
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@dtypes(torch.double)
def test_istft_of_sine(self, device, dtype):
def _test(amplitude, L, n):
@ -1135,7 +1134,7 @@ class TestFFT(TestCase):
_test(amplitude=99, L=10, n=7)
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
@skipCPUIfNoFFT
@dtypes(torch.double)
def test_istft_linearity(self, device, dtype):
num_trials = 100
@ -1201,7 +1200,7 @@ class TestFFT(TestCase):
_test(data_size, kwargs)
@onlyOnCPUAndCUDA
@skipCPUIfNoMkl
@skipCPUIfNoFFT
def test_batch_istft(self, device):
original = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
@ -1282,7 +1281,7 @@ def generate_doc_test(doc_test):
runner.summarize()
self.fail('Doctest failed')
setattr(TestFFTDocExamples, 'test_' + doc_test.name, skipCPUIfNoMkl(test))
setattr(TestFFTDocExamples, 'test_' + doc_test.name, skipCPUIfNoFFT(test))
for doc_test in FFTDocTestFinder().find(torch.fft, globs=dict(torch=torch)):
generate_doc_test(doc_test)

View File

@ -619,6 +619,7 @@ has_lapack: _bool
has_cuda: _bool
has_mkldnn: _bool
has_cudnn: _bool
has_spectral: _bool
_GLIBCXX_USE_CXX11_ABI: _bool
default_generator: Generator

View File

@ -928,6 +928,13 @@ PyObject* initModule() {
#endif
ASSERT_TRUE(set_module_attr("has_cudnn", has_cudnn));
#if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED()
PyObject *has_spectral = Py_True;
#else
PyObject *has_spectral = Py_False;
#endif
ASSERT_TRUE(set_module_attr("has_spectral", has_spectral));
// force ATen to initialize because it handles
// setting up TH Errors so that they throw C++ exceptions
at::init();

View File

@ -1094,6 +1094,11 @@ def skipCPUIfNoLapack(fn):
return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn)
# Skips a test on CPU if FFT is not available.
def skipCPUIfNoFFT(fn):
return skipCPUIf(not torch._C.has_spectral, "PyTorch is built without FFT support")(fn)
# Skips a test on CPU if MKL is not available.
def skipCPUIfNoMkl(fn):
return skipCPUIf(not TEST_MKL, "PyTorch is built without MKL support")(fn)

View File

@ -22,7 +22,7 @@ from torch.testing import \
from .._core import _dispatch_dtypes
from torch.testing._internal.common_device_type import \
(skipIf, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfNoCusolver,
skipCPUIfNoLapack, skipCPUIfNoMkl, skipCUDAIfRocm, precisionOverride, toleranceOverride, tol)
skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIfRocm, precisionOverride, toleranceOverride, tol)
from torch.testing._internal.common_cuda import CUDA11OrLater, SM53OrLater
from torch.testing._internal.common_utils import \
(is_iterable_of_tensors,
@ -2647,7 +2647,7 @@ class SpectralFuncInfo(OpInfo):
**kwargs):
decorators = list(decorators) if decorators is not None else []
decorators += [
skipCPUIfNoMkl,
skipCPUIfNoFFT,
skipCUDAIfRocm,
# gradgrad is quite slow
DecorateInfo(slowTest, 'TestGradients', 'test_fn_gradgrad'),