mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[dynamo] Error when user nests FX with dynamo (#87797)"
This reverts commit 1da5aeb97b73664ff0fe2f4bb48379655cede969. Reverted https://github.com/pytorch/pytorch/pull/87797 on behalf of https://github.com/ezyang due to breaks nvfuser stack, needs more investigation
This commit is contained in:
@ -2732,20 +2732,6 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
||||
dynamo_result = graph(x)
|
||||
self.assertTrue(same(real, dynamo_result))
|
||||
|
||||
def test_error_on_nested_fx_trace(self):
|
||||
input = torch.rand(2, 3)
|
||||
|
||||
def f(x):
|
||||
x + x
|
||||
|
||||
real = f(input)
|
||||
|
||||
optimized = torch._dynamo.optimize("eager")(f)
|
||||
self.assertTrue(same(optimized(input), real))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Detected that you are using FX"):
|
||||
gm = torch.fx.symbolic_trace(optimized)
|
||||
|
||||
|
||||
class CustomFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
@ -8,14 +8,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
TestCase,
|
||||
TEST_SCIPY,
|
||||
skipCUDAMemoryLeakCheckIf,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_SCIPY, skipCUDAMemoryLeakCheckIf
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyCUDA,
|
||||
@ -394,7 +387,6 @@ class TestPrims(TestCase):
|
||||
actual = execute(gm, a.mT, executor="nvfuser")
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@skipIfTorchDynamo
|
||||
def test_nvfuser_capability_context(self, device):
|
||||
# This test is to ensure that the torch calls are replaced with refs
|
||||
# based on the nvfuser+prims capability
|
||||
|
@ -153,10 +153,6 @@ dynamo_import = __name__.replace(".config", "")
|
||||
# How to import torchinductor, either torchinductor or torch.inductor
|
||||
inductor_import = dynamo_import.replace("dynamo", "inductor")
|
||||
|
||||
# If true, error with a better message if we symbolically trace over a
|
||||
# dynamo-optimized function. If false, silently suppress dynamo.
|
||||
error_on_nested_fx_trace = True
|
||||
|
||||
# root folder of the project
|
||||
if "torch." in dynamo_import:
|
||||
base_dir = dirname(dirname(dirname(abspath(__file__))))
|
||||
|
@ -14,7 +14,6 @@ from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.fx._symbolic_trace import is_fx_tracing
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
@ -150,14 +149,6 @@ class _TorchDynamoContext:
|
||||
|
||||
@functools.wraps(fn)
|
||||
def _fn(*args, **kwargs):
|
||||
if is_fx_tracing():
|
||||
if config.error_on_nested_fx_trace:
|
||||
raise RuntimeError(
|
||||
"Detected that you are using FX to symbolically trace "
|
||||
"a dynamo-optimized function. This is not supported at the moment."
|
||||
)
|
||||
return fn
|
||||
|
||||
on_enter()
|
||||
prior = set_eval_frame(callback)
|
||||
backend_ctx = backend_ctx_ctor()
|
||||
|
Reference in New Issue
Block a user