mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] revert map/zip iterator related changes (#132528)
Need to revert due to internal hangs: S437700 This reverts commit b6c1490cc02316ffe85e5ae74651d80f0158ba64. Revert "[dynamo] implement IteratorVariable and polyfill fallbacks for enumerate (#131725)" This reverts commit 2576dbbc35d66e8e9ed6cb12216ccc424cb87ec3. Revert "[dynamo] add itertools repeat/count bytecode reconstruction (#131716)" This reverts commit 35b4de32fafc5ad024c20ef1275711bffc557ae9. Revert "[dynamo] add lazy IteratorVariable implementations for map and zip (#131413)" This reverts commit 7d282d87550787d8269593093519c2ad7c5032cd. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/132528 Approved by: https://github.com/ZainRizvi
This commit is contained in:
committed by
PyTorch MergeBot
parent
b71cd149ce
commit
e81e74ca6c
@ -181,22 +181,6 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
v = v + x
|
||||
return v
|
||||
|
||||
def test_itertools_reconstruct(self):
|
||||
def fn(a):
|
||||
it1 = itertools.repeat(1)
|
||||
it2 = itertools.count(2)
|
||||
for _ in range(3):
|
||||
a += next(it1)
|
||||
a += next(it2)
|
||||
return it1, it2, a
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
i1, i2, a = fn(torch.ones(3, 3))
|
||||
it1, it2, b = opt_fn(torch.ones(3, 3))
|
||||
self.assertEqual(next(i1), next(it1))
|
||||
self.assertEqual(next(i2), next(it2))
|
||||
self.assertEqual(a, b)
|
||||
|
||||
@make_test
|
||||
def test_obj_eq(a, b):
|
||||
v = a + b
|
||||
@ -449,7 +433,8 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
empty = collections.deque()
|
||||
d.extend(empty)
|
||||
|
||||
return d
|
||||
# dynamo same() util doesn't support deque so just return a list
|
||||
return list(d)
|
||||
|
||||
@make_test
|
||||
def test_slice1(a):
|
||||
@ -2886,199 +2871,6 @@ class GraphModule(torch.nn.Module):
|
||||
fn(arr, np.s_[..., 1], np.array([3, 3])), np.array([[1, 3], [2, 3]])
|
||||
)
|
||||
|
||||
def test_map_return(self):
|
||||
def fn(a, b):
|
||||
return map(lambda x: x + 1, [a, b])
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
m = opt_fn(torch.randn(3, 3), torch.randn(3, 3))
|
||||
self.assertIsInstance(m, map)
|
||||
|
||||
@make_test
|
||||
def test_map_max(a, b):
|
||||
return max(map(lambda x: x.sum(), [a, b]))
|
||||
|
||||
# max(map(...)) graph breaks
|
||||
@unittest.expectedFailure
|
||||
@make_test
|
||||
def test_map_max_const(a):
|
||||
return max(map(lambda x: x, [1, 2, 3])), a + 1
|
||||
|
||||
@make_test
|
||||
def test_map_list(a, b):
|
||||
return list(map(lambda x: x + 1, [a, b]))
|
||||
|
||||
@make_test
|
||||
def test_map_tuple(a, b):
|
||||
return tuple(map(lambda x: x + 1, [a, b]))
|
||||
|
||||
@make_test
|
||||
def test_map_iter(a, b):
|
||||
it = iter(map(lambda x: x + 1, [a, b]))
|
||||
return next(it)
|
||||
|
||||
@make_test
|
||||
def test_map_zip_dict(a):
|
||||
d = dict(
|
||||
zip(
|
||||
map(lambda x: x + 1, [0, 1, 2]),
|
||||
[map(lambda x: x - 1, [y]) for y in [3, 4, 5]],
|
||||
)
|
||||
)
|
||||
return list(d[3])[0], a + 1 # noqa: RUF015
|
||||
|
||||
@make_test
|
||||
def test_map_dict_fromkeys(a):
|
||||
return dict.fromkeys(map(lambda x: x + 1, [0, 1])), a + 1
|
||||
|
||||
@make_test
|
||||
def test_map_set(a):
|
||||
return set(map(lambda x: x + 1, [0, 1])), a + 1
|
||||
|
||||
# test_map_sum defined earlier
|
||||
|
||||
@make_test
|
||||
def test_map_reduce(a, b):
|
||||
return functools.reduce(lambda x, y: x + y, map(lambda x: x + 1, [a, b]))
|
||||
|
||||
@make_test
|
||||
def test_map_sorted(a):
|
||||
return sorted(map(lambda x: x + 1, [0, 4, 3, 1, 2])), a + 1
|
||||
|
||||
@make_test
|
||||
def test_map_list_extend(a, b, c):
|
||||
l = [a]
|
||||
l.extend(map(lambda x: x + 1, [b, c]))
|
||||
return l
|
||||
|
||||
@make_test
|
||||
def test_map_list_slice_assign(a, b, c, d, e):
|
||||
l = [a, b, c]
|
||||
l[1:2] = map(lambda x: x + 1, [d, e])
|
||||
return l
|
||||
|
||||
@make_test
|
||||
def test_map_deque_extendleft(a, b, c):
|
||||
d = collections.deque([a])
|
||||
d.extendleft(map(lambda x: x + 1, [b, c]))
|
||||
return d
|
||||
|
||||
@make_test
|
||||
def test_map_str_join(a):
|
||||
return "".join(map(lambda x: x, ["a", "b", "c"])), a + 1
|
||||
|
||||
def test_map_with_graph_break(self):
|
||||
def f(a):
|
||||
a += 1
|
||||
|
||||
def g(x):
|
||||
nonlocal a
|
||||
a += 1
|
||||
return x + 1
|
||||
|
||||
m = map(g, [1, 2, 3, 4, 5])
|
||||
a += next(m) # won't graph break
|
||||
torch._dynamo.graph_break()
|
||||
a += next(m) # will graph break
|
||||
return a
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_f = torch.compile(f, backend=cnts)
|
||||
self.assertEqual(f(torch.ones(3, 3)), opt_f(torch.ones(3, 3)))
|
||||
self.assertEqual(cnts.frame_count, 3)
|
||||
|
||||
def test_map_reconstruct(self):
|
||||
def fn(a):
|
||||
return map(lambda x: x[0] + x[1], zip([1, 2, 3], [1, 2, 3])), a + 1
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
m = opt_fn(torch.ones(3, 3))[0]
|
||||
self.assertIsInstance(m, map)
|
||||
self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0]))
|
||||
|
||||
def test_zip_reconstruct(self):
|
||||
def fn(a):
|
||||
return zip([1, 2, 3], map(lambda x: x + 1, [1, 2, 3])), a + 1
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
m = opt_fn(torch.ones(3, 3))[0]
|
||||
self.assertIsInstance(m, zip)
|
||||
self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0]))
|
||||
|
||||
@make_test
|
||||
def test_map_partial_unpack(a, b):
|
||||
y = 1
|
||||
|
||||
def f(x):
|
||||
nonlocal y
|
||||
y += 1
|
||||
return x
|
||||
|
||||
l = list(zip([a, b], map(f, [1, 2, 3, 4])))
|
||||
return a + y
|
||||
|
||||
@make_test
|
||||
def test_map_call_function_ex(a, b):
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
return f(*map(lambda x: x + 1, [a, b]))
|
||||
|
||||
@make_test
|
||||
def test_map_unpack_twice(a, b):
|
||||
m = map(lambda x: x + 1, [a, b])
|
||||
l1 = list(m)
|
||||
l2 = list(m)
|
||||
return l1, l2
|
||||
|
||||
@make_test
|
||||
def test_enumerate(a, b):
|
||||
return list(enumerate([a, b], start=1)), a + 1
|
||||
|
||||
@make_test
|
||||
def test_map_enumerate(a, b):
|
||||
return list(enumerate(map(lambda x: x + 1, [a, b]), start=1)), a + 1
|
||||
|
||||
@make_test
|
||||
def test_map_infinite(a, b):
|
||||
return list(map(lambda x, y: x + y, [a, b], itertools.count(3)))
|
||||
|
||||
@make_test
|
||||
def test_map_unpack_vars(a, b):
|
||||
x, y = map(lambda x: x + 1, [a, b])
|
||||
return x + y
|
||||
|
||||
def test_enumerate_custom(self):
|
||||
class MyClass:
|
||||
def __iter__(self):
|
||||
self.a = 1
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.a > 3:
|
||||
raise StopIteration
|
||||
self.a += 1
|
||||
return self.a
|
||||
|
||||
def fn(x):
|
||||
for i, it in enumerate(MyClass()):
|
||||
x += i + it
|
||||
return x
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
self.assertEqual(fn(torch.ones(3, 3)), opt_fn(torch.ones(3, 3)))
|
||||
|
||||
def test_enumerate_reconstruct(self):
|
||||
def fn(a, b):
|
||||
return enumerate([a, b], start=1)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
inps = (torch.randn(3, 3), torch.randn(3, 3))
|
||||
it1 = fn(*inps)
|
||||
it2 = opt_fn(*inps)
|
||||
self.assertIsInstance(it2, enumerate)
|
||||
self.assertEqual(list(it1), list(it2))
|
||||
|
||||
|
||||
def udf_mul(x, y):
|
||||
return x * y
|
||||
@ -3569,16 +3361,10 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
||||
with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"):
|
||||
nopython_fn(x, ys[:1], zs)
|
||||
|
||||
with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"):
|
||||
nopython_fn(x, ys, zs[:1])
|
||||
|
||||
# Should cause fallback if allow graph break
|
||||
with self.assertRaisesRegex(ValueError, "zip()"):
|
||||
opt_fn(x, ys[:1], zs)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "zip()"):
|
||||
opt_fn(x, ys, zs[:1])
|
||||
|
||||
def test_fn_with_attr(self):
|
||||
def fn(x):
|
||||
if fn.pred:
|
||||
|
@ -5400,17 +5400,15 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
return x, y
|
||||
|
||||
def g(x, y):
|
||||
return map(f, x, y)
|
||||
return tuple(map(f, x, y))
|
||||
|
||||
opt_g = torch.compile(g, fullgraph=True, backend="eager")
|
||||
|
||||
inps = gen_inps(3, 3)
|
||||
self.assertEqual(type(g(*inps)), type(opt_g(*inps)))
|
||||
self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps)))
|
||||
self.assertEqual(g(*inps), opt_g(*inps))
|
||||
|
||||
inps = gen_inps(3, 5)
|
||||
self.assertEqual(type(g(*inps)), type(opt_g(*inps)))
|
||||
self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps)))
|
||||
self.assertEqual(g(*inps), opt_g(*inps))
|
||||
|
||||
def test_staticmethod_allow_in_graph(self):
|
||||
class MyClass:
|
||||
|
@ -24,7 +24,7 @@ def any(iterator):
|
||||
|
||||
|
||||
def index(iterator, item, start=0, end=None):
|
||||
for i, elem in list(enumerate(list(iterator)))[start:end]:
|
||||
for i, elem in list(enumerate(iterator))[start:end]:
|
||||
if item == elem:
|
||||
return i
|
||||
# This will not run in dynamo
|
||||
@ -126,13 +126,6 @@ def getattr_and_trace(*args, **kwargs):
|
||||
return fn(*args[2:], **kwargs)
|
||||
|
||||
|
||||
def enumerate(iterable, start=0):
|
||||
n = start
|
||||
for elem in iterable:
|
||||
yield n, elem
|
||||
n += 1
|
||||
|
||||
|
||||
def mapping_get(obj, key, value=None):
|
||||
try:
|
||||
return obj.__getitem__(key)
|
||||
|
@ -1625,8 +1625,8 @@ class InstructionTranslatorBase(
|
||||
|
||||
if not isinstance(
|
||||
argsvars, BaseListVariable
|
||||
) and argsvars.has_force_unpack_var_sequence(self):
|
||||
argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self))
|
||||
) and argsvars.has_unpack_var_sequence(self):
|
||||
argsvars = TupleVariable(argsvars.unpack_var_sequence(self))
|
||||
|
||||
# Unpack for cases like fn(**obj) where obj is a map
|
||||
if isinstance(kwargsvars, UserDefinedObjectVariable):
|
||||
@ -1795,7 +1795,7 @@ class InstructionTranslatorBase(
|
||||
items = []
|
||||
for seq in seqs:
|
||||
try:
|
||||
items.extend(seq.force_unpack_var_sequence(self))
|
||||
items.extend(seq.unpack_var_sequence(self))
|
||||
except NotImplementedError:
|
||||
unimplemented(f"BUILD_LIST_UNPACK {seq}")
|
||||
self.push(cls(items, mutable_local=MutableLocal()))
|
||||
@ -1833,7 +1833,7 @@ class InstructionTranslatorBase(
|
||||
assert isinstance(keys, TupleVariable)
|
||||
assert keys.is_python_constant()
|
||||
|
||||
keys = keys.force_unpack_var_sequence(self)
|
||||
keys = keys.unpack_var_sequence(self)
|
||||
assert len(keys) == len(values)
|
||||
|
||||
self.push(
|
||||
@ -1923,8 +1923,8 @@ class InstructionTranslatorBase(
|
||||
# x, y = a.shape
|
||||
proxy = getattr(seq.obj.as_proxy(), seq.name)
|
||||
val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)]
|
||||
elif seq.has_force_unpack_var_sequence(self):
|
||||
val = seq.force_unpack_var_sequence(self)
|
||||
elif seq.has_unpack_var_sequence(self):
|
||||
val = seq.unpack_var_sequence(self)
|
||||
else:
|
||||
unimplemented(f"UNPACK_SEQUENCE {seq}")
|
||||
if len(val) != inst.argval:
|
||||
@ -1937,8 +1937,8 @@ class InstructionTranslatorBase(
|
||||
prefix = inst.argval & 0xFF # low byte
|
||||
suffix = inst.argval >> 8 # high byte
|
||||
seq = self.pop()
|
||||
if seq.has_force_unpack_var_sequence(self):
|
||||
vals = list(seq.force_unpack_var_sequence(self))
|
||||
if seq.has_unpack_var_sequence(self):
|
||||
vals = list(seq.unpack_var_sequence(self))
|
||||
assert len(vals) >= prefix + suffix
|
||||
vals_prefix = vals[:prefix]
|
||||
vals_list = vals[prefix : len(vals) - suffix]
|
||||
@ -2362,7 +2362,7 @@ class InstructionTranslatorBase(
|
||||
self.UNARY_POSITIVE(inst)
|
||||
elif inst.argval == 6:
|
||||
# INTRINSIC_LIST_TO_TUPLE
|
||||
self.push(TupleVariable(self.pop().force_unpack_var_sequence(self)))
|
||||
self.push(TupleVariable(self.pop().unpack_var_sequence(self)))
|
||||
else:
|
||||
unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}")
|
||||
|
||||
|
@ -1387,12 +1387,8 @@ def same(
|
||||
"""Check correctness to see if ref and res match"""
|
||||
if fp64_ref is None:
|
||||
fp64_ref = ref
|
||||
if isinstance(
|
||||
ref, (list, tuple, collections.deque, torch.nn.ParameterList, torch.Size)
|
||||
):
|
||||
assert isinstance(
|
||||
res, (list, tuple, collections.deque)
|
||||
), f"type mismatch {type(ref)} {type(res)}"
|
||||
if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)):
|
||||
assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}"
|
||||
if len(ref) != len(res):
|
||||
log_error("Length mismatch")
|
||||
return False
|
||||
|
@ -41,12 +41,9 @@ from .higher_order_ops import (
|
||||
from .iter import (
|
||||
CountIteratorVariable,
|
||||
CycleIteratorVariable,
|
||||
EnumerateVariable,
|
||||
IteratorVariable,
|
||||
ItertoolsVariable,
|
||||
MapVariable,
|
||||
RepeatIteratorVariable,
|
||||
ZipVariable,
|
||||
)
|
||||
from .lazy import LazyVariableTracker
|
||||
from .lists import (
|
||||
|
@ -286,15 +286,6 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
|
||||
raise NotImplementedError
|
||||
|
||||
def force_unpack_var_sequence(self, tx) -> List["VariableTracker"]:
|
||||
# like unpack_var_sequence, but should only be used when it is
|
||||
# safe to eagerly (vs. lazily) unpack this variable.
|
||||
# e.g. map(f, x) is normally evaluated lazily but sometimes
|
||||
# we want to force eager unpacking, e.g. when converting to a list.
|
||||
# NOTE: this method is allowed to mutate the VariableTracker, so
|
||||
# it should only be called once.
|
||||
return self.unpack_var_sequence(tx)
|
||||
|
||||
def has_unpack_var_sequence(self, tx) -> bool:
|
||||
try:
|
||||
self.unpack_var_sequence(tx)
|
||||
@ -302,10 +293,6 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
except NotImplementedError:
|
||||
return False
|
||||
|
||||
# NB: don't call force_unpack_var_sequence, especially if it mutates!
|
||||
def has_force_unpack_var_sequence(self, tx) -> bool:
|
||||
return self.has_unpack_var_sequence(tx)
|
||||
|
||||
def inspect_parameter_names(self) -> List[str]:
|
||||
unimplemented(f"inspect_parameter_names: {self}")
|
||||
|
||||
|
@ -1075,8 +1075,9 @@ class BuiltinVariable(VariableTracker):
|
||||
return tx.inline_user_function_return(user_func_variable, [arg], {})
|
||||
|
||||
def _call_min_max(self, tx: "InstructionTranslator", *args):
|
||||
if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
|
||||
items = args[0].force_unpack_var_sequence(tx)
|
||||
if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
|
||||
# expand iterable
|
||||
items = args[0].unpack_var_sequence(tx)
|
||||
return self._call_min_max_seq(tx, items)
|
||||
elif len(args) == 2:
|
||||
return self._call_min_max_binary(tx, args[0], args[1])
|
||||
@ -1091,10 +1092,6 @@ class BuiltinVariable(VariableTracker):
|
||||
return functools.reduce(functools.partial(self._call_min_max_binary, tx), items)
|
||||
|
||||
def _call_min_max_binary(self, tx: "InstructionTranslator", a, b):
|
||||
if a is None or b is None:
|
||||
# a or b could be none if we reduce and _call_min_max_binary failed
|
||||
# to return something
|
||||
return
|
||||
if self.tensor_args(a, b):
|
||||
if not isinstance(a, variables.TensorVariable):
|
||||
a, b = b, a
|
||||
@ -1243,15 +1240,17 @@ class BuiltinVariable(VariableTracker):
|
||||
),
|
||||
)
|
||||
|
||||
# NOTE must handle IteratorVariable separately!
|
||||
def _call_iter_tuple_list(
|
||||
self, tx: "InstructionTranslator", obj=None, *args, **kwargs
|
||||
):
|
||||
assert not isinstance(obj, variables.IteratorVariable)
|
||||
|
||||
if self._dynamic_args(*args, **kwargs):
|
||||
return self._dyn_proxy(tx, *args, **kwargs)
|
||||
|
||||
if isinstance(obj, variables.IteratorVariable):
|
||||
# For non-list iterators, we will guard on vars that
|
||||
# determine the control flow
|
||||
return obj
|
||||
|
||||
cls = variables.BaseListVariable.cls_for(self.fn)
|
||||
if obj is None:
|
||||
return cls(
|
||||
@ -1279,22 +1278,9 @@ class BuiltinVariable(VariableTracker):
|
||||
mutable_local=MutableLocal(),
|
||||
)
|
||||
|
||||
def _call_tuple_list(self, tx, obj=None, *args, **kwargs):
|
||||
if isinstance(obj, variables.IteratorVariable):
|
||||
cls = variables.BaseListVariable.cls_for(self.fn)
|
||||
return cls(
|
||||
list(obj.force_unpack_var_sequence(tx)),
|
||||
mutable_local=MutableLocal(),
|
||||
)
|
||||
else:
|
||||
return self._call_iter_tuple_list(tx, obj, *args, **kwargs)
|
||||
|
||||
def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs):
|
||||
if isinstance(obj, variables.IteratorVariable):
|
||||
ret = obj
|
||||
else:
|
||||
# Handle the case where we are iterating over a tuple, list or iterator
|
||||
ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs)
|
||||
# Handle the case where we are iterating over a tuple, list or iterator
|
||||
ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs)
|
||||
|
||||
if ret is None:
|
||||
# If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway.
|
||||
@ -1303,8 +1289,8 @@ class BuiltinVariable(VariableTracker):
|
||||
return obj.call_method(tx, "__iter__", args, kwargs)
|
||||
return ret
|
||||
|
||||
call_tuple = _call_tuple_list
|
||||
call_list = _call_tuple_list
|
||||
call_tuple = _call_iter_tuple_list
|
||||
call_list = _call_iter_tuple_list
|
||||
|
||||
def call_callable(self, tx: "InstructionTranslator", arg):
|
||||
from .functions import BaseUserFunctionVariable
|
||||
@ -1360,12 +1346,10 @@ class BuiltinVariable(VariableTracker):
|
||||
ListVariable,
|
||||
TupleVariable,
|
||||
ListIteratorVariable,
|
||||
variables.IteratorVariable,
|
||||
),
|
||||
):
|
||||
items = dict(
|
||||
x.force_unpack_var_sequence(tx)
|
||||
for x in arg.force_unpack_var_sequence(tx)
|
||||
x.unpack_var_sequence(tx) for x in arg.unpack_var_sequence(tx)
|
||||
)
|
||||
return ConstDictVariable(items, user_cls, mutable_local=MutableLocal())
|
||||
elif isinstance(arg, variables.MutableMappingVariable):
|
||||
@ -1411,12 +1395,13 @@ class BuiltinVariable(VariableTracker):
|
||||
return DictVariableType(
|
||||
dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal()
|
||||
)
|
||||
elif arg.has_force_unpack_var_sequence(tx):
|
||||
keys = arg.force_unpack_var_sequence(tx)
|
||||
if all(is_hashable(v) for v in keys):
|
||||
return DictVariableType(
|
||||
dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal()
|
||||
)
|
||||
elif arg.has_unpack_var_sequence(tx) and all(
|
||||
is_hashable(v) for v in arg.unpack_var_sequence(tx)
|
||||
):
|
||||
keys = arg.unpack_var_sequence(tx)
|
||||
return DictVariableType(
|
||||
dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal()
|
||||
)
|
||||
unimplemented(f"{user_cls.__name__}.fromkeys(): {args} {kwargs}")
|
||||
|
||||
def call_set(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
@ -1428,8 +1413,8 @@ class BuiltinVariable(VariableTracker):
|
||||
arg = args[0]
|
||||
if isinstance(arg, variables.SetVariable):
|
||||
return arg.clone(mutable_local=MutableLocal())
|
||||
elif arg.has_force_unpack_var_sequence(tx):
|
||||
items = arg.force_unpack_var_sequence(tx)
|
||||
elif arg.has_unpack_var_sequence(tx):
|
||||
items = arg.unpack_var_sequence(tx)
|
||||
return SetVariable(items, mutable_local=MutableLocal())
|
||||
elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance(
|
||||
arg.value, KeysView
|
||||
@ -1448,36 +1433,32 @@ class BuiltinVariable(VariableTracker):
|
||||
def call_zip(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
if kwargs:
|
||||
assert len(kwargs) == 1 and "strict" in kwargs
|
||||
strict = kwargs.pop("strict", False)
|
||||
args = [
|
||||
arg.unpack_var_sequence(tx) if arg.has_unpack_var_sequence(tx) else arg
|
||||
for arg in args
|
||||
]
|
||||
return variables.ZipVariable(args, strict=strict, mutable_local=MutableLocal())
|
||||
if all(x.has_unpack_var_sequence(tx) for x in args):
|
||||
unpacked = [arg.unpack_var_sequence(tx) for arg in args]
|
||||
if kwargs.pop("strict", False) and len(unpacked) > 0:
|
||||
if not all(len(u) == len(unpacked[0]) for u in unpacked):
|
||||
raise UserError(
|
||||
ValueError,
|
||||
"zip() has one argument of len differing from others",
|
||||
)
|
||||
items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)]
|
||||
return variables.TupleVariable(items)
|
||||
|
||||
_call_enumerate_polyfill = _polyfill_call_impl("enumerate")
|
||||
|
||||
def call_enumerate(self, tx: "InstructionTranslator", iterable, start=_SENTINEL):
|
||||
if start is self._SENTINEL:
|
||||
def call_enumerate(self, tx: "InstructionTranslator", *args):
|
||||
if len(args) == 1:
|
||||
start = 0
|
||||
else:
|
||||
assert isinstance(start, variables.ConstantVariable)
|
||||
start = start.as_python_constant()
|
||||
|
||||
if iterable.has_unpack_var_sequence(tx):
|
||||
return variables.EnumerateVariable(
|
||||
iterable.unpack_var_sequence(tx),
|
||||
start,
|
||||
mutable_local=MutableLocal(),
|
||||
)
|
||||
elif isinstance(iterable, variables.IteratorVariable):
|
||||
return variables.EnumerateVariable(
|
||||
iterable, start, mutable_local=MutableLocal()
|
||||
)
|
||||
|
||||
return self._call_enumerate_polyfill(
|
||||
tx, iterable, variables.ConstantVariable.create(start)
|
||||
)
|
||||
assert len(args) == 2
|
||||
assert isinstance(args[1], variables.ConstantVariable)
|
||||
start = args[1].as_python_constant()
|
||||
if args[0].has_unpack_var_sequence(tx):
|
||||
items = [
|
||||
variables.TupleVariable(
|
||||
[variables.ConstantVariable.create(idx), var],
|
||||
)
|
||||
for idx, var in enumerate(args[0].unpack_var_sequence(tx), start)
|
||||
]
|
||||
return variables.TupleVariable(items)
|
||||
|
||||
def call_len(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
return args[0].call_method(tx, "__len__", args[1:], kwargs)
|
||||
@ -1578,11 +1559,10 @@ class BuiltinVariable(VariableTracker):
|
||||
return obj.call_hasattr(tx, name)
|
||||
|
||||
def call_map(self, tx: "InstructionTranslator", fn, *seqs):
|
||||
seqs = [
|
||||
seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq
|
||||
for seq in seqs
|
||||
]
|
||||
return variables.MapVariable(fn, seqs, mutable_local=MutableLocal())
|
||||
if all(seq.has_unpack_var_sequence(tx) for seq in seqs):
|
||||
unpacked = [seq.unpack_var_sequence(tx) for seq in seqs]
|
||||
items = [fn.call_function(tx, list(args), {}) for args in zip(*unpacked)]
|
||||
return variables.TupleVariable(items)
|
||||
|
||||
def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL):
|
||||
# Special case for sum on tuple of floats and ints
|
||||
@ -1601,10 +1581,10 @@ class BuiltinVariable(VariableTracker):
|
||||
return variables.ConstantVariable.create(
|
||||
sum((x.value for x in seq.items), start=start.value),
|
||||
)
|
||||
if seq.has_force_unpack_var_sequence(tx):
|
||||
if seq.has_unpack_var_sequence(tx):
|
||||
if start is self._SENTINEL:
|
||||
start = variables.ConstantVariable.create(0)
|
||||
items = seq.force_unpack_var_sequence(tx)
|
||||
items = seq.unpack_var_sequence(tx)
|
||||
return BuiltinVariable(functools.reduce).call_function(
|
||||
tx,
|
||||
[
|
||||
@ -1618,8 +1598,8 @@ class BuiltinVariable(VariableTracker):
|
||||
def call_reduce(
|
||||
self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL
|
||||
):
|
||||
if iterable.has_force_unpack_var_sequence(tx):
|
||||
items = iterable.force_unpack_var_sequence(tx)
|
||||
if iterable.has_unpack_var_sequence(tx):
|
||||
items = iterable.unpack_var_sequence(tx)
|
||||
if initial is self._SENTINEL:
|
||||
value, items = items[0], items[1:]
|
||||
else:
|
||||
@ -1906,12 +1886,11 @@ class BuiltinVariable(VariableTracker):
|
||||
return variables.TupleVariable(items)
|
||||
|
||||
def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwargs):
|
||||
if obj.has_force_unpack_var_sequence(tx) and not isinstance(
|
||||
obj, variables.TensorVariable
|
||||
if (
|
||||
obj.has_unpack_var_sequence(tx)
|
||||
and not isinstance(obj, variables.TensorVariable)
|
||||
and all(x.is_python_constant() for x in obj.unpack_var_sequence(tx))
|
||||
):
|
||||
unpacked = obj.force_unpack_var_sequence(tx)
|
||||
if not all(x.is_python_constant() for x in unpacked):
|
||||
return
|
||||
function = kwargs.pop("key", None)
|
||||
reverse = kwargs.pop(
|
||||
"reverse", ConstantVariable.create(False)
|
||||
@ -1919,7 +1898,7 @@ class BuiltinVariable(VariableTracker):
|
||||
assert len(kwargs) == 0
|
||||
if function:
|
||||
items = sorted(
|
||||
unpacked,
|
||||
obj.unpack_var_sequence(tx),
|
||||
key=lambda x: function.call_function(
|
||||
tx, [x], {}
|
||||
).as_python_constant(),
|
||||
@ -1927,7 +1906,7 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
else:
|
||||
items = sorted(
|
||||
unpacked,
|
||||
obj.unpack_var_sequence(tx),
|
||||
key=lambda x: x.as_python_constant(),
|
||||
reverse=reverse,
|
||||
)
|
||||
|
@ -147,14 +147,6 @@ class ConstantVariable(VariableTracker):
|
||||
return variables.BuiltinVariable(str.format).call_function(
|
||||
tx, [self, *args], kwargs
|
||||
)
|
||||
elif name == "join" and istype(self.value, str):
|
||||
assert len(args) == 1 and len(kwargs) == 0
|
||||
arg_unpacked = args[0].force_unpack_var_sequence(tx)
|
||||
try:
|
||||
arg_const = [x.as_python_constant() for x in arg_unpacked]
|
||||
return ConstantVariable.create(self.value.join(arg_const))
|
||||
except NotImplementedError:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
if any(isinstance(x, SymNodeVariable) for x in args):
|
||||
# Promote to SymNodeVariable for operations involving dynamic shapes.
|
||||
|
@ -314,7 +314,6 @@ class ConstDictVariable(VariableTracker):
|
||||
ListVariable,
|
||||
TupleVariable,
|
||||
ListIteratorVariable,
|
||||
variables.IteratorVariable,
|
||||
UserDefinedObjectVariable,
|
||||
),
|
||||
)
|
||||
|
@ -2,17 +2,14 @@
|
||||
|
||||
import itertools
|
||||
import operator
|
||||
import sys
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from .. import polyfill, variables
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..exc import (
|
||||
handle_observed_exception,
|
||||
ObservedUserStopIteration,
|
||||
raise_observed_exception,
|
||||
unimplemented,
|
||||
UserError,
|
||||
)
|
||||
from .base import MutableLocal, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
@ -60,7 +57,6 @@ class ItertoolsVariable(VariableTracker):
|
||||
and not kwargs
|
||||
and all(arg.has_unpack_var_sequence(tx) for arg in args)
|
||||
):
|
||||
# TODO support itertools.chain with arbitrary iterables
|
||||
seqs = [arg.unpack_var_sequence(tx) for arg in args]
|
||||
items = list(itertools.chain.from_iterable(seqs))
|
||||
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
|
||||
@ -212,25 +208,6 @@ class IteratorVariable(VariableTracker):
|
||||
def next_variable(self, tx):
|
||||
unimplemented("abstract method, must implement")
|
||||
|
||||
# NOTE: only call when unpacking this iterator safely done eagerly!
|
||||
# Normally, iterators are accessed lazily.
|
||||
# Example of safe eager unpacking: list(map(f, seq))
|
||||
# Example of unsafe eager unpacking: list(islice(map(f, seq), 5))
|
||||
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
|
||||
result = []
|
||||
while True:
|
||||
try:
|
||||
result.append(self.next_variable(tx))
|
||||
except ObservedUserStopIteration:
|
||||
handle_observed_exception(tx)
|
||||
break
|
||||
return result
|
||||
|
||||
# don't call force_unpack_var_sequence since it can mutate
|
||||
# IteratorVariable state!
|
||||
def has_force_unpack_var_sequence(self, tx) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class RepeatIteratorVariable(IteratorVariable):
|
||||
def __init__(self, item: VariableTracker, **kwargs) -> None:
|
||||
@ -241,18 +218,6 @@ class RepeatIteratorVariable(IteratorVariable):
|
||||
def next_variable(self, tx):
|
||||
return self.item
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
[
|
||||
codegen.create_load_python_module(itertools),
|
||||
codegen.create_load_attr("repeat"),
|
||||
]
|
||||
)
|
||||
)
|
||||
codegen(self.item)
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
||||
|
||||
class CountIteratorVariable(IteratorVariable):
|
||||
def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None:
|
||||
@ -266,23 +231,10 @@ class CountIteratorVariable(IteratorVariable):
|
||||
|
||||
def next_variable(self, tx):
|
||||
assert self.mutable_local
|
||||
old_item = self.item
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.item = self.item.call_method(tx, "__add__", [self.step], {})
|
||||
return old_item
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
[
|
||||
codegen.create_load_python_module(itertools),
|
||||
codegen.create_load_attr("count"),
|
||||
]
|
||||
)
|
||||
)
|
||||
codegen(self.item)
|
||||
codegen(self.step)
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
next_item = self.item.call_method(tx, "__add__", [self.step], {})
|
||||
self.item = next_item
|
||||
return self.item
|
||||
|
||||
|
||||
class CycleIteratorVariable(IteratorVariable):
|
||||
@ -328,180 +280,3 @@ class CycleIteratorVariable(IteratorVariable):
|
||||
return self.item
|
||||
else:
|
||||
raise_observed_exception(StopIteration, tx, self)
|
||||
|
||||
|
||||
class ZipVariable(IteratorVariable):
|
||||
"""
|
||||
Represents zip(*iterables)
|
||||
"""
|
||||
|
||||
_nonvar_fields = {
|
||||
"index",
|
||||
"strict",
|
||||
*IteratorVariable._nonvar_fields,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
iterables: List[Union[List[VariableTracker], VariableTracker]],
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(iterables, list)
|
||||
# can be list[Variable] or VariableTracker (with next_variable implemented)
|
||||
self.iterables = iterables
|
||||
self.index = 0
|
||||
self.strict = strict
|
||||
|
||||
def python_type(self):
|
||||
return zip
|
||||
|
||||
def has_unpack_var_sequence(self, tx) -> bool:
|
||||
return all(
|
||||
isinstance(it, list) or it.has_unpack_var_sequence(tx)
|
||||
for it in self.iterables
|
||||
)
|
||||
|
||||
def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
|
||||
assert self.has_unpack_var_sequence(tx)
|
||||
iterables = []
|
||||
for it in self.iterables:
|
||||
if isinstance(it, list):
|
||||
iterables.append(it[self.index :])
|
||||
else:
|
||||
iterables.append(it.unpack_var_sequence(tx))
|
||||
kwargs = {"strict": self.strict} if self.strict else {}
|
||||
zipped = zip(*iterables, **kwargs)
|
||||
return [variables.TupleVariable(list(var)) for var in zipped]
|
||||
|
||||
def next_variable(self, tx):
|
||||
assert self.mutable_local
|
||||
old_index = self.index
|
||||
args = []
|
||||
|
||||
def get_item(it):
|
||||
if isinstance(it, list):
|
||||
if old_index >= len(it):
|
||||
raise_observed_exception(StopIteration, tx, self)
|
||||
return it[old_index]
|
||||
else:
|
||||
return it.next_variable(tx)
|
||||
|
||||
try:
|
||||
for idx, it in enumerate(self.iterables):
|
||||
args.append(get_item(it))
|
||||
except ObservedUserStopIteration:
|
||||
if self.strict:
|
||||
if idx == 0:
|
||||
# all other iterables should be exhausted
|
||||
for it in self.iterables:
|
||||
try:
|
||||
get_item(it)
|
||||
except ObservedUserStopIteration:
|
||||
handle_observed_exception(tx)
|
||||
continue
|
||||
# no ObservedUserStopIteration - fall through to UserError
|
||||
break
|
||||
else:
|
||||
# all iterables exhausted, raise original error
|
||||
raise
|
||||
handle_observed_exception(tx)
|
||||
raise UserError(
|
||||
ValueError,
|
||||
"zip() has one argument of len differing from others",
|
||||
) from None
|
||||
raise
|
||||
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.index += 1
|
||||
return variables.TupleVariable(args)
|
||||
|
||||
def reconstruct_items(self, codegen):
|
||||
for it in self.iterables:
|
||||
if isinstance(it, list):
|
||||
remaining_items = it[self.index :]
|
||||
codegen.foreach(remaining_items)
|
||||
codegen.append_output(
|
||||
create_instruction("BUILD_TUPLE", arg=len(remaining_items))
|
||||
)
|
||||
else:
|
||||
codegen(it)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True
|
||||
)
|
||||
self.reconstruct_items(codegen)
|
||||
codegen.append_output(
|
||||
create_instruction("BUILD_TUPLE", arg=len(self.iterables))
|
||||
)
|
||||
if sys.version_info >= (3, 10):
|
||||
codegen.extend_output(
|
||||
[
|
||||
codegen.create_load_const("strict"),
|
||||
codegen.create_load_const(self.strict),
|
||||
create_instruction("BUILD_MAP", arg=1),
|
||||
create_instruction("CALL_FUNCTION_EX", arg=1),
|
||||
]
|
||||
)
|
||||
else:
|
||||
codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0))
|
||||
|
||||
|
||||
class MapVariable(ZipVariable):
|
||||
"""
|
||||
Represents map(fn, *iterables)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fn: VariableTracker,
|
||||
iterables: List[Union[List[VariableTracker], VariableTracker]],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(iterables, **kwargs)
|
||||
self.fn = fn
|
||||
|
||||
def python_type(self):
|
||||
return map
|
||||
|
||||
def has_unpack_var_sequence(self, tx) -> bool:
|
||||
return False
|
||||
|
||||
def next_variable(self, tx):
|
||||
args = super().next_variable(tx)
|
||||
return self.fn.call_function(tx, args.items, {})
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True
|
||||
)
|
||||
codegen(self.fn)
|
||||
self.reconstruct_items(codegen)
|
||||
codegen.extend_output(
|
||||
[
|
||||
create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1),
|
||||
create_instruction("CALL_FUNCTION_EX", arg=0),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class EnumerateVariable(ZipVariable):
|
||||
def __init__(
|
||||
self,
|
||||
iterable: Union[List[VariableTracker], VariableTracker],
|
||||
start: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
[CountIteratorVariable(start, mutable_local=MutableLocal()), iterable],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "enumerate"))
|
||||
codegen(self.iterables[1])
|
||||
assert isinstance(self.iterables[0], CountIteratorVariable)
|
||||
codegen(self.iterables[0].item)
|
||||
codegen.extend_output(codegen.create_call_function_kw(2, ("start",), False))
|
||||
|
@ -29,7 +29,6 @@ from ..utils import (
|
||||
from .base import MutableLocal, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .functions import UserFunctionVariable, UserMethodVariable
|
||||
from .iter import IteratorVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -340,11 +339,11 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
name == "extend"
|
||||
and self.mutable_local
|
||||
and args
|
||||
and args[0].has_force_unpack_var_sequence(tx)
|
||||
and args[0].has_unpack_var_sequence(tx)
|
||||
):
|
||||
assert not kwargs
|
||||
(arg,) = args
|
||||
seq = arg.force_unpack_var_sequence(tx)
|
||||
seq = arg.unpack_var_sequence(tx)
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items.extend(seq)
|
||||
return ConstantVariable.create(None)
|
||||
@ -428,13 +427,11 @@ class ListVariable(CommonListMethodsVariable):
|
||||
key, value = args
|
||||
tx.output.side_effects.mutation(self)
|
||||
if isinstance(key, SliceVariable):
|
||||
if not value.has_force_unpack_var_sequence(tx):
|
||||
if not value.has_unpack_var_sequence(tx):
|
||||
unimplemented(
|
||||
f"Missing dynamo support for expanding {value} into a list for slice assignment."
|
||||
)
|
||||
self.items[key.as_python_constant()] = value.force_unpack_var_sequence(
|
||||
tx
|
||||
)
|
||||
self.items[key.as_python_constant()] = value.unpack_var_sequence(tx)
|
||||
else:
|
||||
self.items[key.as_python_constant()] = value
|
||||
return ConstantVariable.create(None)
|
||||
@ -462,12 +459,7 @@ class DequeVariable(CommonListMethodsVariable):
|
||||
)
|
||||
)
|
||||
codegen.foreach(self.items)
|
||||
codegen.extend_output(
|
||||
[
|
||||
create_instruction("BUILD_LIST", arg=len(self.items)),
|
||||
*create_call_function(1, False),
|
||||
]
|
||||
)
|
||||
codegen.extend_output(create_call_function(len(self.items), False))
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
@ -490,15 +482,11 @@ class DequeVariable(CommonListMethodsVariable):
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items[key.as_python_constant()] = value
|
||||
return ConstantVariable.create(None)
|
||||
elif (
|
||||
name == "extendleft"
|
||||
and self.mutable_local
|
||||
and args[0].has_force_unpack_var_sequence(tx)
|
||||
):
|
||||
elif name == "extendleft" and self.mutable_local:
|
||||
assert not kwargs
|
||||
|
||||
(arg,) = args
|
||||
prefix = arg.force_unpack_var_sequence(tx)
|
||||
prefix = arg.unpack_var_sequence(tx)
|
||||
prefix.reverse()
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items = prefix + list(self.items)
|
||||
@ -796,10 +784,10 @@ class SliceVariable(BaseListVariable):
|
||||
return self.items[fields.index(name)]
|
||||
|
||||
|
||||
class ListIteratorVariable(IteratorVariable):
|
||||
class ListIteratorVariable(VariableTracker):
|
||||
_nonvar_fields = {
|
||||
"index",
|
||||
*IteratorVariable._nonvar_fields,
|
||||
*VariableTracker._nonvar_fields,
|
||||
}
|
||||
|
||||
def __init__(self, items, index: int = 0, **kwargs) -> None:
|
||||
@ -850,9 +838,6 @@ class ListIteratorVariable(IteratorVariable):
|
||||
def unpack_var_sequence(self, tx):
|
||||
return list(self.items[self.index :])
|
||||
|
||||
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
|
||||
return self.unpack_var_sequence(tx)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
remaining_items = self.items[self.index :]
|
||||
codegen.foreach(remaining_items)
|
||||
|
@ -352,8 +352,8 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
elif self.value is collections.deque and not kwargs:
|
||||
if len(args) == 0:
|
||||
items = []
|
||||
elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
|
||||
items = args[0].force_unpack_var_sequence(tx)
|
||||
elif len(args) == 1 and args[0].has_unpack_var_sequence(tx):
|
||||
items = args[0].unpack_var_sequence(tx)
|
||||
else:
|
||||
unimplemented("deque() with more than 1 arg not supported")
|
||||
return variables.lists.DequeVariable(items, mutable_local=MutableLocal())
|
||||
@ -653,7 +653,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
assert not (args or kwargs)
|
||||
items = []
|
||||
keys = self.call_method(tx, "keys", [], {})
|
||||
for key in keys.force_unpack_var_sequence(tx):
|
||||
for key in keys.unpack_var_sequence(tx):
|
||||
items.append(
|
||||
TupleVariable(
|
||||
[key, self.odict_getitem(tx, key)],
|
||||
|
Reference in New Issue
Block a user