Files
pytorch/torch/_dynamo/polyfills/builtins.py

123 lines
3.5 KiB
Python

"""
Python polyfills for builtins
"""
from __future__ import annotations
import builtins
import functools
import operator
from typing import Callable, TYPE_CHECKING, TypeVar
from ..decorators import substitute_in_graph
if TYPE_CHECKING:
from collections.abc import Iterable
__all__ = [
"all",
"any",
"enumerate",
"sum",
]
_T = TypeVar("_T")
@substitute_in_graph(builtins.all, can_constant_fold_through=True)
def all(iterable: Iterable[object], /) -> bool:
for elem in iterable:
if not elem:
return False
return True
@substitute_in_graph(builtins.any, can_constant_fold_through=True)
def any(iterable: Iterable[object], /) -> bool:
for elem in iterable:
if elem:
return True
return False
@substitute_in_graph(builtins.enumerate, is_embedded_type=True) # type: ignore[arg-type]
def enumerate(iterable: Iterable[_T], start: int = 0) -> Iterable[tuple[int, _T]]:
if not isinstance(start, int):
raise TypeError(
f"{type(start).__name__!r} object cannot be interpreted as an integer"
)
for x in iterable:
yield start, x
start += 1
@substitute_in_graph(builtins.sum, can_constant_fold_through=True) # type: ignore[arg-type]
def sum(iterable: Iterable[_T], /, start: _T = 0) -> _T: # type: ignore[assignment]
return functools.reduce(operator.add, iterable, start)
class _CallableIterator:
def __init__(self, fn, sentinel): # type: ignore[no-untyped-def]
self.fn = fn
self.sentinel = sentinel
def __iter__(self): # type: ignore[no-untyped-def]
return self
def __next__(self): # type: ignore[no-untyped-def]
# The iterator created in this case will call object with no arguments
# for each call to its __next__() method;
r = self.fn()
# If the value returned is equal to sentinel, StopIteration will be raised
if r == self.sentinel:
raise StopIteration
# otherwise the value will be returned.
return r
class _SENTINEL_MISSING:
pass
# TODO(guilhermeleobas): use substitute_in_graph for iter()
def iter_(fn_or_iterable, sentinel=_SENTINEL_MISSING, /): # type: ignore[no-untyped-def]
# Without a second argument, object must be a collection object which supports
# the iterable (__iter__) or the sequence protocol (__getitem__ with an integer
# starting at 0)
if sentinel is _SENTINEL_MISSING:
iterable = fn_or_iterable
if hasattr(iterable, "__iter__"):
iterator = iterable.__iter__()
if hasattr(iterator, "__next__"):
return iterator
else:
raise TypeError(f"'{type(iterator)}' object is not iterable")
if hasattr(iterable, "__getitem__"):
# Needs to be a new function to avoid iter becoming a generator
def sequence_protocol(iterable): # type: ignore[no-untyped-def]
i = 0
while True:
try:
yield iterable.__getitem__(i)
i += 1
except IndexError:
break
return sequence_protocol(iterable)
raise TypeError(f"'{type(iterable)}' object is not iterable")
else:
# If the second argument, sentinel, is given, then object must be a
# callable object.
fn = fn_or_iterable
if not isinstance(fn, Callable): # type: ignore[arg-type]
raise TypeError("iter(v, w): v must be a callable")
return _CallableIterator(fn, sentinel)