[dynamo] [guard] Add caching for inside torch.compile.disable function to avoid unnecessary recompilation. (#160934)

Fixes #157399
cherry pick of d6a5c03

@mlazos

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160934
Approved by: https://github.com/mlazos
This commit is contained in:
thenumberouscode
2025-08-19 06:01:22 +00:00
committed by PyTorch MergeBot
parent 29afde2020
commit 8f31aa97a3
5 changed files with 91 additions and 7 deletions

View File

@ -8628,15 +8628,64 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
self.assertEqual(seen_frames[1].name, "uwu_inline_me")
self.assertEqual(seen_frames[2].line, "r2 = uwu_inline_me_deep(y, z)")
def test_error_on_recompile(self):
def test_recompile_on_disable_1(self):
# fix https://github.com/pytorch/pytorch/issues/157399
@torch.compile(backend="eager")
def fn(a, b):
return a + b
def fn(x):
@torch._dynamo.disable
def inner(x):
return x + 10
return inner(x) + 1
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
try:
for i in range(5):
fn(torch.rand(2, 3))
except torch._dynamo.exc.RecompileError as e:
self.fail("RecompileError raised unexpectedly: " + str(e))
def test_recompile_on_disable_2(self):
def outer(x, cond):
@torch._dynamo.disable()
def fn0(y):
return y + 1
@torch._dynamo.disable()
def fn1(y):
return y + 2
if cond:
f = fn0
else:
f = fn1
torch._dynamo.graph_break()
# there will be a resume function here
return f(x)
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
with self.assertRaises(torch._dynamo.exc.RecompileError):
fn(torch.rand(2, 3), torch.rand(2, 3))
fn(torch.rand(2, 3), (1, 2, 3))
x = torch.rand(2, 3)
self.assertEqual(outer(x, True), torch.compile(outer)(x, True))
self.assertEqual(outer(x, False), torch.compile(outer)(x, False))
def test_create_nested_fn_cache_clear(self):
def outer(x):
@torch._dynamo.disable()
def f(y):
return y + 2
return f(x) + 1
outer = torch.compile(outer)
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
with self.assertRaises(torch._dynamo.exc.RecompileError):
outer(torch.randn(3, 3))
from torch._dynamo.utils import create_nested_fn_cache
create_nested_fn_cache.clear()
outer(torch.randn(3, 3))
def test_guards_strip_function_call(self):
from torch._dynamo.guards import strip_function_call

View File

@ -614,6 +614,8 @@ class TestAutograd(TestCase):
with disable_gc():
unpack_hook_ref = scope()
if torch._dynamo.is_compiling():
torch._dynamo.reset()
self.assertIsNone(unpack_hook_ref())
def test_will_engine_execute_node(self):

View File

@ -51,6 +51,7 @@ from .mutation_guard import GenerationTracker
from .pgo import reset_code_state
from .symbolic_convert import TensorifyState
from .utils import (
create_nested_fn_cache,
graph_break_reasons,
guard_failures,
orig_code_map,
@ -144,6 +145,7 @@ def reset() -> None:
torch._dynamo.utils.warn_once_cache.clear()
torch._dynamo.utils.user_obj_id_to_weakref.clear()
torch._C._autograd._saved_tensors_hooks_set_tracing(False)
create_nested_fn_cache.clear()
def reset_code_caches() -> None:

View File

@ -4836,3 +4836,22 @@ def get_traced_code() -> Optional[list[CodeType]]:
from torch._guards import TracingContext
return TracingContext.get_traced_code()
class CreateNestedFnCache:
cache: dict[str, types.FunctionType] = {}
@classmethod
def get(cls, key: str) -> Optional[types.FunctionType]:
return cls.cache.get(key, None)
@classmethod
def set(cls, key: str, value: types.FunctionType) -> None:
cls.cache[key] = value
@classmethod
def clear(cls: type[CreateNestedFnCache]) -> None:
cls.cache.clear()
create_nested_fn_cache: CreateNestedFnCache = CreateNestedFnCache()

View File

@ -69,6 +69,7 @@ from ..utils import (
check_unspec_or_constant_args,
cmp_name_to_op_mapping,
counters,
create_nested_fn_cache,
identity,
is_function,
is_wrapper_or_member_descriptor,
@ -276,6 +277,11 @@ def _create_nested_fn(
):
from types import FunctionType
# Add caching for the actual IDs of user functions so that we can use them in the ID_MATCH guard.
cache_key = str(id(code)) + str(id(closure)) + str(id(f_globals))
if create_nested_fn_cache.get(cache_key):
return create_nested_fn_cache.get(cache_key)
func = FunctionType(code, f_globals, name, defaults, closure)
func.__kwdefaults__ = kwdefaults
@ -287,7 +293,7 @@ def _create_nested_fn(
# TypeError: __annotations__ must be set to a dict object
assert annotations is None or isinstance(annotations, dict)
func.__annotations__ = annotations
create_nested_fn_cache.set(cache_key, func)
return func
@ -1466,7 +1472,13 @@ class SkipFunctionVariable(VariableTracker):
@classmethod
def create_with_source(cls, value, source):
if not is_wrapper_or_member_descriptor(value):
if inspect.getattr_static(value, "_torchdynamo_orig_callable", False):
install_guard(
AttrSource(source, "_torchdynamo_orig_callable").make_guard(
GuardBuilder.FUNCTION_MATCH
)
)
elif not is_wrapper_or_member_descriptor(value):
# These descriptors are not guaranteed to return the same object on
# attribute lookup. They are unlikely to be changed, so we can skip
# guarding them.