mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			3 Commits
		
	
	
		
			cslpull91
			...
			mlazos/mla
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 7342afdad4 | |||
| 7f237dd271 | |||
| 7e1c068aa2 | 
@ -14,6 +14,17 @@ from torch.utils._device import DeviceContext
 | 
			
		||||
from torch.utils._python_dispatch import TorchDispatchMode
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestMode(BaseTorchFunctionMode):
 | 
			
		||||
    def __torch_function__(self, func, types, args, kwargs=None):
 | 
			
		||||
        if not kwargs:
 | 
			
		||||
            kwargs = {}
 | 
			
		||||
 | 
			
		||||
        if func == torch.add:
 | 
			
		||||
            return torch.zeros(2, 2)
 | 
			
		||||
 | 
			
		||||
        return super().__torch_function__(func, types, args, kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def setUpClass(cls):
 | 
			
		||||
@ -324,6 +335,199 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
            fn(inp)
 | 
			
		||||
        self.assertEqual(cnt.frame_count, 2)
 | 
			
		||||
 | 
			
		||||
    def test_nested_torch_function_mode(self):
 | 
			
		||||
        mode_1_called = False
 | 
			
		||||
        mode_2_called = False
 | 
			
		||||
 | 
			
		||||
        def reset_state():
 | 
			
		||||
            nonlocal mode_1_called
 | 
			
		||||
            nonlocal mode_2_called
 | 
			
		||||
            mode_1_called = False
 | 
			
		||||
            mode_2_called = False
 | 
			
		||||
 | 
			
		||||
        ones = torch.ones(2, 2)
 | 
			
		||||
        zeros = torch.zeros(2, 2)
 | 
			
		||||
 | 
			
		||||
        class TestMode1(BaseTorchFunctionMode):
 | 
			
		||||
            def __torch_function__(self, func, types, args, kwargs=None):
 | 
			
		||||
                if not kwargs:
 | 
			
		||||
                    kwargs = {}
 | 
			
		||||
 | 
			
		||||
                nonlocal mode_1_called
 | 
			
		||||
 | 
			
		||||
                mode_1_called = True
 | 
			
		||||
 | 
			
		||||
                if func == torch.add:
 | 
			
		||||
                    return zeros
 | 
			
		||||
 | 
			
		||||
                return super().__torch_function__(func, types, args, kwargs)
 | 
			
		||||
 | 
			
		||||
        class TestMode2(BaseTorchFunctionMode):
 | 
			
		||||
            def __torch_function__(self, func, types, args, kwargs=None):
 | 
			
		||||
                if not kwargs:
 | 
			
		||||
                    kwargs = {}
 | 
			
		||||
 | 
			
		||||
                nonlocal mode_2_called
 | 
			
		||||
 | 
			
		||||
                mode_2_called = True
 | 
			
		||||
 | 
			
		||||
                if func == torch.mul:
 | 
			
		||||
                    return ones
 | 
			
		||||
 | 
			
		||||
                return super().__torch_function__(func, types, args, kwargs)
 | 
			
		||||
 | 
			
		||||
        def fn(x):
 | 
			
		||||
            return torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
        def fn_2(x):
 | 
			
		||||
            return torch.mul(x, 3) + torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
        inp = torch.ones(2, 2) + 1
 | 
			
		||||
 | 
			
		||||
        for fn_i in [fn, fn_2]:
 | 
			
		||||
            fn_opt = torch.compile(fn_i, fullgraph=True)
 | 
			
		||||
            with TestMode1(), TestMode2():
 | 
			
		||||
                expected = fn_i(inp), mode_1_called, mode_2_called
 | 
			
		||||
                reset_state()
 | 
			
		||||
                actual = fn_opt(inp), mode_1_called, mode_2_called
 | 
			
		||||
                reset_state()
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_disable(self):
 | 
			
		||||
        class TestSubclass(torch.Tensor):
 | 
			
		||||
            @classmethod
 | 
			
		||||
            def __torch_function__(cls, func, types, args, kwargs=None):
 | 
			
		||||
                if not kwargs:
 | 
			
		||||
                    kwargs = {}
 | 
			
		||||
                if func == torch.add:
 | 
			
		||||
                    return torch.ones(2, 2)
 | 
			
		||||
                return super().__torch_function__(func, types, args, kwargs)
 | 
			
		||||
 | 
			
		||||
        class TestMode(BaseTorchFunctionMode):
 | 
			
		||||
            def __torch_function__(self, func, types, args, kwargs=None):
 | 
			
		||||
                if not kwargs:
 | 
			
		||||
                    kwargs = {}
 | 
			
		||||
 | 
			
		||||
                if func == torch.add:
 | 
			
		||||
                    return torch.zeros(2, 2)
 | 
			
		||||
 | 
			
		||||
                return super().__torch_function__(func, types, args, kwargs)
 | 
			
		||||
 | 
			
		||||
        def fn(x):
 | 
			
		||||
            return torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
        inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
 | 
			
		||||
 | 
			
		||||
        fn_opt = torch.compile(fn, fullgraph=True)
 | 
			
		||||
        with TestMode(), torch._dynamo.config.patch(
 | 
			
		||||
            "traceable_tensor_subclasses", {TestSubclass}
 | 
			
		||||
        ):
 | 
			
		||||
            with torch._C.DisableTorchFunctionSubclass():
 | 
			
		||||
                expected = fn(inp)
 | 
			
		||||
                actual = fn_opt(inp)
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
            with torch._C.DisableTorchFunction():
 | 
			
		||||
                expected = fn(inp)
 | 
			
		||||
                actual = fn_opt(inp)
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_highest_priority(self):
 | 
			
		||||
        class TestSubclass(torch.Tensor):
 | 
			
		||||
            @classmethod
 | 
			
		||||
            def __torch_function__(cls, func, types, args, kwargs=None):
 | 
			
		||||
                if not kwargs:
 | 
			
		||||
                    kwargs = {}
 | 
			
		||||
                if func == torch.add:
 | 
			
		||||
                    return torch.ones(2, 2)
 | 
			
		||||
                return super().__torch_function__(func, types, args, kwargs)
 | 
			
		||||
 | 
			
		||||
        def fn(x):
 | 
			
		||||
            return torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
        inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
 | 
			
		||||
 | 
			
		||||
        fn_opt = torch.compile(fn, fullgraph=True)
 | 
			
		||||
        with TestMode(), torch._dynamo.config.patch(
 | 
			
		||||
            "traceable_tensor_subclasses", {TestSubclass}
 | 
			
		||||
        ):
 | 
			
		||||
            expected = fn(inp)
 | 
			
		||||
            actual = fn_opt(inp)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_enter_exit(self):
 | 
			
		||||
        def fn(x, y):
 | 
			
		||||
            with TestMode():
 | 
			
		||||
                o = torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
            return torch.add(o, y)
 | 
			
		||||
 | 
			
		||||
        inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
 | 
			
		||||
        fn_opt = torch.compile(fn, fullgraph=True)
 | 
			
		||||
 | 
			
		||||
        expected = fn(*inp)
 | 
			
		||||
        actual = fn_opt(*inp)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_graph_break(self):
 | 
			
		||||
        def fn(x, y):
 | 
			
		||||
            with TestMode():
 | 
			
		||||
                torch._dynamo.graph_break()
 | 
			
		||||
                o = torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
            return torch.add(o, y)
 | 
			
		||||
 | 
			
		||||
        inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
 | 
			
		||||
        fn_opt = torch.compile(fn)
 | 
			
		||||
 | 
			
		||||
        expected = fn(*inp)
 | 
			
		||||
        actual = fn_opt(*inp)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_and_pop_graph_break(self):
 | 
			
		||||
        def fn(x, y):
 | 
			
		||||
            with TestMode():
 | 
			
		||||
                z = _pop_torch_function_stack()
 | 
			
		||||
                torch._dynamo.graph_break()
 | 
			
		||||
                _push_on_torch_function_stack(z)
 | 
			
		||||
                o = torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
            return torch.add(o, y)
 | 
			
		||||
 | 
			
		||||
        inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
 | 
			
		||||
        fn_opt = torch.compile(fn)
 | 
			
		||||
 | 
			
		||||
        expected = fn(*inp)
 | 
			
		||||
        actual = fn_opt(*inp)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_and_pop_graph_break_mutation(self):
 | 
			
		||||
        def fn(x, y):
 | 
			
		||||
            with TestMode():
 | 
			
		||||
                z = _pop_torch_function_stack()
 | 
			
		||||
                z.y = 5
 | 
			
		||||
                torch._dynamo.graph_break()
 | 
			
		||||
                _push_on_torch_function_stack(z)
 | 
			
		||||
                o = torch.add(x, 3)
 | 
			
		||||
                o = torch.mul(o, z.y)
 | 
			
		||||
 | 
			
		||||
            return torch.add(o, y)
 | 
			
		||||
 | 
			
		||||
        inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
 | 
			
		||||
        fn_opt = torch.compile(fn)
 | 
			
		||||
 | 
			
		||||
        expected = fn(*inp)
 | 
			
		||||
        actual = fn_opt(*inp)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    from torch._dynamo.test_case import run_tests
 | 
			
		||||
 | 
			
		||||
@ -9,12 +9,7 @@ from functorch.experimental import control_flow
 | 
			
		||||
from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException
 | 
			
		||||
from torch._higher_order_ops.associative_scan import associative_scan
 | 
			
		||||
from torch._higher_order_ops.while_loop import while_loop
 | 
			
		||||
from torch._subclasses.functional_tensor import (
 | 
			
		||||
    CppFunctionalizeAPI,
 | 
			
		||||
    FunctionalTensor,
 | 
			
		||||
    FunctionalTensorMode,
 | 
			
		||||
    PythonFunctionalizeAPI,
 | 
			
		||||
)
 | 
			
		||||
from torch._subclasses.functional_tensor import FunctionalTensor
 | 
			
		||||
from torch.fx.experimental.proxy_tensor import make_fx
 | 
			
		||||
from torch.testing._internal.common_cuda import SM70OrLater
 | 
			
		||||
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
 | 
			
		||||
@ -24,6 +19,7 @@ from torch.testing._internal.common_utils import (
 | 
			
		||||
    IS_WINDOWS,
 | 
			
		||||
    parametrize,
 | 
			
		||||
    run_tests,
 | 
			
		||||
    skipIfCrossRef,
 | 
			
		||||
    skipIfTorchDynamo,
 | 
			
		||||
    TEST_WITH_TORCHDYNAMO,
 | 
			
		||||
    TestCase,
 | 
			
		||||
@ -1557,6 +1553,7 @@ class TestControlFlowTraced(TestCase):
 | 
			
		||||
        self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True)))
 | 
			
		||||
 | 
			
		||||
    @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
 | 
			
		||||
    @skipIfCrossRef  # Arg order changes with crossref,
 | 
			
		||||
    def test_cond_simple_with_linear_compile_check_graph(self):
 | 
			
		||||
        from torch._dynamo.testing import EagerAndRecordGraphs
 | 
			
		||||
 | 
			
		||||
@ -1664,19 +1661,6 @@ def forward(self, arg0_1, arg1_1, arg2_1):
 | 
			
		||||
    """,  # noqa: B950
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _wrap_with_functionalize(self, fn, func_type):
 | 
			
		||||
        mode = None
 | 
			
		||||
        if func_type == "cpp":
 | 
			
		||||
            fn = CppFunctionalizeAPI().functionalize(fn)
 | 
			
		||||
        elif func_type == "python":
 | 
			
		||||
            fn = PythonFunctionalizeAPI().functionalize(fn)
 | 
			
		||||
            mode = FunctionalTensorMode()
 | 
			
		||||
        elif func_type == "functorch":
 | 
			
		||||
            fn = torch.func.functionalize(fn)
 | 
			
		||||
        else:
 | 
			
		||||
            assert func_type == "no"
 | 
			
		||||
        return fn, mode
 | 
			
		||||
 | 
			
		||||
    @parametrize("func_type", ["no", "cpp", "python", "functorch"])
 | 
			
		||||
    def test_while_loop_simple_functionalize_check_graph(self, func_type):
 | 
			
		||||
        fn, inp = WHILE_LOOP_TESTS["simple_with_mutation"]
 | 
			
		||||
@ -1819,6 +1803,7 @@ def forward(self, arg0_1):
 | 
			
		||||
        self._check_compile(fn, inp, backend=backend)
 | 
			
		||||
 | 
			
		||||
    @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
 | 
			
		||||
    @skipIfCrossRef  # Arg order changes with cross ref
 | 
			
		||||
    def test_while_loop_simple_with_linear_compile_check_graph(self):
 | 
			
		||||
        fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
 | 
			
		||||
        from torch._dynamo.testing import EagerAndRecordGraphs
 | 
			
		||||
@ -1894,135 +1879,6 @@ def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_bo
 | 
			
		||||
    return (child, child_1)""",  # noqa: B950
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_while_loop_nested2_traced(self):
 | 
			
		||||
        fn, inp = WHILE_LOOP_TESTS["nested2"]
 | 
			
		||||
        graphs = self._check_tracing(fn, inp)
 | 
			
		||||
        gm = graphs["symbolic"]
 | 
			
		||||
        outer_body = gm.while_loop_body_graph_0
 | 
			
		||||
        outer_cond = gm.while_loop_cond_graph_0
 | 
			
		||||
        inner_body = outer_body.while_loop_body_graph_0
 | 
			
		||||
        inner_cond = outer_body.while_loop_cond_graph_0
 | 
			
		||||
        self.assertExpectedInline(
 | 
			
		||||
            gm.code.strip("\n"),
 | 
			
		||||
            """\
 | 
			
		||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
 | 
			
		||||
    while_loop_cond_graph_0 = self.while_loop_cond_graph_0
 | 
			
		||||
    while_loop_body_graph_0 = self.while_loop_body_graph_0
 | 
			
		||||
    while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ());  while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
 | 
			
		||||
    getitem = while_loop[0]
 | 
			
		||||
    getitem_1 = while_loop[1]
 | 
			
		||||
    getitem_2 = while_loop[2]
 | 
			
		||||
    getitem_3 = while_loop[3];  while_loop = None
 | 
			
		||||
    return (getitem, getitem_1, getitem_2, getitem_3)
 | 
			
		||||
    """,  # noqa: B950
 | 
			
		||||
        )
 | 
			
		||||
        self.assertExpectedInline(
 | 
			
		||||
            outer_body.code.strip("\n"),
 | 
			
		||||
            """\
 | 
			
		||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
 | 
			
		||||
    while_loop_cond_graph_0 = self.while_loop_cond_graph_0
 | 
			
		||||
    while_loop_body_graph_0 = self.while_loop_body_graph_0
 | 
			
		||||
    while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ());  while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
 | 
			
		||||
    getitem = while_loop[0]
 | 
			
		||||
    getitem_1 = while_loop[1]
 | 
			
		||||
    getitem_2 = while_loop[2]
 | 
			
		||||
    getitem_3 = while_loop[3];  while_loop = None
 | 
			
		||||
    sub = torch.ops.aten.sub.Tensor(getitem, 1);  getitem = None
 | 
			
		||||
    clone = torch.ops.aten.clone.default(getitem_1);  getitem_1 = None
 | 
			
		||||
    mul = torch.ops.aten.mul.Tensor(getitem_2, 2);  getitem_2 = None
 | 
			
		||||
    div = torch.ops.aten.div.Tensor(getitem_3, 2);  getitem_3 = None
 | 
			
		||||
    return (sub, clone, mul, div)
 | 
			
		||||
    """,  # noqa: B950
 | 
			
		||||
        )
 | 
			
		||||
        self.assertExpectedInline(
 | 
			
		||||
            outer_body.code.strip("\n"),
 | 
			
		||||
            """\
 | 
			
		||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
 | 
			
		||||
    while_loop_cond_graph_0 = self.while_loop_cond_graph_0
 | 
			
		||||
    while_loop_body_graph_0 = self.while_loop_body_graph_0
 | 
			
		||||
    while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ());  while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
 | 
			
		||||
    getitem = while_loop[0]
 | 
			
		||||
    getitem_1 = while_loop[1]
 | 
			
		||||
    getitem_2 = while_loop[2]
 | 
			
		||||
    getitem_3 = while_loop[3];  while_loop = None
 | 
			
		||||
    sub = torch.ops.aten.sub.Tensor(getitem, 1);  getitem = None
 | 
			
		||||
    clone = torch.ops.aten.clone.default(getitem_1);  getitem_1 = None
 | 
			
		||||
    mul = torch.ops.aten.mul.Tensor(getitem_2, 2);  getitem_2 = None
 | 
			
		||||
    div = torch.ops.aten.div.Tensor(getitem_3, 2);  getitem_3 = None
 | 
			
		||||
    return (sub, clone, mul, div)
 | 
			
		||||
    """,  # noqa: B950
 | 
			
		||||
        )
 | 
			
		||||
        self.assertExpectedInline(
 | 
			
		||||
            inner_body.code.strip("\n"),
 | 
			
		||||
            """\
 | 
			
		||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
 | 
			
		||||
    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
 | 
			
		||||
    sub = torch.ops.aten.sub.Tensor(arg1_1, 1);  arg1_1 = None
 | 
			
		||||
    add = torch.ops.aten.add.Tensor(arg2_1, 3.14);  arg2_1 = None
 | 
			
		||||
    sub_1 = torch.ops.aten.sub.Tensor(arg3_1, 2.71);  arg3_1 = None
 | 
			
		||||
    return (clone, sub, add, sub_1)
 | 
			
		||||
    """,
 | 
			
		||||
        )
 | 
			
		||||
        self.assertExpectedInline(
 | 
			
		||||
            inner_cond.code.strip("\n"),
 | 
			
		||||
            """\
 | 
			
		||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
 | 
			
		||||
    gt = torch.ops.aten.gt.Scalar(arg1_1, 0);  arg1_1 = None
 | 
			
		||||
    return gt
 | 
			
		||||
    """,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_cond_nested_traced(self):
 | 
			
		||||
        def true_nested(y):
 | 
			
		||||
            return y * y
 | 
			
		||||
 | 
			
		||||
        def false_nested(y):
 | 
			
		||||
            return y + y
 | 
			
		||||
 | 
			
		||||
        def true_fn(x, pred2):
 | 
			
		||||
            z = cond(pred2, true_nested, false_nested, [x])
 | 
			
		||||
            return x + z
 | 
			
		||||
 | 
			
		||||
        def false_fn(x, _):
 | 
			
		||||
            return x.cos()
 | 
			
		||||
 | 
			
		||||
        def f(x, pred, pred2):
 | 
			
		||||
            return cond(pred, true_fn, false_fn, [x, pred2])
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(4)
 | 
			
		||||
        graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
 | 
			
		||||
 | 
			
		||||
        result_true_true = graph.forward(
 | 
			
		||||
            x, torch.tensor(True), torch.tensor(True)
 | 
			
		||||
        )  # True + True -> x * x
 | 
			
		||||
        result_true_false = graph.forward(
 | 
			
		||||
            x, torch.tensor(True), torch.tensor(False)
 | 
			
		||||
        )  # True + True -> x + x
 | 
			
		||||
        result_false_true = graph.forward(
 | 
			
		||||
            x, torch.tensor(False), torch.tensor(True)
 | 
			
		||||
        )  # False + either -> cos
 | 
			
		||||
        result_false_false = graph.forward(
 | 
			
		||||
            x, torch.tensor(False), torch.tensor(False)
 | 
			
		||||
        )  # False + either -> cos
 | 
			
		||||
 | 
			
		||||
        self.assertNotEqual(result_true_true, result_true_false)
 | 
			
		||||
        self.assertFalse(torch.allclose(result_false_true, result_true_true))
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(result_false_true, result_false_false)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(result_true_true, (x * x) + x)
 | 
			
		||||
        self.assertEqual(result_true_false, x + x + x)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(result_false_true, torch.cos(x))
 | 
			
		||||
 | 
			
		||||
        graph = make_fx(f, tracing_mode="symbolic")(
 | 
			
		||||
            x, torch.tensor(False), torch.tensor(False)
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            graph(x, torch.tensor(True), torch.tensor(True)),
 | 
			
		||||
            f(x, torch.tensor(True), torch.tensor(True)),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_cond_functionalized(self):
 | 
			
		||||
        def true_fn(x):
 | 
			
		||||
            y = x.sin()
 | 
			
		||||
 | 
			
		||||
@ -181,12 +181,10 @@ class AutoFunctionalizeTests(torch._inductor.test_case.TestCase):
 | 
			
		||||
                self.assertExpectedInline(
 | 
			
		||||
                    post_grad_graphs,
 | 
			
		||||
                    """\
 | 
			
		||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: \
 | 
			
		||||
"f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
 | 
			
		||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
 | 
			
		||||
        # No stacktrace found for following nodes
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg4_1 = arg2_1 = \
 | 
			
		||||
arg3_1 = arg1_1 = arg0_1 = foo_default = None
 | 
			
		||||
        return ()""",
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1);  arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = foo_default = None
 | 
			
		||||
        return ()""",  # noqa: B950
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
 | 
			
		||||
@ -240,7 +238,7 @@ arg3_1 = arg1_1 = arg0_1 = foo_default = None
 | 
			
		||||
                    post_grad_graphs,
 | 
			
		||||
                    """\
 | 
			
		||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1);  arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = None
 | 
			
		||||
        getitem_4: "f32[3][1]cpu" = foo_default[0]
 | 
			
		||||
        getitem_5: "f32[3][1]cpu" = foo_default[1];  foo_default = None
 | 
			
		||||
        return (getitem_4, getitem_5)""",  # noqa: B950
 | 
			
		||||
@ -327,9 +325,8 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
 | 
			
		||||
                    """\
 | 
			
		||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
 | 
			
		||||
        # No stacktrace found for following nodes
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  \
 | 
			
		||||
arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
 | 
			
		||||
        return ()""",
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg0_1, 2, arg1_1);  arg2_1 = arg3_1 = arg0_1 = arg1_1 = foo_default = None
 | 
			
		||||
        return ()""",  # noqa: B950
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
 | 
			
		||||
@ -403,9 +400,9 @@ arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
 | 
			
		||||
                        post_grad_graphs,
 | 
			
		||||
                        """\
 | 
			
		||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"):
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg5_1, [arg3_1, arg4_1], arg2_1, 2, arg1_1);  arg3_1 = arg4_1 = arg1_1 = foo_default = None
 | 
			
		||||
        copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1);  arg2_1 = copy_ = None
 | 
			
		||||
        copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg5_1, arg5_1);  arg5_1 = copy__1 = None
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg1_1, [arg4_1, arg5_1], arg2_1, 2, arg3_1);  arg4_1 = arg5_1 = arg3_1 = foo_default = None
 | 
			
		||||
        copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy_ = None
 | 
			
		||||
        copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1);  arg2_1 = copy__1 = None
 | 
			
		||||
        return ()""",  # noqa: B950
 | 
			
		||||
                        ignore_comments=True,
 | 
			
		||||
                        ignore_empty_lines=True,
 | 
			
		||||
@ -415,9 +412,9 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1
 | 
			
		||||
                        post_grad_graphs,
 | 
			
		||||
                        """\
 | 
			
		||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg2_1 = arg3_1 = arg0_1 = foo_default = None
 | 
			
		||||
        copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy_ = None
 | 
			
		||||
        copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1);  arg4_1 = copy__1 = None
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1);  arg3_1 = arg4_1 = arg2_1 = foo_default = None
 | 
			
		||||
        copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1);  arg0_1 = copy_ = None
 | 
			
		||||
        copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy__1 = None
 | 
			
		||||
        return ()""",  # noqa: B950
 | 
			
		||||
                        ignore_comments=True,
 | 
			
		||||
                        ignore_empty_lines=True,
 | 
			
		||||
@ -504,12 +501,11 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
 | 
			
		||||
                    post_grad_graphs,
 | 
			
		||||
                    """\
 | 
			
		||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg2_1 = arg3_1 = arg0_1 = None
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1);  arg3_1 = arg4_1 = arg2_1 = None
 | 
			
		||||
        getitem_4: "f32[3][1]cpu" = foo_default[0]
 | 
			
		||||
        getitem_5: "f32[3][1]cpu" = foo_default[1];  foo_default = None
 | 
			
		||||
 | 
			
		||||
        copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy_ = None
 | 
			
		||||
        copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1);  arg4_1 = copy__1 = None
 | 
			
		||||
        copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1);  arg0_1 = copy_ = None
 | 
			
		||||
        copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy__1 = None
 | 
			
		||||
        return (getitem_4, getitem_5)""",  # noqa: B950
 | 
			
		||||
                    ignore_comments=True,
 | 
			
		||||
                    ignore_empty_lines=True,
 | 
			
		||||
@ -564,12 +560,12 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
 | 
			
		||||
                        graph_aot,
 | 
			
		||||
                        """\
 | 
			
		||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
 | 
			
		||||
        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1])
 | 
			
		||||
        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg2_1])
 | 
			
		||||
        getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
 | 
			
		||||
        getitem_2: "f32[s0][1]cpu" = auto_functionalized_v2[2];  auto_functionalized_v2 = None
 | 
			
		||||
        add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
 | 
			
		||||
        copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2);  arg1_1 = getitem_2 = copy_ = None
 | 
			
		||||
        copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1);  arg2_1 = getitem_1 = copy__1 = None
 | 
			
		||||
        copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1);  arg1_1 = getitem_1 = copy_ = None
 | 
			
		||||
        copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_2);  arg2_1 = getitem_2 = copy__1 = None
 | 
			
		||||
        return (add,)""",  # noqa: B950
 | 
			
		||||
                        ignore_comments=True,
 | 
			
		||||
                        ignore_empty_lines=True,
 | 
			
		||||
@ -579,12 +575,12 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1
 | 
			
		||||
                        graph_aot,
 | 
			
		||||
                        """\
 | 
			
		||||
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
 | 
			
		||||
        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg0_1])
 | 
			
		||||
        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg0_1, arg1_1])
 | 
			
		||||
        getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
 | 
			
		||||
        getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2];  auto_functionalized_v2 = None
 | 
			
		||||
        add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
 | 
			
		||||
        copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_2);  arg0_1 = getitem_2 = copy_ = None
 | 
			
		||||
        copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1);  arg1_1 = getitem_1 = copy__1 = None
 | 
			
		||||
        copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1);  arg0_1 = getitem_1 = copy_ = None
 | 
			
		||||
        copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2);  arg1_1 = getitem_2 = copy__1 = None
 | 
			
		||||
        return (add,)""",  # noqa: B950
 | 
			
		||||
                        ignore_comments=True,
 | 
			
		||||
                        ignore_empty_lines=True,
 | 
			
		||||
@ -596,8 +592,8 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
 | 
			
		||||
                        graph_inductor,
 | 
			
		||||
                        """\
 | 
			
		||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1);  foo_default = None
 | 
			
		||||
        add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1)
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg1_1, arg2_1);  foo_default = None
 | 
			
		||||
        add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg2_1)
 | 
			
		||||
        copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy_ = None
 | 
			
		||||
        copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1);  arg2_1 = copy__1 = None
 | 
			
		||||
        return (add,)""",
 | 
			
		||||
@ -609,8 +605,8 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1
 | 
			
		||||
                        graph_inductor,
 | 
			
		||||
                        """\
 | 
			
		||||
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg1_1, arg0_1);  foo_default = None
 | 
			
		||||
        add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg0_1)
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg0_1, arg1_1);  foo_default = None
 | 
			
		||||
        add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
 | 
			
		||||
        copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1);  arg0_1 = copy_ = None
 | 
			
		||||
        copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy__1 = None
 | 
			
		||||
        return (add,)""",
 | 
			
		||||
@ -895,8 +891,8 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2
 | 
			
		||||
                    post_grad_graphs,
 | 
			
		||||
                    """\
 | 
			
		||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg2_1 = arg3_1 = arg0_1 = foo_default = None
 | 
			
		||||
        copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy_ = None
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg0_1, 2, arg1_1);  arg2_1 = arg3_1 = arg1_1 = foo_default = None
 | 
			
		||||
        copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1);  arg0_1 = copy_ = None
 | 
			
		||||
        return ()""",  # noqa: B950
 | 
			
		||||
                    ignore_comments=True,
 | 
			
		||||
                    ignore_empty_lines=True,
 | 
			
		||||
 | 
			
		||||
@ -31,6 +31,18 @@ def eager(gm, fake_tensor_inputs, **kwargs):
 | 
			
		||||
    return gm.forward
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_eager_backend_with_torch_function_mode(mode):
 | 
			
		||||
    """Used to trace HOPs (cond and while) for eager exectution, the metadata
 | 
			
		||||
    TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
 | 
			
		||||
    in the HOP, so we need to externally run this mode and not trace it."""
 | 
			
		||||
 | 
			
		||||
    def fn(gm, fake_tensor_inputs, **kwargs):
 | 
			
		||||
        with mode:
 | 
			
		||||
            return gm.forward
 | 
			
		||||
 | 
			
		||||
    return fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@register_backend
 | 
			
		||||
def eager_noexcept(gm, fake_tensor_inputs, **kwargs):
 | 
			
		||||
    if kwargs:
 | 
			
		||||
 | 
			
		||||
@ -94,6 +94,7 @@ from .symbolic_convert import (
 | 
			
		||||
from .trace_rules import is_numpy
 | 
			
		||||
from .utils import (
 | 
			
		||||
    CleanupManager,
 | 
			
		||||
    clear_torch_function_mode_stack,
 | 
			
		||||
    CompilationMetrics,
 | 
			
		||||
    counters,
 | 
			
		||||
    dynamo_timed,
 | 
			
		||||
@ -108,6 +109,7 @@ from .utils import (
 | 
			
		||||
    orig_code_map,
 | 
			
		||||
    record_compilation_metrics,
 | 
			
		||||
    reset_graph_break_dup_checker,
 | 
			
		||||
    set_torch_function_mode_stack,
 | 
			
		||||
    setup_compile_debug,
 | 
			
		||||
    troubleshooting_url,
 | 
			
		||||
    write_record_to_file,
 | 
			
		||||
@ -204,6 +206,8 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
 | 
			
		||||
            py_rng_state = random.getstate()
 | 
			
		||||
            torch_rng_state = torch.random.get_rng_state()
 | 
			
		||||
            cuda_rng_state = None
 | 
			
		||||
            prior_tf_mode_stack = torch.overrides._get_current_function_mode_stack()
 | 
			
		||||
            clear_torch_function_mode_stack()
 | 
			
		||||
            if torch.cuda.is_available():
 | 
			
		||||
                cuda_rng_state = torch.cuda.get_rng_state()
 | 
			
		||||
            allow_tf32 = torch._C._get_cublas_allow_tf32()
 | 
			
		||||
@ -220,6 +224,10 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
 | 
			
		||||
            finally:
 | 
			
		||||
                cleanup.close()
 | 
			
		||||
                exit_stack.close()
 | 
			
		||||
                assert (
 | 
			
		||||
                    torch._C._len_torch_function_stack() == 0
 | 
			
		||||
                ), "Torch function mode stack state changed while dynamo tracing, please report a bug"
 | 
			
		||||
                set_torch_function_mode_stack(prior_tf_mode_stack)
 | 
			
		||||
                torch._C._set_grad_enabled(prior_grad_mode)
 | 
			
		||||
                torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
 | 
			
		||||
                torch.use_deterministic_algorithms(
 | 
			
		||||
@ -605,6 +613,10 @@ def _compile(
 | 
			
		||||
    output: Optional[OutputGraph] = None
 | 
			
		||||
    tracer: Optional[InstructionTranslator] = None
 | 
			
		||||
 | 
			
		||||
    tf_mode_stack: List[
 | 
			
		||||
        torch.overrides.TorchFunctionMode
 | 
			
		||||
    ] = torch.overrides._get_current_function_mode_stack()
 | 
			
		||||
 | 
			
		||||
    @preserve_global_state
 | 
			
		||||
    def transform(
 | 
			
		||||
        instructions: List[Instruction], code_options: Dict[str, object]
 | 
			
		||||
@ -618,6 +630,7 @@ def _compile(
 | 
			
		||||
            locals,
 | 
			
		||||
            globals,
 | 
			
		||||
            builtins,
 | 
			
		||||
            tf_mode_stack,
 | 
			
		||||
            code_options,
 | 
			
		||||
            compiler_fn,
 | 
			
		||||
            one_graph,
 | 
			
		||||
 | 
			
		||||
@ -97,6 +97,7 @@ from .source import (
 | 
			
		||||
    ScriptObjectQualifiedNameSource,
 | 
			
		||||
    ShapeEnvSource,
 | 
			
		||||
    SubclassAttrListSource,
 | 
			
		||||
    TorchFunctionModeStackSource,
 | 
			
		||||
    TupleIteratorGetItemSource,
 | 
			
		||||
    TypeSource,
 | 
			
		||||
    UnspecializedBuiltinNNModuleSource,
 | 
			
		||||
@ -110,6 +111,7 @@ from .utils import (
 | 
			
		||||
    dict_keys_repr,
 | 
			
		||||
    get_custom_getattr,
 | 
			
		||||
    get_torch_function_mode_stack,
 | 
			
		||||
    get_torch_function_mode_stack_at,
 | 
			
		||||
    guard_failures,
 | 
			
		||||
    istype,
 | 
			
		||||
    key_is_id,
 | 
			
		||||
@ -313,6 +315,7 @@ CLOSURE_VARS = {
 | 
			
		||||
    "___dict_contains": lambda a, b: a in b,
 | 
			
		||||
    "___tuple_iterator_len": tuple_iterator_len,
 | 
			
		||||
    "___tuple_iterator_getitem": tuple_iterator_getitem,
 | 
			
		||||
    "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
 | 
			
		||||
    "__math_isnan": math.isnan,
 | 
			
		||||
    "__numpy_isnan": None if np is None else np.isnan,
 | 
			
		||||
    "inf": float("inf"),
 | 
			
		||||
@ -900,6 +903,15 @@ class GuardBuilder(GuardBuilderBase):
 | 
			
		||||
        ):
 | 
			
		||||
            assert base_guard_manager  # to make mypy happy
 | 
			
		||||
            out = base_guard_manager
 | 
			
		||||
        elif istype(source, TorchFunctionModeStackSource):
 | 
			
		||||
            out = root_guard_manager.lambda_manager(
 | 
			
		||||
                python_lambda=lambda _: get_torch_function_mode_stack_at(
 | 
			
		||||
                    source._get_index()
 | 
			
		||||
                ),
 | 
			
		||||
                source=source_name,
 | 
			
		||||
                example_value=example_value,
 | 
			
		||||
                guard_manager_enum=guard_manager_enum,
 | 
			
		||||
            )
 | 
			
		||||
        elif istype(source, GradSource):
 | 
			
		||||
            assert base_guard_manager  # to make mypy happy
 | 
			
		||||
            out = base_guard_manager.grad_manager(
 | 
			
		||||
@ -2206,6 +2218,8 @@ class CheckFunctionManager:
 | 
			
		||||
        self.output_graph = output_graph
 | 
			
		||||
        w_builder = None
 | 
			
		||||
 | 
			
		||||
        # NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing
 | 
			
		||||
        # in case a set default device call was made in the graph.
 | 
			
		||||
        self.torch_function_mode_stack = (
 | 
			
		||||
            output_graph.torch_function_mode_stack if output_graph else None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -1020,7 +1020,7 @@ class OutputGraph:
 | 
			
		||||
            prefix_insts.clear()
 | 
			
		||||
 | 
			
		||||
        for block in reversed(tx.block_stack):
 | 
			
		||||
            block.exit(tx)
 | 
			
		||||
            block.exit(tx, is_graph_break=reason.graph_break)
 | 
			
		||||
 | 
			
		||||
        self.cleanup_graph()
 | 
			
		||||
        tx.prune_dead_locals()
 | 
			
		||||
 | 
			
		||||
@ -25,6 +25,26 @@ if TYPE_CHECKING:
 | 
			
		||||
        sys as sys,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
from torch.overrides import BaseTorchFunctionMode
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# These classes handle support for TorchFunctionModes across
 | 
			
		||||
# graph breaks
 | 
			
		||||
# Today the TorchFunctionMode enter (for the classes we support)
 | 
			
		||||
# simply pushes the mode onto the stack. Since after this occurs
 | 
			
		||||
# the stack is mutated, and we replay these mutations, we don't need
 | 
			
		||||
# any cleanup logic to be run once the graph break occurs, we simply replay
 | 
			
		||||
# these mutations to ensure at the graph break the torch function mode stack is correct
 | 
			
		||||
#  and reconstruct the torch function mode stack normally
 | 
			
		||||
# when we compile the resume function on the other side of the break.
 | 
			
		||||
# However, to ensure we exit properly
 | 
			
		||||
# in the resume function, we need to re-enter the contexts as we do other contexts.
 | 
			
		||||
# These contexts do nothing on enter, but provide the correct exit logic to ensure
 | 
			
		||||
# the stack state is correct.
 | 
			
		||||
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
 | 
			
		||||
    def __enter__(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def index(iterator, item, start=0, end=None):
 | 
			
		||||
    from itertools import islice
 | 
			
		||||
 | 
			
		||||
@ -608,7 +608,7 @@ class TorchFunctionModeStackSource(Source):
 | 
			
		||||
    ind: int
 | 
			
		||||
 | 
			
		||||
    def name(self):
 | 
			
		||||
        return ""
 | 
			
		||||
        return f"___get_torch_function_mode_stack_at({self._get_index()})"
 | 
			
		||||
 | 
			
		||||
    def _get_index(self):
 | 
			
		||||
        from .variables.torch_function import TorchFunctionModeStackVariable
 | 
			
		||||
 | 
			
		||||
@ -19,20 +19,7 @@ import traceback
 | 
			
		||||
import types
 | 
			
		||||
import typing
 | 
			
		||||
import weakref
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    Callable,
 | 
			
		||||
    cast,
 | 
			
		||||
    Deque,
 | 
			
		||||
    Dict,
 | 
			
		||||
    List,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Set,
 | 
			
		||||
    Tuple,
 | 
			
		||||
    Type,
 | 
			
		||||
    TYPE_CHECKING,
 | 
			
		||||
    Union,
 | 
			
		||||
)
 | 
			
		||||
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -72,14 +59,12 @@ from .source import (
 | 
			
		||||
    GlobalWeakRefSource,
 | 
			
		||||
    LocalSource,
 | 
			
		||||
    Source,
 | 
			
		||||
    TorchFunctionModeStackSource,
 | 
			
		||||
)
 | 
			
		||||
from .trace_rules import is_builtin_constant, is_forbidden
 | 
			
		||||
from .utils import (
 | 
			
		||||
    counters,
 | 
			
		||||
    get_fake_value,
 | 
			
		||||
    get_instruction_source_311,
 | 
			
		||||
    get_torch_function_mode_stack,
 | 
			
		||||
    graph_break_dup_warning_checker,
 | 
			
		||||
    istype,
 | 
			
		||||
    LazyString,
 | 
			
		||||
@ -120,11 +105,10 @@ from .variables.misc import (
 | 
			
		||||
)
 | 
			
		||||
from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable
 | 
			
		||||
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from .variables.torch_function import TorchFunctionModeVariable
 | 
			
		||||
 | 
			
		||||
from .variables.torch_function import (
 | 
			
		||||
    SymbolicTorchFunctionState,
 | 
			
		||||
    TorchFunctionModeVariable,
 | 
			
		||||
)
 | 
			
		||||
from .variables.user_defined import (
 | 
			
		||||
    RemovableHandleVariable,
 | 
			
		||||
    UserDefinedClassVariable,
 | 
			
		||||
@ -283,9 +267,12 @@ class BlockStackEntry:
 | 
			
		||||
        else:
 | 
			
		||||
            return ReenterWith(self.stack_index)
 | 
			
		||||
 | 
			
		||||
    def exit(self, tx):
 | 
			
		||||
    def exit(self, tx, is_graph_break):
 | 
			
		||||
        assert self.with_context is not None
 | 
			
		||||
        return self.with_context.exit(tx)
 | 
			
		||||
        if (
 | 
			
		||||
            is_graph_break and self.with_context.exit_on_graph_break()
 | 
			
		||||
        ) or not is_graph_break:
 | 
			
		||||
            return self.with_context.exit(tx)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReturnValueOp(Exception):
 | 
			
		||||
@ -651,8 +638,12 @@ def break_graph_if_unsupported(*, push):
 | 
			
		||||
            cleanup: List[Instruction] = []
 | 
			
		||||
            # Reconstruct the context variable CLASS in the block stack
 | 
			
		||||
            for b in self.block_stack:
 | 
			
		||||
                # Don't exit any modes we have entered,
 | 
			
		||||
                # output bytecode will mutate the tf mode stack accordingly
 | 
			
		||||
                if isinstance(b.with_context, TorchFunctionModeVariable):
 | 
			
		||||
                    continue
 | 
			
		||||
                assert b.with_context is not None
 | 
			
		||||
                assert isinstance(b.with_context, ContextWrappingVariable)
 | 
			
		||||
                assert isinstance(b.with_context, (ContextWrappingVariable))
 | 
			
		||||
                b.with_context.reconstruct_type(cg)
 | 
			
		||||
                cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup))
 | 
			
		||||
            self.output.add_output_instructions(cg.get_instructions())
 | 
			
		||||
@ -728,7 +719,7 @@ class InstructionTranslatorBase(
 | 
			
		||||
    output: OutputGraph
 | 
			
		||||
    symbolic_locals: Dict[str, VariableTracker]
 | 
			
		||||
    symbolic_globals: Dict[str, VariableTracker]
 | 
			
		||||
    symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"]
 | 
			
		||||
    symbolic_torch_function_state: SymbolicTorchFunctionState
 | 
			
		||||
    stack: List[VariableTracker]
 | 
			
		||||
    instruction_pointer: Optional[int]
 | 
			
		||||
    current_instruction: Instruction
 | 
			
		||||
@ -2305,7 +2296,11 @@ class InstructionTranslatorBase(
 | 
			
		||||
        ):
 | 
			
		||||
            unimplemented(f"{inst.opname} {ctx}")
 | 
			
		||||
 | 
			
		||||
        if isinstance(ctx, GenericContextWrappingVariable):
 | 
			
		||||
        if (
 | 
			
		||||
            isinstance(ctx, GenericContextWrappingVariable)
 | 
			
		||||
            and not ctx.supports_graph_breaks()
 | 
			
		||||
        ):
 | 
			
		||||
            breakpoint()
 | 
			
		||||
            self.generic_context_manager_depth += 1
 | 
			
		||||
 | 
			
		||||
        # Need this redundant check for mypy
 | 
			
		||||
@ -2548,7 +2543,7 @@ class InstructionTranslatorBase(
 | 
			
		||||
        code_options: Dict[str, Any],
 | 
			
		||||
        symbolic_locals: Dict[str, VariableTracker],
 | 
			
		||||
        symbolic_globals: Dict[str, VariableTracker],
 | 
			
		||||
        symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"],
 | 
			
		||||
        symbolic_torch_function_state: SymbolicTorchFunctionState,
 | 
			
		||||
        f_code: types.CodeType,
 | 
			
		||||
        export: bool,
 | 
			
		||||
        inline_depth: int,
 | 
			
		||||
@ -2563,7 +2558,7 @@ class InstructionTranslatorBase(
 | 
			
		||||
        self.output = output
 | 
			
		||||
        self.symbolic_locals = symbolic_locals
 | 
			
		||||
        self.symbolic_globals = symbolic_globals
 | 
			
		||||
        self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack
 | 
			
		||||
        self.symbolic_torch_function_state = symbolic_torch_function_state
 | 
			
		||||
        self.stack = []
 | 
			
		||||
        # stack of variable names for tracking 3.13 closures
 | 
			
		||||
        self.name_stack: list[Any] = []
 | 
			
		||||
@ -2652,6 +2647,7 @@ class InstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
        f_locals,
 | 
			
		||||
        f_globals,
 | 
			
		||||
        f_builtins,
 | 
			
		||||
        torch_function_mode_stack,
 | 
			
		||||
        code_options,
 | 
			
		||||
        compiler_fn,
 | 
			
		||||
        one_graph,
 | 
			
		||||
@ -2686,7 +2682,7 @@ class InstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
            symbolic_locals={},  # set below
 | 
			
		||||
            # A global var is inserted only after a STORE_GLOBAL happens to it
 | 
			
		||||
            symbolic_globals={},
 | 
			
		||||
            symbolic_torch_function_mode_stack=collections.deque(),
 | 
			
		||||
            symbolic_torch_function_state=None,  # type: ignore[arg-type] # set below
 | 
			
		||||
            f_code=f_code,
 | 
			
		||||
            export=export,
 | 
			
		||||
            inline_depth=0,
 | 
			
		||||
@ -2721,7 +2717,9 @@ class InstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
                if k in f_locals
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            self._init_torch_function_mode_stack()
 | 
			
		||||
            self.symbolic_torch_function_state = SymbolicTorchFunctionState(
 | 
			
		||||
                torch_function_mode_stack
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = []
 | 
			
		||||
            if export:
 | 
			
		||||
@ -2762,29 +2760,6 @@ class InstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
            )
 | 
			
		||||
            unimplemented(msg)
 | 
			
		||||
 | 
			
		||||
    def _init_torch_function_mode_stack(self):
 | 
			
		||||
        from .variables.torch_function import TorchFunctionModeStackVariable
 | 
			
		||||
 | 
			
		||||
        TorchFunctionModeStackVariable.reset()
 | 
			
		||||
 | 
			
		||||
        self.symbolic_torch_function_mode_stack: Deque[
 | 
			
		||||
            TorchFunctionModeVariable
 | 
			
		||||
        ] = collections.deque()
 | 
			
		||||
        # We want to retrieve all modes to properly reconstruct the stack if needed
 | 
			
		||||
        py_stack = get_torch_function_mode_stack(filter_ignored=False)
 | 
			
		||||
 | 
			
		||||
        if py_stack:
 | 
			
		||||
            has_device_context = isinstance(
 | 
			
		||||
                py_stack[0], torch.utils._device.DeviceContext
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        for i, val in enumerate(py_stack):
 | 
			
		||||
            self.symbolic_torch_function_mode_stack.append(
 | 
			
		||||
                variables.LazyVariableTracker.create(
 | 
			
		||||
                    val, source=TorchFunctionModeStackSource(i)
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def get_example_value(self, source: Source):
 | 
			
		||||
        if isinstance(source, LocalSource):
 | 
			
		||||
            return self.f_locals[source.local_name]
 | 
			
		||||
@ -3116,7 +3091,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
                code,
 | 
			
		||||
                sub_locals,
 | 
			
		||||
                parent.symbolic_globals,
 | 
			
		||||
                parent.symbolic_torch_function_mode_stack,
 | 
			
		||||
                parent.symbolic_torch_function_state,
 | 
			
		||||
                closure_cells,
 | 
			
		||||
                func,
 | 
			
		||||
            )
 | 
			
		||||
@ -3126,7 +3101,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
                code,
 | 
			
		||||
                sub_locals,
 | 
			
		||||
                parent.symbolic_globals,
 | 
			
		||||
                parent.symbolic_torch_function_mode_stack,
 | 
			
		||||
                parent.symbolic_torch_function_state,
 | 
			
		||||
                closure_cells,
 | 
			
		||||
                func,
 | 
			
		||||
            )
 | 
			
		||||
@ -3179,7 +3154,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
        code: types.CodeType,
 | 
			
		||||
        symbolic_locals: Dict[str, VariableTracker],
 | 
			
		||||
        symbolic_globals: Dict[str, VariableTracker],
 | 
			
		||||
        symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"],
 | 
			
		||||
        symbolic_torch_function_state: SymbolicTorchFunctionState,
 | 
			
		||||
        closure_cells: Dict[str, VariableTracker],
 | 
			
		||||
        funcvar: BaseUserFunctionVariable,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
@ -3196,7 +3171,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
            f_builtins=f_builtins,
 | 
			
		||||
            symbolic_locals=symbolic_locals,
 | 
			
		||||
            symbolic_globals=symbolic_globals,
 | 
			
		||||
            symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack,
 | 
			
		||||
            symbolic_torch_function_state=symbolic_torch_function_state,
 | 
			
		||||
            instructions=instructions,
 | 
			
		||||
            code_options={k: getattr(code, k) for k in get_code_keys()},
 | 
			
		||||
            f_code=code,
 | 
			
		||||
 | 
			
		||||
@ -3254,6 +3254,7 @@ MOD_INLINELIST = [
 | 
			
		||||
    "torch.testing",
 | 
			
		||||
    "torch.utils._content_store",
 | 
			
		||||
    "torch.utils._contextlib",
 | 
			
		||||
    "torch.utils._device",
 | 
			
		||||
    "torch.utils._foreach_utils",
 | 
			
		||||
    "torch.utils._python_dispatch",
 | 
			
		||||
    "torch.utils._pytree",
 | 
			
		||||
@ -3588,7 +3589,9 @@ def lookup_inner(
 | 
			
		||||
            if reasons is not None:
 | 
			
		||||
                reasons.add("func name is patched_init")
 | 
			
		||||
            return SkipFunctionVariable
 | 
			
		||||
        elif name == "__torch_function__":
 | 
			
		||||
        elif name == "__torch_function__" or (
 | 
			
		||||
            obj and obj.__name__ == "__torch_function__"
 | 
			
		||||
        ):
 | 
			
		||||
            if reasons is not None:
 | 
			
		||||
                reasons.add("func name is __torch_function__")
 | 
			
		||||
            return UserFunctionVariable
 | 
			
		||||
 | 
			
		||||
@ -63,7 +63,6 @@ import torch.fx.experimental.symbolic_shapes
 | 
			
		||||
import torch.utils._pytree as pytree
 | 
			
		||||
from torch import fx
 | 
			
		||||
from torch._C import (
 | 
			
		||||
    _get_function_stack_at,
 | 
			
		||||
    _instruction_counter,
 | 
			
		||||
    _len_torch_function_stack,
 | 
			
		||||
    _pop_torch_function_stack,
 | 
			
		||||
@ -3065,7 +3064,9 @@ def is_parameter_freezing():
 | 
			
		||||
def get_torch_function_mode_stack(filter_ignored=True):
 | 
			
		||||
    from .variables.torch_function import IGNORED_MODES
 | 
			
		||||
 | 
			
		||||
    stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())]
 | 
			
		||||
    stack = [
 | 
			
		||||
        get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack())
 | 
			
		||||
    ]
 | 
			
		||||
    if filter_ignored:
 | 
			
		||||
        stack = [mode for mode in stack if type(mode) not in IGNORED_MODES]
 | 
			
		||||
 | 
			
		||||
@ -3085,6 +3086,11 @@ def set_torch_function_mode_stack(stack):
 | 
			
		||||
        _push_on_torch_function_stack(mode)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def clear_torch_function_mode_stack():
 | 
			
		||||
    for i in range(_len_torch_function_stack()):
 | 
			
		||||
        _pop_torch_function_stack()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def verify_guard_fn_signature(value):
 | 
			
		||||
    fn = value.__metadata_guard__
 | 
			
		||||
    sig = inspect.signature(fn)
 | 
			
		||||
 | 
			
		||||
@ -125,6 +125,12 @@ class ContextWrappingVariable(VariableTracker):
 | 
			
		||||
        if isinstance(args[0], UserFunctionVariable):
 | 
			
		||||
            return WrappedUserFunctionVariable(args[0], self)
 | 
			
		||||
 | 
			
		||||
    def supports_graph_breaks(self):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def exit_on_graph_break(self):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GenericContextWrappingVariable(UserDefinedObjectVariable):
 | 
			
		||||
    # Some methods in ContextWrappingVariable assumes the arguments are
 | 
			
		||||
@ -183,6 +189,12 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable):
 | 
			
		||||
        tx.generic_context_manager_depth -= 1
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def supports_graph_breaks(self):
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def exit_on_graph_break(self):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
 | 
			
		||||
    """represents torch grad requries grad"""
 | 
			
		||||
@ -637,6 +649,8 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
 | 
			
		||||
 | 
			
		||||
    def _call_func(self, tx: "InstructionTranslator", values):
 | 
			
		||||
        assert len(values) == 1
 | 
			
		||||
        tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0]
 | 
			
		||||
        tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0]
 | 
			
		||||
        tx.output.set_torch_function_state(values[0])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -149,6 +149,15 @@ tracing_state_functions = {
 | 
			
		||||
bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@functools.lru_cache(None)
 | 
			
		||||
def get_overridable_functions():
 | 
			
		||||
    from itertools import chain
 | 
			
		||||
 | 
			
		||||
    from torch.overrides import get_overridable_functions as get_overridable_functions_
 | 
			
		||||
 | 
			
		||||
    return set(chain(*get_overridable_functions_().values()))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseTorchVariable(VariableTracker):
 | 
			
		||||
    """common base for all torch.* functions, classes, modules and other things"""
 | 
			
		||||
 | 
			
		||||
@ -782,10 +791,10 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
 | 
			
		||||
            self, tx: "InstructionTranslator", *args, **kwargs
 | 
			
		||||
        ):
 | 
			
		||||
            assert not args and not kwargs
 | 
			
		||||
            if not tx.symbolic_torch_function_mode_stack:
 | 
			
		||||
            if not tx.symbolic_torch_function_state.mode_stack:
 | 
			
		||||
                raise unimplemented("Popping from an empty torch function mode stack")
 | 
			
		||||
            TorchFunctionModeStackVariable.register_mutation(tx)
 | 
			
		||||
            return tx.symbolic_torch_function_mode_stack.pop()
 | 
			
		||||
            return tx.symbolic_torch_function_state.pop_torch_function_mode()
 | 
			
		||||
 | 
			
		||||
        @register(torch._C._push_on_torch_function_stack)
 | 
			
		||||
        def handle_push_torch_function(
 | 
			
		||||
@ -793,7 +802,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
 | 
			
		||||
        ):
 | 
			
		||||
            assert len(args) == 1 and not kwargs
 | 
			
		||||
            TorchFunctionModeStackVariable.register_mutation(tx)
 | 
			
		||||
            tx.symbolic_torch_function_mode_stack.append(args[0])
 | 
			
		||||
            tx.symbolic_torch_function_state.push_torch_function_mode(args[0])
 | 
			
		||||
            return ConstantVariable.create(None)
 | 
			
		||||
 | 
			
		||||
        @register(torch._C._len_torch_function_stack)
 | 
			
		||||
@ -801,7 +810,9 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
 | 
			
		||||
            self, tx: "InstructionTranslator", *args, **kwargs
 | 
			
		||||
        ):
 | 
			
		||||
            assert not args and not kwargs
 | 
			
		||||
            return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack))
 | 
			
		||||
            return ConstantVariable.create(
 | 
			
		||||
                len(tx.symbolic_torch_function_state.mode_stack)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        @register(torch.set_default_device)
 | 
			
		||||
        def handle_set_default_device(
 | 
			
		||||
@ -833,6 +844,9 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
 | 
			
		||||
        from . import ConstantVariable, SymNodeVariable, TensorVariable
 | 
			
		||||
        from .builder import wrap_fx_proxy
 | 
			
		||||
 | 
			
		||||
        if self.torch_function_override_enabled(tx, args, kwargs):
 | 
			
		||||
            return dispatch_torch_function(tx, self, args, kwargs)
 | 
			
		||||
 | 
			
		||||
        if self.can_constant_fold_through() and check_unspec_or_constant_args(
 | 
			
		||||
            args, kwargs
 | 
			
		||||
        ):
 | 
			
		||||
@ -850,147 +864,144 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
 | 
			
		||||
            if result:
 | 
			
		||||
                return result
 | 
			
		||||
 | 
			
		||||
        if can_dispatch_torch_function(tx, args, kwargs):
 | 
			
		||||
            return dispatch_torch_function(tx, self, args, kwargs)
 | 
			
		||||
        else:
 | 
			
		||||
            any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
 | 
			
		||||
        any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
 | 
			
		||||
 | 
			
		||||
            all_ints_or_floats = all(
 | 
			
		||||
                isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
 | 
			
		||||
                for x in args
 | 
			
		||||
            )
 | 
			
		||||
            if (
 | 
			
		||||
                getattr(self.value, "__module__", "") == "torch"
 | 
			
		||||
                and self.value.__name__ in bin_ops
 | 
			
		||||
                and any_symints_or_symfloats
 | 
			
		||||
                and all_ints_or_floats
 | 
			
		||||
            ):
 | 
			
		||||
                msg = f"""\
 | 
			
		||||
        all_ints_or_floats = all(
 | 
			
		||||
            isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
 | 
			
		||||
            for x in args
 | 
			
		||||
        )
 | 
			
		||||
        if (
 | 
			
		||||
            getattr(self.value, "__module__", "") == "torch"
 | 
			
		||||
            and self.value.__name__ in bin_ops
 | 
			
		||||
            and any_symints_or_symfloats
 | 
			
		||||
            and all_ints_or_floats
 | 
			
		||||
        ):
 | 
			
		||||
            msg = f"""\
 | 
			
		||||
Calling {str(self.value)} on only torch.SymInt arguments is not yet supported.
 | 
			
		||||
To support this behavior, we need to allow const-propping tensors that store symint data.
 | 
			
		||||
For now, dynamo will explicitly graph break when it encounters user code with this behavior.
 | 
			
		||||
"""
 | 
			
		||||
                log.warning(msg)
 | 
			
		||||
                unimplemented(msg)
 | 
			
		||||
            log.warning(msg)
 | 
			
		||||
            unimplemented(msg)
 | 
			
		||||
 | 
			
		||||
            # TODO(voz): Replace w/ dynamic shape rewrite table.
 | 
			
		||||
            # Ideally, we would be able to do this at ctor time, but alas we need a combination
 | 
			
		||||
            # of value + args to determine this.
 | 
			
		||||
            fn_ = self.value
 | 
			
		||||
            if any_symints_or_symfloats:
 | 
			
		||||
                torch_sym_op = f"_sym_{self.value.__name__}"
 | 
			
		||||
                if getattr(self.value, "__module__", None) == "math" and hasattr(
 | 
			
		||||
                    torch, torch_sym_op
 | 
			
		||||
                ):
 | 
			
		||||
                    fn_ = getattr(torch, torch_sym_op)
 | 
			
		||||
        # TODO(voz): Replace w/ dynamic shape rewrite table.
 | 
			
		||||
        # Ideally, we would be able to do this at ctor time, but alas we need a combination
 | 
			
		||||
        # of value + args to determine this.
 | 
			
		||||
        fn_ = self.value
 | 
			
		||||
        if any_symints_or_symfloats:
 | 
			
		||||
            torch_sym_op = f"_sym_{self.value.__name__}"
 | 
			
		||||
            if getattr(self.value, "__module__", None) == "math" and hasattr(
 | 
			
		||||
                torch, torch_sym_op
 | 
			
		||||
            ):
 | 
			
		||||
                fn_ = getattr(torch, torch_sym_op)
 | 
			
		||||
 | 
			
		||||
            fake_out_shape = None
 | 
			
		||||
            if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
 | 
			
		||||
                # Calling fake tensor propagation can mutate the out= tensor in
 | 
			
		||||
                # tx.output.tracked_fakes. tracked_fakes are used to apply
 | 
			
		||||
                # symbolic_shape guards. Mutating them destroys the information
 | 
			
		||||
                # prior to tracing, which is essential for creating right
 | 
			
		||||
                # guards. So save the shape now, and check later if it has
 | 
			
		||||
                # changed. If it has, graph break.
 | 
			
		||||
                fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
 | 
			
		||||
        fake_out_shape = None
 | 
			
		||||
        if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
 | 
			
		||||
            # Calling fake tensor propagation can mutate the out= tensor in
 | 
			
		||||
            # tx.output.tracked_fakes. tracked_fakes are used to apply
 | 
			
		||||
            # symbolic_shape guards. Mutating them destroys the information
 | 
			
		||||
            # prior to tracing, which is essential for creating right
 | 
			
		||||
            # guards. So save the shape now, and check later if it has
 | 
			
		||||
            # changed. If it has, graph break.
 | 
			
		||||
            fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
 | 
			
		||||
 | 
			
		||||
            tensor_variable = wrap_fx_proxy(
 | 
			
		||||
                tx=tx,
 | 
			
		||||
                proxy=tx.output.create_proxy(
 | 
			
		||||
                    "call_function",
 | 
			
		||||
                    fn_,
 | 
			
		||||
                    *proxy_args_kwargs(args, kwargs),
 | 
			
		||||
                ),
 | 
			
		||||
        tensor_variable = wrap_fx_proxy(
 | 
			
		||||
            tx=tx,
 | 
			
		||||
            proxy=tx.output.create_proxy(
 | 
			
		||||
                "call_function",
 | 
			
		||||
                fn_,
 | 
			
		||||
                *proxy_args_kwargs(args, kwargs),
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            isinstance(tensor_variable, TensorVariable)
 | 
			
		||||
            and "requires_grad" in kwargs
 | 
			
		||||
            and kwargs["requires_grad"].as_python_constant()
 | 
			
		||||
        ):
 | 
			
		||||
            unimplemented(
 | 
			
		||||
                """factory functions that return tensors that require grad are not supported.
 | 
			
		||||
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                isinstance(tensor_variable, TensorVariable)
 | 
			
		||||
                and "requires_grad" in kwargs
 | 
			
		||||
                and kwargs["requires_grad"].as_python_constant()
 | 
			
		||||
            ):
 | 
			
		||||
                unimplemented(
 | 
			
		||||
                    """factory functions that return tensors that require grad are not supported.
 | 
			
		||||
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if "out" in kwargs and not (
 | 
			
		||||
                isinstance(kwargs["out"], variables.ConstantVariable)
 | 
			
		||||
                and kwargs["out"].as_python_constant() is None
 | 
			
		||||
            ):
 | 
			
		||||
                # out variants of torch operators like torch.sort and
 | 
			
		||||
                # torch.sigmoid mutate the tensors in the out field. Track such
 | 
			
		||||
                # tensors and rewrite the symbolic locals.
 | 
			
		||||
                if isinstance(tensor_variable, TupleVariable):
 | 
			
		||||
                    assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
 | 
			
		||||
                    output_tensor_names = [
 | 
			
		||||
                        tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
 | 
			
		||||
                    ]
 | 
			
		||||
                    for idx, name in enumerate(output_tensor_names):
 | 
			
		||||
                        if name in tx.symbolic_locals:
 | 
			
		||||
                            tx.symbolic_locals[name] = tensor_variable.items[idx]
 | 
			
		||||
                    for out_tensor, result_tensor in zip(
 | 
			
		||||
                        kwargs["out"].items, tensor_variable.items
 | 
			
		||||
                    ):
 | 
			
		||||
                        if (
 | 
			
		||||
                            out_tensor.source
 | 
			
		||||
                            and out_tensor in tx.output.graphargs
 | 
			
		||||
                            and isinstance(out_tensor, variables.TensorVariable)
 | 
			
		||||
                            and isinstance(result_tensor, variables.TensorVariable)
 | 
			
		||||
                            and out_tensor.size != result_tensor.size
 | 
			
		||||
                        ):
 | 
			
		||||
                            # It's hard to get out variants with resizing on graph inputs work
 | 
			
		||||
                            # properly across dynamo/aot/inductor, just fall back.
 | 
			
		||||
                            unimplemented("out variants with resizing on graph inputs")
 | 
			
		||||
                elif isinstance(tensor_variable, TensorVariable):
 | 
			
		||||
                    assert isinstance(kwargs["out"], TensorVariable)
 | 
			
		||||
                    assert "example_value" in kwargs["out"].proxy.node.meta
 | 
			
		||||
                    fake_tensor = tensor_variable.proxy.node.meta["example_value"]
 | 
			
		||||
                    fake_out = kwargs["out"].proxy.node.meta["example_value"]
 | 
			
		||||
        if "out" in kwargs and not (
 | 
			
		||||
            isinstance(kwargs["out"], variables.ConstantVariable)
 | 
			
		||||
            and kwargs["out"].as_python_constant() is None
 | 
			
		||||
        ):
 | 
			
		||||
            # out variants of torch operators like torch.sort and
 | 
			
		||||
            # torch.sigmoid mutate the tensors in the out field. Track such
 | 
			
		||||
            # tensors and rewrite the symbolic locals.
 | 
			
		||||
            if isinstance(tensor_variable, TupleVariable):
 | 
			
		||||
                assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
 | 
			
		||||
                output_tensor_names = [
 | 
			
		||||
                    tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
 | 
			
		||||
                ]
 | 
			
		||||
                for idx, name in enumerate(output_tensor_names):
 | 
			
		||||
                    if name in tx.symbolic_locals:
 | 
			
		||||
                        tx.symbolic_locals[name] = tensor_variable.items[idx]
 | 
			
		||||
                for out_tensor, result_tensor in zip(
 | 
			
		||||
                    kwargs["out"].items, tensor_variable.items
 | 
			
		||||
                ):
 | 
			
		||||
                    if (
 | 
			
		||||
                        kwargs["out"].source
 | 
			
		||||
                        and kwargs["out"] in tx.output.graphargs
 | 
			
		||||
                        and fake_out_shape != fake_tensor.shape
 | 
			
		||||
                        out_tensor.source
 | 
			
		||||
                        and out_tensor in tx.output.graphargs
 | 
			
		||||
                        and isinstance(out_tensor, variables.TensorVariable)
 | 
			
		||||
                        and isinstance(result_tensor, variables.TensorVariable)
 | 
			
		||||
                        and out_tensor.size != result_tensor.size
 | 
			
		||||
                    ):
 | 
			
		||||
                        # It's hard to get out variants with resizing on graph inputs work
 | 
			
		||||
                        # properly across dynamo/aot/inductor, just fall back.
 | 
			
		||||
                        unimplemented("out variants with resizing on graph inputs")
 | 
			
		||||
            elif isinstance(tensor_variable, TensorVariable):
 | 
			
		||||
                assert isinstance(kwargs["out"], TensorVariable)
 | 
			
		||||
                assert "example_value" in kwargs["out"].proxy.node.meta
 | 
			
		||||
                fake_tensor = tensor_variable.proxy.node.meta["example_value"]
 | 
			
		||||
                fake_out = kwargs["out"].proxy.node.meta["example_value"]
 | 
			
		||||
                if (
 | 
			
		||||
                    kwargs["out"].source
 | 
			
		||||
                    and kwargs["out"] in tx.output.graphargs
 | 
			
		||||
                    and fake_out_shape != fake_tensor.shape
 | 
			
		||||
                ):
 | 
			
		||||
                    # It's hard to get out variants with resizing on graph inputs work
 | 
			
		||||
                    # properly across dynamo/aot/inductor, just fall back.
 | 
			
		||||
                    unimplemented("out variants with resizing on graph inputs")
 | 
			
		||||
                if not torch._prims_common.is_contiguous(fake_out):
 | 
			
		||||
                    # It's difficult to handle strides correctly in functionalization
 | 
			
		||||
                    # when calling an out= op with a non-contiguous out argument
 | 
			
		||||
                    unimplemented(
 | 
			
		||||
                        "out= op was called where output tensor was non-contiguous"
 | 
			
		||||
                    )
 | 
			
		||||
                name = tx.find_symbolic_locals_name(kwargs["out"])
 | 
			
		||||
                if name in tx.symbolic_locals:
 | 
			
		||||
                    tx.symbolic_locals[name] = tensor_variable
 | 
			
		||||
            elif (
 | 
			
		||||
                isinstance(tensor_variable, ConstantVariable)
 | 
			
		||||
                and tensor_variable.value is None
 | 
			
		||||
            ):
 | 
			
		||||
                # Handle out-variant custom ops that return None.
 | 
			
		||||
                if isinstance(kwargs["out"], TensorVariable):
 | 
			
		||||
                    assert "example_value" in kwargs["out"].proxy.node.meta
 | 
			
		||||
                    fake_out = kwargs["out"].proxy.node.meta["example_value"]
 | 
			
		||||
                    if not torch._prims_common.is_contiguous(fake_out):
 | 
			
		||||
                        # It's difficult to handle strides correctly in functionalization
 | 
			
		||||
                        # when calling an out= op with a non-contiguous out argument
 | 
			
		||||
                        unimplemented(
 | 
			
		||||
                            "out= op was called where output tensor was non-contiguous"
 | 
			
		||||
                        )
 | 
			
		||||
                    name = tx.find_symbolic_locals_name(kwargs["out"])
 | 
			
		||||
                    if name in tx.symbolic_locals:
 | 
			
		||||
                        tx.symbolic_locals[name] = tensor_variable
 | 
			
		||||
                elif (
 | 
			
		||||
                    isinstance(tensor_variable, ConstantVariable)
 | 
			
		||||
                    and tensor_variable.value is None
 | 
			
		||||
                ):
 | 
			
		||||
                    # Handle out-variant custom ops that return None.
 | 
			
		||||
                    if isinstance(kwargs["out"], TensorVariable):
 | 
			
		||||
                        assert "example_value" in kwargs["out"].proxy.node.meta
 | 
			
		||||
                        fake_out = kwargs["out"].proxy.node.meta["example_value"]
 | 
			
		||||
                elif isinstance(kwargs["out"], ListVariable):
 | 
			
		||||
                    for idx, x in enumerate(kwargs["out"].items):
 | 
			
		||||
                        assert "example_value" in x.proxy.node.meta  # type: ignore[attr-defined]
 | 
			
		||||
                        fake_out = x.proxy.node.meta["example_value"]  # type: ignore[attr-defined]
 | 
			
		||||
                        if not torch._prims_common.is_contiguous(fake_out):
 | 
			
		||||
                            # It's difficult to handle strides correctly in functionalization
 | 
			
		||||
                            # when calling an out= op with a non-contiguous out argument
 | 
			
		||||
                            unimplemented(
 | 
			
		||||
                                "out= op was called where output tensor was non-contiguous"
 | 
			
		||||
                                "out= op was called where some of the output tensors were non-contiguous"
 | 
			
		||||
                            )
 | 
			
		||||
                    elif isinstance(kwargs["out"], ListVariable):
 | 
			
		||||
                        for idx, x in enumerate(kwargs["out"].items):
 | 
			
		||||
                            assert "example_value" in x.proxy.node.meta  # type: ignore[attr-defined]
 | 
			
		||||
                            fake_out = x.proxy.node.meta["example_value"]  # type: ignore[attr-defined]
 | 
			
		||||
                            if not torch._prims_common.is_contiguous(fake_out):
 | 
			
		||||
                                # It's difficult to handle strides correctly in functionalization
 | 
			
		||||
                                # when calling an out= op with a non-contiguous out argument
 | 
			
		||||
                                unimplemented(
 | 
			
		||||
                                    "out= op was called where some of the output tensors were non-contiguous"
 | 
			
		||||
                                )
 | 
			
		||||
                else:
 | 
			
		||||
                    unimplemented(f"out variant of {type(kwargs['out'])}")
 | 
			
		||||
            else:
 | 
			
		||||
                unimplemented(f"out variant of {type(kwargs['out'])}")
 | 
			
		||||
 | 
			
		||||
            return tensor_variable
 | 
			
		||||
        return tensor_variable
 | 
			
		||||
 | 
			
		||||
    def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs):
 | 
			
		||||
        """inline behavior of torch.nn.modules.utils._ntuple"""
 | 
			
		||||
@ -1118,3 +1129,9 @@ Either create the tensor outside the compiled region, or do not set the tensor t
 | 
			
		||||
            source
 | 
			
		||||
        )
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def torch_function_override_enabled(self, tx, args, kwargs):
 | 
			
		||||
        return (
 | 
			
		||||
            self.get_function() in get_overridable_functions()
 | 
			
		||||
            and can_dispatch_torch_function(tx, args, kwargs)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -1,20 +1,34 @@
 | 
			
		||||
# mypy: ignore-errors
 | 
			
		||||
 | 
			
		||||
import collections
 | 
			
		||||
import contextlib
 | 
			
		||||
import inspect
 | 
			
		||||
from typing import Dict, List, TYPE_CHECKING
 | 
			
		||||
from typing import Deque, Dict, List, TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
import torch._C
 | 
			
		||||
import torch.utils._pytree as pytree
 | 
			
		||||
from torch._guards import Source
 | 
			
		||||
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
 | 
			
		||||
from torch.overrides import (
 | 
			
		||||
    _get_overloaded_args,
 | 
			
		||||
    get_default_nowrap_functions,
 | 
			
		||||
    TorchFunctionMode,
 | 
			
		||||
)
 | 
			
		||||
from torch.utils._device import DeviceContext
 | 
			
		||||
 | 
			
		||||
from ..exc import unimplemented
 | 
			
		||||
from ..guards import GuardBuilder, install_guard
 | 
			
		||||
from ..polyfills import NoEnterTorchFunctionMode
 | 
			
		||||
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
 | 
			
		||||
from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
 | 
			
		||||
from ..utils import (
 | 
			
		||||
    class_has_getattribute,
 | 
			
		||||
    get_safe_global_name,
 | 
			
		||||
    has_torch_function,
 | 
			
		||||
    is_tensor_base_attr_getter,
 | 
			
		||||
)
 | 
			
		||||
from .base import VariableTracker
 | 
			
		||||
from .constant import ConstantVariable
 | 
			
		||||
from .ctx_manager import ContextWrappingVariable
 | 
			
		||||
from .ctx_manager import GenericContextWrappingVariable
 | 
			
		||||
from .lazy import LazyVariableTracker
 | 
			
		||||
from .lists import TupleVariable
 | 
			
		||||
from .tensor import TensorSubclassVariable, TensorVariable
 | 
			
		||||
from .user_defined import UserDefinedObjectVariable
 | 
			
		||||
@ -59,6 +73,67 @@ banned_attrs = [
 | 
			
		||||
IGNORED_MODES = {DeviceContext}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SymbolicTorchFunctionState:
 | 
			
		||||
    def __init__(self, py_stack):
 | 
			
		||||
        # This is annoyingly complicated because of how the torch function subclass + mode C API was designed
 | 
			
		||||
        # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass
 | 
			
		||||
        # These are their definitions:
 | 
			
		||||
        # 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered
 | 
			
		||||
        # (if either are entered, this will be False)
 | 
			
		||||
        # 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR
 | 
			
		||||
        # torch._C.DisableTorchFunction has been entered
 | 
			
		||||
        # To disambiguate these and keep myself sane I added a C API to check whether all torch function
 | 
			
		||||
        # concepts (modes and subclasses) are enabled.
 | 
			
		||||
        # This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate
 | 
			
		||||
        # the stack length from the enablement state of torch function modes.
 | 
			
		||||
        # This is important because now if a mode is pushed while dynamo is tracing, we know whether
 | 
			
		||||
        # or not torch function modes are enabled and whether we should trace it.
 | 
			
		||||
        self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled()
 | 
			
		||||
 | 
			
		||||
        # This differs from the C API of the same name
 | 
			
		||||
        # this will only be false iff we have entered torch._C.DisableTorchFunction
 | 
			
		||||
        # and does not take into account the mode stack length, while the C API bundles these
 | 
			
		||||
        # two concepts
 | 
			
		||||
        self.torch_function_mode_enabled = (
 | 
			
		||||
            not torch._C._is_torch_function_all_disabled()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.cur_mode = None
 | 
			
		||||
 | 
			
		||||
        TorchFunctionModeStackVariable.reset()
 | 
			
		||||
 | 
			
		||||
        self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque()
 | 
			
		||||
 | 
			
		||||
        for i, val in enumerate(py_stack):
 | 
			
		||||
            self.mode_stack.append(
 | 
			
		||||
                LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i))
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def in_torch_function_mode(self):
 | 
			
		||||
        return len(self.mode_stack) > 0
 | 
			
		||||
 | 
			
		||||
    def pop_torch_function_mode(self):
 | 
			
		||||
        return self.mode_stack.pop()
 | 
			
		||||
 | 
			
		||||
    def push_torch_function_mode(self, mode_var):
 | 
			
		||||
        self.mode_stack.append(mode_var)
 | 
			
		||||
 | 
			
		||||
    def call_torch_function_mode(self, tx, fn, types, args, kwargs):
 | 
			
		||||
        with self._pop_mode_for_inlining() as cur_mode:
 | 
			
		||||
            return cur_mode.call_torch_function(tx, fn, types, args, kwargs)
 | 
			
		||||
 | 
			
		||||
    @contextlib.contextmanager
 | 
			
		||||
    def _pop_mode_for_inlining(self):
 | 
			
		||||
        old_mode = self.cur_mode
 | 
			
		||||
        self.cur_mode = self.pop_torch_function_mode()
 | 
			
		||||
        try:
 | 
			
		||||
            yield self.cur_mode
 | 
			
		||||
        finally:
 | 
			
		||||
            mode = self.cur_mode
 | 
			
		||||
            self.cur_mode = old_mode
 | 
			
		||||
            self.push_torch_function_mode(mode)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TorchFunctionModeStackVariable(VariableTracker):
 | 
			
		||||
    """Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation"""
 | 
			
		||||
 | 
			
		||||
@ -88,19 +163,20 @@ class TorchFunctionModeStackVariable(VariableTracker):
 | 
			
		||||
    def register_mutation(cls, tx: "InstructionTranslator"):
 | 
			
		||||
        if cls.stack_value_singleton not in tx.output.side_effects:
 | 
			
		||||
            var = cls(
 | 
			
		||||
                source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack
 | 
			
		||||
                source=Source(),
 | 
			
		||||
                symbolic_stack=tx.symbolic_torch_function_state.mode_stack,
 | 
			
		||||
            )
 | 
			
		||||
            tx.output.side_effects.track_mutable(cls.stack_value_singleton, var)
 | 
			
		||||
            tx.output.side_effects.mutation(var)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def register_device_context_insertion(cls, tx: "InstructionTranslator"):
 | 
			
		||||
        stack = tx.symbolic_torch_function_mode_stack
 | 
			
		||||
        stack = tx.symbolic_torch_function_state.mode_stack
 | 
			
		||||
        if stack and cls.is_device_context(stack[0]):
 | 
			
		||||
            return
 | 
			
		||||
        else:
 | 
			
		||||
            cls.offset += 1
 | 
			
		||||
            tx.symbolic_torch_function_mode_stack.insert(
 | 
			
		||||
            stack.insert(
 | 
			
		||||
                0,
 | 
			
		||||
                TorchFunctionModeVariable(
 | 
			
		||||
                    None, source=TorchFunctionModeStackSource(-cls.offset)
 | 
			
		||||
@ -109,7 +185,7 @@ class TorchFunctionModeStackVariable(VariableTracker):
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def clear_default_device(cls, tx: "InstructionTranslator"):
 | 
			
		||||
        stack = tx.symbolic_torch_function_mode_stack
 | 
			
		||||
        stack = tx.symbolic_torch_function_state.mode_stack
 | 
			
		||||
        if stack and cls.is_device_context(stack[0]):
 | 
			
		||||
            stack.popleft()
 | 
			
		||||
            cls.offset -= 1
 | 
			
		||||
@ -123,24 +199,91 @@ class TorchFunctionModeStackVariable(VariableTracker):
 | 
			
		||||
        return ind + cls.offset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TorchFunctionModeVariable(ContextWrappingVariable):
 | 
			
		||||
    def __init__(self, value, **kwargs):
 | 
			
		||||
        super().__init__(value, **kwargs)
 | 
			
		||||
        self.value = value
 | 
			
		||||
 | 
			
		||||
class TorchFunctionModeVariable(GenericContextWrappingVariable):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_global_mangled_name(tx, val):
 | 
			
		||||
        return get_safe_global_name(
 | 
			
		||||
            tx, f"__torch_function_mode_{val.__class__.__name__}", val
 | 
			
		||||
    def is_supported_torch_function_mode(ty):
 | 
			
		||||
        # Supported in this sense means we can support graph breaks under the
 | 
			
		||||
        # context.
 | 
			
		||||
        # We are able to trace custom modes but if there are graph breaks under them
 | 
			
		||||
        # and they have a custom __enter__/__exit__ we don't handle this for the
 | 
			
		||||
        # same reason we don't handle generic context managers: there may be side effects
 | 
			
		||||
        # that are now affected by executing the funtion across two frames instead of one
 | 
			
		||||
        # Today we support the enter/exit of the default TorchFunctionMode as well as
 | 
			
		||||
        # DeviceContext (which is used for set_default_device)
 | 
			
		||||
        return issubclass(ty, (DeviceContext, NoEnterTorchFunctionMode)) or (
 | 
			
		||||
            not class_has_getattribute(ty)
 | 
			
		||||
            and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
 | 
			
		||||
            and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def __init__(self, value, source=None, **kwargs):
 | 
			
		||||
        super().__init__(value, **kwargs)
 | 
			
		||||
        self.value = value
 | 
			
		||||
        self.cm_obj = value  # needed for BC with calling enter from CM code
 | 
			
		||||
        self.source = source
 | 
			
		||||
 | 
			
		||||
    def reconstruct(self, codegen):
 | 
			
		||||
        # We don't support locally created torch function modes yet
 | 
			
		||||
        # This shouldn't be called unless we have a source
 | 
			
		||||
        assert self.source
 | 
			
		||||
        self.source.reconstruct(codegen)
 | 
			
		||||
 | 
			
		||||
    def _call_func(self, tx, values):
 | 
			
		||||
        unimplemented("torch function mode context manager is not supported yet")
 | 
			
		||||
    def module_name(self):
 | 
			
		||||
        return self.value.__module__
 | 
			
		||||
 | 
			
		||||
    def fn_name(self):
 | 
			
		||||
        return type(self.value).__name__
 | 
			
		||||
 | 
			
		||||
    def python_type(self):
 | 
			
		||||
        return type(self.value)
 | 
			
		||||
 | 
			
		||||
    def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
 | 
			
		||||
        return call_torch_function(
 | 
			
		||||
            tx,
 | 
			
		||||
            self,
 | 
			
		||||
            build_torch_function_fn(tx, self.value, self.source),
 | 
			
		||||
            fn,
 | 
			
		||||
            types,
 | 
			
		||||
            args,
 | 
			
		||||
            kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def enter(self, tx):
 | 
			
		||||
        from .torch import TorchInGraphFunctionVariable
 | 
			
		||||
 | 
			
		||||
        TorchInGraphFunctionVariable(
 | 
			
		||||
            torch._C._push_on_torch_function_stack
 | 
			
		||||
        ).call_function(tx, [self], {})
 | 
			
		||||
        return ConstantVariable.create(None)
 | 
			
		||||
 | 
			
		||||
    def exit(self, tx: "InstructionTranslator", *args):
 | 
			
		||||
        from .torch import TorchInGraphFunctionVariable
 | 
			
		||||
 | 
			
		||||
        TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
 | 
			
		||||
            tx, [], {}
 | 
			
		||||
        )
 | 
			
		||||
        return ConstantVariable.create(None)
 | 
			
		||||
 | 
			
		||||
    def reconstruct_type(self, codegen):
 | 
			
		||||
        ty = NoEnterTorchFunctionMode
 | 
			
		||||
        # NoEnterDeviceTorchFunctionMode
 | 
			
		||||
        # if isinstance(self.value, DeviceContext)
 | 
			
		||||
        # else NoEnterTorchFunctionMode
 | 
			
		||||
        # codegen(
 | 
			
		||||
        #    AttrSource(
 | 
			
		||||
        #        codegen.tx.import_source(torch._dynamo.polyfills.__name__), ty.__name__),
 | 
			
		||||
        #    )
 | 
			
		||||
        codegen(
 | 
			
		||||
            AttrSource(
 | 
			
		||||
                codegen.tx.import_source(ty.__module__),
 | 
			
		||||
                ty.__name__,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def supports_graph_breaks(self):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def exit_on_graph_break(self):
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_all_args(args, kwargs):
 | 
			
		||||
@ -231,9 +374,13 @@ def build_torch_function_fn(tx: "InstructionTranslator", value, source):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
 | 
			
		||||
    return tx.output.torch_function_enabled and any(
 | 
			
		||||
    has_overridden_args = any(
 | 
			
		||||
        has_torch_function(arg) for arg in _get_all_args(args, kwargs)
 | 
			
		||||
    )
 | 
			
		||||
    tf_state = tx.symbolic_torch_function_state
 | 
			
		||||
    return (has_overridden_args and tf_state.torch_function_subclass_enabled) or (
 | 
			
		||||
        tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode()
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
 | 
			
		||||
@ -245,11 +392,20 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
 | 
			
		||||
        _get_subclass_type,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args])
 | 
			
		||||
 | 
			
		||||
    if tx.symbolic_torch_function_state.in_torch_function_mode():
 | 
			
		||||
        res = tx.symbolic_torch_function_state.call_torch_function_mode(
 | 
			
		||||
            tx, fn, types, args, kwargs
 | 
			
		||||
        )
 | 
			
		||||
        if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
 | 
			
		||||
            return res
 | 
			
		||||
 | 
			
		||||
    for arg in overloaded_args:
 | 
			
		||||
        res = arg.call_torch_function(
 | 
			
		||||
            tx,
 | 
			
		||||
            fn,
 | 
			
		||||
            TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]),
 | 
			
		||||
            types,
 | 
			
		||||
            args,
 | 
			
		||||
            kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -82,11 +82,6 @@ def is_forbidden_context_manager(ctx):
 | 
			
		||||
        from _pytest.python_api import RaisesContext
 | 
			
		||||
        from _pytest.recwarn import WarningsChecker
 | 
			
		||||
 | 
			
		||||
        # TODO mlazos: Temporary to get this stack to pass
 | 
			
		||||
        # remove in subsequent PR
 | 
			
		||||
        from torch.overrides import BaseTorchFunctionMode
 | 
			
		||||
 | 
			
		||||
        f_ctxs.append(BaseTorchFunctionMode)
 | 
			
		||||
        f_ctxs.append(RaisesContext)
 | 
			
		||||
        f_ctxs.append(WarningsChecker)
 | 
			
		||||
    except ImportError:
 | 
			
		||||
@ -413,15 +408,25 @@ class UserDefinedClassVariable(UserDefinedVariable):
 | 
			
		||||
            and self.source
 | 
			
		||||
            and not is_forbidden_context_manager(self.value)
 | 
			
		||||
        ):
 | 
			
		||||
            # import here to avoid an unfortunate circular dependency.
 | 
			
		||||
            from torch.overrides import TorchFunctionMode
 | 
			
		||||
 | 
			
		||||
            from .ctx_manager import GenericContextWrappingVariable
 | 
			
		||||
            from .torch_function import TorchFunctionModeVariable
 | 
			
		||||
 | 
			
		||||
            if issubclass(
 | 
			
		||||
                self.value, TorchFunctionMode
 | 
			
		||||
            ) and TorchFunctionModeVariable.is_supported_torch_function_mode(
 | 
			
		||||
                self.value
 | 
			
		||||
            ):
 | 
			
		||||
                var_cls = TorchFunctionModeVariable
 | 
			
		||||
            else:
 | 
			
		||||
                var_cls = GenericContextWrappingVariable
 | 
			
		||||
 | 
			
		||||
            cm_obj = tx.output.side_effects.track_object_new(
 | 
			
		||||
                self.source, self.value, GenericContextWrappingVariable, {}
 | 
			
		||||
                self.source, self.value, var_cls, {}
 | 
			
		||||
            )
 | 
			
		||||
            cm_obj.call_method(tx, "__init__", args, kwargs)
 | 
			
		||||
            return cm_obj
 | 
			
		||||
 | 
			
		||||
        elif is_namedtuple_cls(self.value):
 | 
			
		||||
            fields = namedtuple_fields(self.value)
 | 
			
		||||
            # check if this a quasi-namedtuple or a real one
 | 
			
		||||
 | 
			
		||||
@ -506,7 +506,11 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
 | 
			
		||||
 | 
			
		||||
    def __torch_function__(self, func, types, args=(), kwargs=None):
 | 
			
		||||
        kwargs = kwargs or {}
 | 
			
		||||
        if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc:
 | 
			
		||||
        if (
 | 
			
		||||
            not torch.compiler.is_dynamo_compiling()
 | 
			
		||||
            and log.isEnabledFor(logging.DEBUG)
 | 
			
		||||
            and config.extended_debug_current_loc
 | 
			
		||||
        ):
 | 
			
		||||
            frame = _find_user_code_frame()
 | 
			
		||||
            if frame is not None:
 | 
			
		||||
                log.debug(
 | 
			
		||||
 | 
			
		||||
@ -28,6 +28,7 @@ from torch._ops import HigherOrderOperator
 | 
			
		||||
from torch._subclasses.fake_tensor import FakeTensorMode
 | 
			
		||||
from torch._subclasses.functional_tensor import disable_functional_mode
 | 
			
		||||
from torch.fx.experimental.proxy_tensor import (
 | 
			
		||||
    _temp_remove_metadata_torch_function_mode,
 | 
			
		||||
    _temp_remove_pre_dispatch_torch_function_mode,
 | 
			
		||||
    disable_proxy_modes_tracing,
 | 
			
		||||
    ProxyTorchDispatchMode,
 | 
			
		||||
@ -129,6 +130,10 @@ def cond(pred, true_fn, false_fn, operands):
 | 
			
		||||
    if torch.compiler.is_dynamo_compiling():
 | 
			
		||||
        return cond_op(pred, true_fn, false_fn, operands)
 | 
			
		||||
 | 
			
		||||
    from torch._dynamo.backends.debugging import (
 | 
			
		||||
        make_eager_backend_with_torch_function_mode,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if isinstance(pred, (bool, int, float)):
 | 
			
		||||
        log.warning(
 | 
			
		||||
            "Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
 | 
			
		||||
@ -169,12 +174,15 @@ def cond(pred, true_fn, false_fn, operands):
 | 
			
		||||
    def _cond_op_wrapper(*args, **kwargs):
 | 
			
		||||
        return cond_op(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    with _set_compilation_env():
 | 
			
		||||
        with torch._dynamo.utils.disable_cache_limit():
 | 
			
		||||
            with _temp_remove_pre_dispatch_torch_function_mode():
 | 
			
		||||
                return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)(
 | 
			
		||||
                    pred, true_fn, false_fn, operands
 | 
			
		||||
                )
 | 
			
		||||
    with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode():
 | 
			
		||||
        with _temp_remove_metadata_torch_function_mode() as metadata_mode:
 | 
			
		||||
            if metadata_mode:
 | 
			
		||||
                backend = make_eager_backend_with_torch_function_mode(metadata_mode)
 | 
			
		||||
            else:
 | 
			
		||||
                backend = "eager"
 | 
			
		||||
            return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)(
 | 
			
		||||
                pred, true_fn, false_fn, operands
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_fw_bw_graph_branches(true_fn, false_fn, *operands):
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,11 @@ from torch._higher_order_ops.utils import (
 | 
			
		||||
)
 | 
			
		||||
from torch._ops import HigherOrderOperator
 | 
			
		||||
from torch._subclasses.fake_tensor import FakeTensorMode
 | 
			
		||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
 | 
			
		||||
from torch.fx.experimental.proxy_tensor import (
 | 
			
		||||
    _temp_remove_metadata_torch_function_mode,
 | 
			
		||||
    ProxyTorchDispatchMode,
 | 
			
		||||
    track_tensor_tree,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WhileLoopOp(HigherOrderOperator):
 | 
			
		||||
@ -113,6 +117,9 @@ def while_loop(cond_fn, body_fn, carried_inputs):
 | 
			
		||||
        - 'while_loop' only supports **inference** right now. Autograd will be supported in the future.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    from torch._dynamo.backends.debugging import (
 | 
			
		||||
        make_eager_backend_with_torch_function_mode,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo.
 | 
			
		||||
    # parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs.
 | 
			
		||||
@ -140,9 +147,15 @@ def while_loop(cond_fn, body_fn, carried_inputs):
 | 
			
		||||
        return while_loop_op(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
 | 
			
		||||
        return torch.compile(_while_loop_op_wrapper, backend="eager", fullgraph=True)(
 | 
			
		||||
            cond_fn, body_fn, carried_inputs, additional_inputs
 | 
			
		||||
        )
 | 
			
		||||
        with _temp_remove_metadata_torch_function_mode() as metadata_mode:
 | 
			
		||||
            with _temp_remove_metadata_torch_function_mode() as metadata_mode:
 | 
			
		||||
                if metadata_mode:
 | 
			
		||||
                    backend = make_eager_backend_with_torch_function_mode(metadata_mode)
 | 
			
		||||
                else:
 | 
			
		||||
                    backend = "eager"
 | 
			
		||||
                return torch.compile(
 | 
			
		||||
                    _while_loop_op_wrapper, backend=backend, fullgraph=True
 | 
			
		||||
                )(cond_fn, body_fn, carried_inputs, additional_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd)
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,7 @@ import typing_extensions
 | 
			
		||||
import warnings
 | 
			
		||||
import weakref
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from contextlib import contextmanager, ExitStack, nullcontext
 | 
			
		||||
from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
@ -1084,38 +1084,43 @@ class PythonKeyTracer(Tracer):
 | 
			
		||||
            return e
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@contextmanager
 | 
			
		||||
def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]:
 | 
			
		||||
    from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode
 | 
			
		||||
def _make_temp_remove_mode_context_manager(
 | 
			
		||||
    mode_ty: Type[TorchFunctionMode],
 | 
			
		||||
) -> Callable[[], _GeneratorContextManager[None]]:
 | 
			
		||||
    @contextmanager
 | 
			
		||||
    def context_manager_fn() -> Generator[None, None, None]:
 | 
			
		||||
        from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode
 | 
			
		||||
 | 
			
		||||
    temp_elements = []
 | 
			
		||||
    pre_dispatch_mode = None
 | 
			
		||||
        temp_elements = []
 | 
			
		||||
        removed_mode = None
 | 
			
		||||
 | 
			
		||||
    while _len_torch_function_stack() > 0:
 | 
			
		||||
        mode = _pop_mode()
 | 
			
		||||
        if isinstance(mode, PreDispatchTorchFunctionMode):
 | 
			
		||||
            pre_dispatch_mode = mode
 | 
			
		||||
            break
 | 
			
		||||
        else:
 | 
			
		||||
            temp_elements.append(mode)
 | 
			
		||||
        while _len_torch_function_stack() > 0:
 | 
			
		||||
            mode = _pop_mode()
 | 
			
		||||
            if isinstance(mode, mode_ty):
 | 
			
		||||
                removed_mode = mode
 | 
			
		||||
                break
 | 
			
		||||
            else:
 | 
			
		||||
                temp_elements.append(mode)
 | 
			
		||||
 | 
			
		||||
    for mode in reversed(temp_elements):
 | 
			
		||||
        _push_mode(mode)
 | 
			
		||||
        for mode in reversed(temp_elements):
 | 
			
		||||
            _push_mode(mode)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        yield
 | 
			
		||||
        try:
 | 
			
		||||
            yield removed_mode
 | 
			
		||||
 | 
			
		||||
    finally:
 | 
			
		||||
        if pre_dispatch_mode is not None:
 | 
			
		||||
            count = len(temp_elements)
 | 
			
		||||
            while count > 0:
 | 
			
		||||
                mode = _pop_mode()
 | 
			
		||||
                count -= 1
 | 
			
		||||
        finally:
 | 
			
		||||
            if removed_mode is not None:
 | 
			
		||||
                count = len(temp_elements)
 | 
			
		||||
                while count > 0:
 | 
			
		||||
                    mode = _pop_mode()
 | 
			
		||||
                    count -= 1
 | 
			
		||||
 | 
			
		||||
            temp_elements.append(pre_dispatch_mode)
 | 
			
		||||
                temp_elements.append(removed_mode)
 | 
			
		||||
 | 
			
		||||
            for mode in reversed(temp_elements):
 | 
			
		||||
                _push_mode(mode)
 | 
			
		||||
                for mode in reversed(temp_elements):
 | 
			
		||||
                    _push_mode(mode)
 | 
			
		||||
 | 
			
		||||
    return context_manager_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch._disable_dynamo
 | 
			
		||||
@ -1230,6 +1235,11 @@ class TorchFunctionMetadataMode(TorchFunctionMode):
 | 
			
		||||
        return func(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager(
 | 
			
		||||
    TorchFunctionMetadataMode
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# This mode is **only** used for pre_dispatch tracing.
 | 
			
		||||
# In particular, we need to make sure that autograd/autocast API's
 | 
			
		||||
# that do not desugar into dispatcher operators stay in the graph.
 | 
			
		||||
@ -1258,6 +1268,11 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
 | 
			
		||||
        return func(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager(
 | 
			
		||||
    PreDispatchTorchFunctionMode
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProxyTorchDispatchMode(TorchDispatchMode):
 | 
			
		||||
    # Ensure this is read-only; this exists only for legacy reasons
 | 
			
		||||
    @property
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,7 @@ from torch._higher_order_ops.flex_attention import (
 | 
			
		||||
)
 | 
			
		||||
from torch._higher_order_ops.utils import _set_compilation_env
 | 
			
		||||
from torch.fx.experimental.proxy_tensor import (
 | 
			
		||||
    _temp_remove_metadata_torch_function_mode,
 | 
			
		||||
    _temp_remove_pre_dispatch_torch_function_mode,
 | 
			
		||||
)
 | 
			
		||||
from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input
 | 
			
		||||
@ -1035,18 +1036,25 @@ def flex_attention(
 | 
			
		||||
    with _set_compilation_env():
 | 
			
		||||
        with torch._dynamo.utils.disable_cache_limit():
 | 
			
		||||
            with _temp_remove_pre_dispatch_torch_function_mode():
 | 
			
		||||
                out, lse = torch.compile(
 | 
			
		||||
                    _flex_attention_hop_wrapper, backend="eager", fullgraph=True
 | 
			
		||||
                )(
 | 
			
		||||
                    query,
 | 
			
		||||
                    key,
 | 
			
		||||
                    value,
 | 
			
		||||
                    score_mod,
 | 
			
		||||
                    block_mask.as_tuple(),
 | 
			
		||||
                    scale,
 | 
			
		||||
                    kernel_options,
 | 
			
		||||
                )
 | 
			
		||||
                if return_lse:
 | 
			
		||||
                    return out, lse * math.log(2)
 | 
			
		||||
                else:
 | 
			
		||||
                    return out
 | 
			
		||||
                with _temp_remove_metadata_torch_function_mode() as metadata_mode:
 | 
			
		||||
                    if metadata_mode:
 | 
			
		||||
                        backend = torch._dynamo.backends.make_eager_backend_with_torch_function_mode(
 | 
			
		||||
                            metadata_mode
 | 
			
		||||
                        )
 | 
			
		||||
                    else:
 | 
			
		||||
                        backend = "eager"
 | 
			
		||||
                    out, lse = torch.compile(
 | 
			
		||||
                        _flex_attention_hop_wrapper, backend=backend, fullgraph=True
 | 
			
		||||
                    )(
 | 
			
		||||
                        query,
 | 
			
		||||
                        key,
 | 
			
		||||
                        value,
 | 
			
		||||
                        score_mod,
 | 
			
		||||
                        block_mask.as_tuple(),
 | 
			
		||||
                        scale,
 | 
			
		||||
                        kernel_options,
 | 
			
		||||
                    )
 | 
			
		||||
                    if return_lse:
 | 
			
		||||
                        return out, lse * math.log(2)
 | 
			
		||||
                    else:
 | 
			
		||||
                        return out
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user