[dynamo] reland map/zip iterator related changes (#135074)

Differential Revision: [D62211019](https://our.internmc.facebook.com/intern/diff/D62211019)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135074
Approved by: https://github.com/jansel, https://github.com/anijain2305, https://github.com/mlazos
This commit is contained in:
William Wen
2024-09-03 16:54:04 -07:00
committed by PyTorch MergeBot
parent 22e1fb6faa
commit a4030e37be
12 changed files with 554 additions and 78 deletions

View File

@ -239,6 +239,22 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
v = v + x
return v
def test_itertools_reconstruct(self):
def fn(a):
it1 = itertools.repeat(1)
it2 = itertools.count(2)
for _ in range(3):
a += next(it1)
a += next(it2)
return it1, it2, a
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
i1, i2, a = fn(torch.ones(3, 3))
it1, it2, b = opt_fn(torch.ones(3, 3))
self.assertEqual(next(i1), next(it1))
self.assertEqual(next(i2), next(it2))
self.assertEqual(a, b)
@make_test
def test_obj_eq(a, b):
v = a + b
@ -507,8 +523,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
empty = collections.deque()
d.extend(empty)
# dynamo same() util doesn't support deque so just return a list
return list(d)
return d
@make_test
def test_slice1(a):
@ -3115,6 +3130,199 @@ class GraphModule(torch.nn.Module):
fn(arr, np.s_[..., 1], np.array([3, 3])), np.array([[1, 3], [2, 3]])
)
def test_map_return(self):
def fn(a, b):
return map(lambda x: x + 1, [a, b])
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
m = opt_fn(torch.randn(3, 3), torch.randn(3, 3))
self.assertIsInstance(m, map)
@make_test
def test_map_max(a, b):
return max(map(lambda x: x.sum(), [a, b]))
# max(map(...)) graph breaks
@unittest.expectedFailure
@make_test
def test_map_max_const(a):
return max(map(lambda x: x, [1, 2, 3])), a + 1
@make_test
def test_map_list(a, b):
return list(map(lambda x: x + 1, [a, b]))
@make_test
def test_map_tuple(a, b):
return tuple(map(lambda x: x + 1, [a, b]))
@make_test
def test_map_iter(a, b):
it = iter(map(lambda x: x + 1, [a, b]))
return next(it)
@make_test
def test_map_zip_dict(a):
d = dict(
zip(
map(lambda x: x + 1, [0, 1, 2]),
[map(lambda x: x - 1, [y]) for y in [3, 4, 5]],
)
)
return list(d[3])[0], a + 1 # noqa: RUF015
@make_test
def test_map_dict_fromkeys(a):
return dict.fromkeys(map(lambda x: x + 1, [0, 1])), a + 1
@make_test
def test_map_set(a):
return set(map(lambda x: x + 1, [0, 1])), a + 1
# test_map_sum defined earlier
@make_test
def test_map_reduce(a, b):
return functools.reduce(lambda x, y: x + y, map(lambda x: x + 1, [a, b]))
@make_test
def test_map_sorted(a):
return sorted(map(lambda x: x + 1, [0, 4, 3, 1, 2])), a + 1
@make_test
def test_map_list_extend(a, b, c):
l = [a]
l.extend(map(lambda x: x + 1, [b, c]))
return l
@make_test
def test_map_list_slice_assign(a, b, c, d, e):
l = [a, b, c]
l[1:2] = map(lambda x: x + 1, [d, e])
return l
@make_test
def test_map_deque_extendleft(a, b, c):
d = collections.deque([a])
d.extendleft(map(lambda x: x + 1, [b, c]))
return d
@make_test
def test_map_str_join(a):
return "".join(map(lambda x: x, ["a", "b", "c"])), a + 1
def test_map_with_graph_break(self):
def f(a):
a += 1
def g(x):
nonlocal a
a += 1
return x + 1
m = map(g, [1, 2, 3, 4, 5])
a += next(m) # won't graph break
torch._dynamo.graph_break()
a += next(m) # will graph break
return a
cnts = torch._dynamo.testing.CompileCounter()
opt_f = torch.compile(f, backend=cnts)
self.assertEqual(f(torch.ones(3, 3)), opt_f(torch.ones(3, 3)))
self.assertEqual(cnts.frame_count, 3)
def test_map_reconstruct(self):
def fn(a):
return map(lambda x: x[0] + x[1], zip([1, 2, 3], [1, 2, 3])), a + 1
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
m = opt_fn(torch.ones(3, 3))[0]
self.assertIsInstance(m, map)
self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0]))
def test_zip_reconstruct(self):
def fn(a):
return zip([1, 2, 3], map(lambda x: x + 1, [1, 2, 3])), a + 1
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
m = opt_fn(torch.ones(3, 3))[0]
self.assertIsInstance(m, zip)
self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0]))
@make_test
def test_map_partial_unpack(a, b):
y = 1
def f(x):
nonlocal y
y += 1
return x
l = list(zip([a, b], map(f, [1, 2, 3, 4])))
return a + y
@make_test
def test_map_call_function_ex(a, b):
def f(x, y):
return x + y
return f(*map(lambda x: x + 1, [a, b]))
@make_test
def test_map_unpack_twice(a, b):
m = map(lambda x: x + 1, [a, b])
l1 = list(m)
l2 = list(m)
return l1, l2
@make_test
def test_enumerate(a, b):
return list(enumerate([a, b], start=1)), a + 1
@make_test
def test_map_enumerate(a, b):
return list(enumerate(map(lambda x: x + 1, [a, b]), start=1)), a + 1
@make_test
def test_map_infinite(a, b):
return list(map(lambda x, y: x + y, [a, b], itertools.count(3)))
@make_test
def test_map_unpack_vars(a, b):
x, y = map(lambda x: x + 1, [a, b])
return x + y
def test_enumerate_custom(self):
class MyClass:
def __iter__(self):
self.a = 1
return self
def __next__(self):
if self.a > 3:
raise StopIteration
self.a += 1
return self.a
def fn(x):
for i, it in enumerate(MyClass()):
x += i + it
return x
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(torch.ones(3, 3)), opt_fn(torch.ones(3, 3)))
def test_enumerate_reconstruct(self):
def fn(a, b):
return enumerate([a, b], start=1)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
inps = (torch.randn(3, 3), torch.randn(3, 3))
it1 = fn(*inps)
it2 = opt_fn(*inps)
self.assertIsInstance(it2, enumerate)
self.assertEqual(list(it1), list(it2))
def udf_mul(x, y):
return x * y
@ -3670,10 +3878,16 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"):
nopython_fn(x, ys[:1], zs)
with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"):
nopython_fn(x, ys, zs[:1])
# Should cause fallback if allow graph break
with self.assertRaisesRegex(ValueError, "zip()"):
opt_fn(x, ys[:1], zs)
with self.assertRaisesRegex(ValueError, "zip()"):
opt_fn(x, ys, zs[:1])
def test_fn_with_attr(self):
def fn(x):
if fn.pred:

View File

@ -5476,15 +5476,17 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
return x, y
def g(x, y):
return tuple(map(f, x, y))
return map(f, x, y)
opt_g = torch.compile(g, fullgraph=True, backend="eager")
inps = gen_inps(3, 3)
self.assertEqual(g(*inps), opt_g(*inps))
self.assertEqual(type(g(*inps)), type(opt_g(*inps)))
self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps)))
inps = gen_inps(3, 5)
self.assertEqual(g(*inps), opt_g(*inps))
self.assertEqual(type(g(*inps)), type(opt_g(*inps)))
self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps)))
def test_staticmethod_allow_in_graph(self):
class MyClass:

View File

@ -1663,8 +1663,8 @@ class InstructionTranslatorBase(
if not isinstance(
argsvars, BaseListVariable
) and argsvars.has_unpack_var_sequence(self):
argsvars = TupleVariable(argsvars.unpack_var_sequence(self))
) and argsvars.has_force_unpack_var_sequence(self):
argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self))
# Unpack for cases like fn(**obj) where obj is a map
if isinstance(kwargsvars, UserDefinedObjectVariable):
@ -1833,7 +1833,7 @@ class InstructionTranslatorBase(
items = []
for seq in seqs:
try:
items.extend(seq.unpack_var_sequence(self))
items.extend(seq.force_unpack_var_sequence(self))
except NotImplementedError:
unimplemented(f"BUILD_LIST_UNPACK {seq}")
self.push(cls(items, mutable_local=MutableLocal()))
@ -1871,7 +1871,7 @@ class InstructionTranslatorBase(
assert isinstance(keys, TupleVariable)
assert keys.is_python_constant()
keys = keys.unpack_var_sequence(self)
keys = keys.force_unpack_var_sequence(self)
assert len(keys) == len(values)
self.push(
@ -1961,8 +1961,8 @@ class InstructionTranslatorBase(
# x, y = a.shape
proxy = getattr(seq.obj.as_proxy(), seq.name)
val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)]
elif seq.has_unpack_var_sequence(self):
val = seq.unpack_var_sequence(self)
elif seq.has_force_unpack_var_sequence(self):
val = seq.force_unpack_var_sequence(self)
else:
unimplemented(f"UNPACK_SEQUENCE {seq}")
if len(val) != inst.argval:
@ -1975,8 +1975,8 @@ class InstructionTranslatorBase(
prefix = inst.argval & 0xFF # low byte
suffix = inst.argval >> 8 # high byte
seq = self.pop()
if seq.has_unpack_var_sequence(self):
vals = list(seq.unpack_var_sequence(self))
if seq.has_force_unpack_var_sequence(self):
vals = list(seq.force_unpack_var_sequence(self))
assert len(vals) >= prefix + suffix
vals_prefix = vals[:prefix]
vals_list = vals[prefix : len(vals) - suffix]
@ -2400,7 +2400,7 @@ class InstructionTranslatorBase(
self.UNARY_POSITIVE(inst)
elif inst.argval == 6:
# INTRINSIC_LIST_TO_TUPLE
self.push(TupleVariable(self.pop().unpack_var_sequence(self)))
self.push(TupleVariable(self.pop().force_unpack_var_sequence(self)))
else:
unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}")

View File

@ -1608,8 +1608,12 @@ def same(
"""Check correctness to see if ref and res match"""
if fp64_ref is None:
fp64_ref = ref
if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)):
assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}"
if isinstance(
ref, (list, tuple, collections.deque, torch.nn.ParameterList, torch.Size)
):
assert isinstance(
res, (list, tuple, collections.deque)
), f"type mismatch {type(ref)} {type(res)}"
if len(ref) != len(res):
log_error("Length mismatch")
return False

View File

@ -46,7 +46,9 @@ from .iter import (
CycleIteratorVariable,
IteratorVariable,
ItertoolsVariable,
MapVariable,
RepeatIteratorVariable,
ZipVariable,
)
from .lazy import LazyVariableTracker
from .lists import (

View File

@ -289,6 +289,15 @@ class VariableTracker(metaclass=VariableTrackerMeta):
def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
raise NotImplementedError
def force_unpack_var_sequence(self, tx) -> List["VariableTracker"]:
# like unpack_var_sequence, but should only be used when it is
# safe to eagerly (vs. lazily) unpack this variable.
# e.g. map(f, x) is normally evaluated lazily but sometimes
# we want to force eager unpacking, e.g. when converting to a list.
# NOTE: this method is allowed to mutate the VariableTracker, so
# it should only be called once.
return self.unpack_var_sequence(tx)
def has_unpack_var_sequence(self, tx) -> bool:
try:
self.unpack_var_sequence(tx)
@ -296,6 +305,10 @@ class VariableTracker(metaclass=VariableTrackerMeta):
except NotImplementedError:
return False
# NB: don't call force_unpack_var_sequence, especially if it mutates!
def has_force_unpack_var_sequence(self, tx) -> bool:
return self.has_unpack_var_sequence(tx)
def inspect_parameter_names(self) -> List[str]:
unimplemented(f"inspect_parameter_names: {self}")

View File

@ -1058,9 +1058,8 @@ class BuiltinVariable(VariableTracker):
return tx.inline_user_function_return(user_func_variable, [arg], {})
def _call_min_max(self, tx: "InstructionTranslator", *args):
if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
# expand iterable
items = args[0].unpack_var_sequence(tx)
if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
items = args[0].force_unpack_var_sequence(tx)
return self._call_min_max_seq(tx, items)
elif len(args) == 2:
return self._call_min_max_binary(tx, args[0], args[1])
@ -1075,6 +1074,10 @@ class BuiltinVariable(VariableTracker):
return functools.reduce(functools.partial(self._call_min_max_binary, tx), items)
def _call_min_max_binary(self, tx: "InstructionTranslator", a, b):
if a is None or b is None:
# a or b could be none if we reduce and _call_min_max_binary failed
# to return something
return
if self.tensor_args(a, b):
if not isinstance(a, variables.TensorVariable):
a, b = b, a
@ -1223,17 +1226,15 @@ class BuiltinVariable(VariableTracker):
),
)
# NOTE must handle IteratorVariable separately!
def _call_iter_tuple_list(
self, tx: "InstructionTranslator", obj=None, *args, **kwargs
):
assert not isinstance(obj, variables.IteratorVariable)
if self._dynamic_args(*args, **kwargs):
return self._dyn_proxy(tx, *args, **kwargs)
if isinstance(obj, variables.IteratorVariable):
# For non-list iterators, we will guard on vars that
# determine the control flow
return obj
cls = variables.BaseListVariable.cls_for(self.fn)
if obj is None:
return cls(
@ -1261,9 +1262,22 @@ class BuiltinVariable(VariableTracker):
mutable_local=MutableLocal(),
)
def _call_tuple_list(self, tx, obj=None, *args, **kwargs):
if isinstance(obj, variables.IteratorVariable):
cls = variables.BaseListVariable.cls_for(self.fn)
return cls(
list(obj.force_unpack_var_sequence(tx)),
mutable_local=MutableLocal(),
)
else:
return self._call_iter_tuple_list(tx, obj, *args, **kwargs)
def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs):
# Handle the case where we are iterating over a tuple, list or iterator
ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs)
if isinstance(obj, variables.IteratorVariable):
ret = obj
else:
# Handle the case where we are iterating over a tuple, list or iterator
ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs)
if ret is None:
# If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway.
@ -1272,8 +1286,8 @@ class BuiltinVariable(VariableTracker):
return obj.call_method(tx, "__iter__", args, kwargs)
return ret
call_tuple = _call_iter_tuple_list
call_list = _call_iter_tuple_list
call_tuple = _call_tuple_list
call_list = _call_tuple_list
def call_callable(self, tx: "InstructionTranslator", arg):
from .functions import BaseUserFunctionVariable
@ -1331,10 +1345,12 @@ class BuiltinVariable(VariableTracker):
ListVariable,
TupleVariable,
ListIteratorVariable,
variables.IteratorVariable,
),
):
items = dict(
x.unpack_var_sequence(tx) for x in arg.unpack_var_sequence(tx)
x.force_unpack_var_sequence(tx)
for x in arg.force_unpack_var_sequence(tx)
)
return ConstDictVariable(items, user_cls, mutable_local=MutableLocal())
elif isinstance(arg, variables.MutableMappingVariable):
@ -1391,13 +1407,12 @@ class BuiltinVariable(VariableTracker):
return DictVariableType(
dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal()
)
elif arg.has_unpack_var_sequence(tx) and all(
is_hashable(v) for v in arg.unpack_var_sequence(tx)
):
keys = arg.unpack_var_sequence(tx)
return DictVariableType(
dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal()
)
elif arg.has_force_unpack_var_sequence(tx):
keys = arg.force_unpack_var_sequence(tx)
if all(is_hashable(v) for v in keys):
return DictVariableType(
dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal()
)
unimplemented(f"{user_cls.__name__}.fromkeys(): {args} {kwargs}")
def call_set(self, tx: "InstructionTranslator", *args, **kwargs):
@ -1409,8 +1424,8 @@ class BuiltinVariable(VariableTracker):
arg = args[0]
if isinstance(arg, variables.SetVariable):
return arg.clone(mutable_local=MutableLocal())
elif arg.has_unpack_var_sequence(tx):
items = arg.unpack_var_sequence(tx)
elif arg.has_force_unpack_var_sequence(tx):
items = arg.force_unpack_var_sequence(tx)
return SetVariable(items, mutable_local=MutableLocal())
elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance(
arg.value, KeysView
@ -1443,16 +1458,12 @@ class BuiltinVariable(VariableTracker):
def call_zip(self, tx: "InstructionTranslator", *args, **kwargs):
if kwargs:
assert len(kwargs) == 1 and "strict" in kwargs
if all(x.has_unpack_var_sequence(tx) for x in args):
unpacked = [arg.unpack_var_sequence(tx) for arg in args]
if kwargs.pop("strict", False) and len(unpacked) > 0:
if not all(len(u) == len(unpacked[0]) for u in unpacked):
raise UserError(
ValueError,
"zip() has one argument of len differing from others",
)
items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)]
return variables.TupleVariable(items)
strict = kwargs.pop("strict", False)
args = [
arg.unpack_var_sequence(tx) if arg.has_unpack_var_sequence(tx) else arg
for arg in args
]
return variables.ZipVariable(args, strict=strict, mutable_local=MutableLocal())
def call_len(self, tx: "InstructionTranslator", *args, **kwargs):
return args[0].call_method(tx, "__len__", args[1:], kwargs)
@ -1553,10 +1564,11 @@ class BuiltinVariable(VariableTracker):
return obj.call_hasattr(tx, name)
def call_map(self, tx: "InstructionTranslator", fn, *seqs):
if all(seq.has_unpack_var_sequence(tx) for seq in seqs):
unpacked = [seq.unpack_var_sequence(tx) for seq in seqs]
items = [fn.call_function(tx, list(args), {}) for args in zip(*unpacked)]
return variables.TupleVariable(items)
seqs = [
seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq
for seq in seqs
]
return variables.MapVariable(fn, seqs, mutable_local=MutableLocal())
def call_filter(self, tx: "InstructionTranslator", fn, seq):
if seq.has_unpack_var_sequence(tx):
@ -1589,10 +1601,10 @@ class BuiltinVariable(VariableTracker):
return variables.ConstantVariable.create(
sum((x.value for x in seq.items), start=start.value),
)
if seq.has_unpack_var_sequence(tx):
if seq.has_force_unpack_var_sequence(tx):
if start is self._SENTINEL:
start = variables.ConstantVariable.create(0)
items = seq.unpack_var_sequence(tx)
items = seq.force_unpack_var_sequence(tx)
return BuiltinVariable(functools.reduce).call_function(
tx,
[
@ -1606,8 +1618,8 @@ class BuiltinVariable(VariableTracker):
def call_reduce(
self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL
):
if iterable.has_unpack_var_sequence(tx):
items = iterable.unpack_var_sequence(tx)
if iterable.has_force_unpack_var_sequence(tx):
items = iterable.force_unpack_var_sequence(tx)
if initial is self._SENTINEL:
value, items = items[0], items[1:]
else:
@ -1903,11 +1915,12 @@ class BuiltinVariable(VariableTracker):
return variables.TupleVariable(items)
def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwargs):
if (
obj.has_unpack_var_sequence(tx)
and not isinstance(obj, variables.TensorVariable)
and all(x.is_python_constant() for x in obj.unpack_var_sequence(tx))
if obj.has_force_unpack_var_sequence(tx) and not isinstance(
obj, variables.TensorVariable
):
unpacked = obj.force_unpack_var_sequence(tx)
if not all(x.is_python_constant() for x in unpacked):
return
function = kwargs.pop("key", None)
reverse = kwargs.pop(
"reverse", ConstantVariable.create(False)
@ -1915,7 +1928,7 @@ class BuiltinVariable(VariableTracker):
assert len(kwargs) == 0
if function:
items = sorted(
obj.unpack_var_sequence(tx),
unpacked,
key=lambda x: function.call_function(
tx, [x], {}
).as_python_constant(),
@ -1923,7 +1936,7 @@ class BuiltinVariable(VariableTracker):
)
else:
items = sorted(
obj.unpack_var_sequence(tx),
unpacked,
key=lambda x: x.as_python_constant(),
reverse=reverse,
)

View File

@ -145,6 +145,14 @@ class ConstantVariable(VariableTracker):
return variables.BuiltinVariable(str.format).call_function(
tx, [self, *args], kwargs
)
elif name == "join" and istype(self.value, str):
assert len(args) == 1 and len(kwargs) == 0
arg_unpacked = args[0].force_unpack_var_sequence(tx)
try:
arg_const = [x.as_python_constant() for x in arg_unpacked]
return ConstantVariable.create(self.value.join(arg_const))
except NotImplementedError:
return super().call_method(tx, name, args, kwargs)
if any(isinstance(x, SymNodeVariable) for x in args):
# Promote to SymNodeVariable for operations involving dynamic shapes.

View File

@ -314,6 +314,7 @@ class ConstDictVariable(VariableTracker):
ListVariable,
TupleVariable,
ListIteratorVariable,
variables.IteratorVariable,
UserDefinedObjectVariable,
),
)

View File

@ -2,14 +2,17 @@
import itertools
import operator
from typing import Dict, List, Optional, TYPE_CHECKING
import sys
from typing import Dict, List, Optional, TYPE_CHECKING, Union
from .. import polyfills, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import (
handle_observed_exception,
ObservedUserStopIteration,
raise_observed_exception,
unimplemented,
UserError,
)
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
@ -197,6 +200,25 @@ class IteratorVariable(VariableTracker):
def next_variable(self, tx):
unimplemented("abstract method, must implement")
# NOTE: only call when unpacking this iterator safely done eagerly!
# Normally, iterators are accessed lazily.
# Example of safe eager unpacking: list(map(f, seq))
# Example of unsafe eager unpacking: list(islice(map(f, seq), 5))
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
result = []
while True:
try:
result.append(self.next_variable(tx))
except ObservedUserStopIteration:
handle_observed_exception(tx)
break
return result
# don't call force_unpack_var_sequence since it can mutate
# IteratorVariable state!
def has_force_unpack_var_sequence(self, tx) -> bool:
return True
class RepeatIteratorVariable(IteratorVariable):
def __init__(self, item: VariableTracker, **kwargs) -> None:
@ -207,6 +229,18 @@ class RepeatIteratorVariable(IteratorVariable):
def next_variable(self, tx):
return self.item
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_python_module(itertools),
codegen.create_load_attr("repeat"),
]
)
)
codegen(self.item)
codegen.extend_output(create_call_function(1, False))
class CountIteratorVariable(IteratorVariable):
def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None:
@ -220,10 +254,23 @@ class CountIteratorVariable(IteratorVariable):
def next_variable(self, tx):
assert self.mutable_local
old_item = self.item
tx.output.side_effects.mutation(self)
next_item = self.item.call_method(tx, "__add__", [self.step], {})
self.item = next_item
return self.item
self.item = self.item.call_method(tx, "__add__", [self.step], {})
return old_item
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_python_module(itertools),
codegen.create_load_attr("count"),
]
)
)
codegen(self.item)
codegen(self.step)
codegen.extend_output(create_call_function(2, False))
class CycleIteratorVariable(IteratorVariable):
@ -269,3 +316,160 @@ class CycleIteratorVariable(IteratorVariable):
return self.item
else:
raise_observed_exception(StopIteration, tx, self)
class ZipVariable(IteratorVariable):
"""
Represents zip(*iterables)
"""
_nonvar_fields = {
"index",
"strict",
*IteratorVariable._nonvar_fields,
}
def __init__(
self,
iterables: List[Union[List[VariableTracker], VariableTracker]],
strict: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
assert isinstance(iterables, list)
# can be list[Variable] or VariableTracker (with next_variable implemented)
self.iterables = iterables
self.index = 0
self.strict = strict
def python_type(self):
return zip
def has_unpack_var_sequence(self, tx) -> bool:
return all(
isinstance(it, list) or it.has_unpack_var_sequence(tx)
for it in self.iterables
)
def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
assert self.has_unpack_var_sequence(tx)
iterables = []
for it in self.iterables:
if isinstance(it, list):
iterables.append(it[self.index :])
else:
iterables.append(it.unpack_var_sequence(tx))
kwargs = {"strict": self.strict} if self.strict else {}
zipped = zip(*iterables, **kwargs)
return [variables.TupleVariable(list(var)) for var in zipped]
def next_variable(self, tx):
assert self.mutable_local
old_index = self.index
args = []
def get_item(it):
if isinstance(it, list):
if old_index >= len(it):
raise_observed_exception(StopIteration, tx, self)
return it[old_index]
else:
return it.next_variable(tx)
try:
for idx, it in enumerate(self.iterables):
args.append(get_item(it))
except ObservedUserStopIteration:
if self.strict:
if idx == 0:
# all other iterables should be exhausted
for it in self.iterables:
try:
get_item(it)
except ObservedUserStopIteration:
handle_observed_exception(tx)
continue
# no ObservedUserStopIteration - fall through to UserError
break
else:
# all iterables exhausted, raise original error
raise
handle_observed_exception(tx)
raise UserError(
ValueError,
"zip() has one argument of len differing from others",
) from None
raise
tx.output.side_effects.mutation(self)
self.index += 1
return variables.TupleVariable(args)
def reconstruct_items(self, codegen):
for it in self.iterables:
if isinstance(it, list):
remaining_items = it[self.index :]
codegen.foreach(remaining_items)
codegen.append_output(
create_instruction("BUILD_TUPLE", arg=len(remaining_items))
)
else:
codegen(it)
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True
)
self.reconstruct_items(codegen)
codegen.append_output(
create_instruction("BUILD_TUPLE", arg=len(self.iterables))
)
if sys.version_info >= (3, 10):
codegen.extend_output(
[
codegen.create_load_const("strict"),
codegen.create_load_const(self.strict),
create_instruction("BUILD_MAP", arg=1),
create_instruction("CALL_FUNCTION_EX", arg=1),
]
)
else:
codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0))
class MapVariable(ZipVariable):
"""
Represents map(fn, *iterables)
"""
def __init__(
self,
fn: VariableTracker,
iterables: List[Union[List[VariableTracker], VariableTracker]],
**kwargs,
) -> None:
super().__init__(iterables, **kwargs)
self.fn = fn
def python_type(self):
return map
def has_unpack_var_sequence(self, tx) -> bool:
return False
def next_variable(self, tx):
args = super().next_variable(tx)
return self.fn.call_function(tx, args.items, {})
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True
)
codegen(self.fn)
self.reconstruct_items(codegen)
codegen.extend_output(
[
create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1),
create_instruction("CALL_FUNCTION_EX", arg=0),
]
)

View File

@ -29,6 +29,7 @@ from ..utils import (
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
from .functions import UserFunctionVariable, UserMethodVariable
from .iter import IteratorVariable
if TYPE_CHECKING:
@ -334,11 +335,11 @@ class CommonListMethodsVariable(BaseListVariable):
name == "extend"
and self.mutable_local
and args
and args[0].has_unpack_var_sequence(tx)
and args[0].has_force_unpack_var_sequence(tx)
):
assert not kwargs
(arg,) = args
seq = arg.unpack_var_sequence(tx)
seq = arg.force_unpack_var_sequence(tx)
tx.output.side_effects.mutation(self)
self.items.extend(seq)
return ConstantVariable.create(None)
@ -422,11 +423,13 @@ class ListVariable(CommonListMethodsVariable):
key, value = args
tx.output.side_effects.mutation(self)
if isinstance(key, SliceVariable):
if not value.has_unpack_var_sequence(tx):
if not value.has_force_unpack_var_sequence(tx):
unimplemented(
f"Missing dynamo support for expanding {value} into a list for slice assignment."
)
self.items[key.as_python_constant()] = value.unpack_var_sequence(tx)
self.items[key.as_python_constant()] = value.force_unpack_var_sequence(
tx
)
else:
self.items[key.as_python_constant()] = value
return ConstantVariable.create(None)
@ -464,7 +467,12 @@ class DequeVariable(CommonListMethodsVariable):
)
)
codegen.foreach(self.items)
codegen.extend_output(create_call_function(len(self.items), False))
codegen.extend_output(
[
create_instruction("BUILD_LIST", arg=len(self.items)),
*create_call_function(1, False),
]
)
def call_method(
self,
@ -487,11 +495,15 @@ class DequeVariable(CommonListMethodsVariable):
tx.output.side_effects.mutation(self)
self.items[key.as_python_constant()] = value
return ConstantVariable.create(None)
elif name == "extendleft" and self.mutable_local:
elif (
name == "extendleft"
and self.mutable_local
and args[0].has_force_unpack_var_sequence(tx)
):
assert not kwargs
(arg,) = args
prefix = arg.unpack_var_sequence(tx)
prefix = arg.force_unpack_var_sequence(tx)
prefix.reverse()
tx.output.side_effects.mutation(self)
self.items = prefix + list(self.items)
@ -802,10 +814,10 @@ class SliceVariable(BaseListVariable):
return self.items[fields.index(name)]
class ListIteratorVariable(VariableTracker):
class ListIteratorVariable(IteratorVariable):
_nonvar_fields = {
"index",
*VariableTracker._nonvar_fields,
*IteratorVariable._nonvar_fields,
}
def __init__(self, items, index: int = 0, **kwargs) -> None:
@ -856,6 +868,9 @@ class ListIteratorVariable(VariableTracker):
def unpack_var_sequence(self, tx):
return list(self.items[self.index :])
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
return self.unpack_var_sequence(tx)
def reconstruct(self, codegen):
remaining_items = self.items[self.index :]
codegen.foreach(remaining_items)

View File

@ -379,8 +379,8 @@ class UserDefinedClassVariable(UserDefinedVariable):
elif self.value is collections.deque and not kwargs:
if len(args) == 0:
items = []
elif len(args) == 1 and args[0].has_unpack_var_sequence(tx):
items = args[0].unpack_var_sequence(tx)
elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
items = args[0].force_unpack_var_sequence(tx)
else:
unimplemented("deque() with more than 1 arg not supported")
return variables.lists.DequeVariable(items, mutable_local=MutableLocal())
@ -749,7 +749,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
assert not (args or kwargs)
items = []
keys = self.call_method(tx, "keys", [], {})
for key in keys.unpack_var_sequence(tx):
for key in keys.force_unpack_var_sequence(tx):
items.append(
TupleVariable(
[key, self.odict_getitem(tx, key)],