Files
pytorch/test/torch_np/test_indexing.py
Manuel Candales fb9a5d248f Fix torch._numpy to match NumPy when empty ellipsis causes advanced indexing separation (#158297)
Fixes #141563

In NumPy, an ellipsis always acts as a separator between advanced indices, even when the ellipsis doesn't actually match any dimensions. In PyTorch an empty ellipsis doesn't cause a separation. This leads to differing behavior between Numpy and PyTorch in this edge case.

This difference in behavior leads to a bug when using torch.compile:
```python
>>> import numpy as np
>>> f = lambda x: x[:,(0,1),...,(0,1)].shape
>>> a = np.ones((3, 4, 5))
>>> f(a)
(2, 3)
>>> torch.compile(f)(a)
(3, 2)
```

Similarly to #157676, this PR doesn't change PyTorch's behavior, but it fixes the translation layer, ensuring torch._numpy compatibility with NumPy. I am marking this PR as fixing #141563, even though PyTorch behavior isn't modified.

Notice that there are still some other bugs in PyTorch's advanced indexing, that need to be fixed (mainly regarding proper accounting of dimensions when multidimensional boolean masks are present). But those need to be fixed at the ATen operator level. Examples:
- #71673
- #107699
- #158125

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158297
Approved by: https://github.com/soumith
2025-07-16 08:11:53 +00:00

489 lines
17 KiB
Python

# Owner(s): ["module: dynamo"]
import numpy
import torch
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
run_tests,
TEST_WITH_TORCHDYNAMO,
TestCase,
)
if TEST_WITH_TORCHDYNAMO:
import numpy as np
from numpy.testing import assert_array_equal
else:
import torch._numpy as np
from torch._numpy.testing import assert_array_equal
@instantiate_parametrized_tests
class TestAdvancedIndexing(TestCase):
"""Test advanced indexing for NumPy compatibility (separated advanced indices)."""
def _test_cases(self, test_cases, description=""):
"""Unified test runner for multiple test cases."""
for case in test_cases:
shape, index = case["shape"], case["index"]
name = case.get("name", f"{description}: {index}")
with self.subTest(name=name):
# Create test arrays
torch_arr = np.arange(numpy.prod(shape)).reshape(shape)
numpy_arr = numpy.arange(numpy.prod(shape)).reshape(shape)
# Test getitem
tr, nr = torch_arr[index], numpy_arr[index]
self.assertEqual(tr.shape, nr.shape, f"{name} getitem shape mismatch")
assert_array_equal(
tr.tensor.numpy() if hasattr(tr, "tensor") else tr, nr
)
# Test setitem
tc, nc = torch_arr.copy(), numpy_arr.copy()
value = np.ones(tr.shape) * 999
val_np = value.tensor.numpy() if hasattr(value, "tensor") else value
tc[index], nc[index] = value, val_np
assert_array_equal(
tc.tensor.numpy() if hasattr(tc, "tensor") else tc,
nc,
f"{name} setitem mismatch",
)
def _generate_standard_indices(self, shape):
"""Generate standard index patterns for given shape."""
if len(shape) == 1:
return [([0, 2, 4],), ([1, 3, 5, 7],), ([0, 0, 1, 1],)]
elif len(shape) == 2:
return [
(slice(None), [0]),
([0], slice(None)),
(slice(1, 3), [0, 2]),
([0, 2], slice(1, 4)),
([0, 2], 1),
(1, [0, 2, 4]),
]
elif len(shape) == 3:
return [
([0, 2], slice(None), slice(None)),
(slice(None), [0, 2], slice(None)),
(slice(None), slice(None), [0, 2, 4]),
([0, 1], slice(None), 2),
([0, 2], 1, slice(None)),
(1, [1, 3], slice(None)),
(slice(None), [0, 2], 3),
([0, 1], slice(None), [0, 2]),
([1, 2], slice(1, 3), [1, 3]),
([0, 1], 1, [1, 2]),
]
elif len(shape) == 4:
return [
(slice(None), [0], 0, slice(None)),
(slice(None), [0, 1], slice(None), 0),
([0, 1], slice(None), [0, 2], slice(None)),
(slice(None), [0, 2], slice(None), [0, 3]),
([0], slice(None), slice(None), [1, 3]),
(slice(None), slice(None), [0, 2], [1, 4]),
]
return []
def _generate_broadcast_indices(self, shape):
"""Generate broadcast-style index patterns."""
if len(shape) >= 2:
# Only generate patterns that fit within the shape bounds
patterns = []
if shape[0] >= 2 and shape[1] >= 4:
patterns.append(([[0, 1], [0, 1]], [[0, 1], [2, 3]])) # 2D broadcasting
if shape[0] >= 3 and shape[1] >= 2:
patterns.append(([0, 1, 2], [0, 1, 0])) # Element selection
return patterns
return []
def test_comprehensive_indexing(self):
"""Test comprehensive indexing patterns across multiple dimensions."""
test_shapes = [(10,), (4, 6), (3, 4, 5), (2, 3, 4, 5), (5, 6, 7, 8, 9)]
all_cases = []
for shape in test_shapes:
# Standard patterns
for idx in self._generate_standard_indices(shape):
all_cases.append({"shape": shape, "index": idx})
# Broadcast patterns (for 2D+)
if len(shape) >= 2:
for idx in self._generate_broadcast_indices(shape):
all_cases.append({"shape": shape, "index": idx})
# Separated indices patterns (for 3D+)
if len(shape) >= 3:
all_cases.extend(
[
{
"shape": shape,
"index": ([0, 1], slice(None), [1, 2]),
"name": "Separated indices",
},
{
"shape": shape,
"index": ([0, 1], [1, 2]),
"name": "Adjacent indices",
},
]
)
# Edge cases with negative indices
if len(shape) >= 3:
all_cases.extend(
[
{"shape": shape, "index": ([0], slice(None), [1])},
{"shape": shape, "index": ([-1], slice(None), [-1])},
{"shape": shape, "index": ([0, -1], slice(None), [1, -1])},
]
)
self._test_cases(all_cases, "Comprehensive indexing")
def test_advanced_separation_patterns(self):
"""Test advanced separation patterns and edge cases."""
cases = [
# Complex multi-dimensional separations
{
"shape": (3, 4, 5, 6, 7, 8),
"index": (
[1, 2],
slice(None),
[2, 4],
slice(None),
[5, 2],
slice(None),
),
"name": "6D multiple separations",
},
{
"shape": (3, 4, 5, 6, 7, 8, 9, 10),
"index": (
[1, 2],
slice(None),
slice(None),
[2, 4],
slice(None),
slice(None),
[5, 2],
slice(None),
),
"name": "8D multiple separations",
},
# Current logic issues
{
"shape": (2, 3, 4, 5),
"index": (0, slice(None), [1], slice(None)),
"name": "Logic test: [0, :, [1], :]",
},
{
"shape": (2, 3, 4, 5),
"index": (0, [1], slice(None), 0),
"name": "Logic test: [0, [1], :, 0]",
},
# High-dimensional edge cases
{
"shape": (5, 6, 7, 8, 9),
"index": (slice(None), slice(None), [1, 2], slice(None), slice(None)),
"name": "Single advanced index at middle position",
},
{
"shape": (5, 6, 7, 8, 9),
"index": ([1, 2], slice(None), slice(None), slice(None), slice(None)),
"name": "Single advanced index at start",
},
{
"shape": (5, 6, 7, 8, 9),
"index": (slice(None), slice(None), slice(None), slice(None), [1, 2]),
"name": "Single advanced index at end",
},
{
"shape": (5, 6, 7, 8, 9),
"index": (2, [1, 2], slice(None), slice(None), slice(None)),
"name": "Advanced index with scalar before",
},
{
"shape": (5, 6, 7, 8, 9),
"index": (2, 3, [1, 2], 4, slice(None)),
"name": "Advanced index with scalars around",
},
]
self._test_cases(cases, "Advanced separation patterns")
def test_broadcast_and_numpy_compatibility(self):
"""Test broadcasting patterns and NumPy documentation examples."""
base_shape = (10, 20, 30, 40, 50)
# Standard broadcast indices
indices_2d = {
"ind_1": [[1, 2], [3, 4], [5, 6]],
"ind_2": [[7, 8], [9, 10], [11, 12]],
}
indices_3d = {"ind_1": [[[1, 2]], [[3, 4]]], "ind_2": [[[5, 6]], [[7, 8]]]}
indices_4d = {
"ind_1": [[[[1, 2]], [[3, 4]]], [[[5, 6]], [[7, 8]]]],
"ind_2": [[[[9, 10]], [[11, 12]]], [[[13, 14]], [[15, 16]]]],
}
cases = [
# 2D broadcast patterns
{
"shape": base_shape,
"index": (slice(None), indices_2d["ind_1"], indices_2d["ind_2"]),
"name": "Adjacent 2D broadcast indices",
},
{
"shape": base_shape,
"index": (
slice(None),
indices_2d["ind_1"],
slice(None),
indices_2d["ind_2"],
),
"name": "Separated 2D broadcast indices",
},
# 3D broadcast patterns
{
"shape": base_shape,
"index": (
slice(None),
indices_3d["ind_1"],
slice(None),
indices_3d["ind_2"],
),
"name": "3D broadcast creating 8D output",
},
# 4D broadcast patterns
{
"shape": base_shape,
"index": (
slice(None),
indices_4d["ind_1"],
slice(None),
indices_4d["ind_2"],
),
"name": "4D broadcast creating 9D output",
},
# NumPy documentation examples
{
"shape": (5, 7),
"index": ([0, 2, 4], slice(1, 3)),
"name": "NumPy docs: mixed advanced and slice",
},
{
"shape": (4, 3),
"index": ([0, 1, 2], [0, 1, 0]),
"name": "NumPy docs: element selection",
},
{
"shape": (4, 3),
"index": ([[0], [3]], [0, 2]),
"name": "NumPy docs: broadcasting case",
},
]
self._test_cases(cases, "Broadcast and NumPy compatibility")
def test_special_index_types(self):
"""Test special index types including tensors, ellipsis, and newaxis."""
cases = [
# Ellipsis handling
{
"shape": (3, 4, 5, 6),
"index": (..., [1, 2], slice(None)),
"name": "Ellipsis with advanced index",
},
{
"shape": (3, 4, 5, 6),
"index": ([0, 1], ..., 2),
"name": "Advanced index with ellipsis",
},
# Newaxis handling
{
"shape": (3, 4, 5),
"index": (None, [1, 2], slice(None)),
"name": "Newaxis with advanced index",
},
{
"shape": (3, 4, 5),
"index": ([1, 2], None, slice(None)),
"name": "Advanced index with newaxis",
},
# Complex mixing
{
"shape": (3, 4, 5, 6),
"index": (0, [1, 2], slice(None), 3),
"name": "Integer-advanced-slice-integer",
},
{
"shape": (4, 5, 6),
"index": ([0, 1], 2, [3, 4]),
"name": "Advanced-integer-advanced",
},
{
"shape": (3, 4, 5, 6, 7),
"index": ([0, 1], slice(None), [2, 3], slice(1, 3), [4, 5]),
"name": "Multiple separated advanced indices",
},
{
"shape": (3, 4, 5),
"index": ([0, 1], [2, 3], [1, 4]),
"name": "All advanced indices",
},
# Boolean indexing cases (issue #158134)
# Tests for boolean values that trigger advanced indexing
{
"shape": (3, 4),
"index": (True,),
"name": "Boolean True indexing",
},
{
"shape": (3, 4),
"index": (False,),
"name": "Boolean False indexing",
},
{
"shape": (2, 3, 4),
"index": (True, slice(None)),
"name": "Boolean True with slice",
},
# Tuple indexing cases
{
"shape": (3, 4, 5),
"index": ((0, 1), slice(None)),
"name": "Tuple indexing with slice",
},
{
"shape": (4, 5),
"index": ((0, 1, 2),),
"name": "Tuple indexing only",
},
{
"shape": (3, 4, 5),
"index": ((0, 1), (2, 3)),
"name": "Multiple tuple indexing",
},
]
# Handle torch tensor cases separately
torch_cases = [
{
"shape": (2, 3, 4),
"index": (torch.tensor(1), [1, 2]),
"name": "Zero-dim tensor with list",
},
{
"shape": (2, 3, 4),
"index": (torch.tensor(1), slice(None), [1, 2]),
"name": "Zero-dim tensor with slice and list",
},
{
"shape": (2, 3, 4),
"index": (torch.tensor(0), torch.tensor(1), [1, 2]),
"name": "Multiple zero-dim tensors",
},
{
"shape": (3, 4),
"index": (torch.tensor(True),),
"name": "Zero-dim bool tensor",
},
{
"shape": (2, 3, 4),
"index": (torch.tensor(False),),
"name": "Zero-dim False tensor",
},
]
# Convert torch tensor indices for numpy comparison
numpy_torch_cases = []
for case in torch_cases:
numpy_index = tuple(
idx.item() if isinstance(idx, torch.Tensor) and idx.ndim == 0 else idx
for idx in case["index"]
)
numpy_torch_cases.append(
{"shape": case["shape"], "index": numpy_index, "name": case["name"]}
)
self._test_cases(cases + numpy_torch_cases, "Special index types")
def test_ellipsis(self):
"""Tests containing ellipsis."""
cases = [
# Ellipsis + Basic indexing
{
"shape": (3, 4, 5),
"index": (slice(None), 0, ..., slice(None)),
"name": "empty ellipsis without advanced indexing",
},
{
"shape": (3, 4, 5),
"index": (slice(None), ..., 0),
"name": "non-empty ellipsis without advanced indexing",
},
# Ellipsis + Advanced indexing without separation
{
"shape": (3, 4, 5),
"index": (slice(None), ..., slice(None), (0, 1)),
"name": "empty ellipsis without separation",
},
{
"shape": (3, 4, 5),
"index": (slice(None), ..., (0, 1)),
"name": "non-empty ellipsis without separation",
},
# Ellipsis + Advanced indexing with separation
{
"shape": (3, 4, 5),
"index": (slice(None), (0, 1), ..., (0, 1)),
"name": "empty ellipsis separation",
},
{
"shape": (1, 3, 4, 5),
"index": (slice(None), (0, 1), ..., (0, 1)),
"name": "non-empty ellipsis separation",
},
{
"shape": (4, 3, 5),
"index": (slice(None), ((0,), (1,)), ..., (0, 1)),
"name": "empty ellipsis separation with 2-depth int sequence",
},
{
"shape": (4, 3, 5, 6),
"index": (slice(None), ((0,), (1,)), ..., (0, 1), slice(None)),
"name": "empty ellipsis separation with 2-depth int sequence and end slice",
},
{
"shape": (4, 3, 5, 6),
"index": (slice(None), ((0,), (1,)), ..., (0, 1), (((0, 1), (1, 2)),)),
"name": "empty ellipsis separation with 2 and 3-depth int sequence",
},
# Ellipsis + Boolean masks in advanced indexing with separation
{
"shape": (3, 4, 5),
"index": (slice(None), True, True, True, ..., 0, 0),
"name": "empty ellipsis separation with 0-dim boolean masks",
},
{
"shape": (4, 3, 5),
"index": (slice(None), (True, True, False), ..., (0, 1)),
"name": "empty ellipsis separation with 1-dim boolean masks",
},
# TODO(manuelcandales) Fix issue #71673 and enable this case
# {
# "shape": (1, 2, 2, 4, 5),
# "index": (slice(None), ((True, False), (True, True)), (0, 1, 2), ..., (0,)),
# "name": "empty ellipsis separation with 2-dim boolean masks",
# },
]
self._test_cases(cases, "Ellipsis and advanced indexing separation")
if __name__ == "__main__":
run_tests()