mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
unbacked handling for view_copy (#159244)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159244 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
222fa451a2
commit
2523e58781
@ -3187,9 +3187,10 @@ class TestUbackedOps(TestCase):
|
||||
f = y.item()
|
||||
t1 = x.view((f, f))
|
||||
t2 = x.reshape((f, f))
|
||||
t3 = torch._ops.ops.aten.view_copy(x, (f, f))
|
||||
# TODO avoid _check_is_size here.
|
||||
torch._check_is_size(f)
|
||||
return t1 * 10, t2 * 10
|
||||
return t1 * 10, t2 * 10, t3
|
||||
|
||||
compiled_func = torch.compile(
|
||||
fullgraph=True,
|
||||
@ -3229,10 +3230,12 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)",
|
||||
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
|
||||
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
|
||||
view: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense])
|
||||
view_1: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]); arg3_1 = _local_scalar_dense = None
|
||||
mul_9: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
|
||||
mul_12: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None
|
||||
return (mul_9, mul_12)""", # noqa: B950
|
||||
view_1: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense])
|
||||
view_2: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]); arg3_1 = _local_scalar_dense = None
|
||||
clone: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.clone.default(view_2); view_2 = None
|
||||
mul_11: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
|
||||
mul_14: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None
|
||||
return (mul_11, mul_14, clone)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
@ -3268,10 +3271,12 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
|
||||
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
|
||||
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
|
||||
view: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense])
|
||||
view_1: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None
|
||||
mul_4: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
|
||||
mul_7: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None
|
||||
return (mul_4, mul_7)""", # noqa: B950
|
||||
view_1: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense])
|
||||
view_2: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None
|
||||
clone: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.clone.default(view_2); view_2 = None
|
||||
mul_6: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
|
||||
mul_9: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None
|
||||
return (mul_6, mul_9, clone)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
@ -601,6 +601,18 @@ def _view_meta(fake_mode, func, a, *shape):
|
||||
return torch._refs._reshape_view_helper(a, *shape, allow_copy=False)
|
||||
|
||||
|
||||
@register_op_impl(aten.view_copy.default)
|
||||
def _view_meta_copy(fake_mode, func, a, *shape, out=None):
|
||||
result = _view_meta(fake_mode, func, a, *shape)
|
||||
if out is not None:
|
||||
return result
|
||||
|
||||
return pytree.tree_map(
|
||||
lambda x: x.clone(memory_format=torch.contiguous_format),
|
||||
result,
|
||||
)
|
||||
|
||||
|
||||
@register_op_impl(aten.repeat_interleave.Tensor)
|
||||
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
|
||||
if output_size is None:
|
||||
|
Reference in New Issue
Block a user