mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Fix standalone compile for op with multiple outputs (#96936)"
This reverts commit 37cde56658e20afae6d94b70d53e4131043e09e8. Reverted https://github.com/pytorch/pytorch/pull/96936 on behalf of https://github.com/kit1980 due to Broke inductor tests on macos-12-py3-arm64 https://github.com/pytorch/pytorch/actions/runs/4458548491/jobs/7830566793
This commit is contained in:
@ -2,7 +2,6 @@
|
||||
import torch
|
||||
from torch import _dynamo as dynamo, _inductor as inductor
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import gen_gm_and_inputs
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU
|
||||
@ -97,18 +96,6 @@ class TestStandaloneInductor(TestCase):
|
||||
actual = mod_opt(inp)
|
||||
self.assertEqual(actual, correct)
|
||||
|
||||
def test_inductor_via_op_with_multiple_outputs(self):
|
||||
x1 = torch.randn((2, 512, 128))
|
||||
x2 = [128]
|
||||
x3 = torch.randn((128))
|
||||
x4 = torch.randn((128,))
|
||||
x5 = 1e-6
|
||||
mod, inp = gen_gm_and_inputs(
|
||||
torch.ops.aten.native_layer_norm.default, (x1, x2, x3, x4, x5), {}
|
||||
)
|
||||
mod_opt = inductor.compile(mod, inp)
|
||||
self.assertEqual(mod(*inp), mod_opt(*inp))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if HAS_CPU:
|
||||
|
@ -572,14 +572,6 @@ def graph_returns_tuple(gm: torch.fx.GraphModule):
|
||||
(rv,) = output_node(gm).args
|
||||
if isinstance(rv, (list, tuple)):
|
||||
return True
|
||||
if (
|
||||
isinstance(rv, torch.fx.node.Node)
|
||||
and hasattr(rv.target, "_schema")
|
||||
and len(rv.target._schema.returns) > 1
|
||||
and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns)
|
||||
):
|
||||
# for graphs whose result is one node with multiple outputs
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user