[dynamo][BE] move dropwhile polyfill to submodule polyfills.itertools (#144066)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144066
Approved by: https://github.com/jansel
This commit is contained in:
Xuehai Pan
2025-01-02 18:22:36 +08:00
committed by PyTorch MergeBot
parent 00df63f09f
commit fb1beb31d2
3 changed files with 25 additions and 24 deletions

View File

@ -122,16 +122,6 @@ def set_difference(set1, set2):
return difference_set
def dropwhile(predicate, iterable):
# dropwhile(lambda x: x<5, [1,4,6,4,1]) -> 6 4 1
iterable = iter(iterable)
for x in iterable:
if not predicate(x):
yield x
break
yield from iterable
def zip_longest(*iterables, fillvalue=None):
# Create a list of iterators from the input iterables
iterators = [iter(it) for it in iterables]

View File

@ -6,7 +6,8 @@ from __future__ import annotations
import itertools
import sys
from typing import Generator, Iterable, Iterator, TypeVar
from typing import Callable, Iterable, Iterator, TypeVar
from typing_extensions import TypeAlias
from ..decorators import substitute_in_graph
@ -14,14 +15,16 @@ from ..decorators import substitute_in_graph
__all__ = [
"chain",
"chain_from_iterable",
"compress",
"dropwhile",
"islice",
"tee",
"compress",
]
_T = TypeVar("_T")
_U = TypeVar("_U")
_Predicate: TypeAlias = Callable[[_T], object]
# Reference: https://docs.python.org/3/library/itertools.html#itertools.chain
@ -39,6 +42,26 @@ def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]:
chain.from_iterable = chain_from_iterable # type: ignore[method-assign]
# Reference: https://docs.python.org/3/library/itertools.html#itertools.compress
@substitute_in_graph(itertools.compress, is_embedded_type=True) # type: ignore[arg-type]
def compress(data: Iterable[_T], selectors: Iterable[_U], /) -> Iterator[_T]:
return (datum for datum, selector in zip(data, selectors) if selector)
# Reference: https://docs.python.org/3/library/itertools.html#itertools.dropwhile
@substitute_in_graph(itertools.dropwhile, is_embedded_type=True) # type: ignore[arg-type]
def dropwhile(predicate: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]:
# dropwhile(lambda x: x < 5, [1, 4, 6, 3, 8]) -> 6 3 8
iterator = iter(iterable)
for x in iterator:
if not predicate(x):
yield x
break
yield from iterator
# Reference: https://docs.python.org/3/library/itertools.html#itertools.islice
@substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type]
def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]:
@ -103,11 +126,3 @@ def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:
return
return tuple(_tee(shared_link) for _ in range(n))
# Reference: https://docs.python.org/3/library/itertools.html#itertools.compress
@substitute_in_graph(itertools.compress, is_embedded_type=True) # type: ignore[arg-type]
def compress(
data: Iterable[_T], selectors: Iterable[_U], /
) -> Generator[_T, None, None]:
return (datum for datum, selector in zip(data, selectors) if selector)

View File

@ -191,10 +191,6 @@ class ItertoolsVariable(VariableTracker):
return variables.CycleIteratorVariable(
*args, mutation_type=ValueMutationNew()
)
elif self.value is itertools.dropwhile:
return variables.UserFunctionVariable(polyfills.dropwhile).call_function(
tx, args, kwargs
)
elif self.value is itertools.zip_longest:
return variables.UserFunctionVariable(polyfills.zip_longest).call_function(
tx, args, kwargs