Files
pytorch/torch/_numpy/_reductions_impl.py
Edward Z. Yang 9bce208dfb Replace follow_imports = silent with normal (#118414)
This is a lot of files changed! Don't panic! Here's how it works:

* Previously, we set `follow_imports = silent` for our mypy.ini configuration. Per https://mypy.readthedocs.io/en/stable/running_mypy.html#follow-imports, what this does is whenever we have an import to a module which is not listed as a file to be typechecked in mypy, we typecheck it as normal but suppress all errors that occurred in that file.
* When mypy is run inside lintrunner, the list of files is precisely the files covered by the glob in lintrunner.toml, but with files in excludes excluded.
* The top-level directive `# mypy: ignore-errors` instructs mypy to typecheck the file as normal, but ignore all errors.
* Therefore, it should be equivalent to set `follow_imports = normal`, if we put `# mypy: ignore-errors` on all files that were previously excluded from the file list.
* Having done this, we can remove the exclude list from .lintrunner.toml, since excluding a file from typechecking is baked into the files themselves.
* torch/_dynamo and torch/_inductor were previously in the exclude list, because they were covered by MYPYINDUCTOR. It is not OK to mark these as `# mypy: ignore-errors` as this will impede typechecking on the alternate configuration. So they are temporarily being checked twice, but I am suppressing the errors in these files as the configurations are not quite the same. I plan to unify the configurations so this is only a temporary state.
* There were some straggler type errors after these changes somehow, so I fixed them as needed. There weren't that many.

In the future, to start type checking a file, just remove the ignore-errors directive from the top of the file.

The codemod was done with this script authored by GPT-4:

```
import glob

exclude_patterns = [
    ...
]

for pattern in exclude_patterns:
    for filepath in glob.glob(pattern, recursive=True):
        if filepath.endswith('.py'):
            with open(filepath, 'r+') as f:
                content = f.read()
                f.seek(0, 0)
                f.write('# mypy: ignore-errors\n\n' + content)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118414
Approved by: https://github.com/thiagocrepaldi, https://github.com/albanD
2024-01-27 02:44:11 +00:00

457 lines
12 KiB
Python

# mypy: ignore-errors
""" Implementation of reduction operations, to be wrapped into arrays, dtypes etc
in the 'public' layer.
Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc
"""
from __future__ import annotations
import functools
from typing import Optional
import torch
from . import _dtypes_impl, _util
from ._normalizations import (
ArrayLike,
AxisLike,
DTypeLike,
KeepDims,
NotImplementedType,
OutArray,
)
def _deco_axis_expand(func):
"""
Generically handle axis arguments in reductions.
axis is *always* the 2nd arg in the function so no need to have a look at its signature
"""
@functools.wraps(func)
def wrapped(a, axis=None, *args, **kwds):
if axis is not None:
axis = _util.normalize_axis_tuple(axis, a.ndim)
if axis == ():
# So we insert a length-one axis and run the reduction along it.
# We cannot return a.clone() as this would sidestep the checks inside the function
newshape = _util.expand_shape(a.shape, axis=0)
a = a.reshape(newshape)
axis = (0,)
return func(a, axis, *args, **kwds)
return wrapped
def _atleast_float(dtype, other_dtype):
"""Return a dtype that is real or complex floating-point.
For inputs that are boolean or integer dtypes, this returns the default
float dtype; inputs that are complex get converted to the default complex
dtype; real floating-point dtypes (`float*`) get passed through unchanged
"""
if dtype is None:
dtype = other_dtype
if not (dtype.is_floating_point or dtype.is_complex):
return _dtypes_impl.default_dtypes().float_dtype
return dtype
@_deco_axis_expand
def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims: KeepDims = False):
return a.count_nonzero(axis)
@_deco_axis_expand
def argmax(
a: ArrayLike,
axis: AxisLike = None,
out: Optional[OutArray] = None,
*,
keepdims: KeepDims = False,
):
if a.is_complex():
raise NotImplementedError(f"argmax with dtype={a.dtype}.")
axis = _util.allow_only_single_axis(axis)
if a.dtype == torch.bool:
# RuntimeError: "argmax_cpu" not implemented for 'Bool'
a = a.to(torch.uint8)
return torch.argmax(a, axis)
@_deco_axis_expand
def argmin(
a: ArrayLike,
axis: AxisLike = None,
out: Optional[OutArray] = None,
*,
keepdims: KeepDims = False,
):
if a.is_complex():
raise NotImplementedError(f"argmin with dtype={a.dtype}.")
axis = _util.allow_only_single_axis(axis)
if a.dtype == torch.bool:
# RuntimeError: "argmin_cpu" not implemented for 'Bool'
a = a.to(torch.uint8)
return torch.argmin(a, axis)
@_deco_axis_expand
def any(
a: ArrayLike,
axis: AxisLike = None,
out: Optional[OutArray] = None,
keepdims: KeepDims = False,
*,
where: NotImplementedType = None,
):
axis = _util.allow_only_single_axis(axis)
axis_kw = {} if axis is None else {"dim": axis}
return torch.any(a, **axis_kw)
@_deco_axis_expand
def all(
a: ArrayLike,
axis: AxisLike = None,
out: Optional[OutArray] = None,
keepdims: KeepDims = False,
*,
where: NotImplementedType = None,
):
axis = _util.allow_only_single_axis(axis)
axis_kw = {} if axis is None else {"dim": axis}
return torch.all(a, **axis_kw)
@_deco_axis_expand
def amax(
a: ArrayLike,
axis: AxisLike = None,
out: Optional[OutArray] = None,
keepdims: KeepDims = False,
initial: NotImplementedType = None,
where: NotImplementedType = None,
):
if a.is_complex():
raise NotImplementedError(f"amax with dtype={a.dtype}")
return a.amax(axis)
max = amax
@_deco_axis_expand
def amin(
a: ArrayLike,
axis: AxisLike = None,
out: Optional[OutArray] = None,
keepdims: KeepDims = False,
initial: NotImplementedType = None,
where: NotImplementedType = None,
):
if a.is_complex():
raise NotImplementedError(f"amin with dtype={a.dtype}")
return a.amin(axis)
min = amin
@_deco_axis_expand
def ptp(
a: ArrayLike,
axis: AxisLike = None,
out: Optional[OutArray] = None,
keepdims: KeepDims = False,
):
return a.amax(axis) - a.amin(axis)
@_deco_axis_expand
def sum(
a: ArrayLike,
axis: AxisLike = None,
dtype: Optional[DTypeLike] = None,
out: Optional[OutArray] = None,
keepdims: KeepDims = False,
initial: NotImplementedType = None,
where: NotImplementedType = None,
):
assert dtype is None or isinstance(dtype, torch.dtype)
if dtype == torch.bool:
dtype = _dtypes_impl.default_dtypes().int_dtype
axis_kw = {} if axis is None else {"dim": axis}
return a.sum(dtype=dtype, **axis_kw)
@_deco_axis_expand
def prod(
a: ArrayLike,
axis: AxisLike = None,
dtype: Optional[DTypeLike] = None,
out: Optional[OutArray] = None,
keepdims: KeepDims = False,
initial: NotImplementedType = None,
where: NotImplementedType = None,
):
axis = _util.allow_only_single_axis(axis)
if dtype == torch.bool:
dtype = _dtypes_impl.default_dtypes().int_dtype
axis_kw = {} if axis is None else {"dim": axis}
return a.prod(dtype=dtype, **axis_kw)
product = prod
@_deco_axis_expand
def mean(
a: ArrayLike,
axis: AxisLike = None,
dtype: Optional[DTypeLike] = None,
out: Optional[OutArray] = None,
keepdims: KeepDims = False,
*,
where: NotImplementedType = None,
):
dtype = _atleast_float(dtype, a.dtype)
axis_kw = {} if axis is None else {"dim": axis}
result = a.mean(dtype=dtype, **axis_kw)
return result
@_deco_axis_expand
def std(
a: ArrayLike,
axis: AxisLike = None,
dtype: Optional[DTypeLike] = None,
out: Optional[OutArray] = None,
ddof=0,
keepdims: KeepDims = False,
*,
where: NotImplementedType = None,
):
in_dtype = dtype
dtype = _atleast_float(dtype, a.dtype)
tensor = _util.cast_if_needed(a, dtype)
result = tensor.std(dim=axis, correction=ddof)
return _util.cast_if_needed(result, in_dtype)
@_deco_axis_expand
def var(
a: ArrayLike,
axis: AxisLike = None,
dtype: Optional[DTypeLike] = None,
out: Optional[OutArray] = None,
ddof=0,
keepdims: KeepDims = False,
*,
where: NotImplementedType = None,
):
in_dtype = dtype
dtype = _atleast_float(dtype, a.dtype)
tensor = _util.cast_if_needed(a, dtype)
result = tensor.var(dim=axis, correction=ddof)
return _util.cast_if_needed(result, in_dtype)
# cumsum / cumprod are almost reductions:
# 1. no keepdims
# 2. axis=None flattens
def cumsum(
a: ArrayLike,
axis: AxisLike = None,
dtype: Optional[DTypeLike] = None,
out: Optional[OutArray] = None,
):
if dtype == torch.bool:
dtype = _dtypes_impl.default_dtypes().int_dtype
if dtype is None:
dtype = a.dtype
(a,), axis = _util.axis_none_flatten(a, axis=axis)
axis = _util.normalize_axis_index(axis, a.ndim)
return a.cumsum(axis=axis, dtype=dtype)
def cumprod(
a: ArrayLike,
axis: AxisLike = None,
dtype: Optional[DTypeLike] = None,
out: Optional[OutArray] = None,
):
if dtype == torch.bool:
dtype = _dtypes_impl.default_dtypes().int_dtype
if dtype is None:
dtype = a.dtype
(a,), axis = _util.axis_none_flatten(a, axis=axis)
axis = _util.normalize_axis_index(axis, a.ndim)
return a.cumprod(axis=axis, dtype=dtype)
cumproduct = cumprod
def average(
a: ArrayLike,
axis=None,
weights: ArrayLike = None,
returned=False,
*,
keepdims=False,
):
if weights is None:
result = mean(a, axis=axis)
wsum = torch.as_tensor(a.numel() / result.numel(), dtype=result.dtype)
else:
if not a.dtype.is_floating_point:
a = a.double()
# axis & weights
if a.shape != weights.shape:
if axis is None:
raise TypeError(
"Axis must be specified when shapes of a and weights differ."
)
if weights.ndim != 1:
raise TypeError(
"1D weights expected when shapes of a and weights differ."
)
if weights.shape[0] != a.shape[axis]:
raise ValueError(
"Length of weights not compatible with specified axis."
)
# setup weight to broadcast along axis
weights = torch.broadcast_to(weights, (a.ndim - 1) * (1,) + weights.shape)
weights = weights.swapaxes(-1, axis)
# do the work
result_dtype = _dtypes_impl.result_type_impl(a, weights)
numerator = sum(a * weights, axis, dtype=result_dtype)
wsum = sum(weights, axis, dtype=result_dtype)
result = numerator / wsum
# We process keepdims manually because the decorator does not deal with variadic returns
if keepdims:
result = _util.apply_keepdims(result, axis, a.ndim)
if returned:
if wsum.shape != result.shape:
wsum = torch.broadcast_to(wsum, result.shape).clone()
return result, wsum
else:
return result
# Not using deco_axis_expand as it assumes that axis is the second arg
def quantile(
a: ArrayLike,
q: ArrayLike,
axis: AxisLike = None,
out: Optional[OutArray] = None,
overwrite_input=False,
method="linear",
keepdims: KeepDims = False,
*,
interpolation: NotImplementedType = None,
):
if overwrite_input:
# raise NotImplementedError("overwrite_input in quantile not implemented.")
# NumPy documents that `overwrite_input` MAY modify inputs:
# https://numpy.org/doc/stable/reference/generated/numpy.percentile.html#numpy-percentile
# Here we choose to work out-of-place because why not.
pass
if not a.dtype.is_floating_point:
dtype = _dtypes_impl.default_dtypes().float_dtype
a = a.to(dtype)
# edge case: torch.quantile only supports float32 and float64
if a.dtype == torch.float16:
a = a.to(torch.float32)
if axis is None:
a = a.flatten()
q = q.flatten()
axis = (0,)
else:
axis = _util.normalize_axis_tuple(axis, a.ndim)
# FIXME(Mario) Doesn't np.quantile accept a tuple?
# torch.quantile does accept a number. If we don't want to implement the tuple behaviour
# (it's deffo low prio) change `normalize_axis_tuple` into a normalize_axis index above.
axis = _util.allow_only_single_axis(axis)
q = _util.cast_if_needed(q, a.dtype)
return torch.quantile(a, q, axis=axis, interpolation=method)
def percentile(
a: ArrayLike,
q: ArrayLike,
axis: AxisLike = None,
out: Optional[OutArray] = None,
overwrite_input=False,
method="linear",
keepdims: KeepDims = False,
*,
interpolation: NotImplementedType = None,
):
# np.percentile(float_tensor, 30) : q.dtype is int64 => q / 100.0 is float32
if _dtypes_impl.python_type_for_torch(q.dtype) == int:
q = q.to(_dtypes_impl.default_dtypes().float_dtype)
qq = q / 100.0
return quantile(
a,
qq,
axis=axis,
overwrite_input=overwrite_input,
method=method,
keepdims=keepdims,
interpolation=interpolation,
)
def median(
a: ArrayLike,
axis=None,
out: Optional[OutArray] = None,
overwrite_input=False,
keepdims: KeepDims = False,
):
return quantile(
a,
torch.as_tensor(0.5),
axis=axis,
overwrite_input=overwrite_input,
out=out,
keepdims=keepdims,
)