[dynamo] simplify polyfill registration for builtins.all and builtins.any (#133769)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133769
Approved by: https://github.com/jansel
This commit is contained in:
Xuehai Pan
2024-08-30 00:47:02 +08:00
committed by PyTorch MergeBot
parent b977abd5de
commit e09324e7da
8 changed files with 51 additions and 33 deletions

View File

@ -168,6 +168,7 @@ def forbid_in_graph(fn):
def substitute_in_graph(
original_fn: _F,
*,
can_constant_fold_through: bool = False,
skip_signature_check: bool = False,
) -> Callable[[_F], _F]:
"""
@ -187,6 +188,10 @@ def substitute_in_graph(
Args:
original_fn (callable): The original function, usually a C function, to register a polyfill
handler for.
can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant
folded through. That is, if the polyfill handler is a pure function and its arguments
are constant, the result of the polyfill handler can be constant folded during the
compilation. Defaults to ``False``.
skip_signature_check (bool, optional): Whether to skip the signature check between the
original function and the polyfill handler. Defaults to ``False``.
@ -319,6 +324,7 @@ def substitute_in_graph(
wrapped.__torch_dynamo_original__ = original_fn # type: ignore[attr-defined]
wrapped.__torch_dynamo_polyfill__ = traceable_fn # type: ignore[attr-defined]
wrapped.__torch_dynamo_can_constant_fold_through__ = can_constant_fold_through # type: ignore[attr-defined]
return wrapped # type: ignore[return-value]

View File

@ -13,20 +13,6 @@ from typing import Any, Callable, Sequence
import torch
def all(iterator):
for elem in iterator:
if not elem:
return False
return True
def any(iterator):
for elem in iterator:
if elem:
return True
return False
def index(iterator, item, start=0, end=None):
for i, elem in islice(enumerate(iterator), start, end):
if item == elem:

View File

@ -2,5 +2,29 @@
Python polyfills for builtins
"""
import builtins
from typing import Iterable
__all__ = [] # type: ignore[var-annotated]
from ..decorators import substitute_in_graph
__all__ = [
"all",
"any",
]
@substitute_in_graph(builtins.all, can_constant_fold_through=True)
def all(iterable: Iterable[object], /) -> bool:
for elem in iterable:
if not elem:
return False
return True
@substitute_in_graph(builtins.any, can_constant_fold_through=True)
def any(iterable: Iterable[object], /) -> bool:
for elem in iterable:
if elem:
return True
return False

View File

@ -14,7 +14,7 @@ __all__ = ["fspath"]
# Copied from os.py in the standard library
@substitute_in_graph(os.fspath)
@substitute_in_graph(os.fspath, can_constant_fold_through=True)
def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr:
if isinstance(path, (str, bytes)):
return path

View File

@ -2980,6 +2980,7 @@ def _disallowed_callable_ids() -> Dict[int, str]:
@FunctionIdSet
def _builtin_function_ids() -> Dict[int, str]:
# See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids
rv = {
id(v): f"builtins.{k}"
for k, v in builtins.__dict__.items()
@ -3072,6 +3073,7 @@ def is_forbidden(obj) -> bool:
def is_builtin_callable(obj) -> bool:
# See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids
return id(obj) in _builtin_function_ids

View File

@ -16,7 +16,7 @@ import torch
from torch import sym_float, sym_int
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from .. import config, polyfills, variables
from .. import config, variables
from ..exc import (
AttributeMutationError,
unimplemented,
@ -94,19 +94,6 @@ IN_PLACE_DESUGARING_MAP = {
}
def _polyfill_call_impl(name):
"""Create a BuiltinVariable.call_{name} method that inlines through polyfill.{name}"""
def call_fn(self, tx: "InstructionTranslator", *args, **kwargs):
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn), args, kwargs
)
fn = getattr(polyfills, name)
call_fn.__name__ = f"call_{name}"
return call_fn
class BuiltinVariable(VariableTracker):
_SENTINEL = object()
_nonvar_fields = {
@ -2124,9 +2111,6 @@ class BuiltinVariable(VariableTracker):
):
return a.call_method(tx, "__contains__", [b], {})
call_all = _polyfill_call_impl("all")
call_any = _polyfill_call_impl("any")
@contextlib.contextmanager
def dynamo_disable_grad(tx):

View File

@ -16,6 +16,7 @@ from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
from ..utils import (
check_constant_args,
check_unspec_or_constant_args,
identity,
is_wrapper_or_member_descriptor,
istype,
@ -965,6 +966,15 @@ class PolyfilledFunctionVariable(VariableTracker):
handler = self._get_polyfill_handlers().get(self.fn)
if handler:
assert callable(handler)
if getattr(
handler, "__torch_dynamo_can_constant_fold_through__", False
) and check_unspec_or_constant_args(args, kwargs):
return ConstantVariable.create(
self.fn( # use the original function which is faster than the polyfill
*[x.as_python_constant() for x in args],
**{k: v.as_python_constant() for k, v in kwargs.items()},
)
)
return SourcelessBuilder.create(tx, handler).call_function(tx, args, kwargs)
for candidate in ("__torch_dynamo_polyfill__", "__python_implementation__"):

View File

@ -123,6 +123,7 @@ def allow_in_graph(fn):
def substitute_in_graph(
original_fn: _F,
*,
can_constant_fold_through: bool = False,
skip_signature_check: bool = False,
) -> Callable[[_F], _F]:
"""
@ -142,6 +143,10 @@ def substitute_in_graph(
Args:
original_fn (callable): The original function, usually a C function, to register a polyfill
handler for.
can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant
folded through. That is, if the polyfill handler is a pure function and its arguments
are constant, the result of the polyfill handler can be constant folded during the
compilation. Defaults to ``False``.
skip_signature_check (bool, optional): Whether to skip the signature check between the
original function and the polyfill handler. Defaults to ``False``.
@ -173,6 +178,7 @@ def substitute_in_graph(
return torch._dynamo.substitute_in_graph(
original_fn,
can_constant_fold_through=can_constant_fold_through,
skip_signature_check=skip_signature_check,
)