[DCP] Use multiprocess Pipes instead of Queues to improve communication contract with checkpointer process (#153488)

Summary:
### Diff Context
- PR introduces Pipes for multiprocess comms with checkpointer process.
- Pipes allow easier comms contract management due to close() API and catch-all feature when background process is dead (e.g. seg faults).

Test Plan: CI

Differential Revision: D74668559

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153488
Approved by: https://github.com/saumishr
This commit is contained in:
Meet Vadakkanchery
2025-05-14 16:47:40 +00:00
committed by PyTorch MergeBot
parent 8799bffc34
commit b6b0080419

View File

@ -87,20 +87,23 @@ class _AsyncCheckpointProcess:
pg_init_info: _ProcessGroupInitInfo,
):
self.ctx = mp.get_context("spawn")
self._mp_queue_send: mp.Queue = self.ctx.Queue()
self._mp_queue_recv: mp.Queue = self.ctx.Queue()
self._process_pipe, child_end = self.ctx.Pipe()
self._save_process = self.ctx.Process(
target=self._checkpointing_subprocess,
args=(
pg_init_info,
self._mp_queue_send,
self._mp_queue_recv,
child_end,
),
daemon=True,
)
self._save_process.start()
# Close the parent's copy of child end after we pass it into the child,
# so the recv()s on it will fail-fast if the child process dies.
child_end.close()
# Wait for the checkpoint background process to initialize.
# Using default GLOO init timeout.
response = self._wait_for_response(timeout=1800)
@ -109,9 +112,37 @@ class _AsyncCheckpointProcess:
def __del__(self) -> None:
if self._save_process.is_alive():
logger.info("Terminating the checkpoint background process...")
self._mp_queue_send.put(_CheckpointSaveProcessControlOpts.TERMINATE)
self._send(_CheckpointSaveProcessControlOpts.TERMINATE)
self._save_process.join()
def _send(self, data: Any) -> None:
self._process_pipe.send(data)
def _wait_for_response(self, timeout: Optional[float] = None) -> Any:
if not self._save_process.is_alive():
logger.info("Checkpoint background process is dead calling join()...")
self._save_process.join()
raise RuntimeError(
f"Checkpoint background process is dead. Exit code: {self._save_process.exitcode}"
)
if timeout is not None and not self._process_pipe.poll(timeout=timeout):
raise RuntimeError(
f"Timed out after {timeout}s while waiting for response from checkpointer process pid: {self._save_process.pid}"
)
try:
response = self._process_pipe.recv()
except EOFError:
raise RuntimeError( # noqa: B904
f"Checkpoint background process is dead. Exit code: {self._save_process.exitcode}"
)
if isinstance(response, BaseException):
raise response
return response
def save(
self,
staged_state_dict: STATE_DICT_TYPE,
@ -129,21 +160,11 @@ class _AsyncCheckpointProcess:
storage_writer=storage_writer,
planner=planner,
)
self._mp_queue_send.put(async_cp_request)
self._send(async_cp_request)
result = self._wait_for_response()
assert isinstance(result, Metadata)
return result
def _wait_for_response(self, timeout: Optional[float] = None) -> Any:
if not self._save_process.is_alive():
logger.info("Checkpoint background process is dead calling join()...")
self._save_process.join()
raise RuntimeError("Checkpoint background process is dead.")
response = self._mp_queue_recv.get(timeout=timeout)
if isinstance(response, BaseException):
raise response
return response
@staticmethod
def _execute_save(
state_dict: STATE_DICT_TYPE,
@ -165,8 +186,7 @@ class _AsyncCheckpointProcess:
@staticmethod
def _checkpointing_subprocess(
pg_init_info: _ProcessGroupInitInfo,
recv: mp.Queue,
send: mp.Queue,
parent_conn,
) -> None:
try:
_init_logger(pg_init_info.global_rank)
@ -187,12 +207,12 @@ class _AsyncCheckpointProcess:
dist.barrier()
logger.info("Checkpoint background process is running...")
send.put(_CheckpointSaveProcessControlOpts.INIT_COMPLETE)
parent_conn.send(_CheckpointSaveProcessControlOpts.INIT_COMPLETE)
# Serving loop.
while True:
logger.info("Waiting for checkpoint save request...")
obj = recv.get()
obj = parent_conn.recv()
if (
isinstance(obj, _CheckpointSaveProcessControlOpts)
and obj == _CheckpointSaveProcessControlOpts.TERMINATE
@ -210,7 +230,7 @@ class _AsyncCheckpointProcess:
storage_writer=obj.storage_writer,
planner=obj.planner,
)
send.put(response)
parent_conn.send(response)
logger.info(
f"Submitted checkpoint save request for checkpoint_id={obj.checkpoint_request_id}" # noqa: G004
)
@ -218,11 +238,12 @@ class _AsyncCheckpointProcess:
logger.error(
f"Checkpoint background process encountered an exception: {e}" # noqa: G004
)
send.put(e)
parent_conn.send(e)
raise
finally:
logger.info("Checkpoint background process is shutting down...")
dist.destroy_process_group()
parent_conn.close()
_CHECKPOINT_PROCESS: Optional[_AsyncCheckpointProcess] = None