mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Prevent unnecessary recompile on disabled functions in the compiled frame (#161883)
Trying out a re-impl of https://github.com/pytorch/pytorch/pull/160934 The above PR led to OOM, most likely because of the cache holding to a nested function (which if not held in the cache would have been garbage collected), which holds on to cuda tensors in its closure. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161883 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
1c1b28d5b6
commit
e9481b6617
@ -106,7 +106,7 @@ dlrm,pass,0
|
||||
|
||||
|
||||
|
||||
doctr_det_predictor,pass,4
|
||||
doctr_det_predictor,pass,3
|
||||
|
||||
|
||||
|
||||
|
|
@ -106,7 +106,7 @@ dlrm,pass,0
|
||||
|
||||
|
||||
|
||||
doctr_det_predictor,pass,4
|
||||
doctr_det_predictor,pass,3
|
||||
|
||||
|
||||
|
||||
|
|
@ -106,7 +106,7 @@ dlrm,pass,0
|
||||
|
||||
|
||||
|
||||
doctr_det_predictor,pass,4
|
||||
doctr_det_predictor,pass,3
|
||||
|
||||
|
||||
|
||||
|
|
@ -106,7 +106,7 @@ dlrm,pass,0
|
||||
|
||||
|
||||
|
||||
doctr_det_predictor,pass,4
|
||||
doctr_det_predictor,pass,3
|
||||
|
||||
|
||||
|
||||
|
|
@ -106,7 +106,7 @@ dlrm,pass,0
|
||||
|
||||
|
||||
|
||||
doctr_det_predictor,pass,4
|
||||
doctr_det_predictor,pass,3
|
||||
|
||||
|
||||
|
||||
|
|
@ -8647,6 +8647,42 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||
self.assertEqual(seen_frames[1].name, "uwu_inline_me")
|
||||
self.assertEqual(seen_frames[2].line, "r2 = uwu_inline_me_deep(y, z)")
|
||||
|
||||
def test_recompile_on_disable_1(self):
|
||||
# fix https://github.com/pytorch/pytorch/issues/157399
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
@torch._dynamo.disable
|
||||
def inner(x):
|
||||
return x + 10
|
||||
|
||||
return inner(x) + 1
|
||||
|
||||
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
||||
try:
|
||||
for i in range(5):
|
||||
fn(torch.rand(2, 3))
|
||||
except torch._dynamo.exc.RecompileError as e:
|
||||
self.fail("RecompileError raised unexpectedly: " + str(e))
|
||||
|
||||
def test_recompile_on_disable_2(self):
|
||||
def outer(x, cond):
|
||||
@torch._dynamo.disable()
|
||||
def fn0(y):
|
||||
return y + 1
|
||||
|
||||
@torch._dynamo.disable()
|
||||
def fn1(y):
|
||||
return y + 2
|
||||
|
||||
if cond:
|
||||
f = fn0
|
||||
else:
|
||||
f = fn1
|
||||
|
||||
torch._dynamo.graph_break()
|
||||
# there will be a resume function here
|
||||
return f(x)
|
||||
|
||||
def test_error_on_recompile(self):
|
||||
@torch.compile(backend="eager")
|
||||
def fn(a, b):
|
||||
|
@ -1466,11 +1466,27 @@ class SkipFunctionVariable(VariableTracker):
|
||||
|
||||
@classmethod
|
||||
def create_with_source(cls, value, source):
|
||||
if not is_wrapper_or_member_descriptor(value):
|
||||
# Use closure match guard (i.e. guard on __code__ object instead of
|
||||
# function id) to avoid guarding on nested functions.
|
||||
if inspect.getattr_static(value, "_torchdynamo_disable", False):
|
||||
# For torch._dynamo.disable function, ensure that the original
|
||||
# function is guarded. Otherwise, the else branch will guard on the
|
||||
# _dynamo.disable.__code__
|
||||
guard_on_source = source
|
||||
guard_on_value = value
|
||||
|
||||
while getattr(guard_on_value, "_torchdynamo_orig_callable", False):
|
||||
guard_on_value = guard_on_value._torchdynamo_orig_callable
|
||||
guard_on_source = AttrSource(
|
||||
guard_on_source, "_torchdynamo_orig_callable"
|
||||
)
|
||||
|
||||
guard_on_source.make_guard(GuardBuilder.FUNCTION_MATCH)
|
||||
elif not is_wrapper_or_member_descriptor(value):
|
||||
# These descriptors are not guaranteed to return the same object on
|
||||
# attribute lookup. They are unlikely to be changed, so we can skip
|
||||
# guarding them.
|
||||
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
||||
install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
|
||||
return cls(value, source=source)
|
||||
|
||||
def call_function(
|
||||
|
Reference in New Issue
Block a user