Revert "Fix standalone compile for op with multiple outputs (#96936)"

This reverts commit 37cde56658e20afae6d94b70d53e4131043e09e8.

Reverted https://github.com/pytorch/pytorch/pull/96936 on behalf of https://github.com/kit1980 due to Broke inductor tests on macos-12-py3-arm64 https://github.com/pytorch/pytorch/actions/runs/4458548491/jobs/7830566793
This commit is contained in:
PyTorch MergeBot
2023-03-19 20:32:13 +00:00
parent 90537a779c
commit 5d33f9cddb
2 changed files with 0 additions and 21 deletions

View File

@ -2,7 +2,6 @@
import torch import torch
from torch import _dynamo as dynamo, _inductor as inductor from torch import _dynamo as dynamo, _inductor as inductor
from torch._dynamo.test_case import run_tests, TestCase from torch._dynamo.test_case import run_tests, TestCase
from torch._inductor.utils import gen_gm_and_inputs
from torch.fx import symbolic_trace from torch.fx import symbolic_trace
from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.inductor_utils import HAS_CPU from torch.testing._internal.inductor_utils import HAS_CPU
@ -97,18 +96,6 @@ class TestStandaloneInductor(TestCase):
actual = mod_opt(inp) actual = mod_opt(inp)
self.assertEqual(actual, correct) self.assertEqual(actual, correct)
def test_inductor_via_op_with_multiple_outputs(self):
x1 = torch.randn((2, 512, 128))
x2 = [128]
x3 = torch.randn((128))
x4 = torch.randn((128,))
x5 = 1e-6
mod, inp = gen_gm_and_inputs(
torch.ops.aten.native_layer_norm.default, (x1, x2, x3, x4, x5), {}
)
mod_opt = inductor.compile(mod, inp)
self.assertEqual(mod(*inp), mod_opt(*inp))
if __name__ == "__main__": if __name__ == "__main__":
if HAS_CPU: if HAS_CPU:

View File

@ -572,14 +572,6 @@ def graph_returns_tuple(gm: torch.fx.GraphModule):
(rv,) = output_node(gm).args (rv,) = output_node(gm).args
if isinstance(rv, (list, tuple)): if isinstance(rv, (list, tuple)):
return True return True
if (
isinstance(rv, torch.fx.node.Node)
and hasattr(rv.target, "_schema")
and len(rv.target._schema.returns) > 1
and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns)
):
# for graphs whose result is one node with multiple outputs
return True
return False return False