From 42955e04f1c9a51391614965e6b3884a594de6df Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 26 Aug 2024 16:57:53 +0000 Subject: [PATCH] Revert "[dynamo] Cache _dynamo.disable results (#134272)" This reverts commit a699bd11551e9755bb9238c6b82c369880789397. Reverted https://github.com/pytorch/pytorch/pull/134272 on behalf of https://github.com/ZainRizvi due to Fails internal tests ([comment](https://github.com/pytorch/pytorch/pull/134272#issuecomment-2310649115)) --- test/dynamo/test_decorators.py | 14 -------------- torch/_dynamo/__init__.py | 1 - torch/_dynamo/convert_frame.py | 1 - torch/_dynamo/decorators.py | 9 +-------- 4 files changed, 1 insertion(+), 24 deletions(-) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 463634e32163..b915e0633f7c 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -184,20 +184,6 @@ class DecoratorTests(torch._dynamo.test_case.TestCase): all(node.target is not torch.sigmoid for node in gm1.graph.nodes) ) - def test_disable_no_recompile(self): - def gn(x): - return torch.cos(x) - - @torch.compile(backend="eager") - def fn(x): - x = torch.sin(x) - x = torch._dynamo.disable(gn, recursive=True)(x) - return torch.sin(x) - - with torch._dynamo.config.patch(error_on_recompile=True): - for _ in range(5): - fn(torch.randn(4)) - def test_allow_in_graph(self): cnts = torch._dynamo.testing.CompileCounter() diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 00286a32750d..7f58ba7f7bf7 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -107,4 +107,3 @@ def reset_code_caches() -> None: if code: reset_code(code) code_context.clear() - convert_frame.disabled_codes.clear() diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 2f0da2be866d..82ddaca2fa2e 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -160,7 +160,6 @@ class Tracker: input_codes = Tracker() output_codes = Tracker() -disabled_codes: Dict[int, Callable[..., Any]] = {} initial_global_state: Optional[GlobalStateGuard] = None diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 6b557d278a65..c83d8d718a62 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -10,7 +10,6 @@ from torch.utils._python_dispatch import is_traceable_wrapper_subclass from . import trace_rules, variables from .comptime import comptime -from .convert_frame import disabled_codes from .eval_frame import DisableContext, innermost_fn, RunOnlyContext from .exc import IncorrectUsage from .external_utils import is_compiling @@ -56,15 +55,9 @@ def disable(fn=None, recursive=True): """ if recursive: if fn is not None: - id_fn = id(fn) - if cached_fn := disabled_codes.get(id_fn): - return cached_fn - fn = innermost_fn(fn) assert callable(fn) - out = DisableContext()(fn) - disabled_codes[id_fn] = out - return out + return DisableContext()(fn) return DisableContext() else: return skip(fn)