[precompile] Fix frame construction for wrapped model. (#165454)

Summary: If a function is wrapped with functools, we should not look at the wrapped function signature but rather the wrapper, since we need to construct the frame for the top level function here.

Test Plan: test_decorated_function_with_functools_wrap_aot

Differential Revision: D84626752

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165454
Approved by: https://github.com/yiming0416
This commit is contained in:
Zhengxu Chen
2025-10-15 02:01:46 +00:00
committed by PyTorch MergeBot
parent ca65023b90
commit 839f6facdb
3 changed files with 41 additions and 3 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: dynamo"]
import functools
import inspect
import os
import pickle
@ -203,6 +204,39 @@ class TestAOTCompile(torch._inductor.test_case.TestCase):
actual = compiled_fn(*example_inputs)
self.assertEqual(expected, actual)
def test_decorated_function_with_functools_wrap_aot(self):
def check_inputs(fn):
@functools.wraps(fn)
def _fn(*args, **kwargs):
for arg in args:
assert arg.shape[0] > 1
return fn(*args, **kwargs)
return _fn
@check_inputs
def foo(x, y):
a = x + x
b = y + y
c = a + b
return c
example_inputs = (torch.ones(3), torch.ones(3))
expected = foo(*example_inputs)
def backend(gm, example_inputs):
return CustomCompiledFunction(gm, example_inputs)
with torch.compiler.set_stance("fail_on_recompile"):
compiled_fn = torch.compile(
foo,
fullgraph=True,
backend=backend,
).aot_compile((example_inputs, {}))
actual = compiled_fn(*example_inputs)
self.assertEqual(expected, actual)
def test_aot_compile_disable_guard_check(self):
def fn(x, y):
return x + y

View File

@ -279,7 +279,7 @@ def aot_compile_fullgraph(
source_info.add_code(traced_code)
artifacts = CompileArtifacts(
signature=inspect.signature(fn),
signature=convert_frame._get_signature(fn),
bytecode=graph_capture_output.bytecode,
guard_manager=check_fn.guard_manager,
guards_state=check_fn.guards_state,

View File

@ -29,6 +29,7 @@ import cProfile
import dis
import functools
import gc
import inspect
import itertools
import logging
import os
@ -975,6 +976,10 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
raise RuntimeError(f"Unsupported model code type {mod}")
def _get_signature(fn: Any) -> inspect.Signature:
return inspect.signature(fn, follow_wrapped=False)
def _get_frame(
mod: Any,
args: tuple[Any, ...],
@ -984,7 +989,6 @@ def _get_frame(
Create a frame to trace, given a model, args, and optional kwargs.
"""
import builtins
import inspect
fn, self_opt = get_traced_fn(mod)
if self_opt is not None:
@ -992,7 +996,7 @@ def _get_frame(
if kwargs is None:
kwargs = {}
signature = inspect.signature(fn)
signature = _get_signature(fn)
bound_arguments = signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()
f_locals = bound_arguments.arguments