Files
pytorch/torch/_numpy/linalg.py
Edward Z. Yang 9bce208dfb Replace follow_imports = silent with normal (#118414)
This is a lot of files changed! Don't panic! Here's how it works:

* Previously, we set `follow_imports = silent` for our mypy.ini configuration. Per https://mypy.readthedocs.io/en/stable/running_mypy.html#follow-imports, what this does is whenever we have an import to a module which is not listed as a file to be typechecked in mypy, we typecheck it as normal but suppress all errors that occurred in that file.
* When mypy is run inside lintrunner, the list of files is precisely the files covered by the glob in lintrunner.toml, but with files in excludes excluded.
* The top-level directive `# mypy: ignore-errors` instructs mypy to typecheck the file as normal, but ignore all errors.
* Therefore, it should be equivalent to set `follow_imports = normal`, if we put `# mypy: ignore-errors` on all files that were previously excluded from the file list.
* Having done this, we can remove the exclude list from .lintrunner.toml, since excluding a file from typechecking is baked into the files themselves.
* torch/_dynamo and torch/_inductor were previously in the exclude list, because they were covered by MYPYINDUCTOR. It is not OK to mark these as `# mypy: ignore-errors` as this will impede typechecking on the alternate configuration. So they are temporarily being checked twice, but I am suppressing the errors in these files as the configurations are not quite the same. I plan to unify the configurations so this is only a temporary state.
* There were some straggler type errors after these changes somehow, so I fixed them as needed. There weren't that many.

In the future, to start type checking a file, just remove the ignore-errors directive from the top of the file.

The codemod was done with this script authored by GPT-4:

```
import glob

exclude_patterns = [
    ...
]

for pattern in exclude_patterns:
    for filepath in glob.glob(pattern, recursive=True):
        if filepath.endswith('.py'):
            with open(filepath, 'r+') as f:
                content = f.read()
                f.seek(0, 0)
                f.write('# mypy: ignore-errors\n\n' + content)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118414
Approved by: https://github.com/thiagocrepaldi, https://github.com/albanD
2024-01-27 02:44:11 +00:00

240 lines
5.5 KiB
Python

# mypy: ignore-errors
from __future__ import annotations
import functools
import math
from typing import Sequence
import torch
from . import _dtypes_impl, _util
from ._normalizations import ArrayLike, KeepDims, normalizer
class LinAlgError(Exception):
pass
def _atleast_float_1(a):
if not (a.dtype.is_floating_point or a.dtype.is_complex):
a = a.to(_dtypes_impl.default_dtypes().float_dtype)
return a
def _atleast_float_2(a, b):
dtyp = _dtypes_impl.result_type_impl(a, b)
if not (dtyp.is_floating_point or dtyp.is_complex):
dtyp = _dtypes_impl.default_dtypes().float_dtype
a = _util.cast_if_needed(a, dtyp)
b = _util.cast_if_needed(b, dtyp)
return a, b
def linalg_errors(func):
@functools.wraps(func)
def wrapped(*args, **kwds):
try:
return func(*args, **kwds)
except torch._C._LinAlgError as e:
raise LinAlgError(*e.args) # noqa: TRY200
return wrapped
# ### Matrix and vector products ###
@normalizer
@linalg_errors
def matrix_power(a: ArrayLike, n):
a = _atleast_float_1(a)
return torch.linalg.matrix_power(a, n)
@normalizer
@linalg_errors
def multi_dot(inputs: Sequence[ArrayLike], *, out=None):
return torch.linalg.multi_dot(inputs)
# ### Solving equations and inverting matrices ###
@normalizer
@linalg_errors
def solve(a: ArrayLike, b: ArrayLike):
a, b = _atleast_float_2(a, b)
return torch.linalg.solve(a, b)
@normalizer
@linalg_errors
def lstsq(a: ArrayLike, b: ArrayLike, rcond=None):
a, b = _atleast_float_2(a, b)
# NumPy is using gelsd: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/umath_linalg.cpp#L3991
# on CUDA, only `gels` is available though, so use it instead
driver = "gels" if a.is_cuda or b.is_cuda else "gelsd"
return torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
@normalizer
@linalg_errors
def inv(a: ArrayLike):
a = _atleast_float_1(a)
result = torch.linalg.inv(a)
return result
@normalizer
@linalg_errors
def pinv(a: ArrayLike, rcond=1e-15, hermitian=False):
a = _atleast_float_1(a)
return torch.linalg.pinv(a, rtol=rcond, hermitian=hermitian)
@normalizer
@linalg_errors
def tensorsolve(a: ArrayLike, b: ArrayLike, axes=None):
a, b = _atleast_float_2(a, b)
return torch.linalg.tensorsolve(a, b, dims=axes)
@normalizer
@linalg_errors
def tensorinv(a: ArrayLike, ind=2):
a = _atleast_float_1(a)
return torch.linalg.tensorinv(a, ind=ind)
# ### Norms and other numbers ###
@normalizer
@linalg_errors
def det(a: ArrayLike):
a = _atleast_float_1(a)
return torch.linalg.det(a)
@normalizer
@linalg_errors
def slogdet(a: ArrayLike):
a = _atleast_float_1(a)
return torch.linalg.slogdet(a)
@normalizer
@linalg_errors
def cond(x: ArrayLike, p=None):
x = _atleast_float_1(x)
# check if empty
# cf: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
if x.numel() == 0 and math.prod(x.shape[-2:]) == 0:
raise LinAlgError("cond is not defined on empty arrays")
result = torch.linalg.cond(x, p=p)
# Convert nans to infs (numpy does it in a data-dependent way, depending on
# whether the input array has nans or not)
# XXX: NumPy does this: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
return torch.where(torch.isnan(result), float("inf"), result)
@normalizer
@linalg_errors
def matrix_rank(a: ArrayLike, tol=None, hermitian=False):
a = _atleast_float_1(a)
if a.ndim < 2:
return int((a != 0).any())
if tol is None:
# follow https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1885
atol = 0
rtol = max(a.shape[-2:]) * torch.finfo(a.dtype).eps
else:
atol, rtol = tol, 0
return torch.linalg.matrix_rank(a, atol=atol, rtol=rtol, hermitian=hermitian)
@normalizer
@linalg_errors
def norm(x: ArrayLike, ord=None, axis=None, keepdims: KeepDims = False):
x = _atleast_float_1(x)
return torch.linalg.norm(x, ord=ord, dim=axis)
# ### Decompositions ###
@normalizer
@linalg_errors
def cholesky(a: ArrayLike):
a = _atleast_float_1(a)
return torch.linalg.cholesky(a)
@normalizer
@linalg_errors
def qr(a: ArrayLike, mode="reduced"):
a = _atleast_float_1(a)
result = torch.linalg.qr(a, mode=mode)
if mode == "r":
# match NumPy
result = result.R
return result
@normalizer
@linalg_errors
def svd(a: ArrayLike, full_matrices=True, compute_uv=True, hermitian=False):
a = _atleast_float_1(a)
if not compute_uv:
return torch.linalg.svdvals(a)
# NB: ignore the hermitian= argument (no pytorch equivalent)
result = torch.linalg.svd(a, full_matrices=full_matrices)
return result
# ### Eigenvalues and eigenvectors ###
@normalizer
@linalg_errors
def eig(a: ArrayLike):
a = _atleast_float_1(a)
w, vt = torch.linalg.eig(a)
if not a.is_complex() and w.is_complex() and (w.imag == 0).all():
w = w.real
vt = vt.real
return w, vt
@normalizer
@linalg_errors
def eigh(a: ArrayLike, UPLO="L"):
a = _atleast_float_1(a)
return torch.linalg.eigh(a, UPLO=UPLO)
@normalizer
@linalg_errors
def eigvals(a: ArrayLike):
a = _atleast_float_1(a)
result = torch.linalg.eigvals(a)
if not a.is_complex() and result.is_complex() and (result.imag == 0).all():
result = result.real
return result
@normalizer
@linalg_errors
def eigvalsh(a: ArrayLike, UPLO="L"):
a = _atleast_float_1(a)
return torch.linalg.eigvalsh(a, UPLO=UPLO)