Improve typing

This commit is contained in:
Lucas Kabela
2025-10-31 14:06:01 -07:00
parent 3e3cb93f07
commit cb8b1096e3
2 changed files with 143 additions and 89 deletions

View File

@ -57,6 +57,7 @@ from ..source import (
GetItemSource,
GlobalSource,
is_constant_source,
Source,
TypeSource,
)
from ..utils import (
@ -141,7 +142,7 @@ IN_PLACE_DESUGARING_MAP = {
_HandlerCallback = Callable[
["InstructionTranslator", typing.Any, typing.Any], VariableTracker
["InstructionTranslator", typing.Any, typing.Any], VariableTracker | None
]
_TrackersType = Union[type[VariableTracker], tuple[type[VariableTracker], ...]]
polyfill_fn_mapping = {
@ -230,7 +231,11 @@ def populate_builtin_to_tensor_fn_map() -> None:
"""
def __torch_function__(
self, func: Any, types: Any, args: Any = (), kwargs: Any | None = None
self,
func: Callable[..., Any],
types: Any,
args: Sequence[Any] = (),
kwargs: dict[str, Any] | None = None,
) -> Any:
kwargs = kwargs or {}
nonlocal most_recent_func
@ -294,14 +299,14 @@ class BuiltinVariable(VariableTracker):
}
@classmethod
def create_with_source(cls, value: Any, source: Any) -> "BuiltinVariable":
def create_with_source(cls, value: Any, source: Source) -> "BuiltinVariable":
install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH))
return cls(value, source=source)
@staticmethod
@functools.cache
def _constant_fold_functions() -> set[Any]:
fns = {
def _constant_fold_functions() -> set[object]:
fns: set[object] = {
abs,
all,
any,
@ -461,8 +466,8 @@ class BuiltinVariable(VariableTracker):
list[
tuple[
tuple[
type[VariableTracker] | tuple[type[VariableTracker], ...],
type[VariableTracker] | tuple[type[VariableTracker], ...],
type[VariableTracker],
_TrackersType,
],
_HandlerCallback,
]
@ -550,7 +555,7 @@ class BuiltinVariable(VariableTracker):
a: VariableTracker,
b: VariableTracker,
*,
fn: Any = op,
fn: Callable[..., Any] = op,
) -> VariableTracker:
from .builder import wrap_fx_proxy
@ -580,12 +585,12 @@ class BuiltinVariable(VariableTracker):
# List-like addition (e.g. [1, 2] + [3, 4])
def tuple_add_handler(
tx: "InstructionTranslator", a: Any, b: Any
tx: "InstructionTranslator", a: BaseListVariable, b: VariableTracker
) -> VariableTracker:
return TupleVariable([*a.items, *b.unpack_var_sequence(tx)])
def size_add_handler(
tx: "InstructionTranslator", a: Any, b: Any
tx: "InstructionTranslator", a: BaseListVariable, b: VariableTracker
) -> VariableTracker:
return SizeVariable([*a.items, *b.unpack_var_sequence(tx)])
@ -652,7 +657,9 @@ class BuiltinVariable(VariableTracker):
]
op_handlers[operator.add].extend(list_like_addition_handlers)
def list_iadd_handler(tx: "InstructionTranslator", a: Any, b: Any) -> Any:
def list_iadd_handler(
tx: "InstructionTranslator", a: BaseListVariable, b: VariableTracker
) -> Any:
if a.is_immutable() or not b.has_unpack_var_sequence(tx):
# Handler doesn't apply
return None
@ -680,11 +687,12 @@ class BuiltinVariable(VariableTracker):
# List-like expansion (e.g. [1, 2, 3] * 3)
def expand_list_like(
tx: "InstructionTranslator", lst: Any, const: Any
tx: "InstructionTranslator", lst: VariableTracker, const: VariableTracker
) -> VariableTracker:
if isinstance(lst, ConstantVariable):
lst, const = const, lst
try:
assert isinstance(lst, BaseListVariable)
return lst.__class__(
items=lst.items * const.as_python_constant(),
mutation_type=ValueMutationNew(),
@ -710,7 +718,7 @@ class BuiltinVariable(VariableTracker):
op_handlers[operator.mul].extend(list_like_expansion_handlers)
def create_cmp_op_handlers(
op: Any,
op: Callable[..., Any],
) -> list[tuple[tuple[_TrackersType, _TrackersType], _HandlerCallback]]:
def compare_by_value(
tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
@ -892,7 +900,7 @@ class BuiltinVariable(VariableTracker):
@staticmethod
def _find_binop_handler(
op: Any, a_type: type, b_type: type
op: Callable[..., Any], a_type: type[VariableTracker], b_type: type
) -> list[_HandlerCallback] | None:
handlers = BuiltinVariable._binop_handlers().get(op)
if handlers is None:
@ -938,10 +946,10 @@ class BuiltinVariable(VariableTracker):
assert name not in codegen.tx.f_globals, "shadowed global"
codegen.append_output(codegen.create_load_global(name, add=True))
def constant_args(self, *args: Any, **kwargs: Any) -> bool:
def constant_args(self, *args: VariableTracker, **kwargs: VariableTracker) -> bool:
return check_constant_args(args, kwargs)
def tensor_args(self, *args: Any) -> bool:
def tensor_args(self, *args: VariableTracker) -> bool:
any_tensor = False
for arg in args:
if isinstance(arg, variables.GetAttrVariable):
@ -957,7 +965,9 @@ class BuiltinVariable(VariableTracker):
any_tensor = any_tensor or issubclass(arg_type, variables.TensorVariable)
return any_tensor
def python_and_tensor_constant_only(self, *args: Any, **kwargs: Any) -> bool:
def python_and_tensor_constant_only(
self, *args: VariableTracker, **kwargs: VariableTracker
) -> bool:
tensor_args = []
non_tensor_args = []
for i in itertools.chain(args, kwargs.values()):
@ -999,7 +1009,7 @@ class BuiltinVariable(VariableTracker):
from .lazy import LazyVariableTracker
obj = BuiltinVariable(fn)
handlers: list[Any] = []
handlers: list[_HandlerCallback] = []
if any(issubclass(t, LazyVariableTracker) for t in arg_types):
return lambda tx, args, kwargs: obj.call_function(
@ -1180,7 +1190,7 @@ class BuiltinVariable(VariableTracker):
handlers.append(constant_fold_handler)
def call_unimplemented_v2(args: Any) -> None:
def call_unimplemented_v2(args: Sequence[VariableTracker]) -> None:
real_arg_types = [arg.python_type_name() for arg in args]
unimplemented_v2(
gb_type="Failed to trace builtin operator",
@ -1202,7 +1212,9 @@ class BuiltinVariable(VariableTracker):
(handler,) = handlers
def builtin_dispatch(
tx: "InstructionTranslator", args: Any, kwargs: Any
tx: "InstructionTranslator",
args: Sequence[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker | None:
rv = handler(tx, args, kwargs)
if rv:
@ -1213,7 +1225,9 @@ class BuiltinVariable(VariableTracker):
else:
def builtin_dispatch(
tx: "InstructionTranslator", args: Any, kwargs: Any
tx: "InstructionTranslator",
args: Sequence[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker | None:
rv = None
for fn in handlers:
@ -1271,9 +1285,8 @@ class BuiltinVariable(VariableTracker):
tmp = args[0]
# swap args and call reverse version of func
args = list(args) # type: ignore[assignment]
args[0] = args[1]
args[1] = tmp
args[0] = args[1] # type: ignore[index]
args[1] = tmp # type: ignore[index]
else:
func = BUILTIN_TO_TENSOR_FN_MAP[self.fn]
else:
@ -1338,7 +1351,7 @@ class BuiltinVariable(VariableTracker):
# Dynamo expects `__eq__` str while operator.eq gives just `eq`
# TODO - supporting all comparison operators could also work but
# it fails lots of tests because graph str changes.
return args[0].call_method(tx, "__eq__", list(args[1:]), kwargs)
return args[0].call_method(tx, "__eq__", args[1:], kwargs)
proxy = tx.output.create_proxy(
"call_function",
fn,
@ -1375,7 +1388,7 @@ class BuiltinVariable(VariableTracker):
if fn is operator.truediv and isinstance(
args[0], variables.UnspecializedPythonVariable
):
args = list(args) # type: ignore[assignment]
args = list(args)
args[0] = args[0].as_python_constant()
return wrap_fx_proxy(tx, proxy)
@ -1402,9 +1415,9 @@ class BuiltinVariable(VariableTracker):
def call_function(
self,
tx: "InstructionTranslator",
args: Sequence["VariableTracker"],
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
args: Sequence[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
key: tuple[object, ...]
if kwargs:
kwargs = {k: v.realize() for k, v in kwargs.items()}
@ -1424,9 +1437,9 @@ class BuiltinVariable(VariableTracker):
self,
tx: "InstructionTranslator",
name: str,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if self.fn is object and name == "__setattr__":
assert len(args) == 3
assert len(kwargs) == 0
@ -1568,7 +1581,7 @@ class BuiltinVariable(VariableTracker):
{},
),
)
return None # type: ignore[return-value]
return None
call_int = _call_int_float
call_float = _call_int_float
@ -1592,7 +1605,7 @@ class BuiltinVariable(VariableTracker):
return SymNodeVariable.create(tx, arg.as_proxy() != 0)
# TODO handle more cases and merge this with this with `generic_jump`.
return None # type: ignore[return-value]
return None
def call_str(
self, tx: "InstructionTranslator", arg: VariableTracker
@ -1652,20 +1665,22 @@ class BuiltinVariable(VariableTracker):
else:
value = ", ".join(a.as_python_constant() for a in arg.args)
return variables.ConstantVariable.create(value=value)
return None # type: ignore[return-value]
return None
def _call_min_max(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker:
def _call_min_max(
self, tx: "InstructionTranslator", *args: VariableTracker
) -> VariableTracker | None:
if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
items = args[0].force_unpack_var_sequence(tx)
return self._call_min_max_seq(tx, items) # type: ignore[return-value]
return self._call_min_max_seq(tx, items)
elif len(args) == 2:
return self._call_min_max_binary(tx, args[0], args[1]) # type: ignore[return-value]
return self._call_min_max_binary(tx, args[0], args[1])
elif len(args) > 2:
return self._call_min_max_seq(tx, args) # type: ignore[arg-type,return-value]
return None # type: ignore[return-value]
return self._call_min_max_seq(tx, args)
return None
def _call_min_max_seq(
self, tx: "InstructionTranslator", items: Any
self, tx: "InstructionTranslator", items: Sequence[VariableTracker]
) -> VariableTracker:
assert len(items) > 0
if len(items) == 1:
@ -1782,7 +1797,7 @@ class BuiltinVariable(VariableTracker):
call_max = _call_min_max
def call_abs(
self, tx: "InstructionTranslator", arg: "VariableTracker"
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
# Call arg.__abs__()
abs_method = BuiltinVariable(getattr).call_function(
@ -1791,7 +1806,7 @@ class BuiltinVariable(VariableTracker):
return abs_method.call_function(tx, [], {})
def call_pos(
self, tx: "InstructionTranslator", arg: "VariableTracker"
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
# Call arg.__pos__()
pos_method = BuiltinVariable(getattr).call_function(
@ -1800,7 +1815,7 @@ class BuiltinVariable(VariableTracker):
return pos_method.call_function(tx, [], {})
def call_index(
self, tx: "InstructionTranslator", arg: "VariableTracker"
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
if isinstance(arg, variables.TensorVariable):
unimplemented_v2(
@ -1818,8 +1833,8 @@ class BuiltinVariable(VariableTracker):
self,
tx: "InstructionTranslator",
arg: VariableTracker,
*args: Any,
**kwargs: Any,
*args: VariableTracker,
**kwargs: VariableTracker,
) -> VariableTracker:
# Call arg.__round__()
round_method = BuiltinVariable(getattr).call_function(
@ -1828,7 +1843,7 @@ class BuiltinVariable(VariableTracker):
return round_method.call_function(tx, args, kwargs)
def call_range(
self, tx: "InstructionTranslator", *args: Any
self, tx: "InstructionTranslator", *args: VariableTracker
) -> VariableTracker | None:
if check_unspec_or_constant_args(args, {}):
return variables.RangeVariable(args)
@ -1840,12 +1855,14 @@ class BuiltinVariable(VariableTracker):
# None no-ops this handler and lets the driving function proceed
return None
def _dynamic_args(self, *args: Any, **kwargs: Any) -> bool:
def _dynamic_args(self, *args: VariableTracker, **kwargs: VariableTracker) -> bool:
return any(isinstance(x, SymNodeVariable) for x in args) or any(
isinstance(x, SymNodeVariable) for x in kwargs.values()
)
def call_slice(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker:
def call_slice(
self, tx: "InstructionTranslator", *args: VariableTracker
) -> VariableTracker:
return variables.SliceVariable(args, tx)
def _dyn_proxy(
@ -1865,8 +1882,8 @@ class BuiltinVariable(VariableTracker):
self,
tx: "InstructionTranslator",
obj: VariableTracker | None = None,
*args: Any,
**kwargs: Any,
*args: VariableTracker,
**kwargs: VariableTracker,
) -> VariableTracker | None:
assert not isinstance(obj, variables.IteratorVariable)
@ -1916,8 +1933,8 @@ class BuiltinVariable(VariableTracker):
self,
tx: "InstructionTranslator",
obj: VariableTracker,
*args: Any,
**kwargs: Any,
*args: VariableTracker,
**kwargs: VariableTracker,
) -> VariableTracker:
cls = variables.BaseListVariable.cls_for(self.fn)
return cls(
@ -1929,8 +1946,8 @@ class BuiltinVariable(VariableTracker):
self,
tx: "InstructionTranslator",
obj: VariableTracker | None = None,
*args: Any,
**kwargs: Any,
*args: VariableTracker,
**kwargs: VariableTracker,
) -> VariableTracker | None:
if isinstance(obj, variables.IteratorVariable):
cls = variables.BaseListVariable.cls_for(self.fn)
@ -1950,8 +1967,8 @@ class BuiltinVariable(VariableTracker):
self,
tx: "InstructionTranslator",
obj: VariableTracker,
*args: Any,
**kwargs: Any,
*args: VariableTracker,
**kwargs: VariableTracker,
) -> VariableTracker:
# avoid the overhead of tracing the polyfill if we already know the class implemented __iter__
if isinstance(
@ -2020,7 +2037,9 @@ class BuiltinVariable(VariableTracker):
else:
return None
def call_cast(self, _: Any, *args: Any, **kwargs: Any) -> VariableTracker | None:
def call_cast(
self, _: Any, *args: VariableTracker, **kwargs: VariableTracker
) -> VariableTracker | None:
if len(args) == 2:
return args[1]
@ -2047,28 +2066,34 @@ class BuiltinVariable(VariableTracker):
@staticmethod
def call_custom_dict(
tx: "InstructionTranslator", user_cls: type, *args: Any, **kwargs: Any
tx: "InstructionTranslator",
user_cls: type,
*args: VariableTracker,
**kwargs: VariableTracker,
) -> VariableTracker:
args = list(args)
args_list = list(args)
if (
len(args) == 1
and isinstance(args[0], variables.GetAttrVariable)
and isinstance(args[0].obj, variables.UserDefinedClassVariable)
and not tx.output.side_effects.has_pending_mutation(args[0].obj)
len(args_list) == 1
and isinstance(args_list[0], variables.GetAttrVariable)
and isinstance(args_list[0].obj, variables.UserDefinedClassVariable)
and not tx.output.side_effects.has_pending_mutation(args_list[0].obj)
):
# Forward the GetAttrVariable(foo, "__dict__") to a realized vt of
# VT(foo.__dict__). This simplifies the construction of the new
# dict.
args[0] = args[0].get_forwarded_dict(tx)
args_list[0] = args_list[0].get_forwarded_dict(tx)
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.construct_dict),
[VariableTracker.build(tx, user_cls), *args],
[VariableTracker.build(tx, user_cls), *args_list],
kwargs,
)
@staticmethod
def call_custom_dict_fromkeys(
tx: "InstructionTranslator", user_cls: type, *args: Any, **kwargs: Any
tx: "InstructionTranslator",
user_cls: type,
*args: VariableTracker,
**kwargs: VariableTracker,
) -> VariableTracker:
if user_cls not in {dict, OrderedDict, defaultdict}:
unimplemented_v2(
@ -2111,16 +2136,17 @@ class BuiltinVariable(VariableTracker):
"2 args",
f"{len(args)} args",
)
assert len(args) >= 2
arg, value = args
DictVariableType = (
ConstDictVariable if user_cls is not defaultdict else DefaultDictVariable
)
if isinstance(arg, dict):
arg = [ConstantVariable.create(k) for k in arg.keys()]
arg_list = [ConstantVariable.create(k) for k in arg.keys()]
return DictVariableType(
# pyrefly: ignore [bad-argument-type]
dict.fromkeys(arg, value),
dict.fromkeys(arg_list, value),
user_cls,
mutation_type=ValueMutationNew(),
)
@ -2147,7 +2173,10 @@ class BuiltinVariable(VariableTracker):
)
def call_set(
self, tx: "InstructionTranslator", *args: Any, **kwargs: Any
self,
tx: "InstructionTranslator",
*args: VariableTracker,
**kwargs: VariableTracker,
) -> VariableTracker:
# Can we merge this implementation and call_dict's one?
assert not kwargs
@ -2185,7 +2214,10 @@ class BuiltinVariable(VariableTracker):
)
def call_frozenset(
self, tx: "InstructionTranslator", *args: Any, **kwargs: Any
self,
tx: "InstructionTranslator",
*args: VariableTracker,
**kwargs: VariableTracker,
) -> VariableTracker:
assert not kwargs
if not args:
@ -2213,7 +2245,10 @@ class BuiltinVariable(VariableTracker):
)
def call_zip(
self, tx: "InstructionTranslator", *args: Any, **kwargs: Any
self,
tx: "InstructionTranslator",
*args: VariableTracker,
**kwargs: VariableTracker,
) -> VariableTracker:
if kwargs:
if not (len(kwargs) == 1 and "strict" in kwargs):
@ -2226,21 +2261,29 @@ class BuiltinVariable(VariableTracker):
strict = kwargs.pop("strict", False)
iter_args = [BuiltinVariable(iter).call_function(tx, [arg], {}) for arg in args]
return variables.ZipVariable(
iter_args, strict=strict, mutation_type=ValueMutationNew()
iter_args,
strict=strict, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
)
def call_len(
self, tx: "InstructionTranslator", *args: Any, **kwargs: Any
self,
tx: "InstructionTranslator",
*args: VariableTracker,
**kwargs: VariableTracker,
) -> VariableTracker:
try:
return args[0].call_method(tx, "__len__", args[1:], kwargs)
return args[0].call_method(tx, "__len__", list(args[1:]), kwargs)
except AttributeError as e:
raise_observed_exception(type(e), tx, args=list(e.args))
def call_getitem(
self, tx: "InstructionTranslator", *args: Any, **kwargs: Any
self,
tx: "InstructionTranslator",
*args: VariableTracker,
**kwargs: VariableTracker,
) -> VariableTracker:
return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
return args[0].call_method(tx, "__getitem__", list(args[1:]), kwargs)
def call_isinstance(
self,
@ -2260,7 +2303,9 @@ class BuiltinVariable(VariableTracker):
if isinstance(arg, variables.TensorVariable) and arg.dtype is not None:
def _tensor_isinstance(tensor_var: Any, tensor_type: Any) -> bool:
def _tensor_isinstance(
tensor_var: VariableTracker, tensor_type: Any
) -> bool:
def check_type(ty: Any) -> bool:
if ty not in tensortype_to_dtype:
example_val = arg.as_proxy().node.meta["example_value"]
@ -2379,7 +2424,9 @@ class BuiltinVariable(VariableTracker):
) -> VariableTracker:
return variables.SuperVariable(a, b)
def call_next(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker:
def call_next(
self, tx: "InstructionTranslator", *args: VariableTracker
) -> VariableTracker:
arg = args[0]
try:
return arg.next_variable(tx)
@ -2404,19 +2451,23 @@ class BuiltinVariable(VariableTracker):
return None
def call_map(
self, tx: "InstructionTranslator", fn: VariableTracker, *seqs: Any
self, tx: "InstructionTranslator", fn: VariableTracker, *seqs: VariableTracker
) -> VariableTracker:
seqs = [
seq_list = [
seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq
for seq in seqs
]
return variables.MapVariable(fn, seqs, mutation_type=ValueMutationNew())
return variables.MapVariable(fn, seq_list, mutation_type=ValueMutationNew())
def call_filter(
self, tx: "InstructionTranslator", fn: VariableTracker, seq: VariableTracker
) -> VariableTracker:
seq = seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq
return variables.FilterVariable(fn, seq, mutation_type=ValueMutationNew())
seq_or_list = (
seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq
)
return variables.FilterVariable(
fn, seq_or_list, mutation_type=ValueMutationNew()
)
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
source = self.source and AttrSource(self.source, name)
@ -2683,7 +2734,7 @@ class BuiltinVariable(VariableTracker):
# .data setting to play correctly with the autograd engine.
# Essentially, dynamo is trying to faithfully preserve the (absurd)
# behavior of .data= from eager mode
def _lower_version_count_by_1(x: Any) -> Any:
def _lower_version_count_by_1(x: torch.Tensor) -> torch.Tensor:
version = x._version
if version > 0:
version = version - 1
@ -2836,14 +2887,16 @@ class BuiltinVariable(VariableTracker):
self,
tx: "InstructionTranslator",
_format_string: VariableTracker,
*args: Any,
**kwargs: Any,
*args: VariableTracker,
**kwargs: VariableTracker,
) -> VariableTracker:
format_string = _format_string.as_python_constant()
format_string = str(format_string)
return variables.StringFormatVariable.create(format_string, args, kwargs)
def call_id(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker:
def call_id(
self, tx: "InstructionTranslator", *args: VariableTracker
) -> VariableTracker:
if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable):
nn_mod_variable = args[0]
mod = tx.output.get_submodule(nn_mod_variable.module_key)

View File

@ -23,6 +23,7 @@ import operator
import textwrap
import traceback
import types
from collections.abc import Sequence
from contextlib import nullcontext
from typing import TYPE_CHECKING
@ -631,7 +632,7 @@ class TensorVariable(VariableTracker):
self,
tx,
name,
args: "list[VariableTracker]",
args: Sequence[VariableTracker],
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
from .builder import SourcelessBuilder, VariableBuilder