mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix fake kernel for the out=...
variant of unbind_copy
(#156643)
`unbind_copy(..., out=...)` returns None rather than the `out` argument (see https://github.com/pytorch/pytorch/issues/130829#issuecomment-2283936222), but the old fake kernel didn't account for that and caused an assertion failure in `pushPyOutToStack`. This patch fixes that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/156643 Approved by: https://github.com/zou3519, https://github.com/jansel, https://github.com/bdhirsh ghstack dependencies: #156642
This commit is contained in:
committed by
PyTorch MergeBot
parent
89aa708b39
commit
a4b59498c5
@ -7027,6 +7027,18 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
torch.compile(f, backend="eager", fullgraph=True)(x, out_res)
|
||||
self.assertEqual(out_ref, out_res)
|
||||
|
||||
def test_unbind_copy_out(self):
|
||||
def f(eye, out):
|
||||
torch.unbind_copy(eye, out=out)
|
||||
|
||||
eye = torch.eye(3)
|
||||
out_ref = (torch.zeros(3), torch.zeros(3), torch.zeros(3))
|
||||
out_res = (torch.zeros(3), torch.zeros(3), torch.zeros(3))
|
||||
|
||||
f(eye, out_ref)
|
||||
torch.compile(f, backend="eager", fullgraph=True)(eye, out_res)
|
||||
self.assertEqual(out_ref, out_res)
|
||||
|
||||
|
||||
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
def test_sub_alpha_scalar_repro(self, device):
|
||||
|
@ -992,6 +992,17 @@ class FakeTensorTest(TestCase):
|
||||
|
||||
self.assertEqual(out.dtype, x.dtype)
|
||||
|
||||
def test_unbind_copy_out(self):
|
||||
# Regression test to ensure we don't error out.
|
||||
with torch._subclasses.fake_tensor.FakeTensorMode() as mode:
|
||||
eye = torch.eye(3)
|
||||
out = (torch.zeros(3), torch.zeros(3), torch.zeros(3))
|
||||
torch.unbind_copy(eye, out=out)
|
||||
|
||||
self.assertEqual(out[0].dtype, eye.dtype)
|
||||
self.assertEqual(out[1].dtype, eye.dtype)
|
||||
self.assertEqual(out[2].dtype, eye.dtype)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(FakeTensorTest)
|
||||
|
||||
|
@ -118,21 +118,17 @@ _ops_and_refs_with_no_numpy_ref = [op for op in ops_and_refs if op.ref is None]
|
||||
aten = torch.ops.aten
|
||||
|
||||
meta_consistency_out_dtype_mismatch_xfails = {
|
||||
xfail("alias_copy"),
|
||||
xfail("all"),
|
||||
xfail("amax"),
|
||||
xfail("amin"),
|
||||
xfail("aminmax"),
|
||||
xfail("any"),
|
||||
xfail("as_strided_copy"),
|
||||
xfail("bucketize"),
|
||||
xfail("conj_physical"),
|
||||
xfail("cross"),
|
||||
xfail("cummax"),
|
||||
xfail("cummin"),
|
||||
xfail("diag"),
|
||||
xfail("diagonal_copy"),
|
||||
xfail("expand_copy"),
|
||||
xfail("fft.ihfft2"),
|
||||
xfail("fft.ihfftn"),
|
||||
xfail("frexp"),
|
||||
@ -167,8 +163,6 @@ meta_consistency_out_dtype_mismatch_xfails = {
|
||||
xfail("msort"),
|
||||
xfail("multinomial"),
|
||||
xfail("nan_to_num"),
|
||||
xfail("nanmean"),
|
||||
xfail("narrow_copy"),
|
||||
xfail("native_batch_norm"),
|
||||
xfail("neg"),
|
||||
xfail("nn.functional.avg_pool3d"),
|
||||
@ -178,7 +172,6 @@ meta_consistency_out_dtype_mismatch_xfails = {
|
||||
xfail("nn.functional.softplus"),
|
||||
xfail("nn.functional.softshrink"),
|
||||
xfail("ormqr"),
|
||||
xfail("permute_copy"),
|
||||
xfail("qr"),
|
||||
xfail("renorm"),
|
||||
xfail("round"),
|
||||
@ -193,15 +186,10 @@ meta_consistency_out_dtype_mismatch_xfails = {
|
||||
xfail("softmax"),
|
||||
xfail("sort"),
|
||||
xfail("sparse.sampled_addmm"),
|
||||
xfail("squeeze_copy"),
|
||||
xfail("t_copy"),
|
||||
xfail("take"),
|
||||
xfail("transpose_copy"),
|
||||
xfail("tril"),
|
||||
xfail("triu"),
|
||||
xfail("unfold_copy"),
|
||||
xfail("unsqueeze_copy"),
|
||||
xfail("view_copy"),
|
||||
xfail("where"),
|
||||
# Output has dynamic shape.
|
||||
# Does not have a meta kernel implementation.
|
||||
@ -2498,7 +2486,6 @@ fake_skips = (
|
||||
"mvlgamma.mvlgamma_p_1", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend
|
||||
"mvlgamma.mvlgamma_p_3", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend
|
||||
"mvlgamma.mvlgamma_p_5", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend
|
||||
"nanmean", # logical_not() got an unexpected keyword argument 'out'
|
||||
"quantile", # quantile() q values must be in the range [0, 1]
|
||||
"nanquantile", # quantile() q values must be in the range [0, 1]
|
||||
"nn.functional.ctc_loss", # The tensor has a non-zero number of elements, but its data is not allocated yet
|
||||
|
@ -1325,15 +1325,31 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
||||
# variant torch ops, the original function could come from a user
|
||||
# defined `@allow_in_graph` function as well, which doesn't have the
|
||||
# same semantics as the torch ops.
|
||||
fake_out_shape = None
|
||||
if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
|
||||
|
||||
# Calling fake tensor propagation can mutate the out= tensor in
|
||||
# tx.output.tracked_fakes. tracked_fakes are used to apply
|
||||
# symbolic_shape guards. Mutating them destroys the information
|
||||
# prior to tracing, which is essential for creating right
|
||||
# guards. So save the shape now, and check later if it has
|
||||
# changed. If it has, graph break.
|
||||
fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
|
||||
saved_out_shapes = None
|
||||
out_kwarg_vt = None
|
||||
if "out" in kwargs:
|
||||
out_kwarg_vt = kwargs["out"]
|
||||
|
||||
# e.g., out=(t1, t2, ...)
|
||||
if isinstance(out_kwarg_vt, (TupleVariable, ListVariable)):
|
||||
saved_out_shapes = []
|
||||
for vt in out_kwarg_vt.items:
|
||||
if isinstance(vt, variables.TensorVariable):
|
||||
shape = vt.proxy.node.meta["example_value"].shape
|
||||
else:
|
||||
shape = None
|
||||
saved_out_shapes.append(shape)
|
||||
|
||||
# e.g., out=output_tensor
|
||||
if isinstance(out_kwarg_vt, variables.TensorVariable):
|
||||
saved_out_shapes = out_kwarg_vt.proxy.node.meta["example_value"].shape
|
||||
|
||||
tensor_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
@ -1356,10 +1372,7 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
||||
)
|
||||
|
||||
# Handle e.g., `torch.add(a, b, out=result)`
|
||||
if "out" in kwargs and not (
|
||||
isinstance(kwargs["out"], variables.ConstantVariable)
|
||||
and kwargs["out"].as_python_constant() is None
|
||||
):
|
||||
if saved_out_shapes is not None:
|
||||
# out variants of torch operators like torch.sort and torch.sigmoid
|
||||
# mutate the tensors in the out field.
|
||||
#
|
||||
@ -1371,26 +1384,20 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
||||
# Note that although these tensor variablels would hold different
|
||||
# proxies, the in-place mutation semantics is preserved in the FX
|
||||
# graph, so we won't have correctness issues.
|
||||
if isinstance(tensor_variable, TupleVariable):
|
||||
assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
|
||||
for out_tensor, result_tensor in zip(
|
||||
kwargs["out"].items, tensor_variable.items
|
||||
if isinstance(saved_out_shapes, list):
|
||||
for out_tensor_vt, saved_out_shape in zip(
|
||||
out_kwarg_vt.items, # type: ignore[union-attr]
|
||||
saved_out_shapes,
|
||||
):
|
||||
if (
|
||||
isinstance(out_tensor, variables.TensorVariable)
|
||||
and isinstance(result_tensor, variables.TensorVariable)
|
||||
and out_tensor._size
|
||||
!= result_tensor._size # we actually want to compare None values here
|
||||
):
|
||||
# It's hard to get out variants with resizing on graph inputs work
|
||||
# properly across dynamo/aot/inductor, just fall back.
|
||||
unimplemented("out variants with resizing on graph inputs")
|
||||
elif isinstance(tensor_variable, TensorVariable):
|
||||
assert isinstance(kwargs["out"], TensorVariable)
|
||||
assert "example_value" in kwargs["out"].proxy.node.meta
|
||||
fake_tensor = tensor_variable.proxy.node.meta["example_value"]
|
||||
fake_out = kwargs["out"].proxy.node.meta["example_value"]
|
||||
if fake_out_shape != fake_tensor.shape:
|
||||
if saved_out_shape is None:
|
||||
# This should be extremely rare, but it's kept for now
|
||||
# until we invest in enforcing the `out=` kwarg for only
|
||||
# torch methods.
|
||||
continue
|
||||
|
||||
assert isinstance(out_tensor_vt, TensorVariable)
|
||||
fake_out = out_tensor_vt.proxy.node.meta["example_value"]
|
||||
if saved_out_shape != fake_out.shape:
|
||||
# It's hard to get out variants with resizing on graph inputs work
|
||||
# properly across dynamo/aot/inductor, just fall back.
|
||||
unimplemented("out variants with resizing on graph inputs")
|
||||
@ -1400,32 +1407,20 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
||||
unimplemented(
|
||||
"out= op was called where output tensor was non-contiguous"
|
||||
)
|
||||
elif (
|
||||
isinstance(tensor_variable, ConstantVariable)
|
||||
and tensor_variable.value is None
|
||||
):
|
||||
# Handle out-variant custom ops that return None.
|
||||
if isinstance(kwargs["out"], TensorVariable):
|
||||
assert "example_value" in kwargs["out"].proxy.node.meta
|
||||
fake_out = kwargs["out"].proxy.node.meta["example_value"]
|
||||
if not torch._prims_common.is_contiguous(fake_out):
|
||||
# It's difficult to handle strides correctly in functionalization
|
||||
# when calling an out= op with a non-contiguous out argument
|
||||
unimplemented(
|
||||
"out= op was called where output tensor was non-contiguous"
|
||||
)
|
||||
elif isinstance(kwargs["out"], ListVariable):
|
||||
for idx, x in enumerate(kwargs["out"].items):
|
||||
assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined]
|
||||
fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined]
|
||||
if not torch._prims_common.is_contiguous(fake_out):
|
||||
# It's difficult to handle strides correctly in functionalization
|
||||
# when calling an out= op with a non-contiguous out argument
|
||||
unimplemented(
|
||||
"out= op was called where some of the output tensors were non-contiguous"
|
||||
)
|
||||
else:
|
||||
unimplemented(f"out variant of {type(kwargs['out'])}")
|
||||
assert isinstance(out_kwarg_vt, TensorVariable)
|
||||
assert "example_value" in out_kwarg_vt.proxy.node.meta
|
||||
fake_out = out_kwarg_vt.proxy.node.meta["example_value"]
|
||||
if saved_out_shapes != fake_out.shape:
|
||||
# It's hard to get out variants with resizing on graph inputs work
|
||||
# properly across dynamo/aot/inductor, just fall back.
|
||||
unimplemented("out variants with resizing on graph inputs")
|
||||
if not torch._prims_common.is_contiguous(fake_out):
|
||||
# It's difficult to handle strides correctly in functionalization
|
||||
# when calling an out= op with a non-contiguous out argument
|
||||
unimplemented(
|
||||
"out= op was called where output tensor was non-contiguous"
|
||||
)
|
||||
|
||||
return tensor_variable
|
||||
|
||||
|
@ -2238,17 +2238,21 @@ def _reduction(
|
||||
return result
|
||||
|
||||
|
||||
def _make_copy_from_view(fn):
|
||||
def _make_copy_from_view(fn, return_none_on_out_variant=False):
|
||||
"""
|
||||
Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy)
|
||||
"""
|
||||
aten_fn = getattr(aten, fn.__name__)
|
||||
annotations = getattr(fn, "__annotations__", {})
|
||||
fn = out_wrapper()(aten_fn)
|
||||
# view ops should not change dtypes, this ensures that the decomp path has
|
||||
# the same error checks as eager.
|
||||
fn = out_wrapper(exact_dtype=True)(aten_fn)
|
||||
|
||||
@wraps(fn)
|
||||
def _fn(*args, out=None, **kwargs):
|
||||
result = fn(*args, out=out, **kwargs)
|
||||
if return_none_on_out_variant and out is not None:
|
||||
return None
|
||||
if out is not None:
|
||||
return result
|
||||
|
||||
@ -6491,7 +6495,7 @@ squeeze_copy = _make_copy_from_view(aten.squeeze)
|
||||
permute_copy = _make_copy_from_view(aten.permute)
|
||||
t_copy = _make_copy_from_view(aten.t)
|
||||
transpose_copy = _make_copy_from_view(aten.transpose)
|
||||
unbind_copy = _make_copy_from_view(aten.unbind)
|
||||
unbind_copy = _make_copy_from_view(aten.unbind, return_none_on_out_variant=True)
|
||||
unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
|
||||
view_copy = _make_copy_from_view(aten.view)
|
||||
|
||||
|
@ -19627,15 +19627,7 @@ op_db: list[OpInfo] = [
|
||||
supports_gradgrad=True,
|
||||
supports_out=True,
|
||||
check_batched_grad=False,
|
||||
skips=(
|
||||
# Expected __torch_dispatch__ for aten::unbind_copy.int_out to return None
|
||||
# but it returned something else instead.
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
'TestProxyTensorOpInfo',
|
||||
'test_make_fx_symbolic_exhaustive_out'
|
||||
),
|
||||
)),
|
||||
OpInfo('vstack',
|
||||
aliases=('row_stack',),
|
||||
dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
|
||||
|
Reference in New Issue
Block a user