mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-30 11:44:59 +08:00
Compare commits
8 Commits
ciflow/tru
...
lucaskabel
| Author | SHA1 | Date | |
|---|---|---|---|
| 8905692c19 | |||
| 7e6924b854 | |||
| e3e1536be4 | |||
| 3399492b27 | |||
| a5216cdcee | |||
| 3aa489879c | |||
| 31d8056b6f | |||
| fe4a9aae0e |
@ -153,7 +153,7 @@ class PyCodegen:
|
||||
self.clear_tos()
|
||||
|
||||
def __call__(
|
||||
self, value: Union[VariableTracker, Source], allow_cache: bool = True
|
||||
self, value: Union[VariableTracker, Source, None], allow_cache: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Generate code such that top-of-stack (TOS) is set to value.
|
||||
@ -188,7 +188,7 @@ class PyCodegen:
|
||||
value to handle aliasing (check side_effects.py and search for
|
||||
allow_cache=False).
|
||||
|
||||
b) If value.source is None, this is not allowed. TODO - assert this.
|
||||
b) If value.source is None, this is not allowed
|
||||
|
||||
Notable effects:
|
||||
1. `self.top_of_stack` will be set to `value`, if we don't codegen
|
||||
@ -197,6 +197,7 @@ class PyCodegen:
|
||||
`top_of_stack` or cached `tempvars`, or (b). `value` has special VT
|
||||
types like `NNModuleVariable`, etc.
|
||||
"""
|
||||
assert value is not None
|
||||
if isinstance(value, Source):
|
||||
# If the source needs to be overridden, use the new one.
|
||||
source = self.overridden_sources.get(value, value)
|
||||
@ -289,7 +290,8 @@ class PyCodegen:
|
||||
self.load_graph_output(graph_outputs[graph_outputs_key].index)
|
||||
output.append(
|
||||
self.create_load_global(
|
||||
value.global_mangled_class_name(self.tx), add=True
|
||||
value.global_mangled_class_name(self.tx), # type: ignore[arg-type]
|
||||
add=True,
|
||||
)
|
||||
)
|
||||
output.extend(create_call_function(2, False))
|
||||
|
||||
@ -1303,6 +1303,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
|
||||
# A small codegen optimization because we might have different
|
||||
# VariableTrackers that share the same source.
|
||||
assert x.source is not None
|
||||
list_idx = x.source.index # type: ignore[attr-defined]
|
||||
if list_idx not in visited:
|
||||
alias_name = self.new_var(
|
||||
@ -1321,6 +1322,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
)
|
||||
|
||||
# operate on alias, handled by suffix codegen
|
||||
assert x.source is not None
|
||||
old_source = x.source
|
||||
overridden_sources[old_source] = LocalSource(visited[list_idx])
|
||||
|
||||
@ -1864,7 +1866,6 @@ class OutputGraph(OutputGraphCommon):
|
||||
and isinstance(var.value, _ExportModuleSpecTrackerDict)
|
||||
):
|
||||
potential_side_effects.append(var)
|
||||
|
||||
side_effect_refs = [
|
||||
_get_source_debug_name(var.source) for var in potential_side_effects
|
||||
]
|
||||
|
||||
@ -258,6 +258,7 @@ class SideEffects:
|
||||
"Dynamo needs to fully exhaust the generator, which may cause "
|
||||
"unintended variable modifications."
|
||||
)
|
||||
assert item.mutation_type is not None
|
||||
if not is_side_effect_safe(item.mutation_type):
|
||||
# TODO plumb HOP information here
|
||||
unimplemented_v2(
|
||||
@ -373,7 +374,7 @@ class SideEffects:
|
||||
|
||||
if self.is_attribute_mutation(item):
|
||||
return item in self.store_attr_mutations
|
||||
|
||||
assert item.mutation_type is not None
|
||||
return item.mutation_type.is_modified # type: ignore[attr-defined]
|
||||
|
||||
def _track_obj(
|
||||
|
||||
@ -111,9 +111,9 @@ def is_constant_source(source: Source) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _get_source_debug_name(source: Source) -> str:
|
||||
def _get_source_debug_name(source: Optional[Source]) -> str:
|
||||
try:
|
||||
return source.name()
|
||||
return source.name() # type: ignore[union-attr]
|
||||
except NotImplementedError:
|
||||
return "<unknown source>"
|
||||
|
||||
|
||||
@ -5024,7 +5024,7 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
|
||||
):
|
||||
if isinstance(val, ConstantVariable) and val.value is None:
|
||||
try:
|
||||
val = tos.next_variable(self)
|
||||
val = tos.next_variable(self) # type: ignore[arg-type]
|
||||
except (StopIteration, exc.ObservedUserStopIteration) as ex:
|
||||
# To implement SEND, we have to look at the implementation
|
||||
# when the iterator returns StopIteration. This translates to this code
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Core variable tracking functionality for Dynamo. This module defines the fundamental
|
||||
classes and systems used to track and manage variables during Dynamo's operation.
|
||||
@ -20,6 +18,9 @@ from collections.abc import Callable, ItemsView, KeysView, Sequence, ValuesView
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
from torch._guards import Guard
|
||||
from torch.fx.proxy import Node
|
||||
|
||||
from .. import graph_break_hints, variables
|
||||
from ..current_scope_id import current_scope_id
|
||||
from ..exc import raise_observed_exception, unimplemented_v2
|
||||
@ -30,7 +31,6 @@ from ..utils import cmp_name_to_op_mapping, istype
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..codegen import PyCodegen
|
||||
from ..symbolic_convert import InstructionTranslator, InstructionTranslatorBase
|
||||
|
||||
|
||||
class SourceType(Enum):
|
||||
@ -115,10 +115,10 @@ class ValueMutationNew(MutationType):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(SourceType.New)
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return self is other
|
||||
|
||||
|
||||
@ -139,7 +139,7 @@ class ValueMutationExisting(MutationType):
|
||||
# filter out which pre-existing values it needs to generate mutation for.
|
||||
is_modified: bool
|
||||
|
||||
def __init__(self, is_modified: bool = False):
|
||||
def __init__(self, is_modified: bool = False) -> None:
|
||||
super().__init__(SourceType.Existing)
|
||||
self.is_modified = is_modified
|
||||
|
||||
@ -150,7 +150,7 @@ class AttributeMutation(MutationType):
|
||||
allows mutation on the value's attributes.
|
||||
"""
|
||||
|
||||
def __init__(self, typ: SourceType):
|
||||
def __init__(self, typ: SourceType) -> None:
|
||||
super().__init__(typ)
|
||||
|
||||
|
||||
@ -166,7 +166,7 @@ class AttributeMutationExisting(AttributeMutation):
|
||||
be used afterwards in Python.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(SourceType.Existing)
|
||||
|
||||
|
||||
@ -182,16 +182,16 @@ class AttributeMutationNew(AttributeMutation):
|
||||
the Python world.
|
||||
"""
|
||||
|
||||
def __init__(self, cls_source: Optional[Source] = None):
|
||||
def __init__(self, cls_source: Optional[Source] = None) -> None:
|
||||
super().__init__(SourceType.New)
|
||||
self.cls_source = cls_source
|
||||
|
||||
|
||||
def _is_top_level_scope(scope_id):
|
||||
def _is_top_level_scope(scope_id: int) -> bool:
|
||||
return scope_id == 1
|
||||
|
||||
|
||||
def is_side_effect_safe(m: MutationType):
|
||||
def is_side_effect_safe(m: MutationType) -> bool:
|
||||
scope_id = current_scope_id()
|
||||
|
||||
# In the top-level scope (if no HigherOrderOperators are involved),
|
||||
@ -209,15 +209,15 @@ def is_side_effect_safe(m: MutationType):
|
||||
class AsPythonConstantNotImplementedError(NotImplementedError):
|
||||
vt: "VariableTracker"
|
||||
|
||||
def __init__(self, vt: "VariableTracker"):
|
||||
def __init__(self, vt: "VariableTracker") -> None:
|
||||
super().__init__(f"{vt} is not a constant")
|
||||
self.vt = vt
|
||||
|
||||
|
||||
class VariableTrackerMeta(type):
|
||||
all_subclasses = []
|
||||
all_subclasses: list[type] = []
|
||||
|
||||
def __instancecheck__(cls, instance) -> bool:
|
||||
def __instancecheck__(cls: type, instance: object) -> bool:
|
||||
"""Make isinstance work with LazyVariableTracker"""
|
||||
# This is super expensive - just having it costs over 4% of tracing
|
||||
# time!
|
||||
@ -227,8 +227,10 @@ class VariableTrackerMeta(type):
|
||||
instance = instance.realize()
|
||||
return type.__instancecheck__(cls, instance)
|
||||
|
||||
def __init__(cls, name, bases, attrs) -> None:
|
||||
super().__init__(name, bases, attrs)
|
||||
def __init__(
|
||||
cls: type, name: str, bases: tuple[type, ...], attrs: dict[str, Any]
|
||||
) -> None:
|
||||
super().__init__(name, bases, attrs) # type: ignore[misc]
|
||||
VariableTrackerMeta.all_subclasses.append(cls)
|
||||
|
||||
|
||||
@ -252,7 +254,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
"user_code_variable_name",
|
||||
}
|
||||
|
||||
def clone(self, **kwargs):
|
||||
def clone(self, **kwargs: Any) -> "VariableTracker":
|
||||
"""Shallow copy with some (optional) changes"""
|
||||
args = dict(self.__dict__)
|
||||
args.update(kwargs)
|
||||
@ -295,14 +297,14 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
# Intended to be overridden to provide more info
|
||||
try:
|
||||
return repr(self.as_python_constant())
|
||||
except NotImplementedError:
|
||||
return repr(self)
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
"""
|
||||
Abstract method to be implemented by subclasses of VariableTracker.
|
||||
|
||||
@ -331,17 +333,17 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
except NotImplementedError:
|
||||
raise NotImplementedError(f"{self} has no type") from None
|
||||
|
||||
def python_type_name(self):
|
||||
def python_type_name(self) -> str:
|
||||
try:
|
||||
return self.python_type().__name__
|
||||
except NotImplementedError:
|
||||
return "<unknown type>"
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
"""For constants"""
|
||||
raise AsPythonConstantNotImplementedError(self)
|
||||
|
||||
def guard_as_python_constant(self):
|
||||
def guard_as_python_constant(self) -> Any:
|
||||
"""Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants"""
|
||||
try:
|
||||
return self.as_python_constant()
|
||||
@ -353,23 +355,25 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
hints=[],
|
||||
)
|
||||
|
||||
def is_python_constant(self):
|
||||
def is_python_constant(self) -> bool:
|
||||
try:
|
||||
self.as_python_constant()
|
||||
return True
|
||||
except NotImplementedError:
|
||||
return False
|
||||
|
||||
def make_guard(self, fn):
|
||||
def make_guard(self, fn: Callable[..., Any]) -> Guard:
|
||||
if self.source:
|
||||
return self.source.make_guard(fn)
|
||||
raise NotImplementedError
|
||||
|
||||
def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any:
|
||||
# TODO[@lucaskabela] - change this type to `InstructionTranslatorBase`
|
||||
# and cascade that (large blast radius)
|
||||
def const_getattr(self, tx: Any, name: str) -> Any:
|
||||
"""getattr(self, name) returning a python constant"""
|
||||
raise NotImplementedError
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||
def var_getattr(self, tx: Any, name: str) -> "VariableTracker":
|
||||
"""getattr(self, name) returning a new variable"""
|
||||
value = self.const_getattr(tx, name)
|
||||
if not variables.ConstantVariable.is_literal(value):
|
||||
@ -381,17 +385,17 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
|
||||
return variables.ConstantVariable.create(value, source=source)
|
||||
|
||||
def is_proxy(self):
|
||||
def is_proxy(self) -> bool:
|
||||
try:
|
||||
self.as_proxy()
|
||||
return True
|
||||
except NotImplementedError:
|
||||
return False
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> Any:
|
||||
raise NotImplementedError(str(self))
|
||||
|
||||
def maybe_fx_node(self):
|
||||
def maybe_fx_node(self) -> Optional[Node]:
|
||||
try:
|
||||
proxy = self.as_proxy()
|
||||
import torch.fx
|
||||
@ -402,13 +406,13 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
except NotImplementedError:
|
||||
return None
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def unpack_var_sequence(self, tx) -> list["VariableTracker"]:
|
||||
def unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]:
|
||||
raise NotImplementedError
|
||||
|
||||
def force_unpack_var_sequence(self, tx) -> list["VariableTracker"]:
|
||||
def force_unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]:
|
||||
# like unpack_var_sequence, but should only be used when it is
|
||||
# safe to eagerly (vs. lazily) unpack this variable.
|
||||
# e.g. map(f, x) is normally evaluated lazily but sometimes
|
||||
@ -417,7 +421,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
# it should only be called once.
|
||||
return self.unpack_var_sequence(tx)
|
||||
|
||||
def has_unpack_var_sequence(self, tx) -> bool:
|
||||
def has_unpack_var_sequence(self, tx: Any) -> bool:
|
||||
try:
|
||||
self.unpack_var_sequence(tx)
|
||||
return True
|
||||
@ -425,13 +429,15 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
return False
|
||||
|
||||
# NB: don't call force_unpack_var_sequence, especially if it mutates!
|
||||
def has_force_unpack_var_sequence(self, tx) -> bool:
|
||||
def has_force_unpack_var_sequence(self, tx: Any) -> bool:
|
||||
return self.has_unpack_var_sequence(tx)
|
||||
|
||||
# Forces unpacking the var sequence while also applying a function to each element.
|
||||
# Only use when it is safe to eagerly unpack this variable (like force_unpack_var_sequence).
|
||||
# INVARIANT: variable must satisfy has_force_unpack_var_sequence() == True!
|
||||
def force_apply_to_var_sequence(self, tx, fn) -> None:
|
||||
def force_apply_to_var_sequence(
|
||||
self, tx: Any, fn: Callable[["VariableTracker"], Any]
|
||||
) -> None:
|
||||
assert self.has_force_unpack_var_sequence(tx)
|
||||
for v in self.unpack_var_sequence(tx):
|
||||
fn(v)
|
||||
@ -444,9 +450,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
hints=[],
|
||||
)
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> "VariableTracker":
|
||||
def call_obj_hasattr(self, tx: Any, name: str) -> "VariableTracker":
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported hasattr call",
|
||||
context=f"call_obj_hasattr {self} {name}",
|
||||
@ -459,9 +463,9 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
|
||||
def call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
tx: Any,
|
||||
args: Sequence["VariableTracker"],
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported function call",
|
||||
@ -475,10 +479,10 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
tx: Any,
|
||||
name: str,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
if name == "__len__" and self.has_unpack_var_sequence(tx):
|
||||
assert not (args or kwargs)
|
||||
@ -562,7 +566,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
hints=hints,
|
||||
)
|
||||
|
||||
def set_name_hint(self, name):
|
||||
def set_name_hint(self, name: str) -> None:
|
||||
pass
|
||||
|
||||
def realize(self) -> "VariableTracker":
|
||||
@ -573,11 +577,11 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
"""Used by LazyVariableTracker to return the real VariableTracker if it already exists"""
|
||||
return self
|
||||
|
||||
def is_realized(self):
|
||||
def is_realized(self) -> bool:
|
||||
"""Used by LazyVariableTracker to indicate an unrealized node"""
|
||||
return True
|
||||
|
||||
def next_variable(self, tx):
|
||||
def next_variable(self, tx: Any) -> "VariableTracker":
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported next() call",
|
||||
context=f"next({self})",
|
||||
@ -585,20 +589,20 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
hints=[*graph_break_hints.USER_ERROR],
|
||||
)
|
||||
|
||||
def is_strict_mode(self, tx):
|
||||
return tx.strict_checks_fn and tx.strict_checks_fn(self)
|
||||
def is_strict_mode(self, tx: Any) -> bool:
|
||||
return bool(tx.strict_checks_fn and tx.strict_checks_fn(self))
|
||||
|
||||
def is_mutable(self):
|
||||
def is_mutable(self) -> bool:
|
||||
"""Whether Dynamo allows mutation on this variable."""
|
||||
return not self.is_immutable()
|
||||
|
||||
def is_immutable(self):
|
||||
def is_immutable(self) -> bool:
|
||||
"""Whether Dynamo bans mutation on this variable."""
|
||||
return self.mutation_type is None
|
||||
|
||||
@staticmethod
|
||||
def build(
|
||||
tx: "InstructionTranslatorBase",
|
||||
tx: Any,
|
||||
value: Any,
|
||||
source: Optional[Source] = None,
|
||||
) -> Any:
|
||||
@ -611,8 +615,8 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
source: Source = None,
|
||||
mutation_type: MutationType = None,
|
||||
source: Optional[Source] = None,
|
||||
mutation_type: Optional[MutationType] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.source = source
|
||||
@ -636,12 +640,12 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
assert source is not None
|
||||
|
||||
|
||||
def raise_type_error_exc(tx: "InstructionTranslator", msg_str: str) -> None:
|
||||
def raise_type_error_exc(tx: Any, msg_str: str) -> None:
|
||||
msg = variables.ConstantVariable.create(msg_str)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
|
||||
|
||||
def typestr(*objs):
|
||||
def typestr(*objs: object) -> str:
|
||||
if len(objs) == 1:
|
||||
(obj,) = objs
|
||||
if isinstance(obj, VariableTracker):
|
||||
|
||||
@ -303,7 +303,7 @@ except ModuleNotFoundError:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.codegen import PyCodegen
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslatorBase
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -2616,7 +2616,7 @@ class VariableBuilder:
|
||||
fake_tensor_value = example_value
|
||||
assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
|
||||
f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
|
||||
"({self.tx.fake_mode}) from InstructionTranslator"
|
||||
"({self.tx.fake_mode}) from InstructionTranslatorBase"
|
||||
)
|
||||
|
||||
# There's something a bit incoherent about pass_arg_as_tensor,
|
||||
@ -2700,7 +2700,7 @@ class VariableBuilder:
|
||||
fake_tensor_value = example_value
|
||||
assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
|
||||
f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
|
||||
"({self.tx.fake_mode}) from InstructionTranslator"
|
||||
"({self.tx.fake_mode}) from InstructionTranslatorBase"
|
||||
)
|
||||
|
||||
proxy.node.meta["grapharg"] = GraphArg(
|
||||
@ -3268,7 +3268,7 @@ def is_dynamic_source(source_name: str) -> bool:
|
||||
|
||||
|
||||
def record_automatic_dynamic(
|
||||
tx: "InstructionTranslator", name: str, e: torch.Tensor
|
||||
tx: "InstructionTranslatorBase", name: str, e: torch.Tensor
|
||||
) -> FrameStateSizeEntry:
|
||||
# This mimics stride inference algorithm in _create_symbolic_sizes_strides_storage_offset
|
||||
ex_size = e.size()
|
||||
@ -3722,7 +3722,7 @@ class SourcelessBuilder:
|
||||
raise AssertionError("Use SourcelessBuilder.create()")
|
||||
|
||||
@staticmethod
|
||||
def create(tx: "InstructionTranslator", value) -> VariableTracker:
|
||||
def create(tx: "InstructionTranslatorBase", value) -> VariableTracker:
|
||||
value_type = type(value)
|
||||
fast_handler = SourcelessBuilder._type_handlers.get(value_type)
|
||||
if fast_handler:
|
||||
@ -3868,7 +3868,7 @@ class SourcelessBuilder:
|
||||
)
|
||||
)
|
||||
|
||||
def passthrough(tx: "InstructionTranslator", value):
|
||||
def passthrough(tx: "InstructionTranslatorBase", value):
|
||||
return value
|
||||
|
||||
for cls in VariableTrackerMeta.all_subclasses:
|
||||
@ -3890,7 +3890,7 @@ class SourcelessUserDefinedObjectBuilder:
|
||||
raise AssertionError("Use SourcelessUserDefinedObjectBuilder.create()")
|
||||
|
||||
@staticmethod
|
||||
def create(tx: "InstructionTranslator", value) -> VariableTracker:
|
||||
def create(tx: "InstructionTranslatorBase", value) -> VariableTracker:
|
||||
value_type = type(value)
|
||||
if issubclass(value_type, MutableMapping):
|
||||
return MutableMappingVariable(value, mutation_type=ValueMutationNew())
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Constant and enum variable tracking in Dynamo.
|
||||
|
||||
@ -8,8 +6,9 @@ values during compilation, ensuring proper handling of Python literals and
|
||||
maintaining type safety through the compilation process.
|
||||
"""
|
||||
|
||||
import enum
|
||||
import operator
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch._dynamo.source import AttrSource, GetItemSource
|
||||
@ -40,7 +39,7 @@ class ConstantVariable(VariableTracker):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(value, **kwargs) -> VariableTracker:
|
||||
def create(value: Any, **kwargs: Any) -> VariableTracker:
|
||||
"""
|
||||
Create a `ConstantVariable` based on the given value, and supports
|
||||
automatic routing for collection types like `tuple` (in which case we'd
|
||||
@ -76,7 +75,7 @@ class ConstantVariable(VariableTracker):
|
||||
|
||||
return ConstantVariable(value, **kwargs)
|
||||
|
||||
def __init__(self, value, **kwargs) -> None:
|
||||
def __init__(self, value: Any, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert ConstantVariable.is_base_literal(value), f"""
|
||||
Cannot construct `ConstantVariable` for value of type {type(value)}.
|
||||
@ -92,48 +91,52 @@ its type to `common_constant_types`.
|
||||
else:
|
||||
self.value = value
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> Any:
|
||||
return self.value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})"
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return self.value
|
||||
|
||||
def is_python_constant(self):
|
||||
def is_python_constant(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def items(self):
|
||||
def items(self) -> list[VariableTracker]:
|
||||
"""
|
||||
Need this when adding a BaseListVariable and a ConstantVariable together.
|
||||
Happens in detectron2.
|
||||
"""
|
||||
return self.unpack_var_sequence(tx=None)
|
||||
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
return ConstantVariable.create(
|
||||
self.value[arg.as_python_constant()],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_base_literal(obj):
|
||||
def is_base_literal(obj: object) -> bool:
|
||||
return type(obj) in common_constant_types
|
||||
|
||||
@staticmethod
|
||||
def is_literal(obj):
|
||||
def is_literal(obj: object) -> bool:
|
||||
if type(obj) in (list, tuple, set, frozenset, torch.Size):
|
||||
return all(ConstantVariable.is_literal(x) for x in obj)
|
||||
return all(ConstantVariable.is_literal(x) for x in obj) # type: ignore[attr-defined]
|
||||
return ConstantVariable.is_base_literal(obj)
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unpack_var_sequence(
|
||||
self, tx: Optional["InstructionTranslator"]
|
||||
) -> list[VariableTracker]:
|
||||
try:
|
||||
return [ConstantVariable.create(x) for x in self.as_python_constant()]
|
||||
except TypeError as e:
|
||||
raise NotImplementedError from e
|
||||
|
||||
def const_getattr(self, tx: "InstructionTranslator", name):
|
||||
def const_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
if not hasattr(self.value, name):
|
||||
raise_observed_exception(AttributeError, tx, args=[name])
|
||||
member = getattr(self.value, name)
|
||||
@ -144,10 +147,10 @@ its type to `common_constant_types`.
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
from .tensor import SymNodeVariable
|
||||
|
||||
if name == "format" and istype(self.value, str):
|
||||
@ -254,7 +257,7 @@ its type to `common_constant_types`.
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
result = hasattr(self.value, name)
|
||||
return variables.ConstantVariable.create(result)
|
||||
|
||||
@ -266,12 +269,14 @@ class EnumVariable(VariableTracker):
|
||||
both standard Enum and IntEnum with proper value tracking and comparison.
|
||||
"""
|
||||
|
||||
def __init__(self, value, **kwargs) -> None:
|
||||
def __init__(self, value: Union[enum.Enum, enum.IntEnum], **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.value = value
|
||||
|
||||
@classmethod
|
||||
def create(cls, cls_type, value_vt, options):
|
||||
def create(
|
||||
cls, cls_type: Any, value_vt: VariableTracker, options: Any
|
||||
) -> "EnumVariable":
|
||||
if isinstance(value_vt, variables.ConstantVariable):
|
||||
for member in list(cls_type):
|
||||
if member.value == value_vt.as_python_constant():
|
||||
@ -285,7 +290,7 @@ class EnumVariable(VariableTracker):
|
||||
hints=[*graph_break_hints.USER_ERROR, *graph_break_hints.SUPPORTABLE],
|
||||
)
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> Union[enum.Enum, int]:
|
||||
if isinstance(self.value, int):
|
||||
return int(self.value) # convert IntEnum to a normal int
|
||||
return self.value
|
||||
@ -293,10 +298,10 @@ class EnumVariable(VariableTracker):
|
||||
def __repr__(self) -> str:
|
||||
return f"EnumVariable({type(self.value)})"
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Union[enum.Enum, enum.IntEnum]:
|
||||
return self.value
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
if not hasattr(self.value, name):
|
||||
raise NotImplementedError
|
||||
if name in cmp_name_to_op_mapping:
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Distributed computing variable tracking classes for PyTorch Dynamo.
|
||||
|
||||
@ -22,7 +20,7 @@ checks and proper tracking of distributed state and operations across processes.
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any, Sequence, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
@ -40,6 +38,7 @@ from .constant import ConstantVariable, EnumVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.codegen import PyCodegen
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
@ -54,7 +53,7 @@ class DistributedVariable(VariableTracker):
|
||||
and hold the tracking value for the corresponding distributed object.
|
||||
"""
|
||||
|
||||
def __init__(self, value, **kwargs) -> None:
|
||||
def __init__(self, value: Any, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if not DistributedVariable.is_available():
|
||||
unimplemented_v2(
|
||||
@ -67,16 +66,16 @@ class DistributedVariable(VariableTracker):
|
||||
)
|
||||
self.value = value
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return type(self.value)
|
||||
|
||||
@staticmethod
|
||||
def is_available():
|
||||
def is_available() -> bool:
|
||||
# check if the distributed package is available or not
|
||||
return torch.distributed.is_available()
|
||||
|
||||
|
||||
def is_from_local(value):
|
||||
def is_from_local(value: Any) -> bool:
|
||||
if not DistributedVariable.is_available():
|
||||
return False
|
||||
from torch.distributed.tensor import DTensor
|
||||
@ -84,7 +83,7 @@ def is_from_local(value):
|
||||
return inspect.isfunction(value) and value is DTensor.from_local
|
||||
|
||||
|
||||
def is_constant_pg_functions(value):
|
||||
def is_constant_pg_functions(value: Any) -> bool:
|
||||
if not DistributedVariable.is_available():
|
||||
return False
|
||||
|
||||
@ -114,7 +113,7 @@ class WorldMetaClassVariable(DistributedVariable):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_group_member_type(cls, value):
|
||||
def is_group_member_type(cls, value: Any) -> bool:
|
||||
if not cls.is_available():
|
||||
return False
|
||||
|
||||
@ -124,10 +123,12 @@ class WorldMetaClassVariable(DistributedVariable):
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
if name == "WORLD":
|
||||
assert self.source
|
||||
source = AttrSource(base=self.source, member="WORLD")
|
||||
install_guard(source.make_guard(GuardBuilder.ID_MATCH))
|
||||
return ProcessGroupVariable(self.value.WORLD)
|
||||
elif name == "NON_GROUP_MEMBER":
|
||||
assert self.source
|
||||
source = AttrSource(base=self.source, member="NON_GROUP_MEMBER")
|
||||
install_guard(source.make_guard(GuardBuilder.ID_MATCH))
|
||||
return EnumVariable(self.value.NON_GROUP_MEMBER)
|
||||
@ -136,7 +137,7 @@ class WorldMetaClassVariable(DistributedVariable):
|
||||
|
||||
class PlacementClassVariable(DistributedVariable):
|
||||
@staticmethod
|
||||
def is_placement_type(value):
|
||||
def is_placement_type(value: Any) -> bool:
|
||||
# we can't rely on importing/accessing torch distributed, it is not always built.
|
||||
if not DistributedVariable.is_available():
|
||||
return False
|
||||
@ -145,15 +146,15 @@ class PlacementClassVariable(DistributedVariable):
|
||||
|
||||
return isinstance(value, type) and issubclass(value, Placement)
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return self.value
|
||||
|
||||
def call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
args: Sequence[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if self.source:
|
||||
# NOTE: we don't need to track mutations to the placement class as they
|
||||
# are supposed to be immutable.
|
||||
@ -168,16 +169,15 @@ class PlacementClassVariable(DistributedVariable):
|
||||
|
||||
class PlacementVariable(DistributedVariable):
|
||||
@staticmethod
|
||||
def is_placement(value):
|
||||
def is_placement(value: Any) -> bool:
|
||||
# we can't rely on importing/accessing torch distributed, it is not always built.
|
||||
if not DistributedVariable.is_available():
|
||||
return False
|
||||
|
||||
from torch.distributed.tensor.placement_types import Placement
|
||||
|
||||
return isinstance(value, Placement)
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return self.value
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
@ -187,11 +187,11 @@ class PlacementVariable(DistributedVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: Sequence[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
from . import ConstantVariable
|
||||
|
||||
# Placement types dynamo tracking only allows following methods
|
||||
@ -221,15 +221,16 @@ class PlacementVariable(DistributedVariable):
|
||||
|
||||
args = [x.as_python_constant() for x in args]
|
||||
kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
||||
assert method is not None
|
||||
if name == "__setattr__":
|
||||
method(self.value, *args, **kwargs)
|
||||
return self
|
||||
constant_val = method(self.value, *args, **kwargs)
|
||||
return ConstantVariable.create(constant_val)
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
return super().call_method(tx, name, args, kwargs) # type: ignore[arg-type]
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# Reconstruct the Placement object by calling its constructor
|
||||
# e.g., Shard(0), Replicate(), Partial()
|
||||
from torch.distributed.tensor.placement_types import Partial, Replicate, Shard
|
||||
@ -256,7 +257,7 @@ class PlacementVariable(DistributedVariable):
|
||||
|
||||
class DeviceMeshVariable(DistributedVariable):
|
||||
@staticmethod
|
||||
def is_device_mesh(value):
|
||||
def is_device_mesh(value: Any) -> bool:
|
||||
# we can't rely on importing/accessing torch distributed, it is not always built.
|
||||
if not DistributedVariable.is_available():
|
||||
return False
|
||||
@ -265,7 +266,7 @@ class DeviceMeshVariable(DistributedVariable):
|
||||
|
||||
return istype(value, DeviceMesh)
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return self.value
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
@ -282,11 +283,11 @@ class DeviceMeshVariable(DistributedVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "size":
|
||||
const_args = [x.as_python_constant() for x in args]
|
||||
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
||||
@ -331,16 +332,16 @@ class ProcessGroupVariable(DistributedVariable):
|
||||
or just graph-break whenever one of our special cases is not hit?
|
||||
"""
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return self.value
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "rank":
|
||||
return variables.ConstantVariable.create(self.value.rank())
|
||||
if name == "size":
|
||||
@ -350,7 +351,7 @@ class ProcessGroupVariable(DistributedVariable):
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
if name == "group_name":
|
||||
return variables.ConstantVariable.create(self.value.group_name)
|
||||
if name in ["rank", "size"]:
|
||||
@ -361,7 +362,7 @@ class ProcessGroupVariable(DistributedVariable):
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
@staticmethod
|
||||
def is_process_group(value):
|
||||
def is_process_group(value: Any) -> bool:
|
||||
# we can't rely on importing/accessing torch distributed, it is not always built.
|
||||
if not DistributedVariable.is_available():
|
||||
return False
|
||||
@ -379,11 +380,11 @@ class BackwardHookVariable(VariableTracker):
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
tx,
|
||||
tx: "InstructionTranslator",
|
||||
module: VariableTracker,
|
||||
user_hooks: VariableTracker,
|
||||
user_pre_hooks: VariableTracker,
|
||||
):
|
||||
) -> "BackwardHookVariable":
|
||||
if not compiled_autograd.compiled_autograd_enabled:
|
||||
unimplemented_v2(
|
||||
gb_type="Module-level backwards hooks require compiled autograd.",
|
||||
@ -394,7 +395,9 @@ class BackwardHookVariable(VariableTracker):
|
||||
],
|
||||
)
|
||||
|
||||
def _in_graph_bw_hooks(bw_state: BackwardState):
|
||||
def _in_graph_bw_hooks(
|
||||
bw_state: BackwardState,
|
||||
) -> torch.utils.hooks.BackwardHook:
|
||||
"""
|
||||
Rather than installing the user hooks in the graph (which
|
||||
don't survive AotAutograd), we install hooks that will call
|
||||
@ -441,7 +444,7 @@ class BackwardHookVariable(VariableTracker):
|
||||
module: VariableTracker,
|
||||
user_hooks: VariableTracker,
|
||||
user_pre_hooks: VariableTracker,
|
||||
**options,
|
||||
**options: Any,
|
||||
) -> None:
|
||||
super().__init__(**options)
|
||||
self.proxy = proxy
|
||||
@ -449,13 +452,13 @@ class BackwardHookVariable(VariableTracker):
|
||||
self.user_hooks = user_hooks
|
||||
self.user_pre_hooks = user_pre_hooks
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> torch.fx.Proxy:
|
||||
return self.proxy
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
@ -463,7 +466,9 @@ class BackwardHookVariable(VariableTracker):
|
||||
return self._setup_hook(tx, name, *args, **kwargs)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def _setup_hook(self, tx: "InstructionTranslator", hook_method_name, args):
|
||||
def _setup_hook(
|
||||
self, tx: "InstructionTranslator", hook_method_name: str, args: VariableTracker
|
||||
) -> VariableTracker:
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
return wrap_fx_proxy(
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module provides iterator-related variable tracking functionality for Dynamo.
|
||||
It implements variable classes for handling Python iterators and itertools functions
|
||||
@ -16,7 +14,8 @@ handling of iterator operations during code transformation and optimization.
|
||||
"""
|
||||
|
||||
import itertools
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Sequence, TYPE_CHECKING, Union
|
||||
|
||||
from .. import graph_break_hints, polyfills, variables
|
||||
from ..bytecode_transformation import (
|
||||
@ -45,20 +44,20 @@ MAX_ITERATOR_LIMIT = 100 * 1024 # 100k
|
||||
|
||||
|
||||
class ItertoolsVariable(VariableTracker):
|
||||
def __init__(self, value, **kwargs) -> None:
|
||||
def __init__(self, value: Any, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.value = value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ItertoolsVariable({self.value})"
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return self.value
|
||||
|
||||
def call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: "list[VariableTracker]",
|
||||
args: Sequence["VariableTracker"],
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
# See also: module `torch._dynamo.polyfills.itertools`
|
||||
@ -111,7 +110,7 @@ class ItertoolsVariable(VariableTracker):
|
||||
hints=[*graph_break_hints.USER_ERROR],
|
||||
)
|
||||
|
||||
def retrieve_const_key(key):
|
||||
def retrieve_const_key(key: VariableTracker) -> Any:
|
||||
if isinstance(key, variables.SymNodeVariable):
|
||||
return key.evaluate_expr()
|
||||
elif isinstance(key, variables.ConstantVariable):
|
||||
@ -144,14 +143,14 @@ class ItertoolsVariable(VariableTracker):
|
||||
|
||||
if "key" in kwargs:
|
||||
|
||||
def keyfunc(x):
|
||||
def keyfunc(x: VariableTracker) -> Any:
|
||||
return retrieve_const_key(
|
||||
kwargs.get("key").call_function(tx, [x], {})
|
||||
kwargs.get("key").call_function(tx, [x], {}) # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
def keyfunc(x):
|
||||
def keyfunc(x: VariableTracker) -> Any:
|
||||
return retrieve_const_key(x)
|
||||
|
||||
result = []
|
||||
@ -219,10 +218,10 @@ class ItertoolsVariable(VariableTracker):
|
||||
|
||||
|
||||
class IteratorVariable(VariableTracker):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def next_variable(self, tx):
|
||||
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
|
||||
unimplemented_v2(
|
||||
gb_type="Unimplemented next() call",
|
||||
context=f"next({self})",
|
||||
@ -234,12 +233,16 @@ class IteratorVariable(VariableTracker):
|
||||
# Normally, iterators are accessed lazily.
|
||||
# Example of safe eager unpacking: list(map(f, seq))
|
||||
# Example of unsafe eager unpacking: list(islice(map(f, seq), 5))
|
||||
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
|
||||
result = []
|
||||
def force_unpack_var_sequence(
|
||||
self, tx: "InstructionTranslator"
|
||||
) -> list[VariableTracker]:
|
||||
result: list[VariableTracker] = []
|
||||
self.force_apply_to_var_sequence(tx, result.append)
|
||||
return result
|
||||
|
||||
def force_apply_to_var_sequence(self, tx, fn) -> None:
|
||||
def force_apply_to_var_sequence(
|
||||
self, tx: "InstructionTranslator", fn: Callable[[Any], Any]
|
||||
) -> None:
|
||||
while True:
|
||||
try:
|
||||
fn(self.next_variable(tx))
|
||||
@ -249,7 +252,7 @@ class IteratorVariable(VariableTracker):
|
||||
|
||||
# don't call force_unpack_var_sequence since it can mutate
|
||||
# IteratorVariable state!
|
||||
def has_force_unpack_var_sequence(self, tx) -> bool:
|
||||
def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@ -267,12 +270,12 @@ class ObjectIteratorVariable(IteratorVariable):
|
||||
> list(b) # empty list
|
||||
"""
|
||||
|
||||
def __init__(self, obj: VariableTracker, **kwargs):
|
||||
def __init__(self, obj: VariableTracker, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.obj = obj
|
||||
self.generator_exhausted = False
|
||||
|
||||
def next_variable(self, tx):
|
||||
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
|
||||
if self.generator_exhausted:
|
||||
raise_observed_exception(StopIteration, tx)
|
||||
|
||||
@ -286,15 +289,15 @@ class ObjectIteratorVariable(IteratorVariable):
|
||||
|
||||
|
||||
class RepeatIteratorVariable(IteratorVariable):
|
||||
def __init__(self, item: VariableTracker, **kwargs) -> None:
|
||||
def __init__(self, item: VariableTracker, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.item = item
|
||||
|
||||
# Repeat needs no mutation, clone self
|
||||
def next_variable(self, tx):
|
||||
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
|
||||
return self.item
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
[
|
||||
@ -308,7 +311,12 @@ class RepeatIteratorVariable(IteratorVariable):
|
||||
|
||||
|
||||
class CountIteratorVariable(IteratorVariable):
|
||||
def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
item: Union[int, VariableTracker] = 0,
|
||||
step: Union[int, VariableTracker] = 1,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if not isinstance(item, VariableTracker):
|
||||
item = ConstantVariable.create(item)
|
||||
@ -317,14 +325,14 @@ class CountIteratorVariable(IteratorVariable):
|
||||
self.item = item
|
||||
self.step = step
|
||||
|
||||
def next_variable(self, tx):
|
||||
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
|
||||
assert self.is_mutable()
|
||||
old_item = self.item
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.item = self.item.call_method(tx, "__add__", [self.step], {})
|
||||
return old_item
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
[
|
||||
@ -353,7 +361,7 @@ class ZipVariable(IteratorVariable):
|
||||
self,
|
||||
iterables: list[VariableTracker],
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(iterables, list)
|
||||
@ -362,16 +370,18 @@ class ZipVariable(IteratorVariable):
|
||||
self.index = 0
|
||||
self.strict = strict
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return zip
|
||||
|
||||
def has_unpack_var_sequence(self, tx) -> bool:
|
||||
def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
|
||||
return all(
|
||||
isinstance(it, list) or it.has_unpack_var_sequence(tx)
|
||||
for it in self.iterables
|
||||
)
|
||||
|
||||
def unpack_var_sequence(self, tx) -> list["VariableTracker"]:
|
||||
def unpack_var_sequence(
|
||||
self, tx: "InstructionTranslator"
|
||||
) -> list["VariableTracker"]:
|
||||
assert self.has_unpack_var_sequence(tx)
|
||||
iterables = []
|
||||
for it in self.iterables:
|
||||
@ -383,7 +393,7 @@ class ZipVariable(IteratorVariable):
|
||||
zipped = zip(*iterables, **kwargs)
|
||||
return [variables.TupleVariable(list(var)) for var in zipped]
|
||||
|
||||
def next_variable(self, tx):
|
||||
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
|
||||
assert self.is_mutable()
|
||||
|
||||
if len(self.iterables) == 0:
|
||||
@ -392,7 +402,9 @@ class ZipVariable(IteratorVariable):
|
||||
old_index = self.index
|
||||
args = []
|
||||
|
||||
def get_item(it):
|
||||
def get_item(
|
||||
it: Union[list[VariableTracker], VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if isinstance(it, list):
|
||||
if old_index >= len(it):
|
||||
raise_observed_exception(StopIteration, tx)
|
||||
@ -400,6 +412,7 @@ class ZipVariable(IteratorVariable):
|
||||
else:
|
||||
return it.next_variable(tx)
|
||||
|
||||
idx = 0
|
||||
try:
|
||||
for idx, it in enumerate(self.iterables):
|
||||
args.append(get_item(it))
|
||||
@ -420,7 +433,7 @@ class ZipVariable(IteratorVariable):
|
||||
raise
|
||||
handle_observed_exception(tx)
|
||||
raise UserError(
|
||||
ValueError,
|
||||
ValueError, # type: ignore[arg-type]
|
||||
"zip() has one argument of len differing from others",
|
||||
) from None
|
||||
raise
|
||||
@ -429,7 +442,7 @@ class ZipVariable(IteratorVariable):
|
||||
self.index += 1
|
||||
return variables.TupleVariable(args)
|
||||
|
||||
def reconstruct_items(self, codegen: "PyCodegen"):
|
||||
def reconstruct_items(self, codegen: "PyCodegen") -> None:
|
||||
for it in self.iterables:
|
||||
if isinstance(it, list):
|
||||
remaining_items = it[self.index :]
|
||||
@ -438,7 +451,7 @@ class ZipVariable(IteratorVariable):
|
||||
else:
|
||||
codegen(it)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True
|
||||
)
|
||||
@ -462,23 +475,23 @@ class MapVariable(ZipVariable):
|
||||
def __init__(
|
||||
self,
|
||||
fn: VariableTracker,
|
||||
iterables: list[Union[list[VariableTracker], VariableTracker]],
|
||||
**kwargs,
|
||||
iterables: list[VariableTracker],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(iterables, **kwargs)
|
||||
self.fn = fn
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return map
|
||||
|
||||
def has_unpack_var_sequence(self, tx) -> bool:
|
||||
def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
|
||||
return False
|
||||
|
||||
def next_variable(self, tx):
|
||||
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
|
||||
args = super().next_variable(tx)
|
||||
return self.fn.call_function(tx, args.items, {})
|
||||
return self.fn.call_function(tx, args.items, {}) # type: ignore[attr-defined]
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True
|
||||
)
|
||||
@ -505,23 +518,25 @@ class FilterVariable(IteratorVariable):
|
||||
def __init__(
|
||||
self,
|
||||
fn: VariableTracker,
|
||||
iterable: Union[list[VariableTracker], VariableTracker],
|
||||
**kwargs,
|
||||
iterable: list[VariableTracker],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.fn = fn
|
||||
self.iterable = iterable
|
||||
self.index = 0
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return filter
|
||||
|
||||
def has_unpack_var_sequence(self, tx) -> bool:
|
||||
def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
|
||||
return isinstance(self.iterable, list) or self.iterable.has_unpack_var_sequence(
|
||||
tx
|
||||
)
|
||||
|
||||
def unpack_var_sequence(self, tx) -> list["VariableTracker"]:
|
||||
def unpack_var_sequence(
|
||||
self, tx: "InstructionTranslator"
|
||||
) -> list["VariableTracker"]:
|
||||
assert self.has_unpack_var_sequence(tx)
|
||||
it = None
|
||||
if isinstance(self.iterable, list):
|
||||
@ -531,8 +546,8 @@ class FilterVariable(IteratorVariable):
|
||||
filtered = self.fn.call_function(tx, it, {})
|
||||
return [variables.TupleVariable([filtered])]
|
||||
|
||||
def next_variable(self, tx):
|
||||
def _next():
|
||||
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
|
||||
def _next() -> VariableTracker:
|
||||
old_index = self.index
|
||||
if isinstance(self.iterable, list):
|
||||
if old_index >= len(self.iterable):
|
||||
@ -555,7 +570,7 @@ class FilterVariable(IteratorVariable):
|
||||
if pred_res.as_python_constant():
|
||||
return item
|
||||
|
||||
def reconstruct_items(self, codegen: "PyCodegen"):
|
||||
def reconstruct_items(self, codegen: "PyCodegen") -> None:
|
||||
if isinstance(self.iterable, list):
|
||||
remaining_items = self.iterable[self.index :]
|
||||
codegen.foreach(remaining_items)
|
||||
@ -563,7 +578,7 @@ class FilterVariable(IteratorVariable):
|
||||
else:
|
||||
codegen(self.iterable)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "filter"))
|
||||
codegen(self.fn)
|
||||
self.reconstruct_items(codegen)
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module implements variable tracking for PyTorch optimizers during Dynamo tracing.
|
||||
|
||||
@ -24,9 +22,11 @@ optimizer-specific optimizations and safety guarantees.
|
||||
|
||||
import logging
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any, Iterable, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch._dynamo.variables.tensor import TensorVariable
|
||||
from torch._guards import Source
|
||||
from torch._logging import getArtifactLogger
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
@ -63,13 +63,14 @@ class GuardInstallException(Exception):
|
||||
perf_hint_log = getArtifactLogger(__name__, "perf_hints")
|
||||
|
||||
|
||||
def _is_static_for_cudagraphs(x):
|
||||
def _is_static_for_cudagraphs(x: Any) -> bool:
|
||||
from torch._inductor.cudagraph_trees import get_manager
|
||||
|
||||
if x.is_cuda:
|
||||
manager = get_manager(x.device.index, False)
|
||||
is_static_address = torch._dynamo.utils.get_static_address_type(x) is not None
|
||||
if manager:
|
||||
assert manager.current_node is not None
|
||||
return (
|
||||
is_static_address
|
||||
or manager.current_node._is_cuda_graph_recorded_tensor(x)
|
||||
@ -91,26 +92,30 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value,
|
||||
grad_to_source=None,
|
||||
static_tensor_names=None,
|
||||
tensor_to_source=None,
|
||||
**kwargs,
|
||||
value: torch.optim.Optimizer,
|
||||
grad_to_source: Optional[dict[Any, GradSource]] = None,
|
||||
static_tensor_names: Optional[set[str]] = None,
|
||||
tensor_to_source: Optional[dict[torch.Tensor, Source]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(value, **kwargs)
|
||||
self.value: torch.optim.Optimizer = value
|
||||
self.grad_to_source = grad_to_source or {}
|
||||
self.tensor_to_source = tensor_to_source or {}
|
||||
self.static_tensor_names = static_tensor_names or set()
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> "VariableTracker":
|
||||
"""This is an optimization to avoid tracing the very slow initialization of the optimizer"""
|
||||
if name == "_init_group":
|
||||
if not hasattr(self.value, "_init_group"):
|
||||
# Fallback: if the optimizer does not have _init_group, trace normally
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
try:
|
||||
self.graph_break_if_pending_mutation(tx)
|
||||
self.move_step_if_cpu()
|
||||
@ -135,11 +140,12 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
# Note: this allows us to intercept the call in call_method
|
||||
# in the typical case, we return a UserMethodVariable
|
||||
# which will directly inline
|
||||
if name in ("_init_group", "step"):
|
||||
assert self.source
|
||||
return GetAttrVariable(self, name, source=AttrSource(self.source, name))
|
||||
|
||||
if name == "param_groups":
|
||||
@ -153,7 +159,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
def graph_break_if_pending_mutation(self, tx):
|
||||
def graph_break_if_pending_mutation(self, tx: "InstructionTranslator") -> None:
|
||||
# If there are pending mutations on a parameter (due to using closure)
|
||||
# then we need to graph break to allow the python version of the parameter
|
||||
# to update, so that running _init_group will initialize the states with
|
||||
@ -167,12 +173,12 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
|
||||
raise Unsupported("Pending mutation on parameter")
|
||||
|
||||
def _set_capturable(self, tx):
|
||||
def _set_capturable(self, tx: "InstructionTranslator") -> None:
|
||||
from . import LazyVariableTracker
|
||||
|
||||
# We only set capturable if params are on cuda
|
||||
# and the state is not initialized
|
||||
def safe_to_set_capturable(group):
|
||||
def safe_to_set_capturable(group: dict[str, Any]) -> bool:
|
||||
all_uninitialized = True
|
||||
all_gpu = True
|
||||
|
||||
@ -199,10 +205,12 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
)
|
||||
param_group_vt.items[key] = ConstantVariable.create(True)
|
||||
|
||||
def get_python_args(self, *args, **kwargs):
|
||||
def get_python_args(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
"""Get python values equivalent to the variable tracker args"""
|
||||
|
||||
def map_arg(arg):
|
||||
def map_arg(arg: Any) -> Any:
|
||||
if isinstance(arg, ConstantVariable):
|
||||
return arg.as_python_constant()
|
||||
elif isinstance(arg, ListVariable) and not arg.items:
|
||||
@ -227,19 +235,19 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
# if this is the case, move it to the GPU
|
||||
# corresponding to the parameter
|
||||
# in most cases this is a no-op because the state is empty
|
||||
def move_step_if_cpu(self):
|
||||
def move_step_if_cpu(self) -> None:
|
||||
for p, state in self.value.state.items():
|
||||
if "step" in state and state["step"].is_cpu:
|
||||
state["step"] = state["step"].to(p.device)
|
||||
|
||||
def map_sources_and_install_guards(self, tx):
|
||||
def map_sources_and_install_guards(self, tx: "InstructionTranslator") -> None:
|
||||
from ..decorators import mark_static_address
|
||||
from .lazy import LazyVariableTracker
|
||||
|
||||
self.grad_to_source = {}
|
||||
self.tensor_to_source = {}
|
||||
|
||||
def mark_static(x):
|
||||
def mark_static(x: Any) -> None:
|
||||
mark_static_address(x, guard=True)
|
||||
|
||||
tree_map_only(torch.Tensor, mark_static, self.value.state)
|
||||
@ -252,12 +260,12 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
)
|
||||
|
||||
state_source = self.source and AttrSource(self.source, "state")
|
||||
|
||||
state_vt = VariableTracker.build(tx, self.value.state, state_source)
|
||||
|
||||
# We need to realize the top level state dict to populate
|
||||
# the guard locals
|
||||
state_vt.realize()
|
||||
assert state_source is not None
|
||||
tx.output.guard_on_key_order.add(state_source)
|
||||
|
||||
# Populate self.grad_to_source and self.tensor_to_source so that we can
|
||||
@ -310,14 +318,14 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
# Note: to avoid spam logs only warn if perf hint artifact is enabled
|
||||
# (NB: artifacts are only enabled at the debug or warning level)
|
||||
if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG):
|
||||
non_static_grads = [src.name() for src in non_static_grads]
|
||||
non_static_grad_names = [src.name() for src in non_static_grads]
|
||||
perf_hint_log.warning(
|
||||
(
|
||||
"Grad tensors %s will be copied during cudagraphs execution."
|
||||
"If using cudagraphs and the grad tensor addresses will be the same across runs,"
|
||||
" use torch._dynamo.decorators.mark_static_address to elide this copy.",
|
||||
),
|
||||
non_static_grads,
|
||||
non_static_grad_names,
|
||||
)
|
||||
|
||||
# We have to again iterate over the state dict to collect the
|
||||
@ -337,7 +345,9 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
p_state_source, ConstDictKeySource(p_state_source, inner_idx)
|
||||
)
|
||||
|
||||
def wrap_tensor(self, tx: "InstructionTranslator", tensor_value):
|
||||
def wrap_tensor(
|
||||
self, tx: "InstructionTranslator", tensor_value: torch.Tensor
|
||||
) -> TensorVariable:
|
||||
"""Wrap state tensor in a TensorVariable"""
|
||||
from ..decorators import mark_static_address
|
||||
|
||||
@ -364,8 +374,13 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
return VariableTracker.build(tx, tensor_value, source)
|
||||
|
||||
def update_list_args(
|
||||
self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs
|
||||
):
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: Iterable[VariableTracker],
|
||||
kwargs: Any,
|
||||
py_args: Iterable[Any],
|
||||
py_kwargs: Any,
|
||||
) -> None:
|
||||
"""Update the args and kwargs to the traced optimizer call"""
|
||||
for arg, py_arg in zip(args, py_args):
|
||||
if isinstance(arg, ListVariable):
|
||||
@ -380,13 +395,13 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
source = arg.source and GetItemSource(arg.source, i)
|
||||
arg.items.append(VariableTracker.build(tx, val, source))
|
||||
|
||||
def create_finalizer(self, tx):
|
||||
def create_finalizer(self, tx: "InstructionTranslator") -> None:
|
||||
names_to_delete = self.static_tensor_names
|
||||
value = self.value
|
||||
tc = tx.output.tracing_context
|
||||
|
||||
def init_finalizer(gm):
|
||||
def clear_static_tensor_refs():
|
||||
def init_finalizer(gm: torch.fx.GraphModule) -> None:
|
||||
def clear_static_tensor_refs() -> None:
|
||||
for name in names_to_delete:
|
||||
gm._buffers.pop(name, None)
|
||||
gm._parameters.pop(name, None)
|
||||
|
||||
@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
This module implements variable tracking for TorchScript objects during Dynamo tracing.
|
||||
|
||||
@ -22,8 +19,13 @@ by limiting operations to known-safe patterns and failing fast for unsafe usage.
|
||||
"""
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Iterable, TYPE_CHECKING, TypeGuard, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
from torch._guards import Source
|
||||
from torch.fx.proxy import Proxy
|
||||
|
||||
from .. import graph_break_hints
|
||||
from ..exc import unimplemented_v2, UnsafeScriptObjectError, Unsupported
|
||||
@ -31,10 +33,19 @@ from .base import VariableTracker
|
||||
from .user_defined import UserDefinedObjectVariable
|
||||
|
||||
|
||||
def _raise_hard_error_if_graph_break(reason):
|
||||
def deco(fn):
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _raise_hard_error_if_graph_break(
|
||||
reason: str,
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
def deco(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
@functools.wraps(fn)
|
||||
def graph_break_as_hard_error(*args, **kwargs):
|
||||
def graph_break_as_hard_error(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except Unsupported as e:
|
||||
@ -49,26 +60,26 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
|
||||
_fake_script_object_cache: dict[int, "TorchScriptObjectVariable"] = {}
|
||||
|
||||
@classmethod
|
||||
def is_matching_cls(cls, user_cls: type):
|
||||
def is_matching_cls(cls, user_cls: type) -> TypeGuard[torch.ScriptObject]:
|
||||
return issubclass(user_cls, torch.ScriptObject)
|
||||
|
||||
@staticmethod
|
||||
def create(proxy, value, **options):
|
||||
def create(proxy: Proxy, value: Any, **options: Any) -> "TorchScriptObjectVariable":
|
||||
return TorchScriptObjectVariable(proxy, value, **options)
|
||||
|
||||
def __init__(self, proxy, value, source, **kwargs) -> None:
|
||||
def __init__(self, proxy: Proxy, value: Any, source: Source, **kwargs: Any) -> None:
|
||||
super().__init__(value, **kwargs)
|
||||
self.proxy = proxy
|
||||
self.proxy.node.meta["example_value"] = value
|
||||
self.source = source
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> Proxy:
|
||||
return self.proxy
|
||||
|
||||
@_raise_hard_error_if_graph_break(
|
||||
"Dynamo cannot safely trace script object due to graph break."
|
||||
)
|
||||
def var_getattr(self, tx, name: str) -> VariableTracker:
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
|
||||
from ..source import AttrSource
|
||||
@ -95,7 +106,7 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
|
||||
"Use method calls instead of attribute access.",
|
||||
],
|
||||
)
|
||||
|
||||
assert self.source is not None
|
||||
return TorchHigherOrderOperatorVariable.make(
|
||||
call_torchbind,
|
||||
source=AttrSource(self.source, name),
|
||||
@ -110,7 +121,13 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
|
||||
@_raise_hard_error_if_graph_break(
|
||||
"Dynamo cannot safely trace script object due to graph break."
|
||||
)
|
||||
def call_method(self, tx, name, args, kwargs):
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: Iterable[Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> VariableTracker:
|
||||
unimplemented_v2(
|
||||
gb_type="Weird method call on TorchScript object",
|
||||
context=f"value={self.value}, method={name}",
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
from inspect import getattr_static
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any, Sequence, TYPE_CHECKING, TypeGuard
|
||||
|
||||
from torch._guards import Source
|
||||
from torch.backends.cuda import SDPAParams
|
||||
from torch.fx.proxy import Proxy
|
||||
|
||||
from ..bytecode_transformation import create_call_function
|
||||
from ..exc import Unsupported
|
||||
@ -29,9 +31,9 @@ class SDPAParamsVariable(VariableTracker):
|
||||
This is a read-only container."""
|
||||
|
||||
@staticmethod
|
||||
def create(tx: "InstructionTranslator", value, source):
|
||||
from torch.backends.cuda import SDPAParams
|
||||
|
||||
def create(
|
||||
tx: "InstructionTranslator", value: Any, source: Source
|
||||
) -> VariableTracker:
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
params = [
|
||||
@ -40,12 +42,14 @@ class SDPAParamsVariable(VariableTracker):
|
||||
]
|
||||
return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {})
|
||||
|
||||
def __init__(self, proxy, param_vars, **kwargs) -> None:
|
||||
def __init__(
|
||||
self, proxy: Proxy, param_vars: Sequence[VariableTracker], **kwargs: Any
|
||||
) -> None:
|
||||
self.proxy = proxy
|
||||
self.param_vars = param_vars
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
assert self.source is None
|
||||
assert self.param_vars is not None
|
||||
codegen.add_push_null(
|
||||
@ -54,7 +58,7 @@ class SDPAParamsVariable(VariableTracker):
|
||||
codegen.foreach(self.param_vars)
|
||||
codegen.extend_output(create_call_function(len(self.param_vars), False))
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> Proxy:
|
||||
return self.proxy
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
@ -80,7 +84,5 @@ class SDPAParamsVariable(VariableTracker):
|
||||
return wrap_fx_proxy(tx=tx, proxy=proxy)
|
||||
|
||||
@staticmethod
|
||||
def is_sdpa_params(value):
|
||||
from torch.backends.cuda import SDPAParams
|
||||
|
||||
def is_sdpa_params(value: Any) -> TypeGuard["SDPAParams"]:
|
||||
return value is SDPAParams
|
||||
|
||||
@ -128,6 +128,7 @@ class StreamVariable(VariableTracker):
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
|
||||
if other.source:
|
||||
assert self.source is not None
|
||||
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
|
||||
return ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type]
|
||||
|
||||
@ -1451,6 +1451,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
):
|
||||
# constant fold functions need to be guarded.
|
||||
if self.value in constant_fold_functions_need_guards:
|
||||
assert self.source is not None
|
||||
source = CallFunctionNoArgsSource(self.source)
|
||||
install_guard(source.make_guard(GuardBuilder.EQUALS_MATCH))
|
||||
# constant fold
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""TorchDynamo support for __torch_function__ tensor subclasses.
|
||||
|
||||
This module implements support for tensor subclasses with __torch_function__ overrides.
|
||||
@ -31,7 +29,8 @@ import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import operator
|
||||
from typing import TYPE_CHECKING
|
||||
from types import TracebackType
|
||||
from typing import Any, Generator, Iterable, Optional, TYPE_CHECKING
|
||||
|
||||
import torch._C
|
||||
import torch.utils._pytree as pytree
|
||||
@ -125,34 +124,134 @@ un_ops = [
|
||||
|
||||
|
||||
banned_attrs = [
|
||||
fn.__self__.__name__
|
||||
fn.__self__.__name__ # type: ignore[attr-defined]
|
||||
for fn in get_default_nowrap_functions()
|
||||
if is_tensor_base_attr_getter(fn)
|
||||
]
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_prev_stack_var_name():
|
||||
def get_prev_stack_var_name() -> str:
|
||||
from ..bytecode_transformation import unique_id
|
||||
|
||||
return unique_id("___prev_torch_function_mode_stack")
|
||||
|
||||
|
||||
class TorchFunctionModeVariable(GenericContextWrappingVariable):
|
||||
@staticmethod
|
||||
def is_supported_torch_function_mode(ty: type[TorchFunctionMode]) -> bool:
|
||||
# Supported in this sense means we can support graph breaks under the
|
||||
# context.
|
||||
# We are able to trace custom modes but if there are graph breaks under them
|
||||
# and they have a custom __enter__/__exit__ we don't handle this for the
|
||||
# same reason we don't handle generic context managers: there may be side effects
|
||||
# that are now affected by executing the function across two frames instead of one
|
||||
# Today we support the enter/exit of the default TorchFunctionMode as well as
|
||||
# DeviceContext (which is used for set_default_device)
|
||||
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
|
||||
not class_has_getattribute(ty)
|
||||
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
|
||||
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: Optional[TorchFunctionMode],
|
||||
source: Optional[Source] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if value is not None:
|
||||
super().__init__(value, **kwargs)
|
||||
self.value = value
|
||||
self.cm_obj = value # needed for BC with calling enter from CM code
|
||||
self.source = source # type: ignore[assignment]
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# This shouldn't be called unless we have a source
|
||||
assert self.source
|
||||
self.source.reconstruct(codegen)
|
||||
|
||||
def module_name(self) -> str:
|
||||
return self.value.__module__
|
||||
|
||||
def fn_name(self) -> str:
|
||||
return type(self.value).__name__
|
||||
|
||||
def python_type(self) -> type:
|
||||
return type(self.value)
|
||||
|
||||
def call_torch_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
fn: VariableTracker,
|
||||
types: TupleVariable,
|
||||
args: Iterable[Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> VariableTracker:
|
||||
return call_torch_function(
|
||||
tx,
|
||||
get_torch_function_fn(tx, self), # type: ignore[arg-type]
|
||||
fn,
|
||||
types,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
def enter(self, tx: "InstructionTranslator") -> VariableTracker:
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
if isinstance(self.value, NoEnterTorchFunctionMode):
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
TorchInGraphFunctionVariable(
|
||||
torch._C._push_on_torch_function_stack
|
||||
).call_function(tx, [self], {})
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker:
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
|
||||
tx, [], {}
|
||||
)
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def reconstruct_type(self, codegen: "PyCodegen") -> None:
|
||||
ty = NoEnterTorchFunctionMode
|
||||
codegen(
|
||||
AttrSource(
|
||||
codegen.tx.import_source(ty.__module__),
|
||||
ty.__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def supports_graph_breaks(self) -> bool:
|
||||
return True
|
||||
|
||||
def exit_on_graph_break(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
|
||||
class TorchFunctionModeStackStateManager:
|
||||
def __init__(self):
|
||||
self.stack = []
|
||||
def __init__(self) -> None:
|
||||
self.stack: list[Any] = []
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> None:
|
||||
self.stack = torch.overrides._get_current_function_mode_stack()
|
||||
clear_torch_function_mode_stack()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
set_torch_function_mode_stack(self.stack)
|
||||
self.stack = []
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temp_restore_stack(self):
|
||||
def temp_restore_stack(self) -> Generator[None, None, None]:
|
||||
prev = torch.overrides._get_current_function_mode_stack()
|
||||
set_torch_function_mode_stack(self.stack)
|
||||
try:
|
||||
@ -165,7 +264,7 @@ torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
|
||||
|
||||
|
||||
class SymbolicTorchFunctionState:
|
||||
def __init__(self, py_stack):
|
||||
def __init__(self, py_stack: Iterable[Any]) -> None:
|
||||
# This is annoyingly complicated because of how the torch function subclass + mode C API was designed
|
||||
# There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass
|
||||
# These are their definitions:
|
||||
@ -199,32 +298,41 @@ class SymbolicTorchFunctionState:
|
||||
|
||||
for i, val in enumerate(py_stack):
|
||||
self.mode_stack.append(
|
||||
LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i))
|
||||
LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def in_torch_function_mode(self):
|
||||
def in_torch_function_mode(self) -> bool:
|
||||
return len(self.mode_stack) > 0
|
||||
|
||||
def pop_torch_function_mode(self):
|
||||
def pop_torch_function_mode(self) -> TorchFunctionModeVariable:
|
||||
return self.mode_stack.pop()
|
||||
|
||||
def push_torch_function_mode(self, mode_var):
|
||||
def push_torch_function_mode(self, mode_var: TorchFunctionModeVariable) -> None:
|
||||
self.mode_stack.append(mode_var)
|
||||
|
||||
def call_torch_function_mode(self, tx, fn, types, args, kwargs):
|
||||
def call_torch_function_mode(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
fn: VariableTracker,
|
||||
types: TupleVariable,
|
||||
args: Iterable[Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
with self._pop_mode_for_inlining() as cur_mode:
|
||||
return cur_mode.call_torch_function(tx, fn, types, args, kwargs)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _pop_mode_for_inlining(self):
|
||||
def _pop_mode_for_inlining(
|
||||
self,
|
||||
) -> Generator[TorchFunctionModeVariable, None, None]:
|
||||
old_mode = self.cur_mode
|
||||
self.cur_mode = self.pop_torch_function_mode()
|
||||
self.cur_mode = self.pop_torch_function_mode() # type: ignore[assignment]
|
||||
try:
|
||||
yield self.cur_mode
|
||||
yield self.cur_mode # type: ignore[misc]
|
||||
finally:
|
||||
mode = self.cur_mode
|
||||
self.cur_mode = old_mode
|
||||
self.push_torch_function_mode(mode)
|
||||
self.push_torch_function_mode(mode) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class TorchFunctionModeStackVariable(VariableTracker):
|
||||
@ -244,16 +352,20 @@ class TorchFunctionModeStackVariable(VariableTracker):
|
||||
# each of the indices of other modes should be shifted left by 1 (-1)
|
||||
offset = 0
|
||||
|
||||
def __init__(self, source, symbolic_stack):
|
||||
def __init__(
|
||||
self,
|
||||
source: Source,
|
||||
symbolic_stack: collections.deque[TorchFunctionModeVariable],
|
||||
) -> None:
|
||||
self.source = source
|
||||
self.symbolic_stack = symbolic_stack
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
def reset(cls) -> None:
|
||||
cls.offset = 0
|
||||
|
||||
@classmethod
|
||||
def register_mutation(cls, tx: "InstructionTranslator"):
|
||||
def register_mutation(cls, tx: "InstructionTranslator") -> None:
|
||||
if cls.stack_value_singleton not in tx.output.side_effects:
|
||||
var = cls(
|
||||
source=Source(),
|
||||
@ -263,7 +375,7 @@ class TorchFunctionModeStackVariable(VariableTracker):
|
||||
tx.output.side_effects.mutation(var)
|
||||
|
||||
@classmethod
|
||||
def register_device_context_insertion(cls, tx: "InstructionTranslator"):
|
||||
def register_device_context_insertion(cls, tx: "InstructionTranslator") -> None:
|
||||
stack = tx.symbolic_torch_function_state.mode_stack
|
||||
if stack and cls.is_device_context(stack[0]):
|
||||
return
|
||||
@ -277,109 +389,28 @@ class TorchFunctionModeStackVariable(VariableTracker):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def clear_default_device(cls, tx: "InstructionTranslator"):
|
||||
def clear_default_device(cls, tx: "InstructionTranslator") -> None:
|
||||
stack = tx.symbolic_torch_function_state.mode_stack
|
||||
if stack and cls.is_device_context(stack[0]):
|
||||
stack.popleft()
|
||||
cls.offset -= 1
|
||||
|
||||
@staticmethod
|
||||
def is_device_context(var):
|
||||
def is_device_context(var: TorchFunctionModeVariable) -> bool:
|
||||
return isinstance(var.value, DeviceContext) or var.value is None
|
||||
|
||||
@classmethod
|
||||
def get_mode_index(cls, ind):
|
||||
def get_mode_index(cls, ind: int) -> int:
|
||||
return ind + cls.offset
|
||||
|
||||
|
||||
class TorchFunctionModeVariable(GenericContextWrappingVariable):
|
||||
@staticmethod
|
||||
def is_supported_torch_function_mode(ty):
|
||||
# Supported in this sense means we can support graph breaks under the
|
||||
# context.
|
||||
# We are able to trace custom modes but if there are graph breaks under them
|
||||
# and they have a custom __enter__/__exit__ we don't handle this for the
|
||||
# same reason we don't handle generic context managers: there may be side effects
|
||||
# that are now affected by executing the function across two frames instead of one
|
||||
# Today we support the enter/exit of the default TorchFunctionMode as well as
|
||||
# DeviceContext (which is used for set_default_device)
|
||||
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
|
||||
not class_has_getattribute(ty)
|
||||
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
|
||||
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
|
||||
)
|
||||
|
||||
def __init__(self, value, source=None, **kwargs):
|
||||
if value is not None:
|
||||
super().__init__(value, **kwargs)
|
||||
self.value = value
|
||||
self.cm_obj = value # needed for BC with calling enter from CM code
|
||||
self.source = source
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
# This shouldn't be called unless we have a source
|
||||
assert self.source
|
||||
self.source.reconstruct(codegen)
|
||||
|
||||
def module_name(self):
|
||||
return self.value.__module__
|
||||
|
||||
def fn_name(self):
|
||||
return type(self.value).__name__
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
|
||||
return call_torch_function(
|
||||
tx,
|
||||
get_torch_function_fn(tx, self),
|
||||
fn,
|
||||
types,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
def enter(self, tx):
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
if isinstance(self.value, NoEnterTorchFunctionMode):
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
TorchInGraphFunctionVariable(
|
||||
torch._C._push_on_torch_function_stack
|
||||
).call_function(tx, [self], {})
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
|
||||
tx, [], {}
|
||||
)
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def reconstruct_type(self, codegen: "PyCodegen"):
|
||||
ty = NoEnterTorchFunctionMode
|
||||
codegen(
|
||||
AttrSource(
|
||||
codegen.tx.import_source(ty.__module__),
|
||||
ty.__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def supports_graph_breaks(self):
|
||||
return True
|
||||
|
||||
def exit_on_graph_break(self):
|
||||
return False
|
||||
|
||||
|
||||
def _get_all_args(args, kwargs):
|
||||
def _get_all_args(
|
||||
args: Iterable[Any], kwargs: dict[str, Any]
|
||||
) -> Iterable[VariableTracker]:
|
||||
return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs))
|
||||
|
||||
|
||||
def _flatten_vts(vts):
|
||||
def _flatten_vts(vts: Iterable[VariableTracker]) -> list[VariableTracker]:
|
||||
from collections import deque
|
||||
|
||||
from .dicts import ConstDictVariable
|
||||
@ -391,7 +422,7 @@ def _flatten_vts(vts):
|
||||
while vts:
|
||||
vt = vts.popleft()
|
||||
|
||||
if not vt.is_realized() and vt.peek_type() in (dict, list, tuple):
|
||||
if not vt.is_realized() and vt.peek_type() in (dict, list, tuple): # type: ignore[attr-defined]
|
||||
vt.realize()
|
||||
|
||||
if vt.is_realized():
|
||||
@ -407,21 +438,28 @@ def _flatten_vts(vts):
|
||||
return output
|
||||
|
||||
|
||||
def _get_subclass_type(var):
|
||||
def _get_subclass_type(var: VariableTracker) -> type:
|
||||
assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable))
|
||||
return var.python_type()
|
||||
|
||||
|
||||
def _get_subclass_type_var(tx: "InstructionTranslator", var):
|
||||
assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable))
|
||||
def _get_subclass_type_var(
|
||||
tx: "InstructionTranslator", var: VariableTracker
|
||||
) -> VariableTracker:
|
||||
if isinstance(var, TensorWithTFOverrideVariable):
|
||||
return var.class_type_var(tx)
|
||||
elif isinstance(var, UserDefinedObjectVariable):
|
||||
source = var.source and TypeSource(var.source)
|
||||
return VariableTracker.build(tx, var.python_type(), source)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected type {type(var)}")
|
||||
|
||||
|
||||
def _is_attr_overridden(tx: "InstructionTranslator", var, name):
|
||||
def _is_attr_overridden(
|
||||
tx: "InstructionTranslator", var: VariableTracker, name: str
|
||||
) -> bool:
|
||||
if not isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)):
|
||||
return False
|
||||
import torch
|
||||
|
||||
overridden = False
|
||||
@ -434,7 +472,14 @@ def _is_attr_overridden(tx: "InstructionTranslator", var, name):
|
||||
return overridden
|
||||
|
||||
|
||||
def call_torch_function(tx, torch_function_var, fn, types, args, kwargs):
|
||||
def call_torch_function(
|
||||
tx: "InstructionTranslator",
|
||||
torch_function_var: VariableTracker,
|
||||
fn: VariableTracker,
|
||||
types: TupleVariable,
|
||||
args: Iterable[Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
# This emulates calling __torch_function__, which has a signature
|
||||
# def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
#
|
||||
@ -451,7 +496,9 @@ def call_torch_function(tx, torch_function_var, fn, types, args, kwargs):
|
||||
return torch_function_var.call_function(tx, tf_args, {})
|
||||
|
||||
|
||||
def get_torch_function_fn(tx: "InstructionTranslator", vt):
|
||||
def get_torch_function_fn(
|
||||
tx: "InstructionTranslator", vt: VariableTracker
|
||||
) -> VariableTracker:
|
||||
# The underlying function could be a classmethod, staticmethod, regular
|
||||
# function or a function with C-implementation. It doesn't matter as long as
|
||||
# they satisfy the calling convention in `call_torch_function`.
|
||||
@ -462,7 +509,9 @@ def get_torch_function_fn(tx: "InstructionTranslator", vt):
|
||||
return func_vt
|
||||
|
||||
|
||||
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
|
||||
def can_dispatch_torch_function(
|
||||
tx: "InstructionTranslator", args: Iterable[Any], kwargs: dict[str, Any]
|
||||
) -> bool:
|
||||
has_overridden_args = any(
|
||||
has_torch_function(arg) for arg in _get_all_args(args, kwargs)
|
||||
)
|
||||
@ -472,7 +521,12 @@ def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
|
||||
)
|
||||
|
||||
|
||||
def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
|
||||
def dispatch_torch_function(
|
||||
tx: "InstructionTranslator",
|
||||
fn: VariableTracker,
|
||||
args: Iterable[Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args"""
|
||||
|
||||
all_args = _get_all_args(args, kwargs)
|
||||
@ -518,7 +572,13 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_tensor_var(cls, tx, tensor_var, class_type, cls_source):
|
||||
def from_tensor_var(
|
||||
cls,
|
||||
tx: "InstructionTranslator",
|
||||
tensor_var: VariableTracker,
|
||||
class_type: type,
|
||||
cls_source: Source,
|
||||
) -> "TensorWithTFOverrideVariable":
|
||||
# [Note: __torch_function__] coerce `tensor_var` into a
|
||||
# TensorWithTFOverrideVariable. In eager, this is just a type change.
|
||||
import torch
|
||||
@ -533,7 +593,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
var.install_global(tx)
|
||||
return var
|
||||
|
||||
def install_global(self, tx):
|
||||
def install_global(self, tx: "InstructionTranslator") -> None:
|
||||
# stash the subclass type to rewrap an output tensor if needed
|
||||
# this is needed because the actual type needs to be available
|
||||
# each time the compiled artifact is run and outputs a wrapped tensor.
|
||||
@ -543,20 +603,20 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
self.global_mangled_class_name(tx), self.class_type
|
||||
)
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return self.class_type
|
||||
|
||||
def class_type_var(self, tx):
|
||||
def class_type_var(self, tx: "InstructionTranslator") -> VariableTracker:
|
||||
return TensorSubclassVariable(
|
||||
self.class_type, source=GlobalSource(self.global_mangled_class_name(tx))
|
||||
)
|
||||
|
||||
def global_mangled_class_name(self, tx):
|
||||
def global_mangled_class_name(self, tx: "InstructionTranslator") -> str:
|
||||
return get_safe_global_name(
|
||||
tx, f"__subclass_{self.class_type.__name__}", self.class_type
|
||||
)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
# [Note: __torch_function__] We currently only support attributes that are defined on
|
||||
# base tensors, custom attribute accesses will graph break.
|
||||
import torch
|
||||
@ -581,7 +641,8 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
and not attr_is_overridden
|
||||
and not inspect.ismethoddescriptor(getattr(torch.Tensor, name))
|
||||
):
|
||||
args, kwargs = [self], {}
|
||||
args = [self]
|
||||
kwargs: dict[Any, Any] = {}
|
||||
if can_dispatch_torch_function(tx, args, kwargs):
|
||||
if self.source:
|
||||
install_guard(
|
||||
@ -642,7 +703,14 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
|
||||
def call_torch_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
fn: VariableTracker,
|
||||
types: TupleVariable,
|
||||
args: Iterable[Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
# NOTE this assumes `__torch_function__` isn't modified during tracing.
|
||||
if not hasattr(self, "torch_function_fn"):
|
||||
self.torch_function_fn = get_torch_function_fn(tx, self)
|
||||
@ -658,8 +726,8 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
|
||||
Reference in New Issue
Block a user