mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-03 23:45:05 +08:00 
			
		
		
		
	Compare commits
	
		
			8 Commits
		
	
	
		
			ciflow/tru
			...
			mlazos/tf-
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 1cad5436f6 | |||
| 318075f6cb | |||
| 6e8b6b11cd | |||
| 9655ebd499 | |||
| df9ef729fd | |||
| 244dc9d802 | |||
| 39867e316c | |||
| e1db3582d8 | 
@ -1,5 +1,5 @@
 | 
			
		||||
add_loop_eager,                compile_time_instruction_count, 2834456320,  0.015
 | 
			
		||||
add_loop_eager_dynamic,        compile_time_instruction_count, 5528896630,  0.025
 | 
			
		||||
add_loop_eager,                compile_time_instruction_count, 3004749893,  0.015
 | 
			
		||||
add_loop_eager_dynamic,        compile_time_instruction_count, 5726573328,  0.025
 | 
			
		||||
add_loop_inductor,             compile_time_instruction_count, 24146845503, 0.015
 | 
			
		||||
add_loop_inductor_dynamic_gpu, compile_time_instruction_count, 39411706509, 0.025
 | 
			
		||||
add_loop_inductor_gpu,         compile_time_instruction_count, 22171041650, 0.015
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -701,7 +701,7 @@ class CompileTest(TestCase):
 | 
			
		||||
            FileCheck()
 | 
			
		||||
            .check(
 | 
			
		||||
                "buf0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced"
 | 
			
		||||
                ".default([arg0_1, arg1_1, arg2_1, arg3_1]"
 | 
			
		||||
                ".default([arg3_1, arg2_1, arg1_1, arg0_1]"
 | 
			
		||||
            )
 | 
			
		||||
            .check("buf1 = buf0[0]")
 | 
			
		||||
            .check("buf2 = buf0[1]")
 | 
			
		||||
@ -717,8 +717,8 @@ class CompileTest(TestCase):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Test aoti
 | 
			
		||||
        out = AOTIRunnerUtil.run("cuda", func, (args,))
 | 
			
		||||
        torch.cuda.synchronize()
 | 
			
		||||
        # out = AOTIRunnerUtil.run("cuda", func, (args,))
 | 
			
		||||
        # torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
 | 
			
		||||
    @fresh_inductor_cache()
 | 
			
		||||
 | 
			
		||||
@ -938,6 +938,16 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
        else:
 | 
			
		||||
            return x - 1
 | 
			
		||||
 | 
			
		||||
    @make_test
 | 
			
		||||
    def test_tensor_size(x):
 | 
			
		||||
        fn = torch.Tensor.size
 | 
			
		||||
        return fn(x + 1)
 | 
			
		||||
 | 
			
		||||
    @make_test
 | 
			
		||||
    def test_tensor_dim(x):
 | 
			
		||||
        fn = torch.Tensor.dim
 | 
			
		||||
        return fn(x + 1)
 | 
			
		||||
 | 
			
		||||
    @make_test
 | 
			
		||||
    def test_tensor_is_inference(x):
 | 
			
		||||
        if x.is_inference():
 | 
			
		||||
 | 
			
		||||
@ -646,10 +646,10 @@ print("arf")
 | 
			
		||||
        self.assertExpectedInline(
 | 
			
		||||
            munge_shape_guards(record.getMessage()),
 | 
			
		||||
            """\
 | 
			
		||||
+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['y'].size()[0]  # return x + torch.cat([y, z])  # #:# in # #:# in #
 | 
			
		||||
+- LAMBDA_GUARD: L['z'].size()[0] == L['y'].size()[0]  # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
 | 
			
		||||
+- LAMBDA_GUARD: Eq(Mod(2*L['y'].size()[0], 3), 0)  # if x.size(0) % 3 == 0:  # #:# in # #:# in #
 | 
			
		||||
+- LAMBDA_GUARD: 2 <= L['y'].size()[0]  # return x + torch.cat([y, z])  # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""",  # noqa: B950
 | 
			
		||||
+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['z'].size()[0]  # return x + torch.cat([y, z])  # #:# in # #:# in #
 | 
			
		||||
+- LAMBDA_GUARD: L['y'].size()[0] == L['z'].size()[0]  # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
 | 
			
		||||
+- LAMBDA_GUARD: Eq(Mod(2*L['z'].size()[0], 3), 0)  # if x.size(0) % 3 == 0:  # #:# in # #:# in #
 | 
			
		||||
+- LAMBDA_GUARD: 2 <= L['z'].size()[0]  # return x + torch.cat([y, z])  # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""",  # noqa: B950
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @make_logging_test(guards=True)
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,6 @@
 | 
			
		||||
# Owner(s): ["module: dynamo"]
 | 
			
		||||
 | 
			
		||||
import operator
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -10,6 +12,7 @@ from torch._C import (
 | 
			
		||||
    _push_on_torch_function_stack,
 | 
			
		||||
)
 | 
			
		||||
from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
 | 
			
		||||
from torch.testing._internal.triton_utils import requires_cuda
 | 
			
		||||
from torch.utils._device import DeviceContext
 | 
			
		||||
from torch.utils._python_dispatch import TorchDispatchMode
 | 
			
		||||
 | 
			
		||||
@ -107,70 +110,6 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
            fn(inp)
 | 
			
		||||
        self.assertEqual(cnt.frame_count, 4)
 | 
			
		||||
 | 
			
		||||
    def _run_ignored_mode_types_test(self):
 | 
			
		||||
        class IgnoredMode(BaseTorchFunctionMode):
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        cnt = torch._dynamo.testing.CompileCounter()
 | 
			
		||||
 | 
			
		||||
        @torch.compile(backend=cnt.__call__, fullgraph=True)
 | 
			
		||||
        def fn(x):
 | 
			
		||||
            return x + 1
 | 
			
		||||
 | 
			
		||||
        inp = torch.ones(2, 2)
 | 
			
		||||
 | 
			
		||||
        with patch(
 | 
			
		||||
            "torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode}
 | 
			
		||||
        ):
 | 
			
		||||
            # initial compile
 | 
			
		||||
            fn(inp)
 | 
			
		||||
 | 
			
		||||
            # no recompile, mode ignored
 | 
			
		||||
            # note: the ref stack is length 0, and the stack we are checking against has length 2
 | 
			
		||||
            # we want to check both ref stack len > runtime stack, and ref stack len < runtime stack
 | 
			
		||||
            with IgnoredMode(), IgnoredMode():
 | 
			
		||||
                fn(inp)
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(cnt.frame_count, 1)
 | 
			
		||||
 | 
			
		||||
            # recompile due to new mode on the stack
 | 
			
		||||
            with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
 | 
			
		||||
                fn(inp)
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(cnt.frame_count, 2)
 | 
			
		||||
 | 
			
		||||
            # recompile
 | 
			
		||||
            # tests both ref stack len > runtime stack len for the above guard check
 | 
			
		||||
            # and ref stack len < runtime stack len for the initial zero mode case
 | 
			
		||||
            with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode():
 | 
			
		||||
                fn(inp)
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(cnt.frame_count, 3)
 | 
			
		||||
 | 
			
		||||
            # no recompile
 | 
			
		||||
            with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
 | 
			
		||||
                fn(inp)
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(cnt.frame_count, 3)
 | 
			
		||||
 | 
			
		||||
        # This is tricky, basically the ignored modes are baked into the guard
 | 
			
		||||
        # IgnoredMode will be ignored forever by that guard.
 | 
			
		||||
        # This is okay since we don't expect to be modifying IGNORED_MODES
 | 
			
		||||
        # in the middle of execution except for the purposes of testing.
 | 
			
		||||
        torch._dynamo.reset()
 | 
			
		||||
 | 
			
		||||
        with IgnoredMode():
 | 
			
		||||
            fn(inp)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(cnt.frame_count, 4)
 | 
			
		||||
 | 
			
		||||
    @torch._dynamo.config.patch("enable_cpp_guard_manager", False)
 | 
			
		||||
    def test_torch_function_mode_guards_ignored_types_py(self):
 | 
			
		||||
        self._run_ignored_mode_types_test()
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_guards_ignored_types_cpp(self):
 | 
			
		||||
        self._run_ignored_mode_types_test()
 | 
			
		||||
 | 
			
		||||
    @torch._dynamo.config.patch("enable_cpp_guard_manager", False)
 | 
			
		||||
    def test_torch_function_mode_guards_py(self):
 | 
			
		||||
        self._run_torch_function_mode_guard_test()
 | 
			
		||||
@ -461,6 +400,205 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_enter_exit(self):
 | 
			
		||||
        def fn(x, y):
 | 
			
		||||
            with TestMode():
 | 
			
		||||
                o = torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
            return torch.add(o, y)
 | 
			
		||||
 | 
			
		||||
        inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
 | 
			
		||||
        fn_opt = torch.compile(fn, fullgraph=True)
 | 
			
		||||
 | 
			
		||||
        expected = fn(*inp)
 | 
			
		||||
        actual = fn_opt(*inp)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_graph_break(self):
 | 
			
		||||
        def fn(x, y):
 | 
			
		||||
            with TestMode():
 | 
			
		||||
                torch._dynamo.graph_break()
 | 
			
		||||
                o = torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
            return torch.add(o, y)
 | 
			
		||||
 | 
			
		||||
        inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
 | 
			
		||||
        fn_opt = torch.compile(fn)
 | 
			
		||||
 | 
			
		||||
        expected = fn(*inp)
 | 
			
		||||
        actual = fn_opt(*inp)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_and_pop_graph_break(self):
 | 
			
		||||
        def fn(x, y):
 | 
			
		||||
            with TestMode():
 | 
			
		||||
                z = _pop_torch_function_stack()
 | 
			
		||||
                torch._dynamo.graph_break()
 | 
			
		||||
                _push_on_torch_function_stack(z)
 | 
			
		||||
                o = torch.add(x, 3)
 | 
			
		||||
 | 
			
		||||
            return torch.add(o, y)
 | 
			
		||||
 | 
			
		||||
        inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
 | 
			
		||||
        fn_opt = torch.compile(fn)
 | 
			
		||||
 | 
			
		||||
        expected = fn(*inp)
 | 
			
		||||
        actual = fn_opt(*inp)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_restore_on_exc(self):
 | 
			
		||||
        @torch._dynamo.disable()
 | 
			
		||||
        def err():
 | 
			
		||||
            raise RuntimeError("test")
 | 
			
		||||
 | 
			
		||||
        @torch.compile()
 | 
			
		||||
        def fn(x):
 | 
			
		||||
            with TestMode():
 | 
			
		||||
                x += 1
 | 
			
		||||
                err()
 | 
			
		||||
                x += 2
 | 
			
		||||
                return x
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            fn(torch.ones(2, 2))
 | 
			
		||||
        except RuntimeError:
 | 
			
		||||
            pass
 | 
			
		||||
        self.assertEqual(_len_torch_function_stack(), 0)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_mode_and_pop_graph_break_mutation(self):
 | 
			
		||||
        def fn(x, y):
 | 
			
		||||
            with TestMode():
 | 
			
		||||
                z = _pop_torch_function_stack()
 | 
			
		||||
                z.y = 5
 | 
			
		||||
                torch._dynamo.graph_break()
 | 
			
		||||
                _push_on_torch_function_stack(z)
 | 
			
		||||
                o = torch.add(x, 3)
 | 
			
		||||
                o = torch.mul(o, z.y)
 | 
			
		||||
 | 
			
		||||
            return torch.add(o, y)
 | 
			
		||||
 | 
			
		||||
        inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
 | 
			
		||||
        fn_opt = torch.compile(fn)
 | 
			
		||||
 | 
			
		||||
        expected = fn(*inp)
 | 
			
		||||
        actual = fn_opt(*inp)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
    # Needs larger cache size since we recompile for each op
 | 
			
		||||
    @patch.object(torch._dynamo.config, "cache_size_limit", 48)
 | 
			
		||||
    def test_builtin_equivalent_funcs(self):
 | 
			
		||||
        from torch._dynamo.variables.torch_function import (
 | 
			
		||||
            bin_int_ops,
 | 
			
		||||
            bin_ops,
 | 
			
		||||
            BUILTIN_TO_TENSOR_FN_MAP,
 | 
			
		||||
            BUILTIN_TO_TENSOR_RFN_MAP,
 | 
			
		||||
            tensor_and_int_ops,
 | 
			
		||||
            un_int_ops,
 | 
			
		||||
            un_ops,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        expected_func = None
 | 
			
		||||
        valid = False
 | 
			
		||||
 | 
			
		||||
        class FuncEquivMode(BaseTorchFunctionMode):
 | 
			
		||||
            def __torch_function__(self, func, types, args=(), kwargs=None):
 | 
			
		||||
                nonlocal expected_func
 | 
			
		||||
                nonlocal valid
 | 
			
		||||
                if not kwargs:
 | 
			
		||||
                    kwargs = {}
 | 
			
		||||
                if torch._dynamo.is_compiling():
 | 
			
		||||
                    valid = expected_func == func
 | 
			
		||||
                return super().__torch_function__(func, types, args, kwargs)
 | 
			
		||||
 | 
			
		||||
        inp0 = torch.ones(1, 1)
 | 
			
		||||
        inp1 = torch.ones(1, 1)
 | 
			
		||||
        inp0_int = torch.ones(1, 1, dtype=torch.int32)
 | 
			
		||||
        inp1_int = torch.ones(1, 1, dtype=torch.int32)
 | 
			
		||||
 | 
			
		||||
        @torch.compile(fullgraph=True)
 | 
			
		||||
        def fn_un(op, inp):
 | 
			
		||||
            return op(inp)
 | 
			
		||||
 | 
			
		||||
        @torch.compile(fullgraph=True)
 | 
			
		||||
        def fn_un_int(op, inp):
 | 
			
		||||
            return op(inp)
 | 
			
		||||
 | 
			
		||||
        @torch.compile(fullgraph=True)
 | 
			
		||||
        def fn_bin(op, inp0, inp1):
 | 
			
		||||
            return op(inp0, inp1)
 | 
			
		||||
 | 
			
		||||
        @torch.compile(fullgraph=True)
 | 
			
		||||
        def fn_bin_int(op, inp0, inp1):
 | 
			
		||||
            return op(inp0, inp1)
 | 
			
		||||
 | 
			
		||||
        @torch.compile(fullgraph=True)
 | 
			
		||||
        def fn_tensor_and_int(op, inp0, inp1):
 | 
			
		||||
            return op(inp0, inp1)
 | 
			
		||||
 | 
			
		||||
        setups_and_oplists = [
 | 
			
		||||
            (lambda o: fn_un(o, inp0), un_ops),
 | 
			
		||||
            (lambda o: fn_un_int(o, inp0_int), un_int_ops),
 | 
			
		||||
            (lambda o: fn_bin(o, inp0, inp1), bin_ops),
 | 
			
		||||
            (lambda o: fn_bin_int(o, inp0_int, inp1_int), bin_int_ops),
 | 
			
		||||
            (lambda o: fn_tensor_and_int(o, inp0_int, 0), tensor_and_int_ops),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        # gather the reverse functions
 | 
			
		||||
        rsetups_and_oplists = [
 | 
			
		||||
            (
 | 
			
		||||
                lambda o: fn_bin(o, 1, inp1),
 | 
			
		||||
                bin_ops,
 | 
			
		||||
            ),  # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
 | 
			
		||||
            (lambda o: fn_bin_int(o, 1, inp1_int), bin_int_ops),
 | 
			
		||||
            (lambda o: fn_tensor_and_int(o, 0, inp0_int), tensor_and_int_ops),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        skips = {operator.not_}  # Has local scalar dense call which graph breaks
 | 
			
		||||
        rskips = {
 | 
			
		||||
            operator.matmul,
 | 
			
		||||
            operator.imatmul,
 | 
			
		||||
            operator.getitem,
 | 
			
		||||
        }  # Doesn't type check with reversed args
 | 
			
		||||
 | 
			
		||||
        def run_checks(setups_and_oplists, skips, ref_map):
 | 
			
		||||
            nonlocal valid
 | 
			
		||||
            nonlocal expected_func
 | 
			
		||||
            for setup_fn, op_list in setups_and_oplists:
 | 
			
		||||
                for op in op_list:
 | 
			
		||||
                    if op in skips or op not in ref_map:
 | 
			
		||||
                        continue
 | 
			
		||||
                    with FuncEquivMode():
 | 
			
		||||
                        expected_func = ref_map[op]
 | 
			
		||||
                        setup_fn(op)
 | 
			
		||||
                        self.assertTrue(valid)
 | 
			
		||||
 | 
			
		||||
                    expected_func = None
 | 
			
		||||
                    valid = False
 | 
			
		||||
 | 
			
		||||
        run_checks(setups_and_oplists, skips, BUILTIN_TO_TENSOR_FN_MAP)
 | 
			
		||||
        run_checks(rsetups_and_oplists, rskips, BUILTIN_TO_TENSOR_RFN_MAP)
 | 
			
		||||
 | 
			
		||||
    @requires_cuda
 | 
			
		||||
    def test_flex_attention(self):
 | 
			
		||||
        import torch
 | 
			
		||||
        from torch.nn.attention.flex_attention import create_block_mask, flex_attention
 | 
			
		||||
 | 
			
		||||
        torch.set_default_device("cuda")
 | 
			
		||||
 | 
			
		||||
        flex_attention = torch.compile(flex_attention, dynamic=False)
 | 
			
		||||
 | 
			
		||||
        prefix_lengths = torch.arange(8)
 | 
			
		||||
 | 
			
		||||
        def prefix_lm(b, h, q, kv):
 | 
			
		||||
            return prefix_lengths[b] >= kv
 | 
			
		||||
 | 
			
		||||
        # This runs in fullgraph already
 | 
			
		||||
        mask = create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    from torch._dynamo.test_case import run_tests
 | 
			
		||||
 | 
			
		||||
@ -672,7 +672,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
        wrapped2 = y.as_subclass(SigmoidToExpSubclass)
 | 
			
		||||
 | 
			
		||||
        def fn(w):
 | 
			
		||||
            return w.sigmoid()
 | 
			
		||||
            return w.exp()
 | 
			
		||||
 | 
			
		||||
        fn_opt = compile_full_eager(fn)
 | 
			
		||||
 | 
			
		||||
@ -683,6 +683,38 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
        self.assertEqual(res_exp, res_act)
 | 
			
		||||
        self.assertEqual(res_exp, res_exp2)
 | 
			
		||||
 | 
			
		||||
    def test_torch_function_call_on_method_arg(self):
 | 
			
		||||
        class LocalSubclass(torch.Tensor):
 | 
			
		||||
            @classmethod
 | 
			
		||||
            def __torch_function__(cls, func, types, args=(), kwargs=None):
 | 
			
		||||
                if func == torch._C.TensorBase.add_:
 | 
			
		||||
                    func = torch._C.TensorBase.sub_
 | 
			
		||||
 | 
			
		||||
                if kwargs is None:
 | 
			
		||||
                    kwargs = {}
 | 
			
		||||
                return super().__torch_function__(func, types, args, kwargs)
 | 
			
		||||
 | 
			
		||||
            def sigmoid(self):
 | 
			
		||||
                return None
 | 
			
		||||
 | 
			
		||||
        x = torch.ones(2, 2)
 | 
			
		||||
        y = torch.ones(2, 2)
 | 
			
		||||
        z = torch.ones(2, 2)
 | 
			
		||||
        wrapped = y.as_subclass(LocalSubclass)
 | 
			
		||||
        wrapped2 = z.as_subclass(LocalSubclass)
 | 
			
		||||
 | 
			
		||||
        def fn(a, w):
 | 
			
		||||
            a.add_(w)
 | 
			
		||||
            return a
 | 
			
		||||
 | 
			
		||||
        fn_opt = torch.compile(fn)
 | 
			
		||||
 | 
			
		||||
        with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
 | 
			
		||||
            res_exp = fn(x, wrapped)
 | 
			
		||||
            res_act = fn_opt(y, wrapped2)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(res_exp, res_act)
 | 
			
		||||
 | 
			
		||||
    def test_user_overidden_method_unsupported(self):
 | 
			
		||||
        class LocalSubclass(torch.Tensor):
 | 
			
		||||
            @classmethod
 | 
			
		||||
 | 
			
		||||
@ -49,9 +49,9 @@ def forward(self, b_submodule_buffer1, x):
 | 
			
		||||
    sin = torch.ops.aten.sin.default(x)
 | 
			
		||||
    strict_graph_0 = self.strict_graph_0
 | 
			
		||||
    strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1));  strict_graph_0 = sin = b_submodule_buffer1 = None
 | 
			
		||||
    getitem_2 = strict_mode[0];  strict_mode = None
 | 
			
		||||
    getitem = strict_mode[0];  strict_mode = None
 | 
			
		||||
    add = torch.ops.aten.add.Tensor(x, 3);  x = None
 | 
			
		||||
    return (getitem_2, add)""",
 | 
			
		||||
    return (getitem, add)""",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertExpectedInline(
 | 
			
		||||
 | 
			
		||||
@ -64,6 +64,7 @@ from torch.testing._internal.common_utils import (
 | 
			
		||||
    IS_SANDCASTLE,
 | 
			
		||||
    IS_WINDOWS,
 | 
			
		||||
    run_tests,
 | 
			
		||||
    skipIfCrossRef,
 | 
			
		||||
    TEST_TRANSFORMERS,
 | 
			
		||||
    TestCase as TorchTestCase,
 | 
			
		||||
)
 | 
			
		||||
@ -6989,6 +6990,7 @@ def forward(self, x):
 | 
			
		||||
        real_names_and_ops = [(node.name, node.op) for node in ep.graph.nodes]
 | 
			
		||||
        self.assertEqual(expected_names_and_ops, real_names_and_ops)
 | 
			
		||||
 | 
			
		||||
    @skipIfCrossRef  # Dynamo changes the order of ops under Torch function modes
 | 
			
		||||
    def test_placeholder_naming_collisions_hoo_subgraphs(self):
 | 
			
		||||
        # test collisions between user inputs, top-level nodes, and HOO subgraph nodes
 | 
			
		||||
        class Foo(torch.nn.Module):
 | 
			
		||||
@ -8325,6 +8327,7 @@ class TestOneOffModelExportResult(TestCase):
 | 
			
		||||
    #     getitem = _scaled_dot_product_flash_attention_for_cpu[0];  _scaled_dot_product_flash_attention_for_cpu = None
 | 
			
		||||
    #     return (getitem,)""")
 | 
			
		||||
 | 
			
		||||
    @skipIfCrossRef
 | 
			
		||||
    @unittest.skipIf(
 | 
			
		||||
        not PLATFORM_SUPPORTS_FLASH_ATTENTION,
 | 
			
		||||
        "Can't run fused SDPA on this platform",
 | 
			
		||||
 | 
			
		||||
@ -4902,6 +4902,7 @@ def forward(self, arg0_1, arg1_1):
 | 
			
		||||
    return [getitem]""",  # noqa: B950
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @skipIfCrossRef  # Arg order changes with crossref
 | 
			
		||||
    def test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph(self):
 | 
			
		||||
        def true_fn(x):
 | 
			
		||||
            return x + x.cos()
 | 
			
		||||
@ -5252,6 +5253,7 @@ def forward(self, arg0_1):
 | 
			
		||||
        ):
 | 
			
		||||
            torch.cond(inp.sum() > 0, f, f, (inp, tmp))
 | 
			
		||||
 | 
			
		||||
    @skipIfCrossRef  # Arg order changes with crossref
 | 
			
		||||
    def test_cond_trace_set__and_mutate_intermediate(self):
 | 
			
		||||
        def f(a, tmp):
 | 
			
		||||
            a = a.clone()
 | 
			
		||||
 | 
			
		||||
@ -180,12 +180,10 @@ class AutoFunctionalizeTests(torch._inductor.test_case.TestCase):
 | 
			
		||||
                self.assertExpectedInline(
 | 
			
		||||
                    post_grad_graphs,
 | 
			
		||||
                    """\
 | 
			
		||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: \
 | 
			
		||||
"f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
 | 
			
		||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
 | 
			
		||||
        # No stacktrace found for following nodes
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg4_1 = arg2_1 = \
 | 
			
		||||
arg3_1 = arg1_1 = arg0_1 = foo_default = None
 | 
			
		||||
        return ()""",
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1);  arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = foo_default = None
 | 
			
		||||
        return ()""",  # noqa: B950
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
 | 
			
		||||
@ -239,7 +237,7 @@ arg3_1 = arg1_1 = arg0_1 = foo_default = None
 | 
			
		||||
                    post_grad_graphs,
 | 
			
		||||
                    """\
 | 
			
		||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1);  arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = None
 | 
			
		||||
        getitem_4: "f32[3][1]cpu" = foo_default[0]
 | 
			
		||||
        getitem_5: "f32[3][1]cpu" = foo_default[1];  foo_default = None
 | 
			
		||||
        return (getitem_4, getitem_5)""",  # noqa: B950
 | 
			
		||||
@ -402,9 +400,9 @@ arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
 | 
			
		||||
                        post_grad_graphs,
 | 
			
		||||
                        """\
 | 
			
		||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"):
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg5_1, [arg3_1, arg4_1], arg2_1, 2, arg1_1);  arg3_1 = arg4_1 = arg1_1 = foo_default = None
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg3_1, [arg4_1, arg5_1], arg2_1, 2, arg1_1);  arg4_1 = arg5_1 = arg1_1 = foo_default = None
 | 
			
		||||
        copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1);  arg2_1 = copy_ = None
 | 
			
		||||
        copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg5_1, arg5_1);  arg5_1 = copy__1 = None
 | 
			
		||||
        copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1);  arg3_1 = copy__1 = None
 | 
			
		||||
        return ()""",  # noqa: B950
 | 
			
		||||
                        ignore_comments=True,
 | 
			
		||||
                        ignore_empty_lines=True,
 | 
			
		||||
@ -414,9 +412,9 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1
 | 
			
		||||
                        post_grad_graphs,
 | 
			
		||||
                        """\
 | 
			
		||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg2_1 = arg3_1 = arg0_1 = foo_default = None
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1);  arg3_1 = arg4_1 = arg0_1 = foo_default = None
 | 
			
		||||
        copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy_ = None
 | 
			
		||||
        copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1);  arg4_1 = copy__1 = None
 | 
			
		||||
        copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1);  arg2_1 = copy__1 = None
 | 
			
		||||
        return ()""",  # noqa: B950
 | 
			
		||||
                        ignore_comments=True,
 | 
			
		||||
                        ignore_empty_lines=True,
 | 
			
		||||
@ -503,12 +501,11 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
 | 
			
		||||
                    post_grad_graphs,
 | 
			
		||||
                    """\
 | 
			
		||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg2_1 = arg3_1 = arg0_1 = None
 | 
			
		||||
        foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1);  arg3_1 = arg4_1 = arg0_1 = None
 | 
			
		||||
        getitem_4: "f32[3][1]cpu" = foo_default[0]
 | 
			
		||||
        getitem_5: "f32[3][1]cpu" = foo_default[1];  foo_default = None
 | 
			
		||||
 | 
			
		||||
        copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy_ = None
 | 
			
		||||
        copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1);  arg4_1 = copy__1 = None
 | 
			
		||||
        copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1);  arg2_1 = copy__1 = None
 | 
			
		||||
        return (getitem_4, getitem_5)""",  # noqa: B950
 | 
			
		||||
                    ignore_comments=True,
 | 
			
		||||
                    ignore_empty_lines=True,
 | 
			
		||||
 | 
			
		||||
@ -67,7 +67,7 @@ class GuardManager:
 | 
			
		||||
    ) -> None: ...
 | 
			
		||||
    def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
 | 
			
		||||
    def add_torch_function_mode_stack_guard(
 | 
			
		||||
        self, initial_stack, ignored_types, verbose_code_parts: list[str]
 | 
			
		||||
        self, initial_stack, verbose_code_parts: list[str]
 | 
			
		||||
    ) -> None: ...
 | 
			
		||||
 | 
			
		||||
class RootGuardManager(GuardManager):
 | 
			
		||||
 | 
			
		||||
@ -1,15 +1,22 @@
 | 
			
		||||
# mypy: allow-untyped-defs
 | 
			
		||||
from typing import Dict, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.utils._pytree as pytree
 | 
			
		||||
from torch._C import DispatchKey
 | 
			
		||||
from torch._higher_order_ops.utils import autograd_not_implemented
 | 
			
		||||
from torch._ops import HigherOrderOperator
 | 
			
		||||
from torch._ops import HigherOrderOperator, OpOverload
 | 
			
		||||
from torch._subclasses import FakeTensorMode
 | 
			
		||||
from torch.fx.experimental._backward_state import BackwardState
 | 
			
		||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
 | 
			
		||||
from torch.overrides import TorchFunctionMode
 | 
			
		||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
 | 
			
		||||
from torch.utils._pytree import tree_map_only
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Tensor = torch.Tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = ["trace_wrapped"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -43,6 +50,27 @@ __all__ = ["trace_wrapped"]
 | 
			
		||||
# compiled autograd do we inline into the function.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransformGetItemToIndex(TorchFunctionMode):
 | 
			
		||||
    # This is needed since we want to support calling
 | 
			
		||||
    # A[q_idx], where q_idx is a scalar tensor in score_mod.
 | 
			
		||||
    # Today, when q_idx is a scalar tensor, we implicitly convert it to a python
 | 
			
		||||
    # scalar and create a view. We do not want that behavior in this case, so we
 | 
			
		||||
    # use this torchfunctionmode to override that behavior for score_mod
 | 
			
		||||
    # wherever we're running it.
 | 
			
		||||
    def __torch_function__(
 | 
			
		||||
        self,
 | 
			
		||||
        func: OpOverload,
 | 
			
		||||
        types: Tuple[torch._C._TensorMeta, ...],
 | 
			
		||||
        args: Tuple[object, ...] = (),
 | 
			
		||||
        kwargs: Optional[Dict[str, object]] = None,
 | 
			
		||||
    ) -> object:
 | 
			
		||||
        if func == torch.Tensor.__getitem__:
 | 
			
		||||
            index_args = pytree.tree_leaves(args[1])
 | 
			
		||||
            if all(isinstance(x, torch.Tensor) for x in index_args):
 | 
			
		||||
                return torch.ops.aten.index(args[0], index_args)
 | 
			
		||||
        return func(*args, **(kwargs or {}))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def trace_wrapped(*args, **kwargs):
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        return _trace_wrapped_op(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
@ -32,13 +32,23 @@ def eager(gm, fake_tensor_inputs, **kwargs):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_eager_backend_with_torch_function_mode(mode):
 | 
			
		||||
    return make_eager_backend_with_torch_function_modes([mode])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_eager_backend_with_torch_function_modes(modes):
 | 
			
		||||
    """Used to trace HOPs (cond and while) for eager exectution, the metadata
 | 
			
		||||
    TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
 | 
			
		||||
    in the HOP, so we need to externally run this mode and not trace it."""
 | 
			
		||||
    from contextlib import ExitStack
 | 
			
		||||
 | 
			
		||||
    def fn(gm, fake_tensor_inputs, **kwargs):
 | 
			
		||||
        with mode:
 | 
			
		||||
            return gm.forward
 | 
			
		||||
        stack = ExitStack()
 | 
			
		||||
        for mode in modes:
 | 
			
		||||
            stack.enter_context(mode)
 | 
			
		||||
 | 
			
		||||
        result = gm.forward
 | 
			
		||||
        stack.close()
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    return fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -120,6 +120,7 @@ from .utils import (
 | 
			
		||||
    troubleshooting_url,
 | 
			
		||||
    write_record_to_file,
 | 
			
		||||
)
 | 
			
		||||
from .variables.torch_function import torch_function_mode_stack_state_mgr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
np: Optional[ModuleType]
 | 
			
		||||
@ -218,15 +219,18 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
 | 
			
		||||
            prior_fwd_from_src = torch.fx.graph_module._forward_from_src
 | 
			
		||||
            torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
 | 
			
		||||
            cleanup = setup_compile_debug()
 | 
			
		||||
 | 
			
		||||
            exit_stack = contextlib.ExitStack()
 | 
			
		||||
            exit_stack.enter_context(
 | 
			
		||||
                torch.fx._symbolic_trace._maybe_revert_all_patches()
 | 
			
		||||
            )
 | 
			
		||||
            exit_stack.enter_context(torch_function_mode_stack_state_mgr)
 | 
			
		||||
            try:
 | 
			
		||||
                return fn(*args, **kwargs)
 | 
			
		||||
            finally:
 | 
			
		||||
                cleanup.close()
 | 
			
		||||
                assert (
 | 
			
		||||
                    torch._C._len_torch_function_stack() == 0
 | 
			
		||||
                ), "Torch function mode stack state changed while dynamo tracing, please report a bug"
 | 
			
		||||
                exit_stack.close()
 | 
			
		||||
                torch._C._set_grad_enabled(prior_grad_mode)
 | 
			
		||||
                torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
 | 
			
		||||
 | 
			
		||||
@ -2356,15 +2356,12 @@ class CheckFunctionManager:
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if config.enable_cpp_guard_manager:
 | 
			
		||||
            from .variables.torch_function import IGNORED_MODES
 | 
			
		||||
 | 
			
		||||
            # Insert the global_state guard
 | 
			
		||||
            assert self.guard_manager  # to make mypy happy
 | 
			
		||||
            self.guard_manager.root.add_global_state_guard(["___check_global_state()"])
 | 
			
		||||
 | 
			
		||||
            self.guard_manager.root.add_torch_function_mode_stack_guard(
 | 
			
		||||
                self.torch_function_mode_stack,
 | 
			
		||||
                list(IGNORED_MODES),
 | 
			
		||||
                ["___check_torch_function_mode_stack()"],
 | 
			
		||||
            )
 | 
			
		||||
            # Clear references to torch_function modes held in the list
 | 
			
		||||
@ -2671,18 +2668,14 @@ def is_recompiles_verbose_enabled():
 | 
			
		||||
# this will only be used if cpp guards are disabled
 | 
			
		||||
def make_torch_function_mode_stack_guard(intial_stack):
 | 
			
		||||
    types = [type(x) for x in intial_stack]
 | 
			
		||||
    from .variables.torch_function import IGNORED_MODES
 | 
			
		||||
 | 
			
		||||
    def check_torch_function_mode_stack():
 | 
			
		||||
        cur_stack = get_torch_function_mode_stack()
 | 
			
		||||
 | 
			
		||||
        types_ = [ty for ty in types if ty not in IGNORED_MODES]
 | 
			
		||||
        cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES]
 | 
			
		||||
 | 
			
		||||
        if len(cur_stack_) != len(types_):
 | 
			
		||||
        if len(cur_stack) != len(types):
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        for ty, mode in zip(types_, cur_stack_):
 | 
			
		||||
        for ty, mode in zip(types, cur_stack):
 | 
			
		||||
            if ty != type(mode):
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -78,7 +78,6 @@ from .utils import (
 | 
			
		||||
    get_instruction_source_311,
 | 
			
		||||
    get_locals_to_steal,
 | 
			
		||||
    get_static_address_type,
 | 
			
		||||
    get_torch_function_mode_stack,
 | 
			
		||||
    graph_break_reasons,
 | 
			
		||||
    increment_op_count,
 | 
			
		||||
    lazy_format_graph_code,
 | 
			
		||||
@ -250,6 +249,7 @@ class OutputGraph:
 | 
			
		||||
        local_scope: Scope,
 | 
			
		||||
        global_scope: Scope,
 | 
			
		||||
        f_code,
 | 
			
		||||
        torch_function_mode_stack,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.tracers = [SubgraphTracer(self, export_root=export)]
 | 
			
		||||
@ -368,7 +368,7 @@ class OutputGraph:
 | 
			
		||||
        # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
 | 
			
		||||
        self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
 | 
			
		||||
        # This records the initial torch function mode stack for guarding
 | 
			
		||||
        self.torch_function_mode_stack = get_torch_function_mode_stack()
 | 
			
		||||
        self.torch_function_mode_stack = torch_function_mode_stack
 | 
			
		||||
 | 
			
		||||
        # Tracks if the output graph has a user defined allowed function in the
 | 
			
		||||
        # graph. This is used later to determine if we should fallback to eager
 | 
			
		||||
@ -1021,7 +1021,7 @@ class OutputGraph:
 | 
			
		||||
            prefix_insts.clear()
 | 
			
		||||
 | 
			
		||||
        for block in reversed(tx.block_stack):
 | 
			
		||||
            block.exit(tx)
 | 
			
		||||
            block.exit(tx, is_graph_break=reason.graph_break)
 | 
			
		||||
 | 
			
		||||
        self.cleanup_graph()
 | 
			
		||||
        tx.prune_dead_locals()
 | 
			
		||||
 | 
			
		||||
@ -25,6 +25,26 @@ if TYPE_CHECKING:
 | 
			
		||||
        sys as sys,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
from torch.overrides import BaseTorchFunctionMode
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# These classes handle support for TorchFunctionModes across
 | 
			
		||||
# graph breaks
 | 
			
		||||
# Today the TorchFunctionMode enter (for the classes we support)
 | 
			
		||||
# simply pushes the mode onto the stack. Since after this occurs
 | 
			
		||||
# the stack is mutated, and we replay these mutations, we don't need
 | 
			
		||||
# any cleanup logic to be run once the graph break occurs, we simply replay
 | 
			
		||||
# these mutations to ensure at the graph break the torch function mode stack is correct
 | 
			
		||||
#  and reconstruct the torch function mode stack normally
 | 
			
		||||
# when we compile the resume function on the other side of the break.
 | 
			
		||||
# However, to ensure we exit properly
 | 
			
		||||
# in the resume function, we need to re-enter the contexts as we do other contexts.
 | 
			
		||||
# These contexts do nothing on enter, but provide the correct exit logic to ensure
 | 
			
		||||
# the stack state is correct.
 | 
			
		||||
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
 | 
			
		||||
    def __enter__(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def index(iterator, item, start=0, end=None):
 | 
			
		||||
    from itertools import islice
 | 
			
		||||
 | 
			
		||||
@ -90,27 +90,25 @@ class ReenterWith:
 | 
			
		||||
    stack_index: int
 | 
			
		||||
    target_values: Optional[Tuple[Any, ...]] = None
 | 
			
		||||
 | 
			
		||||
    # TODO(mlazos) - Uncomment with the reland of torch function mode support
 | 
			
		||||
    # def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
 | 
			
		||||
    #     """
 | 
			
		||||
    #     Codegen based off of:
 | 
			
		||||
    #     try:
 | 
			
		||||
    #         (rest)
 | 
			
		||||
    #     except:
 | 
			
		||||
    #         (restore previous tf mode stack)
 | 
			
		||||
    #         raise
 | 
			
		||||
    def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
 | 
			
		||||
        """
 | 
			
		||||
        Codegen based off of:
 | 
			
		||||
        try:
 | 
			
		||||
            (rest)
 | 
			
		||||
        except:
 | 
			
		||||
            (restore previous tf mode stack)
 | 
			
		||||
            raise
 | 
			
		||||
        """
 | 
			
		||||
        from .variables.torch_function import get_prev_stack_var_name
 | 
			
		||||
 | 
			
		||||
    #     """
 | 
			
		||||
    #     from .variables.torch_function import get_prev_stack_var_name
 | 
			
		||||
        setup_try_except, epilogue = _bytecode_from_template_with_split(
 | 
			
		||||
            _try_except_tf_mode_template,
 | 
			
		||||
            self.stack_index,
 | 
			
		||||
            varname_map={"stack_var_name": get_prev_stack_var_name()},
 | 
			
		||||
        )
 | 
			
		||||
        cleanup[:] = epilogue + cleanup
 | 
			
		||||
 | 
			
		||||
    #     setup_try_except, epilogue = _bytecode_from_template_with_split(
 | 
			
		||||
    #         _try_except_tf_mode_template,
 | 
			
		||||
    #         self.stack_index,
 | 
			
		||||
    #         varname_map={"stack_var_name": get_prev_stack_var_name()},
 | 
			
		||||
    #     )
 | 
			
		||||
    #     cleanup[:] = epilogue + cleanup
 | 
			
		||||
 | 
			
		||||
    #     return setup_try_except
 | 
			
		||||
        return setup_try_except
 | 
			
		||||
 | 
			
		||||
    # If we do not want to destroy the stack, we can do the same thing as a
 | 
			
		||||
    # `SETUP_WITH` block, only that we store the context manager in a local_symbol
 | 
			
		||||
 | 
			
		||||
@ -629,11 +629,22 @@ class SideEffects:
 | 
			
		||||
            elif isinstance(
 | 
			
		||||
                var, variables.torch_function.TorchFunctionModeStackVariable
 | 
			
		||||
            ):
 | 
			
		||||
                # Needed in the finally block for stack restoration
 | 
			
		||||
                cg.add_push_null(
 | 
			
		||||
                    lambda: cg.load_import_from(
 | 
			
		||||
                        utils.__name__, "get_torch_function_mode_stack"
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                cg.call_function(0, False)
 | 
			
		||||
                name = variables.torch_function.get_prev_stack_var_name()
 | 
			
		||||
                cg.code_options["co_varnames"] += (name,)
 | 
			
		||||
                cg.append_output(create_instruction("STORE_FAST", argval=name))
 | 
			
		||||
                cg.add_push_null(
 | 
			
		||||
                    lambda: cg.load_import_from(
 | 
			
		||||
                        utils.__name__, "set_torch_function_mode_stack"
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                cg.foreach(var.symbolic_stack)
 | 
			
		||||
                cg.append_output(
 | 
			
		||||
                    create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))
 | 
			
		||||
 | 
			
		||||
@ -267,13 +267,12 @@ class BlockStackEntry:
 | 
			
		||||
        else:
 | 
			
		||||
            return ReenterWith(self.stack_index)
 | 
			
		||||
 | 
			
		||||
    def exit(self, tx):
 | 
			
		||||
        if hasattr(self, "graph_break") and isinstance(
 | 
			
		||||
            self.with_context, TorchFunctionModeVariable
 | 
			
		||||
        ):
 | 
			
		||||
            return
 | 
			
		||||
    def exit(self, tx, is_graph_break):
 | 
			
		||||
        assert self.with_context is not None
 | 
			
		||||
        return self.with_context.exit(tx)
 | 
			
		||||
        if (
 | 
			
		||||
            is_graph_break and self.with_context.exit_on_graph_break()
 | 
			
		||||
        ) or not is_graph_break:
 | 
			
		||||
            return self.with_context.exit(tx)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReturnValueOp(Exception):
 | 
			
		||||
@ -657,10 +656,17 @@ def break_graph_if_unsupported(*, push):
 | 
			
		||||
            cleanup: List[Instruction] = []
 | 
			
		||||
            # Reconstruct the context variable CLASS in the block stack
 | 
			
		||||
            for b in self.block_stack:
 | 
			
		||||
                # Don't exit any modes we have entered,
 | 
			
		||||
                # output bytecode will mutate the tf mode stack accordingly
 | 
			
		||||
                if isinstance(b.with_context, TorchFunctionModeVariable):
 | 
			
		||||
                    cg.extend_output(
 | 
			
		||||
                        b.resume_fn().try_except_torch_function_mode(
 | 
			
		||||
                            cg.code_options, cleanup
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                    continue
 | 
			
		||||
                assert b.with_context is not None
 | 
			
		||||
                assert isinstance(
 | 
			
		||||
                    b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable)
 | 
			
		||||
                )
 | 
			
		||||
                assert isinstance(b.with_context, (ContextWrappingVariable))
 | 
			
		||||
                b.with_context.reconstruct_type(cg)
 | 
			
		||||
                cg.extend_output(b.resume_fn().try_finally(cg.code_options, cleanup))
 | 
			
		||||
            self.output.add_output_instructions(cg.get_instructions())
 | 
			
		||||
@ -2314,7 +2320,10 @@ class InstructionTranslatorBase(
 | 
			
		||||
        ):
 | 
			
		||||
            unimplemented(f"{inst.opname} {ctx}")
 | 
			
		||||
 | 
			
		||||
        if isinstance(ctx, GenericContextWrappingVariable):
 | 
			
		||||
        if (
 | 
			
		||||
            isinstance(ctx, GenericContextWrappingVariable)
 | 
			
		||||
            and not ctx.supports_graph_breaks()
 | 
			
		||||
        ):
 | 
			
		||||
            self.generic_context_manager_depth += 1
 | 
			
		||||
 | 
			
		||||
        # Need this redundant check for mypy
 | 
			
		||||
@ -2687,6 +2696,7 @@ class InstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
                local_scope=f_locals,
 | 
			
		||||
                global_scope=f_globals,
 | 
			
		||||
                f_code=f_code,
 | 
			
		||||
                torch_function_mode_stack=torch_function_mode_stack,
 | 
			
		||||
            ),
 | 
			
		||||
            instructions=instructions,
 | 
			
		||||
            f_locals=f_locals,
 | 
			
		||||
 | 
			
		||||
@ -187,6 +187,7 @@ def debug_insert_nops(
 | 
			
		||||
        local_scope=locals(),
 | 
			
		||||
        global_scope=globals(),
 | 
			
		||||
        f_code=frame.f_code,
 | 
			
		||||
        torch_function_mode_stack=[],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))
 | 
			
		||||
 | 
			
		||||
@ -304,6 +304,7 @@ manual_torch_name_rule_map = {
 | 
			
		||||
    "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
 | 
			
		||||
    "torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
 | 
			
		||||
    "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
 | 
			
		||||
    "torch.set_default_device": UserFunctionVariable,
 | 
			
		||||
    "torch.sparse_bsc_tensor": SkipFunctionVariable,
 | 
			
		||||
    "torch.sparse_bsr_tensor": SkipFunctionVariable,
 | 
			
		||||
    "torch.sparse_csc_tensor": SkipFunctionVariable,
 | 
			
		||||
@ -2802,7 +2803,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
 | 
			
		||||
        "torch.random.initial_seed",
 | 
			
		||||
        "torch.random.seed",
 | 
			
		||||
        "torch.return_types.pytree_register_structseq",
 | 
			
		||||
        "torch.set_default_device",
 | 
			
		||||
        "torch.set_default_dtype",
 | 
			
		||||
        "torch.set_default_tensor_type",
 | 
			
		||||
        "torch.set_deterministic_debug_mode",
 | 
			
		||||
@ -2912,6 +2912,9 @@ def get_tensor_method():
 | 
			
		||||
            method, (types.MethodDescriptorType, types.WrapperDescriptorType)
 | 
			
		||||
        ):
 | 
			
		||||
            s.add(method)
 | 
			
		||||
 | 
			
		||||
    # mlazos: this is a function which we handle specially in TensorVariable
 | 
			
		||||
    s.add(torch.Tensor.__contains__)  # type: ignore[arg-type]
 | 
			
		||||
    return frozenset(s)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2912,18 +2912,28 @@ def is_torch_function_object(value):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def has_torch_function(vt: torch._dynamo.variables.base.VariableTracker) -> bool:
 | 
			
		||||
    from torch._dynamo.variables import LazyVariableTracker, UserDefinedObjectVariable
 | 
			
		||||
    from torch._dynamo.variables import UserDefinedObjectVariable
 | 
			
		||||
    from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable
 | 
			
		||||
 | 
			
		||||
    if isinstance(vt, TensorWithTFOverrideVariable):
 | 
			
		||||
        return True
 | 
			
		||||
    # Note on lazy vars: The value will either be realized or not throughout the course of execution
 | 
			
		||||
    # if the value has a torch function, it will eventually be realized so we can realize it here
 | 
			
		||||
    # if the value does not have a torch function, it may or may not be realized
 | 
			
		||||
    # if it is realized it will be used and guards will be installed properly
 | 
			
		||||
    # if it is not used, guards won't be installed, and it doesn't matter
 | 
			
		||||
    # if the value has a torch function or not, so we should *not* realize it.
 | 
			
		||||
    # NB: We technically know that if is_realized is False, LazyVariableTracker has the peek_value method
 | 
			
		||||
    # but mypy does not unfortunately
 | 
			
		||||
    if vt.is_realized() or (
 | 
			
		||||
        hasattr(vt, "peek_value") and hasattr(vt.peek_value(), "__torch_function__")
 | 
			
		||||
    ):
 | 
			
		||||
        if isinstance(vt, TensorWithTFOverrideVariable):
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
    if isinstance(vt, LazyVariableTracker):
 | 
			
		||||
        LazyVariableTracker.realize(vt)
 | 
			
		||||
        return isinstance(vt, UserDefinedObjectVariable) and hasattr(
 | 
			
		||||
            vt.value, "__torch_function__"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return isinstance(vt, UserDefinedObjectVariable) and hasattr(
 | 
			
		||||
        vt.value, "__torch_function__"
 | 
			
		||||
    )
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# see note [Tensor Fakification and Symbol Caching]
 | 
			
		||||
@ -3116,16 +3126,10 @@ def is_parameter_freezing():
 | 
			
		||||
    return torch._inductor.config.freezing and not torch.is_grad_enabled()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_torch_function_mode_stack(filter_ignored=True):
 | 
			
		||||
    from .variables.torch_function import IGNORED_MODES
 | 
			
		||||
 | 
			
		||||
    stack = [
 | 
			
		||||
def get_torch_function_mode_stack():
 | 
			
		||||
    return [
 | 
			
		||||
        get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack())
 | 
			
		||||
    ]
 | 
			
		||||
    if filter_ignored:
 | 
			
		||||
        stack = [mode for mode in stack if type(mode) not in IGNORED_MODES]
 | 
			
		||||
 | 
			
		||||
    return stack
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_torch_function_mode_stack_at(ind):
 | 
			
		||||
 | 
			
		||||
@ -204,6 +204,7 @@ from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
 | 
			
		||||
from .torch_function import (
 | 
			
		||||
    build_torch_function_fn,
 | 
			
		||||
    TensorWithTFOverrideVariable,
 | 
			
		||||
    torch_function_mode_stack_state_mgr,
 | 
			
		||||
    TorchFunctionModeVariable,
 | 
			
		||||
)
 | 
			
		||||
from .user_defined import (
 | 
			
		||||
@ -1669,15 +1670,16 @@ class VariableBuilder:
 | 
			
		||||
                # but warning is not the end of the world
 | 
			
		||||
                assert isinstance(value.base, np.nditer)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            tensor_value = _util._try_convert_to_tensor(value)
 | 
			
		||||
            if readonly:
 | 
			
		||||
                from torch._prims_common import clone_preserve_strides
 | 
			
		||||
        with torch_function_mode_stack_state_mgr.temp_restore_stack():
 | 
			
		||||
            try:
 | 
			
		||||
                tensor_value = _util._try_convert_to_tensor(value)
 | 
			
		||||
                if readonly:
 | 
			
		||||
                    from torch._prims_common import clone_preserve_strides
 | 
			
		||||
 | 
			
		||||
                tensor_value = clone_preserve_strides(tensor_value)
 | 
			
		||||
        except NotImplementedError as e:
 | 
			
		||||
            # failed to convert to tensor, graph break
 | 
			
		||||
            unimplemented(str(e))
 | 
			
		||||
                    tensor_value = clone_preserve_strides(tensor_value)
 | 
			
		||||
            except NotImplementedError as e:
 | 
			
		||||
                # failed to convert to tensor, graph break
 | 
			
		||||
                unimplemented(str(e))
 | 
			
		||||
 | 
			
		||||
        # We do this because we want the full behavior of guarding the numpy ndarray as if it were
 | 
			
		||||
        # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
 | 
			
		||||
 | 
			
		||||
@ -200,7 +200,6 @@ class BuiltinVariable(VariableTracker):
 | 
			
		||||
            operator.ne,
 | 
			
		||||
            operator.eq,
 | 
			
		||||
            operator.sub,
 | 
			
		||||
            operator.getitem,
 | 
			
		||||
            operator.length_hint,
 | 
			
		||||
            operator.lshift,
 | 
			
		||||
            operator.rshift,
 | 
			
		||||
@ -212,6 +211,7 @@ class BuiltinVariable(VariableTracker):
 | 
			
		||||
            operator.imatmul,
 | 
			
		||||
            operator.ifloordiv,
 | 
			
		||||
            operator.itruediv,
 | 
			
		||||
            operator.getitem,
 | 
			
		||||
            operator.imod,
 | 
			
		||||
            operator.iadd,
 | 
			
		||||
            operator.isub,
 | 
			
		||||
@ -858,6 +858,39 @@ class BuiltinVariable(VariableTracker):
 | 
			
		||||
        if kwargs and not self.tensor_args(*args, *kwargs.values()):
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # insert handling for torch function here
 | 
			
		||||
        from .builder import SourcelessBuilder
 | 
			
		||||
        from .torch_function import (
 | 
			
		||||
            BUILTIN_TO_TENSOR_FN_MAP,
 | 
			
		||||
            BUILTIN_TO_TENSOR_RFN_MAP,
 | 
			
		||||
            can_dispatch_torch_function,
 | 
			
		||||
            dispatch_torch_function,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if can_dispatch_torch_function(tx, args, kwargs):
 | 
			
		||||
            # Only remap the fn to tensor methods if we aren't exporting
 | 
			
		||||
            # export serde does not handle method descriptors today
 | 
			
		||||
            if not tx.export:
 | 
			
		||||
                # Use sourceless builder, we built the map ourselves
 | 
			
		||||
                if not isinstance(args[0], TensorVariable):
 | 
			
		||||
                    if self.fn in BUILTIN_TO_TENSOR_RFN_MAP:
 | 
			
		||||
                        func = BUILTIN_TO_TENSOR_RFN_MAP[self.fn]
 | 
			
		||||
                    else:
 | 
			
		||||
                        func = BUILTIN_TO_TENSOR_FN_MAP[self.fn]
 | 
			
		||||
 | 
			
		||||
                    tmp = args[0]
 | 
			
		||||
                    # swap args and call reverse version of func
 | 
			
		||||
                    args[0] = args[1]
 | 
			
		||||
                    args[1] = tmp
 | 
			
		||||
                else:
 | 
			
		||||
                    func = BUILTIN_TO_TENSOR_FN_MAP[self.fn]
 | 
			
		||||
            else:
 | 
			
		||||
                func = self.fn
 | 
			
		||||
 | 
			
		||||
            fn_var = SourcelessBuilder.create(tx, func)
 | 
			
		||||
 | 
			
		||||
            return dispatch_torch_function(tx, fn_var, args, kwargs)
 | 
			
		||||
 | 
			
		||||
        fn = self.fn
 | 
			
		||||
        try:
 | 
			
		||||
            # Constant fold for constant tensor and python constants
 | 
			
		||||
 | 
			
		||||
@ -125,6 +125,12 @@ class ContextWrappingVariable(VariableTracker):
 | 
			
		||||
        if isinstance(args[0], UserFunctionVariable):
 | 
			
		||||
            return WrappedUserFunctionVariable(args[0], self)
 | 
			
		||||
 | 
			
		||||
    def supports_graph_breaks(self):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def exit_on_graph_break(self):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GenericContextWrappingVariable(UserDefinedObjectVariable):
 | 
			
		||||
    # Some methods in ContextWrappingVariable assumes the arguments are
 | 
			
		||||
@ -183,6 +189,12 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable):
 | 
			
		||||
        tx.generic_context_manager_depth -= 1
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def supports_graph_breaks(self):
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def exit_on_graph_break(self):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
 | 
			
		||||
    """represents torch grad requries grad"""
 | 
			
		||||
 | 
			
		||||
@ -1998,8 +1998,7 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
 | 
			
		||||
        fn: "VariableTracker",
 | 
			
		||||
        fn_name: str,
 | 
			
		||||
    ):
 | 
			
		||||
        from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
 | 
			
		||||
 | 
			
		||||
        from .._trace_wrapped_higher_order_op import TransformGetItemToIndex
 | 
			
		||||
        from .builder import SourcelessBuilder
 | 
			
		||||
 | 
			
		||||
        tx: InstructionTranslator = tx
 | 
			
		||||
 | 
			
		||||
@ -80,6 +80,14 @@ class LazyVariableTracker(VariableTracker):
 | 
			
		||||
            self.realize()
 | 
			
		||||
        return VariableTracker.clone(self.unwrap(), **kwargs)
 | 
			
		||||
 | 
			
		||||
    def peek_type(self) -> type[Any]:
 | 
			
		||||
        assert not self.is_realized()
 | 
			
		||||
        return type(self._cache.value)
 | 
			
		||||
 | 
			
		||||
    def peek_value(self) -> Any:
 | 
			
		||||
        assert not self.is_realized()
 | 
			
		||||
        return self._cache.value
 | 
			
		||||
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        if self.is_realized():
 | 
			
		||||
            return self.unwrap().__str__()
 | 
			
		||||
 | 
			
		||||
@ -510,9 +510,37 @@ class TensorVariable(VariableTracker):
 | 
			
		||||
        args: "List[VariableTracker]",
 | 
			
		||||
        kwargs: "Dict[str, VariableTracker]",
 | 
			
		||||
    ) -> "VariableTracker":
 | 
			
		||||
        from .builder import SourcelessBuilder, VariableBuilder
 | 
			
		||||
        from .torch_function import can_dispatch_torch_function, dispatch_torch_function
 | 
			
		||||
 | 
			
		||||
        if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops():
 | 
			
		||||
            unimplemented(f"Illegal method invocation {name} in strict mode")
 | 
			
		||||
 | 
			
		||||
        # Only override builtin tensor methods
 | 
			
		||||
        # The user can manually add override handling
 | 
			
		||||
        # with a decorator for other methods (e.g. a dispatch subclass with other methods)
 | 
			
		||||
        has_torch_function_override = False
 | 
			
		||||
        try:
 | 
			
		||||
            inspect.getattr_static(torch.Tensor, name)
 | 
			
		||||
            has_torch_function_override = True
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            has_torch_function_override = False
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            can_dispatch_torch_function(tx, tuple([self] + list(args)), kwargs)
 | 
			
		||||
            and has_torch_function_override
 | 
			
		||||
        ):
 | 
			
		||||
            if self.source:
 | 
			
		||||
                func_var = VariableBuilder(
 | 
			
		||||
                    tx, AttrSource(AttrSource(self.source, "__class__"), name)
 | 
			
		||||
                )(inspect.getattr_static(torch.Tensor, name))
 | 
			
		||||
            else:
 | 
			
		||||
                func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name))
 | 
			
		||||
 | 
			
		||||
            return dispatch_torch_function(
 | 
			
		||||
                tx, func_var, tuple([self] + list(args)), kwargs
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        Dispatch to a method-specific handler defined below.  If the
 | 
			
		||||
        handler returns None (or doesn't exist) we put the method call
 | 
			
		||||
@ -772,6 +800,30 @@ class TensorVariable(VariableTracker):
 | 
			
		||||
            self._warn_capture_scalar_outputs()
 | 
			
		||||
            unimplemented("Tensor.item")
 | 
			
		||||
 | 
			
		||||
    def method___getitem__(self, *args, **kwargs):
 | 
			
		||||
        from ..symbolic_convert import InstructionTranslator
 | 
			
		||||
        from .builder import wrap_fx_proxy
 | 
			
		||||
 | 
			
		||||
        tx = InstructionTranslator.current_tx()
 | 
			
		||||
        if isinstance(args[0], SymNodeVariable):
 | 
			
		||||
            # Standard indexing will force specialization due to
 | 
			
		||||
            # __index__.  Rewrite as a regular torch op which will
 | 
			
		||||
            # trace fine
 | 
			
		||||
            fn, args = torch.select, [
 | 
			
		||||
                variables.ConstantVariable.create(0),
 | 
			
		||||
                args[0],
 | 
			
		||||
            ]
 | 
			
		||||
        else:
 | 
			
		||||
            fn = operator.getitem
 | 
			
		||||
 | 
			
		||||
        proxy = tx.output.create_proxy(
 | 
			
		||||
            "call_function",
 | 
			
		||||
            fn,
 | 
			
		||||
            *proxy_args_kwargs([self] + list(args), kwargs),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return wrap_fx_proxy(tx, proxy)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @functools.lru_cache(None)
 | 
			
		||||
    def _warn_capture_scalar_outputs():
 | 
			
		||||
 | 
			
		||||
@ -159,7 +159,17 @@ def get_overridable_functions():
 | 
			
		||||
 | 
			
		||||
    from torch.overrides import get_overridable_functions as get_overridable_functions_
 | 
			
		||||
 | 
			
		||||
    return set(chain(*get_overridable_functions_().values()))
 | 
			
		||||
    funcs = set(chain(*get_overridable_functions_().values()))
 | 
			
		||||
    more = {
 | 
			
		||||
        torch.ones,
 | 
			
		||||
        torch.ones_like,
 | 
			
		||||
        torch.zeros,
 | 
			
		||||
        torch.zeros_like,
 | 
			
		||||
        torch.empty,
 | 
			
		||||
        torch.full,
 | 
			
		||||
    }
 | 
			
		||||
    funcs.update(more)
 | 
			
		||||
    return funcs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseTorchVariable(VariableTracker):
 | 
			
		||||
@ -835,6 +845,13 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
 | 
			
		||||
                len(tx.symbolic_torch_function_state.mode_stack)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        @register(torch._C._get_function_stack_at)
 | 
			
		||||
        def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs):
 | 
			
		||||
            assert len(args) == 1 and not kwargs
 | 
			
		||||
            ind = args[0].as_python_constant()
 | 
			
		||||
            assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack)
 | 
			
		||||
            return tx.symbolic_torch_function_state.mode_stack[ind]
 | 
			
		||||
 | 
			
		||||
        @register(torch.set_default_device)
 | 
			
		||||
        def handle_set_default_device(
 | 
			
		||||
            self, tx: "InstructionTranslator", *args, **kwargs
 | 
			
		||||
@ -852,7 +869,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
 | 
			
		||||
            else:
 | 
			
		||||
                TorchFunctionModeStackVariable.register_device_context_insertion(tx)
 | 
			
		||||
 | 
			
		||||
            return None
 | 
			
		||||
            return ConstantVariable.create(None)
 | 
			
		||||
 | 
			
		||||
        return handlers
 | 
			
		||||
 | 
			
		||||
@ -883,6 +900,9 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.is_tensor_method():
 | 
			
		||||
            return self.call_tensor_method(tx, args, kwargs)
 | 
			
		||||
 | 
			
		||||
        special_handler = self._get_handlers().get(self.value)
 | 
			
		||||
        if special_handler:
 | 
			
		||||
            result = special_handler(self, tx, *args, **kwargs)
 | 
			
		||||
@ -1155,6 +1175,16 @@ Either create the tensor outside the compiled region, or do not set the tensor t
 | 
			
		||||
        )
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def call_tensor_method(self, tx, args, kwargs):
 | 
			
		||||
        return args[0].call_method(tx, self.get_function().__name__, args[1:], kwargs)
 | 
			
		||||
 | 
			
		||||
    def is_tensor_method(self):
 | 
			
		||||
        return (
 | 
			
		||||
            inspect.ismethoddescriptor(self.get_function())
 | 
			
		||||
            and hasattr(self.get_function(), "__objclass__")
 | 
			
		||||
            and self.get_function().__objclass__ == torch._C.TensorBase
 | 
			
		||||
        ) or self.get_function() is torch.Tensor.__contains__
 | 
			
		||||
 | 
			
		||||
    def torch_function_override_enabled(self, tx, args, kwargs):
 | 
			
		||||
        return (
 | 
			
		||||
            self.get_function() in get_overridable_functions()
 | 
			
		||||
 | 
			
		||||
@ -2,22 +2,37 @@
 | 
			
		||||
 | 
			
		||||
import collections
 | 
			
		||||
import contextlib
 | 
			
		||||
import functools
 | 
			
		||||
import inspect
 | 
			
		||||
import operator
 | 
			
		||||
from typing import Deque, Dict, List, TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
import torch._C
 | 
			
		||||
import torch.utils._pytree as pytree
 | 
			
		||||
from torch._guards import Source
 | 
			
		||||
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
 | 
			
		||||
from torch.overrides import (
 | 
			
		||||
    _get_overloaded_args,
 | 
			
		||||
    BaseTorchFunctionMode,
 | 
			
		||||
    get_default_nowrap_functions,
 | 
			
		||||
    TorchFunctionMode,
 | 
			
		||||
)
 | 
			
		||||
from torch.utils._device import DeviceContext
 | 
			
		||||
 | 
			
		||||
from ..exc import unimplemented
 | 
			
		||||
from ..guards import GuardBuilder, install_guard
 | 
			
		||||
from ..polyfills import NoEnterTorchFunctionMode
 | 
			
		||||
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
 | 
			
		||||
from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
 | 
			
		||||
from ..utils import (
 | 
			
		||||
    class_has_getattribute,
 | 
			
		||||
    clear_torch_function_mode_stack,
 | 
			
		||||
    get_safe_global_name,
 | 
			
		||||
    has_torch_function,
 | 
			
		||||
    is_tensor_base_attr_getter,
 | 
			
		||||
    set_torch_function_mode_stack,
 | 
			
		||||
)
 | 
			
		||||
from .base import VariableTracker
 | 
			
		||||
from .constant import ConstantVariable
 | 
			
		||||
from .ctx_manager import ContextWrappingVariable
 | 
			
		||||
from .ctx_manager import GenericContextWrappingVariable
 | 
			
		||||
from .lazy import LazyVariableTracker
 | 
			
		||||
from .lists import TupleVariable
 | 
			
		||||
from .tensor import TensorSubclassVariable, TensorVariable
 | 
			
		||||
@ -49,6 +64,125 @@ if TYPE_CHECKING:
 | 
			
		||||
 | 
			
		||||
# To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses in dynamo/config.py
 | 
			
		||||
 | 
			
		||||
bin_ops = [
 | 
			
		||||
    operator.pow,
 | 
			
		||||
    operator.mul,
 | 
			
		||||
    operator.matmul,
 | 
			
		||||
    operator.floordiv,
 | 
			
		||||
    operator.truediv,
 | 
			
		||||
    operator.mod,
 | 
			
		||||
    operator.add,
 | 
			
		||||
    operator.lt,
 | 
			
		||||
    operator.gt,
 | 
			
		||||
    operator.ge,
 | 
			
		||||
    operator.le,
 | 
			
		||||
    operator.ne,
 | 
			
		||||
    operator.eq,
 | 
			
		||||
    operator.sub,
 | 
			
		||||
    operator.ipow,
 | 
			
		||||
    operator.imul,
 | 
			
		||||
    operator.imatmul,
 | 
			
		||||
    operator.ifloordiv,
 | 
			
		||||
    operator.itruediv,
 | 
			
		||||
    operator.imod,
 | 
			
		||||
    operator.iadd,
 | 
			
		||||
    operator.isub,
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
bin_int_ops = [
 | 
			
		||||
    operator.and_,
 | 
			
		||||
    operator.or_,
 | 
			
		||||
    operator.xor,
 | 
			
		||||
    operator.iand,
 | 
			
		||||
    operator.ixor,
 | 
			
		||||
    operator.ior,
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
un_int_ops = [operator.invert]
 | 
			
		||||
 | 
			
		||||
tensor_and_int_ops = [
 | 
			
		||||
    operator.lshift,
 | 
			
		||||
    operator.rshift,
 | 
			
		||||
    operator.ilshift,
 | 
			
		||||
    operator.irshift,
 | 
			
		||||
    operator.getitem,
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
un_ops = [
 | 
			
		||||
    operator.abs,
 | 
			
		||||
    operator.pos,
 | 
			
		||||
    operator.neg,
 | 
			
		||||
    operator.not_,  # Note: this has a local scalar dense call
 | 
			
		||||
    operator.length_hint,
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
BUILTIN_TO_TENSOR_FN_MAP = {}
 | 
			
		||||
 | 
			
		||||
# These functions represent the r* versions of the above ops
 | 
			
		||||
# Basically, if __add__(1, Tensor) is called, it is translated
 | 
			
		||||
# to __radd__(Tensor, 1).
 | 
			
		||||
# In the builtin var, we check if there is a tensor in the first args position,
 | 
			
		||||
# if not, we swap the args and use the r* version of the op.
 | 
			
		||||
BUILTIN_TO_TENSOR_RFN_MAP = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def populate_builtin_to_tensor_fn_map():
 | 
			
		||||
    global BUILTIN_TO_TENSOR_FN_MAP
 | 
			
		||||
 | 
			
		||||
    most_recent_func = None
 | 
			
		||||
 | 
			
		||||
    class GetMethodMode(BaseTorchFunctionMode):
 | 
			
		||||
        """
 | 
			
		||||
        Mode to extract the correct methods from torch function invocations
 | 
			
		||||
        (Used to get the correct torch.Tensor methods from builtins)
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        def __torch_function__(self, func, types, args=(), kwargs=None):
 | 
			
		||||
            kwargs = kwargs or {}
 | 
			
		||||
            nonlocal most_recent_func
 | 
			
		||||
            most_recent_func = func
 | 
			
		||||
            return func(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    inp0 = torch.ones(1)
 | 
			
		||||
    inp1 = torch.ones(1)
 | 
			
		||||
    inp0_int = torch.ones(1, dtype=torch.int32)
 | 
			
		||||
    inp1_int = torch.ones(1, dtype=torch.int32)
 | 
			
		||||
    with GetMethodMode():
 | 
			
		||||
        setups_and_oplists = [
 | 
			
		||||
            (lambda o: o(inp0), un_ops),
 | 
			
		||||
            (lambda o: o(inp0_int), un_int_ops),
 | 
			
		||||
            (lambda o: o(inp0, inp1), bin_ops),
 | 
			
		||||
            (lambda o: o(inp0_int, inp1_int), bin_int_ops),
 | 
			
		||||
            (lambda o: o(inp0_int, 0), tensor_and_int_ops),
 | 
			
		||||
        ]
 | 
			
		||||
        for setup_fn, op_list in setups_and_oplists:
 | 
			
		||||
            for op in op_list:
 | 
			
		||||
                setup_fn(op)
 | 
			
		||||
                assert most_recent_func is not None
 | 
			
		||||
                BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func
 | 
			
		||||
 | 
			
		||||
        # gather the reverse functions
 | 
			
		||||
        rsetups_and_oplists = [
 | 
			
		||||
            (
 | 
			
		||||
                lambda o: o(1, inp1),
 | 
			
		||||
                bin_ops,
 | 
			
		||||
            ),  # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
 | 
			
		||||
            (lambda o: o(1, inp1_int), bin_int_ops),
 | 
			
		||||
            (lambda o: o(0, inp0_int), tensor_and_int_ops),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        rskips = {operator.matmul, operator.imatmul, operator.getitem}
 | 
			
		||||
        for setup_fn, op_list in rsetups_and_oplists:
 | 
			
		||||
            for op in op_list:
 | 
			
		||||
                if op in rskips:
 | 
			
		||||
                    continue
 | 
			
		||||
                setup_fn(op)
 | 
			
		||||
                assert most_recent_func is not None
 | 
			
		||||
                if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]:
 | 
			
		||||
                    BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
populate_builtin_to_tensor_fn_map()
 | 
			
		||||
 | 
			
		||||
banned_attrs = [
 | 
			
		||||
    fn.__self__.__name__
 | 
			
		||||
@ -56,11 +190,38 @@ banned_attrs = [
 | 
			
		||||
    if is_tensor_base_attr_getter(fn)
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
# Today set default device is placed in the graph and guarded on separately
 | 
			
		||||
# so we should not trace through it. In the future we can trace it once
 | 
			
		||||
# mode tracing is implemented and not put in the graph, but this is more
 | 
			
		||||
# of a BE project and can be evaluated later
 | 
			
		||||
IGNORED_MODES = {DeviceContext}
 | 
			
		||||
 | 
			
		||||
@functools.lru_cache(None)
 | 
			
		||||
def get_prev_stack_var_name():
 | 
			
		||||
    from ..bytecode_transformation import unique_id
 | 
			
		||||
 | 
			
		||||
    return unique_id("___prev_torch_function_mode_stack")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
 | 
			
		||||
class TorchFunctionModeStackStateManager:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.stack = []
 | 
			
		||||
 | 
			
		||||
    def __enter__(self):
 | 
			
		||||
        self.stack = torch.overrides._get_current_function_mode_stack()
 | 
			
		||||
        clear_torch_function_mode_stack()
 | 
			
		||||
 | 
			
		||||
    def __exit__(self, exc_type, exc_value, traceback):
 | 
			
		||||
        set_torch_function_mode_stack(self.stack)
 | 
			
		||||
        self.stack = []
 | 
			
		||||
 | 
			
		||||
    @contextlib.contextmanager
 | 
			
		||||
    def temp_restore_stack(self):
 | 
			
		||||
        prev = torch.overrides._get_current_function_mode_stack()
 | 
			
		||||
        set_torch_function_mode_stack(self.stack)
 | 
			
		||||
        try:
 | 
			
		||||
            yield
 | 
			
		||||
        finally:
 | 
			
		||||
            set_torch_function_mode_stack(prev)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SymbolicTorchFunctionState:
 | 
			
		||||
@ -189,9 +350,26 @@ class TorchFunctionModeStackVariable(VariableTracker):
 | 
			
		||||
        return ind + cls.offset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TorchFunctionModeVariable(ContextWrappingVariable):
 | 
			
		||||
class TorchFunctionModeVariable(GenericContextWrappingVariable):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def is_supported_torch_function_mode(ty):
 | 
			
		||||
        # Supported in this sense means we can support graph breaks under the
 | 
			
		||||
        # context.
 | 
			
		||||
        # We are able to trace custom modes but if there are graph breaks under them
 | 
			
		||||
        # and they have a custom __enter__/__exit__ we don't handle this for the
 | 
			
		||||
        # same reason we don't handle generic context managers: there may be side effects
 | 
			
		||||
        # that are now affected by executing the funtion across two frames instead of one
 | 
			
		||||
        # Today we support the enter/exit of the default TorchFunctionMode as well as
 | 
			
		||||
        # DeviceContext (which is used for set_default_device)
 | 
			
		||||
        return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
 | 
			
		||||
            not class_has_getattribute(ty)
 | 
			
		||||
            and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
 | 
			
		||||
            and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def __init__(self, value, source=None, **kwargs):
 | 
			
		||||
        super().__init__(value, **kwargs)
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            super().__init__(value, **kwargs)
 | 
			
		||||
        self.value = value
 | 
			
		||||
        self.cm_obj = value  # needed for BC with calling enter from CM code
 | 
			
		||||
        self.source = source
 | 
			
		||||
@ -221,8 +399,39 @@ class TorchFunctionModeVariable(ContextWrappingVariable):
 | 
			
		||||
            kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _call_func(self, tx: "InstructionTranslator", values):
 | 
			
		||||
        unimplemented("enter/exit for torch function mode NYI")
 | 
			
		||||
    def enter(self, tx):
 | 
			
		||||
        from .torch import TorchInGraphFunctionVariable
 | 
			
		||||
 | 
			
		||||
        if isinstance(self.value, NoEnterTorchFunctionMode):
 | 
			
		||||
            return ConstantVariable.create(None)
 | 
			
		||||
 | 
			
		||||
        TorchInGraphFunctionVariable(
 | 
			
		||||
            torch._C._push_on_torch_function_stack
 | 
			
		||||
        ).call_function(tx, [self], {})
 | 
			
		||||
        return ConstantVariable.create(None)
 | 
			
		||||
 | 
			
		||||
    def exit(self, tx: "InstructionTranslator", *args):
 | 
			
		||||
        from .torch import TorchInGraphFunctionVariable
 | 
			
		||||
 | 
			
		||||
        TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
 | 
			
		||||
            tx, [], {}
 | 
			
		||||
        )
 | 
			
		||||
        return ConstantVariable.create(None)
 | 
			
		||||
 | 
			
		||||
    def reconstruct_type(self, codegen):
 | 
			
		||||
        ty = NoEnterTorchFunctionMode
 | 
			
		||||
        codegen(
 | 
			
		||||
            AttrSource(
 | 
			
		||||
                codegen.tx.import_source(ty.__module__),
 | 
			
		||||
                ty.__name__,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def supports_graph_breaks(self):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def exit_on_graph_break(self):
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_all_args(args, kwargs):
 | 
			
		||||
@ -233,7 +442,6 @@ def _flatten_vts(vts):
 | 
			
		||||
    from collections import deque
 | 
			
		||||
 | 
			
		||||
    from .dicts import ConstDictVariable
 | 
			
		||||
    from .lazy import LazyVariableTracker
 | 
			
		||||
    from .lists import ListVariable
 | 
			
		||||
 | 
			
		||||
    vts = deque(vts)
 | 
			
		||||
@ -241,13 +449,17 @@ def _flatten_vts(vts):
 | 
			
		||||
 | 
			
		||||
    while vts:
 | 
			
		||||
        vt = vts.pop()
 | 
			
		||||
        LazyVariableTracker.realize_all(vt)
 | 
			
		||||
        if isinstance(vt, ListVariable):
 | 
			
		||||
            vts.extend(vt.items)
 | 
			
		||||
        elif isinstance(vt, ConstDictVariable):
 | 
			
		||||
            vts.extend(vt.items.values())
 | 
			
		||||
        else:
 | 
			
		||||
            output.append(vt)
 | 
			
		||||
 | 
			
		||||
        if not vt.is_realized() and vt.peek_type() in (dict, list, tuple):
 | 
			
		||||
            vt.realize()
 | 
			
		||||
 | 
			
		||||
        if vt.is_realized():
 | 
			
		||||
            if isinstance(vt, ListVariable):
 | 
			
		||||
                vts.extend(vt.items)
 | 
			
		||||
            elif isinstance(vt, ConstDictVariable):
 | 
			
		||||
                vts.extend(vt.items.values())
 | 
			
		||||
 | 
			
		||||
        output.append(vt)
 | 
			
		||||
 | 
			
		||||
    return output
 | 
			
		||||
 | 
			
		||||
@ -301,8 +513,15 @@ def call_torch_function(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_torch_function_fn(tx: "InstructionTranslator", value, source):
 | 
			
		||||
    from types import FunctionType
 | 
			
		||||
 | 
			
		||||
    from .builder import SourcelessBuilder, VariableBuilder
 | 
			
		||||
 | 
			
		||||
    func = value.__torch_function__.__func__
 | 
			
		||||
 | 
			
		||||
    if not isinstance(func, FunctionType):
 | 
			
		||||
        unimplemented("Builtin/C++ torch function implementations NYI")
 | 
			
		||||
 | 
			
		||||
    if source:
 | 
			
		||||
        return VariableBuilder(
 | 
			
		||||
            tx,
 | 
			
		||||
 | 
			
		||||
@ -413,10 +413,22 @@ class UserDefinedClassVariable(UserDefinedVariable):
 | 
			
		||||
            and self.source
 | 
			
		||||
            and not is_forbidden_context_manager(self.value)
 | 
			
		||||
        ):
 | 
			
		||||
            from torch.overrides import TorchFunctionMode
 | 
			
		||||
 | 
			
		||||
            from .ctx_manager import GenericContextWrappingVariable
 | 
			
		||||
            from .torch_function import TorchFunctionModeVariable
 | 
			
		||||
 | 
			
		||||
            if issubclass(
 | 
			
		||||
                self.value, TorchFunctionMode
 | 
			
		||||
            ) and TorchFunctionModeVariable.is_supported_torch_function_mode(
 | 
			
		||||
                self.value
 | 
			
		||||
            ):
 | 
			
		||||
                var_cls = TorchFunctionModeVariable
 | 
			
		||||
            else:
 | 
			
		||||
                var_cls = GenericContextWrappingVariable
 | 
			
		||||
 | 
			
		||||
            cm_obj = tx.output.side_effects.track_object_new(
 | 
			
		||||
                self.source, self.value, GenericContextWrappingVariable, {}
 | 
			
		||||
                self.source, self.value, var_cls, {}
 | 
			
		||||
            )
 | 
			
		||||
            cm_obj.call_method(tx, "__init__", args, kwargs)
 | 
			
		||||
            return cm_obj
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,7 @@ from torch._higher_order_ops.utils import (
 | 
			
		||||
    reenter_make_fx,
 | 
			
		||||
    UnsupportedAliasMutationException,
 | 
			
		||||
)
 | 
			
		||||
from torch._ops import HigherOrderOperator, OpOverload
 | 
			
		||||
from torch._ops import HigherOrderOperator
 | 
			
		||||
from torch._subclasses import FakeTensorMode
 | 
			
		||||
from torch.fx.experimental.proxy_tensor import (
 | 
			
		||||
    make_fx,
 | 
			
		||||
@ -19,7 +19,6 @@ from torch.fx.experimental.proxy_tensor import (
 | 
			
		||||
    track_tensor_tree,
 | 
			
		||||
)
 | 
			
		||||
from torch.fx.graph_module import GraphModule
 | 
			
		||||
from torch.overrides import TorchFunctionMode
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Duplicate of _inductor/kernel/flex_attention.py to avoid circular import
 | 
			
		||||
@ -69,27 +68,6 @@ def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch
 | 
			
		||||
    return new_out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransformGetItemToIndex(TorchFunctionMode):
 | 
			
		||||
    # This is needed since we want to support calling
 | 
			
		||||
    # A[q_idx], where q_idx is a scalar tensor in score_mod.
 | 
			
		||||
    # Today, when q_idx is a scalar tensor, we implicitly convert it to a python
 | 
			
		||||
    # scalar and create a view. We do not want that behavior in this case, so we
 | 
			
		||||
    # use this torchfunctionmode to override that behavior for score_mod
 | 
			
		||||
    # wherever we're running it.
 | 
			
		||||
    def __torch_function__(
 | 
			
		||||
        self,
 | 
			
		||||
        func: OpOverload,
 | 
			
		||||
        types: Tuple[torch._C._TensorMeta, ...],
 | 
			
		||||
        args: Tuple[object, ...] = (),
 | 
			
		||||
        kwargs: Optional[Dict[str, object]] = None,
 | 
			
		||||
    ) -> object:
 | 
			
		||||
        if func == torch.Tensor.__getitem__:
 | 
			
		||||
            index_args = pytree.tree_leaves(args[1])
 | 
			
		||||
            if all(isinstance(x, torch.Tensor) for x in index_args):
 | 
			
		||||
                return torch.ops.aten.index(args[0], index_args)
 | 
			
		||||
        return func(*args, **(kwargs or {}))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FlexAttentionHOP(HigherOrderOperator):
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        super().__init__("flex_attention", cacheable=True)
 | 
			
		||||
@ -185,6 +163,8 @@ def _math_attention_inner(
 | 
			
		||||
    score_mod_other_buffers: Tuple = (),
 | 
			
		||||
    mask_mod_other_buffers: Tuple = (),
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
 | 
			
		||||
 | 
			
		||||
    working_precision = torch.float64 if query.dtype == torch.float64 else torch.float32
 | 
			
		||||
 | 
			
		||||
    scores = (query @ key.transpose(-2, -1)).to(dtype=working_precision)
 | 
			
		||||
@ -318,6 +298,8 @@ def trace_flex_attention(
 | 
			
		||||
    This will produce a GraphModule that will be stored on the root tracer as "sdpa_score". We
 | 
			
		||||
    access this graph module in inductor to inline the score_mod function to the triton template.
 | 
			
		||||
    """
 | 
			
		||||
    from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
 | 
			
		||||
 | 
			
		||||
    example_out = flex_attention(
 | 
			
		||||
        query,
 | 
			
		||||
        key,
 | 
			
		||||
@ -414,6 +396,8 @@ def flex_attention_functionalize(
 | 
			
		||||
    guard against any mutations in the score_mod function, to the other_buffers since those
 | 
			
		||||
    are free variables.
 | 
			
		||||
    """
 | 
			
		||||
    from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
 | 
			
		||||
 | 
			
		||||
    query_unwrapped = ctx.unwrap_tensors(query)
 | 
			
		||||
    key_unwrapped = ctx.unwrap_tensors(key)
 | 
			
		||||
    value_unwrapped = ctx.unwrap_tensors(value)
 | 
			
		||||
@ -715,6 +699,8 @@ def flex_attention_autograd(
 | 
			
		||||
    score_mod_other_buffers: Tuple[Tensor, ...] = (),
 | 
			
		||||
    mask_mod_other_buffers: Tuple[Tensor, ...] = (),
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
 | 
			
		||||
 | 
			
		||||
    with TransformGetItemToIndex():
 | 
			
		||||
        input_requires_grad = any(t.requires_grad for t in (query, key, value))
 | 
			
		||||
        if torch.is_grad_enabled() and input_requires_grad:
 | 
			
		||||
@ -765,6 +751,8 @@ def sdpa_dense_backward(
 | 
			
		||||
    score_mod_other_buffers: Tuple,
 | 
			
		||||
    mask_mod_other_buffers: Tuple,
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 | 
			
		||||
    from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
 | 
			
		||||
 | 
			
		||||
    # Get outputs before calling repeat interleave
 | 
			
		||||
    actual_grad_query = torch.empty_like(query)
 | 
			
		||||
    actual_grad_key = torch.empty_like(key)
 | 
			
		||||
@ -892,6 +880,8 @@ def trace_flex_attention_backward(
 | 
			
		||||
    mask_mod_other_buffers: Tuple = (),
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 | 
			
		||||
    """We already have the forward graph and joint graph from the forward pass, so we create a proxy attach both graphs"""
 | 
			
		||||
    from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
 | 
			
		||||
 | 
			
		||||
    example_out = flex_attention_backward(
 | 
			
		||||
        query,
 | 
			
		||||
        key,
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,8 @@ from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_imp
 | 
			
		||||
from torch._ops import HigherOrderOperator
 | 
			
		||||
from torch._subclasses.fake_tensor import FakeTensorMode
 | 
			
		||||
from torch.fx.experimental.proxy_tensor import (
 | 
			
		||||
    _temp_remove_metadata_torch_function_mode,
 | 
			
		||||
    _temp_remove_pre_dispatch_torch_function_mode,
 | 
			
		||||
    disable_proxy_modes_tracing,
 | 
			
		||||
    make_fx,
 | 
			
		||||
    ProxyTorchDispatchMode,
 | 
			
		||||
@ -18,14 +20,26 @@ from torch.utils._python_dispatch import _get_current_dispatch_mode
 | 
			
		||||
 | 
			
		||||
@exposed_in("torch")
 | 
			
		||||
def strict_mode(callable, operands):
 | 
			
		||||
    from torch._dynamo.backends.debugging import (
 | 
			
		||||
        make_eager_backend_with_torch_function_modes,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if torch.compiler.is_dynamo_compiling():
 | 
			
		||||
        return strict_mode_op(callable, operands)
 | 
			
		||||
 | 
			
		||||
    with _set_compilation_env():
 | 
			
		||||
        with torch._dynamo.utils.disable_cache_limit():
 | 
			
		||||
            return torch.compile(strict_mode_op, backend="eager", fullgraph=True)(
 | 
			
		||||
                callable, operands
 | 
			
		||||
            )
 | 
			
		||||
        with _temp_remove_metadata_torch_function_mode() as metadata_mode:
 | 
			
		||||
            with _temp_remove_pre_dispatch_torch_function_mode() as predispatch_mode:
 | 
			
		||||
                modes = [metadata_mode, predispatch_mode]
 | 
			
		||||
                modes = [mode for mode in modes if mode is not None]
 | 
			
		||||
                if modes:
 | 
			
		||||
                    backend = make_eager_backend_with_torch_function_modes(modes)
 | 
			
		||||
                else:
 | 
			
		||||
                    backend = "eager"
 | 
			
		||||
                with torch._dynamo.utils.disable_cache_limit():
 | 
			
		||||
                    return torch.compile(
 | 
			
		||||
                        strict_mode_op, backend=backend, fullgraph=True
 | 
			
		||||
                    )(callable, operands)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StrictMode(HigherOrderOperator):
 | 
			
		||||
 | 
			
		||||
@ -2540,90 +2540,40 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard {
 | 
			
		||||
 public:
 | 
			
		||||
  TORCH_FUNCTION_MODE_STACK(
 | 
			
		||||
      const py::list& initial_stack,
 | 
			
		||||
      const py::list& ignored_types,
 | 
			
		||||
      py::object verbose_code_parts)
 | 
			
		||||
      : LeafGuard(std::move(verbose_code_parts)),
 | 
			
		||||
        _ref_stack(),
 | 
			
		||||
        _ignored_types() {
 | 
			
		||||
      : LeafGuard(std::move(verbose_code_parts)), _ref_stack() {
 | 
			
		||||
    Py_ssize_t len = PyList_Size(initial_stack.ptr());
 | 
			
		||||
    for (Py_ssize_t idx = 0; idx < len; idx++) {
 | 
			
		||||
      PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref
 | 
			
		||||
      auto type = Py_TYPE(mode);
 | 
			
		||||
      this->_ref_stack.push_back(type);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    len = PyList_Size(ignored_types.ptr());
 | 
			
		||||
    for (Py_ssize_t idx = 0; idx < len; idx++) {
 | 
			
		||||
      PyObject* type_obj =
 | 
			
		||||
          PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref
 | 
			
		||||
      if (PyType_Check(type_obj) == 0) {
 | 
			
		||||
        PyErr_SetString(
 | 
			
		||||
            PyExc_TypeError, "ignored_types should contain a list of types");
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      PyTypeObject* type = (PyTypeObject*)type_obj;
 | 
			
		||||
      this->_ignored_types.insert(type);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool check_nopybind(PyObject* value) override {
 | 
			
		||||
    // Ignore value arg, only used to satisfy the interface
 | 
			
		||||
    size_t ref_ind = 0;
 | 
			
		||||
    const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len();
 | 
			
		||||
    const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len();
 | 
			
		||||
    const size_t ref_stack_size = this->_ref_stack.size();
 | 
			
		||||
 | 
			
		||||
    int64_t idx = 0;
 | 
			
		||||
    while ((idx < len) && (ref_ind < ref_stack_size)) {
 | 
			
		||||
    if (len != ref_stack_size) {
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int64_t idx = 0; (size_t)idx < len; idx++) {
 | 
			
		||||
      std::shared_ptr<c10::SafePyObject> mode =
 | 
			
		||||
          at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
 | 
			
		||||
 | 
			
		||||
      PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
 | 
			
		||||
      bool act_ignored = this->_ignored_types.count(mode_type) > 0;
 | 
			
		||||
      bool ref_ignored =
 | 
			
		||||
          this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0;
 | 
			
		||||
      // skip ignored types
 | 
			
		||||
      if (act_ignored && ref_ignored) {
 | 
			
		||||
        idx++;
 | 
			
		||||
        ref_ind++;
 | 
			
		||||
        continue;
 | 
			
		||||
      } else if (ref_ignored) {
 | 
			
		||||
        ref_ind++;
 | 
			
		||||
        continue;
 | 
			
		||||
      } else if (act_ignored) {
 | 
			
		||||
        idx++;
 | 
			
		||||
        continue;
 | 
			
		||||
      }
 | 
			
		||||
      // if we already have more non-ignored modes than the ref stack
 | 
			
		||||
      // or if the mode doesn't match at the current index, return false
 | 
			
		||||
      else if (mode_type != _ref_stack.at(ref_ind)) {
 | 
			
		||||
        return false;
 | 
			
		||||
      }
 | 
			
		||||
      ref_ind++;
 | 
			
		||||
      idx++;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (; ref_ind < ref_stack_size; ref_ind++) {
 | 
			
		||||
      if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) {
 | 
			
		||||
      if (mode_type != _ref_stack.at(idx)) {
 | 
			
		||||
        return false;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (; idx < len; idx++) {
 | 
			
		||||
      std::shared_ptr<c10::SafePyObject> mode =
 | 
			
		||||
          at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
 | 
			
		||||
 | 
			
		||||
      PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
 | 
			
		||||
      if (!(this->_ignored_types.count(mode_type) > 0)) {
 | 
			
		||||
        return false;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return ref_ind == ref_stack_size && idx == len;
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  std::vector<PyTypeObject*> _ref_stack;
 | 
			
		||||
  std::set<PyTypeObject*> _ignored_types;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class TENSOR_MATCH : public LeafGuard {
 | 
			
		||||
@ -3792,7 +3742,7 @@ PyObject* torch_c_dynamo_guards_init() {
 | 
			
		||||
      LeafGuard,
 | 
			
		||||
      std::shared_ptr<TORCH_FUNCTION_MODE_STACK>>(
 | 
			
		||||
      py_m, "TORCH_FUNCTION_MODE_STACK")
 | 
			
		||||
      .def(py::init<py::list, py::list, py::list>())
 | 
			
		||||
      .def(py::init<py::list, py::list>())
 | 
			
		||||
      .def("__call__", &TORCH_FUNCTION_MODE_STACK::check);
 | 
			
		||||
  py::class_<DATA_PTR_MATCH, LeafGuard, std::shared_ptr<DATA_PTR_MATCH>>(
 | 
			
		||||
      py_m, "DATA_PTR_MATCH")
 | 
			
		||||
@ -4029,10 +3979,9 @@ PyObject* torch_c_dynamo_guards_init() {
 | 
			
		||||
          "add_torch_function_mode_stack_guard",
 | 
			
		||||
          [](GuardManager& self,
 | 
			
		||||
             const py::list& initial_stack,
 | 
			
		||||
             const py::list& ignored_types,
 | 
			
		||||
             py::object verbose_code_parts) -> void {
 | 
			
		||||
            self.add_leaf_guard(std::make_shared<TORCH_FUNCTION_MODE_STACK>(
 | 
			
		||||
                initial_stack, ignored_types, std::move(verbose_code_parts)));
 | 
			
		||||
                initial_stack, std::move(verbose_code_parts)));
 | 
			
		||||
          })
 | 
			
		||||
      .def(
 | 
			
		||||
          "add_data_ptr_guard",
 | 
			
		||||
 | 
			
		||||
@ -13,10 +13,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch import Tensor
 | 
			
		||||
from torch._higher_order_ops.flex_attention import (
 | 
			
		||||
    flex_attention as flex_attention_hop,
 | 
			
		||||
    TransformGetItemToIndex,
 | 
			
		||||
)
 | 
			
		||||
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
 | 
			
		||||
from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop
 | 
			
		||||
from torch._higher_order_ops.utils import _set_compilation_env
 | 
			
		||||
from torch.fx.experimental.proxy_tensor import (
 | 
			
		||||
    _temp_remove_metadata_torch_function_mode,
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user