Fix torch._numpy to match NumPy when empty ellipsis causes advanced indexing separation (#158297)

Fixes #141563

In NumPy, an ellipsis always acts as a separator between advanced indices, even when the ellipsis doesn't actually match any dimensions. In PyTorch an empty ellipsis doesn't cause a separation. This leads to differing behavior between Numpy and PyTorch in this edge case.

This difference in behavior leads to a bug when using torch.compile:
```python
>>> import numpy as np
>>> f = lambda x: x[:,(0,1),...,(0,1)].shape
>>> a = np.ones((3, 4, 5))
>>> f(a)
(2, 3)
>>> torch.compile(f)(a)
(3, 2)
```

Similarly to #157676, this PR doesn't change PyTorch's behavior, but it fixes the translation layer, ensuring torch._numpy compatibility with NumPy. I am marking this PR as fixing #141563, even though PyTorch behavior isn't modified.

Notice that there are still some other bugs in PyTorch's advanced indexing, that need to be fixed (mainly regarding proper accounting of dimensions when multidimensional boolean masks are present). But those need to be fixed at the ATen operator level. Examples:
- #71673
- #107699
- #158125

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158297
Approved by: https://github.com/soumith
This commit is contained in:
Manuel Candales
2025-07-16 08:11:50 +00:00
committed by PyTorch MergeBot
parent ddf502c988
commit fb9a5d248f
2 changed files with 165 additions and 7 deletions

View File

@ -169,17 +169,22 @@ def _upcast_int_indices(index):
return index
def _has_advanced_indexing(index):
"""Check if there's any advanced indexing"""
return any(
isinstance(idx, (Sequence, bool))
or (isinstance(idx, torch.Tensor) and (idx.dtype == torch.bool or idx.ndim > 0))
for idx in index
)
def _numpy_compatible_indexing(index):
"""Convert scalar indices to lists when advanced indexing is present for NumPy compatibility."""
if not isinstance(index, tuple):
index = (index,)
# Check if there's any advanced indexing (sequences, booleans, or tensors)
has_advanced = any(
isinstance(idx, (Sequence, bool))
or (isinstance(idx, torch.Tensor) and (idx.dtype == torch.bool or idx.ndim > 0))
for idx in index
)
has_advanced = _has_advanced_indexing(index)
if not has_advanced:
return index
@ -206,6 +211,84 @@ def _numpy_compatible_indexing(index):
return tuple(converted)
def _get_bool_depth(s):
"""Returns the depth of a boolean sequence/tensor"""
if isinstance(s, bool):
return True, 0
if isinstance(s, torch.Tensor) and s.dtype == torch.bool:
return True, s.ndim
if not (isinstance(s, Sequence) and s and s[0] != s):
return False, 0
is_bool, depth = _get_bool_depth(s[0])
return is_bool, depth + 1
def _numpy_empty_ellipsis_patch(index, tensor_ndim):
"""
Patch for NumPy-compatible ellipsis behavior when ellipsis doesn't match any dimensions.
In NumPy, when an ellipsis (...) doesn't actually match any dimensions of the input array,
it still acts as a separator between advanced indices. PyTorch doesn't have this behavior.
This function detects when we have:
1. Advanced indexing on both sides of an ellipsis
2. The ellipsis doesn't actually match any dimensions
"""
if not isinstance(index, tuple):
index = (index,)
# Find ellipsis position
ellipsis_pos = None
for i, idx in enumerate(index):
if idx is Ellipsis:
ellipsis_pos = i
break
# If no ellipsis, no patch needed
if ellipsis_pos is None:
return index, lambda x: x, lambda x: x
# Count non-ellipsis dimensions consumed by the index
consumed_dims = 0
for idx in index:
is_bool, depth = _get_bool_depth(idx)
if is_bool:
consumed_dims += depth
elif idx is Ellipsis or idx is None:
continue
else:
consumed_dims += 1
# Calculate how many dimensions the ellipsis should match
ellipsis_dims = tensor_ndim - consumed_dims
# Check if ellipsis doesn't match any dimensions
if ellipsis_dims == 0:
# Check if we have advanced indexing on both sides of ellipsis
left_advanced = _has_advanced_indexing(index[:ellipsis_pos])
right_advanced = _has_advanced_indexing(index[ellipsis_pos + 1 :])
if left_advanced and right_advanced:
# This is the case where NumPy and PyTorch differ
# We need to ensure the advanced indices are treated as separated
new_index = index[:ellipsis_pos] + (None,) + index[ellipsis_pos + 1 :]
end_ndims = 1 + sum(
1 for idx in index[ellipsis_pos + 1 :] if isinstance(idx, slice)
)
def squeeze_fn(x):
return x.squeeze(-end_ndims)
def unsqueeze_fn(x):
if isinstance(x, torch.Tensor) and x.ndim >= end_ndims:
return x.unsqueeze(-end_ndims)
return x
return new_index, squeeze_fn, unsqueeze_fn
return index, lambda x: x, lambda x: x
# Used to indicate that a parameter is unspecified (as opposed to explicitly
# `None`)
class _Unspecified:
@ -507,19 +590,23 @@ class ndarray:
index = _upcast_int_indices(index)
# Apply NumPy-compatible indexing conversion
index = _numpy_compatible_indexing(index)
return ndarray(tensor.__getitem__(index))
# Apply NumPy-compatible empty ellipsis behavior
index, maybe_squeeze, _ = _numpy_empty_ellipsis_patch(index, tensor.ndim)
return maybe_squeeze(ndarray(tensor.__getitem__(index)))
def __setitem__(self, index, value):
index = _util.ndarrays_to_tensors(index)
index = _upcast_int_indices(index)
# Apply NumPy-compatible indexing conversion
index = _numpy_compatible_indexing(index)
# Apply NumPy-compatible empty ellipsis behavior
index, _, maybe_unsqueeze = _numpy_empty_ellipsis_patch(index, self.tensor.ndim)
if not _dtypes_impl.is_scalar(value):
value = normalize_array_like(value)
value = _util.cast_if_needed(value, self.tensor.dtype)
return self.tensor.__setitem__(index, value)
return self.tensor.__setitem__(index, maybe_unsqueeze(value))
take = _funcs.take
put = _funcs.put