[dynamo][compile-time] Cache the function signature to speedup inlining (#153396)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153396
Approved by: https://github.com/jansel, https://github.com/StrongerXi
ghstack dependencies: #153333
This commit is contained in:
Animesh Jain
2025-05-13 23:00:29 -07:00
committed by PyTorch MergeBot
parent 2344eca5eb
commit 8f3d7972ad
2 changed files with 109 additions and 41 deletions

View File

@ -6,7 +6,7 @@ add_loop_eager_dynamic,compile_time_instruction_count,5928000000,0.025
add_loop_inductor,compile_time_instruction_count,29400000000,0.015
add_loop_inductor,compile_time_instruction_count,29380000000,0.015
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,25900000000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,954500000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,945000000,0.015
@ -54,11 +54,11 @@ symint_sum_loop,compile_time_instruction_count,4262000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2079000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2091000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5940000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5981000000,0.015
@ -74,4 +74,4 @@ aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3818000000,
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10270000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10290000000,0.015

1 add_loop_eager compile_time_instruction_count 3051000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 954500000 945000000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18240000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 16340000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 10370000000 0.2
10 update_hint_regression compile_time_instruction_count 1715000000 0.02
11 float_args compile_time_instruction_count 444500000 0.015
12 sum_floordiv_regression compile_time_instruction_count 1009000000 0.015
18 aotdispatcher_partitioner_cpu2 compile_time_instruction_count 1900000000 0.015
19 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3818000000 0.015
20 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10270000000 10290000000 0.015
21
22
23
24
54
55
56
57
58
59
60
61
62
63
64
74
75
76
77

View File

@ -30,9 +30,11 @@ import itertools
import sys
import types
from collections.abc import Sequence
from types import FunctionType
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
from typing_extensions import Never
from unittest.mock import patch
from weakref import WeakKeyDictionary
import torch
@ -88,6 +90,104 @@ if TYPE_CHECKING:
_F = TypeVar("_F", bound=Callable)
CO_VARARGS = 0x04
CO_VARKEYWORDS = 0x08
# Modulelevel cache keyed by the function object
_spec_cache = WeakKeyDictionary()
class FunctionSpec:
def __init__(self, func: FunctionType):
code = func.__code__
vn = code.co_varnames
self.posonly_count = code.co_posonlyargcount
self.arg_count = code.co_argcount
self.kwonly_count = code.co_kwonlyargcount
self.posonly_names = vn[: self.posonly_count]
self.pos_or_kw_names = vn[self.posonly_count : self.arg_count]
self.all_pos_names = self.posonly_names + self.pos_or_kw_names
self.kwonly_names = vn[self.arg_count : self.arg_count + self.kwonly_count]
off = self.arg_count + self.kwonly_count
self.varargs_name = vn[off] if code.co_flags & CO_VARARGS else None
off += 1 if self.varargs_name else 0
self.varkw_name = vn[off] if code.co_flags & CO_VARKEYWORDS else None
def update_defaults(self, func: FunctionType):
# Defaults can change from function call to function call. So re-update
# them on every call.
self.defaults = func.__defaults__ or ()
self.kwdefaults = func.__kwdefaults__ or {}
# Map positionaldefault names → their index in self.defaults
self.pos_default_map = dict(
zip(self.all_pos_names[-len(self.defaults) :], range(len(self.defaults)))
)
def _get_spec(func: FunctionType) -> FunctionSpec:
spec = _spec_cache.get(func)
if spec is None:
spec = FunctionSpec(func)
_spec_cache[func] = spec
return spec
def bind_args_cached(func, tx, fn_source, args, kwargs):
spec = _get_spec(func)
spec.update_defaults(func)
ba = {}
rem_kw = dict(kwargs)
# 1) Bind all positional (pos-only + pos-or-kw)
for i, name in enumerate(spec.all_pos_names):
if i < len(args):
ba[name] = wrap_bound_arg(tx, args[i])
elif name in rem_kw:
if name in spec.posonly_names:
raise TypeError(f"{name} is positional-only")
ba[name] = wrap_bound_arg(tx, rem_kw.pop(name))
elif name in spec.pos_default_map:
idx = spec.pos_default_map[name]
default_source = None
if fn_source:
default_source = DefaultsSource(fn_source, idx)
ba[name] = wrap_bound_arg(tx, spec.defaults[idx], default_source)
else:
raise TypeError(f"Missing required positional argument: {name}")
# 2) *args
extra = args[len(spec.all_pos_names) :]
if spec.varargs_name:
ba[spec.varargs_name] = wrap_bound_arg(tx, tuple(extra))
elif extra:
raise TypeError(
f"Too many positional arguments: got {len(args)}, expected {len(spec.all_pos_names)}"
)
# 3) Keyword-only
for name in spec.kwonly_names:
if name in rem_kw:
ba[name] = wrap_bound_arg(tx, rem_kw.pop(name))
elif name in spec.kwdefaults:
kwdefault_source = None
if fn_source:
kwdefault_source = DefaultsSource(fn_source, name, is_kw=True)
ba[name] = wrap_bound_arg(tx, spec.kwdefaults[name], kwdefault_source)
else:
raise TypeError(f"Missing required keyword-only argument: {name}")
# 4) **kwargs
if spec.varkw_name:
ba[spec.varkw_name] = wrap_bound_arg(tx, rem_kw)
elif rem_kw:
raise TypeError(f"Unexpected keyword arguments: {list(rem_kw)}")
return ba
def wrap_bound_arg(tx: "InstructionTranslator", val, source=None):
@ -278,46 +378,14 @@ class UserFunctionVariable(BaseUserFunctionVariable):
this function, create new bindings for initial locals.
"""
assert not self.is_constant
root_tx = parent.output.root_tx
wrap = functools.partial(wrap_bound_arg, tx=root_tx)
fn: types.FunctionType = self.fn
defaults = fn.__defaults__ or []
defaults_sources = [
None if self.source is None else DefaultsSource(self.source, idx)
for idx, _ in enumerate(defaults)
]
fake_func = types.FunctionType(
fn.__code__,
fn.__globals__,
fn.__name__,
tuple(
[
wrap(val=arg, source=source)
for arg, source in zip(defaults, defaults_sources)
]
),
fn.__closure__,
)
if fn.__kwdefaults__:
kwdefaults_sources = {
k: (
None
if self.source is None
else DefaultsSource(self.source, k, is_kw=True)
)
for k in fn.__kwdefaults__
}
fake_func.__kwdefaults__ = {
k: wrap(val=v, source=kwdefaults_sources[k])
for k, v in fn.__kwdefaults__.items()
}
bound = inspect.signature(fake_func).bind(*args, **kwargs)
bound.apply_defaults()
result = dict(bound.arguments.items())
if not isinstance(fn, FunctionType):
raise TypeError("Only supports regular Python functions.")
root_tx = parent.output.root_tx
result = bind_args_cached(fn, root_tx, self.source, args, kwargs)
wrap_args_kwargs(root_tx, result)
init_cellvars(parent, result, fn.__code__)
closure = self.fn.__closure__ or ()
assert len(closure) == len(self.fn.__code__.co_freevars)