mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[DCP] Use multiprocess Pipes instead of Queues to improve communication contract with checkpointer process (#153488)
Summary: ### Diff Context - PR introduces Pipes for multiprocess comms with checkpointer process. - Pipes allow easier comms contract management due to close() API and catch-all feature when background process is dead (e.g. seg faults). Test Plan: CI Differential Revision: D74668559 Pull Request resolved: https://github.com/pytorch/pytorch/pull/153488 Approved by: https://github.com/saumishr
This commit is contained in:
committed by
PyTorch MergeBot
parent
8799bffc34
commit
b6b0080419
@ -87,20 +87,23 @@ class _AsyncCheckpointProcess:
|
||||
pg_init_info: _ProcessGroupInitInfo,
|
||||
):
|
||||
self.ctx = mp.get_context("spawn")
|
||||
self._mp_queue_send: mp.Queue = self.ctx.Queue()
|
||||
self._mp_queue_recv: mp.Queue = self.ctx.Queue()
|
||||
self._process_pipe, child_end = self.ctx.Pipe()
|
||||
|
||||
self._save_process = self.ctx.Process(
|
||||
target=self._checkpointing_subprocess,
|
||||
args=(
|
||||
pg_init_info,
|
||||
self._mp_queue_send,
|
||||
self._mp_queue_recv,
|
||||
child_end,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
self._save_process.start()
|
||||
|
||||
# Close the parent's copy of child end after we pass it into the child,
|
||||
# so the recv()s on it will fail-fast if the child process dies.
|
||||
child_end.close()
|
||||
|
||||
# Wait for the checkpoint background process to initialize.
|
||||
# Using default GLOO init timeout.
|
||||
response = self._wait_for_response(timeout=1800)
|
||||
@ -109,9 +112,37 @@ class _AsyncCheckpointProcess:
|
||||
def __del__(self) -> None:
|
||||
if self._save_process.is_alive():
|
||||
logger.info("Terminating the checkpoint background process...")
|
||||
self._mp_queue_send.put(_CheckpointSaveProcessControlOpts.TERMINATE)
|
||||
self._send(_CheckpointSaveProcessControlOpts.TERMINATE)
|
||||
self._save_process.join()
|
||||
|
||||
def _send(self, data: Any) -> None:
|
||||
self._process_pipe.send(data)
|
||||
|
||||
def _wait_for_response(self, timeout: Optional[float] = None) -> Any:
|
||||
if not self._save_process.is_alive():
|
||||
logger.info("Checkpoint background process is dead calling join()...")
|
||||
self._save_process.join()
|
||||
raise RuntimeError(
|
||||
f"Checkpoint background process is dead. Exit code: {self._save_process.exitcode}"
|
||||
)
|
||||
|
||||
if timeout is not None and not self._process_pipe.poll(timeout=timeout):
|
||||
raise RuntimeError(
|
||||
f"Timed out after {timeout}s while waiting for response from checkpointer process pid: {self._save_process.pid}"
|
||||
)
|
||||
|
||||
try:
|
||||
response = self._process_pipe.recv()
|
||||
except EOFError:
|
||||
raise RuntimeError( # noqa: B904
|
||||
f"Checkpoint background process is dead. Exit code: {self._save_process.exitcode}"
|
||||
)
|
||||
|
||||
if isinstance(response, BaseException):
|
||||
raise response
|
||||
|
||||
return response
|
||||
|
||||
def save(
|
||||
self,
|
||||
staged_state_dict: STATE_DICT_TYPE,
|
||||
@ -129,21 +160,11 @@ class _AsyncCheckpointProcess:
|
||||
storage_writer=storage_writer,
|
||||
planner=planner,
|
||||
)
|
||||
self._mp_queue_send.put(async_cp_request)
|
||||
self._send(async_cp_request)
|
||||
result = self._wait_for_response()
|
||||
assert isinstance(result, Metadata)
|
||||
return result
|
||||
|
||||
def _wait_for_response(self, timeout: Optional[float] = None) -> Any:
|
||||
if not self._save_process.is_alive():
|
||||
logger.info("Checkpoint background process is dead calling join()...")
|
||||
self._save_process.join()
|
||||
raise RuntimeError("Checkpoint background process is dead.")
|
||||
response = self._mp_queue_recv.get(timeout=timeout)
|
||||
if isinstance(response, BaseException):
|
||||
raise response
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _execute_save(
|
||||
state_dict: STATE_DICT_TYPE,
|
||||
@ -165,8 +186,7 @@ class _AsyncCheckpointProcess:
|
||||
@staticmethod
|
||||
def _checkpointing_subprocess(
|
||||
pg_init_info: _ProcessGroupInitInfo,
|
||||
recv: mp.Queue,
|
||||
send: mp.Queue,
|
||||
parent_conn,
|
||||
) -> None:
|
||||
try:
|
||||
_init_logger(pg_init_info.global_rank)
|
||||
@ -187,12 +207,12 @@ class _AsyncCheckpointProcess:
|
||||
dist.barrier()
|
||||
|
||||
logger.info("Checkpoint background process is running...")
|
||||
send.put(_CheckpointSaveProcessControlOpts.INIT_COMPLETE)
|
||||
parent_conn.send(_CheckpointSaveProcessControlOpts.INIT_COMPLETE)
|
||||
|
||||
# Serving loop.
|
||||
while True:
|
||||
logger.info("Waiting for checkpoint save request...")
|
||||
obj = recv.get()
|
||||
obj = parent_conn.recv()
|
||||
if (
|
||||
isinstance(obj, _CheckpointSaveProcessControlOpts)
|
||||
and obj == _CheckpointSaveProcessControlOpts.TERMINATE
|
||||
@ -210,7 +230,7 @@ class _AsyncCheckpointProcess:
|
||||
storage_writer=obj.storage_writer,
|
||||
planner=obj.planner,
|
||||
)
|
||||
send.put(response)
|
||||
parent_conn.send(response)
|
||||
logger.info(
|
||||
f"Submitted checkpoint save request for checkpoint_id={obj.checkpoint_request_id}" # noqa: G004
|
||||
)
|
||||
@ -218,11 +238,12 @@ class _AsyncCheckpointProcess:
|
||||
logger.error(
|
||||
f"Checkpoint background process encountered an exception: {e}" # noqa: G004
|
||||
)
|
||||
send.put(e)
|
||||
parent_conn.send(e)
|
||||
raise
|
||||
finally:
|
||||
logger.info("Checkpoint background process is shutting down...")
|
||||
dist.destroy_process_group()
|
||||
parent_conn.close()
|
||||
|
||||
|
||||
_CHECKPOINT_PROCESS: Optional[_AsyncCheckpointProcess] = None
|
||||
|
||||
Reference in New Issue
Block a user