Files
pytorch/test/test_indexing.py
Natalia Gimelshein 34e3930401 fix numpy compatibility for 2d small list indices (#154806)
Will fix #119548 and linked issues once we switch from warning to the new behavior,
but for now, given how much this syntax was used in our test suite, we suspect a silent change will be disruptive.
We will change the behavior after 2.8 branch is cut.
Numpy behavior was changed at least in numpy 1.24 (more than 2 years ago)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154806
Approved by: https://github.com/cyyever, https://github.com/Skylion007, https://github.com/albanD
2025-06-04 01:58:52 +00:00

1997 lines
78 KiB
Python

# Owner(s): ["module: tests"]
import operator
import random
import unittest
import warnings
from functools import reduce
import numpy as np
import torch
from torch import tensor
from torch.testing import make_tensor
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfCPU,
dtypesIfCUDA,
instantiate_device_type_tests,
onlyCUDA,
onlyNativeDeviceTypes,
skipXLA,
)
from torch.testing._internal.common_utils import (
DeterministicGuard,
run_tests,
serialTest,
skipIfTorchDynamo,
TEST_CUDA,
TestCase,
xfailIfTorchDynamo,
)
class TestIndexing(TestCase):
def test_index(self, device):
def consec(size, start=1):
sequence = torch.ones(torch.tensor(size).prod(0)).cumsum(0)
sequence.add_(start - 1)
return sequence.view(*size)
reference = consec((3, 3, 3)).to(device)
# empty tensor indexing
self.assertEqual(
reference[torch.LongTensor().to(device)], reference.new(0, 3, 3)
)
self.assertEqual(reference[0], consec((3, 3)), atol=0, rtol=0)
self.assertEqual(reference[1], consec((3, 3), 10), atol=0, rtol=0)
self.assertEqual(reference[2], consec((3, 3), 19), atol=0, rtol=0)
self.assertEqual(reference[0, 1], consec((3,), 4), atol=0, rtol=0)
self.assertEqual(reference[0:2], consec((2, 3, 3)), atol=0, rtol=0)
self.assertEqual(reference[2, 2, 2], 27, atol=0, rtol=0)
self.assertEqual(reference[:], consec((3, 3, 3)), atol=0, rtol=0)
# indexing with Ellipsis
self.assertEqual(
reference[..., 2],
torch.tensor([[3.0, 6.0, 9.0], [12.0, 15.0, 18.0], [21.0, 24.0, 27.0]]),
atol=0,
rtol=0,
)
self.assertEqual(
reference[0, ..., 2], torch.tensor([3.0, 6.0, 9.0]), atol=0, rtol=0
)
self.assertEqual(reference[..., 2], reference[:, :, 2], atol=0, rtol=0)
self.assertEqual(reference[0, ..., 2], reference[0, :, 2], atol=0, rtol=0)
self.assertEqual(reference[0, 2, ...], reference[0, 2], atol=0, rtol=0)
self.assertEqual(reference[..., 2, 2, 2], 27, atol=0, rtol=0)
self.assertEqual(reference[2, ..., 2, 2], 27, atol=0, rtol=0)
self.assertEqual(reference[2, 2, ..., 2], 27, atol=0, rtol=0)
self.assertEqual(reference[2, 2, 2, ...], 27, atol=0, rtol=0)
self.assertEqual(reference[...], reference, atol=0, rtol=0)
reference_5d = consec((3, 3, 3, 3, 3)).to(device)
self.assertEqual(
reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], atol=0, rtol=0
)
self.assertEqual(
reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0], atol=0, rtol=0
)
self.assertEqual(
reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1], atol=0, rtol=0
)
self.assertEqual(reference_5d[...], reference_5d, atol=0, rtol=0)
# LongTensor indexing
reference = consec((5, 5, 5)).to(device)
idx = torch.LongTensor([2, 4]).to(device)
self.assertEqual(reference[idx], torch.stack([reference[2], reference[4]]))
# TODO: enable one indexing is implemented like in numpy
# self.assertEqual(reference[2, idx], torch.stack([reference[2, 2], reference[2, 4]]))
# self.assertEqual(reference[3, idx, 1], torch.stack([reference[3, 2], reference[3, 4]])[:, 1])
# None indexing
self.assertEqual(reference[2, None], reference[2].unsqueeze(0))
self.assertEqual(
reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0)
)
self.assertEqual(reference[2:4, None], reference[2:4].unsqueeze(1))
self.assertEqual(
reference[None, 2, None, None],
reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0),
)
self.assertEqual(
reference[None, 2:5, None, None],
reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2),
)
# indexing 0-length slice
self.assertEqual(torch.empty(0, 5, 5), reference[slice(0)])
self.assertEqual(torch.empty(0, 5), reference[slice(0), 2])
self.assertEqual(torch.empty(0, 5), reference[2, slice(0)])
self.assertEqual(torch.tensor([]), reference[2, 1:1, 2])
# indexing with step
reference = consec((10, 10, 10)).to(device)
self.assertEqual(reference[1:5:2], torch.stack([reference[1], reference[3]], 0))
self.assertEqual(
reference[1:6:2], torch.stack([reference[1], reference[3], reference[5]], 0)
)
self.assertEqual(reference[1:9:4], torch.stack([reference[1], reference[5]], 0))
self.assertEqual(
reference[2:4, 1:5:2],
torch.stack([reference[2:4, 1], reference[2:4, 3]], 1),
)
self.assertEqual(
reference[3, 1:6:2],
torch.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0),
)
self.assertEqual(
reference[None, 2, 1:9:4],
torch.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0),
)
self.assertEqual(
reference[:, 2, 1:6:2],
torch.stack(
[reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1
),
)
lst = [list(range(i, i + 10)) for i in range(0, 100, 10)]
tensor = torch.DoubleTensor(lst).to(device)
for _i in range(100):
idx1_start = random.randrange(10)
idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1)
idx1_step = random.randrange(1, 8)
idx1 = slice(idx1_start, idx1_end, idx1_step)
if random.randrange(2) == 0:
idx2_start = random.randrange(10)
idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1)
idx2_step = random.randrange(1, 8)
idx2 = slice(idx2_start, idx2_end, idx2_step)
lst_indexed = [l[idx2] for l in lst[idx1]]
tensor_indexed = tensor[idx1, idx2]
else:
lst_indexed = lst[idx1]
tensor_indexed = tensor[idx1]
self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed)
self.assertRaises(ValueError, lambda: reference[1:9:0])
self.assertRaises(ValueError, lambda: reference[1:9:-1])
self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1])
self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1])
self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])
self.assertRaises(IndexError, lambda: reference[0.0])
self.assertRaises(TypeError, lambda: reference[0.0:2.0])
self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0])
self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0])
self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0])
self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0])
def delitem():
del reference[0]
self.assertRaises(TypeError, delitem)
@onlyNativeDeviceTypes
@dtypes(torch.half, torch.double)
def test_advancedindex(self, device, dtype):
# Tests for Integer Array Indexing, Part I - Purely integer array
# indexing
def consec(size, start=1):
# Creates the sequence in float since CPU half doesn't support the
# needed operations. Converts to dtype before returning.
numel = reduce(operator.mul, size, 1)
sequence = torch.ones(numel, dtype=torch.float, device=device).cumsum(0)
sequence.add_(start - 1)
return sequence.view(*size).to(dtype=dtype)
# pick a random valid indexer type
def ri(indices):
choice = random.randint(0, 2)
if choice == 0:
return torch.LongTensor(indices).to(device)
elif choice == 1:
return list(indices)
else:
return tuple(indices)
def validate_indexing(x):
self.assertEqual(x[[0]], consec((1,)))
self.assertEqual(x[ri([0]),], consec((1,)))
self.assertEqual(x[ri([3]),], consec((1,), 4))
self.assertEqual(x[[2, 3, 4]], consec((3,), 3))
self.assertEqual(x[ri([2, 3, 4]),], consec((3,), 3))
self.assertEqual(
x[ri([0, 2, 4]),], torch.tensor([1, 3, 5], dtype=dtype, device=device)
)
def validate_setting(x):
x[[0]] = -2
self.assertEqual(x[[0]], torch.tensor([-2], dtype=dtype, device=device))
x[[0]] = -1
self.assertEqual(
x[ri([0]),], torch.tensor([-1], dtype=dtype, device=device)
)
x[[2, 3, 4]] = 4
self.assertEqual(
x[[2, 3, 4]], torch.tensor([4, 4, 4], dtype=dtype, device=device)
)
x[ri([2, 3, 4]),] = 3
self.assertEqual(
x[ri([2, 3, 4]),], torch.tensor([3, 3, 3], dtype=dtype, device=device)
)
x[ri([0, 2, 4]),] = torch.tensor([5, 4, 3], dtype=dtype, device=device)
self.assertEqual(
x[ri([0, 2, 4]),], torch.tensor([5, 4, 3], dtype=dtype, device=device)
)
# Only validates indexing and setting for halfs
if dtype == torch.half:
reference = consec((10,))
validate_indexing(reference)
validate_setting(reference)
return
# Case 1: Purely Integer Array Indexing
reference = consec((10,))
validate_indexing(reference)
# setting values
validate_setting(reference)
# Tensor with stride != 1
# strided is [1, 3, 5, 7]
reference = consec((10,))
strided = torch.tensor((), dtype=dtype, device=device)
strided.set_(
reference.untyped_storage(),
storage_offset=0,
size=torch.Size([4]),
stride=[2],
)
self.assertEqual(strided[[0]], torch.tensor([1], dtype=dtype, device=device))
self.assertEqual(
strided[ri([0]),], torch.tensor([1], dtype=dtype, device=device)
)
self.assertEqual(
strided[ri([3]),], torch.tensor([7], dtype=dtype, device=device)
)
self.assertEqual(
strided[[1, 2]], torch.tensor([3, 5], dtype=dtype, device=device)
)
self.assertEqual(
strided[ri([1, 2]),], torch.tensor([3, 5], dtype=dtype, device=device)
)
self.assertEqual(
strided[ri([[2, 1], [0, 3]]),],
torch.tensor([[5, 3], [1, 7]], dtype=dtype, device=device),
)
# stride is [4, 8]
strided = torch.tensor((), dtype=dtype, device=device)
strided.set_(
reference.untyped_storage(),
storage_offset=4,
size=torch.Size([2]),
stride=[4],
)
self.assertEqual(strided[[0]], torch.tensor([5], dtype=dtype, device=device))
self.assertEqual(
strided[ri([0]),], torch.tensor([5], dtype=dtype, device=device)
)
self.assertEqual(
strided[ri([1]),], torch.tensor([9], dtype=dtype, device=device)
)
self.assertEqual(
strided[[0, 1]], torch.tensor([5, 9], dtype=dtype, device=device)
)
self.assertEqual(
strided[ri([0, 1]),], torch.tensor([5, 9], dtype=dtype, device=device)
)
self.assertEqual(
strided[ri([[0, 1], [1, 0]]),],
torch.tensor([[5, 9], [9, 5]], dtype=dtype, device=device),
)
# reference is 1 2
# 3 4
# 5 6
reference = consec((3, 2))
self.assertEqual(
reference[ri([0, 1, 2]), ri([0])],
torch.tensor([1, 3, 5], dtype=dtype, device=device),
)
self.assertEqual(
reference[ri([0, 1, 2]), ri([1])],
torch.tensor([2, 4, 6], dtype=dtype, device=device),
)
self.assertEqual(reference[ri([0]), ri([0])], consec((1,)))
self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6))
self.assertEqual(
reference[(ri([0, 0]), ri([0, 1]))],
torch.tensor([1, 2], dtype=dtype, device=device),
)
self.assertEqual(
reference[(ri([0, 1, 1, 0, 2]), ri([1]))],
torch.tensor([2, 4, 4, 2, 6], dtype=dtype, device=device),
)
self.assertEqual(
reference[(ri([0, 0, 1, 1]), ri([0, 1, 0, 0]))],
torch.tensor([1, 2, 3, 3], dtype=dtype, device=device),
)
rows = ri([[0, 0], [1, 2]])
columns = ([0],)
self.assertEqual(
reference[rows, columns],
torch.tensor([[1, 1], [3, 5]], dtype=dtype, device=device),
)
rows = ri([[0, 0], [1, 2]])
columns = ri([1, 0])
self.assertEqual(
reference[rows, columns],
torch.tensor([[2, 1], [4, 5]], dtype=dtype, device=device),
)
rows = ri([[0, 0], [1, 2]])
columns = ri([[0, 1], [1, 0]])
self.assertEqual(
reference[rows, columns],
torch.tensor([[1, 2], [4, 5]], dtype=dtype, device=device),
)
# setting values
reference[ri([0]), ri([1])] = -1
self.assertEqual(
reference[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device)
)
reference[ri([0, 1, 2]), ri([0])] = torch.tensor(
[-1, 2, -4], dtype=dtype, device=device
)
self.assertEqual(
reference[ri([0, 1, 2]), ri([0])],
torch.tensor([-1, 2, -4], dtype=dtype, device=device),
)
reference[rows, columns] = torch.tensor(
[[4, 6], [2, 3]], dtype=dtype, device=device
)
self.assertEqual(
reference[rows, columns],
torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device),
)
# Verify still works with Transposed (i.e. non-contiguous) Tensors
reference = torch.tensor(
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype, device=device
).t_()
# Transposed: [[0, 4, 8],
# [1, 5, 9],
# [2, 6, 10],
# [3, 7, 11]]
self.assertEqual(
reference[ri([0, 1, 2]), ri([0])],
torch.tensor([0, 1, 2], dtype=dtype, device=device),
)
self.assertEqual(
reference[ri([0, 1, 2]), ri([1])],
torch.tensor([4, 5, 6], dtype=dtype, device=device),
)
self.assertEqual(
reference[ri([0]), ri([0])], torch.tensor([0], dtype=dtype, device=device)
)
self.assertEqual(
reference[ri([2]), ri([1])], torch.tensor([6], dtype=dtype, device=device)
)
self.assertEqual(
reference[(ri([0, 0]), ri([0, 1]))],
torch.tensor([0, 4], dtype=dtype, device=device),
)
self.assertEqual(
reference[(ri([0, 1, 1, 0, 3]), ri([1]))],
torch.tensor([4, 5, 5, 4, 7], dtype=dtype, device=device),
)
self.assertEqual(
reference[(ri([0, 0, 1, 1]), ri([0, 1, 0, 0]))],
torch.tensor([0, 4, 1, 1], dtype=dtype, device=device),
)
rows = ri([[0, 0], [1, 2]])
columns = ([0],)
self.assertEqual(
reference[rows, columns],
torch.tensor([[0, 0], [1, 2]], dtype=dtype, device=device),
)
rows = ri([[0, 0], [1, 2]])
columns = ri([1, 0])
self.assertEqual(
reference[rows, columns],
torch.tensor([[4, 0], [5, 2]], dtype=dtype, device=device),
)
rows = ri([[0, 0], [1, 3]])
columns = ri([[0, 1], [1, 2]])
self.assertEqual(
reference[rows, columns],
torch.tensor([[0, 4], [5, 11]], dtype=dtype, device=device),
)
# setting values
reference[ri([0]), ri([1])] = -1
self.assertEqual(
reference[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device)
)
reference[ri([0, 1, 2]), ri([0])] = torch.tensor(
[-1, 2, -4], dtype=dtype, device=device
)
self.assertEqual(
reference[ri([0, 1, 2]), ri([0])],
torch.tensor([-1, 2, -4], dtype=dtype, device=device),
)
reference[rows, columns] = torch.tensor(
[[4, 6], [2, 3]], dtype=dtype, device=device
)
self.assertEqual(
reference[rows, columns],
torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device),
)
# stride != 1
# strided is [[1 3 5 7],
# [9 11 13 15]]
reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
strided = torch.tensor((), dtype=dtype, device=device)
strided.set_(
reference.untyped_storage(), 1, size=torch.Size([2, 4]), stride=[8, 2]
)
self.assertEqual(
strided[ri([0, 1]), ri([0])],
torch.tensor([1, 9], dtype=dtype, device=device),
)
self.assertEqual(
strided[ri([0, 1]), ri([1])],
torch.tensor([3, 11], dtype=dtype, device=device),
)
self.assertEqual(
strided[ri([0]), ri([0])], torch.tensor([1], dtype=dtype, device=device)
)
self.assertEqual(
strided[ri([1]), ri([3])], torch.tensor([15], dtype=dtype, device=device)
)
self.assertEqual(
strided[(ri([0, 0]), ri([0, 3]))],
torch.tensor([1, 7], dtype=dtype, device=device),
)
self.assertEqual(
strided[(ri([1]), ri([0, 1, 1, 0, 3]))],
torch.tensor([9, 11, 11, 9, 15], dtype=dtype, device=device),
)
self.assertEqual(
strided[(ri([0, 0, 1, 1]), ri([0, 1, 0, 0]))],
torch.tensor([1, 3, 9, 9], dtype=dtype, device=device),
)
rows = ri([[0, 0], [1, 1]])
columns = ([0],)
self.assertEqual(
strided[rows, columns],
torch.tensor([[1, 1], [9, 9]], dtype=dtype, device=device),
)
rows = ri([[0, 1], [1, 0]])
columns = ri([1, 2])
self.assertEqual(
strided[rows, columns],
torch.tensor([[3, 13], [11, 5]], dtype=dtype, device=device),
)
rows = ri([[0, 0], [1, 1]])
columns = ri([[0, 1], [1, 2]])
self.assertEqual(
strided[rows, columns],
torch.tensor([[1, 3], [11, 13]], dtype=dtype, device=device),
)
# setting values
# strided is [[10, 11],
# [17, 18]]
reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
strided = torch.tensor((), dtype=dtype, device=device)
strided.set_(
reference.untyped_storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]
)
self.assertEqual(
strided[ri([0]), ri([1])], torch.tensor([11], dtype=dtype, device=device)
)
strided[ri([0]), ri([1])] = -1
self.assertEqual(
strided[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device)
)
reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
strided = torch.tensor((), dtype=dtype, device=device)
strided.set_(
reference.untyped_storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]
)
self.assertEqual(
strided[ri([0, 1]), ri([1, 0])],
torch.tensor([11, 17], dtype=dtype, device=device),
)
strided[ri([0, 1]), ri([1, 0])] = torch.tensor(
[-1, 2], dtype=dtype, device=device
)
self.assertEqual(
strided[ri([0, 1]), ri([1, 0])],
torch.tensor([-1, 2], dtype=dtype, device=device),
)
reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
strided = torch.tensor((), dtype=dtype, device=device)
strided.set_(
reference.untyped_storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]
)
rows = ri([[0], [1]])
columns = ri([[0, 1], [0, 1]])
self.assertEqual(
strided[rows, columns],
torch.tensor([[10, 11], [17, 18]], dtype=dtype, device=device),
)
strided[rows, columns] = torch.tensor(
[[4, 6], [2, 3]], dtype=dtype, device=device
)
self.assertEqual(
strided[rows, columns],
torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device),
)
# Tests using less than the number of dims, and ellipsis
# reference is 1 2
# 3 4
# 5 6
reference = consec((3, 2))
self.assertEqual(
reference[ri([0, 2]),],
torch.tensor([[1, 2], [5, 6]], dtype=dtype, device=device),
)
self.assertEqual(
reference[ri([1]), ...], torch.tensor([[3, 4]], dtype=dtype, device=device)
)
self.assertEqual(
reference[..., ri([1])],
torch.tensor([[2], [4], [6]], dtype=dtype, device=device),
)
# verify too many indices fails
with self.assertRaises(IndexError):
reference[ri([1]), ri([0, 2]), ri([3])]
# test invalid index fails
reference = torch.empty(10, dtype=dtype, device=device)
# can't test cuda because it is a device assert
if not reference.is_cuda:
for err_idx in (10, -11):
with self.assertRaisesRegex(IndexError, r"out of"):
reference[err_idx]
with self.assertRaisesRegex(IndexError, r"out of"):
reference[torch.LongTensor([err_idx]).to(device)]
with self.assertRaisesRegex(IndexError, r"out of"):
reference[[err_idx]]
def tensor_indices_to_np(tensor, indices):
# convert the Torch Tensor to a numpy array
tensor = tensor.to(device="cpu")
npt = tensor.numpy()
# convert indices
idxs = tuple(
i.tolist() if isinstance(i, torch.LongTensor) else i for i in indices
)
return npt, idxs
def get_numpy(tensor, indices):
npt, idxs = tensor_indices_to_np(tensor, indices)
# index and return as a Torch Tensor
return torch.tensor(npt[idxs], dtype=dtype, device=device)
def set_numpy(tensor, indices, value):
if not isinstance(value, int):
if self.device_type != "cpu":
value = value.cpu()
value = value.numpy()
npt, idxs = tensor_indices_to_np(tensor, indices)
npt[idxs] = value
return npt
def assert_get_eq(tensor, indexer):
self.assertEqual(tensor[indexer], get_numpy(tensor, indexer))
def assert_set_eq(tensor, indexer, val):
pyt = tensor.clone()
numt = tensor.clone()
pyt[indexer] = val
numt = torch.tensor(
set_numpy(numt, indexer, val), dtype=dtype, device=device
)
self.assertEqual(pyt, numt)
def assert_backward_eq(tensor, indexer):
cpu = tensor.float().detach().clone().requires_grad_(True)
outcpu = cpu[indexer]
gOcpu = torch.rand_like(outcpu)
outcpu.backward(gOcpu)
dev = cpu.to(device).detach().requires_grad_(True)
outdev = dev[indexer]
outdev.backward(gOcpu.to(device))
self.assertEqual(cpu.grad, dev.grad)
def get_set_tensor(indexed, indexer):
set_size = indexed[indexer].size()
set_count = indexed[indexer].numel()
set_tensor = torch.randperm(set_count).view(set_size).double().to(device)
return set_tensor
# Tensor is 0 1 2 3 4
# 5 6 7 8 9
# 10 11 12 13 14
# 15 16 17 18 19
reference = torch.arange(0.0, 20, dtype=dtype, device=device).view(4, 5)
indices_to_test = [
# grab the second, fourth columns
(slice(None), [1, 3]),
# first, third rows,
([0, 2], slice(None)),
# weird shape
(slice(None), [[0, 1], [2, 3]]),
# negatives
([-1], [0]),
([0, 2], [-1]),
(slice(None), [-1]),
]
# only test dupes on gets
get_indices_to_test = indices_to_test + [(slice(None), [0, 1, 1, 2, 2])]
for indexer in get_indices_to_test:
assert_get_eq(reference, indexer)
if self.device_type != "cpu":
assert_backward_eq(reference, indexer)
for indexer in indices_to_test:
assert_set_eq(reference, indexer, 44)
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
reference = torch.arange(0.0, 160, dtype=dtype, device=device).view(4, 8, 5)
indices_to_test = [
(slice(None), slice(None), (0, 3, 4)),
(slice(None), (2, 4, 5, 7), slice(None)),
((2, 3), slice(None), slice(None)),
(slice(None), (0, 2, 3), (1, 3, 4)),
(slice(None), (0,), (1, 2, 4)),
(slice(None), (0, 1, 3), (4,)),
(slice(None), ((0, 1), (1, 0)), ((2, 3),)),
(slice(None), ((0, 1), (2, 3)), ((0,),)),
(slice(None), ((5, 6),), ((0, 3), (4, 4))),
((0, 2, 3), (1, 3, 4), slice(None)),
((0,), (1, 2, 4), slice(None)),
((0, 1, 3), (4,), slice(None)),
(((0, 1), (1, 0)), ((2, 1), (3, 5)), slice(None)),
(((0, 1), (1, 0)), ((2, 3),), slice(None)),
(((0, 1), (2, 3)), ((0,),), slice(None)),
(((2, 1),), ((0, 3), (4, 4)), slice(None)),
(((2,),), ((0, 3), (4, 1)), slice(None)),
# non-contiguous indexing subspace
((0, 2, 3), slice(None), (1, 3, 4)),
# [...]
# less dim, ellipsis
((0, 2),),
((0, 2), slice(None)),
((0, 2), Ellipsis),
((0, 2), slice(None), Ellipsis),
((0, 2), Ellipsis, slice(None)),
((0, 2), (1, 3)),
((0, 2), (1, 3), Ellipsis),
(Ellipsis, (1, 3), (2, 3)),
(Ellipsis, (2, 3, 4)),
(Ellipsis, slice(None), (2, 3, 4)),
(slice(None), Ellipsis, (2, 3, 4)),
# ellipsis counts for nothing
(Ellipsis, slice(None), slice(None), (0, 3, 4)),
(slice(None), Ellipsis, slice(None), (0, 3, 4)),
(slice(None), slice(None), Ellipsis, (0, 3, 4)),
(slice(None), slice(None), (0, 3, 4), Ellipsis),
(Ellipsis, ((0, 1), (1, 0)), ((2, 1), (3, 5)), slice(None)),
(((0, 1), (1, 0)), ((2, 1), (3, 5)), Ellipsis, slice(None)),
(((0, 1), (1, 0)), ((2, 1), (3, 5)), slice(None), Ellipsis),
]
for indexer in indices_to_test:
assert_get_eq(reference, indexer)
assert_set_eq(reference, indexer, 212)
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
if torch.cuda.is_available():
assert_backward_eq(reference, indexer)
reference = torch.arange(0.0, 1296, dtype=dtype, device=device).view(3, 9, 8, 6)
indices_to_test = [
(slice(None), slice(None), slice(None), (0, 3, 4)),
(slice(None), slice(None), (2, 4, 5, 7), slice(None)),
(slice(None), (2, 3), slice(None), slice(None)),
((1, 2), slice(None), slice(None), slice(None)),
(slice(None), slice(None), (0, 2, 3), (1, 3, 4)),
(slice(None), slice(None), (0,), (1, 2, 4)),
(slice(None), slice(None), (0, 1, 3), (4,)),
(slice(None), slice(None), ((0, 1), (1, 0)), ((2, 3),)),
(slice(None), slice(None), ((0, 1), (2, 3)), ((0,),)),
(slice(None), slice(None), ((5, 6),), ((0, 3), (4, 4))),
(slice(None), (0, 2, 3), (1, 3, 4), slice(None)),
(slice(None), (0,), (1, 2, 4), slice(None)),
(slice(None), (0, 1, 3), (4,), slice(None)),
(slice(None), ((0, 1), (3, 4)), ((2, 3), (0, 1)), slice(None)),
(slice(None), ((0, 1), (3, 4)), ((2, 3),), slice(None)),
(slice(None), ((0, 1), (3, 2)), ((0,),), slice(None)),
(slice(None), ((2, 1),), ((0, 3), (6, 4)), slice(None)),
(slice(None), ((2,),), ((0, 3), (4, 2)), slice(None)),
((0, 1, 2), (1, 3, 4), slice(None), slice(None)),
((0,), (1, 2, 4), slice(None), slice(None)),
((0, 1, 2), (4,), slice(None), slice(None)),
(((0, 1), (0, 2)), ((2, 4), (1, 5)), slice(None), slice(None)),
(((0, 1), (1, 2)), ((2, 0),), slice(None), slice(None)),
(((2, 2),), ((0, 3), (4, 5)), slice(None), slice(None)),
(((2,),), ((0, 3), (4, 5)), slice(None), slice(None)),
(slice(None), (3, 4, 6), (0, 2, 3), (1, 3, 4)),
(slice(None), (2, 3, 4), (1, 3, 4), (4,)),
(slice(None), (0, 1, 3), (4,), (1, 3, 4)),
(slice(None), (6,), (0, 2, 3), (1, 3, 4)),
(slice(None), (2, 3, 5), (3,), (4,)),
(slice(None), (0,), (4,), (1, 3, 4)),
(slice(None), (6,), (0, 2, 3), (1,)),
(slice(None), ((0, 3), (3, 6)), ((0, 1), (1, 3)), ((5, 3), (1, 2))),
((2, 2, 1), (0, 2, 3), (1, 3, 4), slice(None)),
((2, 0, 1), (1, 2, 3), (4,), slice(None)),
((0, 1, 2), (4,), (1, 3, 4), slice(None)),
((0,), (0, 2, 3), (1, 3, 4), slice(None)),
((0, 2, 1), (3,), (4,), slice(None)),
((0,), (4,), (1, 3, 4), slice(None)),
((1,), (0, 2, 3), (1,), slice(None)),
(((1, 2), (1, 2)), ((0, 1), (2, 3)), ((2, 3), (3, 5)), slice(None)),
# less dim, ellipsis
(Ellipsis, (0, 3, 4)),
(Ellipsis, slice(None), (0, 3, 4)),
(Ellipsis, slice(None), slice(None), (0, 3, 4)),
(slice(None), Ellipsis, (0, 3, 4)),
(slice(None), slice(None), Ellipsis, (0, 3, 4)),
(slice(None), (0, 2, 3), (1, 3, 4)),
(slice(None), (0, 2, 3), (1, 3, 4), Ellipsis),
(Ellipsis, (0, 2, 3), (1, 3, 4), slice(None)),
((0,), (1, 2, 4)),
((0,), (1, 2, 4), slice(None)),
((0,), (1, 2, 4), Ellipsis),
((0,), (1, 2, 4), Ellipsis, slice(None)),
((1,),),
((0, 2, 1), (3,), (4,)),
((0, 2, 1), (3,), (4,), slice(None)),
((0, 2, 1), (3,), (4,), Ellipsis),
(Ellipsis, (0, 2, 1), (3,), (4,)),
]
for indexer in indices_to_test:
assert_get_eq(reference, indexer)
assert_set_eq(reference, indexer, 1333)
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
indices_to_test += [
(slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]),
(slice(None), slice(None), [[2]], [[0, 3], [4, 4]]),
]
for indexer in indices_to_test:
assert_get_eq(reference, indexer)
assert_set_eq(reference, indexer, 1333)
if self.device_type != "cpu":
assert_backward_eq(reference, indexer)
def test_advancedindex_big(self, device):
reference = torch.arange(0, 123344, dtype=torch.int, device=device)
self.assertEqual(
reference[[0, 123, 44488, 68807, 123343],],
torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int),
)
def test_set_item_to_scalar_tensor(self, device):
m = random.randint(1, 10)
n = random.randint(1, 10)
z = torch.randn([m, n], device=device)
a = 1.0
w = torch.tensor(a, requires_grad=True, device=device)
z[:, 0] = w
z.sum().backward()
self.assertEqual(w.grad, m * a)
def test_single_int(self, device):
v = torch.randn(5, 7, 3, device=device)
self.assertEqual(v[4].shape, (7, 3))
def test_multiple_int(self, device):
v = torch.randn(5, 7, 3, device=device)
self.assertEqual(v[4].shape, (7, 3))
self.assertEqual(v[4, :, 1].shape, (7,))
def test_none(self, device):
v = torch.randn(5, 7, 3, device=device)
self.assertEqual(v[None].shape, (1, 5, 7, 3))
self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
def test_step(self, device):
v = torch.arange(10, device=device)
self.assertEqual(v[::1], v)
self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
self.assertEqual(v[::11].tolist(), [0])
self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
def test_step_assignment(self, device):
v = torch.zeros(4, 4, device=device)
v[0, 1::2] = torch.tensor([3.0, 4.0], device=device)
self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
self.assertEqual(v[1:].sum(), 0)
def test_bool_indices(self, device):
v = torch.randn(5, 7, 3, device=device)
boolIndices = torch.tensor(
[True, False, True, True, False], dtype=torch.bool, device=device
)
self.assertEqual(v[boolIndices].shape, (3, 7, 3))
self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]]))
v = torch.tensor([True, False, True], dtype=torch.bool, device=device)
boolIndices = torch.tensor(
[True, False, False], dtype=torch.bool, device=device
)
uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8, device=device)
with warnings.catch_warnings(record=True) as w:
v1 = v[boolIndices]
v2 = v[uint8Indices]
self.assertEqual(v1.shape, v2.shape)
self.assertEqual(v1, v2)
self.assertEqual(
v[boolIndices], tensor([True], dtype=torch.bool, device=device)
)
self.assertEqual(len(w), 1)
def test_list_indices(self, device):
N = 1000
t = torch.randn(N, device=device)
# Set window size
W = 10
# Generate a list of lists, containing overlapping window indices
indices = [range(i, i + W) for i in range(0, N - W)]
for i in [len(indices), 100, 32]:
windowed_data = t[indices[:i]]
self.assertEqual(windowed_data.shape, (i, W))
with self.assertRaisesRegex(IndexError, "too many indices"):
windowed_data = t[indices[:31]]
def test_bool_indices_accumulate(self, device):
mask = torch.zeros(size=(10,), dtype=torch.bool, device=device)
y = torch.ones(size=(10, 10), device=device)
y.index_put_((mask,), y[mask], accumulate=True)
self.assertEqual(y, torch.ones(size=(10, 10), device=device))
def test_multiple_bool_indices(self, device):
v = torch.randn(5, 7, 3, device=device)
# note: these broadcast together and are transposed to the first dim
mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool, device=device)
mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device)
self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
def test_byte_mask(self, device):
v = torch.randn(5, 7, 3, device=device)
mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
with warnings.catch_warnings(record=True) as w:
res = v[mask]
self.assertEqual(res.shape, (3, 7, 3))
self.assertEqual(res, torch.stack([v[0], v[2], v[3]]))
self.assertEqual(len(w), 1)
v = torch.tensor([1.0], device=device)
self.assertEqual(v[v == 0], torch.tensor([], device=device))
def test_byte_mask_accumulate(self, device):
mask = torch.zeros(size=(10,), dtype=torch.uint8, device=device)
y = torch.ones(size=(10, 10), device=device)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
y.index_put_((mask,), y[mask], accumulate=True)
self.assertEqual(y, torch.ones(size=(10, 10), device=device))
self.assertEqual(len(w), 2)
@skipIfTorchDynamo(
"This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472"
)
@serialTest(TEST_CUDA)
def test_index_put_accumulate_large_tensor(self, device):
# This test is for tensors with number of elements >= INT_MAX (2^31 - 1).
N = (1 << 31) + 5
dt = torch.int8
a = torch.ones(N, dtype=dt, device=device)
indices = torch.tensor(
[-2, 0, -2, -1, 0, -1, 1], device=device, dtype=torch.long
)
values = torch.tensor([6, 5, 6, 6, 5, 7, 11], dtype=dt, device=device)
a.index_put_((indices,), values, accumulate=True)
self.assertEqual(a[0], 11)
self.assertEqual(a[1], 12)
self.assertEqual(a[2], 1)
self.assertEqual(a[-3], 1)
self.assertEqual(a[-2], 13)
self.assertEqual(a[-1], 14)
a = torch.ones((2, N), dtype=dt, device=device)
indices0 = torch.tensor([0, -1, 0, 1], device=device, dtype=torch.long)
indices1 = torch.tensor([-2, -1, 0, 1], device=device, dtype=torch.long)
values = torch.tensor([12, 13, 10, 11], dtype=dt, device=device)
a.index_put_((indices0, indices1), values, accumulate=True)
self.assertEqual(a[0, 0], 11)
self.assertEqual(a[0, 1], 1)
self.assertEqual(a[1, 0], 1)
self.assertEqual(a[1, 1], 12)
self.assertEqual(a[:, 2], torch.ones(2, dtype=torch.int8))
self.assertEqual(a[:, -3], torch.ones(2, dtype=torch.int8))
self.assertEqual(a[0, -2], 13)
self.assertEqual(a[1, -2], 1)
self.assertEqual(a[-1, -1], 14)
self.assertEqual(a[0, -1], 1)
@onlyNativeDeviceTypes
def test_index_put_accumulate_expanded_values(self, device):
# checks the issue with cuda: https://github.com/pytorch/pytorch/issues/39227
# and verifies consistency with CPU result
t = torch.zeros((5, 2))
t_dev = t.to(device)
indices = [torch.tensor([0, 1, 2, 3]), torch.tensor([1])]
indices_dev = [i.to(device) for i in indices]
values0d = torch.tensor(1.0)
values1d = torch.tensor([1.0])
out_cuda = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True)
out_cpu = t.index_put_(indices, values0d, accumulate=True)
self.assertEqual(out_cuda.cpu(), out_cpu)
out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
out_cpu = t.index_put_(indices, values1d, accumulate=True)
self.assertEqual(out_cuda.cpu(), out_cpu)
t = torch.zeros(4, 3, 2)
t_dev = t.to(device)
indices = [
torch.tensor([0]),
torch.arange(3)[:, None],
torch.arange(2)[None, :],
]
indices_dev = [i.to(device) for i in indices]
values1d = torch.tensor([-1.0, -2.0])
values2d = torch.tensor([[-1.0, -2.0]])
out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
out_cpu = t.index_put_(indices, values1d, accumulate=True)
self.assertEqual(out_cuda.cpu(), out_cpu)
out_cuda = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True)
out_cpu = t.index_put_(indices, values2d, accumulate=True)
self.assertEqual(out_cuda.cpu(), out_cpu)
@onlyCUDA
def test_index_put_large_indices(self, device):
def generate_indices(num_indices: int, index_range: int):
indices = []
for _ in range(num_indices):
x = random.randint(0, index_range - 1)
indices.append(x)
return torch.tensor(indices)
num_indices = 401988
max_index_range = 2000
target_index_range = [16, 256, 2000]
# BFloat16
for generated_index_range in target_index_range:
# create CPU tensors
a_tensor_size = (max_index_range, 256)
a = torch.randn(a_tensor_size, dtype=torch.bfloat16)
b = generate_indices(
num_indices=num_indices, index_range=generated_index_range
)
c_tensor_size = (num_indices, 256)
c = torch.randn(c_tensor_size, dtype=torch.bfloat16)
# create GPU copies
a_dev = a.to(device)
b_dev = b.to(device)
c_dev = c.to(device)
# run
a.index_put_(indices=[b], values=c, accumulate=True)
a_dev.index_put_(indices=[b_dev], values=c_dev, accumulate=True)
self.assertEqual(a_dev.cpu(), a)
# Float32
for generated_index_range in target_index_range:
# create CPU tensors
a_tensor_size = (max_index_range, 256)
a = torch.randn(a_tensor_size, dtype=torch.float32)
b = generate_indices(
num_indices=num_indices, index_range=generated_index_range
)
c_tensor_size = (num_indices, 256)
c = torch.randn(c_tensor_size, dtype=torch.float32)
# create GPU copies
a_dev = a.to(device)
b_dev = b.to(device)
c_dev = c.to(device)
# run
torch.use_deterministic_algorithms(True)
a.index_put_(indices=[b], values=c, accumulate=True)
torch.use_deterministic_algorithms(False)
a_dev.index_put_(indices=[b_dev], values=c_dev, accumulate=True)
self.assertEqual(a_dev.cpu(), a)
@onlyCUDA
def test_index_put_accumulate_non_contiguous(self, device):
t = torch.zeros((5, 2, 2))
t_dev = t.to(device)
t1 = t_dev[:, 0, :]
t2 = t[:, 0, :]
self.assertTrue(not t1.is_contiguous())
self.assertTrue(not t2.is_contiguous())
indices = [torch.tensor([0, 1])]
indices_dev = [i.to(device) for i in indices]
value = torch.randn(2, 2)
out_cuda = t1.index_put_(indices_dev, value.to(device), accumulate=True)
out_cpu = t2.index_put_(indices, value, accumulate=True)
self.assertTrue(not t1.is_contiguous())
self.assertTrue(not t2.is_contiguous())
self.assertEqual(out_cuda.cpu(), out_cpu)
@onlyCUDA
def test_index_put_deterministic_with_optional_tensors(self, device):
def func(x, i, v):
with DeterministicGuard(True):
x[..., i] = v
return x
def func1(x, i, v):
with DeterministicGuard(True):
x[i] = v
return x
n = 4
t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
t_dev = t.to(device)
indices = torch.tensor([1, 0])
indices_dev = indices.to(device)
value0d = torch.tensor(10.0)
value1d = torch.tensor([1.0, 2.0])
values2d = torch.randn(n, 1)
for val in (value0d, value1d, values2d):
out_cuda = func(t_dev, indices_dev, val.to(device))
out_cpu = func(t, indices, val)
self.assertEqual(out_cuda.cpu(), out_cpu)
t = torch.zeros((5, 4))
t_dev = t.to(device)
indices = torch.tensor([1, 4, 3])
indices_dev = indices.to(device)
val = torch.randn(4)
out_cuda = func1(t_dev, indices_dev, val.cuda())
out_cpu = func1(t, indices, val)
self.assertEqual(out_cuda.cpu(), out_cpu)
t = torch.zeros(2, 3, 4)
ind = torch.tensor([0, 1])
val = torch.randn(6, 2)
with self.assertRaisesRegex(RuntimeError, "shape mismatch"):
func(t, ind, val)
with self.assertRaisesRegex(RuntimeError, "must match"):
func(t.to(device), ind.to(device), val.to(device))
val = torch.randn(2, 3, 1)
out_cuda = func1(t.to(device), ind.to(device), val.to(device))
out_cpu = func1(t, ind, val)
self.assertEqual(out_cuda.cpu(), out_cpu)
@onlyNativeDeviceTypes
def test_index_put_accumulate_duplicate_indices(self, device):
for i in range(1, 512):
# generate indices by random walk, this will create indices with
# lots of duplicates interleaved with each other
delta = torch.empty(i, dtype=torch.double, device=device).uniform_(-1, 1)
indices = delta.cumsum(0).long()
input = torch.randn(indices.abs().max() + 1, device=device)
values = torch.randn(indices.size(0), device=device)
output = input.index_put((indices,), values, accumulate=True)
input_list = input.tolist()
indices_list = indices.tolist()
values_list = values.tolist()
for i, v in zip(indices_list, values_list):
input_list[i] += v
self.assertEqual(output, input_list)
@onlyNativeDeviceTypes
def test_index_ind_dtype(self, device):
x = torch.randn(4, 4, device=device)
ind_long = torch.randint(4, (4,), dtype=torch.long, device=device)
ind_int = ind_long.int()
src = torch.randn(4, device=device)
ref = x[ind_long, ind_long]
res = x[ind_int, ind_int]
self.assertEqual(ref, res)
ref = x[ind_long, :]
res = x[ind_int, :]
self.assertEqual(ref, res)
ref = x[:, ind_long]
res = x[:, ind_int]
self.assertEqual(ref, res)
# no repeating indices for index_put
ind_long = torch.arange(4, dtype=torch.long, device=device)
ind_int = ind_long.int()
for accum in (True, False):
inp_ref = x.clone()
inp_res = x.clone()
torch.index_put_(inp_ref, (ind_long, ind_long), src, accum)
torch.index_put_(inp_res, (ind_int, ind_int), src, accum)
self.assertEqual(inp_ref, inp_res)
@skipXLA
def test_index_put_accumulate_empty(self, device):
# Regression test for https://github.com/pytorch/pytorch/issues/94667
input = torch.rand([], dtype=torch.float32, device=device)
with self.assertRaises(RuntimeError):
input.index_put([], torch.tensor([1.0], device=device), True)
def test_multiple_byte_mask(self, device):
v = torch.randn(5, 7, 3, device=device)
# note: these broadcast together and are transposed to the first dim
mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
mask2 = torch.ByteTensor([1, 1, 1]).to(device)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
self.assertEqual(len(w), 2)
def test_byte_mask2d(self, device):
v = torch.randn(5, 7, 3, device=device)
c = torch.randn(5, 7, device=device)
num_ones = (c > 0).sum()
r = v[c > 0]
self.assertEqual(r.shape, (num_ones, 3))
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_jit_indexing(self, device):
def fn1(x):
x[x < 50] = 1.0
return x
def fn2(x):
x[0:50] = 1.0
return x
scripted_fn1 = torch.jit.script(fn1)
scripted_fn2 = torch.jit.script(fn2)
data = torch.arange(100, device=device, dtype=torch.float)
out = scripted_fn1(data.detach().clone())
ref = torch.tensor(
np.concatenate((np.ones(50), np.arange(50, 100))),
device=device,
dtype=torch.float,
)
self.assertEqual(out, ref)
out = scripted_fn2(data.detach().clone())
self.assertEqual(out, ref)
def test_int_indices(self, device):
v = torch.randn(5, 7, 3, device=device)
self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
@dtypes(
torch.cfloat, torch.cdouble, torch.float, torch.bfloat16, torch.long, torch.bool
)
@dtypesIfCPU(
torch.cfloat, torch.cdouble, torch.float, torch.long, torch.bool, torch.bfloat16
)
@dtypesIfCUDA(
torch.cfloat,
torch.cdouble,
torch.half,
torch.long,
torch.bool,
torch.bfloat16,
torch.float8_e5m2,
torch.float8_e4m3fn,
)
def test_index_put_src_datatype(self, device, dtype):
src = torch.ones(3, 2, 4, device=device, dtype=dtype)
vals = torch.ones(3, 2, 4, device=device, dtype=dtype)
indices = (torch.tensor([0, 2, 1]),)
res = src.index_put_(indices, vals, accumulate=True)
self.assertEqual(res.shape, src.shape)
@dtypes(torch.float, torch.bfloat16, torch.long, torch.bool)
@dtypesIfCPU(torch.float, torch.long, torch.bfloat16, torch.bool)
@dtypesIfCUDA(torch.half, torch.long, torch.bfloat16, torch.bool)
def test_index_src_datatype(self, device, dtype):
src = torch.ones(3, 2, 4, device=device, dtype=dtype)
# test index
res = src[[0, 2, 1], :, :]
self.assertEqual(res.shape, src.shape)
# test index_put, no accum
src[[0, 2, 1], :, :] = res
self.assertEqual(res.shape, src.shape)
def test_int_indices2d(self, device):
# From the NumPy indexing example
x = torch.arange(0, 12, device=device).view(4, 3)
rows = torch.tensor([[0, 0], [3, 3]], device=device)
columns = torch.tensor([[0, 2], [0, 2]], device=device)
self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])
def test_int_indices_broadcast(self, device):
# From the NumPy indexing example
x = torch.arange(0, 12, device=device).view(4, 3)
rows = torch.tensor([0, 3], device=device)
columns = torch.tensor([0, 2], device=device)
result = x[rows[:, None], columns]
self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
def test_empty_index(self, device):
x = torch.arange(0, 12, device=device).view(4, 3)
idx = torch.tensor([], dtype=torch.long, device=device)
self.assertEqual(x[idx].numel(), 0)
# empty assignment should have no effect but not throw an exception
y = x.clone()
y[idx] = -1
self.assertEqual(x, y)
mask = torch.zeros(4, 3, device=device).bool()
y[mask] = -1
self.assertEqual(x, y)
def test_empty_ndim_index(self, device):
x = torch.randn(5, device=device)
self.assertEqual(
torch.empty(0, 2, device=device),
x[torch.empty(0, 2, dtype=torch.int64, device=device)],
)
x = torch.randn(2, 3, 4, 5, device=device)
self.assertEqual(
torch.empty(2, 0, 6, 4, 5, device=device),
x[:, torch.empty(0, 6, dtype=torch.int64, device=device)],
)
x = torch.empty(10, 0, device=device)
self.assertEqual(x[[1, 2]].shape, (2, 0))
self.assertEqual(x[[], []].shape, (0,))
with self.assertRaisesRegex(IndexError, "for dimension with size 0"):
x[:, [0, 1]]
def test_empty_ndim_index_bool(self, device):
x = torch.randn(5, device=device)
self.assertRaises(
IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)]
)
def test_empty_slice(self, device):
x = torch.randn(2, 3, 4, 5, device=device)
y = x[:, :, :, 1]
z = y[:, 1:1, :]
self.assertEqual((2, 0, 4), z.shape)
# this isn't technically necessary, but matches NumPy stride calculations.
self.assertEqual((60, 20, 5), z.stride())
self.assertTrue(z.is_contiguous())
def test_index_getitem_copy_bools_slices(self, device):
true = torch.tensor(1, dtype=torch.uint8, device=device)
false = torch.tensor(0, dtype=torch.uint8, device=device)
tensors = [torch.randn(2, 3, device=device), torch.tensor(3.0, device=device)]
for a in tensors:
self.assertNotEqual(a.data_ptr(), a[True].data_ptr())
self.assertEqual(torch.empty(0, *a.shape), a[False])
self.assertNotEqual(a.data_ptr(), a[true].data_ptr())
self.assertEqual(torch.empty(0, *a.shape), a[false])
self.assertEqual(a.data_ptr(), a[None].data_ptr())
self.assertEqual(a.data_ptr(), a[...].data_ptr())
def test_index_setitem_bools_slices(self, device):
true = torch.tensor(1, dtype=torch.uint8, device=device)
false = torch.tensor(0, dtype=torch.uint8, device=device)
tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)]
for a in tensors:
# prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
# (some of these ops already prefix a 1 to the size)
neg_ones = torch.ones_like(a) * -1
neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
a[True] = neg_ones_expanded
self.assertEqual(a, neg_ones)
a[False] = 5
self.assertEqual(a, neg_ones)
a[true] = neg_ones_expanded * 2
self.assertEqual(a, neg_ones * 2)
a[false] = 5
self.assertEqual(a, neg_ones * 2)
a[None] = neg_ones_expanded * 3
self.assertEqual(a, neg_ones * 3)
a[...] = neg_ones_expanded * 4
self.assertEqual(a, neg_ones * 4)
if a.dim() == 0:
with self.assertRaises(IndexError):
a[:] = neg_ones_expanded * 5
def test_index_scalar_with_bool_mask(self, device):
a = torch.tensor(1, device=device)
uintMask = torch.tensor(True, dtype=torch.uint8, device=device)
boolMask = torch.tensor(True, dtype=torch.bool, device=device)
self.assertEqual(a[uintMask], a[boolMask])
self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
a = torch.tensor(True, dtype=torch.bool, device=device)
self.assertEqual(a[uintMask], a[boolMask])
self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
def test_setitem_expansion_error(self, device):
true = torch.tensor(True, device=device)
a = torch.randn(2, 3, device=device)
# check prefix with non-1s doesn't work
a_expanded = a.expand(torch.Size([5, 1]) + a.size())
# NumPy: ValueError
with self.assertRaises(RuntimeError):
a[True] = a_expanded
with self.assertRaises(RuntimeError):
a[true] = a_expanded
def test_getitem_scalars(self, device):
zero = torch.tensor(0, dtype=torch.int64, device=device)
one = torch.tensor(1, dtype=torch.int64, device=device)
# non-scalar indexed with scalars
a = torch.randn(2, 3, device=device)
self.assertEqual(a[0], a[zero])
self.assertEqual(a[0][1], a[zero][one])
self.assertEqual(a[0, 1], a[zero, one])
self.assertEqual(a[0, one], a[zero, 1])
# indexing by a scalar should slice (not copy)
self.assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr())
self.assertEqual(a[1].data_ptr(), a[one.int()].data_ptr())
self.assertEqual(a[1].data_ptr(), a[one.short()].data_ptr())
# scalar indexed with scalar
r = torch.randn((), device=device)
with self.assertRaises(IndexError):
r[:]
with self.assertRaises(IndexError):
r[zero]
self.assertEqual(r, r[...])
def test_setitem_scalars(self, device):
zero = torch.tensor(0, dtype=torch.int64)
# non-scalar indexed with scalars
a = torch.randn(2, 3, device=device)
a_set_with_number = a.clone()
a_set_with_scalar = a.clone()
b = torch.randn(3, device=device)
a_set_with_number[0] = b
a_set_with_scalar[zero] = b
self.assertEqual(a_set_with_number, a_set_with_scalar)
a[1, zero] = 7.7
self.assertEqual(7.7, a[1, 0])
# scalar indexed with scalars
r = torch.randn((), device=device)
with self.assertRaises(IndexError):
r[:] = 8.8
with self.assertRaises(IndexError):
r[zero] = 8.8
r[...] = 9.9
self.assertEqual(9.9, r)
def test_basic_advanced_combined(self, device):
# From the NumPy indexing example
x = torch.arange(0, 12, device=device).view(4, 3)
self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
# Check that it is a copy
unmodified = x.clone()
x[1:2, [1, 2]].zero_()
self.assertEqual(x, unmodified)
# But assignment should modify the original
unmodified = x.clone()
x[1:2, [1, 2]] = 0
self.assertNotEqual(x, unmodified)
def test_int_assignment(self, device):
x = torch.arange(0, 4, device=device).view(2, 2)
x[1] = 5
self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
x = torch.arange(0, 4, device=device).view(2, 2)
x[1] = torch.arange(5, 7, device=device)
self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
def test_byte_tensor_assignment(self, device):
x = torch.arange(0.0, 16, device=device).view(4, 4)
b = torch.ByteTensor([True, False, True, False]).to(device)
value = torch.tensor([3.0, 4.0, 5.0, 6.0], device=device)
with warnings.catch_warnings(record=True) as w:
x[b] = value
self.assertEqual(len(w), 1)
self.assertEqual(x[0], value)
self.assertEqual(x[1], torch.arange(4.0, 8, device=device))
self.assertEqual(x[2], value)
self.assertEqual(x[3], torch.arange(12.0, 16, device=device))
def test_variable_slicing(self, device):
x = torch.arange(0, 16, device=device).view(4, 4)
indices = torch.IntTensor([0, 1]).to(device)
i, j = indices
self.assertEqual(x[i:j], x[0:1])
def test_ellipsis_tensor(self, device):
x = torch.arange(0, 9, device=device).view(3, 3)
idx = torch.tensor([0, 2], device=device)
self.assertEqual(x[..., idx].tolist(), [[0, 2], [3, 5], [6, 8]])
self.assertEqual(x[idx, ...].tolist(), [[0, 1, 2], [6, 7, 8]])
def test_unravel_index_errors(self, device):
with self.assertRaisesRegex(TypeError, r"expected 'indices' to be integer"):
torch.unravel_index(torch.tensor(0.5, device=device), (2, 2))
with self.assertRaisesRegex(TypeError, r"expected 'indices' to be integer"):
torch.unravel_index(torch.tensor([], device=device), (10, 3, 5))
with self.assertRaisesRegex(
TypeError, r"expected 'shape' to be int or sequence"
):
torch.unravel_index(
torch.tensor([1], device=device, dtype=torch.int64),
torch.tensor([1, 2, 3]),
)
with self.assertRaisesRegex(
TypeError, r"expected 'shape' sequence to only contain ints"
):
torch.unravel_index(
torch.tensor([1], device=device, dtype=torch.int64), (1, 2, 2.0)
)
with self.assertRaisesRegex(
ValueError, r"'shape' cannot have negative values, but got \(2, -3\)"
):
torch.unravel_index(torch.tensor(0, device=device), (2, -3))
def test_invalid_index(self, device):
x = torch.arange(0, 16, device=device).view(4, 4)
self.assertRaisesRegex(TypeError, "slice indices", lambda: x["0":"1"])
def test_out_of_bound_index(self, device):
x = torch.arange(0, 100, device=device).view(2, 5, 10)
self.assertRaisesRegex(
IndexError,
"index 5 is out of bounds for dimension 1 with size 5",
lambda: x[0, 5],
)
self.assertRaisesRegex(
IndexError,
"index 4 is out of bounds for dimension 0 with size 2",
lambda: x[4, 5],
)
self.assertRaisesRegex(
IndexError,
"index 15 is out of bounds for dimension 2 with size 10",
lambda: x[0, 1, 15],
)
self.assertRaisesRegex(
IndexError,
"index 12 is out of bounds for dimension 2 with size 10",
lambda: x[:, :, 12],
)
def test_zero_dim_index(self, device):
x = torch.tensor(10, device=device)
self.assertEqual(x, x.item())
def runner():
print(x[0])
return x[0]
self.assertRaisesRegex(IndexError, "invalid index", runner)
@onlyCUDA
def test_invalid_device(self, device):
idx = torch.tensor([0, 1])
b = torch.zeros(5, device=device)
c = torch.tensor([1.0, 2.0], device="cpu")
for accumulate in [True, False]:
self.assertRaises(
RuntimeError,
lambda: torch.index_put_(b, (idx,), c, accumulate=accumulate),
)
@onlyCUDA
def test_cpu_indices(self, device):
idx = torch.tensor([0, 1])
b = torch.zeros(2, device=device)
x = torch.ones(10, device=device)
x[idx] = b # index_put_
ref = torch.ones(10, device=device)
ref[:2] = 0
self.assertEqual(x, ref, atol=0, rtol=0)
out = x[idx] # index
self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
@dtypes(torch.long, torch.float32)
def test_take_along_dim(self, device, dtype):
def _test_against_numpy(t, indices, dim):
actual = torch.take_along_dim(t, indices, dim=dim)
t_np = t.cpu().numpy()
indices_np = indices.cpu().numpy()
expected = np.take_along_axis(t_np, indices_np, axis=dim)
self.assertEqual(actual, expected, atol=0, rtol=0)
for shape in [(3, 2), (2, 3, 5), (2, 4, 0), (2, 3, 1, 4)]:
for noncontiguous in [True, False]:
t = make_tensor(
shape, device=device, dtype=dtype, noncontiguous=noncontiguous
)
for dim in list(range(t.ndim)) + [None]:
if dim is None:
indices = torch.argsort(t.view(-1))
else:
indices = torch.argsort(t, dim=dim)
_test_against_numpy(t, indices, dim)
# test broadcasting
t = torch.ones((3, 4, 1), device=device)
indices = torch.ones((1, 2, 5), dtype=torch.long, device=device)
_test_against_numpy(t, indices, 1)
# test empty indices
t = torch.ones((3, 4, 5), device=device)
indices = torch.ones((3, 0, 5), dtype=torch.long, device=device)
_test_against_numpy(t, indices, 1)
@dtypes(torch.long, torch.float)
def test_take_along_dim_invalid(self, device, dtype):
shape = (2, 3, 1, 4)
dim = 0
t = make_tensor(shape, device=device, dtype=dtype)
indices = torch.argsort(t, dim=dim)
# dim of `t` and `indices` does not match
with self.assertRaisesRegex(
RuntimeError, "input and indices should have the same number of dimensions"
):
torch.take_along_dim(t, indices[0], dim=0)
# invalid `indices` dtype
with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"):
torch.take_along_dim(t, indices.to(torch.bool), dim=0)
with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"):
torch.take_along_dim(t, indices.to(torch.float), dim=0)
with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"):
torch.take_along_dim(t, indices.to(torch.int32), dim=0)
# invalid axis
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
torch.take_along_dim(t, indices, dim=-7)
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
torch.take_along_dim(t, indices, dim=7)
@onlyCUDA
@dtypes(torch.float)
def test_gather_take_along_dim_cross_device(self, device, dtype):
shape = (2, 3, 1, 4)
dim = 0
t = make_tensor(shape, device=device, dtype=dtype)
indices = torch.argsort(t, dim=dim)
with self.assertRaisesRegex(
RuntimeError, "Expected all tensors to be on the same device"
):
torch.gather(t, 0, indices.cpu())
with self.assertRaisesRegex(
RuntimeError,
r"Expected tensor to have .* but got tensor with .* torch.take_along_dim()",
):
torch.take_along_dim(t, indices.cpu(), dim=0)
with self.assertRaisesRegex(
RuntimeError, "Expected all tensors to be on the same device"
):
torch.gather(t.cpu(), 0, indices)
with self.assertRaisesRegex(
RuntimeError,
r"Expected tensor to have .* but got tensor with .* torch.take_along_dim()",
):
torch.take_along_dim(t.cpu(), indices, dim=0)
@onlyCUDA
def test_cuda_broadcast_index_use_deterministic_algorithms(self, device):
with DeterministicGuard(True):
idx1 = torch.tensor([0])
idx2 = torch.tensor([2, 6])
idx3 = torch.tensor([1, 5, 7])
tensor_a = torch.rand(13, 11, 12, 13, 12).cpu()
tensor_b = tensor_a.to(device=device)
tensor_a[idx1] = 1.0
tensor_a[idx1, :, idx2, idx2, :] = 2.0
tensor_a[:, idx1, idx3, :, idx3] = 3.0
tensor_b[idx1] = 1.0
tensor_b[idx1, :, idx2, idx2, :] = 2.0
tensor_b[:, idx1, idx3, :, idx3] = 3.0
self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0)
tensor_a = torch.rand(10, 11).cpu()
tensor_b = tensor_a.to(device=device)
tensor_a[idx3] = 1.0
tensor_a[idx2, :] = 2.0
tensor_a[:, idx2] = 3.0
tensor_a[:, idx1] = 4.0
tensor_b[idx3] = 1.0
tensor_b[idx2, :] = 2.0
tensor_b[:, idx2] = 3.0
tensor_b[:, idx1] = 4.0
self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0)
tensor_a = torch.rand(10, 10).cpu()
tensor_b = tensor_a.to(device=device)
tensor_a[[8]] = 1.0
tensor_b[[8]] = 1.0
self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0)
tensor_a = torch.rand(10).cpu()
tensor_b = tensor_a.to(device=device)
tensor_a[6] = 1.0
tensor_b[6] = 1.0
self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0)
def test_index_limits(self, device):
# Regression test for https://github.com/pytorch/pytorch/issues/115415
t = torch.tensor([], device=device)
idx_min = torch.iinfo(torch.int64).min
idx_max = torch.iinfo(torch.int64).max
self.assertRaises(IndexError, lambda: t[idx_min])
self.assertRaises(IndexError, lambda: t[idx_max])
# The tests below are from NumPy test_indexing.py with some modifications to
# make them compatible with PyTorch. It's licensed under the BDS license below:
#
# Copyright (c) 2005-2017, NumPy Developers.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
#
# * Neither the name of the NumPy Developers nor the names of any
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
class NumpyTests(TestCase):
def test_index_no_floats(self, device):
a = torch.tensor([[[5.0]]], device=device)
self.assertRaises(IndexError, lambda: a[0.0])
self.assertRaises(IndexError, lambda: a[0, 0.0])
self.assertRaises(IndexError, lambda: a[0.0, 0])
self.assertRaises(IndexError, lambda: a[0.0, :])
self.assertRaises(IndexError, lambda: a[:, 0.0])
self.assertRaises(IndexError, lambda: a[:, 0.0, :])
self.assertRaises(IndexError, lambda: a[0.0, :, :])
self.assertRaises(IndexError, lambda: a[0, 0, 0.0])
self.assertRaises(IndexError, lambda: a[0.0, 0, 0])
self.assertRaises(IndexError, lambda: a[0, 0.0, 0])
self.assertRaises(IndexError, lambda: a[-1.4])
self.assertRaises(IndexError, lambda: a[0, -1.4])
self.assertRaises(IndexError, lambda: a[-1.4, 0])
self.assertRaises(IndexError, lambda: a[-1.4, :])
self.assertRaises(IndexError, lambda: a[:, -1.4])
self.assertRaises(IndexError, lambda: a[:, -1.4, :])
self.assertRaises(IndexError, lambda: a[-1.4, :, :])
self.assertRaises(IndexError, lambda: a[0, 0, -1.4])
self.assertRaises(IndexError, lambda: a[-1.4, 0, 0])
self.assertRaises(IndexError, lambda: a[0, -1.4, 0])
# self.assertRaises(IndexError, lambda: a[0.0:, 0.0])
# self.assertRaises(IndexError, lambda: a[0.0:, 0.0,:])
def test_none_index(self, device):
# `None` index adds newaxis
a = tensor([1, 2, 3], device=device)
self.assertEqual(a[None].dim(), a.dim() + 1)
def test_empty_tuple_index(self, device):
# Empty tuple index creates a view
a = tensor([1, 2, 3], device=device)
self.assertEqual(a[()], a)
self.assertEqual(a[()].data_ptr(), a.data_ptr())
def test_empty_fancy_index(self, device):
# Empty list index creates an empty array
a = tensor([1, 2, 3], device=device)
self.assertEqual(a[[]], torch.tensor([], dtype=torch.long, device=device))
b = tensor([], device=device).long()
self.assertEqual(a[[]], torch.tensor([], dtype=torch.long, device=device))
b = tensor([], device=device).float()
self.assertRaises(IndexError, lambda: a[b])
def test_ellipsis_index(self, device):
a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
self.assertIsNot(a[...], a)
self.assertEqual(a[...], a)
# `a[...]` was `a` in numpy <1.9.
self.assertEqual(a[...].data_ptr(), a.data_ptr())
# Slicing with ellipsis can skip an
# arbitrary number of dimensions
self.assertEqual(a[0, ...], a[0])
self.assertEqual(a[0, ...], a[0, :])
self.assertEqual(a[..., 0], a[:, 0])
# In NumPy, slicing with ellipsis results in a 0-dim array. In PyTorch
# we don't have separate 0-dim arrays and scalars.
self.assertEqual(a[0, ..., 1], torch.tensor(2, device=device))
# Assignment with `(Ellipsis,)` on 0-d arrays
b = torch.tensor(1)
b[(Ellipsis,)] = 2
self.assertEqual(b, 2)
def test_single_int_index(self, device):
# Single integer index selects one row
a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
self.assertEqual(a[0], [1, 2, 3])
self.assertEqual(a[-1], [7, 8, 9])
# Index out of bounds produces IndexError
self.assertRaises(IndexError, a.__getitem__, 1 << 30)
# Index overflow produces Exception NB: different exception type
self.assertRaises(Exception, a.__getitem__, 1 << 64)
def test_single_bool_index(self, device):
# Single boolean index
a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
self.assertEqual(a[True], a[None])
self.assertEqual(a[False], a[None][0:0])
def test_boolean_shape_mismatch(self, device):
arr = torch.ones((5, 4, 3), device=device)
index = tensor([True], device=device)
self.assertRaisesRegex(IndexError, "mask", lambda: arr[index])
index = tensor([False] * 6, device=device)
self.assertRaisesRegex(IndexError, "mask", lambda: arr[index])
index = torch.ByteTensor(4, 4).to(device).zero_()
self.assertRaisesRegex(IndexError, "mask", lambda: arr[index])
self.assertRaisesRegex(IndexError, "mask", lambda: arr[(slice(None), index)])
def test_boolean_indexing_onedim(self, device):
# Indexing a 2-dimensional array with
# boolean array of length one
a = tensor([[0.0, 0.0, 0.0]], device=device)
b = tensor([True], device=device)
self.assertEqual(a[b], a)
# boolean assignment
a[b] = 1.0
self.assertEqual(a, tensor([[1.0, 1.0, 1.0]], device=device))
# https://github.com/pytorch/pytorch/issues/127003
@xfailIfTorchDynamo
def test_boolean_assignment_value_mismatch(self, device):
# A boolean assignment should fail when the shape of the values
# cannot be broadcast to the subscription. (see also gh-3458)
a = torch.arange(0, 4, device=device)
def f(a, v):
a[a > -1] = tensor(v).to(device)
self.assertRaisesRegex(Exception, "shape mismatch", f, a, [])
self.assertRaisesRegex(Exception, "shape mismatch", f, a, [1, 2, 3])
self.assertRaisesRegex(Exception, "shape mismatch", f, a[:1], [1, 2, 3])
def test_boolean_indexing_twodim(self, device):
# Indexing a 2-dimensional array with
# 2-dimensional boolean array
a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
b = tensor(
[[True, False, True], [False, True, False], [True, False, True]],
device=device,
)
self.assertEqual(a[b], tensor([1, 3, 5, 7, 9], device=device))
self.assertEqual(a[b[1]], tensor([[4, 5, 6]], device=device))
self.assertEqual(a[b[0]], a[b[2]])
# boolean assignment
a[b] = 0
self.assertEqual(a, tensor([[0, 2, 0], [4, 0, 6], [0, 8, 0]], device=device))
def test_boolean_indexing_weirdness(self, device):
# Weird boolean indexing things
a = torch.ones((2, 3, 4), device=device)
self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape)
self.assertEqual(
torch.ones(1, 2, device=device), a[True, [0, 1], True, True, [1], [[2]]]
)
self.assertRaises(IndexError, lambda: a[False, [0, 1], ...])
def test_boolean_indexing_weirdness_tensors(self, device):
# Weird boolean indexing things
false = torch.tensor(False, device=device)
true = torch.tensor(True, device=device)
a = torch.ones((2, 3, 4), device=device)
self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape)
self.assertEqual(
torch.ones(1, 2, device=device), a[true, [0, 1], true, true, [1], [[2]]]
)
self.assertRaises(IndexError, lambda: a[false, [0, 1], ...])
def test_boolean_indexing_alldims(self, device):
true = torch.tensor(True, device=device)
a = torch.ones((2, 3), device=device)
self.assertEqual((1, 2, 3), a[True, True].shape)
self.assertEqual((1, 2, 3), a[true, true].shape)
def test_boolean_list_indexing(self, device):
# Indexing a 2-dimensional array with
# boolean lists
a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
b = [True, False, False]
c = [True, True, False]
self.assertEqual(a[b], tensor([[1, 2, 3]], device=device))
self.assertEqual(a[b, b], tensor([1], device=device))
self.assertEqual(a[c], tensor([[1, 2, 3], [4, 5, 6]], device=device))
self.assertEqual(a[c, c], tensor([1, 5], device=device))
def test_everything_returns_views(self, device):
# Before `...` would return a itself.
a = tensor([5], device=device)
self.assertIsNot(a, a[()])
self.assertIsNot(a, a[...])
self.assertIsNot(a, a[:])
def test_broaderrors_indexing(self, device):
a = torch.zeros(5, 5, device=device)
self.assertRaisesRegex(
IndexError, "shape mismatch", a.__getitem__, ([0, 1], [0, 1, 2])
)
self.assertRaisesRegex(
IndexError, "shape mismatch", a.__setitem__, ([0, 1], [0, 1, 2]), 0
)
def test_trivial_fancy_out_of_bounds(self, device):
a = torch.zeros(5, device=device)
ind = torch.ones(20, dtype=torch.int64, device=device)
if a.is_cuda:
raise unittest.SkipTest("CUDA asserts instead of raising an exception")
ind[-1] = 10
self.assertRaises(IndexError, a.__getitem__, ind)
self.assertRaises(IndexError, a.__setitem__, ind, 0)
ind = torch.ones(20, dtype=torch.int64, device=device)
ind[0] = 11
self.assertRaises(IndexError, a.__getitem__, ind)
self.assertRaises(IndexError, a.__setitem__, ind, 0)
def test_index_is_larger(self, device):
# Simple case of fancy index broadcasting of the index.
a = torch.zeros((5, 5), device=device)
a[[[0], [1], [2]], [0, 1, 2]] = tensor([2.0, 3.0, 4.0], device=device)
self.assertTrue((a[:3, :3] == tensor([2.0, 3.0, 4.0], device=device)).all())
def test_broadcast_subspace(self, device):
a = torch.zeros((100, 100), device=device)
v = torch.arange(0.0, 100, device=device)[:, None]
b = torch.arange(99, -1, -1, device=device).long()
a[b] = v
expected = b.float().unsqueeze(1).expand(100, 100)
self.assertEqual(a, expected)
def test_truncate_leading_1s(self, device):
col_max = torch.randn(1, 4)
kernel = col_max.T * col_max # [4, 4] tensor
kernel2 = kernel.clone()
# Set the diagonal
kernel[range(len(kernel)), range(len(kernel))] = torch.square(col_max)
torch.diagonal(kernel2).copy_(torch.square(col_max.view(4)))
self.assertEqual(kernel, kernel2)
instantiate_device_type_tests(TestIndexing, globals(), except_for="meta")
instantiate_device_type_tests(NumpyTests, globals(), except_for="meta")
if __name__ == "__main__":
run_tests()