mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix to keep stride in return_and_correct_aliasing() (#117860)
Fixes #117794
Fix tripped the assert here: 86dedebeaf/torch/utils/_python_dispatch.py (L216)
From investigation: I found that functionalization of an in-place op (`mul_` in this test case) results in the strides of `TwoTensor`'s `a` / `b` components being mutated to be contiguous. This is not reflected in the outer tensor, causing the assert to be tripped.
After discussion with Brian, I address this in this PR by disallowing input mutations on non-contiguous tensor subclass inputs for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117860
Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
fa77829126
commit
e7eab2f07e
@ -917,6 +917,13 @@ def forward(self, x_a_1, x_b_1, y_1):
|
||||
return (mul, mul_1, add)
|
||||
""")
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/117794
|
||||
def test_return_and_correct_aliasing_gives_correct_stride(self):
|
||||
t = TwoTensor(torch.randn(2, 2), torch.randn(2, 2))
|
||||
x = torch.randn(2, 2)
|
||||
# slicing should result in the same stride for TwoTensor as a dense tensor would give
|
||||
self.assertEqual(t[:, 0].stride(), x[:, 0].stride())
|
||||
|
||||
def test_make_wrapper_subclass_propagates_metadata(self) -> None:
|
||||
class WrapperTensor(torch.Tensor):
|
||||
elem: torch.Tensor
|
||||
|
Reference in New Issue
Block a user