[Bugfix] Match eager stride semantics for cloned tensors with preserve_format in compile (#163017)

Fixes #161010 by making `clone_meta` match the semantics of strides for eager mode.

This is:
  * Case 1: Tensor is_non_overlapping_and_dense; in this case, stride should match input tensor stride
  * Case 2: Otherwise, stride should be contiguous computed from input tensor using `compute_elementwise_output_strides`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163017
Approved by: https://github.com/williamwen42, https://github.com/xmfan

Co-authored-by: morrison-turnansky <mturnans@redhat.com>
This commit is contained in:
Lucas Kabela
2025-09-19 19:41:29 +00:00
committed by PyTorch MergeBot
parent bc7b17a36d
commit 979e10f7d6
3 changed files with 86 additions and 9 deletions

View File

@ -4822,6 +4822,67 @@ class ReproTests(torch._dynamo.test_case.TestCase):
"encountered a mutation on a view chain of length 2, where view 1 was an as_strided",
):
f_compiled(a)
# See https://github.com/pytorch/pytorch/issues/161010
def test_preserve_stride_with_clone(self) -> None:
A = torch.rand(5, 5, device="cuda" if torch.cuda.is_available() else "cpu")
B = torch.rand(5, 5, device="cuda" if torch.cuda.is_available() else "cpu")
def fn(
src: torch.Tensor, count: torch.Tensor
) -> tuple[tuple[int, ...], tuple[int, ...]]:
Q, R = torch.linalg.qr(src)
rhs = torch.ones(Q.shape[0], 1, device=src.device)
a = torch.linalg.solve_triangular(R, Q.T @ rhs, upper=True)
cloned = a.clone(memory_format=torch.preserve_format)
return a.stride(), cloned.stride()
a_stride, cloned_stride = fn(A, torch.zeros(1))
self.assertEqual(
a_stride,
cloned_stride,
f"Strides should match in eager: {a_stride} against {cloned_stride}",
)
compiled_a_stride, compiled_cloned_stride = torch.compile(fn, backend="eager")(
B, torch.zeros(1)
)
self.assertEqual(
compiled_a_stride,
compiled_cloned_stride,
f"Strides should match in eager: {compiled_a_stride} against {compiled_cloned_stride}",
)
# Extension of https://github.com/pytorch/pytorch/issues/161010
# in the non memory dense case
def test_clone_not_memory_dense(self):
def foo() -> torch.Tensor:
x = torch.randn(10, 8).t()[::2, ::2]
y = x.clone()
return y
y = foo()
self.assertEqual(
y.stride(),
(1, 4),
"Reference eager implementation should have stride (1, 4)",
)
y = torch.compile(foo, backend="eager")()
self.assertEqual(
y.stride(), (1, 4), "Compile with eager backend should have stride (1, 4)"
)
y = torch.compile(foo, backend="aot_eager")()
self.assertEqual(
y.stride(),
(1, 4),
"Compile with aot_eager backend should have stride (1, 4)",
)
y = torch.compile(foo, backend="inductor")()
self.assertEqual(
y.stride(),
(1, 4),
"Compile with inductor backend should have stride (1, 4)",
)
# https://github.com/pytorch/pytorch/issues/146598
@unittest.expectedFailure

View File

@ -342,6 +342,16 @@ $1: f32[2] = torch._ops.prims.sin.default($0)""")
x = torch.randn(4, dtype=torch.complex64, device='meta').conj()
x + 1
def test_clone_meta_stride_preservation_dense(self):
tensor = torch.randn(1, 5).t()
meta_clone = prims._clone_meta(tensor, memory_format=torch.preserve_format)
self.assertEqual(tensor.stride(), meta_clone.stride())
def test_clone_meta_stride_preservation_sparse(self):
tensor = torch.arange(12).float().view(3, 4)[1:, ::2]
meta_clone = prims._clone_meta(tensor, memory_format=torch.preserve_format)
self.assertEqual(tensor.contiguous().stride(), meta_clone.stride())
def test_check_deprecation_warning(self):
with self.assertWarnsRegex(FutureWarning, 'will be removed in the future'):
torch._prims_common.check(True, lambda: 'message')

View File

@ -692,12 +692,18 @@ def _clone_meta(
device=input.device,
memory_format=memory_format,
)
else:
# Match eager behavior by preserving strides for non_overlapping_and_dense tensors
# If not, eager clone creates contiguous strides
computed_stride = None
if utils.is_non_overlapping_and_dense(input):
computed_stride = input.stride()
else:
computed_stride = utils.compute_elementwise_output_strides(input)
# memory_format == torch.preserve_format
strides = utils.compute_elementwise_output_strides(input)
return torch.empty_strided(
input.shape,
strides,
computed_stride,
dtype=input.dtype,
layout=input.layout,
device=input.device,