mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add PocketFFT support (#60976)
Summary: Needed on platforms, that do not have MKL, such as aarch64 and M1 - Add `AT_POCKETFFT_ENABLED()` to Config.h.in - Introduce torch._C.has_spectral that is true if PyTorch was compiled with either MKL or PocketFFT - Modify spectral test to use skipCPUIfNoFFT instead of skipCPUIfNoMKL Share implementation of `_out` functions as well as fft_fill_with_conjugate_symmetry_stub between MKL and PocketFFT implementations Fixes https://github.com/pytorch/pytorch/issues/41592 Pull Request resolved: https://github.com/pytorch/pytorch/pull/60976 Reviewed By: walterddr, driazati, janeyx99, samestep Differential Revision: D29466530 Pulled By: malfet fbshipit-source-id: ac5edb3d40e7c413267825f92a5e8bc4bb249caf
This commit is contained in:
committed by
Facebook GitHub Bot
parent
2d0c6e60a7
commit
4036820506
@ -516,6 +516,7 @@ header_template_rule(
|
||||
"@AT_MKLDNN_ENABLED@": "1",
|
||||
"@AT_MKL_ENABLED@": "0",
|
||||
"@AT_FFTW_ENABLED@": "0",
|
||||
"@AT_POCKETFFT_ENABLED@": "0",
|
||||
"@AT_NNPACK_ENABLED@": "0",
|
||||
"@CAFFE2_STATIC_LINK_CUDA_INT@": "0",
|
||||
"@AT_BUILD_WITH_BLAS@": "1",
|
||||
|
@ -9,6 +9,7 @@
|
||||
#define AT_MKLDNN_ENABLED() @AT_MKLDNN_ENABLED@
|
||||
#define AT_MKL_ENABLED() @AT_MKL_ENABLED@
|
||||
#define AT_FFTW_ENABLED() @AT_FFTW_ENABLED@
|
||||
#define AT_POCKETFFT_ENABLED() @AT_POCKETFFT_ENABLED@
|
||||
#define AT_NNPACK_ENABLED() @AT_NNPACK_ENABLED@
|
||||
#define CAFFE2_STATIC_LINK_CUDA() @CAFFE2_STATIC_LINK_CUDA_INT@
|
||||
#define AT_BUILD_WITH_BLAS() @AT_BUILD_WITH_BLAS@
|
||||
|
@ -6,66 +6,10 @@
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#if !AT_MKL_ENABLED()
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
REGISTER_NO_CPU_DISPATCH(fft_fill_with_conjugate_symmetry_stub, fft_fill_with_conjugate_symmetry_fn);
|
||||
|
||||
Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
|
||||
AT_ERROR("fft: ATen not compiled with MKL support");
|
||||
}
|
||||
|
||||
Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
|
||||
AT_ERROR("fft: ATen not compiled with MKL support");
|
||||
}
|
||||
|
||||
Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
|
||||
AT_ERROR("fft: ATen not compiled with MKL support");
|
||||
}
|
||||
|
||||
Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||
bool onesided, Tensor& out) {
|
||||
AT_ERROR("fft: ATen not compiled with MKL support");
|
||||
}
|
||||
|
||||
Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||
int64_t last_dim_size, Tensor& out) {
|
||||
AT_ERROR("fft: ATen not compiled with MKL support");
|
||||
}
|
||||
|
||||
Tensor& _fft_c2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||
bool forward, Tensor& out) {
|
||||
AT_ERROR("fft: ATen not compiled with MKL support");
|
||||
}
|
||||
|
||||
}}
|
||||
|
||||
#else // AT_MKL_ENABLED
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED()
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/Utils.h>
|
||||
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
#include <cmath>
|
||||
|
||||
#include <mkl_dfti.h>
|
||||
#include <ATen/mkl/Exceptions.h>
|
||||
#include <ATen/mkl/Descriptors.h>
|
||||
#include <ATen/mkl/Limits.h>
|
||||
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
// In real-to-complex transform, MKL FFT only fills half of the values due to
|
||||
// conjugate symmetry. See native/SpectralUtils.h for more details.
|
||||
// The following structs are used to fill in the other half with symmetry in
|
||||
@ -208,6 +152,186 @@ REGISTER_ARCH_DISPATCH(fft_fill_with_conjugate_symmetry_stub, DEFAULT, &_fft_fil
|
||||
REGISTER_AVX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
|
||||
REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
|
||||
|
||||
// _out variants can be shared between PocketFFT and MKL
|
||||
Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||
bool onesided, Tensor& out) {
|
||||
auto result = _fft_r2c_mkl(self, dim, normalization, /*onesided=*/true);
|
||||
if (onesided) {
|
||||
resize_output(out, result.sizes());
|
||||
return out.copy_(result);
|
||||
}
|
||||
|
||||
resize_output(out, self.sizes());
|
||||
|
||||
auto last_dim = dim.back();
|
||||
auto last_dim_halfsize = result.sizes()[last_dim];
|
||||
auto out_slice = out.slice(last_dim, 0, last_dim_halfsize);
|
||||
out_slice.copy_(result);
|
||||
at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||
int64_t last_dim_size, Tensor& out) {
|
||||
auto result = _fft_c2r_mkl(self, dim, normalization, last_dim_size);
|
||||
resize_output(out, result.sizes());
|
||||
return out.copy_(result);
|
||||
}
|
||||
|
||||
Tensor& _fft_c2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||
bool forward, Tensor& out) {
|
||||
auto result = _fft_c2c_mkl(self, dim, normalization, forward);
|
||||
resize_output(out, result.sizes());
|
||||
return out.copy_(result);
|
||||
}
|
||||
|
||||
}} // namespace at::native
|
||||
#endif /* AT_MKL_ENALED() || AT_POCKETFFT_ENABLED() */
|
||||
|
||||
#if AT_POCKETFFT_ENABLED()
|
||||
#include <pocketfft_hdronly.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
namespace {
|
||||
using namespace pocketfft;
|
||||
|
||||
stride_t stride_from_tensor(const Tensor& t) {
|
||||
stride_t stride(t.strides().begin(), t.strides().end());
|
||||
for(auto& s: stride) {
|
||||
s *= t.element_size();
|
||||
}
|
||||
return stride;
|
||||
}
|
||||
|
||||
inline shape_t shape_from_tensor(const Tensor& t) {
|
||||
return shape_t(t.sizes().begin(), t.sizes().end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline std::complex<T> *tensor_cdata(Tensor& t) {
|
||||
return reinterpret_cast<std::complex<T>*>(t.data<c10::complex<T>>());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline const std::complex<T> *tensor_cdata(const Tensor& t) {
|
||||
return reinterpret_cast<const std::complex<T>*>(t.data<c10::complex<T>>());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T compute_fct(int64_t size, int64_t normalization) {
|
||||
constexpr auto one = static_cast<T>(1);
|
||||
switch (static_cast<fft_norm_mode>(normalization)) {
|
||||
case fft_norm_mode::none: return one;
|
||||
case fft_norm_mode::by_n: return one / static_cast<T>(size);
|
||||
case fft_norm_mode::by_root_n: return one / std::sqrt(static_cast<T>(size));
|
||||
}
|
||||
AT_ERROR("Unsupported normalization type", normalization);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T compute_fct(const Tensor& t, IntArrayRef dim, int64_t normalization) {
|
||||
if (static_cast<fft_norm_mode>(normalization) == fft_norm_mode::none) {
|
||||
return static_cast<T>(1);
|
||||
}
|
||||
const auto& sizes = t.sizes();
|
||||
int64_t n = 1;
|
||||
for(auto idx: dim) {
|
||||
n *= sizes[idx];
|
||||
}
|
||||
return compute_fct<T>(n, normalization);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
|
||||
auto in_sizes = self.sizes();
|
||||
DimVector out_sizes(in_sizes.begin(), in_sizes.end());
|
||||
out_sizes[dim.back()] = last_dim_size;
|
||||
auto out = at::empty(out_sizes, self.options().dtype(c10::toValueType(self.scalar_type())));
|
||||
pocketfft::shape_t axes(dim.begin(), dim.end());
|
||||
if (self.scalar_type() == kComplexFloat) {
|
||||
pocketfft::c2r(shape_from_tensor(out), stride_from_tensor(self), stride_from_tensor(out), axes, false,
|
||||
tensor_cdata<float>(self),
|
||||
out.data<float>(), compute_fct<float>(out, dim, normalization));
|
||||
} else {
|
||||
pocketfft::c2r(shape_from_tensor(out), stride_from_tensor(self), stride_from_tensor(out), axes, false,
|
||||
tensor_cdata<double>(self),
|
||||
out.data<double>(), compute_fct<double>(out, dim, normalization));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
|
||||
TORCH_CHECK(self.is_floating_point());
|
||||
auto input_sizes = self.sizes();
|
||||
DimVector out_sizes(input_sizes.begin(), input_sizes.end());
|
||||
auto last_dim = dim.back();
|
||||
auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
|
||||
if (onesided) {
|
||||
out_sizes[last_dim] = last_dim_halfsize;
|
||||
}
|
||||
|
||||
auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type())));
|
||||
pocketfft::shape_t axes(dim.begin(), dim.end());
|
||||
if (self.scalar_type() == kFloat) {
|
||||
pocketfft::r2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, true,
|
||||
self.data<float>(),
|
||||
tensor_cdata<float>(out), compute_fct<float>(self, dim, normalization));
|
||||
} else {
|
||||
pocketfft::r2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, true,
|
||||
self.data<double>(),
|
||||
tensor_cdata<double>(out), compute_fct<double>(self, dim, normalization));
|
||||
}
|
||||
|
||||
if (!onesided) {
|
||||
at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
|
||||
TORCH_CHECK(self.is_complex());
|
||||
auto out = at::empty(self.sizes(), self.options());
|
||||
pocketfft::shape_t axes(dim.begin(), dim.end());
|
||||
if (self.scalar_type() == kComplexFloat) {
|
||||
pocketfft::c2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, forward,
|
||||
tensor_cdata<float>(self),
|
||||
tensor_cdata<float>(out), compute_fct<float>(self, dim, normalization));
|
||||
} else {
|
||||
pocketfft::c2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, forward,
|
||||
tensor_cdata<double>(self),
|
||||
tensor_cdata<double>(out), compute_fct<double>(self, dim, normalization));
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
}}
|
||||
|
||||
#elif AT_MKL_ENABLED()
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Utils.h>
|
||||
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
#include <cmath>
|
||||
|
||||
#include <mkl_dfti.h>
|
||||
#include <ATen/mkl/Exceptions.h>
|
||||
#include <ATen/mkl/Descriptors.h>
|
||||
#include <ATen/mkl/Limits.h>
|
||||
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
// Constructs an mkl-fft plan descriptor representing the desired transform
|
||||
// For complex types, strides are in units of 2 * element_size(dtype)
|
||||
// sizes are for the full signal, including batch size and always two-sided
|
||||
@ -401,13 +525,6 @@ Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||
return _exec_fft(out, input, out_sizes, dim, normalization, /*forward=*/false);
|
||||
}
|
||||
|
||||
Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||
int64_t last_dim_size, Tensor& out) {
|
||||
auto result = _fft_c2r_mkl(self, dim, normalization, last_dim_size);
|
||||
resize_output(out, result.sizes());
|
||||
return out.copy_(result);
|
||||
}
|
||||
|
||||
// n-dimensional real to complex FFT
|
||||
Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
|
||||
TORCH_CHECK(self.is_floating_point());
|
||||
@ -429,24 +546,6 @@ Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||
bool onesided, Tensor& out) {
|
||||
auto result = _fft_r2c_mkl(self, dim, normalization, /*onesided=*/true);
|
||||
if (onesided) {
|
||||
resize_output(out, result.sizes());
|
||||
return out.copy_(result);
|
||||
}
|
||||
|
||||
resize_output(out, self.sizes());
|
||||
|
||||
auto last_dim = dim.back();
|
||||
auto last_dim_halfsize = result.sizes()[last_dim];
|
||||
auto out_slice = out.slice(last_dim, 0, last_dim_halfsize);
|
||||
out_slice.copy_(result);
|
||||
at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
|
||||
return out;
|
||||
}
|
||||
|
||||
// n-dimensional complex to complex FFT/IFFT
|
||||
Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
|
||||
TORCH_CHECK(self.is_complex());
|
||||
@ -455,13 +554,40 @@ Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||
return _exec_fft(out, self, self.sizes(), sorted_dims, normalization, forward);
|
||||
}
|
||||
|
||||
}} // namespace at::native
|
||||
|
||||
#else
|
||||
|
||||
namespace at { namespace native {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
REGISTER_NO_CPU_DISPATCH(fft_fill_with_conjugate_symmetry_stub, fft_fill_with_conjugate_symmetry_fn);
|
||||
|
||||
Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
|
||||
AT_ERROR("fft: ATen not compiled with FFT support");
|
||||
}
|
||||
|
||||
Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
|
||||
AT_ERROR("fft: ATen not compiled with FFT support");
|
||||
}
|
||||
|
||||
Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
|
||||
AT_ERROR("fft: ATen not compiled with FFT support");
|
||||
}
|
||||
|
||||
Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||
bool onesided, Tensor& out) {
|
||||
AT_ERROR("fft: ATen not compiled with FFT support");
|
||||
}
|
||||
|
||||
Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||
int64_t last_dim_size, Tensor& out) {
|
||||
AT_ERROR("fft: ATen not compiled with FFT support");
|
||||
}
|
||||
|
||||
Tensor& _fft_c2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||
bool forward, Tensor& out) {
|
||||
auto result = _fft_c2c_mkl(self, dim, normalization, forward);
|
||||
resize_output(out, result.sizes());
|
||||
return out.copy_(result);
|
||||
AT_ERROR("fft: ATen not compiled with FFT support");
|
||||
}
|
||||
|
||||
}} // namespace at::native
|
||||
|
||||
#endif
|
||||
|
@ -518,6 +518,11 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/tensorexpr/llvm_codegen.cpp PROPERTIES COMPILE_FLAGS -Wno-init-list-lifetime)
|
||||
endif()
|
||||
|
||||
# Pass path to PocketFFT
|
||||
if(AT_POCKETFFT_ENABLED)
|
||||
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/mkl/SpectralOps.cpp PROPERTIES INCLUDE_DIRECTORIES "${POCKETFFT_INCLUDE_DIR}")
|
||||
endif()
|
||||
|
||||
if(NOT INTERN_DISABLE_MOBILE_INTERP)
|
||||
set(MOBILE_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/jit/mobile/function.cpp
|
||||
|
@ -216,6 +216,18 @@ if(USE_FFTW OR NOT MKL_FOUND)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# --- [ PocketFFT
|
||||
set(AT_POCKETFFT_ENABLED 0)
|
||||
if(NOT MKL_FOUND)
|
||||
find_path(POCKETFFT_INCLUDE_DIR NAMES pocketfft_hdronly.h
|
||||
PATHS /usr/local/include
|
||||
PATHS $ENV{POCKETFFT_HOME}
|
||||
)
|
||||
if(POCKETFFT_INCLUDE_DIR)
|
||||
set(AT_POCKETFFT_ENABLED 1)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ---[ Dependencies
|
||||
# NNPACK and family (QNNPACK, PYTORCH_QNNPACK, and XNNPACK) can download and
|
||||
# compile their dependencies in isolation as part of their build. These dependencies
|
||||
|
@ -11,8 +11,7 @@ from torch.testing._internal.common_utils import \
|
||||
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, dtypes, onlyOnCPUAndCUDA,
|
||||
skipCPUIfNoMkl, deviceCountAtLeast, onlyCUDA, OpDTypes,
|
||||
skipIf)
|
||||
skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf)
|
||||
from torch.testing._internal.common_methods_invocations import spectral_funcs, SpectralFuncInfo
|
||||
|
||||
from setuptools import distutils
|
||||
@ -167,7 +166,7 @@ class TestFFT(TestCase):
|
||||
actual = op(input, *args)
|
||||
self.assertEqual(actual, expected, exact_dtype=exact_dtype)
|
||||
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
|
||||
def test_fft_round_trip(self, device, dtype):
|
||||
@ -228,7 +227,7 @@ class TestFFT(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "ihfft expects a real input tensor"):
|
||||
torch.fft.ihfft(t)
|
||||
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(torch.int8, torch.float, torch.double, torch.complex64, torch.complex128)
|
||||
def test_fft_type_promotion(self, device, dtype):
|
||||
@ -310,7 +309,7 @@ class TestFFT(TestCase):
|
||||
actual = op(input, s, dim, norm)
|
||||
self.assertEqual(actual, expected, exact_dtype=exact_dtype)
|
||||
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
|
||||
def test_fftn_round_trip(self, device, dtype):
|
||||
@ -375,7 +374,7 @@ class TestFFT(TestCase):
|
||||
# NOTE: 2d transforms are only thin wrappers over n-dim transforms,
|
||||
# so don't require exhaustive testing.
|
||||
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(torch.double, torch.complex128)
|
||||
def test_fft2_numpy(self, device, dtype):
|
||||
@ -423,7 +422,7 @@ class TestFFT(TestCase):
|
||||
actual = fn(valid_input, s, dim, norm)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(torch.float, torch.complex64)
|
||||
def test_fft2_fftn_equivalence(self, device, dtype):
|
||||
@ -460,7 +459,7 @@ class TestFFT(TestCase):
|
||||
|
||||
self.assertEqual(actual, expect)
|
||||
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@onlyOnCPUAndCUDA
|
||||
def test_fft2_invalid(self, device):
|
||||
a = torch.rand(10, 10, 10, device=device)
|
||||
@ -486,7 +485,7 @@ class TestFFT(TestCase):
|
||||
|
||||
# Helper functions
|
||||
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@onlyOnCPUAndCUDA
|
||||
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
|
||||
@dtypes(torch.float, torch.double)
|
||||
@ -512,7 +511,7 @@ class TestFFT(TestCase):
|
||||
actual = torch_fn(*args, device=device, dtype=dtype)
|
||||
self.assertEqual(actual, expected, exact_dtype=False)
|
||||
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_fftfreq_out(self, device, dtype):
|
||||
@ -524,7 +523,7 @@ class TestFFT(TestCase):
|
||||
self.assertEqual(actual, expect)
|
||||
|
||||
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@onlyOnCPUAndCUDA
|
||||
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
|
||||
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
|
||||
@ -550,7 +549,7 @@ class TestFFT(TestCase):
|
||||
actual = torch_fn(input, dim=dim)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@onlyOnCPUAndCUDA
|
||||
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
|
||||
@dtypes(torch.float, torch.double)
|
||||
@ -629,7 +628,7 @@ class TestFFT(TestCase):
|
||||
_test_complex((40, 60, 3, 80), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:])
|
||||
_test_complex((30, 55, 50, 22), 3, lambda x: x[:, 3:53, 15:40, 1:21])
|
||||
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(torch.double)
|
||||
def test_fft_ifft_rfft_irfft(self, device, dtype):
|
||||
@ -708,7 +707,7 @@ class TestFFT(TestCase):
|
||||
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1
|
||||
|
||||
# passes on ROCm w/ python 2.7, fails w/ python 3.6
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(torch.double)
|
||||
def test_stft(self, device, dtype):
|
||||
@ -775,7 +774,7 @@ class TestFFT(TestCase):
|
||||
_test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError)
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
def test_complex_stft_roundtrip(self, device, dtype):
|
||||
test_args = list(product(
|
||||
@ -817,7 +816,7 @@ class TestFFT(TestCase):
|
||||
self.assertEqual(x_roundtrip, x)
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
def test_stft_roundtrip_complex_window(self, device, dtype):
|
||||
test_args = list(product(
|
||||
@ -857,7 +856,7 @@ class TestFFT(TestCase):
|
||||
self.assertEqual(x_roundtrip, x)
|
||||
|
||||
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@dtypes(torch.cdouble)
|
||||
def test_complex_stft_definition(self, device, dtype):
|
||||
test_args = list(product(
|
||||
@ -877,7 +876,7 @@ class TestFFT(TestCase):
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@dtypes(torch.cdouble)
|
||||
def test_complex_stft_real_equiv(self, device, dtype):
|
||||
test_args = list(product(
|
||||
@ -910,7 +909,7 @@ class TestFFT(TestCase):
|
||||
center=center, normalized=normalized)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@dtypes(torch.cdouble)
|
||||
def test_complex_istft_real_equiv(self, device, dtype):
|
||||
test_args = list(product(
|
||||
@ -936,7 +935,7 @@ class TestFFT(TestCase):
|
||||
return_complex=True)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
def test_complex_stft_onesided(self, device):
|
||||
# stft of complex input cannot be onesided
|
||||
for x_dtype, window_dtype in product((torch.double, torch.cdouble), repeat=2):
|
||||
@ -958,14 +957,14 @@ class TestFFT(TestCase):
|
||||
|
||||
# stft is currently warning that it requires return-complex while an upgrader is written
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
def test_stft_requires_complex(self, device):
|
||||
x = torch.rand(100)
|
||||
y = x.stft(10, pad_mode='constant')
|
||||
# with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
|
||||
# y = x.stft(10, pad_mode='constant')
|
||||
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
def test_fft_input_modification(self, device):
|
||||
# FFT functions should not modify their input (gh-34551)
|
||||
|
||||
@ -986,7 +985,7 @@ class TestFFT(TestCase):
|
||||
self.assertEqual(half_spectrum, half_spectrum_copy)
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@dtypes(torch.double)
|
||||
def test_istft_round_trip_simple_cases(self, device, dtype):
|
||||
"""stft -> istft should recover the original signale"""
|
||||
@ -999,7 +998,7 @@ class TestFFT(TestCase):
|
||||
_test(torch.zeros(4, dtype=dtype, device=device), 4, 4)
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@dtypes(torch.double)
|
||||
def test_istft_round_trip_various_params(self, device, dtype):
|
||||
"""stft -> istft should recover the original signale"""
|
||||
@ -1101,7 +1100,7 @@ class TestFFT(TestCase):
|
||||
self.assertRaises(RuntimeError, torch.istft, torch.zeros((0, 3, 2)), 2)
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@dtypes(torch.double)
|
||||
def test_istft_of_sine(self, device, dtype):
|
||||
def _test(amplitude, L, n):
|
||||
@ -1135,7 +1134,7 @@ class TestFFT(TestCase):
|
||||
_test(amplitude=99, L=10, n=7)
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
@dtypes(torch.double)
|
||||
def test_istft_linearity(self, device, dtype):
|
||||
num_trials = 100
|
||||
@ -1201,7 +1200,7 @@ class TestFFT(TestCase):
|
||||
_test(data_size, kwargs)
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@skipCPUIfNoMkl
|
||||
@skipCPUIfNoFFT
|
||||
def test_batch_istft(self, device):
|
||||
original = torch.tensor([
|
||||
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
|
||||
@ -1282,7 +1281,7 @@ def generate_doc_test(doc_test):
|
||||
runner.summarize()
|
||||
self.fail('Doctest failed')
|
||||
|
||||
setattr(TestFFTDocExamples, 'test_' + doc_test.name, skipCPUIfNoMkl(test))
|
||||
setattr(TestFFTDocExamples, 'test_' + doc_test.name, skipCPUIfNoFFT(test))
|
||||
|
||||
for doc_test in FFTDocTestFinder().find(torch.fft, globs=dict(torch=torch)):
|
||||
generate_doc_test(doc_test)
|
||||
|
@ -619,6 +619,7 @@ has_lapack: _bool
|
||||
has_cuda: _bool
|
||||
has_mkldnn: _bool
|
||||
has_cudnn: _bool
|
||||
has_spectral: _bool
|
||||
_GLIBCXX_USE_CXX11_ABI: _bool
|
||||
default_generator: Generator
|
||||
|
||||
|
@ -928,6 +928,13 @@ PyObject* initModule() {
|
||||
#endif
|
||||
ASSERT_TRUE(set_module_attr("has_cudnn", has_cudnn));
|
||||
|
||||
#if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED()
|
||||
PyObject *has_spectral = Py_True;
|
||||
#else
|
||||
PyObject *has_spectral = Py_False;
|
||||
#endif
|
||||
ASSERT_TRUE(set_module_attr("has_spectral", has_spectral));
|
||||
|
||||
// force ATen to initialize because it handles
|
||||
// setting up TH Errors so that they throw C++ exceptions
|
||||
at::init();
|
||||
|
@ -1094,6 +1094,11 @@ def skipCPUIfNoLapack(fn):
|
||||
return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn)
|
||||
|
||||
|
||||
# Skips a test on CPU if FFT is not available.
|
||||
def skipCPUIfNoFFT(fn):
|
||||
return skipCPUIf(not torch._C.has_spectral, "PyTorch is built without FFT support")(fn)
|
||||
|
||||
|
||||
# Skips a test on CPU if MKL is not available.
|
||||
def skipCPUIfNoMkl(fn):
|
||||
return skipCPUIf(not TEST_MKL, "PyTorch is built without MKL support")(fn)
|
||||
|
@ -22,7 +22,7 @@ from torch.testing import \
|
||||
from .._core import _dispatch_dtypes
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(skipIf, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfNoCusolver,
|
||||
skipCPUIfNoLapack, skipCPUIfNoMkl, skipCUDAIfRocm, precisionOverride, toleranceOverride, tol)
|
||||
skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIfRocm, precisionOverride, toleranceOverride, tol)
|
||||
from torch.testing._internal.common_cuda import CUDA11OrLater, SM53OrLater
|
||||
from torch.testing._internal.common_utils import \
|
||||
(is_iterable_of_tensors,
|
||||
@ -2647,7 +2647,7 @@ class SpectralFuncInfo(OpInfo):
|
||||
**kwargs):
|
||||
decorators = list(decorators) if decorators is not None else []
|
||||
decorators += [
|
||||
skipCPUIfNoMkl,
|
||||
skipCPUIfNoFFT,
|
||||
skipCUDAIfRocm,
|
||||
# gradgrad is quite slow
|
||||
DecorateInfo(slowTest, 'TestGradients', 'test_fn_gradgrad'),
|
||||
|
Reference in New Issue
Block a user