PEP585 update - torch/distributed/elastic torch/distributed/checkpoint (#145163)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145163
Approved by: https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein
2025-01-18 14:58:05 -08:00
committed by PyTorch MergeBot
parent c64e657632
commit 316808e4e9
47 changed files with 311 additions and 344 deletions

View File

@ -24,7 +24,7 @@ from dataclasses import dataclass, field
from enum import IntFlag
from multiprocessing import synchronize
from types import FrameType
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record
@ -100,7 +100,7 @@ def _get_default_signal() -> signal.Signals:
return signal.SIGTERM
def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str):
def _validate_full_rank(d: dict[int, Any], nprocs: int, what: str):
actual_keys = set(d.keys())
expected_keys = set(range(nprocs))
@ -122,7 +122,7 @@ class Std(IntFlag):
ALL = OUT | ERR
@classmethod
def from_str(cls, vm: str) -> Union["Std", Dict[int, "Std"]]:
def from_str(cls, vm: str) -> Union["Std", dict[int, "Std"]]:
"""
Example:
::
@ -143,7 +143,7 @@ class Std(IntFlag):
if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0)
return to_std(vm)
elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2)
d: Dict[int, Std] = {}
d: dict[int, Std] = {}
for m in vm.split(","):
i, v = m.split(":")
d[int(i)] = to_std(v)
@ -155,8 +155,8 @@ class Std(IntFlag):
def to_map(
val_or_map: Union[Std, Dict[int, Std]], local_world_size: int
) -> Dict[int, Std]:
val_or_map: Union[Std, dict[int, Std]], local_world_size: int
) -> dict[int, Std]:
"""
Certain APIs take redirect settings either as a single value (e.g. apply to all
local ranks) or as an explicit user-provided mapping. This method is a convenience
@ -184,11 +184,11 @@ class LogsDest:
For each log type, holds mapping of local rank ids to file paths.
"""
stdouts: Dict[int, str] = field(default_factory=dict)
stderrs: Dict[int, str] = field(default_factory=dict)
tee_stdouts: Dict[int, str] = field(default_factory=dict)
tee_stderrs: Dict[int, str] = field(default_factory=dict)
error_files: Dict[int, str] = field(default_factory=dict)
stdouts: dict[int, str] = field(default_factory=dict)
stderrs: dict[int, str] = field(default_factory=dict)
tee_stdouts: dict[int, str] = field(default_factory=dict)
tee_stderrs: dict[int, str] = field(default_factory=dict)
error_files: dict[int, str] = field(default_factory=dict)
class LogsSpecs(ABC):
@ -211,9 +211,9 @@ class LogsSpecs(ABC):
def __init__(
self,
log_dir: Optional[str] = None,
redirects: Union[Std, Dict[int, Std]] = Std.NONE,
tee: Union[Std, Dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[Set[int]] = None,
redirects: Union[Std, dict[int, Std]] = Std.NONE,
tee: Union[Std, dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[set[int]] = None,
) -> None:
self._root_log_dir = log_dir
self._redirects = redirects
@ -223,7 +223,7 @@ class LogsSpecs(ABC):
@abstractmethod
def reify(
self,
envs: Dict[int, Dict[str, str]],
envs: dict[int, dict[str, str]],
) -> LogsDest:
"""
Given the environment variables, builds destination of log files for each of the local ranks.
@ -249,9 +249,9 @@ class DefaultLogsSpecs(LogsSpecs):
def __init__(
self,
log_dir: Optional[str] = None,
redirects: Union[Std, Dict[int, Std]] = Std.NONE,
tee: Union[Std, Dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[Set[int]] = None,
redirects: Union[Std, dict[int, Std]] = Std.NONE,
tee: Union[Std, dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[set[int]] = None,
) -> None:
if log_dir != os.devnull:
if not log_dir:
@ -278,7 +278,7 @@ class DefaultLogsSpecs(LogsSpecs):
def reify(
self,
envs: Dict[int, Dict[str, str]],
envs: dict[int, dict[str, str]],
) -> LogsDest:
"""
Uses following scheme to build log destination paths:
@ -331,8 +331,8 @@ class DefaultLogsSpecs(LogsSpecs):
SYS_STREAM = "" # special case to indicate to output to console
stdouts = dict.fromkeys(range(nprocs), SYS_STREAM)
stderrs = dict.fromkeys(range(nprocs), SYS_STREAM)
tee_stdouts: Dict[int, str] = {}
tee_stderrs: Dict[int, str] = {}
tee_stdouts: dict[int, str] = {}
tee_stderrs: dict[int, str] = {}
error_files = {}
for local_rank in range(nprocs):
@ -414,10 +414,10 @@ class RunProcsResult:
"""
return_values: Dict[int, Any] = field(default_factory=dict)
failures: Dict[int, ProcessFailure] = field(default_factory=dict)
stdouts: Dict[int, str] = field(default_factory=dict)
stderrs: Dict[int, str] = field(default_factory=dict)
return_values: dict[int, Any] = field(default_factory=dict)
failures: dict[int, ProcessFailure] = field(default_factory=dict)
stdouts: dict[int, str] = field(default_factory=dict)
stderrs: dict[int, str] = field(default_factory=dict)
def is_failed(self) -> bool:
return len(self.failures) > 0
@ -438,10 +438,10 @@ class PContext(abc.ABC):
self,
name: str,
entrypoint: Union[Callable, str],
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
args: dict[int, tuple],
envs: dict[int, dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
log_line_prefixes: Optional[dict[int, str]] = None,
):
self.name = name
# validate that all mappings have the same number of keys and
@ -544,7 +544,7 @@ class PContext(abc.ABC):
return None
@abc.abstractmethod
def pids(self) -> Dict[int, int]:
def pids(self) -> dict[int, int]:
"""Return pids of processes mapped by their respective local_ranks."""
raise NotImplementedError
@ -587,11 +587,11 @@ def get_std_cm(std_rd: str, redirect_fn):
def _wrap(
local_rank: int,
fn: Callable,
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None)
stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None)
ret_vals: Dict[int, mp.SimpleQueue],
args: dict[int, tuple],
envs: dict[int, dict[str, str]],
stdout_redirects: dict[int, str], # redirect file for stdout (to console if None)
stderr_redirects: dict[int, str], # redirect file for stderr (to console if None)
ret_vals: dict[int, mp.SimpleQueue],
queue_finished_reading_event: synchronize.Event,
) -> None:
# get the per-rank params up front so we fail fast if no mapping is found
@ -621,11 +621,11 @@ class MultiprocessContext(PContext):
self,
name: str,
entrypoint: Callable,
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
args: dict[int, tuple],
envs: dict[int, dict[str, str]],
start_method: str,
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
log_line_prefixes: Optional[dict[int, str]] = None,
):
super().__init__(
name,
@ -644,7 +644,7 @@ class MultiprocessContext(PContext):
}
# see comments in ``join()`` for what this is
self._return_values: Dict[int, Any] = {}
self._return_values: dict[int, Any] = {}
self._pc: Optional[mp.ProcessContext] = None
# Note: set method should ONLY be invoked for the use case when all processes finished
# successfully. If any process died on event.wait() calling set() method will deadlock.
@ -755,7 +755,7 @@ class MultiprocessContext(PContext):
stderrs=self.stderrs,
)
def pids(self) -> Dict[int, int]:
def pids(self) -> dict[int, int]:
assert self._pc is not None # assertion for mypy type checking
return dict(enumerate(self._pc.pids()))
@ -803,10 +803,10 @@ class SubprocessContext(PContext):
self,
name: str,
entrypoint: str,
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
args: dict[int, tuple],
envs: dict[int, dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
log_line_prefixes: Optional[dict[int, str]] = None,
):
super().__init__(
name,
@ -818,9 +818,9 @@ class SubprocessContext(PContext):
)
# state vector; _vdone[local_rank] -> is local_rank finished or not
self._running_local_ranks: Set[int] = set(range(self.nprocs))
self._failures: Dict[int, ProcessFailure] = {}
self.subprocess_handlers: Dict[int, SubprocessHandler] = {}
self._running_local_ranks: set[int] = set(range(self.nprocs))
self._failures: dict[int, ProcessFailure] = {}
self.subprocess_handlers: dict[int, SubprocessHandler] = {}
def _start(self):
if self.subprocess_handlers:
@ -884,7 +884,7 @@ class SubprocessContext(PContext):
else: # there are no failures and procs still running
return None
def pids(self) -> Dict[int, int]:
def pids(self) -> dict[int, int]:
return {
local_rank: sh.proc.pid
for local_rank, sh in self.subprocess_handlers.items()