diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv index 01762c5f5f29..1d199fe8ea66 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv @@ -106,7 +106,7 @@ dlrm,pass,0 -doctr_det_predictor,pass,4 +doctr_det_predictor,pass,3 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv index 3e4c3caa1ca9..20cad351b127 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv @@ -106,7 +106,7 @@ dlrm,pass,0 -doctr_det_predictor,pass,4 +doctr_det_predictor,pass,3 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index 63d0efa38f63..2b2c1a504647 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -106,7 +106,7 @@ dlrm,pass,0 -doctr_det_predictor,pass,4 +doctr_det_predictor,pass,3 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv index 01762c5f5f29..1d199fe8ea66 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv @@ -106,7 +106,7 @@ dlrm,pass,0 -doctr_det_predictor,pass,4 +doctr_det_predictor,pass,3 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index fbd169539ab7..e41018657c0e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -106,7 +106,7 @@ dlrm,pass,0 -doctr_det_predictor,pass,4 +doctr_det_predictor,pass,3 diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 62802522767d..1d746a093dc4 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -8647,6 +8647,42 @@ utils_device.CURRENT_DEVICE == None""".split("\n"): self.assertEqual(seen_frames[1].name, "uwu_inline_me") self.assertEqual(seen_frames[2].line, "r2 = uwu_inline_me_deep(y, z)") + def test_recompile_on_disable_1(self): + # fix https://github.com/pytorch/pytorch/issues/157399 + @torch.compile(backend="eager") + def fn(x): + @torch._dynamo.disable + def inner(x): + return x + 10 + + return inner(x) + 1 + + with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): + try: + for i in range(5): + fn(torch.rand(2, 3)) + except torch._dynamo.exc.RecompileError as e: + self.fail("RecompileError raised unexpectedly: " + str(e)) + + def test_recompile_on_disable_2(self): + def outer(x, cond): + @torch._dynamo.disable() + def fn0(y): + return y + 1 + + @torch._dynamo.disable() + def fn1(y): + return y + 2 + + if cond: + f = fn0 + else: + f = fn1 + + torch._dynamo.graph_break() + # there will be a resume function here + return f(x) + def test_error_on_recompile(self): @torch.compile(backend="eager") def fn(a, b): diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 6eb7d0666cd8..9a643fb81922 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1466,11 +1466,27 @@ class SkipFunctionVariable(VariableTracker): @classmethod def create_with_source(cls, value, source): - if not is_wrapper_or_member_descriptor(value): + # Use closure match guard (i.e. guard on __code__ object instead of + # function id) to avoid guarding on nested functions. + if inspect.getattr_static(value, "_torchdynamo_disable", False): + # For torch._dynamo.disable function, ensure that the original + # function is guarded. Otherwise, the else branch will guard on the + # _dynamo.disable.__code__ + guard_on_source = source + guard_on_value = value + + while getattr(guard_on_value, "_torchdynamo_orig_callable", False): + guard_on_value = guard_on_value._torchdynamo_orig_callable + guard_on_source = AttrSource( + guard_on_source, "_torchdynamo_orig_callable" + ) + + guard_on_source.make_guard(GuardBuilder.FUNCTION_MATCH) + elif not is_wrapper_or_member_descriptor(value): # These descriptors are not guaranteed to return the same object on # attribute lookup. They are unlikely to be changed, so we can skip # guarding them. - install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) + install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) return cls(value, source=source) def call_function(