mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Ensure global state guard is preserved across serialization. (#157285)
Currently, every time we construct a GLOBAL_STATE guard, we always create a fresh guard based on the current global state. For precompile, we want to create a GLOBAL_STATE guard always based on some external sources, e.g. serialized global states. This can also be applied with the normal case where we just pass in the global state guard from Python. Differential Revision: [D77400988](https://our.internmc.facebook.com/intern/diff/D77400988/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157285 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
b146e1a264
commit
0f9c1b374f
@ -305,6 +305,9 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||||||
nonlocal ref_gm
|
nonlocal ref_gm
|
||||||
nonlocal loaded_gm
|
nonlocal loaded_gm
|
||||||
|
|
||||||
|
torch._dynamo.convert_frame.initial_global_state = (
|
||||||
|
torch._C._dynamo.guards.GlobalStateGuard()
|
||||||
|
)
|
||||||
tracer = InstructionTranslator(
|
tracer = InstructionTranslator(
|
||||||
instructions,
|
instructions,
|
||||||
self._frame_state.f_code,
|
self._frame_state.f_code,
|
||||||
@ -341,6 +344,8 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||||||
)
|
)
|
||||||
ref_gm = check_fn_manager.guard_manager
|
ref_gm = check_fn_manager.guard_manager
|
||||||
guards_state = check_fn_manager.guards_state
|
guards_state = check_fn_manager.guards_state
|
||||||
|
self._cached_guards_state = guards_state
|
||||||
|
self._cached_f_code = self._frame_state.f_code
|
||||||
self.assertIsNotNone(guards_state)
|
self.assertIsNotNone(guards_state)
|
||||||
guards_state = pickle.loads(guards_state)
|
guards_state = pickle.loads(guards_state)
|
||||||
|
|
||||||
@ -355,6 +360,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||||||
try:
|
try:
|
||||||
transform_code_object(self._frame_state.f_code, transform)
|
transform_code_object(self._frame_state.f_code, transform)
|
||||||
finally:
|
finally:
|
||||||
|
torch._dynamo.convert_frame.initial_global_state = None
|
||||||
self._frame_state = None
|
self._frame_state = None
|
||||||
|
|
||||||
self.assertIsNotNone(ref_gm)
|
self.assertIsNotNone(ref_gm)
|
||||||
@ -1138,6 +1144,25 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||||
|
|
||||||
|
def test_grad_mode_loading(self):
|
||||||
|
def fn(x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
x = torch.randn(3, 2)
|
||||||
|
with torch.enable_grad():
|
||||||
|
ref, _ = self._test_serialization("GRAD_MODE", fn, x)
|
||||||
|
with torch.no_grad():
|
||||||
|
# Ensure guards state loading is not affected by the current global grad mode.
|
||||||
|
guards_state = pickle.loads(self._cached_guards_state)
|
||||||
|
check_fn_manager = CheckFunctionManager(
|
||||||
|
self._cached_f_code,
|
||||||
|
guards_state.output_graph,
|
||||||
|
guards_serialization_mode="load",
|
||||||
|
shape_code_parts=guards_state.shape_code_parts,
|
||||||
|
)
|
||||||
|
loaded = check_fn_manager.guard_manager
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
|
||||||
def test_deterministic_algorithms(self):
|
def test_deterministic_algorithms(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
return x + 1
|
return x + 1
|
||||||
|
@ -3275,7 +3275,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||||||
def test_global_state_guard_serialization(self):
|
def test_global_state_guard_serialization(self):
|
||||||
GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard
|
GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard
|
||||||
guards = GlobalStateGuard()
|
guards = GlobalStateGuard()
|
||||||
serialized_guards = guards.dump()
|
serialized_guards = guards.__getstate__()
|
||||||
json_guards = json.loads(serialized_guards)
|
json_guards = json.loads(serialized_guards)
|
||||||
|
|
||||||
samples = []
|
samples = []
|
||||||
@ -3297,17 +3297,17 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||||||
samples.append(new_dict)
|
samples.append(new_dict)
|
||||||
|
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
guards.load(json.dumps(sample))
|
guards.__setstate__(json.dumps(sample))
|
||||||
self.assertFalse(guards.check())
|
self.assertFalse(guards.check())
|
||||||
|
|
||||||
guards.load(json.dumps(json_guards))
|
guards.__setstate__(json.dumps(json_guards))
|
||||||
self.assertTrue(guards.check())
|
self.assertTrue(guards.check())
|
||||||
|
|
||||||
# Test on autocast states.
|
# Test on autocast states.
|
||||||
def _test_autocast(dtype):
|
def _test_autocast(dtype):
|
||||||
with torch.autocast("cpu", dtype):
|
with torch.autocast("cpu", dtype):
|
||||||
guards = GlobalStateGuard()
|
guards = GlobalStateGuard()
|
||||||
serialized_guards = guards.dump()
|
serialized_guards = guards.__getstate__()
|
||||||
json_guards = json.loads(serialized_guards)
|
json_guards = json.loads(serialized_guards)
|
||||||
|
|
||||||
for i, enabled in enumerate(json_guards["autocast_state"]["enabled"]):
|
for i, enabled in enumerate(json_guards["autocast_state"]["enabled"]):
|
||||||
@ -3316,7 +3316,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||||||
type(json_guards["autocast_state"]["dtype"][i]), int
|
type(json_guards["autocast_state"]["dtype"][i]), int
|
||||||
)
|
)
|
||||||
json_guards["autocast_state"]["dtype"][i] += 1
|
json_guards["autocast_state"]["dtype"][i] += 1
|
||||||
guards.load(json.dumps(json_guards))
|
guards.__setstate__(json.dumps(json_guards))
|
||||||
self.assertFalse(guards.check())
|
self.assertFalse(guards.check())
|
||||||
|
|
||||||
_test_autocast(torch.float16)
|
_test_autocast(torch.float16)
|
||||||
|
@ -100,7 +100,9 @@ class GuardManager:
|
|||||||
equals_val,
|
equals_val,
|
||||||
verbose_code_parts: list[str],
|
verbose_code_parts: list[str],
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
|
def add_global_state_guard(
|
||||||
|
self, initial_state, verbose_code_parts: list[str]
|
||||||
|
) -> None: ...
|
||||||
def add_torch_function_mode_stack_guard(
|
def add_torch_function_mode_stack_guard(
|
||||||
self, initial_stack, verbose_code_parts: list[str]
|
self, initial_stack, verbose_code_parts: list[str]
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
@ -3061,7 +3061,11 @@ class CheckFunctionManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Insert the global_state guard
|
# Insert the global_state guard
|
||||||
self.guard_manager.root.add_global_state_guard(["___check_global_state()"])
|
assert self.output_graph is not None
|
||||||
|
global_state = self.output_graph.global_state_guard
|
||||||
|
self.guard_manager.root.add_global_state_guard(
|
||||||
|
global_state, ["___check_global_state()"]
|
||||||
|
)
|
||||||
|
|
||||||
self.guard_manager.root.add_torch_function_mode_stack_guard(
|
self.guard_manager.root.add_torch_function_mode_stack_guard(
|
||||||
self.torch_function_mode_stack,
|
self.torch_function_mode_stack,
|
||||||
@ -3188,8 +3192,7 @@ class CheckFunctionManager:
|
|||||||
"dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns]
|
"dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns]
|
||||||
)
|
)
|
||||||
|
|
||||||
global_state = convert_frame.initial_global_state
|
if convert_frame.initial_global_state is None:
|
||||||
if global_state is None:
|
|
||||||
# we should only hit this case in NopTests()
|
# we should only hit this case in NopTests()
|
||||||
global_state = convert_frame.GlobalStateGuard()
|
global_state = convert_frame.GlobalStateGuard()
|
||||||
closure_vars = {
|
closure_vars = {
|
||||||
|
@ -310,6 +310,7 @@ class OutputGraphGuardsState:
|
|||||||
dual_level: int
|
dual_level: int
|
||||||
functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter]
|
functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter]
|
||||||
current_device: Optional[torch.device]
|
current_device: Optional[torch.device]
|
||||||
|
global_state_guard: torch._C._dynamo.guards.GlobalStateGuard
|
||||||
|
|
||||||
export: bool = False
|
export: bool = False
|
||||||
export_constraints: bool = False
|
export_constraints: bool = False
|
||||||
@ -379,6 +380,9 @@ class OutputGraph(OutputGraphGuardsState):
|
|||||||
dual_level=torch.autograd.forward_ad._current_level,
|
dual_level=torch.autograd.forward_ad._current_level,
|
||||||
functorch_layers=torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters(),
|
functorch_layers=torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters(),
|
||||||
current_device=torch.utils._device.CURRENT_DEVICE,
|
current_device=torch.utils._device.CURRENT_DEVICE,
|
||||||
|
# initial_global_state is only None during NopTest.
|
||||||
|
global_state_guard=torch._dynamo.convert_frame.initial_global_state
|
||||||
|
or torch._C._dynamo.guards.GlobalStateGuard(),
|
||||||
)
|
)
|
||||||
self.tracers = [SubgraphTracer(self, is_export=export)]
|
self.tracers = [SubgraphTracer(self, is_export=export)]
|
||||||
# Map from graph input's `Source` to its `VariableTracker` to
|
# Map from graph input's `Source` to its `VariableTracker` to
|
||||||
@ -675,6 +679,7 @@ class OutputGraph(OutputGraphGuardsState):
|
|||||||
dual_level=self.dual_level,
|
dual_level=self.dual_level,
|
||||||
functorch_layers=self.functorch_layers,
|
functorch_layers=self.functorch_layers,
|
||||||
current_device=self.current_device,
|
current_device=self.current_device,
|
||||||
|
global_state_guard=self.global_state_guard,
|
||||||
export=self.export,
|
export=self.export,
|
||||||
export_constraints=self.export_constraints,
|
export_constraints=self.export_constraints,
|
||||||
_guards=self.guards,
|
_guards=self.guards,
|
||||||
|
@ -744,11 +744,11 @@ static PyMethodDef GlobalStateGuard_methods[] = {
|
|||||||
(PyCFunction)(void*)GlobalStateGuard_reason,
|
(PyCFunction)(void*)GlobalStateGuard_reason,
|
||||||
METH_NOARGS,
|
METH_NOARGS,
|
||||||
"Return string reason for guard check failing"},
|
"Return string reason for guard check failing"},
|
||||||
{"dump",
|
{"__getstate__",
|
||||||
(PyCFunction)(void*)GlobalStateGuard_dump,
|
(PyCFunction)(void*)GlobalStateGuard_dump,
|
||||||
METH_NOARGS,
|
METH_NOARGS,
|
||||||
"Return serialized json format"},
|
"Return serialized json format"},
|
||||||
{"load",
|
{"__setstate__",
|
||||||
(PyCFunction)(void*)GlobalStateGuard_load,
|
(PyCFunction)(void*)GlobalStateGuard_load,
|
||||||
METH_VARARGS,
|
METH_VARARGS,
|
||||||
"Parse serialized json format"},
|
"Parse serialized json format"},
|
||||||
@ -1889,9 +1889,19 @@ class DEFAULT_DEVICE : public LeafGuard {
|
|||||||
class GLOBAL_STATE : public LeafGuard {
|
class GLOBAL_STATE : public LeafGuard {
|
||||||
public:
|
public:
|
||||||
GLOBAL_STATE(py::object verbose_code_parts)
|
GLOBAL_STATE(py::object verbose_code_parts)
|
||||||
: LeafGuard(std::move(verbose_code_parts)) {
|
: LeafGuard(std::move(verbose_code_parts)),
|
||||||
_guard = std::make_unique<GlobalStateGuard>();
|
_guard(PyObject_New(GlobalStateGuard, &GlobalStateGuardType)) {
|
||||||
_guard->init();
|
_guard->init();
|
||||||
|
owner_ = py::reinterpret_steal<py::object>((PyObject*)_guard);
|
||||||
|
}
|
||||||
|
|
||||||
|
GLOBAL_STATE(py::object initial_state, py::object verbose_code_parts)
|
||||||
|
: LeafGuard(std::move(verbose_code_parts)),
|
||||||
|
owner_(std::move(initial_state)),
|
||||||
|
_guard((GlobalStateGuard*)owner_.ptr()) {
|
||||||
|
if (!PyObject_TypeCheck(owner_.ptr(), &GlobalStateGuardType)) {
|
||||||
|
throw py::type_error("GLOBAL_STATE expects a GlobalStateGuard");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool check_nopybind(PyObject* value) override { // borrowed ref
|
bool check_nopybind(PyObject* value) override { // borrowed ref
|
||||||
@ -1913,7 +1923,8 @@ class GLOBAL_STATE : public LeafGuard {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<GlobalStateGuard> _guard;
|
py::object owner_;
|
||||||
|
GlobalStateGuard* _guard;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Checks that an attr is absent in the object. We don't need the opposite
|
// Checks that an attr is absent in the object. We don't need the opposite
|
||||||
@ -5842,9 +5853,11 @@ PyObject* torch_c_dynamo_guards_init() {
|
|||||||
})
|
})
|
||||||
.def(
|
.def(
|
||||||
"add_global_state_guard",
|
"add_global_state_guard",
|
||||||
[](GuardManager& self, py::object verbose_code_parts) -> void {
|
[](GuardManager& self,
|
||||||
self.add_leaf_guard(
|
py::object initial_state,
|
||||||
std::make_shared<GLOBAL_STATE>(std::move(verbose_code_parts)));
|
py::object verbose_code_parts) -> void {
|
||||||
|
self.add_leaf_guard(std::make_shared<GLOBAL_STATE>(
|
||||||
|
std::move(initial_state), std::move(verbose_code_parts)));
|
||||||
})
|
})
|
||||||
.def(
|
.def(
|
||||||
"add_torch_function_mode_stack_guard",
|
"add_torch_function_mode_stack_guard",
|
||||||
|
Reference in New Issue
Block a user