Revert "[Dynamo] Trace enter/exit of TorchFunctionModes (#135422) (#137114)"

This reverts commit 51bc839b94829f176e3c1b7f62e3448d6028c480.

Reverted https://github.com/pytorch/pytorch/pull/137114 on behalf of https://github.com/huydhn due to The top of the stack has been reverted but it leaves trunk in a broken state, so I try to revert the rest of the stack ([comment](https://github.com/pytorch/pytorch/pull/137114#issuecomment-2400765603))
This commit is contained in:
PyTorch MergeBot
2024-10-08 20:33:17 +00:00
parent 8c937445ee
commit d34b617bb9
14 changed files with 52 additions and 321 deletions

View File

@ -461,94 +461,6 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
self.assertEqual(expected, actual)
def test_torch_function_mode_enter_exit(self):
def fn(x, y):
with TestMode():
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn, fullgraph=True)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_graph_break(self):
def fn(x, y):
with TestMode():
torch._dynamo.graph_break()
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_and_pop_graph_break(self):
def fn(x, y):
with TestMode():
z = _pop_torch_function_stack()
torch._dynamo.graph_break()
_push_on_torch_function_stack(z)
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_restore_on_exc(self):
@torch._dynamo.disable()
def err():
raise RuntimeError("test")
@torch.compile()
def fn(x):
with TestMode():
x += 1
err()
x += 2
return x
try:
fn(torch.ones(2, 2))
except RuntimeError:
pass
self.assertEqual(_len_torch_function_stack(), 0)
def test_torch_function_mode_and_pop_graph_break_mutation(self):
def fn(x, y):
with TestMode():
z = _pop_torch_function_stack()
z.y = 5
torch._dynamo.graph_break()
_push_on_torch_function_stack(z)
o = torch.add(x, 3)
o = torch.mul(o, z.y)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -120,7 +120,6 @@ from .utils import (
troubleshooting_url,
write_record_to_file,
)
from .variables.torch_function import torch_function_mode_stack_state_mgr
np: Optional[ModuleType]
@ -219,18 +218,15 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
cleanup = setup_compile_debug()
exit_stack = contextlib.ExitStack()
exit_stack.enter_context(
torch.fx._symbolic_trace._maybe_revert_all_patches()
)
exit_stack.enter_context(torch_function_mode_stack_state_mgr)
try:
return fn(*args, **kwargs)
finally:
cleanup.close()
assert (
torch._C._len_torch_function_stack() == 0
), "Torch function mode stack state changed while dynamo tracing, please report a bug"
exit_stack.close()
torch._C._set_grad_enabled(prior_grad_mode)
torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)

View File

@ -78,6 +78,7 @@ from .utils import (
get_instruction_source_311,
get_locals_to_steal,
get_static_address_type,
get_torch_function_mode_stack,
graph_break_reasons,
increment_op_count,
lazy_format_graph_code,
@ -249,7 +250,6 @@ class OutputGraph:
local_scope: Scope,
global_scope: Scope,
f_code,
torch_function_mode_stack,
):
super().__init__()
self.tracers = [SubgraphTracer(self, export_root=export)]
@ -368,7 +368,7 @@ class OutputGraph:
# This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
# This records the initial torch function mode stack for guarding
self.torch_function_mode_stack = torch_function_mode_stack
self.torch_function_mode_stack = get_torch_function_mode_stack()
# Tracks if the output graph has a user defined allowed function in the
# graph. This is used later to determine if we should fallback to eager
@ -1021,7 +1021,7 @@ class OutputGraph:
prefix_insts.clear()
for block in reversed(tx.block_stack):
block.exit(tx, is_graph_break=reason.graph_break)
block.exit(tx)
self.cleanup_graph()
tx.prune_dead_locals()

View File

@ -25,26 +25,6 @@ if TYPE_CHECKING:
sys as sys,
)
from torch.overrides import BaseTorchFunctionMode
# These classes handle support for TorchFunctionModes across
# graph breaks
# Today the TorchFunctionMode enter (for the classes we support)
# simply pushes the mode onto the stack. Since after this occurs
# the stack is mutated, and we replay these mutations, we don't need
# any cleanup logic to be run once the graph break occurs, we simply replay
# these mutations to ensure at the graph break the torch function mode stack is correct
# and reconstruct the torch function mode stack normally
# when we compile the resume function on the other side of the break.
# However, to ensure we exit properly
# in the resume function, we need to re-enter the contexts as we do other contexts.
# These contexts do nothing on enter, but provide the correct exit logic to ensure
# the stack state is correct.
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
def __enter__(self):
pass
def index(iterator, item, start=0, end=None):
from itertools import islice

View File

@ -90,25 +90,27 @@ class ReenterWith:
stack_index: int
target_values: Optional[Tuple[Any, ...]] = None
def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
"""
Codegen based off of:
try:
(rest)
except:
(restore previous tf mode stack)
raise
"""
from .variables.torch_function import get_prev_stack_var_name
# TODO(mlazos) - Uncomment with the reland of torch function mode support
# def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
# """
# Codegen based off of:
# try:
# (rest)
# except:
# (restore previous tf mode stack)
# raise
setup_try_except, epilogue = _bytecode_from_template_with_split(
_try_except_tf_mode_template,
self.stack_index,
varname_map={"stack_var_name": get_prev_stack_var_name()},
)
cleanup[:] = epilogue + cleanup
# """
# from .variables.torch_function import get_prev_stack_var_name
return setup_try_except
# setup_try_except, epilogue = _bytecode_from_template_with_split(
# _try_except_tf_mode_template,
# self.stack_index,
# varname_map={"stack_var_name": get_prev_stack_var_name()},
# )
# cleanup[:] = epilogue + cleanup
# return setup_try_except
# If we do not want to destroy the stack, we can do the same thing as a
# `SETUP_WITH` block, only that we store the context manager in a local_symbol

View File

@ -629,22 +629,11 @@ class SideEffects:
elif isinstance(
var, variables.torch_function.TorchFunctionModeStackVariable
):
# Needed in the finally block for stack restoration
cg.add_push_null(
lambda: cg.load_import_from(
utils.__name__, "get_torch_function_mode_stack"
)
)
cg.call_function(0, False)
name = variables.torch_function.get_prev_stack_var_name()
cg.code_options["co_varnames"] += (name,)
cg.append_output(create_instruction("STORE_FAST", argval=name))
cg.add_push_null(
lambda: cg.load_import_from(
utils.__name__, "set_torch_function_mode_stack"
)
)
cg.foreach(var.symbolic_stack)
cg.append_output(
create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))

View File

@ -267,11 +267,12 @@ class BlockStackEntry:
else:
return ReenterWith(self.stack_index)
def exit(self, tx, is_graph_break):
def exit(self, tx):
if hasattr(self, "graph_break") and isinstance(
self.with_context, TorchFunctionModeVariable
):
return
assert self.with_context is not None
if (
is_graph_break and self.with_context.exit_on_graph_break()
) or not is_graph_break:
return self.with_context.exit(tx)
@ -656,17 +657,10 @@ def break_graph_if_unsupported(*, push):
cleanup: List[Instruction] = []
# Reconstruct the context variable CLASS in the block stack
for b in self.block_stack:
# Don't exit any modes we have entered,
# output bytecode will mutate the tf mode stack accordingly
if isinstance(b.with_context, TorchFunctionModeVariable):
cg.extend_output(
b.resume_fn().try_except_torch_function_mode(
cg.code_options, cleanup
)
)
continue
assert b.with_context is not None
assert isinstance(b.with_context, (ContextWrappingVariable))
assert isinstance(
b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable)
)
b.with_context.reconstruct_type(cg)
cg.extend_output(b.resume_fn().try_finally(cg.code_options, cleanup))
self.output.add_output_instructions(cg.get_instructions())
@ -2320,10 +2314,7 @@ class InstructionTranslatorBase(
):
unimplemented(f"{inst.opname} {ctx}")
if (
isinstance(ctx, GenericContextWrappingVariable)
and not ctx.supports_graph_breaks()
):
if isinstance(ctx, GenericContextWrappingVariable):
self.generic_context_manager_depth += 1
# Need this redundant check for mypy
@ -2696,7 +2687,6 @@ class InstructionTranslator(InstructionTranslatorBase):
local_scope=f_locals,
global_scope=f_globals,
f_code=f_code,
torch_function_mode_stack=torch_function_mode_stack,
),
instructions=instructions,
f_locals=f_locals,

View File

@ -187,7 +187,6 @@ def debug_insert_nops(
local_scope=locals(),
global_scope=globals(),
f_code=frame.f_code,
torch_function_mode_stack=[],
)
return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))

View File

@ -304,7 +304,6 @@ manual_torch_name_rule_map = {
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
"torch.set_default_device": UserFunctionVariable,
"torch.sparse_bsc_tensor": SkipFunctionVariable,
"torch.sparse_bsr_tensor": SkipFunctionVariable,
"torch.sparse_csc_tensor": SkipFunctionVariable,
@ -2803,6 +2802,7 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch.random.initial_seed",
"torch.random.seed",
"torch.return_types.pytree_register_structseq",
"torch.set_default_device",
"torch.set_default_dtype",
"torch.set_default_tensor_type",
"torch.set_deterministic_debug_mode",

View File

@ -204,7 +204,6 @@ from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
from .torch_function import (
build_torch_function_fn,
TensorWithTFOverrideVariable,
torch_function_mode_stack_state_mgr,
TorchFunctionModeVariable,
)
from .user_defined import (
@ -1670,7 +1669,6 @@ class VariableBuilder:
# but warning is not the end of the world
assert isinstance(value.base, np.nditer)
with torch_function_mode_stack_state_mgr.temp_restore_stack():
try:
tensor_value = _util._try_convert_to_tensor(value)
if readonly:

View File

@ -125,12 +125,6 @@ class ContextWrappingVariable(VariableTracker):
if isinstance(args[0], UserFunctionVariable):
return WrappedUserFunctionVariable(args[0], self)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return True
class GenericContextWrappingVariable(UserDefinedObjectVariable):
# Some methods in ContextWrappingVariable assumes the arguments are
@ -189,12 +183,6 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable):
tx.generic_context_manager_depth -= 1
return x
def supports_graph_breaks(self):
return False
def exit_on_graph_break(self):
return True
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
"""represents torch grad requries grad"""

View File

@ -159,17 +159,7 @@ def get_overridable_functions():
from torch.overrides import get_overridable_functions as get_overridable_functions_
funcs = set(chain(*get_overridable_functions_().values()))
more = {
torch.ones,
torch.ones_like,
torch.zeros,
torch.zeros_like,
torch.empty,
torch.full,
}
funcs.update(more)
return funcs
return set(chain(*get_overridable_functions_().values()))
class BaseTorchVariable(VariableTracker):
@ -845,13 +835,6 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
len(tx.symbolic_torch_function_state.mode_stack)
)
@register(torch._C._get_function_stack_at)
def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs):
assert len(args) == 1 and not kwargs
ind = args[0].as_python_constant()
assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack)
return tx.symbolic_torch_function_state.mode_stack[ind]
@register(torch.set_default_device)
def handle_set_default_device(
self, tx: "InstructionTranslator", *args, **kwargs
@ -869,7 +852,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
else:
TorchFunctionModeStackVariable.register_device_context_insertion(tx)
return ConstantVariable.create(None)
return None
return handlers

View File

@ -2,35 +2,22 @@
import collections
import contextlib
import functools
import inspect
from typing import Deque, Dict, List, TYPE_CHECKING
import torch._C
import torch.utils._pytree as pytree
from torch._guards import Source
from torch.overrides import (
_get_overloaded_args,
get_default_nowrap_functions,
TorchFunctionMode,
)
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
from torch.utils._device import DeviceContext
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..polyfills import NoEnterTorchFunctionMode
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
from ..utils import (
class_has_getattribute,
clear_torch_function_mode_stack,
get_safe_global_name,
has_torch_function,
is_tensor_base_attr_getter,
set_torch_function_mode_stack,
)
from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import GenericContextWrappingVariable
from .ctx_manager import ContextWrappingVariable
from .lazy import LazyVariableTracker
from .lists import TupleVariable
from .tensor import TensorSubclassVariable, TensorVariable
@ -76,39 +63,6 @@ banned_attrs = [
IGNORED_MODES = {DeviceContext}
@functools.lru_cache(None)
def get_prev_stack_var_name():
from ..bytecode_transformation import unique_id
return unique_id("___prev_torch_function_mode_stack")
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
class TorchFunctionModeStackStateManager:
def __init__(self):
self.stack = []
def __enter__(self):
self.stack = torch.overrides._get_current_function_mode_stack()
clear_torch_function_mode_stack()
def __exit__(self, exc_type, exc_value, traceback):
set_torch_function_mode_stack(self.stack)
self.stack = []
@contextlib.contextmanager
def temp_restore_stack(self):
prev = torch.overrides._get_current_function_mode_stack()
set_torch_function_mode_stack(self.stack)
try:
yield
finally:
set_torch_function_mode_stack(prev)
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
class SymbolicTorchFunctionState:
def __init__(self, py_stack):
# This is annoyingly complicated because of how the torch function subclass + mode C API was designed
@ -235,25 +189,8 @@ class TorchFunctionModeStackVariable(VariableTracker):
return ind + cls.offset
class TorchFunctionModeVariable(GenericContextWrappingVariable):
@staticmethod
def is_supported_torch_function_mode(ty):
# Supported in this sense means we can support graph breaks under the
# context.
# We are able to trace custom modes but if there are graph breaks under them
# and they have a custom __enter__/__exit__ we don't handle this for the
# same reason we don't handle generic context managers: there may be side effects
# that are now affected by executing the funtion across two frames instead of one
# Today we support the enter/exit of the default TorchFunctionMode as well as
# DeviceContext (which is used for set_default_device)
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
not class_has_getattribute(ty)
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
)
class TorchFunctionModeVariable(ContextWrappingVariable):
def __init__(self, value, source=None, **kwargs):
if value is not None:
super().__init__(value, **kwargs)
self.value = value
self.cm_obj = value # needed for BC with calling enter from CM code
@ -284,39 +221,8 @@ class TorchFunctionModeVariable(GenericContextWrappingVariable):
kwargs,
)
def enter(self, tx):
from .torch import TorchInGraphFunctionVariable
if isinstance(self.value, NoEnterTorchFunctionMode):
return ConstantVariable.create(None)
TorchInGraphFunctionVariable(
torch._C._push_on_torch_function_stack
).call_function(tx, [self], {})
return ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
from .torch import TorchInGraphFunctionVariable
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
tx, [], {}
)
return ConstantVariable.create(None)
def reconstruct_type(self, codegen):
ty = NoEnterTorchFunctionMode
codegen(
AttrSource(
codegen.tx.import_source(ty.__module__),
ty.__name__,
)
)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return False
def _call_func(self, tx: "InstructionTranslator", values):
unimplemented("enter/exit for torch function mode NYI")
def _get_all_args(args, kwargs):

View File

@ -413,22 +413,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
and self.source
and not is_forbidden_context_manager(self.value)
):
from torch.overrides import TorchFunctionMode
from .ctx_manager import GenericContextWrappingVariable
from .torch_function import TorchFunctionModeVariable
if issubclass(
self.value, TorchFunctionMode
) and TorchFunctionModeVariable.is_supported_torch_function_mode(
self.value
):
var_cls = TorchFunctionModeVariable
else:
var_cls = GenericContextWrappingVariable
cm_obj = tx.output.side_effects.track_object_new(
self.source, self.value, var_cls, {}
self.source, self.value, GenericContextWrappingVariable, {}
)
cm_obj.call_method(tx, "__init__", args, kwargs)
return cm_obj