Files
pytorch/torch/_numpy/_ndarray.py
Manuel Candales fb9a5d248f Fix torch._numpy to match NumPy when empty ellipsis causes advanced indexing separation (#158297)
Fixes #141563

In NumPy, an ellipsis always acts as a separator between advanced indices, even when the ellipsis doesn't actually match any dimensions. In PyTorch an empty ellipsis doesn't cause a separation. This leads to differing behavior between Numpy and PyTorch in this edge case.

This difference in behavior leads to a bug when using torch.compile:
```python
>>> import numpy as np
>>> f = lambda x: x[:,(0,1),...,(0,1)].shape
>>> a = np.ones((3, 4, 5))
>>> f(a)
(2, 3)
>>> torch.compile(f)(a)
(3, 2)
```

Similarly to #157676, this PR doesn't change PyTorch's behavior, but it fixes the translation layer, ensuring torch._numpy compatibility with NumPy. I am marking this PR as fixing #141563, even though PyTorch behavior isn't modified.

Notice that there are still some other bugs in PyTorch's advanced indexing, that need to be fixed (mainly regarding proper accounting of dimensions when multidimensional boolean masks are present). But those need to be fixed at the ATen operator level. Examples:
- #71673
- #107699
- #158125

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158297
Approved by: https://github.com/soumith
2025-07-16 08:11:53 +00:00

721 lines
21 KiB
Python

# mypy: ignore-errors
from __future__ import annotations
import builtins
import math
import operator
from collections.abc import Sequence
import torch
from . import _dtypes, _dtypes_impl, _funcs, _ufuncs, _util
from ._normalizations import (
ArrayLike,
normalize_array_like,
normalizer,
NotImplementedType,
)
newaxis = None
FLAGS = [
"C_CONTIGUOUS",
"F_CONTIGUOUS",
"OWNDATA",
"WRITEABLE",
"ALIGNED",
"WRITEBACKIFCOPY",
"FNC",
"FORC",
"BEHAVED",
"CARRAY",
"FARRAY",
]
SHORTHAND_TO_FLAGS = {
"C": "C_CONTIGUOUS",
"F": "F_CONTIGUOUS",
"O": "OWNDATA",
"W": "WRITEABLE",
"A": "ALIGNED",
"X": "WRITEBACKIFCOPY",
"B": "BEHAVED",
"CA": "CARRAY",
"FA": "FARRAY",
}
class Flags:
def __init__(self, flag_to_value: dict):
assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check
self._flag_to_value = flag_to_value
def __getattr__(self, attr: str):
if attr.islower() and attr.upper() in FLAGS:
return self[attr.upper()]
else:
raise AttributeError(f"No flag attribute '{attr}'")
def __getitem__(self, key):
if key in SHORTHAND_TO_FLAGS.keys():
key = SHORTHAND_TO_FLAGS[key]
if key in FLAGS:
try:
return self._flag_to_value[key]
except KeyError as e:
raise NotImplementedError(f"{key=}") from e
else:
raise KeyError(f"No flag key '{key}'")
def __setattr__(self, attr, value):
if attr.islower() and attr.upper() in FLAGS:
self[attr.upper()] = value
else:
super().__setattr__(attr, value)
def __setitem__(self, key, value):
if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys():
raise NotImplementedError("Modifying flags is not implemented")
else:
raise KeyError(f"No flag key '{key}'")
def create_method(fn, name=None):
name = name or fn.__name__
def f(*args, **kwargs):
return fn(*args, **kwargs)
f.__name__ = name
f.__qualname__ = f"ndarray.{name}"
return f
# Map ndarray.name_method -> np.name_func
# If name_func == None, it means that name_method == name_func
methods = {
"clip": None,
"nonzero": None,
"repeat": None,
"round": None,
"squeeze": None,
"swapaxes": None,
"ravel": None,
# linalg
"diagonal": None,
"dot": None,
"trace": None,
# sorting
"argsort": None,
"searchsorted": None,
# reductions
"argmax": None,
"argmin": None,
"any": None,
"all": None,
"max": None,
"min": None,
"ptp": None,
"sum": None,
"prod": None,
"mean": None,
"var": None,
"std": None,
# scans
"cumsum": None,
"cumprod": None,
# advanced indexing
"take": None,
"choose": None,
}
dunder = {
"abs": "absolute",
"invert": None,
"pos": "positive",
"neg": "negative",
"gt": "greater",
"lt": "less",
"ge": "greater_equal",
"le": "less_equal",
}
# dunder methods with right-looking and in-place variants
ri_dunder = {
"add": None,
"sub": "subtract",
"mul": "multiply",
"truediv": "divide",
"floordiv": "floor_divide",
"pow": "power",
"mod": "remainder",
"and": "bitwise_and",
"or": "bitwise_or",
"xor": "bitwise_xor",
"lshift": "left_shift",
"rshift": "right_shift",
"matmul": None,
}
def _upcast_int_indices(index):
if isinstance(index, torch.Tensor):
if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8):
return index.to(torch.int64)
elif isinstance(index, tuple):
return tuple(_upcast_int_indices(i) for i in index)
return index
def _has_advanced_indexing(index):
"""Check if there's any advanced indexing"""
return any(
isinstance(idx, (Sequence, bool))
or (isinstance(idx, torch.Tensor) and (idx.dtype == torch.bool or idx.ndim > 0))
for idx in index
)
def _numpy_compatible_indexing(index):
"""Convert scalar indices to lists when advanced indexing is present for NumPy compatibility."""
if not isinstance(index, tuple):
index = (index,)
# Check if there's any advanced indexing (sequences, booleans, or tensors)
has_advanced = _has_advanced_indexing(index)
if not has_advanced:
return index
# Convert integer scalar indices to single-element lists when advanced indexing is present
# Note: Do NOT convert boolean scalars (True/False) as they have special meaning in NumPy
converted = []
for idx in index:
if isinstance(idx, int) and not isinstance(idx, bool):
# Integer scalars should be converted to lists
converted.append([idx])
elif (
isinstance(idx, torch.Tensor)
and idx.ndim == 0
and not torch.is_floating_point(idx)
and idx.dtype != torch.bool
):
# Zero-dimensional tensors holding integers should be treated the same as integer scalars
converted.append([idx])
else:
# Everything else (booleans, lists, slices, etc.) stays as is
converted.append(idx)
return tuple(converted)
def _get_bool_depth(s):
"""Returns the depth of a boolean sequence/tensor"""
if isinstance(s, bool):
return True, 0
if isinstance(s, torch.Tensor) and s.dtype == torch.bool:
return True, s.ndim
if not (isinstance(s, Sequence) and s and s[0] != s):
return False, 0
is_bool, depth = _get_bool_depth(s[0])
return is_bool, depth + 1
def _numpy_empty_ellipsis_patch(index, tensor_ndim):
"""
Patch for NumPy-compatible ellipsis behavior when ellipsis doesn't match any dimensions.
In NumPy, when an ellipsis (...) doesn't actually match any dimensions of the input array,
it still acts as a separator between advanced indices. PyTorch doesn't have this behavior.
This function detects when we have:
1. Advanced indexing on both sides of an ellipsis
2. The ellipsis doesn't actually match any dimensions
"""
if not isinstance(index, tuple):
index = (index,)
# Find ellipsis position
ellipsis_pos = None
for i, idx in enumerate(index):
if idx is Ellipsis:
ellipsis_pos = i
break
# If no ellipsis, no patch needed
if ellipsis_pos is None:
return index, lambda x: x, lambda x: x
# Count non-ellipsis dimensions consumed by the index
consumed_dims = 0
for idx in index:
is_bool, depth = _get_bool_depth(idx)
if is_bool:
consumed_dims += depth
elif idx is Ellipsis or idx is None:
continue
else:
consumed_dims += 1
# Calculate how many dimensions the ellipsis should match
ellipsis_dims = tensor_ndim - consumed_dims
# Check if ellipsis doesn't match any dimensions
if ellipsis_dims == 0:
# Check if we have advanced indexing on both sides of ellipsis
left_advanced = _has_advanced_indexing(index[:ellipsis_pos])
right_advanced = _has_advanced_indexing(index[ellipsis_pos + 1 :])
if left_advanced and right_advanced:
# This is the case where NumPy and PyTorch differ
# We need to ensure the advanced indices are treated as separated
new_index = index[:ellipsis_pos] + (None,) + index[ellipsis_pos + 1 :]
end_ndims = 1 + sum(
1 for idx in index[ellipsis_pos + 1 :] if isinstance(idx, slice)
)
def squeeze_fn(x):
return x.squeeze(-end_ndims)
def unsqueeze_fn(x):
if isinstance(x, torch.Tensor) and x.ndim >= end_ndims:
return x.unsqueeze(-end_ndims)
return x
return new_index, squeeze_fn, unsqueeze_fn
return index, lambda x: x, lambda x: x
# Used to indicate that a parameter is unspecified (as opposed to explicitly
# `None`)
class _Unspecified:
pass
_Unspecified.unspecified = _Unspecified()
###############################################################
# ndarray class #
###############################################################
class ndarray:
def __init__(self, t=None):
if t is None:
self.tensor = torch.Tensor()
elif isinstance(t, torch.Tensor):
self.tensor = t
else:
raise ValueError(
"ndarray constructor is not recommended; prefer"
"either array(...) or zeros/empty(...)"
)
# Register NumPy functions as methods
for method, name in methods.items():
fn = getattr(_funcs, name or method)
vars()[method] = create_method(fn, method)
# Regular methods but coming from ufuncs
conj = create_method(_ufuncs.conjugate, "conj")
conjugate = create_method(_ufuncs.conjugate)
for method, name in dunder.items():
fn = getattr(_ufuncs, name or method)
method = f"__{method}__"
vars()[method] = create_method(fn, method)
for method, name in ri_dunder.items():
fn = getattr(_ufuncs, name or method)
plain = f"__{method}__"
vars()[plain] = create_method(fn, plain)
rvar = f"__r{method}__"
vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar)
ivar = f"__i{method}__"
vars()[ivar] = create_method(
lambda self, other, fn=fn: fn(self, other, out=self), ivar
)
# There's no __idivmod__
__divmod__ = create_method(_ufuncs.divmod, "__divmod__")
__rdivmod__ = create_method(
lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__"
)
# prevent loop variables leaking into the ndarray class namespace
del ivar, rvar, name, plain, fn, method
@property
def shape(self):
return tuple(self.tensor.shape)
@property
def size(self):
return self.tensor.numel()
@property
def ndim(self):
return self.tensor.ndim
@property
def dtype(self):
return _dtypes.dtype(self.tensor.dtype)
@property
def strides(self):
elsize = self.tensor.element_size()
return tuple(stride * elsize for stride in self.tensor.stride())
@property
def itemsize(self):
return self.tensor.element_size()
@property
def flags(self):
# Note contiguous in torch is assumed C-style
return Flags(
{
"C_CONTIGUOUS": self.tensor.is_contiguous(),
"F_CONTIGUOUS": self.T.tensor.is_contiguous(),
"OWNDATA": self.tensor._base is None,
"WRITEABLE": True, # pytorch does not have readonly tensors
}
)
@property
def data(self):
return self.tensor.data_ptr()
@property
def nbytes(self):
return self.tensor.storage().nbytes()
@property
def T(self):
return self.transpose()
@property
def real(self):
return _funcs.real(self)
@real.setter
def real(self, value):
self.tensor.real = asarray(value).tensor
@property
def imag(self):
return _funcs.imag(self)
@imag.setter
def imag(self, value):
self.tensor.imag = asarray(value).tensor
# ctors
def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True):
if order != "K":
raise NotImplementedError(f"astype(..., order={order} is not implemented.")
if casting != "unsafe":
raise NotImplementedError(
f"astype(..., casting={casting} is not implemented."
)
if not subok:
raise NotImplementedError(f"astype(..., subok={subok} is not implemented.")
if not copy:
raise NotImplementedError(f"astype(..., copy={copy} is not implemented.")
torch_dtype = _dtypes.dtype(dtype).torch_dtype
t = self.tensor.to(torch_dtype)
return ndarray(t)
@normalizer
def copy(self: ArrayLike, order: NotImplementedType = "C"):
return self.clone()
@normalizer
def flatten(self: ArrayLike, order: NotImplementedType = "C"):
return torch.flatten(self)
def resize(self, *new_shape, refcheck=False):
# NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
if refcheck:
raise NotImplementedError(
f"resize(..., refcheck={refcheck} is not implemented."
)
if new_shape in [(), (None,)]:
return
# support both x.resize((2, 2)) and x.resize(2, 2)
if len(new_shape) == 1:
new_shape = new_shape[0]
if isinstance(new_shape, int):
new_shape = (new_shape,)
if builtins.any(x < 0 for x in new_shape):
raise ValueError("all elements of `new_shape` must be non-negative")
new_numel, old_numel = math.prod(new_shape), self.tensor.numel()
self.tensor.resize_(new_shape)
if new_numel >= old_numel:
# zero-fill new elements
assert self.tensor.is_contiguous()
b = self.tensor.flatten() # does not copy
b[old_numel:].zero_()
def view(self, dtype=_Unspecified.unspecified, type=_Unspecified.unspecified):
if dtype is _Unspecified.unspecified:
dtype = self.dtype
if type is not _Unspecified.unspecified:
raise NotImplementedError(f"view(..., type={type} is not implemented.")
torch_dtype = _dtypes.dtype(dtype).torch_dtype
tview = self.tensor.view(torch_dtype)
return ndarray(tview)
@normalizer
def fill(self, value: ArrayLike):
# Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and
# error out on D > 0 arrays
self.tensor.fill_(value)
def tolist(self):
return self.tensor.tolist()
def __iter__(self):
return (ndarray(x) for x in self.tensor.__iter__())
def __str__(self):
return (
str(self.tensor)
.replace("tensor", "torch.ndarray")
.replace("dtype=torch.", "dtype=")
)
__repr__ = create_method(__str__)
def __eq__(self, other):
try:
return _ufuncs.equal(self, other)
except (RuntimeError, TypeError):
# Failed to convert other to array: definitely not equal.
falsy = torch.full(self.shape, fill_value=False, dtype=bool)
return asarray(falsy)
def __ne__(self, other):
return ~(self == other)
def __index__(self):
try:
return operator.index(self.tensor.item())
except Exception as exc:
raise TypeError(
"only integer scalar arrays can be converted to a scalar index"
) from exc
def __bool__(self):
return bool(self.tensor)
def __int__(self):
return int(self.tensor)
def __float__(self):
return float(self.tensor)
def __complex__(self):
return complex(self.tensor)
def is_integer(self):
try:
v = self.tensor.item()
result = int(v) == v
except Exception:
result = False
return result
def __len__(self):
return self.tensor.shape[0]
def __contains__(self, x):
return self.tensor.__contains__(x)
def transpose(self, *axes):
# np.transpose(arr, axis=None) but arr.transpose(*axes)
return _funcs.transpose(self, axes)
def reshape(self, *shape, order="C"):
# arr.reshape(shape) and arr.reshape(*shape)
return _funcs.reshape(self, shape, order=order)
def sort(self, axis=-1, kind=None, order=None):
# ndarray.sort works in-place
_funcs.copyto(self, _funcs.sort(self, axis, kind, order))
def item(self, *args):
# Mimic NumPy's implementation with three special cases (no arguments,
# a flat index and a multi-index):
# https://github.com/numpy/numpy/blob/main/numpy/_core/src/multiarray/methods.c#L702
if args == ():
return self.tensor.item()
elif len(args) == 1:
# int argument
return self.ravel()[args[0]]
else:
return self.__getitem__(args)
def __getitem__(self, index):
tensor = self.tensor
def neg_step(i, s):
if not (isinstance(s, slice) and s.step is not None and s.step < 0):
return s
nonlocal tensor
tensor = torch.flip(tensor, (i,))
# Account for the fact that a slice includes the start but not the end
assert isinstance(s.start, int) or s.start is None
assert isinstance(s.stop, int) or s.stop is None
start = s.stop + 1 if s.stop else None
stop = s.start + 1 if s.start else None
return slice(start, stop, -s.step)
if isinstance(index, Sequence):
index = type(index)(neg_step(i, s) for i, s in enumerate(index))
else:
index = neg_step(0, index)
index = _util.ndarrays_to_tensors(index)
index = _upcast_int_indices(index)
# Apply NumPy-compatible indexing conversion
index = _numpy_compatible_indexing(index)
# Apply NumPy-compatible empty ellipsis behavior
index, maybe_squeeze, _ = _numpy_empty_ellipsis_patch(index, tensor.ndim)
return maybe_squeeze(ndarray(tensor.__getitem__(index)))
def __setitem__(self, index, value):
index = _util.ndarrays_to_tensors(index)
index = _upcast_int_indices(index)
# Apply NumPy-compatible indexing conversion
index = _numpy_compatible_indexing(index)
# Apply NumPy-compatible empty ellipsis behavior
index, _, maybe_unsqueeze = _numpy_empty_ellipsis_patch(index, self.tensor.ndim)
if not _dtypes_impl.is_scalar(value):
value = normalize_array_like(value)
value = _util.cast_if_needed(value, self.tensor.dtype)
return self.tensor.__setitem__(index, maybe_unsqueeze(value))
take = _funcs.take
put = _funcs.put
def __dlpack__(self, *, stream=None):
return self.tensor.__dlpack__(stream=stream)
def __dlpack_device__(self):
return self.tensor.__dlpack_device__()
def _tolist(obj):
"""Recursively convert tensors into lists."""
a1 = []
for elem in obj:
if isinstance(elem, (list, tuple)):
elem = _tolist(elem)
if isinstance(elem, ndarray):
a1.append(elem.tensor.tolist())
else:
a1.append(elem)
return a1
# This is the ideally the only place which talks to ndarray directly.
# The rest goes through asarray (preferred) or array.
def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None):
if subok is not False:
raise NotImplementedError("'subok' parameter is not supported.")
if like is not None:
raise NotImplementedError("'like' parameter is not supported.")
if order != "K":
raise NotImplementedError
# a happy path
if (
isinstance(obj, ndarray)
and copy is False
and dtype is None
and ndmin <= obj.ndim
):
return obj
if isinstance(obj, (list, tuple)):
# FIXME and they have the same dtype, device, etc
if obj and all(isinstance(x, torch.Tensor) for x in obj):
# list of arrays: *under torch.Dynamo* these are FakeTensors
obj = torch.stack(obj)
else:
# XXX: remove tolist
# lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists
obj = _tolist(obj)
# is obj an ndarray already?
if isinstance(obj, ndarray):
obj = obj.tensor
# is a specific dtype requested?
torch_dtype = None
if dtype is not None:
torch_dtype = _dtypes.dtype(dtype).torch_dtype
tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin)
return ndarray(tensor)
def asarray(a, dtype=None, order="K", *, like=None):
return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
def ascontiguousarray(a, dtype=None, *, like=None):
arr = asarray(a, dtype=dtype, like=like)
if not arr.tensor.is_contiguous():
arr.tensor = arr.tensor.contiguous()
return arr
def from_dlpack(x, /):
t = torch.from_dlpack(x)
return ndarray(t)
def _extract_dtype(entry):
try:
dty = _dtypes.dtype(entry)
except Exception:
dty = asarray(entry).dtype
return dty
def can_cast(from_, to, casting="safe"):
from_ = _extract_dtype(from_)
to_ = _extract_dtype(to)
return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)
def result_type(*arrays_and_dtypes):
tensors = []
for entry in arrays_and_dtypes:
try:
t = asarray(entry).tensor
except (RuntimeError, ValueError, TypeError):
dty = _dtypes.dtype(entry)
t = torch.empty(1, dtype=dty.torch_dtype)
tensors.append(t)
torch_dtype = _dtypes_impl.result_type_impl(*tensors)
return _dtypes.dtype(torch_dtype)