mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Bugfix] Match eager stride semantics for cloned tensors with preserve_format in compile (#163017)
Fixes #161010 by making `clone_meta` match the semantics of strides for eager mode. This is: * Case 1: Tensor is_non_overlapping_and_dense; in this case, stride should match input tensor stride * Case 2: Otherwise, stride should be contiguous computed from input tensor using `compute_elementwise_output_strides` Pull Request resolved: https://github.com/pytorch/pytorch/pull/163017 Approved by: https://github.com/williamwen42, https://github.com/xmfan Co-authored-by: morrison-turnansky <mturnans@redhat.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
bc7b17a36d
commit
979e10f7d6
@ -4822,6 +4822,67 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||||||
"encountered a mutation on a view chain of length 2, where view 1 was an as_strided",
|
"encountered a mutation on a view chain of length 2, where view 1 was an as_strided",
|
||||||
):
|
):
|
||||||
f_compiled(a)
|
f_compiled(a)
|
||||||
|
# See https://github.com/pytorch/pytorch/issues/161010
|
||||||
|
|
||||||
|
def test_preserve_stride_with_clone(self) -> None:
|
||||||
|
A = torch.rand(5, 5, device="cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
B = torch.rand(5, 5, device="cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
def fn(
|
||||||
|
src: torch.Tensor, count: torch.Tensor
|
||||||
|
) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||||
|
Q, R = torch.linalg.qr(src)
|
||||||
|
rhs = torch.ones(Q.shape[0], 1, device=src.device)
|
||||||
|
a = torch.linalg.solve_triangular(R, Q.T @ rhs, upper=True)
|
||||||
|
cloned = a.clone(memory_format=torch.preserve_format)
|
||||||
|
return a.stride(), cloned.stride()
|
||||||
|
|
||||||
|
a_stride, cloned_stride = fn(A, torch.zeros(1))
|
||||||
|
self.assertEqual(
|
||||||
|
a_stride,
|
||||||
|
cloned_stride,
|
||||||
|
f"Strides should match in eager: {a_stride} against {cloned_stride}",
|
||||||
|
)
|
||||||
|
|
||||||
|
compiled_a_stride, compiled_cloned_stride = torch.compile(fn, backend="eager")(
|
||||||
|
B, torch.zeros(1)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
compiled_a_stride,
|
||||||
|
compiled_cloned_stride,
|
||||||
|
f"Strides should match in eager: {compiled_a_stride} against {compiled_cloned_stride}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extension of https://github.com/pytorch/pytorch/issues/161010
|
||||||
|
# in the non memory dense case
|
||||||
|
def test_clone_not_memory_dense(self):
|
||||||
|
def foo() -> torch.Tensor:
|
||||||
|
x = torch.randn(10, 8).t()[::2, ::2]
|
||||||
|
y = x.clone()
|
||||||
|
return y
|
||||||
|
|
||||||
|
y = foo()
|
||||||
|
self.assertEqual(
|
||||||
|
y.stride(),
|
||||||
|
(1, 4),
|
||||||
|
"Reference eager implementation should have stride (1, 4)",
|
||||||
|
)
|
||||||
|
y = torch.compile(foo, backend="eager")()
|
||||||
|
self.assertEqual(
|
||||||
|
y.stride(), (1, 4), "Compile with eager backend should have stride (1, 4)"
|
||||||
|
)
|
||||||
|
y = torch.compile(foo, backend="aot_eager")()
|
||||||
|
self.assertEqual(
|
||||||
|
y.stride(),
|
||||||
|
(1, 4),
|
||||||
|
"Compile with aot_eager backend should have stride (1, 4)",
|
||||||
|
)
|
||||||
|
y = torch.compile(foo, backend="inductor")()
|
||||||
|
self.assertEqual(
|
||||||
|
y.stride(),
|
||||||
|
(1, 4),
|
||||||
|
"Compile with inductor backend should have stride (1, 4)",
|
||||||
|
)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/146598
|
# https://github.com/pytorch/pytorch/issues/146598
|
||||||
@unittest.expectedFailure
|
@unittest.expectedFailure
|
||||||
|
|||||||
@ -342,6 +342,16 @@ $1: f32[2] = torch._ops.prims.sin.default($0)""")
|
|||||||
x = torch.randn(4, dtype=torch.complex64, device='meta').conj()
|
x = torch.randn(4, dtype=torch.complex64, device='meta').conj()
|
||||||
x + 1
|
x + 1
|
||||||
|
|
||||||
|
def test_clone_meta_stride_preservation_dense(self):
|
||||||
|
tensor = torch.randn(1, 5).t()
|
||||||
|
meta_clone = prims._clone_meta(tensor, memory_format=torch.preserve_format)
|
||||||
|
self.assertEqual(tensor.stride(), meta_clone.stride())
|
||||||
|
|
||||||
|
def test_clone_meta_stride_preservation_sparse(self):
|
||||||
|
tensor = torch.arange(12).float().view(3, 4)[1:, ::2]
|
||||||
|
meta_clone = prims._clone_meta(tensor, memory_format=torch.preserve_format)
|
||||||
|
self.assertEqual(tensor.contiguous().stride(), meta_clone.stride())
|
||||||
|
|
||||||
def test_check_deprecation_warning(self):
|
def test_check_deprecation_warning(self):
|
||||||
with self.assertWarnsRegex(FutureWarning, 'will be removed in the future'):
|
with self.assertWarnsRegex(FutureWarning, 'will be removed in the future'):
|
||||||
torch._prims_common.check(True, lambda: 'message')
|
torch._prims_common.check(True, lambda: 'message')
|
||||||
|
|||||||
@ -692,16 +692,22 @@ def _clone_meta(
|
|||||||
device=input.device,
|
device=input.device,
|
||||||
memory_format=memory_format,
|
memory_format=memory_format,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# Match eager behavior by preserving strides for non_overlapping_and_dense tensors
|
||||||
|
# If not, eager clone creates contiguous strides
|
||||||
|
computed_stride = None
|
||||||
|
if utils.is_non_overlapping_and_dense(input):
|
||||||
|
computed_stride = input.stride()
|
||||||
|
else:
|
||||||
|
computed_stride = utils.compute_elementwise_output_strides(input)
|
||||||
|
|
||||||
# memory_format == torch.preserve_format
|
return torch.empty_strided(
|
||||||
strides = utils.compute_elementwise_output_strides(input)
|
input.shape,
|
||||||
return torch.empty_strided(
|
computed_stride,
|
||||||
input.shape,
|
dtype=input.dtype,
|
||||||
strides,
|
layout=input.layout,
|
||||||
dtype=input.dtype,
|
device=input.device,
|
||||||
layout=input.layout,
|
)
|
||||||
device=input.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
clone = _make_prim(
|
clone = _make_prim(
|
||||||
|
|||||||
Reference in New Issue
Block a user