mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
decomposition spectral ops fixes (#108360)
Fixes https://github.com/pytorch/pytorch/issues/105986, https://github.com/pytorch/pytorch/issues/108204, https://github.com/pytorch/pytorch/issues/108205 Fix all issues flagged when making changes for https://github.com/pytorch/pytorch/pull/107421 Pull Request resolved: https://github.com/pytorch/pytorch/pull/108360 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
d4230e5574
commit
0f88d93b10
@ -325,12 +325,7 @@ class TestFFT(TestCase):
|
||||
# TODO: Remove torch.half error when complex32 is fully implemented
|
||||
sample = first_sample(self, op.sample_inputs(device, dtype))
|
||||
device_type = torch.device(device).type
|
||||
# FIXME: https://github.com/pytorch/pytorch/issues/108204
|
||||
default_msg = (
|
||||
r"(Unsupported dtype|"
|
||||
r"FFT doesn't support (tensors*|transforms) of type|"
|
||||
r"expected scalar type \w+ but found|)"
|
||||
)
|
||||
default_msg = "Unsupported dtype"
|
||||
if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM:
|
||||
err_msg = default_msg
|
||||
elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater:
|
||||
@ -451,7 +446,7 @@ class TestFFT(TestCase):
|
||||
def test_fftn_invalid(self, device, dtype, op):
|
||||
a = torch.rand(10, 10, 10, device=device, dtype=dtype)
|
||||
# FIXME: https://github.com/pytorch/pytorch/issues/108205
|
||||
errMsg = r"(dims must be unique|duplicate value in the list of dims)"
|
||||
errMsg = "dims must be unique"
|
||||
with self.assertRaisesRegex(RuntimeError, errMsg):
|
||||
op(a, dim=(0, 1, 0))
|
||||
|
||||
|
@ -54,7 +54,9 @@ def _apply_norm(
|
||||
return x * (1 / signal_numel) if normalize else x
|
||||
|
||||
|
||||
def _promote_type_fft(dtype: torch.dtype, require_complex: bool) -> torch.dtype:
|
||||
def _promote_type_fft(
|
||||
dtype: torch.dtype, require_complex: bool, device: torch.device
|
||||
) -> torch.dtype:
|
||||
"""Helper to promote a dtype to one supported by the FFT primitives"""
|
||||
if dtype.is_complex:
|
||||
return dtype
|
||||
@ -63,6 +65,13 @@ def _promote_type_fft(dtype: torch.dtype, require_complex: bool) -> torch.dtype:
|
||||
if not dtype.is_floating_point:
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
allowed_types = [torch.float32, torch.float64]
|
||||
maybe_support_half = device.type in ["cuda", "meta"] and not torch.version.hip
|
||||
|
||||
if maybe_support_half:
|
||||
allowed_types.append(torch.float16)
|
||||
torch._check(dtype in allowed_types, lambda: f"Unsupported dtype {dtype}")
|
||||
|
||||
if require_complex:
|
||||
dtype = utils.corresponding_complex_dtype(dtype)
|
||||
|
||||
@ -74,7 +83,7 @@ def _maybe_promote_tensor_fft(
|
||||
) -> TensorLikeType:
|
||||
"""Helper to promote a tensor to a dtype supported by the FFT primitives"""
|
||||
cur_type = t.dtype
|
||||
new_type = _promote_type_fft(cur_type, require_complex)
|
||||
new_type = _promote_type_fft(cur_type, require_complex, t.device)
|
||||
return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value]
|
||||
|
||||
|
||||
@ -116,8 +125,10 @@ def _fft_c2r(
|
||||
input = _maybe_promote_tensor_fft(input, require_complex=True)
|
||||
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
|
||||
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
|
||||
num_points = 0 if n is None else n
|
||||
torch._check(
|
||||
last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified"
|
||||
last_dim_size >= 1,
|
||||
lambda: f"Invalid number of data points ({num_points}) specified",
|
||||
)
|
||||
|
||||
if n is not None:
|
||||
@ -146,6 +157,12 @@ def _fft_r2c(
|
||||
)
|
||||
input = _maybe_promote_tensor_fft(input)
|
||||
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
|
||||
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
|
||||
num_points = 0 if n is None else n
|
||||
torch._check(
|
||||
last_dim_size >= 1,
|
||||
lambda: f"Invalid number of data points ({num_points}) specified",
|
||||
)
|
||||
|
||||
if n is not None:
|
||||
input = _resize_fft_input(input, dims, (n,))
|
||||
@ -169,6 +186,12 @@ def _fft_c2c(
|
||||
lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
|
||||
)
|
||||
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
|
||||
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
|
||||
num_points = 0 if n is None else n
|
||||
torch._check(
|
||||
last_dim_size >= 1,
|
||||
lambda: f"Invalid number of data points ({num_points}) specified",
|
||||
)
|
||||
|
||||
if n is not None:
|
||||
input = _resize_fft_input(input, dims, (n,))
|
||||
@ -267,7 +290,9 @@ def _canonicalize_fft_shape_and_dim_args(
|
||||
ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
|
||||
|
||||
# Check dims are unique
|
||||
torch._check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique")
|
||||
torch._check(
|
||||
len(set(ret_dims)) == len(ret_dims), lambda: "FFT dims must be unique"
|
||||
)
|
||||
|
||||
if shape is not None:
|
||||
if not isinstance(shape, Sequence):
|
||||
|
@ -650,70 +650,26 @@ python_ref_db: List[OpInfo] = [
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.fft",
|
||||
torch_opinfo_name="fft.fft",
|
||||
skips=(
|
||||
# _refs.fft.* functions have inconsistent behavior for empty tensors
|
||||
# https://github.com/pytorch/pytorch/issues/105986
|
||||
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
|
||||
),
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.ifft",
|
||||
torch_opinfo_name="fft.ifft",
|
||||
skips=(
|
||||
# _refs.fft.* functions have inconsistent behavior for empty tensors
|
||||
# https://github.com/pytorch/pytorch/issues/105986
|
||||
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
|
||||
),
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.rfft",
|
||||
torch_opinfo_name="fft.rfft",
|
||||
skips=(
|
||||
# _refs.fft.* functions have inconsistent behavior for empty tensors
|
||||
# https://github.com/pytorch/pytorch/issues/105986
|
||||
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
|
||||
),
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.irfft",
|
||||
torch_opinfo_name="fft.irfft",
|
||||
skips=(
|
||||
# _refs.fft.* functions have inconsistent behavior for empty tensors
|
||||
# https://github.com/pytorch/pytorch/issues/105986
|
||||
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
|
||||
# TODO: internally promoted to complex64 so not rejected
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestFFT",
|
||||
"test_fft_half_and_bfloat16_errors",
|
||||
dtypes=[torch.bfloat16],
|
||||
),
|
||||
),
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.hfft",
|
||||
torch_opinfo_name="fft.hfft",
|
||||
skips=(
|
||||
# _refs.fft.* functions have inconsistent behavior for empty tensors
|
||||
# https://github.com/pytorch/pytorch/issues/105986
|
||||
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
|
||||
# FIXME: https://github.com/pytorch/pytorch/issues/108204
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestFFT",
|
||||
"test_fft_half_and_bfloat16_errors",
|
||||
dtypes=[torch.bfloat16],
|
||||
),
|
||||
),
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.ihfft",
|
||||
torch_opinfo_name="fft.ihfft",
|
||||
skips=(
|
||||
# _refs.fft.* functions have inconsistent behavior for empty tensors
|
||||
# https://github.com/pytorch/pytorch/issues/105986
|
||||
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
|
||||
),
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.fftn",
|
||||
@ -725,15 +681,6 @@ python_ref_db: List[OpInfo] = [
|
||||
"test_reference_nd",
|
||||
)
|
||||
],
|
||||
skips=(
|
||||
# FIXME: https://github.com/pytorch/pytorch/issues/108204
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestFFT",
|
||||
"test_fft_half_and_bfloat16_errors",
|
||||
dtypes=[torch.bfloat16],
|
||||
),
|
||||
),
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.ifftn",
|
||||
@ -745,15 +692,6 @@ python_ref_db: List[OpInfo] = [
|
||||
"test_reference_nd",
|
||||
)
|
||||
],
|
||||
skips=(
|
||||
# FIXME: https://github.com/pytorch/pytorch/issues/108204
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestFFT",
|
||||
"test_fft_half_and_bfloat16_errors",
|
||||
dtypes=[torch.bfloat16],
|
||||
),
|
||||
),
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.rfftn",
|
||||
@ -769,15 +707,6 @@ python_ref_db: List[OpInfo] = [
|
||||
"test_reference_nd",
|
||||
)
|
||||
],
|
||||
skips=(
|
||||
# FIXME: https://github.com/pytorch/pytorch/issues/108204
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestFFT",
|
||||
"test_fft_half_and_bfloat16_errors",
|
||||
dtypes=[torch.bfloat16],
|
||||
),
|
||||
),
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.hfftn",
|
||||
@ -789,21 +718,6 @@ python_ref_db: List[OpInfo] = [
|
||||
"test_reference_nd",
|
||||
)
|
||||
],
|
||||
skips=(
|
||||
# FIXME: https://github.com/pytorch/pytorch/issues/108204
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestFFT",
|
||||
"test_fft_half_and_bfloat16_errors",
|
||||
dtypes=[torch.bfloat16],
|
||||
),
|
||||
# FIXME: https://github.com/pytorch/pytorch/issues/108205
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestFFT",
|
||||
"test_fftn_invalid",
|
||||
),
|
||||
),
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.ihfftn",
|
||||
@ -815,14 +729,6 @@ python_ref_db: List[OpInfo] = [
|
||||
"test_reference_nd",
|
||||
)
|
||||
],
|
||||
skips=(
|
||||
# FIXME: https://github.com/pytorch/pytorch/issues/108205
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestFFT",
|
||||
"test_fftn_invalid",
|
||||
),
|
||||
),
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.fft2",
|
||||
|
Reference in New Issue
Block a user