mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #141563 In NumPy, an ellipsis always acts as a separator between advanced indices, even when the ellipsis doesn't actually match any dimensions. In PyTorch an empty ellipsis doesn't cause a separation. This leads to differing behavior between Numpy and PyTorch in this edge case. This difference in behavior leads to a bug when using torch.compile: ```python >>> import numpy as np >>> f = lambda x: x[:,(0,1),...,(0,1)].shape >>> a = np.ones((3, 4, 5)) >>> f(a) (2, 3) >>> torch.compile(f)(a) (3, 2) ``` Similarly to #157676, this PR doesn't change PyTorch's behavior, but it fixes the translation layer, ensuring torch._numpy compatibility with NumPy. I am marking this PR as fixing #141563, even though PyTorch behavior isn't modified. Notice that there are still some other bugs in PyTorch's advanced indexing, that need to be fixed (mainly regarding proper accounting of dimensions when multidimensional boolean masks are present). But those need to be fixed at the ATen operator level. Examples: - #71673 - #107699 - #158125 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158297 Approved by: https://github.com/soumith
489 lines
17 KiB
Python
489 lines
17 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import numpy
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
run_tests,
|
|
TEST_WITH_TORCHDYNAMO,
|
|
TestCase,
|
|
)
|
|
|
|
|
|
if TEST_WITH_TORCHDYNAMO:
|
|
import numpy as np
|
|
from numpy.testing import assert_array_equal
|
|
else:
|
|
import torch._numpy as np
|
|
from torch._numpy.testing import assert_array_equal
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class TestAdvancedIndexing(TestCase):
|
|
"""Test advanced indexing for NumPy compatibility (separated advanced indices)."""
|
|
|
|
def _test_cases(self, test_cases, description=""):
|
|
"""Unified test runner for multiple test cases."""
|
|
for case in test_cases:
|
|
shape, index = case["shape"], case["index"]
|
|
name = case.get("name", f"{description}: {index}")
|
|
|
|
with self.subTest(name=name):
|
|
# Create test arrays
|
|
torch_arr = np.arange(numpy.prod(shape)).reshape(shape)
|
|
numpy_arr = numpy.arange(numpy.prod(shape)).reshape(shape)
|
|
|
|
# Test getitem
|
|
tr, nr = torch_arr[index], numpy_arr[index]
|
|
self.assertEqual(tr.shape, nr.shape, f"{name} getitem shape mismatch")
|
|
assert_array_equal(
|
|
tr.tensor.numpy() if hasattr(tr, "tensor") else tr, nr
|
|
)
|
|
|
|
# Test setitem
|
|
tc, nc = torch_arr.copy(), numpy_arr.copy()
|
|
value = np.ones(tr.shape) * 999
|
|
val_np = value.tensor.numpy() if hasattr(value, "tensor") else value
|
|
tc[index], nc[index] = value, val_np
|
|
assert_array_equal(
|
|
tc.tensor.numpy() if hasattr(tc, "tensor") else tc,
|
|
nc,
|
|
f"{name} setitem mismatch",
|
|
)
|
|
|
|
def _generate_standard_indices(self, shape):
|
|
"""Generate standard index patterns for given shape."""
|
|
if len(shape) == 1:
|
|
return [([0, 2, 4],), ([1, 3, 5, 7],), ([0, 0, 1, 1],)]
|
|
elif len(shape) == 2:
|
|
return [
|
|
(slice(None), [0]),
|
|
([0], slice(None)),
|
|
(slice(1, 3), [0, 2]),
|
|
([0, 2], slice(1, 4)),
|
|
([0, 2], 1),
|
|
(1, [0, 2, 4]),
|
|
]
|
|
elif len(shape) == 3:
|
|
return [
|
|
([0, 2], slice(None), slice(None)),
|
|
(slice(None), [0, 2], slice(None)),
|
|
(slice(None), slice(None), [0, 2, 4]),
|
|
([0, 1], slice(None), 2),
|
|
([0, 2], 1, slice(None)),
|
|
(1, [1, 3], slice(None)),
|
|
(slice(None), [0, 2], 3),
|
|
([0, 1], slice(None), [0, 2]),
|
|
([1, 2], slice(1, 3), [1, 3]),
|
|
([0, 1], 1, [1, 2]),
|
|
]
|
|
elif len(shape) == 4:
|
|
return [
|
|
(slice(None), [0], 0, slice(None)),
|
|
(slice(None), [0, 1], slice(None), 0),
|
|
([0, 1], slice(None), [0, 2], slice(None)),
|
|
(slice(None), [0, 2], slice(None), [0, 3]),
|
|
([0], slice(None), slice(None), [1, 3]),
|
|
(slice(None), slice(None), [0, 2], [1, 4]),
|
|
]
|
|
return []
|
|
|
|
def _generate_broadcast_indices(self, shape):
|
|
"""Generate broadcast-style index patterns."""
|
|
if len(shape) >= 2:
|
|
# Only generate patterns that fit within the shape bounds
|
|
patterns = []
|
|
if shape[0] >= 2 and shape[1] >= 4:
|
|
patterns.append(([[0, 1], [0, 1]], [[0, 1], [2, 3]])) # 2D broadcasting
|
|
if shape[0] >= 3 and shape[1] >= 2:
|
|
patterns.append(([0, 1, 2], [0, 1, 0])) # Element selection
|
|
return patterns
|
|
return []
|
|
|
|
def test_comprehensive_indexing(self):
|
|
"""Test comprehensive indexing patterns across multiple dimensions."""
|
|
test_shapes = [(10,), (4, 6), (3, 4, 5), (2, 3, 4, 5), (5, 6, 7, 8, 9)]
|
|
|
|
all_cases = []
|
|
|
|
for shape in test_shapes:
|
|
# Standard patterns
|
|
for idx in self._generate_standard_indices(shape):
|
|
all_cases.append({"shape": shape, "index": idx})
|
|
|
|
# Broadcast patterns (for 2D+)
|
|
if len(shape) >= 2:
|
|
for idx in self._generate_broadcast_indices(shape):
|
|
all_cases.append({"shape": shape, "index": idx})
|
|
|
|
# Separated indices patterns (for 3D+)
|
|
if len(shape) >= 3:
|
|
all_cases.extend(
|
|
[
|
|
{
|
|
"shape": shape,
|
|
"index": ([0, 1], slice(None), [1, 2]),
|
|
"name": "Separated indices",
|
|
},
|
|
{
|
|
"shape": shape,
|
|
"index": ([0, 1], [1, 2]),
|
|
"name": "Adjacent indices",
|
|
},
|
|
]
|
|
)
|
|
|
|
# Edge cases with negative indices
|
|
if len(shape) >= 3:
|
|
all_cases.extend(
|
|
[
|
|
{"shape": shape, "index": ([0], slice(None), [1])},
|
|
{"shape": shape, "index": ([-1], slice(None), [-1])},
|
|
{"shape": shape, "index": ([0, -1], slice(None), [1, -1])},
|
|
]
|
|
)
|
|
|
|
self._test_cases(all_cases, "Comprehensive indexing")
|
|
|
|
def test_advanced_separation_patterns(self):
|
|
"""Test advanced separation patterns and edge cases."""
|
|
cases = [
|
|
# Complex multi-dimensional separations
|
|
{
|
|
"shape": (3, 4, 5, 6, 7, 8),
|
|
"index": (
|
|
[1, 2],
|
|
slice(None),
|
|
[2, 4],
|
|
slice(None),
|
|
[5, 2],
|
|
slice(None),
|
|
),
|
|
"name": "6D multiple separations",
|
|
},
|
|
{
|
|
"shape": (3, 4, 5, 6, 7, 8, 9, 10),
|
|
"index": (
|
|
[1, 2],
|
|
slice(None),
|
|
slice(None),
|
|
[2, 4],
|
|
slice(None),
|
|
slice(None),
|
|
[5, 2],
|
|
slice(None),
|
|
),
|
|
"name": "8D multiple separations",
|
|
},
|
|
# Current logic issues
|
|
{
|
|
"shape": (2, 3, 4, 5),
|
|
"index": (0, slice(None), [1], slice(None)),
|
|
"name": "Logic test: [0, :, [1], :]",
|
|
},
|
|
{
|
|
"shape": (2, 3, 4, 5),
|
|
"index": (0, [1], slice(None), 0),
|
|
"name": "Logic test: [0, [1], :, 0]",
|
|
},
|
|
# High-dimensional edge cases
|
|
{
|
|
"shape": (5, 6, 7, 8, 9),
|
|
"index": (slice(None), slice(None), [1, 2], slice(None), slice(None)),
|
|
"name": "Single advanced index at middle position",
|
|
},
|
|
{
|
|
"shape": (5, 6, 7, 8, 9),
|
|
"index": ([1, 2], slice(None), slice(None), slice(None), slice(None)),
|
|
"name": "Single advanced index at start",
|
|
},
|
|
{
|
|
"shape": (5, 6, 7, 8, 9),
|
|
"index": (slice(None), slice(None), slice(None), slice(None), [1, 2]),
|
|
"name": "Single advanced index at end",
|
|
},
|
|
{
|
|
"shape": (5, 6, 7, 8, 9),
|
|
"index": (2, [1, 2], slice(None), slice(None), slice(None)),
|
|
"name": "Advanced index with scalar before",
|
|
},
|
|
{
|
|
"shape": (5, 6, 7, 8, 9),
|
|
"index": (2, 3, [1, 2], 4, slice(None)),
|
|
"name": "Advanced index with scalars around",
|
|
},
|
|
]
|
|
self._test_cases(cases, "Advanced separation patterns")
|
|
|
|
def test_broadcast_and_numpy_compatibility(self):
|
|
"""Test broadcasting patterns and NumPy documentation examples."""
|
|
base_shape = (10, 20, 30, 40, 50)
|
|
|
|
# Standard broadcast indices
|
|
indices_2d = {
|
|
"ind_1": [[1, 2], [3, 4], [5, 6]],
|
|
"ind_2": [[7, 8], [9, 10], [11, 12]],
|
|
}
|
|
indices_3d = {"ind_1": [[[1, 2]], [[3, 4]]], "ind_2": [[[5, 6]], [[7, 8]]]}
|
|
indices_4d = {
|
|
"ind_1": [[[[1, 2]], [[3, 4]]], [[[5, 6]], [[7, 8]]]],
|
|
"ind_2": [[[[9, 10]], [[11, 12]]], [[[13, 14]], [[15, 16]]]],
|
|
}
|
|
|
|
cases = [
|
|
# 2D broadcast patterns
|
|
{
|
|
"shape": base_shape,
|
|
"index": (slice(None), indices_2d["ind_1"], indices_2d["ind_2"]),
|
|
"name": "Adjacent 2D broadcast indices",
|
|
},
|
|
{
|
|
"shape": base_shape,
|
|
"index": (
|
|
slice(None),
|
|
indices_2d["ind_1"],
|
|
slice(None),
|
|
indices_2d["ind_2"],
|
|
),
|
|
"name": "Separated 2D broadcast indices",
|
|
},
|
|
# 3D broadcast patterns
|
|
{
|
|
"shape": base_shape,
|
|
"index": (
|
|
slice(None),
|
|
indices_3d["ind_1"],
|
|
slice(None),
|
|
indices_3d["ind_2"],
|
|
),
|
|
"name": "3D broadcast creating 8D output",
|
|
},
|
|
# 4D broadcast patterns
|
|
{
|
|
"shape": base_shape,
|
|
"index": (
|
|
slice(None),
|
|
indices_4d["ind_1"],
|
|
slice(None),
|
|
indices_4d["ind_2"],
|
|
),
|
|
"name": "4D broadcast creating 9D output",
|
|
},
|
|
# NumPy documentation examples
|
|
{
|
|
"shape": (5, 7),
|
|
"index": ([0, 2, 4], slice(1, 3)),
|
|
"name": "NumPy docs: mixed advanced and slice",
|
|
},
|
|
{
|
|
"shape": (4, 3),
|
|
"index": ([0, 1, 2], [0, 1, 0]),
|
|
"name": "NumPy docs: element selection",
|
|
},
|
|
{
|
|
"shape": (4, 3),
|
|
"index": ([[0], [3]], [0, 2]),
|
|
"name": "NumPy docs: broadcasting case",
|
|
},
|
|
]
|
|
self._test_cases(cases, "Broadcast and NumPy compatibility")
|
|
|
|
def test_special_index_types(self):
|
|
"""Test special index types including tensors, ellipsis, and newaxis."""
|
|
cases = [
|
|
# Ellipsis handling
|
|
{
|
|
"shape": (3, 4, 5, 6),
|
|
"index": (..., [1, 2], slice(None)),
|
|
"name": "Ellipsis with advanced index",
|
|
},
|
|
{
|
|
"shape": (3, 4, 5, 6),
|
|
"index": ([0, 1], ..., 2),
|
|
"name": "Advanced index with ellipsis",
|
|
},
|
|
# Newaxis handling
|
|
{
|
|
"shape": (3, 4, 5),
|
|
"index": (None, [1, 2], slice(None)),
|
|
"name": "Newaxis with advanced index",
|
|
},
|
|
{
|
|
"shape": (3, 4, 5),
|
|
"index": ([1, 2], None, slice(None)),
|
|
"name": "Advanced index with newaxis",
|
|
},
|
|
# Complex mixing
|
|
{
|
|
"shape": (3, 4, 5, 6),
|
|
"index": (0, [1, 2], slice(None), 3),
|
|
"name": "Integer-advanced-slice-integer",
|
|
},
|
|
{
|
|
"shape": (4, 5, 6),
|
|
"index": ([0, 1], 2, [3, 4]),
|
|
"name": "Advanced-integer-advanced",
|
|
},
|
|
{
|
|
"shape": (3, 4, 5, 6, 7),
|
|
"index": ([0, 1], slice(None), [2, 3], slice(1, 3), [4, 5]),
|
|
"name": "Multiple separated advanced indices",
|
|
},
|
|
{
|
|
"shape": (3, 4, 5),
|
|
"index": ([0, 1], [2, 3], [1, 4]),
|
|
"name": "All advanced indices",
|
|
},
|
|
# Boolean indexing cases (issue #158134)
|
|
# Tests for boolean values that trigger advanced indexing
|
|
{
|
|
"shape": (3, 4),
|
|
"index": (True,),
|
|
"name": "Boolean True indexing",
|
|
},
|
|
{
|
|
"shape": (3, 4),
|
|
"index": (False,),
|
|
"name": "Boolean False indexing",
|
|
},
|
|
{
|
|
"shape": (2, 3, 4),
|
|
"index": (True, slice(None)),
|
|
"name": "Boolean True with slice",
|
|
},
|
|
# Tuple indexing cases
|
|
{
|
|
"shape": (3, 4, 5),
|
|
"index": ((0, 1), slice(None)),
|
|
"name": "Tuple indexing with slice",
|
|
},
|
|
{
|
|
"shape": (4, 5),
|
|
"index": ((0, 1, 2),),
|
|
"name": "Tuple indexing only",
|
|
},
|
|
{
|
|
"shape": (3, 4, 5),
|
|
"index": ((0, 1), (2, 3)),
|
|
"name": "Multiple tuple indexing",
|
|
},
|
|
]
|
|
|
|
# Handle torch tensor cases separately
|
|
torch_cases = [
|
|
{
|
|
"shape": (2, 3, 4),
|
|
"index": (torch.tensor(1), [1, 2]),
|
|
"name": "Zero-dim tensor with list",
|
|
},
|
|
{
|
|
"shape": (2, 3, 4),
|
|
"index": (torch.tensor(1), slice(None), [1, 2]),
|
|
"name": "Zero-dim tensor with slice and list",
|
|
},
|
|
{
|
|
"shape": (2, 3, 4),
|
|
"index": (torch.tensor(0), torch.tensor(1), [1, 2]),
|
|
"name": "Multiple zero-dim tensors",
|
|
},
|
|
{
|
|
"shape": (3, 4),
|
|
"index": (torch.tensor(True),),
|
|
"name": "Zero-dim bool tensor",
|
|
},
|
|
{
|
|
"shape": (2, 3, 4),
|
|
"index": (torch.tensor(False),),
|
|
"name": "Zero-dim False tensor",
|
|
},
|
|
]
|
|
|
|
# Convert torch tensor indices for numpy comparison
|
|
numpy_torch_cases = []
|
|
for case in torch_cases:
|
|
numpy_index = tuple(
|
|
idx.item() if isinstance(idx, torch.Tensor) and idx.ndim == 0 else idx
|
|
for idx in case["index"]
|
|
)
|
|
numpy_torch_cases.append(
|
|
{"shape": case["shape"], "index": numpy_index, "name": case["name"]}
|
|
)
|
|
|
|
self._test_cases(cases + numpy_torch_cases, "Special index types")
|
|
|
|
def test_ellipsis(self):
|
|
"""Tests containing ellipsis."""
|
|
cases = [
|
|
# Ellipsis + Basic indexing
|
|
{
|
|
"shape": (3, 4, 5),
|
|
"index": (slice(None), 0, ..., slice(None)),
|
|
"name": "empty ellipsis without advanced indexing",
|
|
},
|
|
{
|
|
"shape": (3, 4, 5),
|
|
"index": (slice(None), ..., 0),
|
|
"name": "non-empty ellipsis without advanced indexing",
|
|
},
|
|
# Ellipsis + Advanced indexing without separation
|
|
{
|
|
"shape": (3, 4, 5),
|
|
"index": (slice(None), ..., slice(None), (0, 1)),
|
|
"name": "empty ellipsis without separation",
|
|
},
|
|
{
|
|
"shape": (3, 4, 5),
|
|
"index": (slice(None), ..., (0, 1)),
|
|
"name": "non-empty ellipsis without separation",
|
|
},
|
|
# Ellipsis + Advanced indexing with separation
|
|
{
|
|
"shape": (3, 4, 5),
|
|
"index": (slice(None), (0, 1), ..., (0, 1)),
|
|
"name": "empty ellipsis separation",
|
|
},
|
|
{
|
|
"shape": (1, 3, 4, 5),
|
|
"index": (slice(None), (0, 1), ..., (0, 1)),
|
|
"name": "non-empty ellipsis separation",
|
|
},
|
|
{
|
|
"shape": (4, 3, 5),
|
|
"index": (slice(None), ((0,), (1,)), ..., (0, 1)),
|
|
"name": "empty ellipsis separation with 2-depth int sequence",
|
|
},
|
|
{
|
|
"shape": (4, 3, 5, 6),
|
|
"index": (slice(None), ((0,), (1,)), ..., (0, 1), slice(None)),
|
|
"name": "empty ellipsis separation with 2-depth int sequence and end slice",
|
|
},
|
|
{
|
|
"shape": (4, 3, 5, 6),
|
|
"index": (slice(None), ((0,), (1,)), ..., (0, 1), (((0, 1), (1, 2)),)),
|
|
"name": "empty ellipsis separation with 2 and 3-depth int sequence",
|
|
},
|
|
# Ellipsis + Boolean masks in advanced indexing with separation
|
|
{
|
|
"shape": (3, 4, 5),
|
|
"index": (slice(None), True, True, True, ..., 0, 0),
|
|
"name": "empty ellipsis separation with 0-dim boolean masks",
|
|
},
|
|
{
|
|
"shape": (4, 3, 5),
|
|
"index": (slice(None), (True, True, False), ..., (0, 1)),
|
|
"name": "empty ellipsis separation with 1-dim boolean masks",
|
|
},
|
|
# TODO(manuelcandales) Fix issue #71673 and enable this case
|
|
# {
|
|
# "shape": (1, 2, 2, 4, 5),
|
|
# "index": (slice(None), ((True, False), (True, True)), (0, 1, 2), ..., (0,)),
|
|
# "name": "empty ellipsis separation with 2-dim boolean masks",
|
|
# },
|
|
]
|
|
self._test_cases(cases, "Ellipsis and advanced indexing separation")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|