Revert "[Dynamo] Trace enter/exit of TorchFunctionModes (#135422) (#137114)"

This reverts commit 51bc839b94829f176e3c1b7f62e3448d6028c480.

Reverted https://github.com/pytorch/pytorch/pull/137114 on behalf of https://github.com/huydhn due to The top of the stack has been reverted but it leaves trunk in a broken state, so I try to revert the rest of the stack ([comment](https://github.com/pytorch/pytorch/pull/137114#issuecomment-2400765603))
This commit is contained in:
PyTorch MergeBot
2024-10-08 20:33:17 +00:00
parent 8c937445ee
commit d34b617bb9
14 changed files with 52 additions and 321 deletions

View File

@ -2,35 +2,22 @@
import collections
import contextlib
import functools
import inspect
from typing import Deque, Dict, List, TYPE_CHECKING
import torch._C
import torch.utils._pytree as pytree
from torch._guards import Source
from torch.overrides import (
_get_overloaded_args,
get_default_nowrap_functions,
TorchFunctionMode,
)
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
from torch.utils._device import DeviceContext
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..polyfills import NoEnterTorchFunctionMode
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
from ..utils import (
class_has_getattribute,
clear_torch_function_mode_stack,
get_safe_global_name,
has_torch_function,
is_tensor_base_attr_getter,
set_torch_function_mode_stack,
)
from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import GenericContextWrappingVariable
from .ctx_manager import ContextWrappingVariable
from .lazy import LazyVariableTracker
from .lists import TupleVariable
from .tensor import TensorSubclassVariable, TensorVariable
@ -76,39 +63,6 @@ banned_attrs = [
IGNORED_MODES = {DeviceContext}
@functools.lru_cache(None)
def get_prev_stack_var_name():
from ..bytecode_transformation import unique_id
return unique_id("___prev_torch_function_mode_stack")
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
class TorchFunctionModeStackStateManager:
def __init__(self):
self.stack = []
def __enter__(self):
self.stack = torch.overrides._get_current_function_mode_stack()
clear_torch_function_mode_stack()
def __exit__(self, exc_type, exc_value, traceback):
set_torch_function_mode_stack(self.stack)
self.stack = []
@contextlib.contextmanager
def temp_restore_stack(self):
prev = torch.overrides._get_current_function_mode_stack()
set_torch_function_mode_stack(self.stack)
try:
yield
finally:
set_torch_function_mode_stack(prev)
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
class SymbolicTorchFunctionState:
def __init__(self, py_stack):
# This is annoyingly complicated because of how the torch function subclass + mode C API was designed
@ -235,26 +189,9 @@ class TorchFunctionModeStackVariable(VariableTracker):
return ind + cls.offset
class TorchFunctionModeVariable(GenericContextWrappingVariable):
@staticmethod
def is_supported_torch_function_mode(ty):
# Supported in this sense means we can support graph breaks under the
# context.
# We are able to trace custom modes but if there are graph breaks under them
# and they have a custom __enter__/__exit__ we don't handle this for the
# same reason we don't handle generic context managers: there may be side effects
# that are now affected by executing the funtion across two frames instead of one
# Today we support the enter/exit of the default TorchFunctionMode as well as
# DeviceContext (which is used for set_default_device)
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
not class_has_getattribute(ty)
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
)
class TorchFunctionModeVariable(ContextWrappingVariable):
def __init__(self, value, source=None, **kwargs):
if value is not None:
super().__init__(value, **kwargs)
super().__init__(value, **kwargs)
self.value = value
self.cm_obj = value # needed for BC with calling enter from CM code
self.source = source
@ -284,39 +221,8 @@ class TorchFunctionModeVariable(GenericContextWrappingVariable):
kwargs,
)
def enter(self, tx):
from .torch import TorchInGraphFunctionVariable
if isinstance(self.value, NoEnterTorchFunctionMode):
return ConstantVariable.create(None)
TorchInGraphFunctionVariable(
torch._C._push_on_torch_function_stack
).call_function(tx, [self], {})
return ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
from .torch import TorchInGraphFunctionVariable
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
tx, [], {}
)
return ConstantVariable.create(None)
def reconstruct_type(self, codegen):
ty = NoEnterTorchFunctionMode
codegen(
AttrSource(
codegen.tx.import_source(ty.__module__),
ty.__name__,
)
)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return False
def _call_func(self, tx: "InstructionTranslator", values):
unimplemented("enter/exit for torch function mode NYI")
def _get_all_args(args, kwargs):