diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 46f0e67a5b7b..e9cc4e5026f5 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -4822,6 +4822,67 @@ class ReproTests(torch._dynamo.test_case.TestCase): "encountered a mutation on a view chain of length 2, where view 1 was an as_strided", ): f_compiled(a) + # See https://github.com/pytorch/pytorch/issues/161010 + + def test_preserve_stride_with_clone(self) -> None: + A = torch.rand(5, 5, device="cuda" if torch.cuda.is_available() else "cpu") + B = torch.rand(5, 5, device="cuda" if torch.cuda.is_available() else "cpu") + + def fn( + src: torch.Tensor, count: torch.Tensor + ) -> tuple[tuple[int, ...], tuple[int, ...]]: + Q, R = torch.linalg.qr(src) + rhs = torch.ones(Q.shape[0], 1, device=src.device) + a = torch.linalg.solve_triangular(R, Q.T @ rhs, upper=True) + cloned = a.clone(memory_format=torch.preserve_format) + return a.stride(), cloned.stride() + + a_stride, cloned_stride = fn(A, torch.zeros(1)) + self.assertEqual( + a_stride, + cloned_stride, + f"Strides should match in eager: {a_stride} against {cloned_stride}", + ) + + compiled_a_stride, compiled_cloned_stride = torch.compile(fn, backend="eager")( + B, torch.zeros(1) + ) + self.assertEqual( + compiled_a_stride, + compiled_cloned_stride, + f"Strides should match in eager: {compiled_a_stride} against {compiled_cloned_stride}", + ) + + # Extension of https://github.com/pytorch/pytorch/issues/161010 + # in the non memory dense case + def test_clone_not_memory_dense(self): + def foo() -> torch.Tensor: + x = torch.randn(10, 8).t()[::2, ::2] + y = x.clone() + return y + + y = foo() + self.assertEqual( + y.stride(), + (1, 4), + "Reference eager implementation should have stride (1, 4)", + ) + y = torch.compile(foo, backend="eager")() + self.assertEqual( + y.stride(), (1, 4), "Compile with eager backend should have stride (1, 4)" + ) + y = torch.compile(foo, backend="aot_eager")() + self.assertEqual( + y.stride(), + (1, 4), + "Compile with aot_eager backend should have stride (1, 4)", + ) + y = torch.compile(foo, backend="inductor")() + self.assertEqual( + y.stride(), + (1, 4), + "Compile with inductor backend should have stride (1, 4)", + ) # https://github.com/pytorch/pytorch/issues/146598 @unittest.expectedFailure diff --git a/test/test_prims.py b/test/test_prims.py index 58ed8a7dd758..e528a1eb2e4e 100644 --- a/test/test_prims.py +++ b/test/test_prims.py @@ -342,6 +342,16 @@ $1: f32[2] = torch._ops.prims.sin.default($0)""") x = torch.randn(4, dtype=torch.complex64, device='meta').conj() x + 1 + def test_clone_meta_stride_preservation_dense(self): + tensor = torch.randn(1, 5).t() + meta_clone = prims._clone_meta(tensor, memory_format=torch.preserve_format) + self.assertEqual(tensor.stride(), meta_clone.stride()) + + def test_clone_meta_stride_preservation_sparse(self): + tensor = torch.arange(12).float().view(3, 4)[1:, ::2] + meta_clone = prims._clone_meta(tensor, memory_format=torch.preserve_format) + self.assertEqual(tensor.contiguous().stride(), meta_clone.stride()) + def test_check_deprecation_warning(self): with self.assertWarnsRegex(FutureWarning, 'will be removed in the future'): torch._prims_common.check(True, lambda: 'message') diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 5fef517dc59f..034263ea4849 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -692,16 +692,22 @@ def _clone_meta( device=input.device, memory_format=memory_format, ) + else: + # Match eager behavior by preserving strides for non_overlapping_and_dense tensors + # If not, eager clone creates contiguous strides + computed_stride = None + if utils.is_non_overlapping_and_dense(input): + computed_stride = input.stride() + else: + computed_stride = utils.compute_elementwise_output_strides(input) - # memory_format == torch.preserve_format - strides = utils.compute_elementwise_output_strides(input) - return torch.empty_strided( - input.shape, - strides, - dtype=input.dtype, - layout=input.layout, - device=input.device, - ) + return torch.empty_strided( + input.shape, + computed_stride, + dtype=input.dtype, + layout=input.layout, + device=input.device, + ) clone = _make_prim(