mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[Dynamo] Trace enter/exit of TorchFunctionModes (#135422)"
This reverts commit 2af3b8ffd84e36b91279174e9106f84b2d2a11f2. Reverted https://github.com/pytorch/pytorch/pull/135422 on behalf of https://github.com/albanD due to Broke tests on main ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2348886378))
This commit is contained in:
@ -461,94 +461,6 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_enter_exit(self):
|
||||
def fn(x, y):
|
||||
with TestMode():
|
||||
o = torch.add(x, 3)
|
||||
|
||||
return torch.add(o, y)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
|
||||
expected = fn(*inp)
|
||||
actual = fn_opt(*inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_graph_break(self):
|
||||
def fn(x, y):
|
||||
with TestMode():
|
||||
torch._dynamo.graph_break()
|
||||
o = torch.add(x, 3)
|
||||
|
||||
return torch.add(o, y)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
expected = fn(*inp)
|
||||
actual = fn_opt(*inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_and_pop_graph_break(self):
|
||||
def fn(x, y):
|
||||
with TestMode():
|
||||
z = _pop_torch_function_stack()
|
||||
torch._dynamo.graph_break()
|
||||
_push_on_torch_function_stack(z)
|
||||
o = torch.add(x, 3)
|
||||
|
||||
return torch.add(o, y)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
expected = fn(*inp)
|
||||
actual = fn_opt(*inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_restore_on_exc(self):
|
||||
@torch._dynamo.disable()
|
||||
def err():
|
||||
raise RuntimeError("test")
|
||||
|
||||
@torch.compile()
|
||||
def fn(x):
|
||||
with TestMode():
|
||||
x += 1
|
||||
err()
|
||||
x += 2
|
||||
return x
|
||||
|
||||
try:
|
||||
fn(torch.ones(2, 2))
|
||||
except RuntimeError:
|
||||
pass
|
||||
self.assertEqual(_len_torch_function_stack(), 0)
|
||||
|
||||
def test_torch_function_mode_and_pop_graph_break_mutation(self):
|
||||
def fn(x, y):
|
||||
with TestMode():
|
||||
z = _pop_torch_function_stack()
|
||||
z.y = 5
|
||||
torch._dynamo.graph_break()
|
||||
_push_on_torch_function_stack(z)
|
||||
o = torch.add(x, 3)
|
||||
o = torch.mul(o, z.y)
|
||||
|
||||
return torch.add(o, y)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
expected = fn(*inp)
|
||||
actual = fn_opt(*inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -112,7 +112,6 @@ from .utils import (
|
||||
troubleshooting_url,
|
||||
write_record_to_file,
|
||||
)
|
||||
from .variables.torch_function import torch_function_mode_stack_state_mgr
|
||||
|
||||
|
||||
np: Optional[ModuleType]
|
||||
@ -211,18 +210,15 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
|
||||
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
|
||||
cleanup = setup_compile_debug()
|
||||
|
||||
exit_stack = contextlib.ExitStack()
|
||||
exit_stack.enter_context(
|
||||
torch.fx._symbolic_trace._maybe_revert_all_patches()
|
||||
)
|
||||
exit_stack.enter_context(torch_function_mode_stack_state_mgr)
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
cleanup.close()
|
||||
assert (
|
||||
torch._C._len_torch_function_stack() == 0
|
||||
), "Torch function mode stack state changed while dynamo tracing, please report a bug"
|
||||
exit_stack.close()
|
||||
torch._C._set_grad_enabled(prior_grad_mode)
|
||||
torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
|
||||
|
@ -78,6 +78,7 @@ from .utils import (
|
||||
get_instruction_source_311,
|
||||
get_locals_to_steal,
|
||||
get_static_address_type,
|
||||
get_torch_function_mode_stack,
|
||||
graph_break_reasons,
|
||||
increment_op_count,
|
||||
lazy_format_graph_code,
|
||||
@ -249,7 +250,6 @@ class OutputGraph:
|
||||
local_scope: Scope,
|
||||
global_scope: Scope,
|
||||
f_code,
|
||||
torch_function_mode_stack,
|
||||
):
|
||||
super().__init__()
|
||||
self.tracers = [SubgraphTracer(self, export_root=export)]
|
||||
@ -368,7 +368,7 @@ class OutputGraph:
|
||||
# This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
|
||||
self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
|
||||
# This records the initial torch function mode stack for guarding
|
||||
self.torch_function_mode_stack = torch_function_mode_stack
|
||||
self.torch_function_mode_stack = get_torch_function_mode_stack()
|
||||
|
||||
# Tracks if the output graph has a user defined allowed function in the
|
||||
# graph. This is used later to determine if we should fallback to eager
|
||||
@ -1020,7 +1020,7 @@ class OutputGraph:
|
||||
prefix_insts.clear()
|
||||
|
||||
for block in reversed(tx.block_stack):
|
||||
block.exit(tx, is_graph_break=reason.graph_break)
|
||||
block.exit(tx)
|
||||
|
||||
self.cleanup_graph()
|
||||
tx.prune_dead_locals()
|
||||
|
@ -25,26 +25,6 @@ if TYPE_CHECKING:
|
||||
sys as sys,
|
||||
)
|
||||
|
||||
from torch.overrides import BaseTorchFunctionMode
|
||||
|
||||
|
||||
# These classes handle support for TorchFunctionModes across
|
||||
# graph breaks
|
||||
# Today the TorchFunctionMode enter (for the classes we support)
|
||||
# simply pushes the mode onto the stack. Since after this occurs
|
||||
# the stack is mutated, and we replay these mutations, we don't need
|
||||
# any cleanup logic to be run once the graph break occurs, we simply replay
|
||||
# these mutations to ensure at the graph break the torch function mode stack is correct
|
||||
# and reconstruct the torch function mode stack normally
|
||||
# when we compile the resume function on the other side of the break.
|
||||
# However, to ensure we exit properly
|
||||
# in the resume function, we need to re-enter the contexts as we do other contexts.
|
||||
# These contexts do nothing on enter, but provide the correct exit logic to ensure
|
||||
# the stack state is correct.
|
||||
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
|
||||
def index(iterator, item, start=0, end=None):
|
||||
from itertools import islice
|
||||
|
@ -6,7 +6,6 @@ import types
|
||||
from typing import Any, cast, Dict, List, Optional, Tuple
|
||||
|
||||
from .bytecode_transformation import (
|
||||
add_push_null,
|
||||
create_call_function,
|
||||
create_call_method,
|
||||
create_dup_top,
|
||||
@ -49,109 +48,6 @@ class ReenterWith:
|
||||
stack_index: int
|
||||
target_values: Optional[Tuple[Any, ...]] = None
|
||||
|
||||
def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
|
||||
"""
|
||||
Codegen based off of:
|
||||
try:
|
||||
(rest)
|
||||
except:
|
||||
(restore previous stack)
|
||||
|
||||
"""
|
||||
from .variables.torch_function import get_prev_stack_var_name
|
||||
|
||||
except_jump_target = create_instruction(
|
||||
"NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO"
|
||||
)
|
||||
cleanup_complete_jump_target = create_instruction("NOP")
|
||||
|
||||
setup_finally: List[Instruction] = []
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
setup_finally.append(
|
||||
create_instruction("SETUP_FINALLY", target=except_jump_target)
|
||||
)
|
||||
else:
|
||||
exn_tab_begin = create_instruction("NOP")
|
||||
exn_tab_end = create_instruction("NOP")
|
||||
exn_tab_begin.exn_tab_entry = InstructionExnTabEntry(
|
||||
exn_tab_begin,
|
||||
exn_tab_end,
|
||||
except_jump_target,
|
||||
self.stack_index + 1,
|
||||
False,
|
||||
)
|
||||
setup_finally.append(exn_tab_begin)
|
||||
|
||||
def create_reset():
|
||||
insts = [
|
||||
create_instruction(
|
||||
"LOAD_GLOBAL", argval="__import_torch_dot__dynamo_dot_utils"
|
||||
),
|
||||
create_instruction("LOAD_ATTR", argval="set_torch_function_mode_stack"),
|
||||
]
|
||||
add_push_null(insts)
|
||||
return [
|
||||
*insts,
|
||||
create_instruction("LOAD_FAST", argval=get_prev_stack_var_name()),
|
||||
*create_call_function(1, False),
|
||||
create_instruction("POP_TOP"),
|
||||
]
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
epilogue = [
|
||||
create_instruction("POP_BLOCK"),
|
||||
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
|
||||
except_jump_target,
|
||||
*create_reset(),
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("POP_TOP"),
|
||||
*create_reset(),
|
||||
create_instruction("RAISE_VARARGS", argval=0),
|
||||
create_instruction("POP_EXCEPT", argval=0),
|
||||
create_instruction("END_FINALLY"),
|
||||
cleanup_complete_jump_target,
|
||||
]
|
||||
elif sys.version_info < (3, 11):
|
||||
epilogue = [
|
||||
create_instruction("POP_BLOCK"),
|
||||
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
|
||||
except_jump_target,
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("POP_TOP"),
|
||||
*create_reset(),
|
||||
create_instruction("RAISE_VARARGS", argval=0),
|
||||
create_instruction("POP_EXCEPT", argval=0),
|
||||
cleanup_complete_jump_target,
|
||||
]
|
||||
else:
|
||||
finally_exn_tab_end = create_instruction("RAISE_VARARGS", argval=0)
|
||||
finally_exn_tab_target = create_instruction("COPY", arg=3)
|
||||
except_jump_target.exn_tab_entry = InstructionExnTabEntry(
|
||||
except_jump_target,
|
||||
finally_exn_tab_end,
|
||||
finally_exn_tab_target,
|
||||
self.stack_index + 2,
|
||||
True,
|
||||
)
|
||||
epilogue = [
|
||||
exn_tab_end,
|
||||
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
|
||||
except_jump_target, # PUSH_EXC_INFO
|
||||
create_instruction("POP_TOP"),
|
||||
*create_reset(),
|
||||
finally_exn_tab_end,
|
||||
finally_exn_tab_target, # COPY 3
|
||||
create_instruction("POP_EXCEPT"),
|
||||
create_instruction("RERAISE", arg=1), # RERAISE 1
|
||||
cleanup_complete_jump_target,
|
||||
]
|
||||
|
||||
cleanup[:] = epilogue + cleanup
|
||||
return setup_finally
|
||||
|
||||
# If we do not want to destroy the stack, we can do the same thing as a
|
||||
# `SETUP_WITH` block, only that we store the context manager in a local_symbol
|
||||
def try_except(self, code_options, cleanup: List[Instruction]):
|
||||
|
@ -593,22 +593,11 @@ class SideEffects:
|
||||
elif isinstance(
|
||||
var, variables.torch_function.TorchFunctionModeStackVariable
|
||||
):
|
||||
# Needed in the finally block for stack restoration
|
||||
cg.add_push_null(
|
||||
lambda: cg.load_import_from(
|
||||
utils.__name__, "get_torch_function_mode_stack"
|
||||
)
|
||||
)
|
||||
cg.call_function(0, False)
|
||||
name = variables.torch_function.get_prev_stack_var_name()
|
||||
cg.code_options["co_varnames"] += (name,)
|
||||
cg.append_output(create_instruction("STORE_FAST", argval=name))
|
||||
cg.add_push_null(
|
||||
lambda: cg.load_import_from(
|
||||
utils.__name__, "set_torch_function_mode_stack"
|
||||
)
|
||||
)
|
||||
|
||||
cg.foreach(var.symbolic_stack)
|
||||
cg.append_output(
|
||||
create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))
|
||||
|
@ -267,12 +267,13 @@ class BlockStackEntry:
|
||||
else:
|
||||
return ReenterWith(self.stack_index)
|
||||
|
||||
def exit(self, tx, is_graph_break):
|
||||
def exit(self, tx):
|
||||
if hasattr(self, "graph_break") and isinstance(
|
||||
self.with_context, TorchFunctionModeVariable
|
||||
):
|
||||
return
|
||||
assert self.with_context is not None
|
||||
if (
|
||||
is_graph_break and self.with_context.exit_on_graph_break()
|
||||
) or not is_graph_break:
|
||||
return self.with_context.exit(tx)
|
||||
return self.with_context.exit(tx)
|
||||
|
||||
|
||||
class ReturnValueOp(Exception):
|
||||
@ -638,17 +639,10 @@ def break_graph_if_unsupported(*, push):
|
||||
cleanup: List[Instruction] = []
|
||||
# Reconstruct the context variable CLASS in the block stack
|
||||
for b in self.block_stack:
|
||||
# Don't exit any modes we have entered,
|
||||
# output bytecode will mutate the tf mode stack accordingly
|
||||
if isinstance(b.with_context, TorchFunctionModeVariable):
|
||||
cg.extend_output(
|
||||
b.resume_fn().try_except_torch_function_mode(
|
||||
cg.code_options, cleanup
|
||||
)
|
||||
)
|
||||
continue
|
||||
assert b.with_context is not None
|
||||
assert isinstance(b.with_context, (ContextWrappingVariable))
|
||||
assert isinstance(
|
||||
b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable)
|
||||
)
|
||||
b.with_context.reconstruct_type(cg)
|
||||
cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup))
|
||||
self.output.add_output_instructions(cg.get_instructions())
|
||||
@ -2301,10 +2295,7 @@ class InstructionTranslatorBase(
|
||||
):
|
||||
unimplemented(f"{inst.opname} {ctx}")
|
||||
|
||||
if (
|
||||
isinstance(ctx, GenericContextWrappingVariable)
|
||||
and not ctx.supports_graph_breaks()
|
||||
):
|
||||
if isinstance(ctx, GenericContextWrappingVariable):
|
||||
self.generic_context_manager_depth += 1
|
||||
|
||||
# Need this redundant check for mypy
|
||||
@ -2677,7 +2668,6 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
local_scope=f_locals,
|
||||
global_scope=f_globals,
|
||||
f_code=f_code,
|
||||
torch_function_mode_stack=torch_function_mode_stack,
|
||||
),
|
||||
instructions=instructions,
|
||||
f_locals=f_locals,
|
||||
|
@ -163,7 +163,6 @@ def debug_insert_nops(
|
||||
local_scope=locals(),
|
||||
global_scope=globals(),
|
||||
f_code=frame.f_code,
|
||||
torch_function_mode_stack=[],
|
||||
)
|
||||
|
||||
return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))
|
||||
|
@ -304,7 +304,6 @@ manual_torch_name_rule_map = {
|
||||
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
|
||||
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
|
||||
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
|
||||
"torch.set_default_device": UserFunctionVariable,
|
||||
"torch.sparse_bsc_tensor": SkipFunctionVariable,
|
||||
"torch.sparse_bsr_tensor": SkipFunctionVariable,
|
||||
"torch.sparse_csc_tensor": SkipFunctionVariable,
|
||||
@ -2798,6 +2797,7 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch.random.initial_seed",
|
||||
"torch.random.seed",
|
||||
"torch.return_types.pytree_register_structseq",
|
||||
"torch.set_default_device",
|
||||
"torch.set_default_dtype",
|
||||
"torch.set_default_tensor_type",
|
||||
"torch.set_deterministic_debug_mode",
|
||||
|
@ -204,7 +204,6 @@ from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
|
||||
from .torch_function import (
|
||||
build_torch_function_fn,
|
||||
TensorWithTFOverrideVariable,
|
||||
torch_function_mode_stack_state_mgr,
|
||||
TorchFunctionModeVariable,
|
||||
)
|
||||
from .user_defined import (
|
||||
@ -1669,16 +1668,15 @@ class VariableBuilder:
|
||||
# but warning is not the end of the world
|
||||
assert isinstance(value.base, np.nditer)
|
||||
|
||||
with torch_function_mode_stack_state_mgr.temp_restore_stack():
|
||||
try:
|
||||
tensor_value = _util._try_convert_to_tensor(value)
|
||||
if readonly:
|
||||
from torch._prims_common import clone_preserve_strides
|
||||
try:
|
||||
tensor_value = _util._try_convert_to_tensor(value)
|
||||
if readonly:
|
||||
from torch._prims_common import clone_preserve_strides
|
||||
|
||||
tensor_value = clone_preserve_strides(tensor_value)
|
||||
except NotImplementedError as e:
|
||||
# failed to convert to tensor, graph break
|
||||
unimplemented(str(e))
|
||||
tensor_value = clone_preserve_strides(tensor_value)
|
||||
except NotImplementedError as e:
|
||||
# failed to convert to tensor, graph break
|
||||
unimplemented(str(e))
|
||||
|
||||
# We do this because we want the full behavior of guarding the numpy ndarray as if it were
|
||||
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
|
||||
|
@ -125,12 +125,6 @@ class ContextWrappingVariable(VariableTracker):
|
||||
if isinstance(args[0], UserFunctionVariable):
|
||||
return WrappedUserFunctionVariable(args[0], self)
|
||||
|
||||
def supports_graph_breaks(self):
|
||||
return True
|
||||
|
||||
def exit_on_graph_break(self):
|
||||
return True
|
||||
|
||||
|
||||
class GenericContextWrappingVariable(UserDefinedObjectVariable):
|
||||
# Some methods in ContextWrappingVariable assumes the arguments are
|
||||
@ -189,12 +183,6 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable):
|
||||
tx.generic_context_manager_depth -= 1
|
||||
return x
|
||||
|
||||
def supports_graph_breaks(self):
|
||||
return False
|
||||
|
||||
def exit_on_graph_break(self):
|
||||
return True
|
||||
|
||||
|
||||
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
|
||||
"""represents torch grad requries grad"""
|
||||
|
@ -162,10 +162,7 @@ def get_overridable_functions():
|
||||
|
||||
from torch.overrides import get_overridable_functions as get_overridable_functions_
|
||||
|
||||
funcs = set(chain(*get_overridable_functions_().values()))
|
||||
more = {torch.ones, torch.ones_like, torch.zeros, torch.zeros_like, torch.empty}
|
||||
funcs.update(more)
|
||||
return funcs
|
||||
return set(chain(*get_overridable_functions_().values()))
|
||||
|
||||
|
||||
class BaseTorchVariable(VariableTracker):
|
||||
@ -841,13 +838,6 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
len(tx.symbolic_torch_function_state.mode_stack)
|
||||
)
|
||||
|
||||
@register(torch._C._get_function_stack_at)
|
||||
def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
assert len(args) == 1 and not kwargs
|
||||
ind = args[0].as_python_constant()
|
||||
assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack)
|
||||
return tx.symbolic_torch_function_state.mode_stack[ind]
|
||||
|
||||
@register(torch.set_default_device)
|
||||
def handle_set_default_device(
|
||||
self, tx: "InstructionTranslator", *args, **kwargs
|
||||
@ -865,7 +855,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
else:
|
||||
TorchFunctionModeStackVariable.register_device_context_insertion(tx)
|
||||
|
||||
return ConstantVariable.create(None)
|
||||
return None
|
||||
|
||||
return handlers
|
||||
|
||||
|
@ -2,35 +2,22 @@
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Deque, Dict, List, TYPE_CHECKING
|
||||
|
||||
import torch._C
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._guards import Source
|
||||
from torch.overrides import (
|
||||
_get_overloaded_args,
|
||||
get_default_nowrap_functions,
|
||||
TorchFunctionMode,
|
||||
)
|
||||
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
|
||||
from torch.utils._device import DeviceContext
|
||||
|
||||
from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..polyfills import NoEnterTorchFunctionMode
|
||||
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
|
||||
from ..utils import (
|
||||
class_has_getattribute,
|
||||
clear_torch_function_mode_stack,
|
||||
get_safe_global_name,
|
||||
has_torch_function,
|
||||
is_tensor_base_attr_getter,
|
||||
set_torch_function_mode_stack,
|
||||
)
|
||||
from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import GenericContextWrappingVariable
|
||||
from .ctx_manager import ContextWrappingVariable
|
||||
from .lazy import LazyVariableTracker
|
||||
from .lists import TupleVariable
|
||||
from .tensor import TensorSubclassVariable, TensorVariable
|
||||
@ -76,39 +63,6 @@ banned_attrs = [
|
||||
IGNORED_MODES = {DeviceContext}
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_prev_stack_var_name():
|
||||
from ..bytecode_transformation import unique_id
|
||||
|
||||
return unique_id("___prev_torch_function_mode_stack")
|
||||
|
||||
|
||||
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
|
||||
class TorchFunctionModeStackStateManager:
|
||||
def __init__(self):
|
||||
self.stack = []
|
||||
|
||||
def __enter__(self):
|
||||
self.stack = torch.overrides._get_current_function_mode_stack()
|
||||
clear_torch_function_mode_stack()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
set_torch_function_mode_stack(self.stack)
|
||||
self.stack = []
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temp_restore_stack(self):
|
||||
prev = torch.overrides._get_current_function_mode_stack()
|
||||
set_torch_function_mode_stack(self.stack)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
set_torch_function_mode_stack(prev)
|
||||
|
||||
|
||||
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
|
||||
|
||||
|
||||
class SymbolicTorchFunctionState:
|
||||
def __init__(self, py_stack):
|
||||
# This is annoyingly complicated because of how the torch function subclass + mode C API was designed
|
||||
@ -235,26 +189,9 @@ class TorchFunctionModeStackVariable(VariableTracker):
|
||||
return ind + cls.offset
|
||||
|
||||
|
||||
class TorchFunctionModeVariable(GenericContextWrappingVariable):
|
||||
@staticmethod
|
||||
def is_supported_torch_function_mode(ty):
|
||||
# Supported in this sense means we can support graph breaks under the
|
||||
# context.
|
||||
# We are able to trace custom modes but if there are graph breaks under them
|
||||
# and they have a custom __enter__/__exit__ we don't handle this for the
|
||||
# same reason we don't handle generic context managers: there may be side effects
|
||||
# that are now affected by executing the funtion across two frames instead of one
|
||||
# Today we support the enter/exit of the default TorchFunctionMode as well as
|
||||
# DeviceContext (which is used for set_default_device)
|
||||
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
|
||||
not class_has_getattribute(ty)
|
||||
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
|
||||
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
|
||||
)
|
||||
|
||||
class TorchFunctionModeVariable(ContextWrappingVariable):
|
||||
def __init__(self, value, source=None, **kwargs):
|
||||
if value is not None:
|
||||
super().__init__(value, **kwargs)
|
||||
super().__init__(value, **kwargs)
|
||||
self.value = value
|
||||
self.cm_obj = value # needed for BC with calling enter from CM code
|
||||
self.source = source
|
||||
@ -284,39 +221,8 @@ class TorchFunctionModeVariable(GenericContextWrappingVariable):
|
||||
kwargs,
|
||||
)
|
||||
|
||||
def enter(self, tx):
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
if isinstance(self.value, NoEnterTorchFunctionMode):
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
TorchInGraphFunctionVariable(
|
||||
torch._C._push_on_torch_function_stack
|
||||
).call_function(tx, [self], {})
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
|
||||
tx, [], {}
|
||||
)
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def reconstruct_type(self, codegen):
|
||||
ty = NoEnterTorchFunctionMode
|
||||
codegen(
|
||||
AttrSource(
|
||||
codegen.tx.import_source(ty.__module__),
|
||||
ty.__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def supports_graph_breaks(self):
|
||||
return True
|
||||
|
||||
def exit_on_graph_break(self):
|
||||
return False
|
||||
def _call_func(self, tx: "InstructionTranslator", values):
|
||||
unimplemented("enter/exit for torch function mode NYI")
|
||||
|
||||
|
||||
def _get_all_args(args, kwargs):
|
||||
|
@ -409,22 +409,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
and self.source
|
||||
and not is_forbidden_context_manager(self.value)
|
||||
):
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
from .ctx_manager import GenericContextWrappingVariable
|
||||
from .torch_function import TorchFunctionModeVariable
|
||||
|
||||
if issubclass(
|
||||
self.value, TorchFunctionMode
|
||||
) and TorchFunctionModeVariable.is_supported_torch_function_mode(
|
||||
self.value
|
||||
):
|
||||
var_cls = TorchFunctionModeVariable
|
||||
else:
|
||||
var_cls = GenericContextWrappingVariable
|
||||
|
||||
cm_obj = tx.output.side_effects.track_object_new(
|
||||
self.source, self.value, var_cls, {}
|
||||
self.source, self.value, GenericContextWrappingVariable, {}
|
||||
)
|
||||
cm_obj.call_method(tx, "__init__", args, kwargs)
|
||||
return cm_obj
|
||||
|
Reference in New Issue
Block a user