mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix numpy compatibility for 2d small list indices (#154806)
Will fix #119548 and linked issues once we switch from warning to the new behavior, but for now, given how much this syntax was used in our test suite, we suspect a silent change will be disruptive. We will change the behavior after 2.8 branch is cut. Numpy behavior was changed at least in numpy 1.24 (more than 2 years ago) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154806 Approved by: https://github.com/cyyever, https://github.com/Skylion007, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
e2760544fa
commit
34e3930401
@ -21,10 +21,10 @@ class TensorParallelRandomStateTests(DTensorTestBase):
|
||||
assert shape[0] % n == 0
|
||||
local_shape = [shape[0] // n, shape[1]]
|
||||
|
||||
slice_idx = [
|
||||
slice_idx = (
|
||||
slice(idx * local_shape[0], (idx + 1) * local_shape[0]),
|
||||
slice(local_shape[1]),
|
||||
]
|
||||
)
|
||||
return large_tensor[slice_idx]
|
||||
|
||||
def check_gathered_tensors(self, self_rank, size, gathered_tensors, assertFunc):
|
||||
|
@ -65,12 +65,12 @@ class DistTensorRandomInitTest(DTensorTestBase):
|
||||
# compare with local tensors from other ranks
|
||||
for other_rank in range(self.world_size):
|
||||
if self.rank != other_rank:
|
||||
slice_idx = [
|
||||
slice_idx = (
|
||||
slice(input_size[0]),
|
||||
slice(
|
||||
other_rank * input_size[1], (other_rank + 1) * input_size[1]
|
||||
),
|
||||
]
|
||||
)
|
||||
# other rank should have a different local tensor
|
||||
self.assertNotEqual(dtensor.full_tensor()[slice_idx], local_tensor)
|
||||
|
||||
@ -537,9 +537,9 @@ class DistTensorRandomOpTest(DTensorTestBase):
|
||||
slice(offset, offset + size) for offset, size in other_local_shard
|
||||
]
|
||||
if local_shard_offset == other_local_shard_offset:
|
||||
self.assertEqual(full_tensor[slice_idx], local_tensor)
|
||||
self.assertEqual(full_tensor[tuple(slice_idx)], local_tensor)
|
||||
else:
|
||||
self.assertNotEqual(full_tensor[slice_idx], local_tensor)
|
||||
self.assertNotEqual(full_tensor[tuple(slice_idx)], local_tensor)
|
||||
|
||||
|
||||
class DistTensorRandomOpsTest3D(DTensorTestBase):
|
||||
|
@ -322,13 +322,13 @@ def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
|
||||
test_args = [
|
||||
(3, ([1, 2],)),
|
||||
(3, (slice(0, 3),)),
|
||||
(3, ([slice(0, 3), 1],)),
|
||||
(3, ([[0, 2, 3], [1, 3, 3], [0, 0, 2]],)),
|
||||
(3, ([[0, 0, 3], [1, 1, 3], [0, 0, 2]],)),
|
||||
(3, ([slice(None), slice(None), [0, 3]],)),
|
||||
(3, ([slice(None), [0, 3], slice(None)],)),
|
||||
(3, ([[0, 3], slice(None), slice(None)],)),
|
||||
(3, ([[0, 3], [1, 2], slice(None)],)),
|
||||
(3, ((slice(0, 3), 1),)),
|
||||
(3, (([0, 2, 3], [1, 3, 3], [0, 0, 2]),)),
|
||||
(3, (([0, 0, 3], [1, 1, 3], [0, 0, 2]),)),
|
||||
(3, ((slice(None), slice(None), [0, 3]),)),
|
||||
(3, ((slice(None), [0, 3], slice(None)),)),
|
||||
(3, (([0, 3], slice(None), slice(None)),)),
|
||||
(3, (([0, 3], [1, 2], slice(None)),)),
|
||||
(
|
||||
3,
|
||||
(
|
||||
@ -337,20 +337,20 @@ def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
|
||||
],
|
||||
),
|
||||
),
|
||||
(3, ([[0, 3], slice(None)],)),
|
||||
(3, ([[0, 3], Ellipsis],)),
|
||||
(3, ([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])],)),
|
||||
(4, ([slice(None), adv_idx, adv_idx, slice(None)],)),
|
||||
(4, ([slice(None), adv_idx, slice(None), adv_idx],)),
|
||||
(4, ([adv_idx, slice(None), slice(None), adv_idx],)),
|
||||
(4, ([slice(None), slice(None), adv_idx, adv_idx],)),
|
||||
(4, ([Ellipsis, adv_idx, adv_idx],)),
|
||||
(5, ([slice(None), slice(None), adv_idx, slice(None), adv_idx],)),
|
||||
(5, ([slice(None), slice(None), adv_idx, adv_idx, slice(None)],)),
|
||||
(5, ([slice(None), slice(None), adv_idx, None, adv_idx, slice(None)],)),
|
||||
(6, ([slice(None), slice(None), slice(None), adv_idx, adv_idx],)),
|
||||
(6, ([slice(None), slice(None), adv_idx, adv_idx, adv_idx],)),
|
||||
(6, ([slice(None), slice(None), None, adv_idx, adv_idx, adv_idx],)),
|
||||
(3, (([0, 3], slice(None)),)),
|
||||
(3, (([0, 3], Ellipsis),)),
|
||||
(3, (([0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])),)),
|
||||
(4, ((slice(None), adv_idx, adv_idx, slice(None)),)),
|
||||
(4, ((slice(None), adv_idx, slice(None), adv_idx),)),
|
||||
(4, ((adv_idx, slice(None), slice(None), adv_idx),)),
|
||||
(4, ((slice(None), slice(None), adv_idx, adv_idx),)),
|
||||
(4, ((Ellipsis, adv_idx, adv_idx),)),
|
||||
(5, ((slice(None), slice(None), adv_idx, slice(None), adv_idx),)),
|
||||
(5, ((slice(None), slice(None), adv_idx, adv_idx, slice(None)),)),
|
||||
(5, ((slice(None), slice(None), adv_idx, None, adv_idx, slice(None)),)),
|
||||
(6, ((slice(None), slice(None), slice(None), adv_idx, adv_idx),)),
|
||||
(6, ((slice(None), slice(None), adv_idx, adv_idx, adv_idx),)),
|
||||
(6, ((slice(None), slice(None), None, adv_idx, adv_idx, adv_idx),)),
|
||||
]
|
||||
|
||||
def get_shape(dim):
|
||||
@ -400,20 +400,22 @@ def sample_inputs_aten_index_put(op_info, device, dtype, requires_grad, **kwargs
|
||||
adv_idx = torch.LongTensor([[0, 1], [2, 3]])
|
||||
# self_shape, indices
|
||||
additional = [
|
||||
((5, 6, 7, 8), [None, adv_idx, adv_idx, None]),
|
||||
((5, 6, 7, 8), [None, adv_idx, None, adv_idx]),
|
||||
((5, 6, 7, 8), [adv_idx, None, None, adv_idx]),
|
||||
((5, 6, 7, 8), [None, None, adv_idx, adv_idx]),
|
||||
((5, 6, 7, 8, 9), [None, None, adv_idx, None, adv_idx]),
|
||||
((5, 6, 7, 8, 9), [None, None, adv_idx, adv_idx, None]),
|
||||
((5, 6, 7, 8, 9, 10), [None, None, None, adv_idx, adv_idx]),
|
||||
((5, 6, 7, 8, 9, 10), [None, None, adv_idx, adv_idx, adv_idx]),
|
||||
((5, 6, 7, 8), (None, adv_idx, adv_idx, None)),
|
||||
((5, 6, 7, 8), (None, adv_idx, None, adv_idx)),
|
||||
((5, 6, 7, 8), (adv_idx, None, None, adv_idx)),
|
||||
((5, 6, 7, 8), (None, None, adv_idx, adv_idx)),
|
||||
((5, 6, 7, 8, 9), (None, None, adv_idx, None, adv_idx)),
|
||||
((5, 6, 7, 8, 9), (None, None, adv_idx, adv_idx, None)),
|
||||
((5, 6, 7, 8, 9, 10), (None, None, None, adv_idx, adv_idx)),
|
||||
((5, 6, 7, 8, 9, 10), (None, None, adv_idx, adv_idx, adv_idx)),
|
||||
]
|
||||
for self_shape, indices in additional:
|
||||
for broadcast_value in [False, True]:
|
||||
inp = make_arg(self_shape)
|
||||
|
||||
tmp_indices = [slice(None) if idx is None else idx for idx in indices]
|
||||
tmp_indices = tuple(
|
||||
[slice(None) if idx is None else idx for idx in indices]
|
||||
)
|
||||
values_shape = inp[tmp_indices].shape
|
||||
if broadcast_value:
|
||||
values_shape = values_shape[3:]
|
||||
|
@ -3028,8 +3028,8 @@ class TestAutograd(TestCase):
|
||||
check_index(x, y, ([1, 2, 3], [0]))
|
||||
check_index(x, y, ([1, 2], [2, 1]))
|
||||
check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 3]]))
|
||||
check_index(x, y, ([slice(None), [2, 3]]))
|
||||
check_index(x, y, ([[2, 3], slice(None)]))
|
||||
check_index(x, y, ((slice(None), [2, 3])))
|
||||
check_index(x, y, (([2, 3], slice(None))))
|
||||
|
||||
# advanced indexing, with less dim, or ellipsis
|
||||
check_index(x, y, ([0]))
|
||||
@ -3061,8 +3061,8 @@ class TestAutograd(TestCase):
|
||||
# advanced indexing, with a tensor wrapped in a variable
|
||||
z = torch.LongTensor([0, 1])
|
||||
zv = Variable(z, requires_grad=False)
|
||||
seq = [z, Ellipsis]
|
||||
seqv = [zv, Ellipsis]
|
||||
seq = (z, Ellipsis)
|
||||
seqv = (zv, Ellipsis)
|
||||
|
||||
if y.grad is not None:
|
||||
with torch.no_grad():
|
||||
@ -3086,7 +3086,7 @@ class TestAutograd(TestCase):
|
||||
x = torch.arange(1.0, 17).view(4, 4)
|
||||
y = Variable(x, requires_grad=True)
|
||||
|
||||
idx = [[1, 1, 3, 2, 1, 2], [0]]
|
||||
idx = ([1, 1, 3, 2, 1, 2], [0])
|
||||
y[idx].sum().backward()
|
||||
expected_grad = torch.zeros(4, 4)
|
||||
for i in idx[0]:
|
||||
@ -3097,7 +3097,7 @@ class TestAutograd(TestCase):
|
||||
|
||||
x = torch.arange(1.0, 17).view(4, 4)
|
||||
y = Variable(x, requires_grad=True)
|
||||
idx = [[[1, 2], [0, 0]], [[0, 1], [1, 1]]]
|
||||
idx = ([[1, 2], [0, 0]], [[0, 1], [1, 1]])
|
||||
y[idx].sum().backward()
|
||||
expected_grad = torch.tensor(
|
||||
[
|
||||
@ -3112,7 +3112,7 @@ class TestAutograd(TestCase):
|
||||
x = torch.arange(1.0, 65).view(4, 4, 4)
|
||||
y = Variable(x, requires_grad=True)
|
||||
|
||||
idx = [[1, 1, 1], slice(None), slice(None)]
|
||||
idx = ([1, 1, 1], slice(None), slice(None))
|
||||
y[idx].sum().backward()
|
||||
expected_grad = torch.empty(4, 4, 4).zero_()
|
||||
expected_grad[1].fill_(3)
|
||||
@ -3541,32 +3541,32 @@ class TestAutograd(TestCase):
|
||||
self._test_setitem((5, 5), 1)
|
||||
self._test_setitem((5,), 1)
|
||||
self._test_setitem((1,), 0)
|
||||
self._test_setitem((10,), [[0, 4, 2]])
|
||||
self._test_setitem((5, 5), [[0, 4], [2, 2]])
|
||||
self._test_setitem((5, 5, 5), [slice(None), slice(None), [1, 3]])
|
||||
self._test_setitem((5, 5, 5), [slice(None), [1, 3], slice(None)])
|
||||
self._test_setitem((5, 5, 5), [[1, 3], slice(None), slice(None)])
|
||||
self._test_setitem((5, 5, 5), [slice(None), [2, 4], [1, 3]])
|
||||
self._test_setitem((5, 5, 5), [[1, 3], [2, 4], slice(None)])
|
||||
self._test_setitem((10,), ([0, 4, 2]))
|
||||
self._test_setitem((5, 5), ([0, 4], [2, 2]))
|
||||
self._test_setitem((5, 5, 5), (slice(None), slice(None), [1, 3]))
|
||||
self._test_setitem((5, 5, 5), (slice(None), [1, 3], slice(None)))
|
||||
self._test_setitem((5, 5, 5), ([1, 3], slice(None), slice(None)))
|
||||
self._test_setitem((5, 5, 5), (slice(None), [2, 4], [1, 3]))
|
||||
self._test_setitem((5, 5, 5), ([1, 3], [2, 4], slice(None)))
|
||||
self._test_setitem_tensor((5, 5), 3)
|
||||
self._test_setitem_tensor((5, 5), [[0, 1], [1, 0]])
|
||||
self._test_setitem_tensor((5, 5), ([0, 1], [1, 0]))
|
||||
self._test_setitem_tensor((5,), 3)
|
||||
self._test_setitem_tensor(
|
||||
(5,), Variable(torch.LongTensor([3]), requires_grad=False).sum()
|
||||
)
|
||||
self._test_setitem_tensor((5,), [[0, 1, 2, 3]])
|
||||
self._test_setitem_tensor((5, 5, 5), [slice(None), slice(None), [1, 3]])
|
||||
self._test_setitem_tensor((5, 5, 5), [slice(None), [1, 3], slice(None)])
|
||||
self._test_setitem_tensor((5, 5, 5), [[1, 3], slice(None), slice(None)])
|
||||
self._test_setitem_tensor((5, 5, 5), [slice(None), [2, 4], [1, 3]])
|
||||
self._test_setitem_tensor((5, 5, 5), [[1, 3], [2, 4], slice(None)])
|
||||
self._test_setitem_tensor((5, 5, 5), (slice(None), slice(None), [1, 3]))
|
||||
self._test_setitem_tensor((5, 5, 5), (slice(None), [1, 3], slice(None)))
|
||||
self._test_setitem_tensor((5, 5, 5), ([1, 3], slice(None), slice(None)))
|
||||
self._test_setitem_tensor((5, 5, 5), (slice(None), [2, 4], [1, 3]))
|
||||
self._test_setitem_tensor((5, 5, 5), ([1, 3], [2, 4], slice(None)))
|
||||
self._test_setitem_tensor(
|
||||
(5, 5, 5),
|
||||
[
|
||||
(
|
||||
Variable(torch.LongTensor([1, 3]), requires_grad=False),
|
||||
[2, 4],
|
||||
slice(None),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
def test_setitem_mask(self):
|
||||
|
@ -250,7 +250,10 @@ class TestIndexing(TestCase):
|
||||
reference = consec((10,))
|
||||
strided = torch.tensor((), dtype=dtype, device=device)
|
||||
strided.set_(
|
||||
reference.storage(), storage_offset=0, size=torch.Size([4]), stride=[2]
|
||||
reference.untyped_storage(),
|
||||
storage_offset=0,
|
||||
size=torch.Size([4]),
|
||||
stride=[2],
|
||||
)
|
||||
|
||||
self.assertEqual(strided[[0]], torch.tensor([1], dtype=dtype, device=device))
|
||||
@ -274,7 +277,10 @@ class TestIndexing(TestCase):
|
||||
# stride is [4, 8]
|
||||
strided = torch.tensor((), dtype=dtype, device=device)
|
||||
strided.set_(
|
||||
reference.storage(), storage_offset=4, size=torch.Size([2]), stride=[4]
|
||||
reference.untyped_storage(),
|
||||
storage_offset=4,
|
||||
size=torch.Size([2]),
|
||||
stride=[4],
|
||||
)
|
||||
self.assertEqual(strided[[0]], torch.tensor([5], dtype=dtype, device=device))
|
||||
self.assertEqual(
|
||||
@ -309,15 +315,15 @@ class TestIndexing(TestCase):
|
||||
self.assertEqual(reference[ri([0]), ri([0])], consec((1,)))
|
||||
self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6))
|
||||
self.assertEqual(
|
||||
reference[[ri([0, 0]), ri([0, 1])]],
|
||||
reference[(ri([0, 0]), ri([0, 1]))],
|
||||
torch.tensor([1, 2], dtype=dtype, device=device),
|
||||
)
|
||||
self.assertEqual(
|
||||
reference[[ri([0, 1, 1, 0, 2]), ri([1])]],
|
||||
reference[(ri([0, 1, 1, 0, 2]), ri([1]))],
|
||||
torch.tensor([2, 4, 4, 2, 6], dtype=dtype, device=device),
|
||||
)
|
||||
self.assertEqual(
|
||||
reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
|
||||
reference[(ri([0, 0, 1, 1]), ri([0, 1, 0, 0]))],
|
||||
torch.tensor([1, 2, 3, 3], dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
@ -387,15 +393,15 @@ class TestIndexing(TestCase):
|
||||
reference[ri([2]), ri([1])], torch.tensor([6], dtype=dtype, device=device)
|
||||
)
|
||||
self.assertEqual(
|
||||
reference[[ri([0, 0]), ri([0, 1])]],
|
||||
reference[(ri([0, 0]), ri([0, 1]))],
|
||||
torch.tensor([0, 4], dtype=dtype, device=device),
|
||||
)
|
||||
self.assertEqual(
|
||||
reference[[ri([0, 1, 1, 0, 3]), ri([1])]],
|
||||
reference[(ri([0, 1, 1, 0, 3]), ri([1]))],
|
||||
torch.tensor([4, 5, 5, 4, 7], dtype=dtype, device=device),
|
||||
)
|
||||
self.assertEqual(
|
||||
reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
|
||||
reference[(ri([0, 0, 1, 1]), ri([0, 1, 0, 0]))],
|
||||
torch.tensor([0, 4, 1, 1], dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
@ -446,7 +452,9 @@ class TestIndexing(TestCase):
|
||||
|
||||
reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
|
||||
strided = torch.tensor((), dtype=dtype, device=device)
|
||||
strided.set_(reference.storage(), 1, size=torch.Size([2, 4]), stride=[8, 2])
|
||||
strided.set_(
|
||||
reference.untyped_storage(), 1, size=torch.Size([2, 4]), stride=[8, 2]
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
strided[ri([0, 1]), ri([0])],
|
||||
@ -463,15 +471,15 @@ class TestIndexing(TestCase):
|
||||
strided[ri([1]), ri([3])], torch.tensor([15], dtype=dtype, device=device)
|
||||
)
|
||||
self.assertEqual(
|
||||
strided[[ri([0, 0]), ri([0, 3])]],
|
||||
strided[(ri([0, 0]), ri([0, 3]))],
|
||||
torch.tensor([1, 7], dtype=dtype, device=device),
|
||||
)
|
||||
self.assertEqual(
|
||||
strided[[ri([1]), ri([0, 1, 1, 0, 3])]],
|
||||
strided[(ri([1]), ri([0, 1, 1, 0, 3]))],
|
||||
torch.tensor([9, 11, 11, 9, 15], dtype=dtype, device=device),
|
||||
)
|
||||
self.assertEqual(
|
||||
strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
|
||||
strided[(ri([0, 0, 1, 1]), ri([0, 1, 0, 0]))],
|
||||
torch.tensor([1, 3, 9, 9], dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
@ -502,7 +510,9 @@ class TestIndexing(TestCase):
|
||||
|
||||
reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
|
||||
strided = torch.tensor((), dtype=dtype, device=device)
|
||||
strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1])
|
||||
strided.set_(
|
||||
reference.untyped_storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]
|
||||
)
|
||||
self.assertEqual(
|
||||
strided[ri([0]), ri([1])], torch.tensor([11], dtype=dtype, device=device)
|
||||
)
|
||||
@ -513,7 +523,9 @@ class TestIndexing(TestCase):
|
||||
|
||||
reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
|
||||
strided = torch.tensor((), dtype=dtype, device=device)
|
||||
strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1])
|
||||
strided.set_(
|
||||
reference.untyped_storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]
|
||||
)
|
||||
self.assertEqual(
|
||||
strided[ri([0, 1]), ri([1, 0])],
|
||||
torch.tensor([11, 17], dtype=dtype, device=device),
|
||||
@ -528,7 +540,9 @@ class TestIndexing(TestCase):
|
||||
|
||||
reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
|
||||
strided = torch.tensor((), dtype=dtype, device=device)
|
||||
strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1])
|
||||
strided.set_(
|
||||
reference.untyped_storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]
|
||||
)
|
||||
|
||||
rows = ri([[0], [1]])
|
||||
columns = ri([[0, 1], [0, 1]])
|
||||
@ -642,19 +656,19 @@ class TestIndexing(TestCase):
|
||||
|
||||
indices_to_test = [
|
||||
# grab the second, fourth columns
|
||||
[slice(None), [1, 3]],
|
||||
(slice(None), [1, 3]),
|
||||
# first, third rows,
|
||||
[[0, 2], slice(None)],
|
||||
([0, 2], slice(None)),
|
||||
# weird shape
|
||||
[slice(None), [[0, 1], [2, 3]]],
|
||||
(slice(None), [[0, 1], [2, 3]]),
|
||||
# negatives
|
||||
[[-1], [0]],
|
||||
[[0, 2], [-1]],
|
||||
[slice(None), [-1]],
|
||||
([-1], [0]),
|
||||
([0, 2], [-1]),
|
||||
(slice(None), [-1]),
|
||||
]
|
||||
|
||||
# only test dupes on gets
|
||||
get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]
|
||||
get_indices_to_test = indices_to_test + [(slice(None), [0, 1, 1, 2, 2])]
|
||||
|
||||
for indexer in get_indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
@ -668,46 +682,46 @@ class TestIndexing(TestCase):
|
||||
reference = torch.arange(0.0, 160, dtype=dtype, device=device).view(4, 8, 5)
|
||||
|
||||
indices_to_test = [
|
||||
[slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), [2, 4, 5, 7], slice(None)],
|
||||
[[2, 3], slice(None), slice(None)],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [0], [1, 2, 4]],
|
||||
[slice(None), [0, 1, 3], [4]],
|
||||
[slice(None), [[0, 1], [1, 0]], [[2, 3]]],
|
||||
[slice(None), [[0, 1], [2, 3]], [[0]]],
|
||||
[slice(None), [[5, 6]], [[0, 3], [4, 4]]],
|
||||
[[0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[0], [1, 2, 4], slice(None)],
|
||||
[[0, 1, 3], [4], slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 3]], slice(None)],
|
||||
[[[0, 1], [2, 3]], [[0]], slice(None)],
|
||||
[[[2, 1]], [[0, 3], [4, 4]], slice(None)],
|
||||
[[[2]], [[0, 3], [4, 1]], slice(None)],
|
||||
(slice(None), slice(None), (0, 3, 4)),
|
||||
(slice(None), (2, 4, 5, 7), slice(None)),
|
||||
((2, 3), slice(None), slice(None)),
|
||||
(slice(None), (0, 2, 3), (1, 3, 4)),
|
||||
(slice(None), (0,), (1, 2, 4)),
|
||||
(slice(None), (0, 1, 3), (4,)),
|
||||
(slice(None), ((0, 1), (1, 0)), ((2, 3),)),
|
||||
(slice(None), ((0, 1), (2, 3)), ((0,),)),
|
||||
(slice(None), ((5, 6),), ((0, 3), (4, 4))),
|
||||
((0, 2, 3), (1, 3, 4), slice(None)),
|
||||
((0,), (1, 2, 4), slice(None)),
|
||||
((0, 1, 3), (4,), slice(None)),
|
||||
(((0, 1), (1, 0)), ((2, 1), (3, 5)), slice(None)),
|
||||
(((0, 1), (1, 0)), ((2, 3),), slice(None)),
|
||||
(((0, 1), (2, 3)), ((0,),), slice(None)),
|
||||
(((2, 1),), ((0, 3), (4, 4)), slice(None)),
|
||||
(((2,),), ((0, 3), (4, 1)), slice(None)),
|
||||
# non-contiguous indexing subspace
|
||||
[[0, 2, 3], slice(None), [1, 3, 4]],
|
||||
((0, 2, 3), slice(None), (1, 3, 4)),
|
||||
# [...]
|
||||
# less dim, ellipsis
|
||||
[[0, 2]],
|
||||
[[0, 2], slice(None)],
|
||||
[[0, 2], Ellipsis],
|
||||
[[0, 2], slice(None), Ellipsis],
|
||||
[[0, 2], Ellipsis, slice(None)],
|
||||
[[0, 2], [1, 3]],
|
||||
[[0, 2], [1, 3], Ellipsis],
|
||||
[Ellipsis, [1, 3], [2, 3]],
|
||||
[Ellipsis, [2, 3, 4]],
|
||||
[Ellipsis, slice(None), [2, 3, 4]],
|
||||
[slice(None), Ellipsis, [2, 3, 4]],
|
||||
((0, 2),),
|
||||
((0, 2), slice(None)),
|
||||
((0, 2), Ellipsis),
|
||||
((0, 2), slice(None), Ellipsis),
|
||||
((0, 2), Ellipsis, slice(None)),
|
||||
((0, 2), (1, 3)),
|
||||
((0, 2), (1, 3), Ellipsis),
|
||||
(Ellipsis, (1, 3), (2, 3)),
|
||||
(Ellipsis, (2, 3, 4)),
|
||||
(Ellipsis, slice(None), (2, 3, 4)),
|
||||
(slice(None), Ellipsis, (2, 3, 4)),
|
||||
# ellipsis counts for nothing
|
||||
[Ellipsis, slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), Ellipsis, slice(None), [0, 3, 4]],
|
||||
[slice(None), slice(None), Ellipsis, [0, 3, 4]],
|
||||
[slice(None), slice(None), [0, 3, 4], Ellipsis],
|
||||
[Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],
|
||||
(Ellipsis, slice(None), slice(None), (0, 3, 4)),
|
||||
(slice(None), Ellipsis, slice(None), (0, 3, 4)),
|
||||
(slice(None), slice(None), Ellipsis, (0, 3, 4)),
|
||||
(slice(None), slice(None), (0, 3, 4), Ellipsis),
|
||||
(Ellipsis, ((0, 1), (1, 0)), ((2, 1), (3, 5)), slice(None)),
|
||||
(((0, 1), (1, 0)), ((2, 1), (3, 5)), Ellipsis, slice(None)),
|
||||
(((0, 1), (1, 0)), ((2, 1), (3, 5)), slice(None), Ellipsis),
|
||||
]
|
||||
|
||||
for indexer in indices_to_test:
|
||||
@ -720,65 +734,65 @@ class TestIndexing(TestCase):
|
||||
reference = torch.arange(0.0, 1296, dtype=dtype, device=device).view(3, 9, 8, 6)
|
||||
|
||||
indices_to_test = [
|
||||
[slice(None), slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), slice(None), [2, 4, 5, 7], slice(None)],
|
||||
[slice(None), [2, 3], slice(None), slice(None)],
|
||||
[[1, 2], slice(None), slice(None), slice(None)],
|
||||
[slice(None), slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), slice(None), [0], [1, 2, 4]],
|
||||
[slice(None), slice(None), [0, 1, 3], [4]],
|
||||
[slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],
|
||||
[slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],
|
||||
[slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[slice(None), [0], [1, 2, 4], slice(None)],
|
||||
[slice(None), [0, 1, 3], [4], slice(None)],
|
||||
[slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],
|
||||
[slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],
|
||||
[slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],
|
||||
[slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],
|
||||
[slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],
|
||||
[[0, 1, 2], [1, 3, 4], slice(None), slice(None)],
|
||||
[[0], [1, 2, 4], slice(None), slice(None)],
|
||||
[[0, 1, 2], [4], slice(None), slice(None)],
|
||||
[[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],
|
||||
[[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],
|
||||
[[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],
|
||||
[[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],
|
||||
[slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [2, 3, 4], [1, 3, 4], [4]],
|
||||
[slice(None), [0, 1, 3], [4], [1, 3, 4]],
|
||||
[slice(None), [6], [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [2, 3, 5], [3], [4]],
|
||||
[slice(None), [0], [4], [1, 3, 4]],
|
||||
[slice(None), [6], [0, 2, 3], [1]],
|
||||
[slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
|
||||
[[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[2, 0, 1], [1, 2, 3], [4], slice(None)],
|
||||
[[0, 1, 2], [4], [1, 3, 4], slice(None)],
|
||||
[[0], [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[0, 2, 1], [3], [4], slice(None)],
|
||||
[[0], [4], [1, 3, 4], slice(None)],
|
||||
[[1], [0, 2, 3], [1], slice(None)],
|
||||
[[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],
|
||||
(slice(None), slice(None), slice(None), (0, 3, 4)),
|
||||
(slice(None), slice(None), (2, 4, 5, 7), slice(None)),
|
||||
(slice(None), (2, 3), slice(None), slice(None)),
|
||||
((1, 2), slice(None), slice(None), slice(None)),
|
||||
(slice(None), slice(None), (0, 2, 3), (1, 3, 4)),
|
||||
(slice(None), slice(None), (0,), (1, 2, 4)),
|
||||
(slice(None), slice(None), (0, 1, 3), (4,)),
|
||||
(slice(None), slice(None), ((0, 1), (1, 0)), ((2, 3),)),
|
||||
(slice(None), slice(None), ((0, 1), (2, 3)), ((0,),)),
|
||||
(slice(None), slice(None), ((5, 6),), ((0, 3), (4, 4))),
|
||||
(slice(None), (0, 2, 3), (1, 3, 4), slice(None)),
|
||||
(slice(None), (0,), (1, 2, 4), slice(None)),
|
||||
(slice(None), (0, 1, 3), (4,), slice(None)),
|
||||
(slice(None), ((0, 1), (3, 4)), ((2, 3), (0, 1)), slice(None)),
|
||||
(slice(None), ((0, 1), (3, 4)), ((2, 3),), slice(None)),
|
||||
(slice(None), ((0, 1), (3, 2)), ((0,),), slice(None)),
|
||||
(slice(None), ((2, 1),), ((0, 3), (6, 4)), slice(None)),
|
||||
(slice(None), ((2,),), ((0, 3), (4, 2)), slice(None)),
|
||||
((0, 1, 2), (1, 3, 4), slice(None), slice(None)),
|
||||
((0,), (1, 2, 4), slice(None), slice(None)),
|
||||
((0, 1, 2), (4,), slice(None), slice(None)),
|
||||
(((0, 1), (0, 2)), ((2, 4), (1, 5)), slice(None), slice(None)),
|
||||
(((0, 1), (1, 2)), ((2, 0),), slice(None), slice(None)),
|
||||
(((2, 2),), ((0, 3), (4, 5)), slice(None), slice(None)),
|
||||
(((2,),), ((0, 3), (4, 5)), slice(None), slice(None)),
|
||||
(slice(None), (3, 4, 6), (0, 2, 3), (1, 3, 4)),
|
||||
(slice(None), (2, 3, 4), (1, 3, 4), (4,)),
|
||||
(slice(None), (0, 1, 3), (4,), (1, 3, 4)),
|
||||
(slice(None), (6,), (0, 2, 3), (1, 3, 4)),
|
||||
(slice(None), (2, 3, 5), (3,), (4,)),
|
||||
(slice(None), (0,), (4,), (1, 3, 4)),
|
||||
(slice(None), (6,), (0, 2, 3), (1,)),
|
||||
(slice(None), ((0, 3), (3, 6)), ((0, 1), (1, 3)), ((5, 3), (1, 2))),
|
||||
((2, 2, 1), (0, 2, 3), (1, 3, 4), slice(None)),
|
||||
((2, 0, 1), (1, 2, 3), (4,), slice(None)),
|
||||
((0, 1, 2), (4,), (1, 3, 4), slice(None)),
|
||||
((0,), (0, 2, 3), (1, 3, 4), slice(None)),
|
||||
((0, 2, 1), (3,), (4,), slice(None)),
|
||||
((0,), (4,), (1, 3, 4), slice(None)),
|
||||
((1,), (0, 2, 3), (1,), slice(None)),
|
||||
(((1, 2), (1, 2)), ((0, 1), (2, 3)), ((2, 3), (3, 5)), slice(None)),
|
||||
# less dim, ellipsis
|
||||
[Ellipsis, [0, 3, 4]],
|
||||
[Ellipsis, slice(None), [0, 3, 4]],
|
||||
[Ellipsis, slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), Ellipsis, [0, 3, 4]],
|
||||
[slice(None), slice(None), Ellipsis, [0, 3, 4]],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],
|
||||
[Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[0], [1, 2, 4]],
|
||||
[[0], [1, 2, 4], slice(None)],
|
||||
[[0], [1, 2, 4], Ellipsis],
|
||||
[[0], [1, 2, 4], Ellipsis, slice(None)],
|
||||
[[1]],
|
||||
[[0, 2, 1], [3], [4]],
|
||||
[[0, 2, 1], [3], [4], slice(None)],
|
||||
[[0, 2, 1], [3], [4], Ellipsis],
|
||||
[Ellipsis, [0, 2, 1], [3], [4]],
|
||||
(Ellipsis, (0, 3, 4)),
|
||||
(Ellipsis, slice(None), (0, 3, 4)),
|
||||
(Ellipsis, slice(None), slice(None), (0, 3, 4)),
|
||||
(slice(None), Ellipsis, (0, 3, 4)),
|
||||
(slice(None), slice(None), Ellipsis, (0, 3, 4)),
|
||||
(slice(None), (0, 2, 3), (1, 3, 4)),
|
||||
(slice(None), (0, 2, 3), (1, 3, 4), Ellipsis),
|
||||
(Ellipsis, (0, 2, 3), (1, 3, 4), slice(None)),
|
||||
((0,), (1, 2, 4)),
|
||||
((0,), (1, 2, 4), slice(None)),
|
||||
((0,), (1, 2, 4), Ellipsis),
|
||||
((0,), (1, 2, 4), Ellipsis, slice(None)),
|
||||
((1,),),
|
||||
((0, 2, 1), (3,), (4,)),
|
||||
((0, 2, 1), (3,), (4,), slice(None)),
|
||||
((0, 2, 1), (3,), (4,), Ellipsis),
|
||||
(Ellipsis, (0, 2, 1), (3,), (4,)),
|
||||
]
|
||||
|
||||
for indexer in indices_to_test:
|
||||
@ -786,8 +800,8 @@ class TestIndexing(TestCase):
|
||||
assert_set_eq(reference, indexer, 1333)
|
||||
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
indices_to_test += [
|
||||
[slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
|
||||
[slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],
|
||||
(slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]),
|
||||
(slice(None), slice(None), [[2]], [[0, 3], [4, 4]]),
|
||||
]
|
||||
for indexer in indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
@ -866,6 +880,21 @@ class TestIndexing(TestCase):
|
||||
)
|
||||
self.assertEqual(len(w), 1)
|
||||
|
||||
def test_list_indices(self, device):
|
||||
N = 1000
|
||||
t = torch.randn(N, device=device)
|
||||
# Set window size
|
||||
W = 10
|
||||
# Generate a list of lists, containing overlapping window indices
|
||||
indices = [range(i, i + W) for i in range(0, N - W)]
|
||||
|
||||
for i in [len(indices), 100, 32]:
|
||||
windowed_data = t[indices[:i]]
|
||||
self.assertEqual(windowed_data.shape, (i, W))
|
||||
|
||||
with self.assertRaisesRegex(IndexError, "too many indices"):
|
||||
windowed_data = t[indices[:31]]
|
||||
|
||||
def test_bool_indices_accumulate(self, device):
|
||||
mask = torch.zeros(size=(10,), dtype=torch.bool, device=device)
|
||||
y = torch.ones(size=(10, 10), device=device)
|
||||
|
@ -60,20 +60,22 @@ def _hermitian_conj(x, dim):
|
||||
"""
|
||||
out = torch.empty_like(x)
|
||||
mid = (x.size(dim) - 1) // 2
|
||||
idx = [slice(None)] * out.dim()
|
||||
idx_center = list(idx)
|
||||
idx_center[dim] = 0
|
||||
idx = tuple([slice(None)] * out.dim())
|
||||
out[idx] = x[idx]
|
||||
|
||||
idx_neg = list(idx)
|
||||
idx_neg[dim] = slice(-mid, None)
|
||||
idx_pos = idx
|
||||
idx_neg = tuple(idx_neg)
|
||||
idx_pos = list(idx)
|
||||
idx_pos[dim] = slice(1, mid + 1)
|
||||
idx_pos = tuple(idx_pos)
|
||||
|
||||
out[idx_pos] = x[idx_neg].flip(dim)
|
||||
out[idx_neg] = x[idx_pos].flip(dim)
|
||||
if (2 * mid + 1 < x.size(dim)):
|
||||
idx = list(idx)
|
||||
idx[dim] = mid + 1
|
||||
idx = tuple(idx)
|
||||
out[idx] = x[idx]
|
||||
return out.conj()
|
||||
|
||||
@ -518,6 +520,7 @@ class TestFFT(TestCase):
|
||||
lastdim_size = input.size(lastdim) // 2 + 1
|
||||
idx = [slice(None)] * input_ndim
|
||||
idx[lastdim] = slice(0, lastdim_size)
|
||||
idx = tuple(idx)
|
||||
input = input[idx]
|
||||
|
||||
s = [shape[dim] for dim in actual_dims]
|
||||
@ -558,6 +561,7 @@ class TestFFT(TestCase):
|
||||
lastdim_size = expect.size(lastdim) // 2 + 1
|
||||
idx = [slice(None)] * input_ndim
|
||||
idx[lastdim] = slice(0, lastdim_size)
|
||||
idx = tuple(idx)
|
||||
expect = expect[idx]
|
||||
|
||||
actual = torch.fft.ihfftn(input, dim=dim, norm="ortho")
|
||||
|
@ -941,7 +941,7 @@ def choose(
|
||||
]
|
||||
|
||||
idx_list[0] = a
|
||||
return choices[idx_list].squeeze(0)
|
||||
return choices[tuple(idx_list)].squeeze(0)
|
||||
|
||||
|
||||
# ### unique et al. ###
|
||||
|
@ -25,6 +25,7 @@
|
||||
#include <ATen/TracerMode.h>
|
||||
#include <ATen/core/LegacyTypeDispatch.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <c10/core/Layout.h>
|
||||
@ -292,6 +293,13 @@ static bool treatSequenceAsTuple(PyObject* index) {
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
if (n >= 32) {
|
||||
TORCH_WARN(
|
||||
"Using a non-tuple sequence for "
|
||||
"multidimensional indexing is deprecated and will be changed in "
|
||||
"pytorch 2.9; use x[tuple(seq)] instead of "
|
||||
"x[seq]. In pytorch 2.9 this will be interpreted as tensor index, "
|
||||
"x[torch.tensor(seq)], which will result either in an error or a "
|
||||
"different result");
|
||||
return false;
|
||||
}
|
||||
for (Py_ssize_t i = 0; i < n; i++) {
|
||||
|
@ -592,7 +592,7 @@ def _distribute_tensors(
|
||||
]
|
||||
if local_state.is_meta:
|
||||
# Use .clone() here rather than view to clone and return only the sliced portion, minimizing memory access and cost.
|
||||
local_tensor = full_tensor[slices].detach().clone()
|
||||
local_tensor = full_tensor[tuple(slices)].detach().clone()
|
||||
# TODO: currently, we cannot handle strided sharding if the dp dimension is not even. For example,
|
||||
# one of the case that is not yet supported is when placements = (Shard(0), _StridedShard(0, sf=2)).
|
||||
ret = DTensor.from_local(
|
||||
@ -605,7 +605,7 @@ def _distribute_tensors(
|
||||
else:
|
||||
ret = local_state
|
||||
# Copy full_tensor[slices] into local_state.to_local() to reduce memory footprint.
|
||||
ret.to_local().copy_(full_tensor[slices])
|
||||
ret.to_local().copy_(full_tensor[tuple(slices)])
|
||||
local_state_dict[key] = ret
|
||||
|
||||
|
||||
|
@ -394,6 +394,8 @@ class PruningContainer(BasePruningMethod):
|
||||
raise ValueError(f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}")
|
||||
|
||||
# compute the new mask on the unpruned slice of the tensor t
|
||||
if isinstance(slc, list):
|
||||
slc = tuple(slc)
|
||||
partial_mask = method.compute_mask(t[slc], default_mask=mask[slc])
|
||||
new_mask[slc] = partial_mask.to(dtype=new_mask.dtype)
|
||||
|
||||
@ -625,6 +627,7 @@ class RandomStructured(BasePruningMethod):
|
||||
mask = torch.zeros_like(t)
|
||||
slc = [slice(None)] * len(t.shape)
|
||||
slc[dim] = channel_mask
|
||||
slc = tuple(slc)
|
||||
mask[slc] = 1
|
||||
return mask
|
||||
|
||||
@ -739,6 +742,7 @@ class LnStructured(BasePruningMethod):
|
||||
# replace a None at position=dim with indices
|
||||
# e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3]
|
||||
slc[dim] = indices
|
||||
slc = tuple(slc)
|
||||
# use slc to slice mask and replace all its entries with 1s
|
||||
# e.g.: mask[:, :, [0, 2, 3]] = 1
|
||||
mask[slc] = 1
|
||||
|
@ -124,7 +124,7 @@ def multidim_slicer(dims, slices, *tensors):
|
||||
for d, d_slice in zip(dims, slices):
|
||||
if d is not None:
|
||||
s[d] = d_slice
|
||||
yield t[s]
|
||||
yield t[tuple(s)]
|
||||
|
||||
|
||||
def ptr_stride_extractor(*tensors):
|
||||
|
@ -3266,17 +3266,17 @@ def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
|
||||
test_args = [
|
||||
([1, 2],),
|
||||
(slice(0, 3),),
|
||||
([slice(0, 3), 1],),
|
||||
([[0, 2, 3], [1, 3, 3], [0, 0, 2]],),
|
||||
([[0, 0, 3], [1, 1, 3], [0, 0, 2]],),
|
||||
([slice(None), slice(None), [0, 3]],),
|
||||
([slice(None), [0, 3], slice(None)],),
|
||||
([[0, 3], slice(None), slice(None)],),
|
||||
([[0, 3], [1, 2], slice(None)],),
|
||||
([[0, 3], ],),
|
||||
([[0, 3], slice(None)],),
|
||||
([[0, 3], Ellipsis],),
|
||||
([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])],),
|
||||
((slice(0, 3), 1),),
|
||||
(([0, 2, 3], [1, 3, 3], [0, 0, 2]),),
|
||||
(([0, 0, 3], [1, 1, 3], [0, 0, 2]),),
|
||||
((slice(None), slice(None), [0, 3]),),
|
||||
((slice(None), [0, 3], slice(None)),),
|
||||
(([0, 3], slice(None), slice(None)),),
|
||||
(([0, 3], [1, 2], slice(None)),),
|
||||
(([0, 3], ),),
|
||||
(([0, 3], slice(None)),),
|
||||
(([0, 3], Ellipsis),),
|
||||
(([0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])),),
|
||||
(index_variable(2, S, device=device),),
|
||||
(mask_not_all_zeros((S,)),),
|
||||
]
|
||||
@ -3284,7 +3284,7 @@ def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
|
||||
for args in test_args:
|
||||
yield SampleInput(make_arg((S, S, S)), args=args)
|
||||
|
||||
yield SampleInput(make_arg((S, S, S, S)), args=([slice(None), [0, 1], slice(None), [0, 1]],))
|
||||
yield SampleInput(make_arg((S, S, S, S)), args=((slice(None), [0, 1], slice(None), [0, 1]),))
|
||||
|
||||
def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
|
||||
|
@ -290,7 +290,7 @@ class FuzzedTensor:
|
||||
raw_tensor = raw_tensor.permute(tuple(np.argsort(order)))
|
||||
|
||||
slices = [slice(0, size * step, step) for size, step in zip(size, steps)]
|
||||
tensor = raw_tensor[slices]
|
||||
tensor = raw_tensor[tuple(slices)]
|
||||
|
||||
properties = {
|
||||
"numel": int(tensor.numel()),
|
||||
|
Reference in New Issue
Block a user