mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 20:34:54 +08:00 
			
		
		
		
	Compare commits
	
		
			7 Commits
		
	
	
		
			ciflow/bin
			...
			mlazos/tf-
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| ac3dabf652 | |||
| 54ab06fc07 | |||
| 32542724be | |||
| dfbb990dc4 | |||
| 194d46e91c | |||
| 9094fb5c7c | |||
| ec6b49eed9 | 
| @ -3380,6 +3380,21 @@ utils_device.CURRENT_DEVICE == None""".split( | ||||
|         self.assertTrue(same(obj41.y, obj42.y)) | ||||
|         self.assertEqual(cnts.frame_count, 1) | ||||
|  | ||||
|     def test_thread_local_setattr(self): | ||||
|         from threading import local | ||||
|  | ||||
|         loc = local() | ||||
|  | ||||
|         @torch.compile(fullgraph=True) | ||||
|         def fn(x, l): | ||||
|             l.x = x | ||||
|             return x + 1 | ||||
|  | ||||
|         x = torch.ones(2, 2) | ||||
|         fn(x, loc) | ||||
|  | ||||
|         self.assertTrue(loc.x is x) | ||||
|  | ||||
|     def test_user_defined_class_name(self): | ||||
|         class MyClassFoo: | ||||
|             pass | ||||
|  | ||||
| @ -1,5 +1,4 @@ | ||||
| # Owner(s): ["module: dynamo"] | ||||
| from unittest.mock import patch | ||||
|  | ||||
| import torch | ||||
| import torch._dynamo.test_case | ||||
| @ -14,6 +13,17 @@ from torch.utils._device import DeviceContext | ||||
| from torch.utils._python_dispatch import TorchDispatchMode | ||||
|  | ||||
|  | ||||
| class TestMode(BaseTorchFunctionMode): | ||||
|     def __torch_function__(self, func, types, args, kwargs=None): | ||||
|         if not kwargs: | ||||
|             kwargs = {} | ||||
|  | ||||
|         if func == torch.add: | ||||
|             return torch.zeros(2, 2) | ||||
|  | ||||
|         return super().__torch_function__(func, types, args, kwargs) | ||||
|  | ||||
|  | ||||
| class TorchDispatchModeTests(torch._dynamo.test_case.TestCase): | ||||
|     @classmethod | ||||
|     def setUpClass(cls): | ||||
| @ -57,9 +67,11 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         torch.set_default_device(None) | ||||
|         torch._dynamo.reset() | ||||
|  | ||||
|     def tearDown(self): | ||||
|         torch.set_default_device(None) | ||||
|         torch._dynamo.reset() | ||||
|  | ||||
|     def _run_torch_function_mode_guard_test(self): | ||||
|         class TestMode1(BaseTorchFunctionMode): | ||||
| @ -94,70 +106,6 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase): | ||||
|             fn(inp) | ||||
|         self.assertEqual(cnt.frame_count, 4) | ||||
|  | ||||
|     def _run_ignored_mode_types_test(self): | ||||
|         class IgnoredMode(BaseTorchFunctionMode): | ||||
|             pass | ||||
|  | ||||
|         cnt = torch._dynamo.testing.CompileCounter() | ||||
|  | ||||
|         @torch.compile(backend=cnt.__call__, fullgraph=True) | ||||
|         def fn(x): | ||||
|             return x + 1 | ||||
|  | ||||
|         inp = torch.ones(2, 2) | ||||
|  | ||||
|         with patch( | ||||
|             "torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode} | ||||
|         ): | ||||
|             # initial compile | ||||
|             fn(inp) | ||||
|  | ||||
|             # no recompile, mode ignored | ||||
|             # note: the ref stack is length 0, and the stack we are checking against has length 2 | ||||
|             # we want to check both ref stack len > runtime stack, and ref stack len < runtime stack | ||||
|             with IgnoredMode(), IgnoredMode(): | ||||
|                 fn(inp) | ||||
|  | ||||
|             self.assertEqual(cnt.frame_count, 1) | ||||
|  | ||||
|             # recompile due to new mode on the stack | ||||
|             with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): | ||||
|                 fn(inp) | ||||
|  | ||||
|             self.assertEqual(cnt.frame_count, 2) | ||||
|  | ||||
|             # recompile | ||||
|             # tests both ref stack len > runtime stack len for the above guard check | ||||
|             # and ref stack len < runtime stack len for the initial zero mode case | ||||
|             with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode(): | ||||
|                 fn(inp) | ||||
|  | ||||
|             self.assertEqual(cnt.frame_count, 3) | ||||
|  | ||||
|             # no recompile | ||||
|             with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): | ||||
|                 fn(inp) | ||||
|  | ||||
|             self.assertEqual(cnt.frame_count, 3) | ||||
|  | ||||
|         # This is tricky, basically the ignored modes are baked into the guard | ||||
|         # IgnoredMode will be ignored forever by that guard. | ||||
|         # This is okay since we don't expect to be modifying IGNORED_MODES | ||||
|         # in the middle of execution except for the purposes of testing. | ||||
|         torch._dynamo.reset() | ||||
|  | ||||
|         with IgnoredMode(): | ||||
|             fn(inp) | ||||
|  | ||||
|         self.assertEqual(cnt.frame_count, 4) | ||||
|  | ||||
|     @torch._dynamo.config.patch("enable_cpp_guard_manager", False) | ||||
|     def test_torch_function_mode_guards_ignored_types_py(self): | ||||
|         self._run_ignored_mode_types_test() | ||||
|  | ||||
|     def test_torch_function_mode_guards_ignored_types_cpp(self): | ||||
|         self._run_ignored_mode_types_test() | ||||
|  | ||||
|     @torch._dynamo.config.patch("enable_cpp_guard_manager", False) | ||||
|     def test_torch_function_mode_guards_py(self): | ||||
|         self._run_torch_function_mode_guard_test() | ||||
| @ -324,6 +272,218 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase): | ||||
|             fn(inp) | ||||
|         self.assertEqual(cnt.frame_count, 2) | ||||
|  | ||||
|     def test_nested_torch_function_mode(self): | ||||
|         mode_1_called = False | ||||
|         mode_2_called = False | ||||
|  | ||||
|         def reset_state(): | ||||
|             nonlocal mode_1_called | ||||
|             nonlocal mode_2_called | ||||
|             mode_1_called = False | ||||
|             mode_2_called = False | ||||
|  | ||||
|         ones = torch.ones(2, 2) | ||||
|         zeros = torch.zeros(2, 2) | ||||
|  | ||||
|         class TestMode1(BaseTorchFunctionMode): | ||||
|             def __torch_function__(self, func, types, args, kwargs=None): | ||||
|                 if not kwargs: | ||||
|                     kwargs = {} | ||||
|  | ||||
|                 nonlocal mode_1_called | ||||
|  | ||||
|                 mode_1_called = True | ||||
|  | ||||
|                 if func == torch.add: | ||||
|                     return zeros | ||||
|  | ||||
|                 return super().__torch_function__(func, types, args, kwargs) | ||||
|  | ||||
|         class TestMode2(BaseTorchFunctionMode): | ||||
|             def __torch_function__(self, func, types, args, kwargs=None): | ||||
|                 if not kwargs: | ||||
|                     kwargs = {} | ||||
|  | ||||
|                 nonlocal mode_2_called | ||||
|  | ||||
|                 mode_2_called = True | ||||
|  | ||||
|                 if func == torch.mul: | ||||
|                     return ones | ||||
|  | ||||
|                 return super().__torch_function__(func, types, args, kwargs) | ||||
|  | ||||
|         def fn(x): | ||||
|             return torch.add(x, 3) | ||||
|  | ||||
|         def fn_2(x): | ||||
|             return torch.mul(x, 3) + torch.add(x, 3) | ||||
|  | ||||
|         inp = torch.ones(2, 2) + 1 | ||||
|  | ||||
|         for fn_i in [fn, fn_2]: | ||||
|             fn_opt = torch.compile(fn_i, fullgraph=True) | ||||
|             with TestMode1(), TestMode2(): | ||||
|                 expected = fn_i(inp), mode_1_called, mode_2_called | ||||
|                 reset_state() | ||||
|                 actual = fn_opt(inp), mode_1_called, mode_2_called | ||||
|                 reset_state() | ||||
|  | ||||
|             self.assertEqual(expected, actual) | ||||
|  | ||||
|     def test_torch_function_mode_disable(self): | ||||
|         class TestSubclass(torch.Tensor): | ||||
|             @classmethod | ||||
|             def __torch_function__(cls, func, types, args, kwargs=None): | ||||
|                 if not kwargs: | ||||
|                     kwargs = {} | ||||
|                 if func == torch.add: | ||||
|                     return torch.ones(2, 2) | ||||
|                 return super().__torch_function__(func, types, args, kwargs) | ||||
|  | ||||
|         class TestMode(BaseTorchFunctionMode): | ||||
|             def __torch_function__(self, func, types, args, kwargs=None): | ||||
|                 if not kwargs: | ||||
|                     kwargs = {} | ||||
|  | ||||
|                 if func == torch.add: | ||||
|                     return torch.zeros(2, 2) | ||||
|  | ||||
|                 return super().__torch_function__(func, types, args, kwargs) | ||||
|  | ||||
|         def fn(x): | ||||
|             return torch.add(x, 3) | ||||
|  | ||||
|         inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) | ||||
|  | ||||
|         fn_opt = torch.compile(fn, fullgraph=True) | ||||
|         with TestMode(), torch._dynamo.config.patch( | ||||
|             "traceable_tensor_subclasses", {TestSubclass} | ||||
|         ): | ||||
|             with torch._C.DisableTorchFunctionSubclass(): | ||||
|                 expected = fn(inp) | ||||
|                 actual = fn_opt(inp) | ||||
|  | ||||
|             self.assertEqual(expected, actual) | ||||
|  | ||||
|             with torch._C.DisableTorchFunction(): | ||||
|                 expected = fn(inp) | ||||
|                 actual = fn_opt(inp) | ||||
|  | ||||
|             self.assertEqual(expected, actual) | ||||
|  | ||||
|     def test_torch_function_mode_highest_priority(self): | ||||
|         class TestSubclass(torch.Tensor): | ||||
|             @classmethod | ||||
|             def __torch_function__(cls, func, types, args, kwargs=None): | ||||
|                 if not kwargs: | ||||
|                     kwargs = {} | ||||
|                 if func == torch.add: | ||||
|                     return torch.ones(2, 2) | ||||
|                 return super().__torch_function__(func, types, args, kwargs) | ||||
|  | ||||
|         def fn(x): | ||||
|             return torch.add(x, 3) | ||||
|  | ||||
|         inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) | ||||
|  | ||||
|         fn_opt = torch.compile(fn, fullgraph=True) | ||||
|         with TestMode(), torch._dynamo.config.patch( | ||||
|             "traceable_tensor_subclasses", {TestSubclass} | ||||
|         ): | ||||
|             expected = fn(inp) | ||||
|             actual = fn_opt(inp) | ||||
|  | ||||
|         self.assertEqual(expected, actual) | ||||
|  | ||||
|     def test_torch_function_mode_enter_exit(self): | ||||
|         def fn(x, y): | ||||
|             with TestMode(): | ||||
|                 o = torch.add(x, 3) | ||||
|  | ||||
|             return torch.add(o, y) | ||||
|  | ||||
|         inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) | ||||
|         fn_opt = torch.compile(fn, fullgraph=True) | ||||
|  | ||||
|         expected = fn(*inp) | ||||
|         actual = fn_opt(*inp) | ||||
|  | ||||
|         self.assertEqual(expected, actual) | ||||
|  | ||||
|     def test_torch_function_mode_graph_break(self): | ||||
|         def fn(x, y): | ||||
|             with TestMode(): | ||||
|                 torch._dynamo.graph_break() | ||||
|                 o = torch.add(x, 3) | ||||
|  | ||||
|             return torch.add(o, y) | ||||
|  | ||||
|         inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) | ||||
|         fn_opt = torch.compile(fn) | ||||
|  | ||||
|         expected = fn(*inp) | ||||
|         actual = fn_opt(*inp) | ||||
|  | ||||
|         self.assertEqual(expected, actual) | ||||
|  | ||||
|     def test_torch_function_mode_and_pop_graph_break(self): | ||||
|         def fn(x, y): | ||||
|             with TestMode(): | ||||
|                 z = _pop_torch_function_stack() | ||||
|                 torch._dynamo.graph_break() | ||||
|                 _push_on_torch_function_stack(z) | ||||
|                 o = torch.add(x, 3) | ||||
|  | ||||
|             return torch.add(o, y) | ||||
|  | ||||
|         inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) | ||||
|         fn_opt = torch.compile(fn) | ||||
|  | ||||
|         expected = fn(*inp) | ||||
|         actual = fn_opt(*inp) | ||||
|  | ||||
|         self.assertEqual(expected, actual) | ||||
|  | ||||
|     def test_torch_function_mode_restore_on_exc(self): | ||||
|         @torch._dynamo.disable() | ||||
|         def err(): | ||||
|             raise RuntimeError("test") | ||||
|  | ||||
|         @torch.compile() | ||||
|         def fn(x): | ||||
|             with TestMode(): | ||||
|                 x += 1 | ||||
|                 err() | ||||
|                 x += 2 | ||||
|                 return x | ||||
|  | ||||
|         try: | ||||
|             fn(torch.ones(2, 2)) | ||||
|         except RuntimeError: | ||||
|             pass | ||||
|         self.assertEqual(_len_torch_function_stack(), 0) | ||||
|  | ||||
|     def test_torch_function_mode_and_pop_graph_break_mutation(self): | ||||
|         def fn(x, y): | ||||
|             with TestMode(): | ||||
|                 z = _pop_torch_function_stack() | ||||
|                 z.y = 5 | ||||
|                 torch._dynamo.graph_break() | ||||
|                 _push_on_torch_function_stack(z) | ||||
|                 o = torch.add(x, 3) | ||||
|                 o = torch.mul(o, z.y) | ||||
|  | ||||
|             return torch.add(o, y) | ||||
|  | ||||
|         inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) | ||||
|         fn_opt = torch.compile(fn) | ||||
|  | ||||
|         expected = fn(*inp) | ||||
|         actual = fn_opt(*inp) | ||||
|  | ||||
|         self.assertEqual(expected, actual) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     from torch._dynamo.test_case import run_tests | ||||
|  | ||||
| @ -24,6 +24,7 @@ from torch.testing._internal.common_utils import ( | ||||
|     IS_WINDOWS, | ||||
|     parametrize, | ||||
|     run_tests, | ||||
|     skipIfCrossRef, | ||||
|     skipIfTorchDynamo, | ||||
|     TEST_WITH_TORCHDYNAMO, | ||||
|     TestCase, | ||||
| @ -1557,6 +1558,7 @@ class TestControlFlowTraced(TestCase): | ||||
|         self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True))) | ||||
|  | ||||
|     @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") | ||||
|     @skipIfCrossRef  # Arg order changes with crossref | ||||
|     def test_cond_simple_with_linear_compile_check_graph(self): | ||||
|         from torch._dynamo.testing import EagerAndRecordGraphs | ||||
|  | ||||
| @ -1819,6 +1821,7 @@ def forward(self, arg0_1): | ||||
|         self._check_compile(fn, inp, backend=backend) | ||||
|  | ||||
|     @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") | ||||
|     @skipIfCrossRef  # Arg order changes with cross ref | ||||
|     def test_while_loop_simple_with_linear_compile_check_graph(self): | ||||
|         fn, inp = WHILE_LOOP_TESTS["simple_with_linear"] | ||||
|         from torch._dynamo.testing import EagerAndRecordGraphs | ||||
|  | ||||
| @ -13,7 +13,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import ( | ||||
| from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR | ||||
| from torch.fx import Node | ||||
| from torch.testing._internal.common_quantization import QuantizationTestCase | ||||
| from torch.testing._internal.common_utils import IS_WINDOWS | ||||
| from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef | ||||
|  | ||||
|  | ||||
| class TestHelperModules: | ||||
| @ -139,6 +139,8 @@ class TestMetaDataPorting(QuantizationTestCase): | ||||
|             self.assertEqual(v, node_tags[k]) | ||||
|         return m | ||||
|  | ||||
|     @skipIfCrossRef  # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack | ||||
|     # trace of the mode torch function impl doesn't match the traced graph stored lineno. | ||||
|     def test_simple_metadata_porting(self): | ||||
|         """ | ||||
|         Model under test | ||||
|  | ||||
| @ -67,7 +67,7 @@ class GuardManager: | ||||
|     ) -> None: ... | ||||
|     def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ... | ||||
|     def add_torch_function_mode_stack_guard( | ||||
|         self, initial_stack, ignored_types, verbose_code_parts: list[str] | ||||
|         self, initial_stack, verbose_code_parts: list[str] | ||||
|     ) -> None: ... | ||||
|  | ||||
| class RootGuardManager(GuardManager): | ||||
|  | ||||
| @ -31,6 +31,18 @@ def eager(gm, fake_tensor_inputs, **kwargs): | ||||
|     return gm.forward | ||||
|  | ||||
|  | ||||
| def make_eager_backend_with_torch_function_mode(mode): | ||||
|     """Used to trace HOPs (cond and while) for eager exectution, the metadata | ||||
|     TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks | ||||
|     in the HOP, so we need to externally run this mode and not trace it.""" | ||||
|  | ||||
|     def fn(gm, fake_tensor_inputs, **kwargs): | ||||
|         with mode: | ||||
|             return gm.forward | ||||
|  | ||||
|     return fn | ||||
|  | ||||
|  | ||||
| @register_backend | ||||
| def eager_noexcept(gm, fake_tensor_inputs, **kwargs): | ||||
|     if kwargs: | ||||
|  | ||||
| @ -112,6 +112,7 @@ from .utils import ( | ||||
|     troubleshooting_url, | ||||
|     write_record_to_file, | ||||
| ) | ||||
| from .variables.torch_function import torch_function_mode_stack_state_mgr | ||||
|  | ||||
|  | ||||
| np: Optional[ModuleType] | ||||
| @ -210,15 +211,18 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: | ||||
|             prior_fwd_from_src = torch.fx.graph_module._forward_from_src | ||||
|             torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result | ||||
|             cleanup = setup_compile_debug() | ||||
|  | ||||
|             exit_stack = contextlib.ExitStack() | ||||
|             exit_stack.enter_context( | ||||
|                 torch.fx._symbolic_trace._maybe_revert_all_patches() | ||||
|             ) | ||||
|             exit_stack.enter_context(torch_function_mode_stack_state_mgr) | ||||
|             try: | ||||
|                 return fn(*args, **kwargs) | ||||
|             finally: | ||||
|                 cleanup.close() | ||||
|                 assert ( | ||||
|                     torch._C._len_torch_function_stack() == 0 | ||||
|                 ), "Torch function mode stack state changed while dynamo tracing, please report a bug" | ||||
|                 exit_stack.close() | ||||
|                 torch._C._set_grad_enabled(prior_grad_mode) | ||||
|                 torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) | ||||
| @ -605,6 +609,10 @@ def _compile( | ||||
|     output: Optional[OutputGraph] = None | ||||
|     tracer: Optional[InstructionTranslator] = None | ||||
|  | ||||
|     tf_mode_stack: List[ | ||||
|         torch.overrides.TorchFunctionMode | ||||
|     ] = torch.overrides._get_current_function_mode_stack() | ||||
|  | ||||
|     @preserve_global_state | ||||
|     def transform( | ||||
|         instructions: List[Instruction], code_options: Dict[str, object] | ||||
| @ -618,6 +626,7 @@ def _compile( | ||||
|             locals, | ||||
|             globals, | ||||
|             builtins, | ||||
|             tf_mode_stack, | ||||
|             code_options, | ||||
|             compiler_fn, | ||||
|             one_graph, | ||||
|  | ||||
| @ -97,6 +97,7 @@ from .source import ( | ||||
|     ScriptObjectQualifiedNameSource, | ||||
|     ShapeEnvSource, | ||||
|     SubclassAttrListSource, | ||||
|     TorchFunctionModeStackSource, | ||||
|     TupleIteratorGetItemSource, | ||||
|     TypeSource, | ||||
|     UnspecializedBuiltinNNModuleSource, | ||||
| @ -110,6 +111,7 @@ from .utils import ( | ||||
|     dict_keys_repr, | ||||
|     get_custom_getattr, | ||||
|     get_torch_function_mode_stack, | ||||
|     get_torch_function_mode_stack_at, | ||||
|     guard_failures, | ||||
|     istype, | ||||
|     key_is_id, | ||||
| @ -313,6 +315,7 @@ CLOSURE_VARS = { | ||||
|     "___dict_contains": lambda a, b: a in b, | ||||
|     "___tuple_iterator_len": tuple_iterator_len, | ||||
|     "___tuple_iterator_getitem": tuple_iterator_getitem, | ||||
|     "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, | ||||
|     "__math_isnan": math.isnan, | ||||
|     "__numpy_isnan": None if np is None else np.isnan, | ||||
|     "inf": float("inf"), | ||||
| @ -900,6 +903,15 @@ class GuardBuilder(GuardBuilderBase): | ||||
|         ): | ||||
|             assert base_guard_manager  # to make mypy happy | ||||
|             out = base_guard_manager | ||||
|         elif istype(source, TorchFunctionModeStackSource): | ||||
|             out = root_guard_manager.lambda_manager( | ||||
|                 python_lambda=lambda _: get_torch_function_mode_stack_at( | ||||
|                     source._get_index() | ||||
|                 ), | ||||
|                 source=source_name, | ||||
|                 example_value=example_value, | ||||
|                 guard_manager_enum=guard_manager_enum, | ||||
|             ) | ||||
|         elif istype(source, GradSource): | ||||
|             assert base_guard_manager  # to make mypy happy | ||||
|             out = base_guard_manager.grad_manager( | ||||
| @ -2206,6 +2218,8 @@ class CheckFunctionManager: | ||||
|         self.output_graph = output_graph | ||||
|         w_builder = None | ||||
|  | ||||
|         # NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing | ||||
|         # in case a set default device call was made in the graph. | ||||
|         self.torch_function_mode_stack = ( | ||||
|             output_graph.torch_function_mode_stack if output_graph else None | ||||
|         ) | ||||
| @ -2322,15 +2336,12 @@ class CheckFunctionManager: | ||||
|         ) | ||||
|  | ||||
|         if config.enable_cpp_guard_manager: | ||||
|             from .variables.torch_function import IGNORED_MODES | ||||
|  | ||||
|             # Insert the global_state guard | ||||
|             assert self.guard_manager  # to make mypy happy | ||||
|             self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) | ||||
|  | ||||
|             self.guard_manager.root.add_torch_function_mode_stack_guard( | ||||
|                 self.torch_function_mode_stack, | ||||
|                 list(IGNORED_MODES), | ||||
|                 ["___check_torch_function_mode_stack()"], | ||||
|             ) | ||||
|             # Clear references to torch_function modes held in the list | ||||
| @ -2637,16 +2648,14 @@ def is_recompiles_verbose_enabled(): | ||||
| # this will only be used if cpp guards are disabled | ||||
| def make_torch_function_mode_stack_guard(intial_stack): | ||||
|     types = [type(x) for x in intial_stack] | ||||
|     from .variables.torch_function import IGNORED_MODES | ||||
|  | ||||
|     def check_torch_function_mode_stack(): | ||||
|         cur_stack = get_torch_function_mode_stack() | ||||
|  | ||||
|         if len(cur_stack) != len(types): | ||||
|             return False | ||||
|  | ||||
|         for ty, mode in zip(types, cur_stack): | ||||
|             if ty in IGNORED_MODES: | ||||
|                 continue | ||||
|             if ty != type(mode): | ||||
|                 return False | ||||
|  | ||||
|  | ||||
| @ -78,7 +78,6 @@ from .utils import ( | ||||
|     get_instruction_source_311, | ||||
|     get_locals_to_steal, | ||||
|     get_static_address_type, | ||||
|     get_torch_function_mode_stack, | ||||
|     graph_break_reasons, | ||||
|     increment_op_count, | ||||
|     lazy_format_graph_code, | ||||
| @ -250,6 +249,7 @@ class OutputGraph: | ||||
|         local_scope: Scope, | ||||
|         global_scope: Scope, | ||||
|         f_code, | ||||
|         torch_function_mode_stack, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.tracers = [SubgraphTracer(self, export_root=export)] | ||||
| @ -368,7 +368,7 @@ class OutputGraph: | ||||
|         # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty | ||||
|         self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled() | ||||
|         # This records the initial torch function mode stack for guarding | ||||
|         self.torch_function_mode_stack = get_torch_function_mode_stack() | ||||
|         self.torch_function_mode_stack = torch_function_mode_stack | ||||
|  | ||||
|         # Tracks if the output graph has a user defined allowed function in the | ||||
|         # graph. This is used later to determine if we should fallback to eager | ||||
| @ -1020,7 +1020,7 @@ class OutputGraph: | ||||
|             prefix_insts.clear() | ||||
|  | ||||
|         for block in reversed(tx.block_stack): | ||||
|             block.exit(tx) | ||||
|             block.exit(tx, is_graph_break=reason.graph_break) | ||||
|  | ||||
|         self.cleanup_graph() | ||||
|         tx.prune_dead_locals() | ||||
|  | ||||
| @ -25,6 +25,26 @@ if TYPE_CHECKING: | ||||
|         sys as sys, | ||||
|     ) | ||||
|  | ||||
| from torch.overrides import BaseTorchFunctionMode | ||||
|  | ||||
|  | ||||
| # These classes handle support for TorchFunctionModes across | ||||
| # graph breaks | ||||
| # Today the TorchFunctionMode enter (for the classes we support) | ||||
| # simply pushes the mode onto the stack. Since after this occurs | ||||
| # the stack is mutated, and we replay these mutations, we don't need | ||||
| # any cleanup logic to be run once the graph break occurs, we simply replay | ||||
| # these mutations to ensure at the graph break the torch function mode stack is correct | ||||
| #  and reconstruct the torch function mode stack normally | ||||
| # when we compile the resume function on the other side of the break. | ||||
| # However, to ensure we exit properly | ||||
| # in the resume function, we need to re-enter the contexts as we do other contexts. | ||||
| # These contexts do nothing on enter, but provide the correct exit logic to ensure | ||||
| # the stack state is correct. | ||||
| class NoEnterTorchFunctionMode(BaseTorchFunctionMode): | ||||
|     def __enter__(self): | ||||
|         pass | ||||
|  | ||||
|  | ||||
| def index(iterator, item, start=0, end=None): | ||||
|     from itertools import islice | ||||
|  | ||||
| @ -48,6 +48,107 @@ class ReenterWith: | ||||
|     stack_index: int | ||||
|     target_values: Optional[Tuple[Any, ...]] = None | ||||
|  | ||||
|     def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]): | ||||
|         """ | ||||
|         Codegen based off of: | ||||
|         try: | ||||
|             (rest) | ||||
|         finally: | ||||
|  | ||||
|         """ | ||||
|         except_jump_target = create_instruction( | ||||
|             "NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO" | ||||
|         ) | ||||
|         cleanup_complete_jump_target = create_instruction("NOP") | ||||
|  | ||||
|         setup_finally: List[Instruction] = [] | ||||
|  | ||||
|         if sys.version_info < (3, 11): | ||||
|             setup_finally.append( | ||||
|                 create_instruction("SETUP_FINALLY", target=except_jump_target) | ||||
|             ) | ||||
|         else: | ||||
|             exn_tab_begin = create_instruction("NOP") | ||||
|             exn_tab_end = create_instruction("NOP") | ||||
|             exn_tab_begin.exn_tab_entry = InstructionExnTabEntry( | ||||
|                 exn_tab_begin, | ||||
|                 exn_tab_end, | ||||
|                 except_jump_target, | ||||
|                 self.stack_index + 1, | ||||
|                 False, | ||||
|             ) | ||||
|             setup_finally.append(exn_tab_begin) | ||||
|  | ||||
|         def create_reset(): | ||||
|             insts = [ | ||||
|                 create_instruction( | ||||
|                     "LOAD_GLOBAL", argval="__import_torch_dot__dynamo_dot_utils" | ||||
|                 ), | ||||
|                 create_instruction("LOAD_ATTR", argval="set_torch_function_mode_stack"), | ||||
|             ] | ||||
|             return [ | ||||
|                 *insts, | ||||
|                 create_instruction( | ||||
|                     "LOAD_FAST", argval="___prev_torch_function_mode_stack" | ||||
|                 ), | ||||
|                 *create_call_function(1, True), | ||||
|                 create_instruction("POP_TOP"), | ||||
|             ] | ||||
|  | ||||
|         if sys.version_info < (3, 9): | ||||
|             epilogue = [ | ||||
|                 create_instruction("POP_BLOCK"), | ||||
|                 create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), | ||||
|                 except_jump_target, | ||||
|                 *create_reset(), | ||||
|                 create_instruction("POP_TOP"), | ||||
|                 create_instruction("POP_TOP"), | ||||
|                 create_instruction("POP_TOP"), | ||||
|                 *create_reset(), | ||||
|                 create_instruction("RAISE_VARARGS", argval=0), | ||||
|                 create_instruction("POP_EXCEPT", argval=0), | ||||
|                 create_instruction("END_FINALLY"), | ||||
|                 cleanup_complete_jump_target, | ||||
|             ] | ||||
|         elif sys.version_info < (3, 11): | ||||
|             epilogue = [ | ||||
|                 create_instruction("POP_BLOCK"), | ||||
|                 create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), | ||||
|                 except_jump_target, | ||||
|                 create_instruction("POP_TOP"), | ||||
|                 create_instruction("POP_TOP"), | ||||
|                 create_instruction("POP_TOP"), | ||||
|                 *create_reset(), | ||||
|                 create_instruction("RAISE_VARARGS", argval=0), | ||||
|                 create_instruction("POP_EXCEPT", argval=0), | ||||
|                 cleanup_complete_jump_target, | ||||
|             ] | ||||
|         else: | ||||
|             finally_exn_tab_end = create_instruction("RAISE_VARARGS", argval=0) | ||||
|             finally_exn_tab_target = create_instruction("COPY", arg=3) | ||||
|             except_jump_target.exn_tab_entry = InstructionExnTabEntry( | ||||
|                 except_jump_target, | ||||
|                 finally_exn_tab_end, | ||||
|                 finally_exn_tab_target, | ||||
|                 self.stack_index + 2, | ||||
|                 True, | ||||
|             ) | ||||
|             epilogue = [ | ||||
|                 exn_tab_end, | ||||
|                 create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), | ||||
|                 except_jump_target,  # PUSH_EXC_INFO | ||||
|                 create_instruction("POP_TOP"), | ||||
|                 *create_reset(), | ||||
|                 finally_exn_tab_end, | ||||
|                 finally_exn_tab_target,  # COPY 3 | ||||
|                 create_instruction("POP_EXCEPT"), | ||||
|                 create_instruction("RERAISE", arg=1),  # RERAISE 1 | ||||
|                 cleanup_complete_jump_target, | ||||
|             ] | ||||
|  | ||||
|         cleanup[:] = epilogue + cleanup | ||||
|         return setup_finally | ||||
|  | ||||
|     # If we do not want to destroy the stack, we can do the same thing as a | ||||
|     # `SETUP_WITH` block, only that we store the context manager in a local_symbol | ||||
|     def try_except(self, code_options, cleanup: List[Instruction]): | ||||
|  | ||||
| @ -593,16 +593,19 @@ class SideEffects: | ||||
|             elif isinstance( | ||||
|                 var, variables.torch_function.TorchFunctionModeStackVariable | ||||
|             ): | ||||
|                 cg.add_push_null( | ||||
|                     lambda: cg.load_import_from( | ||||
|                         utils.__name__, "set_torch_function_mode_stack" | ||||
|                     ) | ||||
|                 ) | ||||
|                 # Needed in the finally block for stack restoration | ||||
|                 cg.load_import_from(utils.__name__, "get_torch_function_mode_stack") | ||||
|                 cg.call_function(0, True) | ||||
|                 name = "___prev_torch_function_mode_stack" | ||||
|                 cg.code_options["co_varnames"] += (name,) | ||||
|                 cg.append_output(create_instruction("STORE_FAST", argval=name)) | ||||
|                 cg.load_import_from(utils.__name__, "set_torch_function_mode_stack") | ||||
|  | ||||
|                 cg.foreach(var.symbolic_stack) | ||||
|                 cg.append_output( | ||||
|                     create_instruction("BUILD_LIST", arg=len(var.symbolic_stack)) | ||||
|                 ) | ||||
|                 cg.call_function(1, False) | ||||
|                 cg.call_function(1, True) | ||||
|                 cg.append_output(create_instruction("POP_TOP")) | ||||
|             elif self.is_attribute_mutation(var): | ||||
|                 # Applying mutations involves two steps: 1) Push all | ||||
|  | ||||
| @ -608,7 +608,7 @@ class TorchFunctionModeStackSource(Source): | ||||
|     ind: int | ||||
|  | ||||
|     def name(self): | ||||
|         return "" | ||||
|         return f"___get_torch_function_mode_stack_at({self._get_index()})" | ||||
|  | ||||
|     def _get_index(self): | ||||
|         from .variables.torch_function import TorchFunctionModeStackVariable | ||||
|  | ||||
| @ -19,20 +19,7 @@ import traceback | ||||
| import types | ||||
| import typing | ||||
| import weakref | ||||
| from typing import ( | ||||
|     Any, | ||||
|     Callable, | ||||
|     cast, | ||||
|     Deque, | ||||
|     Dict, | ||||
|     List, | ||||
|     Optional, | ||||
|     Set, | ||||
|     Tuple, | ||||
|     Type, | ||||
|     TYPE_CHECKING, | ||||
|     Union, | ||||
| ) | ||||
| from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union | ||||
| from unittest.mock import patch | ||||
|  | ||||
| import torch | ||||
| @ -72,14 +59,12 @@ from .source import ( | ||||
|     GlobalWeakRefSource, | ||||
|     LocalSource, | ||||
|     Source, | ||||
|     TorchFunctionModeStackSource, | ||||
| ) | ||||
| from .trace_rules import is_builtin_constant, is_forbidden | ||||
| from .utils import ( | ||||
|     counters, | ||||
|     get_fake_value, | ||||
|     get_instruction_source_311, | ||||
|     get_torch_function_mode_stack, | ||||
|     graph_break_dup_warning_checker, | ||||
|     istype, | ||||
|     LazyString, | ||||
| @ -120,11 +105,10 @@ from .variables.misc import ( | ||||
| ) | ||||
| from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable | ||||
| from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from .variables.torch_function import TorchFunctionModeVariable | ||||
|  | ||||
| from .variables.torch_function import ( | ||||
|     SymbolicTorchFunctionState, | ||||
|     TorchFunctionModeVariable, | ||||
| ) | ||||
| from .variables.user_defined import ( | ||||
|     RemovableHandleVariable, | ||||
|     UserDefinedClassVariable, | ||||
| @ -283,9 +267,12 @@ class BlockStackEntry: | ||||
|         else: | ||||
|             return ReenterWith(self.stack_index) | ||||
|  | ||||
|     def exit(self, tx): | ||||
|     def exit(self, tx, is_graph_break): | ||||
|         assert self.with_context is not None | ||||
|         return self.with_context.exit(tx) | ||||
|         if ( | ||||
|             is_graph_break and self.with_context.exit_on_graph_break() | ||||
|         ) or not is_graph_break: | ||||
|             return self.with_context.exit(tx) | ||||
|  | ||||
|  | ||||
| class ReturnValueOp(Exception): | ||||
| @ -651,8 +638,17 @@ def break_graph_if_unsupported(*, push): | ||||
|             cleanup: List[Instruction] = [] | ||||
|             # Reconstruct the context variable CLASS in the block stack | ||||
|             for b in self.block_stack: | ||||
|                 # Don't exit any modes we have entered, | ||||
|                 # output bytecode will mutate the tf mode stack accordingly | ||||
|                 if isinstance(b.with_context, TorchFunctionModeVariable): | ||||
|                     cg.extend_output( | ||||
|                         b.resume_fn().try_except_torch_function_mode( | ||||
|                             cg.code_options, cleanup | ||||
|                         ) | ||||
|                     ) | ||||
|                     continue | ||||
|                 assert b.with_context is not None | ||||
|                 assert isinstance(b.with_context, ContextWrappingVariable) | ||||
|                 assert isinstance(b.with_context, (ContextWrappingVariable)) | ||||
|                 b.with_context.reconstruct_type(cg) | ||||
|                 cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) | ||||
|             self.output.add_output_instructions(cg.get_instructions()) | ||||
| @ -728,7 +724,7 @@ class InstructionTranslatorBase( | ||||
|     output: OutputGraph | ||||
|     symbolic_locals: Dict[str, VariableTracker] | ||||
|     symbolic_globals: Dict[str, VariableTracker] | ||||
|     symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"] | ||||
|     symbolic_torch_function_state: SymbolicTorchFunctionState | ||||
|     stack: List[VariableTracker] | ||||
|     instruction_pointer: Optional[int] | ||||
|     current_instruction: Instruction | ||||
| @ -2305,7 +2301,10 @@ class InstructionTranslatorBase( | ||||
|         ): | ||||
|             unimplemented(f"{inst.opname} {ctx}") | ||||
|  | ||||
|         if isinstance(ctx, GenericContextWrappingVariable): | ||||
|         if ( | ||||
|             isinstance(ctx, GenericContextWrappingVariable) | ||||
|             and not ctx.supports_graph_breaks() | ||||
|         ): | ||||
|             self.generic_context_manager_depth += 1 | ||||
|  | ||||
|         # Need this redundant check for mypy | ||||
| @ -2548,7 +2547,7 @@ class InstructionTranslatorBase( | ||||
|         code_options: Dict[str, Any], | ||||
|         symbolic_locals: Dict[str, VariableTracker], | ||||
|         symbolic_globals: Dict[str, VariableTracker], | ||||
|         symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], | ||||
|         symbolic_torch_function_state: SymbolicTorchFunctionState, | ||||
|         f_code: types.CodeType, | ||||
|         export: bool, | ||||
|         inline_depth: int, | ||||
| @ -2563,7 +2562,7 @@ class InstructionTranslatorBase( | ||||
|         self.output = output | ||||
|         self.symbolic_locals = symbolic_locals | ||||
|         self.symbolic_globals = symbolic_globals | ||||
|         self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack | ||||
|         self.symbolic_torch_function_state = symbolic_torch_function_state | ||||
|         self.stack = [] | ||||
|         # stack of variable names for tracking 3.13 closures | ||||
|         self.name_stack: list[Any] = [] | ||||
| @ -2652,6 +2651,7 @@ class InstructionTranslator(InstructionTranslatorBase): | ||||
|         f_locals, | ||||
|         f_globals, | ||||
|         f_builtins, | ||||
|         torch_function_mode_stack, | ||||
|         code_options, | ||||
|         compiler_fn, | ||||
|         one_graph, | ||||
| @ -2677,6 +2677,7 @@ class InstructionTranslator(InstructionTranslatorBase): | ||||
|                 local_scope=f_locals, | ||||
|                 global_scope=f_globals, | ||||
|                 f_code=f_code, | ||||
|                 torch_function_mode_stack=torch_function_mode_stack, | ||||
|             ), | ||||
|             instructions=instructions, | ||||
|             f_locals=f_locals, | ||||
| @ -2686,7 +2687,7 @@ class InstructionTranslator(InstructionTranslatorBase): | ||||
|             symbolic_locals={},  # set below | ||||
|             # A global var is inserted only after a STORE_GLOBAL happens to it | ||||
|             symbolic_globals={}, | ||||
|             symbolic_torch_function_mode_stack=collections.deque(), | ||||
|             symbolic_torch_function_state=None,  # type: ignore[arg-type] # set below | ||||
|             f_code=f_code, | ||||
|             export=export, | ||||
|             inline_depth=0, | ||||
| @ -2721,7 +2722,9 @@ class InstructionTranslator(InstructionTranslatorBase): | ||||
|                 if k in f_locals | ||||
|             } | ||||
|  | ||||
|             self._init_torch_function_mode_stack() | ||||
|             self.symbolic_torch_function_state = SymbolicTorchFunctionState( | ||||
|                 torch_function_mode_stack | ||||
|             ) | ||||
|  | ||||
|             self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = [] | ||||
|             if export: | ||||
| @ -2762,29 +2765,6 @@ class InstructionTranslator(InstructionTranslatorBase): | ||||
|             ) | ||||
|             unimplemented(msg) | ||||
|  | ||||
|     def _init_torch_function_mode_stack(self): | ||||
|         from .variables.torch_function import TorchFunctionModeStackVariable | ||||
|  | ||||
|         TorchFunctionModeStackVariable.reset() | ||||
|  | ||||
|         self.symbolic_torch_function_mode_stack: Deque[ | ||||
|             TorchFunctionModeVariable | ||||
|         ] = collections.deque() | ||||
|         # We want to retrieve all modes to properly reconstruct the stack if needed | ||||
|         py_stack = get_torch_function_mode_stack(filter_ignored=False) | ||||
|  | ||||
|         if py_stack: | ||||
|             has_device_context = isinstance( | ||||
|                 py_stack[0], torch.utils._device.DeviceContext | ||||
|             ) | ||||
|  | ||||
|         for i, val in enumerate(py_stack): | ||||
|             self.symbolic_torch_function_mode_stack.append( | ||||
|                 variables.LazyVariableTracker.create( | ||||
|                     val, source=TorchFunctionModeStackSource(i) | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|     def get_example_value(self, source: Source): | ||||
|         if isinstance(source, LocalSource): | ||||
|             return self.f_locals[source.local_name] | ||||
| @ -3116,7 +3096,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase): | ||||
|                 code, | ||||
|                 sub_locals, | ||||
|                 parent.symbolic_globals, | ||||
|                 parent.symbolic_torch_function_mode_stack, | ||||
|                 parent.symbolic_torch_function_state, | ||||
|                 closure_cells, | ||||
|                 func, | ||||
|             ) | ||||
| @ -3126,7 +3106,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase): | ||||
|                 code, | ||||
|                 sub_locals, | ||||
|                 parent.symbolic_globals, | ||||
|                 parent.symbolic_torch_function_mode_stack, | ||||
|                 parent.symbolic_torch_function_state, | ||||
|                 closure_cells, | ||||
|                 func, | ||||
|             ) | ||||
| @ -3179,7 +3159,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase): | ||||
|         code: types.CodeType, | ||||
|         symbolic_locals: Dict[str, VariableTracker], | ||||
|         symbolic_globals: Dict[str, VariableTracker], | ||||
|         symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], | ||||
|         symbolic_torch_function_state: SymbolicTorchFunctionState, | ||||
|         closure_cells: Dict[str, VariableTracker], | ||||
|         funcvar: BaseUserFunctionVariable, | ||||
|     ) -> None: | ||||
| @ -3196,7 +3176,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase): | ||||
|             f_builtins=f_builtins, | ||||
|             symbolic_locals=symbolic_locals, | ||||
|             symbolic_globals=symbolic_globals, | ||||
|             symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack, | ||||
|             symbolic_torch_function_state=symbolic_torch_function_state, | ||||
|             instructions=instructions, | ||||
|             code_options={k: getattr(code, k) for k in get_code_keys()}, | ||||
|             f_code=code, | ||||
|  | ||||
| @ -163,6 +163,7 @@ def debug_insert_nops( | ||||
|         local_scope=locals(), | ||||
|         global_scope=globals(), | ||||
|         f_code=frame.f_code, | ||||
|         torch_function_mode_stack=[], | ||||
|     ) | ||||
|  | ||||
|     return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0)) | ||||
|  | ||||
| @ -303,6 +303,7 @@ manual_torch_name_rule_map = { | ||||
|     "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, | ||||
|     "torch.cuda._get_device_properties": TorchInGraphFunctionVariable, | ||||
|     "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, | ||||
|     "torch.set_default_device": UserFunctionVariable, | ||||
|     "torch.sparse_bsc_tensor": SkipFunctionVariable, | ||||
|     "torch.sparse_bsr_tensor": SkipFunctionVariable, | ||||
|     "torch.sparse_csc_tensor": SkipFunctionVariable, | ||||
| @ -2795,7 +2796,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys( | ||||
|         "torch.random.initial_seed", | ||||
|         "torch.random.seed", | ||||
|         "torch.return_types.pytree_register_structseq", | ||||
|         "torch.set_default_device", | ||||
|         "torch.set_default_dtype", | ||||
|         "torch.set_default_tensor_type", | ||||
|         "torch.set_deterministic_debug_mode", | ||||
| @ -3254,6 +3254,7 @@ MOD_INLINELIST = [ | ||||
|     "torch.testing", | ||||
|     "torch.utils._content_store", | ||||
|     "torch.utils._contextlib", | ||||
|     "torch.utils._device", | ||||
|     "torch.utils._foreach_utils", | ||||
|     "torch.utils._python_dispatch", | ||||
|     "torch.utils._pytree", | ||||
| @ -3588,7 +3589,9 @@ def lookup_inner( | ||||
|             if reasons is not None: | ||||
|                 reasons.add("func name is patched_init") | ||||
|             return SkipFunctionVariable | ||||
|         elif name == "__torch_function__": | ||||
|         elif name == "__torch_function__" or ( | ||||
|             obj and obj.__name__ == "__torch_function__" | ||||
|         ): | ||||
|             if reasons is not None: | ||||
|                 reasons.add("func name is __torch_function__") | ||||
|             return UserFunctionVariable | ||||
|  | ||||
| @ -63,7 +63,6 @@ import torch.fx.experimental.symbolic_shapes | ||||
| import torch.utils._pytree as pytree | ||||
| from torch import fx | ||||
| from torch._C import ( | ||||
|     _get_function_stack_at, | ||||
|     _instruction_counter, | ||||
|     _len_torch_function_stack, | ||||
|     _pop_torch_function_stack, | ||||
| @ -3062,14 +3061,10 @@ def is_parameter_freezing(): | ||||
|     return torch._inductor.config.freezing and not torch.is_grad_enabled() | ||||
|  | ||||
|  | ||||
| def get_torch_function_mode_stack(filter_ignored=True): | ||||
|     from .variables.torch_function import IGNORED_MODES | ||||
|  | ||||
|     stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())] | ||||
|     if filter_ignored: | ||||
|         stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] | ||||
|  | ||||
|     return stack | ||||
| def get_torch_function_mode_stack(): | ||||
|     return [ | ||||
|         get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) | ||||
|     ] | ||||
|  | ||||
|  | ||||
| def get_torch_function_mode_stack_at(ind): | ||||
| @ -3085,6 +3080,11 @@ def set_torch_function_mode_stack(stack): | ||||
|         _push_on_torch_function_stack(mode) | ||||
|  | ||||
|  | ||||
| def clear_torch_function_mode_stack(): | ||||
|     for i in range(_len_torch_function_stack()): | ||||
|         _pop_torch_function_stack() | ||||
|  | ||||
|  | ||||
| def verify_guard_fn_signature(value): | ||||
|     fn = value.__metadata_guard__ | ||||
|     sig = inspect.signature(fn) | ||||
|  | ||||
| @ -204,6 +204,7 @@ from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable | ||||
| from .torch_function import ( | ||||
|     build_torch_function_fn, | ||||
|     TensorWithTFOverrideVariable, | ||||
|     torch_function_mode_stack_state_mgr, | ||||
|     TorchFunctionModeVariable, | ||||
| ) | ||||
| from .user_defined import ( | ||||
| @ -1663,15 +1664,16 @@ class VariableBuilder: | ||||
|                 # but warning is not the end of the world | ||||
|                 assert isinstance(value.base, np.nditer) | ||||
|  | ||||
|         try: | ||||
|             tensor_value = _util._try_convert_to_tensor(value) | ||||
|             if readonly: | ||||
|                 from torch._prims_common import clone_preserve_strides | ||||
|         with torch_function_mode_stack_state_mgr.temp_restore_stack(): | ||||
|             try: | ||||
|                 tensor_value = _util._try_convert_to_tensor(value) | ||||
|                 if readonly: | ||||
|                     from torch._prims_common import clone_preserve_strides | ||||
|  | ||||
|                 tensor_value = clone_preserve_strides(tensor_value) | ||||
|         except NotImplementedError as e: | ||||
|             # failed to convert to tensor, graph break | ||||
|             unimplemented(str(e)) | ||||
|                     tensor_value = clone_preserve_strides(tensor_value) | ||||
|             except NotImplementedError as e: | ||||
|                 # failed to convert to tensor, graph break | ||||
|                 unimplemented(str(e)) | ||||
|  | ||||
|         # We do this because we want the full behavior of guarding the numpy ndarray as if it were | ||||
|         # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here | ||||
|  | ||||
| @ -125,6 +125,12 @@ class ContextWrappingVariable(VariableTracker): | ||||
|         if isinstance(args[0], UserFunctionVariable): | ||||
|             return WrappedUserFunctionVariable(args[0], self) | ||||
|  | ||||
|     def supports_graph_breaks(self): | ||||
|         return True | ||||
|  | ||||
|     def exit_on_graph_break(self): | ||||
|         return True | ||||
|  | ||||
|  | ||||
| class GenericContextWrappingVariable(UserDefinedObjectVariable): | ||||
|     # Some methods in ContextWrappingVariable assumes the arguments are | ||||
| @ -183,6 +189,12 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable): | ||||
|         tx.generic_context_manager_depth -= 1 | ||||
|         return x | ||||
|  | ||||
|     def supports_graph_breaks(self): | ||||
|         return False | ||||
|  | ||||
|     def exit_on_graph_break(self): | ||||
|         return True | ||||
|  | ||||
|  | ||||
| class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): | ||||
|     """represents torch grad requries grad""" | ||||
| @ -637,6 +649,8 @@ class TorchFunctionDisableVariable(ContextWrappingVariable): | ||||
|  | ||||
|     def _call_func(self, tx: "InstructionTranslator", values): | ||||
|         assert len(values) == 1 | ||||
|         tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0] | ||||
|         tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0] | ||||
|         tx.output.set_torch_function_state(values[0]) | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -149,6 +149,18 @@ tracing_state_functions = { | ||||
| bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"]) | ||||
|  | ||||
|  | ||||
| @functools.lru_cache(None) | ||||
| def get_overridable_functions(): | ||||
|     from itertools import chain | ||||
|  | ||||
|     from torch.overrides import get_overridable_functions as get_overridable_functions_ | ||||
|  | ||||
|     funcs = set(chain(*get_overridable_functions_().values())) | ||||
|     more = {torch.ones, torch.ones_like, torch.zeros, torch.zeros_like, torch.empty} | ||||
|     funcs.update(more) | ||||
|     return funcs | ||||
|  | ||||
|  | ||||
| class BaseTorchVariable(VariableTracker): | ||||
|     """common base for all torch.* functions, classes, modules and other things""" | ||||
|  | ||||
| @ -782,10 +794,10 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): | ||||
|             self, tx: "InstructionTranslator", *args, **kwargs | ||||
|         ): | ||||
|             assert not args and not kwargs | ||||
|             if not tx.symbolic_torch_function_mode_stack: | ||||
|             if not tx.symbolic_torch_function_state.mode_stack: | ||||
|                 raise unimplemented("Popping from an empty torch function mode stack") | ||||
|             TorchFunctionModeStackVariable.register_mutation(tx) | ||||
|             return tx.symbolic_torch_function_mode_stack.pop() | ||||
|             return tx.symbolic_torch_function_state.pop_torch_function_mode() | ||||
|  | ||||
|         @register(torch._C._push_on_torch_function_stack) | ||||
|         def handle_push_torch_function( | ||||
| @ -793,7 +805,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): | ||||
|         ): | ||||
|             assert len(args) == 1 and not kwargs | ||||
|             TorchFunctionModeStackVariable.register_mutation(tx) | ||||
|             tx.symbolic_torch_function_mode_stack.append(args[0]) | ||||
|             tx.symbolic_torch_function_state.push_torch_function_mode(args[0]) | ||||
|             return ConstantVariable.create(None) | ||||
|  | ||||
|         @register(torch._C._len_torch_function_stack) | ||||
| @ -801,7 +813,16 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): | ||||
|             self, tx: "InstructionTranslator", *args, **kwargs | ||||
|         ): | ||||
|             assert not args and not kwargs | ||||
|             return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack)) | ||||
|             return ConstantVariable.create( | ||||
|                 len(tx.symbolic_torch_function_state.mode_stack) | ||||
|             ) | ||||
|  | ||||
|         @register(torch._C._get_function_stack_at) | ||||
|         def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs): | ||||
|             assert len(args) == 1 and not kwargs | ||||
|             ind = args[0].as_python_constant() | ||||
|             assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack) | ||||
|             return tx.symbolic_torch_function_state.mode_stack[ind] | ||||
|  | ||||
|         @register(torch.set_default_device) | ||||
|         def handle_set_default_device( | ||||
| @ -820,7 +841,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): | ||||
|             else: | ||||
|                 TorchFunctionModeStackVariable.register_device_context_insertion(tx) | ||||
|  | ||||
|             return None | ||||
|             return ConstantVariable.create(None) | ||||
|  | ||||
|         return handlers | ||||
|  | ||||
| @ -833,6 +854,9 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): | ||||
|         from . import ConstantVariable, SymNodeVariable, TensorVariable | ||||
|         from .builder import wrap_fx_proxy | ||||
|  | ||||
|         if self.torch_function_override_enabled(tx, args, kwargs): | ||||
|             return dispatch_torch_function(tx, self, args, kwargs) | ||||
|  | ||||
|         if self.can_constant_fold_through() and check_unspec_or_constant_args( | ||||
|             args, kwargs | ||||
|         ): | ||||
| @ -850,147 +874,144 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): | ||||
|             if result: | ||||
|                 return result | ||||
|  | ||||
|         if can_dispatch_torch_function(tx, args, kwargs): | ||||
|             return dispatch_torch_function(tx, self, args, kwargs) | ||||
|         else: | ||||
|             any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) | ||||
|         any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) | ||||
|  | ||||
|             all_ints_or_floats = all( | ||||
|                 isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) | ||||
|                 for x in args | ||||
|             ) | ||||
|             if ( | ||||
|                 getattr(self.value, "__module__", "") == "torch" | ||||
|                 and self.value.__name__ in bin_ops | ||||
|                 and any_symints_or_symfloats | ||||
|                 and all_ints_or_floats | ||||
|             ): | ||||
|                 msg = f"""\ | ||||
|         all_ints_or_floats = all( | ||||
|             isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) | ||||
|             for x in args | ||||
|         ) | ||||
|         if ( | ||||
|             getattr(self.value, "__module__", "") == "torch" | ||||
|             and self.value.__name__ in bin_ops | ||||
|             and any_symints_or_symfloats | ||||
|             and all_ints_or_floats | ||||
|         ): | ||||
|             msg = f"""\ | ||||
| Calling {str(self.value)} on only torch.SymInt arguments is not yet supported. | ||||
| To support this behavior, we need to allow const-propping tensors that store symint data. | ||||
| For now, dynamo will explicitly graph break when it encounters user code with this behavior. | ||||
| """ | ||||
|                 log.warning(msg) | ||||
|                 unimplemented(msg) | ||||
|             log.warning(msg) | ||||
|             unimplemented(msg) | ||||
|  | ||||
|             # TODO(voz): Replace w/ dynamic shape rewrite table. | ||||
|             # Ideally, we would be able to do this at ctor time, but alas we need a combination | ||||
|             # of value + args to determine this. | ||||
|             fn_ = self.value | ||||
|             if any_symints_or_symfloats: | ||||
|                 torch_sym_op = f"_sym_{self.value.__name__}" | ||||
|                 if getattr(self.value, "__module__", None) == "math" and hasattr( | ||||
|                     torch, torch_sym_op | ||||
|                 ): | ||||
|                     fn_ = getattr(torch, torch_sym_op) | ||||
|         # TODO(voz): Replace w/ dynamic shape rewrite table. | ||||
|         # Ideally, we would be able to do this at ctor time, but alas we need a combination | ||||
|         # of value + args to determine this. | ||||
|         fn_ = self.value | ||||
|         if any_symints_or_symfloats: | ||||
|             torch_sym_op = f"_sym_{self.value.__name__}" | ||||
|             if getattr(self.value, "__module__", None) == "math" and hasattr( | ||||
|                 torch, torch_sym_op | ||||
|             ): | ||||
|                 fn_ = getattr(torch, torch_sym_op) | ||||
|  | ||||
|             fake_out_shape = None | ||||
|             if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): | ||||
|                 # Calling fake tensor propagation can mutate the out= tensor in | ||||
|                 # tx.output.tracked_fakes. tracked_fakes are used to apply | ||||
|                 # symbolic_shape guards. Mutating them destroys the information | ||||
|                 # prior to tracing, which is essential for creating right | ||||
|                 # guards. So save the shape now, and check later if it has | ||||
|                 # changed. If it has, graph break. | ||||
|                 fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape | ||||
|         fake_out_shape = None | ||||
|         if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): | ||||
|             # Calling fake tensor propagation can mutate the out= tensor in | ||||
|             # tx.output.tracked_fakes. tracked_fakes are used to apply | ||||
|             # symbolic_shape guards. Mutating them destroys the information | ||||
|             # prior to tracing, which is essential for creating right | ||||
|             # guards. So save the shape now, and check later if it has | ||||
|             # changed. If it has, graph break. | ||||
|             fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape | ||||
|  | ||||
|             tensor_variable = wrap_fx_proxy( | ||||
|                 tx=tx, | ||||
|                 proxy=tx.output.create_proxy( | ||||
|                     "call_function", | ||||
|                     fn_, | ||||
|                     *proxy_args_kwargs(args, kwargs), | ||||
|                 ), | ||||
|         tensor_variable = wrap_fx_proxy( | ||||
|             tx=tx, | ||||
|             proxy=tx.output.create_proxy( | ||||
|                 "call_function", | ||||
|                 fn_, | ||||
|                 *proxy_args_kwargs(args, kwargs), | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|         if ( | ||||
|             isinstance(tensor_variable, TensorVariable) | ||||
|             and "requires_grad" in kwargs | ||||
|             and kwargs["requires_grad"].as_python_constant() | ||||
|         ): | ||||
|             unimplemented( | ||||
|                 """factory functions that return tensors that require grad are not supported. | ||||
| Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" | ||||
|             ) | ||||
|  | ||||
|             if ( | ||||
|                 isinstance(tensor_variable, TensorVariable) | ||||
|                 and "requires_grad" in kwargs | ||||
|                 and kwargs["requires_grad"].as_python_constant() | ||||
|             ): | ||||
|                 unimplemented( | ||||
|                     """factory functions that return tensors that require grad are not supported. | ||||
| Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" | ||||
|                 ) | ||||
|  | ||||
|             if "out" in kwargs and not ( | ||||
|                 isinstance(kwargs["out"], variables.ConstantVariable) | ||||
|                 and kwargs["out"].as_python_constant() is None | ||||
|             ): | ||||
|                 # out variants of torch operators like torch.sort and | ||||
|                 # torch.sigmoid mutate the tensors in the out field. Track such | ||||
|                 # tensors and rewrite the symbolic locals. | ||||
|                 if isinstance(tensor_variable, TupleVariable): | ||||
|                     assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) | ||||
|                     output_tensor_names = [ | ||||
|                         tx.find_symbolic_locals_name(x) for x in kwargs["out"].items | ||||
|                     ] | ||||
|                     for idx, name in enumerate(output_tensor_names): | ||||
|                         if name in tx.symbolic_locals: | ||||
|                             tx.symbolic_locals[name] = tensor_variable.items[idx] | ||||
|                     for out_tensor, result_tensor in zip( | ||||
|                         kwargs["out"].items, tensor_variable.items | ||||
|                     ): | ||||
|                         if ( | ||||
|                             out_tensor.source | ||||
|                             and out_tensor in tx.output.graphargs | ||||
|                             and isinstance(out_tensor, variables.TensorVariable) | ||||
|                             and isinstance(result_tensor, variables.TensorVariable) | ||||
|                             and out_tensor.size != result_tensor.size | ||||
|                         ): | ||||
|                             # It's hard to get out variants with resizing on graph inputs work | ||||
|                             # properly across dynamo/aot/inductor, just fall back. | ||||
|                             unimplemented("out variants with resizing on graph inputs") | ||||
|                 elif isinstance(tensor_variable, TensorVariable): | ||||
|                     assert isinstance(kwargs["out"], TensorVariable) | ||||
|                     assert "example_value" in kwargs["out"].proxy.node.meta | ||||
|                     fake_tensor = tensor_variable.proxy.node.meta["example_value"] | ||||
|                     fake_out = kwargs["out"].proxy.node.meta["example_value"] | ||||
|         if "out" in kwargs and not ( | ||||
|             isinstance(kwargs["out"], variables.ConstantVariable) | ||||
|             and kwargs["out"].as_python_constant() is None | ||||
|         ): | ||||
|             # out variants of torch operators like torch.sort and | ||||
|             # torch.sigmoid mutate the tensors in the out field. Track such | ||||
|             # tensors and rewrite the symbolic locals. | ||||
|             if isinstance(tensor_variable, TupleVariable): | ||||
|                 assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) | ||||
|                 output_tensor_names = [ | ||||
|                     tx.find_symbolic_locals_name(x) for x in kwargs["out"].items | ||||
|                 ] | ||||
|                 for idx, name in enumerate(output_tensor_names): | ||||
|                     if name in tx.symbolic_locals: | ||||
|                         tx.symbolic_locals[name] = tensor_variable.items[idx] | ||||
|                 for out_tensor, result_tensor in zip( | ||||
|                     kwargs["out"].items, tensor_variable.items | ||||
|                 ): | ||||
|                     if ( | ||||
|                         kwargs["out"].source | ||||
|                         and kwargs["out"] in tx.output.graphargs | ||||
|                         and fake_out_shape != fake_tensor.shape | ||||
|                         out_tensor.source | ||||
|                         and out_tensor in tx.output.graphargs | ||||
|                         and isinstance(out_tensor, variables.TensorVariable) | ||||
|                         and isinstance(result_tensor, variables.TensorVariable) | ||||
|                         and out_tensor.size != result_tensor.size | ||||
|                     ): | ||||
|                         # It's hard to get out variants with resizing on graph inputs work | ||||
|                         # properly across dynamo/aot/inductor, just fall back. | ||||
|                         unimplemented("out variants with resizing on graph inputs") | ||||
|             elif isinstance(tensor_variable, TensorVariable): | ||||
|                 assert isinstance(kwargs["out"], TensorVariable) | ||||
|                 assert "example_value" in kwargs["out"].proxy.node.meta | ||||
|                 fake_tensor = tensor_variable.proxy.node.meta["example_value"] | ||||
|                 fake_out = kwargs["out"].proxy.node.meta["example_value"] | ||||
|                 if ( | ||||
|                     kwargs["out"].source | ||||
|                     and kwargs["out"] in tx.output.graphargs | ||||
|                     and fake_out_shape != fake_tensor.shape | ||||
|                 ): | ||||
|                     # It's hard to get out variants with resizing on graph inputs work | ||||
|                     # properly across dynamo/aot/inductor, just fall back. | ||||
|                     unimplemented("out variants with resizing on graph inputs") | ||||
|                 if not torch._prims_common.is_contiguous(fake_out): | ||||
|                     # It's difficult to handle strides correctly in functionalization | ||||
|                     # when calling an out= op with a non-contiguous out argument | ||||
|                     unimplemented( | ||||
|                         "out= op was called where output tensor was non-contiguous" | ||||
|                     ) | ||||
|                 name = tx.find_symbolic_locals_name(kwargs["out"]) | ||||
|                 if name in tx.symbolic_locals: | ||||
|                     tx.symbolic_locals[name] = tensor_variable | ||||
|             elif ( | ||||
|                 isinstance(tensor_variable, ConstantVariable) | ||||
|                 and tensor_variable.value is None | ||||
|             ): | ||||
|                 # Handle out-variant custom ops that return None. | ||||
|                 if isinstance(kwargs["out"], TensorVariable): | ||||
|                     assert "example_value" in kwargs["out"].proxy.node.meta | ||||
|                     fake_out = kwargs["out"].proxy.node.meta["example_value"] | ||||
|                     if not torch._prims_common.is_contiguous(fake_out): | ||||
|                         # It's difficult to handle strides correctly in functionalization | ||||
|                         # when calling an out= op with a non-contiguous out argument | ||||
|                         unimplemented( | ||||
|                             "out= op was called where output tensor was non-contiguous" | ||||
|                         ) | ||||
|                     name = tx.find_symbolic_locals_name(kwargs["out"]) | ||||
|                     if name in tx.symbolic_locals: | ||||
|                         tx.symbolic_locals[name] = tensor_variable | ||||
|                 elif ( | ||||
|                     isinstance(tensor_variable, ConstantVariable) | ||||
|                     and tensor_variable.value is None | ||||
|                 ): | ||||
|                     # Handle out-variant custom ops that return None. | ||||
|                     if isinstance(kwargs["out"], TensorVariable): | ||||
|                         assert "example_value" in kwargs["out"].proxy.node.meta | ||||
|                         fake_out = kwargs["out"].proxy.node.meta["example_value"] | ||||
|                 elif isinstance(kwargs["out"], ListVariable): | ||||
|                     for idx, x in enumerate(kwargs["out"].items): | ||||
|                         assert "example_value" in x.proxy.node.meta  # type: ignore[attr-defined] | ||||
|                         fake_out = x.proxy.node.meta["example_value"]  # type: ignore[attr-defined] | ||||
|                         if not torch._prims_common.is_contiguous(fake_out): | ||||
|                             # It's difficult to handle strides correctly in functionalization | ||||
|                             # when calling an out= op with a non-contiguous out argument | ||||
|                             unimplemented( | ||||
|                                 "out= op was called where output tensor was non-contiguous" | ||||
|                                 "out= op was called where some of the output tensors were non-contiguous" | ||||
|                             ) | ||||
|                     elif isinstance(kwargs["out"], ListVariable): | ||||
|                         for idx, x in enumerate(kwargs["out"].items): | ||||
|                             assert "example_value" in x.proxy.node.meta  # type: ignore[attr-defined] | ||||
|                             fake_out = x.proxy.node.meta["example_value"]  # type: ignore[attr-defined] | ||||
|                             if not torch._prims_common.is_contiguous(fake_out): | ||||
|                                 # It's difficult to handle strides correctly in functionalization | ||||
|                                 # when calling an out= op with a non-contiguous out argument | ||||
|                                 unimplemented( | ||||
|                                     "out= op was called where some of the output tensors were non-contiguous" | ||||
|                                 ) | ||||
|                 else: | ||||
|                     unimplemented(f"out variant of {type(kwargs['out'])}") | ||||
|             else: | ||||
|                 unimplemented(f"out variant of {type(kwargs['out'])}") | ||||
|  | ||||
|             return tensor_variable | ||||
|         return tensor_variable | ||||
|  | ||||
|     def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs): | ||||
|         """inline behavior of torch.nn.modules.utils._ntuple""" | ||||
| @ -1118,3 +1139,12 @@ Either create the tensor outside the compiled region, or do not set the tensor t | ||||
|             source | ||||
|         ) | ||||
|         return result | ||||
|  | ||||
|     def torch_function_override_enabled(self, tx, args, kwargs): | ||||
|         return ( | ||||
|             self.get_function() in get_overridable_functions() | ||||
|             or isinstance( | ||||
|                 self.get_function(), | ||||
|                 (torch._ops.OpOverload, torch._ops.OpOverloadPacket), | ||||
|             ) | ||||
|         ) and can_dispatch_torch_function(tx, args, kwargs) | ||||
|  | ||||
| @ -1,20 +1,36 @@ | ||||
| # mypy: ignore-errors | ||||
|  | ||||
| import collections | ||||
| import contextlib | ||||
| import inspect | ||||
| from typing import Dict, List, TYPE_CHECKING | ||||
| from typing import Deque, Dict, List, TYPE_CHECKING | ||||
|  | ||||
| import torch._C | ||||
| import torch.utils._pytree as pytree | ||||
| from torch._guards import Source | ||||
| from torch.overrides import _get_overloaded_args, get_default_nowrap_functions | ||||
| from torch.overrides import ( | ||||
|     _get_overloaded_args, | ||||
|     get_default_nowrap_functions, | ||||
|     TorchFunctionMode, | ||||
| ) | ||||
| from torch.utils._device import DeviceContext | ||||
|  | ||||
| from ..exc import unimplemented | ||||
| from ..guards import GuardBuilder, install_guard | ||||
| from ..polyfills import NoEnterTorchFunctionMode | ||||
| from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource | ||||
| from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter | ||||
| from ..utils import ( | ||||
|     class_has_getattribute, | ||||
|     clear_torch_function_mode_stack, | ||||
|     get_safe_global_name, | ||||
|     has_torch_function, | ||||
|     is_tensor_base_attr_getter, | ||||
|     set_torch_function_mode_stack, | ||||
| ) | ||||
| from .base import VariableTracker | ||||
| from .constant import ConstantVariable | ||||
| from .ctx_manager import ContextWrappingVariable | ||||
| from .ctx_manager import GenericContextWrappingVariable | ||||
| from .lazy import LazyVariableTracker | ||||
| from .lists import TupleVariable | ||||
| from .tensor import TensorSubclassVariable, TensorVariable | ||||
| from .user_defined import UserDefinedObjectVariable | ||||
| @ -52,11 +68,92 @@ banned_attrs = [ | ||||
|     if is_tensor_base_attr_getter(fn) | ||||
| ] | ||||
|  | ||||
| # Today set default device is placed in the graph and guarded on separately | ||||
| # so we should not trace through it. In the future we can trace it once | ||||
| # mode tracing is implemented and not put in the graph, but this is more | ||||
| # of a BE project and can be evaluated later | ||||
| IGNORED_MODES = {DeviceContext} | ||||
|  | ||||
| # Used to clear/restore the python torch function mode stack and temporarily restore it as needed | ||||
| class TorchFunctionModeStackStateManager: | ||||
|     def __init__(self): | ||||
|         self.stack = [] | ||||
|  | ||||
|     def __enter__(self): | ||||
|         self.stack = torch.overrides._get_current_function_mode_stack() | ||||
|         clear_torch_function_mode_stack() | ||||
|  | ||||
|     def __exit__(self, exc_type, exc_value, traceback): | ||||
|         set_torch_function_mode_stack(self.stack) | ||||
|         self.stack = [] | ||||
|  | ||||
|     @contextlib.contextmanager | ||||
|     def temp_restore_stack(self): | ||||
|         prev = torch.overrides._get_current_function_mode_stack() | ||||
|         set_torch_function_mode_stack(self.stack) | ||||
|         try: | ||||
|             yield | ||||
|         finally: | ||||
|             set_torch_function_mode_stack(prev) | ||||
|  | ||||
|  | ||||
| torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager() | ||||
|  | ||||
|  | ||||
| class SymbolicTorchFunctionState: | ||||
|     def __init__(self, py_stack): | ||||
|         # This is annoyingly complicated because of how the torch function subclass + mode C API was designed | ||||
|         # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass | ||||
|         # These are their definitions: | ||||
|         # 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered | ||||
|         # (if either are entered, this will be False) | ||||
|         # 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR | ||||
|         # torch._C.DisableTorchFunction has been entered | ||||
|         # To disambiguate these and keep myself sane I added a C API to check whether all torch function | ||||
|         # concepts (modes and subclasses) are enabled. | ||||
|         # This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate | ||||
|         # the stack length from the enablement state of torch function modes. | ||||
|         # This is important because now if a mode is pushed while dynamo is tracing, we know whether | ||||
|         # or not torch function modes are enabled and whether we should trace it. | ||||
|         self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled() | ||||
|  | ||||
|         # This differs from the C API of the same name | ||||
|         # this will only be false iff we have entered torch._C.DisableTorchFunction | ||||
|         # and does not take into account the mode stack length, while the C API bundles these | ||||
|         # two concepts | ||||
|         self.torch_function_mode_enabled = ( | ||||
|             not torch._C._is_torch_function_all_disabled() | ||||
|         ) | ||||
|  | ||||
|         self.cur_mode = None | ||||
|  | ||||
|         TorchFunctionModeStackVariable.reset() | ||||
|  | ||||
|         self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque() | ||||
|  | ||||
|         for i, val in enumerate(py_stack): | ||||
|             self.mode_stack.append( | ||||
|                 LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) | ||||
|             ) | ||||
|  | ||||
|     def in_torch_function_mode(self): | ||||
|         return len(self.mode_stack) > 0 | ||||
|  | ||||
|     def pop_torch_function_mode(self): | ||||
|         return self.mode_stack.pop() | ||||
|  | ||||
|     def push_torch_function_mode(self, mode_var): | ||||
|         self.mode_stack.append(mode_var) | ||||
|  | ||||
|     def call_torch_function_mode(self, tx, fn, types, args, kwargs): | ||||
|         with self._pop_mode_for_inlining() as cur_mode: | ||||
|             return cur_mode.call_torch_function(tx, fn, types, args, kwargs) | ||||
|  | ||||
|     @contextlib.contextmanager | ||||
|     def _pop_mode_for_inlining(self): | ||||
|         old_mode = self.cur_mode | ||||
|         self.cur_mode = self.pop_torch_function_mode() | ||||
|         try: | ||||
|             yield self.cur_mode | ||||
|         finally: | ||||
|             mode = self.cur_mode | ||||
|             self.cur_mode = old_mode | ||||
|             self.push_torch_function_mode(mode) | ||||
|  | ||||
|  | ||||
| class TorchFunctionModeStackVariable(VariableTracker): | ||||
| @ -88,19 +185,20 @@ class TorchFunctionModeStackVariable(VariableTracker): | ||||
|     def register_mutation(cls, tx: "InstructionTranslator"): | ||||
|         if cls.stack_value_singleton not in tx.output.side_effects: | ||||
|             var = cls( | ||||
|                 source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack | ||||
|                 source=Source(), | ||||
|                 symbolic_stack=tx.symbolic_torch_function_state.mode_stack, | ||||
|             ) | ||||
|             tx.output.side_effects.track_mutable(cls.stack_value_singleton, var) | ||||
|             tx.output.side_effects.mutation(var) | ||||
|  | ||||
|     @classmethod | ||||
|     def register_device_context_insertion(cls, tx: "InstructionTranslator"): | ||||
|         stack = tx.symbolic_torch_function_mode_stack | ||||
|         stack = tx.symbolic_torch_function_state.mode_stack | ||||
|         if stack and cls.is_device_context(stack[0]): | ||||
|             return | ||||
|         else: | ||||
|             cls.offset += 1 | ||||
|             tx.symbolic_torch_function_mode_stack.insert( | ||||
|             stack.insert( | ||||
|                 0, | ||||
|                 TorchFunctionModeVariable( | ||||
|                     None, source=TorchFunctionModeStackSource(-cls.offset) | ||||
| @ -109,7 +207,7 @@ class TorchFunctionModeStackVariable(VariableTracker): | ||||
|  | ||||
|     @classmethod | ||||
|     def clear_default_device(cls, tx: "InstructionTranslator"): | ||||
|         stack = tx.symbolic_torch_function_mode_stack | ||||
|         stack = tx.symbolic_torch_function_state.mode_stack | ||||
|         if stack and cls.is_device_context(stack[0]): | ||||
|             stack.popleft() | ||||
|             cls.offset -= 1 | ||||
| @ -123,24 +221,88 @@ class TorchFunctionModeStackVariable(VariableTracker): | ||||
|         return ind + cls.offset | ||||
|  | ||||
|  | ||||
| class TorchFunctionModeVariable(ContextWrappingVariable): | ||||
|     def __init__(self, value, **kwargs): | ||||
|         super().__init__(value, **kwargs) | ||||
|         self.value = value | ||||
|  | ||||
| class TorchFunctionModeVariable(GenericContextWrappingVariable): | ||||
|     @staticmethod | ||||
|     def get_global_mangled_name(tx, val): | ||||
|         return get_safe_global_name( | ||||
|             tx, f"__torch_function_mode_{val.__class__.__name__}", val | ||||
|     def is_supported_torch_function_mode(ty): | ||||
|         # Supported in this sense means we can support graph breaks under the | ||||
|         # context. | ||||
|         # We are able to trace custom modes but if there are graph breaks under them | ||||
|         # and they have a custom __enter__/__exit__ we don't handle this for the | ||||
|         # same reason we don't handle generic context managers: there may be side effects | ||||
|         # that are now affected by executing the funtion across two frames instead of one | ||||
|         # Today we support the enter/exit of the default TorchFunctionMode as well as | ||||
|         # DeviceContext (which is used for set_default_device) | ||||
|         return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( | ||||
|             not class_has_getattribute(ty) | ||||
|             and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__ | ||||
|             and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__ | ||||
|         ) | ||||
|  | ||||
|     def __init__(self, value, source=None, **kwargs): | ||||
|         if value is not None: | ||||
|             super().__init__(value, **kwargs) | ||||
|         self.value = value | ||||
|         self.cm_obj = value  # needed for BC with calling enter from CM code | ||||
|         self.source = source | ||||
|  | ||||
|     def reconstruct(self, codegen): | ||||
|         # We don't support locally created torch function modes yet | ||||
|         # This shouldn't be called unless we have a source | ||||
|         assert self.source | ||||
|         self.source.reconstruct(codegen) | ||||
|  | ||||
|     def _call_func(self, tx, values): | ||||
|         unimplemented("torch function mode context manager is not supported yet") | ||||
|     def module_name(self): | ||||
|         return self.value.__module__ | ||||
|  | ||||
|     def fn_name(self): | ||||
|         return type(self.value).__name__ | ||||
|  | ||||
|     def python_type(self): | ||||
|         return type(self.value) | ||||
|  | ||||
|     def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): | ||||
|         return call_torch_function( | ||||
|             tx, | ||||
|             self, | ||||
|             build_torch_function_fn(tx, self.value, self.source), | ||||
|             fn, | ||||
|             types, | ||||
|             args, | ||||
|             kwargs, | ||||
|         ) | ||||
|  | ||||
|     def enter(self, tx): | ||||
|         from .torch import TorchInGraphFunctionVariable | ||||
|  | ||||
|         if isinstance(self.value, NoEnterTorchFunctionMode): | ||||
|             return ConstantVariable.create(None) | ||||
|  | ||||
|         TorchInGraphFunctionVariable( | ||||
|             torch._C._push_on_torch_function_stack | ||||
|         ).call_function(tx, [self], {}) | ||||
|         return ConstantVariable.create(None) | ||||
|  | ||||
|     def exit(self, tx: "InstructionTranslator", *args): | ||||
|         from .torch import TorchInGraphFunctionVariable | ||||
|  | ||||
|         TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( | ||||
|             tx, [], {} | ||||
|         ) | ||||
|         return ConstantVariable.create(None) | ||||
|  | ||||
|     def reconstruct_type(self, codegen): | ||||
|         ty = NoEnterTorchFunctionMode | ||||
|         codegen( | ||||
|             AttrSource( | ||||
|                 codegen.tx.import_source(ty.__module__), | ||||
|                 ty.__name__, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def supports_graph_breaks(self): | ||||
|         return True | ||||
|  | ||||
|     def exit_on_graph_break(self): | ||||
|         return False | ||||
|  | ||||
|  | ||||
| def _get_all_args(args, kwargs): | ||||
| @ -231,9 +393,13 @@ def build_torch_function_fn(tx: "InstructionTranslator", value, source): | ||||
|  | ||||
|  | ||||
| def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): | ||||
|     return tx.output.torch_function_enabled and any( | ||||
|     has_overridden_args = any( | ||||
|         has_torch_function(arg) for arg in _get_all_args(args, kwargs) | ||||
|     ) | ||||
|     tf_state = tx.symbolic_torch_function_state | ||||
|     return (has_overridden_args and tf_state.torch_function_subclass_enabled) or ( | ||||
|         tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode() | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): | ||||
| @ -245,11 +411,20 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): | ||||
|         _get_subclass_type, | ||||
|     ) | ||||
|  | ||||
|     types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]) | ||||
|  | ||||
|     if tx.symbolic_torch_function_state.in_torch_function_mode(): | ||||
|         res = tx.symbolic_torch_function_state.call_torch_function_mode( | ||||
|             tx, fn, types, args, kwargs | ||||
|         ) | ||||
|         if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): | ||||
|             return res | ||||
|  | ||||
|     for arg in overloaded_args: | ||||
|         res = arg.call_torch_function( | ||||
|             tx, | ||||
|             fn, | ||||
|             TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]), | ||||
|             types, | ||||
|             args, | ||||
|             kwargs, | ||||
|         ) | ||||
|  | ||||
| @ -9,6 +9,7 @@ import inspect | ||||
| import itertools | ||||
| import random | ||||
| import sys | ||||
| import threading | ||||
| import types | ||||
| import warnings | ||||
| from typing import Dict, Generic, List, TYPE_CHECKING | ||||
| @ -82,11 +83,6 @@ def is_forbidden_context_manager(ctx): | ||||
|         from _pytest.python_api import RaisesContext | ||||
|         from _pytest.recwarn import WarningsChecker | ||||
|  | ||||
|         # TODO mlazos: Temporary to get this stack to pass | ||||
|         # remove in subsequent PR | ||||
|         from torch.overrides import BaseTorchFunctionMode | ||||
|  | ||||
|         f_ctxs.append(BaseTorchFunctionMode) | ||||
|         f_ctxs.append(RaisesContext) | ||||
|         f_ctxs.append(WarningsChecker) | ||||
|     except ImportError: | ||||
| @ -413,15 +409,25 @@ class UserDefinedClassVariable(UserDefinedVariable): | ||||
|             and self.source | ||||
|             and not is_forbidden_context_manager(self.value) | ||||
|         ): | ||||
|             # import here to avoid an unfortunate circular dependency. | ||||
|             from torch.overrides import TorchFunctionMode | ||||
|  | ||||
|             from .ctx_manager import GenericContextWrappingVariable | ||||
|             from .torch_function import TorchFunctionModeVariable | ||||
|  | ||||
|             if issubclass( | ||||
|                 self.value, TorchFunctionMode | ||||
|             ) and TorchFunctionModeVariable.is_supported_torch_function_mode( | ||||
|                 self.value | ||||
|             ): | ||||
|                 var_cls = TorchFunctionModeVariable | ||||
|             else: | ||||
|                 var_cls = GenericContextWrappingVariable | ||||
|  | ||||
|             cm_obj = tx.output.side_effects.track_object_new( | ||||
|                 self.source, self.value, GenericContextWrappingVariable, {} | ||||
|                 self.source, self.value, var_cls, {} | ||||
|             ) | ||||
|             cm_obj.call_method(tx, "__init__", args, kwargs) | ||||
|             return cm_obj | ||||
|  | ||||
|         elif is_namedtuple_cls(self.value): | ||||
|             fields = namedtuple_fields(self.value) | ||||
|             # check if this a quasi-namedtuple or a real one | ||||
| @ -711,7 +717,7 @@ class UserDefinedObjectVariable(UserDefinedVariable): | ||||
|             if method is object.__init__: | ||||
|                 return ConstantVariable.create(None) | ||||
|  | ||||
|             if is_standard_setattr(method): | ||||
|             if is_standard_setattr(method) or isinstance(self.value, threading.local): | ||||
|                 return self.method_setattr_standard(tx, *args, **kwargs) | ||||
|  | ||||
|             # [NOTE] OrderedDict, dict subtypes must always have source | ||||
| @ -809,7 +815,7 @@ class UserDefinedObjectVariable(UserDefinedVariable): | ||||
|     def needs_slow_setattr(self): | ||||
|         return not is_standard_setattr( | ||||
|             inspect.getattr_static(self.value, "__setattr__", None) | ||||
|         ) | ||||
|         ) and not isinstance(self.value, threading.local) | ||||
|  | ||||
|     def unpack_var_sequence(self, tx): | ||||
|         if ( | ||||
|  | ||||
| @ -506,7 +506,11 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode): | ||||
|  | ||||
|     def __torch_function__(self, func, types, args=(), kwargs=None): | ||||
|         kwargs = kwargs or {} | ||||
|         if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc: | ||||
|         if ( | ||||
|             not torch.compiler.is_dynamo_compiling() | ||||
|             and log.isEnabledFor(logging.DEBUG) | ||||
|             and config.extended_debug_current_loc | ||||
|         ): | ||||
|             frame = _find_user_code_frame() | ||||
|             if frame is not None: | ||||
|                 log.debug( | ||||
|  | ||||
| @ -28,6 +28,7 @@ from torch._ops import HigherOrderOperator | ||||
| from torch._subclasses.fake_tensor import FakeTensorMode | ||||
| from torch._subclasses.functional_tensor import disable_functional_mode | ||||
| from torch.fx.experimental.proxy_tensor import ( | ||||
|     _temp_remove_metadata_torch_function_mode, | ||||
|     _temp_remove_pre_dispatch_torch_function_mode, | ||||
|     disable_proxy_modes_tracing, | ||||
|     ProxyTorchDispatchMode, | ||||
| @ -129,6 +130,10 @@ def cond(pred, true_fn, false_fn, operands): | ||||
|     if torch.compiler.is_dynamo_compiling(): | ||||
|         return cond_op(pred, true_fn, false_fn, operands) | ||||
|  | ||||
|     from torch._dynamo.backends.debugging import ( | ||||
|         make_eager_backend_with_torch_function_mode, | ||||
|     ) | ||||
|  | ||||
|     if isinstance(pred, (bool, int, float)): | ||||
|         log.warning( | ||||
|             "Pred is a Python constant. When used with torch.cond, it executes only one of the branches." | ||||
| @ -169,12 +174,15 @@ def cond(pred, true_fn, false_fn, operands): | ||||
|     def _cond_op_wrapper(*args, **kwargs): | ||||
|         return cond_op(*args, **kwargs) | ||||
|  | ||||
|     with _set_compilation_env(): | ||||
|         with torch._dynamo.utils.disable_cache_limit(): | ||||
|             with _temp_remove_pre_dispatch_torch_function_mode(): | ||||
|                 return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)( | ||||
|                     pred, true_fn, false_fn, operands | ||||
|                 ) | ||||
|     with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode(): | ||||
|         with _temp_remove_metadata_torch_function_mode() as metadata_mode: | ||||
|             if metadata_mode: | ||||
|                 backend = make_eager_backend_with_torch_function_mode(metadata_mode) | ||||
|             else: | ||||
|                 backend = "eager" | ||||
|             return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)( | ||||
|                 pred, true_fn, false_fn, operands | ||||
|             ) | ||||
|  | ||||
|  | ||||
| def create_fw_bw_graph_branches(true_fn, false_fn, *operands): | ||||
|  | ||||
| @ -15,7 +15,11 @@ from torch._higher_order_ops.utils import ( | ||||
| ) | ||||
| from torch._ops import HigherOrderOperator | ||||
| from torch._subclasses.fake_tensor import FakeTensorMode | ||||
| from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree | ||||
| from torch.fx.experimental.proxy_tensor import ( | ||||
|     _temp_remove_metadata_torch_function_mode, | ||||
|     ProxyTorchDispatchMode, | ||||
|     track_tensor_tree, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class WhileLoopOp(HigherOrderOperator): | ||||
| @ -113,6 +117,9 @@ def while_loop(cond_fn, body_fn, carried_inputs): | ||||
|         - 'while_loop' only supports **inference** right now. Autograd will be supported in the future. | ||||
|  | ||||
|     """ | ||||
|     from torch._dynamo.backends.debugging import ( | ||||
|         make_eager_backend_with_torch_function_mode, | ||||
|     ) | ||||
|  | ||||
|     # Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo. | ||||
|     # parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs. | ||||
| @ -140,9 +147,15 @@ def while_loop(cond_fn, body_fn, carried_inputs): | ||||
|         return while_loop_op(*args, **kwargs) | ||||
|  | ||||
|     with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): | ||||
|         return torch.compile(_while_loop_op_wrapper, backend="eager", fullgraph=True)( | ||||
|             cond_fn, body_fn, carried_inputs, additional_inputs | ||||
|         ) | ||||
|         with _temp_remove_metadata_torch_function_mode() as metadata_mode: | ||||
|             with _temp_remove_metadata_torch_function_mode() as metadata_mode: | ||||
|                 if metadata_mode: | ||||
|                     backend = make_eager_backend_with_torch_function_mode(metadata_mode) | ||||
|                 else: | ||||
|                     backend = "eager" | ||||
|                 return torch.compile( | ||||
|                     _while_loop_op_wrapper, backend=backend, fullgraph=True | ||||
|                 )(cond_fn, body_fn, carried_inputs, additional_inputs) | ||||
|  | ||||
|  | ||||
| @while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd) | ||||
|  | ||||
| @ -2515,62 +2515,40 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { | ||||
|  public: | ||||
|   TORCH_FUNCTION_MODE_STACK( | ||||
|       const py::list& initial_stack, | ||||
|       const py::list& ignored_types, | ||||
|       py::object verbose_code_parts) | ||||
|       : LeafGuard(std::move(verbose_code_parts)), | ||||
|         _ref_stack(), | ||||
|         _ignored_types() { | ||||
|       : LeafGuard(std::move(verbose_code_parts)), _ref_stack() { | ||||
|     Py_ssize_t len = PyList_Size(initial_stack.ptr()); | ||||
|     for (Py_ssize_t idx = 0; idx < len; idx++) { | ||||
|       PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref | ||||
|       this->_ref_stack.push_back(Py_TYPE(mode)); | ||||
|     } | ||||
|  | ||||
|     len = PyList_Size(ignored_types.ptr()); | ||||
|     for (Py_ssize_t idx = 0; idx < len; idx++) { | ||||
|       PyObject* type_obj = | ||||
|           PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref | ||||
|       if (PyType_Check(type_obj) == 0) { | ||||
|         PyErr_SetString( | ||||
|             PyExc_TypeError, "ignored_types should contain a list of types"); | ||||
|         return; | ||||
|       } | ||||
|       PyTypeObject* type = (PyTypeObject*)type_obj; | ||||
|       this->_ignored_types.insert(type); | ||||
|       auto type = Py_TYPE(mode); | ||||
|       this->_ref_stack.push_back(type); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   bool check_nopybind(PyObject* value) override { | ||||
|     // Ignore value arg, only used to satisfy the interface | ||||
|     size_t ref_ind = 0; | ||||
|     int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); | ||||
|     const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len(); | ||||
|     const size_t ref_stack_size = this->_ref_stack.size(); | ||||
|  | ||||
|     for (int64_t idx = 0; idx < len; idx++) { | ||||
|     if (len != ref_stack_size) { | ||||
|       return false; | ||||
|     } | ||||
|  | ||||
|     for (int64_t idx = 0; (size_t)idx < len; idx++) { | ||||
|       std::shared_ptr<c10::SafePyObject> mode = | ||||
|           at::impl::PythonTorchFunctionTLS::get_stack_at(idx); | ||||
|  | ||||
|       PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); | ||||
|       // skip ignored types | ||||
|       if (this->_ignored_types.count(mode_type) > 0) { | ||||
|         continue; | ||||
|       } | ||||
|       // if we already have more non-ignored modes than the ref stack | ||||
|       // or if the mode doesn't match at the current index, return false | ||||
|       else if ( | ||||
|           (ref_stack_size == 0) || (ref_ind > ref_stack_size - 1) || | ||||
|           mode_type != _ref_stack[ref_ind]) { | ||||
|       if (mode_type != _ref_stack.at(idx)) { | ||||
|         return false; | ||||
|       } | ||||
|       ref_ind++; | ||||
|     } | ||||
|  | ||||
|     return ref_ind == this->_ref_stack.size(); | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   std::vector<PyTypeObject*> _ref_stack; | ||||
|   std::set<PyTypeObject*> _ignored_types; | ||||
| }; | ||||
|  | ||||
| class TENSOR_MATCH : public LeafGuard { | ||||
| @ -3672,7 +3650,7 @@ PyObject* torch_c_dynamo_guards_init() { | ||||
|       LeafGuard, | ||||
|       std::shared_ptr<TORCH_FUNCTION_MODE_STACK>>( | ||||
|       py_m, "TORCH_FUNCTION_MODE_STACK") | ||||
|       .def(py::init<py::list, py::list, py::list>()) | ||||
|       .def(py::init<py::list, py::list>()) | ||||
|       .def("__call__", &TORCH_FUNCTION_MODE_STACK::check); | ||||
|   py::class_<DATA_PTR_MATCH, LeafGuard, std::shared_ptr<DATA_PTR_MATCH>>( | ||||
|       py_m, "DATA_PTR_MATCH") | ||||
| @ -3903,10 +3881,9 @@ PyObject* torch_c_dynamo_guards_init() { | ||||
|           "add_torch_function_mode_stack_guard", | ||||
|           [](GuardManager& self, | ||||
|              const py::list& initial_stack, | ||||
|              const py::list& ignored_types, | ||||
|              py::object verbose_code_parts) -> void { | ||||
|             self.add_leaf_guard(std::make_shared<TORCH_FUNCTION_MODE_STACK>( | ||||
|                 initial_stack, ignored_types, std::move(verbose_code_parts))); | ||||
|                 initial_stack, std::move(verbose_code_parts))); | ||||
|           }) | ||||
|       .def( | ||||
|           "add_data_ptr_guard", | ||||
|  | ||||
| @ -17,7 +17,7 @@ import typing_extensions | ||||
| import warnings | ||||
| import weakref | ||||
| from collections import defaultdict | ||||
| from contextlib import contextmanager, ExitStack, nullcontext | ||||
| from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext | ||||
| from dataclasses import dataclass | ||||
| from typing import ( | ||||
|     Any, | ||||
| @ -1084,38 +1084,43 @@ class PythonKeyTracer(Tracer): | ||||
|             return e | ||||
|  | ||||
|  | ||||
| @contextmanager | ||||
| def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]: | ||||
|     from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode | ||||
| def _make_temp_remove_mode_context_manager( | ||||
|     mode_ty: Type[TorchFunctionMode], | ||||
| ) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]: | ||||
|     @contextmanager | ||||
|     def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]: | ||||
|         from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode | ||||
|  | ||||
|     temp_elements = [] | ||||
|     pre_dispatch_mode = None | ||||
|         temp_elements = [] | ||||
|         removed_mode = None | ||||
|  | ||||
|     while _len_torch_function_stack() > 0: | ||||
|         mode = _pop_mode() | ||||
|         if isinstance(mode, PreDispatchTorchFunctionMode): | ||||
|             pre_dispatch_mode = mode | ||||
|             break | ||||
|         else: | ||||
|             temp_elements.append(mode) | ||||
|         while _len_torch_function_stack() > 0: | ||||
|             mode = _pop_mode() | ||||
|             if isinstance(mode, mode_ty): | ||||
|                 removed_mode = mode | ||||
|                 break | ||||
|             else: | ||||
|                 temp_elements.append(mode) | ||||
|  | ||||
|     for mode in reversed(temp_elements): | ||||
|         _push_mode(mode) | ||||
|         for mode in reversed(temp_elements): | ||||
|             _push_mode(mode) | ||||
|  | ||||
|     try: | ||||
|         yield | ||||
|         try: | ||||
|             yield removed_mode | ||||
|  | ||||
|     finally: | ||||
|         if pre_dispatch_mode is not None: | ||||
|             count = len(temp_elements) | ||||
|             while count > 0: | ||||
|                 mode = _pop_mode() | ||||
|                 count -= 1 | ||||
|         finally: | ||||
|             if removed_mode is not None: | ||||
|                 count = len(temp_elements) | ||||
|                 while count > 0: | ||||
|                     mode = _pop_mode() | ||||
|                     count -= 1 | ||||
|  | ||||
|             temp_elements.append(pre_dispatch_mode) | ||||
|                 temp_elements.append(removed_mode) | ||||
|  | ||||
|             for mode in reversed(temp_elements): | ||||
|                 _push_mode(mode) | ||||
|                 for mode in reversed(temp_elements): | ||||
|                     _push_mode(mode) | ||||
|  | ||||
|     return context_manager_fn | ||||
|  | ||||
|  | ||||
| @torch._disable_dynamo | ||||
| @ -1230,6 +1235,11 @@ class TorchFunctionMetadataMode(TorchFunctionMode): | ||||
|         return func(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| _temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager( | ||||
|     TorchFunctionMetadataMode | ||||
| ) | ||||
|  | ||||
|  | ||||
| # This mode is **only** used for pre_dispatch tracing. | ||||
| # In particular, we need to make sure that autograd/autocast API's | ||||
| # that do not desugar into dispatcher operators stay in the graph. | ||||
| @ -1258,6 +1268,11 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode): | ||||
|         return func(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| _temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager( | ||||
|     PreDispatchTorchFunctionMode | ||||
| ) | ||||
|  | ||||
|  | ||||
| class ProxyTorchDispatchMode(TorchDispatchMode): | ||||
|     # Ensure this is read-only; this exists only for legacy reasons | ||||
|     @property | ||||
|  | ||||
| @ -19,6 +19,7 @@ from torch._higher_order_ops.flex_attention import ( | ||||
| ) | ||||
| from torch._higher_order_ops.utils import _set_compilation_env | ||||
| from torch.fx.experimental.proxy_tensor import ( | ||||
|     _temp_remove_metadata_torch_function_mode, | ||||
|     _temp_remove_pre_dispatch_torch_function_mode, | ||||
| ) | ||||
| from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input | ||||
| @ -1027,6 +1028,10 @@ def flex_attention( | ||||
|     if not torch._dynamo.is_dynamo_supported(): | ||||
|         raise RuntimeError("flex_attention requires dynamo support") | ||||
|  | ||||
|     from torch._dynamo.backends.debugging import ( | ||||
|         make_eager_backend_with_torch_function_mode, | ||||
|     ) | ||||
|  | ||||
|     # Dynamo is expecting a callable with "__code__" attribute. | ||||
|     # We cannot directly pass hop to it. So we wrap it in a dummy function. | ||||
|     def _flex_attention_hop_wrapper(*args, **kwargs): | ||||
| @ -1035,18 +1040,25 @@ def flex_attention( | ||||
|     with _set_compilation_env(): | ||||
|         with torch._dynamo.utils.disable_cache_limit(): | ||||
|             with _temp_remove_pre_dispatch_torch_function_mode(): | ||||
|                 out, lse = torch.compile( | ||||
|                     _flex_attention_hop_wrapper, backend="eager", fullgraph=True | ||||
|                 )( | ||||
|                     query, | ||||
|                     key, | ||||
|                     value, | ||||
|                     score_mod, | ||||
|                     block_mask.as_tuple(), | ||||
|                     scale, | ||||
|                     kernel_options, | ||||
|                 ) | ||||
|                 if return_lse: | ||||
|                     return out, lse * math.log(2) | ||||
|                 else: | ||||
|                     return out | ||||
|                 with _temp_remove_metadata_torch_function_mode() as metadata_mode: | ||||
|                     if metadata_mode: | ||||
|                         backend = make_eager_backend_with_torch_function_mode( | ||||
|                             metadata_mode | ||||
|                         ) | ||||
|                     else: | ||||
|                         backend = "eager" | ||||
|                     out, lse = torch.compile( | ||||
|                         _flex_attention_hop_wrapper, backend="eager", fullgraph=True | ||||
|                     )( | ||||
|                         query, | ||||
|                         key, | ||||
|                         value, | ||||
|                         score_mod, | ||||
|                         block_mask.as_tuple(), | ||||
|                         scale, | ||||
|                         kernel_options, | ||||
|                     ) | ||||
|                     if return_lse: | ||||
|                         return out, lse * math.log(2) | ||||
|                     else: | ||||
|                         return out | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	