From 839f6facdba92f8fe90cbd50721ff9a025474969 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Wed, 15 Oct 2025 02:01:46 +0000 Subject: [PATCH] [precompile] Fix frame construction for wrapped model. (#165454) Summary: If a function is wrapped with functools, we should not look at the wrapped function signature but rather the wrapper, since we need to construct the frame for the top level function here. Test Plan: test_decorated_function_with_functools_wrap_aot Differential Revision: D84626752 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165454 Approved by: https://github.com/yiming0416 --- test/dynamo/test_aot_compile.py | 34 +++++++++++++++++++++++++++++++++ torch/_dynamo/aot_compile.py | 2 +- torch/_dynamo/convert_frame.py | 8 ++++++-- 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_aot_compile.py b/test/dynamo/test_aot_compile.py index c5ff7dd70cb7..d543fe76d65c 100644 --- a/test/dynamo/test_aot_compile.py +++ b/test/dynamo/test_aot_compile.py @@ -1,5 +1,6 @@ # Owner(s): ["module: dynamo"] +import functools import inspect import os import pickle @@ -203,6 +204,39 @@ class TestAOTCompile(torch._inductor.test_case.TestCase): actual = compiled_fn(*example_inputs) self.assertEqual(expected, actual) + def test_decorated_function_with_functools_wrap_aot(self): + def check_inputs(fn): + @functools.wraps(fn) + def _fn(*args, **kwargs): + for arg in args: + assert arg.shape[0] > 1 + + return fn(*args, **kwargs) + + return _fn + + @check_inputs + def foo(x, y): + a = x + x + b = y + y + c = a + b + return c + + example_inputs = (torch.ones(3), torch.ones(3)) + expected = foo(*example_inputs) + + def backend(gm, example_inputs): + return CustomCompiledFunction(gm, example_inputs) + + with torch.compiler.set_stance("fail_on_recompile"): + compiled_fn = torch.compile( + foo, + fullgraph=True, + backend=backend, + ).aot_compile((example_inputs, {})) + actual = compiled_fn(*example_inputs) + self.assertEqual(expected, actual) + def test_aot_compile_disable_guard_check(self): def fn(x, y): return x + y diff --git a/torch/_dynamo/aot_compile.py b/torch/_dynamo/aot_compile.py index 142e244067ba..c49f54edfd3f 100644 --- a/torch/_dynamo/aot_compile.py +++ b/torch/_dynamo/aot_compile.py @@ -279,7 +279,7 @@ def aot_compile_fullgraph( source_info.add_code(traced_code) artifacts = CompileArtifacts( - signature=inspect.signature(fn), + signature=convert_frame._get_signature(fn), bytecode=graph_capture_output.bytecode, guard_manager=check_fn.guard_manager, guards_state=check_fn.guards_state, diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 0e73948f50b8..cf7392763e6c 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -29,6 +29,7 @@ import cProfile import dis import functools import gc +import inspect import itertools import logging import os @@ -975,6 +976,10 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]: raise RuntimeError(f"Unsupported model code type {mod}") +def _get_signature(fn: Any) -> inspect.Signature: + return inspect.signature(fn, follow_wrapped=False) + + def _get_frame( mod: Any, args: tuple[Any, ...], @@ -984,7 +989,6 @@ def _get_frame( Create a frame to trace, given a model, args, and optional kwargs. """ import builtins - import inspect fn, self_opt = get_traced_fn(mod) if self_opt is not None: @@ -992,7 +996,7 @@ def _get_frame( if kwargs is None: kwargs = {} - signature = inspect.signature(fn) + signature = _get_signature(fn) bound_arguments = signature.bind(*args, **kwargs) bound_arguments.apply_defaults() f_locals = bound_arguments.arguments