Files
pytorch/torch/_numpy/_normalizations.py
2025-07-09 11:02:22 +00:00

262 lines
8.5 KiB
Python

# mypy: ignore-errors
""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on."""
from __future__ import annotations
import functools
import inspect
import operator
import typing
import torch
from . import _dtypes, _dtypes_impl, _util
ArrayLike = typing.TypeVar("ArrayLike")
Scalar = typing.Union[int, float, complex, bool]
ArrayLikeOrScalar = typing.Union[ArrayLike, Scalar]
DTypeLike = typing.TypeVar("DTypeLike")
AxisLike = typing.TypeVar("AxisLike")
NDArray = typing.TypeVar("NDArray")
CastingModes = typing.TypeVar("CastingModes")
KeepDims = typing.TypeVar("KeepDims")
# OutArray is to annotate the out= array argument.
#
# This one is special is several respects:
# First, It needs to be an NDArray, and we need to preserve the `result is out`
# semantics. Therefore, we cannot just extract the Tensor from the out array.
# So we never pass the out array to implementer functions and handle it in the
# `normalizer` below.
# Second, the out= argument can be either keyword or positional argument, and
# as a positional arg, it can be anywhere in the signature.
# To handle all this, we define a special `OutArray` annotation and dispatch on it.
#
OutArray = typing.TypeVar("OutArray")
try:
from typing import NotImplementedType
except ImportError:
NotImplementedType = typing.TypeVar("NotImplementedType")
def normalize_array_like(x, parm=None): # codespell:ignore
from ._ndarray import asarray
return asarray(x).tensor
def normalize_array_like_or_scalar(x, parm=None): # codespell:ignore
if _dtypes_impl.is_scalar_or_symbolic(x):
return x
return normalize_array_like(x, parm) # codespell:ignore
def normalize_optional_array_like_or_scalar(x, parm=None): # codespell:ignore
if x is None:
return None
return normalize_array_like_or_scalar(x, parm) # codespell:ignore
def normalize_optional_array_like(x, parm=None): # codespell:ignore
# This explicit normalizer is needed because otherwise normalize_array_like
# does not run for a parameter annotated as Optional[ArrayLike]
return None if x is None else normalize_array_like(x, parm) # codespell:ignore
def normalize_seq_array_like(x, parm=None): # codespell:ignore
return tuple(normalize_array_like(value) for value in x)
def normalize_dtype(dtype, parm=None): # codespell:ignore
# cf _decorators.dtype_to_torch
torch_dtype = None
if dtype is not None:
dtype = _dtypes.dtype(dtype)
torch_dtype = dtype.torch_dtype
return torch_dtype
def normalize_not_implemented(arg, parm): # codespell:ignore
if arg != parm.default: # codespell:ignore
raise NotImplementedError(
f"'{parm.name}' parameter is not supported." # codespell:ignore
)
def normalize_axis_like(arg, parm=None): # codespell:ignore
from ._ndarray import ndarray
if isinstance(arg, ndarray):
arg = operator.index(arg)
return arg
def normalize_ndarray(arg, parm=None): # codespell:ignore
# check the arg is an ndarray, extract its tensor attribute
if arg is None:
return arg
from ._ndarray import ndarray
if not isinstance(arg, ndarray):
raise TypeError(f"'{parm.name}' must be an array") # codespell:ignore
return arg.tensor
def normalize_outarray(arg, parm=None): # codespell:ignore
# almost normalize_ndarray, only return the array, not its tensor
if arg is None:
return arg
from ._ndarray import ndarray
# Dynamo can pass torch tensors as out arguments,
# wrap it in an ndarray before processing
if isinstance(arg, torch.Tensor):
arg = ndarray(arg)
if not isinstance(arg, ndarray):
raise TypeError(f"'{parm.name}' must be an array") # codespell:ignore
return arg
def normalize_casting(arg, parm=None): # codespell:ignore
if arg not in ["no", "equiv", "safe", "same_kind", "unsafe"]:
raise ValueError(
f"casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe' (got '{arg}')"
)
return arg
normalizers = {
"ArrayLike": normalize_array_like,
"ArrayLikeOrScalar": normalize_array_like_or_scalar,
"Optional[ArrayLike]": normalize_optional_array_like,
"Sequence[ArrayLike]": normalize_seq_array_like,
"Optional[ArrayLikeOrScalar]": normalize_optional_array_like_or_scalar,
"Optional[NDArray]": normalize_ndarray,
"Optional[OutArray]": normalize_outarray,
"NDArray": normalize_ndarray,
"Optional[DTypeLike]": normalize_dtype,
"AxisLike": normalize_axis_like,
"NotImplementedType": normalize_not_implemented,
"Optional[CastingModes]": normalize_casting,
}
def maybe_normalize(arg, parm): # codespell:ignore
"""Normalize arg if a normalizer is registered."""
normalizer = normalizers.get(parm.annotation, None) # codespell:ignore
return normalizer(arg, parm) if normalizer else arg # codespell:ignore
# ### Return value helpers ###
def maybe_copy_to(out, result, promote_scalar_result=False):
# NB: here out is either an ndarray or None
if out is None:
return result
elif isinstance(result, torch.Tensor):
if result.shape != out.shape:
can_fit = result.numel() == 1 and out.ndim == 0
if promote_scalar_result and can_fit:
result = result.squeeze()
else:
raise ValueError(
f"Bad size of the out array: out.shape = {out.shape}"
f" while result.shape = {result.shape}."
)
out.tensor.copy_(result)
return out
elif isinstance(result, (tuple, list)):
return type(result)(
maybe_copy_to(o, r, promote_scalar_result) for o, r in zip(out, result)
)
else:
raise AssertionError # We should never hit this path
def wrap_tensors(result):
from ._ndarray import ndarray
if isinstance(result, torch.Tensor):
return ndarray(result)
elif isinstance(result, (tuple, list)):
result = type(result)(wrap_tensors(x) for x in result)
return result
def array_or_scalar(values, py_type=float, return_scalar=False):
if return_scalar:
return py_type(values.item())
else:
from ._ndarray import ndarray
return ndarray(values)
# ### The main decorator to normalize arguments / postprocess the output ###
def normalizer(_func=None, *, promote_scalar_result=False):
def normalizer_inner(func):
@functools.wraps(func)
def wrapped(*args, **kwds):
sig = inspect.signature(func)
params = sig.parameters
first_param = next(iter(params.values()))
# NumPy's API does not have positional args before variadic positional args
if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
args = [maybe_normalize(arg, first_param) for arg in args]
else:
# NB: extra unknown arguments: pass through, will raise in func(*args) below
args = (
tuple(
maybe_normalize(arg, parm) # codespell:ignore
for arg, parm in zip(args, params.values()) # codespell:ignore
)
+ args[len(params.values()) :]
)
kwds = {
name: maybe_normalize(arg, params[name]) if name in params else arg
for name, arg in kwds.items()
}
result = func(*args, **kwds)
# keepdims
bound_args = None
if "keepdims" in params and params["keepdims"].annotation == "KeepDims":
# keepdims can be in any position so we need sig.bind
bound_args = sig.bind(*args, **kwds).arguments
if bound_args.get("keepdims", False):
# In this case the first arg is the initial tensor and
# the second arg is (optionally) the axis
tensor = args[0]
axis = bound_args.get("axis")
result = _util.apply_keepdims(result, axis, tensor.ndim)
# out
if "out" in params:
# out can be in any position so we need sig.bind
if bound_args is None:
bound_args = sig.bind(*args, **kwds).arguments
out = bound_args.get("out")
result = maybe_copy_to(out, result, promote_scalar_result)
result = wrap_tensors(result)
return result
return wrapped
if _func is None:
return normalizer_inner
else:
return normalizer_inner(_func)