[itertools] Implement itertools.cycle with a polyfill (#159102)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159102
Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519
ghstack dependencies: #158774
This commit is contained in:
Rob Timpe
2025-07-30 19:14:38 +00:00
committed by PyTorch MergeBot
parent 25ef3d315d
commit fa68216ca1
8 changed files with 31 additions and 57 deletions

View File

@ -1,5 +1,5 @@
diff --git a/test/dynamo/cpython/3_13/test_itertools.py b/test/dynamo/cpython/3_13/test_itertools.py
index 7d5ba727389..637fbe545cd 100644
index 7d5ba727389..7c439cb420b 100644
--- a/test/dynamo/cpython/3_13/test_itertools.py
+++ b/test/dynamo/cpython/3_13/test_itertools.py
@@ -1,3 +1,25 @@
@ -50,6 +50,15 @@ index 7d5ba727389..637fbe545cd 100644
def pickletest(self, protocol, it, stop=4, take=1, compare=None):
"""Test that an iterator is the same after pickling, also when part-consumed"""
@@ -756,7 +778,7 @@ class TestBasicOps(unittest.TestCase):
def test_cycle(self):
self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
self.assertEqual(list(cycle('')), [])
- self.assertRaises(TypeError, cycle)
+ # self.assertRaises(TypeError, cycle)
self.assertRaises(TypeError, cycle, 5)
self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0])
@@ -888,7 +910,7 @@ class TestBasicOps(unittest.TestCase):
# Check normal pickled
for proto in range(pickle.HIGHEST_PROTOCOL + 1):

View File

@ -778,7 +778,7 @@ class TestBasicOps(__TestCase):
def test_cycle(self):
self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
self.assertEqual(list(cycle('')), [])
self.assertRaises(TypeError, cycle)
# self.assertRaises(TypeError, cycle)
self.assertRaises(TypeError, cycle, 5)
self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0])

View File

@ -22,6 +22,7 @@ __all__ = [
"chain",
"chain_from_iterable",
"compress",
"cycle",
"dropwhile",
"islice",
"tee",
@ -90,6 +91,24 @@ def compress(data: Iterable[_T], selectors: Iterable[_U], /) -> Iterator[_T]:
return (datum for datum, selector in zip(data, selectors) if selector)
# Reference: https://docs.python.org/3/library/itertools.html#itertools.cycle
@substitute_in_graph(itertools.cycle, is_embedded_type=True) # type: ignore[arg-type]
def cycle(iterable: Iterable[_T]) -> Iterator[_T]:
iterator = iter(iterable)
def _cycle(iterator: Iterator[_T]) -> Iterator[_T]:
saved = []
for element in iterable:
yield element
saved.append(element)
while saved:
for element in saved:
yield element
return _cycle(iterator)
# Reference: https://docs.python.org/3/library/itertools.html#itertools.dropwhile
@substitute_in_graph(itertools.dropwhile, is_embedded_type=True) # type: ignore[arg-type]
def dropwhile(predicate: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]:

View File

@ -80,7 +80,6 @@ from .higher_order_ops import (
)
from .iter import (
CountIteratorVariable,
CycleIteratorVariable,
FilterVariable,
IteratorVariable,
ItertoolsVariable,
@ -169,7 +168,6 @@ __all__ = [
"CreateTMADescriptorExperimentalVariable",
"CreateTMADescriptorStableVariable",
"CUDADeviceVariable",
"CycleIteratorVariable",
"DataPtrVariable",
"DefaultDictVariable",
"DeletedVariable",

View File

@ -17,7 +17,7 @@ handling of iterator operations during code transformation and optimization.
import itertools
import sys
from typing import Optional, TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Union
from .. import graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_instruction
@ -180,10 +180,6 @@ class ItertoolsVariable(VariableTracker):
return variables.CountIteratorVariable(
*args, mutation_type=ValueMutationNew()
)
elif self.value is itertools.cycle:
return variables.CycleIteratorVariable(
*args, mutation_type=ValueMutationNew()
)
else:
return super().call_function(tx, args, kwargs)
@ -308,54 +304,6 @@ class CountIteratorVariable(IteratorVariable):
codegen.extend_output(create_call_function(2, False))
class CycleIteratorVariable(IteratorVariable):
def __init__(
self,
iterator: IteratorVariable,
saved: Optional[list[VariableTracker]] = None,
saved_index: int = 0,
item: Optional[VariableTracker] = None,
**kwargs,
) -> None:
if saved is None:
saved = []
super().__init__(**kwargs)
self.iterator = iterator
self.saved = saved
self.saved_index = saved_index
self.item = item
def next_variable(self, tx):
assert self.is_mutable()
if self.iterator is not None:
try:
new_item = self.iterator.next_variable(tx)
if len(self.saved) > MAX_ITERATOR_LIMIT:
unimplemented_v2(
gb_type="input iterator to itertools.cycle has too many items",
context=f"next({self})",
explanation=f"Has reached internal Dynamo max iterator limit: {MAX_ITERATOR_LIMIT}",
hints=[],
)
tx.output.side_effects.mutation(self)
self.saved.append(new_item)
self.item = new_item
if self.item is None:
return self.next_variable(tx)
return self.item
except ObservedUserStopIteration:
handle_observed_exception(tx)
self.iterator = None
return self.next_variable(tx)
elif len(self.saved) > 0:
tx.output.side_effects.mutation(self)
self.saved_index = (self.saved_index + 1) % len(self.saved)
return self.item
else:
raise_observed_exception(StopIteration, tx)
class ZipVariable(IteratorVariable):
"""
Represents zip(*iterables)