# mypy: ignore-errors import unittest from functools import partial import numpy as np import torch from torch.testing import make_tensor from torch.testing._internal.common_cuda import SM53OrLater from torch.testing._internal.common_device_type import precisionOverride from torch.testing._internal.common_dtype import ( all_types_and, all_types_and_complex_and, ) from torch.testing._internal.common_utils import TEST_SCIPY, TEST_WITH_ROCM from torch.testing._internal.opinfo.core import ( DecorateInfo, ErrorInput, OpInfo, sample_inputs_spectral_ops, SampleInput, SpectralFuncInfo, SpectralFuncType, ) from torch.testing._internal.opinfo.refs import ( _find_referenced_opinfo, _inherit_constructor_args, PythonRefInfo, ) has_scipy_fft = False if TEST_SCIPY: try: import scipy.fft has_scipy_fft = True except ModuleNotFoundError: pass class SpectralFuncPythonRefInfo(SpectralFuncInfo): """ An OpInfo for a Python reference of an elementwise unary operation. """ def __init__( self, name, # the stringname of the callable Python reference *, op=None, # the function variant of the operation, populated as torch. if None torch_opinfo_name, # the string name of the corresponding torch opinfo torch_opinfo_variant="", **kwargs, ): # additional kwargs override kwargs inherited from the torch opinfo self.torch_opinfo_name = torch_opinfo_name self.torch_opinfo = _find_referenced_opinfo( torch_opinfo_name, torch_opinfo_variant, op_db=op_db ) assert isinstance(self.torch_opinfo, SpectralFuncInfo) inherited = self.torch_opinfo._original_spectral_func_args ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) super().__init__(**ukwargs) def error_inputs_fft(op_info, device, **kwargs): make_arg = partial(make_tensor, device=device, dtype=torch.float32) # Zero-dimensional tensor has no dimension to take FFT of yield ErrorInput( SampleInput(make_arg()), error_type=IndexError, error_regex="Dimension specified as -1 but tensor has no dimensions", ) def error_inputs_fftn(op_info, device, **kwargs): make_arg = partial(make_tensor, device=device, dtype=torch.float32) # Specifying a dimension on a zero-dimensional tensor yield ErrorInput( SampleInput(make_arg(), dim=(0,)), error_type=IndexError, error_regex="Dimension specified as 0 but tensor has no dimensions", ) def sample_inputs_fft_with_min( op_info, device, dtype, requires_grad=False, *, min_size, **kwargs ): yield from sample_inputs_spectral_ops( op_info, device, dtype, requires_grad, **kwargs ) if TEST_WITH_ROCM: # FIXME: Causes floating point exception on ROCm return # Check the "Invalid number of data points" error isn't too strict # https://github.com/pytorch/pytorch/pull/109083 a = make_tensor(min_size, dtype=dtype, device=device, requires_grad=requires_grad) yield SampleInput(a) def sample_inputs_fftshift(op_info, device, dtype, requires_grad, **kwargs): def mt(shape, **kwargs): return make_tensor( shape, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs ) yield SampleInput(mt((9, 10))) yield SampleInput(mt((50,)), kwargs=dict(dim=0)) yield SampleInput(mt((5, 11)), kwargs=dict(dim=(1,))) yield SampleInput(mt((5, 6)), kwargs=dict(dim=(0, 1))) yield SampleInput(mt((5, 6, 2)), kwargs=dict(dim=(0, 2))) # Operator database op_db: list[OpInfo] = [ SpectralFuncInfo( "fft.fft", aten_name="fft_fft", decomp_aten_name="_fft_c2c", ref=np.fft.fft, ndimensional=SpectralFuncType.OneD, dtypes=all_types_and_complex_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *(() if (not SM53OrLater) else (torch.half, torch.complex32)), ), sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1), error_inputs_func=error_inputs_fft, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, ), SpectralFuncInfo( "fft.fft2", aten_name="fft_fft2", ref=np.fft.fft2, decomp_aten_name="_fft_c2c", ndimensional=SpectralFuncType.TwoD, dtypes=all_types_and_complex_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *(() if (not SM53OrLater) else (torch.half, torch.complex32)), ), sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), error_inputs_func=error_inputs_fftn, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})], skips=( DecorateInfo( unittest.skip("Skipped!"), "TestCommon", "test_complex_half_reference_testing", device_type="cuda", dtypes=[torch.complex32], active_if=TEST_WITH_ROCM, ), ), ), SpectralFuncInfo( "fft.fftn", aten_name="fft_fftn", decomp_aten_name="_fft_c2c", ref=np.fft.fftn, ndimensional=SpectralFuncType.ND, dtypes=all_types_and_complex_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *(() if (not SM53OrLater) else (torch.half, torch.complex32)), ), sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), error_inputs_func=error_inputs_fftn, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})], ), SpectralFuncInfo( "fft.hfft", aten_name="fft_hfft", decomp_aten_name="_fft_c2r", ref=np.fft.hfft, ndimensional=SpectralFuncType.OneD, dtypes=all_types_and_complex_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *(() if (not SM53OrLater) else (torch.half, torch.complex32)), ), sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=2), error_inputs_func=error_inputs_fft, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, check_batched_gradgrad=False, skips=( # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 DecorateInfo( unittest.skip("Skipped!"), "TestSchemaCheckModeOpInfo", "test_schema_correctness", dtypes=(torch.complex64, torch.complex128), ), ), ), SpectralFuncInfo( "fft.hfft2", aten_name="fft_hfft2", decomp_aten_name="_fft_c2r", ref=scipy.fft.hfft2 if has_scipy_fft else None, ndimensional=SpectralFuncType.TwoD, dtypes=all_types_and_complex_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *(() if (not SM53OrLater) else (torch.half, torch.complex32)), ), sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(2, 2)), error_inputs_func=error_inputs_fftn, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_gradgrad=False, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, decorators=[ DecorateInfo( precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}), "TestFFT", "test_reference_nd", ), ], skips=( # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 DecorateInfo( unittest.skip("Skipped!"), "TestSchemaCheckModeOpInfo", "test_schema_correctness", ), # FIXME: errors are too large; needs investigation DecorateInfo( unittest.skip("Skipped!"), "TestCommon", "test_complex_half_reference_testing", device_type="cuda", ), ), ), SpectralFuncInfo( "fft.hfftn", aten_name="fft_hfftn", decomp_aten_name="_fft_c2r", ref=scipy.fft.hfftn if has_scipy_fft else None, ndimensional=SpectralFuncType.ND, dtypes=all_types_and_complex_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *(() if (not SM53OrLater) else (torch.half, torch.complex32)), ), sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(2, 2)), error_inputs_func=error_inputs_fftn, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_gradgrad=False, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, decorators=[ DecorateInfo( precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}), "TestFFT", "test_reference_nd", ), ], skips=( # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 DecorateInfo( unittest.skip("Skipped!"), "TestSchemaCheckModeOpInfo", "test_schema_correctness", ), ), ), SpectralFuncInfo( "fft.rfft", aten_name="fft_rfft", decomp_aten_name="_fft_r2c", ref=np.fft.rfft, ndimensional=SpectralFuncType.OneD, dtypes=all_types_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and( torch.bool, *(() if (not SM53OrLater) else (torch.half,)) ), sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1), error_inputs_func=error_inputs_fft, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_grad=False, skips=(), check_batched_gradgrad=False, ), SpectralFuncInfo( "fft.rfft2", aten_name="fft_rfft2", decomp_aten_name="_fft_r2c", ref=np.fft.rfft2, ndimensional=SpectralFuncType.TwoD, dtypes=all_types_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and( torch.bool, *(() if (not SM53OrLater) else (torch.half,)) ), sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), error_inputs_func=error_inputs_fftn, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_grad=False, check_batched_gradgrad=False, decorators=[ precisionOverride({torch.float: 1e-4}), ], ), SpectralFuncInfo( "fft.rfftn", aten_name="fft_rfftn", decomp_aten_name="_fft_r2c", ref=np.fft.rfftn, ndimensional=SpectralFuncType.ND, dtypes=all_types_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and( torch.bool, *(() if (not SM53OrLater) else (torch.half,)) ), sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), error_inputs_func=error_inputs_fftn, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_grad=False, check_batched_gradgrad=False, decorators=[ precisionOverride({torch.float: 1e-4}), ], ), SpectralFuncInfo( "fft.ifft", aten_name="fft_ifft", decomp_aten_name="_fft_c2c", ref=np.fft.ifft, ndimensional=SpectralFuncType.OneD, sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1), error_inputs_func=error_inputs_fft, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, dtypes=all_types_and_complex_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *(() if (not SM53OrLater) else (torch.half, torch.complex32)), ), ), SpectralFuncInfo( "fft.ifft2", aten_name="fft_ifft2", decomp_aten_name="_fft_c2c", ref=np.fft.ifft2, ndimensional=SpectralFuncType.TwoD, sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), error_inputs_func=error_inputs_fftn, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, dtypes=all_types_and_complex_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *(() if (not SM53OrLater) else (torch.half, torch.complex32)), ), decorators=[ DecorateInfo( precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), "TestFFT", "test_reference_nd", ) ], ), SpectralFuncInfo( "fft.ifftn", aten_name="fft_ifftn", decomp_aten_name="_fft_c2c", ref=np.fft.ifftn, ndimensional=SpectralFuncType.ND, sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), error_inputs_func=error_inputs_fftn, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, dtypes=all_types_and_complex_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *(() if (not SM53OrLater) else (torch.half, torch.complex32)), ), decorators=[ DecorateInfo( precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), "TestFFT", "test_reference_nd", ) ], ), SpectralFuncInfo( "fft.ihfft", aten_name="fft_ihfft", decomp_aten_name="_fft_r2c", ref=np.fft.ihfft, ndimensional=SpectralFuncType.OneD, sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), error_inputs_func=error_inputs_fft, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, dtypes=all_types_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and( torch.bool, *(() if (not SM53OrLater) else (torch.half,)) ), skips=(), check_batched_grad=False, ), SpectralFuncInfo( "fft.ihfft2", aten_name="fft_ihfft2", decomp_aten_name="_fft_r2c", ref=scipy.fft.ihfftn if has_scipy_fft else None, ndimensional=SpectralFuncType.TwoD, sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), error_inputs_func=error_inputs_fftn, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, dtypes=all_types_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and( torch.bool, *(() if (not SM53OrLater) else (torch.half,)) ), check_batched_grad=False, check_batched_gradgrad=False, decorators=( # The values for attribute 'shape' do not match: torch.Size([5, 6, 5]) != torch.Size([5, 6, 6]). DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"), DecorateInfo( precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd" ), # Mismatched elements! DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out"), DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warnings"), ), ), SpectralFuncInfo( "fft.ihfftn", aten_name="fft_ihfftn", decomp_aten_name="_fft_r2c", ref=scipy.fft.ihfftn if has_scipy_fft else None, ndimensional=SpectralFuncType.ND, sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), error_inputs_func=error_inputs_fftn, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, dtypes=all_types_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archss dtypesIfCUDA=all_types_and( torch.bool, *(() if (not SM53OrLater) else (torch.half,)) ), check_batched_grad=False, check_batched_gradgrad=False, decorators=[ # The values for attribute 'shape' do not match: torch.Size([5, 6, 5]) != torch.Size([5, 6, 6]). DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"), # Mismatched elements! DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out"), DecorateInfo( precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd" ), ], ), SpectralFuncInfo( "fft.irfft", aten_name="fft_irfft", decomp_aten_name="_fft_c2r", ref=np.fft.irfft, ndimensional=SpectralFuncType.OneD, sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)), error_inputs_func=error_inputs_fft, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, dtypes=all_types_and_complex_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *(() if (not SM53OrLater) else (torch.half, torch.complex32)), ), check_batched_gradgrad=False, ), SpectralFuncInfo( "fft.irfft2", aten_name="fft_irfft2", decomp_aten_name="_fft_c2r", ref=np.fft.irfft2, ndimensional=SpectralFuncType.TwoD, sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)), error_inputs_func=error_inputs_fftn, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, dtypes=all_types_and_complex_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *(() if (not SM53OrLater) else (torch.half, torch.complex32)), ), check_batched_gradgrad=False, decorators=[ DecorateInfo( precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), "TestFFT", "test_reference_nd", ) ], ), SpectralFuncInfo( "fft.irfftn", aten_name="fft_irfftn", decomp_aten_name="_fft_c2r", ref=np.fft.irfftn, ndimensional=SpectralFuncType.ND, sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)), error_inputs_func=error_inputs_fftn, # https://github.com/pytorch/pytorch/issues/80411 gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, dtypes=all_types_and_complex_and(torch.bool), # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *(() if (not SM53OrLater) else (torch.half, torch.complex32)), ), check_batched_gradgrad=False, decorators=[ DecorateInfo( precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), "TestFFT", "test_reference_nd", ) ], ), OpInfo( "fft.fftshift", dtypes=all_types_and_complex_and( torch.bool, torch.bfloat16, torch.half, torch.chalf ), sample_inputs_func=sample_inputs_fftshift, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, ), OpInfo( "fft.ifftshift", dtypes=all_types_and_complex_and( torch.bool, torch.bfloat16, torch.half, torch.chalf ), sample_inputs_func=sample_inputs_fftshift, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, ), ] python_ref_db: list[OpInfo] = [ SpectralFuncPythonRefInfo( "_refs.fft.fft", torch_opinfo_name="fft.fft", ), SpectralFuncPythonRefInfo( "_refs.fft.ifft", torch_opinfo_name="fft.ifft", ), SpectralFuncPythonRefInfo( "_refs.fft.rfft", torch_opinfo_name="fft.rfft", ), SpectralFuncPythonRefInfo( "_refs.fft.irfft", torch_opinfo_name="fft.irfft", ), SpectralFuncPythonRefInfo( "_refs.fft.hfft", torch_opinfo_name="fft.hfft", ), SpectralFuncPythonRefInfo( "_refs.fft.ihfft", torch_opinfo_name="fft.ihfft", ), SpectralFuncPythonRefInfo( "_refs.fft.fftn", torch_opinfo_name="fft.fftn", decorators=[ DecorateInfo( precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), "TestFFT", "test_reference_nd", ) ], ), SpectralFuncPythonRefInfo( "_refs.fft.ifftn", torch_opinfo_name="fft.ifftn", decorators=[ DecorateInfo( precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), "TestFFT", "test_reference_nd", ) ], ), SpectralFuncPythonRefInfo( "_refs.fft.rfftn", torch_opinfo_name="fft.rfftn", ), SpectralFuncPythonRefInfo( "_refs.fft.irfftn", torch_opinfo_name="fft.irfftn", decorators=[ DecorateInfo( precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), "TestFFT", "test_reference_nd", ) ], ), SpectralFuncPythonRefInfo( "_refs.fft.hfftn", torch_opinfo_name="fft.hfftn", decorators=[ DecorateInfo( precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}), "TestFFT", "test_reference_nd", ) ], ), SpectralFuncPythonRefInfo( "_refs.fft.ihfftn", torch_opinfo_name="fft.ihfftn", decorators=[ DecorateInfo( precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd", ), # AssertionError: Reference result was farther (0.09746177145360499) from the precise # computation than the torch result was (0.09111555632069855) DecorateInfo( unittest.skip("Skipped!"), "TestCommon", "test_python_ref_torch_fallback", dtypes=(torch.float16,), device_type="cuda", ), # AssertionError: Reference result was farther (0.0953431016138116) from the precise # computation than the torch result was (0.09305490684430734) DecorateInfo( unittest.skip("Skipped!"), "TestCommon", "test_python_ref_executor", dtypes=(torch.float16,), device_type="cuda", ), ], ), SpectralFuncPythonRefInfo( "_refs.fft.fft2", torch_opinfo_name="fft.fft2", ), SpectralFuncPythonRefInfo( "_refs.fft.ifft2", torch_opinfo_name="fft.ifft2", decorators=[ DecorateInfo( precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), "TestFFT", "test_reference_nd", ) ], ), SpectralFuncPythonRefInfo( "_refs.fft.rfft2", torch_opinfo_name="fft.rfft2", ), SpectralFuncPythonRefInfo( "_refs.fft.irfft2", torch_opinfo_name="fft.irfft2", decorators=[ DecorateInfo( precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), "TestFFT", "test_reference_nd", ) ], ), SpectralFuncPythonRefInfo( "_refs.fft.hfft2", torch_opinfo_name="fft.hfft2", decorators=[ DecorateInfo( precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}), "TestFFT", "test_reference_nd", ) ], ), SpectralFuncPythonRefInfo( "_refs.fft.ihfft2", torch_opinfo_name="fft.ihfft2", decorators=[ DecorateInfo( precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd", ), # FIXME: # Reference result was farther (0.0953431016138116) from the precise computation # than the torch result was (0.09305490684430734)! DecorateInfo( unittest.skip("Skipped!"), "TestCommon", "test_python_ref_executor", device_type="cuda", ), ], ), PythonRefInfo( "_refs.fft.fftshift", op_db=op_db, torch_opinfo_name="fft.fftshift", ), PythonRefInfo( "_refs.fft.ifftshift", op_db=op_db, torch_opinfo_name="fft.ifftshift", ), ]