Compare commits

...

1 Commits

Author SHA1 Message Date
f170a2f521 Type lists.py 2025-11-06 12:01:33 -08:00
4 changed files with 214 additions and 156 deletions

View File

@ -3228,7 +3228,7 @@ class InstructionTranslatorBase(
def BUILD_SLICE(self, inst: Instruction) -> None:
items = self.popn(inst.argval)
self.push(SliceVariable(items, tx=self))
self.push(SliceVariable(items, tx=self)) # type: ignore[arg-type]
def BUILD_LIST(self, inst: Instruction) -> None:
items = self.popn(inst.argval)
@ -3607,7 +3607,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, ListVariable)
assert obj.is_mutable()
obj.call_method(self, "extend", [v], {})
obj.call_method(self, "extend", [v], {}) # type: ignore[arg-type]
def LIST_TO_TUPLE(self, inst: Instruction) -> None:
self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) # type: ignore[arg-type]
@ -3673,7 +3673,7 @@ class InstructionTranslatorBase(
def MATCH_KEYS(self, inst: Instruction) -> None:
tos = self.stack[-1]
assert isinstance(tos, TupleVariable)
keys = tos.unpack_var_sequence(self)
keys = tos.unpack_var_sequence(self) # type: ignore[arg-type]
tos1 = self.stack[-2]
assert isinstance(tos1, ConstDictVariable)

View File

@ -1513,7 +1513,7 @@ class WithExitFunctionVariable(VariableTracker):
# Note here we reconstruct the context manager rather than the
# exit function. The handler generated by BlockStackEntry
# will re-enter the context in the resume function.
self.ctx.reconstruct_type(codegen) # type: ignore[attr-defined]
self.ctx.reconstruct_type(codegen) # type: ignore[union-attr]
if codegen.tx.output.partial_convert:
if sys.version_info >= (3, 11):
codegen.append_output(create_instruction("PUSH_NULL"))
@ -1522,10 +1522,10 @@ class WithExitFunctionVariable(VariableTracker):
# We rely on classes subtyping `GenericContextWrappingVariable`
# to implement these fns and have these attributes
codegen.extend_output(
[codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[arg-type]
[codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[union-attr]
)
codegen.extend_output(
create_call_function(len(self.ctx.target_values), False) # type: ignore[arg-type]
create_call_function(len(self.ctx.target_values), False) # type: ignore[union-attr]
)
codegen.append_output(create_setup_with(self.target))
codegen.append_output(create_instruction("POP_TOP"))

View File

@ -82,7 +82,8 @@ class ItertoolsVariable(VariableTracker):
for item in itertools.product(*seqs, repeat=r)
]
return variables.ListIteratorVariable(
items, mutation_type=ValueMutationNew()
items, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
)
elif (
self.value is itertools.combinations
@ -98,7 +99,8 @@ class ItertoolsVariable(VariableTracker):
for item in itertools.combinations(iterable, r):
items.append(variables.TupleVariable(list(item)))
return variables.ListIteratorVariable(
items, mutation_type=ValueMutationNew()
items, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
)
elif self.value is itertools.groupby:
if any(kw != "key" for kw in kwargs.keys()):
@ -181,7 +183,8 @@ class ItertoolsVariable(VariableTracker):
from_exc=e,
)
return variables.ListIteratorVariable(
result, mutation_type=ValueMutationNew()
result, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
)
elif self.value is itertools.repeat:
if len(args) < 2:
@ -212,7 +215,8 @@ class ItertoolsVariable(VariableTracker):
)
]
return variables.ListIteratorVariable(
items, mutation_type=ValueMutationNew()
items, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
)
else:
return super().call_function(tx, args, kwargs)

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
Variable tracking implementations for list-like data structures in Dynamo.
@ -20,7 +18,7 @@ import collections
import inspect
import operator
import sys
from typing import Optional, TYPE_CHECKING
from typing import Any, Optional, Sequence, TYPE_CHECKING
import torch
import torch.fx
@ -60,11 +58,11 @@ if TYPE_CHECKING:
class BaseListVariable(VariableTracker):
@staticmethod
def cls_for_instance(obj):
def cls_for_instance(obj: Any) -> type["BaseListVariable"]:
return BaseListVariable.cls_for(type(obj))
@staticmethod
def cls_for(obj):
def cls_for(obj: Any) -> type:
return {
iter: ListIteratorVariable,
list: ListVariable,
@ -80,34 +78,38 @@ class BaseListVariable(VariableTracker):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
assert isinstance(items, list)
assert all(isinstance(x, VariableTracker) for x in items)
self.items: list[VariableTracker] = items
def _as_proxy(self):
def _as_proxy(self) -> list[Any]:
return [x.as_proxy() for x in self.items]
def modified(self, items, **kwargs):
def modified(
self, items: list[VariableTracker], **kwargs: Any
) -> "BaseListVariable":
return type(self)(items, **kwargs)
@property
def value(self):
def value(self) -> Any:
return self.as_python_constant()
def debug_repr_helper(self, prefix, suffix):
def debug_repr_helper(self, prefix: str, suffix: str) -> str:
return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix
def as_python_constant(self):
def as_python_constant(self) -> Any:
return self.python_type()([x.as_python_constant() for x in self.items])
def as_proxy(self):
def as_proxy(self) -> Any:
assert self.python_type() is not SizeVariable
return self.python_type()(self._as_proxy())
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
from .tensor import SymNodeVariable
if isinstance(arg, SymNodeVariable):
@ -134,16 +136,16 @@ class BaseListVariable(VariableTracker):
IndexError, tx, args=["list index out of range"]
)
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
return list(self.items)
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__getitem__":
from .tensor import TensorVariable
@ -224,15 +226,15 @@ class BaseListVariable(VariableTracker):
if type(self) is not type(args[0]):
tp_name = self.python_type_name()
other = args[0].python_type_name()
msg = ConstantVariable.create(
msg_vt = ConstantVariable.create(
f'can only concatenate {tp_name} (not "{other}") to {tp_name}'
)
raise_observed_exception(TypeError, tx, args=[msg])
raise_observed_exception(TypeError, tx, args=[msg_vt])
if name == "__add__":
return type(self)(self.items + args[0].items, source=self.source)
return type(self)(self.items + args[0].items, source=self.source) # type: ignore[attr-defined]
else:
self.items += args[0].items
self.items += args[0].items # type: ignore[attr-defined]
return self
elif name in ("__mul__", "__imul__"):
if kwargs or len(args) != 1:
@ -244,10 +246,10 @@ class BaseListVariable(VariableTracker):
)
if not (args[0].is_python_constant() and args[0].python_type() is int):
msg = ConstantVariable.create(
msg_vt = ConstantVariable.create(
f"can't multiply sequence by non-int type of '{args[0].python_type_name()}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
raise_observed_exception(TypeError, tx, args=[msg_vt])
val = args[0].as_python_constant()
@ -301,7 +303,7 @@ class BaseListVariable(VariableTracker):
class RangeVariable(BaseListVariable):
def __init__(self, items, **kwargs) -> None:
def __init__(self, items: Sequence[VariableTracker], **kwargs: Any) -> None:
items_to_map = items
start = variables.ConstantVariable.create(0)
stop = None
@ -316,7 +318,7 @@ class RangeVariable(BaseListVariable):
else:
raise AssertionError
def maybe_as_int(x):
def maybe_as_int(x: VariableTracker) -> VariableTracker:
return (
ConstantVariable(int(x.value)) if isinstance(x, ConstantVariable) else x
)
@ -329,22 +331,22 @@ class RangeVariable(BaseListVariable):
assert stop is not None
super().__init__([start, stop, step], **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
return self.debug_repr_helper("range(", ")")
def python_type(self):
def python_type(self) -> type:
return range
def start(self):
def start(self) -> Any:
return self.items[0].as_python_constant()
def stop(self):
def stop(self) -> Any:
return self.items[1].as_python_constant()
def step(self):
def step(self) -> Any:
return self.items[2].as_python_constant()
def range_length(self):
def range_length(self) -> int:
lo = self.start()
hi = self.stop()
step = self.step()
@ -357,7 +359,7 @@ class RangeVariable(BaseListVariable):
else:
return 0
def _get_slice_indices(self, length, slice):
def _get_slice_indices(self, length: int, slice: slice) -> list[int]:
step_is_negative = 0
if slice.step is None:
@ -406,7 +408,7 @@ class RangeVariable(BaseListVariable):
return [start, stop, step]
def apply_index(self, index):
def apply_index(self, index: int) -> VariableTracker:
length = self.range_length()
if index < 0:
index = length + index
@ -421,12 +423,12 @@ class RangeVariable(BaseListVariable):
return variables.ConstantVariable.create(self.start() + (index * self.step()))
def apply_slice(self, slice):
def apply_slice(self, slice: slice) -> "RangeVariable":
(slice_start, slice_stop, slice_step) = self._get_slice_indices(
self.range_length(), slice
)
def compute_item(index):
def compute_item(index: int) -> int:
return self.start() + (index * self.step())
sub_step = self.step() * slice_step
@ -442,10 +444,12 @@ class RangeVariable(BaseListVariable):
)
return result
def as_python_constant(self):
def as_python_constant(self) -> range:
return range(*[x.as_python_constant() for x in self.items])
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
# implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c
index = arg.as_python_constant()
@ -457,28 +461,30 @@ class RangeVariable(BaseListVariable):
msg = ConstantVariable("range indices must be integers or slices")
raise_observed_exception(TypeError, tx, args=[msg])
def as_proxy(self):
def as_proxy(self) -> range:
return self.python_type()(*self._as_proxy())
def unpack_var_sequence(self, tx=None):
def unpack_var_sequence(
self, tx: Optional["InstructionTranslator"] = None
) -> list[VariableTracker]:
return [variables.ConstantVariable.create(x) for x in self.as_python_constant()]
def reconstruct(self, codegen: "PyCodegen") -> None:
assert "range" not in codegen.tx.f_globals
codegen.add_push_null(
lambda: codegen.append_output(codegen.create_load_python_module(range))
lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type]
)
codegen.foreach(self.items)
codegen.extend_output(create_call_function(3, False))
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
if self.python_type() is range:
return variables.ConstantVariable.create(name in range.__dict__)
return super().call_obj_hasattr(tx, name)
def range_equals(self, other: "RangeVariable"):
def range_equals(self, other: "RangeVariable") -> bool:
r0, r1 = self, other
if (
self.range_length() != r1.range_length()
@ -487,12 +493,12 @@ class RangeVariable(BaseListVariable):
):
return False
if len(r0) == 1:
if self.range_length() == 1:
return True
return r0.step() == r1.step()
def range_count(self, x: VariableTracker):
def range_count(self, x: VariableTracker) -> int:
# Based on CPython
# https://github.com/guilhermeleobas/cpython/blob/baefaa6cba1d69efd2f930cdc56bca682c54b139/Objects/rangeobject.c#L442-L486
x = x.as_python_constant()
@ -511,7 +517,13 @@ class RangeVariable(BaseListVariable):
return int(re)
return 0
def call_method(self, tx, name, args, kwargs):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__iter__":
if not all(var.is_python_constant() for var in self.items):
# Can't represent a `range_iterator` without well defined bounds
@ -545,7 +557,10 @@ class RangeVariable(BaseListVariable):
if pt is not range:
return ConstantVariable.create(NotImplemented)
cmp = self.range_equals(other)
if isinstance(other, RangeVariable):
cmp = self.range_equals(other)
else:
cmp = False
# Two ranges are equal if they produce the same sequence of values
if name == "__eq__":
@ -554,7 +569,7 @@ class RangeVariable(BaseListVariable):
return ConstantVariable(not cmp)
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx: "InstructionTranslator", name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
fields = ["start", "stop", "step"]
if name in fields:
return self.items[fields.index(name)]
@ -568,11 +583,11 @@ class CommonListMethodsVariable(BaseListVariable):
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
from .tensor import SymNodeVariable
if name == "append" and self.is_mutable():
@ -676,9 +691,9 @@ class CommonListMethodsVariable(BaseListVariable):
self.items[key.evaluate_expr()] = value
elif isinstance(key, SliceVariable):
if key.is_python_constant():
self.items[key.as_python_constant()] = list(value.items)
self.items[key.as_python_constant()] = list(value.items) # type: ignore[attr-defined]
else:
items = slice(
items_slice = slice(
*[
(
s.evaluate_expr()
@ -688,7 +703,7 @@ class CommonListMethodsVariable(BaseListVariable):
for s in key.items
]
)
self.items[items] = list(value.items)
self.items[items_slice] = list(value.items) # type: ignore[attr-defined]
else:
self.items[key.as_python_constant()] = value
return ConstantVariable.create(None)
@ -733,8 +748,8 @@ class CommonListMethodsVariable(BaseListVariable):
"0 args and 0 kwargs",
f"{len(args)} args and {len(kwargs)} kwargs",
)
items = list(self.items)
return self.modified(items, mutation_type=ValueMutationNew())
items_lst: list[VariableTracker] = list(self.items)
return self.modified(items_lst, mutation_type=ValueMutationNew())
elif name == "reverse" and self.is_mutable():
if args or kwargs:
raise_args_mismatch(
@ -763,13 +778,13 @@ class CommonListMethodsVariable(BaseListVariable):
class ListVariable(CommonListMethodsVariable):
def python_type(self):
def python_type(self) -> type:
return list
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)})"
def debug_repr(self):
def debug_repr(self) -> str:
return self.debug_repr_helper("[", "]")
def reconstruct(self, codegen: "PyCodegen") -> None:
@ -778,11 +793,11 @@ class ListVariable(CommonListMethodsVariable):
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
from .tensor import SymNodeVariable
if name == "__setitem__" and self.is_mutable():
@ -805,14 +820,14 @@ class ListVariable(CommonListMethodsVariable):
msg = ConstantVariable.create("can only assign an iterable")
raise_observed_exception(TypeError, tx, args=[msg])
key = key.as_python_constant()
if key.step == 0:
key_as_const = key.as_python_constant()
if key_as_const.step == 0:
msg = ConstantVariable.create("slice step cannot be zero")
raise_observed_exception(ValueError, tx, args=[msg])
value = value.force_unpack_var_sequence(tx)
value_unpack = value.force_unpack_var_sequence(tx)
try:
self.items[key] = value
self.items[key_as_const] = value_unpack
except Exception as exc:
raise_observed_exception(
type(exc),
@ -859,7 +874,7 @@ class ListVariable(CommonListMethodsVariable):
assert first_non_constant_key is not None
try:
python_type = first_non_constant_key.python_type()
python_type = str(first_non_constant_key.python_type())
except NotImplementedError:
python_type = "unknown"
@ -904,7 +919,7 @@ class ListVariable(CommonListMethodsVariable):
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name == "__class__":
source = AttrSource(self.source, name) if self.source else None
class_type = self.python_type()
@ -916,14 +931,19 @@ class ListVariable(CommonListMethodsVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
if self.python_type() is not list:
return super().call_obj_hasattr(tx, name)
return variables.ConstantVariable.create(hasattr([], name))
class DequeVariable(CommonListMethodsVariable):
def __init__(self, items, maxlen=None, **kwargs) -> None:
def __init__(
self,
items: list[VariableTracker],
maxlen: Optional[VariableTracker] = None,
**kwargs: Any,
) -> None:
if maxlen is None:
maxlen = ConstantVariable.create(None)
assert maxlen.is_python_constant(), (
@ -935,17 +955,17 @@ class DequeVariable(CommonListMethodsVariable):
items = items[-maxlen.as_python_constant() :]
super().__init__(items, **kwargs)
def python_type(self):
def python_type(self) -> type:
return collections.deque
def debug_repr(self):
def debug_repr(self) -> str:
if self.maxlen.as_python_constant() is None:
return self.debug_repr_helper(
"deque([", "], maxlen=" + self.maxlen.debug_repr() + ")"
)
return self.debug_repr_helper("deque([", "])")
def as_python_constant(self):
def as_python_constant(self) -> collections.deque[Any]:
return self.python_type()(
[x.as_python_constant() for x in self.items],
maxlen=self.maxlen.as_python_constant(),
@ -954,7 +974,7 @@ class DequeVariable(CommonListMethodsVariable):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.append_output(
codegen.create_load_python_module(collections.deque)
codegen.create_load_python_module(collections.deque) # type: ignore[arg-type]
)
)
codegen.foreach(self.items)
@ -962,18 +982,18 @@ class DequeVariable(CommonListMethodsVariable):
codegen(self.maxlen)
codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False))
def var_getattr(self, tx: "InstructionTranslator", name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name == "maxlen":
return self.maxlen
return super().var_getattr(tx, name)
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if (
name == "__setitem__"
and self.is_mutable()
@ -1068,20 +1088,20 @@ class DequeVariable(CommonListMethodsVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
if self.python_type() is collections.deque:
return variables.ConstantVariable.create(name in collections.deque.__dict__)
return super().call_obj_hasattr(tx, name)
class TupleVariable(BaseListVariable):
def python_type(self):
def python_type(self) -> type[tuple]: # type: ignore[type-arg]
return tuple
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)})"
def debug_repr(self):
def debug_repr(self) -> str:
return self.debug_repr_helper("(", ")")
def reconstruct(self, codegen: "PyCodegen") -> None:
@ -1090,14 +1110,14 @@ class TupleVariable(BaseListVariable):
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name == "__class__":
source = AttrSource(self.source, name) if self.source else None
class_type = self.python_type()
@ -1109,7 +1129,7 @@ class TupleVariable(BaseListVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
if self.python_type() is not tuple:
return super().call_obj_hasattr(tx, name)
return variables.ConstantVariable.create(hasattr((), name))
@ -1127,18 +1147,18 @@ class SizeVariable(TupleVariable):
self,
items: list[VariableTracker],
proxy: Optional[torch.fx.Proxy] = None,
**kwargs,
**kwargs: Any,
) -> None:
self.proxy = proxy
super().__init__(items, **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
return self.debug_repr_helper("torch.Size([", "])")
def python_type(self):
def python_type(self) -> type:
return torch.Size
def as_proxy(self):
def as_proxy(self) -> Any:
if self.proxy is not None:
return self.proxy
@ -1193,10 +1213,10 @@ class SizeVariable(TupleVariable):
] + create_call_function(1, False)
codegen.extend_output(build_torch_size)
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
return list(self.items)
def numel(self, tx):
def numel(self, tx: "InstructionTranslator") -> VariableTracker:
from .builtin import BuiltinVariable
from .tensor import SymNodeVariable
@ -1226,11 +1246,11 @@ class SizeVariable(TupleVariable):
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__getitem__":
if kwargs or len(args) != 1:
raise_args_mismatch(
@ -1253,7 +1273,9 @@ class SizeVariable(TupleVariable):
return super().call_method(tx, name, args, kwargs)
def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker):
def get_item_dyn(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
from .tensor import SymNodeVariable
if isinstance(arg, SymNodeVariable):
@ -1269,7 +1291,7 @@ class SizeVariable(TupleVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
return variables.ConstantVariable.create(hasattr(torch.Size, name))
@ -1280,33 +1302,39 @@ class NamedTupleVariable(TupleVariable):
*TupleVariable._nonvar_fields,
}
def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None:
def __init__(
self,
items: list[VariableTracker],
tuple_cls: type,
dynamic_attributes: Optional[dict[str, VariableTracker]] = None,
**kwargs: Any,
) -> None:
super().__init__(items, **kwargs)
self.tuple_cls = tuple_cls
self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {}
def is_namedtuple(self):
def is_namedtuple(self) -> bool:
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(
getattr(self.tuple_cls, "_make", None)
)
def is_structseq(self):
def is_structseq(self) -> bool:
return not self.is_namedtuple()
def fields(self):
def fields(self) -> tuple[str, ...]:
return namedtuple_fields(self.tuple_cls)
def debug_repr(self):
def debug_repr(self) -> str:
if self.is_structseq():
# StructSequenceType(iterable)
return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items]))
# NamedTupleType(*iterable)
return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items)))
def python_type(self):
def python_type(self) -> type:
return self.tuple_cls
def as_python_constant(self):
def as_python_constant(self) -> Any:
if self.is_structseq():
# StructSequenceType(iterable)
result = self.python_type()([x.as_python_constant() for x in self.items])
@ -1328,7 +1356,7 @@ class NamedTupleVariable(TupleVariable):
return result
def as_proxy(self):
def as_proxy(self) -> Any:
assert self.python_type() is not SizeVariable
if self.is_structseq():
# StructSequenceType(iterable)
@ -1342,7 +1370,10 @@ class NamedTupleVariable(TupleVariable):
# StructSequenceType(iterable)
# NamedTupleType(*iterable)
# NamedTupleType._make(iterable)
create_fn = self.tuple_cls if self.is_structseq() else self.tuple_cls._make
if self.is_structseq():
create_fn = self.tuple_cls
else:
create_fn = self.tuple_cls._make # type: ignore[attr-defined]
codegen.add_push_null(
lambda: codegen.append_output(
codegen.create_load_const_unchecked(create_fn)
@ -1384,8 +1415,8 @@ class NamedTupleVariable(TupleVariable):
def call_method(
self,
tx,
name,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
@ -1446,7 +1477,9 @@ class NamedTupleVariable(TupleVariable):
return super().call_method(tx, name, args, kwargs)
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
if isinstance(arg, SliceVariable):
# slicing a namedtuple produces a tuple
return TupleVariable(
@ -1455,8 +1488,8 @@ class NamedTupleVariable(TupleVariable):
)
return super().getitem_const(tx, arg)
def var_getattr(self, tx: "InstructionTranslator", name):
def check_and_create_method():
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
def check_and_create_method() -> Optional[VariableTracker]:
method = inspect.getattr_static(self.tuple_cls, name, None)
if isinstance(method, classmethod):
# We need the unbounded cls method to avoid the inline __self__
@ -1489,8 +1522,8 @@ class NamedTupleVariable(TupleVariable):
return super().var_getattr(tx, name)
if name == "_fields":
source = NamedTupleFieldsSource(self.source) if self.source else None
return VariableTracker.build(tx, self.fields(), source=source)
result_source = NamedTupleFieldsSource(self.source) if self.source else None
return VariableTracker.build(tx, self.fields(), source=result_source)
if name in self.dynamic_attributes:
return self.dynamic_attributes[name]
@ -1505,14 +1538,19 @@ class NamedTupleVariable(TupleVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
return variables.ConstantVariable.create(
name in self.dynamic_attributes or hasattr(self.tuple_cls, name)
)
class SliceVariable(VariableTracker):
def __init__(self, items, tx=None, **kwargs) -> None:
def __init__(
self,
items: Sequence[VariableTracker],
tx: Optional["InstructionTranslator"] = None,
**kwargs: Any,
) -> None:
items_to_map = items
start, stop, step = [variables.ConstantVariable.create(None)] * 3
@ -1547,23 +1585,23 @@ class SliceVariable(VariableTracker):
super().__init__(**kwargs)
def debug_repr(self):
return self.debug_repr_helper("slice(", ")")
def debug_repr(self) -> str:
return "slice(" + ", ".join(i.debug_repr() for i in self.items) + ")"
def as_proxy(self):
def as_proxy(self) -> slice:
return slice(*[x.as_proxy() for x in self.items])
def python_type(self):
def python_type(self) -> type:
return slice
def as_python_constant(self):
def as_python_constant(self) -> slice:
return slice(*[guard_if_dyn(x) for x in self.items])
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.foreach(self.items)
codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items)))
def var_getattr(self, tx: "InstructionTranslator", name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name in cmp_name_to_op_mapping:
return variables.GetAttrVariable(self, name)
fields = ["start", "stop", "step"]
@ -1584,7 +1622,9 @@ class ListIteratorVariable(IteratorVariable):
*IteratorVariable._nonvar_fields,
}
def __init__(self, items, index: int = 0, **kwargs) -> None:
def __init__(
self, items: list[VariableTracker], index: int = 0, **kwargs: Any
) -> None:
super().__init__(**kwargs)
assert isinstance(items, list)
# Removing this check as it slows things down too much
@ -1598,7 +1638,7 @@ class ListIteratorVariable(IteratorVariable):
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"
def next_variable(self, tx):
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
assert self.is_mutable()
old_index = self.index
if old_index >= len(self.items) or self.is_exhausted:
@ -1609,27 +1649,31 @@ class ListIteratorVariable(IteratorVariable):
self.index += 1
return self.items[old_index]
def call_obj_hasattr(self, tx, name):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
return variables.ConstantVariable.create(hasattr(iter([]), name))
def python_type(self):
def python_type(self) -> type:
return type(iter([]))
def as_python_constant(self):
def as_python_constant(self) -> Any:
if self.index > 0:
raise NotImplementedError
return iter([x.as_python_constant() for x in self.items])
def has_unpack_var_sequence(self, tx):
def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
return True
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
if self.is_exhausted:
return []
self.is_exhausted = True
return list(self.items[self.index :])
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
def force_unpack_var_sequence(
self, tx: "InstructionTranslator"
) -> list[VariableTracker]:
return self.unpack_var_sequence(tx)
def reconstruct(self, codegen: "PyCodegen") -> None:
@ -1656,27 +1700,37 @@ class RangeIteratorVariable(IteratorVariable):
"iter_obj",
}
def __init__(self, start: int, stop: int, step: int, len_: int, **kwargs):
def __init__(
self, start: int, stop: int, step: int, len_: int, **kwargs: Any
) -> None:
super().__init__(**kwargs)
self.start = start
self.stop = stop
self.step = step
self.len = len_
def call_method(self, tx, name, args, kwargs):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__next__":
return self.next_variable(tx)
elif name == "__iter__":
return self
return super().call_method(tx, name, args, kwargs)
def call_obj_hasattr(self, tx, name):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
if self.python_type() is range_iterator:
ri = iter(range(0))
return ConstantVariable(hasattr(ri, name))
return super().call_obj_hasattr(tx, name)
def next_variable(self, tx):
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
if self.len <= 0:
raise_observed_exception(StopIteration, tx)
@ -1685,12 +1739,12 @@ class RangeIteratorVariable(IteratorVariable):
self.start += self.step
return ConstantVariable.create(current)
def python_type(self):
def python_type(self) -> type:
return range_iterator
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.append_output(codegen.create_load_python_module(range))
lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type]
)
codegen.append_output(codegen.create_load_const(self.start))
codegen.append_output(codegen.create_load_const(self.stop))