fix stride compare failed when size value equal to one in ForeachUtils.h (#134546)

When size value equal to one, tensor strides value need be skipped to compare.
@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134546
Approved by: https://github.com/janeyx99
This commit is contained in:
Shan19900305
2024-09-19 18:43:39 +00:00
committed by PyTorch MergeBot
parent ccca3de0cd
commit 49723a8ff3
2 changed files with 35 additions and 1 deletions

View File

@ -545,6 +545,24 @@ class TestForeach(TestCase):
# Regression test for https://github.com/pytorch/pytorch/issues/113156
torch._foreach_mul_(tensors, 1)
@onlyCUDA
@dtypes(torch.float32)
def test_foreach_check_stride_ignore_dims_of_one(self, device, dtype):
# default tensor stride is (9, 9, 3, 1).
tensor = torch.ones((2, 1, 3, 3), device=device, dtype=dtype)
strided_tensor = torch.ones(
(2, 1, 3, 3), device=device, dtype=dtype
).as_strided((2, 1, 3, 3), (9, 1, 3, 1))
left_inputs = [tensor, strided_tensor]
right_inputs = [strided_tensor, tensor]
compare_result = tensor + strided_tensor
foreach_add_check_ = ForeachFuncWrapper(torch._foreach_add)
out = foreach_add_check_(
(left_inputs, right_inputs), is_cuda=True, expect_fastpath=True
)
for res in out:
self.assertEqual(res, compare_result)
@ops(
filter(lambda op: op.supports_out, foreach_binary_op_db),
dtypes=OpDTypes.supported,