Fix fake kernel for the out=... variant of unbind_copy (#156643)

`unbind_copy(..., out=...)` returns None rather than the `out` argument
(see https://github.com/pytorch/pytorch/issues/130829#issuecomment-2283936222),
but the old fake kernel didn't account for that and caused an assertion
failure in `pushPyOutToStack`. This patch fixes that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156643
Approved by: https://github.com/zou3519, https://github.com/jansel, https://github.com/bdhirsh
ghstack dependencies: #156642
This commit is contained in:
Ryan Guo
2025-06-26 10:27:21 -07:00
committed by PyTorch MergeBot
parent 89aa708b39
commit a4b59498c5
6 changed files with 81 additions and 80 deletions

View File

@ -7027,6 +7027,18 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
torch.compile(f, backend="eager", fullgraph=True)(x, out_res)
self.assertEqual(out_ref, out_res)
def test_unbind_copy_out(self):
def f(eye, out):
torch.unbind_copy(eye, out=out)
eye = torch.eye(3)
out_ref = (torch.zeros(3), torch.zeros(3), torch.zeros(3))
out_res = (torch.zeros(3), torch.zeros(3), torch.zeros(3))
f(eye, out_ref)
torch.compile(f, backend="eager", fullgraph=True)(eye, out_res)
self.assertEqual(out_ref, out_res)
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
def test_sub_alpha_scalar_repro(self, device):

View File

@ -992,6 +992,17 @@ class FakeTensorTest(TestCase):
self.assertEqual(out.dtype, x.dtype)
def test_unbind_copy_out(self):
# Regression test to ensure we don't error out.
with torch._subclasses.fake_tensor.FakeTensorMode() as mode:
eye = torch.eye(3)
out = (torch.zeros(3), torch.zeros(3), torch.zeros(3))
torch.unbind_copy(eye, out=out)
self.assertEqual(out[0].dtype, eye.dtype)
self.assertEqual(out[1].dtype, eye.dtype)
self.assertEqual(out[2].dtype, eye.dtype)
instantiate_parametrized_tests(FakeTensorTest)

View File

@ -118,21 +118,17 @@ _ops_and_refs_with_no_numpy_ref = [op for op in ops_and_refs if op.ref is None]
aten = torch.ops.aten
meta_consistency_out_dtype_mismatch_xfails = {
xfail("alias_copy"),
xfail("all"),
xfail("amax"),
xfail("amin"),
xfail("aminmax"),
xfail("any"),
xfail("as_strided_copy"),
xfail("bucketize"),
xfail("conj_physical"),
xfail("cross"),
xfail("cummax"),
xfail("cummin"),
xfail("diag"),
xfail("diagonal_copy"),
xfail("expand_copy"),
xfail("fft.ihfft2"),
xfail("fft.ihfftn"),
xfail("frexp"),
@ -167,8 +163,6 @@ meta_consistency_out_dtype_mismatch_xfails = {
xfail("msort"),
xfail("multinomial"),
xfail("nan_to_num"),
xfail("nanmean"),
xfail("narrow_copy"),
xfail("native_batch_norm"),
xfail("neg"),
xfail("nn.functional.avg_pool3d"),
@ -178,7 +172,6 @@ meta_consistency_out_dtype_mismatch_xfails = {
xfail("nn.functional.softplus"),
xfail("nn.functional.softshrink"),
xfail("ormqr"),
xfail("permute_copy"),
xfail("qr"),
xfail("renorm"),
xfail("round"),
@ -193,15 +186,10 @@ meta_consistency_out_dtype_mismatch_xfails = {
xfail("softmax"),
xfail("sort"),
xfail("sparse.sampled_addmm"),
xfail("squeeze_copy"),
xfail("t_copy"),
xfail("take"),
xfail("transpose_copy"),
xfail("tril"),
xfail("triu"),
xfail("unfold_copy"),
xfail("unsqueeze_copy"),
xfail("view_copy"),
xfail("where"),
# Output has dynamic shape.
# Does not have a meta kernel implementation.
@ -2498,7 +2486,6 @@ fake_skips = (
"mvlgamma.mvlgamma_p_1", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend
"mvlgamma.mvlgamma_p_3", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend
"mvlgamma.mvlgamma_p_5", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend
"nanmean", # logical_not() got an unexpected keyword argument 'out'
"quantile", # quantile() q values must be in the range [0, 1]
"nanquantile", # quantile() q values must be in the range [0, 1]
"nn.functional.ctc_loss", # The tensor has a non-zero number of elements, but its data is not allocated yet

View File

@ -1325,15 +1325,31 @@ For now, dynamo will explicitly graph break when it encounters user code with th
# variant torch ops, the original function could come from a user
# defined `@allow_in_graph` function as well, which doesn't have the
# same semantics as the torch ops.
fake_out_shape = None
if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
# Calling fake tensor propagation can mutate the out= tensor in
# tx.output.tracked_fakes. tracked_fakes are used to apply
# symbolic_shape guards. Mutating them destroys the information
# prior to tracing, which is essential for creating right
# guards. So save the shape now, and check later if it has
# changed. If it has, graph break.
fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
saved_out_shapes = None
out_kwarg_vt = None
if "out" in kwargs:
out_kwarg_vt = kwargs["out"]
# e.g., out=(t1, t2, ...)
if isinstance(out_kwarg_vt, (TupleVariable, ListVariable)):
saved_out_shapes = []
for vt in out_kwarg_vt.items:
if isinstance(vt, variables.TensorVariable):
shape = vt.proxy.node.meta["example_value"].shape
else:
shape = None
saved_out_shapes.append(shape)
# e.g., out=output_tensor
if isinstance(out_kwarg_vt, variables.TensorVariable):
saved_out_shapes = out_kwarg_vt.proxy.node.meta["example_value"].shape
tensor_variable = wrap_fx_proxy(
tx=tx,
@ -1356,10 +1372,7 @@ Either create the tensor outside the compiled region, or do not set the tensor t
)
# Handle e.g., `torch.add(a, b, out=result)`
if "out" in kwargs and not (
isinstance(kwargs["out"], variables.ConstantVariable)
and kwargs["out"].as_python_constant() is None
):
if saved_out_shapes is not None:
# out variants of torch operators like torch.sort and torch.sigmoid
# mutate the tensors in the out field.
#
@ -1371,26 +1384,20 @@ Either create the tensor outside the compiled region, or do not set the tensor t
# Note that although these tensor variablels would hold different
# proxies, the in-place mutation semantics is preserved in the FX
# graph, so we won't have correctness issues.
if isinstance(tensor_variable, TupleVariable):
assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
for out_tensor, result_tensor in zip(
kwargs["out"].items, tensor_variable.items
if isinstance(saved_out_shapes, list):
for out_tensor_vt, saved_out_shape in zip(
out_kwarg_vt.items, # type: ignore[union-attr]
saved_out_shapes,
):
if (
isinstance(out_tensor, variables.TensorVariable)
and isinstance(result_tensor, variables.TensorVariable)
and out_tensor._size
!= result_tensor._size # we actually want to compare None values here
):
# It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs")
elif isinstance(tensor_variable, TensorVariable):
assert isinstance(kwargs["out"], TensorVariable)
assert "example_value" in kwargs["out"].proxy.node.meta
fake_tensor = tensor_variable.proxy.node.meta["example_value"]
fake_out = kwargs["out"].proxy.node.meta["example_value"]
if fake_out_shape != fake_tensor.shape:
if saved_out_shape is None:
# This should be extremely rare, but it's kept for now
# until we invest in enforcing the `out=` kwarg for only
# torch methods.
continue
assert isinstance(out_tensor_vt, TensorVariable)
fake_out = out_tensor_vt.proxy.node.meta["example_value"]
if saved_out_shape != fake_out.shape:
# It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs")
@ -1400,32 +1407,20 @@ Either create the tensor outside the compiled region, or do not set the tensor t
unimplemented(
"out= op was called where output tensor was non-contiguous"
)
elif (
isinstance(tensor_variable, ConstantVariable)
and tensor_variable.value is None
):
# Handle out-variant custom ops that return None.
if isinstance(kwargs["out"], TensorVariable):
assert "example_value" in kwargs["out"].proxy.node.meta
fake_out = kwargs["out"].proxy.node.meta["example_value"]
if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument
unimplemented(
"out= op was called where output tensor was non-contiguous"
)
elif isinstance(kwargs["out"], ListVariable):
for idx, x in enumerate(kwargs["out"].items):
assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined]
fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined]
if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument
unimplemented(
"out= op was called where some of the output tensors were non-contiguous"
)
else:
unimplemented(f"out variant of {type(kwargs['out'])}")
assert isinstance(out_kwarg_vt, TensorVariable)
assert "example_value" in out_kwarg_vt.proxy.node.meta
fake_out = out_kwarg_vt.proxy.node.meta["example_value"]
if saved_out_shapes != fake_out.shape:
# It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs")
if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument
unimplemented(
"out= op was called where output tensor was non-contiguous"
)
return tensor_variable

View File

@ -2238,17 +2238,21 @@ def _reduction(
return result
def _make_copy_from_view(fn):
def _make_copy_from_view(fn, return_none_on_out_variant=False):
"""
Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy)
"""
aten_fn = getattr(aten, fn.__name__)
annotations = getattr(fn, "__annotations__", {})
fn = out_wrapper()(aten_fn)
# view ops should not change dtypes, this ensures that the decomp path has
# the same error checks as eager.
fn = out_wrapper(exact_dtype=True)(aten_fn)
@wraps(fn)
def _fn(*args, out=None, **kwargs):
result = fn(*args, out=out, **kwargs)
if return_none_on_out_variant and out is not None:
return None
if out is not None:
return result
@ -6491,7 +6495,7 @@ squeeze_copy = _make_copy_from_view(aten.squeeze)
permute_copy = _make_copy_from_view(aten.permute)
t_copy = _make_copy_from_view(aten.t)
transpose_copy = _make_copy_from_view(aten.transpose)
unbind_copy = _make_copy_from_view(aten.unbind)
unbind_copy = _make_copy_from_view(aten.unbind, return_none_on_out_variant=True)
unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
view_copy = _make_copy_from_view(aten.view)

View File

@ -19627,15 +19627,7 @@ op_db: list[OpInfo] = [
supports_gradgrad=True,
supports_out=True,
check_batched_grad=False,
skips=(
# Expected __torch_dispatch__ for aten::unbind_copy.int_out to return None
# but it returned something else instead.
DecorateInfo(
unittest.expectedFailure,
'TestProxyTensorOpInfo',
'test_make_fx_symbolic_exhaustive_out'
),
)),
OpInfo('vstack',
aliases=('row_stack',),
dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),