fix numpy compatibility for 2d small list indices (#154806)

Will fix #119548 and linked issues once we switch from warning to the new behavior,
but for now, given how much this syntax was used in our test suite, we suspect a silent change will be disruptive.
We will change the behavior after 2.8 branch is cut.
Numpy behavior was changed at least in numpy 1.24 (more than 2 years ago)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154806
Approved by: https://github.com/cyyever, https://github.com/Skylion007, https://github.com/albanD
This commit is contained in:
Natalia Gimelshein
2025-06-04 01:58:52 +00:00
committed by PyTorch MergeBot
parent e2760544fa
commit 34e3930401
13 changed files with 244 additions and 197 deletions

View File

@ -124,7 +124,7 @@ def multidim_slicer(dims, slices, *tensors):
for d, d_slice in zip(dims, slices):
if d is not None:
s[d] = d_slice
yield t[s]
yield t[tuple(s)]
def ptr_stride_extractor(*tensors):