fix numpy compatibility for 2d small list indices (#154806)

Will fix #119548 and linked issues once we switch from warning to the new behavior,
but for now, given how much this syntax was used in our test suite, we suspect a silent change will be disruptive.
We will change the behavior after 2.8 branch is cut.
Numpy behavior was changed at least in numpy 1.24 (more than 2 years ago)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154806
Approved by: https://github.com/cyyever, https://github.com/Skylion007, https://github.com/albanD
This commit is contained in:
Natalia Gimelshein
2025-06-04 01:58:52 +00:00
committed by PyTorch MergeBot
parent e2760544fa
commit 34e3930401
13 changed files with 244 additions and 197 deletions

View File

@ -21,10 +21,10 @@ class TensorParallelRandomStateTests(DTensorTestBase):
assert shape[0] % n == 0 assert shape[0] % n == 0
local_shape = [shape[0] // n, shape[1]] local_shape = [shape[0] // n, shape[1]]
slice_idx = [ slice_idx = (
slice(idx * local_shape[0], (idx + 1) * local_shape[0]), slice(idx * local_shape[0], (idx + 1) * local_shape[0]),
slice(local_shape[1]), slice(local_shape[1]),
] )
return large_tensor[slice_idx] return large_tensor[slice_idx]
def check_gathered_tensors(self, self_rank, size, gathered_tensors, assertFunc): def check_gathered_tensors(self, self_rank, size, gathered_tensors, assertFunc):

View File

@ -65,12 +65,12 @@ class DistTensorRandomInitTest(DTensorTestBase):
# compare with local tensors from other ranks # compare with local tensors from other ranks
for other_rank in range(self.world_size): for other_rank in range(self.world_size):
if self.rank != other_rank: if self.rank != other_rank:
slice_idx = [ slice_idx = (
slice(input_size[0]), slice(input_size[0]),
slice( slice(
other_rank * input_size[1], (other_rank + 1) * input_size[1] other_rank * input_size[1], (other_rank + 1) * input_size[1]
), ),
] )
# other rank should have a different local tensor # other rank should have a different local tensor
self.assertNotEqual(dtensor.full_tensor()[slice_idx], local_tensor) self.assertNotEqual(dtensor.full_tensor()[slice_idx], local_tensor)
@ -537,9 +537,9 @@ class DistTensorRandomOpTest(DTensorTestBase):
slice(offset, offset + size) for offset, size in other_local_shard slice(offset, offset + size) for offset, size in other_local_shard
] ]
if local_shard_offset == other_local_shard_offset: if local_shard_offset == other_local_shard_offset:
self.assertEqual(full_tensor[slice_idx], local_tensor) self.assertEqual(full_tensor[tuple(slice_idx)], local_tensor)
else: else:
self.assertNotEqual(full_tensor[slice_idx], local_tensor) self.assertNotEqual(full_tensor[tuple(slice_idx)], local_tensor)
class DistTensorRandomOpsTest3D(DTensorTestBase): class DistTensorRandomOpsTest3D(DTensorTestBase):

View File

@ -322,13 +322,13 @@ def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
test_args = [ test_args = [
(3, ([1, 2],)), (3, ([1, 2],)),
(3, (slice(0, 3),)), (3, (slice(0, 3),)),
(3, ([slice(0, 3), 1],)), (3, ((slice(0, 3), 1),)),
(3, ([[0, 2, 3], [1, 3, 3], [0, 0, 2]],)), (3, (([0, 2, 3], [1, 3, 3], [0, 0, 2]),)),
(3, ([[0, 0, 3], [1, 1, 3], [0, 0, 2]],)), (3, (([0, 0, 3], [1, 1, 3], [0, 0, 2]),)),
(3, ([slice(None), slice(None), [0, 3]],)), (3, ((slice(None), slice(None), [0, 3]),)),
(3, ([slice(None), [0, 3], slice(None)],)), (3, ((slice(None), [0, 3], slice(None)),)),
(3, ([[0, 3], slice(None), slice(None)],)), (3, (([0, 3], slice(None), slice(None)),)),
(3, ([[0, 3], [1, 2], slice(None)],)), (3, (([0, 3], [1, 2], slice(None)),)),
( (
3, 3,
( (
@ -337,20 +337,20 @@ def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
], ],
), ),
), ),
(3, ([[0, 3], slice(None)],)), (3, (([0, 3], slice(None)),)),
(3, ([[0, 3], Ellipsis],)), (3, (([0, 3], Ellipsis),)),
(3, ([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])],)), (3, (([0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])),)),
(4, ([slice(None), adv_idx, adv_idx, slice(None)],)), (4, ((slice(None), adv_idx, adv_idx, slice(None)),)),
(4, ([slice(None), adv_idx, slice(None), adv_idx],)), (4, ((slice(None), adv_idx, slice(None), adv_idx),)),
(4, ([adv_idx, slice(None), slice(None), adv_idx],)), (4, ((adv_idx, slice(None), slice(None), adv_idx),)),
(4, ([slice(None), slice(None), adv_idx, adv_idx],)), (4, ((slice(None), slice(None), adv_idx, adv_idx),)),
(4, ([Ellipsis, adv_idx, adv_idx],)), (4, ((Ellipsis, adv_idx, adv_idx),)),
(5, ([slice(None), slice(None), adv_idx, slice(None), adv_idx],)), (5, ((slice(None), slice(None), adv_idx, slice(None), adv_idx),)),
(5, ([slice(None), slice(None), adv_idx, adv_idx, slice(None)],)), (5, ((slice(None), slice(None), adv_idx, adv_idx, slice(None)),)),
(5, ([slice(None), slice(None), adv_idx, None, adv_idx, slice(None)],)), (5, ((slice(None), slice(None), adv_idx, None, adv_idx, slice(None)),)),
(6, ([slice(None), slice(None), slice(None), adv_idx, adv_idx],)), (6, ((slice(None), slice(None), slice(None), adv_idx, adv_idx),)),
(6, ([slice(None), slice(None), adv_idx, adv_idx, adv_idx],)), (6, ((slice(None), slice(None), adv_idx, adv_idx, adv_idx),)),
(6, ([slice(None), slice(None), None, adv_idx, adv_idx, adv_idx],)), (6, ((slice(None), slice(None), None, adv_idx, adv_idx, adv_idx),)),
] ]
def get_shape(dim): def get_shape(dim):
@ -400,20 +400,22 @@ def sample_inputs_aten_index_put(op_info, device, dtype, requires_grad, **kwargs
adv_idx = torch.LongTensor([[0, 1], [2, 3]]) adv_idx = torch.LongTensor([[0, 1], [2, 3]])
# self_shape, indices # self_shape, indices
additional = [ additional = [
((5, 6, 7, 8), [None, adv_idx, adv_idx, None]), ((5, 6, 7, 8), (None, adv_idx, adv_idx, None)),
((5, 6, 7, 8), [None, adv_idx, None, adv_idx]), ((5, 6, 7, 8), (None, adv_idx, None, adv_idx)),
((5, 6, 7, 8), [adv_idx, None, None, adv_idx]), ((5, 6, 7, 8), (adv_idx, None, None, adv_idx)),
((5, 6, 7, 8), [None, None, adv_idx, adv_idx]), ((5, 6, 7, 8), (None, None, adv_idx, adv_idx)),
((5, 6, 7, 8, 9), [None, None, adv_idx, None, adv_idx]), ((5, 6, 7, 8, 9), (None, None, adv_idx, None, adv_idx)),
((5, 6, 7, 8, 9), [None, None, adv_idx, adv_idx, None]), ((5, 6, 7, 8, 9), (None, None, adv_idx, adv_idx, None)),
((5, 6, 7, 8, 9, 10), [None, None, None, adv_idx, adv_idx]), ((5, 6, 7, 8, 9, 10), (None, None, None, adv_idx, adv_idx)),
((5, 6, 7, 8, 9, 10), [None, None, adv_idx, adv_idx, adv_idx]), ((5, 6, 7, 8, 9, 10), (None, None, adv_idx, adv_idx, adv_idx)),
] ]
for self_shape, indices in additional: for self_shape, indices in additional:
for broadcast_value in [False, True]: for broadcast_value in [False, True]:
inp = make_arg(self_shape) inp = make_arg(self_shape)
tmp_indices = [slice(None) if idx is None else idx for idx in indices] tmp_indices = tuple(
[slice(None) if idx is None else idx for idx in indices]
)
values_shape = inp[tmp_indices].shape values_shape = inp[tmp_indices].shape
if broadcast_value: if broadcast_value:
values_shape = values_shape[3:] values_shape = values_shape[3:]

View File

@ -3028,8 +3028,8 @@ class TestAutograd(TestCase):
check_index(x, y, ([1, 2, 3], [0])) check_index(x, y, ([1, 2, 3], [0]))
check_index(x, y, ([1, 2], [2, 1])) check_index(x, y, ([1, 2], [2, 1]))
check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 3]])) check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 3]]))
check_index(x, y, ([slice(None), [2, 3]])) check_index(x, y, ((slice(None), [2, 3])))
check_index(x, y, ([[2, 3], slice(None)])) check_index(x, y, (([2, 3], slice(None))))
# advanced indexing, with less dim, or ellipsis # advanced indexing, with less dim, or ellipsis
check_index(x, y, ([0])) check_index(x, y, ([0]))
@ -3061,8 +3061,8 @@ class TestAutograd(TestCase):
# advanced indexing, with a tensor wrapped in a variable # advanced indexing, with a tensor wrapped in a variable
z = torch.LongTensor([0, 1]) z = torch.LongTensor([0, 1])
zv = Variable(z, requires_grad=False) zv = Variable(z, requires_grad=False)
seq = [z, Ellipsis] seq = (z, Ellipsis)
seqv = [zv, Ellipsis] seqv = (zv, Ellipsis)
if y.grad is not None: if y.grad is not None:
with torch.no_grad(): with torch.no_grad():
@ -3086,7 +3086,7 @@ class TestAutograd(TestCase):
x = torch.arange(1.0, 17).view(4, 4) x = torch.arange(1.0, 17).view(4, 4)
y = Variable(x, requires_grad=True) y = Variable(x, requires_grad=True)
idx = [[1, 1, 3, 2, 1, 2], [0]] idx = ([1, 1, 3, 2, 1, 2], [0])
y[idx].sum().backward() y[idx].sum().backward()
expected_grad = torch.zeros(4, 4) expected_grad = torch.zeros(4, 4)
for i in idx[0]: for i in idx[0]:
@ -3097,7 +3097,7 @@ class TestAutograd(TestCase):
x = torch.arange(1.0, 17).view(4, 4) x = torch.arange(1.0, 17).view(4, 4)
y = Variable(x, requires_grad=True) y = Variable(x, requires_grad=True)
idx = [[[1, 2], [0, 0]], [[0, 1], [1, 1]]] idx = ([[1, 2], [0, 0]], [[0, 1], [1, 1]])
y[idx].sum().backward() y[idx].sum().backward()
expected_grad = torch.tensor( expected_grad = torch.tensor(
[ [
@ -3112,7 +3112,7 @@ class TestAutograd(TestCase):
x = torch.arange(1.0, 65).view(4, 4, 4) x = torch.arange(1.0, 65).view(4, 4, 4)
y = Variable(x, requires_grad=True) y = Variable(x, requires_grad=True)
idx = [[1, 1, 1], slice(None), slice(None)] idx = ([1, 1, 1], slice(None), slice(None))
y[idx].sum().backward() y[idx].sum().backward()
expected_grad = torch.empty(4, 4, 4).zero_() expected_grad = torch.empty(4, 4, 4).zero_()
expected_grad[1].fill_(3) expected_grad[1].fill_(3)
@ -3541,32 +3541,32 @@ class TestAutograd(TestCase):
self._test_setitem((5, 5), 1) self._test_setitem((5, 5), 1)
self._test_setitem((5,), 1) self._test_setitem((5,), 1)
self._test_setitem((1,), 0) self._test_setitem((1,), 0)
self._test_setitem((10,), [[0, 4, 2]]) self._test_setitem((10,), ([0, 4, 2]))
self._test_setitem((5, 5), [[0, 4], [2, 2]]) self._test_setitem((5, 5), ([0, 4], [2, 2]))
self._test_setitem((5, 5, 5), [slice(None), slice(None), [1, 3]]) self._test_setitem((5, 5, 5), (slice(None), slice(None), [1, 3]))
self._test_setitem((5, 5, 5), [slice(None), [1, 3], slice(None)]) self._test_setitem((5, 5, 5), (slice(None), [1, 3], slice(None)))
self._test_setitem((5, 5, 5), [[1, 3], slice(None), slice(None)]) self._test_setitem((5, 5, 5), ([1, 3], slice(None), slice(None)))
self._test_setitem((5, 5, 5), [slice(None), [2, 4], [1, 3]]) self._test_setitem((5, 5, 5), (slice(None), [2, 4], [1, 3]))
self._test_setitem((5, 5, 5), [[1, 3], [2, 4], slice(None)]) self._test_setitem((5, 5, 5), ([1, 3], [2, 4], slice(None)))
self._test_setitem_tensor((5, 5), 3) self._test_setitem_tensor((5, 5), 3)
self._test_setitem_tensor((5, 5), [[0, 1], [1, 0]]) self._test_setitem_tensor((5, 5), ([0, 1], [1, 0]))
self._test_setitem_tensor((5,), 3) self._test_setitem_tensor((5,), 3)
self._test_setitem_tensor( self._test_setitem_tensor(
(5,), Variable(torch.LongTensor([3]), requires_grad=False).sum() (5,), Variable(torch.LongTensor([3]), requires_grad=False).sum()
) )
self._test_setitem_tensor((5,), [[0, 1, 2, 3]]) self._test_setitem_tensor((5,), [[0, 1, 2, 3]])
self._test_setitem_tensor((5, 5, 5), [slice(None), slice(None), [1, 3]]) self._test_setitem_tensor((5, 5, 5), (slice(None), slice(None), [1, 3]))
self._test_setitem_tensor((5, 5, 5), [slice(None), [1, 3], slice(None)]) self._test_setitem_tensor((5, 5, 5), (slice(None), [1, 3], slice(None)))
self._test_setitem_tensor((5, 5, 5), [[1, 3], slice(None), slice(None)]) self._test_setitem_tensor((5, 5, 5), ([1, 3], slice(None), slice(None)))
self._test_setitem_tensor((5, 5, 5), [slice(None), [2, 4], [1, 3]]) self._test_setitem_tensor((5, 5, 5), (slice(None), [2, 4], [1, 3]))
self._test_setitem_tensor((5, 5, 5), [[1, 3], [2, 4], slice(None)]) self._test_setitem_tensor((5, 5, 5), ([1, 3], [2, 4], slice(None)))
self._test_setitem_tensor( self._test_setitem_tensor(
(5, 5, 5), (5, 5, 5),
[ (
Variable(torch.LongTensor([1, 3]), requires_grad=False), Variable(torch.LongTensor([1, 3]), requires_grad=False),
[2, 4], [2, 4],
slice(None), slice(None),
], ),
) )
def test_setitem_mask(self): def test_setitem_mask(self):

View File

@ -250,7 +250,10 @@ class TestIndexing(TestCase):
reference = consec((10,)) reference = consec((10,))
strided = torch.tensor((), dtype=dtype, device=device) strided = torch.tensor((), dtype=dtype, device=device)
strided.set_( strided.set_(
reference.storage(), storage_offset=0, size=torch.Size([4]), stride=[2] reference.untyped_storage(),
storage_offset=0,
size=torch.Size([4]),
stride=[2],
) )
self.assertEqual(strided[[0]], torch.tensor([1], dtype=dtype, device=device)) self.assertEqual(strided[[0]], torch.tensor([1], dtype=dtype, device=device))
@ -274,7 +277,10 @@ class TestIndexing(TestCase):
# stride is [4, 8] # stride is [4, 8]
strided = torch.tensor((), dtype=dtype, device=device) strided = torch.tensor((), dtype=dtype, device=device)
strided.set_( strided.set_(
reference.storage(), storage_offset=4, size=torch.Size([2]), stride=[4] reference.untyped_storage(),
storage_offset=4,
size=torch.Size([2]),
stride=[4],
) )
self.assertEqual(strided[[0]], torch.tensor([5], dtype=dtype, device=device)) self.assertEqual(strided[[0]], torch.tensor([5], dtype=dtype, device=device))
self.assertEqual( self.assertEqual(
@ -309,15 +315,15 @@ class TestIndexing(TestCase):
self.assertEqual(reference[ri([0]), ri([0])], consec((1,))) self.assertEqual(reference[ri([0]), ri([0])], consec((1,)))
self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6)) self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6))
self.assertEqual( self.assertEqual(
reference[[ri([0, 0]), ri([0, 1])]], reference[(ri([0, 0]), ri([0, 1]))],
torch.tensor([1, 2], dtype=dtype, device=device), torch.tensor([1, 2], dtype=dtype, device=device),
) )
self.assertEqual( self.assertEqual(
reference[[ri([0, 1, 1, 0, 2]), ri([1])]], reference[(ri([0, 1, 1, 0, 2]), ri([1]))],
torch.tensor([2, 4, 4, 2, 6], dtype=dtype, device=device), torch.tensor([2, 4, 4, 2, 6], dtype=dtype, device=device),
) )
self.assertEqual( self.assertEqual(
reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], reference[(ri([0, 0, 1, 1]), ri([0, 1, 0, 0]))],
torch.tensor([1, 2, 3, 3], dtype=dtype, device=device), torch.tensor([1, 2, 3, 3], dtype=dtype, device=device),
) )
@ -387,15 +393,15 @@ class TestIndexing(TestCase):
reference[ri([2]), ri([1])], torch.tensor([6], dtype=dtype, device=device) reference[ri([2]), ri([1])], torch.tensor([6], dtype=dtype, device=device)
) )
self.assertEqual( self.assertEqual(
reference[[ri([0, 0]), ri([0, 1])]], reference[(ri([0, 0]), ri([0, 1]))],
torch.tensor([0, 4], dtype=dtype, device=device), torch.tensor([0, 4], dtype=dtype, device=device),
) )
self.assertEqual( self.assertEqual(
reference[[ri([0, 1, 1, 0, 3]), ri([1])]], reference[(ri([0, 1, 1, 0, 3]), ri([1]))],
torch.tensor([4, 5, 5, 4, 7], dtype=dtype, device=device), torch.tensor([4, 5, 5, 4, 7], dtype=dtype, device=device),
) )
self.assertEqual( self.assertEqual(
reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], reference[(ri([0, 0, 1, 1]), ri([0, 1, 0, 0]))],
torch.tensor([0, 4, 1, 1], dtype=dtype, device=device), torch.tensor([0, 4, 1, 1], dtype=dtype, device=device),
) )
@ -446,7 +452,9 @@ class TestIndexing(TestCase):
reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8) reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
strided = torch.tensor((), dtype=dtype, device=device) strided = torch.tensor((), dtype=dtype, device=device)
strided.set_(reference.storage(), 1, size=torch.Size([2, 4]), stride=[8, 2]) strided.set_(
reference.untyped_storage(), 1, size=torch.Size([2, 4]), stride=[8, 2]
)
self.assertEqual( self.assertEqual(
strided[ri([0, 1]), ri([0])], strided[ri([0, 1]), ri([0])],
@ -463,15 +471,15 @@ class TestIndexing(TestCase):
strided[ri([1]), ri([3])], torch.tensor([15], dtype=dtype, device=device) strided[ri([1]), ri([3])], torch.tensor([15], dtype=dtype, device=device)
) )
self.assertEqual( self.assertEqual(
strided[[ri([0, 0]), ri([0, 3])]], strided[(ri([0, 0]), ri([0, 3]))],
torch.tensor([1, 7], dtype=dtype, device=device), torch.tensor([1, 7], dtype=dtype, device=device),
) )
self.assertEqual( self.assertEqual(
strided[[ri([1]), ri([0, 1, 1, 0, 3])]], strided[(ri([1]), ri([0, 1, 1, 0, 3]))],
torch.tensor([9, 11, 11, 9, 15], dtype=dtype, device=device), torch.tensor([9, 11, 11, 9, 15], dtype=dtype, device=device),
) )
self.assertEqual( self.assertEqual(
strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], strided[(ri([0, 0, 1, 1]), ri([0, 1, 0, 0]))],
torch.tensor([1, 3, 9, 9], dtype=dtype, device=device), torch.tensor([1, 3, 9, 9], dtype=dtype, device=device),
) )
@ -502,7 +510,9 @@ class TestIndexing(TestCase):
reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8) reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
strided = torch.tensor((), dtype=dtype, device=device) strided = torch.tensor((), dtype=dtype, device=device)
strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]) strided.set_(
reference.untyped_storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]
)
self.assertEqual( self.assertEqual(
strided[ri([0]), ri([1])], torch.tensor([11], dtype=dtype, device=device) strided[ri([0]), ri([1])], torch.tensor([11], dtype=dtype, device=device)
) )
@ -513,7 +523,9 @@ class TestIndexing(TestCase):
reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8) reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
strided = torch.tensor((), dtype=dtype, device=device) strided = torch.tensor((), dtype=dtype, device=device)
strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]) strided.set_(
reference.untyped_storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]
)
self.assertEqual( self.assertEqual(
strided[ri([0, 1]), ri([1, 0])], strided[ri([0, 1]), ri([1, 0])],
torch.tensor([11, 17], dtype=dtype, device=device), torch.tensor([11, 17], dtype=dtype, device=device),
@ -528,7 +540,9 @@ class TestIndexing(TestCase):
reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8) reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
strided = torch.tensor((), dtype=dtype, device=device) strided = torch.tensor((), dtype=dtype, device=device)
strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]) strided.set_(
reference.untyped_storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]
)
rows = ri([[0], [1]]) rows = ri([[0], [1]])
columns = ri([[0, 1], [0, 1]]) columns = ri([[0, 1], [0, 1]])
@ -642,19 +656,19 @@ class TestIndexing(TestCase):
indices_to_test = [ indices_to_test = [
# grab the second, fourth columns # grab the second, fourth columns
[slice(None), [1, 3]], (slice(None), [1, 3]),
# first, third rows, # first, third rows,
[[0, 2], slice(None)], ([0, 2], slice(None)),
# weird shape # weird shape
[slice(None), [[0, 1], [2, 3]]], (slice(None), [[0, 1], [2, 3]]),
# negatives # negatives
[[-1], [0]], ([-1], [0]),
[[0, 2], [-1]], ([0, 2], [-1]),
[slice(None), [-1]], (slice(None), [-1]),
] ]
# only test dupes on gets # only test dupes on gets
get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]] get_indices_to_test = indices_to_test + [(slice(None), [0, 1, 1, 2, 2])]
for indexer in get_indices_to_test: for indexer in get_indices_to_test:
assert_get_eq(reference, indexer) assert_get_eq(reference, indexer)
@ -668,46 +682,46 @@ class TestIndexing(TestCase):
reference = torch.arange(0.0, 160, dtype=dtype, device=device).view(4, 8, 5) reference = torch.arange(0.0, 160, dtype=dtype, device=device).view(4, 8, 5)
indices_to_test = [ indices_to_test = [
[slice(None), slice(None), [0, 3, 4]], (slice(None), slice(None), (0, 3, 4)),
[slice(None), [2, 4, 5, 7], slice(None)], (slice(None), (2, 4, 5, 7), slice(None)),
[[2, 3], slice(None), slice(None)], ((2, 3), slice(None), slice(None)),
[slice(None), [0, 2, 3], [1, 3, 4]], (slice(None), (0, 2, 3), (1, 3, 4)),
[slice(None), [0], [1, 2, 4]], (slice(None), (0,), (1, 2, 4)),
[slice(None), [0, 1, 3], [4]], (slice(None), (0, 1, 3), (4,)),
[slice(None), [[0, 1], [1, 0]], [[2, 3]]], (slice(None), ((0, 1), (1, 0)), ((2, 3),)),
[slice(None), [[0, 1], [2, 3]], [[0]]], (slice(None), ((0, 1), (2, 3)), ((0,),)),
[slice(None), [[5, 6]], [[0, 3], [4, 4]]], (slice(None), ((5, 6),), ((0, 3), (4, 4))),
[[0, 2, 3], [1, 3, 4], slice(None)], ((0, 2, 3), (1, 3, 4), slice(None)),
[[0], [1, 2, 4], slice(None)], ((0,), (1, 2, 4), slice(None)),
[[0, 1, 3], [4], slice(None)], ((0, 1, 3), (4,), slice(None)),
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], (((0, 1), (1, 0)), ((2, 1), (3, 5)), slice(None)),
[[[0, 1], [1, 0]], [[2, 3]], slice(None)], (((0, 1), (1, 0)), ((2, 3),), slice(None)),
[[[0, 1], [2, 3]], [[0]], slice(None)], (((0, 1), (2, 3)), ((0,),), slice(None)),
[[[2, 1]], [[0, 3], [4, 4]], slice(None)], (((2, 1),), ((0, 3), (4, 4)), slice(None)),
[[[2]], [[0, 3], [4, 1]], slice(None)], (((2,),), ((0, 3), (4, 1)), slice(None)),
# non-contiguous indexing subspace # non-contiguous indexing subspace
[[0, 2, 3], slice(None), [1, 3, 4]], ((0, 2, 3), slice(None), (1, 3, 4)),
# [...] # [...]
# less dim, ellipsis # less dim, ellipsis
[[0, 2]], ((0, 2),),
[[0, 2], slice(None)], ((0, 2), slice(None)),
[[0, 2], Ellipsis], ((0, 2), Ellipsis),
[[0, 2], slice(None), Ellipsis], ((0, 2), slice(None), Ellipsis),
[[0, 2], Ellipsis, slice(None)], ((0, 2), Ellipsis, slice(None)),
[[0, 2], [1, 3]], ((0, 2), (1, 3)),
[[0, 2], [1, 3], Ellipsis], ((0, 2), (1, 3), Ellipsis),
[Ellipsis, [1, 3], [2, 3]], (Ellipsis, (1, 3), (2, 3)),
[Ellipsis, [2, 3, 4]], (Ellipsis, (2, 3, 4)),
[Ellipsis, slice(None), [2, 3, 4]], (Ellipsis, slice(None), (2, 3, 4)),
[slice(None), Ellipsis, [2, 3, 4]], (slice(None), Ellipsis, (2, 3, 4)),
# ellipsis counts for nothing # ellipsis counts for nothing
[Ellipsis, slice(None), slice(None), [0, 3, 4]], (Ellipsis, slice(None), slice(None), (0, 3, 4)),
[slice(None), Ellipsis, slice(None), [0, 3, 4]], (slice(None), Ellipsis, slice(None), (0, 3, 4)),
[slice(None), slice(None), Ellipsis, [0, 3, 4]], (slice(None), slice(None), Ellipsis, (0, 3, 4)),
[slice(None), slice(None), [0, 3, 4], Ellipsis], (slice(None), slice(None), (0, 3, 4), Ellipsis),
[Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], (Ellipsis, ((0, 1), (1, 0)), ((2, 1), (3, 5)), slice(None)),
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)], (((0, 1), (1, 0)), ((2, 1), (3, 5)), Ellipsis, slice(None)),
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis], (((0, 1), (1, 0)), ((2, 1), (3, 5)), slice(None), Ellipsis),
] ]
for indexer in indices_to_test: for indexer in indices_to_test:
@ -720,65 +734,65 @@ class TestIndexing(TestCase):
reference = torch.arange(0.0, 1296, dtype=dtype, device=device).view(3, 9, 8, 6) reference = torch.arange(0.0, 1296, dtype=dtype, device=device).view(3, 9, 8, 6)
indices_to_test = [ indices_to_test = [
[slice(None), slice(None), slice(None), [0, 3, 4]], (slice(None), slice(None), slice(None), (0, 3, 4)),
[slice(None), slice(None), [2, 4, 5, 7], slice(None)], (slice(None), slice(None), (2, 4, 5, 7), slice(None)),
[slice(None), [2, 3], slice(None), slice(None)], (slice(None), (2, 3), slice(None), slice(None)),
[[1, 2], slice(None), slice(None), slice(None)], ((1, 2), slice(None), slice(None), slice(None)),
[slice(None), slice(None), [0, 2, 3], [1, 3, 4]], (slice(None), slice(None), (0, 2, 3), (1, 3, 4)),
[slice(None), slice(None), [0], [1, 2, 4]], (slice(None), slice(None), (0,), (1, 2, 4)),
[slice(None), slice(None), [0, 1, 3], [4]], (slice(None), slice(None), (0, 1, 3), (4,)),
[slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]], (slice(None), slice(None), ((0, 1), (1, 0)), ((2, 3),)),
[slice(None), slice(None), [[0, 1], [2, 3]], [[0]]], (slice(None), slice(None), ((0, 1), (2, 3)), ((0,),)),
[slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]], (slice(None), slice(None), ((5, 6),), ((0, 3), (4, 4))),
[slice(None), [0, 2, 3], [1, 3, 4], slice(None)], (slice(None), (0, 2, 3), (1, 3, 4), slice(None)),
[slice(None), [0], [1, 2, 4], slice(None)], (slice(None), (0,), (1, 2, 4), slice(None)),
[slice(None), [0, 1, 3], [4], slice(None)], (slice(None), (0, 1, 3), (4,), slice(None)),
[slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)], (slice(None), ((0, 1), (3, 4)), ((2, 3), (0, 1)), slice(None)),
[slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)], (slice(None), ((0, 1), (3, 4)), ((2, 3),), slice(None)),
[slice(None), [[0, 1], [3, 2]], [[0]], slice(None)], (slice(None), ((0, 1), (3, 2)), ((0,),), slice(None)),
[slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)], (slice(None), ((2, 1),), ((0, 3), (6, 4)), slice(None)),
[slice(None), [[2]], [[0, 3], [4, 2]], slice(None)], (slice(None), ((2,),), ((0, 3), (4, 2)), slice(None)),
[[0, 1, 2], [1, 3, 4], slice(None), slice(None)], ((0, 1, 2), (1, 3, 4), slice(None), slice(None)),
[[0], [1, 2, 4], slice(None), slice(None)], ((0,), (1, 2, 4), slice(None), slice(None)),
[[0, 1, 2], [4], slice(None), slice(None)], ((0, 1, 2), (4,), slice(None), slice(None)),
[[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)], (((0, 1), (0, 2)), ((2, 4), (1, 5)), slice(None), slice(None)),
[[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)], (((0, 1), (1, 2)), ((2, 0),), slice(None), slice(None)),
[[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)], (((2, 2),), ((0, 3), (4, 5)), slice(None), slice(None)),
[[[2]], [[0, 3], [4, 5]], slice(None), slice(None)], (((2,),), ((0, 3), (4, 5)), slice(None), slice(None)),
[slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]], (slice(None), (3, 4, 6), (0, 2, 3), (1, 3, 4)),
[slice(None), [2, 3, 4], [1, 3, 4], [4]], (slice(None), (2, 3, 4), (1, 3, 4), (4,)),
[slice(None), [0, 1, 3], [4], [1, 3, 4]], (slice(None), (0, 1, 3), (4,), (1, 3, 4)),
[slice(None), [6], [0, 2, 3], [1, 3, 4]], (slice(None), (6,), (0, 2, 3), (1, 3, 4)),
[slice(None), [2, 3, 5], [3], [4]], (slice(None), (2, 3, 5), (3,), (4,)),
[slice(None), [0], [4], [1, 3, 4]], (slice(None), (0,), (4,), (1, 3, 4)),
[slice(None), [6], [0, 2, 3], [1]], (slice(None), (6,), (0, 2, 3), (1,)),
[slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]], (slice(None), ((0, 3), (3, 6)), ((0, 1), (1, 3)), ((5, 3), (1, 2))),
[[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)], ((2, 2, 1), (0, 2, 3), (1, 3, 4), slice(None)),
[[2, 0, 1], [1, 2, 3], [4], slice(None)], ((2, 0, 1), (1, 2, 3), (4,), slice(None)),
[[0, 1, 2], [4], [1, 3, 4], slice(None)], ((0, 1, 2), (4,), (1, 3, 4), slice(None)),
[[0], [0, 2, 3], [1, 3, 4], slice(None)], ((0,), (0, 2, 3), (1, 3, 4), slice(None)),
[[0, 2, 1], [3], [4], slice(None)], ((0, 2, 1), (3,), (4,), slice(None)),
[[0], [4], [1, 3, 4], slice(None)], ((0,), (4,), (1, 3, 4), slice(None)),
[[1], [0, 2, 3], [1], slice(None)], ((1,), (0, 2, 3), (1,), slice(None)),
[[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)], (((1, 2), (1, 2)), ((0, 1), (2, 3)), ((2, 3), (3, 5)), slice(None)),
# less dim, ellipsis # less dim, ellipsis
[Ellipsis, [0, 3, 4]], (Ellipsis, (0, 3, 4)),
[Ellipsis, slice(None), [0, 3, 4]], (Ellipsis, slice(None), (0, 3, 4)),
[Ellipsis, slice(None), slice(None), [0, 3, 4]], (Ellipsis, slice(None), slice(None), (0, 3, 4)),
[slice(None), Ellipsis, [0, 3, 4]], (slice(None), Ellipsis, (0, 3, 4)),
[slice(None), slice(None), Ellipsis, [0, 3, 4]], (slice(None), slice(None), Ellipsis, (0, 3, 4)),
[slice(None), [0, 2, 3], [1, 3, 4]], (slice(None), (0, 2, 3), (1, 3, 4)),
[slice(None), [0, 2, 3], [1, 3, 4], Ellipsis], (slice(None), (0, 2, 3), (1, 3, 4), Ellipsis),
[Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)], (Ellipsis, (0, 2, 3), (1, 3, 4), slice(None)),
[[0], [1, 2, 4]], ((0,), (1, 2, 4)),
[[0], [1, 2, 4], slice(None)], ((0,), (1, 2, 4), slice(None)),
[[0], [1, 2, 4], Ellipsis], ((0,), (1, 2, 4), Ellipsis),
[[0], [1, 2, 4], Ellipsis, slice(None)], ((0,), (1, 2, 4), Ellipsis, slice(None)),
[[1]], ((1,),),
[[0, 2, 1], [3], [4]], ((0, 2, 1), (3,), (4,)),
[[0, 2, 1], [3], [4], slice(None)], ((0, 2, 1), (3,), (4,), slice(None)),
[[0, 2, 1], [3], [4], Ellipsis], ((0, 2, 1), (3,), (4,), Ellipsis),
[Ellipsis, [0, 2, 1], [3], [4]], (Ellipsis, (0, 2, 1), (3,), (4,)),
] ]
for indexer in indices_to_test: for indexer in indices_to_test:
@ -786,8 +800,8 @@ class TestIndexing(TestCase):
assert_set_eq(reference, indexer, 1333) assert_set_eq(reference, indexer, 1333)
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
indices_to_test += [ indices_to_test += [
[slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]], (slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]),
[slice(None), slice(None), [[2]], [[0, 3], [4, 4]]], (slice(None), slice(None), [[2]], [[0, 3], [4, 4]]),
] ]
for indexer in indices_to_test: for indexer in indices_to_test:
assert_get_eq(reference, indexer) assert_get_eq(reference, indexer)
@ -866,6 +880,21 @@ class TestIndexing(TestCase):
) )
self.assertEqual(len(w), 1) self.assertEqual(len(w), 1)
def test_list_indices(self, device):
N = 1000
t = torch.randn(N, device=device)
# Set window size
W = 10
# Generate a list of lists, containing overlapping window indices
indices = [range(i, i + W) for i in range(0, N - W)]
for i in [len(indices), 100, 32]:
windowed_data = t[indices[:i]]
self.assertEqual(windowed_data.shape, (i, W))
with self.assertRaisesRegex(IndexError, "too many indices"):
windowed_data = t[indices[:31]]
def test_bool_indices_accumulate(self, device): def test_bool_indices_accumulate(self, device):
mask = torch.zeros(size=(10,), dtype=torch.bool, device=device) mask = torch.zeros(size=(10,), dtype=torch.bool, device=device)
y = torch.ones(size=(10, 10), device=device) y = torch.ones(size=(10, 10), device=device)

View File

@ -60,20 +60,22 @@ def _hermitian_conj(x, dim):
""" """
out = torch.empty_like(x) out = torch.empty_like(x)
mid = (x.size(dim) - 1) // 2 mid = (x.size(dim) - 1) // 2
idx = [slice(None)] * out.dim() idx = tuple([slice(None)] * out.dim())
idx_center = list(idx)
idx_center[dim] = 0
out[idx] = x[idx] out[idx] = x[idx]
idx_neg = list(idx) idx_neg = list(idx)
idx_neg[dim] = slice(-mid, None) idx_neg[dim] = slice(-mid, None)
idx_pos = idx idx_neg = tuple(idx_neg)
idx_pos = list(idx)
idx_pos[dim] = slice(1, mid + 1) idx_pos[dim] = slice(1, mid + 1)
idx_pos = tuple(idx_pos)
out[idx_pos] = x[idx_neg].flip(dim) out[idx_pos] = x[idx_neg].flip(dim)
out[idx_neg] = x[idx_pos].flip(dim) out[idx_neg] = x[idx_pos].flip(dim)
if (2 * mid + 1 < x.size(dim)): if (2 * mid + 1 < x.size(dim)):
idx = list(idx)
idx[dim] = mid + 1 idx[dim] = mid + 1
idx = tuple(idx)
out[idx] = x[idx] out[idx] = x[idx]
return out.conj() return out.conj()
@ -518,6 +520,7 @@ class TestFFT(TestCase):
lastdim_size = input.size(lastdim) // 2 + 1 lastdim_size = input.size(lastdim) // 2 + 1
idx = [slice(None)] * input_ndim idx = [slice(None)] * input_ndim
idx[lastdim] = slice(0, lastdim_size) idx[lastdim] = slice(0, lastdim_size)
idx = tuple(idx)
input = input[idx] input = input[idx]
s = [shape[dim] for dim in actual_dims] s = [shape[dim] for dim in actual_dims]
@ -558,6 +561,7 @@ class TestFFT(TestCase):
lastdim_size = expect.size(lastdim) // 2 + 1 lastdim_size = expect.size(lastdim) // 2 + 1
idx = [slice(None)] * input_ndim idx = [slice(None)] * input_ndim
idx[lastdim] = slice(0, lastdim_size) idx[lastdim] = slice(0, lastdim_size)
idx = tuple(idx)
expect = expect[idx] expect = expect[idx]
actual = torch.fft.ihfftn(input, dim=dim, norm="ortho") actual = torch.fft.ihfftn(input, dim=dim, norm="ortho")

View File

@ -941,7 +941,7 @@ def choose(
] ]
idx_list[0] = a idx_list[0] = a
return choices[idx_list].squeeze(0) return choices[tuple(idx_list)].squeeze(0)
# ### unique et al. ### # ### unique et al. ###

View File

@ -25,6 +25,7 @@
#include <ATen/TracerMode.h> #include <ATen/TracerMode.h>
#include <ATen/core/LegacyTypeDispatch.h> #include <ATen/core/LegacyTypeDispatch.h>
#include <c10/core/TensorOptions.h> #include <c10/core/TensorOptions.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <c10/core/Layout.h> #include <c10/core/Layout.h>
@ -292,6 +293,13 @@ static bool treatSequenceAsTuple(PyObject* index) {
} }
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
if (n >= 32) { if (n >= 32) {
TORCH_WARN(
"Using a non-tuple sequence for "
"multidimensional indexing is deprecated and will be changed in "
"pytorch 2.9; use x[tuple(seq)] instead of "
"x[seq]. In pytorch 2.9 this will be interpreted as tensor index, "
"x[torch.tensor(seq)], which will result either in an error or a "
"different result");
return false; return false;
} }
for (Py_ssize_t i = 0; i < n; i++) { for (Py_ssize_t i = 0; i < n; i++) {

View File

@ -592,7 +592,7 @@ def _distribute_tensors(
] ]
if local_state.is_meta: if local_state.is_meta:
# Use .clone() here rather than view to clone and return only the sliced portion, minimizing memory access and cost. # Use .clone() here rather than view to clone and return only the sliced portion, minimizing memory access and cost.
local_tensor = full_tensor[slices].detach().clone() local_tensor = full_tensor[tuple(slices)].detach().clone()
# TODO: currently, we cannot handle strided sharding if the dp dimension is not even. For example, # TODO: currently, we cannot handle strided sharding if the dp dimension is not even. For example,
# one of the case that is not yet supported is when placements = (Shard(0), _StridedShard(0, sf=2)). # one of the case that is not yet supported is when placements = (Shard(0), _StridedShard(0, sf=2)).
ret = DTensor.from_local( ret = DTensor.from_local(
@ -605,7 +605,7 @@ def _distribute_tensors(
else: else:
ret = local_state ret = local_state
# Copy full_tensor[slices] into local_state.to_local() to reduce memory footprint. # Copy full_tensor[slices] into local_state.to_local() to reduce memory footprint.
ret.to_local().copy_(full_tensor[slices]) ret.to_local().copy_(full_tensor[tuple(slices)])
local_state_dict[key] = ret local_state_dict[key] = ret

View File

@ -394,6 +394,8 @@ class PruningContainer(BasePruningMethod):
raise ValueError(f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}") raise ValueError(f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}")
# compute the new mask on the unpruned slice of the tensor t # compute the new mask on the unpruned slice of the tensor t
if isinstance(slc, list):
slc = tuple(slc)
partial_mask = method.compute_mask(t[slc], default_mask=mask[slc]) partial_mask = method.compute_mask(t[slc], default_mask=mask[slc])
new_mask[slc] = partial_mask.to(dtype=new_mask.dtype) new_mask[slc] = partial_mask.to(dtype=new_mask.dtype)
@ -625,6 +627,7 @@ class RandomStructured(BasePruningMethod):
mask = torch.zeros_like(t) mask = torch.zeros_like(t)
slc = [slice(None)] * len(t.shape) slc = [slice(None)] * len(t.shape)
slc[dim] = channel_mask slc[dim] = channel_mask
slc = tuple(slc)
mask[slc] = 1 mask[slc] = 1
return mask return mask
@ -739,6 +742,7 @@ class LnStructured(BasePruningMethod):
# replace a None at position=dim with indices # replace a None at position=dim with indices
# e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3] # e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3]
slc[dim] = indices slc[dim] = indices
slc = tuple(slc)
# use slc to slice mask and replace all its entries with 1s # use slc to slice mask and replace all its entries with 1s
# e.g.: mask[:, :, [0, 2, 3]] = 1 # e.g.: mask[:, :, [0, 2, 3]] = 1
mask[slc] = 1 mask[slc] = 1

View File

@ -124,7 +124,7 @@ def multidim_slicer(dims, slices, *tensors):
for d, d_slice in zip(dims, slices): for d, d_slice in zip(dims, slices):
if d is not None: if d is not None:
s[d] = d_slice s[d] = d_slice
yield t[s] yield t[tuple(s)]
def ptr_stride_extractor(*tensors): def ptr_stride_extractor(*tensors):

View File

@ -3266,17 +3266,17 @@ def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
test_args = [ test_args = [
([1, 2],), ([1, 2],),
(slice(0, 3),), (slice(0, 3),),
([slice(0, 3), 1],), ((slice(0, 3), 1),),
([[0, 2, 3], [1, 3, 3], [0, 0, 2]],), (([0, 2, 3], [1, 3, 3], [0, 0, 2]),),
([[0, 0, 3], [1, 1, 3], [0, 0, 2]],), (([0, 0, 3], [1, 1, 3], [0, 0, 2]),),
([slice(None), slice(None), [0, 3]],), ((slice(None), slice(None), [0, 3]),),
([slice(None), [0, 3], slice(None)],), ((slice(None), [0, 3], slice(None)),),
([[0, 3], slice(None), slice(None)],), (([0, 3], slice(None), slice(None)),),
([[0, 3], [1, 2], slice(None)],), (([0, 3], [1, 2], slice(None)),),
([[0, 3], ],), (([0, 3], ),),
([[0, 3], slice(None)],), (([0, 3], slice(None)),),
([[0, 3], Ellipsis],), (([0, 3], Ellipsis),),
([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])],), (([0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])),),
(index_variable(2, S, device=device),), (index_variable(2, S, device=device),),
(mask_not_all_zeros((S,)),), (mask_not_all_zeros((S,)),),
] ]
@ -3284,7 +3284,7 @@ def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
for args in test_args: for args in test_args:
yield SampleInput(make_arg((S, S, S)), args=args) yield SampleInput(make_arg((S, S, S)), args=args)
yield SampleInput(make_arg((S, S, S, S)), args=([slice(None), [0, 1], slice(None), [0, 1]],)) yield SampleInput(make_arg((S, S, S, S)), args=((slice(None), [0, 1], slice(None), [0, 1]),))
def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs): def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)

View File

@ -290,7 +290,7 @@ class FuzzedTensor:
raw_tensor = raw_tensor.permute(tuple(np.argsort(order))) raw_tensor = raw_tensor.permute(tuple(np.argsort(order)))
slices = [slice(0, size * step, step) for size, step in zip(size, steps)] slices = [slice(0, size * step, step) for size, step in zip(size, steps)]
tensor = raw_tensor[slices] tensor = raw_tensor[tuple(slices)]
properties = { properties = {
"numel": int(tensor.numel()), "numel": int(tensor.numel()),