Files
pytorch/torch/_dynamo/variables/torch_function.py
Michael Lazos 51bc839b94 [Dynamo] Trace enter/exit of TorchFunctionModes (#135422) (#137114)
This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode)

Typically the bytecode for a context manager looks like this during a graph break:
1. graph call
2. enter context
3. unsupported code
4. exit context
5. resume call

resume fn structure:
1. enter context
2. jump
...
3. exit context

The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack).

So for torch function modes the structure of our output code is this:

1. graph call
2. mutate tf mode stack to replay mutations
4. unsupported code
5. on exception restore stack
6. resume function

Then our resume fn looks like this:

1. no-op enter torch function mode
2. jump
3.  exit tf mode

To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context).

Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly.
Approved by: https://github.com/williamwen42
ghstack dependencies: #134732, #133137, #135443, #135444

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137114
Approved by: https://github.com/yanboliang
2024-10-07 18:55:26 +00:00

579 lines
21 KiB
Python

# mypy: ignore-errors
import collections
import contextlib
import functools
import inspect
from typing import Deque, Dict, List, TYPE_CHECKING
import torch._C
import torch.utils._pytree as pytree
from torch._guards import Source
from torch.overrides import (
_get_overloaded_args,
get_default_nowrap_functions,
TorchFunctionMode,
)
from torch.utils._device import DeviceContext
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..polyfills import NoEnterTorchFunctionMode
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
from ..utils import (
class_has_getattribute,
clear_torch_function_mode_stack,
get_safe_global_name,
has_torch_function,
is_tensor_base_attr_getter,
set_torch_function_mode_stack,
)
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import GenericContextWrappingVariable
from .lazy import LazyVariableTracker
from .lists import TupleVariable
from .tensor import TensorSubclassVariable, TensorVariable
from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
# [Note: __torch_function__] This feature is a prototype and has some rough edges (contact mlazos with issues):
# At a high level, a torch function tensor subclass is represented as a TensorWithTFOverrideVariable, which dispatches
# __torch_function__ on attribute accesses, method calls, and torch API calls.
# The following is not supported:
# - triggering __torch_function__ on tensor subclass non-tensor custom attributes
# - graph breaking on mutating guardable tensor properties within a __torch_function__ context, this can cause
# excessive recompiles in certain degenerate cases
# - Matching the exact eager behavior of *ignoring* __torch_function__ objects in non-tensor argument positions of Torch API calls
# The following is supported:
# - static method impls of __torch_function__ on custom objects; this will trigger on torch API calls with the object as
# any argument
# - triggering __torch_function__ on torch API calls with tensor subclass arguments
# - __torch_function__ calls on base tensor attribute access and method calls for tensor subclass instances
# - matches the dispatch ordering behavior of eager __torch_function__ with subclass/object argumnents in any argument position
# See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w
# for more information on the design.
# To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses in dynamo/config.py
banned_attrs = [
fn.__self__.__name__
for fn in get_default_nowrap_functions()
if is_tensor_base_attr_getter(fn)
]
# Today set default device is placed in the graph and guarded on separately
# so we should not trace through it. In the future we can trace it once
# mode tracing is implemented and not put in the graph, but this is more
# of a BE project and can be evaluated later
IGNORED_MODES = {DeviceContext}
@functools.lru_cache(None)
def get_prev_stack_var_name():
from ..bytecode_transformation import unique_id
return unique_id("___prev_torch_function_mode_stack")
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
class TorchFunctionModeStackStateManager:
def __init__(self):
self.stack = []
def __enter__(self):
self.stack = torch.overrides._get_current_function_mode_stack()
clear_torch_function_mode_stack()
def __exit__(self, exc_type, exc_value, traceback):
set_torch_function_mode_stack(self.stack)
self.stack = []
@contextlib.contextmanager
def temp_restore_stack(self):
prev = torch.overrides._get_current_function_mode_stack()
set_torch_function_mode_stack(self.stack)
try:
yield
finally:
set_torch_function_mode_stack(prev)
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
class SymbolicTorchFunctionState:
def __init__(self, py_stack):
# This is annoyingly complicated because of how the torch function subclass + mode C API was designed
# There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass
# These are their definitions:
# 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered
# (if either are entered, this will be False)
# 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR
# torch._C.DisableTorchFunction has been entered
# To disambiguate these and keep myself sane I added a C API to check whether all torch function
# concepts (modes and subclasses) are enabled.
# This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate
# the stack length from the enablement state of torch function modes.
# This is important because now if a mode is pushed while dynamo is tracing, we know whether
# or not torch function modes are enabled and whether we should trace it.
self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled()
# This differs from the C API of the same name
# this will only be false iff we have entered torch._C.DisableTorchFunction
# and does not take into account the mode stack length, while the C API bundles these
# two concepts
self.torch_function_mode_enabled = (
not torch._C._is_torch_function_all_disabled()
)
self.cur_mode = None
TorchFunctionModeStackVariable.reset()
self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque()
for i, val in enumerate(py_stack):
self.mode_stack.append(
LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i))
)
def in_torch_function_mode(self):
return len(self.mode_stack) > 0
def pop_torch_function_mode(self):
return self.mode_stack.pop()
def push_torch_function_mode(self, mode_var):
self.mode_stack.append(mode_var)
def call_torch_function_mode(self, tx, fn, types, args, kwargs):
with self._pop_mode_for_inlining() as cur_mode:
return cur_mode.call_torch_function(tx, fn, types, args, kwargs)
@contextlib.contextmanager
def _pop_mode_for_inlining(self):
old_mode = self.cur_mode
self.cur_mode = self.pop_torch_function_mode()
try:
yield self.cur_mode
finally:
mode = self.cur_mode
self.cur_mode = old_mode
self.push_torch_function_mode(mode)
class TorchFunctionModeStackVariable(VariableTracker):
"""Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation"""
# singleton value representing the global torch function mode stack
# singleton (it exists in C++)
stack_value_singleton = object()
# offset is used to track if we have inserted/removed a
# device context which is always placed at the bottom of the stack
# if a device context is inserted, the graph will run this mutation
# so when we want to reconstruct any other modes on the stack
# their indices should be shifted right by 1 (+1)
# Conversely, if there was a device context on the stack, and the graph
# mutates the stack to remove that context (set default device to None)
# each of the indices of other modes should be shifted left by 1 (-1)
offset = 0
def __init__(self, source, symbolic_stack):
self.source = source
self.symbolic_stack = symbolic_stack
@classmethod
def reset(cls):
cls.offset = 0
@classmethod
def register_mutation(cls, tx: "InstructionTranslator"):
if cls.stack_value_singleton not in tx.output.side_effects:
var = cls(
source=Source(),
symbolic_stack=tx.symbolic_torch_function_state.mode_stack,
)
tx.output.side_effects.track_mutable(cls.stack_value_singleton, var)
tx.output.side_effects.mutation(var)
@classmethod
def register_device_context_insertion(cls, tx: "InstructionTranslator"):
stack = tx.symbolic_torch_function_state.mode_stack
if stack and cls.is_device_context(stack[0]):
return
else:
cls.offset += 1
stack.insert(
0,
TorchFunctionModeVariable(
None, source=TorchFunctionModeStackSource(-cls.offset)
),
)
@classmethod
def clear_default_device(cls, tx: "InstructionTranslator"):
stack = tx.symbolic_torch_function_state.mode_stack
if stack and cls.is_device_context(stack[0]):
stack.popleft()
cls.offset -= 1
@staticmethod
def is_device_context(var):
return isinstance(var.value, DeviceContext) or var.value is None
@classmethod
def get_mode_index(cls, ind):
return ind + cls.offset
class TorchFunctionModeVariable(GenericContextWrappingVariable):
@staticmethod
def is_supported_torch_function_mode(ty):
# Supported in this sense means we can support graph breaks under the
# context.
# We are able to trace custom modes but if there are graph breaks under them
# and they have a custom __enter__/__exit__ we don't handle this for the
# same reason we don't handle generic context managers: there may be side effects
# that are now affected by executing the funtion across two frames instead of one
# Today we support the enter/exit of the default TorchFunctionMode as well as
# DeviceContext (which is used for set_default_device)
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
not class_has_getattribute(ty)
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
)
def __init__(self, value, source=None, **kwargs):
if value is not None:
super().__init__(value, **kwargs)
self.value = value
self.cm_obj = value # needed for BC with calling enter from CM code
self.source = source
def reconstruct(self, codegen):
# This shouldn't be called unless we have a source
assert self.source
self.source.reconstruct(codegen)
def module_name(self):
return self.value.__module__
def fn_name(self):
return type(self.value).__name__
def python_type(self):
return type(self.value)
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
return call_torch_function(
tx,
self,
build_torch_function_fn(tx, self.value, self.source),
fn,
types,
args,
kwargs,
)
def enter(self, tx):
from .torch import TorchInGraphFunctionVariable
if isinstance(self.value, NoEnterTorchFunctionMode):
return ConstantVariable.create(None)
TorchInGraphFunctionVariable(
torch._C._push_on_torch_function_stack
).call_function(tx, [self], {})
return ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
from .torch import TorchInGraphFunctionVariable
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
tx, [], {}
)
return ConstantVariable.create(None)
def reconstruct_type(self, codegen):
ty = NoEnterTorchFunctionMode
codegen(
AttrSource(
codegen.tx.import_source(ty.__module__),
ty.__name__,
)
)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return False
def _get_all_args(args, kwargs):
return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs))
def _flatten_vts(vts):
from collections import deque
from .dicts import ConstDictVariable
from .lazy import LazyVariableTracker
from .lists import ListVariable
vts = deque(vts)
output = []
while vts:
vt = vts.pop()
LazyVariableTracker.realize_all(vt)
if isinstance(vt, ListVariable):
vts.extend(vt.items)
elif isinstance(vt, ConstDictVariable):
vts.extend(vt.items.values())
else:
output.append(vt)
return output
def _get_subclass_type(var):
assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable))
return var.python_type()
def _get_subclass_type_var(tx: "InstructionTranslator", var):
assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable))
if isinstance(var, TensorWithTFOverrideVariable):
return var.class_type_var(tx)
elif isinstance(var, UserDefinedObjectVariable):
from .builder import SourcelessBuilder, VariableBuilder
if var.source:
return VariableBuilder(tx, TypeSource(var.source))(var.python_type())
else:
return SourcelessBuilder.create(tx, var.python_type())
def _is_attr_overidden(tx: "InstructionTranslator", var, name):
import torch
overridden = False
try:
attr_val = inspect.getattr_static(var.python_type(), name)
overridden |= attr_val != getattr(torch.Tensor, name)
except AttributeError:
pass
return overridden
def call_torch_function(
tx, torch_function_type, torch_function_var, fn, types, args, kwargs
):
from .builder import SourcelessBuilder
# signature:
# def __torch_function__(cls, func, types, args=(), kwargs=None):
tf_args = (
torch_function_type,
fn,
types,
SourcelessBuilder.create(tx, tuple(args)),
SourcelessBuilder.create(tx, kwargs),
)
return tx.inline_user_function_return(torch_function_var, tf_args, {})
def build_torch_function_fn(tx: "InstructionTranslator", value, source):
from .builder import SourcelessBuilder, VariableBuilder
if source:
return VariableBuilder(
tx,
AttrSource(AttrSource(source, "__torch_function__"), "__func__"),
)(value.__torch_function__.__func__)
else:
return SourcelessBuilder.create(tx, value.__torch_function__.__func__)
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
has_overridden_args = any(
has_torch_function(arg) for arg in _get_all_args(args, kwargs)
)
tf_state = tx.symbolic_torch_function_state
return (has_overridden_args and tf_state.torch_function_subclass_enabled) or (
tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode()
)
def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
"""Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args"""
all_args = _get_all_args(args, kwargs)
overloaded_args = _get_overloaded_args(
[arg for arg in all_args if has_torch_function(arg)],
_get_subclass_type,
)
types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args])
if tx.symbolic_torch_function_state.in_torch_function_mode():
res = tx.symbolic_torch_function_state.call_torch_function_mode(
tx, fn, types, args, kwargs
)
if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
return res
for arg in overloaded_args:
res = arg.call_torch_function(
tx,
fn,
types,
args,
kwargs,
)
if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
return res
unimplemented(
f"All __torch_function__ overrides for call {fn} with args {args} and kwargs {kwargs} returned NotImplemented"
)
class TensorWithTFOverrideVariable(TensorVariable):
"""
Represents a tensor subclass instance with a __torch_function__ override.
"""
def __init__(self, *args, **kwargs) -> None:
self.torch_function_fn = kwargs.pop("torch_function_fn")
super().__init__(*args, **kwargs)
@classmethod
def from_tensor_var(cls, tx, tensor_var, class_type, torch_function_fn):
import torch
kwargs = dict(tensor_var.__dict__)
assert (
kwargs.pop("class_type") is torch.Tensor
), "invalid class type in TensorWithTFOverrideVariable.from_tensor_var"
var = cls(torch_function_fn=torch_function_fn, class_type=class_type, **kwargs)
var.install_global(tx)
return var
def install_global(self, tx):
# stash the subclass type to rewrap an output tensor if needed
# this is needed because the actual type needs to be available
# each time the compiled artifact is run and outputs a wrapped tensor.
if self.global_mangled_class_name(tx) not in tx.output.global_scope:
# Safe because global_mangled_class_name figures it out
tx.output.install_global_unsafe(
self.global_mangled_class_name(tx), self.class_type
)
def python_type(self):
return self.class_type
def class_type_var(self, tx):
return TensorSubclassVariable(
self.class_type, source=GlobalSource(self.global_mangled_class_name(tx))
)
def global_mangled_class_name(self, tx):
return get_safe_global_name(
tx, f"__subclass_{self.class_type.__name__}", self.class_type
)
def var_getattr(self, tx: "InstructionTranslator", name):
# [Note: __torch_function__] We currently only support attributes that are defined on
# base tensors, custom attribute accesses will graph break.
import torch
from .builder import SourcelessBuilder
if name in banned_attrs:
unimplemented(
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
)
if _is_attr_overidden(tx, self, name):
unimplemented(
f"Accessing overridden method/attribute {name} on a tensor"
" subclass with a __torch_function__ override is not supported"
)
if tx.output.torch_function_enabled and hasattr(torch.Tensor, name):
if self.source:
install_guard(
AttrSource(AttrSource(self.source, "__class__"), name).make_guard(
GuardBuilder.FUNCTION_MATCH
)
)
get_fn = SourcelessBuilder.create(tx, getattr(torch.Tensor, name).__get__)
return self.call_torch_function(
tx,
get_fn,
TupleVariable([self.class_type_var(tx)]),
[self],
{},
)
else:
return super().var_getattr(tx, name)
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
return call_torch_function(
tx,
self.class_type_var(tx),
self.torch_function_fn,
fn,
types,
args,
kwargs,
)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
# This code block implements inlining the __torch_function__ override
# of `call_method`.
if tx.output.torch_function_enabled:
import torch
from .builder import SourcelessBuilder, VariableBuilder
if _is_attr_overidden(tx, self, name):
unimplemented(
f"Calling overridden method {name} on a tensor"
" subclass with a __torch_function__ override is not supported"
)
# [Note: __torch_function__] Currently we only support methods that are defined on tensor
# we will graph break in other cases this will need a bigger overhaul of extracting methods/comparing them for equality
# We've established with the above check that the method is not overridden, so we guard that the method is the same
# as the impl defined on tensor and retrieve it
if self.source:
func_var = VariableBuilder(
tx, AttrSource(AttrSource(self.source, "__class__"), name)
)(inspect.getattr_static(self.python_type(), name))
else:
func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name))
return dispatch_torch_function(tx, func_var, [self] + args, kwargs)
else:
return super().call_method(tx, name, args, kwargs)