mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fixes partially #164878 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165256 Approved by: https://github.com/albanD
364 lines
13 KiB
Python
364 lines
13 KiB
Python
import logging
|
|
import os
|
|
import traceback
|
|
from collections.abc import Callable
|
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from multiprocessing.connection import Connection
|
|
from typing import Any, Optional, Union
|
|
|
|
import torch.multiprocessing as mp
|
|
from torch.multiprocessing.spawn import ProcessExitedException
|
|
|
|
from .checkpoint_writer import CheckpointWriter
|
|
from .types import RankInfo, STATE_DICT
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class CheckpointProcessConfig:
|
|
"""
|
|
Configuration options for the CheckpointProcess.
|
|
|
|
This class provides configuration options for the checkpoint process,
|
|
including initialization functions, timeouts, and writer configuration.
|
|
|
|
Attributes:
|
|
subprocess_init_timeout_secs: Maximum time in seconds to wait for subprocess initialization.
|
|
subprocess_shutdown_timeout_secs: Maximum time in seconds to wait for subprocess shutdown.
|
|
"""
|
|
|
|
subprocess_init_timeout_secs: int = 30
|
|
subprocess_shutdown_timeout_secs: int = 60
|
|
|
|
|
|
class RequestType(Enum):
|
|
PING = "ping"
|
|
WRITE_CHECKPOINT = "write_checkpoint"
|
|
TERMINATE_PROCESS = "exit"
|
|
|
|
|
|
@dataclass
|
|
class WorkerRequest:
|
|
"""
|
|
A dataclass for storing the command to be sent to the worker process.
|
|
Note: This relies on pickling to send the command to the worker process. Handle
|
|
backward compatibility accordingly.
|
|
"""
|
|
|
|
request_type: RequestType
|
|
payload: dict[str, Any]
|
|
|
|
|
|
@dataclass
|
|
class WorkerResponse:
|
|
request_type: RequestType
|
|
success: bool
|
|
error_msg: Optional[str] = None
|
|
payload: Optional[dict[str, Any]] = None
|
|
|
|
|
|
class CheckpointProcess:
|
|
"""
|
|
A checkpoint writer that writes checkpoints to a remote process.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
rank_info: RankInfo,
|
|
config: CheckpointProcessConfig,
|
|
subprocess_init_fn: Callable[[Any], None],
|
|
subprocess_init_args: tuple[Any, ...],
|
|
checkpoint_writer_init_fn: Callable[..., CheckpointWriter],
|
|
checkpoint_writer_init_args: dict[str, Any],
|
|
):
|
|
self._executor = ThreadPoolExecutor(max_workers=1)
|
|
self._rank_info = rank_info
|
|
self._config = config
|
|
self._subprocess_init_fn = subprocess_init_fn
|
|
self._subprocess_init_args = subprocess_init_args
|
|
self._checkpoint_writer_init_fn = checkpoint_writer_init_fn
|
|
self._checkpoint_writer_init_args = checkpoint_writer_init_args
|
|
self.process = None
|
|
self._parent_end: Optional[Connection] = None
|
|
self._child_end: Optional[Connection] = None
|
|
|
|
self.process_creation_future = self._executor.submit(
|
|
self._create_subprocess,
|
|
config,
|
|
)
|
|
|
|
def _create_subprocess(
|
|
self,
|
|
config: CheckpointProcessConfig,
|
|
) -> None:
|
|
logger.info(
|
|
"Creating checkpoint subprocess for rank %d", self._rank_info.global_rank
|
|
)
|
|
|
|
spawn_context = mp.get_context("spawn")
|
|
self._parent_end, child_end = spawn_context.Pipe()
|
|
|
|
# Known workaround for https://github.com/pytorch/pytorch/issues/37377
|
|
os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU"
|
|
|
|
logger.debug("Spawning subprocess for rank_info=%s", self._rank_info)
|
|
self.process = mp.spawn(
|
|
fn=CheckpointProcess._subprocess,
|
|
args=(
|
|
self._rank_info,
|
|
child_end,
|
|
self._subprocess_init_fn,
|
|
self._subprocess_init_args,
|
|
self._checkpoint_writer_init_fn,
|
|
self._checkpoint_writer_init_args,
|
|
),
|
|
nprocs=1,
|
|
join=False,
|
|
daemon=True,
|
|
)
|
|
|
|
# close the child end of the pipe so recv on it will fail
|
|
# fast when the child process is terminated unexpectedly.
|
|
child_end.close()
|
|
self._send(
|
|
request_type=RequestType.PING,
|
|
payload={},
|
|
)
|
|
|
|
logger.debug(
|
|
"Waiting for checkpoint subprocess to initialize (timeout: %ds)",
|
|
config.subprocess_init_timeout_secs,
|
|
)
|
|
|
|
# wait for the timeout or a response from subprocess
|
|
if self._parent_end is None:
|
|
raise AssertionError("Parent end of pipe should be initialized")
|
|
if not self._parent_end.poll(timeout=config.subprocess_init_timeout_secs):
|
|
msg = f"Timed out after {config.subprocess_init_timeout_secs}s waiting for checkpoint subprocess to initialize"
|
|
logger.error(msg)
|
|
raise TimeoutError(msg)
|
|
|
|
self._recv()
|
|
logger.info("Checkpoint subprocess initialized successfully")
|
|
|
|
@staticmethod
|
|
def _subprocess(
|
|
sub_rank: int,
|
|
rank_info: RankInfo,
|
|
parent_pipe: Connection,
|
|
subprocess_init_fn: Callable[[Any], None],
|
|
subprocess_init_args: tuple[Any, ...],
|
|
checkpoint_writer_init_fn: Callable[..., CheckpointWriter],
|
|
checkpoint_writer_init_args: dict[str, Any],
|
|
) -> None:
|
|
logger.debug(
|
|
"Checkpoint subprocess started for rank %d/%d (PID: %d)",
|
|
rank_info.global_rank,
|
|
rank_info.global_world_size,
|
|
os.getpid(),
|
|
)
|
|
|
|
if sub_rank != 0:
|
|
raise AssertionError("We need only one checkpointer per parent training")
|
|
request = WorkerRequest(request_type=RequestType.PING, payload={})
|
|
|
|
try:
|
|
# Calling initialize callback, so we can perform app-specific initialization of the subprocess.
|
|
subprocess_init_fn(*subprocess_init_args)
|
|
|
|
# Initialize checkpoint writer - automatically include rank_info in init_args
|
|
writer_init_args = dict(checkpoint_writer_init_args)
|
|
if "rank_info" not in writer_init_args:
|
|
writer_init_args["rank_info"] = rank_info
|
|
checkpoint_writer = checkpoint_writer_init_fn(**writer_init_args)
|
|
|
|
while True:
|
|
request = parent_pipe.recv()
|
|
|
|
if request.request_type == RequestType.PING:
|
|
parent_pipe.send(
|
|
WorkerResponse(request_type=RequestType.PING, success=True)
|
|
)
|
|
elif request.request_type == RequestType.WRITE_CHECKPOINT:
|
|
path = request.payload["path"]
|
|
logger.info("Writing checkpoint to %s", path)
|
|
|
|
checkpoint_writer.write(
|
|
path=path,
|
|
state_dict=request.payload["state_dict"],
|
|
**request.payload["kwargs"],
|
|
)
|
|
|
|
logger.info("Checkpoint written successfully to %s", path)
|
|
parent_pipe.send(
|
|
WorkerResponse(RequestType.WRITE_CHECKPOINT, success=True)
|
|
)
|
|
elif request.request_type == RequestType.TERMINATE_PROCESS:
|
|
logger.debug("Received termination request.")
|
|
parent_pipe.send(
|
|
WorkerResponse(RequestType.TERMINATE_PROCESS, success=True)
|
|
)
|
|
logger.info("Subprocess terminated gracefully")
|
|
break
|
|
else:
|
|
error_msg = f"Unknown request type: {request.request_type}"
|
|
logger.error(error_msg)
|
|
raise ValueError(error_msg)
|
|
|
|
except Exception as e:
|
|
error_text = traceback.format_exc()
|
|
logger.error(
|
|
"Exception in subprocess (%s): %s", type(e).__name__, error_text
|
|
)
|
|
|
|
# Communicating exception via the queue to the main process
|
|
parent_pipe.send(
|
|
WorkerResponse(
|
|
request_type=request.request_type,
|
|
success=False,
|
|
error_msg=error_text,
|
|
)
|
|
)
|
|
parent_pipe.close()
|
|
logger.error("Subprocess terminated due to exception: %s", e)
|
|
|
|
def _send(self, request_type: RequestType, payload: dict[str, Any]) -> None:
|
|
try:
|
|
if self._parent_end is None:
|
|
raise AssertionError("Parent end of pipe should be initialized")
|
|
self._parent_end.send(
|
|
WorkerRequest(
|
|
request_type=request_type,
|
|
payload=payload,
|
|
)
|
|
)
|
|
except OSError as e:
|
|
error_msg = "Child process terminated unexpectedly"
|
|
logger.error(
|
|
"Communication failed during %s request: %s", request_type.value, e
|
|
)
|
|
raise RuntimeError(error_msg) from e
|
|
|
|
def _recv(self) -> Optional[dict[str, Any]]:
|
|
try:
|
|
if self._parent_end is None:
|
|
raise AssertionError("Parent end of pipe should be initialized")
|
|
response = self._parent_end.recv()
|
|
if response.success is False:
|
|
error_msg = (
|
|
f"Unexpected response from worker process: {response.error_msg}"
|
|
)
|
|
logger.error(error_msg)
|
|
raise RuntimeError(error_msg)
|
|
return response.payload
|
|
except (EOFError, BrokenPipeError, ConnectionResetError) as e:
|
|
error_msg = f"Child process terminated unexpectedly: {e}"
|
|
logger.error(error_msg)
|
|
raise RuntimeError(error_msg) from e
|
|
|
|
def write(
|
|
self,
|
|
state_dict: Union[STATE_DICT, Future[STATE_DICT]],
|
|
path: str,
|
|
**kwargs: Any,
|
|
) -> Optional[Future[None]]:
|
|
logger.debug("Waiting for subprocess initialization to complete")
|
|
|
|
# wait until the process is started
|
|
self.process_creation_future.result()
|
|
|
|
return self._executor.submit(
|
|
self._write,
|
|
state_dict,
|
|
path,
|
|
**kwargs,
|
|
)
|
|
|
|
def _write(
|
|
self,
|
|
state_dict: Union[STATE_DICT, Future[STATE_DICT]],
|
|
path: str,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
logger.debug("Starting checkpoint write to %s", path)
|
|
|
|
# wait for staging state_dict to be available
|
|
if isinstance(state_dict, Future):
|
|
logger.debug("Waiting for state_dict Future to resolve")
|
|
sd = state_dict.result()
|
|
else:
|
|
sd = state_dict
|
|
|
|
# Log state_dict info only if debug logging is enabled (performance-conscious)
|
|
if logger.isEnabledFor(logging.DEBUG):
|
|
if hasattr(sd, "keys"):
|
|
logger.debug("State_dict contains %d keys", len(sd.keys()))
|
|
|
|
self._send(
|
|
request_type=RequestType.WRITE_CHECKPOINT,
|
|
payload={
|
|
"state_dict": sd,
|
|
"path": path,
|
|
"kwargs": kwargs,
|
|
},
|
|
)
|
|
|
|
logger.debug("Waiting for write completion response")
|
|
# wait for response
|
|
self._recv()
|
|
logger.debug("Checkpoint write to %s completed successfully", path)
|
|
|
|
def close(self) -> None:
|
|
logger.debug(
|
|
"Closing CheckpointProcess for rank %d", self._rank_info.global_rank
|
|
)
|
|
self._executor.shutdown(wait=True, cancel_futures=True)
|
|
|
|
if self.process and self.process.processes[0].is_alive():
|
|
subprocess_pid = self.process.processes[0].pid
|
|
# send graceful termination to sub process
|
|
try:
|
|
# pyrefly: ignore # missing-attribute
|
|
self._parent_end.send(
|
|
WorkerRequest(
|
|
request_type=RequestType.TERMINATE_PROCESS,
|
|
payload={},
|
|
)
|
|
)
|
|
except BrokenPipeError:
|
|
logger.warning(
|
|
"BrokenPipeError when sending termination request - subprocess (PID: %d) may have already terminated",
|
|
subprocess_pid,
|
|
)
|
|
# subprocess terminated unexpectedly and below code will raise a
|
|
# ProcessExitedException.
|
|
|
|
logger.debug(
|
|
"Waiting for subprocess to terminate gracefully (timeout: %ds)",
|
|
self._config.subprocess_shutdown_timeout_secs,
|
|
)
|
|
|
|
try:
|
|
if not self.process.join(
|
|
timeout=self._config.subprocess_shutdown_timeout_secs
|
|
):
|
|
# graceful shutdown failed, kill the process.
|
|
logger.warning(
|
|
"Subprocess (PID: %d) did not terminate gracefully within %ds, killing it",
|
|
subprocess_pid,
|
|
self._config.subprocess_shutdown_timeout_secs,
|
|
)
|
|
self.process.processes[0].kill()
|
|
logger.info("Subprocess killed forcefully")
|
|
except ProcessExitedException as e:
|
|
logger.error(
|
|
"ProcessExitedException during subprocess termination: %s", e
|
|
)
|
|
raise
|
|
|
|
logger.debug("CheckpointProcess closed successfully")
|