Files
pytorch/torch/distributed/checkpoint/_experimental/checkpoint_process.py
Yuanyuan Chen 3255e7872b Enable all flake8-logging-format rules (#164655)
These rules are enabled by removing existing suppressions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164655
Approved by: https://github.com/janeyx99, https://github.com/mlazos
2025-10-19 00:59:28 +00:00

362 lines
13 KiB
Python

import logging
import os
import traceback
from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum
from multiprocessing.connection import Connection
from typing import Any, Optional, Union
import torch.multiprocessing as mp
from torch.multiprocessing.spawn import ProcessExitedException
from .checkpoint_writer import CheckpointWriter
from .types import RankInfo, STATE_DICT
logger = logging.getLogger(__name__)
@dataclass
class CheckpointProcessConfig:
"""
Configuration options for the CheckpointProcess.
This class provides configuration options for the checkpoint process,
including initialization functions, timeouts, and writer configuration.
Attributes:
subprocess_init_timeout_secs: Maximum time in seconds to wait for subprocess initialization.
subprocess_shutdown_timeout_secs: Maximum time in seconds to wait for subprocess shutdown.
"""
subprocess_init_timeout_secs: int = 30
subprocess_shutdown_timeout_secs: int = 60
class RequestType(Enum):
PING = "ping"
WRITE_CHECKPOINT = "write_checkpoint"
TERMINATE_PROCESS = "exit"
@dataclass
class WorkerRequest:
"""
A dataclass for storing the command to be sent to the worker process.
Note: This relies on pickling to send the command to the worker process. Handle
backward compatibility accordingly.
"""
request_type: RequestType
payload: dict[str, Any]
@dataclass
class WorkerResponse:
request_type: RequestType
success: bool
error_msg: Optional[str] = None
payload: Optional[dict[str, Any]] = None
class CheckpointProcess:
"""
A checkpoint writer that writes checkpoints to a remote process.
"""
def __init__(
self,
rank_info: RankInfo,
config: CheckpointProcessConfig,
subprocess_init_fn: Callable[[Any], None],
subprocess_init_args: tuple[Any, ...],
checkpoint_writer_init_fn: Callable[..., CheckpointWriter],
checkpoint_writer_init_args: dict[str, Any],
):
self._executor = ThreadPoolExecutor(max_workers=1)
self._rank_info = rank_info
self._config = config
self._subprocess_init_fn = subprocess_init_fn
self._subprocess_init_args = subprocess_init_args
self._checkpoint_writer_init_fn = checkpoint_writer_init_fn
self._checkpoint_writer_init_args = checkpoint_writer_init_args
self.process = None
self._parent_end: Optional[Connection] = None
self._child_end: Optional[Connection] = None
self.process_creation_future = self._executor.submit(
self._create_subprocess,
config,
)
def _create_subprocess(
self,
config: CheckpointProcessConfig,
) -> None:
logger.info(
"Creating checkpoint subprocess for rank %d", self._rank_info.global_rank
)
spawn_context = mp.get_context("spawn")
self._parent_end, child_end = spawn_context.Pipe()
# Known workaround for https://github.com/pytorch/pytorch/issues/37377
os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU"
logger.debug("Spawning subprocess for rank_info=%s", self._rank_info)
self.process = mp.spawn(
fn=CheckpointProcess._subprocess,
args=(
self._rank_info,
child_end,
self._subprocess_init_fn,
self._subprocess_init_args,
self._checkpoint_writer_init_fn,
self._checkpoint_writer_init_args,
),
nprocs=1,
join=False,
daemon=True,
)
# close the child end of the pipe so recv on it will fail
# fast when the child process is terminated unexpectedly.
child_end.close()
self._send(
request_type=RequestType.PING,
payload={},
)
logger.debug(
"Waiting for checkpoint subprocess to initialize (timeout: %ds)",
config.subprocess_init_timeout_secs,
)
# wait for the timeout or a response from subprocess
if self._parent_end is None:
raise AssertionError("Parent end of pipe should be initialized")
if not self._parent_end.poll(timeout=config.subprocess_init_timeout_secs):
msg = f"Timed out after {config.subprocess_init_timeout_secs}s waiting for checkpoint subprocess to initialize"
logger.error(msg)
raise TimeoutError(msg)
self._recv()
logger.info("Checkpoint subprocess initialized successfully")
@staticmethod
def _subprocess(
sub_rank: int,
rank_info: RankInfo,
parent_pipe: Connection,
subprocess_init_fn: Callable[[Any], None],
subprocess_init_args: tuple[Any, ...],
checkpoint_writer_init_fn: Callable[..., CheckpointWriter],
checkpoint_writer_init_args: dict[str, Any],
) -> None:
logger.debug(
"Checkpoint subprocess started for rank %d/%d (PID: %d)",
rank_info.global_rank,
rank_info.global_world_size,
os.getpid(),
)
if sub_rank != 0:
raise AssertionError("We need only one checkpointer per parent training")
request = WorkerRequest(request_type=RequestType.PING, payload={})
try:
# Calling initialize callback, so we can perform app-specific initialization of the subprocess.
subprocess_init_fn(*subprocess_init_args)
# Initialize checkpoint writer - automatically include rank_info in init_args
writer_init_args = dict(checkpoint_writer_init_args)
if "rank_info" not in writer_init_args:
writer_init_args["rank_info"] = rank_info
checkpoint_writer = checkpoint_writer_init_fn(**writer_init_args)
while True:
request = parent_pipe.recv()
if request.request_type == RequestType.PING:
parent_pipe.send(
WorkerResponse(request_type=RequestType.PING, success=True)
)
elif request.request_type == RequestType.WRITE_CHECKPOINT:
path = request.payload["path"]
logger.info("Writing checkpoint to %s", path)
checkpoint_writer.write(
path=path,
state_dict=request.payload["state_dict"],
**request.payload["kwargs"],
)
logger.info("Checkpoint written successfully to %s", path)
parent_pipe.send(
WorkerResponse(RequestType.WRITE_CHECKPOINT, success=True)
)
elif request.request_type == RequestType.TERMINATE_PROCESS:
logger.debug("Received termination request.")
parent_pipe.send(
WorkerResponse(RequestType.TERMINATE_PROCESS, success=True)
)
logger.info("Subprocess terminated gracefully")
break
else:
error_msg = f"Unknown request type: {request.request_type}"
logger.error(error_msg)
raise ValueError(error_msg)
except Exception as e:
error_text = traceback.format_exc()
logger.error(
"Exception in subprocess (%s): %s", type(e).__name__, error_text
)
# Communicating exception via the queue to the main process
parent_pipe.send(
WorkerResponse(
request_type=request.request_type,
success=False,
error_msg=error_text,
)
)
parent_pipe.close()
logger.exception("Subprocess terminated due to exception")
def _send(self, request_type: RequestType, payload: dict[str, Any]) -> None:
try:
if self._parent_end is None:
raise AssertionError("Parent end of pipe should be initialized")
self._parent_end.send(
WorkerRequest(
request_type=request_type,
payload=payload,
)
)
except OSError as e:
error_msg = "Child process terminated unexpectedly"
logger.exception(
"Communication failed during %s request", request_type.value
)
raise RuntimeError(error_msg) from e
def _recv(self) -> Optional[dict[str, Any]]:
try:
if self._parent_end is None:
raise AssertionError("Parent end of pipe should be initialized")
response = self._parent_end.recv()
if response.success is False:
error_msg = (
f"Unexpected response from worker process: {response.error_msg}"
)
logger.error(error_msg)
raise RuntimeError(error_msg)
return response.payload
except (EOFError, BrokenPipeError, ConnectionResetError) as e:
error_msg = f"Child process terminated unexpectedly: {e}"
logger.error(error_msg)
raise RuntimeError(error_msg) from e
def write(
self,
state_dict: Union[STATE_DICT, Future[STATE_DICT]],
path: str,
**kwargs: Any,
) -> Optional[Future[None]]:
logger.debug("Waiting for subprocess initialization to complete")
# wait until the process is started
self.process_creation_future.result()
return self._executor.submit(
self._write,
state_dict,
path,
**kwargs,
)
def _write(
self,
state_dict: Union[STATE_DICT, Future[STATE_DICT]],
path: str,
**kwargs: Any,
) -> None:
logger.debug("Starting checkpoint write to %s", path)
# wait for staging state_dict to be available
if isinstance(state_dict, Future):
logger.debug("Waiting for state_dict Future to resolve")
sd = state_dict.result()
else:
sd = state_dict
# Log state_dict info only if debug logging is enabled (performance-conscious)
if logger.isEnabledFor(logging.DEBUG):
if hasattr(sd, "keys"):
logger.debug("State_dict contains %d keys", len(sd.keys()))
self._send(
request_type=RequestType.WRITE_CHECKPOINT,
payload={
"state_dict": sd,
"path": path,
"kwargs": kwargs,
},
)
logger.debug("Waiting for write completion response")
# wait for response
self._recv()
logger.debug("Checkpoint write to %s completed successfully", path)
def close(self) -> None:
logger.debug(
"Closing CheckpointProcess for rank %d", self._rank_info.global_rank
)
self._executor.shutdown(wait=True, cancel_futures=True)
if self.process and self.process.processes[0].is_alive():
subprocess_pid = self.process.processes[0].pid
# send graceful termination to sub process
try:
# pyrefly: ignore # missing-attribute
self._parent_end.send(
WorkerRequest(
request_type=RequestType.TERMINATE_PROCESS,
payload={},
)
)
except BrokenPipeError:
logger.warning(
"BrokenPipeError when sending termination request - subprocess (PID: %d) may have already terminated",
subprocess_pid,
)
# subprocess terminated unexpectedly and below code will raise a
# ProcessExitedException.
logger.debug(
"Waiting for subprocess to terminate gracefully (timeout: %ds)",
self._config.subprocess_shutdown_timeout_secs,
)
try:
if not self.process.join(
timeout=self._config.subprocess_shutdown_timeout_secs
):
# graceful shutdown failed, kill the process.
logger.warning(
"Subprocess (PID: %d) did not terminate gracefully within %ds, killing it",
subprocess_pid,
self._config.subprocess_shutdown_timeout_secs,
)
self.process.processes[0].kill()
logger.info("Subprocess killed forcefully")
except ProcessExitedException:
logger.exception("ProcessExitedException during subprocess termination")
raise
logger.debug("CheckpointProcess closed successfully")