mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 15:44:58 +08:00
Compare commits
7 Commits
cslpull92
...
mlazos/tf-
Author | SHA1 | Date | |
---|---|---|---|
ac3dabf652 | |||
54ab06fc07 | |||
32542724be | |||
dfbb990dc4 | |||
194d46e91c | |||
9094fb5c7c | |||
ec6b49eed9 |
@ -3380,6 +3380,21 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||
self.assertTrue(same(obj41.y, obj42.y))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_thread_local_setattr(self):
|
||||
from threading import local
|
||||
|
||||
loc = local()
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def fn(x, l):
|
||||
l.x = x
|
||||
return x + 1
|
||||
|
||||
x = torch.ones(2, 2)
|
||||
fn(x, loc)
|
||||
|
||||
self.assertTrue(loc.x is x)
|
||||
|
||||
def test_user_defined_class_name(self):
|
||||
class MyClassFoo:
|
||||
pass
|
||||
|
@ -1,5 +1,4 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
@ -14,6 +13,17 @@ from torch.utils._device import DeviceContext
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
class TestMode(BaseTorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args, kwargs=None):
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
|
||||
if func == torch.add:
|
||||
return torch.zeros(2, 2)
|
||||
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
|
||||
class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -57,9 +67,11 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
torch.set_default_device(None)
|
||||
torch._dynamo.reset()
|
||||
|
||||
def tearDown(self):
|
||||
torch.set_default_device(None)
|
||||
torch._dynamo.reset()
|
||||
|
||||
def _run_torch_function_mode_guard_test(self):
|
||||
class TestMode1(BaseTorchFunctionMode):
|
||||
@ -94,70 +106,6 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
||||
fn(inp)
|
||||
self.assertEqual(cnt.frame_count, 4)
|
||||
|
||||
def _run_ignored_mode_types_test(self):
|
||||
class IgnoredMode(BaseTorchFunctionMode):
|
||||
pass
|
||||
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnt.__call__, fullgraph=True)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
inp = torch.ones(2, 2)
|
||||
|
||||
with patch(
|
||||
"torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode}
|
||||
):
|
||||
# initial compile
|
||||
fn(inp)
|
||||
|
||||
# no recompile, mode ignored
|
||||
# note: the ref stack is length 0, and the stack we are checking against has length 2
|
||||
# we want to check both ref stack len > runtime stack, and ref stack len < runtime stack
|
||||
with IgnoredMode(), IgnoredMode():
|
||||
fn(inp)
|
||||
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
# recompile due to new mode on the stack
|
||||
with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
|
||||
fn(inp)
|
||||
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
# recompile
|
||||
# tests both ref stack len > runtime stack len for the above guard check
|
||||
# and ref stack len < runtime stack len for the initial zero mode case
|
||||
with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode():
|
||||
fn(inp)
|
||||
|
||||
self.assertEqual(cnt.frame_count, 3)
|
||||
|
||||
# no recompile
|
||||
with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
|
||||
fn(inp)
|
||||
|
||||
self.assertEqual(cnt.frame_count, 3)
|
||||
|
||||
# This is tricky, basically the ignored modes are baked into the guard
|
||||
# IgnoredMode will be ignored forever by that guard.
|
||||
# This is okay since we don't expect to be modifying IGNORED_MODES
|
||||
# in the middle of execution except for the purposes of testing.
|
||||
torch._dynamo.reset()
|
||||
|
||||
with IgnoredMode():
|
||||
fn(inp)
|
||||
|
||||
self.assertEqual(cnt.frame_count, 4)
|
||||
|
||||
@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
|
||||
def test_torch_function_mode_guards_ignored_types_py(self):
|
||||
self._run_ignored_mode_types_test()
|
||||
|
||||
def test_torch_function_mode_guards_ignored_types_cpp(self):
|
||||
self._run_ignored_mode_types_test()
|
||||
|
||||
@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
|
||||
def test_torch_function_mode_guards_py(self):
|
||||
self._run_torch_function_mode_guard_test()
|
||||
@ -324,6 +272,218 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
||||
fn(inp)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_nested_torch_function_mode(self):
|
||||
mode_1_called = False
|
||||
mode_2_called = False
|
||||
|
||||
def reset_state():
|
||||
nonlocal mode_1_called
|
||||
nonlocal mode_2_called
|
||||
mode_1_called = False
|
||||
mode_2_called = False
|
||||
|
||||
ones = torch.ones(2, 2)
|
||||
zeros = torch.zeros(2, 2)
|
||||
|
||||
class TestMode1(BaseTorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args, kwargs=None):
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
|
||||
nonlocal mode_1_called
|
||||
|
||||
mode_1_called = True
|
||||
|
||||
if func == torch.add:
|
||||
return zeros
|
||||
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
class TestMode2(BaseTorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args, kwargs=None):
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
|
||||
nonlocal mode_2_called
|
||||
|
||||
mode_2_called = True
|
||||
|
||||
if func == torch.mul:
|
||||
return ones
|
||||
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
def fn(x):
|
||||
return torch.add(x, 3)
|
||||
|
||||
def fn_2(x):
|
||||
return torch.mul(x, 3) + torch.add(x, 3)
|
||||
|
||||
inp = torch.ones(2, 2) + 1
|
||||
|
||||
for fn_i in [fn, fn_2]:
|
||||
fn_opt = torch.compile(fn_i, fullgraph=True)
|
||||
with TestMode1(), TestMode2():
|
||||
expected = fn_i(inp), mode_1_called, mode_2_called
|
||||
reset_state()
|
||||
actual = fn_opt(inp), mode_1_called, mode_2_called
|
||||
reset_state()
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_disable(self):
|
||||
class TestSubclass(torch.Tensor):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args, kwargs=None):
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
if func == torch.add:
|
||||
return torch.ones(2, 2)
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
class TestMode(BaseTorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args, kwargs=None):
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
|
||||
if func == torch.add:
|
||||
return torch.zeros(2, 2)
|
||||
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
def fn(x):
|
||||
return torch.add(x, 3)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
|
||||
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
with TestMode(), torch._dynamo.config.patch(
|
||||
"traceable_tensor_subclasses", {TestSubclass}
|
||||
):
|
||||
with torch._C.DisableTorchFunctionSubclass():
|
||||
expected = fn(inp)
|
||||
actual = fn_opt(inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
expected = fn(inp)
|
||||
actual = fn_opt(inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_highest_priority(self):
|
||||
class TestSubclass(torch.Tensor):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args, kwargs=None):
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
if func == torch.add:
|
||||
return torch.ones(2, 2)
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
def fn(x):
|
||||
return torch.add(x, 3)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
|
||||
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
with TestMode(), torch._dynamo.config.patch(
|
||||
"traceable_tensor_subclasses", {TestSubclass}
|
||||
):
|
||||
expected = fn(inp)
|
||||
actual = fn_opt(inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_enter_exit(self):
|
||||
def fn(x, y):
|
||||
with TestMode():
|
||||
o = torch.add(x, 3)
|
||||
|
||||
return torch.add(o, y)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
|
||||
expected = fn(*inp)
|
||||
actual = fn_opt(*inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_graph_break(self):
|
||||
def fn(x, y):
|
||||
with TestMode():
|
||||
torch._dynamo.graph_break()
|
||||
o = torch.add(x, 3)
|
||||
|
||||
return torch.add(o, y)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
expected = fn(*inp)
|
||||
actual = fn_opt(*inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_and_pop_graph_break(self):
|
||||
def fn(x, y):
|
||||
with TestMode():
|
||||
z = _pop_torch_function_stack()
|
||||
torch._dynamo.graph_break()
|
||||
_push_on_torch_function_stack(z)
|
||||
o = torch.add(x, 3)
|
||||
|
||||
return torch.add(o, y)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
expected = fn(*inp)
|
||||
actual = fn_opt(*inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_restore_on_exc(self):
|
||||
@torch._dynamo.disable()
|
||||
def err():
|
||||
raise RuntimeError("test")
|
||||
|
||||
@torch.compile()
|
||||
def fn(x):
|
||||
with TestMode():
|
||||
x += 1
|
||||
err()
|
||||
x += 2
|
||||
return x
|
||||
|
||||
try:
|
||||
fn(torch.ones(2, 2))
|
||||
except RuntimeError:
|
||||
pass
|
||||
self.assertEqual(_len_torch_function_stack(), 0)
|
||||
|
||||
def test_torch_function_mode_and_pop_graph_break_mutation(self):
|
||||
def fn(x, y):
|
||||
with TestMode():
|
||||
z = _pop_torch_function_stack()
|
||||
z.y = 5
|
||||
torch._dynamo.graph_break()
|
||||
_push_on_torch_function_stack(z)
|
||||
o = torch.add(x, 3)
|
||||
o = torch.mul(o, z.y)
|
||||
|
||||
return torch.add(o, y)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
expected = fn(*inp)
|
||||
actual = fn_opt(*inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -24,6 +24,7 @@ from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
parametrize,
|
||||
run_tests,
|
||||
skipIfCrossRef,
|
||||
skipIfTorchDynamo,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
TestCase,
|
||||
@ -1557,6 +1558,7 @@ class TestControlFlowTraced(TestCase):
|
||||
self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True)))
|
||||
|
||||
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
|
||||
@skipIfCrossRef # Arg order changes with crossref
|
||||
def test_cond_simple_with_linear_compile_check_graph(self):
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs
|
||||
|
||||
@ -1819,6 +1821,7 @@ def forward(self, arg0_1):
|
||||
self._check_compile(fn, inp, backend=backend)
|
||||
|
||||
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
|
||||
@skipIfCrossRef # Arg order changes with cross ref
|
||||
def test_while_loop_simple_with_linear_compile_check_graph(self):
|
||||
fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs
|
||||
|
@ -13,7 +13,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
|
||||
from torch.fx import Node
|
||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef
|
||||
|
||||
|
||||
class TestHelperModules:
|
||||
@ -139,6 +139,8 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||
self.assertEqual(v, node_tags[k])
|
||||
return m
|
||||
|
||||
@skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack
|
||||
# trace of the mode torch function impl doesn't match the traced graph stored lineno.
|
||||
def test_simple_metadata_porting(self):
|
||||
"""
|
||||
Model under test
|
||||
|
@ -67,7 +67,7 @@ class GuardManager:
|
||||
) -> None: ...
|
||||
def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
|
||||
def add_torch_function_mode_stack_guard(
|
||||
self, initial_stack, ignored_types, verbose_code_parts: list[str]
|
||||
self, initial_stack, verbose_code_parts: list[str]
|
||||
) -> None: ...
|
||||
|
||||
class RootGuardManager(GuardManager):
|
||||
|
@ -31,6 +31,18 @@ def eager(gm, fake_tensor_inputs, **kwargs):
|
||||
return gm.forward
|
||||
|
||||
|
||||
def make_eager_backend_with_torch_function_mode(mode):
|
||||
"""Used to trace HOPs (cond and while) for eager exectution, the metadata
|
||||
TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
|
||||
in the HOP, so we need to externally run this mode and not trace it."""
|
||||
|
||||
def fn(gm, fake_tensor_inputs, **kwargs):
|
||||
with mode:
|
||||
return gm.forward
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
@register_backend
|
||||
def eager_noexcept(gm, fake_tensor_inputs, **kwargs):
|
||||
if kwargs:
|
||||
|
@ -112,6 +112,7 @@ from .utils import (
|
||||
troubleshooting_url,
|
||||
write_record_to_file,
|
||||
)
|
||||
from .variables.torch_function import torch_function_mode_stack_state_mgr
|
||||
|
||||
|
||||
np: Optional[ModuleType]
|
||||
@ -210,15 +211,18 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
|
||||
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
|
||||
cleanup = setup_compile_debug()
|
||||
|
||||
exit_stack = contextlib.ExitStack()
|
||||
exit_stack.enter_context(
|
||||
torch.fx._symbolic_trace._maybe_revert_all_patches()
|
||||
)
|
||||
exit_stack.enter_context(torch_function_mode_stack_state_mgr)
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
cleanup.close()
|
||||
assert (
|
||||
torch._C._len_torch_function_stack() == 0
|
||||
), "Torch function mode stack state changed while dynamo tracing, please report a bug"
|
||||
exit_stack.close()
|
||||
torch._C._set_grad_enabled(prior_grad_mode)
|
||||
torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
|
||||
@ -605,6 +609,10 @@ def _compile(
|
||||
output: Optional[OutputGraph] = None
|
||||
tracer: Optional[InstructionTranslator] = None
|
||||
|
||||
tf_mode_stack: List[
|
||||
torch.overrides.TorchFunctionMode
|
||||
] = torch.overrides._get_current_function_mode_stack()
|
||||
|
||||
@preserve_global_state
|
||||
def transform(
|
||||
instructions: List[Instruction], code_options: Dict[str, object]
|
||||
@ -618,6 +626,7 @@ def _compile(
|
||||
locals,
|
||||
globals,
|
||||
builtins,
|
||||
tf_mode_stack,
|
||||
code_options,
|
||||
compiler_fn,
|
||||
one_graph,
|
||||
|
@ -97,6 +97,7 @@ from .source import (
|
||||
ScriptObjectQualifiedNameSource,
|
||||
ShapeEnvSource,
|
||||
SubclassAttrListSource,
|
||||
TorchFunctionModeStackSource,
|
||||
TupleIteratorGetItemSource,
|
||||
TypeSource,
|
||||
UnspecializedBuiltinNNModuleSource,
|
||||
@ -110,6 +111,7 @@ from .utils import (
|
||||
dict_keys_repr,
|
||||
get_custom_getattr,
|
||||
get_torch_function_mode_stack,
|
||||
get_torch_function_mode_stack_at,
|
||||
guard_failures,
|
||||
istype,
|
||||
key_is_id,
|
||||
@ -313,6 +315,7 @@ CLOSURE_VARS = {
|
||||
"___dict_contains": lambda a, b: a in b,
|
||||
"___tuple_iterator_len": tuple_iterator_len,
|
||||
"___tuple_iterator_getitem": tuple_iterator_getitem,
|
||||
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
|
||||
"__math_isnan": math.isnan,
|
||||
"__numpy_isnan": None if np is None else np.isnan,
|
||||
"inf": float("inf"),
|
||||
@ -900,6 +903,15 @@ class GuardBuilder(GuardBuilderBase):
|
||||
):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
out = base_guard_manager
|
||||
elif istype(source, TorchFunctionModeStackSource):
|
||||
out = root_guard_manager.lambda_manager(
|
||||
python_lambda=lambda _: get_torch_function_mode_stack_at(
|
||||
source._get_index()
|
||||
),
|
||||
source=source_name,
|
||||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, GradSource):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
out = base_guard_manager.grad_manager(
|
||||
@ -2206,6 +2218,8 @@ class CheckFunctionManager:
|
||||
self.output_graph = output_graph
|
||||
w_builder = None
|
||||
|
||||
# NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing
|
||||
# in case a set default device call was made in the graph.
|
||||
self.torch_function_mode_stack = (
|
||||
output_graph.torch_function_mode_stack if output_graph else None
|
||||
)
|
||||
@ -2322,15 +2336,12 @@ class CheckFunctionManager:
|
||||
)
|
||||
|
||||
if config.enable_cpp_guard_manager:
|
||||
from .variables.torch_function import IGNORED_MODES
|
||||
|
||||
# Insert the global_state guard
|
||||
assert self.guard_manager # to make mypy happy
|
||||
self.guard_manager.root.add_global_state_guard(["___check_global_state()"])
|
||||
|
||||
self.guard_manager.root.add_torch_function_mode_stack_guard(
|
||||
self.torch_function_mode_stack,
|
||||
list(IGNORED_MODES),
|
||||
["___check_torch_function_mode_stack()"],
|
||||
)
|
||||
# Clear references to torch_function modes held in the list
|
||||
@ -2637,16 +2648,14 @@ def is_recompiles_verbose_enabled():
|
||||
# this will only be used if cpp guards are disabled
|
||||
def make_torch_function_mode_stack_guard(intial_stack):
|
||||
types = [type(x) for x in intial_stack]
|
||||
from .variables.torch_function import IGNORED_MODES
|
||||
|
||||
def check_torch_function_mode_stack():
|
||||
cur_stack = get_torch_function_mode_stack()
|
||||
|
||||
if len(cur_stack) != len(types):
|
||||
return False
|
||||
|
||||
for ty, mode in zip(types, cur_stack):
|
||||
if ty in IGNORED_MODES:
|
||||
continue
|
||||
if ty != type(mode):
|
||||
return False
|
||||
|
||||
|
@ -78,7 +78,6 @@ from .utils import (
|
||||
get_instruction_source_311,
|
||||
get_locals_to_steal,
|
||||
get_static_address_type,
|
||||
get_torch_function_mode_stack,
|
||||
graph_break_reasons,
|
||||
increment_op_count,
|
||||
lazy_format_graph_code,
|
||||
@ -250,6 +249,7 @@ class OutputGraph:
|
||||
local_scope: Scope,
|
||||
global_scope: Scope,
|
||||
f_code,
|
||||
torch_function_mode_stack,
|
||||
):
|
||||
super().__init__()
|
||||
self.tracers = [SubgraphTracer(self, export_root=export)]
|
||||
@ -368,7 +368,7 @@ class OutputGraph:
|
||||
# This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
|
||||
self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
|
||||
# This records the initial torch function mode stack for guarding
|
||||
self.torch_function_mode_stack = get_torch_function_mode_stack()
|
||||
self.torch_function_mode_stack = torch_function_mode_stack
|
||||
|
||||
# Tracks if the output graph has a user defined allowed function in the
|
||||
# graph. This is used later to determine if we should fallback to eager
|
||||
@ -1020,7 +1020,7 @@ class OutputGraph:
|
||||
prefix_insts.clear()
|
||||
|
||||
for block in reversed(tx.block_stack):
|
||||
block.exit(tx)
|
||||
block.exit(tx, is_graph_break=reason.graph_break)
|
||||
|
||||
self.cleanup_graph()
|
||||
tx.prune_dead_locals()
|
||||
|
@ -25,6 +25,26 @@ if TYPE_CHECKING:
|
||||
sys as sys,
|
||||
)
|
||||
|
||||
from torch.overrides import BaseTorchFunctionMode
|
||||
|
||||
|
||||
# These classes handle support for TorchFunctionModes across
|
||||
# graph breaks
|
||||
# Today the TorchFunctionMode enter (for the classes we support)
|
||||
# simply pushes the mode onto the stack. Since after this occurs
|
||||
# the stack is mutated, and we replay these mutations, we don't need
|
||||
# any cleanup logic to be run once the graph break occurs, we simply replay
|
||||
# these mutations to ensure at the graph break the torch function mode stack is correct
|
||||
# and reconstruct the torch function mode stack normally
|
||||
# when we compile the resume function on the other side of the break.
|
||||
# However, to ensure we exit properly
|
||||
# in the resume function, we need to re-enter the contexts as we do other contexts.
|
||||
# These contexts do nothing on enter, but provide the correct exit logic to ensure
|
||||
# the stack state is correct.
|
||||
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
|
||||
def index(iterator, item, start=0, end=None):
|
||||
from itertools import islice
|
||||
|
@ -48,6 +48,107 @@ class ReenterWith:
|
||||
stack_index: int
|
||||
target_values: Optional[Tuple[Any, ...]] = None
|
||||
|
||||
def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
|
||||
"""
|
||||
Codegen based off of:
|
||||
try:
|
||||
(rest)
|
||||
finally:
|
||||
|
||||
"""
|
||||
except_jump_target = create_instruction(
|
||||
"NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO"
|
||||
)
|
||||
cleanup_complete_jump_target = create_instruction("NOP")
|
||||
|
||||
setup_finally: List[Instruction] = []
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
setup_finally.append(
|
||||
create_instruction("SETUP_FINALLY", target=except_jump_target)
|
||||
)
|
||||
else:
|
||||
exn_tab_begin = create_instruction("NOP")
|
||||
exn_tab_end = create_instruction("NOP")
|
||||
exn_tab_begin.exn_tab_entry = InstructionExnTabEntry(
|
||||
exn_tab_begin,
|
||||
exn_tab_end,
|
||||
except_jump_target,
|
||||
self.stack_index + 1,
|
||||
False,
|
||||
)
|
||||
setup_finally.append(exn_tab_begin)
|
||||
|
||||
def create_reset():
|
||||
insts = [
|
||||
create_instruction(
|
||||
"LOAD_GLOBAL", argval="__import_torch_dot__dynamo_dot_utils"
|
||||
),
|
||||
create_instruction("LOAD_ATTR", argval="set_torch_function_mode_stack"),
|
||||
]
|
||||
return [
|
||||
*insts,
|
||||
create_instruction(
|
||||
"LOAD_FAST", argval="___prev_torch_function_mode_stack"
|
||||
),
|
||||
*create_call_function(1, True),
|
||||
create_instruction("POP_TOP"),
|
||||
]
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
epilogue = [
|
||||
create_instruction("POP_BLOCK"),
|
||||
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
|
||||
except_jump_target,
|
||||
*create_reset(),
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("POP_TOP"),
|
||||
*create_reset(),
|
||||
create_instruction("RAISE_VARARGS", argval=0),
|
||||
create_instruction("POP_EXCEPT", argval=0),
|
||||
create_instruction("END_FINALLY"),
|
||||
cleanup_complete_jump_target,
|
||||
]
|
||||
elif sys.version_info < (3, 11):
|
||||
epilogue = [
|
||||
create_instruction("POP_BLOCK"),
|
||||
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
|
||||
except_jump_target,
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("POP_TOP"),
|
||||
*create_reset(),
|
||||
create_instruction("RAISE_VARARGS", argval=0),
|
||||
create_instruction("POP_EXCEPT", argval=0),
|
||||
cleanup_complete_jump_target,
|
||||
]
|
||||
else:
|
||||
finally_exn_tab_end = create_instruction("RAISE_VARARGS", argval=0)
|
||||
finally_exn_tab_target = create_instruction("COPY", arg=3)
|
||||
except_jump_target.exn_tab_entry = InstructionExnTabEntry(
|
||||
except_jump_target,
|
||||
finally_exn_tab_end,
|
||||
finally_exn_tab_target,
|
||||
self.stack_index + 2,
|
||||
True,
|
||||
)
|
||||
epilogue = [
|
||||
exn_tab_end,
|
||||
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
|
||||
except_jump_target, # PUSH_EXC_INFO
|
||||
create_instruction("POP_TOP"),
|
||||
*create_reset(),
|
||||
finally_exn_tab_end,
|
||||
finally_exn_tab_target, # COPY 3
|
||||
create_instruction("POP_EXCEPT"),
|
||||
create_instruction("RERAISE", arg=1), # RERAISE 1
|
||||
cleanup_complete_jump_target,
|
||||
]
|
||||
|
||||
cleanup[:] = epilogue + cleanup
|
||||
return setup_finally
|
||||
|
||||
# If we do not want to destroy the stack, we can do the same thing as a
|
||||
# `SETUP_WITH` block, only that we store the context manager in a local_symbol
|
||||
def try_except(self, code_options, cleanup: List[Instruction]):
|
||||
|
@ -593,16 +593,19 @@ class SideEffects:
|
||||
elif isinstance(
|
||||
var, variables.torch_function.TorchFunctionModeStackVariable
|
||||
):
|
||||
cg.add_push_null(
|
||||
lambda: cg.load_import_from(
|
||||
utils.__name__, "set_torch_function_mode_stack"
|
||||
)
|
||||
)
|
||||
# Needed in the finally block for stack restoration
|
||||
cg.load_import_from(utils.__name__, "get_torch_function_mode_stack")
|
||||
cg.call_function(0, True)
|
||||
name = "___prev_torch_function_mode_stack"
|
||||
cg.code_options["co_varnames"] += (name,)
|
||||
cg.append_output(create_instruction("STORE_FAST", argval=name))
|
||||
cg.load_import_from(utils.__name__, "set_torch_function_mode_stack")
|
||||
|
||||
cg.foreach(var.symbolic_stack)
|
||||
cg.append_output(
|
||||
create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))
|
||||
)
|
||||
cg.call_function(1, False)
|
||||
cg.call_function(1, True)
|
||||
cg.append_output(create_instruction("POP_TOP"))
|
||||
elif self.is_attribute_mutation(var):
|
||||
# Applying mutations involves two steps: 1) Push all
|
||||
|
@ -608,7 +608,7 @@ class TorchFunctionModeStackSource(Source):
|
||||
ind: int
|
||||
|
||||
def name(self):
|
||||
return ""
|
||||
return f"___get_torch_function_mode_stack_at({self._get_index()})"
|
||||
|
||||
def _get_index(self):
|
||||
from .variables.torch_function import TorchFunctionModeStackVariable
|
||||
|
@ -19,20 +19,7 @@ import traceback
|
||||
import types
|
||||
import typing
|
||||
import weakref
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Deque,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -72,14 +59,12 @@ from .source import (
|
||||
GlobalWeakRefSource,
|
||||
LocalSource,
|
||||
Source,
|
||||
TorchFunctionModeStackSource,
|
||||
)
|
||||
from .trace_rules import is_builtin_constant, is_forbidden
|
||||
from .utils import (
|
||||
counters,
|
||||
get_fake_value,
|
||||
get_instruction_source_311,
|
||||
get_torch_function_mode_stack,
|
||||
graph_break_dup_warning_checker,
|
||||
istype,
|
||||
LazyString,
|
||||
@ -120,11 +105,10 @@ from .variables.misc import (
|
||||
)
|
||||
from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable
|
||||
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .variables.torch_function import TorchFunctionModeVariable
|
||||
|
||||
from .variables.torch_function import (
|
||||
SymbolicTorchFunctionState,
|
||||
TorchFunctionModeVariable,
|
||||
)
|
||||
from .variables.user_defined import (
|
||||
RemovableHandleVariable,
|
||||
UserDefinedClassVariable,
|
||||
@ -283,9 +267,12 @@ class BlockStackEntry:
|
||||
else:
|
||||
return ReenterWith(self.stack_index)
|
||||
|
||||
def exit(self, tx):
|
||||
def exit(self, tx, is_graph_break):
|
||||
assert self.with_context is not None
|
||||
return self.with_context.exit(tx)
|
||||
if (
|
||||
is_graph_break and self.with_context.exit_on_graph_break()
|
||||
) or not is_graph_break:
|
||||
return self.with_context.exit(tx)
|
||||
|
||||
|
||||
class ReturnValueOp(Exception):
|
||||
@ -651,8 +638,17 @@ def break_graph_if_unsupported(*, push):
|
||||
cleanup: List[Instruction] = []
|
||||
# Reconstruct the context variable CLASS in the block stack
|
||||
for b in self.block_stack:
|
||||
# Don't exit any modes we have entered,
|
||||
# output bytecode will mutate the tf mode stack accordingly
|
||||
if isinstance(b.with_context, TorchFunctionModeVariable):
|
||||
cg.extend_output(
|
||||
b.resume_fn().try_except_torch_function_mode(
|
||||
cg.code_options, cleanup
|
||||
)
|
||||
)
|
||||
continue
|
||||
assert b.with_context is not None
|
||||
assert isinstance(b.with_context, ContextWrappingVariable)
|
||||
assert isinstance(b.with_context, (ContextWrappingVariable))
|
||||
b.with_context.reconstruct_type(cg)
|
||||
cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup))
|
||||
self.output.add_output_instructions(cg.get_instructions())
|
||||
@ -728,7 +724,7 @@ class InstructionTranslatorBase(
|
||||
output: OutputGraph
|
||||
symbolic_locals: Dict[str, VariableTracker]
|
||||
symbolic_globals: Dict[str, VariableTracker]
|
||||
symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"]
|
||||
symbolic_torch_function_state: SymbolicTorchFunctionState
|
||||
stack: List[VariableTracker]
|
||||
instruction_pointer: Optional[int]
|
||||
current_instruction: Instruction
|
||||
@ -2305,7 +2301,10 @@ class InstructionTranslatorBase(
|
||||
):
|
||||
unimplemented(f"{inst.opname} {ctx}")
|
||||
|
||||
if isinstance(ctx, GenericContextWrappingVariable):
|
||||
if (
|
||||
isinstance(ctx, GenericContextWrappingVariable)
|
||||
and not ctx.supports_graph_breaks()
|
||||
):
|
||||
self.generic_context_manager_depth += 1
|
||||
|
||||
# Need this redundant check for mypy
|
||||
@ -2548,7 +2547,7 @@ class InstructionTranslatorBase(
|
||||
code_options: Dict[str, Any],
|
||||
symbolic_locals: Dict[str, VariableTracker],
|
||||
symbolic_globals: Dict[str, VariableTracker],
|
||||
symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"],
|
||||
symbolic_torch_function_state: SymbolicTorchFunctionState,
|
||||
f_code: types.CodeType,
|
||||
export: bool,
|
||||
inline_depth: int,
|
||||
@ -2563,7 +2562,7 @@ class InstructionTranslatorBase(
|
||||
self.output = output
|
||||
self.symbolic_locals = symbolic_locals
|
||||
self.symbolic_globals = symbolic_globals
|
||||
self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack
|
||||
self.symbolic_torch_function_state = symbolic_torch_function_state
|
||||
self.stack = []
|
||||
# stack of variable names for tracking 3.13 closures
|
||||
self.name_stack: list[Any] = []
|
||||
@ -2652,6 +2651,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
f_locals,
|
||||
f_globals,
|
||||
f_builtins,
|
||||
torch_function_mode_stack,
|
||||
code_options,
|
||||
compiler_fn,
|
||||
one_graph,
|
||||
@ -2677,6 +2677,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
local_scope=f_locals,
|
||||
global_scope=f_globals,
|
||||
f_code=f_code,
|
||||
torch_function_mode_stack=torch_function_mode_stack,
|
||||
),
|
||||
instructions=instructions,
|
||||
f_locals=f_locals,
|
||||
@ -2686,7 +2687,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
symbolic_locals={}, # set below
|
||||
# A global var is inserted only after a STORE_GLOBAL happens to it
|
||||
symbolic_globals={},
|
||||
symbolic_torch_function_mode_stack=collections.deque(),
|
||||
symbolic_torch_function_state=None, # type: ignore[arg-type] # set below
|
||||
f_code=f_code,
|
||||
export=export,
|
||||
inline_depth=0,
|
||||
@ -2721,7 +2722,9 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
if k in f_locals
|
||||
}
|
||||
|
||||
self._init_torch_function_mode_stack()
|
||||
self.symbolic_torch_function_state = SymbolicTorchFunctionState(
|
||||
torch_function_mode_stack
|
||||
)
|
||||
|
||||
self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = []
|
||||
if export:
|
||||
@ -2762,29 +2765,6 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
)
|
||||
unimplemented(msg)
|
||||
|
||||
def _init_torch_function_mode_stack(self):
|
||||
from .variables.torch_function import TorchFunctionModeStackVariable
|
||||
|
||||
TorchFunctionModeStackVariable.reset()
|
||||
|
||||
self.symbolic_torch_function_mode_stack: Deque[
|
||||
TorchFunctionModeVariable
|
||||
] = collections.deque()
|
||||
# We want to retrieve all modes to properly reconstruct the stack if needed
|
||||
py_stack = get_torch_function_mode_stack(filter_ignored=False)
|
||||
|
||||
if py_stack:
|
||||
has_device_context = isinstance(
|
||||
py_stack[0], torch.utils._device.DeviceContext
|
||||
)
|
||||
|
||||
for i, val in enumerate(py_stack):
|
||||
self.symbolic_torch_function_mode_stack.append(
|
||||
variables.LazyVariableTracker.create(
|
||||
val, source=TorchFunctionModeStackSource(i)
|
||||
)
|
||||
)
|
||||
|
||||
def get_example_value(self, source: Source):
|
||||
if isinstance(source, LocalSource):
|
||||
return self.f_locals[source.local_name]
|
||||
@ -3116,7 +3096,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
code,
|
||||
sub_locals,
|
||||
parent.symbolic_globals,
|
||||
parent.symbolic_torch_function_mode_stack,
|
||||
parent.symbolic_torch_function_state,
|
||||
closure_cells,
|
||||
func,
|
||||
)
|
||||
@ -3126,7 +3106,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
code,
|
||||
sub_locals,
|
||||
parent.symbolic_globals,
|
||||
parent.symbolic_torch_function_mode_stack,
|
||||
parent.symbolic_torch_function_state,
|
||||
closure_cells,
|
||||
func,
|
||||
)
|
||||
@ -3179,7 +3159,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
code: types.CodeType,
|
||||
symbolic_locals: Dict[str, VariableTracker],
|
||||
symbolic_globals: Dict[str, VariableTracker],
|
||||
symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"],
|
||||
symbolic_torch_function_state: SymbolicTorchFunctionState,
|
||||
closure_cells: Dict[str, VariableTracker],
|
||||
funcvar: BaseUserFunctionVariable,
|
||||
) -> None:
|
||||
@ -3196,7 +3176,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
f_builtins=f_builtins,
|
||||
symbolic_locals=symbolic_locals,
|
||||
symbolic_globals=symbolic_globals,
|
||||
symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack,
|
||||
symbolic_torch_function_state=symbolic_torch_function_state,
|
||||
instructions=instructions,
|
||||
code_options={k: getattr(code, k) for k in get_code_keys()},
|
||||
f_code=code,
|
||||
|
@ -163,6 +163,7 @@ def debug_insert_nops(
|
||||
local_scope=locals(),
|
||||
global_scope=globals(),
|
||||
f_code=frame.f_code,
|
||||
torch_function_mode_stack=[],
|
||||
)
|
||||
|
||||
return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))
|
||||
|
@ -303,6 +303,7 @@ manual_torch_name_rule_map = {
|
||||
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
|
||||
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
|
||||
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
|
||||
"torch.set_default_device": UserFunctionVariable,
|
||||
"torch.sparse_bsc_tensor": SkipFunctionVariable,
|
||||
"torch.sparse_bsr_tensor": SkipFunctionVariable,
|
||||
"torch.sparse_csc_tensor": SkipFunctionVariable,
|
||||
@ -2795,7 +2796,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch.random.initial_seed",
|
||||
"torch.random.seed",
|
||||
"torch.return_types.pytree_register_structseq",
|
||||
"torch.set_default_device",
|
||||
"torch.set_default_dtype",
|
||||
"torch.set_default_tensor_type",
|
||||
"torch.set_deterministic_debug_mode",
|
||||
@ -3254,6 +3254,7 @@ MOD_INLINELIST = [
|
||||
"torch.testing",
|
||||
"torch.utils._content_store",
|
||||
"torch.utils._contextlib",
|
||||
"torch.utils._device",
|
||||
"torch.utils._foreach_utils",
|
||||
"torch.utils._python_dispatch",
|
||||
"torch.utils._pytree",
|
||||
@ -3588,7 +3589,9 @@ def lookup_inner(
|
||||
if reasons is not None:
|
||||
reasons.add("func name is patched_init")
|
||||
return SkipFunctionVariable
|
||||
elif name == "__torch_function__":
|
||||
elif name == "__torch_function__" or (
|
||||
obj and obj.__name__ == "__torch_function__"
|
||||
):
|
||||
if reasons is not None:
|
||||
reasons.add("func name is __torch_function__")
|
||||
return UserFunctionVariable
|
||||
|
@ -63,7 +63,6 @@ import torch.fx.experimental.symbolic_shapes
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import fx
|
||||
from torch._C import (
|
||||
_get_function_stack_at,
|
||||
_instruction_counter,
|
||||
_len_torch_function_stack,
|
||||
_pop_torch_function_stack,
|
||||
@ -3062,14 +3061,10 @@ def is_parameter_freezing():
|
||||
return torch._inductor.config.freezing and not torch.is_grad_enabled()
|
||||
|
||||
|
||||
def get_torch_function_mode_stack(filter_ignored=True):
|
||||
from .variables.torch_function import IGNORED_MODES
|
||||
|
||||
stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())]
|
||||
if filter_ignored:
|
||||
stack = [mode for mode in stack if type(mode) not in IGNORED_MODES]
|
||||
|
||||
return stack
|
||||
def get_torch_function_mode_stack():
|
||||
return [
|
||||
get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack())
|
||||
]
|
||||
|
||||
|
||||
def get_torch_function_mode_stack_at(ind):
|
||||
@ -3085,6 +3080,11 @@ def set_torch_function_mode_stack(stack):
|
||||
_push_on_torch_function_stack(mode)
|
||||
|
||||
|
||||
def clear_torch_function_mode_stack():
|
||||
for i in range(_len_torch_function_stack()):
|
||||
_pop_torch_function_stack()
|
||||
|
||||
|
||||
def verify_guard_fn_signature(value):
|
||||
fn = value.__metadata_guard__
|
||||
sig = inspect.signature(fn)
|
||||
|
@ -204,6 +204,7 @@ from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
|
||||
from .torch_function import (
|
||||
build_torch_function_fn,
|
||||
TensorWithTFOverrideVariable,
|
||||
torch_function_mode_stack_state_mgr,
|
||||
TorchFunctionModeVariable,
|
||||
)
|
||||
from .user_defined import (
|
||||
@ -1663,15 +1664,16 @@ class VariableBuilder:
|
||||
# but warning is not the end of the world
|
||||
assert isinstance(value.base, np.nditer)
|
||||
|
||||
try:
|
||||
tensor_value = _util._try_convert_to_tensor(value)
|
||||
if readonly:
|
||||
from torch._prims_common import clone_preserve_strides
|
||||
with torch_function_mode_stack_state_mgr.temp_restore_stack():
|
||||
try:
|
||||
tensor_value = _util._try_convert_to_tensor(value)
|
||||
if readonly:
|
||||
from torch._prims_common import clone_preserve_strides
|
||||
|
||||
tensor_value = clone_preserve_strides(tensor_value)
|
||||
except NotImplementedError as e:
|
||||
# failed to convert to tensor, graph break
|
||||
unimplemented(str(e))
|
||||
tensor_value = clone_preserve_strides(tensor_value)
|
||||
except NotImplementedError as e:
|
||||
# failed to convert to tensor, graph break
|
||||
unimplemented(str(e))
|
||||
|
||||
# We do this because we want the full behavior of guarding the numpy ndarray as if it were
|
||||
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
|
||||
|
@ -125,6 +125,12 @@ class ContextWrappingVariable(VariableTracker):
|
||||
if isinstance(args[0], UserFunctionVariable):
|
||||
return WrappedUserFunctionVariable(args[0], self)
|
||||
|
||||
def supports_graph_breaks(self):
|
||||
return True
|
||||
|
||||
def exit_on_graph_break(self):
|
||||
return True
|
||||
|
||||
|
||||
class GenericContextWrappingVariable(UserDefinedObjectVariable):
|
||||
# Some methods in ContextWrappingVariable assumes the arguments are
|
||||
@ -183,6 +189,12 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable):
|
||||
tx.generic_context_manager_depth -= 1
|
||||
return x
|
||||
|
||||
def supports_graph_breaks(self):
|
||||
return False
|
||||
|
||||
def exit_on_graph_break(self):
|
||||
return True
|
||||
|
||||
|
||||
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
|
||||
"""represents torch grad requries grad"""
|
||||
@ -637,6 +649,8 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
|
||||
|
||||
def _call_func(self, tx: "InstructionTranslator", values):
|
||||
assert len(values) == 1
|
||||
tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0]
|
||||
tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0]
|
||||
tx.output.set_torch_function_state(values[0])
|
||||
|
||||
|
||||
|
@ -149,6 +149,18 @@ tracing_state_functions = {
|
||||
bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"])
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_overridable_functions():
|
||||
from itertools import chain
|
||||
|
||||
from torch.overrides import get_overridable_functions as get_overridable_functions_
|
||||
|
||||
funcs = set(chain(*get_overridable_functions_().values()))
|
||||
more = {torch.ones, torch.ones_like, torch.zeros, torch.zeros_like, torch.empty}
|
||||
funcs.update(more)
|
||||
return funcs
|
||||
|
||||
|
||||
class BaseTorchVariable(VariableTracker):
|
||||
"""common base for all torch.* functions, classes, modules and other things"""
|
||||
|
||||
@ -782,10 +794,10 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
self, tx: "InstructionTranslator", *args, **kwargs
|
||||
):
|
||||
assert not args and not kwargs
|
||||
if not tx.symbolic_torch_function_mode_stack:
|
||||
if not tx.symbolic_torch_function_state.mode_stack:
|
||||
raise unimplemented("Popping from an empty torch function mode stack")
|
||||
TorchFunctionModeStackVariable.register_mutation(tx)
|
||||
return tx.symbolic_torch_function_mode_stack.pop()
|
||||
return tx.symbolic_torch_function_state.pop_torch_function_mode()
|
||||
|
||||
@register(torch._C._push_on_torch_function_stack)
|
||||
def handle_push_torch_function(
|
||||
@ -793,7 +805,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
):
|
||||
assert len(args) == 1 and not kwargs
|
||||
TorchFunctionModeStackVariable.register_mutation(tx)
|
||||
tx.symbolic_torch_function_mode_stack.append(args[0])
|
||||
tx.symbolic_torch_function_state.push_torch_function_mode(args[0])
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
@register(torch._C._len_torch_function_stack)
|
||||
@ -801,7 +813,16 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
self, tx: "InstructionTranslator", *args, **kwargs
|
||||
):
|
||||
assert not args and not kwargs
|
||||
return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack))
|
||||
return ConstantVariable.create(
|
||||
len(tx.symbolic_torch_function_state.mode_stack)
|
||||
)
|
||||
|
||||
@register(torch._C._get_function_stack_at)
|
||||
def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
assert len(args) == 1 and not kwargs
|
||||
ind = args[0].as_python_constant()
|
||||
assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack)
|
||||
return tx.symbolic_torch_function_state.mode_stack[ind]
|
||||
|
||||
@register(torch.set_default_device)
|
||||
def handle_set_default_device(
|
||||
@ -820,7 +841,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
else:
|
||||
TorchFunctionModeStackVariable.register_device_context_insertion(tx)
|
||||
|
||||
return None
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
return handlers
|
||||
|
||||
@ -833,6 +854,9 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
from . import ConstantVariable, SymNodeVariable, TensorVariable
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
if self.torch_function_override_enabled(tx, args, kwargs):
|
||||
return dispatch_torch_function(tx, self, args, kwargs)
|
||||
|
||||
if self.can_constant_fold_through() and check_unspec_or_constant_args(
|
||||
args, kwargs
|
||||
):
|
||||
@ -850,147 +874,144 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
if result:
|
||||
return result
|
||||
|
||||
if can_dispatch_torch_function(tx, args, kwargs):
|
||||
return dispatch_torch_function(tx, self, args, kwargs)
|
||||
else:
|
||||
any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
|
||||
any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
|
||||
|
||||
all_ints_or_floats = all(
|
||||
isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
|
||||
for x in args
|
||||
)
|
||||
if (
|
||||
getattr(self.value, "__module__", "") == "torch"
|
||||
and self.value.__name__ in bin_ops
|
||||
and any_symints_or_symfloats
|
||||
and all_ints_or_floats
|
||||
):
|
||||
msg = f"""\
|
||||
all_ints_or_floats = all(
|
||||
isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
|
||||
for x in args
|
||||
)
|
||||
if (
|
||||
getattr(self.value, "__module__", "") == "torch"
|
||||
and self.value.__name__ in bin_ops
|
||||
and any_symints_or_symfloats
|
||||
and all_ints_or_floats
|
||||
):
|
||||
msg = f"""\
|
||||
Calling {str(self.value)} on only torch.SymInt arguments is not yet supported.
|
||||
To support this behavior, we need to allow const-propping tensors that store symint data.
|
||||
For now, dynamo will explicitly graph break when it encounters user code with this behavior.
|
||||
"""
|
||||
log.warning(msg)
|
||||
unimplemented(msg)
|
||||
log.warning(msg)
|
||||
unimplemented(msg)
|
||||
|
||||
# TODO(voz): Replace w/ dynamic shape rewrite table.
|
||||
# Ideally, we would be able to do this at ctor time, but alas we need a combination
|
||||
# of value + args to determine this.
|
||||
fn_ = self.value
|
||||
if any_symints_or_symfloats:
|
||||
torch_sym_op = f"_sym_{self.value.__name__}"
|
||||
if getattr(self.value, "__module__", None) == "math" and hasattr(
|
||||
torch, torch_sym_op
|
||||
):
|
||||
fn_ = getattr(torch, torch_sym_op)
|
||||
# TODO(voz): Replace w/ dynamic shape rewrite table.
|
||||
# Ideally, we would be able to do this at ctor time, but alas we need a combination
|
||||
# of value + args to determine this.
|
||||
fn_ = self.value
|
||||
if any_symints_or_symfloats:
|
||||
torch_sym_op = f"_sym_{self.value.__name__}"
|
||||
if getattr(self.value, "__module__", None) == "math" and hasattr(
|
||||
torch, torch_sym_op
|
||||
):
|
||||
fn_ = getattr(torch, torch_sym_op)
|
||||
|
||||
fake_out_shape = None
|
||||
if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
|
||||
# Calling fake tensor propagation can mutate the out= tensor in
|
||||
# tx.output.tracked_fakes. tracked_fakes are used to apply
|
||||
# symbolic_shape guards. Mutating them destroys the information
|
||||
# prior to tracing, which is essential for creating right
|
||||
# guards. So save the shape now, and check later if it has
|
||||
# changed. If it has, graph break.
|
||||
fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
|
||||
fake_out_shape = None
|
||||
if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
|
||||
# Calling fake tensor propagation can mutate the out= tensor in
|
||||
# tx.output.tracked_fakes. tracked_fakes are used to apply
|
||||
# symbolic_shape guards. Mutating them destroys the information
|
||||
# prior to tracing, which is essential for creating right
|
||||
# guards. So save the shape now, and check later if it has
|
||||
# changed. If it has, graph break.
|
||||
fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
|
||||
|
||||
tensor_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
fn_,
|
||||
*proxy_args_kwargs(args, kwargs),
|
||||
),
|
||||
tensor_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
fn_,
|
||||
*proxy_args_kwargs(args, kwargs),
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(tensor_variable, TensorVariable)
|
||||
and "requires_grad" in kwargs
|
||||
and kwargs["requires_grad"].as_python_constant()
|
||||
):
|
||||
unimplemented(
|
||||
"""factory functions that return tensors that require grad are not supported.
|
||||
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(tensor_variable, TensorVariable)
|
||||
and "requires_grad" in kwargs
|
||||
and kwargs["requires_grad"].as_python_constant()
|
||||
):
|
||||
unimplemented(
|
||||
"""factory functions that return tensors that require grad are not supported.
|
||||
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
|
||||
)
|
||||
|
||||
if "out" in kwargs and not (
|
||||
isinstance(kwargs["out"], variables.ConstantVariable)
|
||||
and kwargs["out"].as_python_constant() is None
|
||||
):
|
||||
# out variants of torch operators like torch.sort and
|
||||
# torch.sigmoid mutate the tensors in the out field. Track such
|
||||
# tensors and rewrite the symbolic locals.
|
||||
if isinstance(tensor_variable, TupleVariable):
|
||||
assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
|
||||
output_tensor_names = [
|
||||
tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
|
||||
]
|
||||
for idx, name in enumerate(output_tensor_names):
|
||||
if name in tx.symbolic_locals:
|
||||
tx.symbolic_locals[name] = tensor_variable.items[idx]
|
||||
for out_tensor, result_tensor in zip(
|
||||
kwargs["out"].items, tensor_variable.items
|
||||
):
|
||||
if (
|
||||
out_tensor.source
|
||||
and out_tensor in tx.output.graphargs
|
||||
and isinstance(out_tensor, variables.TensorVariable)
|
||||
and isinstance(result_tensor, variables.TensorVariable)
|
||||
and out_tensor.size != result_tensor.size
|
||||
):
|
||||
# It's hard to get out variants with resizing on graph inputs work
|
||||
# properly across dynamo/aot/inductor, just fall back.
|
||||
unimplemented("out variants with resizing on graph inputs")
|
||||
elif isinstance(tensor_variable, TensorVariable):
|
||||
assert isinstance(kwargs["out"], TensorVariable)
|
||||
assert "example_value" in kwargs["out"].proxy.node.meta
|
||||
fake_tensor = tensor_variable.proxy.node.meta["example_value"]
|
||||
fake_out = kwargs["out"].proxy.node.meta["example_value"]
|
||||
if "out" in kwargs and not (
|
||||
isinstance(kwargs["out"], variables.ConstantVariable)
|
||||
and kwargs["out"].as_python_constant() is None
|
||||
):
|
||||
# out variants of torch operators like torch.sort and
|
||||
# torch.sigmoid mutate the tensors in the out field. Track such
|
||||
# tensors and rewrite the symbolic locals.
|
||||
if isinstance(tensor_variable, TupleVariable):
|
||||
assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
|
||||
output_tensor_names = [
|
||||
tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
|
||||
]
|
||||
for idx, name in enumerate(output_tensor_names):
|
||||
if name in tx.symbolic_locals:
|
||||
tx.symbolic_locals[name] = tensor_variable.items[idx]
|
||||
for out_tensor, result_tensor in zip(
|
||||
kwargs["out"].items, tensor_variable.items
|
||||
):
|
||||
if (
|
||||
kwargs["out"].source
|
||||
and kwargs["out"] in tx.output.graphargs
|
||||
and fake_out_shape != fake_tensor.shape
|
||||
out_tensor.source
|
||||
and out_tensor in tx.output.graphargs
|
||||
and isinstance(out_tensor, variables.TensorVariable)
|
||||
and isinstance(result_tensor, variables.TensorVariable)
|
||||
and out_tensor.size != result_tensor.size
|
||||
):
|
||||
# It's hard to get out variants with resizing on graph inputs work
|
||||
# properly across dynamo/aot/inductor, just fall back.
|
||||
unimplemented("out variants with resizing on graph inputs")
|
||||
elif isinstance(tensor_variable, TensorVariable):
|
||||
assert isinstance(kwargs["out"], TensorVariable)
|
||||
assert "example_value" in kwargs["out"].proxy.node.meta
|
||||
fake_tensor = tensor_variable.proxy.node.meta["example_value"]
|
||||
fake_out = kwargs["out"].proxy.node.meta["example_value"]
|
||||
if (
|
||||
kwargs["out"].source
|
||||
and kwargs["out"] in tx.output.graphargs
|
||||
and fake_out_shape != fake_tensor.shape
|
||||
):
|
||||
# It's hard to get out variants with resizing on graph inputs work
|
||||
# properly across dynamo/aot/inductor, just fall back.
|
||||
unimplemented("out variants with resizing on graph inputs")
|
||||
if not torch._prims_common.is_contiguous(fake_out):
|
||||
# It's difficult to handle strides correctly in functionalization
|
||||
# when calling an out= op with a non-contiguous out argument
|
||||
unimplemented(
|
||||
"out= op was called where output tensor was non-contiguous"
|
||||
)
|
||||
name = tx.find_symbolic_locals_name(kwargs["out"])
|
||||
if name in tx.symbolic_locals:
|
||||
tx.symbolic_locals[name] = tensor_variable
|
||||
elif (
|
||||
isinstance(tensor_variable, ConstantVariable)
|
||||
and tensor_variable.value is None
|
||||
):
|
||||
# Handle out-variant custom ops that return None.
|
||||
if isinstance(kwargs["out"], TensorVariable):
|
||||
assert "example_value" in kwargs["out"].proxy.node.meta
|
||||
fake_out = kwargs["out"].proxy.node.meta["example_value"]
|
||||
if not torch._prims_common.is_contiguous(fake_out):
|
||||
# It's difficult to handle strides correctly in functionalization
|
||||
# when calling an out= op with a non-contiguous out argument
|
||||
unimplemented(
|
||||
"out= op was called where output tensor was non-contiguous"
|
||||
)
|
||||
name = tx.find_symbolic_locals_name(kwargs["out"])
|
||||
if name in tx.symbolic_locals:
|
||||
tx.symbolic_locals[name] = tensor_variable
|
||||
elif (
|
||||
isinstance(tensor_variable, ConstantVariable)
|
||||
and tensor_variable.value is None
|
||||
):
|
||||
# Handle out-variant custom ops that return None.
|
||||
if isinstance(kwargs["out"], TensorVariable):
|
||||
assert "example_value" in kwargs["out"].proxy.node.meta
|
||||
fake_out = kwargs["out"].proxy.node.meta["example_value"]
|
||||
elif isinstance(kwargs["out"], ListVariable):
|
||||
for idx, x in enumerate(kwargs["out"].items):
|
||||
assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined]
|
||||
fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined]
|
||||
if not torch._prims_common.is_contiguous(fake_out):
|
||||
# It's difficult to handle strides correctly in functionalization
|
||||
# when calling an out= op with a non-contiguous out argument
|
||||
unimplemented(
|
||||
"out= op was called where output tensor was non-contiguous"
|
||||
"out= op was called where some of the output tensors were non-contiguous"
|
||||
)
|
||||
elif isinstance(kwargs["out"], ListVariable):
|
||||
for idx, x in enumerate(kwargs["out"].items):
|
||||
assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined]
|
||||
fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined]
|
||||
if not torch._prims_common.is_contiguous(fake_out):
|
||||
# It's difficult to handle strides correctly in functionalization
|
||||
# when calling an out= op with a non-contiguous out argument
|
||||
unimplemented(
|
||||
"out= op was called where some of the output tensors were non-contiguous"
|
||||
)
|
||||
else:
|
||||
unimplemented(f"out variant of {type(kwargs['out'])}")
|
||||
else:
|
||||
unimplemented(f"out variant of {type(kwargs['out'])}")
|
||||
|
||||
return tensor_variable
|
||||
return tensor_variable
|
||||
|
||||
def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs):
|
||||
"""inline behavior of torch.nn.modules.utils._ntuple"""
|
||||
@ -1118,3 +1139,12 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
||||
source
|
||||
)
|
||||
return result
|
||||
|
||||
def torch_function_override_enabled(self, tx, args, kwargs):
|
||||
return (
|
||||
self.get_function() in get_overridable_functions()
|
||||
or isinstance(
|
||||
self.get_function(),
|
||||
(torch._ops.OpOverload, torch._ops.OpOverloadPacket),
|
||||
)
|
||||
) and can_dispatch_torch_function(tx, args, kwargs)
|
||||
|
@ -1,20 +1,36 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import inspect
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
from typing import Deque, Dict, List, TYPE_CHECKING
|
||||
|
||||
import torch._C
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._guards import Source
|
||||
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
|
||||
from torch.overrides import (
|
||||
_get_overloaded_args,
|
||||
get_default_nowrap_functions,
|
||||
TorchFunctionMode,
|
||||
)
|
||||
from torch.utils._device import DeviceContext
|
||||
|
||||
from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..polyfills import NoEnterTorchFunctionMode
|
||||
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
|
||||
from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
|
||||
from ..utils import (
|
||||
class_has_getattribute,
|
||||
clear_torch_function_mode_stack,
|
||||
get_safe_global_name,
|
||||
has_torch_function,
|
||||
is_tensor_base_attr_getter,
|
||||
set_torch_function_mode_stack,
|
||||
)
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import ContextWrappingVariable
|
||||
from .ctx_manager import GenericContextWrappingVariable
|
||||
from .lazy import LazyVariableTracker
|
||||
from .lists import TupleVariable
|
||||
from .tensor import TensorSubclassVariable, TensorVariable
|
||||
from .user_defined import UserDefinedObjectVariable
|
||||
@ -52,11 +68,92 @@ banned_attrs = [
|
||||
if is_tensor_base_attr_getter(fn)
|
||||
]
|
||||
|
||||
# Today set default device is placed in the graph and guarded on separately
|
||||
# so we should not trace through it. In the future we can trace it once
|
||||
# mode tracing is implemented and not put in the graph, but this is more
|
||||
# of a BE project and can be evaluated later
|
||||
IGNORED_MODES = {DeviceContext}
|
||||
|
||||
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
|
||||
class TorchFunctionModeStackStateManager:
|
||||
def __init__(self):
|
||||
self.stack = []
|
||||
|
||||
def __enter__(self):
|
||||
self.stack = torch.overrides._get_current_function_mode_stack()
|
||||
clear_torch_function_mode_stack()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
set_torch_function_mode_stack(self.stack)
|
||||
self.stack = []
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temp_restore_stack(self):
|
||||
prev = torch.overrides._get_current_function_mode_stack()
|
||||
set_torch_function_mode_stack(self.stack)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
set_torch_function_mode_stack(prev)
|
||||
|
||||
|
||||
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
|
||||
|
||||
|
||||
class SymbolicTorchFunctionState:
|
||||
def __init__(self, py_stack):
|
||||
# This is annoyingly complicated because of how the torch function subclass + mode C API was designed
|
||||
# There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass
|
||||
# These are their definitions:
|
||||
# 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered
|
||||
# (if either are entered, this will be False)
|
||||
# 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR
|
||||
# torch._C.DisableTorchFunction has been entered
|
||||
# To disambiguate these and keep myself sane I added a C API to check whether all torch function
|
||||
# concepts (modes and subclasses) are enabled.
|
||||
# This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate
|
||||
# the stack length from the enablement state of torch function modes.
|
||||
# This is important because now if a mode is pushed while dynamo is tracing, we know whether
|
||||
# or not torch function modes are enabled and whether we should trace it.
|
||||
self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled()
|
||||
|
||||
# This differs from the C API of the same name
|
||||
# this will only be false iff we have entered torch._C.DisableTorchFunction
|
||||
# and does not take into account the mode stack length, while the C API bundles these
|
||||
# two concepts
|
||||
self.torch_function_mode_enabled = (
|
||||
not torch._C._is_torch_function_all_disabled()
|
||||
)
|
||||
|
||||
self.cur_mode = None
|
||||
|
||||
TorchFunctionModeStackVariable.reset()
|
||||
|
||||
self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque()
|
||||
|
||||
for i, val in enumerate(py_stack):
|
||||
self.mode_stack.append(
|
||||
LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i))
|
||||
)
|
||||
|
||||
def in_torch_function_mode(self):
|
||||
return len(self.mode_stack) > 0
|
||||
|
||||
def pop_torch_function_mode(self):
|
||||
return self.mode_stack.pop()
|
||||
|
||||
def push_torch_function_mode(self, mode_var):
|
||||
self.mode_stack.append(mode_var)
|
||||
|
||||
def call_torch_function_mode(self, tx, fn, types, args, kwargs):
|
||||
with self._pop_mode_for_inlining() as cur_mode:
|
||||
return cur_mode.call_torch_function(tx, fn, types, args, kwargs)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _pop_mode_for_inlining(self):
|
||||
old_mode = self.cur_mode
|
||||
self.cur_mode = self.pop_torch_function_mode()
|
||||
try:
|
||||
yield self.cur_mode
|
||||
finally:
|
||||
mode = self.cur_mode
|
||||
self.cur_mode = old_mode
|
||||
self.push_torch_function_mode(mode)
|
||||
|
||||
|
||||
class TorchFunctionModeStackVariable(VariableTracker):
|
||||
@ -88,19 +185,20 @@ class TorchFunctionModeStackVariable(VariableTracker):
|
||||
def register_mutation(cls, tx: "InstructionTranslator"):
|
||||
if cls.stack_value_singleton not in tx.output.side_effects:
|
||||
var = cls(
|
||||
source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack
|
||||
source=Source(),
|
||||
symbolic_stack=tx.symbolic_torch_function_state.mode_stack,
|
||||
)
|
||||
tx.output.side_effects.track_mutable(cls.stack_value_singleton, var)
|
||||
tx.output.side_effects.mutation(var)
|
||||
|
||||
@classmethod
|
||||
def register_device_context_insertion(cls, tx: "InstructionTranslator"):
|
||||
stack = tx.symbolic_torch_function_mode_stack
|
||||
stack = tx.symbolic_torch_function_state.mode_stack
|
||||
if stack and cls.is_device_context(stack[0]):
|
||||
return
|
||||
else:
|
||||
cls.offset += 1
|
||||
tx.symbolic_torch_function_mode_stack.insert(
|
||||
stack.insert(
|
||||
0,
|
||||
TorchFunctionModeVariable(
|
||||
None, source=TorchFunctionModeStackSource(-cls.offset)
|
||||
@ -109,7 +207,7 @@ class TorchFunctionModeStackVariable(VariableTracker):
|
||||
|
||||
@classmethod
|
||||
def clear_default_device(cls, tx: "InstructionTranslator"):
|
||||
stack = tx.symbolic_torch_function_mode_stack
|
||||
stack = tx.symbolic_torch_function_state.mode_stack
|
||||
if stack and cls.is_device_context(stack[0]):
|
||||
stack.popleft()
|
||||
cls.offset -= 1
|
||||
@ -123,24 +221,88 @@ class TorchFunctionModeStackVariable(VariableTracker):
|
||||
return ind + cls.offset
|
||||
|
||||
|
||||
class TorchFunctionModeVariable(ContextWrappingVariable):
|
||||
def __init__(self, value, **kwargs):
|
||||
super().__init__(value, **kwargs)
|
||||
self.value = value
|
||||
|
||||
class TorchFunctionModeVariable(GenericContextWrappingVariable):
|
||||
@staticmethod
|
||||
def get_global_mangled_name(tx, val):
|
||||
return get_safe_global_name(
|
||||
tx, f"__torch_function_mode_{val.__class__.__name__}", val
|
||||
def is_supported_torch_function_mode(ty):
|
||||
# Supported in this sense means we can support graph breaks under the
|
||||
# context.
|
||||
# We are able to trace custom modes but if there are graph breaks under them
|
||||
# and they have a custom __enter__/__exit__ we don't handle this for the
|
||||
# same reason we don't handle generic context managers: there may be side effects
|
||||
# that are now affected by executing the funtion across two frames instead of one
|
||||
# Today we support the enter/exit of the default TorchFunctionMode as well as
|
||||
# DeviceContext (which is used for set_default_device)
|
||||
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
|
||||
not class_has_getattribute(ty)
|
||||
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
|
||||
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
|
||||
)
|
||||
|
||||
def __init__(self, value, source=None, **kwargs):
|
||||
if value is not None:
|
||||
super().__init__(value, **kwargs)
|
||||
self.value = value
|
||||
self.cm_obj = value # needed for BC with calling enter from CM code
|
||||
self.source = source
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
# We don't support locally created torch function modes yet
|
||||
# This shouldn't be called unless we have a source
|
||||
assert self.source
|
||||
self.source.reconstruct(codegen)
|
||||
|
||||
def _call_func(self, tx, values):
|
||||
unimplemented("torch function mode context manager is not supported yet")
|
||||
def module_name(self):
|
||||
return self.value.__module__
|
||||
|
||||
def fn_name(self):
|
||||
return type(self.value).__name__
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
|
||||
return call_torch_function(
|
||||
tx,
|
||||
self,
|
||||
build_torch_function_fn(tx, self.value, self.source),
|
||||
fn,
|
||||
types,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
def enter(self, tx):
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
if isinstance(self.value, NoEnterTorchFunctionMode):
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
TorchInGraphFunctionVariable(
|
||||
torch._C._push_on_torch_function_stack
|
||||
).call_function(tx, [self], {})
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
|
||||
tx, [], {}
|
||||
)
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def reconstruct_type(self, codegen):
|
||||
ty = NoEnterTorchFunctionMode
|
||||
codegen(
|
||||
AttrSource(
|
||||
codegen.tx.import_source(ty.__module__),
|
||||
ty.__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def supports_graph_breaks(self):
|
||||
return True
|
||||
|
||||
def exit_on_graph_break(self):
|
||||
return False
|
||||
|
||||
|
||||
def _get_all_args(args, kwargs):
|
||||
@ -231,9 +393,13 @@ def build_torch_function_fn(tx: "InstructionTranslator", value, source):
|
||||
|
||||
|
||||
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
|
||||
return tx.output.torch_function_enabled and any(
|
||||
has_overridden_args = any(
|
||||
has_torch_function(arg) for arg in _get_all_args(args, kwargs)
|
||||
)
|
||||
tf_state = tx.symbolic_torch_function_state
|
||||
return (has_overridden_args and tf_state.torch_function_subclass_enabled) or (
|
||||
tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode()
|
||||
)
|
||||
|
||||
|
||||
def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
|
||||
@ -245,11 +411,20 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
|
||||
_get_subclass_type,
|
||||
)
|
||||
|
||||
types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args])
|
||||
|
||||
if tx.symbolic_torch_function_state.in_torch_function_mode():
|
||||
res = tx.symbolic_torch_function_state.call_torch_function_mode(
|
||||
tx, fn, types, args, kwargs
|
||||
)
|
||||
if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
|
||||
return res
|
||||
|
||||
for arg in overloaded_args:
|
||||
res = arg.call_torch_function(
|
||||
tx,
|
||||
fn,
|
||||
TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]),
|
||||
types,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
|
@ -9,6 +9,7 @@ import inspect
|
||||
import itertools
|
||||
import random
|
||||
import sys
|
||||
import threading
|
||||
import types
|
||||
import warnings
|
||||
from typing import Dict, Generic, List, TYPE_CHECKING
|
||||
@ -82,11 +83,6 @@ def is_forbidden_context_manager(ctx):
|
||||
from _pytest.python_api import RaisesContext
|
||||
from _pytest.recwarn import WarningsChecker
|
||||
|
||||
# TODO mlazos: Temporary to get this stack to pass
|
||||
# remove in subsequent PR
|
||||
from torch.overrides import BaseTorchFunctionMode
|
||||
|
||||
f_ctxs.append(BaseTorchFunctionMode)
|
||||
f_ctxs.append(RaisesContext)
|
||||
f_ctxs.append(WarningsChecker)
|
||||
except ImportError:
|
||||
@ -413,15 +409,25 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
and self.source
|
||||
and not is_forbidden_context_manager(self.value)
|
||||
):
|
||||
# import here to avoid an unfortunate circular dependency.
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
from .ctx_manager import GenericContextWrappingVariable
|
||||
from .torch_function import TorchFunctionModeVariable
|
||||
|
||||
if issubclass(
|
||||
self.value, TorchFunctionMode
|
||||
) and TorchFunctionModeVariable.is_supported_torch_function_mode(
|
||||
self.value
|
||||
):
|
||||
var_cls = TorchFunctionModeVariable
|
||||
else:
|
||||
var_cls = GenericContextWrappingVariable
|
||||
|
||||
cm_obj = tx.output.side_effects.track_object_new(
|
||||
self.source, self.value, GenericContextWrappingVariable, {}
|
||||
self.source, self.value, var_cls, {}
|
||||
)
|
||||
cm_obj.call_method(tx, "__init__", args, kwargs)
|
||||
return cm_obj
|
||||
|
||||
elif is_namedtuple_cls(self.value):
|
||||
fields = namedtuple_fields(self.value)
|
||||
# check if this a quasi-namedtuple or a real one
|
||||
@ -711,7 +717,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
if method is object.__init__:
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
if is_standard_setattr(method):
|
||||
if is_standard_setattr(method) or isinstance(self.value, threading.local):
|
||||
return self.method_setattr_standard(tx, *args, **kwargs)
|
||||
|
||||
# [NOTE] OrderedDict, dict subtypes must always have source
|
||||
@ -809,7 +815,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
def needs_slow_setattr(self):
|
||||
return not is_standard_setattr(
|
||||
inspect.getattr_static(self.value, "__setattr__", None)
|
||||
)
|
||||
) and not isinstance(self.value, threading.local)
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
if (
|
||||
|
@ -506,7 +506,11 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc:
|
||||
if (
|
||||
not torch.compiler.is_dynamo_compiling()
|
||||
and log.isEnabledFor(logging.DEBUG)
|
||||
and config.extended_debug_current_loc
|
||||
):
|
||||
frame = _find_user_code_frame()
|
||||
if frame is not None:
|
||||
log.debug(
|
||||
|
@ -28,6 +28,7 @@ from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch._subclasses.functional_tensor import disable_functional_mode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_temp_remove_metadata_torch_function_mode,
|
||||
_temp_remove_pre_dispatch_torch_function_mode,
|
||||
disable_proxy_modes_tracing,
|
||||
ProxyTorchDispatchMode,
|
||||
@ -129,6 +130,10 @@ def cond(pred, true_fn, false_fn, operands):
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return cond_op(pred, true_fn, false_fn, operands)
|
||||
|
||||
from torch._dynamo.backends.debugging import (
|
||||
make_eager_backend_with_torch_function_mode,
|
||||
)
|
||||
|
||||
if isinstance(pred, (bool, int, float)):
|
||||
log.warning(
|
||||
"Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
|
||||
@ -169,12 +174,15 @@ def cond(pred, true_fn, false_fn, operands):
|
||||
def _cond_op_wrapper(*args, **kwargs):
|
||||
return cond_op(*args, **kwargs)
|
||||
|
||||
with _set_compilation_env():
|
||||
with torch._dynamo.utils.disable_cache_limit():
|
||||
with _temp_remove_pre_dispatch_torch_function_mode():
|
||||
return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)(
|
||||
pred, true_fn, false_fn, operands
|
||||
)
|
||||
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode():
|
||||
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
|
||||
if metadata_mode:
|
||||
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
|
||||
else:
|
||||
backend = "eager"
|
||||
return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)(
|
||||
pred, true_fn, false_fn, operands
|
||||
)
|
||||
|
||||
|
||||
def create_fw_bw_graph_branches(true_fn, false_fn, *operands):
|
||||
|
@ -15,7 +15,11 @@ from torch._higher_order_ops.utils import (
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_temp_remove_metadata_torch_function_mode,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
)
|
||||
|
||||
|
||||
class WhileLoopOp(HigherOrderOperator):
|
||||
@ -113,6 +117,9 @@ def while_loop(cond_fn, body_fn, carried_inputs):
|
||||
- 'while_loop' only supports **inference** right now. Autograd will be supported in the future.
|
||||
|
||||
"""
|
||||
from torch._dynamo.backends.debugging import (
|
||||
make_eager_backend_with_torch_function_mode,
|
||||
)
|
||||
|
||||
# Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo.
|
||||
# parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs.
|
||||
@ -140,9 +147,15 @@ def while_loop(cond_fn, body_fn, carried_inputs):
|
||||
return while_loop_op(*args, **kwargs)
|
||||
|
||||
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
|
||||
return torch.compile(_while_loop_op_wrapper, backend="eager", fullgraph=True)(
|
||||
cond_fn, body_fn, carried_inputs, additional_inputs
|
||||
)
|
||||
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
|
||||
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
|
||||
if metadata_mode:
|
||||
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
|
||||
else:
|
||||
backend = "eager"
|
||||
return torch.compile(
|
||||
_while_loop_op_wrapper, backend=backend, fullgraph=True
|
||||
)(cond_fn, body_fn, carried_inputs, additional_inputs)
|
||||
|
||||
|
||||
@while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
|
@ -2515,62 +2515,40 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard {
|
||||
public:
|
||||
TORCH_FUNCTION_MODE_STACK(
|
||||
const py::list& initial_stack,
|
||||
const py::list& ignored_types,
|
||||
py::object verbose_code_parts)
|
||||
: LeafGuard(std::move(verbose_code_parts)),
|
||||
_ref_stack(),
|
||||
_ignored_types() {
|
||||
: LeafGuard(std::move(verbose_code_parts)), _ref_stack() {
|
||||
Py_ssize_t len = PyList_Size(initial_stack.ptr());
|
||||
for (Py_ssize_t idx = 0; idx < len; idx++) {
|
||||
PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref
|
||||
this->_ref_stack.push_back(Py_TYPE(mode));
|
||||
}
|
||||
|
||||
len = PyList_Size(ignored_types.ptr());
|
||||
for (Py_ssize_t idx = 0; idx < len; idx++) {
|
||||
PyObject* type_obj =
|
||||
PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref
|
||||
if (PyType_Check(type_obj) == 0) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError, "ignored_types should contain a list of types");
|
||||
return;
|
||||
}
|
||||
PyTypeObject* type = (PyTypeObject*)type_obj;
|
||||
this->_ignored_types.insert(type);
|
||||
auto type = Py_TYPE(mode);
|
||||
this->_ref_stack.push_back(type);
|
||||
}
|
||||
}
|
||||
|
||||
bool check_nopybind(PyObject* value) override {
|
||||
// Ignore value arg, only used to satisfy the interface
|
||||
size_t ref_ind = 0;
|
||||
int64_t len = at::impl::PythonTorchFunctionTLS::stack_len();
|
||||
const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len();
|
||||
const size_t ref_stack_size = this->_ref_stack.size();
|
||||
|
||||
for (int64_t idx = 0; idx < len; idx++) {
|
||||
if (len != ref_stack_size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int64_t idx = 0; (size_t)idx < len; idx++) {
|
||||
std::shared_ptr<c10::SafePyObject> mode =
|
||||
at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
|
||||
|
||||
PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
|
||||
// skip ignored types
|
||||
if (this->_ignored_types.count(mode_type) > 0) {
|
||||
continue;
|
||||
}
|
||||
// if we already have more non-ignored modes than the ref stack
|
||||
// or if the mode doesn't match at the current index, return false
|
||||
else if (
|
||||
(ref_stack_size == 0) || (ref_ind > ref_stack_size - 1) ||
|
||||
mode_type != _ref_stack[ref_ind]) {
|
||||
if (mode_type != _ref_stack.at(idx)) {
|
||||
return false;
|
||||
}
|
||||
ref_ind++;
|
||||
}
|
||||
|
||||
return ref_ind == this->_ref_stack.size();
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<PyTypeObject*> _ref_stack;
|
||||
std::set<PyTypeObject*> _ignored_types;
|
||||
};
|
||||
|
||||
class TENSOR_MATCH : public LeafGuard {
|
||||
@ -3672,7 +3650,7 @@ PyObject* torch_c_dynamo_guards_init() {
|
||||
LeafGuard,
|
||||
std::shared_ptr<TORCH_FUNCTION_MODE_STACK>>(
|
||||
py_m, "TORCH_FUNCTION_MODE_STACK")
|
||||
.def(py::init<py::list, py::list, py::list>())
|
||||
.def(py::init<py::list, py::list>())
|
||||
.def("__call__", &TORCH_FUNCTION_MODE_STACK::check);
|
||||
py::class_<DATA_PTR_MATCH, LeafGuard, std::shared_ptr<DATA_PTR_MATCH>>(
|
||||
py_m, "DATA_PTR_MATCH")
|
||||
@ -3903,10 +3881,9 @@ PyObject* torch_c_dynamo_guards_init() {
|
||||
"add_torch_function_mode_stack_guard",
|
||||
[](GuardManager& self,
|
||||
const py::list& initial_stack,
|
||||
const py::list& ignored_types,
|
||||
py::object verbose_code_parts) -> void {
|
||||
self.add_leaf_guard(std::make_shared<TORCH_FUNCTION_MODE_STACK>(
|
||||
initial_stack, ignored_types, std::move(verbose_code_parts)));
|
||||
initial_stack, std::move(verbose_code_parts)));
|
||||
})
|
||||
.def(
|
||||
"add_data_ptr_guard",
|
||||
|
@ -17,7 +17,7 @@ import typing_extensions
|
||||
import warnings
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, ExitStack, nullcontext
|
||||
from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
@ -1084,38 +1084,43 @@ class PythonKeyTracer(Tracer):
|
||||
return e
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]:
|
||||
from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode
|
||||
def _make_temp_remove_mode_context_manager(
|
||||
mode_ty: Type[TorchFunctionMode],
|
||||
) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]:
|
||||
@contextmanager
|
||||
def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]:
|
||||
from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode
|
||||
|
||||
temp_elements = []
|
||||
pre_dispatch_mode = None
|
||||
temp_elements = []
|
||||
removed_mode = None
|
||||
|
||||
while _len_torch_function_stack() > 0:
|
||||
mode = _pop_mode()
|
||||
if isinstance(mode, PreDispatchTorchFunctionMode):
|
||||
pre_dispatch_mode = mode
|
||||
break
|
||||
else:
|
||||
temp_elements.append(mode)
|
||||
while _len_torch_function_stack() > 0:
|
||||
mode = _pop_mode()
|
||||
if isinstance(mode, mode_ty):
|
||||
removed_mode = mode
|
||||
break
|
||||
else:
|
||||
temp_elements.append(mode)
|
||||
|
||||
for mode in reversed(temp_elements):
|
||||
_push_mode(mode)
|
||||
for mode in reversed(temp_elements):
|
||||
_push_mode(mode)
|
||||
|
||||
try:
|
||||
yield
|
||||
try:
|
||||
yield removed_mode
|
||||
|
||||
finally:
|
||||
if pre_dispatch_mode is not None:
|
||||
count = len(temp_elements)
|
||||
while count > 0:
|
||||
mode = _pop_mode()
|
||||
count -= 1
|
||||
finally:
|
||||
if removed_mode is not None:
|
||||
count = len(temp_elements)
|
||||
while count > 0:
|
||||
mode = _pop_mode()
|
||||
count -= 1
|
||||
|
||||
temp_elements.append(pre_dispatch_mode)
|
||||
temp_elements.append(removed_mode)
|
||||
|
||||
for mode in reversed(temp_elements):
|
||||
_push_mode(mode)
|
||||
for mode in reversed(temp_elements):
|
||||
_push_mode(mode)
|
||||
|
||||
return context_manager_fn
|
||||
|
||||
|
||||
@torch._disable_dynamo
|
||||
@ -1230,6 +1235,11 @@ class TorchFunctionMetadataMode(TorchFunctionMode):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager(
|
||||
TorchFunctionMetadataMode
|
||||
)
|
||||
|
||||
|
||||
# This mode is **only** used for pre_dispatch tracing.
|
||||
# In particular, we need to make sure that autograd/autocast API's
|
||||
# that do not desugar into dispatcher operators stay in the graph.
|
||||
@ -1258,6 +1268,11 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager(
|
||||
PreDispatchTorchFunctionMode
|
||||
)
|
||||
|
||||
|
||||
class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||
# Ensure this is read-only; this exists only for legacy reasons
|
||||
@property
|
||||
|
@ -19,6 +19,7 @@ from torch._higher_order_ops.flex_attention import (
|
||||
)
|
||||
from torch._higher_order_ops.utils import _set_compilation_env
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_temp_remove_metadata_torch_function_mode,
|
||||
_temp_remove_pre_dispatch_torch_function_mode,
|
||||
)
|
||||
from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input
|
||||
@ -1027,6 +1028,10 @@ def flex_attention(
|
||||
if not torch._dynamo.is_dynamo_supported():
|
||||
raise RuntimeError("flex_attention requires dynamo support")
|
||||
|
||||
from torch._dynamo.backends.debugging import (
|
||||
make_eager_backend_with_torch_function_mode,
|
||||
)
|
||||
|
||||
# Dynamo is expecting a callable with "__code__" attribute.
|
||||
# We cannot directly pass hop to it. So we wrap it in a dummy function.
|
||||
def _flex_attention_hop_wrapper(*args, **kwargs):
|
||||
@ -1035,18 +1040,25 @@ def flex_attention(
|
||||
with _set_compilation_env():
|
||||
with torch._dynamo.utils.disable_cache_limit():
|
||||
with _temp_remove_pre_dispatch_torch_function_mode():
|
||||
out, lse = torch.compile(
|
||||
_flex_attention_hop_wrapper, backend="eager", fullgraph=True
|
||||
)(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
score_mod,
|
||||
block_mask.as_tuple(),
|
||||
scale,
|
||||
kernel_options,
|
||||
)
|
||||
if return_lse:
|
||||
return out, lse * math.log(2)
|
||||
else:
|
||||
return out
|
||||
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
|
||||
if metadata_mode:
|
||||
backend = make_eager_backend_with_torch_function_mode(
|
||||
metadata_mode
|
||||
)
|
||||
else:
|
||||
backend = "eager"
|
||||
out, lse = torch.compile(
|
||||
_flex_attention_hop_wrapper, backend="eager", fullgraph=True
|
||||
)(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
score_mod,
|
||||
block_mask.as_tuple(),
|
||||
scale,
|
||||
kernel_options,
|
||||
)
|
||||
if return_lse:
|
||||
return out, lse * math.log(2)
|
||||
else:
|
||||
return out
|
||||
|
Reference in New Issue
Block a user