mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix torch._numpy to match NumPy when empty ellipsis causes advanced indexing separation (#158297)
Fixes #141563 In NumPy, an ellipsis always acts as a separator between advanced indices, even when the ellipsis doesn't actually match any dimensions. In PyTorch an empty ellipsis doesn't cause a separation. This leads to differing behavior between Numpy and PyTorch in this edge case. This difference in behavior leads to a bug when using torch.compile: ```python >>> import numpy as np >>> f = lambda x: x[:,(0,1),...,(0,1)].shape >>> a = np.ones((3, 4, 5)) >>> f(a) (2, 3) >>> torch.compile(f)(a) (3, 2) ``` Similarly to #157676, this PR doesn't change PyTorch's behavior, but it fixes the translation layer, ensuring torch._numpy compatibility with NumPy. I am marking this PR as fixing #141563, even though PyTorch behavior isn't modified. Notice that there are still some other bugs in PyTorch's advanced indexing, that need to be fixed (mainly regarding proper accounting of dimensions when multidimensional boolean masks are present). But those need to be fixed at the ATen operator level. Examples: - #71673 - #107699 - #158125 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158297 Approved by: https://github.com/soumith
This commit is contained in:
committed by
PyTorch MergeBot
parent
ddf502c988
commit
fb9a5d248f
@ -169,17 +169,22 @@ def _upcast_int_indices(index):
|
||||
return index
|
||||
|
||||
|
||||
def _has_advanced_indexing(index):
|
||||
"""Check if there's any advanced indexing"""
|
||||
return any(
|
||||
isinstance(idx, (Sequence, bool))
|
||||
or (isinstance(idx, torch.Tensor) and (idx.dtype == torch.bool or idx.ndim > 0))
|
||||
for idx in index
|
||||
)
|
||||
|
||||
|
||||
def _numpy_compatible_indexing(index):
|
||||
"""Convert scalar indices to lists when advanced indexing is present for NumPy compatibility."""
|
||||
if not isinstance(index, tuple):
|
||||
index = (index,)
|
||||
|
||||
# Check if there's any advanced indexing (sequences, booleans, or tensors)
|
||||
has_advanced = any(
|
||||
isinstance(idx, (Sequence, bool))
|
||||
or (isinstance(idx, torch.Tensor) and (idx.dtype == torch.bool or idx.ndim > 0))
|
||||
for idx in index
|
||||
)
|
||||
has_advanced = _has_advanced_indexing(index)
|
||||
|
||||
if not has_advanced:
|
||||
return index
|
||||
@ -206,6 +211,84 @@ def _numpy_compatible_indexing(index):
|
||||
return tuple(converted)
|
||||
|
||||
|
||||
def _get_bool_depth(s):
|
||||
"""Returns the depth of a boolean sequence/tensor"""
|
||||
if isinstance(s, bool):
|
||||
return True, 0
|
||||
if isinstance(s, torch.Tensor) and s.dtype == torch.bool:
|
||||
return True, s.ndim
|
||||
if not (isinstance(s, Sequence) and s and s[0] != s):
|
||||
return False, 0
|
||||
is_bool, depth = _get_bool_depth(s[0])
|
||||
return is_bool, depth + 1
|
||||
|
||||
|
||||
def _numpy_empty_ellipsis_patch(index, tensor_ndim):
|
||||
"""
|
||||
Patch for NumPy-compatible ellipsis behavior when ellipsis doesn't match any dimensions.
|
||||
|
||||
In NumPy, when an ellipsis (...) doesn't actually match any dimensions of the input array,
|
||||
it still acts as a separator between advanced indices. PyTorch doesn't have this behavior.
|
||||
|
||||
This function detects when we have:
|
||||
1. Advanced indexing on both sides of an ellipsis
|
||||
2. The ellipsis doesn't actually match any dimensions
|
||||
"""
|
||||
if not isinstance(index, tuple):
|
||||
index = (index,)
|
||||
|
||||
# Find ellipsis position
|
||||
ellipsis_pos = None
|
||||
for i, idx in enumerate(index):
|
||||
if idx is Ellipsis:
|
||||
ellipsis_pos = i
|
||||
break
|
||||
|
||||
# If no ellipsis, no patch needed
|
||||
if ellipsis_pos is None:
|
||||
return index, lambda x: x, lambda x: x
|
||||
|
||||
# Count non-ellipsis dimensions consumed by the index
|
||||
consumed_dims = 0
|
||||
for idx in index:
|
||||
is_bool, depth = _get_bool_depth(idx)
|
||||
if is_bool:
|
||||
consumed_dims += depth
|
||||
elif idx is Ellipsis or idx is None:
|
||||
continue
|
||||
else:
|
||||
consumed_dims += 1
|
||||
|
||||
# Calculate how many dimensions the ellipsis should match
|
||||
ellipsis_dims = tensor_ndim - consumed_dims
|
||||
|
||||
# Check if ellipsis doesn't match any dimensions
|
||||
if ellipsis_dims == 0:
|
||||
# Check if we have advanced indexing on both sides of ellipsis
|
||||
left_advanced = _has_advanced_indexing(index[:ellipsis_pos])
|
||||
right_advanced = _has_advanced_indexing(index[ellipsis_pos + 1 :])
|
||||
|
||||
if left_advanced and right_advanced:
|
||||
# This is the case where NumPy and PyTorch differ
|
||||
# We need to ensure the advanced indices are treated as separated
|
||||
new_index = index[:ellipsis_pos] + (None,) + index[ellipsis_pos + 1 :]
|
||||
end_ndims = 1 + sum(
|
||||
1 for idx in index[ellipsis_pos + 1 :] if isinstance(idx, slice)
|
||||
)
|
||||
|
||||
def squeeze_fn(x):
|
||||
return x.squeeze(-end_ndims)
|
||||
|
||||
def unsqueeze_fn(x):
|
||||
if isinstance(x, torch.Tensor) and x.ndim >= end_ndims:
|
||||
return x.unsqueeze(-end_ndims)
|
||||
return x
|
||||
|
||||
return new_index, squeeze_fn, unsqueeze_fn
|
||||
|
||||
return index, lambda x: x, lambda x: x
|
||||
|
||||
|
||||
# Used to indicate that a parameter is unspecified (as opposed to explicitly
|
||||
# `None`)
|
||||
class _Unspecified:
|
||||
@ -507,19 +590,23 @@ class ndarray:
|
||||
index = _upcast_int_indices(index)
|
||||
# Apply NumPy-compatible indexing conversion
|
||||
index = _numpy_compatible_indexing(index)
|
||||
return ndarray(tensor.__getitem__(index))
|
||||
# Apply NumPy-compatible empty ellipsis behavior
|
||||
index, maybe_squeeze, _ = _numpy_empty_ellipsis_patch(index, tensor.ndim)
|
||||
return maybe_squeeze(ndarray(tensor.__getitem__(index)))
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
index = _util.ndarrays_to_tensors(index)
|
||||
index = _upcast_int_indices(index)
|
||||
# Apply NumPy-compatible indexing conversion
|
||||
index = _numpy_compatible_indexing(index)
|
||||
# Apply NumPy-compatible empty ellipsis behavior
|
||||
index, _, maybe_unsqueeze = _numpy_empty_ellipsis_patch(index, self.tensor.ndim)
|
||||
|
||||
if not _dtypes_impl.is_scalar(value):
|
||||
value = normalize_array_like(value)
|
||||
value = _util.cast_if_needed(value, self.tensor.dtype)
|
||||
|
||||
return self.tensor.__setitem__(index, value)
|
||||
return self.tensor.__setitem__(index, maybe_unsqueeze(value))
|
||||
|
||||
take = _funcs.take
|
||||
put = _funcs.put
|
||||
|
||||
Reference in New Issue
Block a user