Fix the check for can_use_expanded_index_path (#140351)

Fixes #129093

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140351
Approved by: https://github.com/mingfeima, https://github.com/cpuhrsch
This commit is contained in:
CaoE
2024-11-15 05:52:19 +00:00
committed by PyTorch MergeBot
parent 8043e67026
commit 6c0a2d8bbf
2 changed files with 26 additions and 1 deletions

View File

@ -349,6 +349,26 @@ class TestScatterGather(TestCase):
self.assertEqual(out, out2)
# test unsqueezed index
# expanded_index kernel can not handle the case:
# the size > 1 and stride == 1 at a dimension.
# for example: the index with size of [1, 8, 7], stride of [1, 1, 0].
# see https://github.com/pytorch/pytorch/issues/129093
def unsqueeze_helper(idx, dim):
if dim == 2:
return idx.unsqueeze(1).t()
else:
return unsqueeze_helper(idx, dim - 1).unsqueeze(dim - 1)
idx = torch.randint(0, dim_size, (input.shape[1],))
idx = unsqueeze_helper(idx, len(input_size))
expanded_shape[0] = 1
idx = idx.expand(expanded_shape)
idx2 = idx.contiguous()
out = torch.gather(input, 0, idx)
out2 = torch.gather(input2, 0, idx2)
self.assertEqual(out, out2)
helper([50, 17], 100)
helper([50, 1], 100)
helper([50, 8, 7], 100)