Files
pytorch/torch/_dynamo/variables/torch_function.py
Ryan Guo 238109ad32 [dynamo] Always trace into tensor subclass __torch_function__ (#149792)
This patch effectively ignores traceable_tensor_subclasses, allowing
Dynamo to always try tracing into the `__torch_function__` of tensor
subclass. This helps us with 2 things:
1. allowing users to directly benefit from better compilation of tensor
   subclass, by just upgrading pytorch, without having to change legacy
   library code (see earlier patches in the stack for examples).
2. potentially exposing more issues in compiling tensor subclass, so we
   can get signals and improve them.

As a consequence, it exposed and fixes 2 subtle bugs:
1. In `build_torch_function_fn`, we could get
   `torch._C._disabled_torch_function_impl` because we have a
   `Parameter` subclass without `__torch_function__` override or if we
   have a tensor subclass with `__torch_dispatch__` override. We graph
   break on this for now, and plan to add support -- the logic for
   simulating `torch._C._disabled_torch_function_impl` is already in
   `SuperVariable`, we just need to reuse it.
2. Sometimes we create `SyntheticLocalSource` and need to remove all the
   guards installed on it, but we only removed the ones whose source
   _is_ the created synthetic source `s`, but forgot about chained
   source like `s.foo`, this showed up as
   `SYNTHETIC_LOCAL['tmp_0'].__torch_function__.__func__`.

Differential Revision: [D71906141](https://our.internmc.facebook.com/intern/diff/D71906141)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149792
Approved by: https://github.com/jansel, https://github.com/mlazos
ghstack dependencies: #149482, #149483, #149484
2025-04-02 17:05:25 +00:00

733 lines
26 KiB
Python

# mypy: ignore-errors
"""TorchDynamo support for __torch_function__ tensor subclasses.
This module implements support for tensor subclasses with __torch_function__ overrides.
A tensor subclass instance is represented as a TensorWithTFOverrideVariable, which handles
dispatching __torch_function__ on attribute accesses, method calls, and torch API calls.
Unsupported features:
- Triggering __torch_function__ on tensor subclass non-tensor custom attributes
- Graph breaking on mutating guardable tensor properties within a __torch_function__ context
(can cause excessive recompiles in certain cases)
- Matching exact eager behavior of ignoring __torch_function__ objects in non-tensor
argument positions of Torch API calls
Supported features:
- Static method implementations of __torch_function__ on custom objects (triggers on torch
API calls with the object as any argument)
- Triggering __torch_function__ on torch API calls with tensor subclass arguments
- __torch_function__ calls on base tensor attribute access and method calls for tensor
subclass instances
- Matches dispatch ordering behavior of eager __torch_function__ with subclass/object
arguments in any position
See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w
for more information on the design.
"""
import collections
import contextlib
import functools
import inspect
import operator
from typing import TYPE_CHECKING
import torch._C
import torch.utils._pytree as pytree
from torch._guards import Source
from torch.overrides import (
_get_overloaded_args,
BaseTorchFunctionMode,
get_default_nowrap_functions,
TorchFunctionMode,
)
from torch.utils._device import DeviceContext
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..polyfills import NoEnterTorchFunctionMode
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
from ..utils import (
class_has_getattribute,
clear_torch_function_mode_stack,
get_safe_global_name,
has_torch_function,
is_tensor_base_attr_getter,
set_torch_function_mode_stack,
)
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import GenericContextWrappingVariable
from .functions import UserMethodVariable
from .lazy import LazyVariableTracker
from .lists import TupleVariable
from .tensor import TensorSubclassVariable, TensorVariable
from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
bin_ops = [
operator.pow,
operator.mul,
operator.matmul,
operator.floordiv,
operator.truediv,
operator.mod,
operator.add,
operator.lt,
operator.gt,
operator.ge,
operator.le,
operator.ne,
operator.eq,
operator.sub,
operator.ipow,
operator.imul,
operator.imatmul,
operator.ifloordiv,
operator.itruediv,
operator.imod,
operator.iadd,
operator.isub,
]
bin_int_ops = [
operator.and_,
operator.or_,
operator.xor,
operator.iand,
operator.ixor,
operator.ior,
]
un_int_ops = [operator.invert]
tensor_and_int_ops = [
operator.lshift,
operator.rshift,
operator.ilshift,
operator.irshift,
operator.getitem,
]
un_ops = [
operator.abs,
operator.pos,
operator.neg,
operator.not_, # Note: this has a local scalar dense call
operator.length_hint,
]
BUILTIN_TO_TENSOR_FN_MAP = {}
# These functions represent the r* versions of the above ops
# Basically, if __add__(1, Tensor) is called, it is translated
# to __radd__(Tensor, 1).
# In the builtin var, we check if there is a tensor in the first args position,
# if not, we swap the args and use the r* version of the op.
BUILTIN_TO_TENSOR_RFN_MAP = {}
def populate_builtin_to_tensor_fn_map():
global BUILTIN_TO_TENSOR_FN_MAP
most_recent_func = None
class GetMethodMode(BaseTorchFunctionMode):
"""
Mode to extract the correct methods from torch function invocations
(Used to get the correct torch.Tensor methods from builtins)
"""
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
nonlocal most_recent_func
most_recent_func = func
return func(*args, **kwargs)
inp0 = torch.ones(1)
inp1 = torch.ones(1)
inp0_int = torch.ones(1, dtype=torch.int32)
inp1_int = torch.ones(1, dtype=torch.int32)
with GetMethodMode():
setups_and_oplists = [
(lambda o: o(inp0), un_ops),
(lambda o: o(inp0_int), un_int_ops),
(lambda o: o(inp0, inp1), bin_ops),
(lambda o: o(inp0_int, inp1_int), bin_int_ops),
(lambda o: o(inp0_int, 0), tensor_and_int_ops),
]
for setup_fn, op_list in setups_and_oplists:
for op in op_list:
setup_fn(op)
assert most_recent_func is not None
BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func
# gather the reverse functions
rsetups_and_oplists = [
(
lambda o: o(1, inp1),
bin_ops,
), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
(lambda o: o(1, inp1_int), bin_int_ops),
(lambda o: o(0, inp0_int), tensor_and_int_ops),
]
rskips = {operator.matmul, operator.imatmul, operator.getitem}
for setup_fn, op_list in rsetups_and_oplists:
for op in op_list:
if op in rskips:
continue
setup_fn(op)
assert most_recent_func is not None
if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]:
BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func
populate_builtin_to_tensor_fn_map()
banned_attrs = [
fn.__self__.__name__
for fn in get_default_nowrap_functions()
if is_tensor_base_attr_getter(fn)
]
@functools.lru_cache(None)
def get_prev_stack_var_name():
from ..bytecode_transformation import unique_id
return unique_id("___prev_torch_function_mode_stack")
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
class TorchFunctionModeStackStateManager:
def __init__(self):
self.stack = []
def __enter__(self):
self.stack = torch.overrides._get_current_function_mode_stack()
clear_torch_function_mode_stack()
def __exit__(self, exc_type, exc_value, traceback):
set_torch_function_mode_stack(self.stack)
self.stack = []
@contextlib.contextmanager
def temp_restore_stack(self):
prev = torch.overrides._get_current_function_mode_stack()
set_torch_function_mode_stack(self.stack)
try:
yield
finally:
set_torch_function_mode_stack(prev)
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
class SymbolicTorchFunctionState:
def __init__(self, py_stack):
# This is annoyingly complicated because of how the torch function subclass + mode C API was designed
# There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass
# These are their definitions:
# 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered
# (if either are entered, this will be False)
# 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR
# torch._C.DisableTorchFunction has been entered
# To disambiguate these and keep myself sane I added a C API to check whether all torch function
# concepts (modes and subclasses) are enabled.
# This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate
# the stack length from the enablement state of torch function modes.
# This is important because now if a mode is pushed while dynamo is tracing, we know whether
# or not torch function modes are enabled and whether we should trace it.
self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled()
# This differs from the C API of the same name
# this will only be false iff we have entered torch._C.DisableTorchFunction
# and does not take into account the mode stack length, while the C API bundles these
# two concepts
self.torch_function_mode_enabled = (
not torch._C._is_torch_function_all_disabled()
)
self.cur_mode = None
TorchFunctionModeStackVariable.reset()
self.mode_stack: collections.deque[TorchFunctionModeVariable] = (
collections.deque()
)
for i, val in enumerate(py_stack):
self.mode_stack.append(
LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i))
)
def in_torch_function_mode(self):
return len(self.mode_stack) > 0
def pop_torch_function_mode(self):
return self.mode_stack.pop()
def push_torch_function_mode(self, mode_var):
self.mode_stack.append(mode_var)
def call_torch_function_mode(self, tx, fn, types, args, kwargs):
with self._pop_mode_for_inlining() as cur_mode:
return cur_mode.call_torch_function(tx, fn, types, args, kwargs)
@contextlib.contextmanager
def _pop_mode_for_inlining(self):
old_mode = self.cur_mode
self.cur_mode = self.pop_torch_function_mode()
try:
yield self.cur_mode
finally:
mode = self.cur_mode
self.cur_mode = old_mode
self.push_torch_function_mode(mode)
class TorchFunctionModeStackVariable(VariableTracker):
"""Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation"""
# singleton value representing the global torch function mode stack
# singleton (it exists in C++)
stack_value_singleton = object()
# offset is used to track if we have inserted/removed a
# device context which is always placed at the bottom of the stack
# if a device context is inserted, the graph will run this mutation
# so when we want to reconstruct any other modes on the stack
# their indices should be shifted right by 1 (+1)
# Conversely, if there was a device context on the stack, and the graph
# mutates the stack to remove that context (set default device to None)
# each of the indices of other modes should be shifted left by 1 (-1)
offset = 0
def __init__(self, source, symbolic_stack):
self.source = source
self.symbolic_stack = symbolic_stack
@classmethod
def reset(cls):
cls.offset = 0
@classmethod
def register_mutation(cls, tx: "InstructionTranslator"):
if cls.stack_value_singleton not in tx.output.side_effects:
var = cls(
source=Source(),
symbolic_stack=tx.symbolic_torch_function_state.mode_stack,
)
tx.output.side_effects.track_mutable(cls.stack_value_singleton, var)
tx.output.side_effects.mutation(var)
@classmethod
def register_device_context_insertion(cls, tx: "InstructionTranslator"):
stack = tx.symbolic_torch_function_state.mode_stack
if stack and cls.is_device_context(stack[0]):
return
else:
cls.offset += 1
stack.insert(
0,
TorchFunctionModeVariable(
None, source=TorchFunctionModeStackSource(-cls.offset)
),
)
@classmethod
def clear_default_device(cls, tx: "InstructionTranslator"):
stack = tx.symbolic_torch_function_state.mode_stack
if stack and cls.is_device_context(stack[0]):
stack.popleft()
cls.offset -= 1
@staticmethod
def is_device_context(var):
return isinstance(var.value, DeviceContext) or var.value is None
@classmethod
def get_mode_index(cls, ind):
return ind + cls.offset
class TorchFunctionModeVariable(GenericContextWrappingVariable):
@staticmethod
def is_supported_torch_function_mode(ty):
# Supported in this sense means we can support graph breaks under the
# context.
# We are able to trace custom modes but if there are graph breaks under them
# and they have a custom __enter__/__exit__ we don't handle this for the
# same reason we don't handle generic context managers: there may be side effects
# that are now affected by executing the funtion across two frames instead of one
# Today we support the enter/exit of the default TorchFunctionMode as well as
# DeviceContext (which is used for set_default_device)
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
not class_has_getattribute(ty)
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
)
def __init__(self, value, source=None, **kwargs):
if value is not None:
super().__init__(value, **kwargs)
self.value = value
self.cm_obj = value # needed for BC with calling enter from CM code
self.source = source
def reconstruct(self, codegen):
# This shouldn't be called unless we have a source
assert self.source
self.source.reconstruct(codegen)
def module_name(self):
return self.value.__module__
def fn_name(self):
return type(self.value).__name__
def python_type(self):
return type(self.value)
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
return call_torch_function(
tx,
self,
build_torch_function_fn(tx, self.value, self.source),
fn,
types,
args,
kwargs,
)
def enter(self, tx):
from .torch import TorchInGraphFunctionVariable
if isinstance(self.value, NoEnterTorchFunctionMode):
return ConstantVariable.create(None)
TorchInGraphFunctionVariable(
torch._C._push_on_torch_function_stack
).call_function(tx, [self], {})
return ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
from .torch import TorchInGraphFunctionVariable
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
tx, [], {}
)
return ConstantVariable.create(None)
def reconstruct_type(self, codegen):
ty = NoEnterTorchFunctionMode
codegen(
AttrSource(
codegen.tx.import_source(ty.__module__),
ty.__name__,
)
)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return False
def _get_all_args(args, kwargs):
return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs))
def _flatten_vts(vts):
from collections import deque
from .dicts import ConstDictVariable
from .lists import ListVariable
vts = deque(vts)
output = []
while vts:
vt = vts.pop()
if not vt.is_realized() and vt.peek_type() in (dict, list, tuple):
vt.realize()
if vt.is_realized():
if isinstance(vt, ListVariable):
vts.extend(vt.items)
elif isinstance(vt, ConstDictVariable):
vts.extend(vt.items.values())
output.append(vt)
return output
def _get_subclass_type(var):
assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable))
return var.python_type()
def _get_subclass_type_var(tx: "InstructionTranslator", var):
assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable))
if isinstance(var, TensorWithTFOverrideVariable):
return var.class_type_var(tx)
elif isinstance(var, UserDefinedObjectVariable):
source = var.source and TypeSource(var.source)
return VariableTracker.build(tx, var.python_type(), source)
def _is_attr_overidden(tx: "InstructionTranslator", var, name):
import torch
overridden = False
try:
attr_val = inspect.getattr_static(var.python_type(), name)
overridden |= attr_val != getattr(torch.Tensor, name)
except AttributeError:
pass
return overridden
def call_torch_function(
tx, torch_function_type, torch_function_var, fn, types, args, kwargs
):
# signature:
# def __torch_function__(cls, func, types, args=(), kwargs=None):
tf_args = (
torch_function_type,
fn,
types,
VariableTracker.build(tx, tuple(args)),
VariableTracker.build(tx, kwargs),
)
return tx.inline_user_function_return(torch_function_var, tf_args, {})
def build_torch_function_fn(tx: "InstructionTranslator", cls_or_obj, source):
from types import FunctionType
# If we reach here, the target `__torch_function__` should have been
# annotated with `@classmethod`, so accessing it always yield a bound
# method, and the actual `__torch_function__` impl is inside the bound
# `__func__`.
func = cls_or_obj.__torch_function__.__func__
if not isinstance(func, FunctionType):
unimplemented("Builtin/C++ torch function implementations NYI")
func_source = None
if source:
func_source = AttrSource(AttrSource(source, "__torch_function__"), "__func__")
return VariableTracker.build(tx, func, func_source)
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
has_overridden_args = any(
has_torch_function(arg) for arg in _get_all_args(args, kwargs)
)
tf_state = tx.symbolic_torch_function_state
return (has_overridden_args and tf_state.torch_function_subclass_enabled) or (
tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode()
)
def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
"""Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args"""
all_args = _get_all_args(args, kwargs)
overloaded_args = _get_overloaded_args(
[arg for arg in all_args if has_torch_function(arg)],
_get_subclass_type,
)
types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args])
if tx.symbolic_torch_function_state.in_torch_function_mode():
res = tx.symbolic_torch_function_state.call_torch_function_mode(
tx, fn, types, args, kwargs
)
if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
return res
for arg in overloaded_args:
res = arg.call_torch_function(
tx,
fn,
types,
args,
kwargs,
)
if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
return res
unimplemented(
f"All __torch_function__ overrides for call {fn} with args {args} and kwargs {kwargs} returned NotImplemented"
)
class TensorWithTFOverrideVariable(TensorVariable):
"""
Represents a tensor subclass instance with a __torch_function__ override.
"""
def __init__(self, *args, **kwargs) -> None:
self.torch_function_fn = kwargs.pop("torch_function_fn")
super().__init__(*args, **kwargs)
@classmethod
def from_tensor_var(cls, tx, tensor_var, class_type, cls_source):
# [Note: __torch_function__] coerce `tensor_var` into a
# TensorWithTFOverrideVariable. In eager, this is just a type change.
import torch
# This simulates shallow-copying the tensor object.
kwargs = dict(tensor_var.__dict__)
input_tensor_type = kwargs.pop("class_type")
assert input_tensor_type in (torch.Tensor, torch.nn.Parameter), (
f"invalid class type {input_tensor_type} in TensorWithTFOverrideVariable.from_tensor_var"
)
torch_fn_var = build_torch_function_fn(tx, class_type, cls_source)
var = cls(torch_function_fn=torch_fn_var, class_type=class_type, **kwargs)
var.install_global(tx)
return var
def install_global(self, tx):
# stash the subclass type to rewrap an output tensor if needed
# this is needed because the actual type needs to be available
# each time the compiled artifact is run and outputs a wrapped tensor.
if self.global_mangled_class_name(tx) not in tx.output.global_scope:
# Safe because global_mangled_class_name figures it out
tx.output.install_global_unsafe(
self.global_mangled_class_name(tx), self.class_type
)
def python_type(self):
return self.class_type
def class_type_var(self, tx):
return TensorSubclassVariable(
self.class_type, source=GlobalSource(self.global_mangled_class_name(tx))
)
def global_mangled_class_name(self, tx):
return get_safe_global_name(
tx, f"__subclass_{self.class_type.__name__}", self.class_type
)
def var_getattr(self, tx: "InstructionTranslator", name):
# [Note: __torch_function__] We currently only support attributes that are defined on
# base tensors, custom attribute accesses will graph break.
import torch
if name in banned_attrs:
unimplemented(
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
)
# Handle non-overriden attributes inherited from `torch.Tensor`.
attr_is_overriden = _is_attr_overidden(tx, self, name)
if hasattr(torch.Tensor, name) and not attr_is_overriden:
if tx.output.torch_function_enabled:
if self.source:
install_guard(
AttrSource(
AttrSource(self.source, "__class__"), name
).make_guard(GuardBuilder.FUNCTION_MATCH)
)
get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__)
return self.call_torch_function(
tx,
get_fn,
TupleVariable([self.class_type_var(tx)]),
[self],
{},
)
else:
# `TensorVariable.var_getattr` doesn't handle user-defined
# function/attribute well, so we explicitly handle them here.
#
# TODO move this logic into `TensorVariable`, or try to merge it
# with similar logic in `UserDefinedObjectVariable`.
try:
attr = inspect.getattr_static(self.class_type, name)
except AttributeError:
pass
else:
import types
cls_source = GlobalSource(self.global_mangled_class_name(tx))
attr_source = AttrSource(cls_source, name)
if isinstance(attr, types.FunctionType):
install_guard(attr_source.make_guard(GuardBuilder.FUNCTION_MATCH))
return UserMethodVariable(attr, self)
elif isinstance(attr, property):
getter_source = AttrSource(attr_source, "fget")
getter = attr.fget
getter_var = UserMethodVariable(getter, self, source=getter_source)
return getter_var.call_function(tx, [], {})
elif attr_is_overriden:
unimplemented(
f"Currently only support accessing overridden attributes that are functions or properties, but got {type(attr)}" # noqa: B950
)
return super().var_getattr(tx, name)
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
return call_torch_function(
tx,
self.class_type_var(tx),
self.torch_function_fn,
fn,
types,
args,
kwargs,
)
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
# This code block implements inlining the __torch_function__ override
# of `call_method`.
if tx.output.torch_function_enabled:
import torch
if _is_attr_overidden(tx, self, name):
unimplemented(
f"Calling overridden method {name} on a tensor"
" subclass with a __torch_function__ override is not supported"
)
# [Note: __torch_function__] Currently we only support methods that are defined on tensor
# we will graph break in other cases this will need a bigger overhaul of extracting methods/comparing them for equality
# We've established with the above check that the method is not overridden, so we guard that the method is the same
# as the impl defined on tensor and retrieve it
if self.source:
source = AttrSource(AttrSource(self.source, "__class__"), name)
value = inspect.getattr_static(self.python_type(), name)
else:
source = None
value = getattr(torch.Tensor, name)
func_var = VariableTracker.build(tx, value, source)
return dispatch_torch_function(tx, func_var, [self] + args, kwargs)
else:
return super().call_method(tx, name, args, kwargs)