mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-20 10:34:57 +08:00
[dynamo] Make global state guards and torch function stack guards droppable. (#167674)
Summary: Prior to this PR we will always build global and torch funciton guards in all cases. In this PR we did 2 changes to dynamo guards: 1. Created a new guard called "GLOBAL_STATE" which corresponds to the global state guard and can be filtered out using guard_filter_fn 2. Repurpose the existing "TORCH_FUNCTION_STATE" guard for checking torch function mode stack. Also added a new helper `torch.compiler.skip_all_guards_unsafe` which can be useful for use cases like vllm Test Plan: CI Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/167674 Approved by: https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
7ede33b8e3
commit
e0fff31ae3
@ -30,5 +30,6 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
|
|||||||
skip_guard_on_all_nn_modules_unsafe
|
skip_guard_on_all_nn_modules_unsafe
|
||||||
keep_tensor_guards_unsafe
|
keep_tensor_guards_unsafe
|
||||||
skip_guard_on_globals_unsafe
|
skip_guard_on_globals_unsafe
|
||||||
|
skip_all_guards_unsafe
|
||||||
nested_compile_region
|
nested_compile_region
|
||||||
```
|
```
|
||||||
|
|||||||
@ -330,6 +330,13 @@ y = FakeTensor(..., size=(2,))
|
|||||||
'obj_weakref': None
|
'obj_weakref': None
|
||||||
'guarded_class': None
|
'guarded_class': None
|
||||||
}
|
}
|
||||||
|
global '' GLOBAL_STATE
|
||||||
|
{
|
||||||
|
'guard_types': None,
|
||||||
|
'code': None,
|
||||||
|
'obj_weakref': None
|
||||||
|
'guarded_class': None
|
||||||
|
}
|
||||||
global '' TORCH_FUNCTION_STATE
|
global '' TORCH_FUNCTION_STATE
|
||||||
{
|
{
|
||||||
'guard_types': None,
|
'guard_types': None,
|
||||||
|
|||||||
@ -1214,7 +1214,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
|||||||
|
|
||||||
x = torch.randn(3, 2)
|
x = torch.randn(3, 2)
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
ref, loaded = self._test_serialization("GRAD_MODE", fn, x)
|
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
@ -1226,7 +1226,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
|||||||
|
|
||||||
x = torch.randn(3, 2)
|
x = torch.randn(3, 2)
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
ref, _ = self._test_serialization("GRAD_MODE", fn, x)
|
ref, _ = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Ensure guards state loading is not affected by the current global grad mode.
|
# Ensure guards state loading is not affected by the current global grad mode.
|
||||||
guards_state = pickle.loads(self._cached_guards_state)
|
guards_state = pickle.loads(self._cached_guards_state)
|
||||||
@ -1246,7 +1246,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
|||||||
try:
|
try:
|
||||||
x = torch.randn(3, 2)
|
x = torch.randn(3, 2)
|
||||||
torch.use_deterministic_algorithms(True)
|
torch.use_deterministic_algorithms(True)
|
||||||
ref, loaded = self._test_serialization("DETERMINISTIC_ALGORITHMS", fn, x)
|
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||||
torch.use_deterministic_algorithms(False)
|
torch.use_deterministic_algorithms(False)
|
||||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
torch.use_deterministic_algorithms(True)
|
torch.use_deterministic_algorithms(True)
|
||||||
@ -1270,6 +1270,9 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
|||||||
ref, loaded = self._test_serialization("TORCH_FUNCTION_STATE", fn, x)
|
ref, loaded = self._test_serialization("TORCH_FUNCTION_STATE", fn, x)
|
||||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
with GlobalTorchFunctionMode():
|
||||||
|
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||||
with GlobalTorchFunctionMode():
|
with GlobalTorchFunctionMode():
|
||||||
with torch._C.DisableTorchFunction():
|
with torch._C.DisableTorchFunction():
|
||||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
@ -1306,7 +1309,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
|||||||
x = torch.randn(3, 2)
|
x = torch.randn(3, 2)
|
||||||
|
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
ref, loaded = self._test_serialization("FSDP_TRAINING_STATE", fn, x)
|
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
@ -1690,6 +1693,38 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
|||||||
ref, loaded, {"x": x, "d": ModWithDict({"b": 1e-9, "a": 1e9})}, False
|
ref, loaded, {"x": x, "d": ModWithDict({"b": 1e-9, "a": 1e9})}, False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_global_state_guard_filter(self):
|
||||||
|
def foo(x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
x = torch.randn(3, 2)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
compiled_fn = torch.compile(
|
||||||
|
foo, options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe}
|
||||||
|
)
|
||||||
|
compiled_fn(x)
|
||||||
|
|
||||||
|
# Check global guards are gone.
|
||||||
|
with torch.enable_grad(), torch.compiler.set_stance("fail_on_recompile"):
|
||||||
|
self.assertEqual(compiled_fn(x), foo(x))
|
||||||
|
|
||||||
|
def test_torch_function_state_filter(self):
|
||||||
|
def foo(x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
x = torch.randn(3, 2)
|
||||||
|
|
||||||
|
with GlobalTorchFunctionMode():
|
||||||
|
compiled_fn = torch.compile(
|
||||||
|
foo, options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe}
|
||||||
|
)
|
||||||
|
compiled_fn(x)
|
||||||
|
|
||||||
|
# Check global guards are gone.
|
||||||
|
with torch.compiler.set_stance("fail_on_recompile"):
|
||||||
|
self.assertEqual(compiled_fn(x), foo(x))
|
||||||
|
|
||||||
|
|
||||||
class SimpleModule(torch.nn.Module):
|
class SimpleModule(torch.nn.Module):
|
||||||
def __init__(self, c):
|
def __init__(self, c):
|
||||||
|
|||||||
@ -562,7 +562,7 @@ class TestDynamoTimed(TestCase):
|
|||||||
'graph_node_count': 3,
|
'graph_node_count': 3,
|
||||||
'graph_node_shapes': None,
|
'graph_node_shapes': None,
|
||||||
'graph_op_count': 1,
|
'graph_op_count': 1,
|
||||||
'guard_count': 9,
|
'guard_count': 10,
|
||||||
'has_guarded_code': True,
|
'has_guarded_code': True,
|
||||||
'inductor_code_gen_cumulative_compile_time_us': 0,
|
'inductor_code_gen_cumulative_compile_time_us': 0,
|
||||||
'inductor_compile_time_s': 0.0,
|
'inductor_compile_time_s': 0.0,
|
||||||
@ -608,7 +608,7 @@ class TestDynamoTimed(TestCase):
|
|||||||
'tensorify_float_attempt': None,
|
'tensorify_float_attempt': None,
|
||||||
'tensorify_float_failure': None,
|
'tensorify_float_failure': None,
|
||||||
'tensorify_float_success': None,
|
'tensorify_float_success': None,
|
||||||
'triton_compile_time_us': None,
|
'triton_compile_time_us': 0,
|
||||||
'triton_kernel_compile_times_us': None,
|
'triton_kernel_compile_times_us': None,
|
||||||
'triton_version': None}"""
|
'triton_version': None}"""
|
||||||
if _IS_WINDOWS
|
if _IS_WINDOWS
|
||||||
@ -649,7 +649,7 @@ class TestDynamoTimed(TestCase):
|
|||||||
'graph_node_count': 3,
|
'graph_node_count': 3,
|
||||||
'graph_node_shapes': None,
|
'graph_node_shapes': None,
|
||||||
'graph_op_count': 1,
|
'graph_op_count': 1,
|
||||||
'guard_count': 9,
|
'guard_count': 10,
|
||||||
'has_guarded_code': True,
|
'has_guarded_code': True,
|
||||||
'inductor_code_gen_cumulative_compile_time_us': 0,
|
'inductor_code_gen_cumulative_compile_time_us': 0,
|
||||||
'inductor_compile_time_s': 0.0,
|
'inductor_compile_time_s': 0.0,
|
||||||
|
|||||||
@ -2507,12 +2507,30 @@ class GuardBuilder(GuardBuilderBase):
|
|||||||
def DETERMINISTIC_ALGORITHMS(self, guard: Guard) -> None:
|
def DETERMINISTIC_ALGORITHMS(self, guard: Guard) -> None:
|
||||||
pass # we always guard on this via GlobalStateGuard()
|
pass # we always guard on this via GlobalStateGuard()
|
||||||
|
|
||||||
def TORCH_FUNCTION_STATE(self, guard: Guard) -> None:
|
|
||||||
pass # we always guard on this via GlobalStateGuard()
|
|
||||||
|
|
||||||
def FSDP_TRAINING_STATE(self, guard: Guard) -> None:
|
def FSDP_TRAINING_STATE(self, guard: Guard) -> None:
|
||||||
pass # we always guard on this via GlobalStateGuard()
|
pass # we always guard on this via GlobalStateGuard()
|
||||||
|
|
||||||
|
def GLOBAL_STATE(self, guard: Guard) -> None:
|
||||||
|
output_graph = self.check_fn_manager.output_graph
|
||||||
|
assert output_graph is not None
|
||||||
|
global_state = output_graph.global_state_guard
|
||||||
|
self.check_fn_manager.global_state = global_state
|
||||||
|
self.guard_manager.root.add_global_state_guard(
|
||||||
|
global_state, ["___check_global_state()"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def TORCH_FUNCTION_STATE(self, guard: Guard) -> None:
|
||||||
|
assert self.check_fn_manager.torch_function_mode_stack is not None
|
||||||
|
self.check_fn_manager.torch_function_mode_stack_check_fn = (
|
||||||
|
make_torch_function_mode_stack_guard(
|
||||||
|
self.check_fn_manager.torch_function_mode_stack
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.guard_manager.root.add_torch_function_mode_stack_guard(
|
||||||
|
self.check_fn_manager.torch_function_mode_stack,
|
||||||
|
["___check_torch_function_mode_stack()"],
|
||||||
|
)
|
||||||
|
|
||||||
def DEFAULT_DEVICE(self, guard: Guard) -> None:
|
def DEFAULT_DEVICE(self, guard: Guard) -> None:
|
||||||
"""Guard on CURRENT_DEVICE per torch.utils._device"""
|
"""Guard on CURRENT_DEVICE per torch.utils._device"""
|
||||||
assert guard.source is GuardSource.GLOBAL
|
assert guard.source is GuardSource.GLOBAL
|
||||||
@ -3532,6 +3550,8 @@ class CheckFunctionManager:
|
|||||||
self.additional_used_local_vars: OrderedSet[str] = OrderedSet()
|
self.additional_used_local_vars: OrderedSet[str] = OrderedSet()
|
||||||
self.additional_used_global_vars: OrderedSet[str] = OrderedSet()
|
self.additional_used_global_vars: OrderedSet[str] = OrderedSet()
|
||||||
self.runtime_global_scope = runtime_global_scope
|
self.runtime_global_scope = runtime_global_scope
|
||||||
|
self.global_state: Optional[torch._C._dynamo.guards.GlobalStateGuard] = None
|
||||||
|
self.torch_function_mode_stack_check_fn: Optional[Callable[[], bool]] = None
|
||||||
|
|
||||||
if not justknobs_check("pytorch/compiler:guard_nn_modules"):
|
if not justknobs_check("pytorch/compiler:guard_nn_modules"):
|
||||||
log.warning("guard_nn_modules is turned off using justknobs killswitch")
|
log.warning("guard_nn_modules is turned off using justknobs killswitch")
|
||||||
@ -3939,27 +3959,11 @@ class CheckFunctionManager:
|
|||||||
verbose_code_parts = []
|
verbose_code_parts = []
|
||||||
structured_guard_fns: list[Callable[[], dict[str, Any]]] = []
|
structured_guard_fns: list[Callable[[], dict[str, Any]]] = []
|
||||||
|
|
||||||
assert self.torch_function_mode_stack is not None
|
|
||||||
torch_function_mode_stack_check_fn = make_torch_function_mode_stack_guard(
|
|
||||||
self.torch_function_mode_stack
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add compile id info in the guard manager for debugging purpose
|
# Add compile id info in the guard manager for debugging purpose
|
||||||
self.guard_manager.root.attach_compile_id(
|
self.guard_manager.root.attach_compile_id(
|
||||||
str(CompileContext.current_compile_id())
|
str(CompileContext.current_compile_id())
|
||||||
)
|
)
|
||||||
|
|
||||||
# Insert the global_state guard
|
|
||||||
assert self.output_graph is not None
|
|
||||||
global_state = self.output_graph.global_state_guard
|
|
||||||
self.guard_manager.root.add_global_state_guard(
|
|
||||||
global_state, ["___check_global_state()"]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.guard_manager.root.add_torch_function_mode_stack_guard(
|
|
||||||
self.torch_function_mode_stack,
|
|
||||||
["___check_torch_function_mode_stack()"],
|
|
||||||
)
|
|
||||||
# Clear references to torch_function modes held in the list
|
# Clear references to torch_function modes held in the list
|
||||||
self.torch_function_mode_stack = None
|
self.torch_function_mode_stack = None
|
||||||
|
|
||||||
@ -4105,12 +4109,14 @@ class CheckFunctionManager:
|
|||||||
|
|
||||||
if convert_frame.initial_global_state is None:
|
if convert_frame.initial_global_state is None:
|
||||||
# we should only hit this case in NopTests()
|
# we should only hit this case in NopTests()
|
||||||
global_state = convert_frame.GlobalStateGuard()
|
check_global_state = convert_frame.GlobalStateGuard().check
|
||||||
|
else:
|
||||||
|
check_global_state = getattr(self.global_state, "check", None)
|
||||||
closure_vars = {
|
closure_vars = {
|
||||||
"___check_tensors": check_tensors_fn,
|
"___check_tensors": check_tensors_fn,
|
||||||
"___check_tensors_verbose": check_tensors_verbose_fn,
|
"___check_tensors_verbose": check_tensors_verbose_fn,
|
||||||
"___check_global_state": global_state.check,
|
"___check_global_state": check_global_state,
|
||||||
"___check_torch_function_mode_stack": torch_function_mode_stack_check_fn,
|
"___check_torch_function_mode_stack": self.torch_function_mode_stack_check_fn,
|
||||||
**SYMPY_INTERP,
|
**SYMPY_INTERP,
|
||||||
**_get_closure_vars(),
|
**_get_closure_vars(),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -794,6 +794,7 @@ class OutputGraph(OutputGraphCommon):
|
|||||||
|
|
||||||
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE))
|
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE))
|
||||||
|
|
||||||
|
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GLOBAL_STATE))
|
||||||
self.guards.add(
|
self.guards.add(
|
||||||
GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
|
GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -36,6 +36,7 @@ __all__ = [
|
|||||||
"skip_guard_on_all_nn_modules_unsafe",
|
"skip_guard_on_all_nn_modules_unsafe",
|
||||||
"keep_tensor_guards_unsafe",
|
"keep_tensor_guards_unsafe",
|
||||||
"skip_guard_on_globals_unsafe",
|
"skip_guard_on_globals_unsafe",
|
||||||
|
"skip_all_guards_unsafe",
|
||||||
"nested_compile_region",
|
"nested_compile_region",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -617,6 +618,23 @@ def skip_guard_on_globals_unsafe(guard_entries):
|
|||||||
return [not entry.is_global for entry in guard_entries]
|
return [not entry.is_global for entry in guard_entries]
|
||||||
|
|
||||||
|
|
||||||
|
def skip_all_guards_unsafe(guard_entries):
|
||||||
|
"""
|
||||||
|
A function for skipping all guards on a compiled function.
|
||||||
|
|
||||||
|
WARNING: This function will drop all the safety guarantees from Dynamo
|
||||||
|
compiled function. Use this with caution.
|
||||||
|
|
||||||
|
To use this API, use guard_filter_fn argument while calling torch.compile
|
||||||
|
|
||||||
|
>> opt_mod = torch.compile(
|
||||||
|
>> mod,
|
||||||
|
>> options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe},
|
||||||
|
>> )
|
||||||
|
"""
|
||||||
|
return [False for entry in guard_entries]
|
||||||
|
|
||||||
|
|
||||||
def nested_compile_region(fn=None):
|
def nested_compile_region(fn=None):
|
||||||
"""
|
"""
|
||||||
Tells **``torch.compile``** that the marked set of operations forms a nested
|
Tells **``torch.compile``** that the marked set of operations forms a nested
|
||||||
|
|||||||
Reference in New Issue
Block a user