mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
fix stride compare failed when size value equal to one in ForeachUtils.h (#134546)
When size value equal to one, tensor strides value need be skipped to compare. @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/134546 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
ccca3de0cd
commit
49723a8ff3
@ -545,6 +545,24 @@ class TestForeach(TestCase):
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/113156
|
||||
torch._foreach_mul_(tensors, 1)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float32)
|
||||
def test_foreach_check_stride_ignore_dims_of_one(self, device, dtype):
|
||||
# default tensor stride is (9, 9, 3, 1).
|
||||
tensor = torch.ones((2, 1, 3, 3), device=device, dtype=dtype)
|
||||
strided_tensor = torch.ones(
|
||||
(2, 1, 3, 3), device=device, dtype=dtype
|
||||
).as_strided((2, 1, 3, 3), (9, 1, 3, 1))
|
||||
left_inputs = [tensor, strided_tensor]
|
||||
right_inputs = [strided_tensor, tensor]
|
||||
compare_result = tensor + strided_tensor
|
||||
foreach_add_check_ = ForeachFuncWrapper(torch._foreach_add)
|
||||
out = foreach_add_check_(
|
||||
(left_inputs, right_inputs), is_cuda=True, expect_fastpath=True
|
||||
)
|
||||
for res in out:
|
||||
self.assertEqual(res, compare_result)
|
||||
|
||||
@ops(
|
||||
filter(lambda op: op.supports_out, foreach_binary_op_db),
|
||||
dtypes=OpDTypes.supported,
|
||||
|
Reference in New Issue
Block a user