mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
This reverts commit 51bc839b94829f176e3c1b7f62e3448d6028c480. Reverted https://github.com/pytorch/pytorch/pull/137114 on behalf of https://github.com/huydhn due to The top of the stack has been reverted but it leaves trunk in a broken state, so I try to revert the rest of the stack ([comment](https://github.com/pytorch/pytorch/pull/137114#issuecomment-2400765603))
This commit is contained in:
@ -2,35 +2,22 @@
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Deque, Dict, List, TYPE_CHECKING
|
||||
|
||||
import torch._C
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._guards import Source
|
||||
from torch.overrides import (
|
||||
_get_overloaded_args,
|
||||
get_default_nowrap_functions,
|
||||
TorchFunctionMode,
|
||||
)
|
||||
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
|
||||
from torch.utils._device import DeviceContext
|
||||
|
||||
from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..polyfills import NoEnterTorchFunctionMode
|
||||
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
|
||||
from ..utils import (
|
||||
class_has_getattribute,
|
||||
clear_torch_function_mode_stack,
|
||||
get_safe_global_name,
|
||||
has_torch_function,
|
||||
is_tensor_base_attr_getter,
|
||||
set_torch_function_mode_stack,
|
||||
)
|
||||
from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import GenericContextWrappingVariable
|
||||
from .ctx_manager import ContextWrappingVariable
|
||||
from .lazy import LazyVariableTracker
|
||||
from .lists import TupleVariable
|
||||
from .tensor import TensorSubclassVariable, TensorVariable
|
||||
@ -76,39 +63,6 @@ banned_attrs = [
|
||||
IGNORED_MODES = {DeviceContext}
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_prev_stack_var_name():
|
||||
from ..bytecode_transformation import unique_id
|
||||
|
||||
return unique_id("___prev_torch_function_mode_stack")
|
||||
|
||||
|
||||
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
|
||||
class TorchFunctionModeStackStateManager:
|
||||
def __init__(self):
|
||||
self.stack = []
|
||||
|
||||
def __enter__(self):
|
||||
self.stack = torch.overrides._get_current_function_mode_stack()
|
||||
clear_torch_function_mode_stack()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
set_torch_function_mode_stack(self.stack)
|
||||
self.stack = []
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temp_restore_stack(self):
|
||||
prev = torch.overrides._get_current_function_mode_stack()
|
||||
set_torch_function_mode_stack(self.stack)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
set_torch_function_mode_stack(prev)
|
||||
|
||||
|
||||
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
|
||||
|
||||
|
||||
class SymbolicTorchFunctionState:
|
||||
def __init__(self, py_stack):
|
||||
# This is annoyingly complicated because of how the torch function subclass + mode C API was designed
|
||||
@ -235,26 +189,9 @@ class TorchFunctionModeStackVariable(VariableTracker):
|
||||
return ind + cls.offset
|
||||
|
||||
|
||||
class TorchFunctionModeVariable(GenericContextWrappingVariable):
|
||||
@staticmethod
|
||||
def is_supported_torch_function_mode(ty):
|
||||
# Supported in this sense means we can support graph breaks under the
|
||||
# context.
|
||||
# We are able to trace custom modes but if there are graph breaks under them
|
||||
# and they have a custom __enter__/__exit__ we don't handle this for the
|
||||
# same reason we don't handle generic context managers: there may be side effects
|
||||
# that are now affected by executing the funtion across two frames instead of one
|
||||
# Today we support the enter/exit of the default TorchFunctionMode as well as
|
||||
# DeviceContext (which is used for set_default_device)
|
||||
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
|
||||
not class_has_getattribute(ty)
|
||||
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
|
||||
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
|
||||
)
|
||||
|
||||
class TorchFunctionModeVariable(ContextWrappingVariable):
|
||||
def __init__(self, value, source=None, **kwargs):
|
||||
if value is not None:
|
||||
super().__init__(value, **kwargs)
|
||||
super().__init__(value, **kwargs)
|
||||
self.value = value
|
||||
self.cm_obj = value # needed for BC with calling enter from CM code
|
||||
self.source = source
|
||||
@ -284,39 +221,8 @@ class TorchFunctionModeVariable(GenericContextWrappingVariable):
|
||||
kwargs,
|
||||
)
|
||||
|
||||
def enter(self, tx):
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
if isinstance(self.value, NoEnterTorchFunctionMode):
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
TorchInGraphFunctionVariable(
|
||||
torch._C._push_on_torch_function_stack
|
||||
).call_function(tx, [self], {})
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
|
||||
tx, [], {}
|
||||
)
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def reconstruct_type(self, codegen):
|
||||
ty = NoEnterTorchFunctionMode
|
||||
codegen(
|
||||
AttrSource(
|
||||
codegen.tx.import_source(ty.__module__),
|
||||
ty.__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def supports_graph_breaks(self):
|
||||
return True
|
||||
|
||||
def exit_on_graph_break(self):
|
||||
return False
|
||||
def _call_func(self, tx: "InstructionTranslator", values):
|
||||
unimplemented("enter/exit for torch function mode NYI")
|
||||
|
||||
|
||||
def _get_all_args(args, kwargs):
|
||||
|
Reference in New Issue
Block a user