Remove FFT from stride incorrect ops (#145080)

I gotta say, the FFT implementation is completely insane, there's gotta be a better way to do this than repeatedly inplace restriding the output tensor. Anyway, this is a faithful translation of both the MKL and cuFFT paths to Python.

Fixes https://github.com/pytorch/pytorch/issues/135087

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145080
Approved by: https://github.com/Skylion007, https://github.com/albanD
ghstack dependencies: #145530
This commit is contained in:
Edward Z. Yang
2025-01-24 07:13:58 -08:00
committed by PyTorch MergeBot
parent b75afa2e2e
commit 87fdadde1d
4 changed files with 124 additions and 75 deletions

View File

@ -6509,21 +6509,6 @@ symbolic_aot_autograd_failures = {
"linalg.householder_product",
decorator=unittest.skipIf(IS_MACOS and IS_X86, "flaky"),
),
# many complex operators incorrect striding, metadata
xfail("fft.fft", ""),
xfail("fft.hfft2", ""),
xfail("fft.hfft", ""),
xfail("fft.hfftn", ""),
xfail("fft.ifft", ""),
xfail("fft.ihfft2", ""),
xfail("fft.ihfft", ""),
xfail("fft.ihfftn", ""),
xfail("fft.irfft2", ""),
xfail("fft.irfft", ""),
xfail("fft.irfftn", ""),
xfail("fft.rfft2", ""),
xfail("fft.rfft", ""),
xfail("fft.rfftn", ""),
xfail("stft", ""), # Cannot call sizes() on tensor with symbolic sizes/strides
}

View File

@ -2014,24 +2014,6 @@ symbolic_tensor_failures = {
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but...
# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),
xfail('fft.hfft2', ''),
xfail('fft.hfft', ''),
xfail('fft.hfftn', ''),
xfail('fft.ifft', ''),
xfail('fft.ihfft2', ''),
xfail('fft.ihfft', ''),
xfail('fft.ihfftn', ''),
xfail('fft.ihfft2', ''),
xfail('fft.irfft2', ''),
xfail('fft.irfft', ''),
xfail('fft.irfftn', ''),
xfail('fft.rfft2', ''),
xfail('fft.rfft', ''),
xfail('fft.rfftn', ''),
xfail('stft', '')
}
symbolic_tensor_segfaults = {
skip('nn.functional.batch_norm') # Segfault??
@ -2058,10 +2040,6 @@ out_symbolic_tensor_failures = {
xfail('angle', ''),
xfail('argmax', ''),
xfail('argmin', ''),
xfail('fft.fft2', ''),
xfail('fft.fftn', ''),
xfail('fft.ifft2', ''),
xfail('fft.ifftn', ''),
xfail('gather', ''),
xfail('linalg.pinv', ''),
xfail('linalg.pinv', 'hermitian'),

View File

@ -223,8 +223,12 @@ def logcumsumexp(self, dim):
return torch.empty_like(self).contiguous()
# Stride-related code from _exec_fft in aten/src/ATen/native/cuda/SpectralOps.cpp
def _exec_fft(out, self, out_sizes, dim, forward):
# Stride-related code from _exec_fft in aten/src/ATen/native/mkl/SpectralOps.cpp
# and aten/src/ATen/cuda/SpectralOps.cpp
#
# Although the actual FFT launch is different, all the permuting code appears
# to be the same
def _exec_fft(out, self, out_sizes, dim, *, forward):
ndim = self.ndim
signal_ndim = len(dim)
batch_dims = ndim - signal_ndim
@ -258,12 +262,12 @@ def _exec_fft(out, self, out_sizes, dim, forward):
batch_size = input.size(0)
batched_sizes[0] = batch_size
batched_out_sizes = batched_sizes
batched_out_sizes = list(batched_sizes)
for i in range(len(dim)):
batched_out_sizes[i + 1] = out_sizes[dim[i]]
out = out.reshape(batched_out_sizes)
out.resize_(batched_out_sizes, memory_format=torch.contiguous_format)
# Reshaping to original batch shape and inverting the dimension permutation
# Inplace reshaping to original batch shape and inverting the dimension permutation
out_strides = [0 for _ in range(ndim)]
batch_numel = 1
i = batch_dims - 1
@ -273,7 +277,18 @@ def _exec_fft(out, self, out_sizes, dim, forward):
i -= 1
for i in range(batch_dims, ndim):
out_strides[dim_permute[i]] = out.stride(1 + (i - batch_dims))
return out.as_strided(out_sizes, out_strides, out.storage_offset())
out.as_strided_(out_sizes, out_strides, out.storage_offset())
return out
def _sort_dims(self: Tensor, dim: list[int], exclude_last: bool = False):
sorted_dims = list(dim)
self_strides = self.stride()
sorted_dims[: len(sorted_dims) - int(exclude_last)].sort(
key=lambda i: self_strides[i]
)
return sorted_dims
# See _fft_c2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
@ -281,35 +296,82 @@ def _exec_fft(out, self, out_sizes, dim, forward):
@register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
@out_wrapper()
def meta_fft_c2c(self, dim, normalization, forward):
assert self.dtype.is_complex
out_sizes = self.shape
output = self.new_empty(out_sizes)
torch._check(self.dtype.is_complex)
if not dim:
return output
return self.clone()
sorted_dims = dim[:]
self_strides = self.stride()
sorted_dims.sort(key=lambda x: self_strides[x], reverse=True)
output = _exec_fft(output, self, out_sizes, sorted_dims, forward)
sorted_dims = _sort_dims(self, dim)
out = self.new_empty(self.size())
return _exec_fft(out, self, self.size(), sorted_dims, forward=forward)
return output
cufft_max_ndim = 3
def use_optimized_cufft_path(dim: list[int]):
if len(dim) > cufft_max_ndim or (len(dim) >= 2 and dim[0] == 0 and dim[1] == 1):
return False
else:
return True
@register_meta([aten._fft_r2c.default, aten._fft_r2c.out])
@out_wrapper()
def meta_fft_r2c(self, dim, normalization, onesided):
assert self.dtype.is_floating_point
output_sizes = list(self.size())
torch._check(self.dtype.is_floating_point)
input_sizes = list(self.size())
out_sizes = list(input_sizes)
last_dim = dim[-1]
last_dim_halfsize = input_sizes[last_dim] // 2 + 1
onesided_sizes = list(input_sizes)
onesided_sizes[last_dim] = last_dim_halfsize
if onesided:
last_dim = dim[-1]
last_dim_halfsize = (output_sizes[last_dim] // 2) + 1
output_sizes[last_dim] = last_dim_halfsize
out_sizes[last_dim] = last_dim_halfsize
if device_hint(self) == "cuda":
# _fft_r2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
output = self.new_empty(
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
)
working_tensor = self
if use_optimized_cufft_path(dim):
_exec_fft(output, working_tensor, out_sizes, dim, forward=True)
else:
# First do the R2C transform on the last dimension
target_sizes = out_sizes if len(dim) == 1 else onesided_sizes
_exec_fft(output, working_tensor, target_sizes, [last_dim], forward=True)
if len(dim) > 1:
working_tensor = self.new_empty(
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
)
# Then any remaining C2C transforms
sorted_dims = dim[:-1]
while sorted_dims:
output, working_tensor = working_tensor, output
strides = working_tensor.stride()
sorted_dims.sort(
key=lambda i: strides[i], reverse=True
) # NB reverse! Not sure if this is og bug
max_dims = min(cufft_max_ndim, len(sorted_dims))
last_dims = sorted_dims[len(sorted_dims) - max_dims :]
_exec_fft(
output, working_tensor, onesided_sizes, last_dims, forward=True
)
sorted_dims = sorted_dims[: len(sorted_dims) - max_dims]
if not onesided:
if output.size(last_dim) != out_sizes[last_dim]:
working_tensor.resize_(out_sizes, memory_format=torch.contiguous_format)
output = working_tensor
return output
else:
return self.new_empty(
output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
)
@ -375,11 +437,43 @@ def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=
@register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
@out_wrapper()
def meta_fft_c2r(self, dim, normalization, lastdim):
assert self.dtype.is_complex
output_sizes = list(self.size())
output_sizes[dim[-1]] = lastdim
return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))
def meta_fft_c2r(self: Tensor, dim: list[int], normalization: int, lastdim: int):
# _fft_c2r_mkl
torch._check(self.dtype.is_complex)
if device_hint(self) == "cuda":
out_sizes = list(self.size())
out_sizes[dim[-1]] = lastdim
output = self.new_empty(out_sizes, dtype=toRealValueType(self.dtype))
if use_optimized_cufft_path(dim):
return _exec_fft(
output,
self.clone(memory_format=torch.contiguous_format),
out_sizes,
dim,
forward=False,
)
else:
# First complete any C2C transforms
if len(dim) > 1:
temp = meta_fft_c2c(self, dim[:-1], 0, lastdim) # fft_norm_mode::none
else:
temp = self.clone(memory_format=torch.contiguous_format)
return _exec_fft(output, temp, out_sizes, [dim[-1]], forward=False)
else:
input = self
if len(dim) > 1:
c2c_dims = dim[:-1]
input = meta_fft_c2c(self, c2c_dims, normalization, forward=False)
dim = dim[-1:]
out_sizes = list(input.size())
out_sizes[dim[-1]] = lastdim
out = self.new_empty(out_sizes, dtype=toRealValueType(self.dtype))
return _exec_fft(out, input, out_sizes, dim, forward=False)
@register_meta(aten.copy_.default)

View File

@ -222,14 +222,6 @@ def non_kwarg_to(fake_mode, func, *args, **kwargs):
def stride_incorrect_op(op):
if op.namespace not in ("aten", "prims"):
return False
if op is aten._fft_c2c.default:
return False
op_name = op.name()
if "fft" in op_name:
return True
return False