Compare commits

...

7 Commits

Author SHA1 Message Date
ac3dabf652 [Dynamo] Remove ignored modes from torch function mode stack guard
ghstack-source-id: c3398f28b58561ba6241279c6a7cf404aabfa8c7
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135503
2024-09-11 14:02:23 -07:00
54ab06fc07 [Dynamo] Remove ignored modes workaround
ghstack-source-id: 1ec9a6b4c31d310659b4a116abf5bfb1de393b12
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135502
2024-09-11 14:02:22 -07:00
32542724be [Dynamo] Trace enter/exit of TorchFunctionModes
ghstack-source-id: 8f0811c156177e2b54b3aea97835e1b15044080b
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135422
2024-09-11 14:02:22 -07:00
dfbb990dc4 [Dynamo] Simplify torch function mode stack guard
ghstack-source-id: 5fad7e6481132b96b594e6755b0fdb394aa9d56f
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135444
2024-09-09 16:02:11 -07:00
194d46e91c [Dynamo] Support thread local setattr
ghstack-source-id: d7ca565f27a57ba0aed030f74b90b3ce8faa59bd
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135443
2024-09-09 16:02:11 -07:00
9094fb5c7c [Dynamo] Trace torch function modes
ghstack-source-id: 188be474d3f4685d45141153c6425a0a2684715d
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133137
2024-09-09 16:02:11 -07:00
ec6b49eed9 [Dynamo] Disable metadata tf mode when tracing cond
ghstack-source-id: a8d524089ad362b08adb98b6851de5490815fd38
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134732
2024-09-07 23:50:38 -07:00
28 changed files with 977 additions and 403 deletions

View File

@ -3380,6 +3380,21 @@ utils_device.CURRENT_DEVICE == None""".split(
self.assertTrue(same(obj41.y, obj42.y))
self.assertEqual(cnts.frame_count, 1)
def test_thread_local_setattr(self):
from threading import local
loc = local()
@torch.compile(fullgraph=True)
def fn(x, l):
l.x = x
return x + 1
x = torch.ones(2, 2)
fn(x, loc)
self.assertTrue(loc.x is x)
def test_user_defined_class_name(self):
class MyClassFoo:
pass

View File

@ -1,5 +1,4 @@
# Owner(s): ["module: dynamo"]
from unittest.mock import patch
import torch
import torch._dynamo.test_case
@ -14,6 +13,17 @@ from torch.utils._device import DeviceContext
from torch.utils._python_dispatch import TorchDispatchMode
class TestMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.zeros(2, 2)
return super().__torch_function__(func, types, args, kwargs)
class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
@ -57,9 +67,11 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
def setUp(self):
torch.set_default_device(None)
torch._dynamo.reset()
def tearDown(self):
torch.set_default_device(None)
torch._dynamo.reset()
def _run_torch_function_mode_guard_test(self):
class TestMode1(BaseTorchFunctionMode):
@ -94,70 +106,6 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
fn(inp)
self.assertEqual(cnt.frame_count, 4)
def _run_ignored_mode_types_test(self):
class IgnoredMode(BaseTorchFunctionMode):
pass
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt.__call__, fullgraph=True)
def fn(x):
return x + 1
inp = torch.ones(2, 2)
with patch(
"torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode}
):
# initial compile
fn(inp)
# no recompile, mode ignored
# note: the ref stack is length 0, and the stack we are checking against has length 2
# we want to check both ref stack len > runtime stack, and ref stack len < runtime stack
with IgnoredMode(), IgnoredMode():
fn(inp)
self.assertEqual(cnt.frame_count, 1)
# recompile due to new mode on the stack
with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
fn(inp)
self.assertEqual(cnt.frame_count, 2)
# recompile
# tests both ref stack len > runtime stack len for the above guard check
# and ref stack len < runtime stack len for the initial zero mode case
with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode():
fn(inp)
self.assertEqual(cnt.frame_count, 3)
# no recompile
with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
fn(inp)
self.assertEqual(cnt.frame_count, 3)
# This is tricky, basically the ignored modes are baked into the guard
# IgnoredMode will be ignored forever by that guard.
# This is okay since we don't expect to be modifying IGNORED_MODES
# in the middle of execution except for the purposes of testing.
torch._dynamo.reset()
with IgnoredMode():
fn(inp)
self.assertEqual(cnt.frame_count, 4)
@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
def test_torch_function_mode_guards_ignored_types_py(self):
self._run_ignored_mode_types_test()
def test_torch_function_mode_guards_ignored_types_cpp(self):
self._run_ignored_mode_types_test()
@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
def test_torch_function_mode_guards_py(self):
self._run_torch_function_mode_guard_test()
@ -324,6 +272,218 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
fn(inp)
self.assertEqual(cnt.frame_count, 2)
def test_nested_torch_function_mode(self):
mode_1_called = False
mode_2_called = False
def reset_state():
nonlocal mode_1_called
nonlocal mode_2_called
mode_1_called = False
mode_2_called = False
ones = torch.ones(2, 2)
zeros = torch.zeros(2, 2)
class TestMode1(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
nonlocal mode_1_called
mode_1_called = True
if func == torch.add:
return zeros
return super().__torch_function__(func, types, args, kwargs)
class TestMode2(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
nonlocal mode_2_called
mode_2_called = True
if func == torch.mul:
return ones
return super().__torch_function__(func, types, args, kwargs)
def fn(x):
return torch.add(x, 3)
def fn_2(x):
return torch.mul(x, 3) + torch.add(x, 3)
inp = torch.ones(2, 2) + 1
for fn_i in [fn, fn_2]:
fn_opt = torch.compile(fn_i, fullgraph=True)
with TestMode1(), TestMode2():
expected = fn_i(inp), mode_1_called, mode_2_called
reset_state()
actual = fn_opt(inp), mode_1_called, mode_2_called
reset_state()
self.assertEqual(expected, actual)
def test_torch_function_mode_disable(self):
class TestSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.ones(2, 2)
return super().__torch_function__(func, types, args, kwargs)
class TestMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.zeros(2, 2)
return super().__torch_function__(func, types, args, kwargs)
def fn(x):
return torch.add(x, 3)
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
fn_opt = torch.compile(fn, fullgraph=True)
with TestMode(), torch._dynamo.config.patch(
"traceable_tensor_subclasses", {TestSubclass}
):
with torch._C.DisableTorchFunctionSubclass():
expected = fn(inp)
actual = fn_opt(inp)
self.assertEqual(expected, actual)
with torch._C.DisableTorchFunction():
expected = fn(inp)
actual = fn_opt(inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_highest_priority(self):
class TestSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.ones(2, 2)
return super().__torch_function__(func, types, args, kwargs)
def fn(x):
return torch.add(x, 3)
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
fn_opt = torch.compile(fn, fullgraph=True)
with TestMode(), torch._dynamo.config.patch(
"traceable_tensor_subclasses", {TestSubclass}
):
expected = fn(inp)
actual = fn_opt(inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_enter_exit(self):
def fn(x, y):
with TestMode():
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn, fullgraph=True)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_graph_break(self):
def fn(x, y):
with TestMode():
torch._dynamo.graph_break()
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_and_pop_graph_break(self):
def fn(x, y):
with TestMode():
z = _pop_torch_function_stack()
torch._dynamo.graph_break()
_push_on_torch_function_stack(z)
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_restore_on_exc(self):
@torch._dynamo.disable()
def err():
raise RuntimeError("test")
@torch.compile()
def fn(x):
with TestMode():
x += 1
err()
x += 2
return x
try:
fn(torch.ones(2, 2))
except RuntimeError:
pass
self.assertEqual(_len_torch_function_stack(), 0)
def test_torch_function_mode_and_pop_graph_break_mutation(self):
def fn(x, y):
with TestMode():
z = _pop_torch_function_stack()
z.y = 5
torch._dynamo.graph_break()
_push_on_torch_function_stack(z)
o = torch.add(x, 3)
o = torch.mul(o, z.y)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -24,6 +24,7 @@ from torch.testing._internal.common_utils import (
IS_WINDOWS,
parametrize,
run_tests,
skipIfCrossRef,
skipIfTorchDynamo,
TEST_WITH_TORCHDYNAMO,
TestCase,
@ -1557,6 +1558,7 @@ class TestControlFlowTraced(TestCase):
self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True)))
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
@skipIfCrossRef # Arg order changes with crossref
def test_cond_simple_with_linear_compile_check_graph(self):
from torch._dynamo.testing import EagerAndRecordGraphs
@ -1819,6 +1821,7 @@ def forward(self, arg0_1):
self._check_compile(fn, inp, backend=backend)
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
@skipIfCrossRef # Arg order changes with cross ref
def test_while_loop_simple_with_linear_compile_check_graph(self):
fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
from torch._dynamo.testing import EagerAndRecordGraphs

View File

@ -13,7 +13,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
from torch.fx import Node
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_utils import IS_WINDOWS
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef
class TestHelperModules:
@ -139,6 +139,8 @@ class TestMetaDataPorting(QuantizationTestCase):
self.assertEqual(v, node_tags[k])
return m
@skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack
# trace of the mode torch function impl doesn't match the traced graph stored lineno.
def test_simple_metadata_porting(self):
"""
Model under test

View File

@ -67,7 +67,7 @@ class GuardManager:
) -> None: ...
def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
def add_torch_function_mode_stack_guard(
self, initial_stack, ignored_types, verbose_code_parts: list[str]
self, initial_stack, verbose_code_parts: list[str]
) -> None: ...
class RootGuardManager(GuardManager):

View File

@ -31,6 +31,18 @@ def eager(gm, fake_tensor_inputs, **kwargs):
return gm.forward
def make_eager_backend_with_torch_function_mode(mode):
"""Used to trace HOPs (cond and while) for eager exectution, the metadata
TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
in the HOP, so we need to externally run this mode and not trace it."""
def fn(gm, fake_tensor_inputs, **kwargs):
with mode:
return gm.forward
return fn
@register_backend
def eager_noexcept(gm, fake_tensor_inputs, **kwargs):
if kwargs:

View File

@ -112,6 +112,7 @@ from .utils import (
troubleshooting_url,
write_record_to_file,
)
from .variables.torch_function import torch_function_mode_stack_state_mgr
np: Optional[ModuleType]
@ -210,15 +211,18 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
cleanup = setup_compile_debug()
exit_stack = contextlib.ExitStack()
exit_stack.enter_context(
torch.fx._symbolic_trace._maybe_revert_all_patches()
)
exit_stack.enter_context(torch_function_mode_stack_state_mgr)
try:
return fn(*args, **kwargs)
finally:
cleanup.close()
assert (
torch._C._len_torch_function_stack() == 0
), "Torch function mode stack state changed while dynamo tracing, please report a bug"
exit_stack.close()
torch._C._set_grad_enabled(prior_grad_mode)
torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
@ -605,6 +609,10 @@ def _compile(
output: Optional[OutputGraph] = None
tracer: Optional[InstructionTranslator] = None
tf_mode_stack: List[
torch.overrides.TorchFunctionMode
] = torch.overrides._get_current_function_mode_stack()
@preserve_global_state
def transform(
instructions: List[Instruction], code_options: Dict[str, object]
@ -618,6 +626,7 @@ def _compile(
locals,
globals,
builtins,
tf_mode_stack,
code_options,
compiler_fn,
one_graph,

View File

@ -97,6 +97,7 @@ from .source import (
ScriptObjectQualifiedNameSource,
ShapeEnvSource,
SubclassAttrListSource,
TorchFunctionModeStackSource,
TupleIteratorGetItemSource,
TypeSource,
UnspecializedBuiltinNNModuleSource,
@ -110,6 +111,7 @@ from .utils import (
dict_keys_repr,
get_custom_getattr,
get_torch_function_mode_stack,
get_torch_function_mode_stack_at,
guard_failures,
istype,
key_is_id,
@ -313,6 +315,7 @@ CLOSURE_VARS = {
"___dict_contains": lambda a, b: a in b,
"___tuple_iterator_len": tuple_iterator_len,
"___tuple_iterator_getitem": tuple_iterator_getitem,
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
"__math_isnan": math.isnan,
"__numpy_isnan": None if np is None else np.isnan,
"inf": float("inf"),
@ -900,6 +903,15 @@ class GuardBuilder(GuardBuilderBase):
):
assert base_guard_manager # to make mypy happy
out = base_guard_manager
elif istype(source, TorchFunctionModeStackSource):
out = root_guard_manager.lambda_manager(
python_lambda=lambda _: get_torch_function_mode_stack_at(
source._get_index()
),
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, GradSource):
assert base_guard_manager # to make mypy happy
out = base_guard_manager.grad_manager(
@ -2206,6 +2218,8 @@ class CheckFunctionManager:
self.output_graph = output_graph
w_builder = None
# NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing
# in case a set default device call was made in the graph.
self.torch_function_mode_stack = (
output_graph.torch_function_mode_stack if output_graph else None
)
@ -2322,15 +2336,12 @@ class CheckFunctionManager:
)
if config.enable_cpp_guard_manager:
from .variables.torch_function import IGNORED_MODES
# Insert the global_state guard
assert self.guard_manager # to make mypy happy
self.guard_manager.root.add_global_state_guard(["___check_global_state()"])
self.guard_manager.root.add_torch_function_mode_stack_guard(
self.torch_function_mode_stack,
list(IGNORED_MODES),
["___check_torch_function_mode_stack()"],
)
# Clear references to torch_function modes held in the list
@ -2637,16 +2648,14 @@ def is_recompiles_verbose_enabled():
# this will only be used if cpp guards are disabled
def make_torch_function_mode_stack_guard(intial_stack):
types = [type(x) for x in intial_stack]
from .variables.torch_function import IGNORED_MODES
def check_torch_function_mode_stack():
cur_stack = get_torch_function_mode_stack()
if len(cur_stack) != len(types):
return False
for ty, mode in zip(types, cur_stack):
if ty in IGNORED_MODES:
continue
if ty != type(mode):
return False

View File

@ -78,7 +78,6 @@ from .utils import (
get_instruction_source_311,
get_locals_to_steal,
get_static_address_type,
get_torch_function_mode_stack,
graph_break_reasons,
increment_op_count,
lazy_format_graph_code,
@ -250,6 +249,7 @@ class OutputGraph:
local_scope: Scope,
global_scope: Scope,
f_code,
torch_function_mode_stack,
):
super().__init__()
self.tracers = [SubgraphTracer(self, export_root=export)]
@ -368,7 +368,7 @@ class OutputGraph:
# This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
# This records the initial torch function mode stack for guarding
self.torch_function_mode_stack = get_torch_function_mode_stack()
self.torch_function_mode_stack = torch_function_mode_stack
# Tracks if the output graph has a user defined allowed function in the
# graph. This is used later to determine if we should fallback to eager
@ -1020,7 +1020,7 @@ class OutputGraph:
prefix_insts.clear()
for block in reversed(tx.block_stack):
block.exit(tx)
block.exit(tx, is_graph_break=reason.graph_break)
self.cleanup_graph()
tx.prune_dead_locals()

View File

@ -25,6 +25,26 @@ if TYPE_CHECKING:
sys as sys,
)
from torch.overrides import BaseTorchFunctionMode
# These classes handle support for TorchFunctionModes across
# graph breaks
# Today the TorchFunctionMode enter (for the classes we support)
# simply pushes the mode onto the stack. Since after this occurs
# the stack is mutated, and we replay these mutations, we don't need
# any cleanup logic to be run once the graph break occurs, we simply replay
# these mutations to ensure at the graph break the torch function mode stack is correct
# and reconstruct the torch function mode stack normally
# when we compile the resume function on the other side of the break.
# However, to ensure we exit properly
# in the resume function, we need to re-enter the contexts as we do other contexts.
# These contexts do nothing on enter, but provide the correct exit logic to ensure
# the stack state is correct.
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
def __enter__(self):
pass
def index(iterator, item, start=0, end=None):
from itertools import islice

View File

@ -48,6 +48,107 @@ class ReenterWith:
stack_index: int
target_values: Optional[Tuple[Any, ...]] = None
def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
"""
Codegen based off of:
try:
(rest)
finally:
"""
except_jump_target = create_instruction(
"NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO"
)
cleanup_complete_jump_target = create_instruction("NOP")
setup_finally: List[Instruction] = []
if sys.version_info < (3, 11):
setup_finally.append(
create_instruction("SETUP_FINALLY", target=except_jump_target)
)
else:
exn_tab_begin = create_instruction("NOP")
exn_tab_end = create_instruction("NOP")
exn_tab_begin.exn_tab_entry = InstructionExnTabEntry(
exn_tab_begin,
exn_tab_end,
except_jump_target,
self.stack_index + 1,
False,
)
setup_finally.append(exn_tab_begin)
def create_reset():
insts = [
create_instruction(
"LOAD_GLOBAL", argval="__import_torch_dot__dynamo_dot_utils"
),
create_instruction("LOAD_ATTR", argval="set_torch_function_mode_stack"),
]
return [
*insts,
create_instruction(
"LOAD_FAST", argval="___prev_torch_function_mode_stack"
),
*create_call_function(1, True),
create_instruction("POP_TOP"),
]
if sys.version_info < (3, 9):
epilogue = [
create_instruction("POP_BLOCK"),
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
except_jump_target,
*create_reset(),
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
*create_reset(),
create_instruction("RAISE_VARARGS", argval=0),
create_instruction("POP_EXCEPT", argval=0),
create_instruction("END_FINALLY"),
cleanup_complete_jump_target,
]
elif sys.version_info < (3, 11):
epilogue = [
create_instruction("POP_BLOCK"),
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
except_jump_target,
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
*create_reset(),
create_instruction("RAISE_VARARGS", argval=0),
create_instruction("POP_EXCEPT", argval=0),
cleanup_complete_jump_target,
]
else:
finally_exn_tab_end = create_instruction("RAISE_VARARGS", argval=0)
finally_exn_tab_target = create_instruction("COPY", arg=3)
except_jump_target.exn_tab_entry = InstructionExnTabEntry(
except_jump_target,
finally_exn_tab_end,
finally_exn_tab_target,
self.stack_index + 2,
True,
)
epilogue = [
exn_tab_end,
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
except_jump_target, # PUSH_EXC_INFO
create_instruction("POP_TOP"),
*create_reset(),
finally_exn_tab_end,
finally_exn_tab_target, # COPY 3
create_instruction("POP_EXCEPT"),
create_instruction("RERAISE", arg=1), # RERAISE 1
cleanup_complete_jump_target,
]
cleanup[:] = epilogue + cleanup
return setup_finally
# If we do not want to destroy the stack, we can do the same thing as a
# `SETUP_WITH` block, only that we store the context manager in a local_symbol
def try_except(self, code_options, cleanup: List[Instruction]):

View File

@ -593,16 +593,19 @@ class SideEffects:
elif isinstance(
var, variables.torch_function.TorchFunctionModeStackVariable
):
cg.add_push_null(
lambda: cg.load_import_from(
utils.__name__, "set_torch_function_mode_stack"
)
)
# Needed in the finally block for stack restoration
cg.load_import_from(utils.__name__, "get_torch_function_mode_stack")
cg.call_function(0, True)
name = "___prev_torch_function_mode_stack"
cg.code_options["co_varnames"] += (name,)
cg.append_output(create_instruction("STORE_FAST", argval=name))
cg.load_import_from(utils.__name__, "set_torch_function_mode_stack")
cg.foreach(var.symbolic_stack)
cg.append_output(
create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))
)
cg.call_function(1, False)
cg.call_function(1, True)
cg.append_output(create_instruction("POP_TOP"))
elif self.is_attribute_mutation(var):
# Applying mutations involves two steps: 1) Push all

View File

@ -608,7 +608,7 @@ class TorchFunctionModeStackSource(Source):
ind: int
def name(self):
return ""
return f"___get_torch_function_mode_stack_at({self._get_index()})"
def _get_index(self):
from .variables.torch_function import TorchFunctionModeStackVariable

View File

@ -19,20 +19,7 @@ import traceback
import types
import typing
import weakref
from typing import (
Any,
Callable,
cast,
Deque,
Dict,
List,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union
from unittest.mock import patch
import torch
@ -72,14 +59,12 @@ from .source import (
GlobalWeakRefSource,
LocalSource,
Source,
TorchFunctionModeStackSource,
)
from .trace_rules import is_builtin_constant, is_forbidden
from .utils import (
counters,
get_fake_value,
get_instruction_source_311,
get_torch_function_mode_stack,
graph_break_dup_warning_checker,
istype,
LazyString,
@ -120,11 +105,10 @@ from .variables.misc import (
)
from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
if TYPE_CHECKING:
from .variables.torch_function import TorchFunctionModeVariable
from .variables.torch_function import (
SymbolicTorchFunctionState,
TorchFunctionModeVariable,
)
from .variables.user_defined import (
RemovableHandleVariable,
UserDefinedClassVariable,
@ -283,9 +267,12 @@ class BlockStackEntry:
else:
return ReenterWith(self.stack_index)
def exit(self, tx):
def exit(self, tx, is_graph_break):
assert self.with_context is not None
return self.with_context.exit(tx)
if (
is_graph_break and self.with_context.exit_on_graph_break()
) or not is_graph_break:
return self.with_context.exit(tx)
class ReturnValueOp(Exception):
@ -651,8 +638,17 @@ def break_graph_if_unsupported(*, push):
cleanup: List[Instruction] = []
# Reconstruct the context variable CLASS in the block stack
for b in self.block_stack:
# Don't exit any modes we have entered,
# output bytecode will mutate the tf mode stack accordingly
if isinstance(b.with_context, TorchFunctionModeVariable):
cg.extend_output(
b.resume_fn().try_except_torch_function_mode(
cg.code_options, cleanup
)
)
continue
assert b.with_context is not None
assert isinstance(b.with_context, ContextWrappingVariable)
assert isinstance(b.with_context, (ContextWrappingVariable))
b.with_context.reconstruct_type(cg)
cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup))
self.output.add_output_instructions(cg.get_instructions())
@ -728,7 +724,7 @@ class InstructionTranslatorBase(
output: OutputGraph
symbolic_locals: Dict[str, VariableTracker]
symbolic_globals: Dict[str, VariableTracker]
symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"]
symbolic_torch_function_state: SymbolicTorchFunctionState
stack: List[VariableTracker]
instruction_pointer: Optional[int]
current_instruction: Instruction
@ -2305,7 +2301,10 @@ class InstructionTranslatorBase(
):
unimplemented(f"{inst.opname} {ctx}")
if isinstance(ctx, GenericContextWrappingVariable):
if (
isinstance(ctx, GenericContextWrappingVariable)
and not ctx.supports_graph_breaks()
):
self.generic_context_manager_depth += 1
# Need this redundant check for mypy
@ -2548,7 +2547,7 @@ class InstructionTranslatorBase(
code_options: Dict[str, Any],
symbolic_locals: Dict[str, VariableTracker],
symbolic_globals: Dict[str, VariableTracker],
symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"],
symbolic_torch_function_state: SymbolicTorchFunctionState,
f_code: types.CodeType,
export: bool,
inline_depth: int,
@ -2563,7 +2562,7 @@ class InstructionTranslatorBase(
self.output = output
self.symbolic_locals = symbolic_locals
self.symbolic_globals = symbolic_globals
self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack
self.symbolic_torch_function_state = symbolic_torch_function_state
self.stack = []
# stack of variable names for tracking 3.13 closures
self.name_stack: list[Any] = []
@ -2652,6 +2651,7 @@ class InstructionTranslator(InstructionTranslatorBase):
f_locals,
f_globals,
f_builtins,
torch_function_mode_stack,
code_options,
compiler_fn,
one_graph,
@ -2677,6 +2677,7 @@ class InstructionTranslator(InstructionTranslatorBase):
local_scope=f_locals,
global_scope=f_globals,
f_code=f_code,
torch_function_mode_stack=torch_function_mode_stack,
),
instructions=instructions,
f_locals=f_locals,
@ -2686,7 +2687,7 @@ class InstructionTranslator(InstructionTranslatorBase):
symbolic_locals={}, # set below
# A global var is inserted only after a STORE_GLOBAL happens to it
symbolic_globals={},
symbolic_torch_function_mode_stack=collections.deque(),
symbolic_torch_function_state=None, # type: ignore[arg-type] # set below
f_code=f_code,
export=export,
inline_depth=0,
@ -2721,7 +2722,9 @@ class InstructionTranslator(InstructionTranslatorBase):
if k in f_locals
}
self._init_torch_function_mode_stack()
self.symbolic_torch_function_state = SymbolicTorchFunctionState(
torch_function_mode_stack
)
self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = []
if export:
@ -2762,29 +2765,6 @@ class InstructionTranslator(InstructionTranslatorBase):
)
unimplemented(msg)
def _init_torch_function_mode_stack(self):
from .variables.torch_function import TorchFunctionModeStackVariable
TorchFunctionModeStackVariable.reset()
self.symbolic_torch_function_mode_stack: Deque[
TorchFunctionModeVariable
] = collections.deque()
# We want to retrieve all modes to properly reconstruct the stack if needed
py_stack = get_torch_function_mode_stack(filter_ignored=False)
if py_stack:
has_device_context = isinstance(
py_stack[0], torch.utils._device.DeviceContext
)
for i, val in enumerate(py_stack):
self.symbolic_torch_function_mode_stack.append(
variables.LazyVariableTracker.create(
val, source=TorchFunctionModeStackSource(i)
)
)
def get_example_value(self, source: Source):
if isinstance(source, LocalSource):
return self.f_locals[source.local_name]
@ -3116,7 +3096,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
code,
sub_locals,
parent.symbolic_globals,
parent.symbolic_torch_function_mode_stack,
parent.symbolic_torch_function_state,
closure_cells,
func,
)
@ -3126,7 +3106,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
code,
sub_locals,
parent.symbolic_globals,
parent.symbolic_torch_function_mode_stack,
parent.symbolic_torch_function_state,
closure_cells,
func,
)
@ -3179,7 +3159,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
code: types.CodeType,
symbolic_locals: Dict[str, VariableTracker],
symbolic_globals: Dict[str, VariableTracker],
symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"],
symbolic_torch_function_state: SymbolicTorchFunctionState,
closure_cells: Dict[str, VariableTracker],
funcvar: BaseUserFunctionVariable,
) -> None:
@ -3196,7 +3176,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
f_builtins=f_builtins,
symbolic_locals=symbolic_locals,
symbolic_globals=symbolic_globals,
symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack,
symbolic_torch_function_state=symbolic_torch_function_state,
instructions=instructions,
code_options={k: getattr(code, k) for k in get_code_keys()},
f_code=code,

View File

@ -163,6 +163,7 @@ def debug_insert_nops(
local_scope=locals(),
global_scope=globals(),
f_code=frame.f_code,
torch_function_mode_stack=[],
)
return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))

View File

@ -303,6 +303,7 @@ manual_torch_name_rule_map = {
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
"torch.set_default_device": UserFunctionVariable,
"torch.sparse_bsc_tensor": SkipFunctionVariable,
"torch.sparse_bsr_tensor": SkipFunctionVariable,
"torch.sparse_csc_tensor": SkipFunctionVariable,
@ -2795,7 +2796,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch.random.initial_seed",
"torch.random.seed",
"torch.return_types.pytree_register_structseq",
"torch.set_default_device",
"torch.set_default_dtype",
"torch.set_default_tensor_type",
"torch.set_deterministic_debug_mode",
@ -3254,6 +3254,7 @@ MOD_INLINELIST = [
"torch.testing",
"torch.utils._content_store",
"torch.utils._contextlib",
"torch.utils._device",
"torch.utils._foreach_utils",
"torch.utils._python_dispatch",
"torch.utils._pytree",
@ -3588,7 +3589,9 @@ def lookup_inner(
if reasons is not None:
reasons.add("func name is patched_init")
return SkipFunctionVariable
elif name == "__torch_function__":
elif name == "__torch_function__" or (
obj and obj.__name__ == "__torch_function__"
):
if reasons is not None:
reasons.add("func name is __torch_function__")
return UserFunctionVariable

View File

@ -63,7 +63,6 @@ import torch.fx.experimental.symbolic_shapes
import torch.utils._pytree as pytree
from torch import fx
from torch._C import (
_get_function_stack_at,
_instruction_counter,
_len_torch_function_stack,
_pop_torch_function_stack,
@ -3062,14 +3061,10 @@ def is_parameter_freezing():
return torch._inductor.config.freezing and not torch.is_grad_enabled()
def get_torch_function_mode_stack(filter_ignored=True):
from .variables.torch_function import IGNORED_MODES
stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())]
if filter_ignored:
stack = [mode for mode in stack if type(mode) not in IGNORED_MODES]
return stack
def get_torch_function_mode_stack():
return [
get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack())
]
def get_torch_function_mode_stack_at(ind):
@ -3085,6 +3080,11 @@ def set_torch_function_mode_stack(stack):
_push_on_torch_function_stack(mode)
def clear_torch_function_mode_stack():
for i in range(_len_torch_function_stack()):
_pop_torch_function_stack()
def verify_guard_fn_signature(value):
fn = value.__metadata_guard__
sig = inspect.signature(fn)

View File

@ -204,6 +204,7 @@ from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
from .torch_function import (
build_torch_function_fn,
TensorWithTFOverrideVariable,
torch_function_mode_stack_state_mgr,
TorchFunctionModeVariable,
)
from .user_defined import (
@ -1663,15 +1664,16 @@ class VariableBuilder:
# but warning is not the end of the world
assert isinstance(value.base, np.nditer)
try:
tensor_value = _util._try_convert_to_tensor(value)
if readonly:
from torch._prims_common import clone_preserve_strides
with torch_function_mode_stack_state_mgr.temp_restore_stack():
try:
tensor_value = _util._try_convert_to_tensor(value)
if readonly:
from torch._prims_common import clone_preserve_strides
tensor_value = clone_preserve_strides(tensor_value)
except NotImplementedError as e:
# failed to convert to tensor, graph break
unimplemented(str(e))
tensor_value = clone_preserve_strides(tensor_value)
except NotImplementedError as e:
# failed to convert to tensor, graph break
unimplemented(str(e))
# We do this because we want the full behavior of guarding the numpy ndarray as if it were
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here

View File

@ -125,6 +125,12 @@ class ContextWrappingVariable(VariableTracker):
if isinstance(args[0], UserFunctionVariable):
return WrappedUserFunctionVariable(args[0], self)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return True
class GenericContextWrappingVariable(UserDefinedObjectVariable):
# Some methods in ContextWrappingVariable assumes the arguments are
@ -183,6 +189,12 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable):
tx.generic_context_manager_depth -= 1
return x
def supports_graph_breaks(self):
return False
def exit_on_graph_break(self):
return True
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
"""represents torch grad requries grad"""
@ -637,6 +649,8 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
def _call_func(self, tx: "InstructionTranslator", values):
assert len(values) == 1
tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0]
tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0]
tx.output.set_torch_function_state(values[0])

View File

@ -149,6 +149,18 @@ tracing_state_functions = {
bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"])
@functools.lru_cache(None)
def get_overridable_functions():
from itertools import chain
from torch.overrides import get_overridable_functions as get_overridable_functions_
funcs = set(chain(*get_overridable_functions_().values()))
more = {torch.ones, torch.ones_like, torch.zeros, torch.zeros_like, torch.empty}
funcs.update(more)
return funcs
class BaseTorchVariable(VariableTracker):
"""common base for all torch.* functions, classes, modules and other things"""
@ -782,10 +794,10 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
self, tx: "InstructionTranslator", *args, **kwargs
):
assert not args and not kwargs
if not tx.symbolic_torch_function_mode_stack:
if not tx.symbolic_torch_function_state.mode_stack:
raise unimplemented("Popping from an empty torch function mode stack")
TorchFunctionModeStackVariable.register_mutation(tx)
return tx.symbolic_torch_function_mode_stack.pop()
return tx.symbolic_torch_function_state.pop_torch_function_mode()
@register(torch._C._push_on_torch_function_stack)
def handle_push_torch_function(
@ -793,7 +805,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
):
assert len(args) == 1 and not kwargs
TorchFunctionModeStackVariable.register_mutation(tx)
tx.symbolic_torch_function_mode_stack.append(args[0])
tx.symbolic_torch_function_state.push_torch_function_mode(args[0])
return ConstantVariable.create(None)
@register(torch._C._len_torch_function_stack)
@ -801,7 +813,16 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
self, tx: "InstructionTranslator", *args, **kwargs
):
assert not args and not kwargs
return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack))
return ConstantVariable.create(
len(tx.symbolic_torch_function_state.mode_stack)
)
@register(torch._C._get_function_stack_at)
def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs):
assert len(args) == 1 and not kwargs
ind = args[0].as_python_constant()
assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack)
return tx.symbolic_torch_function_state.mode_stack[ind]
@register(torch.set_default_device)
def handle_set_default_device(
@ -820,7 +841,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
else:
TorchFunctionModeStackVariable.register_device_context_insertion(tx)
return None
return ConstantVariable.create(None)
return handlers
@ -833,6 +854,9 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
from . import ConstantVariable, SymNodeVariable, TensorVariable
from .builder import wrap_fx_proxy
if self.torch_function_override_enabled(tx, args, kwargs):
return dispatch_torch_function(tx, self, args, kwargs)
if self.can_constant_fold_through() and check_unspec_or_constant_args(
args, kwargs
):
@ -850,147 +874,144 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
if result:
return result
if can_dispatch_torch_function(tx, args, kwargs):
return dispatch_torch_function(tx, self, args, kwargs)
else:
any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
all_ints_or_floats = all(
isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
for x in args
)
if (
getattr(self.value, "__module__", "") == "torch"
and self.value.__name__ in bin_ops
and any_symints_or_symfloats
and all_ints_or_floats
):
msg = f"""\
all_ints_or_floats = all(
isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
for x in args
)
if (
getattr(self.value, "__module__", "") == "torch"
and self.value.__name__ in bin_ops
and any_symints_or_symfloats
and all_ints_or_floats
):
msg = f"""\
Calling {str(self.value)} on only torch.SymInt arguments is not yet supported.
To support this behavior, we need to allow const-propping tensors that store symint data.
For now, dynamo will explicitly graph break when it encounters user code with this behavior.
"""
log.warning(msg)
unimplemented(msg)
log.warning(msg)
unimplemented(msg)
# TODO(voz): Replace w/ dynamic shape rewrite table.
# Ideally, we would be able to do this at ctor time, but alas we need a combination
# of value + args to determine this.
fn_ = self.value
if any_symints_or_symfloats:
torch_sym_op = f"_sym_{self.value.__name__}"
if getattr(self.value, "__module__", None) == "math" and hasattr(
torch, torch_sym_op
):
fn_ = getattr(torch, torch_sym_op)
# TODO(voz): Replace w/ dynamic shape rewrite table.
# Ideally, we would be able to do this at ctor time, but alas we need a combination
# of value + args to determine this.
fn_ = self.value
if any_symints_or_symfloats:
torch_sym_op = f"_sym_{self.value.__name__}"
if getattr(self.value, "__module__", None) == "math" and hasattr(
torch, torch_sym_op
):
fn_ = getattr(torch, torch_sym_op)
fake_out_shape = None
if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
# Calling fake tensor propagation can mutate the out= tensor in
# tx.output.tracked_fakes. tracked_fakes are used to apply
# symbolic_shape guards. Mutating them destroys the information
# prior to tracing, which is essential for creating right
# guards. So save the shape now, and check later if it has
# changed. If it has, graph break.
fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
fake_out_shape = None
if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
# Calling fake tensor propagation can mutate the out= tensor in
# tx.output.tracked_fakes. tracked_fakes are used to apply
# symbolic_shape guards. Mutating them destroys the information
# prior to tracing, which is essential for creating right
# guards. So save the shape now, and check later if it has
# changed. If it has, graph break.
fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
fn_,
*proxy_args_kwargs(args, kwargs),
),
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
fn_,
*proxy_args_kwargs(args, kwargs),
),
)
if (
isinstance(tensor_variable, TensorVariable)
and "requires_grad" in kwargs
and kwargs["requires_grad"].as_python_constant()
):
unimplemented(
"""factory functions that return tensors that require grad are not supported.
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
)
if (
isinstance(tensor_variable, TensorVariable)
and "requires_grad" in kwargs
and kwargs["requires_grad"].as_python_constant()
):
unimplemented(
"""factory functions that return tensors that require grad are not supported.
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
)
if "out" in kwargs and not (
isinstance(kwargs["out"], variables.ConstantVariable)
and kwargs["out"].as_python_constant() is None
):
# out variants of torch operators like torch.sort and
# torch.sigmoid mutate the tensors in the out field. Track such
# tensors and rewrite the symbolic locals.
if isinstance(tensor_variable, TupleVariable):
assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
output_tensor_names = [
tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
]
for idx, name in enumerate(output_tensor_names):
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable.items[idx]
for out_tensor, result_tensor in zip(
kwargs["out"].items, tensor_variable.items
):
if (
out_tensor.source
and out_tensor in tx.output.graphargs
and isinstance(out_tensor, variables.TensorVariable)
and isinstance(result_tensor, variables.TensorVariable)
and out_tensor.size != result_tensor.size
):
# It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs")
elif isinstance(tensor_variable, TensorVariable):
assert isinstance(kwargs["out"], TensorVariable)
assert "example_value" in kwargs["out"].proxy.node.meta
fake_tensor = tensor_variable.proxy.node.meta["example_value"]
fake_out = kwargs["out"].proxy.node.meta["example_value"]
if "out" in kwargs and not (
isinstance(kwargs["out"], variables.ConstantVariable)
and kwargs["out"].as_python_constant() is None
):
# out variants of torch operators like torch.sort and
# torch.sigmoid mutate the tensors in the out field. Track such
# tensors and rewrite the symbolic locals.
if isinstance(tensor_variable, TupleVariable):
assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
output_tensor_names = [
tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
]
for idx, name in enumerate(output_tensor_names):
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable.items[idx]
for out_tensor, result_tensor in zip(
kwargs["out"].items, tensor_variable.items
):
if (
kwargs["out"].source
and kwargs["out"] in tx.output.graphargs
and fake_out_shape != fake_tensor.shape
out_tensor.source
and out_tensor in tx.output.graphargs
and isinstance(out_tensor, variables.TensorVariable)
and isinstance(result_tensor, variables.TensorVariable)
and out_tensor.size != result_tensor.size
):
# It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs")
elif isinstance(tensor_variable, TensorVariable):
assert isinstance(kwargs["out"], TensorVariable)
assert "example_value" in kwargs["out"].proxy.node.meta
fake_tensor = tensor_variable.proxy.node.meta["example_value"]
fake_out = kwargs["out"].proxy.node.meta["example_value"]
if (
kwargs["out"].source
and kwargs["out"] in tx.output.graphargs
and fake_out_shape != fake_tensor.shape
):
# It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs")
if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument
unimplemented(
"out= op was called where output tensor was non-contiguous"
)
name = tx.find_symbolic_locals_name(kwargs["out"])
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable
elif (
isinstance(tensor_variable, ConstantVariable)
and tensor_variable.value is None
):
# Handle out-variant custom ops that return None.
if isinstance(kwargs["out"], TensorVariable):
assert "example_value" in kwargs["out"].proxy.node.meta
fake_out = kwargs["out"].proxy.node.meta["example_value"]
if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument
unimplemented(
"out= op was called where output tensor was non-contiguous"
)
name = tx.find_symbolic_locals_name(kwargs["out"])
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable
elif (
isinstance(tensor_variable, ConstantVariable)
and tensor_variable.value is None
):
# Handle out-variant custom ops that return None.
if isinstance(kwargs["out"], TensorVariable):
assert "example_value" in kwargs["out"].proxy.node.meta
fake_out = kwargs["out"].proxy.node.meta["example_value"]
elif isinstance(kwargs["out"], ListVariable):
for idx, x in enumerate(kwargs["out"].items):
assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined]
fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined]
if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument
unimplemented(
"out= op was called where output tensor was non-contiguous"
"out= op was called where some of the output tensors were non-contiguous"
)
elif isinstance(kwargs["out"], ListVariable):
for idx, x in enumerate(kwargs["out"].items):
assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined]
fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined]
if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument
unimplemented(
"out= op was called where some of the output tensors were non-contiguous"
)
else:
unimplemented(f"out variant of {type(kwargs['out'])}")
else:
unimplemented(f"out variant of {type(kwargs['out'])}")
return tensor_variable
return tensor_variable
def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs):
"""inline behavior of torch.nn.modules.utils._ntuple"""
@ -1118,3 +1139,12 @@ Either create the tensor outside the compiled region, or do not set the tensor t
source
)
return result
def torch_function_override_enabled(self, tx, args, kwargs):
return (
self.get_function() in get_overridable_functions()
or isinstance(
self.get_function(),
(torch._ops.OpOverload, torch._ops.OpOverloadPacket),
)
) and can_dispatch_torch_function(tx, args, kwargs)

View File

@ -1,20 +1,36 @@
# mypy: ignore-errors
import collections
import contextlib
import inspect
from typing import Dict, List, TYPE_CHECKING
from typing import Deque, Dict, List, TYPE_CHECKING
import torch._C
import torch.utils._pytree as pytree
from torch._guards import Source
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
from torch.overrides import (
_get_overloaded_args,
get_default_nowrap_functions,
TorchFunctionMode,
)
from torch.utils._device import DeviceContext
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..polyfills import NoEnterTorchFunctionMode
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
from ..utils import (
class_has_getattribute,
clear_torch_function_mode_stack,
get_safe_global_name,
has_torch_function,
is_tensor_base_attr_getter,
set_torch_function_mode_stack,
)
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import ContextWrappingVariable
from .ctx_manager import GenericContextWrappingVariable
from .lazy import LazyVariableTracker
from .lists import TupleVariable
from .tensor import TensorSubclassVariable, TensorVariable
from .user_defined import UserDefinedObjectVariable
@ -52,11 +68,92 @@ banned_attrs = [
if is_tensor_base_attr_getter(fn)
]
# Today set default device is placed in the graph and guarded on separately
# so we should not trace through it. In the future we can trace it once
# mode tracing is implemented and not put in the graph, but this is more
# of a BE project and can be evaluated later
IGNORED_MODES = {DeviceContext}
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
class TorchFunctionModeStackStateManager:
def __init__(self):
self.stack = []
def __enter__(self):
self.stack = torch.overrides._get_current_function_mode_stack()
clear_torch_function_mode_stack()
def __exit__(self, exc_type, exc_value, traceback):
set_torch_function_mode_stack(self.stack)
self.stack = []
@contextlib.contextmanager
def temp_restore_stack(self):
prev = torch.overrides._get_current_function_mode_stack()
set_torch_function_mode_stack(self.stack)
try:
yield
finally:
set_torch_function_mode_stack(prev)
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
class SymbolicTorchFunctionState:
def __init__(self, py_stack):
# This is annoyingly complicated because of how the torch function subclass + mode C API was designed
# There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass
# These are their definitions:
# 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered
# (if either are entered, this will be False)
# 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR
# torch._C.DisableTorchFunction has been entered
# To disambiguate these and keep myself sane I added a C API to check whether all torch function
# concepts (modes and subclasses) are enabled.
# This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate
# the stack length from the enablement state of torch function modes.
# This is important because now if a mode is pushed while dynamo is tracing, we know whether
# or not torch function modes are enabled and whether we should trace it.
self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled()
# This differs from the C API of the same name
# this will only be false iff we have entered torch._C.DisableTorchFunction
# and does not take into account the mode stack length, while the C API bundles these
# two concepts
self.torch_function_mode_enabled = (
not torch._C._is_torch_function_all_disabled()
)
self.cur_mode = None
TorchFunctionModeStackVariable.reset()
self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque()
for i, val in enumerate(py_stack):
self.mode_stack.append(
LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i))
)
def in_torch_function_mode(self):
return len(self.mode_stack) > 0
def pop_torch_function_mode(self):
return self.mode_stack.pop()
def push_torch_function_mode(self, mode_var):
self.mode_stack.append(mode_var)
def call_torch_function_mode(self, tx, fn, types, args, kwargs):
with self._pop_mode_for_inlining() as cur_mode:
return cur_mode.call_torch_function(tx, fn, types, args, kwargs)
@contextlib.contextmanager
def _pop_mode_for_inlining(self):
old_mode = self.cur_mode
self.cur_mode = self.pop_torch_function_mode()
try:
yield self.cur_mode
finally:
mode = self.cur_mode
self.cur_mode = old_mode
self.push_torch_function_mode(mode)
class TorchFunctionModeStackVariable(VariableTracker):
@ -88,19 +185,20 @@ class TorchFunctionModeStackVariable(VariableTracker):
def register_mutation(cls, tx: "InstructionTranslator"):
if cls.stack_value_singleton not in tx.output.side_effects:
var = cls(
source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack
source=Source(),
symbolic_stack=tx.symbolic_torch_function_state.mode_stack,
)
tx.output.side_effects.track_mutable(cls.stack_value_singleton, var)
tx.output.side_effects.mutation(var)
@classmethod
def register_device_context_insertion(cls, tx: "InstructionTranslator"):
stack = tx.symbolic_torch_function_mode_stack
stack = tx.symbolic_torch_function_state.mode_stack
if stack and cls.is_device_context(stack[0]):
return
else:
cls.offset += 1
tx.symbolic_torch_function_mode_stack.insert(
stack.insert(
0,
TorchFunctionModeVariable(
None, source=TorchFunctionModeStackSource(-cls.offset)
@ -109,7 +207,7 @@ class TorchFunctionModeStackVariable(VariableTracker):
@classmethod
def clear_default_device(cls, tx: "InstructionTranslator"):
stack = tx.symbolic_torch_function_mode_stack
stack = tx.symbolic_torch_function_state.mode_stack
if stack and cls.is_device_context(stack[0]):
stack.popleft()
cls.offset -= 1
@ -123,24 +221,88 @@ class TorchFunctionModeStackVariable(VariableTracker):
return ind + cls.offset
class TorchFunctionModeVariable(ContextWrappingVariable):
def __init__(self, value, **kwargs):
super().__init__(value, **kwargs)
self.value = value
class TorchFunctionModeVariable(GenericContextWrappingVariable):
@staticmethod
def get_global_mangled_name(tx, val):
return get_safe_global_name(
tx, f"__torch_function_mode_{val.__class__.__name__}", val
def is_supported_torch_function_mode(ty):
# Supported in this sense means we can support graph breaks under the
# context.
# We are able to trace custom modes but if there are graph breaks under them
# and they have a custom __enter__/__exit__ we don't handle this for the
# same reason we don't handle generic context managers: there may be side effects
# that are now affected by executing the funtion across two frames instead of one
# Today we support the enter/exit of the default TorchFunctionMode as well as
# DeviceContext (which is used for set_default_device)
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
not class_has_getattribute(ty)
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
)
def __init__(self, value, source=None, **kwargs):
if value is not None:
super().__init__(value, **kwargs)
self.value = value
self.cm_obj = value # needed for BC with calling enter from CM code
self.source = source
def reconstruct(self, codegen):
# We don't support locally created torch function modes yet
# This shouldn't be called unless we have a source
assert self.source
self.source.reconstruct(codegen)
def _call_func(self, tx, values):
unimplemented("torch function mode context manager is not supported yet")
def module_name(self):
return self.value.__module__
def fn_name(self):
return type(self.value).__name__
def python_type(self):
return type(self.value)
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
return call_torch_function(
tx,
self,
build_torch_function_fn(tx, self.value, self.source),
fn,
types,
args,
kwargs,
)
def enter(self, tx):
from .torch import TorchInGraphFunctionVariable
if isinstance(self.value, NoEnterTorchFunctionMode):
return ConstantVariable.create(None)
TorchInGraphFunctionVariable(
torch._C._push_on_torch_function_stack
).call_function(tx, [self], {})
return ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
from .torch import TorchInGraphFunctionVariable
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
tx, [], {}
)
return ConstantVariable.create(None)
def reconstruct_type(self, codegen):
ty = NoEnterTorchFunctionMode
codegen(
AttrSource(
codegen.tx.import_source(ty.__module__),
ty.__name__,
)
)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return False
def _get_all_args(args, kwargs):
@ -231,9 +393,13 @@ def build_torch_function_fn(tx: "InstructionTranslator", value, source):
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
return tx.output.torch_function_enabled and any(
has_overridden_args = any(
has_torch_function(arg) for arg in _get_all_args(args, kwargs)
)
tf_state = tx.symbolic_torch_function_state
return (has_overridden_args and tf_state.torch_function_subclass_enabled) or (
tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode()
)
def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
@ -245,11 +411,20 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
_get_subclass_type,
)
types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args])
if tx.symbolic_torch_function_state.in_torch_function_mode():
res = tx.symbolic_torch_function_state.call_torch_function_mode(
tx, fn, types, args, kwargs
)
if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
return res
for arg in overloaded_args:
res = arg.call_torch_function(
tx,
fn,
TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]),
types,
args,
kwargs,
)

View File

@ -9,6 +9,7 @@ import inspect
import itertools
import random
import sys
import threading
import types
import warnings
from typing import Dict, Generic, List, TYPE_CHECKING
@ -82,11 +83,6 @@ def is_forbidden_context_manager(ctx):
from _pytest.python_api import RaisesContext
from _pytest.recwarn import WarningsChecker
# TODO mlazos: Temporary to get this stack to pass
# remove in subsequent PR
from torch.overrides import BaseTorchFunctionMode
f_ctxs.append(BaseTorchFunctionMode)
f_ctxs.append(RaisesContext)
f_ctxs.append(WarningsChecker)
except ImportError:
@ -413,15 +409,25 @@ class UserDefinedClassVariable(UserDefinedVariable):
and self.source
and not is_forbidden_context_manager(self.value)
):
# import here to avoid an unfortunate circular dependency.
from torch.overrides import TorchFunctionMode
from .ctx_manager import GenericContextWrappingVariable
from .torch_function import TorchFunctionModeVariable
if issubclass(
self.value, TorchFunctionMode
) and TorchFunctionModeVariable.is_supported_torch_function_mode(
self.value
):
var_cls = TorchFunctionModeVariable
else:
var_cls = GenericContextWrappingVariable
cm_obj = tx.output.side_effects.track_object_new(
self.source, self.value, GenericContextWrappingVariable, {}
self.source, self.value, var_cls, {}
)
cm_obj.call_method(tx, "__init__", args, kwargs)
return cm_obj
elif is_namedtuple_cls(self.value):
fields = namedtuple_fields(self.value)
# check if this a quasi-namedtuple or a real one
@ -711,7 +717,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
if method is object.__init__:
return ConstantVariable.create(None)
if is_standard_setattr(method):
if is_standard_setattr(method) or isinstance(self.value, threading.local):
return self.method_setattr_standard(tx, *args, **kwargs)
# [NOTE] OrderedDict, dict subtypes must always have source
@ -809,7 +815,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
def needs_slow_setattr(self):
return not is_standard_setattr(
inspect.getattr_static(self.value, "__setattr__", None)
)
) and not isinstance(self.value, threading.local)
def unpack_var_sequence(self, tx):
if (

View File

@ -506,7 +506,11 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc:
if (
not torch.compiler.is_dynamo_compiling()
and log.isEnabledFor(logging.DEBUG)
and config.extended_debug_current_loc
):
frame = _find_user_code_frame()
if frame is not None:
log.debug(

View File

@ -28,6 +28,7 @@ from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.functional_tensor import disable_functional_mode
from torch.fx.experimental.proxy_tensor import (
_temp_remove_metadata_torch_function_mode,
_temp_remove_pre_dispatch_torch_function_mode,
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
@ -129,6 +130,10 @@ def cond(pred, true_fn, false_fn, operands):
if torch.compiler.is_dynamo_compiling():
return cond_op(pred, true_fn, false_fn, operands)
from torch._dynamo.backends.debugging import (
make_eager_backend_with_torch_function_mode,
)
if isinstance(pred, (bool, int, float)):
log.warning(
"Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
@ -169,12 +174,15 @@ def cond(pred, true_fn, false_fn, operands):
def _cond_op_wrapper(*args, **kwargs):
return cond_op(*args, **kwargs)
with _set_compilation_env():
with torch._dynamo.utils.disable_cache_limit():
with _temp_remove_pre_dispatch_torch_function_mode():
return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)(
pred, true_fn, false_fn, operands
)
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode():
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
if metadata_mode:
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
else:
backend = "eager"
return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)(
pred, true_fn, false_fn, operands
)
def create_fw_bw_graph_branches(true_fn, false_fn, *operands):

View File

@ -15,7 +15,11 @@ from torch._higher_order_ops.utils import (
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.fx.experimental.proxy_tensor import (
_temp_remove_metadata_torch_function_mode,
ProxyTorchDispatchMode,
track_tensor_tree,
)
class WhileLoopOp(HigherOrderOperator):
@ -113,6 +117,9 @@ def while_loop(cond_fn, body_fn, carried_inputs):
- 'while_loop' only supports **inference** right now. Autograd will be supported in the future.
"""
from torch._dynamo.backends.debugging import (
make_eager_backend_with_torch_function_mode,
)
# Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo.
# parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs.
@ -140,9 +147,15 @@ def while_loop(cond_fn, body_fn, carried_inputs):
return while_loop_op(*args, **kwargs)
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
return torch.compile(_while_loop_op_wrapper, backend="eager", fullgraph=True)(
cond_fn, body_fn, carried_inputs, additional_inputs
)
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
if metadata_mode:
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
else:
backend = "eager"
return torch.compile(
_while_loop_op_wrapper, backend=backend, fullgraph=True
)(cond_fn, body_fn, carried_inputs, additional_inputs)
@while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd)

View File

@ -2515,62 +2515,40 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard {
public:
TORCH_FUNCTION_MODE_STACK(
const py::list& initial_stack,
const py::list& ignored_types,
py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)),
_ref_stack(),
_ignored_types() {
: LeafGuard(std::move(verbose_code_parts)), _ref_stack() {
Py_ssize_t len = PyList_Size(initial_stack.ptr());
for (Py_ssize_t idx = 0; idx < len; idx++) {
PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref
this->_ref_stack.push_back(Py_TYPE(mode));
}
len = PyList_Size(ignored_types.ptr());
for (Py_ssize_t idx = 0; idx < len; idx++) {
PyObject* type_obj =
PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref
if (PyType_Check(type_obj) == 0) {
PyErr_SetString(
PyExc_TypeError, "ignored_types should contain a list of types");
return;
}
PyTypeObject* type = (PyTypeObject*)type_obj;
this->_ignored_types.insert(type);
auto type = Py_TYPE(mode);
this->_ref_stack.push_back(type);
}
}
bool check_nopybind(PyObject* value) override {
// Ignore value arg, only used to satisfy the interface
size_t ref_ind = 0;
int64_t len = at::impl::PythonTorchFunctionTLS::stack_len();
const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len();
const size_t ref_stack_size = this->_ref_stack.size();
for (int64_t idx = 0; idx < len; idx++) {
if (len != ref_stack_size) {
return false;
}
for (int64_t idx = 0; (size_t)idx < len; idx++) {
std::shared_ptr<c10::SafePyObject> mode =
at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
// skip ignored types
if (this->_ignored_types.count(mode_type) > 0) {
continue;
}
// if we already have more non-ignored modes than the ref stack
// or if the mode doesn't match at the current index, return false
else if (
(ref_stack_size == 0) || (ref_ind > ref_stack_size - 1) ||
mode_type != _ref_stack[ref_ind]) {
if (mode_type != _ref_stack.at(idx)) {
return false;
}
ref_ind++;
}
return ref_ind == this->_ref_stack.size();
return true;
}
private:
std::vector<PyTypeObject*> _ref_stack;
std::set<PyTypeObject*> _ignored_types;
};
class TENSOR_MATCH : public LeafGuard {
@ -3672,7 +3650,7 @@ PyObject* torch_c_dynamo_guards_init() {
LeafGuard,
std::shared_ptr<TORCH_FUNCTION_MODE_STACK>>(
py_m, "TORCH_FUNCTION_MODE_STACK")
.def(py::init<py::list, py::list, py::list>())
.def(py::init<py::list, py::list>())
.def("__call__", &TORCH_FUNCTION_MODE_STACK::check);
py::class_<DATA_PTR_MATCH, LeafGuard, std::shared_ptr<DATA_PTR_MATCH>>(
py_m, "DATA_PTR_MATCH")
@ -3903,10 +3881,9 @@ PyObject* torch_c_dynamo_guards_init() {
"add_torch_function_mode_stack_guard",
[](GuardManager& self,
const py::list& initial_stack,
const py::list& ignored_types,
py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<TORCH_FUNCTION_MODE_STACK>(
initial_stack, ignored_types, std::move(verbose_code_parts)));
initial_stack, std::move(verbose_code_parts)));
})
.def(
"add_data_ptr_guard",

View File

@ -17,7 +17,7 @@ import typing_extensions
import warnings
import weakref
from collections import defaultdict
from contextlib import contextmanager, ExitStack, nullcontext
from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext
from dataclasses import dataclass
from typing import (
Any,
@ -1084,38 +1084,43 @@ class PythonKeyTracer(Tracer):
return e
@contextmanager
def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]:
from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode
def _make_temp_remove_mode_context_manager(
mode_ty: Type[TorchFunctionMode],
) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]:
@contextmanager
def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]:
from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode
temp_elements = []
pre_dispatch_mode = None
temp_elements = []
removed_mode = None
while _len_torch_function_stack() > 0:
mode = _pop_mode()
if isinstance(mode, PreDispatchTorchFunctionMode):
pre_dispatch_mode = mode
break
else:
temp_elements.append(mode)
while _len_torch_function_stack() > 0:
mode = _pop_mode()
if isinstance(mode, mode_ty):
removed_mode = mode
break
else:
temp_elements.append(mode)
for mode in reversed(temp_elements):
_push_mode(mode)
for mode in reversed(temp_elements):
_push_mode(mode)
try:
yield
try:
yield removed_mode
finally:
if pre_dispatch_mode is not None:
count = len(temp_elements)
while count > 0:
mode = _pop_mode()
count -= 1
finally:
if removed_mode is not None:
count = len(temp_elements)
while count > 0:
mode = _pop_mode()
count -= 1
temp_elements.append(pre_dispatch_mode)
temp_elements.append(removed_mode)
for mode in reversed(temp_elements):
_push_mode(mode)
for mode in reversed(temp_elements):
_push_mode(mode)
return context_manager_fn
@torch._disable_dynamo
@ -1230,6 +1235,11 @@ class TorchFunctionMetadataMode(TorchFunctionMode):
return func(*args, **kwargs)
_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager(
TorchFunctionMetadataMode
)
# This mode is **only** used for pre_dispatch tracing.
# In particular, we need to make sure that autograd/autocast API's
# that do not desugar into dispatcher operators stay in the graph.
@ -1258,6 +1268,11 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
return func(*args, **kwargs)
_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager(
PreDispatchTorchFunctionMode
)
class ProxyTorchDispatchMode(TorchDispatchMode):
# Ensure this is read-only; this exists only for legacy reasons
@property

View File

@ -19,6 +19,7 @@ from torch._higher_order_ops.flex_attention import (
)
from torch._higher_order_ops.utils import _set_compilation_env
from torch.fx.experimental.proxy_tensor import (
_temp_remove_metadata_torch_function_mode,
_temp_remove_pre_dispatch_torch_function_mode,
)
from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input
@ -1027,6 +1028,10 @@ def flex_attention(
if not torch._dynamo.is_dynamo_supported():
raise RuntimeError("flex_attention requires dynamo support")
from torch._dynamo.backends.debugging import (
make_eager_backend_with_torch_function_mode,
)
# Dynamo is expecting a callable with "__code__" attribute.
# We cannot directly pass hop to it. So we wrap it in a dummy function.
def _flex_attention_hop_wrapper(*args, **kwargs):
@ -1035,18 +1040,25 @@ def flex_attention(
with _set_compilation_env():
with torch._dynamo.utils.disable_cache_limit():
with _temp_remove_pre_dispatch_torch_function_mode():
out, lse = torch.compile(
_flex_attention_hop_wrapper, backend="eager", fullgraph=True
)(
query,
key,
value,
score_mod,
block_mask.as_tuple(),
scale,
kernel_options,
)
if return_lse:
return out, lse * math.log(2)
else:
return out
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
if metadata_mode:
backend = make_eager_backend_with_torch_function_mode(
metadata_mode
)
else:
backend = "eager"
out, lse = torch.compile(
_flex_attention_hop_wrapper, backend="eager", fullgraph=True
)(
query,
key,
value,
score_mod,
block_mask.as_tuple(),
scale,
kernel_options,
)
if return_lse:
return out, lse * math.log(2)
else:
return out