diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index a0e3369ed73c..e936e66f23f8 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -239,6 +239,22 @@ class FunctionTests(torch._dynamo.test_case.TestCase): v = v + x return v + def test_itertools_reconstruct(self): + def fn(a): + it1 = itertools.repeat(1) + it2 = itertools.count(2) + for _ in range(3): + a += next(it1) + a += next(it2) + return it1, it2, a + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + i1, i2, a = fn(torch.ones(3, 3)) + it1, it2, b = opt_fn(torch.ones(3, 3)) + self.assertEqual(next(i1), next(it1)) + self.assertEqual(next(i2), next(it2)) + self.assertEqual(a, b) + @make_test def test_obj_eq(a, b): v = a + b @@ -507,8 +523,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase): empty = collections.deque() d.extend(empty) - # dynamo same() util doesn't support deque so just return a list - return list(d) + return d @make_test def test_slice1(a): @@ -3115,6 +3130,199 @@ class GraphModule(torch.nn.Module): fn(arr, np.s_[..., 1], np.array([3, 3])), np.array([[1, 3], [2, 3]]) ) + def test_map_return(self): + def fn(a, b): + return map(lambda x: x + 1, [a, b]) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + m = opt_fn(torch.randn(3, 3), torch.randn(3, 3)) + self.assertIsInstance(m, map) + + @make_test + def test_map_max(a, b): + return max(map(lambda x: x.sum(), [a, b])) + + # max(map(...)) graph breaks + @unittest.expectedFailure + @make_test + def test_map_max_const(a): + return max(map(lambda x: x, [1, 2, 3])), a + 1 + + @make_test + def test_map_list(a, b): + return list(map(lambda x: x + 1, [a, b])) + + @make_test + def test_map_tuple(a, b): + return tuple(map(lambda x: x + 1, [a, b])) + + @make_test + def test_map_iter(a, b): + it = iter(map(lambda x: x + 1, [a, b])) + return next(it) + + @make_test + def test_map_zip_dict(a): + d = dict( + zip( + map(lambda x: x + 1, [0, 1, 2]), + [map(lambda x: x - 1, [y]) for y in [3, 4, 5]], + ) + ) + return list(d[3])[0], a + 1 # noqa: RUF015 + + @make_test + def test_map_dict_fromkeys(a): + return dict.fromkeys(map(lambda x: x + 1, [0, 1])), a + 1 + + @make_test + def test_map_set(a): + return set(map(lambda x: x + 1, [0, 1])), a + 1 + + # test_map_sum defined earlier + + @make_test + def test_map_reduce(a, b): + return functools.reduce(lambda x, y: x + y, map(lambda x: x + 1, [a, b])) + + @make_test + def test_map_sorted(a): + return sorted(map(lambda x: x + 1, [0, 4, 3, 1, 2])), a + 1 + + @make_test + def test_map_list_extend(a, b, c): + l = [a] + l.extend(map(lambda x: x + 1, [b, c])) + return l + + @make_test + def test_map_list_slice_assign(a, b, c, d, e): + l = [a, b, c] + l[1:2] = map(lambda x: x + 1, [d, e]) + return l + + @make_test + def test_map_deque_extendleft(a, b, c): + d = collections.deque([a]) + d.extendleft(map(lambda x: x + 1, [b, c])) + return d + + @make_test + def test_map_str_join(a): + return "".join(map(lambda x: x, ["a", "b", "c"])), a + 1 + + def test_map_with_graph_break(self): + def f(a): + a += 1 + + def g(x): + nonlocal a + a += 1 + return x + 1 + + m = map(g, [1, 2, 3, 4, 5]) + a += next(m) # won't graph break + torch._dynamo.graph_break() + a += next(m) # will graph break + return a + + cnts = torch._dynamo.testing.CompileCounter() + opt_f = torch.compile(f, backend=cnts) + self.assertEqual(f(torch.ones(3, 3)), opt_f(torch.ones(3, 3))) + self.assertEqual(cnts.frame_count, 3) + + def test_map_reconstruct(self): + def fn(a): + return map(lambda x: x[0] + x[1], zip([1, 2, 3], [1, 2, 3])), a + 1 + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + m = opt_fn(torch.ones(3, 3))[0] + self.assertIsInstance(m, map) + self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0])) + + def test_zip_reconstruct(self): + def fn(a): + return zip([1, 2, 3], map(lambda x: x + 1, [1, 2, 3])), a + 1 + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + m = opt_fn(torch.ones(3, 3))[0] + self.assertIsInstance(m, zip) + self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0])) + + @make_test + def test_map_partial_unpack(a, b): + y = 1 + + def f(x): + nonlocal y + y += 1 + return x + + l = list(zip([a, b], map(f, [1, 2, 3, 4]))) + return a + y + + @make_test + def test_map_call_function_ex(a, b): + def f(x, y): + return x + y + + return f(*map(lambda x: x + 1, [a, b])) + + @make_test + def test_map_unpack_twice(a, b): + m = map(lambda x: x + 1, [a, b]) + l1 = list(m) + l2 = list(m) + return l1, l2 + + @make_test + def test_enumerate(a, b): + return list(enumerate([a, b], start=1)), a + 1 + + @make_test + def test_map_enumerate(a, b): + return list(enumerate(map(lambda x: x + 1, [a, b]), start=1)), a + 1 + + @make_test + def test_map_infinite(a, b): + return list(map(lambda x, y: x + y, [a, b], itertools.count(3))) + + @make_test + def test_map_unpack_vars(a, b): + x, y = map(lambda x: x + 1, [a, b]) + return x + y + + def test_enumerate_custom(self): + class MyClass: + def __iter__(self): + self.a = 1 + return self + + def __next__(self): + if self.a > 3: + raise StopIteration + self.a += 1 + return self.a + + def fn(x): + for i, it in enumerate(MyClass()): + x += i + it + return x + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(torch.ones(3, 3)), opt_fn(torch.ones(3, 3))) + + def test_enumerate_reconstruct(self): + def fn(a, b): + return enumerate([a, b], start=1) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + inps = (torch.randn(3, 3), torch.randn(3, 3)) + it1 = fn(*inps) + it2 = opt_fn(*inps) + self.assertIsInstance(it2, enumerate) + self.assertEqual(list(it1), list(it2)) + def udf_mul(x, y): return x * y @@ -3670,10 +3878,16 @@ class DefaultsTests(torch._dynamo.test_case.TestCase): with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"): nopython_fn(x, ys[:1], zs) + with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"): + nopython_fn(x, ys, zs[:1]) + # Should cause fallback if allow graph break with self.assertRaisesRegex(ValueError, "zip()"): opt_fn(x, ys[:1], zs) + with self.assertRaisesRegex(ValueError, "zip()"): + opt_fn(x, ys, zs[:1]) + def test_fn_with_attr(self): def fn(x): if fn.pred: diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 19e9f7d96172..a3c4226ed6a9 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5476,15 +5476,17 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): return x, y def g(x, y): - return tuple(map(f, x, y)) + return map(f, x, y) opt_g = torch.compile(g, fullgraph=True, backend="eager") inps = gen_inps(3, 3) - self.assertEqual(g(*inps), opt_g(*inps)) + self.assertEqual(type(g(*inps)), type(opt_g(*inps))) + self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps))) inps = gen_inps(3, 5) - self.assertEqual(g(*inps), opt_g(*inps)) + self.assertEqual(type(g(*inps)), type(opt_g(*inps))) + self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps))) def test_staticmethod_allow_in_graph(self): class MyClass: diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index e605e57b6aa4..ab92f82aa0f6 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1663,8 +1663,8 @@ class InstructionTranslatorBase( if not isinstance( argsvars, BaseListVariable - ) and argsvars.has_unpack_var_sequence(self): - argsvars = TupleVariable(argsvars.unpack_var_sequence(self)) + ) and argsvars.has_force_unpack_var_sequence(self): + argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self)) # Unpack for cases like fn(**obj) where obj is a map if isinstance(kwargsvars, UserDefinedObjectVariable): @@ -1833,7 +1833,7 @@ class InstructionTranslatorBase( items = [] for seq in seqs: try: - items.extend(seq.unpack_var_sequence(self)) + items.extend(seq.force_unpack_var_sequence(self)) except NotImplementedError: unimplemented(f"BUILD_LIST_UNPACK {seq}") self.push(cls(items, mutable_local=MutableLocal())) @@ -1871,7 +1871,7 @@ class InstructionTranslatorBase( assert isinstance(keys, TupleVariable) assert keys.is_python_constant() - keys = keys.unpack_var_sequence(self) + keys = keys.force_unpack_var_sequence(self) assert len(keys) == len(values) self.push( @@ -1961,8 +1961,8 @@ class InstructionTranslatorBase( # x, y = a.shape proxy = getattr(seq.obj.as_proxy(), seq.name) val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)] - elif seq.has_unpack_var_sequence(self): - val = seq.unpack_var_sequence(self) + elif seq.has_force_unpack_var_sequence(self): + val = seq.force_unpack_var_sequence(self) else: unimplemented(f"UNPACK_SEQUENCE {seq}") if len(val) != inst.argval: @@ -1975,8 +1975,8 @@ class InstructionTranslatorBase( prefix = inst.argval & 0xFF # low byte suffix = inst.argval >> 8 # high byte seq = self.pop() - if seq.has_unpack_var_sequence(self): - vals = list(seq.unpack_var_sequence(self)) + if seq.has_force_unpack_var_sequence(self): + vals = list(seq.force_unpack_var_sequence(self)) assert len(vals) >= prefix + suffix vals_prefix = vals[:prefix] vals_list = vals[prefix : len(vals) - suffix] @@ -2400,7 +2400,7 @@ class InstructionTranslatorBase( self.UNARY_POSITIVE(inst) elif inst.argval == 6: # INTRINSIC_LIST_TO_TUPLE - self.push(TupleVariable(self.pop().unpack_var_sequence(self))) + self.push(TupleVariable(self.pop().force_unpack_var_sequence(self))) else: unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}") diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 627ada1ff88a..0ec548d788f8 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1608,8 +1608,12 @@ def same( """Check correctness to see if ref and res match""" if fp64_ref is None: fp64_ref = ref - if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)): - assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}" + if isinstance( + ref, (list, tuple, collections.deque, torch.nn.ParameterList, torch.Size) + ): + assert isinstance( + res, (list, tuple, collections.deque) + ), f"type mismatch {type(ref)} {type(res)}" if len(ref) != len(res): log_error("Length mismatch") return False diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 021233cd6e23..d522d773e6e8 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -46,7 +46,9 @@ from .iter import ( CycleIteratorVariable, IteratorVariable, ItertoolsVariable, + MapVariable, RepeatIteratorVariable, + ZipVariable, ) from .lazy import LazyVariableTracker from .lists import ( diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 25b4ffad99e7..723c5a90c66a 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -289,6 +289,15 @@ class VariableTracker(metaclass=VariableTrackerMeta): def unpack_var_sequence(self, tx) -> List["VariableTracker"]: raise NotImplementedError + def force_unpack_var_sequence(self, tx) -> List["VariableTracker"]: + # like unpack_var_sequence, but should only be used when it is + # safe to eagerly (vs. lazily) unpack this variable. + # e.g. map(f, x) is normally evaluated lazily but sometimes + # we want to force eager unpacking, e.g. when converting to a list. + # NOTE: this method is allowed to mutate the VariableTracker, so + # it should only be called once. + return self.unpack_var_sequence(tx) + def has_unpack_var_sequence(self, tx) -> bool: try: self.unpack_var_sequence(tx) @@ -296,6 +305,10 @@ class VariableTracker(metaclass=VariableTrackerMeta): except NotImplementedError: return False + # NB: don't call force_unpack_var_sequence, especially if it mutates! + def has_force_unpack_var_sequence(self, tx) -> bool: + return self.has_unpack_var_sequence(tx) + def inspect_parameter_names(self) -> List[str]: unimplemented(f"inspect_parameter_names: {self}") diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 6ffd820c3f13..b6ff05e429d1 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1058,9 +1058,8 @@ class BuiltinVariable(VariableTracker): return tx.inline_user_function_return(user_func_variable, [arg], {}) def _call_min_max(self, tx: "InstructionTranslator", *args): - if len(args) == 1 and args[0].has_unpack_var_sequence(tx): - # expand iterable - items = args[0].unpack_var_sequence(tx) + if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) return self._call_min_max_seq(tx, items) elif len(args) == 2: return self._call_min_max_binary(tx, args[0], args[1]) @@ -1075,6 +1074,10 @@ class BuiltinVariable(VariableTracker): return functools.reduce(functools.partial(self._call_min_max_binary, tx), items) def _call_min_max_binary(self, tx: "InstructionTranslator", a, b): + if a is None or b is None: + # a or b could be none if we reduce and _call_min_max_binary failed + # to return something + return if self.tensor_args(a, b): if not isinstance(a, variables.TensorVariable): a, b = b, a @@ -1223,17 +1226,15 @@ class BuiltinVariable(VariableTracker): ), ) + # NOTE must handle IteratorVariable separately! def _call_iter_tuple_list( self, tx: "InstructionTranslator", obj=None, *args, **kwargs ): + assert not isinstance(obj, variables.IteratorVariable) + if self._dynamic_args(*args, **kwargs): return self._dyn_proxy(tx, *args, **kwargs) - if isinstance(obj, variables.IteratorVariable): - # For non-list iterators, we will guard on vars that - # determine the control flow - return obj - cls = variables.BaseListVariable.cls_for(self.fn) if obj is None: return cls( @@ -1261,9 +1262,22 @@ class BuiltinVariable(VariableTracker): mutable_local=MutableLocal(), ) + def _call_tuple_list(self, tx, obj=None, *args, **kwargs): + if isinstance(obj, variables.IteratorVariable): + cls = variables.BaseListVariable.cls_for(self.fn) + return cls( + list(obj.force_unpack_var_sequence(tx)), + mutable_local=MutableLocal(), + ) + else: + return self._call_iter_tuple_list(tx, obj, *args, **kwargs) + def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs): - # Handle the case where we are iterating over a tuple, list or iterator - ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs) + if isinstance(obj, variables.IteratorVariable): + ret = obj + else: + # Handle the case where we are iterating over a tuple, list or iterator + ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs) if ret is None: # If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway. @@ -1272,8 +1286,8 @@ class BuiltinVariable(VariableTracker): return obj.call_method(tx, "__iter__", args, kwargs) return ret - call_tuple = _call_iter_tuple_list - call_list = _call_iter_tuple_list + call_tuple = _call_tuple_list + call_list = _call_tuple_list def call_callable(self, tx: "InstructionTranslator", arg): from .functions import BaseUserFunctionVariable @@ -1331,10 +1345,12 @@ class BuiltinVariable(VariableTracker): ListVariable, TupleVariable, ListIteratorVariable, + variables.IteratorVariable, ), ): items = dict( - x.unpack_var_sequence(tx) for x in arg.unpack_var_sequence(tx) + x.force_unpack_var_sequence(tx) + for x in arg.force_unpack_var_sequence(tx) ) return ConstDictVariable(items, user_cls, mutable_local=MutableLocal()) elif isinstance(arg, variables.MutableMappingVariable): @@ -1391,13 +1407,12 @@ class BuiltinVariable(VariableTracker): return DictVariableType( dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal() ) - elif arg.has_unpack_var_sequence(tx) and all( - is_hashable(v) for v in arg.unpack_var_sequence(tx) - ): - keys = arg.unpack_var_sequence(tx) - return DictVariableType( - dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal() - ) + elif arg.has_force_unpack_var_sequence(tx): + keys = arg.force_unpack_var_sequence(tx) + if all(is_hashable(v) for v in keys): + return DictVariableType( + dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal() + ) unimplemented(f"{user_cls.__name__}.fromkeys(): {args} {kwargs}") def call_set(self, tx: "InstructionTranslator", *args, **kwargs): @@ -1409,8 +1424,8 @@ class BuiltinVariable(VariableTracker): arg = args[0] if isinstance(arg, variables.SetVariable): return arg.clone(mutable_local=MutableLocal()) - elif arg.has_unpack_var_sequence(tx): - items = arg.unpack_var_sequence(tx) + elif arg.has_force_unpack_var_sequence(tx): + items = arg.force_unpack_var_sequence(tx) return SetVariable(items, mutable_local=MutableLocal()) elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( arg.value, KeysView @@ -1443,16 +1458,12 @@ class BuiltinVariable(VariableTracker): def call_zip(self, tx: "InstructionTranslator", *args, **kwargs): if kwargs: assert len(kwargs) == 1 and "strict" in kwargs - if all(x.has_unpack_var_sequence(tx) for x in args): - unpacked = [arg.unpack_var_sequence(tx) for arg in args] - if kwargs.pop("strict", False) and len(unpacked) > 0: - if not all(len(u) == len(unpacked[0]) for u in unpacked): - raise UserError( - ValueError, - "zip() has one argument of len differing from others", - ) - items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)] - return variables.TupleVariable(items) + strict = kwargs.pop("strict", False) + args = [ + arg.unpack_var_sequence(tx) if arg.has_unpack_var_sequence(tx) else arg + for arg in args + ] + return variables.ZipVariable(args, strict=strict, mutable_local=MutableLocal()) def call_len(self, tx: "InstructionTranslator", *args, **kwargs): return args[0].call_method(tx, "__len__", args[1:], kwargs) @@ -1553,10 +1564,11 @@ class BuiltinVariable(VariableTracker): return obj.call_hasattr(tx, name) def call_map(self, tx: "InstructionTranslator", fn, *seqs): - if all(seq.has_unpack_var_sequence(tx) for seq in seqs): - unpacked = [seq.unpack_var_sequence(tx) for seq in seqs] - items = [fn.call_function(tx, list(args), {}) for args in zip(*unpacked)] - return variables.TupleVariable(items) + seqs = [ + seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq + for seq in seqs + ] + return variables.MapVariable(fn, seqs, mutable_local=MutableLocal()) def call_filter(self, tx: "InstructionTranslator", fn, seq): if seq.has_unpack_var_sequence(tx): @@ -1589,10 +1601,10 @@ class BuiltinVariable(VariableTracker): return variables.ConstantVariable.create( sum((x.value for x in seq.items), start=start.value), ) - if seq.has_unpack_var_sequence(tx): + if seq.has_force_unpack_var_sequence(tx): if start is self._SENTINEL: start = variables.ConstantVariable.create(0) - items = seq.unpack_var_sequence(tx) + items = seq.force_unpack_var_sequence(tx) return BuiltinVariable(functools.reduce).call_function( tx, [ @@ -1606,8 +1618,8 @@ class BuiltinVariable(VariableTracker): def call_reduce( self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL ): - if iterable.has_unpack_var_sequence(tx): - items = iterable.unpack_var_sequence(tx) + if iterable.has_force_unpack_var_sequence(tx): + items = iterable.force_unpack_var_sequence(tx) if initial is self._SENTINEL: value, items = items[0], items[1:] else: @@ -1903,11 +1915,12 @@ class BuiltinVariable(VariableTracker): return variables.TupleVariable(items) def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwargs): - if ( - obj.has_unpack_var_sequence(tx) - and not isinstance(obj, variables.TensorVariable) - and all(x.is_python_constant() for x in obj.unpack_var_sequence(tx)) + if obj.has_force_unpack_var_sequence(tx) and not isinstance( + obj, variables.TensorVariable ): + unpacked = obj.force_unpack_var_sequence(tx) + if not all(x.is_python_constant() for x in unpacked): + return function = kwargs.pop("key", None) reverse = kwargs.pop( "reverse", ConstantVariable.create(False) @@ -1915,7 +1928,7 @@ class BuiltinVariable(VariableTracker): assert len(kwargs) == 0 if function: items = sorted( - obj.unpack_var_sequence(tx), + unpacked, key=lambda x: function.call_function( tx, [x], {} ).as_python_constant(), @@ -1923,7 +1936,7 @@ class BuiltinVariable(VariableTracker): ) else: items = sorted( - obj.unpack_var_sequence(tx), + unpacked, key=lambda x: x.as_python_constant(), reverse=reverse, ) diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 0028ecd81dec..65de0aab6ef3 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -145,6 +145,14 @@ class ConstantVariable(VariableTracker): return variables.BuiltinVariable(str.format).call_function( tx, [self, *args], kwargs ) + elif name == "join" and istype(self.value, str): + assert len(args) == 1 and len(kwargs) == 0 + arg_unpacked = args[0].force_unpack_var_sequence(tx) + try: + arg_const = [x.as_python_constant() for x in arg_unpacked] + return ConstantVariable.create(self.value.join(arg_const)) + except NotImplementedError: + return super().call_method(tx, name, args, kwargs) if any(isinstance(x, SymNodeVariable) for x in args): # Promote to SymNodeVariable for operations involving dynamic shapes. diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 6c69364ca57e..10b55c1b7e96 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -314,6 +314,7 @@ class ConstDictVariable(VariableTracker): ListVariable, TupleVariable, ListIteratorVariable, + variables.IteratorVariable, UserDefinedObjectVariable, ), ) diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 192f25a0c6b6..1f8dac8811f5 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -2,14 +2,17 @@ import itertools import operator -from typing import Dict, List, Optional, TYPE_CHECKING +import sys +from typing import Dict, List, Optional, TYPE_CHECKING, Union from .. import polyfills, variables +from ..bytecode_transformation import create_call_function, create_instruction from ..exc import ( handle_observed_exception, ObservedUserStopIteration, raise_observed_exception, unimplemented, + UserError, ) from .base import MutableLocal, VariableTracker from .constant import ConstantVariable @@ -197,6 +200,25 @@ class IteratorVariable(VariableTracker): def next_variable(self, tx): unimplemented("abstract method, must implement") + # NOTE: only call when unpacking this iterator safely done eagerly! + # Normally, iterators are accessed lazily. + # Example of safe eager unpacking: list(map(f, seq)) + # Example of unsafe eager unpacking: list(islice(map(f, seq), 5)) + def force_unpack_var_sequence(self, tx) -> List[VariableTracker]: + result = [] + while True: + try: + result.append(self.next_variable(tx)) + except ObservedUserStopIteration: + handle_observed_exception(tx) + break + return result + + # don't call force_unpack_var_sequence since it can mutate + # IteratorVariable state! + def has_force_unpack_var_sequence(self, tx) -> bool: + return True + class RepeatIteratorVariable(IteratorVariable): def __init__(self, item: VariableTracker, **kwargs) -> None: @@ -207,6 +229,18 @@ class RepeatIteratorVariable(IteratorVariable): def next_variable(self, tx): return self.item + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(itertools), + codegen.create_load_attr("repeat"), + ] + ) + ) + codegen(self.item) + codegen.extend_output(create_call_function(1, False)) + class CountIteratorVariable(IteratorVariable): def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None: @@ -220,10 +254,23 @@ class CountIteratorVariable(IteratorVariable): def next_variable(self, tx): assert self.mutable_local + old_item = self.item tx.output.side_effects.mutation(self) - next_item = self.item.call_method(tx, "__add__", [self.step], {}) - self.item = next_item - return self.item + self.item = self.item.call_method(tx, "__add__", [self.step], {}) + return old_item + + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(itertools), + codegen.create_load_attr("count"), + ] + ) + ) + codegen(self.item) + codegen(self.step) + codegen.extend_output(create_call_function(2, False)) class CycleIteratorVariable(IteratorVariable): @@ -269,3 +316,160 @@ class CycleIteratorVariable(IteratorVariable): return self.item else: raise_observed_exception(StopIteration, tx, self) + + +class ZipVariable(IteratorVariable): + """ + Represents zip(*iterables) + """ + + _nonvar_fields = { + "index", + "strict", + *IteratorVariable._nonvar_fields, + } + + def __init__( + self, + iterables: List[Union[List[VariableTracker], VariableTracker]], + strict: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + assert isinstance(iterables, list) + # can be list[Variable] or VariableTracker (with next_variable implemented) + self.iterables = iterables + self.index = 0 + self.strict = strict + + def python_type(self): + return zip + + def has_unpack_var_sequence(self, tx) -> bool: + return all( + isinstance(it, list) or it.has_unpack_var_sequence(tx) + for it in self.iterables + ) + + def unpack_var_sequence(self, tx) -> List["VariableTracker"]: + assert self.has_unpack_var_sequence(tx) + iterables = [] + for it in self.iterables: + if isinstance(it, list): + iterables.append(it[self.index :]) + else: + iterables.append(it.unpack_var_sequence(tx)) + kwargs = {"strict": self.strict} if self.strict else {} + zipped = zip(*iterables, **kwargs) + return [variables.TupleVariable(list(var)) for var in zipped] + + def next_variable(self, tx): + assert self.mutable_local + old_index = self.index + args = [] + + def get_item(it): + if isinstance(it, list): + if old_index >= len(it): + raise_observed_exception(StopIteration, tx, self) + return it[old_index] + else: + return it.next_variable(tx) + + try: + for idx, it in enumerate(self.iterables): + args.append(get_item(it)) + except ObservedUserStopIteration: + if self.strict: + if idx == 0: + # all other iterables should be exhausted + for it in self.iterables: + try: + get_item(it) + except ObservedUserStopIteration: + handle_observed_exception(tx) + continue + # no ObservedUserStopIteration - fall through to UserError + break + else: + # all iterables exhausted, raise original error + raise + handle_observed_exception(tx) + raise UserError( + ValueError, + "zip() has one argument of len differing from others", + ) from None + raise + + tx.output.side_effects.mutation(self) + self.index += 1 + return variables.TupleVariable(args) + + def reconstruct_items(self, codegen): + for it in self.iterables: + if isinstance(it, list): + remaining_items = it[self.index :] + codegen.foreach(remaining_items) + codegen.append_output( + create_instruction("BUILD_TUPLE", arg=len(remaining_items)) + ) + else: + codegen(it) + + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True + ) + self.reconstruct_items(codegen) + codegen.append_output( + create_instruction("BUILD_TUPLE", arg=len(self.iterables)) + ) + if sys.version_info >= (3, 10): + codegen.extend_output( + [ + codegen.create_load_const("strict"), + codegen.create_load_const(self.strict), + create_instruction("BUILD_MAP", arg=1), + create_instruction("CALL_FUNCTION_EX", arg=1), + ] + ) + else: + codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0)) + + +class MapVariable(ZipVariable): + """ + Represents map(fn, *iterables) + """ + + def __init__( + self, + fn: VariableTracker, + iterables: List[Union[List[VariableTracker], VariableTracker]], + **kwargs, + ) -> None: + super().__init__(iterables, **kwargs) + self.fn = fn + + def python_type(self): + return map + + def has_unpack_var_sequence(self, tx) -> bool: + return False + + def next_variable(self, tx): + args = super().next_variable(tx) + return self.fn.call_function(tx, args.items, {}) + + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True + ) + codegen(self.fn) + self.reconstruct_items(codegen) + codegen.extend_output( + [ + create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1), + create_instruction("CALL_FUNCTION_EX", arg=0), + ] + ) diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 1fed456a8f7d..30916e0b6996 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -29,6 +29,7 @@ from ..utils import ( from .base import MutableLocal, VariableTracker from .constant import ConstantVariable from .functions import UserFunctionVariable, UserMethodVariable +from .iter import IteratorVariable if TYPE_CHECKING: @@ -334,11 +335,11 @@ class CommonListMethodsVariable(BaseListVariable): name == "extend" and self.mutable_local and args - and args[0].has_unpack_var_sequence(tx) + and args[0].has_force_unpack_var_sequence(tx) ): assert not kwargs (arg,) = args - seq = arg.unpack_var_sequence(tx) + seq = arg.force_unpack_var_sequence(tx) tx.output.side_effects.mutation(self) self.items.extend(seq) return ConstantVariable.create(None) @@ -422,11 +423,13 @@ class ListVariable(CommonListMethodsVariable): key, value = args tx.output.side_effects.mutation(self) if isinstance(key, SliceVariable): - if not value.has_unpack_var_sequence(tx): + if not value.has_force_unpack_var_sequence(tx): unimplemented( f"Missing dynamo support for expanding {value} into a list for slice assignment." ) - self.items[key.as_python_constant()] = value.unpack_var_sequence(tx) + self.items[key.as_python_constant()] = value.force_unpack_var_sequence( + tx + ) else: self.items[key.as_python_constant()] = value return ConstantVariable.create(None) @@ -464,7 +467,12 @@ class DequeVariable(CommonListMethodsVariable): ) ) codegen.foreach(self.items) - codegen.extend_output(create_call_function(len(self.items), False)) + codegen.extend_output( + [ + create_instruction("BUILD_LIST", arg=len(self.items)), + *create_call_function(1, False), + ] + ) def call_method( self, @@ -487,11 +495,15 @@ class DequeVariable(CommonListMethodsVariable): tx.output.side_effects.mutation(self) self.items[key.as_python_constant()] = value return ConstantVariable.create(None) - elif name == "extendleft" and self.mutable_local: + elif ( + name == "extendleft" + and self.mutable_local + and args[0].has_force_unpack_var_sequence(tx) + ): assert not kwargs (arg,) = args - prefix = arg.unpack_var_sequence(tx) + prefix = arg.force_unpack_var_sequence(tx) prefix.reverse() tx.output.side_effects.mutation(self) self.items = prefix + list(self.items) @@ -802,10 +814,10 @@ class SliceVariable(BaseListVariable): return self.items[fields.index(name)] -class ListIteratorVariable(VariableTracker): +class ListIteratorVariable(IteratorVariable): _nonvar_fields = { "index", - *VariableTracker._nonvar_fields, + *IteratorVariable._nonvar_fields, } def __init__(self, items, index: int = 0, **kwargs) -> None: @@ -856,6 +868,9 @@ class ListIteratorVariable(VariableTracker): def unpack_var_sequence(self, tx): return list(self.items[self.index :]) + def force_unpack_var_sequence(self, tx) -> List[VariableTracker]: + return self.unpack_var_sequence(tx) + def reconstruct(self, codegen): remaining_items = self.items[self.index :] codegen.foreach(remaining_items) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 0d53d6d21f55..34e1d0d10c9f 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -379,8 +379,8 @@ class UserDefinedClassVariable(UserDefinedVariable): elif self.value is collections.deque and not kwargs: if len(args) == 0: items = [] - elif len(args) == 1 and args[0].has_unpack_var_sequence(tx): - items = args[0].unpack_var_sequence(tx) + elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) else: unimplemented("deque() with more than 1 arg not supported") return variables.lists.DequeVariable(items, mutable_local=MutableLocal()) @@ -749,7 +749,7 @@ class UserDefinedObjectVariable(UserDefinedVariable): assert not (args or kwargs) items = [] keys = self.call_method(tx, "keys", [], {}) - for key in keys.unpack_var_sequence(tx): + for key in keys.force_unpack_var_sequence(tx): items.append( TupleVariable( [key, self.odict_getitem(tx, key)],