mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 04:54:55 +08:00
fix norrow_copy correctness issue for non-contiguous input for cpu path(reland) (#91883)
This PR is about re-land https://github.com/pytorch/pytorch/pull/91789. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91883 Approved by: https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
d1cc64b2ac
commit
1892c75a45
@ -1217,7 +1217,9 @@ Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t
|
||||
// Should just use narrow_copy_out, but this API is used internally at Meta:
|
||||
// https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561
|
||||
Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){
|
||||
auto output = at::empty_like(self);
|
||||
// narrow_copy_dense_cpu_out always resize output's size, so there only create
|
||||
// a zero size tensor.
|
||||
auto output = at::empty({0}, self.options());
|
||||
return narrow_copy_dense_cpu_out(self, dim, start, length, output);
|
||||
}
|
||||
|
||||
|
||||
@ -3542,7 +3542,6 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||
xfail('bitwise_left_shift', device_type='cpu'),
|
||||
decorate('bitwise_right_shift', device_type='cpu',
|
||||
decorator=expectedFailureIf(not (IS_MACOS and IS_X86))),
|
||||
xfail('narrow_copy', device_type='cpu'),
|
||||
|
||||
# UBSAN: runtime error: shift exponent -1 is negative
|
||||
decorate('bitwise_left_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
|
||||
@ -3721,11 +3720,6 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||
xfail('le'),
|
||||
xfail('lt'),
|
||||
xfail('ne'),
|
||||
# AssertionError
|
||||
# Mismatched elements: 18 / 20 (90.0%)
|
||||
# Greatest absolute difference: 14.031710147857666 at index (0, 5) (up to 0.0001 allowed)
|
||||
# Greatest relative difference: 2.9177700113052603 at index (0, 3) (up to 0.0001 allowed)
|
||||
xfail('narrow_copy', device_type='cpu'),
|
||||
# UBSAN: runtime error: 1.27043e+262 is outside the range of representable values of type 'float'
|
||||
decorate('special.zeta', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
|
||||
# RuntimeError: Expected all tensors to be on the same device,
|
||||
|
||||
@ -2971,6 +2971,13 @@ else:
|
||||
sz[d] = 0
|
||||
self.assertEqual(sz, y.size())
|
||||
|
||||
def test_narrow_copy_non_contiguous(self, device):
|
||||
# see https://github.com/pytorch/pytorch/issues/91690.
|
||||
inp = torch.randn(10, 2, device=device).movedim(-1, 0)
|
||||
expected = torch.narrow_copy(inp.contiguous(), 1, 0, 10)
|
||||
actual = torch.narrow_copy(inp, 1, 0, 10)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
# FIXME: move to indexing test suite
|
||||
@parametrize("reduce", ['prod', 'amin', 'amax', 'mean'])
|
||||
@dtypes(*all_types_and(torch.half, torch.bfloat16))
|
||||
|
||||
Reference in New Issue
Block a user