[dynamo] ensure polyfill function has the same signature as the original function in substitute_in_graph (#133813)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133813
Approved by: https://github.com/jansel
This commit is contained in:
Xuehai Pan
2024-08-21 16:49:23 +08:00
committed by PyTorch MergeBot
parent 240467adfe
commit c95ddd4bf2
3 changed files with 85 additions and 6 deletions

View File

@ -264,13 +264,22 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
torch._dynamo.reset()
counters.clear()
with self.assertRaisesRegex(TypeError, "Signature mismatch"):
@torch._dynamo.substitute_in_graph(operator.indexOf)
def polyfill(sequence, x):
def _(sequence, x):
for i, item in enumerate(sequence):
if item is x or item == x:
return i
raise ValueError("sequence.index(x): x not in sequence")
@torch._dynamo.substitute_in_graph(operator.indexOf)
def polyfill(a, b):
for i, item in enumerate(a):
if item is b or item == b:
return i
raise ValueError("sequence.index(x): x not in sequence")
cnts = torch._dynamo.testing.CompileCounter()
fn = operator.indexOf
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
# ruff: noqa: TCH004
import functools
import inspect
from dataclasses import dataclass
from typing import Any, Callable, TYPE_CHECKING, TypeVar
@ -164,7 +165,11 @@ def forbid_in_graph(fn):
return fn
def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
def substitute_in_graph(
original_fn: _F,
*,
skip_signature_check: bool = False,
) -> Callable[[_F], _F]:
"""
Register a polyfill handler for a function, usually a C function from the C extension, to be
used in place of the original function when inlining the original function in the graph.
@ -182,6 +187,8 @@ def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
Args:
original_fn (callable): The original function, usually a C function, to register a polyfill
handler for.
skip_signature_check (bool, optional): Whether to skip the signature check between the
original function and the polyfill handler. Defaults to ``False``.
Returns:
A decorator that registers the polyfill handler for the original function.
@ -213,6 +220,60 @@ def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
)
def wrapper(traceable_fn: _F) -> _F:
if not is_function(traceable_fn):
raise TypeError(
f"@substitute_in_graph(...) expects a function but got {type(traceable_fn)!r}"
)
if not skip_signature_check:
try:
original_sig = inspect.signature(original_fn)
except ValueError:
pass
else:
traceable_sig = inspect.signature(traceable_fn)
def sig_ident(sig):
# Ignore annotations for parameters and return type
return (
tuple(
p.name
for p in sig.parameters.values()
if (
p.kind
not in {
p.KEYWORD_ONLY,
# the name of *args and **kwargs is not important
p.VAR_POSITIONAL,
p.VAR_KEYWORD,
}
)
),
{
p.name
for p in sig.parameters.values()
if p.kind == p.KEYWORD_ONLY
},
{
p.name: p.default
for p in sig.parameters.values()
# the name of *args and **kwargs is not important
if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
},
)
wildcard_sig = inspect.signature(lambda *args, **kwargs: None)
if (
sig_ident(original_sig) != sig_ident(traceable_sig)
and sig_ident(original_sig) != sig_ident(wildcard_sig)
and sig_ident(traceable_sig) != sig_ident(wildcard_sig)
):
raise TypeError(
f"Signature mismatch between {original_fn} and {traceable_fn}: "
f"{original_sig} != {traceable_sig}"
)
from torch._dynamo.guards import GuardBuilder
from torch._dynamo.trace_rules import get_torch_obj_rule_map
from torch._dynamo.variables import PolyfilledFunctionVariable

View File

@ -120,7 +120,11 @@ def allow_in_graph(fn):
return torch._dynamo.allow_in_graph(fn)
def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
def substitute_in_graph(
original_fn: _F,
*,
skip_signature_check: bool = False,
) -> Callable[[_F], _F]:
"""
Register a polyfill handler for a function, usually a C function from the C extension, to be
used in place of the original function when inlining the original function in the graph.
@ -138,6 +142,8 @@ def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
Args:
original_fn (callable): The original function, usually a C function, to register a polyfill
handler for.
skip_signature_check (bool, optional): Whether to skip the signature check between the
original function and the polyfill handler. Defaults to ``False``.
Returns:
A decorator that registers the polyfill handler for the original function.
@ -165,7 +171,10 @@ def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
"""
import torch._dynamo
return torch._dynamo.substitute_in_graph(original_fn)
return torch._dynamo.substitute_in_graph(
original_fn,
skip_signature_check=skip_signature_check,
)
def list_backends(exclude_tags=("debug", "experimental")) -> List[str]: