mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove FFT from stride incorrect ops (#145080)
I gotta say, the FFT implementation is completely insane, there's gotta be a better way to do this than repeatedly inplace restriding the output tensor. Anyway, this is a faithful translation of both the MKL and cuFFT paths to Python. Fixes https://github.com/pytorch/pytorch/issues/135087 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/145080 Approved by: https://github.com/Skylion007, https://github.com/albanD ghstack dependencies: #145530
This commit is contained in:
committed by
PyTorch MergeBot
parent
b75afa2e2e
commit
87fdadde1d
@ -6509,21 +6509,6 @@ symbolic_aot_autograd_failures = {
|
||||
"linalg.householder_product",
|
||||
decorator=unittest.skipIf(IS_MACOS and IS_X86, "flaky"),
|
||||
),
|
||||
# many complex operators incorrect striding, metadata
|
||||
xfail("fft.fft", ""),
|
||||
xfail("fft.hfft2", ""),
|
||||
xfail("fft.hfft", ""),
|
||||
xfail("fft.hfftn", ""),
|
||||
xfail("fft.ifft", ""),
|
||||
xfail("fft.ihfft2", ""),
|
||||
xfail("fft.ihfft", ""),
|
||||
xfail("fft.ihfftn", ""),
|
||||
xfail("fft.irfft2", ""),
|
||||
xfail("fft.irfft", ""),
|
||||
xfail("fft.irfftn", ""),
|
||||
xfail("fft.rfft2", ""),
|
||||
xfail("fft.rfft", ""),
|
||||
xfail("fft.rfftn", ""),
|
||||
xfail("stft", ""), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
}
|
||||
|
||||
|
@ -2014,24 +2014,6 @@ symbolic_tensor_failures = {
|
||||
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
|
||||
|
||||
xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but...
|
||||
|
||||
# many complex operators incorrect striding, metadata
|
||||
xfail('fft.fft', ''),
|
||||
xfail('fft.hfft2', ''),
|
||||
xfail('fft.hfft', ''),
|
||||
xfail('fft.hfftn', ''),
|
||||
xfail('fft.ifft', ''),
|
||||
xfail('fft.ihfft2', ''),
|
||||
xfail('fft.ihfft', ''),
|
||||
xfail('fft.ihfftn', ''),
|
||||
xfail('fft.ihfft2', ''),
|
||||
xfail('fft.irfft2', ''),
|
||||
xfail('fft.irfft', ''),
|
||||
xfail('fft.irfftn', ''),
|
||||
xfail('fft.rfft2', ''),
|
||||
xfail('fft.rfft', ''),
|
||||
xfail('fft.rfftn', ''),
|
||||
xfail('stft', '')
|
||||
}
|
||||
symbolic_tensor_segfaults = {
|
||||
skip('nn.functional.batch_norm') # Segfault??
|
||||
@ -2058,10 +2040,6 @@ out_symbolic_tensor_failures = {
|
||||
xfail('angle', ''),
|
||||
xfail('argmax', ''),
|
||||
xfail('argmin', ''),
|
||||
xfail('fft.fft2', ''),
|
||||
xfail('fft.fftn', ''),
|
||||
xfail('fft.ifft2', ''),
|
||||
xfail('fft.ifftn', ''),
|
||||
xfail('gather', ''),
|
||||
xfail('linalg.pinv', ''),
|
||||
xfail('linalg.pinv', 'hermitian'),
|
||||
|
@ -223,8 +223,12 @@ def logcumsumexp(self, dim):
|
||||
return torch.empty_like(self).contiguous()
|
||||
|
||||
|
||||
# Stride-related code from _exec_fft in aten/src/ATen/native/cuda/SpectralOps.cpp
|
||||
def _exec_fft(out, self, out_sizes, dim, forward):
|
||||
# Stride-related code from _exec_fft in aten/src/ATen/native/mkl/SpectralOps.cpp
|
||||
# and aten/src/ATen/cuda/SpectralOps.cpp
|
||||
#
|
||||
# Although the actual FFT launch is different, all the permuting code appears
|
||||
# to be the same
|
||||
def _exec_fft(out, self, out_sizes, dim, *, forward):
|
||||
ndim = self.ndim
|
||||
signal_ndim = len(dim)
|
||||
batch_dims = ndim - signal_ndim
|
||||
@ -258,12 +262,12 @@ def _exec_fft(out, self, out_sizes, dim, forward):
|
||||
|
||||
batch_size = input.size(0)
|
||||
batched_sizes[0] = batch_size
|
||||
batched_out_sizes = batched_sizes
|
||||
batched_out_sizes = list(batched_sizes)
|
||||
for i in range(len(dim)):
|
||||
batched_out_sizes[i + 1] = out_sizes[dim[i]]
|
||||
out = out.reshape(batched_out_sizes)
|
||||
out.resize_(batched_out_sizes, memory_format=torch.contiguous_format)
|
||||
|
||||
# Reshaping to original batch shape and inverting the dimension permutation
|
||||
# Inplace reshaping to original batch shape and inverting the dimension permutation
|
||||
out_strides = [0 for _ in range(ndim)]
|
||||
batch_numel = 1
|
||||
i = batch_dims - 1
|
||||
@ -273,7 +277,18 @@ def _exec_fft(out, self, out_sizes, dim, forward):
|
||||
i -= 1
|
||||
for i in range(batch_dims, ndim):
|
||||
out_strides[dim_permute[i]] = out.stride(1 + (i - batch_dims))
|
||||
return out.as_strided(out_sizes, out_strides, out.storage_offset())
|
||||
out.as_strided_(out_sizes, out_strides, out.storage_offset())
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _sort_dims(self: Tensor, dim: list[int], exclude_last: bool = False):
|
||||
sorted_dims = list(dim)
|
||||
self_strides = self.stride()
|
||||
sorted_dims[: len(sorted_dims) - int(exclude_last)].sort(
|
||||
key=lambda i: self_strides[i]
|
||||
)
|
||||
return sorted_dims
|
||||
|
||||
|
||||
# See _fft_c2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
|
||||
@ -281,35 +296,82 @@ def _exec_fft(out, self, out_sizes, dim, forward):
|
||||
@register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
|
||||
@out_wrapper()
|
||||
def meta_fft_c2c(self, dim, normalization, forward):
|
||||
assert self.dtype.is_complex
|
||||
|
||||
out_sizes = self.shape
|
||||
output = self.new_empty(out_sizes)
|
||||
|
||||
torch._check(self.dtype.is_complex)
|
||||
if not dim:
|
||||
return output
|
||||
return self.clone()
|
||||
|
||||
sorted_dims = dim[:]
|
||||
self_strides = self.stride()
|
||||
sorted_dims.sort(key=lambda x: self_strides[x], reverse=True)
|
||||
output = _exec_fft(output, self, out_sizes, sorted_dims, forward)
|
||||
sorted_dims = _sort_dims(self, dim)
|
||||
out = self.new_empty(self.size())
|
||||
return _exec_fft(out, self, self.size(), sorted_dims, forward=forward)
|
||||
|
||||
return output
|
||||
|
||||
cufft_max_ndim = 3
|
||||
|
||||
|
||||
def use_optimized_cufft_path(dim: list[int]):
|
||||
if len(dim) > cufft_max_ndim or (len(dim) >= 2 and dim[0] == 0 and dim[1] == 1):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
@register_meta([aten._fft_r2c.default, aten._fft_r2c.out])
|
||||
@out_wrapper()
|
||||
def meta_fft_r2c(self, dim, normalization, onesided):
|
||||
assert self.dtype.is_floating_point
|
||||
output_sizes = list(self.size())
|
||||
torch._check(self.dtype.is_floating_point)
|
||||
input_sizes = list(self.size())
|
||||
out_sizes = list(input_sizes)
|
||||
last_dim = dim[-1]
|
||||
last_dim_halfsize = input_sizes[last_dim] // 2 + 1
|
||||
onesided_sizes = list(input_sizes)
|
||||
onesided_sizes[last_dim] = last_dim_halfsize
|
||||
|
||||
if onesided:
|
||||
last_dim = dim[-1]
|
||||
last_dim_halfsize = (output_sizes[last_dim] // 2) + 1
|
||||
output_sizes[last_dim] = last_dim_halfsize
|
||||
out_sizes[last_dim] = last_dim_halfsize
|
||||
|
||||
if device_hint(self) == "cuda":
|
||||
# _fft_r2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
|
||||
output = self.new_empty(
|
||||
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
|
||||
)
|
||||
|
||||
working_tensor = self
|
||||
if use_optimized_cufft_path(dim):
|
||||
_exec_fft(output, working_tensor, out_sizes, dim, forward=True)
|
||||
else:
|
||||
# First do the R2C transform on the last dimension
|
||||
target_sizes = out_sizes if len(dim) == 1 else onesided_sizes
|
||||
_exec_fft(output, working_tensor, target_sizes, [last_dim], forward=True)
|
||||
if len(dim) > 1:
|
||||
working_tensor = self.new_empty(
|
||||
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
|
||||
)
|
||||
|
||||
# Then any remaining C2C transforms
|
||||
sorted_dims = dim[:-1]
|
||||
while sorted_dims:
|
||||
output, working_tensor = working_tensor, output
|
||||
strides = working_tensor.stride()
|
||||
sorted_dims.sort(
|
||||
key=lambda i: strides[i], reverse=True
|
||||
) # NB reverse! Not sure if this is og bug
|
||||
max_dims = min(cufft_max_ndim, len(sorted_dims))
|
||||
last_dims = sorted_dims[len(sorted_dims) - max_dims :]
|
||||
_exec_fft(
|
||||
output, working_tensor, onesided_sizes, last_dims, forward=True
|
||||
)
|
||||
sorted_dims = sorted_dims[: len(sorted_dims) - max_dims]
|
||||
|
||||
if not onesided:
|
||||
if output.size(last_dim) != out_sizes[last_dim]:
|
||||
working_tensor.resize_(out_sizes, memory_format=torch.contiguous_format)
|
||||
output = working_tensor
|
||||
|
||||
return output
|
||||
|
||||
else:
|
||||
return self.new_empty(
|
||||
output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
|
||||
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
|
||||
)
|
||||
|
||||
|
||||
@ -375,11 +437,43 @@ def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=
|
||||
|
||||
@register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
|
||||
@out_wrapper()
|
||||
def meta_fft_c2r(self, dim, normalization, lastdim):
|
||||
assert self.dtype.is_complex
|
||||
output_sizes = list(self.size())
|
||||
output_sizes[dim[-1]] = lastdim
|
||||
return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))
|
||||
def meta_fft_c2r(self: Tensor, dim: list[int], normalization: int, lastdim: int):
|
||||
# _fft_c2r_mkl
|
||||
torch._check(self.dtype.is_complex)
|
||||
|
||||
if device_hint(self) == "cuda":
|
||||
out_sizes = list(self.size())
|
||||
out_sizes[dim[-1]] = lastdim
|
||||
|
||||
output = self.new_empty(out_sizes, dtype=toRealValueType(self.dtype))
|
||||
|
||||
if use_optimized_cufft_path(dim):
|
||||
return _exec_fft(
|
||||
output,
|
||||
self.clone(memory_format=torch.contiguous_format),
|
||||
out_sizes,
|
||||
dim,
|
||||
forward=False,
|
||||
)
|
||||
else:
|
||||
# First complete any C2C transforms
|
||||
if len(dim) > 1:
|
||||
temp = meta_fft_c2c(self, dim[:-1], 0, lastdim) # fft_norm_mode::none
|
||||
else:
|
||||
temp = self.clone(memory_format=torch.contiguous_format)
|
||||
return _exec_fft(output, temp, out_sizes, [dim[-1]], forward=False)
|
||||
|
||||
else:
|
||||
input = self
|
||||
if len(dim) > 1:
|
||||
c2c_dims = dim[:-1]
|
||||
input = meta_fft_c2c(self, c2c_dims, normalization, forward=False)
|
||||
dim = dim[-1:]
|
||||
|
||||
out_sizes = list(input.size())
|
||||
out_sizes[dim[-1]] = lastdim
|
||||
out = self.new_empty(out_sizes, dtype=toRealValueType(self.dtype))
|
||||
return _exec_fft(out, input, out_sizes, dim, forward=False)
|
||||
|
||||
|
||||
@register_meta(aten.copy_.default)
|
||||
|
@ -222,14 +222,6 @@ def non_kwarg_to(fake_mode, func, *args, **kwargs):
|
||||
|
||||
|
||||
def stride_incorrect_op(op):
|
||||
if op.namespace not in ("aten", "prims"):
|
||||
return False
|
||||
if op is aten._fft_c2c.default:
|
||||
return False
|
||||
|
||||
op_name = op.name()
|
||||
if "fft" in op_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user