mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165037 Approved by: https://github.com/mlazos
1562 lines
55 KiB
Python
1562 lines
55 KiB
Python
# mypy: ignore-errors
|
|
|
|
"""
|
|
Variable tracking implementations for list-like data structures in Dynamo.
|
|
|
|
This module provides specialized variable tracking for various collection types:
|
|
- Lists and list subclasses (including torch.nn.ModuleList, ParameterList)
|
|
- Tuples and named tuples
|
|
- Ranges and slices
|
|
- Collections.deque
|
|
- torch.Size with special proxy handling
|
|
|
|
The implementations support both mutable and immutable collections, iteration,
|
|
and common sequence operations. Each collection type has a dedicated Variable
|
|
class that handles its unique behaviors while integrating with Dynamo's
|
|
variable tracking system.
|
|
"""
|
|
|
|
import collections
|
|
import inspect
|
|
import operator
|
|
import sys
|
|
from typing import Optional, TYPE_CHECKING
|
|
|
|
import torch
|
|
import torch.fx
|
|
|
|
from .. import graph_break_hints, polyfills, variables
|
|
from ..bytecode_transformation import (
|
|
create_build_tuple,
|
|
create_call_function,
|
|
create_instruction,
|
|
create_rot_n,
|
|
)
|
|
from ..exc import raise_observed_exception, unimplemented_v2
|
|
from ..source import AttrSource, NamedTupleFieldsSource
|
|
from ..utils import (
|
|
cmp_name_to_op_mapping,
|
|
cmp_name_to_op_str_mapping,
|
|
get_fake_value,
|
|
guard_if_dyn,
|
|
iter_contains,
|
|
Lit,
|
|
namedtuple_fields,
|
|
odict_values,
|
|
raise_args_mismatch,
|
|
range_iterator,
|
|
set_example_value,
|
|
)
|
|
from .base import ValueMutationNew, VariableTracker
|
|
from .constant import ConstantVariable
|
|
from .functions import UserFunctionVariable, UserMethodVariable
|
|
from .iter import IteratorVariable
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._dynamo.codegen import PyCodegen
|
|
from torch._dynamo.symbolic_convert import InstructionTranslator
|
|
|
|
|
|
class BaseListVariable(VariableTracker):
|
|
@staticmethod
|
|
def cls_for_instance(obj):
|
|
return BaseListVariable.cls_for(type(obj))
|
|
|
|
@staticmethod
|
|
def cls_for(obj):
|
|
return {
|
|
iter: ListIteratorVariable,
|
|
list: ListVariable,
|
|
slice: SliceVariable,
|
|
torch.Size: SizeVariable,
|
|
tuple: TupleVariable,
|
|
odict_values: ListVariable,
|
|
torch.nn.ParameterList: ListVariable,
|
|
torch.nn.ModuleList: ListVariable,
|
|
collections.deque: DequeVariable,
|
|
}[obj]
|
|
|
|
def __init__(
|
|
self,
|
|
items: list[VariableTracker],
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(**kwargs)
|
|
assert isinstance(items, list)
|
|
assert all(isinstance(x, VariableTracker) for x in items)
|
|
self.items: list[VariableTracker] = items
|
|
|
|
def _as_proxy(self):
|
|
return [x.as_proxy() for x in self.items]
|
|
|
|
def modified(self, items, **kwargs):
|
|
return type(self)(items, **kwargs)
|
|
|
|
@property
|
|
def value(self):
|
|
return self.as_python_constant()
|
|
|
|
def debug_repr_helper(self, prefix, suffix):
|
|
return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix
|
|
|
|
def as_python_constant(self):
|
|
return self.python_type()([x.as_python_constant() for x in self.items])
|
|
|
|
def as_proxy(self):
|
|
assert self.python_type() is not SizeVariable
|
|
return self.python_type()(self._as_proxy())
|
|
|
|
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
|
from .tensor import SymNodeVariable
|
|
|
|
if isinstance(arg, SymNodeVariable):
|
|
index = arg.sym_num
|
|
else:
|
|
index = arg.as_python_constant()
|
|
|
|
if isinstance(index, slice):
|
|
if index.step == 0:
|
|
msg = ConstantVariable.create("slice step cannot be zero")
|
|
raise_observed_exception(ValueError, tx, args=[msg])
|
|
# Set source to None because slicing a list gives a new local
|
|
return self.clone(
|
|
items=self.items[index],
|
|
source=None,
|
|
mutation_type=ValueMutationNew() if self.mutation_type else None,
|
|
)
|
|
else:
|
|
assert isinstance(index, (int, torch.SymInt))
|
|
try:
|
|
return self.items[index]
|
|
except IndexError:
|
|
raise_observed_exception(
|
|
IndexError, tx, args=["list index out of range"]
|
|
)
|
|
|
|
def unpack_var_sequence(self, tx):
|
|
return list(self.items)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: list["VariableTracker"],
|
|
kwargs: dict[str, "VariableTracker"],
|
|
) -> "VariableTracker":
|
|
if name == "__getitem__":
|
|
from .tensor import TensorVariable
|
|
|
|
if len(args) != 1:
|
|
msg = ConstantVariable.create(
|
|
f"{name} takes exactly one argument ({len(args)} given)"
|
|
)
|
|
raise_observed_exception(TypeError, tx, args=[msg])
|
|
|
|
assert not kwargs and len(args) == 1
|
|
if isinstance(args[0], TensorVariable):
|
|
value = get_fake_value(args[0].as_proxy().node, tx)
|
|
if value.constant is not None and value.constant.numel() == 1:
|
|
value = variables.ConstantVariable.create(value.constant.item())
|
|
else:
|
|
unimplemented_v2(
|
|
gb_type="Indexing list with non-scalar tensor",
|
|
context=f"call_method {self} {name} {args} {kwargs}",
|
|
explanation=(
|
|
"Attempted to index list-like object with tensor with > 1 element."
|
|
),
|
|
hints=[*graph_break_hints.USER_ERROR],
|
|
)
|
|
else:
|
|
value = args[0]
|
|
|
|
if value.python_type() not in (int, slice):
|
|
msg = f"indices must be integers or slices, not {value.python_type()}"
|
|
raise_observed_exception(TypeError, tx, args=[ConstantVariable(msg)])
|
|
|
|
return self.getitem_const(tx, value)
|
|
elif name == "__contains__":
|
|
if len(args) != 1 or kwargs:
|
|
raise_args_mismatch(tx, name)
|
|
return iter_contains(self.unpack_var_sequence(tx), args[0], tx)
|
|
elif name == "index":
|
|
if not len(args):
|
|
raise_args_mismatch(tx, name)
|
|
|
|
return tx.inline_user_function_return(
|
|
VariableTracker.build(tx, polyfills.index),
|
|
[self] + list(args),
|
|
kwargs,
|
|
)
|
|
elif name == "count":
|
|
if len(args) != 1:
|
|
raise_args_mismatch(tx, name)
|
|
return VariableTracker.build(tx, operator.countOf).call_function(
|
|
tx,
|
|
[self, args[0]],
|
|
kwargs,
|
|
)
|
|
elif name in ("__add__", "__iadd__"):
|
|
if kwargs or len(args) != 1:
|
|
raise_args_mismatch(tx, name)
|
|
|
|
if type(self) is not type(args[0]):
|
|
tp_name = self.python_type_name()
|
|
other = args[0].python_type_name()
|
|
msg = ConstantVariable.create(
|
|
f'can only concatenate {tp_name} (not "{other}") to {tp_name}'
|
|
)
|
|
raise_observed_exception(TypeError, tx, args=[msg])
|
|
|
|
if name == "__add__":
|
|
return type(self)(self.items + args[0].items, source=self.source)
|
|
else:
|
|
self.items += args[0].items
|
|
return self
|
|
elif name in ("__mul__", "__imul__"):
|
|
if kwargs or len(args) != 1:
|
|
raise_args_mismatch(tx, name)
|
|
|
|
if not (args[0].is_python_constant() and args[0].python_type() is int):
|
|
msg = ConstantVariable.create(
|
|
f"can't multiply sequence by non-int type of '{args[0].python_type_name()}'"
|
|
)
|
|
raise_observed_exception(TypeError, tx, args=[msg])
|
|
|
|
val = args[0].as_python_constant()
|
|
|
|
if name == "__mul__":
|
|
return type(self)(self.items * val, source=self.source)
|
|
else:
|
|
self.items *= val
|
|
return self
|
|
elif name in cmp_name_to_op_mapping:
|
|
if len(args) != 1:
|
|
raise_args_mismatch(tx, name)
|
|
|
|
left = self
|
|
right = args[0]
|
|
# TODO this type check logic mirrors the following
|
|
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/object.c#L991-L1007
|
|
# But we should probably move it up the stack to so that we don't
|
|
# need to duplicate it for different VTs.
|
|
if not isinstance(left, BaseListVariable) or not isinstance(
|
|
right, BaseListVariable
|
|
):
|
|
if name == "__eq__":
|
|
return variables.BuiltinVariable(operator.is_).call_function(
|
|
tx, (left, right), {}
|
|
)
|
|
elif name == "__ne__":
|
|
return variables.BuiltinVariable(operator.is_not).call_function(
|
|
tx, (left, right), {}
|
|
)
|
|
else:
|
|
op_str = cmp_name_to_op_str_mapping[name]
|
|
left_ty = left.python_type_name()
|
|
right_ty = right.python_type_name()
|
|
msg = f"{op_str} not supported between instances of '{left_ty}' and '{right_ty}'"
|
|
raise_observed_exception(TypeError, tx, args=[msg])
|
|
|
|
return variables.UserFunctionVariable(polyfills.list_cmp).call_function(
|
|
tx,
|
|
[variables.BuiltinVariable(cmp_name_to_op_mapping[name]), left, right],
|
|
{},
|
|
)
|
|
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
|
|
class RangeVariable(BaseListVariable):
|
|
def __init__(self, items, **kwargs) -> None:
|
|
items_to_map = items
|
|
start = variables.ConstantVariable.create(0)
|
|
stop = None
|
|
step = variables.ConstantVariable.create(1)
|
|
|
|
if len(items_to_map) == 1:
|
|
(stop,) = items_to_map
|
|
elif len(items_to_map) == 2:
|
|
start, stop = items_to_map
|
|
elif len(items_to_map) == 3:
|
|
start, stop, step = items_to_map
|
|
else:
|
|
raise AssertionError
|
|
|
|
def maybe_as_int(x):
|
|
return (
|
|
ConstantVariable(int(x.value)) if isinstance(x, ConstantVariable) else x
|
|
)
|
|
|
|
# cast each argument to an integer
|
|
start = maybe_as_int(start)
|
|
step = maybe_as_int(step)
|
|
stop = maybe_as_int(stop)
|
|
|
|
assert stop is not None
|
|
super().__init__([start, stop, step], **kwargs)
|
|
|
|
def debug_repr(self):
|
|
return self.debug_repr_helper("range(", ")")
|
|
|
|
def python_type(self):
|
|
return range
|
|
|
|
def start(self):
|
|
return self.items[0].as_python_constant()
|
|
|
|
def stop(self):
|
|
return self.items[1].as_python_constant()
|
|
|
|
def step(self):
|
|
return self.items[2].as_python_constant()
|
|
|
|
def range_length(self):
|
|
lo = self.start()
|
|
hi = self.stop()
|
|
step = self.step()
|
|
|
|
assert step != 0
|
|
if step > 0 and lo < hi:
|
|
return 1 + (hi - 1 - lo) // step
|
|
elif step < 0 and lo > hi:
|
|
return 1 + (lo - 1 - hi) // (0 - step)
|
|
else:
|
|
return 0
|
|
|
|
def _get_slice_indices(self, length, slice):
|
|
step_is_negative = 0
|
|
|
|
if slice.step is None:
|
|
step = 1
|
|
step_is_negative = False
|
|
else:
|
|
step = slice.step
|
|
step_is_negative = slice.step < 0
|
|
|
|
# Find lower and upper bounds for start and stop.
|
|
if step_is_negative:
|
|
lower = -1
|
|
upper = length + lower
|
|
else:
|
|
lower = 0
|
|
upper = length
|
|
|
|
# Compute start
|
|
if slice.start is None:
|
|
start = upper if step_is_negative else lower
|
|
else:
|
|
start = slice.start
|
|
|
|
if start < 0:
|
|
start += length
|
|
if start < lower:
|
|
start = lower
|
|
else:
|
|
if start > upper:
|
|
start = upper
|
|
|
|
# Compute stop.
|
|
if slice.stop is None:
|
|
stop = lower if step_is_negative else upper
|
|
|
|
else:
|
|
stop = slice.stop
|
|
|
|
if stop < 0:
|
|
stop += length
|
|
if stop < lower:
|
|
stop = lower
|
|
else:
|
|
if stop > upper:
|
|
stop = upper
|
|
|
|
return [start, stop, step]
|
|
|
|
def apply_index(self, index):
|
|
length = self.range_length()
|
|
if index < 0:
|
|
index = length + index
|
|
|
|
if index < 0 or index >= length:
|
|
tx = torch._dynamo.symbolic_convert.InstructionTranslator.current_tx()
|
|
raise_observed_exception(
|
|
IndexError,
|
|
tx,
|
|
args=[ConstantVariable("range object index out of range")],
|
|
)
|
|
|
|
return variables.ConstantVariable.create(self.start() + (index * self.step()))
|
|
|
|
def apply_slice(self, slice):
|
|
(slice_start, slice_stop, slice_step) = self._get_slice_indices(
|
|
self.range_length(), slice
|
|
)
|
|
|
|
def compute_item(index):
|
|
return self.start() + (index * self.step())
|
|
|
|
sub_step = self.step() * slice_step
|
|
sub_start = compute_item(slice_start)
|
|
sub_stop = compute_item(slice_stop)
|
|
|
|
result = RangeVariable(
|
|
[
|
|
variables.ConstantVariable.create(x)
|
|
for x in [sub_start, sub_stop, sub_step]
|
|
],
|
|
mutation_type=ValueMutationNew() if self.mutation_type else None,
|
|
)
|
|
return result
|
|
|
|
def as_python_constant(self):
|
|
return range(*[x.as_python_constant() for x in self.items])
|
|
|
|
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
|
# implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c
|
|
index = arg.as_python_constant()
|
|
|
|
if isinstance(index, slice):
|
|
return self.apply_slice(index)
|
|
elif isinstance(index, int):
|
|
return self.apply_index(index)
|
|
else:
|
|
msg = ConstantVariable("range indices must be integers or slices")
|
|
raise_observed_exception(TypeError, tx, args=[msg])
|
|
|
|
def as_proxy(self):
|
|
return self.python_type()(*self._as_proxy())
|
|
|
|
def unpack_var_sequence(self, tx=None):
|
|
return [variables.ConstantVariable.create(x) for x in self.as_python_constant()]
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
assert "range" not in codegen.tx.f_globals
|
|
codegen.add_push_null(
|
|
lambda: codegen.append_output(codegen.create_load_python_module(range))
|
|
)
|
|
codegen.foreach(self.items)
|
|
codegen.extend_output(create_call_function(3, False))
|
|
|
|
def call_obj_hasattr(
|
|
self, tx: "InstructionTranslator", name: str
|
|
) -> "VariableTracker":
|
|
if self.python_type() is not range:
|
|
return super().call_obj_hasattr(tx, name)
|
|
return variables.ConstantVariable.create(hasattr(range(0), name))
|
|
|
|
def range_equals(self, other: "RangeVariable"):
|
|
r0, r1 = self, other
|
|
if (
|
|
self.range_length() != r1.range_length()
|
|
or self.range_length() == 0
|
|
or r0.start() != r1.start()
|
|
):
|
|
return False
|
|
|
|
if len(r0) == 1:
|
|
return True
|
|
|
|
return r0.step() == r1.step()
|
|
|
|
def range_count(self, x: VariableTracker):
|
|
# Based on CPython
|
|
# https://github.com/guilhermeleobas/cpython/blob/baefaa6cba1d69efd2f930cdc56bca682c54b139/Objects/rangeobject.c#L442-L486
|
|
x = x.as_python_constant()
|
|
if type(x) not in (bool, int, float):
|
|
return 0
|
|
|
|
start, stop, step = self.start(), self.stop(), self.step()
|
|
|
|
if step == 0:
|
|
return 0
|
|
|
|
in_range = (start <= x < stop) if step > 0 else (stop < x <= start)
|
|
|
|
if in_range:
|
|
re = ((x - start) % step) == 0
|
|
return int(re)
|
|
return 0
|
|
|
|
def call_method(self, tx, name, args, kwargs):
|
|
if name == "__iter__":
|
|
if not all(var.is_python_constant() for var in self.items):
|
|
# Can't represent a `range_iterator` without well defined bounds
|
|
return variables.misc.DelayGraphBreakVariable(
|
|
msg="Cannot create range_iterator: bounds (start, stop, step) must be fully defined as concrete constants.",
|
|
)
|
|
return RangeIteratorVariable(
|
|
self.start(), self.stop(), self.step(), self.range_length()
|
|
)
|
|
elif name == "__len__":
|
|
length = self.range_length()
|
|
if length > sys.maxsize:
|
|
raise_observed_exception(OverflowError, tx)
|
|
return ConstantVariable.create(self.range_length())
|
|
elif name in ("count", "__contains__"):
|
|
return ConstantVariable(self.range_count(*args))
|
|
elif name == "__getitem__":
|
|
return self.getitem_const(tx, *args)
|
|
elif name in cmp_name_to_op_mapping:
|
|
other = args[0]
|
|
pt = other.python_type()
|
|
if name not in ("__eq__", "__ne__"):
|
|
# ranges are only comparable to other ranges
|
|
msg = f"{name} not supported between instances of 'range' and '{pt}'"
|
|
raise_observed_exception(
|
|
TypeError,
|
|
tx,
|
|
args=[ConstantVariable.create(msg)],
|
|
)
|
|
|
|
if pt is not range:
|
|
return ConstantVariable.create(NotImplemented)
|
|
|
|
cmp = self.range_equals(other)
|
|
|
|
# Two ranges are equal if they produce the same sequence of values
|
|
if name == "__eq__":
|
|
return ConstantVariable(cmp)
|
|
else:
|
|
return ConstantVariable(not cmp)
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name):
|
|
fields = ["start", "stop", "step"]
|
|
if name in fields:
|
|
return self.items[fields.index(name)]
|
|
return super().var_getattr(tx, name)
|
|
|
|
|
|
class CommonListMethodsVariable(BaseListVariable):
|
|
"""
|
|
Implement methods common to List and other List-like things
|
|
"""
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: list["VariableTracker"],
|
|
kwargs: dict[str, "VariableTracker"],
|
|
) -> "VariableTracker":
|
|
from .tensor import SymNodeVariable
|
|
|
|
if name == "append" and self.is_mutable():
|
|
assert not kwargs
|
|
if len(args) != 1:
|
|
raise_args_mismatch(tx, name)
|
|
(arg,) = args
|
|
tx.output.side_effects.mutation(self)
|
|
self.items.append(arg)
|
|
return ConstantVariable.create(None)
|
|
elif name == "extend" and self.is_mutable():
|
|
if len(args) != 1 or kwargs:
|
|
raise_args_mismatch(tx, name)
|
|
|
|
if not args[0].has_force_unpack_var_sequence(tx):
|
|
msg = ConstantVariable.create(f"{type(args[0])} object is not iterable")
|
|
raise_observed_exception(TypeError, tx, args=[msg])
|
|
|
|
(arg,) = args
|
|
arg.force_apply_to_var_sequence(
|
|
tx, lambda item: self.call_method(tx, "append", [item], {})
|
|
)
|
|
return ConstantVariable.create(None)
|
|
elif name == "insert" and self.is_mutable():
|
|
if kwargs or len(args) != 2:
|
|
raise_args_mismatch(tx, name)
|
|
idx, value = args
|
|
if isinstance(idx, SymNodeVariable):
|
|
const_idx = idx.evaluate_expr()
|
|
else:
|
|
const_idx = idx.as_python_constant()
|
|
tx.output.side_effects.mutation(self)
|
|
self.items.insert(const_idx, value)
|
|
return ConstantVariable.create(None)
|
|
elif name == "pop" and self.is_mutable():
|
|
assert not kwargs
|
|
if kwargs or len(args) > 1:
|
|
raise_args_mismatch(tx, name)
|
|
|
|
if len(self.items) == 0:
|
|
msg = ConstantVariable.create("pop from empty list")
|
|
raise_observed_exception(IndexError, tx, args=[msg])
|
|
|
|
if len(args):
|
|
idx = args[0].as_python_constant()
|
|
if idx > len(self.items):
|
|
msg = ConstantVariable.create("pop index out of range")
|
|
raise_observed_exception(IndexError, tx, args=[msg])
|
|
tx.output.side_effects.mutation(self)
|
|
return self.items.pop(*[a.as_python_constant() for a in args])
|
|
elif name == "clear" and self.is_mutable():
|
|
if args or kwargs:
|
|
raise_observed_exception(TypeError, tx)
|
|
tx.output.side_effects.mutation(self)
|
|
self.items.clear()
|
|
return ConstantVariable.create(None)
|
|
elif (
|
|
name == "__setitem__"
|
|
and self.is_mutable()
|
|
and args
|
|
and (
|
|
args[0].is_python_constant()
|
|
or isinstance(args[0], SymNodeVariable)
|
|
or (
|
|
isinstance(args[0], SliceVariable)
|
|
and all(
|
|
s.is_python_constant() or isinstance(s, SymNodeVariable)
|
|
for s in args[0].items
|
|
)
|
|
)
|
|
)
|
|
):
|
|
assert not kwargs
|
|
key, value = args
|
|
tx.output.side_effects.mutation(self)
|
|
if isinstance(key, SymNodeVariable):
|
|
self.items[key.evaluate_expr()] = value
|
|
elif isinstance(key, SliceVariable):
|
|
if key.is_python_constant():
|
|
self.items[key.as_python_constant()] = list(value.items)
|
|
else:
|
|
items = slice(
|
|
*[
|
|
(
|
|
s.evaluate_expr()
|
|
if isinstance(s, SymNodeVariable)
|
|
else s.as_python_constant()
|
|
)
|
|
for s in key.items
|
|
]
|
|
)
|
|
self.items[items] = list(value.items)
|
|
else:
|
|
self.items[key.as_python_constant()] = value
|
|
return ConstantVariable.create(None)
|
|
elif name == "__delitem__" and self.is_mutable():
|
|
if kwargs or len(args) != 1:
|
|
raise_args_mismatch(tx, name)
|
|
|
|
tx.output.side_effects.mutation(self)
|
|
if args[0].is_python_constant() and isinstance(
|
|
args[0].as_python_constant(), (int, slice)
|
|
):
|
|
if isinstance(args[0], SymNodeVariable):
|
|
idx = args[0].evaluate_expr()
|
|
else:
|
|
idx = args[0].as_python_constant()
|
|
|
|
try:
|
|
self.items.__delitem__(idx)
|
|
except (IndexError, ValueError) as exc:
|
|
raise_observed_exception(
|
|
type(exc),
|
|
tx,
|
|
args=list(map(ConstantVariable.create, exc.args)),
|
|
)
|
|
else:
|
|
msg = ConstantVariable.create(
|
|
f"list indices must be integers or slices, not {args[0].python_type_name()}"
|
|
)
|
|
raise_observed_exception(TypeError, tx, args=[msg])
|
|
return ConstantVariable.create(None)
|
|
elif name == "copy":
|
|
# List copy() doesn't have args and kwargs
|
|
if args or kwargs:
|
|
raise_args_mismatch(tx, name)
|
|
items = list(self.items)
|
|
return self.modified(items, mutation_type=ValueMutationNew())
|
|
elif name == "reverse" and self.is_mutable():
|
|
if args or kwargs:
|
|
raise_args_mismatch(tx, name)
|
|
self.items.reverse()
|
|
tx.output.side_effects.mutation(self)
|
|
return ConstantVariable.create(None)
|
|
elif name == "remove" and self.is_mutable():
|
|
if len(args) != 1 or kwargs:
|
|
raise_args_mismatch(tx, name)
|
|
|
|
idx = self.call_method(tx, "index", args, kwargs)
|
|
self.call_method(tx, "pop", [idx], {})
|
|
return ConstantVariable.create(None)
|
|
else:
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
|
|
class ListVariable(CommonListMethodsVariable):
|
|
def python_type(self):
|
|
return list
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}(length={len(self.items)})"
|
|
|
|
def debug_repr(self):
|
|
return self.debug_repr_helper("[", "]")
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
codegen.foreach(self.items)
|
|
codegen.append_output(create_instruction("BUILD_LIST", arg=len(self.items)))
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: list["VariableTracker"],
|
|
kwargs: dict[str, "VariableTracker"],
|
|
) -> "VariableTracker":
|
|
from .tensor import SymNodeVariable
|
|
|
|
if name == "__setitem__" and self.is_mutable():
|
|
if kwargs or len(args) != 2:
|
|
raise_args_mismatch(tx, name)
|
|
key, value = args
|
|
|
|
if not key.is_python_constant():
|
|
# probably will graph-break
|
|
super().call_method(tx, name, args, kwargs)
|
|
|
|
tx.output.side_effects.mutation(self)
|
|
if isinstance(key, SliceVariable):
|
|
if not value.has_force_unpack_var_sequence(tx):
|
|
msg = ConstantVariable.create("can only assign an iterable")
|
|
raise_observed_exception(TypeError, tx, args=[msg])
|
|
|
|
key = key.as_python_constant()
|
|
if key.step == 0:
|
|
msg = ConstantVariable.create("slice step cannot be zero")
|
|
raise_observed_exception(ValueError, tx, args=[msg])
|
|
|
|
value = value.force_unpack_var_sequence(tx)
|
|
try:
|
|
self.items[key] = value
|
|
except Exception as exc:
|
|
raise_observed_exception(
|
|
type(exc),
|
|
tx,
|
|
args=list(map(ConstantVariable.create, exc.args)),
|
|
)
|
|
else:
|
|
if isinstance(key, SymNodeVariable):
|
|
key = key.evaluate_expr()
|
|
else:
|
|
key = key.as_python_constant()
|
|
|
|
try:
|
|
self.items[key] = value
|
|
except (IndexError, TypeError) as e:
|
|
raise_observed_exception(
|
|
type(e), tx, args=list(map(ConstantVariable.create, e.args))
|
|
)
|
|
return ConstantVariable.create(None)
|
|
|
|
if name == "sort" and self.is_mutable():
|
|
assert len(args) == 0
|
|
key_fn_var = kwargs.pop("key", ConstantVariable.create(None))
|
|
reverse = kwargs.pop(
|
|
"reverse", ConstantVariable.create(False)
|
|
).as_python_constant()
|
|
assert len(kwargs) == 0
|
|
|
|
if (
|
|
key_fn_var.is_python_constant()
|
|
and key_fn_var.as_python_constant() is None
|
|
):
|
|
keys = self.items.copy()
|
|
else:
|
|
keys = [key_fn_var.call_function(tx, [x], {}) for x in self.items]
|
|
|
|
if not all(k.is_python_constant() for k in keys):
|
|
first_non_constant_key = None
|
|
for k in keys:
|
|
if not k.is_python_constant():
|
|
first_non_constant_key = k
|
|
assert first_non_constant_key is not None
|
|
|
|
try:
|
|
python_type = first_non_constant_key.python_type()
|
|
except NotImplementedError:
|
|
python_type = "unknown"
|
|
|
|
unimplemented_v2(
|
|
gb_type="sort with non-constant keys",
|
|
context=str(first_non_constant_key),
|
|
explanation=(
|
|
f"Cannot perform sort with non-constant key. "
|
|
f"First non-constant key type: {python_type}. "
|
|
f"Most notably, we cannot sort with Tensor or SymInt keys, but we can "
|
|
f"sort ints."
|
|
),
|
|
hints=["Use something else as the key."],
|
|
)
|
|
|
|
tx.output.side_effects.mutation(self)
|
|
sorted_items_with_keys = sorted(
|
|
(
|
|
(
|
|
x,
|
|
k.as_python_constant(),
|
|
-i if reverse else i, # extra key to ensure stable sort
|
|
)
|
|
for i, (k, x) in enumerate(zip(keys, self.items))
|
|
),
|
|
key=operator.itemgetter(1, 2),
|
|
reverse=reverse,
|
|
)
|
|
self.items[:] = [x for x, *_ in sorted_items_with_keys]
|
|
return ConstantVariable.create(None)
|
|
|
|
if name == "__init__" and self.is_mutable():
|
|
assert not kwargs
|
|
if len(args) == 0:
|
|
return ConstantVariable.create(None)
|
|
elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
|
|
(arg,) = args
|
|
tx.output.side_effects.mutation(self)
|
|
self.items[:] = arg.force_unpack_var_sequence(tx)
|
|
return ConstantVariable.create(None)
|
|
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def var_getattr(self, tx, name):
|
|
if name == "__class__":
|
|
source = AttrSource(self.source, name) if self.source else None
|
|
class_type = self.python_type()
|
|
if class_type is list:
|
|
return variables.BuiltinVariable(class_type, source=source)
|
|
else:
|
|
return variables.UserDefinedClassVariable(class_type, source=source)
|
|
return super().var_getattr(tx, name)
|
|
|
|
def call_obj_hasattr(
|
|
self, tx: "InstructionTranslator", name: str
|
|
) -> "VariableTracker":
|
|
if self.python_type() is not list:
|
|
return super().call_obj_hasattr(tx, name)
|
|
return variables.ConstantVariable.create(hasattr([], name))
|
|
|
|
|
|
class DequeVariable(CommonListMethodsVariable):
|
|
def __init__(self, items, maxlen=None, **kwargs) -> None:
|
|
if maxlen is None:
|
|
maxlen = ConstantVariable.create(None)
|
|
assert maxlen.is_python_constant(), (
|
|
f"maxlen must be a constant, got: {maxlen.debug_repr()}"
|
|
)
|
|
self.maxlen = maxlen
|
|
items = list(items)
|
|
if self.maxlen.as_python_constant() is not None:
|
|
items = items[-maxlen.as_python_constant() :]
|
|
super().__init__(items, **kwargs)
|
|
|
|
def python_type(self):
|
|
return collections.deque
|
|
|
|
def debug_repr(self):
|
|
if self.maxlen.as_python_constant() is None:
|
|
return self.debug_repr_helper(
|
|
"deque([", "], maxlen=" + self.maxlen.debug_repr() + ")"
|
|
)
|
|
return self.debug_repr_helper("deque([", "])")
|
|
|
|
def as_python_constant(self):
|
|
return self.python_type()(
|
|
[x.as_python_constant() for x in self.items],
|
|
maxlen=self.maxlen.as_python_constant(),
|
|
)
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
codegen.add_push_null(
|
|
lambda: codegen.append_output(
|
|
codegen.create_load_python_module(collections.deque)
|
|
)
|
|
)
|
|
codegen.foreach(self.items)
|
|
codegen.extend_output([create_instruction("BUILD_LIST", arg=len(self.items))])
|
|
codegen(self.maxlen)
|
|
codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False))
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name):
|
|
if name == "maxlen":
|
|
return self.maxlen
|
|
return super().var_getattr(tx, name)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: list["VariableTracker"],
|
|
kwargs: dict[str, "VariableTracker"],
|
|
) -> "VariableTracker":
|
|
if (
|
|
name == "__setitem__"
|
|
and self.is_mutable()
|
|
and args
|
|
and args[0].is_python_constant()
|
|
):
|
|
assert len(args) == 2
|
|
assert not kwargs
|
|
key, value = args
|
|
assert key.is_python_constant()
|
|
assert isinstance(key.as_python_constant(), int)
|
|
tx.output.side_effects.mutation(self)
|
|
self.items[key.as_python_constant()] = value
|
|
return ConstantVariable.create(None)
|
|
|
|
maxlen = self.maxlen.as_python_constant()
|
|
if maxlen is not None:
|
|
slice_within_maxlen = slice(-maxlen, None)
|
|
else:
|
|
slice_within_maxlen = None
|
|
|
|
if (
|
|
name == "extendleft"
|
|
and self.is_mutable()
|
|
and len(args) > 0
|
|
and args[0].has_force_unpack_var_sequence(tx)
|
|
):
|
|
assert len(args) == 1
|
|
assert not kwargs
|
|
# NOTE this is inefficient, but the alternative is to represent self.items
|
|
# as a deque, which is a more intrusive change.
|
|
args[0].force_apply_to_var_sequence(
|
|
tx, lambda item: self.call_method(tx, "appendleft", [item], {})
|
|
)
|
|
slice_within_maxlen = slice(None, maxlen)
|
|
result = ConstantVariable.create(None)
|
|
elif name == "popleft" and self.is_mutable():
|
|
assert not args
|
|
assert not kwargs
|
|
tx.output.side_effects.mutation(self)
|
|
result, *self.items[:] = self.items
|
|
elif name == "appendleft" and len(args) > 0 and self.is_mutable():
|
|
assert len(args) == 1
|
|
assert not kwargs
|
|
tx.output.side_effects.mutation(self)
|
|
self.items[:] = [args[0], *self.items]
|
|
slice_within_maxlen = slice(None, maxlen)
|
|
result = ConstantVariable.create(None)
|
|
elif name == "insert" and len(args) > 0 and self.is_mutable():
|
|
assert len(args) == 2
|
|
assert not kwargs
|
|
if maxlen is not None and len(self.items) == maxlen:
|
|
raise_observed_exception(
|
|
IndexError, tx, args=["deque already at its maximum size"]
|
|
)
|
|
result = super().call_method(tx, name, args, kwargs)
|
|
else:
|
|
result = super().call_method(tx, name, args, kwargs)
|
|
|
|
if (
|
|
slice_within_maxlen is not None
|
|
and maxlen is not None
|
|
and len(self.items) > maxlen
|
|
):
|
|
self.items[:] = self.items[slice_within_maxlen]
|
|
return result
|
|
|
|
|
|
class TupleVariable(BaseListVariable):
|
|
def python_type(self):
|
|
return tuple
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}(length={len(self.items)})"
|
|
|
|
def debug_repr(self):
|
|
return self.debug_repr_helper("(", ")")
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
codegen.foreach(self.items)
|
|
codegen.append_output(create_build_tuple(len(self.items)))
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: list["VariableTracker"],
|
|
kwargs: dict[str, "VariableTracker"],
|
|
) -> "VariableTracker":
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def var_getattr(self, tx, name):
|
|
if name == "__class__":
|
|
source = AttrSource(self.source, name) if self.source else None
|
|
class_type = self.python_type()
|
|
if class_type is tuple:
|
|
return variables.BuiltinVariable(class_type, source=source)
|
|
else:
|
|
return variables.UserDefinedClassVariable(class_type, source=source)
|
|
return super().var_getattr(tx, name)
|
|
|
|
def call_obj_hasattr(
|
|
self, tx: "InstructionTranslator", name: str
|
|
) -> "VariableTracker":
|
|
if self.python_type() is not tuple:
|
|
return super().call_obj_hasattr(tx, name)
|
|
return variables.ConstantVariable.create(hasattr((), name))
|
|
|
|
|
|
class SizeVariable(TupleVariable):
|
|
"""torch.Size(...)"""
|
|
|
|
_nonvar_fields = {
|
|
"proxy",
|
|
*TupleVariable._nonvar_fields,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
items: list[VariableTracker],
|
|
proxy: Optional[torch.fx.Proxy] = None,
|
|
**kwargs,
|
|
) -> None:
|
|
self.proxy = proxy
|
|
super().__init__(items, **kwargs)
|
|
|
|
def debug_repr(self):
|
|
return self.debug_repr_helper("torch.Size([", "])")
|
|
|
|
def python_type(self):
|
|
return torch.Size
|
|
|
|
def as_proxy(self):
|
|
if self.proxy is not None:
|
|
return self.proxy
|
|
|
|
# torch.Size needs special handling. Normally, we pun a list-like
|
|
# container to directly contain Proxy/Node objects from FX, and FX
|
|
# knows to look inside containers (via map_aggregate). But torch.Size
|
|
# is weird; although it subclasses from tuple, it doesn't allow
|
|
# members which aren't int-like (rejecting Proxy and Node). This
|
|
# means we can't use the normal representation trick
|
|
# torch.Size([proxy0, proxy1]). I looked into seeing if I could
|
|
# relax torch.Size in PyTorch proper, but if torch.Size constructor
|
|
# sees a type that it doesn't recognize, it will try to call
|
|
# __index__() on it, so there is no BC way to actually change this
|
|
# behavior (though it occurs to me that I could have just added a
|
|
# YOLO no checking alternate constructor.)
|
|
#
|
|
# To work around this problem, I represent a torch.Size proxy as
|
|
# a straight up proxy, that would have been constructed by taking
|
|
# the constituent proxies as arguments. This trick can be generally
|
|
# used for any construct that we need a proxy for but we can't
|
|
# directly represent as an aggregate; I don't see very many examples
|
|
# of this in torchdynamo though!
|
|
|
|
# Look for a proxy. If there are none, do the legacy behavior
|
|
tracer = None
|
|
proxies = self._as_proxy()
|
|
for proxy in proxies:
|
|
if isinstance(proxy, torch.fx.Proxy):
|
|
tracer = proxy.tracer
|
|
break
|
|
|
|
if tracer is None:
|
|
return torch.Size(proxies)
|
|
|
|
proxy = tracer.create_proxy("call_function", torch.Size, (proxies,), {})
|
|
set_example_value(
|
|
proxy.node,
|
|
torch.Size(
|
|
[
|
|
p.node.meta["example_value"] if not isinstance(p, int) else p
|
|
for p in proxies
|
|
]
|
|
),
|
|
)
|
|
return proxy
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
codegen.add_push_null(lambda: codegen.load_import_from("torch", "Size"))
|
|
codegen.foreach(self.items)
|
|
build_torch_size = [
|
|
create_build_tuple(len(self.items)),
|
|
] + create_call_function(1, False)
|
|
codegen.extend_output(build_torch_size)
|
|
|
|
def unpack_var_sequence(self, tx):
|
|
return list(self.items)
|
|
|
|
def numel(self, tx):
|
|
from .builtin import BuiltinVariable
|
|
from .tensor import SymNodeVariable
|
|
|
|
const_result = 1
|
|
sym_sizes = []
|
|
|
|
for v in self.items:
|
|
if isinstance(v, ConstantVariable):
|
|
const_result *= v.value
|
|
else:
|
|
assert isinstance(v, SymNodeVariable), type(v)
|
|
# Delay proxy calls until we know it will be necessary
|
|
sym_sizes.append(v)
|
|
|
|
result = ConstantVariable.create(const_result)
|
|
if sym_sizes and const_result == 1:
|
|
# Skip multiplying by 1
|
|
result, *sym_sizes = sym_sizes
|
|
|
|
if not sym_sizes or const_result == 0:
|
|
return result
|
|
|
|
mul = BuiltinVariable(operator.mul)
|
|
for v in sym_sizes:
|
|
result = mul.call_function(tx, [result, v], {})
|
|
return result
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: list["VariableTracker"],
|
|
kwargs: dict[str, "VariableTracker"],
|
|
) -> "VariableTracker":
|
|
if name == "__getitem__":
|
|
assert not kwargs and len(args) == 1
|
|
out = self.get_item_dyn(tx, args[0])
|
|
return out
|
|
elif name == "numel":
|
|
assert not args and not kwargs
|
|
return self.numel(tx)
|
|
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker):
|
|
from .tensor import SymNodeVariable
|
|
|
|
if isinstance(arg, SymNodeVariable):
|
|
index = arg.sym_num
|
|
else:
|
|
index = arg.as_python_constant()
|
|
|
|
if isinstance(index, slice):
|
|
return SizeVariable(self.items[index])
|
|
else:
|
|
assert isinstance(index, (int, torch.SymInt))
|
|
return self.items[index]
|
|
|
|
def call_obj_hasattr(
|
|
self, tx: "InstructionTranslator", name: str
|
|
) -> "VariableTracker":
|
|
return variables.ConstantVariable.create(hasattr(torch.Size, name))
|
|
|
|
|
|
class NamedTupleVariable(TupleVariable):
|
|
_nonvar_fields = {
|
|
"tuple_cls",
|
|
"dynamic_attributes",
|
|
*TupleVariable._nonvar_fields,
|
|
}
|
|
|
|
def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None:
|
|
super().__init__(items, **kwargs)
|
|
self.tuple_cls = tuple_cls
|
|
self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {}
|
|
|
|
def is_namedtuple(self):
|
|
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(
|
|
getattr(self.tuple_cls, "_make", None)
|
|
)
|
|
|
|
def is_structseq(self):
|
|
return not self.is_namedtuple()
|
|
|
|
def fields(self):
|
|
return namedtuple_fields(self.tuple_cls)
|
|
|
|
def debug_repr(self):
|
|
if self.is_structseq():
|
|
# StructSequenceType(iterable)
|
|
return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items]))
|
|
# NamedTupleType(*iterable)
|
|
return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items)))
|
|
|
|
def python_type(self):
|
|
return self.tuple_cls
|
|
|
|
def as_python_constant(self):
|
|
if self.is_structseq():
|
|
# StructSequenceType(iterable)
|
|
result = self.python_type()([x.as_python_constant() for x in self.items])
|
|
else:
|
|
# NamedTupleType(*iterable)
|
|
result = self.python_type()(*[x.as_python_constant() for x in self.items])
|
|
|
|
# Apply dynamic attributes if any were set
|
|
if self.dynamic_attributes:
|
|
for attr_name, attr_value in self.dynamic_attributes.items():
|
|
# Convert VariableTracker to Python constant if needed
|
|
if hasattr(attr_value, "as_python_constant"):
|
|
python_value = attr_value.as_python_constant()
|
|
else:
|
|
raise NotImplementedError(
|
|
"Can not convert dynamic attribute without python constant value to python constant."
|
|
)
|
|
setattr(result, attr_name, python_value)
|
|
|
|
return result
|
|
|
|
def as_proxy(self):
|
|
assert self.python_type() is not SizeVariable
|
|
if self.is_structseq():
|
|
# StructSequenceType(iterable)
|
|
return self.python_type()(self._as_proxy())
|
|
# NamedTupleType(*iterable)
|
|
return self.python_type()(*self._as_proxy())
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
# Always reconstruct the NamedTuple normally first
|
|
# Constructors:
|
|
# StructSequenceType(iterable)
|
|
# NamedTupleType(*iterable)
|
|
# NamedTupleType._make(iterable)
|
|
create_fn = self.tuple_cls if self.is_structseq() else self.tuple_cls._make
|
|
codegen.add_push_null(
|
|
lambda: codegen.append_output(
|
|
codegen.create_load_const_unchecked(create_fn)
|
|
)
|
|
)
|
|
codegen.foreach(self.items)
|
|
codegen.extend_output(
|
|
[
|
|
create_build_tuple(len(self.items)),
|
|
]
|
|
+ create_call_function(1, False)
|
|
)
|
|
|
|
for name, value in self.dynamic_attributes.items():
|
|
codegen.dup_top()
|
|
codegen(value)
|
|
codegen.extend_output(create_rot_n(2))
|
|
codegen.store_attr(name)
|
|
|
|
def _is_method_overridden(self, method_name: str) -> bool:
|
|
"""Checks if a method is overridden in the NamedTuple subclass.
|
|
|
|
Args:
|
|
method_name (str): The name of the method to check.
|
|
|
|
Returns:
|
|
bool: True if the method is overridden in the subclass, False otherwise.
|
|
|
|
Raises:
|
|
ValueError: If the NamedTuple class does not inherit from both Tuple and Object.
|
|
"""
|
|
if len(self.tuple_cls.__mro__) < 3:
|
|
raise ValueError("NamedTuple should inherit from Tuple and Object.")
|
|
if getattr(self.tuple_cls, method_name, None) == getattr(
|
|
self.tuple_cls.__mro__[-3], method_name, None
|
|
):
|
|
return False
|
|
return True
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: list[VariableTracker],
|
|
kwargs: dict[str, VariableTracker],
|
|
) -> VariableTracker:
|
|
if name == "__setattr__":
|
|
assert len(args) == 2
|
|
assert len(kwargs) == 0
|
|
attr, value = args
|
|
attr = attr.as_python_constant()
|
|
if (
|
|
# structseq is immutable
|
|
self.is_structseq()
|
|
# namedtuple directly created by `collections.namedtuple` is immutable
|
|
or self.tuple_cls.__bases__ == (tuple,)
|
|
# fields are immutable
|
|
or attr in self.fields()
|
|
):
|
|
raise_observed_exception(AttributeError, tx)
|
|
# Subclass of namedtuple type can have dynamic attributes
|
|
tx.output.side_effects.mutation(self)
|
|
if self.source:
|
|
tx.output.side_effects.store_attr(self, attr, value)
|
|
self.dynamic_attributes[attr] = value
|
|
return ConstantVariable.create(None)
|
|
elif name == "_replace":
|
|
# NamedTuple._replace should create a new instance with replaced fields
|
|
if args:
|
|
raise_observed_exception(
|
|
TypeError,
|
|
tx,
|
|
args=[
|
|
ConstantVariable.create(
|
|
"_replace() takes no positional arguments"
|
|
)
|
|
],
|
|
)
|
|
|
|
# Get the field names for validation
|
|
fields = self.fields()
|
|
|
|
# Start with current items (copy them)
|
|
new_items = list(self.items)
|
|
|
|
# Replace fields specified in kwargs
|
|
for field_name, new_value in kwargs.items():
|
|
if field_name not in fields:
|
|
raise_observed_exception(
|
|
ValueError,
|
|
tx,
|
|
args=[
|
|
ConstantVariable.create(
|
|
f"Got unexpected field name: '{field_name}'"
|
|
)
|
|
],
|
|
)
|
|
|
|
# Replace the item at the field's index
|
|
field_index = fields.index(field_name)
|
|
new_items[field_index] = new_value
|
|
|
|
return NamedTupleVariable(new_items, self.tuple_cls)
|
|
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
|
if isinstance(arg, SliceVariable):
|
|
# slicing a namedtuple produces a tuple
|
|
return TupleVariable(
|
|
self.items[arg.as_python_constant()],
|
|
source=None,
|
|
)
|
|
return super().getitem_const(tx, arg)
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name):
|
|
def check_and_create_method():
|
|
method = inspect.getattr_static(self.tuple_cls, name, None)
|
|
if isinstance(method, classmethod):
|
|
# We need the unbounded cls method to avoid the inline __self__
|
|
return UserMethodVariable(
|
|
method.__func__,
|
|
variables.UserDefinedClassVariable(self.tuple_cls),
|
|
)
|
|
elif isinstance(method, staticmethod):
|
|
return UserFunctionVariable(method.__func__)
|
|
elif inspect.isfunction(method):
|
|
return UserMethodVariable(method, self)
|
|
else:
|
|
return None
|
|
|
|
# Avoid UserMethodVariable fallback precisely when methods NamedTuple methods have not been overwritten.
|
|
if (
|
|
name == "_replace"
|
|
and not self._is_method_overridden("_replace")
|
|
and not self._is_method_overridden("__getattr__")
|
|
):
|
|
# Return a BuiltinVariable for the _replace method
|
|
# Get the actual _replace method from the tuple class
|
|
actual_replace_method = getattr(self.tuple_cls, "_replace", None)
|
|
if actual_replace_method:
|
|
from ..source import AttrSource
|
|
|
|
source = AttrSource(self.source, name) if self.source else None
|
|
return variables.GetAttrVariable(self, name, source=source)
|
|
# Fallback if _replace doesn't exist (shouldn't happen for proper NamedTuples)
|
|
return super().var_getattr(tx, name)
|
|
|
|
if name == "_fields":
|
|
source = NamedTupleFieldsSource(self.source) if self.source else None
|
|
return VariableTracker.build(tx, self.fields(), source=source)
|
|
|
|
if name in self.dynamic_attributes:
|
|
return self.dynamic_attributes[name]
|
|
|
|
fields = self.fields()
|
|
if name not in fields:
|
|
method = check_and_create_method()
|
|
if not method:
|
|
return super().var_getattr(tx, name)
|
|
return method
|
|
return self.items[fields.index(name)]
|
|
|
|
def call_obj_hasattr(
|
|
self, tx: "InstructionTranslator", name: str
|
|
) -> "VariableTracker":
|
|
return variables.ConstantVariable.create(
|
|
name in self.dynamic_attributes or hasattr(self.tuple_cls, name)
|
|
)
|
|
|
|
|
|
class SliceVariable(VariableTracker):
|
|
def __init__(self, items, **kwargs) -> None:
|
|
items_to_map = items
|
|
start, stop, step = [variables.ConstantVariable.create(None)] * 3
|
|
|
|
if len(items_to_map) == 1:
|
|
(stop,) = items_to_map
|
|
elif len(items_to_map) == 2:
|
|
start, stop = items_to_map
|
|
elif len(items_to_map) == 3:
|
|
start, stop, step = items_to_map
|
|
else:
|
|
raise AssertionError
|
|
|
|
if isinstance(start, variables.TensorVariable) or isinstance(
|
|
stop, variables.TensorVariable
|
|
):
|
|
unimplemented_v2(
|
|
gb_type="Dynamic slicing with Tensor arguments",
|
|
context=f"SliceVariable start: {start}, stop: {stop}, step: {step}",
|
|
explanation="Creating slices with Tensor arguments is not supported. "
|
|
"e.g. `l[:x]`, where `x` is a 1-element tensor.",
|
|
hints=[
|
|
*graph_break_hints.SUPPORTABLE,
|
|
],
|
|
)
|
|
self.items = (start, stop, step)
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
def debug_repr(self):
|
|
return self.debug_repr_helper("slice(", ")")
|
|
|
|
def as_proxy(self):
|
|
return slice(*[x.as_proxy() for x in self.items])
|
|
|
|
def python_type(self):
|
|
return slice
|
|
|
|
def as_python_constant(self):
|
|
return slice(*[guard_if_dyn(x) for x in self.items])
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
codegen.foreach(self.items)
|
|
codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items)))
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name):
|
|
if name in cmp_name_to_op_mapping:
|
|
return variables.GetAttrVariable(self, name)
|
|
fields = ["start", "stop", "step"]
|
|
if name not in fields:
|
|
unimplemented_v2(
|
|
gb_type="Unsupported attribute for slice() object",
|
|
context=f"var_getattr {self} {name}",
|
|
explanation=f"Expected attribute to be one of {','.join(fields)} "
|
|
f"but got {name}",
|
|
hints=[*graph_break_hints.USER_ERROR],
|
|
)
|
|
return self.items[fields.index(name)]
|
|
|
|
|
|
class ListIteratorVariable(IteratorVariable):
|
|
_nonvar_fields = {
|
|
"index",
|
|
*IteratorVariable._nonvar_fields,
|
|
}
|
|
|
|
def __init__(self, items, index: int = 0, **kwargs) -> None:
|
|
super().__init__(**kwargs)
|
|
assert isinstance(items, list)
|
|
# Removing this check as it slows things down too much
|
|
# https://github.com/pytorch/pytorch/pull/87533#issuecomment-1287574492
|
|
|
|
# assert all(isinstance(x, VariableTracker) for x in items)
|
|
self.items = items
|
|
self.index = index
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"
|
|
|
|
def next_variable(self, tx):
|
|
assert self.is_mutable()
|
|
old_index = self.index
|
|
if old_index >= len(self.items):
|
|
raise_observed_exception(StopIteration, tx)
|
|
|
|
tx.output.side_effects.mutation(self)
|
|
self.index += 1
|
|
return self.items[old_index]
|
|
|
|
def call_obj_hasattr(self, tx, name):
|
|
return variables.ConstantVariable.create(hasattr(iter([]), name))
|
|
|
|
def python_type(self):
|
|
return type(iter([]))
|
|
|
|
def as_python_constant(self):
|
|
if self.index > 0:
|
|
raise NotImplementedError
|
|
return iter([x.as_python_constant() for x in self.items])
|
|
|
|
def has_unpack_var_sequence(self, tx):
|
|
return True
|
|
|
|
def unpack_var_sequence(self, tx):
|
|
r = list(self.items[self.index :])
|
|
self.index = len(self.items)
|
|
return r
|
|
|
|
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
|
|
return self.unpack_var_sequence(tx)
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
remaining_items = self.items[self.index :]
|
|
codegen.foreach(remaining_items)
|
|
codegen.extend_output(
|
|
[
|
|
create_build_tuple(len(remaining_items)),
|
|
create_instruction("GET_ITER"),
|
|
]
|
|
)
|
|
|
|
|
|
class TupleIteratorVariable(ListIteratorVariable):
|
|
pass
|
|
|
|
|
|
class RangeIteratorVariable(IteratorVariable):
|
|
# only needed for isinstance(..., range_iterator) to work
|
|
_nonvar_fields = {
|
|
"iter_obj",
|
|
}
|
|
|
|
def __init__(self, start: int, stop: int, step: int, len_: int, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.start = start
|
|
self.stop = stop
|
|
self.step = step
|
|
self.len = len_
|
|
|
|
def call_method(self, tx, name, args, kwargs):
|
|
if name == "__next__":
|
|
return self.next_variable(tx)
|
|
elif name == "__iter__":
|
|
return self
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def call_obj_hasattr(self, tx, name):
|
|
if self.python_type() is range_iterator:
|
|
ri = iter(range(0))
|
|
return ConstantVariable(hasattr(ri, name))
|
|
return super().call_obj_hasattr(tx, name)
|
|
|
|
def next_variable(self, tx):
|
|
if self.len <= 0:
|
|
raise_observed_exception(StopIteration, tx)
|
|
|
|
self.len -= 1
|
|
current = self.start
|
|
self.start += self.step
|
|
return ConstantVariable.create(current)
|
|
|
|
def python_type(self):
|
|
return range_iterator
|
|
|
|
def reconstruct(self, codegen: "PyCodegen"):
|
|
codegen.add_push_null(
|
|
lambda: codegen.append_output(codegen.create_load_python_module(range))
|
|
)
|
|
codegen.append_output(codegen.create_load_const(self.start))
|
|
codegen.append_output(codegen.create_load_const(self.stop))
|
|
codegen.append_output(codegen.create_load_const(self.step))
|
|
codegen.extend_output(create_call_function(3, False))
|
|
codegen.append_output(create_instruction("GET_ITER"))
|