mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[iter] support iter(callable, sentinel)
(#156416)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156416 Approved by: https://github.com/XuehaiPan, https://github.com/zou3519 ghstack dependencies: #156371
This commit is contained in:
committed by
PyTorch MergeBot
parent
fcf59df2b6
commit
4e3e3dc0a7
@ -7,7 +7,7 @@ from __future__ import annotations
|
|||||||
import builtins
|
import builtins
|
||||||
import functools
|
import functools
|
||||||
import operator
|
import operator
|
||||||
from typing import TYPE_CHECKING, TypeVar
|
from typing import Callable, TYPE_CHECKING, TypeVar
|
||||||
|
|
||||||
from ..decorators import substitute_in_graph
|
from ..decorators import substitute_in_graph
|
||||||
|
|
||||||
@ -60,19 +60,58 @@ def sum(iterable: Iterable[_T], /, start: _T = 0) -> _T: # type: ignore[assignm
|
|||||||
return functools.reduce(operator.add, iterable, start)
|
return functools.reduce(operator.add, iterable, start)
|
||||||
|
|
||||||
|
|
||||||
# TODO(guilhermeleobas): use substitute_in_graph for iter()
|
class _CallableIterator:
|
||||||
def iter(iterable): # type: ignore[no-untyped-def]
|
def __init__(self, fn, sentinel): # type: ignore[no-untyped-def]
|
||||||
if hasattr(iterable, "__iter__"):
|
self.fn = fn
|
||||||
return iterable.__iter__()
|
self.sentinel = sentinel
|
||||||
if hasattr(iterable, "__getitem__"):
|
|
||||||
# Needs to be a new function to avoid iter_protocol becoming a generator
|
|
||||||
def sequence_protocol(iterable): # type: ignore[no-untyped-def]
|
|
||||||
i = 0
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
yield iterable.__getitem__(i)
|
|
||||||
i += 1
|
|
||||||
except IndexError:
|
|
||||||
break
|
|
||||||
|
|
||||||
return sequence_protocol(iterable)
|
def __iter__(self): # type: ignore[no-untyped-def]
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self): # type: ignore[no-untyped-def]
|
||||||
|
# The iterator created in this case will call object with no arguments
|
||||||
|
# for each call to its __next__() method;
|
||||||
|
r = self.fn()
|
||||||
|
|
||||||
|
# If the value returned is equal to sentinel, StopIteration will be raised
|
||||||
|
if r == self.sentinel:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
# otherwise the value will be returned.
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
class _SENTINEL_MISSING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(guilhermeleobas): use substitute_in_graph for iter()
|
||||||
|
def iter_(fn_or_iterable, sentinel=_SENTINEL_MISSING, /): # type: ignore[no-untyped-def]
|
||||||
|
# Without a second argument, object must be a collection object which supports
|
||||||
|
# the iterable (__iter__) or the sequence protocol (__getitem__ with an integer
|
||||||
|
# starting at 0)
|
||||||
|
if sentinel is _SENTINEL_MISSING:
|
||||||
|
iterable = fn_or_iterable
|
||||||
|
if hasattr(iterable, "__iter__"):
|
||||||
|
return iterable.__iter__()
|
||||||
|
if hasattr(iterable, "__getitem__"):
|
||||||
|
# Needs to be a new function to avoid iter becoming a generator
|
||||||
|
def sequence_protocol(iterable): # type: ignore[no-untyped-def]
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
yield iterable.__getitem__(i)
|
||||||
|
i += 1
|
||||||
|
except IndexError:
|
||||||
|
break
|
||||||
|
|
||||||
|
return sequence_protocol(iterable)
|
||||||
|
else:
|
||||||
|
# If the second argument, sentinel, is given, then object must be a
|
||||||
|
# callable object.
|
||||||
|
fn = fn_or_iterable
|
||||||
|
|
||||||
|
if not isinstance(fn, Callable): # type: ignore[arg-type]
|
||||||
|
raise TypeError("iter(v, w): v must be a callable")
|
||||||
|
|
||||||
|
return _CallableIterator(fn, sentinel)
|
||||||
|
@ -1810,13 +1810,9 @@ class BuiltinVariable(VariableTracker):
|
|||||||
# (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator.
|
# (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator.
|
||||||
# If the object implements a __getitem__ method, iter(...) will call obj.__getitem__()
|
# If the object implements a __getitem__ method, iter(...) will call obj.__getitem__()
|
||||||
# with an integer argument starting at 0, until __getitem__ raises IndexError
|
# with an integer argument starting at 0, until __getitem__ raises IndexError
|
||||||
return variables.UserFunctionVariable(
|
ret = variables.UserFunctionVariable(
|
||||||
polyfills.builtins.iter
|
polyfills.builtins.iter_
|
||||||
).call_function(
|
).call_function(tx, [obj, *args], {})
|
||||||
tx,
|
|
||||||
[obj],
|
|
||||||
{},
|
|
||||||
)
|
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user