mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix the check for can_use_expanded_index_path (#140351)
Fixes #129093 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140351 Approved by: https://github.com/mingfeima, https://github.com/cpuhrsch
This commit is contained in:
@ -349,6 +349,26 @@ class TestScatterGather(TestCase):
|
||||
|
||||
self.assertEqual(out, out2)
|
||||
|
||||
# test unsqueezed index
|
||||
# expanded_index kernel can not handle the case:
|
||||
# the size > 1 and stride == 1 at a dimension.
|
||||
# for example: the index with size of [1, 8, 7], stride of [1, 1, 0].
|
||||
# see https://github.com/pytorch/pytorch/issues/129093
|
||||
def unsqueeze_helper(idx, dim):
|
||||
if dim == 2:
|
||||
return idx.unsqueeze(1).t()
|
||||
else:
|
||||
return unsqueeze_helper(idx, dim - 1).unsqueeze(dim - 1)
|
||||
|
||||
idx = torch.randint(0, dim_size, (input.shape[1],))
|
||||
idx = unsqueeze_helper(idx, len(input_size))
|
||||
expanded_shape[0] = 1
|
||||
idx = idx.expand(expanded_shape)
|
||||
idx2 = idx.contiguous()
|
||||
out = torch.gather(input, 0, idx)
|
||||
out2 = torch.gather(input2, 0, idx2)
|
||||
self.assertEqual(out, out2)
|
||||
|
||||
helper([50, 17], 100)
|
||||
helper([50, 1], 100)
|
||||
helper([50, 8, 7], 100)
|
||||
|
Reference in New Issue
Block a user