[dynamo] Make global state guards and torch function stack guards droppable. (#167674)

Summary:
Prior to this PR we will always build global and torch funciton guards in all cases.

In this PR we did 2 changes to dynamo guards:
1. Created a new guard called "GLOBAL_STATE" which corresponds to the global state guard and can be filtered out using guard_filter_fn
2. Repurpose the existing "TORCH_FUNCTION_STATE" guard for checking torch function mode stack.

Also added a new helper `torch.compiler.skip_all_guards_unsafe` which can be useful for use cases like vllm

Test Plan:
CI

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167674
Approved by: https://github.com/anijain2305
This commit is contained in:
Zhengxu Chen
2025-11-14 18:11:39 +00:00
committed by PyTorch MergeBot
parent 7ede33b8e3
commit e0fff31ae3
7 changed files with 97 additions and 29 deletions

View File

@ -30,5 +30,6 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
skip_guard_on_all_nn_modules_unsafe skip_guard_on_all_nn_modules_unsafe
keep_tensor_guards_unsafe keep_tensor_guards_unsafe
skip_guard_on_globals_unsafe skip_guard_on_globals_unsafe
skip_all_guards_unsafe
nested_compile_region nested_compile_region
``` ```

View File

@ -330,6 +330,13 @@ y = FakeTensor(..., size=(2,))
'obj_weakref': None 'obj_weakref': None
'guarded_class': None 'guarded_class': None
} }
global '' GLOBAL_STATE
{
'guard_types': None,
'code': None,
'obj_weakref': None
'guarded_class': None
}
global '' TORCH_FUNCTION_STATE global '' TORCH_FUNCTION_STATE
{ {
'guard_types': None, 'guard_types': None,

View File

@ -1214,7 +1214,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
x = torch.randn(3, 2) x = torch.randn(3, 2)
with torch.enable_grad(): with torch.enable_grad():
ref, loaded = self._test_serialization("GRAD_MODE", fn, x) ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
with torch.no_grad(): with torch.no_grad():
self._test_check_fn(ref, loaded, {"x": x}, False) self._test_check_fn(ref, loaded, {"x": x}, False)
with torch.enable_grad(): with torch.enable_grad():
@ -1226,7 +1226,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
x = torch.randn(3, 2) x = torch.randn(3, 2)
with torch.enable_grad(): with torch.enable_grad():
ref, _ = self._test_serialization("GRAD_MODE", fn, x) ref, _ = self._test_serialization("GLOBAL_STATE", fn, x)
with torch.no_grad(): with torch.no_grad():
# Ensure guards state loading is not affected by the current global grad mode. # Ensure guards state loading is not affected by the current global grad mode.
guards_state = pickle.loads(self._cached_guards_state) guards_state = pickle.loads(self._cached_guards_state)
@ -1246,7 +1246,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
try: try:
x = torch.randn(3, 2) x = torch.randn(3, 2)
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
ref, loaded = self._test_serialization("DETERMINISTIC_ALGORITHMS", fn, x) ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
torch.use_deterministic_algorithms(False) torch.use_deterministic_algorithms(False)
self._test_check_fn(ref, loaded, {"x": x}, False) self._test_check_fn(ref, loaded, {"x": x}, False)
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
@ -1270,6 +1270,9 @@ class TestGuardSerialization(TestGuardSerializationBase):
ref, loaded = self._test_serialization("TORCH_FUNCTION_STATE", fn, x) ref, loaded = self._test_serialization("TORCH_FUNCTION_STATE", fn, x)
self._test_check_fn(ref, loaded, {"x": x}, True) self._test_check_fn(ref, loaded, {"x": x}, True)
self._test_check_fn(ref, loaded, {"x": x}, False) self._test_check_fn(ref, loaded, {"x": x}, False)
with GlobalTorchFunctionMode():
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
self._test_check_fn(ref, loaded, {"x": x}, True)
with GlobalTorchFunctionMode(): with GlobalTorchFunctionMode():
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
self._test_check_fn(ref, loaded, {"x": x}, False) self._test_check_fn(ref, loaded, {"x": x}, False)
@ -1306,7 +1309,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
x = torch.randn(3, 2) x = torch.randn(3, 2)
with torch.enable_grad(): with torch.enable_grad():
ref, loaded = self._test_serialization("FSDP_TRAINING_STATE", fn, x) ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
with torch.no_grad(): with torch.no_grad():
self._test_check_fn(ref, loaded, {"x": x}, False) self._test_check_fn(ref, loaded, {"x": x}, False)
with torch.enable_grad(): with torch.enable_grad():
@ -1690,6 +1693,38 @@ class TestGuardSerialization(TestGuardSerializationBase):
ref, loaded, {"x": x, "d": ModWithDict({"b": 1e-9, "a": 1e9})}, False ref, loaded, {"x": x, "d": ModWithDict({"b": 1e-9, "a": 1e9})}, False
) )
def test_global_state_guard_filter(self):
def foo(x):
return x + 1
x = torch.randn(3, 2)
with torch.no_grad():
compiled_fn = torch.compile(
foo, options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe}
)
compiled_fn(x)
# Check global guards are gone.
with torch.enable_grad(), torch.compiler.set_stance("fail_on_recompile"):
self.assertEqual(compiled_fn(x), foo(x))
def test_torch_function_state_filter(self):
def foo(x):
return x + 1
x = torch.randn(3, 2)
with GlobalTorchFunctionMode():
compiled_fn = torch.compile(
foo, options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe}
)
compiled_fn(x)
# Check global guards are gone.
with torch.compiler.set_stance("fail_on_recompile"):
self.assertEqual(compiled_fn(x), foo(x))
class SimpleModule(torch.nn.Module): class SimpleModule(torch.nn.Module):
def __init__(self, c): def __init__(self, c):

View File

@ -562,7 +562,7 @@ class TestDynamoTimed(TestCase):
'graph_node_count': 3, 'graph_node_count': 3,
'graph_node_shapes': None, 'graph_node_shapes': None,
'graph_op_count': 1, 'graph_op_count': 1,
'guard_count': 9, 'guard_count': 10,
'has_guarded_code': True, 'has_guarded_code': True,
'inductor_code_gen_cumulative_compile_time_us': 0, 'inductor_code_gen_cumulative_compile_time_us': 0,
'inductor_compile_time_s': 0.0, 'inductor_compile_time_s': 0.0,
@ -608,7 +608,7 @@ class TestDynamoTimed(TestCase):
'tensorify_float_attempt': None, 'tensorify_float_attempt': None,
'tensorify_float_failure': None, 'tensorify_float_failure': None,
'tensorify_float_success': None, 'tensorify_float_success': None,
'triton_compile_time_us': None, 'triton_compile_time_us': 0,
'triton_kernel_compile_times_us': None, 'triton_kernel_compile_times_us': None,
'triton_version': None}""" 'triton_version': None}"""
if _IS_WINDOWS if _IS_WINDOWS
@ -649,7 +649,7 @@ class TestDynamoTimed(TestCase):
'graph_node_count': 3, 'graph_node_count': 3,
'graph_node_shapes': None, 'graph_node_shapes': None,
'graph_op_count': 1, 'graph_op_count': 1,
'guard_count': 9, 'guard_count': 10,
'has_guarded_code': True, 'has_guarded_code': True,
'inductor_code_gen_cumulative_compile_time_us': 0, 'inductor_code_gen_cumulative_compile_time_us': 0,
'inductor_compile_time_s': 0.0, 'inductor_compile_time_s': 0.0,

View File

@ -2507,12 +2507,30 @@ class GuardBuilder(GuardBuilderBase):
def DETERMINISTIC_ALGORITHMS(self, guard: Guard) -> None: def DETERMINISTIC_ALGORITHMS(self, guard: Guard) -> None:
pass # we always guard on this via GlobalStateGuard() pass # we always guard on this via GlobalStateGuard()
def TORCH_FUNCTION_STATE(self, guard: Guard) -> None:
pass # we always guard on this via GlobalStateGuard()
def FSDP_TRAINING_STATE(self, guard: Guard) -> None: def FSDP_TRAINING_STATE(self, guard: Guard) -> None:
pass # we always guard on this via GlobalStateGuard() pass # we always guard on this via GlobalStateGuard()
def GLOBAL_STATE(self, guard: Guard) -> None:
output_graph = self.check_fn_manager.output_graph
assert output_graph is not None
global_state = output_graph.global_state_guard
self.check_fn_manager.global_state = global_state
self.guard_manager.root.add_global_state_guard(
global_state, ["___check_global_state()"]
)
def TORCH_FUNCTION_STATE(self, guard: Guard) -> None:
assert self.check_fn_manager.torch_function_mode_stack is not None
self.check_fn_manager.torch_function_mode_stack_check_fn = (
make_torch_function_mode_stack_guard(
self.check_fn_manager.torch_function_mode_stack
)
)
self.guard_manager.root.add_torch_function_mode_stack_guard(
self.check_fn_manager.torch_function_mode_stack,
["___check_torch_function_mode_stack()"],
)
def DEFAULT_DEVICE(self, guard: Guard) -> None: def DEFAULT_DEVICE(self, guard: Guard) -> None:
"""Guard on CURRENT_DEVICE per torch.utils._device""" """Guard on CURRENT_DEVICE per torch.utils._device"""
assert guard.source is GuardSource.GLOBAL assert guard.source is GuardSource.GLOBAL
@ -3532,6 +3550,8 @@ class CheckFunctionManager:
self.additional_used_local_vars: OrderedSet[str] = OrderedSet() self.additional_used_local_vars: OrderedSet[str] = OrderedSet()
self.additional_used_global_vars: OrderedSet[str] = OrderedSet() self.additional_used_global_vars: OrderedSet[str] = OrderedSet()
self.runtime_global_scope = runtime_global_scope self.runtime_global_scope = runtime_global_scope
self.global_state: Optional[torch._C._dynamo.guards.GlobalStateGuard] = None
self.torch_function_mode_stack_check_fn: Optional[Callable[[], bool]] = None
if not justknobs_check("pytorch/compiler:guard_nn_modules"): if not justknobs_check("pytorch/compiler:guard_nn_modules"):
log.warning("guard_nn_modules is turned off using justknobs killswitch") log.warning("guard_nn_modules is turned off using justknobs killswitch")
@ -3939,27 +3959,11 @@ class CheckFunctionManager:
verbose_code_parts = [] verbose_code_parts = []
structured_guard_fns: list[Callable[[], dict[str, Any]]] = [] structured_guard_fns: list[Callable[[], dict[str, Any]]] = []
assert self.torch_function_mode_stack is not None
torch_function_mode_stack_check_fn = make_torch_function_mode_stack_guard(
self.torch_function_mode_stack
)
# Add compile id info in the guard manager for debugging purpose # Add compile id info in the guard manager for debugging purpose
self.guard_manager.root.attach_compile_id( self.guard_manager.root.attach_compile_id(
str(CompileContext.current_compile_id()) str(CompileContext.current_compile_id())
) )
# Insert the global_state guard
assert self.output_graph is not None
global_state = self.output_graph.global_state_guard
self.guard_manager.root.add_global_state_guard(
global_state, ["___check_global_state()"]
)
self.guard_manager.root.add_torch_function_mode_stack_guard(
self.torch_function_mode_stack,
["___check_torch_function_mode_stack()"],
)
# Clear references to torch_function modes held in the list # Clear references to torch_function modes held in the list
self.torch_function_mode_stack = None self.torch_function_mode_stack = None
@ -4105,12 +4109,14 @@ class CheckFunctionManager:
if convert_frame.initial_global_state is None: if convert_frame.initial_global_state is None:
# we should only hit this case in NopTests() # we should only hit this case in NopTests()
global_state = convert_frame.GlobalStateGuard() check_global_state = convert_frame.GlobalStateGuard().check
else:
check_global_state = getattr(self.global_state, "check", None)
closure_vars = { closure_vars = {
"___check_tensors": check_tensors_fn, "___check_tensors": check_tensors_fn,
"___check_tensors_verbose": check_tensors_verbose_fn, "___check_tensors_verbose": check_tensors_verbose_fn,
"___check_global_state": global_state.check, "___check_global_state": check_global_state,
"___check_torch_function_mode_stack": torch_function_mode_stack_check_fn, "___check_torch_function_mode_stack": self.torch_function_mode_stack_check_fn,
**SYMPY_INTERP, **SYMPY_INTERP,
**_get_closure_vars(), **_get_closure_vars(),
} }

View File

@ -794,6 +794,7 @@ class OutputGraph(OutputGraphCommon):
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE)) self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE))
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GLOBAL_STATE))
self.guards.add( self.guards.add(
GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE) GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
) )

View File

@ -36,6 +36,7 @@ __all__ = [
"skip_guard_on_all_nn_modules_unsafe", "skip_guard_on_all_nn_modules_unsafe",
"keep_tensor_guards_unsafe", "keep_tensor_guards_unsafe",
"skip_guard_on_globals_unsafe", "skip_guard_on_globals_unsafe",
"skip_all_guards_unsafe",
"nested_compile_region", "nested_compile_region",
] ]
@ -617,6 +618,23 @@ def skip_guard_on_globals_unsafe(guard_entries):
return [not entry.is_global for entry in guard_entries] return [not entry.is_global for entry in guard_entries]
def skip_all_guards_unsafe(guard_entries):
"""
A function for skipping all guards on a compiled function.
WARNING: This function will drop all the safety guarantees from Dynamo
compiled function. Use this with caution.
To use this API, use guard_filter_fn argument while calling torch.compile
>> opt_mod = torch.compile(
>> mod,
>> options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe},
>> )
"""
return [False for entry in guard_entries]
def nested_compile_region(fn=None): def nested_compile_region(fn=None):
""" """
Tells **``torch.compile``** that the marked set of operations forms a nested Tells **``torch.compile``** that the marked set of operations forms a nested