Files
pytorch/torch/_numpy/fft.py
Edward Z. Yang 9bce208dfb Replace follow_imports = silent with normal (#118414)
This is a lot of files changed! Don't panic! Here's how it works:

* Previously, we set `follow_imports = silent` for our mypy.ini configuration. Per https://mypy.readthedocs.io/en/stable/running_mypy.html#follow-imports, what this does is whenever we have an import to a module which is not listed as a file to be typechecked in mypy, we typecheck it as normal but suppress all errors that occurred in that file.
* When mypy is run inside lintrunner, the list of files is precisely the files covered by the glob in lintrunner.toml, but with files in excludes excluded.
* The top-level directive `# mypy: ignore-errors` instructs mypy to typecheck the file as normal, but ignore all errors.
* Therefore, it should be equivalent to set `follow_imports = normal`, if we put `# mypy: ignore-errors` on all files that were previously excluded from the file list.
* Having done this, we can remove the exclude list from .lintrunner.toml, since excluding a file from typechecking is baked into the files themselves.
* torch/_dynamo and torch/_inductor were previously in the exclude list, because they were covered by MYPYINDUCTOR. It is not OK to mark these as `# mypy: ignore-errors` as this will impede typechecking on the alternate configuration. So they are temporarily being checked twice, but I am suppressing the errors in these files as the configurations are not quite the same. I plan to unify the configurations so this is only a temporary state.
* There were some straggler type errors after these changes somehow, so I fixed them as needed. There weren't that many.

In the future, to start type checking a file, just remove the ignore-errors directive from the top of the file.

The codemod was done with this script authored by GPT-4:

```
import glob

exclude_patterns = [
    ...
]

for pattern in exclude_patterns:
    for filepath in glob.glob(pattern, recursive=True):
        if filepath.endswith('.py'):
            with open(filepath, 'r+') as f:
                content = f.read()
                f.seek(0, 0)
                f.write('# mypy: ignore-errors\n\n' + content)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118414
Approved by: https://github.com/thiagocrepaldi, https://github.com/albanD
2024-01-27 02:44:11 +00:00

131 lines
2.7 KiB
Python

# mypy: ignore-errors
from __future__ import annotations
import functools
import torch
from . import _dtypes_impl, _util
from ._normalizations import ArrayLike, normalizer
def upcast(func):
"""NumPy fft casts inputs to 64 bit and *returns 64-bit results*."""
@functools.wraps(func)
def wrapped(tensor, *args, **kwds):
target_dtype = (
_dtypes_impl.default_dtypes().complex_dtype
if tensor.is_complex()
else _dtypes_impl.default_dtypes().float_dtype
)
tensor = _util.cast_if_needed(tensor, target_dtype)
return func(tensor, *args, **kwds)
return wrapped
@normalizer
@upcast
def fft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.fft(a, n, dim=axis, norm=norm)
@normalizer
@upcast
def ifft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.ifft(a, n, dim=axis, norm=norm)
@normalizer
@upcast
def rfft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.rfft(a, n, dim=axis, norm=norm)
@normalizer
@upcast
def irfft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.irfft(a, n, dim=axis, norm=norm)
@normalizer
@upcast
def fftn(a: ArrayLike, s=None, axes=None, norm=None):
return torch.fft.fftn(a, s, dim=axes, norm=norm)
@normalizer
@upcast
def ifftn(a: ArrayLike, s=None, axes=None, norm=None):
return torch.fft.ifftn(a, s, dim=axes, norm=norm)
@normalizer
@upcast
def rfftn(a: ArrayLike, s=None, axes=None, norm=None):
return torch.fft.rfftn(a, s, dim=axes, norm=norm)
@normalizer
@upcast
def irfftn(a: ArrayLike, s=None, axes=None, norm=None):
return torch.fft.irfftn(a, s, dim=axes, norm=norm)
@normalizer
@upcast
def fft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
return torch.fft.fft2(a, s, dim=axes, norm=norm)
@normalizer
@upcast
def ifft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
return torch.fft.ifft2(a, s, dim=axes, norm=norm)
@normalizer
@upcast
def rfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
return torch.fft.rfft2(a, s, dim=axes, norm=norm)
@normalizer
@upcast
def irfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
return torch.fft.irfft2(a, s, dim=axes, norm=norm)
@normalizer
@upcast
def hfft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.hfft(a, n, dim=axis, norm=norm)
@normalizer
@upcast
def ihfft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.ihfft(a, n, dim=axis, norm=norm)
@normalizer
def fftfreq(n, d=1.0):
return torch.fft.fftfreq(n, d)
@normalizer
def rfftfreq(n, d=1.0):
return torch.fft.rfftfreq(n, d)
@normalizer
def fftshift(x: ArrayLike, axes=None):
return torch.fft.fftshift(x, axes)
@normalizer
def ifftshift(x: ArrayLike, axes=None):
return torch.fft.ifftshift(x, axes)