mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo][compile-time] Cache the function signature to speedup inlining (#153396)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153396 Approved by: https://github.com/jansel, https://github.com/StrongerXi ghstack dependencies: #153333
This commit is contained in:
committed by
PyTorch MergeBot
parent
2344eca5eb
commit
8f3d7972ad
@ -6,7 +6,7 @@ add_loop_eager_dynamic,compile_time_instruction_count,5928000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor,compile_time_instruction_count,29400000000,0.015
|
||||
add_loop_inductor,compile_time_instruction_count,29380000000,0.015
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,25900000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,954500000,0.015
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,945000000,0.015
|
||||
|
||||
|
||||
|
||||
@ -54,11 +54,11 @@ symint_sum_loop,compile_time_instruction_count,4262000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2079000000,0.015
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2091000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5940000000,0.015
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5981000000,0.015
|
||||
|
||||
|
||||
|
||||
@ -74,4 +74,4 @@ aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3818000000,
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10270000000,0.015
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10290000000,0.015
|
||||
|
|
@ -30,9 +30,11 @@ import itertools
|
||||
import sys
|
||||
import types
|
||||
from collections.abc import Sequence
|
||||
from types import FunctionType
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import Never
|
||||
from unittest.mock import patch
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
import torch
|
||||
|
||||
@ -88,6 +90,104 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
_F = TypeVar("_F", bound=Callable)
|
||||
CO_VARARGS = 0x04
|
||||
CO_VARKEYWORDS = 0x08
|
||||
|
||||
|
||||
# Module‐level cache keyed by the function object
|
||||
_spec_cache = WeakKeyDictionary()
|
||||
|
||||
|
||||
class FunctionSpec:
|
||||
def __init__(self, func: FunctionType):
|
||||
code = func.__code__
|
||||
vn = code.co_varnames
|
||||
|
||||
self.posonly_count = code.co_posonlyargcount
|
||||
self.arg_count = code.co_argcount
|
||||
self.kwonly_count = code.co_kwonlyargcount
|
||||
|
||||
self.posonly_names = vn[: self.posonly_count]
|
||||
self.pos_or_kw_names = vn[self.posonly_count : self.arg_count]
|
||||
self.all_pos_names = self.posonly_names + self.pos_or_kw_names
|
||||
self.kwonly_names = vn[self.arg_count : self.arg_count + self.kwonly_count]
|
||||
|
||||
off = self.arg_count + self.kwonly_count
|
||||
self.varargs_name = vn[off] if code.co_flags & CO_VARARGS else None
|
||||
off += 1 if self.varargs_name else 0
|
||||
self.varkw_name = vn[off] if code.co_flags & CO_VARKEYWORDS else None
|
||||
|
||||
def update_defaults(self, func: FunctionType):
|
||||
# Defaults can change from function call to function call. So re-update
|
||||
# them on every call.
|
||||
self.defaults = func.__defaults__ or ()
|
||||
self.kwdefaults = func.__kwdefaults__ or {}
|
||||
|
||||
# Map positional‐default names → their index in self.defaults
|
||||
self.pos_default_map = dict(
|
||||
zip(self.all_pos_names[-len(self.defaults) :], range(len(self.defaults)))
|
||||
)
|
||||
|
||||
|
||||
def _get_spec(func: FunctionType) -> FunctionSpec:
|
||||
spec = _spec_cache.get(func)
|
||||
if spec is None:
|
||||
spec = FunctionSpec(func)
|
||||
_spec_cache[func] = spec
|
||||
return spec
|
||||
|
||||
|
||||
def bind_args_cached(func, tx, fn_source, args, kwargs):
|
||||
spec = _get_spec(func)
|
||||
spec.update_defaults(func)
|
||||
ba = {}
|
||||
rem_kw = dict(kwargs)
|
||||
|
||||
# 1) Bind all positional (pos-only + pos-or-kw)
|
||||
for i, name in enumerate(spec.all_pos_names):
|
||||
if i < len(args):
|
||||
ba[name] = wrap_bound_arg(tx, args[i])
|
||||
elif name in rem_kw:
|
||||
if name in spec.posonly_names:
|
||||
raise TypeError(f"{name} is positional-only")
|
||||
ba[name] = wrap_bound_arg(tx, rem_kw.pop(name))
|
||||
elif name in spec.pos_default_map:
|
||||
idx = spec.pos_default_map[name]
|
||||
default_source = None
|
||||
if fn_source:
|
||||
default_source = DefaultsSource(fn_source, idx)
|
||||
ba[name] = wrap_bound_arg(tx, spec.defaults[idx], default_source)
|
||||
else:
|
||||
raise TypeError(f"Missing required positional argument: {name}")
|
||||
|
||||
# 2) *args
|
||||
extra = args[len(spec.all_pos_names) :]
|
||||
if spec.varargs_name:
|
||||
ba[spec.varargs_name] = wrap_bound_arg(tx, tuple(extra))
|
||||
elif extra:
|
||||
raise TypeError(
|
||||
f"Too many positional arguments: got {len(args)}, expected {len(spec.all_pos_names)}"
|
||||
)
|
||||
|
||||
# 3) Keyword-only
|
||||
for name in spec.kwonly_names:
|
||||
if name in rem_kw:
|
||||
ba[name] = wrap_bound_arg(tx, rem_kw.pop(name))
|
||||
elif name in spec.kwdefaults:
|
||||
kwdefault_source = None
|
||||
if fn_source:
|
||||
kwdefault_source = DefaultsSource(fn_source, name, is_kw=True)
|
||||
ba[name] = wrap_bound_arg(tx, spec.kwdefaults[name], kwdefault_source)
|
||||
else:
|
||||
raise TypeError(f"Missing required keyword-only argument: {name}")
|
||||
|
||||
# 4) **kwargs
|
||||
if spec.varkw_name:
|
||||
ba[spec.varkw_name] = wrap_bound_arg(tx, rem_kw)
|
||||
elif rem_kw:
|
||||
raise TypeError(f"Unexpected keyword arguments: {list(rem_kw)}")
|
||||
|
||||
return ba
|
||||
|
||||
|
||||
def wrap_bound_arg(tx: "InstructionTranslator", val, source=None):
|
||||
@ -278,46 +378,14 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
||||
this function, create new bindings for initial locals.
|
||||
"""
|
||||
assert not self.is_constant
|
||||
root_tx = parent.output.root_tx
|
||||
wrap = functools.partial(wrap_bound_arg, tx=root_tx)
|
||||
|
||||
fn: types.FunctionType = self.fn
|
||||
defaults = fn.__defaults__ or []
|
||||
defaults_sources = [
|
||||
None if self.source is None else DefaultsSource(self.source, idx)
|
||||
for idx, _ in enumerate(defaults)
|
||||
]
|
||||
fake_func = types.FunctionType(
|
||||
fn.__code__,
|
||||
fn.__globals__,
|
||||
fn.__name__,
|
||||
tuple(
|
||||
[
|
||||
wrap(val=arg, source=source)
|
||||
for arg, source in zip(defaults, defaults_sources)
|
||||
]
|
||||
),
|
||||
fn.__closure__,
|
||||
)
|
||||
if fn.__kwdefaults__:
|
||||
kwdefaults_sources = {
|
||||
k: (
|
||||
None
|
||||
if self.source is None
|
||||
else DefaultsSource(self.source, k, is_kw=True)
|
||||
)
|
||||
for k in fn.__kwdefaults__
|
||||
}
|
||||
fake_func.__kwdefaults__ = {
|
||||
k: wrap(val=v, source=kwdefaults_sources[k])
|
||||
for k, v in fn.__kwdefaults__.items()
|
||||
}
|
||||
|
||||
bound = inspect.signature(fake_func).bind(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
result = dict(bound.arguments.items())
|
||||
if not isinstance(fn, FunctionType):
|
||||
raise TypeError("Only supports regular Python functions.")
|
||||
root_tx = parent.output.root_tx
|
||||
result = bind_args_cached(fn, root_tx, self.source, args, kwargs)
|
||||
|
||||
wrap_args_kwargs(root_tx, result)
|
||||
init_cellvars(parent, result, fn.__code__)
|
||||
closure = self.fn.__closure__ or ()
|
||||
assert len(closure) == len(self.fn.__code__.co_freevars)
|
||||
|
Reference in New Issue
Block a user