[iter] support iter(callable, sentinel) (#156416)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156416
Approved by: https://github.com/XuehaiPan, https://github.com/zou3519
ghstack dependencies: #156371
This commit is contained in:
Guilherme Leobas
2025-07-29 16:11:37 -03:00
committed by PyTorch MergeBot
parent fcf59df2b6
commit 4e3e3dc0a7
6 changed files with 58 additions and 23 deletions

View File

@ -7,7 +7,7 @@ from __future__ import annotations
import builtins import builtins
import functools import functools
import operator import operator
from typing import TYPE_CHECKING, TypeVar from typing import Callable, TYPE_CHECKING, TypeVar
from ..decorators import substitute_in_graph from ..decorators import substitute_in_graph
@ -60,19 +60,58 @@ def sum(iterable: Iterable[_T], /, start: _T = 0) -> _T: # type: ignore[assignm
return functools.reduce(operator.add, iterable, start) return functools.reduce(operator.add, iterable, start)
# TODO(guilhermeleobas): use substitute_in_graph for iter() class _CallableIterator:
def iter(iterable): # type: ignore[no-untyped-def] def __init__(self, fn, sentinel): # type: ignore[no-untyped-def]
if hasattr(iterable, "__iter__"): self.fn = fn
return iterable.__iter__() self.sentinel = sentinel
if hasattr(iterable, "__getitem__"):
# Needs to be a new function to avoid iter_protocol becoming a generator
def sequence_protocol(iterable): # type: ignore[no-untyped-def]
i = 0
while True:
try:
yield iterable.__getitem__(i)
i += 1
except IndexError:
break
return sequence_protocol(iterable) def __iter__(self): # type: ignore[no-untyped-def]
return self
def __next__(self): # type: ignore[no-untyped-def]
# The iterator created in this case will call object with no arguments
# for each call to its __next__() method;
r = self.fn()
# If the value returned is equal to sentinel, StopIteration will be raised
if r == self.sentinel:
raise StopIteration
# otherwise the value will be returned.
return r
class _SENTINEL_MISSING:
pass
# TODO(guilhermeleobas): use substitute_in_graph for iter()
def iter_(fn_or_iterable, sentinel=_SENTINEL_MISSING, /): # type: ignore[no-untyped-def]
# Without a second argument, object must be a collection object which supports
# the iterable (__iter__) or the sequence protocol (__getitem__ with an integer
# starting at 0)
if sentinel is _SENTINEL_MISSING:
iterable = fn_or_iterable
if hasattr(iterable, "__iter__"):
return iterable.__iter__()
if hasattr(iterable, "__getitem__"):
# Needs to be a new function to avoid iter becoming a generator
def sequence_protocol(iterable): # type: ignore[no-untyped-def]
i = 0
while True:
try:
yield iterable.__getitem__(i)
i += 1
except IndexError:
break
return sequence_protocol(iterable)
else:
# If the second argument, sentinel, is given, then object must be a
# callable object.
fn = fn_or_iterable
if not isinstance(fn, Callable): # type: ignore[arg-type]
raise TypeError("iter(v, w): v must be a callable")
return _CallableIterator(fn, sentinel)

View File

@ -1810,13 +1810,9 @@ class BuiltinVariable(VariableTracker):
# (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator. # (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator.
# If the object implements a __getitem__ method, iter(...) will call obj.__getitem__() # If the object implements a __getitem__ method, iter(...) will call obj.__getitem__()
# with an integer argument starting at 0, until __getitem__ raises IndexError # with an integer argument starting at 0, until __getitem__ raises IndexError
return variables.UserFunctionVariable( ret = variables.UserFunctionVariable(
polyfills.builtins.iter polyfills.builtins.iter_
).call_function( ).call_function(tx, [obj, *args], {})
tx,
[obj],
{},
)
return ret return ret