From e0fff31ae31bf3fc7eec39391f90f4893a27ee27 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Fri, 14 Nov 2025 18:11:39 +0000 Subject: [PATCH] [dynamo] Make global state guards and torch function stack guards droppable. (#167674) Summary: Prior to this PR we will always build global and torch funciton guards in all cases. In this PR we did 2 changes to dynamo guards: 1. Created a new guard called "GLOBAL_STATE" which corresponds to the global state guard and can be filtered out using guard_filter_fn 2. Repurpose the existing "TORCH_FUNCTION_STATE" guard for checking torch function mode stack. Also added a new helper `torch.compiler.skip_all_guards_unsafe` which can be useful for use cases like vllm Test Plan: CI Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/167674 Approved by: https://github.com/anijain2305 --- docs/source/torch.compiler_api.md | 1 + test/dynamo/test_comptime.py | 7 ++++ test/dynamo/test_guard_serialization.py | 43 +++++++++++++++++++-- test/dynamo/test_utils.py | 6 +-- torch/_dynamo/guards.py | 50 ++++++++++++++----------- torch/_dynamo/output_graph.py | 1 + torch/compiler/__init__.py | 18 +++++++++ 7 files changed, 97 insertions(+), 29 deletions(-) diff --git a/docs/source/torch.compiler_api.md b/docs/source/torch.compiler_api.md index 2b79b0e67007..66237db8163f 100644 --- a/docs/source/torch.compiler_api.md +++ b/docs/source/torch.compiler_api.md @@ -30,5 +30,6 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`. skip_guard_on_all_nn_modules_unsafe keep_tensor_guards_unsafe skip_guard_on_globals_unsafe + skip_all_guards_unsafe nested_compile_region ``` diff --git a/test/dynamo/test_comptime.py b/test/dynamo/test_comptime.py index efd0f0e9f0f6..619d2800e281 100644 --- a/test/dynamo/test_comptime.py +++ b/test/dynamo/test_comptime.py @@ -330,6 +330,13 @@ y = FakeTensor(..., size=(2,)) 'obj_weakref': None 'guarded_class': None } + global '' GLOBAL_STATE + { + 'guard_types': None, + 'code': None, + 'obj_weakref': None + 'guarded_class': None + } global '' TORCH_FUNCTION_STATE { 'guard_types': None, diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index d81032a457ab..efa9b7572b2b 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -1214,7 +1214,7 @@ class TestGuardSerialization(TestGuardSerializationBase): x = torch.randn(3, 2) with torch.enable_grad(): - ref, loaded = self._test_serialization("GRAD_MODE", fn, x) + ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x) with torch.no_grad(): self._test_check_fn(ref, loaded, {"x": x}, False) with torch.enable_grad(): @@ -1226,7 +1226,7 @@ class TestGuardSerialization(TestGuardSerializationBase): x = torch.randn(3, 2) with torch.enable_grad(): - ref, _ = self._test_serialization("GRAD_MODE", fn, x) + ref, _ = self._test_serialization("GLOBAL_STATE", fn, x) with torch.no_grad(): # Ensure guards state loading is not affected by the current global grad mode. guards_state = pickle.loads(self._cached_guards_state) @@ -1246,7 +1246,7 @@ class TestGuardSerialization(TestGuardSerializationBase): try: x = torch.randn(3, 2) torch.use_deterministic_algorithms(True) - ref, loaded = self._test_serialization("DETERMINISTIC_ALGORITHMS", fn, x) + ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x) torch.use_deterministic_algorithms(False) self._test_check_fn(ref, loaded, {"x": x}, False) torch.use_deterministic_algorithms(True) @@ -1270,6 +1270,9 @@ class TestGuardSerialization(TestGuardSerializationBase): ref, loaded = self._test_serialization("TORCH_FUNCTION_STATE", fn, x) self._test_check_fn(ref, loaded, {"x": x}, True) self._test_check_fn(ref, loaded, {"x": x}, False) + with GlobalTorchFunctionMode(): + ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x) + self._test_check_fn(ref, loaded, {"x": x}, True) with GlobalTorchFunctionMode(): with torch._C.DisableTorchFunction(): self._test_check_fn(ref, loaded, {"x": x}, False) @@ -1306,7 +1309,7 @@ class TestGuardSerialization(TestGuardSerializationBase): x = torch.randn(3, 2) with torch.enable_grad(): - ref, loaded = self._test_serialization("FSDP_TRAINING_STATE", fn, x) + ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x) with torch.no_grad(): self._test_check_fn(ref, loaded, {"x": x}, False) with torch.enable_grad(): @@ -1690,6 +1693,38 @@ class TestGuardSerialization(TestGuardSerializationBase): ref, loaded, {"x": x, "d": ModWithDict({"b": 1e-9, "a": 1e9})}, False ) + def test_global_state_guard_filter(self): + def foo(x): + return x + 1 + + x = torch.randn(3, 2) + + with torch.no_grad(): + compiled_fn = torch.compile( + foo, options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe} + ) + compiled_fn(x) + + # Check global guards are gone. + with torch.enable_grad(), torch.compiler.set_stance("fail_on_recompile"): + self.assertEqual(compiled_fn(x), foo(x)) + + def test_torch_function_state_filter(self): + def foo(x): + return x + 1 + + x = torch.randn(3, 2) + + with GlobalTorchFunctionMode(): + compiled_fn = torch.compile( + foo, options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe} + ) + compiled_fn(x) + + # Check global guards are gone. + with torch.compiler.set_stance("fail_on_recompile"): + self.assertEqual(compiled_fn(x), foo(x)) + class SimpleModule(torch.nn.Module): def __init__(self, c): diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index 66ebe17399ac..0662a7bf912b 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -562,7 +562,7 @@ class TestDynamoTimed(TestCase): 'graph_node_count': 3, 'graph_node_shapes': None, 'graph_op_count': 1, - 'guard_count': 9, + 'guard_count': 10, 'has_guarded_code': True, 'inductor_code_gen_cumulative_compile_time_us': 0, 'inductor_compile_time_s': 0.0, @@ -608,7 +608,7 @@ class TestDynamoTimed(TestCase): 'tensorify_float_attempt': None, 'tensorify_float_failure': None, 'tensorify_float_success': None, - 'triton_compile_time_us': None, + 'triton_compile_time_us': 0, 'triton_kernel_compile_times_us': None, 'triton_version': None}""" if _IS_WINDOWS @@ -649,7 +649,7 @@ class TestDynamoTimed(TestCase): 'graph_node_count': 3, 'graph_node_shapes': None, 'graph_op_count': 1, - 'guard_count': 9, + 'guard_count': 10, 'has_guarded_code': True, 'inductor_code_gen_cumulative_compile_time_us': 0, 'inductor_compile_time_s': 0.0, diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 0f4d0d897b46..a75118f9e503 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2507,12 +2507,30 @@ class GuardBuilder(GuardBuilderBase): def DETERMINISTIC_ALGORITHMS(self, guard: Guard) -> None: pass # we always guard on this via GlobalStateGuard() - def TORCH_FUNCTION_STATE(self, guard: Guard) -> None: - pass # we always guard on this via GlobalStateGuard() - def FSDP_TRAINING_STATE(self, guard: Guard) -> None: pass # we always guard on this via GlobalStateGuard() + def GLOBAL_STATE(self, guard: Guard) -> None: + output_graph = self.check_fn_manager.output_graph + assert output_graph is not None + global_state = output_graph.global_state_guard + self.check_fn_manager.global_state = global_state + self.guard_manager.root.add_global_state_guard( + global_state, ["___check_global_state()"] + ) + + def TORCH_FUNCTION_STATE(self, guard: Guard) -> None: + assert self.check_fn_manager.torch_function_mode_stack is not None + self.check_fn_manager.torch_function_mode_stack_check_fn = ( + make_torch_function_mode_stack_guard( + self.check_fn_manager.torch_function_mode_stack + ) + ) + self.guard_manager.root.add_torch_function_mode_stack_guard( + self.check_fn_manager.torch_function_mode_stack, + ["___check_torch_function_mode_stack()"], + ) + def DEFAULT_DEVICE(self, guard: Guard) -> None: """Guard on CURRENT_DEVICE per torch.utils._device""" assert guard.source is GuardSource.GLOBAL @@ -3532,6 +3550,8 @@ class CheckFunctionManager: self.additional_used_local_vars: OrderedSet[str] = OrderedSet() self.additional_used_global_vars: OrderedSet[str] = OrderedSet() self.runtime_global_scope = runtime_global_scope + self.global_state: Optional[torch._C._dynamo.guards.GlobalStateGuard] = None + self.torch_function_mode_stack_check_fn: Optional[Callable[[], bool]] = None if not justknobs_check("pytorch/compiler:guard_nn_modules"): log.warning("guard_nn_modules is turned off using justknobs killswitch") @@ -3939,27 +3959,11 @@ class CheckFunctionManager: verbose_code_parts = [] structured_guard_fns: list[Callable[[], dict[str, Any]]] = [] - assert self.torch_function_mode_stack is not None - torch_function_mode_stack_check_fn = make_torch_function_mode_stack_guard( - self.torch_function_mode_stack - ) - # Add compile id info in the guard manager for debugging purpose self.guard_manager.root.attach_compile_id( str(CompileContext.current_compile_id()) ) - # Insert the global_state guard - assert self.output_graph is not None - global_state = self.output_graph.global_state_guard - self.guard_manager.root.add_global_state_guard( - global_state, ["___check_global_state()"] - ) - - self.guard_manager.root.add_torch_function_mode_stack_guard( - self.torch_function_mode_stack, - ["___check_torch_function_mode_stack()"], - ) # Clear references to torch_function modes held in the list self.torch_function_mode_stack = None @@ -4105,12 +4109,14 @@ class CheckFunctionManager: if convert_frame.initial_global_state is None: # we should only hit this case in NopTests() - global_state = convert_frame.GlobalStateGuard() + check_global_state = convert_frame.GlobalStateGuard().check + else: + check_global_state = getattr(self.global_state, "check", None) closure_vars = { "___check_tensors": check_tensors_fn, "___check_tensors_verbose": check_tensors_verbose_fn, - "___check_global_state": global_state.check, - "___check_torch_function_mode_stack": torch_function_mode_stack_check_fn, + "___check_global_state": check_global_state, + "___check_torch_function_mode_stack": self.torch_function_mode_stack_check_fn, **SYMPY_INTERP, **_get_closure_vars(), } diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 0a5dfc8fc2ed..87920771feb1 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -794,6 +794,7 @@ class OutputGraph(OutputGraphCommon): self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE)) + self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GLOBAL_STATE)) self.guards.add( GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE) ) diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index 1e744f54362d..809ec86fa5ec 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -36,6 +36,7 @@ __all__ = [ "skip_guard_on_all_nn_modules_unsafe", "keep_tensor_guards_unsafe", "skip_guard_on_globals_unsafe", + "skip_all_guards_unsafe", "nested_compile_region", ] @@ -617,6 +618,23 @@ def skip_guard_on_globals_unsafe(guard_entries): return [not entry.is_global for entry in guard_entries] +def skip_all_guards_unsafe(guard_entries): + """ + A function for skipping all guards on a compiled function. + + WARNING: This function will drop all the safety guarantees from Dynamo + compiled function. Use this with caution. + + To use this API, use guard_filter_fn argument while calling torch.compile + + >> opt_mod = torch.compile( + >> mod, + >> options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe}, + >> ) + """ + return [False for entry in guard_entries] + + def nested_compile_region(fn=None): """ Tells **``torch.compile``** that the marked set of operations forms a nested