[itertools] Implement itertools.cycle with a polyfill (#159102)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159102
Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519
ghstack dependencies: #158774
This commit is contained in:
Rob Timpe
2025-07-30 19:14:38 +00:00
committed by PyTorch MergeBot
parent 25ef3d315d
commit fa68216ca1
8 changed files with 31 additions and 57 deletions

View File

@ -17,7 +17,7 @@ handling of iterator operations during code transformation and optimization.
import itertools
import sys
from typing import Optional, TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Union
from .. import graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_instruction
@ -180,10 +180,6 @@ class ItertoolsVariable(VariableTracker):
return variables.CountIteratorVariable(
*args, mutation_type=ValueMutationNew()
)
elif self.value is itertools.cycle:
return variables.CycleIteratorVariable(
*args, mutation_type=ValueMutationNew()
)
else:
return super().call_function(tx, args, kwargs)
@ -308,54 +304,6 @@ class CountIteratorVariable(IteratorVariable):
codegen.extend_output(create_call_function(2, False))
class CycleIteratorVariable(IteratorVariable):
def __init__(
self,
iterator: IteratorVariable,
saved: Optional[list[VariableTracker]] = None,
saved_index: int = 0,
item: Optional[VariableTracker] = None,
**kwargs,
) -> None:
if saved is None:
saved = []
super().__init__(**kwargs)
self.iterator = iterator
self.saved = saved
self.saved_index = saved_index
self.item = item
def next_variable(self, tx):
assert self.is_mutable()
if self.iterator is not None:
try:
new_item = self.iterator.next_variable(tx)
if len(self.saved) > MAX_ITERATOR_LIMIT:
unimplemented_v2(
gb_type="input iterator to itertools.cycle has too many items",
context=f"next({self})",
explanation=f"Has reached internal Dynamo max iterator limit: {MAX_ITERATOR_LIMIT}",
hints=[],
)
tx.output.side_effects.mutation(self)
self.saved.append(new_item)
self.item = new_item
if self.item is None:
return self.next_variable(tx)
return self.item
except ObservedUserStopIteration:
handle_observed_exception(tx)
self.iterator = None
return self.next_variable(tx)
elif len(self.saved) > 0:
tx.output.side_effects.mutation(self)
self.saved_index = (self.saved_index + 1) % len(self.saved)
return self.item
else:
raise_observed_exception(StopIteration, tx)
class ZipVariable(IteratorVariable):
"""
Represents zip(*iterables)