mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow subclassing namedtuple type. Allow assign attributes to instances of these subtypes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140534 Approved by: https://github.com/jansel
1111 lines
36 KiB
Python
1111 lines
36 KiB
Python
# mypy: ignore-errors
|
|
|
|
import collections
|
|
import functools
|
|
import inspect
|
|
import operator
|
|
import types
|
|
from typing import Dict, List, Optional, TYPE_CHECKING
|
|
|
|
import torch
|
|
import torch.fx
|
|
from torch._guards import Source
|
|
|
|
from .. import polyfills, variables
|
|
from ..bytecode_transformation import create_call_function, create_instruction
|
|
from ..exc import raise_observed_exception, unimplemented
|
|
from ..source import AttrSource
|
|
from ..utils import (
|
|
get_fake_value,
|
|
guard_if_dyn,
|
|
is_namedtuple,
|
|
istype,
|
|
iter_contains,
|
|
Lit,
|
|
namedtuple_fields,
|
|
odict_values,
|
|
set_example_value,
|
|
)
|
|
from .base import ValueMutationNew, VariableTracker
|
|
from .constant import ConstantVariable
|
|
from .functions import UserFunctionVariable, UserMethodVariable
|
|
from .iter import IteratorVariable
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._dynamo.codegen import PyCodegen
|
|
from torch._dynamo.symbolic_convert import InstructionTranslator
|
|
|
|
|
|
class BaseListVariable(VariableTracker):
|
|
@staticmethod
|
|
def cls_for_instance(obj):
|
|
if is_namedtuple(obj):
|
|
return functools.partial(NamedTupleVariable, tuple_cls=type(obj))
|
|
return BaseListVariable.cls_for(type(obj))
|
|
|
|
@staticmethod
|
|
def cls_for(obj):
|
|
return {
|
|
iter: ListIteratorVariable,
|
|
list: ListVariable,
|
|
slice: SliceVariable,
|
|
torch.Size: SizeVariable,
|
|
tuple: TupleVariable,
|
|
odict_values: ListVariable,
|
|
torch.nn.ParameterList: ListVariable,
|
|
torch.nn.ModuleList: ListVariable,
|
|
collections.deque: DequeVariable,
|
|
}[obj]
|
|
|
|
def __init__(
|
|
self,
|
|
items: List[VariableTracker],
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(**kwargs)
|
|
assert isinstance(items, list)
|
|
assert all(isinstance(x, VariableTracker) for x in items)
|
|
self.items: List[VariableTracker] = items
|
|
|
|
def _as_proxy(self):
|
|
return [x.as_proxy() for x in self.items]
|
|
|
|
def modified(self, items, **kwargs):
|
|
return type(self)(items, **kwargs)
|
|
|
|
@property
|
|
def value(self):
|
|
return self.as_python_constant()
|
|
|
|
def debug_repr_helper(self, prefix, suffix):
|
|
return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix
|
|
|
|
def as_python_constant(self):
|
|
return self.python_type()([x.as_python_constant() for x in self.items])
|
|
|
|
def as_proxy(self):
|
|
assert self.python_type() is not SizeVariable
|
|
return self.python_type()(self._as_proxy())
|
|
|
|
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
|
from .tensor import SymNodeVariable
|
|
|
|
if isinstance(arg, SymNodeVariable):
|
|
index = arg.sym_num
|
|
else:
|
|
index = arg.as_python_constant()
|
|
|
|
if isinstance(index, slice):
|
|
# Set source to None because slicing a list gives a new local
|
|
return self.clone(
|
|
items=self.items[index],
|
|
source=None,
|
|
mutation_type=ValueMutationNew() if self.mutation_type else None,
|
|
)
|
|
else:
|
|
assert isinstance(index, (int, torch.SymInt))
|
|
return self.items[index]
|
|
|
|
def unpack_var_sequence(self, tx):
|
|
return list(self.items)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: List["VariableTracker"],
|
|
kwargs: Dict[str, "VariableTracker"],
|
|
) -> "VariableTracker":
|
|
if name == "__getitem__":
|
|
from .tensor import TensorVariable
|
|
|
|
assert not kwargs and len(args) == 1
|
|
if isinstance(args[0], TensorVariable):
|
|
value = get_fake_value(args[0].as_proxy().node, tx)
|
|
if value.constant is not None and value.constant.numel() == 1:
|
|
value = variables.ConstantVariable.create(value.constant.item())
|
|
else:
|
|
unimplemented("__getitem__ with non-constant tensor")
|
|
else:
|
|
value = args[0]
|
|
return self.getitem_const(tx, value)
|
|
elif name == "__contains__":
|
|
assert len(args) == 1
|
|
assert not kwargs
|
|
return iter_contains(self.unpack_var_sequence(tx), args[0], tx)
|
|
elif name == "index":
|
|
return tx.inline_user_function_return(
|
|
VariableTracker.build(tx, polyfills.index),
|
|
[self] + list(args),
|
|
kwargs,
|
|
)
|
|
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
@staticmethod
|
|
def list_compare(tx: "InstructionTranslator", op, left, right):
|
|
return variables.UserFunctionVariable(polyfills.list_cmp).call_function(
|
|
tx, [variables.BuiltinVariable(op), left, right], {}
|
|
)
|
|
|
|
|
|
class RangeVariable(BaseListVariable):
|
|
def __init__(self, items, **kwargs) -> None:
|
|
items_to_map = items
|
|
start = variables.ConstantVariable.create(0)
|
|
stop = None
|
|
step = variables.ConstantVariable.create(1)
|
|
|
|
if len(items_to_map) == 1:
|
|
(stop,) = items_to_map
|
|
elif len(items_to_map) == 2:
|
|
start, stop = items_to_map
|
|
elif len(items_to_map) == 3:
|
|
start, stop, step = items_to_map
|
|
else:
|
|
raise AssertionError
|
|
|
|
assert stop is not None
|
|
super().__init__([start, stop, step], **kwargs)
|
|
|
|
def debug_repr(self):
|
|
return self.debug_repr_helper("range(", ")")
|
|
|
|
def python_type(self):
|
|
return range
|
|
|
|
def start(self):
|
|
return self.items[0].as_python_constant()
|
|
|
|
def stop(self):
|
|
return self.items[1].as_python_constant()
|
|
|
|
def step(self):
|
|
return self.items[2].as_python_constant()
|
|
|
|
def range_length(self):
|
|
lo = self.start()
|
|
hi = self.stop()
|
|
step = self.step()
|
|
|
|
assert step != 0
|
|
if step > 0 and lo < hi:
|
|
return 1 + (hi - 1 - lo) // step
|
|
elif step < 0 and lo > hi:
|
|
return 1 + (lo - 1 - hi) // (0 - step)
|
|
else:
|
|
return 0
|
|
|
|
def _get_slice_indices(self, length, slice):
|
|
step_is_negative = 0
|
|
|
|
if slice.step is None:
|
|
step = 1
|
|
step_is_negative = False
|
|
else:
|
|
step = slice.step
|
|
step_is_negative = slice.step < 0
|
|
|
|
# Find lower and upper bounds for start and stop.
|
|
if step_is_negative:
|
|
lower = -1
|
|
upper = length + lower
|
|
else:
|
|
lower = 0
|
|
upper = length
|
|
|
|
# Compute start
|
|
if slice.start is None:
|
|
start = upper if step_is_negative else lower
|
|
else:
|
|
start = slice.start
|
|
|
|
if start < 0:
|
|
start += length
|
|
if start < lower:
|
|
start = lower
|
|
else:
|
|
if start > upper:
|
|
start = upper
|
|
|
|
# Compute stop.
|
|
if slice.stop is None:
|
|
stop = lower if step_is_negative else upper
|
|
|
|
else:
|
|
stop = slice.stop
|
|
|
|
if stop < 0:
|
|
stop += length
|
|
if stop < lower:
|
|
stop = lower
|
|
else:
|
|
if stop > upper:
|
|
stop = upper
|
|
|
|
return [start, stop, step]
|
|
|
|
def apply_index(self, index):
|
|
length = self.range_length()
|
|
if index < 0:
|
|
index = length + index
|
|
|
|
if index < 0 or index >= length:
|
|
raise IndexError(f"index {index} is out of range")
|
|
|
|
return variables.ConstantVariable.create(self.start() + (index * self.step()))
|
|
|
|
def apply_slice(self, slice):
|
|
(slice_start, slice_stop, slice_step) = self._get_slice_indices(
|
|
self.range_length(), slice
|
|
)
|
|
|
|
def compute_item(index):
|
|
return self.start() + (index * self.step())
|
|
|
|
sub_step = self.step() * slice_step
|
|
sub_start = compute_item(slice_start)
|
|
sub_stop = compute_item(slice_stop)
|
|
|
|
result = RangeVariable(
|
|
[
|
|
variables.ConstantVariable.create(x)
|
|
for x in [sub_start, sub_stop, sub_step]
|
|
],
|
|
mutation_type=ValueMutationNew() if self.mutation_type else None,
|
|
)
|
|
return result
|
|
|
|
def as_python_constant(self):
|
|
return range(*[x.as_python_constant() for x in self.items])
|
|
|
|
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
|
# implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c
|
|
index = arg.as_python_constant()
|
|
|
|
if isinstance(index, slice):
|
|
return self.apply_slice(index)
|
|
else:
|
|
return self.apply_index(index)
|
|
|
|
def as_proxy(self):
|
|
return self.python_type()(*self._as_proxy())
|
|
|
|
def unpack_var_sequence(self, tx=None):
|
|
return [variables.ConstantVariable.create(x) for x in self.as_python_constant()]
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
assert "range" not in codegen.tx.f_globals
|
|
codegen.add_push_null(
|
|
lambda: codegen.append_output(codegen.create_load_python_module(range))
|
|
)
|
|
codegen.foreach(self.items)
|
|
codegen.extend_output(create_call_function(3, False))
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name):
|
|
fields = ["start", "stop", "step"]
|
|
if name not in fields:
|
|
unimplemented(f"range.{name}")
|
|
return self.items[fields.index(name)]
|
|
|
|
|
|
class CommonListMethodsVariable(BaseListVariable):
|
|
"""
|
|
Implement methods common to List and other List-like things
|
|
"""
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: List["VariableTracker"],
|
|
kwargs: Dict[str, "VariableTracker"],
|
|
) -> "VariableTracker":
|
|
from .tensor import SymNodeVariable
|
|
|
|
if name == "append" and self.is_mutable():
|
|
assert not kwargs
|
|
(arg,) = args
|
|
tx.output.side_effects.mutation(self)
|
|
self.items.append(arg)
|
|
return ConstantVariable.create(None)
|
|
elif (
|
|
name == "extend"
|
|
and self.is_mutable()
|
|
and args
|
|
and args[0].has_force_unpack_var_sequence(tx)
|
|
):
|
|
assert not kwargs
|
|
(arg,) = args
|
|
seq = arg.force_unpack_var_sequence(tx)
|
|
tx.output.side_effects.mutation(self)
|
|
self.items.extend(seq)
|
|
return ConstantVariable.create(None)
|
|
elif name == "insert" and self.is_mutable():
|
|
assert not kwargs
|
|
idx, value = args
|
|
if isinstance(idx, SymNodeVariable):
|
|
const_idx = idx.evaluate_expr()
|
|
else:
|
|
const_idx = idx.as_python_constant()
|
|
tx.output.side_effects.mutation(self)
|
|
self.items.insert(const_idx, value)
|
|
return ConstantVariable.create(None)
|
|
elif name == "pop" and self.is_mutable():
|
|
assert not kwargs
|
|
tx.output.side_effects.mutation(self)
|
|
return self.items.pop(*[a.as_python_constant() for a in args])
|
|
elif name == "clear" and self.is_mutable():
|
|
assert not kwargs and not args
|
|
tx.output.side_effects.mutation(self)
|
|
self.items.clear()
|
|
return ConstantVariable.create(None)
|
|
elif (
|
|
name == "__setitem__"
|
|
and self.is_mutable()
|
|
and args
|
|
and args[0].is_python_constant()
|
|
):
|
|
assert not kwargs
|
|
key, value = args
|
|
tx.output.side_effects.mutation(self)
|
|
if isinstance(key, SliceVariable):
|
|
self.items[key.as_python_constant()] = list(value.items)
|
|
else:
|
|
self.items[key.as_python_constant()] = value
|
|
return ConstantVariable.create(None)
|
|
elif name == "copy":
|
|
# List copy() doesn't have args and kwargs
|
|
assert not kwargs
|
|
assert not args
|
|
items = list(self.items)
|
|
return self.modified(items, mutation_type=ValueMutationNew())
|
|
elif name == "reverse" and self.is_mutable():
|
|
assert not kwargs
|
|
assert not args
|
|
self.items.reverse()
|
|
tx.output.side_effects.mutation(self)
|
|
return ConstantVariable.create(None)
|
|
else:
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
|
|
class ListVariable(CommonListMethodsVariable):
|
|
def python_type(self):
|
|
return list
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}(length={len(self.items)})"
|
|
|
|
def debug_repr(self):
|
|
return self.debug_repr_helper("[", "]")
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
codegen.foreach(self.items)
|
|
codegen.append_output(create_instruction("BUILD_LIST", arg=len(self.items)))
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: List["VariableTracker"],
|
|
kwargs: Dict[str, "VariableTracker"],
|
|
) -> "VariableTracker":
|
|
if (
|
|
name == "__setitem__"
|
|
and self.is_mutable()
|
|
and args
|
|
and args[0].is_python_constant()
|
|
):
|
|
assert not kwargs
|
|
key, value = args
|
|
tx.output.side_effects.mutation(self)
|
|
if isinstance(key, SliceVariable):
|
|
if not value.has_force_unpack_var_sequence(tx):
|
|
unimplemented(
|
|
f"Missing dynamo support for expanding {value} into a list for slice assignment."
|
|
)
|
|
self.items[key.as_python_constant()] = value.force_unpack_var_sequence(
|
|
tx
|
|
)
|
|
else:
|
|
self.items[key.as_python_constant()] = value
|
|
return ConstantVariable.create(None)
|
|
else:
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def var_getattr(self, tx, name):
|
|
if name == "__class__":
|
|
source = AttrSource(self.source, name) if self.source else None
|
|
class_type = self.python_type()
|
|
if class_type is list:
|
|
return variables.BuiltinVariable(class_type, source=source)
|
|
else:
|
|
return variables.UserDefinedClassVariable(class_type, source=source)
|
|
return super().var_getattr(tx, name)
|
|
|
|
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
|
if self.python_type() is not list:
|
|
return super().call_hasattr(tx, name)
|
|
return variables.ConstantVariable.create(hasattr([], name))
|
|
|
|
|
|
class DequeVariable(CommonListMethodsVariable):
|
|
def __init__(self, items, maxlen=None, **kwargs) -> None:
|
|
if maxlen is None:
|
|
maxlen = ConstantVariable.create(None)
|
|
assert (
|
|
maxlen.is_python_constant()
|
|
), f"maxlen must be a constant, got: {maxlen.debug_repr()}"
|
|
self.maxlen = maxlen
|
|
items = list(items)
|
|
if self.maxlen.as_python_constant() is not None:
|
|
items = items[-maxlen.as_python_constant() :]
|
|
super().__init__(items, **kwargs)
|
|
|
|
def python_type(self):
|
|
return collections.deque
|
|
|
|
def debug_repr(self):
|
|
if self.maxlen.as_python_constant() is None:
|
|
return self.debug_repr_helper(
|
|
"deque([", "], maxlen=" + self.maxlen.debug_repr() + ")"
|
|
)
|
|
return self.debug_repr_helper("deque([", "])")
|
|
|
|
def as_python_constant(self):
|
|
return self.python_type()(
|
|
[x.as_python_constant() for x in self.items],
|
|
maxlen=self.maxlen.as_python_constant(),
|
|
)
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
assert "deque" not in codegen.tx.f_globals
|
|
codegen.add_push_null(
|
|
lambda: codegen.append_output(
|
|
codegen.create_load_python_module(collections.deque)
|
|
)
|
|
)
|
|
codegen.foreach(self.items)
|
|
codegen.extend_output([create_instruction("BUILD_LIST", arg=len(self.items))])
|
|
codegen(self.maxlen)
|
|
codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False))
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name):
|
|
if name == "maxlen":
|
|
return self.maxlen
|
|
return super().var_getattr(tx, name)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: List["VariableTracker"],
|
|
kwargs: Dict[str, "VariableTracker"],
|
|
) -> "VariableTracker":
|
|
if (
|
|
name == "__setitem__"
|
|
and self.is_mutable()
|
|
and args
|
|
and args[0].is_python_constant()
|
|
):
|
|
assert len(args) == 2
|
|
assert not kwargs
|
|
key, value = args
|
|
assert key.is_python_constant()
|
|
assert isinstance(key.as_python_constant(), int)
|
|
tx.output.side_effects.mutation(self)
|
|
self.items[key.as_python_constant()] = value
|
|
return ConstantVariable.create(None)
|
|
|
|
maxlen = self.maxlen.as_python_constant()
|
|
if maxlen is not None:
|
|
slice_within_maxlen = slice(-maxlen, None)
|
|
else:
|
|
slice_within_maxlen = None
|
|
|
|
if (
|
|
name == "extendleft"
|
|
and self.is_mutable()
|
|
and len(args) > 0
|
|
and args[0].has_force_unpack_var_sequence(tx)
|
|
):
|
|
assert len(args) == 1
|
|
assert not kwargs
|
|
prefix = args[0].force_unpack_var_sequence(tx)
|
|
tx.output.side_effects.mutation(self)
|
|
self.items[:] = [*reversed(prefix), *self.items]
|
|
slice_within_maxlen = slice(None, maxlen)
|
|
result = ConstantVariable.create(None)
|
|
elif name == "popleft" and self.is_mutable():
|
|
assert not args
|
|
assert not kwargs
|
|
tx.output.side_effects.mutation(self)
|
|
result, *self.items[:] = self.items
|
|
elif name == "appendleft" and len(args) > 0 and self.is_mutable():
|
|
assert len(args) == 1
|
|
assert not kwargs
|
|
tx.output.side_effects.mutation(self)
|
|
self.items[:] = [args[0], *self.items]
|
|
slice_within_maxlen = slice(None, maxlen)
|
|
result = ConstantVariable.create(None)
|
|
else:
|
|
result = super().call_method(tx, name, args, kwargs)
|
|
|
|
if (
|
|
slice_within_maxlen is not None
|
|
and maxlen is not None
|
|
and len(self.items) > maxlen
|
|
):
|
|
self.items[:] = self.items[slice_within_maxlen]
|
|
return result
|
|
|
|
|
|
class TupleVariable(BaseListVariable):
|
|
def python_type(self):
|
|
return tuple
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}(length={len(self.items)})"
|
|
|
|
def debug_repr(self):
|
|
return self.debug_repr_helper("(", ")")
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
codegen.foreach(self.items)
|
|
codegen.append_output(create_instruction("BUILD_TUPLE", arg=len(self.items)))
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: List["VariableTracker"],
|
|
kwargs: Dict[str, "VariableTracker"],
|
|
) -> "VariableTracker":
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def var_getattr(self, tx, name):
|
|
if name == "__class__":
|
|
source = AttrSource(self.source, name) if self.source else None
|
|
class_type = self.python_type()
|
|
if class_type is tuple:
|
|
return variables.BuiltinVariable(class_type, source=source)
|
|
else:
|
|
return variables.UserDefinedClassVariable(class_type, source=source)
|
|
return super().var_getattr(tx, name)
|
|
|
|
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
|
if self.python_type() is not tuple:
|
|
return super().call_hasattr(tx, name)
|
|
return variables.ConstantVariable.create(hasattr((), name))
|
|
|
|
|
|
class SizeVariable(TupleVariable):
|
|
"""torch.Size(...)"""
|
|
|
|
_nonvar_fields = {
|
|
"proxy",
|
|
*TupleVariable._nonvar_fields,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
items: List[VariableTracker],
|
|
proxy: Optional[torch.fx.Proxy] = None,
|
|
**kwargs,
|
|
) -> None:
|
|
self.proxy = proxy
|
|
super().__init__(items, **kwargs)
|
|
|
|
def debug_repr(self):
|
|
return self.debug_repr_helper("torch.Size([", "])")
|
|
|
|
def python_type(self):
|
|
return torch.Size
|
|
|
|
def as_proxy(self):
|
|
if self.proxy is not None:
|
|
return self.proxy
|
|
|
|
# torch.Size needs special handling. Normally, we pun a list-like
|
|
# container to directly contain Proxy/Node objects from FX, and FX
|
|
# knows to look inside containers (via map_aggregate). But torch.Size
|
|
# is weird; although it subclasses from tuple, it doesn't allow
|
|
# members which aren't int-like (rejecting Proxy and Node). This
|
|
# means we can't use the normal representation trick
|
|
# torch.Size([proxy0, proxy1]). I looked into seeing if I could
|
|
# relax torch.Size in PyTorch proper, but if torch.Size constructor
|
|
# sees a type that it doesn't recognize, it will try to call
|
|
# __index__() on it, so there is no BC way to actually change this
|
|
# behavior (though it occurs to me that I could have just added a
|
|
# YOLO no checking alternate constructor.)
|
|
#
|
|
# To work around this problem, I represent a torch.Size proxy as
|
|
# a straight up proxy, that would have been constructed by taking
|
|
# the constituent proxies as arguments. This trick can be generally
|
|
# used for any construct that we need a proxy for but we can't
|
|
# directly represent as an aggregate; I don't see very many examples
|
|
# of this in torchdynamo though!
|
|
|
|
# Look for a proxy. If there are none, do the legacy behavior
|
|
tracer = None
|
|
proxies = self._as_proxy()
|
|
for proxy in proxies:
|
|
if isinstance(proxy, torch.fx.Proxy):
|
|
tracer = proxy.tracer
|
|
break
|
|
|
|
if tracer is None:
|
|
return torch.Size(proxies)
|
|
|
|
proxy = tracer.create_proxy("call_function", torch.Size, (proxies,), {})
|
|
set_example_value(
|
|
proxy.node,
|
|
torch.Size(
|
|
[
|
|
p.node.meta["example_value"] if not isinstance(p, int) else p
|
|
for p in proxies
|
|
]
|
|
),
|
|
)
|
|
return proxy
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
codegen.add_push_null(lambda: codegen.load_import_from("torch", "Size"))
|
|
codegen.foreach(self.items)
|
|
build_torch_size = [
|
|
create_instruction("BUILD_TUPLE", arg=len(self.items)),
|
|
] + create_call_function(1, False)
|
|
codegen.extend_output(build_torch_size)
|
|
|
|
def unpack_var_sequence(self, tx):
|
|
return list(self.items)
|
|
|
|
def numel(self, tx):
|
|
from .builtin import BuiltinVariable
|
|
from .tensor import SymNodeVariable
|
|
|
|
const_result = 1
|
|
sym_sizes = []
|
|
|
|
for v in self.items:
|
|
if isinstance(v, ConstantVariable):
|
|
const_result *= v.value
|
|
else:
|
|
assert isinstance(v, SymNodeVariable), type(v)
|
|
# Delay proxy calls until we know it will be necessary
|
|
sym_sizes.append(v)
|
|
|
|
result = ConstantVariable.create(const_result)
|
|
if sym_sizes and const_result == 1:
|
|
# Skip multiplying by 1
|
|
result, *sym_sizes = sym_sizes
|
|
|
|
if not sym_sizes or const_result == 0:
|
|
return result
|
|
|
|
mul = BuiltinVariable(operator.mul)
|
|
for v in sym_sizes:
|
|
result = mul.call_function(tx, [result, v], {})
|
|
return result
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: List["VariableTracker"],
|
|
kwargs: Dict[str, "VariableTracker"],
|
|
) -> "VariableTracker":
|
|
if name == "__getitem__":
|
|
assert not kwargs and len(args) == 1
|
|
out = self.get_item_dyn(tx, args[0])
|
|
return out
|
|
elif name == "numel":
|
|
assert not args and not kwargs
|
|
return self.numel(tx)
|
|
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker):
|
|
from .tensor import SymNodeVariable
|
|
|
|
if isinstance(arg, SymNodeVariable):
|
|
index = arg.sym_num
|
|
else:
|
|
index = arg.as_python_constant()
|
|
|
|
if isinstance(index, slice):
|
|
return SizeVariable(self.items[index])
|
|
else:
|
|
assert isinstance(index, (int, torch.SymInt))
|
|
return self.items[index]
|
|
|
|
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
|
return variables.ConstantVariable.create(hasattr(torch.Size, name))
|
|
|
|
|
|
class NamedTupleVariable(TupleVariable):
|
|
_nonvar_fields = {
|
|
"tuple_cls",
|
|
"dynamic_attributes",
|
|
*TupleVariable._nonvar_fields,
|
|
}
|
|
|
|
def __init__(self, items, tuple_cls, **kwargs) -> None:
|
|
super().__init__(items, **kwargs)
|
|
self.tuple_cls = tuple_cls
|
|
self.dynamic_attributes = {}
|
|
|
|
def is_namedtuple(self):
|
|
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(
|
|
getattr(self.tuple_cls, "_make", None)
|
|
)
|
|
|
|
def is_structseq(self):
|
|
return not self.is_namedtuple()
|
|
|
|
def fields(self):
|
|
return namedtuple_fields(self.tuple_cls)
|
|
|
|
def debug_repr(self):
|
|
if self.is_structseq():
|
|
# StructSequenceType(iterable)
|
|
return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items]))
|
|
# NamedTupleType(*iterable)
|
|
return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items)))
|
|
|
|
def python_type(self):
|
|
return self.tuple_cls
|
|
|
|
def as_python_constant(self):
|
|
if self.is_structseq():
|
|
# StructSequenceType(iterable)
|
|
return self.python_type()([x.as_python_constant() for x in self.items])
|
|
# NamedTupleType(*iterable)
|
|
return self.python_type()(*[x.as_python_constant() for x in self.items])
|
|
|
|
def as_proxy(self):
|
|
assert self.python_type() is not SizeVariable
|
|
if self.is_structseq():
|
|
# StructSequenceType(iterable)
|
|
return self.python_type()(self._as_proxy())
|
|
# NamedTupleType(*iterable)
|
|
return self.python_type()(*self._as_proxy())
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
# Constructors:
|
|
# StructSequenceType(iterable)
|
|
# NamedTupleType(*iterable)
|
|
# NamedTupleType._make(iterable)
|
|
create_fn = self.tuple_cls if self.is_structseq() else self.tuple_cls._make
|
|
codegen.add_push_null(
|
|
lambda: codegen.append_output(codegen._create_load_const(create_fn))
|
|
)
|
|
codegen.foreach(self.items)
|
|
codegen.extend_output(
|
|
[
|
|
create_instruction("BUILD_TUPLE", arg=len(self.items)),
|
|
]
|
|
+ create_call_function(1, False)
|
|
)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: List[VariableTracker],
|
|
kwargs: Dict[str, VariableTracker],
|
|
) -> VariableTracker:
|
|
if name == "__setattr__":
|
|
assert len(args) == 2
|
|
assert len(kwargs) == 0
|
|
attr, value = args
|
|
attr = attr.as_python_constant()
|
|
if (
|
|
# structseq is immutable
|
|
self.is_structseq()
|
|
# namedtuple directly created by `collections.namedtuple` is immutable
|
|
or self.tuple_cls.__bases__ == (tuple,)
|
|
# fields are immutable
|
|
or attr in self.fields()
|
|
):
|
|
raise_observed_exception(AttributeError, tx)
|
|
# Subclass of namedtuple type can have dynamic attributes
|
|
tx.output.side_effects.mutation(self)
|
|
self.dynamic_attributes[attr] = value
|
|
return ConstantVariable.create(None)
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name):
|
|
def check_and_create_method():
|
|
method = inspect.getattr_static(self.tuple_cls, name, None)
|
|
if isinstance(method, classmethod):
|
|
# We need the unbounded cls method to avoid the inline __self__
|
|
return UserMethodVariable(
|
|
method.__func__,
|
|
variables.UserDefinedClassVariable(self.tuple_cls),
|
|
)
|
|
elif isinstance(method, staticmethod):
|
|
return UserFunctionVariable(method.__func__)
|
|
elif inspect.isfunction(method):
|
|
return UserMethodVariable(method, self)
|
|
else:
|
|
return None
|
|
|
|
if name in self.dynamic_attributes:
|
|
return self.dynamic_attributes[name]
|
|
|
|
fields = self.fields()
|
|
if name not in fields:
|
|
method = check_and_create_method()
|
|
if not method:
|
|
return super().var_getattr(tx, name)
|
|
return method
|
|
return self.items[fields.index(name)]
|
|
|
|
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
|
return variables.ConstantVariable.create(
|
|
name in self.dynamic_attributes or hasattr(self.tuple_cls, name)
|
|
)
|
|
|
|
|
|
class SliceVariable(BaseListVariable):
|
|
def __init__(self, items, **kwargs) -> None:
|
|
items_to_map = items
|
|
start, stop, step = [variables.ConstantVariable.create(None)] * 3
|
|
|
|
if len(items_to_map) == 1:
|
|
(stop,) = items_to_map
|
|
elif len(items_to_map) == 2:
|
|
start, stop = items_to_map
|
|
elif len(items_to_map) == 3:
|
|
start, stop, step = items_to_map
|
|
else:
|
|
raise AssertionError
|
|
|
|
if isinstance(start, variables.TensorVariable) or isinstance(
|
|
stop, variables.TensorVariable
|
|
):
|
|
unimplemented("Dynamic slicing on data-dependent value is not supported")
|
|
|
|
super().__init__([start, stop, step], **kwargs)
|
|
|
|
def debug_repr(self):
|
|
return self.debug_repr_helper("slice(", ")")
|
|
|
|
def as_proxy(self):
|
|
return slice(*self._as_proxy())
|
|
|
|
def python_type(self):
|
|
return slice
|
|
|
|
def as_python_constant(self):
|
|
return slice(*[guard_if_dyn(x) for x in self.items])
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
codegen.foreach(self.items)
|
|
codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items)))
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name):
|
|
fields = ["start", "stop", "step"]
|
|
if name not in fields:
|
|
unimplemented(f"slice.{name}")
|
|
return self.items[fields.index(name)]
|
|
|
|
|
|
class ListIteratorVariable(IteratorVariable):
|
|
_nonvar_fields = {
|
|
"index",
|
|
*IteratorVariable._nonvar_fields,
|
|
}
|
|
|
|
def __init__(self, items, index: int = 0, **kwargs) -> None:
|
|
super().__init__(**kwargs)
|
|
assert isinstance(items, list)
|
|
# Removing this check as it slows things down too much
|
|
# https://github.com/pytorch/pytorch/pull/87533#issuecomment-1287574492
|
|
|
|
# assert all(isinstance(x, VariableTracker) for x in items)
|
|
self.items = items
|
|
self.index = index
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"
|
|
|
|
def next_variable(self, tx):
|
|
assert self.is_mutable()
|
|
old_index = self.index
|
|
if old_index >= len(self.items):
|
|
raise_observed_exception(StopIteration, tx)
|
|
|
|
tx.output.side_effects.mutation(self)
|
|
self.index += 1
|
|
return self.items[old_index]
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
):
|
|
if name == "__contains__":
|
|
assert len(args) == 1
|
|
assert not kwargs
|
|
return iter_contains(self.items[self.index :], args[0], tx)
|
|
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def python_type(self):
|
|
return type(iter([]))
|
|
|
|
def as_python_constant(self):
|
|
if self.index > 0:
|
|
raise NotImplementedError
|
|
return iter([x.as_python_constant() for x in self.items])
|
|
|
|
def unpack_var_sequence(self, tx):
|
|
return list(self.items[self.index :])
|
|
|
|
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
|
|
return self.unpack_var_sequence(tx)
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
remaining_items = self.items[self.index :]
|
|
codegen.foreach(remaining_items)
|
|
codegen.extend_output(
|
|
[
|
|
create_instruction("BUILD_TUPLE", arg=len(remaining_items)),
|
|
create_instruction("GET_ITER"),
|
|
]
|
|
)
|
|
|
|
|
|
class TupleIteratorVariable(ListIteratorVariable):
|
|
pass
|
|
|
|
|
|
class RestrictedListSubclassVariable(ListVariable):
|
|
"""
|
|
This is a special case of UserDefinedObjectVariable where:
|
|
1) The user subclasses list
|
|
2) None of the list methods are overriden, merely some new methods are added
|
|
|
|
In these cases, we can prevent graph breaks by not using the general
|
|
UserDefinedObjectVariable machinery and instead treating it like
|
|
a ListVariable.
|
|
"""
|
|
|
|
_nonvar_fields = {"user_cls", "user_cls_source", *ListVariable._nonvar_fields}
|
|
_allowed_names = {
|
|
"__call__",
|
|
"__module__",
|
|
"__dict__",
|
|
"__doc__",
|
|
"__name__",
|
|
"__qualname__",
|
|
}
|
|
_disallowed_names = {
|
|
"__getattribute__",
|
|
"__getattr__",
|
|
"__setattr__",
|
|
}
|
|
|
|
@classmethod
|
|
def _is_non_conflicting_subclass(
|
|
cls,
|
|
user_cls: type,
|
|
python_cls: type,
|
|
):
|
|
"""Ensures user_cls inherits from python_cls (e.g. list) and does not override any methods on python_cls"""
|
|
if (
|
|
not istype(user_cls, type)
|
|
or user_cls.__bases__ != (python_cls,)
|
|
or user_cls.__mro__ != (user_cls, python_cls, object)
|
|
):
|
|
return False # not subclass
|
|
return not any(
|
|
hasattr(python_cls, name) or name in cls._disallowed_names
|
|
for name in set(user_cls.__dict__.keys()) - cls._allowed_names
|
|
)
|
|
|
|
@classmethod
|
|
def is_matching_cls(cls, user_cls: type):
|
|
return cls._is_non_conflicting_subclass(user_cls, list)
|
|
|
|
def __init__(
|
|
self, items, *, user_cls: type, user_cls_source: Source, **kwargs
|
|
) -> None:
|
|
super().__init__(items=items, **kwargs)
|
|
self.user_cls = user_cls
|
|
self.user_cls_source = user_cls_source
|
|
assert istype(user_cls, type)
|
|
assert isinstance(user_cls_source, Source)
|
|
|
|
def debug_repr(self):
|
|
# The constructor is safe as no methods, including __init__, are
|
|
# allowed to be overridden
|
|
# NB: This is guaranteed to print like a list, as __repr__ cannot be
|
|
# overridden, this is... well, it's OK I guess (consistent with
|
|
# eager), but it could be misleading. You will have to query type
|
|
# instead for details.
|
|
return repr(self.user_cls([Lit(x.debug_repr()) for x in self.items]))
|
|
|
|
def python_type(self):
|
|
return self.user_cls
|
|
|
|
def as_proxy(self):
|
|
return [x.as_proxy() for x in self.items]
|
|
|
|
def as_python_constant(self):
|
|
raise NotImplementedError
|
|
|
|
def is_python_constant(self):
|
|
return False
|
|
|
|
@property
|
|
def value(self):
|
|
raise AttributeError("value")
|
|
|
|
def modified(self, items, **kwargs):
|
|
return type(self)(
|
|
items,
|
|
user_cls=self.user_cls,
|
|
user_cls_source=self.user_cls_source,
|
|
**kwargs,
|
|
)
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
codegen.add_push_null(lambda: codegen(self.user_cls_source))
|
|
super().reconstruct(codegen)
|
|
codegen.extend_output(create_call_function(1, False))
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: List["VariableTracker"],
|
|
kwargs: Dict[str, "VariableTracker"],
|
|
) -> "VariableTracker":
|
|
if name in self.user_cls.__dict__:
|
|
method = self.user_cls.__dict__[name]
|
|
if isinstance(method, types.FunctionType):
|
|
# inline the method
|
|
source = AttrSource(self.user_cls_source, name)
|
|
return UserMethodVariable(method, self, source=source).call_function(
|
|
tx, args, kwargs
|
|
)
|
|
unimplemented(
|
|
f"RestrictedListSubclassVariable method {self.user_cls.__name__}.{name}"
|
|
)
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def call_function(
|
|
self,
|
|
tx: "InstructionTranslator",
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
return self.call_method(tx, "__call__", args, kwargs)
|