Revert "[dynamo] Error when user nests FX with dynamo (#87797)"

This reverts commit 1da5aeb97b73664ff0fe2f4bb48379655cede969.

Reverted https://github.com/pytorch/pytorch/pull/87797 on behalf of https://github.com/ezyang due to breaks nvfuser stack, needs more investigation
This commit is contained in:
PyTorch MergeBot
2022-10-31 23:49:37 +00:00
parent caaf37a111
commit c0761a835b
4 changed files with 1 additions and 36 deletions

View File

@ -2732,20 +2732,6 @@ class MiscTests(torch._dynamo.test_case.TestCase):
dynamo_result = graph(x)
self.assertTrue(same(real, dynamo_result))
def test_error_on_nested_fx_trace(self):
input = torch.rand(2, 3)
def f(x):
x + x
real = f(input)
optimized = torch._dynamo.optimize("eager")(f)
self.assertTrue(same(optimized(input), real))
with self.assertRaisesRegex(RuntimeError, "Detected that you are using FX"):
gm = torch.fx.symbolic_trace(optimized)
class CustomFunc(torch.autograd.Function):
@staticmethod

View File

@ -8,14 +8,7 @@ import unittest
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
parametrize,
run_tests,
TestCase,
TEST_SCIPY,
skipCUDAMemoryLeakCheckIf,
skipIfTorchDynamo,
)
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_SCIPY, skipCUDAMemoryLeakCheckIf
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCUDA,
@ -394,7 +387,6 @@ class TestPrims(TestCase):
actual = execute(gm, a.mT, executor="nvfuser")
self.assertEqual(expected, actual)
@skipIfTorchDynamo
def test_nvfuser_capability_context(self, device):
# This test is to ensure that the torch calls are replaced with refs
# based on the nvfuser+prims capability

View File

@ -153,10 +153,6 @@ dynamo_import = __name__.replace(".config", "")
# How to import torchinductor, either torchinductor or torch.inductor
inductor_import = dynamo_import.replace("dynamo", "inductor")
# If true, error with a better message if we symbolically trace over a
# dynamo-optimized function. If false, silently suppress dynamo.
error_on_nested_fx_trace = True
# root folder of the project
if "torch." in dynamo_import:
base_dir = dirname(dirname(dirname(abspath(__file__))))

View File

@ -14,7 +14,6 @@ from unittest.mock import patch
import torch
import torch.utils._pytree as pytree
from torch.fx._symbolic_trace import is_fx_tracing
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn.parallel.distributed import DistributedDataParallel
@ -150,14 +149,6 @@ class _TorchDynamoContext:
@functools.wraps(fn)
def _fn(*args, **kwargs):
if is_fx_tracing():
if config.error_on_nested_fx_trace:
raise RuntimeError(
"Detected that you are using FX to symbolically trace "
"a dynamo-optimized function. This is not supported at the moment."
)
return fn
on_enter()
prior = set_eval_frame(callback)
backend_ctx = backend_ctx_ctor()