mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-28 10:34:54 +08:00
[dynamo] ensure polyfill function has the same signature as the original function in substitute_in_graph (#133813)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133813 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
240467adfe
commit
c95ddd4bf2
@ -264,13 +264,22 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
torch._dynamo.reset()
|
||||
counters.clear()
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "Signature mismatch"):
|
||||
|
||||
@torch._dynamo.substitute_in_graph(operator.indexOf)
|
||||
def polyfill(sequence, x):
|
||||
def _(sequence, x):
|
||||
for i, item in enumerate(sequence):
|
||||
if item is x or item == x:
|
||||
return i
|
||||
raise ValueError("sequence.index(x): x not in sequence")
|
||||
|
||||
@torch._dynamo.substitute_in_graph(operator.indexOf)
|
||||
def polyfill(a, b):
|
||||
for i, item in enumerate(a):
|
||||
if item is b or item == b:
|
||||
return i
|
||||
raise ValueError("sequence.index(x): x not in sequence")
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = operator.indexOf
|
||||
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# ruff: noqa: TCH004
|
||||
import functools
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, TYPE_CHECKING, TypeVar
|
||||
|
||||
@ -164,7 +165,11 @@ def forbid_in_graph(fn):
|
||||
return fn
|
||||
|
||||
|
||||
def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
|
||||
def substitute_in_graph(
|
||||
original_fn: _F,
|
||||
*,
|
||||
skip_signature_check: bool = False,
|
||||
) -> Callable[[_F], _F]:
|
||||
"""
|
||||
Register a polyfill handler for a function, usually a C function from the C extension, to be
|
||||
used in place of the original function when inlining the original function in the graph.
|
||||
@ -182,6 +187,8 @@ def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
|
||||
Args:
|
||||
original_fn (callable): The original function, usually a C function, to register a polyfill
|
||||
handler for.
|
||||
skip_signature_check (bool, optional): Whether to skip the signature check between the
|
||||
original function and the polyfill handler. Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
A decorator that registers the polyfill handler for the original function.
|
||||
@ -213,6 +220,60 @@ def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
|
||||
)
|
||||
|
||||
def wrapper(traceable_fn: _F) -> _F:
|
||||
if not is_function(traceable_fn):
|
||||
raise TypeError(
|
||||
f"@substitute_in_graph(...) expects a function but got {type(traceable_fn)!r}"
|
||||
)
|
||||
|
||||
if not skip_signature_check:
|
||||
try:
|
||||
original_sig = inspect.signature(original_fn)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
traceable_sig = inspect.signature(traceable_fn)
|
||||
|
||||
def sig_ident(sig):
|
||||
# Ignore annotations for parameters and return type
|
||||
return (
|
||||
tuple(
|
||||
p.name
|
||||
for p in sig.parameters.values()
|
||||
if (
|
||||
p.kind
|
||||
not in {
|
||||
p.KEYWORD_ONLY,
|
||||
# the name of *args and **kwargs is not important
|
||||
p.VAR_POSITIONAL,
|
||||
p.VAR_KEYWORD,
|
||||
}
|
||||
)
|
||||
),
|
||||
{
|
||||
p.name
|
||||
for p in sig.parameters.values()
|
||||
if p.kind == p.KEYWORD_ONLY
|
||||
},
|
||||
{
|
||||
p.name: p.default
|
||||
for p in sig.parameters.values()
|
||||
# the name of *args and **kwargs is not important
|
||||
if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
|
||||
},
|
||||
)
|
||||
|
||||
wildcard_sig = inspect.signature(lambda *args, **kwargs: None)
|
||||
|
||||
if (
|
||||
sig_ident(original_sig) != sig_ident(traceable_sig)
|
||||
and sig_ident(original_sig) != sig_ident(wildcard_sig)
|
||||
and sig_ident(traceable_sig) != sig_ident(wildcard_sig)
|
||||
):
|
||||
raise TypeError(
|
||||
f"Signature mismatch between {original_fn} and {traceable_fn}: "
|
||||
f"{original_sig} != {traceable_sig}"
|
||||
)
|
||||
|
||||
from torch._dynamo.guards import GuardBuilder
|
||||
from torch._dynamo.trace_rules import get_torch_obj_rule_map
|
||||
from torch._dynamo.variables import PolyfilledFunctionVariable
|
||||
|
||||
@ -120,7 +120,11 @@ def allow_in_graph(fn):
|
||||
return torch._dynamo.allow_in_graph(fn)
|
||||
|
||||
|
||||
def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
|
||||
def substitute_in_graph(
|
||||
original_fn: _F,
|
||||
*,
|
||||
skip_signature_check: bool = False,
|
||||
) -> Callable[[_F], _F]:
|
||||
"""
|
||||
Register a polyfill handler for a function, usually a C function from the C extension, to be
|
||||
used in place of the original function when inlining the original function in the graph.
|
||||
@ -138,6 +142,8 @@ def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
|
||||
Args:
|
||||
original_fn (callable): The original function, usually a C function, to register a polyfill
|
||||
handler for.
|
||||
skip_signature_check (bool, optional): Whether to skip the signature check between the
|
||||
original function and the polyfill handler. Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
A decorator that registers the polyfill handler for the original function.
|
||||
@ -165,7 +171,10 @@ def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
|
||||
"""
|
||||
import torch._dynamo
|
||||
|
||||
return torch._dynamo.substitute_in_graph(original_fn)
|
||||
return torch._dynamo.substitute_in_graph(
|
||||
original_fn,
|
||||
skip_signature_check=skip_signature_check,
|
||||
)
|
||||
|
||||
|
||||
def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
|
||||
|
||||
Reference in New Issue
Block a user