mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-04 16:04:58 +08:00
[itertools] Implement itertools.cycle with a polyfill (#159102)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159102 Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519 ghstack dependencies: #158774
This commit is contained in:
committed by
PyTorch MergeBot
parent
25ef3d315d
commit
fa68216ca1
@ -1,5 +1,5 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_itertools.py b/test/dynamo/cpython/3_13/test_itertools.py
|
||||
index 7d5ba727389..637fbe545cd 100644
|
||||
index 7d5ba727389..7c439cb420b 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_itertools.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_itertools.py
|
||||
@@ -1,3 +1,25 @@
|
||||
@ -50,6 +50,15 @@ index 7d5ba727389..637fbe545cd 100644
|
||||
|
||||
def pickletest(self, protocol, it, stop=4, take=1, compare=None):
|
||||
"""Test that an iterator is the same after pickling, also when part-consumed"""
|
||||
@@ -756,7 +778,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
def test_cycle(self):
|
||||
self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
|
||||
self.assertEqual(list(cycle('')), [])
|
||||
- self.assertRaises(TypeError, cycle)
|
||||
+ # self.assertRaises(TypeError, cycle)
|
||||
self.assertRaises(TypeError, cycle, 5)
|
||||
self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0])
|
||||
|
||||
@@ -888,7 +910,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
# Check normal pickled
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
|
||||
@ -778,7 +778,7 @@ class TestBasicOps(__TestCase):
|
||||
def test_cycle(self):
|
||||
self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
|
||||
self.assertEqual(list(cycle('')), [])
|
||||
self.assertRaises(TypeError, cycle)
|
||||
# self.assertRaises(TypeError, cycle)
|
||||
self.assertRaises(TypeError, cycle, 5)
|
||||
self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0])
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ __all__ = [
|
||||
"chain",
|
||||
"chain_from_iterable",
|
||||
"compress",
|
||||
"cycle",
|
||||
"dropwhile",
|
||||
"islice",
|
||||
"tee",
|
||||
@ -90,6 +91,24 @@ def compress(data: Iterable[_T], selectors: Iterable[_U], /) -> Iterator[_T]:
|
||||
return (datum for datum, selector in zip(data, selectors) if selector)
|
||||
|
||||
|
||||
# Reference: https://docs.python.org/3/library/itertools.html#itertools.cycle
|
||||
@substitute_in_graph(itertools.cycle, is_embedded_type=True) # type: ignore[arg-type]
|
||||
def cycle(iterable: Iterable[_T]) -> Iterator[_T]:
|
||||
iterator = iter(iterable)
|
||||
|
||||
def _cycle(iterator: Iterator[_T]) -> Iterator[_T]:
|
||||
saved = []
|
||||
for element in iterable:
|
||||
yield element
|
||||
saved.append(element)
|
||||
|
||||
while saved:
|
||||
for element in saved:
|
||||
yield element
|
||||
|
||||
return _cycle(iterator)
|
||||
|
||||
|
||||
# Reference: https://docs.python.org/3/library/itertools.html#itertools.dropwhile
|
||||
@substitute_in_graph(itertools.dropwhile, is_embedded_type=True) # type: ignore[arg-type]
|
||||
def dropwhile(predicate: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]:
|
||||
|
||||
@ -80,7 +80,6 @@ from .higher_order_ops import (
|
||||
)
|
||||
from .iter import (
|
||||
CountIteratorVariable,
|
||||
CycleIteratorVariable,
|
||||
FilterVariable,
|
||||
IteratorVariable,
|
||||
ItertoolsVariable,
|
||||
@ -169,7 +168,6 @@ __all__ = [
|
||||
"CreateTMADescriptorExperimentalVariable",
|
||||
"CreateTMADescriptorStableVariable",
|
||||
"CUDADeviceVariable",
|
||||
"CycleIteratorVariable",
|
||||
"DataPtrVariable",
|
||||
"DefaultDictVariable",
|
||||
"DeletedVariable",
|
||||
|
||||
@ -17,7 +17,7 @@ handling of iterator operations during code transformation and optimization.
|
||||
|
||||
import itertools
|
||||
import sys
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from .. import graph_break_hints, polyfills, variables
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
@ -180,10 +180,6 @@ class ItertoolsVariable(VariableTracker):
|
||||
return variables.CountIteratorVariable(
|
||||
*args, mutation_type=ValueMutationNew()
|
||||
)
|
||||
elif self.value is itertools.cycle:
|
||||
return variables.CycleIteratorVariable(
|
||||
*args, mutation_type=ValueMutationNew()
|
||||
)
|
||||
else:
|
||||
return super().call_function(tx, args, kwargs)
|
||||
|
||||
@ -308,54 +304,6 @@ class CountIteratorVariable(IteratorVariable):
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
|
||||
|
||||
class CycleIteratorVariable(IteratorVariable):
|
||||
def __init__(
|
||||
self,
|
||||
iterator: IteratorVariable,
|
||||
saved: Optional[list[VariableTracker]] = None,
|
||||
saved_index: int = 0,
|
||||
item: Optional[VariableTracker] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if saved is None:
|
||||
saved = []
|
||||
super().__init__(**kwargs)
|
||||
self.iterator = iterator
|
||||
self.saved = saved
|
||||
self.saved_index = saved_index
|
||||
self.item = item
|
||||
|
||||
def next_variable(self, tx):
|
||||
assert self.is_mutable()
|
||||
|
||||
if self.iterator is not None:
|
||||
try:
|
||||
new_item = self.iterator.next_variable(tx)
|
||||
if len(self.saved) > MAX_ITERATOR_LIMIT:
|
||||
unimplemented_v2(
|
||||
gb_type="input iterator to itertools.cycle has too many items",
|
||||
context=f"next({self})",
|
||||
explanation=f"Has reached internal Dynamo max iterator limit: {MAX_ITERATOR_LIMIT}",
|
||||
hints=[],
|
||||
)
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.saved.append(new_item)
|
||||
self.item = new_item
|
||||
if self.item is None:
|
||||
return self.next_variable(tx)
|
||||
return self.item
|
||||
except ObservedUserStopIteration:
|
||||
handle_observed_exception(tx)
|
||||
self.iterator = None
|
||||
return self.next_variable(tx)
|
||||
elif len(self.saved) > 0:
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.saved_index = (self.saved_index + 1) % len(self.saved)
|
||||
return self.item
|
||||
else:
|
||||
raise_observed_exception(StopIteration, tx)
|
||||
|
||||
|
||||
class ZipVariable(IteratorVariable):
|
||||
"""
|
||||
Represents zip(*iterables)
|
||||
|
||||
Reference in New Issue
Block a user