[dynamo] Make global state guards and torch function stack guards droppable. (#167674)

Summary:
Prior to this PR we will always build global and torch funciton guards in all cases.

In this PR we did 2 changes to dynamo guards:
1. Created a new guard called "GLOBAL_STATE" which corresponds to the global state guard and can be filtered out using guard_filter_fn
2. Repurpose the existing "TORCH_FUNCTION_STATE" guard for checking torch function mode stack.

Also added a new helper `torch.compiler.skip_all_guards_unsafe` which can be useful for use cases like vllm

Test Plan:
CI

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167674
Approved by: https://github.com/anijain2305
This commit is contained in:
Zhengxu Chen
2025-11-14 18:11:39 +00:00
committed by PyTorch MergeBot
parent 7ede33b8e3
commit e0fff31ae3
7 changed files with 97 additions and 29 deletions

View File

@ -30,5 +30,6 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
skip_guard_on_all_nn_modules_unsafe
keep_tensor_guards_unsafe
skip_guard_on_globals_unsafe
skip_all_guards_unsafe
nested_compile_region
```

View File

@ -330,6 +330,13 @@ y = FakeTensor(..., size=(2,))
'obj_weakref': None
'guarded_class': None
}
global '' GLOBAL_STATE
{
'guard_types': None,
'code': None,
'obj_weakref': None
'guarded_class': None
}
global '' TORCH_FUNCTION_STATE
{
'guard_types': None,

View File

@ -1214,7 +1214,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
x = torch.randn(3, 2)
with torch.enable_grad():
ref, loaded = self._test_serialization("GRAD_MODE", fn, x)
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
with torch.no_grad():
self._test_check_fn(ref, loaded, {"x": x}, False)
with torch.enable_grad():
@ -1226,7 +1226,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
x = torch.randn(3, 2)
with torch.enable_grad():
ref, _ = self._test_serialization("GRAD_MODE", fn, x)
ref, _ = self._test_serialization("GLOBAL_STATE", fn, x)
with torch.no_grad():
# Ensure guards state loading is not affected by the current global grad mode.
guards_state = pickle.loads(self._cached_guards_state)
@ -1246,7 +1246,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
try:
x = torch.randn(3, 2)
torch.use_deterministic_algorithms(True)
ref, loaded = self._test_serialization("DETERMINISTIC_ALGORITHMS", fn, x)
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
torch.use_deterministic_algorithms(False)
self._test_check_fn(ref, loaded, {"x": x}, False)
torch.use_deterministic_algorithms(True)
@ -1270,6 +1270,9 @@ class TestGuardSerialization(TestGuardSerializationBase):
ref, loaded = self._test_serialization("TORCH_FUNCTION_STATE", fn, x)
self._test_check_fn(ref, loaded, {"x": x}, True)
self._test_check_fn(ref, loaded, {"x": x}, False)
with GlobalTorchFunctionMode():
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
self._test_check_fn(ref, loaded, {"x": x}, True)
with GlobalTorchFunctionMode():
with torch._C.DisableTorchFunction():
self._test_check_fn(ref, loaded, {"x": x}, False)
@ -1306,7 +1309,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
x = torch.randn(3, 2)
with torch.enable_grad():
ref, loaded = self._test_serialization("FSDP_TRAINING_STATE", fn, x)
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
with torch.no_grad():
self._test_check_fn(ref, loaded, {"x": x}, False)
with torch.enable_grad():
@ -1690,6 +1693,38 @@ class TestGuardSerialization(TestGuardSerializationBase):
ref, loaded, {"x": x, "d": ModWithDict({"b": 1e-9, "a": 1e9})}, False
)
def test_global_state_guard_filter(self):
def foo(x):
return x + 1
x = torch.randn(3, 2)
with torch.no_grad():
compiled_fn = torch.compile(
foo, options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe}
)
compiled_fn(x)
# Check global guards are gone.
with torch.enable_grad(), torch.compiler.set_stance("fail_on_recompile"):
self.assertEqual(compiled_fn(x), foo(x))
def test_torch_function_state_filter(self):
def foo(x):
return x + 1
x = torch.randn(3, 2)
with GlobalTorchFunctionMode():
compiled_fn = torch.compile(
foo, options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe}
)
compiled_fn(x)
# Check global guards are gone.
with torch.compiler.set_stance("fail_on_recompile"):
self.assertEqual(compiled_fn(x), foo(x))
class SimpleModule(torch.nn.Module):
def __init__(self, c):

View File

@ -562,7 +562,7 @@ class TestDynamoTimed(TestCase):
'graph_node_count': 3,
'graph_node_shapes': None,
'graph_op_count': 1,
'guard_count': 9,
'guard_count': 10,
'has_guarded_code': True,
'inductor_code_gen_cumulative_compile_time_us': 0,
'inductor_compile_time_s': 0.0,
@ -608,7 +608,7 @@ class TestDynamoTimed(TestCase):
'tensorify_float_attempt': None,
'tensorify_float_failure': None,
'tensorify_float_success': None,
'triton_compile_time_us': None,
'triton_compile_time_us': 0,
'triton_kernel_compile_times_us': None,
'triton_version': None}"""
if _IS_WINDOWS
@ -649,7 +649,7 @@ class TestDynamoTimed(TestCase):
'graph_node_count': 3,
'graph_node_shapes': None,
'graph_op_count': 1,
'guard_count': 9,
'guard_count': 10,
'has_guarded_code': True,
'inductor_code_gen_cumulative_compile_time_us': 0,
'inductor_compile_time_s': 0.0,

View File

@ -2507,12 +2507,30 @@ class GuardBuilder(GuardBuilderBase):
def DETERMINISTIC_ALGORITHMS(self, guard: Guard) -> None:
pass # we always guard on this via GlobalStateGuard()
def TORCH_FUNCTION_STATE(self, guard: Guard) -> None:
pass # we always guard on this via GlobalStateGuard()
def FSDP_TRAINING_STATE(self, guard: Guard) -> None:
pass # we always guard on this via GlobalStateGuard()
def GLOBAL_STATE(self, guard: Guard) -> None:
output_graph = self.check_fn_manager.output_graph
assert output_graph is not None
global_state = output_graph.global_state_guard
self.check_fn_manager.global_state = global_state
self.guard_manager.root.add_global_state_guard(
global_state, ["___check_global_state()"]
)
def TORCH_FUNCTION_STATE(self, guard: Guard) -> None:
assert self.check_fn_manager.torch_function_mode_stack is not None
self.check_fn_manager.torch_function_mode_stack_check_fn = (
make_torch_function_mode_stack_guard(
self.check_fn_manager.torch_function_mode_stack
)
)
self.guard_manager.root.add_torch_function_mode_stack_guard(
self.check_fn_manager.torch_function_mode_stack,
["___check_torch_function_mode_stack()"],
)
def DEFAULT_DEVICE(self, guard: Guard) -> None:
"""Guard on CURRENT_DEVICE per torch.utils._device"""
assert guard.source is GuardSource.GLOBAL
@ -3532,6 +3550,8 @@ class CheckFunctionManager:
self.additional_used_local_vars: OrderedSet[str] = OrderedSet()
self.additional_used_global_vars: OrderedSet[str] = OrderedSet()
self.runtime_global_scope = runtime_global_scope
self.global_state: Optional[torch._C._dynamo.guards.GlobalStateGuard] = None
self.torch_function_mode_stack_check_fn: Optional[Callable[[], bool]] = None
if not justknobs_check("pytorch/compiler:guard_nn_modules"):
log.warning("guard_nn_modules is turned off using justknobs killswitch")
@ -3939,27 +3959,11 @@ class CheckFunctionManager:
verbose_code_parts = []
structured_guard_fns: list[Callable[[], dict[str, Any]]] = []
assert self.torch_function_mode_stack is not None
torch_function_mode_stack_check_fn = make_torch_function_mode_stack_guard(
self.torch_function_mode_stack
)
# Add compile id info in the guard manager for debugging purpose
self.guard_manager.root.attach_compile_id(
str(CompileContext.current_compile_id())
)
# Insert the global_state guard
assert self.output_graph is not None
global_state = self.output_graph.global_state_guard
self.guard_manager.root.add_global_state_guard(
global_state, ["___check_global_state()"]
)
self.guard_manager.root.add_torch_function_mode_stack_guard(
self.torch_function_mode_stack,
["___check_torch_function_mode_stack()"],
)
# Clear references to torch_function modes held in the list
self.torch_function_mode_stack = None
@ -4105,12 +4109,14 @@ class CheckFunctionManager:
if convert_frame.initial_global_state is None:
# we should only hit this case in NopTests()
global_state = convert_frame.GlobalStateGuard()
check_global_state = convert_frame.GlobalStateGuard().check
else:
check_global_state = getattr(self.global_state, "check", None)
closure_vars = {
"___check_tensors": check_tensors_fn,
"___check_tensors_verbose": check_tensors_verbose_fn,
"___check_global_state": global_state.check,
"___check_torch_function_mode_stack": torch_function_mode_stack_check_fn,
"___check_global_state": check_global_state,
"___check_torch_function_mode_stack": self.torch_function_mode_stack_check_fn,
**SYMPY_INTERP,
**_get_closure_vars(),
}

View File

@ -794,6 +794,7 @@ class OutputGraph(OutputGraphCommon):
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE))
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GLOBAL_STATE))
self.guards.add(
GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
)

View File

@ -36,6 +36,7 @@ __all__ = [
"skip_guard_on_all_nn_modules_unsafe",
"keep_tensor_guards_unsafe",
"skip_guard_on_globals_unsafe",
"skip_all_guards_unsafe",
"nested_compile_region",
]
@ -617,6 +618,23 @@ def skip_guard_on_globals_unsafe(guard_entries):
return [not entry.is_global for entry in guard_entries]
def skip_all_guards_unsafe(guard_entries):
"""
A function for skipping all guards on a compiled function.
WARNING: This function will drop all the safety guarantees from Dynamo
compiled function. Use this with caution.
To use this API, use guard_filter_fn argument while calling torch.compile
>> opt_mod = torch.compile(
>> mod,
>> options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe},
>> )
"""
return [False for entry in guard_entries]
def nested_compile_region(fn=None):
"""
Tells **``torch.compile``** that the marked set of operations forms a nested