mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
PEP585 update - torch/distributed/elastic torch/distributed/checkpoint (#145163)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145163 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
c64e657632
commit
316808e4e9
@ -24,7 +24,7 @@ from dataclasses import dataclass, field
|
||||
from enum import IntFlag
|
||||
from multiprocessing import synchronize
|
||||
from types import FrameType
|
||||
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record
|
||||
@ -100,7 +100,7 @@ def _get_default_signal() -> signal.Signals:
|
||||
return signal.SIGTERM
|
||||
|
||||
|
||||
def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str):
|
||||
def _validate_full_rank(d: dict[int, Any], nprocs: int, what: str):
|
||||
actual_keys = set(d.keys())
|
||||
expected_keys = set(range(nprocs))
|
||||
|
||||
@ -122,7 +122,7 @@ class Std(IntFlag):
|
||||
ALL = OUT | ERR
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, vm: str) -> Union["Std", Dict[int, "Std"]]:
|
||||
def from_str(cls, vm: str) -> Union["Std", dict[int, "Std"]]:
|
||||
"""
|
||||
Example:
|
||||
::
|
||||
@ -143,7 +143,7 @@ class Std(IntFlag):
|
||||
if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0)
|
||||
return to_std(vm)
|
||||
elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2)
|
||||
d: Dict[int, Std] = {}
|
||||
d: dict[int, Std] = {}
|
||||
for m in vm.split(","):
|
||||
i, v = m.split(":")
|
||||
d[int(i)] = to_std(v)
|
||||
@ -155,8 +155,8 @@ class Std(IntFlag):
|
||||
|
||||
|
||||
def to_map(
|
||||
val_or_map: Union[Std, Dict[int, Std]], local_world_size: int
|
||||
) -> Dict[int, Std]:
|
||||
val_or_map: Union[Std, dict[int, Std]], local_world_size: int
|
||||
) -> dict[int, Std]:
|
||||
"""
|
||||
Certain APIs take redirect settings either as a single value (e.g. apply to all
|
||||
local ranks) or as an explicit user-provided mapping. This method is a convenience
|
||||
@ -184,11 +184,11 @@ class LogsDest:
|
||||
For each log type, holds mapping of local rank ids to file paths.
|
||||
"""
|
||||
|
||||
stdouts: Dict[int, str] = field(default_factory=dict)
|
||||
stderrs: Dict[int, str] = field(default_factory=dict)
|
||||
tee_stdouts: Dict[int, str] = field(default_factory=dict)
|
||||
tee_stderrs: Dict[int, str] = field(default_factory=dict)
|
||||
error_files: Dict[int, str] = field(default_factory=dict)
|
||||
stdouts: dict[int, str] = field(default_factory=dict)
|
||||
stderrs: dict[int, str] = field(default_factory=dict)
|
||||
tee_stdouts: dict[int, str] = field(default_factory=dict)
|
||||
tee_stderrs: dict[int, str] = field(default_factory=dict)
|
||||
error_files: dict[int, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
class LogsSpecs(ABC):
|
||||
@ -211,9 +211,9 @@ class LogsSpecs(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir: Optional[str] = None,
|
||||
redirects: Union[Std, Dict[int, Std]] = Std.NONE,
|
||||
tee: Union[Std, Dict[int, Std]] = Std.NONE,
|
||||
local_ranks_filter: Optional[Set[int]] = None,
|
||||
redirects: Union[Std, dict[int, Std]] = Std.NONE,
|
||||
tee: Union[Std, dict[int, Std]] = Std.NONE,
|
||||
local_ranks_filter: Optional[set[int]] = None,
|
||||
) -> None:
|
||||
self._root_log_dir = log_dir
|
||||
self._redirects = redirects
|
||||
@ -223,7 +223,7 @@ class LogsSpecs(ABC):
|
||||
@abstractmethod
|
||||
def reify(
|
||||
self,
|
||||
envs: Dict[int, Dict[str, str]],
|
||||
envs: dict[int, dict[str, str]],
|
||||
) -> LogsDest:
|
||||
"""
|
||||
Given the environment variables, builds destination of log files for each of the local ranks.
|
||||
@ -249,9 +249,9 @@ class DefaultLogsSpecs(LogsSpecs):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir: Optional[str] = None,
|
||||
redirects: Union[Std, Dict[int, Std]] = Std.NONE,
|
||||
tee: Union[Std, Dict[int, Std]] = Std.NONE,
|
||||
local_ranks_filter: Optional[Set[int]] = None,
|
||||
redirects: Union[Std, dict[int, Std]] = Std.NONE,
|
||||
tee: Union[Std, dict[int, Std]] = Std.NONE,
|
||||
local_ranks_filter: Optional[set[int]] = None,
|
||||
) -> None:
|
||||
if log_dir != os.devnull:
|
||||
if not log_dir:
|
||||
@ -278,7 +278,7 @@ class DefaultLogsSpecs(LogsSpecs):
|
||||
|
||||
def reify(
|
||||
self,
|
||||
envs: Dict[int, Dict[str, str]],
|
||||
envs: dict[int, dict[str, str]],
|
||||
) -> LogsDest:
|
||||
"""
|
||||
Uses following scheme to build log destination paths:
|
||||
@ -331,8 +331,8 @@ class DefaultLogsSpecs(LogsSpecs):
|
||||
SYS_STREAM = "" # special case to indicate to output to console
|
||||
stdouts = dict.fromkeys(range(nprocs), SYS_STREAM)
|
||||
stderrs = dict.fromkeys(range(nprocs), SYS_STREAM)
|
||||
tee_stdouts: Dict[int, str] = {}
|
||||
tee_stderrs: Dict[int, str] = {}
|
||||
tee_stdouts: dict[int, str] = {}
|
||||
tee_stderrs: dict[int, str] = {}
|
||||
error_files = {}
|
||||
|
||||
for local_rank in range(nprocs):
|
||||
@ -414,10 +414,10 @@ class RunProcsResult:
|
||||
|
||||
"""
|
||||
|
||||
return_values: Dict[int, Any] = field(default_factory=dict)
|
||||
failures: Dict[int, ProcessFailure] = field(default_factory=dict)
|
||||
stdouts: Dict[int, str] = field(default_factory=dict)
|
||||
stderrs: Dict[int, str] = field(default_factory=dict)
|
||||
return_values: dict[int, Any] = field(default_factory=dict)
|
||||
failures: dict[int, ProcessFailure] = field(default_factory=dict)
|
||||
stdouts: dict[int, str] = field(default_factory=dict)
|
||||
stderrs: dict[int, str] = field(default_factory=dict)
|
||||
|
||||
def is_failed(self) -> bool:
|
||||
return len(self.failures) > 0
|
||||
@ -438,10 +438,10 @@ class PContext(abc.ABC):
|
||||
self,
|
||||
name: str,
|
||||
entrypoint: Union[Callable, str],
|
||||
args: Dict[int, Tuple],
|
||||
envs: Dict[int, Dict[str, str]],
|
||||
args: dict[int, tuple],
|
||||
envs: dict[int, dict[str, str]],
|
||||
logs_specs: LogsSpecs,
|
||||
log_line_prefixes: Optional[Dict[int, str]] = None,
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
):
|
||||
self.name = name
|
||||
# validate that all mappings have the same number of keys and
|
||||
@ -544,7 +544,7 @@ class PContext(abc.ABC):
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
def pids(self) -> Dict[int, int]:
|
||||
def pids(self) -> dict[int, int]:
|
||||
"""Return pids of processes mapped by their respective local_ranks."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -587,11 +587,11 @@ def get_std_cm(std_rd: str, redirect_fn):
|
||||
def _wrap(
|
||||
local_rank: int,
|
||||
fn: Callable,
|
||||
args: Dict[int, Tuple],
|
||||
envs: Dict[int, Dict[str, str]],
|
||||
stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None)
|
||||
stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None)
|
||||
ret_vals: Dict[int, mp.SimpleQueue],
|
||||
args: dict[int, tuple],
|
||||
envs: dict[int, dict[str, str]],
|
||||
stdout_redirects: dict[int, str], # redirect file for stdout (to console if None)
|
||||
stderr_redirects: dict[int, str], # redirect file for stderr (to console if None)
|
||||
ret_vals: dict[int, mp.SimpleQueue],
|
||||
queue_finished_reading_event: synchronize.Event,
|
||||
) -> None:
|
||||
# get the per-rank params up front so we fail fast if no mapping is found
|
||||
@ -621,11 +621,11 @@ class MultiprocessContext(PContext):
|
||||
self,
|
||||
name: str,
|
||||
entrypoint: Callable,
|
||||
args: Dict[int, Tuple],
|
||||
envs: Dict[int, Dict[str, str]],
|
||||
args: dict[int, tuple],
|
||||
envs: dict[int, dict[str, str]],
|
||||
start_method: str,
|
||||
logs_specs: LogsSpecs,
|
||||
log_line_prefixes: Optional[Dict[int, str]] = None,
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
@ -644,7 +644,7 @@ class MultiprocessContext(PContext):
|
||||
}
|
||||
|
||||
# see comments in ``join()`` for what this is
|
||||
self._return_values: Dict[int, Any] = {}
|
||||
self._return_values: dict[int, Any] = {}
|
||||
self._pc: Optional[mp.ProcessContext] = None
|
||||
# Note: set method should ONLY be invoked for the use case when all processes finished
|
||||
# successfully. If any process died on event.wait() calling set() method will deadlock.
|
||||
@ -755,7 +755,7 @@ class MultiprocessContext(PContext):
|
||||
stderrs=self.stderrs,
|
||||
)
|
||||
|
||||
def pids(self) -> Dict[int, int]:
|
||||
def pids(self) -> dict[int, int]:
|
||||
assert self._pc is not None # assertion for mypy type checking
|
||||
return dict(enumerate(self._pc.pids()))
|
||||
|
||||
@ -803,10 +803,10 @@ class SubprocessContext(PContext):
|
||||
self,
|
||||
name: str,
|
||||
entrypoint: str,
|
||||
args: Dict[int, Tuple],
|
||||
envs: Dict[int, Dict[str, str]],
|
||||
args: dict[int, tuple],
|
||||
envs: dict[int, dict[str, str]],
|
||||
logs_specs: LogsSpecs,
|
||||
log_line_prefixes: Optional[Dict[int, str]] = None,
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
@ -818,9 +818,9 @@ class SubprocessContext(PContext):
|
||||
)
|
||||
|
||||
# state vector; _vdone[local_rank] -> is local_rank finished or not
|
||||
self._running_local_ranks: Set[int] = set(range(self.nprocs))
|
||||
self._failures: Dict[int, ProcessFailure] = {}
|
||||
self.subprocess_handlers: Dict[int, SubprocessHandler] = {}
|
||||
self._running_local_ranks: set[int] = set(range(self.nprocs))
|
||||
self._failures: dict[int, ProcessFailure] = {}
|
||||
self.subprocess_handlers: dict[int, SubprocessHandler] = {}
|
||||
|
||||
def _start(self):
|
||||
if self.subprocess_handlers:
|
||||
@ -884,7 +884,7 @@ class SubprocessContext(PContext):
|
||||
else: # there are no failures and procs still running
|
||||
return None
|
||||
|
||||
def pids(self) -> Dict[int, int]:
|
||||
def pids(self) -> dict[int, int]:
|
||||
return {
|
||||
local_rank: sh.proc.pid
|
||||
for local_rank, sh in self.subprocess_handlers.items()
|
||||
|
Reference in New Issue
Block a user