mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] reland map/zip iterator related changes (#135074)
Differential Revision: [D62211019](https://our.internmc.facebook.com/intern/diff/D62211019) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135074 Approved by: https://github.com/jansel, https://github.com/anijain2305, https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
22e1fb6faa
commit
a4030e37be
@ -2,14 +2,17 @@
|
||||
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
import sys
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from .. import polyfills, variables
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..exc import (
|
||||
handle_observed_exception,
|
||||
ObservedUserStopIteration,
|
||||
raise_observed_exception,
|
||||
unimplemented,
|
||||
UserError,
|
||||
)
|
||||
from .base import MutableLocal, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
@ -197,6 +200,25 @@ class IteratorVariable(VariableTracker):
|
||||
def next_variable(self, tx):
|
||||
unimplemented("abstract method, must implement")
|
||||
|
||||
# NOTE: only call when unpacking this iterator safely done eagerly!
|
||||
# Normally, iterators are accessed lazily.
|
||||
# Example of safe eager unpacking: list(map(f, seq))
|
||||
# Example of unsafe eager unpacking: list(islice(map(f, seq), 5))
|
||||
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
|
||||
result = []
|
||||
while True:
|
||||
try:
|
||||
result.append(self.next_variable(tx))
|
||||
except ObservedUserStopIteration:
|
||||
handle_observed_exception(tx)
|
||||
break
|
||||
return result
|
||||
|
||||
# don't call force_unpack_var_sequence since it can mutate
|
||||
# IteratorVariable state!
|
||||
def has_force_unpack_var_sequence(self, tx) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class RepeatIteratorVariable(IteratorVariable):
|
||||
def __init__(self, item: VariableTracker, **kwargs) -> None:
|
||||
@ -207,6 +229,18 @@ class RepeatIteratorVariable(IteratorVariable):
|
||||
def next_variable(self, tx):
|
||||
return self.item
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
[
|
||||
codegen.create_load_python_module(itertools),
|
||||
codegen.create_load_attr("repeat"),
|
||||
]
|
||||
)
|
||||
)
|
||||
codegen(self.item)
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
||||
|
||||
class CountIteratorVariable(IteratorVariable):
|
||||
def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None:
|
||||
@ -220,10 +254,23 @@ class CountIteratorVariable(IteratorVariable):
|
||||
|
||||
def next_variable(self, tx):
|
||||
assert self.mutable_local
|
||||
old_item = self.item
|
||||
tx.output.side_effects.mutation(self)
|
||||
next_item = self.item.call_method(tx, "__add__", [self.step], {})
|
||||
self.item = next_item
|
||||
return self.item
|
||||
self.item = self.item.call_method(tx, "__add__", [self.step], {})
|
||||
return old_item
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
[
|
||||
codegen.create_load_python_module(itertools),
|
||||
codegen.create_load_attr("count"),
|
||||
]
|
||||
)
|
||||
)
|
||||
codegen(self.item)
|
||||
codegen(self.step)
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
|
||||
|
||||
class CycleIteratorVariable(IteratorVariable):
|
||||
@ -269,3 +316,160 @@ class CycleIteratorVariable(IteratorVariable):
|
||||
return self.item
|
||||
else:
|
||||
raise_observed_exception(StopIteration, tx, self)
|
||||
|
||||
|
||||
class ZipVariable(IteratorVariable):
|
||||
"""
|
||||
Represents zip(*iterables)
|
||||
"""
|
||||
|
||||
_nonvar_fields = {
|
||||
"index",
|
||||
"strict",
|
||||
*IteratorVariable._nonvar_fields,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
iterables: List[Union[List[VariableTracker], VariableTracker]],
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(iterables, list)
|
||||
# can be list[Variable] or VariableTracker (with next_variable implemented)
|
||||
self.iterables = iterables
|
||||
self.index = 0
|
||||
self.strict = strict
|
||||
|
||||
def python_type(self):
|
||||
return zip
|
||||
|
||||
def has_unpack_var_sequence(self, tx) -> bool:
|
||||
return all(
|
||||
isinstance(it, list) or it.has_unpack_var_sequence(tx)
|
||||
for it in self.iterables
|
||||
)
|
||||
|
||||
def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
|
||||
assert self.has_unpack_var_sequence(tx)
|
||||
iterables = []
|
||||
for it in self.iterables:
|
||||
if isinstance(it, list):
|
||||
iterables.append(it[self.index :])
|
||||
else:
|
||||
iterables.append(it.unpack_var_sequence(tx))
|
||||
kwargs = {"strict": self.strict} if self.strict else {}
|
||||
zipped = zip(*iterables, **kwargs)
|
||||
return [variables.TupleVariable(list(var)) for var in zipped]
|
||||
|
||||
def next_variable(self, tx):
|
||||
assert self.mutable_local
|
||||
old_index = self.index
|
||||
args = []
|
||||
|
||||
def get_item(it):
|
||||
if isinstance(it, list):
|
||||
if old_index >= len(it):
|
||||
raise_observed_exception(StopIteration, tx, self)
|
||||
return it[old_index]
|
||||
else:
|
||||
return it.next_variable(tx)
|
||||
|
||||
try:
|
||||
for idx, it in enumerate(self.iterables):
|
||||
args.append(get_item(it))
|
||||
except ObservedUserStopIteration:
|
||||
if self.strict:
|
||||
if idx == 0:
|
||||
# all other iterables should be exhausted
|
||||
for it in self.iterables:
|
||||
try:
|
||||
get_item(it)
|
||||
except ObservedUserStopIteration:
|
||||
handle_observed_exception(tx)
|
||||
continue
|
||||
# no ObservedUserStopIteration - fall through to UserError
|
||||
break
|
||||
else:
|
||||
# all iterables exhausted, raise original error
|
||||
raise
|
||||
handle_observed_exception(tx)
|
||||
raise UserError(
|
||||
ValueError,
|
||||
"zip() has one argument of len differing from others",
|
||||
) from None
|
||||
raise
|
||||
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.index += 1
|
||||
return variables.TupleVariable(args)
|
||||
|
||||
def reconstruct_items(self, codegen):
|
||||
for it in self.iterables:
|
||||
if isinstance(it, list):
|
||||
remaining_items = it[self.index :]
|
||||
codegen.foreach(remaining_items)
|
||||
codegen.append_output(
|
||||
create_instruction("BUILD_TUPLE", arg=len(remaining_items))
|
||||
)
|
||||
else:
|
||||
codegen(it)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True
|
||||
)
|
||||
self.reconstruct_items(codegen)
|
||||
codegen.append_output(
|
||||
create_instruction("BUILD_TUPLE", arg=len(self.iterables))
|
||||
)
|
||||
if sys.version_info >= (3, 10):
|
||||
codegen.extend_output(
|
||||
[
|
||||
codegen.create_load_const("strict"),
|
||||
codegen.create_load_const(self.strict),
|
||||
create_instruction("BUILD_MAP", arg=1),
|
||||
create_instruction("CALL_FUNCTION_EX", arg=1),
|
||||
]
|
||||
)
|
||||
else:
|
||||
codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0))
|
||||
|
||||
|
||||
class MapVariable(ZipVariable):
|
||||
"""
|
||||
Represents map(fn, *iterables)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fn: VariableTracker,
|
||||
iterables: List[Union[List[VariableTracker], VariableTracker]],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(iterables, **kwargs)
|
||||
self.fn = fn
|
||||
|
||||
def python_type(self):
|
||||
return map
|
||||
|
||||
def has_unpack_var_sequence(self, tx) -> bool:
|
||||
return False
|
||||
|
||||
def next_variable(self, tx):
|
||||
args = super().next_variable(tx)
|
||||
return self.fn.call_function(tx, args.items, {})
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True
|
||||
)
|
||||
codegen(self.fn)
|
||||
self.reconstruct_items(codegen)
|
||||
codegen.extend_output(
|
||||
[
|
||||
create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1),
|
||||
create_instruction("CALL_FUNCTION_EX", arg=0),
|
||||
]
|
||||
)
|
||||
|
Reference in New Issue
Block a user