[dynamo] ensure polyfill function has the same signature as the original function in substitute_in_graph (#133813)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133813
Approved by: https://github.com/jansel
This commit is contained in:
Xuehai Pan
2024-08-21 16:49:23 +08:00
committed by PyTorch MergeBot
parent 240467adfe
commit c95ddd4bf2
3 changed files with 85 additions and 6 deletions

View File

@ -120,7 +120,11 @@ def allow_in_graph(fn):
return torch._dynamo.allow_in_graph(fn)
def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
def substitute_in_graph(
original_fn: _F,
*,
skip_signature_check: bool = False,
) -> Callable[[_F], _F]:
"""
Register a polyfill handler for a function, usually a C function from the C extension, to be
used in place of the original function when inlining the original function in the graph.
@ -138,6 +142,8 @@ def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
Args:
original_fn (callable): The original function, usually a C function, to register a polyfill
handler for.
skip_signature_check (bool, optional): Whether to skip the signature check between the
original function and the polyfill handler. Defaults to ``False``.
Returns:
A decorator that registers the polyfill handler for the original function.
@ -165,7 +171,10 @@ def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
"""
import torch._dynamo
return torch._dynamo.substitute_in_graph(original_fn)
return torch._dynamo.substitute_in_graph(
original_fn,
skip_signature_check=skip_signature_check,
)
def list_backends(exclude_tags=("debug", "experimental")) -> List[str]: