mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-03 23:45:05 +08:00 
			
		
		
		
	Compare commits
	
		
			7 Commits
		
	
	
		
			ciflow/tru
			...
			mlazos/tf-
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 153c25faf3 | |||
| b9e76c2b8b | |||
| a22fb0bed4 | |||
| 16dffb028b | |||
| c2ebc2de89 | |||
| 7cba4399d3 | |||
| 66a8797357 | 
@ -3380,6 +3380,21 @@ utils_device.CURRENT_DEVICE == None""".split(
 | 
			
		||||
        self.assertTrue(same(obj41.y, obj42.y))
 | 
			
		||||
        self.assertEqual(cnts.frame_count, 1)
 | 
			
		||||
 | 
			
		||||
    def test_thread_local_setattr(self):
 | 
			
		||||
        from threading import local
 | 
			
		||||
 | 
			
		||||
        loc = local()
 | 
			
		||||
 | 
			
		||||
        @torch.compile(fullgraph=True)
 | 
			
		||||
        def fn(x, l):
 | 
			
		||||
            l.x = x
 | 
			
		||||
            return x + 1
 | 
			
		||||
 | 
			
		||||
        x = torch.ones(2, 2)
 | 
			
		||||
        fn(x, loc)
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(loc.x is x)
 | 
			
		||||
 | 
			
		||||
    def test_user_defined_class_name(self):
 | 
			
		||||
        class MyClassFoo:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,4 @@
 | 
			
		||||
# Owner(s): ["module: dynamo"]
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch._dynamo.test_case
 | 
			
		||||
@ -14,6 +13,17 @@ from torch.utils._device import DeviceContext
 | 
			
		||||
from torch.utils._python_dispatch import TorchDispatchMode
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestMode(BaseTorchFunctionMode):
 | 
			
		||||
    def __torch_function__(self, func, types, args, kwargs=None):
 | 
			
		||||
        if not kwargs:
 | 
			
		||||
            kwargs = {}
 | 
			
		||||
 | 
			
		||||
        if func == torch.add:
 | 
			
		||||
            return torch.zeros(2, 2)
 | 
			
		||||
 | 
			
		||||
        return super().__torch_function__(func, types, args, kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def setUpClass(cls):
 | 
			
		||||
@ -57,9 +67,11 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        torch.set_default_device(None)
 | 
			
		||||
        torch._dynamo.reset()
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        torch.set_default_device(None)
 | 
			
		||||
        torch._dynamo.reset()
 | 
			
		||||
 | 
			
		||||
    def _run_torch_function_mode_guard_test(self):
 | 
			
		||||
        class TestMode1(BaseTorchFunctionMode):
 | 
			
		||||
@ -94,70 +106,6 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
            fn(inp)
 | 
			
		||||
        self.assertEqual(cnt.frame_count, 4)
 | 
			
		||||
 | 
			
		||||
    def _run_ignored_mode_types_test(self):
 | 
			
		||||
        class IgnoredMode(BaseTorchFunctionMode):
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        cnt = torch._dynamo.testing.CompileCounter()
 | 
			
		||||
 | 
			
		||||
        @torch.compile(backend=cnt.__call__, fullgraph=True)
 | 
			
		||||
        def fn(x):
 | 
			
		||||
            return x + 1
 | 
			
		||||
 | 
			
		||||
        inp = torch.ones(2, 2)
 | 
			
		||||
 | 
			
		||||
        with patch(
 | 
			
		||||
            "torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode}
 | 
			
		||||
        ):
 | 
			
		||||
            # initial compile
 | 
			
		||||
            fn(inp)
 | 
			
		||||
 | 
			
		||||
            # no recompile, mode ignored
 | 
			
		||||
            # note: the ref stack is length 0, and the stack we are checking against has length 2
 | 
			
		||||
            # we want to check both ref stack len > runtime stack, and ref stack len < runtime stack
 | 
			
		||||
            with IgnoredMode(), IgnoredMode():
 | 
			
		||||
                fn(inp)
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(cnt.frame_count, 1)
 | 
			
		||||
 | 
			
		||||
            # recompile due to new mode on the stack
 | 
			
		||||
            with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
 | 
			
		||||
                fn(inp)
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(cnt.frame_count, 2)
 | 
			
		||||
 | 
			
		||||
            # recompile
 | 
			
		||||
            # tests both ref stack len > runtime stack len for the above guard check
 | 
			
		||||
            # and ref stack len < runtime stack len for the initial zero mode case
 | 
			
		||||
            with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode():
 | 
			
		||||
                fn(inp)
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(cnt.frame_count, 3)
 | 
			
		||||
 | 
			
		||||
            # no recompile
 | 
			
		||||
            with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
 | 
			
		||||
                fn(inp)
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(cnt.frame_count, 3)
 | 
			
		||||
 | 
			
		||||
        # This is tricky, basically the ignored modes are baked into the guard
 | 
			
		||||
        # IgnoredMode will be ignored forever by that guard.
 | 
			
		||||
        # This is okay since we don't expect to be modifying IGNORED_MODES
 | 
			
		||||
        # in the middle of execution except for the purposes of testing.
 | 
			
		||||
        torch._dynamo.reset()
 | 
			
		||||
 | 
			
		||||
        with IgnoredMode():
 | 
			
		||||
            fn(inp)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(cnt.frame_count, 4)
 | 
			
		||||
 | 
			
		||||
    @torch._dynamo.config.patch("enable_cpp_guard_manager", False)
 | 
			
		||||
    def test_torch_function_mode_guards_ignored_types_py(self):
 | 
			
		||||
        self._run_ignored_mode_types_test()
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_guards_ignored_types_cpp(self):
 | 
			
		||||
        self._run_ignored_mode_types_test()
 | 
			
		||||
 | 
			
		||||
    @torch._dynamo.config.patch("enable_cpp_guard_manager", False)
 | 
			
		||||
    def test_torch_function_mode_guards_py(self):
 | 
			
		||||
        self._run_torch_function_mode_guard_test()
 | 
			
		||||
@ -324,6 +272,218 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
            fn(inp)
 | 
			
		||||
        self.assertEqual(cnt.frame_count, 2)
 | 
			
		||||
 | 
			
		||||
    def test_nested_torch_function_mode(self):
 | 
			
		||||
        mode_1_called = False
 | 
			
		||||
        mode_2_called = False
 | 
			
		||||
 | 
			
		||||
        def reset_state():
 | 
			
		||||
            nonlocal mode_1_called
 | 
			
		||||
            nonlocal mode_2_called
 | 
			
		||||
            mode_1_called = False
 | 
			
		||||
            mode_2_called = False
 | 
			
		||||
 | 
			
		||||
        ones = torch.ones(2, 2)
 | 
			
		||||
        zeros = torch.zeros(2, 2)
 | 
			
		||||
 | 
			
		||||
        class TestMode1(BaseTorchFunctionMode):
 | 
			
		||||
            def __torch_function__(self, func, types, args, kwargs=None):
 | 
			
		||||
                if not kwargs:
 | 
			
		||||
                    kwargs = {}
 | 
			
		||||
 | 
			
		||||
                nonlocal mode_1_called
 | 
			
		||||
 | 
			
		||||
                mode_1_called = True
 | 
			
		||||
 | 
			
		||||
                if func == torch.add:
 | 
			
		||||
                    return zeros
 | 
			
		||||
 | 
			
		||||
                return super().__torch_function__(func, types, args, kwargs)
 | 
			
		||||
 | 
			
		||||
        class TestMode2(BaseTorchFunctionMode):
 | 
			
		||||
            def __torch_function__(self, func, types, args, kwargs=None):
 | 
			
		||||
                if not kwargs:
 | 
			
		||||
                    kwargs = {}
 | 
			
		||||
 | 
			
		||||
                nonlocal mode_2_called
 | 
			
		||||
 | 
			
		||||
                mode_2_called = True
 | 
			
		||||
 | 
			
		||||
                if func == torch.mul:
 | 
			
		||||
                    return ones
 | 
			
		||||
 | 
			
		||||
                return super().__torch_function__(func, types, args, kwargs)
 | 
			
		||||
 | 
			
		||||
        def fn(x):
 | 
			
		||||
            return torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
        def fn_2(x):
 | 
			
		||||
            return torch.mul(x, 3) + torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
        inp = torch.ones(2, 2) + 1
 | 
			
		||||
 | 
			
		||||
        for fn_i in [fn, fn_2]:
 | 
			
		||||
            fn_opt = torch.compile(fn_i, fullgraph=True)
 | 
			
		||||
            with TestMode1(), TestMode2():
 | 
			
		||||
                expected = fn_i(inp), mode_1_called, mode_2_called
 | 
			
		||||
                reset_state()
 | 
			
		||||
                actual = fn_opt(inp), mode_1_called, mode_2_called
 | 
			
		||||
                reset_state()
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_disable(self):
 | 
			
		||||
        class TestSubclass(torch.Tensor):
 | 
			
		||||
            @classmethod
 | 
			
		||||
            def __torch_function__(cls, func, types, args, kwargs=None):
 | 
			
		||||
                if not kwargs:
 | 
			
		||||
                    kwargs = {}
 | 
			
		||||
                if func == torch.add:
 | 
			
		||||
                    return torch.ones(2, 2)
 | 
			
		||||
                return super().__torch_function__(func, types, args, kwargs)
 | 
			
		||||
 | 
			
		||||
        class TestMode(BaseTorchFunctionMode):
 | 
			
		||||
            def __torch_function__(self, func, types, args, kwargs=None):
 | 
			
		||||
                if not kwargs:
 | 
			
		||||
                    kwargs = {}
 | 
			
		||||
 | 
			
		||||
                if func == torch.add:
 | 
			
		||||
                    return torch.zeros(2, 2)
 | 
			
		||||
 | 
			
		||||
                return super().__torch_function__(func, types, args, kwargs)
 | 
			
		||||
 | 
			
		||||
        def fn(x):
 | 
			
		||||
            return torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
        inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
 | 
			
		||||
 | 
			
		||||
        fn_opt = torch.compile(fn, fullgraph=True)
 | 
			
		||||
        with TestMode(), torch._dynamo.config.patch(
 | 
			
		||||
            "traceable_tensor_subclasses", {TestSubclass}
 | 
			
		||||
        ):
 | 
			
		||||
            with torch._C.DisableTorchFunctionSubclass():
 | 
			
		||||
                expected = fn(inp)
 | 
			
		||||
                actual = fn_opt(inp)
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
            with torch._C.DisableTorchFunction():
 | 
			
		||||
                expected = fn(inp)
 | 
			
		||||
                actual = fn_opt(inp)
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_highest_priority(self):
 | 
			
		||||
        class TestSubclass(torch.Tensor):
 | 
			
		||||
            @classmethod
 | 
			
		||||
            def __torch_function__(cls, func, types, args, kwargs=None):
 | 
			
		||||
                if not kwargs:
 | 
			
		||||
                    kwargs = {}
 | 
			
		||||
                if func == torch.add:
 | 
			
		||||
                    return torch.ones(2, 2)
 | 
			
		||||
                return super().__torch_function__(func, types, args, kwargs)
 | 
			
		||||
 | 
			
		||||
        def fn(x):
 | 
			
		||||
            return torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
        inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
 | 
			
		||||
 | 
			
		||||
        fn_opt = torch.compile(fn, fullgraph=True)
 | 
			
		||||
        with TestMode(), torch._dynamo.config.patch(
 | 
			
		||||
            "traceable_tensor_subclasses", {TestSubclass}
 | 
			
		||||
        ):
 | 
			
		||||
            expected = fn(inp)
 | 
			
		||||
            actual = fn_opt(inp)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_enter_exit(self):
 | 
			
		||||
        def fn(x, y):
 | 
			
		||||
            with TestMode():
 | 
			
		||||
                o = torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
            return torch.add(o, y)
 | 
			
		||||
 | 
			
		||||
        inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
 | 
			
		||||
        fn_opt = torch.compile(fn, fullgraph=True)
 | 
			
		||||
 | 
			
		||||
        expected = fn(*inp)
 | 
			
		||||
        actual = fn_opt(*inp)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_graph_break(self):
 | 
			
		||||
        def fn(x, y):
 | 
			
		||||
            with TestMode():
 | 
			
		||||
                torch._dynamo.graph_break()
 | 
			
		||||
                o = torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
            return torch.add(o, y)
 | 
			
		||||
 | 
			
		||||
        inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
 | 
			
		||||
        fn_opt = torch.compile(fn)
 | 
			
		||||
 | 
			
		||||
        expected = fn(*inp)
 | 
			
		||||
        actual = fn_opt(*inp)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_and_pop_graph_break(self):
 | 
			
		||||
        def fn(x, y):
 | 
			
		||||
            with TestMode():
 | 
			
		||||
                z = _pop_torch_function_stack()
 | 
			
		||||
                torch._dynamo.graph_break()
 | 
			
		||||
                _push_on_torch_function_stack(z)
 | 
			
		||||
                o = torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
            return torch.add(o, y)
 | 
			
		||||
 | 
			
		||||
        inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
 | 
			
		||||
        fn_opt = torch.compile(fn)
 | 
			
		||||
 | 
			
		||||
        expected = fn(*inp)
 | 
			
		||||
        actual = fn_opt(*inp)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_restore_on_exc(self):
 | 
			
		||||
        @torch._dynamo.disable()
 | 
			
		||||
        def err():
 | 
			
		||||
            raise RuntimeError("test")
 | 
			
		||||
 | 
			
		||||
        @torch.compile()
 | 
			
		||||
        def fn(x):
 | 
			
		||||
            with TestMode():
 | 
			
		||||
                x += 1
 | 
			
		||||
                err()
 | 
			
		||||
                x += 2
 | 
			
		||||
                return x
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            fn(torch.ones(2, 2))
 | 
			
		||||
        except RuntimeError:
 | 
			
		||||
            pass
 | 
			
		||||
        self.assertEqual(_len_torch_function_stack(), 0)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_and_pop_graph_break_mutation(self):
 | 
			
		||||
        def fn(x, y):
 | 
			
		||||
            with TestMode():
 | 
			
		||||
                z = _pop_torch_function_stack()
 | 
			
		||||
                z.y = 5
 | 
			
		||||
                torch._dynamo.graph_break()
 | 
			
		||||
                _push_on_torch_function_stack(z)
 | 
			
		||||
                o = torch.add(x, 3)
 | 
			
		||||
                o = torch.mul(o, z.y)
 | 
			
		||||
 | 
			
		||||
            return torch.add(o, y)
 | 
			
		||||
 | 
			
		||||
        inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
 | 
			
		||||
        fn_opt = torch.compile(fn)
 | 
			
		||||
 | 
			
		||||
        expected = fn(*inp)
 | 
			
		||||
        actual = fn_opt(*inp)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    from torch._dynamo.test_case import run_tests
 | 
			
		||||
 | 
			
		||||
@ -26,6 +26,7 @@ from torch.testing._internal.common_utils import (
 | 
			
		||||
    parametrize,
 | 
			
		||||
    requires_cuda,
 | 
			
		||||
    run_tests,
 | 
			
		||||
    skipIfCrossRef,
 | 
			
		||||
    skipIfRocm,
 | 
			
		||||
    skipIfTorchDynamo,
 | 
			
		||||
    TEST_WITH_TORCHDYNAMO,
 | 
			
		||||
@ -2882,6 +2883,7 @@ def forward(self, pred_1, x_1):
 | 
			
		||||
            gm = make_fx(f, tracing_mode="symbolic")(add_wrong_dtype, init, x)
 | 
			
		||||
 | 
			
		||||
    @skipIfNoDynamoSupport
 | 
			
		||||
    @skipIfCrossRef  # Arg order changes with crossref
 | 
			
		||||
    def test_scan_simple_graph(self):
 | 
			
		||||
        from torch._dynamo.testing import EagerAndRecordGraphs
 | 
			
		||||
 | 
			
		||||
@ -2988,6 +2990,7 @@ class TestControlFlowTraced(TestCase):
 | 
			
		||||
        self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True)))
 | 
			
		||||
 | 
			
		||||
    @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
 | 
			
		||||
    @skipIfCrossRef  # Arg order changes with crossref
 | 
			
		||||
    def test_cond_simple_with_linear_compile_check_graph(self):
 | 
			
		||||
        from torch._dynamo.testing import EagerAndRecordGraphs
 | 
			
		||||
 | 
			
		||||
@ -3250,6 +3253,7 @@ def forward(self, arg0_1):
 | 
			
		||||
        self._check_compile(fn, inp, backend=backend)
 | 
			
		||||
 | 
			
		||||
    @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
 | 
			
		||||
    @skipIfCrossRef  # Arg order changes with cross ref
 | 
			
		||||
    def test_while_loop_simple_with_linear_compile_check_graph(self):
 | 
			
		||||
        fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
 | 
			
		||||
        from torch._dynamo.testing import EagerAndRecordGraphs
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
 | 
			
		||||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
 | 
			
		||||
from torch.fx import Node
 | 
			
		||||
from torch.testing._internal.common_quantization import QuantizationTestCase
 | 
			
		||||
from torch.testing._internal.common_utils import IS_WINDOWS
 | 
			
		||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestHelperModules:
 | 
			
		||||
@ -139,6 +139,8 @@ class TestMetaDataPorting(QuantizationTestCase):
 | 
			
		||||
            self.assertEqual(v, node_tags[k])
 | 
			
		||||
        return m
 | 
			
		||||
 | 
			
		||||
    @skipIfCrossRef  # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack
 | 
			
		||||
    # trace of the mode torch function impl doesn't match the traced graph stored lineno.
 | 
			
		||||
    def test_simple_metadata_porting(self):
 | 
			
		||||
        """
 | 
			
		||||
        Model under test
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
 | 
			
		||||
)
 | 
			
		||||
from torch.export import export_for_training
 | 
			
		||||
from torch.testing._internal.common_quantization import TestHelperModules
 | 
			
		||||
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase
 | 
			
		||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, TestCase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _extract_debug_handles(model) -> Dict[torch.fx.Node, int]:
 | 
			
		||||
@ -117,6 +117,8 @@ class TestNumericDebugger(TestCase):
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(debug_handle_map, debug_handle_map_ref)
 | 
			
		||||
 | 
			
		||||
    @skipIfCrossRef  # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack
 | 
			
		||||
    # trace of the mode torch function impl doesn't match the traced graph stored lineno.
 | 
			
		||||
    def test_re_export_preserve_handle(self):
 | 
			
		||||
        m = TestHelperModules.Conv2dThenConv1d()
 | 
			
		||||
        example_inputs = m.example_inputs()
 | 
			
		||||
 | 
			
		||||
@ -67,7 +67,7 @@ class GuardManager:
 | 
			
		||||
    ) -> None: ...
 | 
			
		||||
    def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
 | 
			
		||||
    def add_torch_function_mode_stack_guard(
 | 
			
		||||
        self, initial_stack, ignored_types, verbose_code_parts: list[str]
 | 
			
		||||
        self, initial_stack, verbose_code_parts: list[str]
 | 
			
		||||
    ) -> None: ...
 | 
			
		||||
 | 
			
		||||
class RootGuardManager(GuardManager):
 | 
			
		||||
 | 
			
		||||
@ -31,6 +31,18 @@ def eager(gm, fake_tensor_inputs, **kwargs):
 | 
			
		||||
    return gm.forward
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_eager_backend_with_torch_function_mode(mode):
 | 
			
		||||
    """Used to trace HOPs (cond and while) for eager exectution, the metadata
 | 
			
		||||
    TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
 | 
			
		||||
    in the HOP, so we need to externally run this mode and not trace it."""
 | 
			
		||||
 | 
			
		||||
    def fn(gm, fake_tensor_inputs, **kwargs):
 | 
			
		||||
        with mode:
 | 
			
		||||
            return gm.forward
 | 
			
		||||
 | 
			
		||||
    return fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@register_backend
 | 
			
		||||
def eager_noexcept(gm, fake_tensor_inputs, **kwargs):
 | 
			
		||||
    if kwargs:
 | 
			
		||||
 | 
			
		||||
@ -112,6 +112,7 @@ from .utils import (
 | 
			
		||||
    troubleshooting_url,
 | 
			
		||||
    write_record_to_file,
 | 
			
		||||
)
 | 
			
		||||
from .variables.torch_function import torch_function_mode_stack_state_mgr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
np: Optional[ModuleType]
 | 
			
		||||
@ -210,15 +211,18 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
 | 
			
		||||
            prior_fwd_from_src = torch.fx.graph_module._forward_from_src
 | 
			
		||||
            torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
 | 
			
		||||
            cleanup = setup_compile_debug()
 | 
			
		||||
 | 
			
		||||
            exit_stack = contextlib.ExitStack()
 | 
			
		||||
            exit_stack.enter_context(
 | 
			
		||||
                torch.fx._symbolic_trace._maybe_revert_all_patches()
 | 
			
		||||
            )
 | 
			
		||||
            exit_stack.enter_context(torch_function_mode_stack_state_mgr)
 | 
			
		||||
            try:
 | 
			
		||||
                return fn(*args, **kwargs)
 | 
			
		||||
            finally:
 | 
			
		||||
                cleanup.close()
 | 
			
		||||
                assert (
 | 
			
		||||
                    torch._C._len_torch_function_stack() == 0
 | 
			
		||||
                ), "Torch function mode stack state changed while dynamo tracing, please report a bug"
 | 
			
		||||
                exit_stack.close()
 | 
			
		||||
                torch._C._set_grad_enabled(prior_grad_mode)
 | 
			
		||||
                torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
 | 
			
		||||
@ -605,6 +609,10 @@ def _compile(
 | 
			
		||||
    output: Optional[OutputGraph] = None
 | 
			
		||||
    tracer: Optional[InstructionTranslator] = None
 | 
			
		||||
 | 
			
		||||
    tf_mode_stack: List[
 | 
			
		||||
        torch.overrides.TorchFunctionMode
 | 
			
		||||
    ] = torch.overrides._get_current_function_mode_stack()
 | 
			
		||||
 | 
			
		||||
    @preserve_global_state
 | 
			
		||||
    def transform(
 | 
			
		||||
        instructions: List[Instruction], code_options: Dict[str, object]
 | 
			
		||||
@ -618,6 +626,7 @@ def _compile(
 | 
			
		||||
            locals,
 | 
			
		||||
            globals,
 | 
			
		||||
            builtins,
 | 
			
		||||
            tf_mode_stack,
 | 
			
		||||
            code_options,
 | 
			
		||||
            compiler_fn,
 | 
			
		||||
            one_graph,
 | 
			
		||||
 | 
			
		||||
@ -98,6 +98,7 @@ from .source import (
 | 
			
		||||
    ScriptObjectQualifiedNameSource,
 | 
			
		||||
    ShapeEnvSource,
 | 
			
		||||
    SubclassAttrListSource,
 | 
			
		||||
    TorchFunctionModeStackSource,
 | 
			
		||||
    TupleIteratorGetItemSource,
 | 
			
		||||
    TypeSource,
 | 
			
		||||
    UnspecializedBuiltinNNModuleSource,
 | 
			
		||||
@ -111,6 +112,7 @@ from .utils import (
 | 
			
		||||
    dict_keys_repr,
 | 
			
		||||
    get_custom_getattr,
 | 
			
		||||
    get_torch_function_mode_stack,
 | 
			
		||||
    get_torch_function_mode_stack_at,
 | 
			
		||||
    guard_failures,
 | 
			
		||||
    istype,
 | 
			
		||||
    key_is_id,
 | 
			
		||||
@ -314,6 +316,7 @@ CLOSURE_VARS = {
 | 
			
		||||
    "___dict_contains": lambda a, b: a in b,
 | 
			
		||||
    "___tuple_iterator_len": tuple_iterator_len,
 | 
			
		||||
    "___tuple_iterator_getitem": tuple_iterator_getitem,
 | 
			
		||||
    "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
 | 
			
		||||
    "__math_isnan": math.isnan,
 | 
			
		||||
    "__numpy_isnan": None if np is None else np.isnan,
 | 
			
		||||
    "inf": float("inf"),
 | 
			
		||||
@ -901,6 +904,15 @@ class GuardBuilder(GuardBuilderBase):
 | 
			
		||||
        ):
 | 
			
		||||
            assert base_guard_manager  # to make mypy happy
 | 
			
		||||
            out = base_guard_manager
 | 
			
		||||
        elif istype(source, TorchFunctionModeStackSource):
 | 
			
		||||
            out = root_guard_manager.lambda_manager(
 | 
			
		||||
                python_lambda=lambda _: get_torch_function_mode_stack_at(
 | 
			
		||||
                    source._get_index()
 | 
			
		||||
                ),
 | 
			
		||||
                source=source_name,
 | 
			
		||||
                example_value=example_value,
 | 
			
		||||
                guard_manager_enum=guard_manager_enum,
 | 
			
		||||
            )
 | 
			
		||||
        elif istype(source, GradSource):
 | 
			
		||||
            assert base_guard_manager  # to make mypy happy
 | 
			
		||||
            out = base_guard_manager.grad_manager(
 | 
			
		||||
@ -2214,6 +2226,8 @@ class CheckFunctionManager:
 | 
			
		||||
        self.output_graph = output_graph
 | 
			
		||||
        w_builder = None
 | 
			
		||||
 | 
			
		||||
        # NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing
 | 
			
		||||
        # in case a set default device call was made in the graph.
 | 
			
		||||
        self.torch_function_mode_stack = (
 | 
			
		||||
            output_graph.torch_function_mode_stack if output_graph else None
 | 
			
		||||
        )
 | 
			
		||||
@ -2330,15 +2344,12 @@ class CheckFunctionManager:
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if config.enable_cpp_guard_manager:
 | 
			
		||||
            from .variables.torch_function import IGNORED_MODES
 | 
			
		||||
 | 
			
		||||
            # Insert the global_state guard
 | 
			
		||||
            assert self.guard_manager  # to make mypy happy
 | 
			
		||||
            self.guard_manager.root.add_global_state_guard(["___check_global_state()"])
 | 
			
		||||
 | 
			
		||||
            self.guard_manager.root.add_torch_function_mode_stack_guard(
 | 
			
		||||
                self.torch_function_mode_stack,
 | 
			
		||||
                list(IGNORED_MODES),
 | 
			
		||||
                ["___check_torch_function_mode_stack()"],
 | 
			
		||||
            )
 | 
			
		||||
            # Clear references to torch_function modes held in the list
 | 
			
		||||
@ -2645,16 +2656,14 @@ def is_recompiles_verbose_enabled():
 | 
			
		||||
# this will only be used if cpp guards are disabled
 | 
			
		||||
def make_torch_function_mode_stack_guard(intial_stack):
 | 
			
		||||
    types = [type(x) for x in intial_stack]
 | 
			
		||||
    from .variables.torch_function import IGNORED_MODES
 | 
			
		||||
 | 
			
		||||
    def check_torch_function_mode_stack():
 | 
			
		||||
        cur_stack = get_torch_function_mode_stack()
 | 
			
		||||
 | 
			
		||||
        if len(cur_stack) != len(types):
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        for ty, mode in zip(types, cur_stack):
 | 
			
		||||
            if ty in IGNORED_MODES:
 | 
			
		||||
                continue
 | 
			
		||||
            if ty != type(mode):
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -78,7 +78,6 @@ from .utils import (
 | 
			
		||||
    get_instruction_source_311,
 | 
			
		||||
    get_locals_to_steal,
 | 
			
		||||
    get_static_address_type,
 | 
			
		||||
    get_torch_function_mode_stack,
 | 
			
		||||
    graph_break_reasons,
 | 
			
		||||
    increment_op_count,
 | 
			
		||||
    lazy_format_graph_code,
 | 
			
		||||
@ -250,6 +249,7 @@ class OutputGraph:
 | 
			
		||||
        local_scope: Scope,
 | 
			
		||||
        global_scope: Scope,
 | 
			
		||||
        f_code,
 | 
			
		||||
        torch_function_mode_stack,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.tracers = [SubgraphTracer(self, export_root=export)]
 | 
			
		||||
@ -368,7 +368,7 @@ class OutputGraph:
 | 
			
		||||
        # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
 | 
			
		||||
        self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
 | 
			
		||||
        # This records the initial torch function mode stack for guarding
 | 
			
		||||
        self.torch_function_mode_stack = get_torch_function_mode_stack()
 | 
			
		||||
        self.torch_function_mode_stack = torch_function_mode_stack
 | 
			
		||||
 | 
			
		||||
        # Tracks if the output graph has a user defined allowed function in the
 | 
			
		||||
        # graph. This is used later to determine if we should fallback to eager
 | 
			
		||||
@ -1020,7 +1020,7 @@ class OutputGraph:
 | 
			
		||||
            prefix_insts.clear()
 | 
			
		||||
 | 
			
		||||
        for block in reversed(tx.block_stack):
 | 
			
		||||
            block.exit(tx)
 | 
			
		||||
            block.exit(tx, is_graph_break=reason.graph_break)
 | 
			
		||||
 | 
			
		||||
        self.cleanup_graph()
 | 
			
		||||
        tx.prune_dead_locals()
 | 
			
		||||
 | 
			
		||||
@ -25,6 +25,26 @@ if TYPE_CHECKING:
 | 
			
		||||
        sys as sys,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
from torch.overrides import BaseTorchFunctionMode
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# These classes handle support for TorchFunctionModes across
 | 
			
		||||
# graph breaks
 | 
			
		||||
# Today the TorchFunctionMode enter (for the classes we support)
 | 
			
		||||
# simply pushes the mode onto the stack. Since after this occurs
 | 
			
		||||
# the stack is mutated, and we replay these mutations, we don't need
 | 
			
		||||
# any cleanup logic to be run once the graph break occurs, we simply replay
 | 
			
		||||
# these mutations to ensure at the graph break the torch function mode stack is correct
 | 
			
		||||
#  and reconstruct the torch function mode stack normally
 | 
			
		||||
# when we compile the resume function on the other side of the break.
 | 
			
		||||
# However, to ensure we exit properly
 | 
			
		||||
# in the resume function, we need to re-enter the contexts as we do other contexts.
 | 
			
		||||
# These contexts do nothing on enter, but provide the correct exit logic to ensure
 | 
			
		||||
# the stack state is correct.
 | 
			
		||||
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
 | 
			
		||||
    def __enter__(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def index(iterator, item, start=0, end=None):
 | 
			
		||||
    from itertools import islice
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,7 @@ import types
 | 
			
		||||
from typing import Any, cast, Dict, List, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
from .bytecode_transformation import (
 | 
			
		||||
    add_push_null,
 | 
			
		||||
    create_call_function,
 | 
			
		||||
    create_call_method,
 | 
			
		||||
    create_dup_top,
 | 
			
		||||
@ -48,6 +49,109 @@ class ReenterWith:
 | 
			
		||||
    stack_index: int
 | 
			
		||||
    target_values: Optional[Tuple[Any, ...]] = None
 | 
			
		||||
 | 
			
		||||
    def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
 | 
			
		||||
        """
 | 
			
		||||
        Codegen based off of:
 | 
			
		||||
        try:
 | 
			
		||||
            (rest)
 | 
			
		||||
        except:
 | 
			
		||||
            (restore previous stack)
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        from .variables.torch_function import get_prev_stack_var_name
 | 
			
		||||
 | 
			
		||||
        except_jump_target = create_instruction(
 | 
			
		||||
            "NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO"
 | 
			
		||||
        )
 | 
			
		||||
        cleanup_complete_jump_target = create_instruction("NOP")
 | 
			
		||||
 | 
			
		||||
        setup_finally: List[Instruction] = []
 | 
			
		||||
 | 
			
		||||
        if sys.version_info < (3, 11):
 | 
			
		||||
            setup_finally.append(
 | 
			
		||||
                create_instruction("SETUP_FINALLY", target=except_jump_target)
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            exn_tab_begin = create_instruction("NOP")
 | 
			
		||||
            exn_tab_end = create_instruction("NOP")
 | 
			
		||||
            exn_tab_begin.exn_tab_entry = InstructionExnTabEntry(
 | 
			
		||||
                exn_tab_begin,
 | 
			
		||||
                exn_tab_end,
 | 
			
		||||
                except_jump_target,
 | 
			
		||||
                self.stack_index + 1,
 | 
			
		||||
                False,
 | 
			
		||||
            )
 | 
			
		||||
            setup_finally.append(exn_tab_begin)
 | 
			
		||||
 | 
			
		||||
        def create_reset():
 | 
			
		||||
            insts = [
 | 
			
		||||
                create_instruction(
 | 
			
		||||
                    "LOAD_GLOBAL", argval="__import_torch_dot__dynamo_dot_utils"
 | 
			
		||||
                ),
 | 
			
		||||
                create_instruction("LOAD_ATTR", argval="set_torch_function_mode_stack"),
 | 
			
		||||
            ]
 | 
			
		||||
            add_push_null(insts)
 | 
			
		||||
            return [
 | 
			
		||||
                *insts,
 | 
			
		||||
                create_instruction("LOAD_FAST", argval=get_prev_stack_var_name()),
 | 
			
		||||
                *create_call_function(1, False),
 | 
			
		||||
                create_instruction("POP_TOP"),
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
        if sys.version_info < (3, 9):
 | 
			
		||||
            epilogue = [
 | 
			
		||||
                create_instruction("POP_BLOCK"),
 | 
			
		||||
                create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
 | 
			
		||||
                except_jump_target,
 | 
			
		||||
                *create_reset(),
 | 
			
		||||
                create_instruction("POP_TOP"),
 | 
			
		||||
                create_instruction("POP_TOP"),
 | 
			
		||||
                create_instruction("POP_TOP"),
 | 
			
		||||
                *create_reset(),
 | 
			
		||||
                create_instruction("RAISE_VARARGS", argval=0),
 | 
			
		||||
                create_instruction("POP_EXCEPT", argval=0),
 | 
			
		||||
                create_instruction("END_FINALLY"),
 | 
			
		||||
                cleanup_complete_jump_target,
 | 
			
		||||
            ]
 | 
			
		||||
        elif sys.version_info < (3, 11):
 | 
			
		||||
            epilogue = [
 | 
			
		||||
                create_instruction("POP_BLOCK"),
 | 
			
		||||
                create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
 | 
			
		||||
                except_jump_target,
 | 
			
		||||
                create_instruction("POP_TOP"),
 | 
			
		||||
                create_instruction("POP_TOP"),
 | 
			
		||||
                create_instruction("POP_TOP"),
 | 
			
		||||
                *create_reset(),
 | 
			
		||||
                create_instruction("RAISE_VARARGS", argval=0),
 | 
			
		||||
                create_instruction("POP_EXCEPT", argval=0),
 | 
			
		||||
                cleanup_complete_jump_target,
 | 
			
		||||
            ]
 | 
			
		||||
        else:
 | 
			
		||||
            finally_exn_tab_end = create_instruction("RAISE_VARARGS", argval=0)
 | 
			
		||||
            finally_exn_tab_target = create_instruction("COPY", arg=3)
 | 
			
		||||
            except_jump_target.exn_tab_entry = InstructionExnTabEntry(
 | 
			
		||||
                except_jump_target,
 | 
			
		||||
                finally_exn_tab_end,
 | 
			
		||||
                finally_exn_tab_target,
 | 
			
		||||
                self.stack_index + 2,
 | 
			
		||||
                True,
 | 
			
		||||
            )
 | 
			
		||||
            epilogue = [
 | 
			
		||||
                exn_tab_end,
 | 
			
		||||
                create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
 | 
			
		||||
                except_jump_target,  # PUSH_EXC_INFO
 | 
			
		||||
                create_instruction("POP_TOP"),
 | 
			
		||||
                *create_reset(),
 | 
			
		||||
                finally_exn_tab_end,
 | 
			
		||||
                finally_exn_tab_target,  # COPY 3
 | 
			
		||||
                create_instruction("POP_EXCEPT"),
 | 
			
		||||
                create_instruction("RERAISE", arg=1),  # RERAISE 1
 | 
			
		||||
                cleanup_complete_jump_target,
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
        cleanup[:] = epilogue + cleanup
 | 
			
		||||
        return setup_finally
 | 
			
		||||
 | 
			
		||||
    # If we do not want to destroy the stack, we can do the same thing as a
 | 
			
		||||
    # `SETUP_WITH` block, only that we store the context manager in a local_symbol
 | 
			
		||||
    def try_except(self, code_options, cleanup: List[Instruction]):
 | 
			
		||||
 | 
			
		||||
@ -593,11 +593,22 @@ class SideEffects:
 | 
			
		||||
            elif isinstance(
 | 
			
		||||
                var, variables.torch_function.TorchFunctionModeStackVariable
 | 
			
		||||
            ):
 | 
			
		||||
                # Needed in the finally block for stack restoration
 | 
			
		||||
                cg.add_push_null(
 | 
			
		||||
                    lambda: cg.load_import_from(
 | 
			
		||||
                        utils.__name__, "get_torch_function_mode_stack"
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                cg.call_function(0, False)
 | 
			
		||||
                name = variables.torch_function.get_prev_stack_var_name()
 | 
			
		||||
                cg.code_options["co_varnames"] += (name,)
 | 
			
		||||
                cg.append_output(create_instruction("STORE_FAST", argval=name))
 | 
			
		||||
                cg.add_push_null(
 | 
			
		||||
                    lambda: cg.load_import_from(
 | 
			
		||||
                        utils.__name__, "set_torch_function_mode_stack"
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                cg.foreach(var.symbolic_stack)
 | 
			
		||||
                cg.append_output(
 | 
			
		||||
                    create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))
 | 
			
		||||
 | 
			
		||||
@ -619,7 +619,7 @@ class TorchFunctionModeStackSource(Source):
 | 
			
		||||
    ind: int
 | 
			
		||||
 | 
			
		||||
    def name(self):
 | 
			
		||||
        return ""
 | 
			
		||||
        return f"___get_torch_function_mode_stack_at({self._get_index()})"
 | 
			
		||||
 | 
			
		||||
    def _get_index(self):
 | 
			
		||||
        from .variables.torch_function import TorchFunctionModeStackVariable
 | 
			
		||||
 | 
			
		||||
@ -19,20 +19,7 @@ import traceback
 | 
			
		||||
import types
 | 
			
		||||
import typing
 | 
			
		||||
import weakref
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    Callable,
 | 
			
		||||
    cast,
 | 
			
		||||
    Deque,
 | 
			
		||||
    Dict,
 | 
			
		||||
    List,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Set,
 | 
			
		||||
    Tuple,
 | 
			
		||||
    Type,
 | 
			
		||||
    TYPE_CHECKING,
 | 
			
		||||
    Union,
 | 
			
		||||
)
 | 
			
		||||
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -72,14 +59,12 @@ from .source import (
 | 
			
		||||
    GlobalWeakRefSource,
 | 
			
		||||
    LocalSource,
 | 
			
		||||
    Source,
 | 
			
		||||
    TorchFunctionModeStackSource,
 | 
			
		||||
)
 | 
			
		||||
from .trace_rules import is_builtin_constant, is_forbidden
 | 
			
		||||
from .utils import (
 | 
			
		||||
    counters,
 | 
			
		||||
    get_fake_value,
 | 
			
		||||
    get_instruction_source_311,
 | 
			
		||||
    get_torch_function_mode_stack,
 | 
			
		||||
    graph_break_dup_warning_checker,
 | 
			
		||||
    istype,
 | 
			
		||||
    LazyString,
 | 
			
		||||
@ -120,11 +105,10 @@ from .variables.misc import (
 | 
			
		||||
)
 | 
			
		||||
from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable
 | 
			
		||||
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from .variables.torch_function import TorchFunctionModeVariable
 | 
			
		||||
 | 
			
		||||
from .variables.torch_function import (
 | 
			
		||||
    SymbolicTorchFunctionState,
 | 
			
		||||
    TorchFunctionModeVariable,
 | 
			
		||||
)
 | 
			
		||||
from .variables.user_defined import (
 | 
			
		||||
    RemovableHandleVariable,
 | 
			
		||||
    UserDefinedClassVariable,
 | 
			
		||||
@ -283,9 +267,12 @@ class BlockStackEntry:
 | 
			
		||||
        else:
 | 
			
		||||
            return ReenterWith(self.stack_index)
 | 
			
		||||
 | 
			
		||||
    def exit(self, tx):
 | 
			
		||||
    def exit(self, tx, is_graph_break):
 | 
			
		||||
        assert self.with_context is not None
 | 
			
		||||
        return self.with_context.exit(tx)
 | 
			
		||||
        if (
 | 
			
		||||
            is_graph_break and self.with_context.exit_on_graph_break()
 | 
			
		||||
        ) or not is_graph_break:
 | 
			
		||||
            return self.with_context.exit(tx)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReturnValueOp(Exception):
 | 
			
		||||
@ -651,8 +638,17 @@ def break_graph_if_unsupported(*, push):
 | 
			
		||||
            cleanup: List[Instruction] = []
 | 
			
		||||
            # Reconstruct the context variable CLASS in the block stack
 | 
			
		||||
            for b in self.block_stack:
 | 
			
		||||
                # Don't exit any modes we have entered,
 | 
			
		||||
                # output bytecode will mutate the tf mode stack accordingly
 | 
			
		||||
                if isinstance(b.with_context, TorchFunctionModeVariable):
 | 
			
		||||
                    cg.extend_output(
 | 
			
		||||
                        b.resume_fn().try_except_torch_function_mode(
 | 
			
		||||
                            cg.code_options, cleanup
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                    continue
 | 
			
		||||
                assert b.with_context is not None
 | 
			
		||||
                assert isinstance(b.with_context, ContextWrappingVariable)
 | 
			
		||||
                assert isinstance(b.with_context, (ContextWrappingVariable))
 | 
			
		||||
                b.with_context.reconstruct_type(cg)
 | 
			
		||||
                cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup))
 | 
			
		||||
            self.output.add_output_instructions(cg.get_instructions())
 | 
			
		||||
@ -728,7 +724,7 @@ class InstructionTranslatorBase(
 | 
			
		||||
    output: OutputGraph
 | 
			
		||||
    symbolic_locals: Dict[str, VariableTracker]
 | 
			
		||||
    symbolic_globals: Dict[str, VariableTracker]
 | 
			
		||||
    symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"]
 | 
			
		||||
    symbolic_torch_function_state: SymbolicTorchFunctionState
 | 
			
		||||
    stack: List[VariableTracker]
 | 
			
		||||
    instruction_pointer: Optional[int]
 | 
			
		||||
    current_instruction: Instruction
 | 
			
		||||
@ -2305,7 +2301,10 @@ class InstructionTranslatorBase(
 | 
			
		||||
        ):
 | 
			
		||||
            unimplemented(f"{inst.opname} {ctx}")
 | 
			
		||||
 | 
			
		||||
        if isinstance(ctx, GenericContextWrappingVariable):
 | 
			
		||||
        if (
 | 
			
		||||
            isinstance(ctx, GenericContextWrappingVariable)
 | 
			
		||||
            and not ctx.supports_graph_breaks()
 | 
			
		||||
        ):
 | 
			
		||||
            self.generic_context_manager_depth += 1
 | 
			
		||||
 | 
			
		||||
        # Need this redundant check for mypy
 | 
			
		||||
@ -2548,7 +2547,7 @@ class InstructionTranslatorBase(
 | 
			
		||||
        code_options: Dict[str, Any],
 | 
			
		||||
        symbolic_locals: Dict[str, VariableTracker],
 | 
			
		||||
        symbolic_globals: Dict[str, VariableTracker],
 | 
			
		||||
        symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"],
 | 
			
		||||
        symbolic_torch_function_state: SymbolicTorchFunctionState,
 | 
			
		||||
        f_code: types.CodeType,
 | 
			
		||||
        export: bool,
 | 
			
		||||
        inline_depth: int,
 | 
			
		||||
@ -2563,7 +2562,7 @@ class InstructionTranslatorBase(
 | 
			
		||||
        self.output = output
 | 
			
		||||
        self.symbolic_locals = symbolic_locals
 | 
			
		||||
        self.symbolic_globals = symbolic_globals
 | 
			
		||||
        self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack
 | 
			
		||||
        self.symbolic_torch_function_state = symbolic_torch_function_state
 | 
			
		||||
        self.stack = []
 | 
			
		||||
        # stack of variable names for tracking 3.13 closures
 | 
			
		||||
        self.name_stack: list[Any] = []
 | 
			
		||||
@ -2652,6 +2651,7 @@ class InstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
        f_locals,
 | 
			
		||||
        f_globals,
 | 
			
		||||
        f_builtins,
 | 
			
		||||
        torch_function_mode_stack,
 | 
			
		||||
        code_options,
 | 
			
		||||
        compiler_fn,
 | 
			
		||||
        one_graph,
 | 
			
		||||
@ -2677,6 +2677,7 @@ class InstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
                local_scope=f_locals,
 | 
			
		||||
                global_scope=f_globals,
 | 
			
		||||
                f_code=f_code,
 | 
			
		||||
                torch_function_mode_stack=torch_function_mode_stack,
 | 
			
		||||
            ),
 | 
			
		||||
            instructions=instructions,
 | 
			
		||||
            f_locals=f_locals,
 | 
			
		||||
@ -2686,7 +2687,7 @@ class InstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
            symbolic_locals={},  # set below
 | 
			
		||||
            # A global var is inserted only after a STORE_GLOBAL happens to it
 | 
			
		||||
            symbolic_globals={},
 | 
			
		||||
            symbolic_torch_function_mode_stack=collections.deque(),
 | 
			
		||||
            symbolic_torch_function_state=None,  # type: ignore[arg-type] # set below
 | 
			
		||||
            f_code=f_code,
 | 
			
		||||
            export=export,
 | 
			
		||||
            inline_depth=0,
 | 
			
		||||
@ -2721,7 +2722,9 @@ class InstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
                if k in f_locals
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            self._init_torch_function_mode_stack()
 | 
			
		||||
            self.symbolic_torch_function_state = SymbolicTorchFunctionState(
 | 
			
		||||
                torch_function_mode_stack
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = []
 | 
			
		||||
            if export:
 | 
			
		||||
@ -2762,29 +2765,6 @@ class InstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
            )
 | 
			
		||||
            unimplemented(msg)
 | 
			
		||||
 | 
			
		||||
    def _init_torch_function_mode_stack(self):
 | 
			
		||||
        from .variables.torch_function import TorchFunctionModeStackVariable
 | 
			
		||||
 | 
			
		||||
        TorchFunctionModeStackVariable.reset()
 | 
			
		||||
 | 
			
		||||
        self.symbolic_torch_function_mode_stack: Deque[
 | 
			
		||||
            TorchFunctionModeVariable
 | 
			
		||||
        ] = collections.deque()
 | 
			
		||||
        # We want to retrieve all modes to properly reconstruct the stack if needed
 | 
			
		||||
        py_stack = get_torch_function_mode_stack(filter_ignored=False)
 | 
			
		||||
 | 
			
		||||
        if py_stack:
 | 
			
		||||
            has_device_context = isinstance(
 | 
			
		||||
                py_stack[0], torch.utils._device.DeviceContext
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        for i, val in enumerate(py_stack):
 | 
			
		||||
            self.symbolic_torch_function_mode_stack.append(
 | 
			
		||||
                variables.LazyVariableTracker.create(
 | 
			
		||||
                    val, source=TorchFunctionModeStackSource(i)
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def get_example_value(self, source: Source):
 | 
			
		||||
        if isinstance(source, LocalSource):
 | 
			
		||||
            return self.f_locals[source.local_name]
 | 
			
		||||
@ -3116,7 +3096,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
                code,
 | 
			
		||||
                sub_locals,
 | 
			
		||||
                parent.symbolic_globals,
 | 
			
		||||
                parent.symbolic_torch_function_mode_stack,
 | 
			
		||||
                parent.symbolic_torch_function_state,
 | 
			
		||||
                closure_cells,
 | 
			
		||||
                func,
 | 
			
		||||
            )
 | 
			
		||||
@ -3126,7 +3106,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
                code,
 | 
			
		||||
                sub_locals,
 | 
			
		||||
                parent.symbolic_globals,
 | 
			
		||||
                parent.symbolic_torch_function_mode_stack,
 | 
			
		||||
                parent.symbolic_torch_function_state,
 | 
			
		||||
                closure_cells,
 | 
			
		||||
                func,
 | 
			
		||||
            )
 | 
			
		||||
@ -3179,7 +3159,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
        code: types.CodeType,
 | 
			
		||||
        symbolic_locals: Dict[str, VariableTracker],
 | 
			
		||||
        symbolic_globals: Dict[str, VariableTracker],
 | 
			
		||||
        symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"],
 | 
			
		||||
        symbolic_torch_function_state: SymbolicTorchFunctionState,
 | 
			
		||||
        closure_cells: Dict[str, VariableTracker],
 | 
			
		||||
        funcvar: BaseUserFunctionVariable,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
@ -3196,7 +3176,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
            f_builtins=f_builtins,
 | 
			
		||||
            symbolic_locals=symbolic_locals,
 | 
			
		||||
            symbolic_globals=symbolic_globals,
 | 
			
		||||
            symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack,
 | 
			
		||||
            symbolic_torch_function_state=symbolic_torch_function_state,
 | 
			
		||||
            instructions=instructions,
 | 
			
		||||
            code_options={k: getattr(code, k) for k in get_code_keys()},
 | 
			
		||||
            f_code=code,
 | 
			
		||||
 | 
			
		||||
@ -163,6 +163,7 @@ def debug_insert_nops(
 | 
			
		||||
        local_scope=locals(),
 | 
			
		||||
        global_scope=globals(),
 | 
			
		||||
        f_code=frame.f_code,
 | 
			
		||||
        torch_function_mode_stack=[],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))
 | 
			
		||||
 | 
			
		||||
@ -304,6 +304,7 @@ manual_torch_name_rule_map = {
 | 
			
		||||
    "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
 | 
			
		||||
    "torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
 | 
			
		||||
    "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
 | 
			
		||||
    "torch.set_default_device": UserFunctionVariable,
 | 
			
		||||
    "torch.sparse_bsc_tensor": SkipFunctionVariable,
 | 
			
		||||
    "torch.sparse_bsr_tensor": SkipFunctionVariable,
 | 
			
		||||
    "torch.sparse_csc_tensor": SkipFunctionVariable,
 | 
			
		||||
@ -2797,7 +2798,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
 | 
			
		||||
        "torch.random.initial_seed",
 | 
			
		||||
        "torch.random.seed",
 | 
			
		||||
        "torch.return_types.pytree_register_structseq",
 | 
			
		||||
        "torch.set_default_device",
 | 
			
		||||
        "torch.set_default_dtype",
 | 
			
		||||
        "torch.set_default_tensor_type",
 | 
			
		||||
        "torch.set_deterministic_debug_mode",
 | 
			
		||||
@ -3258,6 +3258,7 @@ MOD_INLINELIST = [
 | 
			
		||||
    "torch.testing",
 | 
			
		||||
    "torch.utils._content_store",
 | 
			
		||||
    "torch.utils._contextlib",
 | 
			
		||||
    "torch.utils._device",
 | 
			
		||||
    "torch.utils._foreach_utils",
 | 
			
		||||
    "torch.utils._python_dispatch",
 | 
			
		||||
    "torch.utils._pytree",
 | 
			
		||||
@ -3592,7 +3593,9 @@ def lookup_inner(
 | 
			
		||||
            if reasons is not None:
 | 
			
		||||
                reasons.add("func name is patched_init")
 | 
			
		||||
            return SkipFunctionVariable
 | 
			
		||||
        elif name == "__torch_function__":
 | 
			
		||||
        elif name == "__torch_function__" or (
 | 
			
		||||
            obj and obj.__name__ == "__torch_function__"
 | 
			
		||||
        ):
 | 
			
		||||
            if reasons is not None:
 | 
			
		||||
                reasons.add("func name is __torch_function__")
 | 
			
		||||
            return UserFunctionVariable
 | 
			
		||||
 | 
			
		||||
@ -63,7 +63,6 @@ import torch.fx.experimental.symbolic_shapes
 | 
			
		||||
import torch.utils._pytree as pytree
 | 
			
		||||
from torch import fx
 | 
			
		||||
from torch._C import (
 | 
			
		||||
    _get_function_stack_at,
 | 
			
		||||
    _instruction_counter,
 | 
			
		||||
    _len_torch_function_stack,
 | 
			
		||||
    _pop_torch_function_stack,
 | 
			
		||||
@ -3084,14 +3083,10 @@ def is_parameter_freezing():
 | 
			
		||||
    return torch._inductor.config.freezing and not torch.is_grad_enabled()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_torch_function_mode_stack(filter_ignored=True):
 | 
			
		||||
    from .variables.torch_function import IGNORED_MODES
 | 
			
		||||
 | 
			
		||||
    stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())]
 | 
			
		||||
    if filter_ignored:
 | 
			
		||||
        stack = [mode for mode in stack if type(mode) not in IGNORED_MODES]
 | 
			
		||||
 | 
			
		||||
    return stack
 | 
			
		||||
def get_torch_function_mode_stack():
 | 
			
		||||
    return [
 | 
			
		||||
        get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack())
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_torch_function_mode_stack_at(ind):
 | 
			
		||||
@ -3107,6 +3102,11 @@ def set_torch_function_mode_stack(stack):
 | 
			
		||||
        _push_on_torch_function_stack(mode)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def clear_torch_function_mode_stack():
 | 
			
		||||
    for i in range(_len_torch_function_stack()):
 | 
			
		||||
        _pop_torch_function_stack()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def verify_guard_fn_signature(value):
 | 
			
		||||
    fn = value.__metadata_guard__
 | 
			
		||||
    sig = inspect.signature(fn)
 | 
			
		||||
 | 
			
		||||
@ -204,6 +204,7 @@ from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
 | 
			
		||||
from .torch_function import (
 | 
			
		||||
    build_torch_function_fn,
 | 
			
		||||
    TensorWithTFOverrideVariable,
 | 
			
		||||
    torch_function_mode_stack_state_mgr,
 | 
			
		||||
    TorchFunctionModeVariable,
 | 
			
		||||
)
 | 
			
		||||
from .user_defined import (
 | 
			
		||||
@ -1668,15 +1669,16 @@ class VariableBuilder:
 | 
			
		||||
                # but warning is not the end of the world
 | 
			
		||||
                assert isinstance(value.base, np.nditer)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            tensor_value = _util._try_convert_to_tensor(value)
 | 
			
		||||
            if readonly:
 | 
			
		||||
                from torch._prims_common import clone_preserve_strides
 | 
			
		||||
        with torch_function_mode_stack_state_mgr.temp_restore_stack():
 | 
			
		||||
            try:
 | 
			
		||||
                tensor_value = _util._try_convert_to_tensor(value)
 | 
			
		||||
                if readonly:
 | 
			
		||||
                    from torch._prims_common import clone_preserve_strides
 | 
			
		||||
 | 
			
		||||
                tensor_value = clone_preserve_strides(tensor_value)
 | 
			
		||||
        except NotImplementedError as e:
 | 
			
		||||
            # failed to convert to tensor, graph break
 | 
			
		||||
            unimplemented(str(e))
 | 
			
		||||
                    tensor_value = clone_preserve_strides(tensor_value)
 | 
			
		||||
            except NotImplementedError as e:
 | 
			
		||||
                # failed to convert to tensor, graph break
 | 
			
		||||
                unimplemented(str(e))
 | 
			
		||||
 | 
			
		||||
        # We do this because we want the full behavior of guarding the numpy ndarray as if it were
 | 
			
		||||
        # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
 | 
			
		||||
 | 
			
		||||
@ -125,6 +125,12 @@ class ContextWrappingVariable(VariableTracker):
 | 
			
		||||
        if isinstance(args[0], UserFunctionVariable):
 | 
			
		||||
            return WrappedUserFunctionVariable(args[0], self)
 | 
			
		||||
 | 
			
		||||
    def supports_graph_breaks(self):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def exit_on_graph_break(self):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GenericContextWrappingVariable(UserDefinedObjectVariable):
 | 
			
		||||
    # Some methods in ContextWrappingVariable assumes the arguments are
 | 
			
		||||
@ -183,6 +189,12 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable):
 | 
			
		||||
        tx.generic_context_manager_depth -= 1
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def supports_graph_breaks(self):
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def exit_on_graph_break(self):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
 | 
			
		||||
    """represents torch grad requries grad"""
 | 
			
		||||
@ -637,6 +649,8 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
 | 
			
		||||
 | 
			
		||||
    def _call_func(self, tx: "InstructionTranslator", values):
 | 
			
		||||
        assert len(values) == 1
 | 
			
		||||
        tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0]
 | 
			
		||||
        tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0]
 | 
			
		||||
        tx.output.set_torch_function_state(values[0])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -156,6 +156,25 @@ tracing_state_functions = {
 | 
			
		||||
bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@functools.lru_cache(None)
 | 
			
		||||
def get_overridable_functions():
 | 
			
		||||
    from itertools import chain
 | 
			
		||||
 | 
			
		||||
    from torch.overrides import get_overridable_functions as get_overridable_functions_
 | 
			
		||||
 | 
			
		||||
    funcs = set(chain(*get_overridable_functions_().values()))
 | 
			
		||||
    more = {
 | 
			
		||||
        torch.ones,
 | 
			
		||||
        torch.ones_like,
 | 
			
		||||
        torch.zeros,
 | 
			
		||||
        torch.zeros_like,
 | 
			
		||||
        torch.empty,
 | 
			
		||||
        torch.full,
 | 
			
		||||
    }
 | 
			
		||||
    funcs.update(more)
 | 
			
		||||
    return funcs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseTorchVariable(VariableTracker):
 | 
			
		||||
    """common base for all torch.* functions, classes, modules and other things"""
 | 
			
		||||
 | 
			
		||||
@ -806,10 +825,10 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
 | 
			
		||||
            self, tx: "InstructionTranslator", *args, **kwargs
 | 
			
		||||
        ):
 | 
			
		||||
            assert not args and not kwargs
 | 
			
		||||
            if not tx.symbolic_torch_function_mode_stack:
 | 
			
		||||
            if not tx.symbolic_torch_function_state.mode_stack:
 | 
			
		||||
                raise unimplemented("Popping from an empty torch function mode stack")
 | 
			
		||||
            TorchFunctionModeStackVariable.register_mutation(tx)
 | 
			
		||||
            return tx.symbolic_torch_function_mode_stack.pop()
 | 
			
		||||
            return tx.symbolic_torch_function_state.pop_torch_function_mode()
 | 
			
		||||
 | 
			
		||||
        @register(torch._C._push_on_torch_function_stack)
 | 
			
		||||
        def handle_push_torch_function(
 | 
			
		||||
@ -817,7 +836,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
 | 
			
		||||
        ):
 | 
			
		||||
            assert len(args) == 1 and not kwargs
 | 
			
		||||
            TorchFunctionModeStackVariable.register_mutation(tx)
 | 
			
		||||
            tx.symbolic_torch_function_mode_stack.append(args[0])
 | 
			
		||||
            tx.symbolic_torch_function_state.push_torch_function_mode(args[0])
 | 
			
		||||
            return ConstantVariable.create(None)
 | 
			
		||||
 | 
			
		||||
        @register(torch._C._len_torch_function_stack)
 | 
			
		||||
@ -825,7 +844,16 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
 | 
			
		||||
            self, tx: "InstructionTranslator", *args, **kwargs
 | 
			
		||||
        ):
 | 
			
		||||
            assert not args and not kwargs
 | 
			
		||||
            return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack))
 | 
			
		||||
            return ConstantVariable.create(
 | 
			
		||||
                len(tx.symbolic_torch_function_state.mode_stack)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        @register(torch._C._get_function_stack_at)
 | 
			
		||||
        def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs):
 | 
			
		||||
            assert len(args) == 1 and not kwargs
 | 
			
		||||
            ind = args[0].as_python_constant()
 | 
			
		||||
            assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack)
 | 
			
		||||
            return tx.symbolic_torch_function_state.mode_stack[ind]
 | 
			
		||||
 | 
			
		||||
        @register(torch.set_default_device)
 | 
			
		||||
        def handle_set_default_device(
 | 
			
		||||
@ -844,7 +872,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
 | 
			
		||||
            else:
 | 
			
		||||
                TorchFunctionModeStackVariable.register_device_context_insertion(tx)
 | 
			
		||||
 | 
			
		||||
            return None
 | 
			
		||||
            return ConstantVariable.create(None)
 | 
			
		||||
 | 
			
		||||
        return handlers
 | 
			
		||||
 | 
			
		||||
@ -857,6 +885,9 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
 | 
			
		||||
        from . import ConstantVariable, SymNodeVariable, TensorVariable
 | 
			
		||||
        from .builder import wrap_fx_proxy
 | 
			
		||||
 | 
			
		||||
        if self.torch_function_override_enabled(tx, args, kwargs):
 | 
			
		||||
            return dispatch_torch_function(tx, self, args, kwargs)
 | 
			
		||||
 | 
			
		||||
        if self.can_constant_fold_through() and check_unspec_or_constant_args(
 | 
			
		||||
            args, kwargs
 | 
			
		||||
        ):
 | 
			
		||||
@ -878,147 +909,144 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
 | 
			
		||||
            if result:
 | 
			
		||||
                return result
 | 
			
		||||
 | 
			
		||||
        if can_dispatch_torch_function(tx, args, kwargs):
 | 
			
		||||
            return dispatch_torch_function(tx, self, args, kwargs)
 | 
			
		||||
        else:
 | 
			
		||||
            any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
 | 
			
		||||
        any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
 | 
			
		||||
 | 
			
		||||
            all_ints_or_floats = all(
 | 
			
		||||
                isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
 | 
			
		||||
                for x in args
 | 
			
		||||
            )
 | 
			
		||||
            if (
 | 
			
		||||
                getattr(self.value, "__module__", "") == "torch"
 | 
			
		||||
                and self.value.__name__ in bin_ops
 | 
			
		||||
                and any_symints_or_symfloats
 | 
			
		||||
                and all_ints_or_floats
 | 
			
		||||
            ):
 | 
			
		||||
                msg = f"""\
 | 
			
		||||
        all_ints_or_floats = all(
 | 
			
		||||
            isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
 | 
			
		||||
            for x in args
 | 
			
		||||
        )
 | 
			
		||||
        if (
 | 
			
		||||
            getattr(self.value, "__module__", "") == "torch"
 | 
			
		||||
            and self.value.__name__ in bin_ops
 | 
			
		||||
            and any_symints_or_symfloats
 | 
			
		||||
            and all_ints_or_floats
 | 
			
		||||
        ):
 | 
			
		||||
            msg = f"""\
 | 
			
		||||
Calling {str(self.value)} on only torch.SymInt arguments is not yet supported.
 | 
			
		||||
To support this behavior, we need to allow const-propping tensors that store symint data.
 | 
			
		||||
For now, dynamo will explicitly graph break when it encounters user code with this behavior.
 | 
			
		||||
"""
 | 
			
		||||
                log.warning(msg)
 | 
			
		||||
                unimplemented(msg)
 | 
			
		||||
            log.warning(msg)
 | 
			
		||||
            unimplemented(msg)
 | 
			
		||||
 | 
			
		||||
            # TODO(voz): Replace w/ dynamic shape rewrite table.
 | 
			
		||||
            # Ideally, we would be able to do this at ctor time, but alas we need a combination
 | 
			
		||||
            # of value + args to determine this.
 | 
			
		||||
            fn_ = self.value
 | 
			
		||||
            if any_symints_or_symfloats:
 | 
			
		||||
                torch_sym_op = f"_sym_{self.value.__name__}"
 | 
			
		||||
                if getattr(self.value, "__module__", None) == "math" and hasattr(
 | 
			
		||||
                    torch, torch_sym_op
 | 
			
		||||
                ):
 | 
			
		||||
                    fn_ = getattr(torch, torch_sym_op)
 | 
			
		||||
        # TODO(voz): Replace w/ dynamic shape rewrite table.
 | 
			
		||||
        # Ideally, we would be able to do this at ctor time, but alas we need a combination
 | 
			
		||||
        # of value + args to determine this.
 | 
			
		||||
        fn_ = self.value
 | 
			
		||||
        if any_symints_or_symfloats:
 | 
			
		||||
            torch_sym_op = f"_sym_{self.value.__name__}"
 | 
			
		||||
            if getattr(self.value, "__module__", None) == "math" and hasattr(
 | 
			
		||||
                torch, torch_sym_op
 | 
			
		||||
            ):
 | 
			
		||||
                fn_ = getattr(torch, torch_sym_op)
 | 
			
		||||
 | 
			
		||||
            fake_out_shape = None
 | 
			
		||||
            if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
 | 
			
		||||
                # Calling fake tensor propagation can mutate the out= tensor in
 | 
			
		||||
                # tx.output.tracked_fakes. tracked_fakes are used to apply
 | 
			
		||||
                # symbolic_shape guards. Mutating them destroys the information
 | 
			
		||||
                # prior to tracing, which is essential for creating right
 | 
			
		||||
                # guards. So save the shape now, and check later if it has
 | 
			
		||||
                # changed. If it has, graph break.
 | 
			
		||||
                fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
 | 
			
		||||
        fake_out_shape = None
 | 
			
		||||
        if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
 | 
			
		||||
            # Calling fake tensor propagation can mutate the out= tensor in
 | 
			
		||||
            # tx.output.tracked_fakes. tracked_fakes are used to apply
 | 
			
		||||
            # symbolic_shape guards. Mutating them destroys the information
 | 
			
		||||
            # prior to tracing, which is essential for creating right
 | 
			
		||||
            # guards. So save the shape now, and check later if it has
 | 
			
		||||
            # changed. If it has, graph break.
 | 
			
		||||
            fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
 | 
			
		||||
 | 
			
		||||
            tensor_variable = wrap_fx_proxy(
 | 
			
		||||
                tx=tx,
 | 
			
		||||
                proxy=tx.output.create_proxy(
 | 
			
		||||
                    "call_function",
 | 
			
		||||
                    fn_,
 | 
			
		||||
                    *proxy_args_kwargs(args, kwargs),
 | 
			
		||||
                ),
 | 
			
		||||
        tensor_variable = wrap_fx_proxy(
 | 
			
		||||
            tx=tx,
 | 
			
		||||
            proxy=tx.output.create_proxy(
 | 
			
		||||
                "call_function",
 | 
			
		||||
                fn_,
 | 
			
		||||
                *proxy_args_kwargs(args, kwargs),
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            isinstance(tensor_variable, TensorVariable)
 | 
			
		||||
            and "requires_grad" in kwargs
 | 
			
		||||
            and kwargs["requires_grad"].as_python_constant()
 | 
			
		||||
        ):
 | 
			
		||||
            unimplemented(
 | 
			
		||||
                """factory functions that return tensors that require grad are not supported.
 | 
			
		||||
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                isinstance(tensor_variable, TensorVariable)
 | 
			
		||||
                and "requires_grad" in kwargs
 | 
			
		||||
                and kwargs["requires_grad"].as_python_constant()
 | 
			
		||||
            ):
 | 
			
		||||
                unimplemented(
 | 
			
		||||
                    """factory functions that return tensors that require grad are not supported.
 | 
			
		||||
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if "out" in kwargs and not (
 | 
			
		||||
                isinstance(kwargs["out"], variables.ConstantVariable)
 | 
			
		||||
                and kwargs["out"].as_python_constant() is None
 | 
			
		||||
            ):
 | 
			
		||||
                # out variants of torch operators like torch.sort and
 | 
			
		||||
                # torch.sigmoid mutate the tensors in the out field. Track such
 | 
			
		||||
                # tensors and rewrite the symbolic locals.
 | 
			
		||||
                if isinstance(tensor_variable, TupleVariable):
 | 
			
		||||
                    assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
 | 
			
		||||
                    output_tensor_names = [
 | 
			
		||||
                        tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
 | 
			
		||||
                    ]
 | 
			
		||||
                    for idx, name in enumerate(output_tensor_names):
 | 
			
		||||
                        if name in tx.symbolic_locals:
 | 
			
		||||
                            tx.symbolic_locals[name] = tensor_variable.items[idx]
 | 
			
		||||
                    for out_tensor, result_tensor in zip(
 | 
			
		||||
                        kwargs["out"].items, tensor_variable.items
 | 
			
		||||
                    ):
 | 
			
		||||
                        if (
 | 
			
		||||
                            out_tensor.source
 | 
			
		||||
                            and out_tensor in tx.output.graphargs
 | 
			
		||||
                            and isinstance(out_tensor, variables.TensorVariable)
 | 
			
		||||
                            and isinstance(result_tensor, variables.TensorVariable)
 | 
			
		||||
                            and out_tensor.size != result_tensor.size
 | 
			
		||||
                        ):
 | 
			
		||||
                            # It's hard to get out variants with resizing on graph inputs work
 | 
			
		||||
                            # properly across dynamo/aot/inductor, just fall back.
 | 
			
		||||
                            unimplemented("out variants with resizing on graph inputs")
 | 
			
		||||
                elif isinstance(tensor_variable, TensorVariable):
 | 
			
		||||
                    assert isinstance(kwargs["out"], TensorVariable)
 | 
			
		||||
                    assert "example_value" in kwargs["out"].proxy.node.meta
 | 
			
		||||
                    fake_tensor = tensor_variable.proxy.node.meta["example_value"]
 | 
			
		||||
                    fake_out = kwargs["out"].proxy.node.meta["example_value"]
 | 
			
		||||
        if "out" in kwargs and not (
 | 
			
		||||
            isinstance(kwargs["out"], variables.ConstantVariable)
 | 
			
		||||
            and kwargs["out"].as_python_constant() is None
 | 
			
		||||
        ):
 | 
			
		||||
            # out variants of torch operators like torch.sort and
 | 
			
		||||
            # torch.sigmoid mutate the tensors in the out field. Track such
 | 
			
		||||
            # tensors and rewrite the symbolic locals.
 | 
			
		||||
            if isinstance(tensor_variable, TupleVariable):
 | 
			
		||||
                assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
 | 
			
		||||
                output_tensor_names = [
 | 
			
		||||
                    tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
 | 
			
		||||
                ]
 | 
			
		||||
                for idx, name in enumerate(output_tensor_names):
 | 
			
		||||
                    if name in tx.symbolic_locals:
 | 
			
		||||
                        tx.symbolic_locals[name] = tensor_variable.items[idx]
 | 
			
		||||
                for out_tensor, result_tensor in zip(
 | 
			
		||||
                    kwargs["out"].items, tensor_variable.items
 | 
			
		||||
                ):
 | 
			
		||||
                    if (
 | 
			
		||||
                        kwargs["out"].source
 | 
			
		||||
                        and kwargs["out"] in tx.output.graphargs
 | 
			
		||||
                        and fake_out_shape != fake_tensor.shape
 | 
			
		||||
                        out_tensor.source
 | 
			
		||||
                        and out_tensor in tx.output.graphargs
 | 
			
		||||
                        and isinstance(out_tensor, variables.TensorVariable)
 | 
			
		||||
                        and isinstance(result_tensor, variables.TensorVariable)
 | 
			
		||||
                        and out_tensor.size != result_tensor.size
 | 
			
		||||
                    ):
 | 
			
		||||
                        # It's hard to get out variants with resizing on graph inputs work
 | 
			
		||||
                        # properly across dynamo/aot/inductor, just fall back.
 | 
			
		||||
                        unimplemented("out variants with resizing on graph inputs")
 | 
			
		||||
            elif isinstance(tensor_variable, TensorVariable):
 | 
			
		||||
                assert isinstance(kwargs["out"], TensorVariable)
 | 
			
		||||
                assert "example_value" in kwargs["out"].proxy.node.meta
 | 
			
		||||
                fake_tensor = tensor_variable.proxy.node.meta["example_value"]
 | 
			
		||||
                fake_out = kwargs["out"].proxy.node.meta["example_value"]
 | 
			
		||||
                if (
 | 
			
		||||
                    kwargs["out"].source
 | 
			
		||||
                    and kwargs["out"] in tx.output.graphargs
 | 
			
		||||
                    and fake_out_shape != fake_tensor.shape
 | 
			
		||||
                ):
 | 
			
		||||
                    # It's hard to get out variants with resizing on graph inputs work
 | 
			
		||||
                    # properly across dynamo/aot/inductor, just fall back.
 | 
			
		||||
                    unimplemented("out variants with resizing on graph inputs")
 | 
			
		||||
                if not torch._prims_common.is_contiguous(fake_out):
 | 
			
		||||
                    # It's difficult to handle strides correctly in functionalization
 | 
			
		||||
                    # when calling an out= op with a non-contiguous out argument
 | 
			
		||||
                    unimplemented(
 | 
			
		||||
                        "out= op was called where output tensor was non-contiguous"
 | 
			
		||||
                    )
 | 
			
		||||
                name = tx.find_symbolic_locals_name(kwargs["out"])
 | 
			
		||||
                if name in tx.symbolic_locals:
 | 
			
		||||
                    tx.symbolic_locals[name] = tensor_variable
 | 
			
		||||
            elif (
 | 
			
		||||
                isinstance(tensor_variable, ConstantVariable)
 | 
			
		||||
                and tensor_variable.value is None
 | 
			
		||||
            ):
 | 
			
		||||
                # Handle out-variant custom ops that return None.
 | 
			
		||||
                if isinstance(kwargs["out"], TensorVariable):
 | 
			
		||||
                    assert "example_value" in kwargs["out"].proxy.node.meta
 | 
			
		||||
                    fake_out = kwargs["out"].proxy.node.meta["example_value"]
 | 
			
		||||
                    if not torch._prims_common.is_contiguous(fake_out):
 | 
			
		||||
                        # It's difficult to handle strides correctly in functionalization
 | 
			
		||||
                        # when calling an out= op with a non-contiguous out argument
 | 
			
		||||
                        unimplemented(
 | 
			
		||||
                            "out= op was called where output tensor was non-contiguous"
 | 
			
		||||
                        )
 | 
			
		||||
                    name = tx.find_symbolic_locals_name(kwargs["out"])
 | 
			
		||||
                    if name in tx.symbolic_locals:
 | 
			
		||||
                        tx.symbolic_locals[name] = tensor_variable
 | 
			
		||||
                elif (
 | 
			
		||||
                    isinstance(tensor_variable, ConstantVariable)
 | 
			
		||||
                    and tensor_variable.value is None
 | 
			
		||||
                ):
 | 
			
		||||
                    # Handle out-variant custom ops that return None.
 | 
			
		||||
                    if isinstance(kwargs["out"], TensorVariable):
 | 
			
		||||
                        assert "example_value" in kwargs["out"].proxy.node.meta
 | 
			
		||||
                        fake_out = kwargs["out"].proxy.node.meta["example_value"]
 | 
			
		||||
                elif isinstance(kwargs["out"], ListVariable):
 | 
			
		||||
                    for idx, x in enumerate(kwargs["out"].items):
 | 
			
		||||
                        assert "example_value" in x.proxy.node.meta  # type: ignore[attr-defined]
 | 
			
		||||
                        fake_out = x.proxy.node.meta["example_value"]  # type: ignore[attr-defined]
 | 
			
		||||
                        if not torch._prims_common.is_contiguous(fake_out):
 | 
			
		||||
                            # It's difficult to handle strides correctly in functionalization
 | 
			
		||||
                            # when calling an out= op with a non-contiguous out argument
 | 
			
		||||
                            unimplemented(
 | 
			
		||||
                                "out= op was called where output tensor was non-contiguous"
 | 
			
		||||
                                "out= op was called where some of the output tensors were non-contiguous"
 | 
			
		||||
                            )
 | 
			
		||||
                    elif isinstance(kwargs["out"], ListVariable):
 | 
			
		||||
                        for idx, x in enumerate(kwargs["out"].items):
 | 
			
		||||
                            assert "example_value" in x.proxy.node.meta  # type: ignore[attr-defined]
 | 
			
		||||
                            fake_out = x.proxy.node.meta["example_value"]  # type: ignore[attr-defined]
 | 
			
		||||
                            if not torch._prims_common.is_contiguous(fake_out):
 | 
			
		||||
                                # It's difficult to handle strides correctly in functionalization
 | 
			
		||||
                                # when calling an out= op with a non-contiguous out argument
 | 
			
		||||
                                unimplemented(
 | 
			
		||||
                                    "out= op was called where some of the output tensors were non-contiguous"
 | 
			
		||||
                                )
 | 
			
		||||
                else:
 | 
			
		||||
                    unimplemented(f"out variant of {type(kwargs['out'])}")
 | 
			
		||||
            else:
 | 
			
		||||
                unimplemented(f"out variant of {type(kwargs['out'])}")
 | 
			
		||||
 | 
			
		||||
            return tensor_variable
 | 
			
		||||
        return tensor_variable
 | 
			
		||||
 | 
			
		||||
    def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs):
 | 
			
		||||
        """inline behavior of torch.nn.modules.utils._ntuple"""
 | 
			
		||||
@ -1146,3 +1174,12 @@ Either create the tensor outside the compiled region, or do not set the tensor t
 | 
			
		||||
            source
 | 
			
		||||
        )
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def torch_function_override_enabled(self, tx, args, kwargs):
 | 
			
		||||
        return (
 | 
			
		||||
            self.get_function() in get_overridable_functions()
 | 
			
		||||
            or isinstance(
 | 
			
		||||
                self.get_function(),
 | 
			
		||||
                (torch._ops.OpOverload, torch._ops.OpOverloadPacket),
 | 
			
		||||
            )
 | 
			
		||||
        ) and can_dispatch_torch_function(tx, args, kwargs)
 | 
			
		||||
 | 
			
		||||
@ -1,20 +1,37 @@
 | 
			
		||||
# mypy: ignore-errors
 | 
			
		||||
 | 
			
		||||
import collections
 | 
			
		||||
import contextlib
 | 
			
		||||
import functools
 | 
			
		||||
import inspect
 | 
			
		||||
from typing import Dict, List, TYPE_CHECKING
 | 
			
		||||
from typing import Deque, Dict, List, TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
import torch._C
 | 
			
		||||
import torch.utils._pytree as pytree
 | 
			
		||||
from torch._guards import Source
 | 
			
		||||
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
 | 
			
		||||
from torch.overrides import (
 | 
			
		||||
    _get_overloaded_args,
 | 
			
		||||
    get_default_nowrap_functions,
 | 
			
		||||
    TorchFunctionMode,
 | 
			
		||||
)
 | 
			
		||||
from torch.utils._device import DeviceContext
 | 
			
		||||
 | 
			
		||||
from ..exc import unimplemented
 | 
			
		||||
from ..guards import GuardBuilder, install_guard
 | 
			
		||||
from ..polyfills import NoEnterTorchFunctionMode
 | 
			
		||||
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
 | 
			
		||||
from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
 | 
			
		||||
from ..utils import (
 | 
			
		||||
    class_has_getattribute,
 | 
			
		||||
    clear_torch_function_mode_stack,
 | 
			
		||||
    get_safe_global_name,
 | 
			
		||||
    has_torch_function,
 | 
			
		||||
    is_tensor_base_attr_getter,
 | 
			
		||||
    set_torch_function_mode_stack,
 | 
			
		||||
)
 | 
			
		||||
from .base import VariableTracker
 | 
			
		||||
from .constant import ConstantVariable
 | 
			
		||||
from .ctx_manager import ContextWrappingVariable
 | 
			
		||||
from .ctx_manager import GenericContextWrappingVariable
 | 
			
		||||
from .lazy import LazyVariableTracker
 | 
			
		||||
from .lists import TupleVariable
 | 
			
		||||
from .tensor import TensorSubclassVariable, TensorVariable
 | 
			
		||||
from .user_defined import UserDefinedObjectVariable
 | 
			
		||||
@ -52,11 +69,99 @@ banned_attrs = [
 | 
			
		||||
    if is_tensor_base_attr_getter(fn)
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
# Today set default device is placed in the graph and guarded on separately
 | 
			
		||||
# so we should not trace through it. In the future we can trace it once
 | 
			
		||||
# mode tracing is implemented and not put in the graph, but this is more
 | 
			
		||||
# of a BE project and can be evaluated later
 | 
			
		||||
IGNORED_MODES = {DeviceContext}
 | 
			
		||||
 | 
			
		||||
@functools.lru_cache(None)
 | 
			
		||||
def get_prev_stack_var_name():
 | 
			
		||||
    from ..bytecode_transformation import unique_id
 | 
			
		||||
 | 
			
		||||
    return unique_id("___prev_torch_function_mode_stack")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
 | 
			
		||||
class TorchFunctionModeStackStateManager:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.stack = []
 | 
			
		||||
 | 
			
		||||
    def __enter__(self):
 | 
			
		||||
        self.stack = torch.overrides._get_current_function_mode_stack()
 | 
			
		||||
        clear_torch_function_mode_stack()
 | 
			
		||||
 | 
			
		||||
    def __exit__(self, exc_type, exc_value, traceback):
 | 
			
		||||
        set_torch_function_mode_stack(self.stack)
 | 
			
		||||
        self.stack = []
 | 
			
		||||
 | 
			
		||||
    @contextlib.contextmanager
 | 
			
		||||
    def temp_restore_stack(self):
 | 
			
		||||
        prev = torch.overrides._get_current_function_mode_stack()
 | 
			
		||||
        set_torch_function_mode_stack(self.stack)
 | 
			
		||||
        try:
 | 
			
		||||
            yield
 | 
			
		||||
        finally:
 | 
			
		||||
            set_torch_function_mode_stack(prev)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SymbolicTorchFunctionState:
 | 
			
		||||
    def __init__(self, py_stack):
 | 
			
		||||
        # This is annoyingly complicated because of how the torch function subclass + mode C API was designed
 | 
			
		||||
        # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass
 | 
			
		||||
        # These are their definitions:
 | 
			
		||||
        # 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered
 | 
			
		||||
        # (if either are entered, this will be False)
 | 
			
		||||
        # 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR
 | 
			
		||||
        # torch._C.DisableTorchFunction has been entered
 | 
			
		||||
        # To disambiguate these and keep myself sane I added a C API to check whether all torch function
 | 
			
		||||
        # concepts (modes and subclasses) are enabled.
 | 
			
		||||
        # This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate
 | 
			
		||||
        # the stack length from the enablement state of torch function modes.
 | 
			
		||||
        # This is important because now if a mode is pushed while dynamo is tracing, we know whether
 | 
			
		||||
        # or not torch function modes are enabled and whether we should trace it.
 | 
			
		||||
        self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled()
 | 
			
		||||
 | 
			
		||||
        # This differs from the C API of the same name
 | 
			
		||||
        # this will only be false iff we have entered torch._C.DisableTorchFunction
 | 
			
		||||
        # and does not take into account the mode stack length, while the C API bundles these
 | 
			
		||||
        # two concepts
 | 
			
		||||
        self.torch_function_mode_enabled = (
 | 
			
		||||
            not torch._C._is_torch_function_all_disabled()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.cur_mode = None
 | 
			
		||||
 | 
			
		||||
        TorchFunctionModeStackVariable.reset()
 | 
			
		||||
 | 
			
		||||
        self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque()
 | 
			
		||||
 | 
			
		||||
        for i, val in enumerate(py_stack):
 | 
			
		||||
            self.mode_stack.append(
 | 
			
		||||
                LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i))
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def in_torch_function_mode(self):
 | 
			
		||||
        return len(self.mode_stack) > 0
 | 
			
		||||
 | 
			
		||||
    def pop_torch_function_mode(self):
 | 
			
		||||
        return self.mode_stack.pop()
 | 
			
		||||
 | 
			
		||||
    def push_torch_function_mode(self, mode_var):
 | 
			
		||||
        self.mode_stack.append(mode_var)
 | 
			
		||||
 | 
			
		||||
    def call_torch_function_mode(self, tx, fn, types, args, kwargs):
 | 
			
		||||
        with self._pop_mode_for_inlining() as cur_mode:
 | 
			
		||||
            return cur_mode.call_torch_function(tx, fn, types, args, kwargs)
 | 
			
		||||
 | 
			
		||||
    @contextlib.contextmanager
 | 
			
		||||
    def _pop_mode_for_inlining(self):
 | 
			
		||||
        old_mode = self.cur_mode
 | 
			
		||||
        self.cur_mode = self.pop_torch_function_mode()
 | 
			
		||||
        try:
 | 
			
		||||
            yield self.cur_mode
 | 
			
		||||
        finally:
 | 
			
		||||
            mode = self.cur_mode
 | 
			
		||||
            self.cur_mode = old_mode
 | 
			
		||||
            self.push_torch_function_mode(mode)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TorchFunctionModeStackVariable(VariableTracker):
 | 
			
		||||
@ -88,19 +193,20 @@ class TorchFunctionModeStackVariable(VariableTracker):
 | 
			
		||||
    def register_mutation(cls, tx: "InstructionTranslator"):
 | 
			
		||||
        if cls.stack_value_singleton not in tx.output.side_effects:
 | 
			
		||||
            var = cls(
 | 
			
		||||
                source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack
 | 
			
		||||
                source=Source(),
 | 
			
		||||
                symbolic_stack=tx.symbolic_torch_function_state.mode_stack,
 | 
			
		||||
            )
 | 
			
		||||
            tx.output.side_effects.track_mutable(cls.stack_value_singleton, var)
 | 
			
		||||
            tx.output.side_effects.mutation(var)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def register_device_context_insertion(cls, tx: "InstructionTranslator"):
 | 
			
		||||
        stack = tx.symbolic_torch_function_mode_stack
 | 
			
		||||
        stack = tx.symbolic_torch_function_state.mode_stack
 | 
			
		||||
        if stack and cls.is_device_context(stack[0]):
 | 
			
		||||
            return
 | 
			
		||||
        else:
 | 
			
		||||
            cls.offset += 1
 | 
			
		||||
            tx.symbolic_torch_function_mode_stack.insert(
 | 
			
		||||
            stack.insert(
 | 
			
		||||
                0,
 | 
			
		||||
                TorchFunctionModeVariable(
 | 
			
		||||
                    None, source=TorchFunctionModeStackSource(-cls.offset)
 | 
			
		||||
@ -109,7 +215,7 @@ class TorchFunctionModeStackVariable(VariableTracker):
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def clear_default_device(cls, tx: "InstructionTranslator"):
 | 
			
		||||
        stack = tx.symbolic_torch_function_mode_stack
 | 
			
		||||
        stack = tx.symbolic_torch_function_state.mode_stack
 | 
			
		||||
        if stack and cls.is_device_context(stack[0]):
 | 
			
		||||
            stack.popleft()
 | 
			
		||||
            cls.offset -= 1
 | 
			
		||||
@ -123,24 +229,88 @@ class TorchFunctionModeStackVariable(VariableTracker):
 | 
			
		||||
        return ind + cls.offset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TorchFunctionModeVariable(ContextWrappingVariable):
 | 
			
		||||
    def __init__(self, value, **kwargs):
 | 
			
		||||
        super().__init__(value, **kwargs)
 | 
			
		||||
        self.value = value
 | 
			
		||||
 | 
			
		||||
class TorchFunctionModeVariable(GenericContextWrappingVariable):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_global_mangled_name(tx, val):
 | 
			
		||||
        return get_safe_global_name(
 | 
			
		||||
            tx, f"__torch_function_mode_{val.__class__.__name__}", val
 | 
			
		||||
    def is_supported_torch_function_mode(ty):
 | 
			
		||||
        # Supported in this sense means we can support graph breaks under the
 | 
			
		||||
        # context.
 | 
			
		||||
        # We are able to trace custom modes but if there are graph breaks under them
 | 
			
		||||
        # and they have a custom __enter__/__exit__ we don't handle this for the
 | 
			
		||||
        # same reason we don't handle generic context managers: there may be side effects
 | 
			
		||||
        # that are now affected by executing the funtion across two frames instead of one
 | 
			
		||||
        # Today we support the enter/exit of the default TorchFunctionMode as well as
 | 
			
		||||
        # DeviceContext (which is used for set_default_device)
 | 
			
		||||
        return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
 | 
			
		||||
            not class_has_getattribute(ty)
 | 
			
		||||
            and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
 | 
			
		||||
            and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def __init__(self, value, source=None, **kwargs):
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            super().__init__(value, **kwargs)
 | 
			
		||||
        self.value = value
 | 
			
		||||
        self.cm_obj = value  # needed for BC with calling enter from CM code
 | 
			
		||||
        self.source = source
 | 
			
		||||
 | 
			
		||||
    def reconstruct(self, codegen):
 | 
			
		||||
        # We don't support locally created torch function modes yet
 | 
			
		||||
        # This shouldn't be called unless we have a source
 | 
			
		||||
        assert self.source
 | 
			
		||||
        self.source.reconstruct(codegen)
 | 
			
		||||
 | 
			
		||||
    def _call_func(self, tx, values):
 | 
			
		||||
        unimplemented("torch function mode context manager is not supported yet")
 | 
			
		||||
    def module_name(self):
 | 
			
		||||
        return self.value.__module__
 | 
			
		||||
 | 
			
		||||
    def fn_name(self):
 | 
			
		||||
        return type(self.value).__name__
 | 
			
		||||
 | 
			
		||||
    def python_type(self):
 | 
			
		||||
        return type(self.value)
 | 
			
		||||
 | 
			
		||||
    def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
 | 
			
		||||
        return call_torch_function(
 | 
			
		||||
            tx,
 | 
			
		||||
            self,
 | 
			
		||||
            build_torch_function_fn(tx, self.value, self.source),
 | 
			
		||||
            fn,
 | 
			
		||||
            types,
 | 
			
		||||
            args,
 | 
			
		||||
            kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def enter(self, tx):
 | 
			
		||||
        from .torch import TorchInGraphFunctionVariable
 | 
			
		||||
 | 
			
		||||
        if isinstance(self.value, NoEnterTorchFunctionMode):
 | 
			
		||||
            return ConstantVariable.create(None)
 | 
			
		||||
 | 
			
		||||
        TorchInGraphFunctionVariable(
 | 
			
		||||
            torch._C._push_on_torch_function_stack
 | 
			
		||||
        ).call_function(tx, [self], {})
 | 
			
		||||
        return ConstantVariable.create(None)
 | 
			
		||||
 | 
			
		||||
    def exit(self, tx: "InstructionTranslator", *args):
 | 
			
		||||
        from .torch import TorchInGraphFunctionVariable
 | 
			
		||||
 | 
			
		||||
        TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
 | 
			
		||||
            tx, [], {}
 | 
			
		||||
        )
 | 
			
		||||
        return ConstantVariable.create(None)
 | 
			
		||||
 | 
			
		||||
    def reconstruct_type(self, codegen):
 | 
			
		||||
        ty = NoEnterTorchFunctionMode
 | 
			
		||||
        codegen(
 | 
			
		||||
            AttrSource(
 | 
			
		||||
                codegen.tx.import_source(ty.__module__),
 | 
			
		||||
                ty.__name__,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def supports_graph_breaks(self):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def exit_on_graph_break(self):
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_all_args(args, kwargs):
 | 
			
		||||
@ -231,9 +401,13 @@ def build_torch_function_fn(tx: "InstructionTranslator", value, source):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
 | 
			
		||||
    return tx.output.torch_function_enabled and any(
 | 
			
		||||
    has_overridden_args = any(
 | 
			
		||||
        has_torch_function(arg) for arg in _get_all_args(args, kwargs)
 | 
			
		||||
    )
 | 
			
		||||
    tf_state = tx.symbolic_torch_function_state
 | 
			
		||||
    return (has_overridden_args and tf_state.torch_function_subclass_enabled) or (
 | 
			
		||||
        tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode()
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
 | 
			
		||||
@ -245,11 +419,20 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
 | 
			
		||||
        _get_subclass_type,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args])
 | 
			
		||||
 | 
			
		||||
    if tx.symbolic_torch_function_state.in_torch_function_mode():
 | 
			
		||||
        res = tx.symbolic_torch_function_state.call_torch_function_mode(
 | 
			
		||||
            tx, fn, types, args, kwargs
 | 
			
		||||
        )
 | 
			
		||||
        if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
 | 
			
		||||
            return res
 | 
			
		||||
 | 
			
		||||
    for arg in overloaded_args:
 | 
			
		||||
        res = arg.call_torch_function(
 | 
			
		||||
            tx,
 | 
			
		||||
            fn,
 | 
			
		||||
            TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]),
 | 
			
		||||
            types,
 | 
			
		||||
            args,
 | 
			
		||||
            kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,7 @@ import inspect
 | 
			
		||||
import itertools
 | 
			
		||||
import random
 | 
			
		||||
import sys
 | 
			
		||||
import threading
 | 
			
		||||
import types
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Dict, Generic, List, TYPE_CHECKING
 | 
			
		||||
@ -82,11 +83,6 @@ def is_forbidden_context_manager(ctx):
 | 
			
		||||
        from _pytest.python_api import RaisesContext
 | 
			
		||||
        from _pytest.recwarn import WarningsChecker
 | 
			
		||||
 | 
			
		||||
        # TODO mlazos: Temporary to get this stack to pass
 | 
			
		||||
        # remove in subsequent PR
 | 
			
		||||
        from torch.overrides import BaseTorchFunctionMode
 | 
			
		||||
 | 
			
		||||
        f_ctxs.append(BaseTorchFunctionMode)
 | 
			
		||||
        f_ctxs.append(RaisesContext)
 | 
			
		||||
        f_ctxs.append(WarningsChecker)
 | 
			
		||||
    except ImportError:
 | 
			
		||||
@ -413,15 +409,25 @@ class UserDefinedClassVariable(UserDefinedVariable):
 | 
			
		||||
            and self.source
 | 
			
		||||
            and not is_forbidden_context_manager(self.value)
 | 
			
		||||
        ):
 | 
			
		||||
            # import here to avoid an unfortunate circular dependency.
 | 
			
		||||
            from torch.overrides import TorchFunctionMode
 | 
			
		||||
 | 
			
		||||
            from .ctx_manager import GenericContextWrappingVariable
 | 
			
		||||
            from .torch_function import TorchFunctionModeVariable
 | 
			
		||||
 | 
			
		||||
            if issubclass(
 | 
			
		||||
                self.value, TorchFunctionMode
 | 
			
		||||
            ) and TorchFunctionModeVariable.is_supported_torch_function_mode(
 | 
			
		||||
                self.value
 | 
			
		||||
            ):
 | 
			
		||||
                var_cls = TorchFunctionModeVariable
 | 
			
		||||
            else:
 | 
			
		||||
                var_cls = GenericContextWrappingVariable
 | 
			
		||||
 | 
			
		||||
            cm_obj = tx.output.side_effects.track_object_new(
 | 
			
		||||
                self.source, self.value, GenericContextWrappingVariable, {}
 | 
			
		||||
                self.source, self.value, var_cls, {}
 | 
			
		||||
            )
 | 
			
		||||
            cm_obj.call_method(tx, "__init__", args, kwargs)
 | 
			
		||||
            return cm_obj
 | 
			
		||||
 | 
			
		||||
        elif is_namedtuple_cls(self.value):
 | 
			
		||||
            fields = namedtuple_fields(self.value)
 | 
			
		||||
            # check if this a quasi-namedtuple or a real one
 | 
			
		||||
@ -711,7 +717,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
 | 
			
		||||
            if method is object.__init__:
 | 
			
		||||
                return ConstantVariable.create(None)
 | 
			
		||||
 | 
			
		||||
            if is_standard_setattr(method):
 | 
			
		||||
            if is_standard_setattr(method) or isinstance(self.value, threading.local):
 | 
			
		||||
                return self.method_setattr_standard(tx, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
            # [NOTE] OrderedDict, dict subtypes must always have source
 | 
			
		||||
@ -809,7 +815,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
 | 
			
		||||
    def needs_slow_setattr(self):
 | 
			
		||||
        return not is_standard_setattr(
 | 
			
		||||
            inspect.getattr_static(self.value, "__setattr__", None)
 | 
			
		||||
        )
 | 
			
		||||
        ) and not isinstance(self.value, threading.local)
 | 
			
		||||
 | 
			
		||||
    def unpack_var_sequence(self, tx):
 | 
			
		||||
        if (
 | 
			
		||||
 | 
			
		||||
@ -506,7 +506,11 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
 | 
			
		||||
 | 
			
		||||
    def __torch_function__(self, func, types, args=(), kwargs=None):
 | 
			
		||||
        kwargs = kwargs or {}
 | 
			
		||||
        if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc:
 | 
			
		||||
        if (
 | 
			
		||||
            not torch.compiler.is_dynamo_compiling()
 | 
			
		||||
            and log.isEnabledFor(logging.DEBUG)
 | 
			
		||||
            and config.extended_debug_current_loc
 | 
			
		||||
        ):
 | 
			
		||||
            frame = _find_user_code_frame()
 | 
			
		||||
            if frame is not None:
 | 
			
		||||
                log.debug(
 | 
			
		||||
 | 
			
		||||
@ -28,6 +28,7 @@ from torch._ops import HigherOrderOperator
 | 
			
		||||
from torch._subclasses.fake_tensor import FakeTensorMode
 | 
			
		||||
from torch._subclasses.functional_tensor import disable_functional_mode
 | 
			
		||||
from torch.fx.experimental.proxy_tensor import (
 | 
			
		||||
    _temp_remove_metadata_torch_function_mode,
 | 
			
		||||
    _temp_remove_pre_dispatch_torch_function_mode,
 | 
			
		||||
    disable_proxy_modes_tracing,
 | 
			
		||||
    ProxyTorchDispatchMode,
 | 
			
		||||
@ -129,6 +130,10 @@ def cond(pred, true_fn, false_fn, operands):
 | 
			
		||||
    if torch.compiler.is_dynamo_compiling():
 | 
			
		||||
        return cond_op(pred, true_fn, false_fn, operands)
 | 
			
		||||
 | 
			
		||||
    from torch._dynamo.backends.debugging import (
 | 
			
		||||
        make_eager_backend_with_torch_function_mode,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if isinstance(pred, (bool, int, float)):
 | 
			
		||||
        log.warning(
 | 
			
		||||
            "Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
 | 
			
		||||
@ -169,12 +174,15 @@ def cond(pred, true_fn, false_fn, operands):
 | 
			
		||||
    def _cond_op_wrapper(*args, **kwargs):
 | 
			
		||||
        return cond_op(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    with _set_compilation_env():
 | 
			
		||||
        with torch._dynamo.utils.disable_cache_limit():
 | 
			
		||||
            with _temp_remove_pre_dispatch_torch_function_mode():
 | 
			
		||||
                return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)(
 | 
			
		||||
                    pred, true_fn, false_fn, operands
 | 
			
		||||
                )
 | 
			
		||||
    with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode():
 | 
			
		||||
        with _temp_remove_metadata_torch_function_mode() as metadata_mode:
 | 
			
		||||
            if metadata_mode:
 | 
			
		||||
                backend = make_eager_backend_with_torch_function_mode(metadata_mode)
 | 
			
		||||
            else:
 | 
			
		||||
                backend = "eager"
 | 
			
		||||
            return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)(
 | 
			
		||||
                pred, true_fn, false_fn, operands
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_fw_bw_graph_branches(true_fn, false_fn, *operands):
 | 
			
		||||
 | 
			
		||||
@ -20,6 +20,7 @@ from torch._higher_order_ops.utils import (
 | 
			
		||||
from torch._ops import HigherOrderOperator
 | 
			
		||||
from torch._subclasses.fake_tensor import FakeTensorMode
 | 
			
		||||
from torch.fx.experimental.proxy_tensor import (
 | 
			
		||||
    _temp_remove_metadata_torch_function_mode,
 | 
			
		||||
    disable_proxy_modes_tracing,
 | 
			
		||||
    ProxyTorchDispatchMode,
 | 
			
		||||
    track_tensor_tree,
 | 
			
		||||
@ -118,10 +119,19 @@ def scan(
 | 
			
		||||
        return scan(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    if not torch._dynamo.is_compiling():
 | 
			
		||||
        from torch._dynamo.backends.debugging import (
 | 
			
		||||
            make_eager_backend_with_torch_function_mode,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
 | 
			
		||||
            return torch.compile(_scan_op_wrapper, backend="eager", fullgraph=True)(
 | 
			
		||||
                combine_fn, init, xs, dim=dim, reverse=reverse
 | 
			
		||||
            )
 | 
			
		||||
            with _temp_remove_metadata_torch_function_mode() as metadata_mode:
 | 
			
		||||
                if metadata_mode:
 | 
			
		||||
                    backend = make_eager_backend_with_torch_function_mode(metadata_mode)
 | 
			
		||||
                else:
 | 
			
		||||
                    backend = "eager"
 | 
			
		||||
                return torch.compile(_scan_op_wrapper, backend=backend, fullgraph=True)(
 | 
			
		||||
                    combine_fn, init, xs, dim=dim, reverse=reverse
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    leaves_init, spec_init = pytree.tree_flatten(init)
 | 
			
		||||
    leaves_xs, spec_xs = pytree.tree_flatten(xs)
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,11 @@ from torch._higher_order_ops.utils import (
 | 
			
		||||
)
 | 
			
		||||
from torch._ops import HigherOrderOperator
 | 
			
		||||
from torch._subclasses.fake_tensor import FakeTensorMode
 | 
			
		||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
 | 
			
		||||
from torch.fx.experimental.proxy_tensor import (
 | 
			
		||||
    _temp_remove_metadata_torch_function_mode,
 | 
			
		||||
    ProxyTorchDispatchMode,
 | 
			
		||||
    track_tensor_tree,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WhileLoopOp(HigherOrderOperator):
 | 
			
		||||
@ -113,6 +117,9 @@ def while_loop(cond_fn, body_fn, carried_inputs):
 | 
			
		||||
        - 'while_loop' only supports **inference** right now. Autograd will be supported in the future.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    from torch._dynamo.backends.debugging import (
 | 
			
		||||
        make_eager_backend_with_torch_function_mode,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo.
 | 
			
		||||
    # parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs.
 | 
			
		||||
@ -140,9 +147,15 @@ def while_loop(cond_fn, body_fn, carried_inputs):
 | 
			
		||||
        return while_loop_op(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
 | 
			
		||||
        return torch.compile(_while_loop_op_wrapper, backend="eager", fullgraph=True)(
 | 
			
		||||
            cond_fn, body_fn, carried_inputs, additional_inputs
 | 
			
		||||
        )
 | 
			
		||||
        with _temp_remove_metadata_torch_function_mode() as metadata_mode:
 | 
			
		||||
            with _temp_remove_metadata_torch_function_mode() as metadata_mode:
 | 
			
		||||
                if metadata_mode:
 | 
			
		||||
                    backend = make_eager_backend_with_torch_function_mode(metadata_mode)
 | 
			
		||||
                else:
 | 
			
		||||
                    backend = "eager"
 | 
			
		||||
                return torch.compile(
 | 
			
		||||
                    _while_loop_op_wrapper, backend=backend, fullgraph=True
 | 
			
		||||
                )(cond_fn, body_fn, carried_inputs, additional_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd)
 | 
			
		||||
 | 
			
		||||
@ -2515,62 +2515,40 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard {
 | 
			
		||||
 public:
 | 
			
		||||
  TORCH_FUNCTION_MODE_STACK(
 | 
			
		||||
      const py::list& initial_stack,
 | 
			
		||||
      const py::list& ignored_types,
 | 
			
		||||
      py::object verbose_code_parts)
 | 
			
		||||
      : LeafGuard(std::move(verbose_code_parts)),
 | 
			
		||||
        _ref_stack(),
 | 
			
		||||
        _ignored_types() {
 | 
			
		||||
      : LeafGuard(std::move(verbose_code_parts)), _ref_stack() {
 | 
			
		||||
    Py_ssize_t len = PyList_Size(initial_stack.ptr());
 | 
			
		||||
    for (Py_ssize_t idx = 0; idx < len; idx++) {
 | 
			
		||||
      PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref
 | 
			
		||||
      this->_ref_stack.push_back(Py_TYPE(mode));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    len = PyList_Size(ignored_types.ptr());
 | 
			
		||||
    for (Py_ssize_t idx = 0; idx < len; idx++) {
 | 
			
		||||
      PyObject* type_obj =
 | 
			
		||||
          PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref
 | 
			
		||||
      if (PyType_Check(type_obj) == 0) {
 | 
			
		||||
        PyErr_SetString(
 | 
			
		||||
            PyExc_TypeError, "ignored_types should contain a list of types");
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      PyTypeObject* type = (PyTypeObject*)type_obj;
 | 
			
		||||
      this->_ignored_types.insert(type);
 | 
			
		||||
      auto type = Py_TYPE(mode);
 | 
			
		||||
      this->_ref_stack.push_back(type);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool check_nopybind(PyObject* value) override {
 | 
			
		||||
    // Ignore value arg, only used to satisfy the interface
 | 
			
		||||
    size_t ref_ind = 0;
 | 
			
		||||
    int64_t len = at::impl::PythonTorchFunctionTLS::stack_len();
 | 
			
		||||
    const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len();
 | 
			
		||||
    const size_t ref_stack_size = this->_ref_stack.size();
 | 
			
		||||
 | 
			
		||||
    for (int64_t idx = 0; idx < len; idx++) {
 | 
			
		||||
    if (len != ref_stack_size) {
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int64_t idx = 0; (size_t)idx < len; idx++) {
 | 
			
		||||
      std::shared_ptr<c10::SafePyObject> mode =
 | 
			
		||||
          at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
 | 
			
		||||
 | 
			
		||||
      PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
 | 
			
		||||
      // skip ignored types
 | 
			
		||||
      if (this->_ignored_types.count(mode_type) > 0) {
 | 
			
		||||
        continue;
 | 
			
		||||
      }
 | 
			
		||||
      // if we already have more non-ignored modes than the ref stack
 | 
			
		||||
      // or if the mode doesn't match at the current index, return false
 | 
			
		||||
      else if (
 | 
			
		||||
          (ref_stack_size == 0) || (ref_ind > ref_stack_size - 1) ||
 | 
			
		||||
          mode_type != _ref_stack[ref_ind]) {
 | 
			
		||||
      if (mode_type != _ref_stack.at(idx)) {
 | 
			
		||||
        return false;
 | 
			
		||||
      }
 | 
			
		||||
      ref_ind++;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return ref_ind == this->_ref_stack.size();
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  std::vector<PyTypeObject*> _ref_stack;
 | 
			
		||||
  std::set<PyTypeObject*> _ignored_types;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class TENSOR_MATCH : public LeafGuard {
 | 
			
		||||
@ -3735,7 +3713,7 @@ PyObject* torch_c_dynamo_guards_init() {
 | 
			
		||||
      LeafGuard,
 | 
			
		||||
      std::shared_ptr<TORCH_FUNCTION_MODE_STACK>>(
 | 
			
		||||
      py_m, "TORCH_FUNCTION_MODE_STACK")
 | 
			
		||||
      .def(py::init<py::list, py::list, py::list>())
 | 
			
		||||
      .def(py::init<py::list, py::list>())
 | 
			
		||||
      .def("__call__", &TORCH_FUNCTION_MODE_STACK::check);
 | 
			
		||||
  py::class_<DATA_PTR_MATCH, LeafGuard, std::shared_ptr<DATA_PTR_MATCH>>(
 | 
			
		||||
      py_m, "DATA_PTR_MATCH")
 | 
			
		||||
@ -3972,10 +3950,9 @@ PyObject* torch_c_dynamo_guards_init() {
 | 
			
		||||
          "add_torch_function_mode_stack_guard",
 | 
			
		||||
          [](GuardManager& self,
 | 
			
		||||
             const py::list& initial_stack,
 | 
			
		||||
             const py::list& ignored_types,
 | 
			
		||||
             py::object verbose_code_parts) -> void {
 | 
			
		||||
            self.add_leaf_guard(std::make_shared<TORCH_FUNCTION_MODE_STACK>(
 | 
			
		||||
                initial_stack, ignored_types, std::move(verbose_code_parts)));
 | 
			
		||||
                initial_stack, std::move(verbose_code_parts)));
 | 
			
		||||
          })
 | 
			
		||||
      .def(
 | 
			
		||||
          "add_data_ptr_guard",
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,7 @@ import typing_extensions
 | 
			
		||||
import warnings
 | 
			
		||||
import weakref
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from contextlib import contextmanager, ExitStack, nullcontext
 | 
			
		||||
from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
@ -1084,38 +1084,43 @@ class PythonKeyTracer(Tracer):
 | 
			
		||||
            return e
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@contextmanager
 | 
			
		||||
def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]:
 | 
			
		||||
    from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode
 | 
			
		||||
def _make_temp_remove_mode_context_manager(
 | 
			
		||||
    mode_ty: Type[TorchFunctionMode],
 | 
			
		||||
) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]:
 | 
			
		||||
    @contextmanager
 | 
			
		||||
    def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]:
 | 
			
		||||
        from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode
 | 
			
		||||
 | 
			
		||||
    temp_elements = []
 | 
			
		||||
    pre_dispatch_mode = None
 | 
			
		||||
        temp_elements = []
 | 
			
		||||
        removed_mode = None
 | 
			
		||||
 | 
			
		||||
    while _len_torch_function_stack() > 0:
 | 
			
		||||
        mode = _pop_mode()
 | 
			
		||||
        if isinstance(mode, PreDispatchTorchFunctionMode):
 | 
			
		||||
            pre_dispatch_mode = mode
 | 
			
		||||
            break
 | 
			
		||||
        else:
 | 
			
		||||
            temp_elements.append(mode)
 | 
			
		||||
        while _len_torch_function_stack() > 0:
 | 
			
		||||
            mode = _pop_mode()
 | 
			
		||||
            if isinstance(mode, mode_ty):
 | 
			
		||||
                removed_mode = mode
 | 
			
		||||
                break
 | 
			
		||||
            else:
 | 
			
		||||
                temp_elements.append(mode)
 | 
			
		||||
 | 
			
		||||
    for mode in reversed(temp_elements):
 | 
			
		||||
        _push_mode(mode)
 | 
			
		||||
        for mode in reversed(temp_elements):
 | 
			
		||||
            _push_mode(mode)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        yield
 | 
			
		||||
        try:
 | 
			
		||||
            yield removed_mode
 | 
			
		||||
 | 
			
		||||
    finally:
 | 
			
		||||
        if pre_dispatch_mode is not None:
 | 
			
		||||
            count = len(temp_elements)
 | 
			
		||||
            while count > 0:
 | 
			
		||||
                mode = _pop_mode()
 | 
			
		||||
                count -= 1
 | 
			
		||||
        finally:
 | 
			
		||||
            if removed_mode is not None:
 | 
			
		||||
                count = len(temp_elements)
 | 
			
		||||
                while count > 0:
 | 
			
		||||
                    mode = _pop_mode()
 | 
			
		||||
                    count -= 1
 | 
			
		||||
 | 
			
		||||
            temp_elements.append(pre_dispatch_mode)
 | 
			
		||||
                temp_elements.append(removed_mode)
 | 
			
		||||
 | 
			
		||||
            for mode in reversed(temp_elements):
 | 
			
		||||
                _push_mode(mode)
 | 
			
		||||
                for mode in reversed(temp_elements):
 | 
			
		||||
                    _push_mode(mode)
 | 
			
		||||
 | 
			
		||||
    return context_manager_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch._disable_dynamo
 | 
			
		||||
@ -1230,6 +1235,11 @@ class TorchFunctionMetadataMode(TorchFunctionMode):
 | 
			
		||||
        return func(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager(
 | 
			
		||||
    TorchFunctionMetadataMode
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# This mode is **only** used for pre_dispatch tracing.
 | 
			
		||||
# In particular, we need to make sure that autograd/autocast API's
 | 
			
		||||
# that do not desugar into dispatcher operators stay in the graph.
 | 
			
		||||
@ -1258,6 +1268,11 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
 | 
			
		||||
        return func(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager(
 | 
			
		||||
    PreDispatchTorchFunctionMode
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProxyTorchDispatchMode(TorchDispatchMode):
 | 
			
		||||
    # Ensure this is read-only; this exists only for legacy reasons
 | 
			
		||||
    @property
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,7 @@ from torch._higher_order_ops.flex_attention import (
 | 
			
		||||
)
 | 
			
		||||
from torch._higher_order_ops.utils import _set_compilation_env
 | 
			
		||||
from torch.fx.experimental.proxy_tensor import (
 | 
			
		||||
    _temp_remove_metadata_torch_function_mode,
 | 
			
		||||
    _temp_remove_pre_dispatch_torch_function_mode,
 | 
			
		||||
)
 | 
			
		||||
from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input
 | 
			
		||||
@ -1033,6 +1034,10 @@ def flex_attention(
 | 
			
		||||
    if not torch._dynamo.is_dynamo_supported():
 | 
			
		||||
        raise RuntimeError("flex_attention requires dynamo support")
 | 
			
		||||
 | 
			
		||||
    from torch._dynamo.backends.debugging import (
 | 
			
		||||
        make_eager_backend_with_torch_function_mode,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Dynamo is expecting a callable with "__code__" attribute.
 | 
			
		||||
    # We cannot directly pass hop to it. So we wrap it in a dummy function.
 | 
			
		||||
    def _flex_attention_hop_wrapper(*args, **kwargs):
 | 
			
		||||
@ -1041,18 +1046,25 @@ def flex_attention(
 | 
			
		||||
    with _set_compilation_env():
 | 
			
		||||
        with torch._dynamo.utils.disable_cache_limit():
 | 
			
		||||
            with _temp_remove_pre_dispatch_torch_function_mode():
 | 
			
		||||
                out, lse = torch.compile(
 | 
			
		||||
                    _flex_attention_hop_wrapper, backend="eager", fullgraph=True
 | 
			
		||||
                )(
 | 
			
		||||
                    query,
 | 
			
		||||
                    key,
 | 
			
		||||
                    value,
 | 
			
		||||
                    score_mod,
 | 
			
		||||
                    block_mask.as_tuple(),
 | 
			
		||||
                    scale,
 | 
			
		||||
                    kernel_options,
 | 
			
		||||
                )
 | 
			
		||||
                if return_lse:
 | 
			
		||||
                    return out, lse * math.log(2)
 | 
			
		||||
                else:
 | 
			
		||||
                    return out
 | 
			
		||||
                with _temp_remove_metadata_torch_function_mode() as metadata_mode:
 | 
			
		||||
                    if metadata_mode:
 | 
			
		||||
                        backend = make_eager_backend_with_torch_function_mode(
 | 
			
		||||
                            metadata_mode
 | 
			
		||||
                        )
 | 
			
		||||
                    else:
 | 
			
		||||
                        backend = "eager"
 | 
			
		||||
                    out, lse = torch.compile(
 | 
			
		||||
                        _flex_attention_hop_wrapper, backend=backend, fullgraph=True
 | 
			
		||||
                    )(
 | 
			
		||||
                        query,
 | 
			
		||||
                        key,
 | 
			
		||||
                        value,
 | 
			
		||||
                        score_mod,
 | 
			
		||||
                        block_mask.as_tuple(),
 | 
			
		||||
                        scale,
 | 
			
		||||
                        kernel_options,
 | 
			
		||||
                    )
 | 
			
		||||
                    if return_lse:
 | 
			
		||||
                        return out, lse * math.log(2)
 | 
			
		||||
                    else:
 | 
			
		||||
                        return out
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user