ekamiti
2023-09-09 04:48:09 +00:00
committed by PyTorch MergeBot
parent d4230e5574
commit 0f88d93b10
3 changed files with 31 additions and 105 deletions

View File

@ -325,12 +325,7 @@ class TestFFT(TestCase):
# TODO: Remove torch.half error when complex32 is fully implemented
sample = first_sample(self, op.sample_inputs(device, dtype))
device_type = torch.device(device).type
# FIXME: https://github.com/pytorch/pytorch/issues/108204
default_msg = (
r"(Unsupported dtype|"
r"FFT doesn't support (tensors*|transforms) of type|"
r"expected scalar type \w+ but found|)"
)
default_msg = "Unsupported dtype"
if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM:
err_msg = default_msg
elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater:
@ -451,7 +446,7 @@ class TestFFT(TestCase):
def test_fftn_invalid(self, device, dtype, op):
a = torch.rand(10, 10, 10, device=device, dtype=dtype)
# FIXME: https://github.com/pytorch/pytorch/issues/108205
errMsg = r"(dims must be unique|duplicate value in the list of dims)"
errMsg = "dims must be unique"
with self.assertRaisesRegex(RuntimeError, errMsg):
op(a, dim=(0, 1, 0))

View File

@ -54,7 +54,9 @@ def _apply_norm(
return x * (1 / signal_numel) if normalize else x
def _promote_type_fft(dtype: torch.dtype, require_complex: bool) -> torch.dtype:
def _promote_type_fft(
dtype: torch.dtype, require_complex: bool, device: torch.device
) -> torch.dtype:
"""Helper to promote a dtype to one supported by the FFT primitives"""
if dtype.is_complex:
return dtype
@ -63,6 +65,13 @@ def _promote_type_fft(dtype: torch.dtype, require_complex: bool) -> torch.dtype:
if not dtype.is_floating_point:
dtype = torch.get_default_dtype()
allowed_types = [torch.float32, torch.float64]
maybe_support_half = device.type in ["cuda", "meta"] and not torch.version.hip
if maybe_support_half:
allowed_types.append(torch.float16)
torch._check(dtype in allowed_types, lambda: f"Unsupported dtype {dtype}")
if require_complex:
dtype = utils.corresponding_complex_dtype(dtype)
@ -74,7 +83,7 @@ def _maybe_promote_tensor_fft(
) -> TensorLikeType:
"""Helper to promote a tensor to a dtype supported by the FFT primitives"""
cur_type = t.dtype
new_type = _promote_type_fft(cur_type, require_complex)
new_type = _promote_type_fft(cur_type, require_complex, t.device)
return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value]
@ -116,8 +125,10 @@ def _fft_c2r(
input = _maybe_promote_tensor_fft(input, require_complex=True)
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
num_points = 0 if n is None else n
torch._check(
last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified"
last_dim_size >= 1,
lambda: f"Invalid number of data points ({num_points}) specified",
)
if n is not None:
@ -146,6 +157,12 @@ def _fft_r2c(
)
input = _maybe_promote_tensor_fft(input)
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
num_points = 0 if n is None else n
torch._check(
last_dim_size >= 1,
lambda: f"Invalid number of data points ({num_points}) specified",
)
if n is not None:
input = _resize_fft_input(input, dims, (n,))
@ -169,6 +186,12 @@ def _fft_c2c(
lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
)
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
num_points = 0 if n is None else n
torch._check(
last_dim_size >= 1,
lambda: f"Invalid number of data points ({num_points}) specified",
)
if n is not None:
input = _resize_fft_input(input, dims, (n,))
@ -267,7 +290,9 @@ def _canonicalize_fft_shape_and_dim_args(
ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
# Check dims are unique
torch._check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique")
torch._check(
len(set(ret_dims)) == len(ret_dims), lambda: "FFT dims must be unique"
)
if shape is not None:
if not isinstance(shape, Sequence):

View File

@ -650,70 +650,26 @@ python_ref_db: List[OpInfo] = [
SpectralFuncPythonRefInfo(
"_refs.fft.fft",
torch_opinfo_name="fft.fft",
skips=(
# _refs.fft.* functions have inconsistent behavior for empty tensors
# https://github.com/pytorch/pytorch/issues/105986
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.ifft",
torch_opinfo_name="fft.ifft",
skips=(
# _refs.fft.* functions have inconsistent behavior for empty tensors
# https://github.com/pytorch/pytorch/issues/105986
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.rfft",
torch_opinfo_name="fft.rfft",
skips=(
# _refs.fft.* functions have inconsistent behavior for empty tensors
# https://github.com/pytorch/pytorch/issues/105986
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.irfft",
torch_opinfo_name="fft.irfft",
skips=(
# _refs.fft.* functions have inconsistent behavior for empty tensors
# https://github.com/pytorch/pytorch/issues/105986
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
# TODO: internally promoted to complex64 so not rejected
DecorateInfo(
unittest.expectedFailure,
"TestFFT",
"test_fft_half_and_bfloat16_errors",
dtypes=[torch.bfloat16],
),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.hfft",
torch_opinfo_name="fft.hfft",
skips=(
# _refs.fft.* functions have inconsistent behavior for empty tensors
# https://github.com/pytorch/pytorch/issues/105986
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
# FIXME: https://github.com/pytorch/pytorch/issues/108204
DecorateInfo(
unittest.expectedFailure,
"TestFFT",
"test_fft_half_and_bfloat16_errors",
dtypes=[torch.bfloat16],
),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.ihfft",
torch_opinfo_name="fft.ihfft",
skips=(
# _refs.fft.* functions have inconsistent behavior for empty tensors
# https://github.com/pytorch/pytorch/issues/105986
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.fftn",
@ -725,15 +681,6 @@ python_ref_db: List[OpInfo] = [
"test_reference_nd",
)
],
skips=(
# FIXME: https://github.com/pytorch/pytorch/issues/108204
DecorateInfo(
unittest.expectedFailure,
"TestFFT",
"test_fft_half_and_bfloat16_errors",
dtypes=[torch.bfloat16],
),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.ifftn",
@ -745,15 +692,6 @@ python_ref_db: List[OpInfo] = [
"test_reference_nd",
)
],
skips=(
# FIXME: https://github.com/pytorch/pytorch/issues/108204
DecorateInfo(
unittest.expectedFailure,
"TestFFT",
"test_fft_half_and_bfloat16_errors",
dtypes=[torch.bfloat16],
),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.rfftn",
@ -769,15 +707,6 @@ python_ref_db: List[OpInfo] = [
"test_reference_nd",
)
],
skips=(
# FIXME: https://github.com/pytorch/pytorch/issues/108204
DecorateInfo(
unittest.expectedFailure,
"TestFFT",
"test_fft_half_and_bfloat16_errors",
dtypes=[torch.bfloat16],
),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.hfftn",
@ -789,21 +718,6 @@ python_ref_db: List[OpInfo] = [
"test_reference_nd",
)
],
skips=(
# FIXME: https://github.com/pytorch/pytorch/issues/108204
DecorateInfo(
unittest.expectedFailure,
"TestFFT",
"test_fft_half_and_bfloat16_errors",
dtypes=[torch.bfloat16],
),
# FIXME: https://github.com/pytorch/pytorch/issues/108205
DecorateInfo(
unittest.expectedFailure,
"TestFFT",
"test_fftn_invalid",
),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.ihfftn",
@ -815,14 +729,6 @@ python_ref_db: List[OpInfo] = [
"test_reference_nd",
)
],
skips=(
# FIXME: https://github.com/pytorch/pytorch/issues/108205
DecorateInfo(
unittest.expectedFailure,
"TestFFT",
"test_fftn_invalid",
),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.fft2",