mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158774 Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519
123 lines
3.5 KiB
Python
123 lines
3.5 KiB
Python
"""
|
|
Python polyfills for builtins
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import builtins
|
|
import functools
|
|
import operator
|
|
from typing import Callable, TYPE_CHECKING, TypeVar
|
|
|
|
from ..decorators import substitute_in_graph
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Iterable
|
|
|
|
|
|
__all__ = [
|
|
"all",
|
|
"any",
|
|
"enumerate",
|
|
"sum",
|
|
]
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
|
|
@substitute_in_graph(builtins.all, can_constant_fold_through=True)
|
|
def all(iterable: Iterable[object], /) -> bool:
|
|
for elem in iterable:
|
|
if not elem:
|
|
return False
|
|
return True
|
|
|
|
|
|
@substitute_in_graph(builtins.any, can_constant_fold_through=True)
|
|
def any(iterable: Iterable[object], /) -> bool:
|
|
for elem in iterable:
|
|
if elem:
|
|
return True
|
|
return False
|
|
|
|
|
|
@substitute_in_graph(builtins.enumerate, is_embedded_type=True) # type: ignore[arg-type]
|
|
def enumerate(iterable: Iterable[_T], start: int = 0) -> Iterable[tuple[int, _T]]:
|
|
if not isinstance(start, int):
|
|
raise TypeError(
|
|
f"{type(start).__name__!r} object cannot be interpreted as an integer"
|
|
)
|
|
|
|
for x in iterable:
|
|
yield start, x
|
|
start += 1
|
|
|
|
|
|
@substitute_in_graph(builtins.sum, can_constant_fold_through=True) # type: ignore[arg-type]
|
|
def sum(iterable: Iterable[_T], /, start: _T = 0) -> _T: # type: ignore[assignment]
|
|
return functools.reduce(operator.add, iterable, start)
|
|
|
|
|
|
class _CallableIterator:
|
|
def __init__(self, fn, sentinel): # type: ignore[no-untyped-def]
|
|
self.fn = fn
|
|
self.sentinel = sentinel
|
|
|
|
def __iter__(self): # type: ignore[no-untyped-def]
|
|
return self
|
|
|
|
def __next__(self): # type: ignore[no-untyped-def]
|
|
# The iterator created in this case will call object with no arguments
|
|
# for each call to its __next__() method;
|
|
r = self.fn()
|
|
|
|
# If the value returned is equal to sentinel, StopIteration will be raised
|
|
if r == self.sentinel:
|
|
raise StopIteration
|
|
|
|
# otherwise the value will be returned.
|
|
return r
|
|
|
|
|
|
class _SENTINEL_MISSING:
|
|
pass
|
|
|
|
|
|
# TODO(guilhermeleobas): use substitute_in_graph for iter()
|
|
def iter_(fn_or_iterable, sentinel=_SENTINEL_MISSING, /): # type: ignore[no-untyped-def]
|
|
# Without a second argument, object must be a collection object which supports
|
|
# the iterable (__iter__) or the sequence protocol (__getitem__ with an integer
|
|
# starting at 0)
|
|
if sentinel is _SENTINEL_MISSING:
|
|
iterable = fn_or_iterable
|
|
if hasattr(iterable, "__iter__"):
|
|
iterator = iterable.__iter__()
|
|
if hasattr(iterator, "__next__"):
|
|
return iterator
|
|
else:
|
|
raise TypeError(f"'{type(iterator)}' object is not iterable")
|
|
if hasattr(iterable, "__getitem__"):
|
|
# Needs to be a new function to avoid iter becoming a generator
|
|
def sequence_protocol(iterable): # type: ignore[no-untyped-def]
|
|
i = 0
|
|
while True:
|
|
try:
|
|
yield iterable.__getitem__(i)
|
|
i += 1
|
|
except IndexError:
|
|
break
|
|
|
|
return sequence_protocol(iterable)
|
|
raise TypeError(f"'{type(iterable)}' object is not iterable")
|
|
else:
|
|
# If the second argument, sentinel, is given, then object must be a
|
|
# callable object.
|
|
fn = fn_or_iterable
|
|
|
|
if not isinstance(fn, Callable): # type: ignore[arg-type]
|
|
raise TypeError("iter(v, w): v must be a callable")
|
|
|
|
return _CallableIterator(fn, sentinel)
|