mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-15 06:48:48 +08:00
[dynamo] Make global state guards and torch function stack guards droppable. (#167674)
Summary: Prior to this PR we will always build global and torch funciton guards in all cases. In this PR we did 2 changes to dynamo guards: 1. Created a new guard called "GLOBAL_STATE" which corresponds to the global state guard and can be filtered out using guard_filter_fn 2. Repurpose the existing "TORCH_FUNCTION_STATE" guard for checking torch function mode stack. Also added a new helper `torch.compiler.skip_all_guards_unsafe` which can be useful for use cases like vllm Test Plan: CI Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/167674 Approved by: https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
7ede33b8e3
commit
e0fff31ae3
@ -30,5 +30,6 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
|
||||
skip_guard_on_all_nn_modules_unsafe
|
||||
keep_tensor_guards_unsafe
|
||||
skip_guard_on_globals_unsafe
|
||||
skip_all_guards_unsafe
|
||||
nested_compile_region
|
||||
```
|
||||
|
||||
@ -330,6 +330,13 @@ y = FakeTensor(..., size=(2,))
|
||||
'obj_weakref': None
|
||||
'guarded_class': None
|
||||
}
|
||||
global '' GLOBAL_STATE
|
||||
{
|
||||
'guard_types': None,
|
||||
'code': None,
|
||||
'obj_weakref': None
|
||||
'guarded_class': None
|
||||
}
|
||||
global '' TORCH_FUNCTION_STATE
|
||||
{
|
||||
'guard_types': None,
|
||||
|
||||
@ -1214,7 +1214,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
with torch.enable_grad():
|
||||
ref, loaded = self._test_serialization("GRAD_MODE", fn, x)
|
||||
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||
with torch.no_grad():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
with torch.enable_grad():
|
||||
@ -1226,7 +1226,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
with torch.enable_grad():
|
||||
ref, _ = self._test_serialization("GRAD_MODE", fn, x)
|
||||
ref, _ = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||
with torch.no_grad():
|
||||
# Ensure guards state loading is not affected by the current global grad mode.
|
||||
guards_state = pickle.loads(self._cached_guards_state)
|
||||
@ -1246,7 +1246,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
try:
|
||||
x = torch.randn(3, 2)
|
||||
torch.use_deterministic_algorithms(True)
|
||||
ref, loaded = self._test_serialization("DETERMINISTIC_ALGORITHMS", fn, x)
|
||||
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||
torch.use_deterministic_algorithms(False)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
torch.use_deterministic_algorithms(True)
|
||||
@ -1270,6 +1270,9 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
ref, loaded = self._test_serialization("TORCH_FUNCTION_STATE", fn, x)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
with GlobalTorchFunctionMode():
|
||||
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||
with GlobalTorchFunctionMode():
|
||||
with torch._C.DisableTorchFunction():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
@ -1306,7 +1309,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
x = torch.randn(3, 2)
|
||||
|
||||
with torch.enable_grad():
|
||||
ref, loaded = self._test_serialization("FSDP_TRAINING_STATE", fn, x)
|
||||
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||
with torch.no_grad():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
with torch.enable_grad():
|
||||
@ -1690,6 +1693,38 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
ref, loaded, {"x": x, "d": ModWithDict({"b": 1e-9, "a": 1e9})}, False
|
||||
)
|
||||
|
||||
def test_global_state_guard_filter(self):
|
||||
def foo(x):
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
|
||||
with torch.no_grad():
|
||||
compiled_fn = torch.compile(
|
||||
foo, options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe}
|
||||
)
|
||||
compiled_fn(x)
|
||||
|
||||
# Check global guards are gone.
|
||||
with torch.enable_grad(), torch.compiler.set_stance("fail_on_recompile"):
|
||||
self.assertEqual(compiled_fn(x), foo(x))
|
||||
|
||||
def test_torch_function_state_filter(self):
|
||||
def foo(x):
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
|
||||
with GlobalTorchFunctionMode():
|
||||
compiled_fn = torch.compile(
|
||||
foo, options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe}
|
||||
)
|
||||
compiled_fn(x)
|
||||
|
||||
# Check global guards are gone.
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
self.assertEqual(compiled_fn(x), foo(x))
|
||||
|
||||
|
||||
class SimpleModule(torch.nn.Module):
|
||||
def __init__(self, c):
|
||||
|
||||
@ -562,7 +562,7 @@ class TestDynamoTimed(TestCase):
|
||||
'graph_node_count': 3,
|
||||
'graph_node_shapes': None,
|
||||
'graph_op_count': 1,
|
||||
'guard_count': 9,
|
||||
'guard_count': 10,
|
||||
'has_guarded_code': True,
|
||||
'inductor_code_gen_cumulative_compile_time_us': 0,
|
||||
'inductor_compile_time_s': 0.0,
|
||||
@ -608,7 +608,7 @@ class TestDynamoTimed(TestCase):
|
||||
'tensorify_float_attempt': None,
|
||||
'tensorify_float_failure': None,
|
||||
'tensorify_float_success': None,
|
||||
'triton_compile_time_us': None,
|
||||
'triton_compile_time_us': 0,
|
||||
'triton_kernel_compile_times_us': None,
|
||||
'triton_version': None}"""
|
||||
if _IS_WINDOWS
|
||||
@ -649,7 +649,7 @@ class TestDynamoTimed(TestCase):
|
||||
'graph_node_count': 3,
|
||||
'graph_node_shapes': None,
|
||||
'graph_op_count': 1,
|
||||
'guard_count': 9,
|
||||
'guard_count': 10,
|
||||
'has_guarded_code': True,
|
||||
'inductor_code_gen_cumulative_compile_time_us': 0,
|
||||
'inductor_compile_time_s': 0.0,
|
||||
|
||||
@ -2507,12 +2507,30 @@ class GuardBuilder(GuardBuilderBase):
|
||||
def DETERMINISTIC_ALGORITHMS(self, guard: Guard) -> None:
|
||||
pass # we always guard on this via GlobalStateGuard()
|
||||
|
||||
def TORCH_FUNCTION_STATE(self, guard: Guard) -> None:
|
||||
pass # we always guard on this via GlobalStateGuard()
|
||||
|
||||
def FSDP_TRAINING_STATE(self, guard: Guard) -> None:
|
||||
pass # we always guard on this via GlobalStateGuard()
|
||||
|
||||
def GLOBAL_STATE(self, guard: Guard) -> None:
|
||||
output_graph = self.check_fn_manager.output_graph
|
||||
assert output_graph is not None
|
||||
global_state = output_graph.global_state_guard
|
||||
self.check_fn_manager.global_state = global_state
|
||||
self.guard_manager.root.add_global_state_guard(
|
||||
global_state, ["___check_global_state()"]
|
||||
)
|
||||
|
||||
def TORCH_FUNCTION_STATE(self, guard: Guard) -> None:
|
||||
assert self.check_fn_manager.torch_function_mode_stack is not None
|
||||
self.check_fn_manager.torch_function_mode_stack_check_fn = (
|
||||
make_torch_function_mode_stack_guard(
|
||||
self.check_fn_manager.torch_function_mode_stack
|
||||
)
|
||||
)
|
||||
self.guard_manager.root.add_torch_function_mode_stack_guard(
|
||||
self.check_fn_manager.torch_function_mode_stack,
|
||||
["___check_torch_function_mode_stack()"],
|
||||
)
|
||||
|
||||
def DEFAULT_DEVICE(self, guard: Guard) -> None:
|
||||
"""Guard on CURRENT_DEVICE per torch.utils._device"""
|
||||
assert guard.source is GuardSource.GLOBAL
|
||||
@ -3532,6 +3550,8 @@ class CheckFunctionManager:
|
||||
self.additional_used_local_vars: OrderedSet[str] = OrderedSet()
|
||||
self.additional_used_global_vars: OrderedSet[str] = OrderedSet()
|
||||
self.runtime_global_scope = runtime_global_scope
|
||||
self.global_state: Optional[torch._C._dynamo.guards.GlobalStateGuard] = None
|
||||
self.torch_function_mode_stack_check_fn: Optional[Callable[[], bool]] = None
|
||||
|
||||
if not justknobs_check("pytorch/compiler:guard_nn_modules"):
|
||||
log.warning("guard_nn_modules is turned off using justknobs killswitch")
|
||||
@ -3939,27 +3959,11 @@ class CheckFunctionManager:
|
||||
verbose_code_parts = []
|
||||
structured_guard_fns: list[Callable[[], dict[str, Any]]] = []
|
||||
|
||||
assert self.torch_function_mode_stack is not None
|
||||
torch_function_mode_stack_check_fn = make_torch_function_mode_stack_guard(
|
||||
self.torch_function_mode_stack
|
||||
)
|
||||
|
||||
# Add compile id info in the guard manager for debugging purpose
|
||||
self.guard_manager.root.attach_compile_id(
|
||||
str(CompileContext.current_compile_id())
|
||||
)
|
||||
|
||||
# Insert the global_state guard
|
||||
assert self.output_graph is not None
|
||||
global_state = self.output_graph.global_state_guard
|
||||
self.guard_manager.root.add_global_state_guard(
|
||||
global_state, ["___check_global_state()"]
|
||||
)
|
||||
|
||||
self.guard_manager.root.add_torch_function_mode_stack_guard(
|
||||
self.torch_function_mode_stack,
|
||||
["___check_torch_function_mode_stack()"],
|
||||
)
|
||||
# Clear references to torch_function modes held in the list
|
||||
self.torch_function_mode_stack = None
|
||||
|
||||
@ -4105,12 +4109,14 @@ class CheckFunctionManager:
|
||||
|
||||
if convert_frame.initial_global_state is None:
|
||||
# we should only hit this case in NopTests()
|
||||
global_state = convert_frame.GlobalStateGuard()
|
||||
check_global_state = convert_frame.GlobalStateGuard().check
|
||||
else:
|
||||
check_global_state = getattr(self.global_state, "check", None)
|
||||
closure_vars = {
|
||||
"___check_tensors": check_tensors_fn,
|
||||
"___check_tensors_verbose": check_tensors_verbose_fn,
|
||||
"___check_global_state": global_state.check,
|
||||
"___check_torch_function_mode_stack": torch_function_mode_stack_check_fn,
|
||||
"___check_global_state": check_global_state,
|
||||
"___check_torch_function_mode_stack": self.torch_function_mode_stack_check_fn,
|
||||
**SYMPY_INTERP,
|
||||
**_get_closure_vars(),
|
||||
}
|
||||
|
||||
@ -794,6 +794,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
|
||||
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE))
|
||||
|
||||
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GLOBAL_STATE))
|
||||
self.guards.add(
|
||||
GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
|
||||
)
|
||||
|
||||
@ -36,6 +36,7 @@ __all__ = [
|
||||
"skip_guard_on_all_nn_modules_unsafe",
|
||||
"keep_tensor_guards_unsafe",
|
||||
"skip_guard_on_globals_unsafe",
|
||||
"skip_all_guards_unsafe",
|
||||
"nested_compile_region",
|
||||
]
|
||||
|
||||
@ -617,6 +618,23 @@ def skip_guard_on_globals_unsafe(guard_entries):
|
||||
return [not entry.is_global for entry in guard_entries]
|
||||
|
||||
|
||||
def skip_all_guards_unsafe(guard_entries):
|
||||
"""
|
||||
A function for skipping all guards on a compiled function.
|
||||
|
||||
WARNING: This function will drop all the safety guarantees from Dynamo
|
||||
compiled function. Use this with caution.
|
||||
|
||||
To use this API, use guard_filter_fn argument while calling torch.compile
|
||||
|
||||
>> opt_mod = torch.compile(
|
||||
>> mod,
|
||||
>> options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe},
|
||||
>> )
|
||||
"""
|
||||
return [False for entry in guard_entries]
|
||||
|
||||
|
||||
def nested_compile_region(fn=None):
|
||||
"""
|
||||
Tells **``torch.compile``** that the marked set of operations forms a nested
|
||||
|
||||
Reference in New Issue
Block a user