Files
pytorch/torch/_dynamo/variables/lists.py
morrison-turnansky 86d34a43f5 NamedTuple: Allow side effects for dynamic attributes (#161645)
I confirmed that the tracing was correct i.e. NamedTupleVariable had the correct dynamic attribute added to it.

The problem was that NamedTupleVariable was always marked as immutable. This does not reflect the behavior of namedtuple.

Subclasses of namedtuple may be mutable, so when a NamedTupleVariable is derived from a subclass that is mutable, I made NamedTupleVariable mutable as well. Then side_effects correctly updates the returned object.

Fixes #161610

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161645
Approved by: https://github.com/anijain2305, https://github.com/StrongerXi
2025-09-09 19:42:02 +00:00

1475 lines
52 KiB
Python

# mypy: ignore-errors
"""
Variable tracking implementations for list-like data structures in Dynamo.
This module provides specialized variable tracking for various collection types:
- Lists and list subclasses (including torch.nn.ModuleList, ParameterList)
- Tuples and named tuples
- Ranges and slices
- Collections.deque
- torch.Size with special proxy handling
The implementations support both mutable and immutable collections, iteration,
and common sequence operations. Each collection type has a dedicated Variable
class that handles its unique behaviors while integrating with Dynamo's
variable tracking system.
"""
import collections
import inspect
import operator
import sys
from typing import Optional, TYPE_CHECKING
import torch
import torch.fx
from .. import graph_break_hints, polyfills, variables
from ..bytecode_transformation import (
create_call_function,
create_instruction,
create_rot_n,
)
from ..exc import raise_observed_exception, unimplemented_v2
from ..source import AttrSource, NamedTupleFieldsSource
from ..utils import (
cmp_name_to_op_mapping,
cmp_name_to_op_str_mapping,
get_fake_value,
guard_if_dyn,
iter_contains,
Lit,
namedtuple_fields,
odict_values,
raise_args_mismatch,
range_iterator,
set_example_value,
)
from .base import ValueMutationNew, VariableTracker
from .constant import ConstantVariable
from .functions import UserFunctionVariable, UserMethodVariable
from .iter import IteratorVariable
if TYPE_CHECKING:
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator
class BaseListVariable(VariableTracker):
@staticmethod
def cls_for_instance(obj):
return BaseListVariable.cls_for(type(obj))
@staticmethod
def cls_for(obj):
return {
iter: ListIteratorVariable,
list: ListVariable,
slice: SliceVariable,
torch.Size: SizeVariable,
tuple: TupleVariable,
odict_values: ListVariable,
torch.nn.ParameterList: ListVariable,
torch.nn.ModuleList: ListVariable,
collections.deque: DequeVariable,
}[obj]
def __init__(
self,
items: list[VariableTracker],
**kwargs,
) -> None:
super().__init__(**kwargs)
assert isinstance(items, list)
assert all(isinstance(x, VariableTracker) for x in items)
self.items: list[VariableTracker] = items
def _as_proxy(self):
return [x.as_proxy() for x in self.items]
def modified(self, items, **kwargs):
return type(self)(items, **kwargs)
@property
def value(self):
return self.as_python_constant()
def debug_repr_helper(self, prefix, suffix):
return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix
def as_python_constant(self):
return self.python_type()([x.as_python_constant() for x in self.items])
def as_proxy(self):
assert self.python_type() is not SizeVariable
return self.python_type()(self._as_proxy())
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
from .tensor import SymNodeVariable
if isinstance(arg, SymNodeVariable):
index = arg.sym_num
else:
index = arg.as_python_constant()
if isinstance(index, slice):
if index.step == 0:
msg = ConstantVariable.create("slice step cannot be zero")
raise_observed_exception(ValueError, tx, args=[msg])
# Set source to None because slicing a list gives a new local
return self.clone(
items=self.items[index],
source=None,
mutation_type=ValueMutationNew() if self.mutation_type else None,
)
else:
assert isinstance(index, (int, torch.SymInt))
try:
return self.items[index]
except IndexError:
raise_observed_exception(
IndexError, tx, args=["list index out of range"]
)
def unpack_var_sequence(self, tx):
return list(self.items)
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if name == "__getitem__":
from .tensor import TensorVariable
if len(args) != 1:
msg = ConstantVariable.create(
f"{name} takes exactly one argument ({len(args)} given)"
)
raise_observed_exception(TypeError, tx, args=[msg])
assert not kwargs and len(args) == 1
if isinstance(args[0], TensorVariable):
value = get_fake_value(args[0].as_proxy().node, tx)
if value.constant is not None and value.constant.numel() == 1:
value = variables.ConstantVariable.create(value.constant.item())
else:
unimplemented_v2(
gb_type="Indexing list with non-scalar tensor",
context=f"call_method {self} {name} {args} {kwargs}",
explanation=(
"Attempted to index list-like object with tensor with > 1 element."
),
hints=[*graph_break_hints.USER_ERROR],
)
else:
value = args[0]
if value.python_type() not in (int, slice):
msg = f"indices must be integers or slices, not {value.python_type()}"
raise_observed_exception(TypeError, tx, args=[ConstantVariable(msg)])
return self.getitem_const(tx, value)
elif name == "__contains__":
if len(args) != 1 or kwargs:
raise_args_mismatch(tx, name)
return iter_contains(self.unpack_var_sequence(tx), args[0], tx)
elif name == "index":
if not len(args):
raise_args_mismatch(tx, name)
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.index),
[self] + list(args),
kwargs,
)
elif name == "count":
if len(args) != 1:
raise_args_mismatch(tx, name)
return VariableTracker.build(tx, operator.countOf).call_function(
tx,
[self, args[0]],
kwargs,
)
elif name in ("__add__", "__iadd__"):
if kwargs or len(args) != 1:
raise_args_mismatch(tx, name)
if type(self) != type(args[0]):
tp_name = self.python_type_name()
other = args[0].python_type_name()
msg = ConstantVariable.create(
f'can only concatenate {tp_name} (not "{other}") to {tp_name}'
)
raise_observed_exception(TypeError, tx, args=[msg])
if name == "__add__":
return type(self)(self.items + args[0].items, source=self.source)
else:
self.items += args[0].items
return self
elif name in ("__mul__", "__imul__"):
if kwargs or len(args) != 1:
raise_args_mismatch(tx, name)
if not (args[0].is_python_constant() and args[0].python_type() is int):
msg = ConstantVariable.create(
f"can't multiply sequence by non-int type of '{args[0].python_type_name()}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
val = args[0].as_python_constant()
if name == "__mul__":
return type(self)(self.items * val, source=self.source)
else:
self.items *= val
return self
elif name in cmp_name_to_op_mapping:
if len(args) != 1:
raise_args_mismatch(tx, name)
left = self
right = args[0]
# TODO this type check logic mirrors the following
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/object.c#L991-L1007
# But we should probably move it up the stack to so that we don't
# need to duplicate it for different VTs.
if not isinstance(left, BaseListVariable) or not isinstance(
right, BaseListVariable
):
if name == "__eq__":
return variables.BuiltinVariable(operator.is_).call_function(
tx, (left, right), {}
)
elif name == "__ne__":
return variables.BuiltinVariable(operator.is_not).call_function(
tx, (left, right), {}
)
else:
op_str = cmp_name_to_op_str_mapping[name]
left_ty = left.python_type_name()
right_ty = right.python_type_name()
msg = f"{op_str} not supported between instances of '{left_ty}' and '{right_ty}'"
raise_observed_exception(TypeError, tx, args=[msg])
return variables.UserFunctionVariable(polyfills.list_cmp).call_function(
tx,
[variables.BuiltinVariable(cmp_name_to_op_mapping[name]), left, right],
{},
)
return super().call_method(tx, name, args, kwargs)
class RangeVariable(BaseListVariable):
def __init__(self, items, **kwargs) -> None:
items_to_map = items
start = variables.ConstantVariable.create(0)
stop = None
step = variables.ConstantVariable.create(1)
if len(items_to_map) == 1:
(stop,) = items_to_map
elif len(items_to_map) == 2:
start, stop = items_to_map
elif len(items_to_map) == 3:
start, stop, step = items_to_map
else:
raise AssertionError
def maybe_as_int(x):
return (
ConstantVariable(int(x.value)) if isinstance(x, ConstantVariable) else x
)
# cast each argument to an integer
start = maybe_as_int(start)
step = maybe_as_int(step)
stop = maybe_as_int(stop)
assert stop is not None
super().__init__([start, stop, step], **kwargs)
def debug_repr(self):
return self.debug_repr_helper("range(", ")")
def python_type(self):
return range
def start(self):
return self.items[0].as_python_constant()
def stop(self):
return self.items[1].as_python_constant()
def step(self):
return self.items[2].as_python_constant()
def range_length(self):
lo = self.start()
hi = self.stop()
step = self.step()
assert step != 0
if step > 0 and lo < hi:
return 1 + (hi - 1 - lo) // step
elif step < 0 and lo > hi:
return 1 + (lo - 1 - hi) // (0 - step)
else:
return 0
def _get_slice_indices(self, length, slice):
step_is_negative = 0
if slice.step is None:
step = 1
step_is_negative = False
else:
step = slice.step
step_is_negative = slice.step < 0
# Find lower and upper bounds for start and stop.
if step_is_negative:
lower = -1
upper = length + lower
else:
lower = 0
upper = length
# Compute start
if slice.start is None:
start = upper if step_is_negative else lower
else:
start = slice.start
if start < 0:
start += length
if start < lower:
start = lower
else:
if start > upper:
start = upper
# Compute stop.
if slice.stop is None:
stop = lower if step_is_negative else upper
else:
stop = slice.stop
if stop < 0:
stop += length
if stop < lower:
stop = lower
else:
if stop > upper:
stop = upper
return [start, stop, step]
def apply_index(self, index):
length = self.range_length()
if index < 0:
index = length + index
if index < 0 or index >= length:
tx = torch._dynamo.symbolic_convert.InstructionTranslator.current_tx()
raise_observed_exception(
IndexError,
tx,
args=[ConstantVariable("range object index out of range")],
)
return variables.ConstantVariable.create(self.start() + (index * self.step()))
def apply_slice(self, slice):
(slice_start, slice_stop, slice_step) = self._get_slice_indices(
self.range_length(), slice
)
def compute_item(index):
return self.start() + (index * self.step())
sub_step = self.step() * slice_step
sub_start = compute_item(slice_start)
sub_stop = compute_item(slice_stop)
result = RangeVariable(
[
variables.ConstantVariable.create(x)
for x in [sub_start, sub_stop, sub_step]
],
mutation_type=ValueMutationNew() if self.mutation_type else None,
)
return result
def as_python_constant(self):
return range(*[x.as_python_constant() for x in self.items])
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
# implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c
index = arg.as_python_constant()
if isinstance(index, slice):
return self.apply_slice(index)
elif isinstance(index, int):
return self.apply_index(index)
else:
msg = ConstantVariable("range indices must be integers or slices")
raise_observed_exception(TypeError, tx, args=[msg])
def as_proxy(self):
return self.python_type()(*self._as_proxy())
def unpack_var_sequence(self, tx=None):
return [variables.ConstantVariable.create(x) for x in self.as_python_constant()]
def reconstruct(self, codegen: "PyCodegen") -> None:
assert "range" not in codegen.tx.f_globals
codegen.add_push_null(
lambda: codegen.append_output(codegen.create_load_python_module(range))
)
codegen.foreach(self.items)
codegen.extend_output(create_call_function(3, False))
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
if self.python_type() is not range:
return super().call_obj_hasattr(tx, name)
return variables.ConstantVariable.create(hasattr(range(0), name))
def range_equals(self, other: "RangeVariable"):
r0, r1 = self, other
if (
self.range_length() != r1.range_length()
or self.range_length() == 0
or r0.start() != r1.start()
):
return False
if len(r0) == 1:
return True
return r0.step() == r1.step()
def range_count(self, x: VariableTracker):
# Based on CPython
# https://github.com/guilhermeleobas/cpython/blob/baefaa6cba1d69efd2f930cdc56bca682c54b139/Objects/rangeobject.c#L442-L486
x = x.as_python_constant()
if type(x) not in (bool, int, float):
return 0
start, stop, step = self.start(), self.stop(), self.step()
if step == 0:
return 0
in_range = (start <= x < stop) if step > 0 else (stop < x <= start)
if in_range:
re = ((x - start) % step) == 0
return int(re)
return 0
def call_method(self, tx, name, args, kwargs):
if name == "__iter__":
if not all(var.is_python_constant() for var in self.items):
# Can't represent a `range_iterator` without well defined bounds
return variables.misc.DelayGraphBreakVariable(
msg="Cannot create range_iterator: bounds (start, stop, step) must be fully defined as concrete constants.",
)
return RangeIteratorVariable(
self.start(), self.stop(), self.step(), self.range_length()
)
elif name == "__len__":
length = self.range_length()
if length > sys.maxsize:
raise_observed_exception(OverflowError, tx)
return ConstantVariable.create(self.range_length())
elif name in ("count", "__contains__"):
return ConstantVariable(self.range_count(*args))
elif name == "__getitem__":
return self.getitem_const(tx, *args)
elif name in cmp_name_to_op_mapping:
other = args[0]
pt = other.python_type()
if name not in ("__eq__", "__ne__"):
# ranges are only comparable to other ranges
msg = f"{name} not supported between instances of 'range' and '{pt}'"
raise_observed_exception(
TypeError,
tx,
args=[ConstantVariable.create(msg)],
)
if pt is not range:
return ConstantVariable.create(NotImplemented)
cmp = self.range_equals(other)
# Two ranges are equal if they produce the same sequence of values
if name == "__eq__":
return ConstantVariable(cmp)
else:
return ConstantVariable(not cmp)
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx: "InstructionTranslator", name):
fields = ["start", "stop", "step"]
if name in fields:
return self.items[fields.index(name)]
return super().var_getattr(tx, name)
class CommonListMethodsVariable(BaseListVariable):
"""
Implement methods common to List and other List-like things
"""
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
from .tensor import SymNodeVariable
if name == "append" and self.is_mutable():
assert not kwargs
if len(args) != 1:
raise_args_mismatch(tx, name)
(arg,) = args
tx.output.side_effects.mutation(self)
self.items.append(arg)
return ConstantVariable.create(None)
elif name == "extend" and self.is_mutable():
if len(args) != 1 or kwargs:
raise_args_mismatch(tx, name)
if not args[0].has_force_unpack_var_sequence(tx):
msg = ConstantVariable.create(f"{type(args[0])} object is not iterable")
raise_observed_exception(TypeError, tx, args=[msg])
(arg,) = args
arg.force_apply_to_var_sequence(
tx, lambda item: self.call_method(tx, "append", [item], {})
)
return ConstantVariable.create(None)
elif name == "insert" and self.is_mutable():
if kwargs or len(args) != 2:
raise_args_mismatch(tx, name)
idx, value = args
if isinstance(idx, SymNodeVariable):
const_idx = idx.evaluate_expr()
else:
const_idx = idx.as_python_constant()
tx.output.side_effects.mutation(self)
self.items.insert(const_idx, value)
return ConstantVariable.create(None)
elif name == "pop" and self.is_mutable():
assert not kwargs
if kwargs or len(args) > 1:
raise_args_mismatch(tx, name)
if len(self.items) == 0:
msg = ConstantVariable.create("pop from empty list")
raise_observed_exception(IndexError, tx, args=[msg])
if len(args):
idx = args[0].as_python_constant()
if idx > len(self.items):
msg = ConstantVariable.create("pop index out of range")
raise_observed_exception(IndexError, tx, args=[msg])
tx.output.side_effects.mutation(self)
return self.items.pop(*[a.as_python_constant() for a in args])
elif name == "clear" and self.is_mutable():
if args or kwargs:
raise_observed_exception(TypeError, tx)
tx.output.side_effects.mutation(self)
self.items.clear()
return ConstantVariable.create(None)
elif (
name == "__setitem__"
and self.is_mutable()
and args
and (
args[0].is_python_constant()
or isinstance(args[0], SymNodeVariable)
or (
isinstance(args[0], SliceVariable)
and all(
s.is_python_constant() or isinstance(s, SymNodeVariable)
for s in args[0].items
)
)
)
):
assert not kwargs
key, value = args
tx.output.side_effects.mutation(self)
if isinstance(key, SymNodeVariable):
self.items[key.evaluate_expr()] = value
elif isinstance(key, SliceVariable):
if key.is_python_constant():
self.items[key.as_python_constant()] = list(value.items)
else:
items = slice(
*[
s.evaluate_expr()
if isinstance(s, SymNodeVariable)
else s.as_python_constant()
for s in key.items
]
)
self.items[items] = list(value.items)
else:
self.items[key.as_python_constant()] = value
return ConstantVariable.create(None)
elif name == "__delitem__" and self.is_mutable():
if kwargs or len(args) != 1:
raise_args_mismatch(tx, name)
tx.output.side_effects.mutation(self)
if args[0].is_python_constant() and isinstance(
args[0].as_python_constant(), (int, slice)
):
if isinstance(args[0], SymNodeVariable):
idx = args[0].evaluate_expr()
else:
idx = args[0].as_python_constant()
try:
self.items.__delitem__(idx)
except (IndexError, ValueError) as exc:
raise_observed_exception(
type(exc),
tx,
args=list(map(ConstantVariable.create, exc.args)),
)
else:
msg = ConstantVariable.create(
f"list indices must be integers or slices, not {args[0].python_type_name()}"
)
raise_observed_exception(TypeError, tx, args=[msg])
return ConstantVariable.create(None)
elif name == "copy":
# List copy() doesn't have args and kwargs
if args or kwargs:
raise_args_mismatch(tx, name)
items = list(self.items)
return self.modified(items, mutation_type=ValueMutationNew())
elif name == "reverse" and self.is_mutable():
if args or kwargs:
raise_args_mismatch(tx, name)
self.items.reverse()
tx.output.side_effects.mutation(self)
return ConstantVariable.create(None)
elif name == "remove" and self.is_mutable():
if len(args) != 1 or kwargs:
raise_args_mismatch(tx, name)
idx = self.call_method(tx, "index", args, kwargs)
self.call_method(tx, "pop", [idx], {})
return ConstantVariable.create(None)
else:
return super().call_method(tx, name, args, kwargs)
class ListVariable(CommonListMethodsVariable):
def python_type(self):
return list
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)})"
def debug_repr(self):
return self.debug_repr_helper("[", "]")
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.foreach(self.items)
codegen.append_output(create_instruction("BUILD_LIST", arg=len(self.items)))
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
from .tensor import SymNodeVariable
if name == "__setitem__" and self.is_mutable():
if kwargs or len(args) != 2:
raise_args_mismatch(tx, name)
key, value = args
if not key.is_python_constant():
# probably will graph-break
super().call_method(tx, name, args, kwargs)
tx.output.side_effects.mutation(self)
if isinstance(key, SliceVariable):
if not value.has_force_unpack_var_sequence(tx):
msg = ConstantVariable.create("can only assign an iterable")
raise_observed_exception(TypeError, tx, args=[msg])
key = key.as_python_constant()
if key.step == 0:
msg = ConstantVariable.create("slice step cannot be zero")
raise_observed_exception(ValueError, tx, args=[msg])
value = value.force_unpack_var_sequence(tx)
try:
self.items[key] = value
except Exception as exc:
raise_observed_exception(
type(exc),
tx,
args=list(map(ConstantVariable.create, exc.args)),
)
else:
if isinstance(key, SymNodeVariable):
key = key.evaluate_expr()
else:
key = key.as_python_constant()
try:
self.items[key] = value
except (IndexError, TypeError) as e:
raise_observed_exception(
type(e), tx, args=list(map(ConstantVariable.create, e.args))
)
return ConstantVariable.create(None)
if name == "sort" and self.is_mutable():
assert len(args) == 0
key_fn_var = kwargs.pop("key", ConstantVariable.create(None))
reverse = kwargs.pop(
"reverse", ConstantVariable.create(False)
).as_python_constant()
assert len(kwargs) == 0
if (
key_fn_var.is_python_constant()
and key_fn_var.as_python_constant() is None
):
keys = self.items.copy()
else:
keys = [key_fn_var.call_function(tx, [x], {}) for x in self.items]
if not all(k.is_python_constant() for k in keys):
first_non_constant_key = None
for k in keys:
if not k.is_python_constant():
first_non_constant_key = k
assert first_non_constant_key is not None
try:
python_type = first_non_constant_key.python_type()
except NotImplementedError:
python_type = "unknown"
unimplemented_v2(
gb_type="sort with non-constant keys",
context=str(first_non_constant_key),
explanation=(
f"Cannot perform sort with non-constant key. "
f"First non-constant key type: {python_type}. "
f"Most notably, we cannot sort with Tensor or SymInt keys, but we can "
f"sort ints."
),
hints=["Use something else as the key."],
)
tx.output.side_effects.mutation(self)
sorted_items_with_keys = sorted(
(
(
x,
k.as_python_constant(),
-i if reverse else i, # extra key to ensure stable sort
)
for i, (k, x) in enumerate(zip(keys, self.items))
),
key=operator.itemgetter(1, 2),
reverse=reverse,
)
self.items[:] = [x for x, *_ in sorted_items_with_keys]
return ConstantVariable.create(None)
if name == "__init__" and self.is_mutable():
assert not kwargs
if len(args) == 0:
return ConstantVariable.create(None)
elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
(arg,) = args
tx.output.side_effects.mutation(self)
self.items[:] = arg.force_unpack_var_sequence(tx)
return ConstantVariable.create(None)
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
if name == "__class__":
source = AttrSource(self.source, name) if self.source else None
class_type = self.python_type()
if class_type is list:
return variables.BuiltinVariable(class_type, source=source)
else:
return variables.UserDefinedClassVariable(class_type, source=source)
return super().var_getattr(tx, name)
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
if self.python_type() is not list:
return super().call_obj_hasattr(tx, name)
return variables.ConstantVariable.create(hasattr([], name))
class DequeVariable(CommonListMethodsVariable):
def __init__(self, items, maxlen=None, **kwargs) -> None:
if maxlen is None:
maxlen = ConstantVariable.create(None)
assert maxlen.is_python_constant(), (
f"maxlen must be a constant, got: {maxlen.debug_repr()}"
)
self.maxlen = maxlen
items = list(items)
if self.maxlen.as_python_constant() is not None:
items = items[-maxlen.as_python_constant() :]
super().__init__(items, **kwargs)
def python_type(self):
return collections.deque
def debug_repr(self):
if self.maxlen.as_python_constant() is None:
return self.debug_repr_helper(
"deque([", "], maxlen=" + self.maxlen.debug_repr() + ")"
)
return self.debug_repr_helper("deque([", "])")
def as_python_constant(self):
return self.python_type()(
[x.as_python_constant() for x in self.items],
maxlen=self.maxlen.as_python_constant(),
)
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.append_output(
codegen.create_load_python_module(collections.deque)
)
)
codegen.foreach(self.items)
codegen.extend_output([create_instruction("BUILD_LIST", arg=len(self.items))])
codegen(self.maxlen)
codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False))
def var_getattr(self, tx: "InstructionTranslator", name):
if name == "maxlen":
return self.maxlen
return super().var_getattr(tx, name)
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if (
name == "__setitem__"
and self.is_mutable()
and args
and args[0].is_python_constant()
):
assert len(args) == 2
assert not kwargs
key, value = args
assert key.is_python_constant()
assert isinstance(key.as_python_constant(), int)
tx.output.side_effects.mutation(self)
self.items[key.as_python_constant()] = value
return ConstantVariable.create(None)
maxlen = self.maxlen.as_python_constant()
if maxlen is not None:
slice_within_maxlen = slice(-maxlen, None)
else:
slice_within_maxlen = None
if (
name == "extendleft"
and self.is_mutable()
and len(args) > 0
and args[0].has_force_unpack_var_sequence(tx)
):
assert len(args) == 1
assert not kwargs
# NOTE this is inefficient, but the alternative is to represent self.items
# as a deque, which is a more intrusive change.
args[0].force_apply_to_var_sequence(
tx, lambda item: self.call_method(tx, "appendleft", [item], {})
)
slice_within_maxlen = slice(None, maxlen)
result = ConstantVariable.create(None)
elif name == "popleft" and self.is_mutable():
assert not args
assert not kwargs
tx.output.side_effects.mutation(self)
result, *self.items[:] = self.items
elif name == "appendleft" and len(args) > 0 and self.is_mutable():
assert len(args) == 1
assert not kwargs
tx.output.side_effects.mutation(self)
self.items[:] = [args[0], *self.items]
slice_within_maxlen = slice(None, maxlen)
result = ConstantVariable.create(None)
elif name == "insert" and len(args) > 0 and self.is_mutable():
assert len(args) == 2
assert not kwargs
if maxlen is not None and len(self.items) == maxlen:
raise_observed_exception(
IndexError, tx, args=["deque already at its maximum size"]
)
result = super().call_method(tx, name, args, kwargs)
else:
result = super().call_method(tx, name, args, kwargs)
if (
slice_within_maxlen is not None
and maxlen is not None
and len(self.items) > maxlen
):
self.items[:] = self.items[slice_within_maxlen]
return result
class TupleVariable(BaseListVariable):
def python_type(self):
return tuple
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)})"
def debug_repr(self):
return self.debug_repr_helper("(", ")")
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.foreach(self.items)
codegen.append_output(create_instruction("BUILD_TUPLE", arg=len(self.items)))
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
if name == "__class__":
source = AttrSource(self.source, name) if self.source else None
class_type = self.python_type()
if class_type is tuple:
return variables.BuiltinVariable(class_type, source=source)
else:
return variables.UserDefinedClassVariable(class_type, source=source)
return super().var_getattr(tx, name)
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
if self.python_type() is not tuple:
return super().call_obj_hasattr(tx, name)
return variables.ConstantVariable.create(hasattr((), name))
class SizeVariable(TupleVariable):
"""torch.Size(...)"""
_nonvar_fields = {
"proxy",
*TupleVariable._nonvar_fields,
}
def __init__(
self,
items: list[VariableTracker],
proxy: Optional[torch.fx.Proxy] = None,
**kwargs,
) -> None:
self.proxy = proxy
super().__init__(items, **kwargs)
def debug_repr(self):
return self.debug_repr_helper("torch.Size([", "])")
def python_type(self):
return torch.Size
def as_proxy(self):
if self.proxy is not None:
return self.proxy
# torch.Size needs special handling. Normally, we pun a list-like
# container to directly contain Proxy/Node objects from FX, and FX
# knows to look inside containers (via map_aggregate). But torch.Size
# is weird; although it subclasses from tuple, it doesn't allow
# members which aren't int-like (rejecting Proxy and Node). This
# means we can't use the normal representation trick
# torch.Size([proxy0, proxy1]). I looked into seeing if I could
# relax torch.Size in PyTorch proper, but if torch.Size constructor
# sees a type that it doesn't recognize, it will try to call
# __index__() on it, so there is no BC way to actually change this
# behavior (though it occurs to me that I could have just added a
# YOLO no checking alternate constructor.)
#
# To work around this problem, I represent a torch.Size proxy as
# a straight up proxy, that would have been constructed by taking
# the constituent proxies as arguments. This trick can be generally
# used for any construct that we need a proxy for but we can't
# directly represent as an aggregate; I don't see very many examples
# of this in torchdynamo though!
# Look for a proxy. If there are none, do the legacy behavior
tracer = None
proxies = self._as_proxy()
for proxy in proxies:
if isinstance(proxy, torch.fx.Proxy):
tracer = proxy.tracer
break
if tracer is None:
return torch.Size(proxies)
proxy = tracer.create_proxy("call_function", torch.Size, (proxies,), {})
set_example_value(
proxy.node,
torch.Size(
[
p.node.meta["example_value"] if not isinstance(p, int) else p
for p in proxies
]
),
)
return proxy
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(lambda: codegen.load_import_from("torch", "Size"))
codegen.foreach(self.items)
build_torch_size = [
create_instruction("BUILD_TUPLE", arg=len(self.items)),
] + create_call_function(1, False)
codegen.extend_output(build_torch_size)
def unpack_var_sequence(self, tx):
return list(self.items)
def numel(self, tx):
from .builtin import BuiltinVariable
from .tensor import SymNodeVariable
const_result = 1
sym_sizes = []
for v in self.items:
if isinstance(v, ConstantVariable):
const_result *= v.value
else:
assert isinstance(v, SymNodeVariable), type(v)
# Delay proxy calls until we know it will be necessary
sym_sizes.append(v)
result = ConstantVariable.create(const_result)
if sym_sizes and const_result == 1:
# Skip multiplying by 1
result, *sym_sizes = sym_sizes
if not sym_sizes or const_result == 0:
return result
mul = BuiltinVariable(operator.mul)
for v in sym_sizes:
result = mul.call_function(tx, [result, v], {})
return result
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if name == "__getitem__":
assert not kwargs and len(args) == 1
out = self.get_item_dyn(tx, args[0])
return out
elif name == "numel":
assert not args and not kwargs
return self.numel(tx)
return super().call_method(tx, name, args, kwargs)
def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker):
from .tensor import SymNodeVariable
if isinstance(arg, SymNodeVariable):
index = arg.sym_num
else:
index = arg.as_python_constant()
if isinstance(index, slice):
return SizeVariable(self.items[index])
else:
assert isinstance(index, (int, torch.SymInt))
return self.items[index]
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
return variables.ConstantVariable.create(hasattr(torch.Size, name))
class NamedTupleVariable(TupleVariable):
_nonvar_fields = {
"tuple_cls",
"dynamic_attributes",
*TupleVariable._nonvar_fields,
}
def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None:
super().__init__(items, **kwargs)
self.tuple_cls = tuple_cls
self.dynamic_attributes = {} if not dynamic_attributes else dynamic_attributes
def is_namedtuple(self):
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(
getattr(self.tuple_cls, "_make", None)
)
def is_structseq(self):
return not self.is_namedtuple()
def fields(self):
return namedtuple_fields(self.tuple_cls)
def debug_repr(self):
if self.is_structseq():
# StructSequenceType(iterable)
return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items]))
# NamedTupleType(*iterable)
return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items)))
def python_type(self):
return self.tuple_cls
def as_python_constant(self):
if self.is_structseq():
# StructSequenceType(iterable)
result = self.python_type()([x.as_python_constant() for x in self.items])
else:
# NamedTupleType(*iterable)
result = self.python_type()(*[x.as_python_constant() for x in self.items])
# Apply dynamic attributes if any were set
if self.dynamic_attributes:
for attr_name, attr_value in self.dynamic_attributes.items():
# Convert VariableTracker to Python constant if needed
if hasattr(attr_value, "as_python_constant"):
python_value = attr_value.as_python_constant()
else:
raise NotImplementedError(
"Can not convert dynamic attribute without python constant value to python constant."
)
setattr(result, attr_name, python_value)
return result
def as_proxy(self):
assert self.python_type() is not SizeVariable
if self.is_structseq():
# StructSequenceType(iterable)
return self.python_type()(self._as_proxy())
# NamedTupleType(*iterable)
return self.python_type()(*self._as_proxy())
def reconstruct(self, codegen: "PyCodegen") -> None:
# Always reconstruct the NamedTuple normally first
# Constructors:
# StructSequenceType(iterable)
# NamedTupleType(*iterable)
# NamedTupleType._make(iterable)
create_fn = self.tuple_cls if self.is_structseq() else self.tuple_cls._make
codegen.add_push_null(
lambda: codegen.append_output(
codegen.create_load_const_unchecked(create_fn)
)
)
codegen.foreach(self.items)
codegen.extend_output(
[
create_instruction("BUILD_TUPLE", arg=len(self.items)),
]
+ create_call_function(1, False)
)
for name, value in self.dynamic_attributes.items():
codegen.dup_top()
codegen(value)
codegen.extend_output(create_rot_n(2))
codegen.store_attr(name)
def call_method(
self,
tx,
name,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__setattr__":
assert len(args) == 2
assert len(kwargs) == 0
attr, value = args
attr = attr.as_python_constant()
if (
# structseq is immutable
self.is_structseq()
# namedtuple directly created by `collections.namedtuple` is immutable
or self.tuple_cls.__bases__ == (tuple,)
# fields are immutable
or attr in self.fields()
):
raise_observed_exception(AttributeError, tx)
# Subclass of namedtuple type can have dynamic attributes
tx.output.side_effects.mutation(self)
if self.source:
tx.output.side_effects.store_attr(self, attr, value)
self.dynamic_attributes[attr] = value
return ConstantVariable.create(None)
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx: "InstructionTranslator", name):
def check_and_create_method():
method = inspect.getattr_static(self.tuple_cls, name, None)
if isinstance(method, classmethod):
# We need the unbounded cls method to avoid the inline __self__
return UserMethodVariable(
method.__func__,
variables.UserDefinedClassVariable(self.tuple_cls),
)
elif isinstance(method, staticmethod):
return UserFunctionVariable(method.__func__)
elif inspect.isfunction(method):
return UserMethodVariable(method, self)
else:
return None
if name == "_fields":
source = NamedTupleFieldsSource(self.source) if self.source else None
return VariableTracker.build(tx, self.fields(), source=source)
if name in self.dynamic_attributes:
return self.dynamic_attributes[name]
fields = self.fields()
if name not in fields:
method = check_and_create_method()
if not method:
return super().var_getattr(tx, name)
return method
return self.items[fields.index(name)]
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
return variables.ConstantVariable.create(
name in self.dynamic_attributes or hasattr(self.tuple_cls, name)
)
class SliceVariable(VariableTracker):
def __init__(self, items, **kwargs) -> None:
items_to_map = items
start, stop, step = [variables.ConstantVariable.create(None)] * 3
if len(items_to_map) == 1:
(stop,) = items_to_map
elif len(items_to_map) == 2:
start, stop = items_to_map
elif len(items_to_map) == 3:
start, stop, step = items_to_map
else:
raise AssertionError
if isinstance(start, variables.TensorVariable) or isinstance(
stop, variables.TensorVariable
):
unimplemented_v2(
gb_type="Dynamic slicing with Tensor arguments",
context=f"SliceVariable start: {start}, stop: {stop}, step: {step}",
explanation="Creating slices with Tensor arguments is not supported. "
"e.g. `l[:x]`, where `x` is a 1-element tensor.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
self.items = (start, stop, step)
super().__init__(**kwargs)
def debug_repr(self):
return self.debug_repr_helper("slice(", ")")
def as_proxy(self):
return slice(*[x.as_proxy() for x in self.items])
def python_type(self):
return slice
def as_python_constant(self):
return slice(*[guard_if_dyn(x) for x in self.items])
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.foreach(self.items)
codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items)))
def var_getattr(self, tx: "InstructionTranslator", name):
if name in cmp_name_to_op_mapping:
return variables.GetAttrVariable(self, name)
fields = ["start", "stop", "step"]
if name not in fields:
unimplemented_v2(
gb_type="Unsupported attribute for slice() object",
context=f"var_getattr {self} {name}",
explanation=f"Expected attribute to be one of {','.join(fields)} "
f"but got {name}",
hints=[*graph_break_hints.USER_ERROR],
)
return self.items[fields.index(name)]
class ListIteratorVariable(IteratorVariable):
_nonvar_fields = {
"index",
*IteratorVariable._nonvar_fields,
}
def __init__(self, items, index: int = 0, **kwargs) -> None:
super().__init__(**kwargs)
assert isinstance(items, list)
# Removing this check as it slows things down too much
# https://github.com/pytorch/pytorch/pull/87533#issuecomment-1287574492
# assert all(isinstance(x, VariableTracker) for x in items)
self.items = items
self.index = index
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"
def next_variable(self, tx):
assert self.is_mutable()
old_index = self.index
if old_index >= len(self.items):
raise_observed_exception(StopIteration, tx)
tx.output.side_effects.mutation(self)
self.index += 1
return self.items[old_index]
def call_obj_hasattr(self, tx, name):
return variables.ConstantVariable.create(hasattr(iter([]), name))
def python_type(self):
return type(iter([]))
def as_python_constant(self):
if self.index > 0:
raise NotImplementedError
return iter([x.as_python_constant() for x in self.items])
def has_unpack_var_sequence(self, tx):
return True
def unpack_var_sequence(self, tx):
r = list(self.items[self.index :])
self.index = len(self.items)
return r
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
return self.unpack_var_sequence(tx)
def reconstruct(self, codegen: "PyCodegen") -> None:
remaining_items = self.items[self.index :]
codegen.foreach(remaining_items)
codegen.extend_output(
[
create_instruction("BUILD_TUPLE", arg=len(remaining_items)),
create_instruction("GET_ITER"),
]
)
class TupleIteratorVariable(ListIteratorVariable):
pass
class RangeIteratorVariable(IteratorVariable):
# only needed for isinstance(..., range_iterator) to work
_nonvar_fields = {
"iter_obj",
}
def __init__(self, start: int, stop: int, step: int, len_: int, **kwargs):
super().__init__(**kwargs)
self.start = start
self.stop = stop
self.step = step
self.len = len_
def call_method(self, tx, name, args, kwargs):
if name == "__next__":
return self.next_variable(tx)
elif name == "__iter__":
return self
return super().call_method(tx, name, args, kwargs)
def call_obj_hasattr(self, tx, name):
if self.python_type() is range_iterator:
ri = iter(range(0))
return ConstantVariable(hasattr(ri, name))
return super().call_obj_hasattr(tx, name)
def next_variable(self, tx):
if self.len <= 0:
raise_observed_exception(StopIteration, tx)
self.len -= 1
current = self.start
self.start += self.step
return ConstantVariable.create(current)
def python_type(self):
return range_iterator
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.append_output(codegen.create_load_python_module(range))
)
codegen.append_output(codegen.create_load_const(self.start))
codegen.append_output(codegen.create_load_const(self.stop))
codegen.append_output(codegen.create_load_const(self.step))
codegen.extend_output(create_call_function(3, False))
codegen.append_output(create_instruction("GET_ITER"))