Revert "[dynamo] [guard] Add caching for inside torch.compile.disable function to avoid unnecessary recompilation. (#157566)"

This reverts commit 8e07c9870d07c5a318ab21bb16b3fa27576851e6.

Reverted https://github.com/pytorch/pytorch/pull/157566 on behalf of https://github.com/yangw-dev due to failed an odd internal test, please reach out to metamate to fix it, D79112610 ([comment](https://github.com/pytorch/pytorch/pull/157566#issuecomment-3141840110))
This commit is contained in:
PyTorch MergeBot
2025-08-01 01:27:44 +00:00
parent 690fc9cf88
commit cb4f41e125
5 changed files with 7 additions and 91 deletions

View File

@ -8556,64 +8556,15 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
self.assertEqual(seen_frames[1].name, "uwu_inline_me")
self.assertEqual(seen_frames[2].line, "r2 = uwu_inline_me_deep(y, z)")
def test_recompile_on_disable_1(self):
# fix https://github.com/pytorch/pytorch/issues/157399
def test_error_on_recompile(self):
@torch.compile(backend="eager")
def fn(x):
@torch._dynamo.disable
def inner(x):
return x + 10
return inner(x) + 1
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
try:
for i in range(5):
fn(torch.rand(2, 3))
except torch._dynamo.exc.RecompileError as e:
self.fail("RecompileError raised unexpectedly: " + str(e))
def test_recompile_on_disable_2(self):
def outer(x, cond):
@torch._dynamo.disable()
def fn0(y):
return y + 1
@torch._dynamo.disable()
def fn1(y):
return y + 2
if cond:
f = fn0
else:
f = fn1
torch._dynamo.graph_break()
# there will be a resume function here
return f(x)
def fn(a, b):
return a + b
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
with self.assertRaises(torch._dynamo.exc.RecompileError):
x = torch.rand(2, 3)
self.assertEqual(outer(x, True), torch.compile(outer)(x, True))
self.assertEqual(outer(x, False), torch.compile(outer)(x, False))
def test_create_nested_fn_cache_clear(self):
def outer(x):
@torch._dynamo.disable()
def f(y):
return y + 2
return f(x) + 1
outer = torch.compile(outer)
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
with self.assertRaises(torch._dynamo.exc.RecompileError):
outer(torch.randn(3, 3))
from torch._dynamo.utils import create_nested_fn_cache
create_nested_fn_cache.clear()
outer(torch.randn(3, 3))
fn(torch.rand(2, 3), torch.rand(2, 3))
fn(torch.rand(2, 3), (1, 2, 3))
def test_guards_strip_function_call(self):
from torch._dynamo.guards import strip_function_call

View File

@ -610,8 +610,6 @@ class TestAutograd(TestCase):
with disable_gc():
unpack_hook_ref = scope()
if torch._dynamo.is_compiling():
torch._dynamo.reset()
self.assertIsNone(unpack_hook_ref())
def test_will_engine_execute_node(self):

View File

@ -51,7 +51,6 @@ from .mutation_guard import GenerationTracker
from .pgo import reset_code_state
from .symbolic_convert import TensorifyState
from .utils import (
create_nested_fn_cache,
graph_break_reasons,
guard_failures,
orig_code_map,
@ -145,7 +144,6 @@ def reset() -> None:
torch._dynamo.utils.warn_once_cache.clear()
torch._dynamo.utils.user_obj_id_to_weakref.clear()
torch._C._autograd._saved_tensors_hooks_set_tracing(False)
create_nested_fn_cache.clear()
def reset_code_caches() -> None:

View File

@ -4771,22 +4771,3 @@ def get_traced_code() -> Optional[list[CodeType]]:
from torch._guards import TracingContext
return TracingContext.get_traced_code()
class CreateNestedFnCache:
cache: dict[str, types.FunctionType] = {}
@classmethod
def get(cls, key):
return cls.cache.get(key, None)
@classmethod
def set(cls, key, value):
cls.cache[key] = value
@classmethod
def clear(cls):
cls.cache.clear()
create_nested_fn_cache: CreateNestedFnCache = CreateNestedFnCache()

View File

@ -62,7 +62,6 @@ from ..utils import (
check_unspec_or_constant_args,
cmp_name_to_op_mapping,
counters,
create_nested_fn_cache,
identity,
is_function,
is_wrapper_or_member_descriptor,
@ -270,11 +269,6 @@ def _create_nested_fn(
):
from types import FunctionType
# Add caching for the actual IDs of user functions so that we can use them in the ID_MATCH guard.
cache_key = str(id(code)) + str(id(closure)) + str(id(f_globals))
if create_nested_fn_cache.get(cache_key):
return create_nested_fn_cache.get(cache_key)
func = FunctionType(code, f_globals, name, defaults, closure)
func.__kwdefaults__ = kwdefaults
@ -286,7 +280,7 @@ def _create_nested_fn(
# TypeError: __annotations__ must be set to a dict object
assert annotations is None or isinstance(annotations, dict)
func.__annotations__ = annotations
create_nested_fn_cache.set(cache_key, func)
return func
@ -1433,13 +1427,7 @@ class SkipFunctionVariable(VariableTracker):
@classmethod
def create_with_source(cls, value, source):
if inspect.getattr_static(value, "_torchdynamo_orig_callable", False):
install_guard(
AttrSource(source, "_torchdynamo_orig_callable").make_guard(
GuardBuilder.FUNCTION_MATCH
)
)
elif not is_wrapper_or_member_descriptor(value):
if not is_wrapper_or_member_descriptor(value):
# These descriptors are not guaranteed to return the same object on
# attribute lookup. They are unlikely to be changed, so we can skip
# guarding them.