[dynamo] Ensure global state guard is preserved across serialization. (#157285)

Currently, every time we construct a GLOBAL_STATE guard, we always create a fresh guard based on the current global state. For precompile, we want to create a GLOBAL_STATE guard always based on some external sources, e.g. serialized global states. This can also be applied with the normal case where we just pass in the global state guard from Python.

Differential Revision: [D77400988](https://our.internmc.facebook.com/intern/diff/D77400988/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157285
Approved by: https://github.com/jansel
This commit is contained in:
zhxchen17
2025-06-30 12:13:16 -07:00
committed by PyTorch MergeBot
parent b146e1a264
commit 0f9c1b374f
6 changed files with 65 additions and 17 deletions

View File

@ -305,6 +305,9 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
nonlocal ref_gm nonlocal ref_gm
nonlocal loaded_gm nonlocal loaded_gm
torch._dynamo.convert_frame.initial_global_state = (
torch._C._dynamo.guards.GlobalStateGuard()
)
tracer = InstructionTranslator( tracer = InstructionTranslator(
instructions, instructions,
self._frame_state.f_code, self._frame_state.f_code,
@ -341,6 +344,8 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
) )
ref_gm = check_fn_manager.guard_manager ref_gm = check_fn_manager.guard_manager
guards_state = check_fn_manager.guards_state guards_state = check_fn_manager.guards_state
self._cached_guards_state = guards_state
self._cached_f_code = self._frame_state.f_code
self.assertIsNotNone(guards_state) self.assertIsNotNone(guards_state)
guards_state = pickle.loads(guards_state) guards_state = pickle.loads(guards_state)
@ -355,6 +360,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
try: try:
transform_code_object(self._frame_state.f_code, transform) transform_code_object(self._frame_state.f_code, transform)
finally: finally:
torch._dynamo.convert_frame.initial_global_state = None
self._frame_state = None self._frame_state = None
self.assertIsNotNone(ref_gm) self.assertIsNotNone(ref_gm)
@ -1138,6 +1144,25 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
with torch.enable_grad(): with torch.enable_grad():
self._test_check_fn(ref, loaded, {"x": x}, True) self._test_check_fn(ref, loaded, {"x": x}, True)
def test_grad_mode_loading(self):
def fn(x):
return x + 1
x = torch.randn(3, 2)
with torch.enable_grad():
ref, _ = self._test_serialization("GRAD_MODE", fn, x)
with torch.no_grad():
# Ensure guards state loading is not affected by the current global grad mode.
guards_state = pickle.loads(self._cached_guards_state)
check_fn_manager = CheckFunctionManager(
self._cached_f_code,
guards_state.output_graph,
guards_serialization_mode="load",
shape_code_parts=guards_state.shape_code_parts,
)
loaded = check_fn_manager.guard_manager
self._test_check_fn(ref, loaded, {"x": x}, False)
def test_deterministic_algorithms(self): def test_deterministic_algorithms(self):
def fn(x): def fn(x):
return x + 1 return x + 1

View File

@ -3275,7 +3275,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
def test_global_state_guard_serialization(self): def test_global_state_guard_serialization(self):
GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard
guards = GlobalStateGuard() guards = GlobalStateGuard()
serialized_guards = guards.dump() serialized_guards = guards.__getstate__()
json_guards = json.loads(serialized_guards) json_guards = json.loads(serialized_guards)
samples = [] samples = []
@ -3297,17 +3297,17 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
samples.append(new_dict) samples.append(new_dict)
for sample in samples: for sample in samples:
guards.load(json.dumps(sample)) guards.__setstate__(json.dumps(sample))
self.assertFalse(guards.check()) self.assertFalse(guards.check())
guards.load(json.dumps(json_guards)) guards.__setstate__(json.dumps(json_guards))
self.assertTrue(guards.check()) self.assertTrue(guards.check())
# Test on autocast states. # Test on autocast states.
def _test_autocast(dtype): def _test_autocast(dtype):
with torch.autocast("cpu", dtype): with torch.autocast("cpu", dtype):
guards = GlobalStateGuard() guards = GlobalStateGuard()
serialized_guards = guards.dump() serialized_guards = guards.__getstate__()
json_guards = json.loads(serialized_guards) json_guards = json.loads(serialized_guards)
for i, enabled in enumerate(json_guards["autocast_state"]["enabled"]): for i, enabled in enumerate(json_guards["autocast_state"]["enabled"]):
@ -3316,7 +3316,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
type(json_guards["autocast_state"]["dtype"][i]), int type(json_guards["autocast_state"]["dtype"][i]), int
) )
json_guards["autocast_state"]["dtype"][i] += 1 json_guards["autocast_state"]["dtype"][i] += 1
guards.load(json.dumps(json_guards)) guards.__setstate__(json.dumps(json_guards))
self.assertFalse(guards.check()) self.assertFalse(guards.check())
_test_autocast(torch.float16) _test_autocast(torch.float16)

View File

@ -100,7 +100,9 @@ class GuardManager:
equals_val, equals_val,
verbose_code_parts: list[str], verbose_code_parts: list[str],
) -> None: ... ) -> None: ...
def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ... def add_global_state_guard(
self, initial_state, verbose_code_parts: list[str]
) -> None: ...
def add_torch_function_mode_stack_guard( def add_torch_function_mode_stack_guard(
self, initial_stack, verbose_code_parts: list[str] self, initial_stack, verbose_code_parts: list[str]
) -> None: ... ) -> None: ...

View File

@ -3061,7 +3061,11 @@ class CheckFunctionManager:
) )
# Insert the global_state guard # Insert the global_state guard
self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) assert self.output_graph is not None
global_state = self.output_graph.global_state_guard
self.guard_manager.root.add_global_state_guard(
global_state, ["___check_global_state()"]
)
self.guard_manager.root.add_torch_function_mode_stack_guard( self.guard_manager.root.add_torch_function_mode_stack_guard(
self.torch_function_mode_stack, self.torch_function_mode_stack,
@ -3188,8 +3192,7 @@ class CheckFunctionManager:
"dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns] "dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns]
) )
global_state = convert_frame.initial_global_state if convert_frame.initial_global_state is None:
if global_state is None:
# we should only hit this case in NopTests() # we should only hit this case in NopTests()
global_state = convert_frame.GlobalStateGuard() global_state = convert_frame.GlobalStateGuard()
closure_vars = { closure_vars = {

View File

@ -310,6 +310,7 @@ class OutputGraphGuardsState:
dual_level: int dual_level: int
functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter] functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter]
current_device: Optional[torch.device] current_device: Optional[torch.device]
global_state_guard: torch._C._dynamo.guards.GlobalStateGuard
export: bool = False export: bool = False
export_constraints: bool = False export_constraints: bool = False
@ -379,6 +380,9 @@ class OutputGraph(OutputGraphGuardsState):
dual_level=torch.autograd.forward_ad._current_level, dual_level=torch.autograd.forward_ad._current_level,
functorch_layers=torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters(), functorch_layers=torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters(),
current_device=torch.utils._device.CURRENT_DEVICE, current_device=torch.utils._device.CURRENT_DEVICE,
# initial_global_state is only None during NopTest.
global_state_guard=torch._dynamo.convert_frame.initial_global_state
or torch._C._dynamo.guards.GlobalStateGuard(),
) )
self.tracers = [SubgraphTracer(self, is_export=export)] self.tracers = [SubgraphTracer(self, is_export=export)]
# Map from graph input's `Source` to its `VariableTracker` to # Map from graph input's `Source` to its `VariableTracker` to
@ -675,6 +679,7 @@ class OutputGraph(OutputGraphGuardsState):
dual_level=self.dual_level, dual_level=self.dual_level,
functorch_layers=self.functorch_layers, functorch_layers=self.functorch_layers,
current_device=self.current_device, current_device=self.current_device,
global_state_guard=self.global_state_guard,
export=self.export, export=self.export,
export_constraints=self.export_constraints, export_constraints=self.export_constraints,
_guards=self.guards, _guards=self.guards,

View File

@ -744,11 +744,11 @@ static PyMethodDef GlobalStateGuard_methods[] = {
(PyCFunction)(void*)GlobalStateGuard_reason, (PyCFunction)(void*)GlobalStateGuard_reason,
METH_NOARGS, METH_NOARGS,
"Return string reason for guard check failing"}, "Return string reason for guard check failing"},
{"dump", {"__getstate__",
(PyCFunction)(void*)GlobalStateGuard_dump, (PyCFunction)(void*)GlobalStateGuard_dump,
METH_NOARGS, METH_NOARGS,
"Return serialized json format"}, "Return serialized json format"},
{"load", {"__setstate__",
(PyCFunction)(void*)GlobalStateGuard_load, (PyCFunction)(void*)GlobalStateGuard_load,
METH_VARARGS, METH_VARARGS,
"Parse serialized json format"}, "Parse serialized json format"},
@ -1889,9 +1889,19 @@ class DEFAULT_DEVICE : public LeafGuard {
class GLOBAL_STATE : public LeafGuard { class GLOBAL_STATE : public LeafGuard {
public: public:
GLOBAL_STATE(py::object verbose_code_parts) GLOBAL_STATE(py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)) { : LeafGuard(std::move(verbose_code_parts)),
_guard = std::make_unique<GlobalStateGuard>(); _guard(PyObject_New(GlobalStateGuard, &GlobalStateGuardType)) {
_guard->init(); _guard->init();
owner_ = py::reinterpret_steal<py::object>((PyObject*)_guard);
}
GLOBAL_STATE(py::object initial_state, py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)),
owner_(std::move(initial_state)),
_guard((GlobalStateGuard*)owner_.ptr()) {
if (!PyObject_TypeCheck(owner_.ptr(), &GlobalStateGuardType)) {
throw py::type_error("GLOBAL_STATE expects a GlobalStateGuard");
}
} }
bool check_nopybind(PyObject* value) override { // borrowed ref bool check_nopybind(PyObject* value) override { // borrowed ref
@ -1913,7 +1923,8 @@ class GLOBAL_STATE : public LeafGuard {
} }
private: private:
std::unique_ptr<GlobalStateGuard> _guard; py::object owner_;
GlobalStateGuard* _guard;
}; };
// Checks that an attr is absent in the object. We don't need the opposite // Checks that an attr is absent in the object. We don't need the opposite
@ -5842,9 +5853,11 @@ PyObject* torch_c_dynamo_guards_init() {
}) })
.def( .def(
"add_global_state_guard", "add_global_state_guard",
[](GuardManager& self, py::object verbose_code_parts) -> void { [](GuardManager& self,
self.add_leaf_guard( py::object initial_state,
std::make_shared<GLOBAL_STATE>(std::move(verbose_code_parts))); py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<GLOBAL_STATE>(
std::move(initial_state), std::move(verbose_code_parts)));
}) })
.def( .def(
"add_torch_function_mode_stack_guard", "add_torch_function_mode_stack_guard",