mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Improve typing
This commit is contained in:
@ -57,6 +57,7 @@ from ..source import (
|
||||
GetItemSource,
|
||||
GlobalSource,
|
||||
is_constant_source,
|
||||
Source,
|
||||
TypeSource,
|
||||
)
|
||||
from ..utils import (
|
||||
@ -141,7 +142,7 @@ IN_PLACE_DESUGARING_MAP = {
|
||||
|
||||
|
||||
_HandlerCallback = Callable[
|
||||
["InstructionTranslator", typing.Any, typing.Any], VariableTracker
|
||||
["InstructionTranslator", typing.Any, typing.Any], VariableTracker | None
|
||||
]
|
||||
_TrackersType = Union[type[VariableTracker], tuple[type[VariableTracker], ...]]
|
||||
polyfill_fn_mapping = {
|
||||
@ -230,7 +231,11 @@ def populate_builtin_to_tensor_fn_map() -> None:
|
||||
"""
|
||||
|
||||
def __torch_function__(
|
||||
self, func: Any, types: Any, args: Any = (), kwargs: Any | None = None
|
||||
self,
|
||||
func: Callable[..., Any],
|
||||
types: Any,
|
||||
args: Sequence[Any] = (),
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
kwargs = kwargs or {}
|
||||
nonlocal most_recent_func
|
||||
@ -294,14 +299,14 @@ class BuiltinVariable(VariableTracker):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_with_source(cls, value: Any, source: Any) -> "BuiltinVariable":
|
||||
def create_with_source(cls, value: Any, source: Source) -> "BuiltinVariable":
|
||||
install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH))
|
||||
return cls(value, source=source)
|
||||
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
def _constant_fold_functions() -> set[Any]:
|
||||
fns = {
|
||||
def _constant_fold_functions() -> set[object]:
|
||||
fns: set[object] = {
|
||||
abs,
|
||||
all,
|
||||
any,
|
||||
@ -461,8 +466,8 @@ class BuiltinVariable(VariableTracker):
|
||||
list[
|
||||
tuple[
|
||||
tuple[
|
||||
type[VariableTracker] | tuple[type[VariableTracker], ...],
|
||||
type[VariableTracker] | tuple[type[VariableTracker], ...],
|
||||
type[VariableTracker],
|
||||
_TrackersType,
|
||||
],
|
||||
_HandlerCallback,
|
||||
]
|
||||
@ -550,7 +555,7 @@ class BuiltinVariable(VariableTracker):
|
||||
a: VariableTracker,
|
||||
b: VariableTracker,
|
||||
*,
|
||||
fn: Any = op,
|
||||
fn: Callable[..., Any] = op,
|
||||
) -> VariableTracker:
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
@ -580,12 +585,12 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
# List-like addition (e.g. [1, 2] + [3, 4])
|
||||
def tuple_add_handler(
|
||||
tx: "InstructionTranslator", a: Any, b: Any
|
||||
tx: "InstructionTranslator", a: BaseListVariable, b: VariableTracker
|
||||
) -> VariableTracker:
|
||||
return TupleVariable([*a.items, *b.unpack_var_sequence(tx)])
|
||||
|
||||
def size_add_handler(
|
||||
tx: "InstructionTranslator", a: Any, b: Any
|
||||
tx: "InstructionTranslator", a: BaseListVariable, b: VariableTracker
|
||||
) -> VariableTracker:
|
||||
return SizeVariable([*a.items, *b.unpack_var_sequence(tx)])
|
||||
|
||||
@ -652,7 +657,9 @@ class BuiltinVariable(VariableTracker):
|
||||
]
|
||||
op_handlers[operator.add].extend(list_like_addition_handlers)
|
||||
|
||||
def list_iadd_handler(tx: "InstructionTranslator", a: Any, b: Any) -> Any:
|
||||
def list_iadd_handler(
|
||||
tx: "InstructionTranslator", a: BaseListVariable, b: VariableTracker
|
||||
) -> Any:
|
||||
if a.is_immutable() or not b.has_unpack_var_sequence(tx):
|
||||
# Handler doesn't apply
|
||||
return None
|
||||
@ -680,11 +687,12 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
# List-like expansion (e.g. [1, 2, 3] * 3)
|
||||
def expand_list_like(
|
||||
tx: "InstructionTranslator", lst: Any, const: Any
|
||||
tx: "InstructionTranslator", lst: VariableTracker, const: VariableTracker
|
||||
) -> VariableTracker:
|
||||
if isinstance(lst, ConstantVariable):
|
||||
lst, const = const, lst
|
||||
try:
|
||||
assert isinstance(lst, BaseListVariable)
|
||||
return lst.__class__(
|
||||
items=lst.items * const.as_python_constant(),
|
||||
mutation_type=ValueMutationNew(),
|
||||
@ -710,7 +718,7 @@ class BuiltinVariable(VariableTracker):
|
||||
op_handlers[operator.mul].extend(list_like_expansion_handlers)
|
||||
|
||||
def create_cmp_op_handlers(
|
||||
op: Any,
|
||||
op: Callable[..., Any],
|
||||
) -> list[tuple[tuple[_TrackersType, _TrackersType], _HandlerCallback]]:
|
||||
def compare_by_value(
|
||||
tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
|
||||
@ -892,7 +900,7 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
@staticmethod
|
||||
def _find_binop_handler(
|
||||
op: Any, a_type: type, b_type: type
|
||||
op: Callable[..., Any], a_type: type[VariableTracker], b_type: type
|
||||
) -> list[_HandlerCallback] | None:
|
||||
handlers = BuiltinVariable._binop_handlers().get(op)
|
||||
if handlers is None:
|
||||
@ -938,10 +946,10 @@ class BuiltinVariable(VariableTracker):
|
||||
assert name not in codegen.tx.f_globals, "shadowed global"
|
||||
codegen.append_output(codegen.create_load_global(name, add=True))
|
||||
|
||||
def constant_args(self, *args: Any, **kwargs: Any) -> bool:
|
||||
def constant_args(self, *args: VariableTracker, **kwargs: VariableTracker) -> bool:
|
||||
return check_constant_args(args, kwargs)
|
||||
|
||||
def tensor_args(self, *args: Any) -> bool:
|
||||
def tensor_args(self, *args: VariableTracker) -> bool:
|
||||
any_tensor = False
|
||||
for arg in args:
|
||||
if isinstance(arg, variables.GetAttrVariable):
|
||||
@ -957,7 +965,9 @@ class BuiltinVariable(VariableTracker):
|
||||
any_tensor = any_tensor or issubclass(arg_type, variables.TensorVariable)
|
||||
return any_tensor
|
||||
|
||||
def python_and_tensor_constant_only(self, *args: Any, **kwargs: Any) -> bool:
|
||||
def python_and_tensor_constant_only(
|
||||
self, *args: VariableTracker, **kwargs: VariableTracker
|
||||
) -> bool:
|
||||
tensor_args = []
|
||||
non_tensor_args = []
|
||||
for i in itertools.chain(args, kwargs.values()):
|
||||
@ -999,7 +1009,7 @@ class BuiltinVariable(VariableTracker):
|
||||
from .lazy import LazyVariableTracker
|
||||
|
||||
obj = BuiltinVariable(fn)
|
||||
handlers: list[Any] = []
|
||||
handlers: list[_HandlerCallback] = []
|
||||
|
||||
if any(issubclass(t, LazyVariableTracker) for t in arg_types):
|
||||
return lambda tx, args, kwargs: obj.call_function(
|
||||
@ -1180,7 +1190,7 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
handlers.append(constant_fold_handler)
|
||||
|
||||
def call_unimplemented_v2(args: Any) -> None:
|
||||
def call_unimplemented_v2(args: Sequence[VariableTracker]) -> None:
|
||||
real_arg_types = [arg.python_type_name() for arg in args]
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to trace builtin operator",
|
||||
@ -1202,7 +1212,9 @@ class BuiltinVariable(VariableTracker):
|
||||
(handler,) = handlers
|
||||
|
||||
def builtin_dispatch(
|
||||
tx: "InstructionTranslator", args: Any, kwargs: Any
|
||||
tx: "InstructionTranslator",
|
||||
args: Sequence[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker | None:
|
||||
rv = handler(tx, args, kwargs)
|
||||
if rv:
|
||||
@ -1213,7 +1225,9 @@ class BuiltinVariable(VariableTracker):
|
||||
else:
|
||||
|
||||
def builtin_dispatch(
|
||||
tx: "InstructionTranslator", args: Any, kwargs: Any
|
||||
tx: "InstructionTranslator",
|
||||
args: Sequence[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker | None:
|
||||
rv = None
|
||||
for fn in handlers:
|
||||
@ -1271,9 +1285,8 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
tmp = args[0]
|
||||
# swap args and call reverse version of func
|
||||
args = list(args) # type: ignore[assignment]
|
||||
args[0] = args[1]
|
||||
args[1] = tmp
|
||||
args[0] = args[1] # type: ignore[index]
|
||||
args[1] = tmp # type: ignore[index]
|
||||
else:
|
||||
func = BUILTIN_TO_TENSOR_FN_MAP[self.fn]
|
||||
else:
|
||||
@ -1338,7 +1351,7 @@ class BuiltinVariable(VariableTracker):
|
||||
# Dynamo expects `__eq__` str while operator.eq gives just `eq`
|
||||
# TODO - supporting all comparison operators could also work but
|
||||
# it fails lots of tests because graph str changes.
|
||||
return args[0].call_method(tx, "__eq__", list(args[1:]), kwargs)
|
||||
return args[0].call_method(tx, "__eq__", args[1:], kwargs)
|
||||
proxy = tx.output.create_proxy(
|
||||
"call_function",
|
||||
fn,
|
||||
@ -1375,7 +1388,7 @@ class BuiltinVariable(VariableTracker):
|
||||
if fn is operator.truediv and isinstance(
|
||||
args[0], variables.UnspecializedPythonVariable
|
||||
):
|
||||
args = list(args) # type: ignore[assignment]
|
||||
args = list(args)
|
||||
args[0] = args[0].as_python_constant()
|
||||
return wrap_fx_proxy(tx, proxy)
|
||||
|
||||
@ -1402,9 +1415,9 @@ class BuiltinVariable(VariableTracker):
|
||||
def call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: Sequence["VariableTracker"],
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
args: Sequence[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
key: tuple[object, ...]
|
||||
if kwargs:
|
||||
kwargs = {k: v.realize() for k, v in kwargs.items()}
|
||||
@ -1424,9 +1437,9 @@ class BuiltinVariable(VariableTracker):
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if self.fn is object and name == "__setattr__":
|
||||
assert len(args) == 3
|
||||
assert len(kwargs) == 0
|
||||
@ -1568,7 +1581,7 @@ class BuiltinVariable(VariableTracker):
|
||||
{},
|
||||
),
|
||||
)
|
||||
return None # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
call_int = _call_int_float
|
||||
call_float = _call_int_float
|
||||
@ -1592,7 +1605,7 @@ class BuiltinVariable(VariableTracker):
|
||||
return SymNodeVariable.create(tx, arg.as_proxy() != 0)
|
||||
|
||||
# TODO handle more cases and merge this with this with `generic_jump`.
|
||||
return None # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
def call_str(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
@ -1652,20 +1665,22 @@ class BuiltinVariable(VariableTracker):
|
||||
else:
|
||||
value = ", ".join(a.as_python_constant() for a in arg.args)
|
||||
return variables.ConstantVariable.create(value=value)
|
||||
return None # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
def _call_min_max(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker:
|
||||
def _call_min_max(
|
||||
self, tx: "InstructionTranslator", *args: VariableTracker
|
||||
) -> VariableTracker | None:
|
||||
if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
|
||||
items = args[0].force_unpack_var_sequence(tx)
|
||||
return self._call_min_max_seq(tx, items) # type: ignore[return-value]
|
||||
return self._call_min_max_seq(tx, items)
|
||||
elif len(args) == 2:
|
||||
return self._call_min_max_binary(tx, args[0], args[1]) # type: ignore[return-value]
|
||||
return self._call_min_max_binary(tx, args[0], args[1])
|
||||
elif len(args) > 2:
|
||||
return self._call_min_max_seq(tx, args) # type: ignore[arg-type,return-value]
|
||||
return None # type: ignore[return-value]
|
||||
return self._call_min_max_seq(tx, args)
|
||||
return None
|
||||
|
||||
def _call_min_max_seq(
|
||||
self, tx: "InstructionTranslator", items: Any
|
||||
self, tx: "InstructionTranslator", items: Sequence[VariableTracker]
|
||||
) -> VariableTracker:
|
||||
assert len(items) > 0
|
||||
if len(items) == 1:
|
||||
@ -1782,7 +1797,7 @@ class BuiltinVariable(VariableTracker):
|
||||
call_max = _call_min_max
|
||||
|
||||
def call_abs(
|
||||
self, tx: "InstructionTranslator", arg: "VariableTracker"
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
# Call arg.__abs__()
|
||||
abs_method = BuiltinVariable(getattr).call_function(
|
||||
@ -1791,7 +1806,7 @@ class BuiltinVariable(VariableTracker):
|
||||
return abs_method.call_function(tx, [], {})
|
||||
|
||||
def call_pos(
|
||||
self, tx: "InstructionTranslator", arg: "VariableTracker"
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
# Call arg.__pos__()
|
||||
pos_method = BuiltinVariable(getattr).call_function(
|
||||
@ -1800,7 +1815,7 @@ class BuiltinVariable(VariableTracker):
|
||||
return pos_method.call_function(tx, [], {})
|
||||
|
||||
def call_index(
|
||||
self, tx: "InstructionTranslator", arg: "VariableTracker"
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
if isinstance(arg, variables.TensorVariable):
|
||||
unimplemented_v2(
|
||||
@ -1818,8 +1833,8 @@ class BuiltinVariable(VariableTracker):
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
arg: VariableTracker,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
*args: VariableTracker,
|
||||
**kwargs: VariableTracker,
|
||||
) -> VariableTracker:
|
||||
# Call arg.__round__()
|
||||
round_method = BuiltinVariable(getattr).call_function(
|
||||
@ -1828,7 +1843,7 @@ class BuiltinVariable(VariableTracker):
|
||||
return round_method.call_function(tx, args, kwargs)
|
||||
|
||||
def call_range(
|
||||
self, tx: "InstructionTranslator", *args: Any
|
||||
self, tx: "InstructionTranslator", *args: VariableTracker
|
||||
) -> VariableTracker | None:
|
||||
if check_unspec_or_constant_args(args, {}):
|
||||
return variables.RangeVariable(args)
|
||||
@ -1840,12 +1855,14 @@ class BuiltinVariable(VariableTracker):
|
||||
# None no-ops this handler and lets the driving function proceed
|
||||
return None
|
||||
|
||||
def _dynamic_args(self, *args: Any, **kwargs: Any) -> bool:
|
||||
def _dynamic_args(self, *args: VariableTracker, **kwargs: VariableTracker) -> bool:
|
||||
return any(isinstance(x, SymNodeVariable) for x in args) or any(
|
||||
isinstance(x, SymNodeVariable) for x in kwargs.values()
|
||||
)
|
||||
|
||||
def call_slice(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker:
|
||||
def call_slice(
|
||||
self, tx: "InstructionTranslator", *args: VariableTracker
|
||||
) -> VariableTracker:
|
||||
return variables.SliceVariable(args, tx)
|
||||
|
||||
def _dyn_proxy(
|
||||
@ -1865,8 +1882,8 @@ class BuiltinVariable(VariableTracker):
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
obj: VariableTracker | None = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
*args: VariableTracker,
|
||||
**kwargs: VariableTracker,
|
||||
) -> VariableTracker | None:
|
||||
assert not isinstance(obj, variables.IteratorVariable)
|
||||
|
||||
@ -1916,8 +1933,8 @@ class BuiltinVariable(VariableTracker):
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
obj: VariableTracker,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
*args: VariableTracker,
|
||||
**kwargs: VariableTracker,
|
||||
) -> VariableTracker:
|
||||
cls = variables.BaseListVariable.cls_for(self.fn)
|
||||
return cls(
|
||||
@ -1929,8 +1946,8 @@ class BuiltinVariable(VariableTracker):
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
obj: VariableTracker | None = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
*args: VariableTracker,
|
||||
**kwargs: VariableTracker,
|
||||
) -> VariableTracker | None:
|
||||
if isinstance(obj, variables.IteratorVariable):
|
||||
cls = variables.BaseListVariable.cls_for(self.fn)
|
||||
@ -1950,8 +1967,8 @@ class BuiltinVariable(VariableTracker):
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
obj: VariableTracker,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
*args: VariableTracker,
|
||||
**kwargs: VariableTracker,
|
||||
) -> VariableTracker:
|
||||
# avoid the overhead of tracing the polyfill if we already know the class implemented __iter__
|
||||
if isinstance(
|
||||
@ -2020,7 +2037,9 @@ class BuiltinVariable(VariableTracker):
|
||||
else:
|
||||
return None
|
||||
|
||||
def call_cast(self, _: Any, *args: Any, **kwargs: Any) -> VariableTracker | None:
|
||||
def call_cast(
|
||||
self, _: Any, *args: VariableTracker, **kwargs: VariableTracker
|
||||
) -> VariableTracker | None:
|
||||
if len(args) == 2:
|
||||
return args[1]
|
||||
|
||||
@ -2047,28 +2066,34 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
@staticmethod
|
||||
def call_custom_dict(
|
||||
tx: "InstructionTranslator", user_cls: type, *args: Any, **kwargs: Any
|
||||
tx: "InstructionTranslator",
|
||||
user_cls: type,
|
||||
*args: VariableTracker,
|
||||
**kwargs: VariableTracker,
|
||||
) -> VariableTracker:
|
||||
args = list(args)
|
||||
args_list = list(args)
|
||||
if (
|
||||
len(args) == 1
|
||||
and isinstance(args[0], variables.GetAttrVariable)
|
||||
and isinstance(args[0].obj, variables.UserDefinedClassVariable)
|
||||
and not tx.output.side_effects.has_pending_mutation(args[0].obj)
|
||||
len(args_list) == 1
|
||||
and isinstance(args_list[0], variables.GetAttrVariable)
|
||||
and isinstance(args_list[0].obj, variables.UserDefinedClassVariable)
|
||||
and not tx.output.side_effects.has_pending_mutation(args_list[0].obj)
|
||||
):
|
||||
# Forward the GetAttrVariable(foo, "__dict__") to a realized vt of
|
||||
# VT(foo.__dict__). This simplifies the construction of the new
|
||||
# dict.
|
||||
args[0] = args[0].get_forwarded_dict(tx)
|
||||
args_list[0] = args_list[0].get_forwarded_dict(tx)
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, polyfills.construct_dict),
|
||||
[VariableTracker.build(tx, user_cls), *args],
|
||||
[VariableTracker.build(tx, user_cls), *args_list],
|
||||
kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def call_custom_dict_fromkeys(
|
||||
tx: "InstructionTranslator", user_cls: type, *args: Any, **kwargs: Any
|
||||
tx: "InstructionTranslator",
|
||||
user_cls: type,
|
||||
*args: VariableTracker,
|
||||
**kwargs: VariableTracker,
|
||||
) -> VariableTracker:
|
||||
if user_cls not in {dict, OrderedDict, defaultdict}:
|
||||
unimplemented_v2(
|
||||
@ -2111,16 +2136,17 @@ class BuiltinVariable(VariableTracker):
|
||||
"2 args",
|
||||
f"{len(args)} args",
|
||||
)
|
||||
assert len(args) >= 2
|
||||
arg, value = args
|
||||
DictVariableType = (
|
||||
ConstDictVariable if user_cls is not defaultdict else DefaultDictVariable
|
||||
)
|
||||
|
||||
if isinstance(arg, dict):
|
||||
arg = [ConstantVariable.create(k) for k in arg.keys()]
|
||||
arg_list = [ConstantVariable.create(k) for k in arg.keys()]
|
||||
return DictVariableType(
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
dict.fromkeys(arg, value),
|
||||
dict.fromkeys(arg_list, value),
|
||||
user_cls,
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
@ -2147,7 +2173,10 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
|
||||
def call_set(
|
||||
self, tx: "InstructionTranslator", *args: Any, **kwargs: Any
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
*args: VariableTracker,
|
||||
**kwargs: VariableTracker,
|
||||
) -> VariableTracker:
|
||||
# Can we merge this implementation and call_dict's one?
|
||||
assert not kwargs
|
||||
@ -2185,7 +2214,10 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
|
||||
def call_frozenset(
|
||||
self, tx: "InstructionTranslator", *args: Any, **kwargs: Any
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
*args: VariableTracker,
|
||||
**kwargs: VariableTracker,
|
||||
) -> VariableTracker:
|
||||
assert not kwargs
|
||||
if not args:
|
||||
@ -2213,7 +2245,10 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
|
||||
def call_zip(
|
||||
self, tx: "InstructionTranslator", *args: Any, **kwargs: Any
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
*args: VariableTracker,
|
||||
**kwargs: VariableTracker,
|
||||
) -> VariableTracker:
|
||||
if kwargs:
|
||||
if not (len(kwargs) == 1 and "strict" in kwargs):
|
||||
@ -2226,21 +2261,29 @@ class BuiltinVariable(VariableTracker):
|
||||
strict = kwargs.pop("strict", False)
|
||||
iter_args = [BuiltinVariable(iter).call_function(tx, [arg], {}) for arg in args]
|
||||
return variables.ZipVariable(
|
||||
iter_args, strict=strict, mutation_type=ValueMutationNew()
|
||||
iter_args,
|
||||
strict=strict, # type: ignore[arg-type]
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
|
||||
def call_len(
|
||||
self, tx: "InstructionTranslator", *args: Any, **kwargs: Any
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
*args: VariableTracker,
|
||||
**kwargs: VariableTracker,
|
||||
) -> VariableTracker:
|
||||
try:
|
||||
return args[0].call_method(tx, "__len__", args[1:], kwargs)
|
||||
return args[0].call_method(tx, "__len__", list(args[1:]), kwargs)
|
||||
except AttributeError as e:
|
||||
raise_observed_exception(type(e), tx, args=list(e.args))
|
||||
|
||||
def call_getitem(
|
||||
self, tx: "InstructionTranslator", *args: Any, **kwargs: Any
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
*args: VariableTracker,
|
||||
**kwargs: VariableTracker,
|
||||
) -> VariableTracker:
|
||||
return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
|
||||
return args[0].call_method(tx, "__getitem__", list(args[1:]), kwargs)
|
||||
|
||||
def call_isinstance(
|
||||
self,
|
||||
@ -2260,7 +2303,9 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
if isinstance(arg, variables.TensorVariable) and arg.dtype is not None:
|
||||
|
||||
def _tensor_isinstance(tensor_var: Any, tensor_type: Any) -> bool:
|
||||
def _tensor_isinstance(
|
||||
tensor_var: VariableTracker, tensor_type: Any
|
||||
) -> bool:
|
||||
def check_type(ty: Any) -> bool:
|
||||
if ty not in tensortype_to_dtype:
|
||||
example_val = arg.as_proxy().node.meta["example_value"]
|
||||
@ -2379,7 +2424,9 @@ class BuiltinVariable(VariableTracker):
|
||||
) -> VariableTracker:
|
||||
return variables.SuperVariable(a, b)
|
||||
|
||||
def call_next(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker:
|
||||
def call_next(
|
||||
self, tx: "InstructionTranslator", *args: VariableTracker
|
||||
) -> VariableTracker:
|
||||
arg = args[0]
|
||||
try:
|
||||
return arg.next_variable(tx)
|
||||
@ -2404,19 +2451,23 @@ class BuiltinVariable(VariableTracker):
|
||||
return None
|
||||
|
||||
def call_map(
|
||||
self, tx: "InstructionTranslator", fn: VariableTracker, *seqs: Any
|
||||
self, tx: "InstructionTranslator", fn: VariableTracker, *seqs: VariableTracker
|
||||
) -> VariableTracker:
|
||||
seqs = [
|
||||
seq_list = [
|
||||
seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq
|
||||
for seq in seqs
|
||||
]
|
||||
return variables.MapVariable(fn, seqs, mutation_type=ValueMutationNew())
|
||||
return variables.MapVariable(fn, seq_list, mutation_type=ValueMutationNew())
|
||||
|
||||
def call_filter(
|
||||
self, tx: "InstructionTranslator", fn: VariableTracker, seq: VariableTracker
|
||||
) -> VariableTracker:
|
||||
seq = seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq
|
||||
return variables.FilterVariable(fn, seq, mutation_type=ValueMutationNew())
|
||||
seq_or_list = (
|
||||
seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq
|
||||
)
|
||||
return variables.FilterVariable(
|
||||
fn, seq_or_list, mutation_type=ValueMutationNew()
|
||||
)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
source = self.source and AttrSource(self.source, name)
|
||||
@ -2683,7 +2734,7 @@ class BuiltinVariable(VariableTracker):
|
||||
# .data setting to play correctly with the autograd engine.
|
||||
# Essentially, dynamo is trying to faithfully preserve the (absurd)
|
||||
# behavior of .data= from eager mode
|
||||
def _lower_version_count_by_1(x: Any) -> Any:
|
||||
def _lower_version_count_by_1(x: torch.Tensor) -> torch.Tensor:
|
||||
version = x._version
|
||||
if version > 0:
|
||||
version = version - 1
|
||||
@ -2836,14 +2887,16 @@ class BuiltinVariable(VariableTracker):
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
_format_string: VariableTracker,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
*args: VariableTracker,
|
||||
**kwargs: VariableTracker,
|
||||
) -> VariableTracker:
|
||||
format_string = _format_string.as_python_constant()
|
||||
format_string = str(format_string)
|
||||
return variables.StringFormatVariable.create(format_string, args, kwargs)
|
||||
|
||||
def call_id(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker:
|
||||
def call_id(
|
||||
self, tx: "InstructionTranslator", *args: VariableTracker
|
||||
) -> VariableTracker:
|
||||
if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable):
|
||||
nn_mod_variable = args[0]
|
||||
mod = tx.output.get_submodule(nn_mod_variable.module_key)
|
||||
|
||||
@ -23,6 +23,7 @@ import operator
|
||||
import textwrap
|
||||
import traceback
|
||||
import types
|
||||
from collections.abc import Sequence
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -631,7 +632,7 @@ class TensorVariable(VariableTracker):
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
args: Sequence[VariableTracker],
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
Reference in New Issue
Block a user