mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Bugfix] Match eager stride semantics for cloned tensors with preserve_format in compile (#163017)
Fixes #161010 by making `clone_meta` match the semantics of strides for eager mode. This is: * Case 1: Tensor is_non_overlapping_and_dense; in this case, stride should match input tensor stride * Case 2: Otherwise, stride should be contiguous computed from input tensor using `compute_elementwise_output_strides` Pull Request resolved: https://github.com/pytorch/pytorch/pull/163017 Approved by: https://github.com/williamwen42, https://github.com/xmfan Co-authored-by: morrison-turnansky <mturnans@redhat.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
bc7b17a36d
commit
979e10f7d6
@ -4822,6 +4822,67 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
"encountered a mutation on a view chain of length 2, where view 1 was an as_strided",
|
||||
):
|
||||
f_compiled(a)
|
||||
# See https://github.com/pytorch/pytorch/issues/161010
|
||||
|
||||
def test_preserve_stride_with_clone(self) -> None:
|
||||
A = torch.rand(5, 5, device="cuda" if torch.cuda.is_available() else "cpu")
|
||||
B = torch.rand(5, 5, device="cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def fn(
|
||||
src: torch.Tensor, count: torch.Tensor
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
Q, R = torch.linalg.qr(src)
|
||||
rhs = torch.ones(Q.shape[0], 1, device=src.device)
|
||||
a = torch.linalg.solve_triangular(R, Q.T @ rhs, upper=True)
|
||||
cloned = a.clone(memory_format=torch.preserve_format)
|
||||
return a.stride(), cloned.stride()
|
||||
|
||||
a_stride, cloned_stride = fn(A, torch.zeros(1))
|
||||
self.assertEqual(
|
||||
a_stride,
|
||||
cloned_stride,
|
||||
f"Strides should match in eager: {a_stride} against {cloned_stride}",
|
||||
)
|
||||
|
||||
compiled_a_stride, compiled_cloned_stride = torch.compile(fn, backend="eager")(
|
||||
B, torch.zeros(1)
|
||||
)
|
||||
self.assertEqual(
|
||||
compiled_a_stride,
|
||||
compiled_cloned_stride,
|
||||
f"Strides should match in eager: {compiled_a_stride} against {compiled_cloned_stride}",
|
||||
)
|
||||
|
||||
# Extension of https://github.com/pytorch/pytorch/issues/161010
|
||||
# in the non memory dense case
|
||||
def test_clone_not_memory_dense(self):
|
||||
def foo() -> torch.Tensor:
|
||||
x = torch.randn(10, 8).t()[::2, ::2]
|
||||
y = x.clone()
|
||||
return y
|
||||
|
||||
y = foo()
|
||||
self.assertEqual(
|
||||
y.stride(),
|
||||
(1, 4),
|
||||
"Reference eager implementation should have stride (1, 4)",
|
||||
)
|
||||
y = torch.compile(foo, backend="eager")()
|
||||
self.assertEqual(
|
||||
y.stride(), (1, 4), "Compile with eager backend should have stride (1, 4)"
|
||||
)
|
||||
y = torch.compile(foo, backend="aot_eager")()
|
||||
self.assertEqual(
|
||||
y.stride(),
|
||||
(1, 4),
|
||||
"Compile with aot_eager backend should have stride (1, 4)",
|
||||
)
|
||||
y = torch.compile(foo, backend="inductor")()
|
||||
self.assertEqual(
|
||||
y.stride(),
|
||||
(1, 4),
|
||||
"Compile with inductor backend should have stride (1, 4)",
|
||||
)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/146598
|
||||
@unittest.expectedFailure
|
||||
|
@ -342,6 +342,16 @@ $1: f32[2] = torch._ops.prims.sin.default($0)""")
|
||||
x = torch.randn(4, dtype=torch.complex64, device='meta').conj()
|
||||
x + 1
|
||||
|
||||
def test_clone_meta_stride_preservation_dense(self):
|
||||
tensor = torch.randn(1, 5).t()
|
||||
meta_clone = prims._clone_meta(tensor, memory_format=torch.preserve_format)
|
||||
self.assertEqual(tensor.stride(), meta_clone.stride())
|
||||
|
||||
def test_clone_meta_stride_preservation_sparse(self):
|
||||
tensor = torch.arange(12).float().view(3, 4)[1:, ::2]
|
||||
meta_clone = prims._clone_meta(tensor, memory_format=torch.preserve_format)
|
||||
self.assertEqual(tensor.contiguous().stride(), meta_clone.stride())
|
||||
|
||||
def test_check_deprecation_warning(self):
|
||||
with self.assertWarnsRegex(FutureWarning, 'will be removed in the future'):
|
||||
torch._prims_common.check(True, lambda: 'message')
|
||||
|
@ -692,16 +692,22 @@ def _clone_meta(
|
||||
device=input.device,
|
||||
memory_format=memory_format,
|
||||
)
|
||||
else:
|
||||
# Match eager behavior by preserving strides for non_overlapping_and_dense tensors
|
||||
# If not, eager clone creates contiguous strides
|
||||
computed_stride = None
|
||||
if utils.is_non_overlapping_and_dense(input):
|
||||
computed_stride = input.stride()
|
||||
else:
|
||||
computed_stride = utils.compute_elementwise_output_strides(input)
|
||||
|
||||
# memory_format == torch.preserve_format
|
||||
strides = utils.compute_elementwise_output_strides(input)
|
||||
return torch.empty_strided(
|
||||
input.shape,
|
||||
strides,
|
||||
dtype=input.dtype,
|
||||
layout=input.layout,
|
||||
device=input.device,
|
||||
)
|
||||
return torch.empty_strided(
|
||||
input.shape,
|
||||
computed_stride,
|
||||
dtype=input.dtype,
|
||||
layout=input.layout,
|
||||
device=input.device,
|
||||
)
|
||||
|
||||
|
||||
clone = _make_prim(
|
||||
|
Reference in New Issue
Block a user