[Bugfix] Match eager stride semantics for cloned tensors with preserve_format in compile (#163017)

Fixes #161010 by making `clone_meta` match the semantics of strides for eager mode.

This is:
  * Case 1: Tensor is_non_overlapping_and_dense; in this case, stride should match input tensor stride
  * Case 2: Otherwise, stride should be contiguous computed from input tensor using `compute_elementwise_output_strides`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163017
Approved by: https://github.com/williamwen42, https://github.com/xmfan

Co-authored-by: morrison-turnansky <mturnans@redhat.com>
This commit is contained in:
Lucas Kabela
2025-09-19 19:41:29 +00:00
committed by PyTorch MergeBot
parent bc7b17a36d
commit 979e10f7d6
3 changed files with 86 additions and 9 deletions

View File

@ -342,6 +342,16 @@ $1: f32[2] = torch._ops.prims.sin.default($0)""")
x = torch.randn(4, dtype=torch.complex64, device='meta').conj()
x + 1
def test_clone_meta_stride_preservation_dense(self):
tensor = torch.randn(1, 5).t()
meta_clone = prims._clone_meta(tensor, memory_format=torch.preserve_format)
self.assertEqual(tensor.stride(), meta_clone.stride())
def test_clone_meta_stride_preservation_sparse(self):
tensor = torch.arange(12).float().view(3, 4)[1:, ::2]
meta_clone = prims._clone_meta(tensor, memory_format=torch.preserve_format)
self.assertEqual(tensor.contiguous().stride(), meta_clone.stride())
def test_check_deprecation_warning(self):
with self.assertWarnsRegex(FutureWarning, 'will be removed in the future'):
torch._prims_common.check(True, lambda: 'message')