[dynamo] reland map/zip iterator related changes (#135074)

Differential Revision: [D62211019](https://our.internmc.facebook.com/intern/diff/D62211019)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135074
Approved by: https://github.com/jansel, https://github.com/anijain2305, https://github.com/mlazos
This commit is contained in:
William Wen
2024-09-03 16:54:04 -07:00
committed by PyTorch MergeBot
parent 22e1fb6faa
commit a4030e37be
12 changed files with 554 additions and 78 deletions

View File

@ -2,14 +2,17 @@
import itertools
import operator
from typing import Dict, List, Optional, TYPE_CHECKING
import sys
from typing import Dict, List, Optional, TYPE_CHECKING, Union
from .. import polyfills, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import (
handle_observed_exception,
ObservedUserStopIteration,
raise_observed_exception,
unimplemented,
UserError,
)
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
@ -197,6 +200,25 @@ class IteratorVariable(VariableTracker):
def next_variable(self, tx):
unimplemented("abstract method, must implement")
# NOTE: only call when unpacking this iterator safely done eagerly!
# Normally, iterators are accessed lazily.
# Example of safe eager unpacking: list(map(f, seq))
# Example of unsafe eager unpacking: list(islice(map(f, seq), 5))
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
result = []
while True:
try:
result.append(self.next_variable(tx))
except ObservedUserStopIteration:
handle_observed_exception(tx)
break
return result
# don't call force_unpack_var_sequence since it can mutate
# IteratorVariable state!
def has_force_unpack_var_sequence(self, tx) -> bool:
return True
class RepeatIteratorVariable(IteratorVariable):
def __init__(self, item: VariableTracker, **kwargs) -> None:
@ -207,6 +229,18 @@ class RepeatIteratorVariable(IteratorVariable):
def next_variable(self, tx):
return self.item
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_python_module(itertools),
codegen.create_load_attr("repeat"),
]
)
)
codegen(self.item)
codegen.extend_output(create_call_function(1, False))
class CountIteratorVariable(IteratorVariable):
def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None:
@ -220,10 +254,23 @@ class CountIteratorVariable(IteratorVariable):
def next_variable(self, tx):
assert self.mutable_local
old_item = self.item
tx.output.side_effects.mutation(self)
next_item = self.item.call_method(tx, "__add__", [self.step], {})
self.item = next_item
return self.item
self.item = self.item.call_method(tx, "__add__", [self.step], {})
return old_item
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_python_module(itertools),
codegen.create_load_attr("count"),
]
)
)
codegen(self.item)
codegen(self.step)
codegen.extend_output(create_call_function(2, False))
class CycleIteratorVariable(IteratorVariable):
@ -269,3 +316,160 @@ class CycleIteratorVariable(IteratorVariable):
return self.item
else:
raise_observed_exception(StopIteration, tx, self)
class ZipVariable(IteratorVariable):
"""
Represents zip(*iterables)
"""
_nonvar_fields = {
"index",
"strict",
*IteratorVariable._nonvar_fields,
}
def __init__(
self,
iterables: List[Union[List[VariableTracker], VariableTracker]],
strict: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
assert isinstance(iterables, list)
# can be list[Variable] or VariableTracker (with next_variable implemented)
self.iterables = iterables
self.index = 0
self.strict = strict
def python_type(self):
return zip
def has_unpack_var_sequence(self, tx) -> bool:
return all(
isinstance(it, list) or it.has_unpack_var_sequence(tx)
for it in self.iterables
)
def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
assert self.has_unpack_var_sequence(tx)
iterables = []
for it in self.iterables:
if isinstance(it, list):
iterables.append(it[self.index :])
else:
iterables.append(it.unpack_var_sequence(tx))
kwargs = {"strict": self.strict} if self.strict else {}
zipped = zip(*iterables, **kwargs)
return [variables.TupleVariable(list(var)) for var in zipped]
def next_variable(self, tx):
assert self.mutable_local
old_index = self.index
args = []
def get_item(it):
if isinstance(it, list):
if old_index >= len(it):
raise_observed_exception(StopIteration, tx, self)
return it[old_index]
else:
return it.next_variable(tx)
try:
for idx, it in enumerate(self.iterables):
args.append(get_item(it))
except ObservedUserStopIteration:
if self.strict:
if idx == 0:
# all other iterables should be exhausted
for it in self.iterables:
try:
get_item(it)
except ObservedUserStopIteration:
handle_observed_exception(tx)
continue
# no ObservedUserStopIteration - fall through to UserError
break
else:
# all iterables exhausted, raise original error
raise
handle_observed_exception(tx)
raise UserError(
ValueError,
"zip() has one argument of len differing from others",
) from None
raise
tx.output.side_effects.mutation(self)
self.index += 1
return variables.TupleVariable(args)
def reconstruct_items(self, codegen):
for it in self.iterables:
if isinstance(it, list):
remaining_items = it[self.index :]
codegen.foreach(remaining_items)
codegen.append_output(
create_instruction("BUILD_TUPLE", arg=len(remaining_items))
)
else:
codegen(it)
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True
)
self.reconstruct_items(codegen)
codegen.append_output(
create_instruction("BUILD_TUPLE", arg=len(self.iterables))
)
if sys.version_info >= (3, 10):
codegen.extend_output(
[
codegen.create_load_const("strict"),
codegen.create_load_const(self.strict),
create_instruction("BUILD_MAP", arg=1),
create_instruction("CALL_FUNCTION_EX", arg=1),
]
)
else:
codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0))
class MapVariable(ZipVariable):
"""
Represents map(fn, *iterables)
"""
def __init__(
self,
fn: VariableTracker,
iterables: List[Union[List[VariableTracker], VariableTracker]],
**kwargs,
) -> None:
super().__init__(iterables, **kwargs)
self.fn = fn
def python_type(self):
return map
def has_unpack_var_sequence(self, tx) -> bool:
return False
def next_variable(self, tx):
args = super().next_variable(tx)
return self.fn.call_function(tx, args.items, {})
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True
)
codegen(self.fn)
self.reconstruct_items(codegen)
codegen.extend_output(
[
create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1),
create_instruction("CALL_FUNCTION_EX", arg=0),
]
)