diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index efa1875afdf4..6a721a079a63 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -3187,9 +3187,10 @@ class TestUbackedOps(TestCase): f = y.item() t1 = x.view((f, f)) t2 = x.reshape((f, f)) + t3 = torch._ops.ops.aten.view_copy(x, (f, f)) # TODO avoid _check_is_size here. torch._check_is_size(f) - return t1 * 10, t2 * 10 + return t1 * 10, t2 * 10, t3 compiled_func = torch.compile( fullgraph=True, @@ -3229,10 +3230,12 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None view: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]) - view_1: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]); arg3_1 = _local_scalar_dense = None - mul_9: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None - mul_12: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None - return (mul_9, mul_12)""", # noqa: B950 + view_1: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]) + view_2: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]); arg3_1 = _local_scalar_dense = None + clone: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.clone.default(view_2); view_2 = None + mul_11: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None + mul_14: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None + return (mul_11, mul_14, clone)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, ) @@ -3268,10 +3271,12 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1] eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None view: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]) - view_1: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None - mul_4: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None - mul_7: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None - return (mul_4, mul_7)""", # noqa: B950 + view_1: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]) + view_2: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None + clone: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.clone.default(view_2); view_2 = None + mul_6: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None + mul_9: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None + return (mul_6, mul_9, clone)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, ) diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 7348c3b1c0b4..4d33280f7ac8 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -601,6 +601,18 @@ def _view_meta(fake_mode, func, a, *shape): return torch._refs._reshape_view_helper(a, *shape, allow_copy=False) +@register_op_impl(aten.view_copy.default) +def _view_meta_copy(fake_mode, func, a, *shape, out=None): + result = _view_meta(fake_mode, func, a, *shape) + if out is not None: + return result + + return pytree.tree_map( + lambda x: x.clone(memory_format=torch.contiguous_format), + result, + ) + + @register_op_impl(aten.repeat_interleave.Tensor) def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None): if output_size is None: