mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[dynamo] simplify polyfill registration for builtins.all
and builtins.any
(#133769)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133769 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
b977abd5de
commit
e09324e7da
@ -168,6 +168,7 @@ def forbid_in_graph(fn):
|
||||
def substitute_in_graph(
|
||||
original_fn: _F,
|
||||
*,
|
||||
can_constant_fold_through: bool = False,
|
||||
skip_signature_check: bool = False,
|
||||
) -> Callable[[_F], _F]:
|
||||
"""
|
||||
@ -187,6 +188,10 @@ def substitute_in_graph(
|
||||
Args:
|
||||
original_fn (callable): The original function, usually a C function, to register a polyfill
|
||||
handler for.
|
||||
can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant
|
||||
folded through. That is, if the polyfill handler is a pure function and its arguments
|
||||
are constant, the result of the polyfill handler can be constant folded during the
|
||||
compilation. Defaults to ``False``.
|
||||
skip_signature_check (bool, optional): Whether to skip the signature check between the
|
||||
original function and the polyfill handler. Defaults to ``False``.
|
||||
|
||||
@ -319,6 +324,7 @@ def substitute_in_graph(
|
||||
|
||||
wrapped.__torch_dynamo_original__ = original_fn # type: ignore[attr-defined]
|
||||
wrapped.__torch_dynamo_polyfill__ = traceable_fn # type: ignore[attr-defined]
|
||||
wrapped.__torch_dynamo_can_constant_fold_through__ = can_constant_fold_through # type: ignore[attr-defined]
|
||||
|
||||
return wrapped # type: ignore[return-value]
|
||||
|
||||
|
@ -13,20 +13,6 @@ from typing import Any, Callable, Sequence
|
||||
import torch
|
||||
|
||||
|
||||
def all(iterator):
|
||||
for elem in iterator:
|
||||
if not elem:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def any(iterator):
|
||||
for elem in iterator:
|
||||
if elem:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def index(iterator, item, start=0, end=None):
|
||||
for i, elem in islice(enumerate(iterator), start, end):
|
||||
if item == elem:
|
||||
|
@ -2,5 +2,29 @@
|
||||
Python polyfills for builtins
|
||||
"""
|
||||
|
||||
import builtins
|
||||
from typing import Iterable
|
||||
|
||||
__all__ = [] # type: ignore[var-annotated]
|
||||
from ..decorators import substitute_in_graph
|
||||
|
||||
|
||||
__all__ = [
|
||||
"all",
|
||||
"any",
|
||||
]
|
||||
|
||||
|
||||
@substitute_in_graph(builtins.all, can_constant_fold_through=True)
|
||||
def all(iterable: Iterable[object], /) -> bool:
|
||||
for elem in iterable:
|
||||
if not elem:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@substitute_in_graph(builtins.any, can_constant_fold_through=True)
|
||||
def any(iterable: Iterable[object], /) -> bool:
|
||||
for elem in iterable:
|
||||
if elem:
|
||||
return True
|
||||
return False
|
||||
|
@ -14,7 +14,7 @@ __all__ = ["fspath"]
|
||||
|
||||
|
||||
# Copied from os.py in the standard library
|
||||
@substitute_in_graph(os.fspath)
|
||||
@substitute_in_graph(os.fspath, can_constant_fold_through=True)
|
||||
def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr:
|
||||
if isinstance(path, (str, bytes)):
|
||||
return path
|
||||
|
@ -2980,6 +2980,7 @@ def _disallowed_callable_ids() -> Dict[int, str]:
|
||||
|
||||
@FunctionIdSet
|
||||
def _builtin_function_ids() -> Dict[int, str]:
|
||||
# See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids
|
||||
rv = {
|
||||
id(v): f"builtins.{k}"
|
||||
for k, v in builtins.__dict__.items()
|
||||
@ -3072,6 +3073,7 @@ def is_forbidden(obj) -> bool:
|
||||
|
||||
|
||||
def is_builtin_callable(obj) -> bool:
|
||||
# See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids
|
||||
return id(obj) in _builtin_function_ids
|
||||
|
||||
|
||||
|
@ -16,7 +16,7 @@ import torch
|
||||
from torch import sym_float, sym_int
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
from .. import config, polyfills, variables
|
||||
from .. import config, variables
|
||||
from ..exc import (
|
||||
AttributeMutationError,
|
||||
unimplemented,
|
||||
@ -94,19 +94,6 @@ IN_PLACE_DESUGARING_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def _polyfill_call_impl(name):
|
||||
"""Create a BuiltinVariable.call_{name} method that inlines through polyfill.{name}"""
|
||||
|
||||
def call_fn(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
return tx.inline_user_function_return(
|
||||
variables.UserFunctionVariable(fn), args, kwargs
|
||||
)
|
||||
|
||||
fn = getattr(polyfills, name)
|
||||
call_fn.__name__ = f"call_{name}"
|
||||
return call_fn
|
||||
|
||||
|
||||
class BuiltinVariable(VariableTracker):
|
||||
_SENTINEL = object()
|
||||
_nonvar_fields = {
|
||||
@ -2124,9 +2111,6 @@ class BuiltinVariable(VariableTracker):
|
||||
):
|
||||
return a.call_method(tx, "__contains__", [b], {})
|
||||
|
||||
call_all = _polyfill_call_impl("all")
|
||||
call_any = _polyfill_call_impl("any")
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def dynamo_disable_grad(tx):
|
||||
|
@ -16,6 +16,7 @@ from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
|
||||
from ..utils import (
|
||||
check_constant_args,
|
||||
check_unspec_or_constant_args,
|
||||
identity,
|
||||
is_wrapper_or_member_descriptor,
|
||||
istype,
|
||||
@ -965,6 +966,15 @@ class PolyfilledFunctionVariable(VariableTracker):
|
||||
handler = self._get_polyfill_handlers().get(self.fn)
|
||||
if handler:
|
||||
assert callable(handler)
|
||||
if getattr(
|
||||
handler, "__torch_dynamo_can_constant_fold_through__", False
|
||||
) and check_unspec_or_constant_args(args, kwargs):
|
||||
return ConstantVariable.create(
|
||||
self.fn( # use the original function which is faster than the polyfill
|
||||
*[x.as_python_constant() for x in args],
|
||||
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
||||
)
|
||||
)
|
||||
return SourcelessBuilder.create(tx, handler).call_function(tx, args, kwargs)
|
||||
|
||||
for candidate in ("__torch_dynamo_polyfill__", "__python_implementation__"):
|
||||
|
@ -123,6 +123,7 @@ def allow_in_graph(fn):
|
||||
def substitute_in_graph(
|
||||
original_fn: _F,
|
||||
*,
|
||||
can_constant_fold_through: bool = False,
|
||||
skip_signature_check: bool = False,
|
||||
) -> Callable[[_F], _F]:
|
||||
"""
|
||||
@ -142,6 +143,10 @@ def substitute_in_graph(
|
||||
Args:
|
||||
original_fn (callable): The original function, usually a C function, to register a polyfill
|
||||
handler for.
|
||||
can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant
|
||||
folded through. That is, if the polyfill handler is a pure function and its arguments
|
||||
are constant, the result of the polyfill handler can be constant folded during the
|
||||
compilation. Defaults to ``False``.
|
||||
skip_signature_check (bool, optional): Whether to skip the signature check between the
|
||||
original function and the polyfill handler. Defaults to ``False``.
|
||||
|
||||
@ -173,6 +178,7 @@ def substitute_in_graph(
|
||||
|
||||
return torch._dynamo.substitute_in_graph(
|
||||
original_fn,
|
||||
can_constant_fold_through=can_constant_fold_through,
|
||||
skip_signature_check=skip_signature_check,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user