# mypy: ignore-errors import collections import contextlib import functools import inspect import operator from typing import TYPE_CHECKING import torch._C import torch.utils._pytree as pytree from torch._guards import Source from torch.overrides import ( _get_overloaded_args, BaseTorchFunctionMode, get_default_nowrap_functions, TorchFunctionMode, ) from torch.utils._device import DeviceContext from ..exc import unimplemented from ..guards import GuardBuilder, install_guard from ..polyfills import NoEnterTorchFunctionMode from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource from ..utils import ( class_has_getattribute, clear_torch_function_mode_stack, get_safe_global_name, has_torch_function, is_tensor_base_attr_getter, set_torch_function_mode_stack, ) from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import GenericContextWrappingVariable from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable from .user_defined import UserDefinedObjectVariable if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator # [Note: __torch_function__] This feature is a prototype and has some rough edges (contact mlazos with issues): # At a high level, a torch function tensor subclass is represented as a TensorWithTFOverrideVariable, which dispatches # __torch_function__ on attribute accesses, method calls, and torch API calls. # The following is not supported: # - triggering __torch_function__ on tensor subclass non-tensor custom attributes # - graph breaking on mutating guardable tensor properties within a __torch_function__ context, this can cause # excessive recompiles in certain degenerate cases # - Matching the exact eager behavior of *ignoring* __torch_function__ objects in non-tensor argument positions of Torch API calls # The following is supported: # - static method impls of __torch_function__ on custom objects; this will trigger on torch API calls with the object as # any argument # - triggering __torch_function__ on torch API calls with tensor subclass arguments # - __torch_function__ calls on base tensor attribute access and method calls for tensor subclass instances # - matches the dispatch ordering behavior of eager __torch_function__ with subclass/object argumnents in any argument position # See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w # for more information on the design. # To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses in dynamo/config.py bin_ops = [ operator.pow, operator.mul, operator.matmul, operator.floordiv, operator.truediv, operator.mod, operator.add, operator.lt, operator.gt, operator.ge, operator.le, operator.ne, operator.eq, operator.sub, operator.ipow, operator.imul, operator.imatmul, operator.ifloordiv, operator.itruediv, operator.imod, operator.iadd, operator.isub, ] bin_int_ops = [ operator.and_, operator.or_, operator.xor, operator.iand, operator.ixor, operator.ior, ] un_int_ops = [operator.invert] tensor_and_int_ops = [ operator.lshift, operator.rshift, operator.ilshift, operator.irshift, operator.getitem, ] un_ops = [ operator.abs, operator.pos, operator.neg, operator.not_, # Note: this has a local scalar dense call operator.length_hint, ] BUILTIN_TO_TENSOR_FN_MAP = {} # These functions represent the r* versions of the above ops # Basically, if __add__(1, Tensor) is called, it is translated # to __radd__(Tensor, 1). # In the builtin var, we check if there is a tensor in the first args position, # if not, we swap the args and use the r* version of the op. BUILTIN_TO_TENSOR_RFN_MAP = {} def populate_builtin_to_tensor_fn_map(): global BUILTIN_TO_TENSOR_FN_MAP most_recent_func = None class GetMethodMode(BaseTorchFunctionMode): """ Mode to extract the correct methods from torch function invocations (Used to get the correct torch.Tensor methods from builtins) """ def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} nonlocal most_recent_func most_recent_func = func return func(*args, **kwargs) inp0 = torch.ones(1) inp1 = torch.ones(1) inp0_int = torch.ones(1, dtype=torch.int32) inp1_int = torch.ones(1, dtype=torch.int32) with GetMethodMode(): setups_and_oplists = [ (lambda o: o(inp0), un_ops), (lambda o: o(inp0_int), un_int_ops), (lambda o: o(inp0, inp1), bin_ops), (lambda o: o(inp0_int, inp1_int), bin_int_ops), (lambda o: o(inp0_int, 0), tensor_and_int_ops), ] for setup_fn, op_list in setups_and_oplists: for op in op_list: setup_fn(op) assert most_recent_func is not None BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func # gather the reverse functions rsetups_and_oplists = [ ( lambda o: o(1, inp1), bin_ops, ), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int)) (lambda o: o(1, inp1_int), bin_int_ops), (lambda o: o(0, inp0_int), tensor_and_int_ops), ] rskips = {operator.matmul, operator.imatmul, operator.getitem} for setup_fn, op_list in rsetups_and_oplists: for op in op_list: if op in rskips: continue setup_fn(op) assert most_recent_func is not None if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]: BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func populate_builtin_to_tensor_fn_map() banned_attrs = [ fn.__self__.__name__ for fn in get_default_nowrap_functions() if is_tensor_base_attr_getter(fn) ] @functools.lru_cache(None) def get_prev_stack_var_name(): from ..bytecode_transformation import unique_id return unique_id("___prev_torch_function_mode_stack") # Used to clear/restore the python torch function mode stack and temporarily restore it as needed class TorchFunctionModeStackStateManager: def __init__(self): self.stack = [] def __enter__(self): self.stack = torch.overrides._get_current_function_mode_stack() clear_torch_function_mode_stack() def __exit__(self, exc_type, exc_value, traceback): set_torch_function_mode_stack(self.stack) self.stack = [] @contextlib.contextmanager def temp_restore_stack(self): prev = torch.overrides._get_current_function_mode_stack() set_torch_function_mode_stack(self.stack) try: yield finally: set_torch_function_mode_stack(prev) torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager() class SymbolicTorchFunctionState: def __init__(self, py_stack): # This is annoyingly complicated because of how the torch function subclass + mode C API was designed # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass # These are their definitions: # 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered # (if either are entered, this will be False) # 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR # torch._C.DisableTorchFunction has been entered # To disambiguate these and keep myself sane I added a C API to check whether all torch function # concepts (modes and subclasses) are enabled. # This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate # the stack length from the enablement state of torch function modes. # This is important because now if a mode is pushed while dynamo is tracing, we know whether # or not torch function modes are enabled and whether we should trace it. self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled() # This differs from the C API of the same name # this will only be false iff we have entered torch._C.DisableTorchFunction # and does not take into account the mode stack length, while the C API bundles these # two concepts self.torch_function_mode_enabled = ( not torch._C._is_torch_function_all_disabled() ) self.cur_mode = None TorchFunctionModeStackVariable.reset() self.mode_stack: collections.deque[ TorchFunctionModeVariable ] = collections.deque() for i, val in enumerate(py_stack): self.mode_stack.append( LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) ) def in_torch_function_mode(self): return len(self.mode_stack) > 0 def pop_torch_function_mode(self): return self.mode_stack.pop() def push_torch_function_mode(self, mode_var): self.mode_stack.append(mode_var) def call_torch_function_mode(self, tx, fn, types, args, kwargs): with self._pop_mode_for_inlining() as cur_mode: return cur_mode.call_torch_function(tx, fn, types, args, kwargs) @contextlib.contextmanager def _pop_mode_for_inlining(self): old_mode = self.cur_mode self.cur_mode = self.pop_torch_function_mode() try: yield self.cur_mode finally: mode = self.cur_mode self.cur_mode = old_mode self.push_torch_function_mode(mode) class TorchFunctionModeStackVariable(VariableTracker): """Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation""" # singleton value representing the global torch function mode stack # singleton (it exists in C++) stack_value_singleton = object() # offset is used to track if we have inserted/removed a # device context which is always placed at the bottom of the stack # if a device context is inserted, the graph will run this mutation # so when we want to reconstruct any other modes on the stack # their indices should be shifted right by 1 (+1) # Conversely, if there was a device context on the stack, and the graph # mutates the stack to remove that context (set default device to None) # each of the indices of other modes should be shifted left by 1 (-1) offset = 0 def __init__(self, source, symbolic_stack): self.source = source self.symbolic_stack = symbolic_stack @classmethod def reset(cls): cls.offset = 0 @classmethod def register_mutation(cls, tx: "InstructionTranslator"): if cls.stack_value_singleton not in tx.output.side_effects: var = cls( source=Source(), symbolic_stack=tx.symbolic_torch_function_state.mode_stack, ) tx.output.side_effects.track_mutable(cls.stack_value_singleton, var) tx.output.side_effects.mutation(var) @classmethod def register_device_context_insertion(cls, tx: "InstructionTranslator"): stack = tx.symbolic_torch_function_state.mode_stack if stack and cls.is_device_context(stack[0]): return else: cls.offset += 1 stack.insert( 0, TorchFunctionModeVariable( None, source=TorchFunctionModeStackSource(-cls.offset) ), ) @classmethod def clear_default_device(cls, tx: "InstructionTranslator"): stack = tx.symbolic_torch_function_state.mode_stack if stack and cls.is_device_context(stack[0]): stack.popleft() cls.offset -= 1 @staticmethod def is_device_context(var): return isinstance(var.value, DeviceContext) or var.value is None @classmethod def get_mode_index(cls, ind): return ind + cls.offset class TorchFunctionModeVariable(GenericContextWrappingVariable): @staticmethod def is_supported_torch_function_mode(ty): # Supported in this sense means we can support graph breaks under the # context. # We are able to trace custom modes but if there are graph breaks under them # and they have a custom __enter__/__exit__ we don't handle this for the # same reason we don't handle generic context managers: there may be side effects # that are now affected by executing the funtion across two frames instead of one # Today we support the enter/exit of the default TorchFunctionMode as well as # DeviceContext (which is used for set_default_device) return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( not class_has_getattribute(ty) and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__ and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__ ) def __init__(self, value, source=None, **kwargs): if value is not None: super().__init__(value, **kwargs) self.value = value self.cm_obj = value # needed for BC with calling enter from CM code self.source = source def reconstruct(self, codegen): # This shouldn't be called unless we have a source assert self.source self.source.reconstruct(codegen) def module_name(self): return self.value.__module__ def fn_name(self): return type(self.value).__name__ def python_type(self): return type(self.value) def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): return call_torch_function( tx, self, build_torch_function_fn(tx, self.value, self.source), fn, types, args, kwargs, ) def enter(self, tx): from .torch import TorchInGraphFunctionVariable if isinstance(self.value, NoEnterTorchFunctionMode): return ConstantVariable.create(None) TorchInGraphFunctionVariable( torch._C._push_on_torch_function_stack ).call_function(tx, [self], {}) return ConstantVariable.create(None) def exit(self, tx: "InstructionTranslator", *args): from .torch import TorchInGraphFunctionVariable TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( tx, [], {} ) return ConstantVariable.create(None) def reconstruct_type(self, codegen): ty = NoEnterTorchFunctionMode codegen( AttrSource( codegen.tx.import_source(ty.__module__), ty.__name__, ) ) def supports_graph_breaks(self): return True def exit_on_graph_break(self): return False def _get_all_args(args, kwargs): return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs)) def _flatten_vts(vts): from collections import deque from .dicts import ConstDictVariable from .lists import ListVariable vts = deque(vts) output = [] while vts: vt = vts.pop() if not vt.is_realized() and vt.peek_type() in (dict, list, tuple): vt.realize() if vt.is_realized(): if isinstance(vt, ListVariable): vts.extend(vt.items) elif isinstance(vt, ConstDictVariable): vts.extend(vt.items.values()) output.append(vt) return output def _get_subclass_type(var): assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) return var.python_type() def _get_subclass_type_var(tx: "InstructionTranslator", var): assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) if isinstance(var, TensorWithTFOverrideVariable): return var.class_type_var(tx) elif isinstance(var, UserDefinedObjectVariable): source = var.source and TypeSource(var.source) return VariableTracker.build(tx, var.python_type(), source) def _is_attr_overidden(tx: "InstructionTranslator", var, name): import torch overridden = False try: attr_val = inspect.getattr_static(var.python_type(), name) overridden |= attr_val != getattr(torch.Tensor, name) except AttributeError: pass return overridden def call_torch_function( tx, torch_function_type, torch_function_var, fn, types, args, kwargs ): # signature: # def __torch_function__(cls, func, types, args=(), kwargs=None): tf_args = ( torch_function_type, fn, types, VariableTracker.build(tx, tuple(args)), VariableTracker.build(tx, kwargs), ) return tx.inline_user_function_return(torch_function_var, tf_args, {}) def build_torch_function_fn(tx: "InstructionTranslator", value, source): from types import FunctionType func = value.__torch_function__.__func__ if not isinstance(func, FunctionType): unimplemented("Builtin/C++ torch function implementations NYI") source = source and AttrSource(AttrSource(source, "__torch_function__"), "__func__") return VariableTracker.build(tx, func, source) def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): has_overridden_args = any( has_torch_function(arg) for arg in _get_all_args(args, kwargs) ) tf_state = tx.symbolic_torch_function_state return (has_overridden_args and tf_state.torch_function_subclass_enabled) or ( tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode() ) def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): """Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args""" all_args = _get_all_args(args, kwargs) overloaded_args = _get_overloaded_args( [arg for arg in all_args if has_torch_function(arg)], _get_subclass_type, ) types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]) if tx.symbolic_torch_function_state.in_torch_function_mode(): res = tx.symbolic_torch_function_state.call_torch_function_mode( tx, fn, types, args, kwargs ) if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): return res for arg in overloaded_args: res = arg.call_torch_function( tx, fn, types, args, kwargs, ) if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): return res unimplemented( f"All __torch_function__ overrides for call {fn} with args {args} and kwargs {kwargs} returned NotImplemented" ) class TensorWithTFOverrideVariable(TensorVariable): """ Represents a tensor subclass instance with a __torch_function__ override. """ def __init__(self, *args, **kwargs) -> None: self.torch_function_fn = kwargs.pop("torch_function_fn") super().__init__(*args, **kwargs) @classmethod def from_tensor_var(cls, tx, tensor_var, class_type, torch_function_fn): import torch kwargs = dict(tensor_var.__dict__) assert ( kwargs.pop("class_type") is torch.Tensor ), "invalid class type in TensorWithTFOverrideVariable.from_tensor_var" var = cls(torch_function_fn=torch_function_fn, class_type=class_type, **kwargs) var.install_global(tx) return var def install_global(self, tx): # stash the subclass type to rewrap an output tensor if needed # this is needed because the actual type needs to be available # each time the compiled artifact is run and outputs a wrapped tensor. if self.global_mangled_class_name(tx) not in tx.output.global_scope: # Safe because global_mangled_class_name figures it out tx.output.install_global_unsafe( self.global_mangled_class_name(tx), self.class_type ) def python_type(self): return self.class_type def class_type_var(self, tx): return TensorSubclassVariable( self.class_type, source=GlobalSource(self.global_mangled_class_name(tx)) ) def global_mangled_class_name(self, tx): return get_safe_global_name( tx, f"__subclass_{self.class_type.__name__}", self.class_type ) def var_getattr(self, tx: "InstructionTranslator", name): # [Note: __torch_function__] We currently only support attributes that are defined on # base tensors, custom attribute accesses will graph break. import torch if name in banned_attrs: unimplemented( f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported" ) if _is_attr_overidden(tx, self, name): unimplemented( f"Accessing overridden method/attribute {name} on a tensor" " subclass with a __torch_function__ override is not supported" ) if tx.output.torch_function_enabled and hasattr(torch.Tensor, name): if self.source: install_guard( AttrSource(AttrSource(self.source, "__class__"), name).make_guard( GuardBuilder.FUNCTION_MATCH ) ) get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__) return self.call_torch_function( tx, get_fn, TupleVariable([self.class_type_var(tx)]), [self], {}, ) else: return super().var_getattr(tx, name) def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): return call_torch_function( tx, self.class_type_var(tx), self.torch_function_fn, fn, types, args, kwargs, ) def call_method( self, tx, name, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": # This code block implements inlining the __torch_function__ override # of `call_method`. if tx.output.torch_function_enabled: import torch if _is_attr_overidden(tx, self, name): unimplemented( f"Calling overridden method {name} on a tensor" " subclass with a __torch_function__ override is not supported" ) # [Note: __torch_function__] Currently we only support methods that are defined on tensor # we will graph break in other cases this will need a bigger overhaul of extracting methods/comparing them for equality # We've established with the above check that the method is not overridden, so we guard that the method is the same # as the impl defined on tensor and retrieve it if self.source: source = AttrSource(AttrSource(self.source, "__class__"), name) value = inspect.getattr_static(self.python_type(), name) else: source = None value = getattr(torch.Tensor, name) func_var = VariableTracker.build(tx, value, source) return dispatch_torch_function(tx, func_var, [self] + args, kwargs) else: return super().call_method(tx, name, args, kwargs)