Revert "[dynamo] Cache _dynamo.disable results (#134272)"

This reverts commit a699bd11551e9755bb9238c6b82c369880789397.

Reverted https://github.com/pytorch/pytorch/pull/134272 on behalf of https://github.com/ZainRizvi due to Fails internal tests ([comment](https://github.com/pytorch/pytorch/pull/134272#issuecomment-2310649115))
This commit is contained in:
PyTorch MergeBot
2024-08-26 16:57:53 +00:00
parent e94bdc7876
commit 42955e04f1
4 changed files with 1 additions and 24 deletions

View File

@ -184,20 +184,6 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
all(node.target is not torch.sigmoid for node in gm1.graph.nodes)
)
def test_disable_no_recompile(self):
def gn(x):
return torch.cos(x)
@torch.compile(backend="eager")
def fn(x):
x = torch.sin(x)
x = torch._dynamo.disable(gn, recursive=True)(x)
return torch.sin(x)
with torch._dynamo.config.patch(error_on_recompile=True):
for _ in range(5):
fn(torch.randn(4))
def test_allow_in_graph(self):
cnts = torch._dynamo.testing.CompileCounter()

View File

@ -107,4 +107,3 @@ def reset_code_caches() -> None:
if code:
reset_code(code)
code_context.clear()
convert_frame.disabled_codes.clear()

View File

@ -160,7 +160,6 @@ class Tracker:
input_codes = Tracker()
output_codes = Tracker()
disabled_codes: Dict[int, Callable[..., Any]] = {}
initial_global_state: Optional[GlobalStateGuard] = None

View File

@ -10,7 +10,6 @@ from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from . import trace_rules, variables
from .comptime import comptime
from .convert_frame import disabled_codes
from .eval_frame import DisableContext, innermost_fn, RunOnlyContext
from .exc import IncorrectUsage
from .external_utils import is_compiling
@ -56,15 +55,9 @@ def disable(fn=None, recursive=True):
"""
if recursive:
if fn is not None:
id_fn = id(fn)
if cached_fn := disabled_codes.get(id_fn):
return cached_fn
fn = innermost_fn(fn)
assert callable(fn)
out = DisableContext()(fn)
disabled_codes[id_fn] = out
return out
return DisableContext()(fn)
return DisableContext()
else:
return skip(fn)