mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[dynamo] Cache _dynamo.disable results (#134272)"
This reverts commit a699bd11551e9755bb9238c6b82c369880789397. Reverted https://github.com/pytorch/pytorch/pull/134272 on behalf of https://github.com/ZainRizvi due to Fails internal tests ([comment](https://github.com/pytorch/pytorch/pull/134272#issuecomment-2310649115))
This commit is contained in:
@ -184,20 +184,6 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
all(node.target is not torch.sigmoid for node in gm1.graph.nodes)
|
all(node.target is not torch.sigmoid for node in gm1.graph.nodes)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_disable_no_recompile(self):
|
|
||||||
def gn(x):
|
|
||||||
return torch.cos(x)
|
|
||||||
|
|
||||||
@torch.compile(backend="eager")
|
|
||||||
def fn(x):
|
|
||||||
x = torch.sin(x)
|
|
||||||
x = torch._dynamo.disable(gn, recursive=True)(x)
|
|
||||||
return torch.sin(x)
|
|
||||||
|
|
||||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
|
||||||
for _ in range(5):
|
|
||||||
fn(torch.randn(4))
|
|
||||||
|
|
||||||
def test_allow_in_graph(self):
|
def test_allow_in_graph(self):
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
|
||||||
|
@ -107,4 +107,3 @@ def reset_code_caches() -> None:
|
|||||||
if code:
|
if code:
|
||||||
reset_code(code)
|
reset_code(code)
|
||||||
code_context.clear()
|
code_context.clear()
|
||||||
convert_frame.disabled_codes.clear()
|
|
||||||
|
@ -160,7 +160,6 @@ class Tracker:
|
|||||||
|
|
||||||
input_codes = Tracker()
|
input_codes = Tracker()
|
||||||
output_codes = Tracker()
|
output_codes = Tracker()
|
||||||
disabled_codes: Dict[int, Callable[..., Any]] = {}
|
|
||||||
|
|
||||||
initial_global_state: Optional[GlobalStateGuard] = None
|
initial_global_state: Optional[GlobalStateGuard] = None
|
||||||
|
|
||||||
|
@ -10,7 +10,6 @@ from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
|||||||
|
|
||||||
from . import trace_rules, variables
|
from . import trace_rules, variables
|
||||||
from .comptime import comptime
|
from .comptime import comptime
|
||||||
from .convert_frame import disabled_codes
|
|
||||||
from .eval_frame import DisableContext, innermost_fn, RunOnlyContext
|
from .eval_frame import DisableContext, innermost_fn, RunOnlyContext
|
||||||
from .exc import IncorrectUsage
|
from .exc import IncorrectUsage
|
||||||
from .external_utils import is_compiling
|
from .external_utils import is_compiling
|
||||||
@ -56,15 +55,9 @@ def disable(fn=None, recursive=True):
|
|||||||
"""
|
"""
|
||||||
if recursive:
|
if recursive:
|
||||||
if fn is not None:
|
if fn is not None:
|
||||||
id_fn = id(fn)
|
|
||||||
if cached_fn := disabled_codes.get(id_fn):
|
|
||||||
return cached_fn
|
|
||||||
|
|
||||||
fn = innermost_fn(fn)
|
fn = innermost_fn(fn)
|
||||||
assert callable(fn)
|
assert callable(fn)
|
||||||
out = DisableContext()(fn)
|
return DisableContext()(fn)
|
||||||
disabled_codes[id_fn] = out
|
|
||||||
return out
|
|
||||||
return DisableContext()
|
return DisableContext()
|
||||||
else:
|
else:
|
||||||
return skip(fn)
|
return skip(fn)
|
||||||
|
Reference in New Issue
Block a user