# mypy: ignore-errors import collections import functools import inspect import operator import types from typing import Dict, List, Optional, TYPE_CHECKING import torch import torch.fx from torch._guards import Source from .. import polyfills, variables from ..bytecode_transformation import create_call_function, create_instruction from ..exc import raise_observed_exception, unimplemented from ..source import AttrSource from ..utils import ( get_fake_value, guard_if_dyn, is_namedtuple, istype, iter_contains, Lit, namedtuple_fields, odict_values, set_example_value, ) from .base import ValueMutationNew, VariableTracker from .constant import ConstantVariable from .functions import UserFunctionVariable, UserMethodVariable from .iter import IteratorVariable if TYPE_CHECKING: from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator class BaseListVariable(VariableTracker): @staticmethod def cls_for_instance(obj): if is_namedtuple(obj): return functools.partial(NamedTupleVariable, tuple_cls=type(obj)) return BaseListVariable.cls_for(type(obj)) @staticmethod def cls_for(obj): return { iter: ListIteratorVariable, list: ListVariable, slice: SliceVariable, torch.Size: SizeVariable, tuple: TupleVariable, odict_values: ListVariable, torch.nn.ParameterList: ListVariable, torch.nn.ModuleList: ListVariable, collections.deque: DequeVariable, }[obj] def __init__( self, items: List[VariableTracker], **kwargs, ) -> None: super().__init__(**kwargs) assert isinstance(items, list) assert all(isinstance(x, VariableTracker) for x in items) self.items: List[VariableTracker] = items def _as_proxy(self): return [x.as_proxy() for x in self.items] def modified(self, items, **kwargs): return type(self)(items, **kwargs) @property def value(self): return self.as_python_constant() def debug_repr_helper(self, prefix, suffix): return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix def as_python_constant(self): return self.python_type()([x.as_python_constant() for x in self.items]) def as_proxy(self): assert self.python_type() is not SizeVariable return self.python_type()(self._as_proxy()) def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): from .tensor import SymNodeVariable if isinstance(arg, SymNodeVariable): index = arg.sym_num else: index = arg.as_python_constant() if isinstance(index, slice): # Set source to None because slicing a list gives a new local return self.clone( items=self.items[index], source=None, mutation_type=ValueMutationNew() if self.mutation_type else None, ) else: assert isinstance(index, (int, torch.SymInt)) return self.items[index] def unpack_var_sequence(self, tx): return list(self.items) def call_method( self, tx, name, args: List["VariableTracker"], kwargs: Dict[str, "VariableTracker"], ) -> "VariableTracker": if name == "__getitem__": from .tensor import TensorVariable assert not kwargs and len(args) == 1 if isinstance(args[0], TensorVariable): value = get_fake_value(args[0].as_proxy().node, tx) if value.constant is not None and value.constant.numel() == 1: value = variables.ConstantVariable.create(value.constant.item()) else: unimplemented("__getitem__ with non-constant tensor") else: value = args[0] return self.getitem_const(tx, value) elif name == "__contains__": assert len(args) == 1 assert not kwargs return iter_contains(self.unpack_var_sequence(tx), args[0], tx) elif name == "index": return tx.inline_user_function_return( VariableTracker.build(tx, polyfills.index), [self] + list(args), kwargs, ) return super().call_method(tx, name, args, kwargs) @staticmethod def list_compare(tx: "InstructionTranslator", op, left, right): return variables.UserFunctionVariable(polyfills.list_cmp).call_function( tx, [variables.BuiltinVariable(op), left, right], {} ) class RangeVariable(BaseListVariable): def __init__(self, items, **kwargs) -> None: items_to_map = items start = variables.ConstantVariable.create(0) stop = None step = variables.ConstantVariable.create(1) if len(items_to_map) == 1: (stop,) = items_to_map elif len(items_to_map) == 2: start, stop = items_to_map elif len(items_to_map) == 3: start, stop, step = items_to_map else: raise AssertionError assert stop is not None super().__init__([start, stop, step], **kwargs) def debug_repr(self): return self.debug_repr_helper("range(", ")") def python_type(self): return range def start(self): return self.items[0].as_python_constant() def stop(self): return self.items[1].as_python_constant() def step(self): return self.items[2].as_python_constant() def range_length(self): lo = self.start() hi = self.stop() step = self.step() assert step != 0 if step > 0 and lo < hi: return 1 + (hi - 1 - lo) // step elif step < 0 and lo > hi: return 1 + (lo - 1 - hi) // (0 - step) else: return 0 def _get_slice_indices(self, length, slice): step_is_negative = 0 if slice.step is None: step = 1 step_is_negative = False else: step = slice.step step_is_negative = slice.step < 0 # Find lower and upper bounds for start and stop. if step_is_negative: lower = -1 upper = length + lower else: lower = 0 upper = length # Compute start if slice.start is None: start = upper if step_is_negative else lower else: start = slice.start if start < 0: start += length if start < lower: start = lower else: if start > upper: start = upper # Compute stop. if slice.stop is None: stop = lower if step_is_negative else upper else: stop = slice.stop if stop < 0: stop += length if stop < lower: stop = lower else: if stop > upper: stop = upper return [start, stop, step] def apply_index(self, index): length = self.range_length() if index < 0: index = length + index if index < 0 or index >= length: raise IndexError(f"index {index} is out of range") return variables.ConstantVariable.create(self.start() + (index * self.step())) def apply_slice(self, slice): (slice_start, slice_stop, slice_step) = self._get_slice_indices( self.range_length(), slice ) def compute_item(index): return self.start() + (index * self.step()) sub_step = self.step() * slice_step sub_start = compute_item(slice_start) sub_stop = compute_item(slice_stop) result = RangeVariable( [ variables.ConstantVariable.create(x) for x in [sub_start, sub_stop, sub_step] ], mutation_type=ValueMutationNew() if self.mutation_type else None, ) return result def as_python_constant(self): return range(*[x.as_python_constant() for x in self.items]) def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): # implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c index = arg.as_python_constant() if isinstance(index, slice): return self.apply_slice(index) else: return self.apply_index(index) def as_proxy(self): return self.python_type()(*self._as_proxy()) def unpack_var_sequence(self, tx=None): return [variables.ConstantVariable.create(x) for x in self.as_python_constant()] def reconstruct(self, codegen: "PyCodegen") -> None: assert "range" not in codegen.tx.f_globals codegen.add_push_null( lambda: codegen.append_output(codegen.create_load_python_module(range)) ) codegen.foreach(self.items) codegen.extend_output(create_call_function(3, False)) def var_getattr(self, tx: "InstructionTranslator", name): fields = ["start", "stop", "step"] if name not in fields: unimplemented(f"range.{name}") return self.items[fields.index(name)] class CommonListMethodsVariable(BaseListVariable): """ Implement methods common to List and other List-like things """ def call_method( self, tx, name, args: List["VariableTracker"], kwargs: Dict[str, "VariableTracker"], ) -> "VariableTracker": from .tensor import SymNodeVariable if name == "append" and self.is_mutable(): assert not kwargs (arg,) = args tx.output.side_effects.mutation(self) self.items.append(arg) return ConstantVariable.create(None) elif ( name == "extend" and self.is_mutable() and args and args[0].has_force_unpack_var_sequence(tx) ): assert not kwargs (arg,) = args seq = arg.force_unpack_var_sequence(tx) tx.output.side_effects.mutation(self) self.items.extend(seq) return ConstantVariable.create(None) elif name == "insert" and self.is_mutable(): assert not kwargs idx, value = args if isinstance(idx, SymNodeVariable): const_idx = idx.evaluate_expr() else: const_idx = idx.as_python_constant() tx.output.side_effects.mutation(self) self.items.insert(const_idx, value) return ConstantVariable.create(None) elif name == "pop" and self.is_mutable(): assert not kwargs tx.output.side_effects.mutation(self) return self.items.pop(*[a.as_python_constant() for a in args]) elif name == "clear" and self.is_mutable(): assert not kwargs and not args tx.output.side_effects.mutation(self) self.items.clear() return ConstantVariable.create(None) elif ( name == "__setitem__" and self.is_mutable() and args and args[0].is_python_constant() ): assert not kwargs key, value = args tx.output.side_effects.mutation(self) if isinstance(key, SliceVariable): self.items[key.as_python_constant()] = list(value.items) else: self.items[key.as_python_constant()] = value return ConstantVariable.create(None) elif name == "copy": # List copy() doesn't have args and kwargs assert not kwargs assert not args items = list(self.items) return self.modified(items, mutation_type=ValueMutationNew()) elif name == "reverse" and self.is_mutable(): assert not kwargs assert not args self.items.reverse() tx.output.side_effects.mutation(self) return ConstantVariable.create(None) else: return super().call_method(tx, name, args, kwargs) class ListVariable(CommonListMethodsVariable): def python_type(self): return list def __repr__(self) -> str: return f"{self.__class__.__name__}(length={len(self.items)})" def debug_repr(self): return self.debug_repr_helper("[", "]") def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach(self.items) codegen.append_output(create_instruction("BUILD_LIST", arg=len(self.items))) def call_method( self, tx, name, args: List["VariableTracker"], kwargs: Dict[str, "VariableTracker"], ) -> "VariableTracker": if ( name == "__setitem__" and self.is_mutable() and args and args[0].is_python_constant() ): assert not kwargs key, value = args tx.output.side_effects.mutation(self) if isinstance(key, SliceVariable): if not value.has_force_unpack_var_sequence(tx): unimplemented( f"Missing dynamo support for expanding {value} into a list for slice assignment." ) self.items[key.as_python_constant()] = value.force_unpack_var_sequence( tx ) else: self.items[key.as_python_constant()] = value return ConstantVariable.create(None) else: return super().call_method(tx, name, args, kwargs) def var_getattr(self, tx, name): if name == "__class__": source = AttrSource(self.source, name) if self.source else None class_type = self.python_type() if class_type is list: return variables.BuiltinVariable(class_type, source=source) else: return variables.UserDefinedClassVariable(class_type, source=source) return super().var_getattr(tx, name) def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": if self.python_type() is not list: return super().call_hasattr(tx, name) return variables.ConstantVariable.create(hasattr([], name)) class DequeVariable(CommonListMethodsVariable): def __init__(self, items, maxlen=None, **kwargs) -> None: if maxlen is None: maxlen = ConstantVariable.create(None) assert ( maxlen.is_python_constant() ), f"maxlen must be a constant, got: {maxlen.debug_repr()}" self.maxlen = maxlen items = list(items) if self.maxlen.as_python_constant() is not None: items = items[-maxlen.as_python_constant() :] super().__init__(items, **kwargs) def python_type(self): return collections.deque def debug_repr(self): if self.maxlen.as_python_constant() is None: return self.debug_repr_helper( "deque([", "], maxlen=" + self.maxlen.debug_repr() + ")" ) return self.debug_repr_helper("deque([", "])") def as_python_constant(self): return self.python_type()( [x.as_python_constant() for x in self.items], maxlen=self.maxlen.as_python_constant(), ) def reconstruct(self, codegen: "PyCodegen") -> None: assert "deque" not in codegen.tx.f_globals codegen.add_push_null( lambda: codegen.append_output( codegen.create_load_python_module(collections.deque) ) ) codegen.foreach(self.items) codegen.extend_output([create_instruction("BUILD_LIST", arg=len(self.items))]) codegen(self.maxlen) codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False)) def var_getattr(self, tx: "InstructionTranslator", name): if name == "maxlen": return self.maxlen return super().var_getattr(tx, name) def call_method( self, tx, name, args: List["VariableTracker"], kwargs: Dict[str, "VariableTracker"], ) -> "VariableTracker": if ( name == "__setitem__" and self.is_mutable() and args and args[0].is_python_constant() ): assert len(args) == 2 assert not kwargs key, value = args assert key.is_python_constant() assert isinstance(key.as_python_constant(), int) tx.output.side_effects.mutation(self) self.items[key.as_python_constant()] = value return ConstantVariable.create(None) maxlen = self.maxlen.as_python_constant() if maxlen is not None: slice_within_maxlen = slice(-maxlen, None) else: slice_within_maxlen = None if ( name == "extendleft" and self.is_mutable() and len(args) > 0 and args[0].has_force_unpack_var_sequence(tx) ): assert len(args) == 1 assert not kwargs prefix = args[0].force_unpack_var_sequence(tx) tx.output.side_effects.mutation(self) self.items[:] = [*reversed(prefix), *self.items] slice_within_maxlen = slice(None, maxlen) result = ConstantVariable.create(None) elif name == "popleft" and self.is_mutable(): assert not args assert not kwargs tx.output.side_effects.mutation(self) result, *self.items[:] = self.items elif name == "appendleft" and len(args) > 0 and self.is_mutable(): assert len(args) == 1 assert not kwargs tx.output.side_effects.mutation(self) self.items[:] = [args[0], *self.items] slice_within_maxlen = slice(None, maxlen) result = ConstantVariable.create(None) else: result = super().call_method(tx, name, args, kwargs) if ( slice_within_maxlen is not None and maxlen is not None and len(self.items) > maxlen ): self.items[:] = self.items[slice_within_maxlen] return result class TupleVariable(BaseListVariable): def python_type(self): return tuple def __repr__(self) -> str: return f"{self.__class__.__name__}(length={len(self.items)})" def debug_repr(self): return self.debug_repr_helper("(", ")") def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach(self.items) codegen.append_output(create_instruction("BUILD_TUPLE", arg=len(self.items))) def call_method( self, tx, name, args: List["VariableTracker"], kwargs: Dict[str, "VariableTracker"], ) -> "VariableTracker": return super().call_method(tx, name, args, kwargs) def var_getattr(self, tx, name): if name == "__class__": source = AttrSource(self.source, name) if self.source else None class_type = self.python_type() if class_type is tuple: return variables.BuiltinVariable(class_type, source=source) else: return variables.UserDefinedClassVariable(class_type, source=source) return super().var_getattr(tx, name) def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": if self.python_type() is not tuple: return super().call_hasattr(tx, name) return variables.ConstantVariable.create(hasattr((), name)) class SizeVariable(TupleVariable): """torch.Size(...)""" _nonvar_fields = { "proxy", *TupleVariable._nonvar_fields, } def __init__( self, items: List[VariableTracker], proxy: Optional[torch.fx.Proxy] = None, **kwargs, ) -> None: self.proxy = proxy super().__init__(items, **kwargs) def debug_repr(self): return self.debug_repr_helper("torch.Size([", "])") def python_type(self): return torch.Size def as_proxy(self): if self.proxy is not None: return self.proxy # torch.Size needs special handling. Normally, we pun a list-like # container to directly contain Proxy/Node objects from FX, and FX # knows to look inside containers (via map_aggregate). But torch.Size # is weird; although it subclasses from tuple, it doesn't allow # members which aren't int-like (rejecting Proxy and Node). This # means we can't use the normal representation trick # torch.Size([proxy0, proxy1]). I looked into seeing if I could # relax torch.Size in PyTorch proper, but if torch.Size constructor # sees a type that it doesn't recognize, it will try to call # __index__() on it, so there is no BC way to actually change this # behavior (though it occurs to me that I could have just added a # YOLO no checking alternate constructor.) # # To work around this problem, I represent a torch.Size proxy as # a straight up proxy, that would have been constructed by taking # the constituent proxies as arguments. This trick can be generally # used for any construct that we need a proxy for but we can't # directly represent as an aggregate; I don't see very many examples # of this in torchdynamo though! # Look for a proxy. If there are none, do the legacy behavior tracer = None proxies = self._as_proxy() for proxy in proxies: if isinstance(proxy, torch.fx.Proxy): tracer = proxy.tracer break if tracer is None: return torch.Size(proxies) proxy = tracer.create_proxy("call_function", torch.Size, (proxies,), {}) set_example_value( proxy.node, torch.Size( [ p.node.meta["example_value"] if not isinstance(p, int) else p for p in proxies ] ), ) return proxy def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null(lambda: codegen.load_import_from("torch", "Size")) codegen.foreach(self.items) build_torch_size = [ create_instruction("BUILD_TUPLE", arg=len(self.items)), ] + create_call_function(1, False) codegen.extend_output(build_torch_size) def unpack_var_sequence(self, tx): return list(self.items) def numel(self, tx): from .builtin import BuiltinVariable from .tensor import SymNodeVariable const_result = 1 sym_sizes = [] for v in self.items: if isinstance(v, ConstantVariable): const_result *= v.value else: assert isinstance(v, SymNodeVariable), type(v) # Delay proxy calls until we know it will be necessary sym_sizes.append(v) result = ConstantVariable.create(const_result) if sym_sizes and const_result == 1: # Skip multiplying by 1 result, *sym_sizes = sym_sizes if not sym_sizes or const_result == 0: return result mul = BuiltinVariable(operator.mul) for v in sym_sizes: result = mul.call_function(tx, [result, v], {}) return result def call_method( self, tx, name, args: List["VariableTracker"], kwargs: Dict[str, "VariableTracker"], ) -> "VariableTracker": if name == "__getitem__": assert not kwargs and len(args) == 1 out = self.get_item_dyn(tx, args[0]) return out elif name == "numel": assert not args and not kwargs return self.numel(tx) return super().call_method(tx, name, args, kwargs) def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker): from .tensor import SymNodeVariable if isinstance(arg, SymNodeVariable): index = arg.sym_num else: index = arg.as_python_constant() if isinstance(index, slice): return SizeVariable(self.items[index]) else: assert isinstance(index, (int, torch.SymInt)) return self.items[index] def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": return variables.ConstantVariable.create(hasattr(torch.Size, name)) class NamedTupleVariable(TupleVariable): _nonvar_fields = { "tuple_cls", "dynamic_attributes", *TupleVariable._nonvar_fields, } def __init__(self, items, tuple_cls, **kwargs) -> None: super().__init__(items, **kwargs) self.tuple_cls = tuple_cls self.dynamic_attributes = {} def is_namedtuple(self): return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable( getattr(self.tuple_cls, "_make", None) ) def is_structseq(self): return not self.is_namedtuple() def fields(self): return namedtuple_fields(self.tuple_cls) def debug_repr(self): if self.is_structseq(): # StructSequenceType(iterable) return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items])) # NamedTupleType(*iterable) return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items))) def python_type(self): return self.tuple_cls def as_python_constant(self): if self.is_structseq(): # StructSequenceType(iterable) return self.python_type()([x.as_python_constant() for x in self.items]) # NamedTupleType(*iterable) return self.python_type()(*[x.as_python_constant() for x in self.items]) def as_proxy(self): assert self.python_type() is not SizeVariable if self.is_structseq(): # StructSequenceType(iterable) return self.python_type()(self._as_proxy()) # NamedTupleType(*iterable) return self.python_type()(*self._as_proxy()) def reconstruct(self, codegen: "PyCodegen") -> None: # Constructors: # StructSequenceType(iterable) # NamedTupleType(*iterable) # NamedTupleType._make(iterable) create_fn = self.tuple_cls if self.is_structseq() else self.tuple_cls._make codegen.add_push_null( lambda: codegen.append_output(codegen._create_load_const(create_fn)) ) codegen.foreach(self.items) codegen.extend_output( [ create_instruction("BUILD_TUPLE", arg=len(self.items)), ] + create_call_function(1, False) ) def call_method( self, tx, name, args: List[VariableTracker], kwargs: Dict[str, VariableTracker], ) -> VariableTracker: if name == "__setattr__": assert len(args) == 2 assert len(kwargs) == 0 attr, value = args attr = attr.as_python_constant() if ( # structseq is immutable self.is_structseq() # namedtuple directly created by `collections.namedtuple` is immutable or self.tuple_cls.__bases__ == (tuple,) # fields are immutable or attr in self.fields() ): raise_observed_exception(AttributeError, tx) # Subclass of namedtuple type can have dynamic attributes tx.output.side_effects.mutation(self) self.dynamic_attributes[attr] = value return ConstantVariable.create(None) return super().call_method(tx, name, args, kwargs) def var_getattr(self, tx: "InstructionTranslator", name): def check_and_create_method(): method = inspect.getattr_static(self.tuple_cls, name, None) if isinstance(method, classmethod): # We need the unbounded cls method to avoid the inline __self__ return UserMethodVariable( method.__func__, variables.UserDefinedClassVariable(self.tuple_cls), ) elif isinstance(method, staticmethod): return UserFunctionVariable(method.__func__) elif inspect.isfunction(method): return UserMethodVariable(method, self) else: return None if name in self.dynamic_attributes: return self.dynamic_attributes[name] fields = self.fields() if name not in fields: method = check_and_create_method() if not method: return super().var_getattr(tx, name) return method return self.items[fields.index(name)] def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": return variables.ConstantVariable.create( name in self.dynamic_attributes or hasattr(self.tuple_cls, name) ) class SliceVariable(BaseListVariable): def __init__(self, items, **kwargs) -> None: items_to_map = items start, stop, step = [variables.ConstantVariable.create(None)] * 3 if len(items_to_map) == 1: (stop,) = items_to_map elif len(items_to_map) == 2: start, stop = items_to_map elif len(items_to_map) == 3: start, stop, step = items_to_map else: raise AssertionError if isinstance(start, variables.TensorVariable) or isinstance( stop, variables.TensorVariable ): unimplemented("Dynamic slicing on data-dependent value is not supported") super().__init__([start, stop, step], **kwargs) def debug_repr(self): return self.debug_repr_helper("slice(", ")") def as_proxy(self): return slice(*self._as_proxy()) def python_type(self): return slice def as_python_constant(self): return slice(*[guard_if_dyn(x) for x in self.items]) def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach(self.items) codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items))) def var_getattr(self, tx: "InstructionTranslator", name): fields = ["start", "stop", "step"] if name not in fields: unimplemented(f"slice.{name}") return self.items[fields.index(name)] class ListIteratorVariable(IteratorVariable): _nonvar_fields = { "index", *IteratorVariable._nonvar_fields, } def __init__(self, items, index: int = 0, **kwargs) -> None: super().__init__(**kwargs) assert isinstance(items, list) # Removing this check as it slows things down too much # https://github.com/pytorch/pytorch/pull/87533#issuecomment-1287574492 # assert all(isinstance(x, VariableTracker) for x in items) self.items = items self.index = index def __repr__(self) -> str: return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})" def next_variable(self, tx): assert self.is_mutable() old_index = self.index if old_index >= len(self.items): raise_observed_exception(StopIteration, tx) tx.output.side_effects.mutation(self) self.index += 1 return self.items[old_index] def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ): if name == "__contains__": assert len(args) == 1 assert not kwargs return iter_contains(self.items[self.index :], args[0], tx) return super().call_method(tx, name, args, kwargs) def python_type(self): return type(iter([])) def as_python_constant(self): if self.index > 0: raise NotImplementedError return iter([x.as_python_constant() for x in self.items]) def unpack_var_sequence(self, tx): return list(self.items[self.index :]) def force_unpack_var_sequence(self, tx) -> List[VariableTracker]: return self.unpack_var_sequence(tx) def reconstruct(self, codegen: "PyCodegen") -> None: remaining_items = self.items[self.index :] codegen.foreach(remaining_items) codegen.extend_output( [ create_instruction("BUILD_TUPLE", arg=len(remaining_items)), create_instruction("GET_ITER"), ] ) class TupleIteratorVariable(ListIteratorVariable): pass class RestrictedListSubclassVariable(ListVariable): """ This is a special case of UserDefinedObjectVariable where: 1) The user subclasses list 2) None of the list methods are overriden, merely some new methods are added In these cases, we can prevent graph breaks by not using the general UserDefinedObjectVariable machinery and instead treating it like a ListVariable. """ _nonvar_fields = {"user_cls", "user_cls_source", *ListVariable._nonvar_fields} _allowed_names = { "__call__", "__module__", "__dict__", "__doc__", "__name__", "__qualname__", } _disallowed_names = { "__getattribute__", "__getattr__", "__setattr__", } @classmethod def _is_non_conflicting_subclass( cls, user_cls: type, python_cls: type, ): """Ensures user_cls inherits from python_cls (e.g. list) and does not override any methods on python_cls""" if ( not istype(user_cls, type) or user_cls.__bases__ != (python_cls,) or user_cls.__mro__ != (user_cls, python_cls, object) ): return False # not subclass return not any( hasattr(python_cls, name) or name in cls._disallowed_names for name in set(user_cls.__dict__.keys()) - cls._allowed_names ) @classmethod def is_matching_cls(cls, user_cls: type): return cls._is_non_conflicting_subclass(user_cls, list) def __init__( self, items, *, user_cls: type, user_cls_source: Source, **kwargs ) -> None: super().__init__(items=items, **kwargs) self.user_cls = user_cls self.user_cls_source = user_cls_source assert istype(user_cls, type) assert isinstance(user_cls_source, Source) def debug_repr(self): # The constructor is safe as no methods, including __init__, are # allowed to be overridden # NB: This is guaranteed to print like a list, as __repr__ cannot be # overridden, this is... well, it's OK I guess (consistent with # eager), but it could be misleading. You will have to query type # instead for details. return repr(self.user_cls([Lit(x.debug_repr()) for x in self.items])) def python_type(self): return self.user_cls def as_proxy(self): return [x.as_proxy() for x in self.items] def as_python_constant(self): raise NotImplementedError def is_python_constant(self): return False @property def value(self): raise AttributeError("value") def modified(self, items, **kwargs): return type(self)( items, user_cls=self.user_cls, user_cls_source=self.user_cls_source, **kwargs, ) def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null(lambda: codegen(self.user_cls_source)) super().reconstruct(codegen) codegen.extend_output(create_call_function(1, False)) def call_method( self, tx, name, args: List["VariableTracker"], kwargs: Dict[str, "VariableTracker"], ) -> "VariableTracker": if name in self.user_cls.__dict__: method = self.user_cls.__dict__[name] if isinstance(method, types.FunctionType): # inline the method source = AttrSource(self.user_cls_source, name) return UserMethodVariable(method, self, source=source).call_function( tx, args, kwargs ) unimplemented( f"RestrictedListSubclassVariable method {self.user_cls.__name__}.{name}" ) return super().call_method(tx, name, args, kwargs) def call_function( self, tx: "InstructionTranslator", args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": return self.call_method(tx, "__call__", args, kwargs)