[dynamo] Prevent unnecessary recompile on disabled functions in the compiled frame (#161883)

Trying out a re-impl of https://github.com/pytorch/pytorch/pull/160934

The above PR led to OOM, most likely because of the cache holding to a nested function (which if not held in the cache would have been garbage collected), which holds on to cuda tensors in its closure.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161883
Approved by: https://github.com/jansel
This commit is contained in:
Animesh Jain
2025-09-01 12:18:39 -07:00
committed by PyTorch MergeBot
parent 1c1b28d5b6
commit e9481b6617
7 changed files with 59 additions and 7 deletions

View File

@ -106,7 +106,7 @@ dlrm,pass,0
doctr_det_predictor,pass,4
doctr_det_predictor,pass,3

1 name accuracy graph_breaks
106
107
108
109
110
111
112

View File

@ -106,7 +106,7 @@ dlrm,pass,0
doctr_det_predictor,pass,4
doctr_det_predictor,pass,3

1 name accuracy graph_breaks
106
107
108
109
110
111
112

View File

@ -106,7 +106,7 @@ dlrm,pass,0
doctr_det_predictor,pass,4
doctr_det_predictor,pass,3

1 name accuracy graph_breaks
106
107
108
109
110
111
112

View File

@ -106,7 +106,7 @@ dlrm,pass,0
doctr_det_predictor,pass,4
doctr_det_predictor,pass,3

1 name accuracy graph_breaks
106
107
108
109
110
111
112

View File

@ -106,7 +106,7 @@ dlrm,pass,0
doctr_det_predictor,pass,4
doctr_det_predictor,pass,3

1 name accuracy graph_breaks
106
107
108
109
110
111
112

View File

@ -8647,6 +8647,42 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
self.assertEqual(seen_frames[1].name, "uwu_inline_me")
self.assertEqual(seen_frames[2].line, "r2 = uwu_inline_me_deep(y, z)")
def test_recompile_on_disable_1(self):
# fix https://github.com/pytorch/pytorch/issues/157399
@torch.compile(backend="eager")
def fn(x):
@torch._dynamo.disable
def inner(x):
return x + 10
return inner(x) + 1
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
try:
for i in range(5):
fn(torch.rand(2, 3))
except torch._dynamo.exc.RecompileError as e:
self.fail("RecompileError raised unexpectedly: " + str(e))
def test_recompile_on_disable_2(self):
def outer(x, cond):
@torch._dynamo.disable()
def fn0(y):
return y + 1
@torch._dynamo.disable()
def fn1(y):
return y + 2
if cond:
f = fn0
else:
f = fn1
torch._dynamo.graph_break()
# there will be a resume function here
return f(x)
def test_error_on_recompile(self):
@torch.compile(backend="eager")
def fn(a, b):

View File

@ -1466,11 +1466,27 @@ class SkipFunctionVariable(VariableTracker):
@classmethod
def create_with_source(cls, value, source):
if not is_wrapper_or_member_descriptor(value):
# Use closure match guard (i.e. guard on __code__ object instead of
# function id) to avoid guarding on nested functions.
if inspect.getattr_static(value, "_torchdynamo_disable", False):
# For torch._dynamo.disable function, ensure that the original
# function is guarded. Otherwise, the else branch will guard on the
# _dynamo.disable.__code__
guard_on_source = source
guard_on_value = value
while getattr(guard_on_value, "_torchdynamo_orig_callable", False):
guard_on_value = guard_on_value._torchdynamo_orig_callable
guard_on_source = AttrSource(
guard_on_source, "_torchdynamo_orig_callable"
)
guard_on_source.make_guard(GuardBuilder.FUNCTION_MATCH)
elif not is_wrapper_or_member_descriptor(value):
# These descriptors are not guaranteed to return the same object on
# attribute lookup. They are unlikely to be changed, so we can skip
# guarding them.
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
return cls(value, source=source)
def call_function(