mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[itertools] Implement itertools.cycle with a polyfill (#159102)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159102 Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519 ghstack dependencies: #158774
This commit is contained in:
committed by
PyTorch MergeBot
parent
25ef3d315d
commit
fa68216ca1
@ -17,7 +17,7 @@ handling of iterator operations during code transformation and optimization.
|
||||
|
||||
import itertools
|
||||
import sys
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from .. import graph_break_hints, polyfills, variables
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
@ -180,10 +180,6 @@ class ItertoolsVariable(VariableTracker):
|
||||
return variables.CountIteratorVariable(
|
||||
*args, mutation_type=ValueMutationNew()
|
||||
)
|
||||
elif self.value is itertools.cycle:
|
||||
return variables.CycleIteratorVariable(
|
||||
*args, mutation_type=ValueMutationNew()
|
||||
)
|
||||
else:
|
||||
return super().call_function(tx, args, kwargs)
|
||||
|
||||
@ -308,54 +304,6 @@ class CountIteratorVariable(IteratorVariable):
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
|
||||
|
||||
class CycleIteratorVariable(IteratorVariable):
|
||||
def __init__(
|
||||
self,
|
||||
iterator: IteratorVariable,
|
||||
saved: Optional[list[VariableTracker]] = None,
|
||||
saved_index: int = 0,
|
||||
item: Optional[VariableTracker] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if saved is None:
|
||||
saved = []
|
||||
super().__init__(**kwargs)
|
||||
self.iterator = iterator
|
||||
self.saved = saved
|
||||
self.saved_index = saved_index
|
||||
self.item = item
|
||||
|
||||
def next_variable(self, tx):
|
||||
assert self.is_mutable()
|
||||
|
||||
if self.iterator is not None:
|
||||
try:
|
||||
new_item = self.iterator.next_variable(tx)
|
||||
if len(self.saved) > MAX_ITERATOR_LIMIT:
|
||||
unimplemented_v2(
|
||||
gb_type="input iterator to itertools.cycle has too many items",
|
||||
context=f"next({self})",
|
||||
explanation=f"Has reached internal Dynamo max iterator limit: {MAX_ITERATOR_LIMIT}",
|
||||
hints=[],
|
||||
)
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.saved.append(new_item)
|
||||
self.item = new_item
|
||||
if self.item is None:
|
||||
return self.next_variable(tx)
|
||||
return self.item
|
||||
except ObservedUserStopIteration:
|
||||
handle_observed_exception(tx)
|
||||
self.iterator = None
|
||||
return self.next_variable(tx)
|
||||
elif len(self.saved) > 0:
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.saved_index = (self.saved_index + 1) % len(self.saved)
|
||||
return self.item
|
||||
else:
|
||||
raise_observed_exception(StopIteration, tx)
|
||||
|
||||
|
||||
class ZipVariable(IteratorVariable):
|
||||
"""
|
||||
Represents zip(*iterables)
|
||||
|
||||
Reference in New Issue
Block a user