mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[precompile] Fix frame construction for wrapped model. (#165454)
Summary: If a function is wrapped with functools, we should not look at the wrapped function signature but rather the wrapper, since we need to construct the frame for the top level function here. Test Plan: test_decorated_function_with_functools_wrap_aot Differential Revision: D84626752 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165454 Approved by: https://github.com/yiming0416
This commit is contained in:
committed by
PyTorch MergeBot
parent
ca65023b90
commit
839f6facdb
@ -1,5 +1,6 @@
|
|||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
@ -203,6 +204,39 @@ class TestAOTCompile(torch._inductor.test_case.TestCase):
|
|||||||
actual = compiled_fn(*example_inputs)
|
actual = compiled_fn(*example_inputs)
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
|
def test_decorated_function_with_functools_wrap_aot(self):
|
||||||
|
def check_inputs(fn):
|
||||||
|
@functools.wraps(fn)
|
||||||
|
def _fn(*args, **kwargs):
|
||||||
|
for arg in args:
|
||||||
|
assert arg.shape[0] > 1
|
||||||
|
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return _fn
|
||||||
|
|
||||||
|
@check_inputs
|
||||||
|
def foo(x, y):
|
||||||
|
a = x + x
|
||||||
|
b = y + y
|
||||||
|
c = a + b
|
||||||
|
return c
|
||||||
|
|
||||||
|
example_inputs = (torch.ones(3), torch.ones(3))
|
||||||
|
expected = foo(*example_inputs)
|
||||||
|
|
||||||
|
def backend(gm, example_inputs):
|
||||||
|
return CustomCompiledFunction(gm, example_inputs)
|
||||||
|
|
||||||
|
with torch.compiler.set_stance("fail_on_recompile"):
|
||||||
|
compiled_fn = torch.compile(
|
||||||
|
foo,
|
||||||
|
fullgraph=True,
|
||||||
|
backend=backend,
|
||||||
|
).aot_compile((example_inputs, {}))
|
||||||
|
actual = compiled_fn(*example_inputs)
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
def test_aot_compile_disable_guard_check(self):
|
def test_aot_compile_disable_guard_check(self):
|
||||||
def fn(x, y):
|
def fn(x, y):
|
||||||
return x + y
|
return x + y
|
||||||
|
@ -279,7 +279,7 @@ def aot_compile_fullgraph(
|
|||||||
source_info.add_code(traced_code)
|
source_info.add_code(traced_code)
|
||||||
|
|
||||||
artifacts = CompileArtifacts(
|
artifacts = CompileArtifacts(
|
||||||
signature=inspect.signature(fn),
|
signature=convert_frame._get_signature(fn),
|
||||||
bytecode=graph_capture_output.bytecode,
|
bytecode=graph_capture_output.bytecode,
|
||||||
guard_manager=check_fn.guard_manager,
|
guard_manager=check_fn.guard_manager,
|
||||||
guards_state=check_fn.guards_state,
|
guards_state=check_fn.guards_state,
|
||||||
|
@ -29,6 +29,7 @@ import cProfile
|
|||||||
import dis
|
import dis
|
||||||
import functools
|
import functools
|
||||||
import gc
|
import gc
|
||||||
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -975,6 +976,10 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
|
|||||||
raise RuntimeError(f"Unsupported model code type {mod}")
|
raise RuntimeError(f"Unsupported model code type {mod}")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_signature(fn: Any) -> inspect.Signature:
|
||||||
|
return inspect.signature(fn, follow_wrapped=False)
|
||||||
|
|
||||||
|
|
||||||
def _get_frame(
|
def _get_frame(
|
||||||
mod: Any,
|
mod: Any,
|
||||||
args: tuple[Any, ...],
|
args: tuple[Any, ...],
|
||||||
@ -984,7 +989,6 @@ def _get_frame(
|
|||||||
Create a frame to trace, given a model, args, and optional kwargs.
|
Create a frame to trace, given a model, args, and optional kwargs.
|
||||||
"""
|
"""
|
||||||
import builtins
|
import builtins
|
||||||
import inspect
|
|
||||||
|
|
||||||
fn, self_opt = get_traced_fn(mod)
|
fn, self_opt = get_traced_fn(mod)
|
||||||
if self_opt is not None:
|
if self_opt is not None:
|
||||||
@ -992,7 +996,7 @@ def _get_frame(
|
|||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
signature = inspect.signature(fn)
|
signature = _get_signature(fn)
|
||||||
bound_arguments = signature.bind(*args, **kwargs)
|
bound_arguments = signature.bind(*args, **kwargs)
|
||||||
bound_arguments.apply_defaults()
|
bound_arguments.apply_defaults()
|
||||||
f_locals = bound_arguments.arguments
|
f_locals = bound_arguments.arguments
|
||||||
|
Reference in New Issue
Block a user