PEP585 update - torch/distributed/elastic torch/distributed/checkpoint (#145163)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145163
Approved by: https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein
2025-01-18 14:58:05 -08:00
committed by PyTorch MergeBot
parent c64e657632
commit 316808e4e9
47 changed files with 311 additions and 344 deletions

View File

@ -1,5 +1,5 @@
from concurrent.futures import Future from concurrent.futures import Future
from typing import Any, Dict, List, Optional from typing import Any, Optional
import torch.distributed as dist import torch.distributed as dist
import torch.distributed.checkpoint.state_dict_loader as loader import torch.distributed.checkpoint.state_dict_loader as loader
@ -13,7 +13,7 @@ from torch.distributed.checkpoint.storage import (
) )
__all__: List[str] = [] __all__: list[str] = []
class _Checkpointer: class _Checkpointer:
@ -90,7 +90,7 @@ class _Checkpointer:
planner=self.save_planner, planner=self.save_planner,
) )
def load(self, state_dict: Dict[str, Any]) -> None: def load(self, state_dict: dict[str, Any]) -> None:
"""Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization.""" """Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization."""
loader.load( loader.load(
state_dict, state_dict,

View File

@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
import dataclasses import dataclasses
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Set, TYPE_CHECKING from typing import TYPE_CHECKING
from torch.distributed.checkpoint.planner import SavePlan, WriteItem from torch.distributed.checkpoint.planner import SavePlan, WriteItem
@ -13,16 +13,16 @@ __all__ = ["dedup_save_plans"]
def dedup_save_plans( def dedup_save_plans(
all_plans: List[SavePlan], all_plans: list[SavePlan],
save_to_lowest_rank: bool = False, save_to_lowest_rank: bool = False,
) -> List[SavePlan]: ) -> list[SavePlan]:
""" """
Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across
a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry. a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry.
""" """
write_item_to_plan_indices: Dict[MetadataIndex, Set[int]] = defaultdict(set) write_item_to_plan_indices: dict[MetadataIndex, set[int]] = defaultdict(set)
write_item_idx_to_write_item: Dict[MetadataIndex, WriteItem] = {} write_item_idx_to_write_item: dict[MetadataIndex, WriteItem] = {}
for plan_idx, plan in enumerate(all_plans): for plan_idx, plan in enumerate(all_plans):
for write_item in plan.items: for write_item in plan.items:
# map each write item to its plan # map each write item to its plan
@ -30,7 +30,7 @@ def dedup_save_plans(
write_item_idx_to_write_item[write_item.index] = write_item write_item_idx_to_write_item[write_item.index] = write_item
# put item in the plan with the smallest size and remove it from the other plan_indices # put item in the plan with the smallest size and remove it from the other plan_indices
to_remove: List[Set] = [set() for _ in range(len(all_plans))] to_remove: list[set] = [set() for _ in range(len(all_plans))]
plan_to_size = [0] * len(all_plans) plan_to_size = [0] * len(all_plans)
for write_item_idx, plan_indices in write_item_to_plan_indices.items(): for write_item_idx, plan_indices in write_item_to_plan_indices.items():
if save_to_lowest_rank: if save_to_lowest_rank:

View File

@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
import dataclasses import dataclasses
import logging import logging
from typing import Dict, List, TYPE_CHECKING from typing import TYPE_CHECKING
from torch.distributed.checkpoint.planner import SavePlan from torch.distributed.checkpoint.planner import SavePlan
@ -31,9 +31,9 @@ logger = init_logger()
# TODO add docstring for dedup_tensors # TODO add docstring for dedup_tensors
def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]: def dedup_tensors(all_plans: list[SavePlan]) -> list[SavePlan]:
all_plans = list(all_plans) all_plans = list(all_plans)
key_to_plan: Dict[MetadataIndex, List[int]] = {} key_to_plan: dict[MetadataIndex, list[int]] = {}
for plan_idx, plan in enumerate(all_plans): for plan_idx, plan in enumerate(all_plans):
for write_item in plan.items: for write_item in plan.items:
key_to_plan.setdefault(write_item.index, []).append(plan_idx) key_to_plan.setdefault(write_item.index, []).append(plan_idx)
@ -42,7 +42,7 @@ def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
# Remove duplicates by always keeping the first entry. # Remove duplicates by always keeping the first entry.
# Compute the per-rank remove set. # Compute the per-rank remove set.
plan_to_keys: Dict[int, List[MetadataIndex]] = {} plan_to_keys: dict[int, list[MetadataIndex]] = {}
for key, plans in replicated_items.items(): for key, plans in replicated_items.items():
for plan_idx in plans[1:]: for plan_idx in plans[1:]:
plan_to_keys.setdefault(plan_idx, []).append(key) plan_to_keys.setdefault(plan_idx, []).append(key)

View File

@ -3,10 +3,10 @@
import io import io
import os import os
from collections.abc import Sequence from collections.abc import Generator, Sequence
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Generator, Optional, TYPE_CHECKING, Union from typing import Optional, TYPE_CHECKING, Union
from fsspec.core import url_to_fs from fsspec.core import url_to_fs

View File

@ -1,5 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Dict
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
@ -21,7 +20,7 @@ Change set_element to recreate the right type for tuple, OrderedDict, and NamedT
""" """
FLATTEN_MAPPING = Dict[str, OBJ_PATH] FLATTEN_MAPPING = dict[str, OBJ_PATH]
# TODO: Update Docstring for nested_dict.py # TODO: Update Docstring for nested_dict.py

View File

@ -1,5 +1,5 @@
import os import os
from typing import List, Type, Union from typing import Union
from .filesystem import FileSystemReader, FileSystemWriter from .filesystem import FileSystemReader, FileSystemWriter
from .storage import StorageReader, StorageWriter from .storage import StorageReader, StorageWriter
@ -21,7 +21,7 @@ def _storage_setup(
"storage_reader/storage_writer is None." "storage_reader/storage_writer is None."
) )
targets: List[Type[Union[StorageReader, StorageWriter]]] = [] targets: list[type[Union[StorageReader, StorageWriter]]] = []
if reader: if reader:
targets = [ targets = [
FileSystemReader, FileSystemReader,

View File

@ -1,15 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
from typing import ( from collections.abc import Collection, Mapping, MutableMapping
Callable, from typing import Callable, cast, Optional, TypeVar, Union
cast,
Collection,
List,
Mapping,
MutableMapping,
Optional,
TypeVar,
Union,
)
import torch import torch
from torch.distributed._shard.sharded_tensor.api import ShardedTensor from torch.distributed._shard.sharded_tensor.api import ShardedTensor
@ -123,7 +114,7 @@ def set_element(
"""Set ``value`` in ``root_dict`` along the ``path`` object path.""" """Set ``value`` in ``root_dict`` along the ``path`` object path."""
cur_container = cast(CONTAINER_TYPE, root_dict) cur_container = cast(CONTAINER_TYPE, root_dict)
def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None: def extend_list(lst: list[STATE_DICT_ITEM], idx: int) -> None:
while len(lst) <= idx: while len(lst) <= idx:
lst.append(None) lst.append(None)
@ -144,7 +135,7 @@ def set_element(
key = path[-1] key = path[-1]
if type(key) == int: if type(key) == int:
extend_list(cast(List[STATE_DICT_ITEM], cur_container), key) extend_list(cast(list[STATE_DICT_ITEM], cur_container), key)
cur_container[key] = value cur_container[key] = value

View File

@ -1,5 +1,5 @@
import traceback as tb import traceback as tb
from typing import Any, Dict from typing import Any
WRAPPED_EXCEPTION = tuple[BaseException, tb.StackSummary] WRAPPED_EXCEPTION = tuple[BaseException, tb.StackSummary]
@ -22,12 +22,12 @@ def _is_wrapped_exception(obj: Any) -> bool:
class CheckpointException(BaseException): class CheckpointException(BaseException):
"""Exception raised if failure was detected as part of a checkpoint load or save.""" """Exception raised if failure was detected as part of a checkpoint load or save."""
def __init__(self, msg: str, failures: Dict[int, WRAPPED_EXCEPTION]): def __init__(self, msg: str, failures: dict[int, WRAPPED_EXCEPTION]):
super().__init__(msg, failures) super().__init__(msg, failures)
self._failures = failures self._failures = failures
@property @property
def failures(self) -> Dict[int, WRAPPED_EXCEPTION]: def failures(self) -> dict[int, WRAPPED_EXCEPTION]:
"""Return a dictionary mapping node ranks to their associated exceptions in case of failure.""" """Return a dictionary mapping node ranks to their associated exceptions in case of failure."""
return self._failures return self._failures

View File

@ -7,7 +7,7 @@ import logging
import operator import operator
from collections import ChainMap from collections import ChainMap
from functools import reduce from functools import reduce
from typing import Any, cast, Dict, List, Optional, Union from typing import Any, cast, Optional, Union
import torch import torch
from torch.distributed._shard._utils import narrow_tensor_by_index from torch.distributed._shard._utils import narrow_tensor_by_index
@ -106,8 +106,8 @@ class DefaultSavePlanner(SavePlanner):
return self.plan return self.plan
def create_global_plan( def create_global_plan(
self, all_plans: List[SavePlan] self, all_plans: list[SavePlan]
) -> tuple[List[SavePlan], Metadata]: ) -> tuple[list[SavePlan], Metadata]:
all_plans = dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank) all_plans = dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank)
global_plan, metadata = create_default_global_save_plan(all_plans) global_plan, metadata = create_default_global_save_plan(all_plans)
@ -234,7 +234,7 @@ class DefaultLoadPlanner(LoadPlanner):
self.state_dict, self.metadata, not self.allow_partial_load self.state_dict, self.metadata, not self.allow_partial_load
) )
def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: def create_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]:
return create_default_global_load_plan(global_plan) return create_default_global_load_plan(global_plan)
def finish_plan(self, new_plan: LoadPlan) -> LoadPlan: def finish_plan(self, new_plan: LoadPlan) -> LoadPlan:
@ -293,7 +293,7 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
if key in self.keys: if key in self.keys:
True True
unflattened_keys: List[str] = [] unflattened_keys: list[str] = []
planner_data = metadata.planner_data.get(key) planner_data = metadata.planner_data.get(key)
for unflattened_key in planner_data: for unflattened_key in planner_data:
if unflattened_keys: if unflattened_keys:
@ -334,7 +334,7 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
def create_default_local_load_plan( def create_default_local_load_plan(
state_dict: Dict[str, Any], metadata: Metadata, strict: bool = True state_dict: dict[str, Any], metadata: Metadata, strict: bool = True
) -> LoadPlan: ) -> LoadPlan:
requests = [] requests = []
""" """
@ -376,8 +376,8 @@ def create_default_local_load_plan(
def create_default_global_load_plan( def create_default_global_load_plan(
all_plans: List[LoadPlan], all_plans: list[LoadPlan],
) -> List[LoadPlan]: ) -> list[LoadPlan]:
""" """
Create global load plan used by DefaultLoadPlanner. Create global load plan used by DefaultLoadPlanner.
@ -388,7 +388,7 @@ def create_default_global_load_plan(
def create_default_local_save_plan( def create_default_local_save_plan(
state_dict: Dict[str, Any], is_coordinator: bool state_dict: dict[str, Any], is_coordinator: bool
) -> SavePlan: ) -> SavePlan:
""" """
Create the ``SavePlan`` used by DefaultSavePlanner. Create the ``SavePlan`` used by DefaultSavePlanner.
@ -415,9 +415,9 @@ def create_default_local_save_plan(
def create_default_global_save_plan( def create_default_global_save_plan(
all_plans: List[SavePlan], all_plans: list[SavePlan],
rewrite_index_hints: bool = True, rewrite_index_hints: bool = True,
) -> tuple[List[SavePlan], Metadata]: ) -> tuple[list[SavePlan], Metadata]:
""" """
Create the global plan and metadata used by DefaultSavePlanner. Create the global plan and metadata used by DefaultSavePlanner.
@ -426,7 +426,7 @@ def create_default_global_save_plan(
The only global planning change is to update index hints in all ``MetadataIndex`` objects if The only global planning change is to update index hints in all ``MetadataIndex`` objects if
``rewrite_index_hints`` is True. ``rewrite_index_hints`` is True.
""" """
md: Dict[str, STORAGE_TYPES] = {} md: dict[str, STORAGE_TYPES] = {}
new_plans = [] new_plans = []
for plan in all_plans: for plan in all_plans:
new_items = [] new_items = []
@ -506,7 +506,7 @@ def _check_box_bounds(
return True return True
def _validate_global_plan(global_plan: List[SavePlan], metadata: Metadata) -> bool: def _validate_global_plan(global_plan: list[SavePlan], metadata: Metadata) -> bool:
all_good = True all_good = True
for key, value in metadata.state_dict_metadata.items(): for key, value in metadata.state_dict_metadata.items():
if isinstance(value, BytesStorageMetadata): if isinstance(value, BytesStorageMetadata):

View File

@ -10,24 +10,12 @@ import threading
import uuid import uuid
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Generator, Iterable, Iterator, Sequence
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from io import UnsupportedOperation from io import UnsupportedOperation
from pathlib import Path from pathlib import Path
from typing import ( from typing import Any, Callable, cast, IO, Optional, Union
Any,
Callable,
cast,
Dict,
Generator,
IO,
Iterable,
Iterator,
List,
Optional,
Union,
)
# introduced as collections.abc.Buffer in Python 3.12 # introduced as collections.abc.Buffer in Python 3.12
from typing_extensions import Buffer from typing_extensions import Buffer
@ -113,7 +101,7 @@ class _TensorLoader(ABC):
class _SerialCpuLoader(_TensorLoader): class _SerialCpuLoader(_TensorLoader):
def __init__(self, resolve_fun: Callable) -> None: def __init__(self, resolve_fun: Callable) -> None:
self.resolve_fun = resolve_fun self.resolve_fun = resolve_fun
self.items: List[tuple[int, object]] = [] self.items: list[tuple[int, object]] = []
def add(self, size: int, obj: object) -> None: def add(self, size: int, obj: object) -> None:
self.items.append((size, obj)) self.items.append((size, obj))
@ -141,7 +129,7 @@ class _OverlappingCpuLoader(_TensorLoader):
inflight_threshhold: int = 1_000_000, inflight_threshhold: int = 1_000_000,
) -> None: ) -> None:
self.resolve_fun = resolve_fun self.resolve_fun = resolve_fun
self.items: List[tuple[int, object]] = [] self.items: list[tuple[int, object]] = []
self.inflight_threshhold = inflight_threshhold self.inflight_threshhold = inflight_threshhold
self.in_flight_data = 0 self.in_flight_data = 0
self.current_items: collections.deque = collections.deque() self.current_items: collections.deque = collections.deque()
@ -161,7 +149,7 @@ class _OverlappingCpuLoader(_TensorLoader):
def _done(self) -> bool: def _done(self) -> bool:
return self.idx >= len(self.items) return self.idx >= len(self.items)
def _drain(self) -> List[tuple[torch.Tensor, object]]: def _drain(self) -> list[tuple[torch.Tensor, object]]:
drained = [] drained = []
if self.in_flight_data >= self.inflight_threshhold: if self.in_flight_data >= self.inflight_threshhold:
self.stream.synchronize() self.stream.synchronize()
@ -243,7 +231,7 @@ class _StorageWriterTransforms:
def transform_save_stream( def transform_save_stream(
self, write_item: WriteItem, raw_stream: io.IOBase self, write_item: WriteItem, raw_stream: io.IOBase
) -> tuple[IO[bytes], List[str]]: ) -> tuple[IO[bytes], list[str]]:
# In order to avoid leaking fds, transformers' close must # In order to avoid leaking fds, transformers' close must
# cascade to wrapped streams, but since this function can # cascade to wrapped streams, but since this function can
# append to the raw stream, we can't close the actual stream. # append to the raw stream, we can't close the actual stream.
@ -285,14 +273,14 @@ def _item_size(item: WriteItem) -> int:
return size * torch._utils._element_size(dtype) return size * torch._utils._element_size(dtype)
def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]: def _split_by_size_and_type(bins: int, items: list[WriteItem]) -> list[list[WriteItem]]:
if bins == 1: if bins == 1:
return [items] return [items]
bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO] bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO]
tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO] tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO]
buckets: List[List[WriteItem]] = [[] for _ in range(bins)] buckets: list[list[WriteItem]] = [[] for _ in range(bins)]
bucket_sizes = [0 for _ in range(bins)] bucket_sizes = [0 for _ in range(bins)]
tensor_w.sort(key=_item_size, reverse=True) tensor_w.sort(key=_item_size, reverse=True)
@ -584,7 +572,7 @@ class _FileSystemWriter(StorageWriter):
return plan return plan
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]: def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
new_plans = [ new_plans = [
dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
for i, plan in enumerate(plans) for i, plan in enumerate(plans)
@ -595,7 +583,7 @@ class _FileSystemWriter(StorageWriter):
self, self,
plan: SavePlan, plan: SavePlan,
planner: SavePlanner, planner: SavePlanner,
) -> Future[List[WriteResult]]: ) -> Future[list[WriteResult]]:
storage_plan: _StoragePrefix = plan.storage_data storage_plan: _StoragePrefix = plan.storage_data
file_count = 0 file_count = 0
@ -656,11 +644,11 @@ class _FileSystemWriter(StorageWriter):
while True: while True:
res += result_queue.get_nowait() res += result_queue.get_nowait()
except queue.Empty: except queue.Empty:
fut: Future[List[WriteResult]] = Future() fut: Future[list[WriteResult]] = Future()
fut.set_result(res) fut.set_result(res)
return fut return fut
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
storage_md = {} storage_md = {}
for wr_list in results: for wr_list in results:
storage_md.update({wr.index: wr.storage_data for wr in wr_list}) storage_md.update({wr.index: wr.storage_data for wr in wr_list})
@ -737,7 +725,7 @@ class FileSystemReader(StorageReader):
super().__init__() super().__init__()
self.fs = FileSystem() self.fs = FileSystem()
self.path = self.fs.init_path(path) self.path = self.fs.init_path(path)
self.storage_data: Dict[MetadataIndex, _StorageInfo] = {} self.storage_data: dict[MetadataIndex, _StorageInfo] = {}
self.load_id = _generate_uuid() self.load_id = _generate_uuid()
self.transforms = _StorageReaderTransforms(_extension_registry) self.transforms = _StorageReaderTransforms(_extension_registry)
@ -752,7 +740,7 @@ class FileSystemReader(StorageReader):
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
# group requests by file # group requests by file
per_file: Dict[str, List[ReadItem]] = {} per_file: dict[str, list[ReadItem]] = {}
for read_item in plan.items: for read_item in plan.items:
item_md = self.storage_data[read_item.storage_index] item_md = self.storage_data[read_item.storage_index]
path = item_md.relative_path path = item_md.relative_path
@ -828,7 +816,7 @@ class FileSystemReader(StorageReader):
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
return plan return plan
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]: def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]:
return plans return plans
@property @property

View File

@ -2,7 +2,7 @@
import argparse import argparse
import os import os
from enum import Enum from enum import Enum
from typing import cast, Dict, List, Optional, Union from typing import cast, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -133,7 +133,7 @@ class BroadcastingTorchSaveReader(StorageReader):
"""Implementation of the StorageReader method""" """Implementation of the StorageReader method"""
return plan return plan
def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: def prepare_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]:
"""Implementation of the StorageReader method""" """Implementation of the StorageReader method"""
return global_plan return global_plan
@ -177,7 +177,7 @@ class DynamicMetaLoadPlanner(DefaultLoadPlanner):
"""Setups of the planner, extnding default behavior by creating the Metadata object from the state dict""" """Setups of the planner, extnding default behavior by creating the Metadata object from the state dict"""
super().set_up_planner(state_dict, metadata, is_coordinator) super().set_up_planner(state_dict, metadata, is_coordinator)
state_dict_metadata: Dict[str, STORAGE_TYPES] = {} state_dict_metadata: dict[str, STORAGE_TYPES] = {}
for key, tensor in self.state_dict.items(): for key, tensor in self.state_dict.items():
if not torch.is_tensor(tensor): if not torch.is_tensor(tensor):
raise RuntimeError( raise RuntimeError(

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import functools import functools
import time import time
from typing import Any, Callable, Dict, List, TypeVar from typing import Any, Callable, TypeVar
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from uuid import uuid4 from uuid import uuid4
@ -9,7 +9,7 @@ import torch.distributed.c10d_logger as c10d_logger
from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME
__all__: List[str] = [] __all__: list[str] = []
global _dcp_logger global _dcp_logger
_dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME) _dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME)
@ -18,7 +18,7 @@ _T = TypeVar("_T")
_P = ParamSpec("_P") _P = ParamSpec("_P")
def _msg_dict_from_dcp_method_args(*args, **kwargs) -> Dict[str, Any]: def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]:
""" """
Extracts log data from dcp method args Extracts log data from dcp method args
""" """
@ -52,7 +52,7 @@ def _msg_dict_from_dcp_method_args(*args, **kwargs) -> Dict[str, Any]:
return msg_dict return msg_dict
def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]: def _get_msg_dict(func_name, *args, **kwargs) -> dict[str, Any]:
msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs) msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs)
msg_dict.update(c10d_logger._get_msg_dict(func_name, *args, **kwargs)) msg_dict.update(c10d_logger._get_msg_dict(func_name, *args, **kwargs))

View File

@ -1,10 +1,9 @@
import logging import logging
from typing import List
from torch.distributed.logging_handlers import _log_handlers from torch.distributed.logging_handlers import _log_handlers
__all__: List[str] = [] __all__: list[str] = []
DCP_LOGGER_NAME = "dcp_logger" DCP_LOGGER_NAME = "dcp_logger"

View File

@ -1,8 +1,9 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import os import os
from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Union from typing import Any, Optional, Union
import torch import torch
from torch.distributed.checkpoint.stateful import StatefulT from torch.distributed.checkpoint.stateful import StatefulT
@ -113,7 +114,7 @@ class TensorProperties:
class TensorStorageMetadata: class TensorStorageMetadata:
properties: TensorProperties properties: TensorProperties
size: torch.Size size: torch.Size
chunks: List[ChunkStorageMetadata] chunks: list[ChunkStorageMetadata]
@dataclass @dataclass
@ -122,7 +123,7 @@ class BytesStorageMetadata:
STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata] STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata]
STATE_DICT_TYPE = Dict[str, Union[StatefulT, Any]] STATE_DICT_TYPE = dict[str, Union[StatefulT, Any]]
@dataclass @dataclass
@ -137,7 +138,7 @@ class Metadata:
"""This class represents the metadata of the checkpoint.""" """This class represents the metadata of the checkpoint."""
# Keys are the same from the `state_dict` used. # Keys are the same from the `state_dict` used.
state_dict_metadata: Dict[str, STORAGE_TYPES] state_dict_metadata: dict[str, STORAGE_TYPES]
# It is the responsibility of the planner and storage plugins to ensure # It is the responsibility of the planner and storage plugins to ensure
# backward compatibility of the planner_data and storage_data. DCP will # backward compatibility of the planner_data and storage_data. DCP will
# also ensure the backward compatibility of the metadata in this file and # also ensure the backward compatibility of the metadata in this file and

View File

@ -1,7 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
import dataclasses import dataclasses
from typing import cast, Dict, List, Optional, Sequence, Union from collections.abc import Sequence
from typing import cast, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -41,7 +42,7 @@ from torch.distributed.remote_device import _remote_device
from torch.distributed.tensor import DTensor from torch.distributed.tensor import DTensor
STATE_DICT_2D_LAYOUT = Dict[str, tuple[Optional[Sequence[int]], Sequence[int]]] STATE_DICT_2D_LAYOUT = dict[str, tuple[Optional[Sequence[int]], Sequence[int]]]
# TODO: Update docstrings for optimizer.py # TODO: Update docstrings for optimizer.py
@ -77,7 +78,7 @@ def _create_colwise_spec(
] ]
return ChunkShardingSpec( return ChunkShardingSpec(
dim=0, dim=0,
placements=cast(List[Union[_remote_device, str]], placements), placements=cast(list[Union[_remote_device, str]], placements),
) )
@ -154,11 +155,11 @@ def _get_state_dict_2d_layout(
class _ReaderWithOffset(DefaultLoadPlanner): class _ReaderWithOffset(DefaultLoadPlanner):
translation: Dict[MetadataIndex, MetadataIndex] translation: dict[MetadataIndex, MetadataIndex]
state_dict: STATE_DICT_TYPE state_dict: STATE_DICT_TYPE
metadata: Metadata metadata: Metadata
def __init__(self, fqn_to_offset: Dict[str, Sequence[int]]) -> None: def __init__(self, fqn_to_offset: dict[str, Sequence[int]]) -> None:
super().__init__() super().__init__()
self.fqn_to_offset = fqn_to_offset self.fqn_to_offset = fqn_to_offset
self.metadata = Metadata({}) self.metadata = Metadata({})
@ -284,7 +285,7 @@ def load_sharded_optimizer_state_dict(
# Create a state_dict for optimizer state # Create a state_dict for optimizer state
state_dict: STATE_DICT_TYPE = {} state_dict: STATE_DICT_TYPE = {}
fqn_to_offset: Dict[str, Sequence[int]] = {} fqn_to_offset: dict[str, Sequence[int]] = {}
for key, value in metadata.state_dict_metadata.items(): for key, value in metadata.state_dict_metadata.items():
key_path = metadata.planner_data[key] key_path = metadata.planner_data[key]
if key_path[0] != optimizer_key: if key_path[0] != optimizer_key:

View File

@ -4,7 +4,7 @@ import operator
from dataclasses import dataclass from dataclasses import dataclass
from enum import auto, Enum from enum import auto, Enum
from functools import reduce from functools import reduce
from typing import Any, List, Optional, Union from typing import Any, Optional, Union
import torch import torch
from torch.distributed.checkpoint.metadata import ( from torch.distributed.checkpoint.metadata import (
@ -94,14 +94,14 @@ class ReadItem:
@dataclass(frozen=True) @dataclass(frozen=True)
class SavePlan: class SavePlan:
items: List[WriteItem] items: list[WriteItem]
storage_data: Any = None storage_data: Any = None
planner_data: Any = None planner_data: Any = None
@dataclass @dataclass
class LoadPlan: class LoadPlan:
items: List[ReadItem] items: list[ReadItem]
storage_data: Any = None storage_data: Any = None
planner_data: Any = None planner_data: Any = None
@ -231,8 +231,8 @@ class SavePlanner(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def create_global_plan( def create_global_plan(
self, all_plans: List[SavePlan] self, all_plans: list[SavePlan]
) -> tuple[List[SavePlan], Metadata]: ) -> tuple[list[SavePlan], Metadata]:
""" """
Compute the global checkpoint plan and return the local plan of each rank. Compute the global checkpoint plan and return the local plan of each rank.
@ -364,7 +364,7 @@ class LoadPlanner:
""" """
@abc.abstractmethod @abc.abstractmethod
def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: def create_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]:
""" """
Compute the global load plan and return plans for each rank. Compute the global load plan and return plans for each rank.

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import io import io
from typing import Any, Callable, cast, Dict, List from typing import Any, Callable, cast
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -33,7 +33,7 @@ from .resharding import (
) )
__all__: List[str] = ["create_read_items_for_chunk_list"] __all__: list[str] = ["create_read_items_for_chunk_list"]
def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata: def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata:
@ -149,8 +149,8 @@ def _create_read_item_for_tensor(
def create_read_items_for_chunk_list( def create_read_items_for_chunk_list(
fqn: str, fqn: str,
checkpoint_md: TensorStorageMetadata, checkpoint_md: TensorStorageMetadata,
local_chunks: List[ChunkStorageMetadata], local_chunks: list[ChunkStorageMetadata],
) -> List[ReadItem]: ) -> list[ReadItem]:
""" """
Create a list of ``ReadItem`` based on the checkpoint and local chunks. Create a list of ``ReadItem`` based on the checkpoint and local chunks.
@ -218,7 +218,7 @@ def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
return SavePlan(requests) return SavePlan(requests)
def _create_write_items(fqn: str, object: Any) -> List[WriteItem]: def _create_write_items(fqn: str, object: Any) -> list[WriteItem]:
if hasattr(object, "__create_write_items__"): if hasattr(object, "__create_write_items__"):
# DTensor implements _Checkpointable # DTensor implements _Checkpointable
return object.__create_write_items__(fqn, object) return object.__create_write_items__(fqn, object)
@ -244,7 +244,7 @@ def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata:
) )
def _create_chunk_list(tensor: torch.Tensor) -> List[ChunkStorageMetadata]: def _create_chunk_list(tensor: torch.Tensor) -> list[ChunkStorageMetadata]:
if hasattr(tensor, "__create_chunk_list__"): if hasattr(tensor, "__create_chunk_list__"):
# DTensor implements _Checkpointable # DTensor implements _Checkpointable
local_chunks = tensor.__create_chunk_list__() # type: ignore[attr-defined] local_chunks = tensor.__create_chunk_list__() # type: ignore[attr-defined]
@ -263,7 +263,7 @@ def _create_chunk_list(tensor: torch.Tensor) -> List[ChunkStorageMetadata]:
return local_chunks return local_chunks
def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]: def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> list[ReadItem]:
if not isinstance(md, BytesStorageMetadata): if not isinstance(md, BytesStorageMetadata):
try: try:
local_chunks = _create_chunk_list(obj) local_chunks = _create_chunk_list(obj)
@ -286,7 +286,7 @@ def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]:
] ]
def _init_state_dict(state_dict: Dict[str, Any]) -> Any: def _init_state_dict(state_dict: dict[str, Any]) -> Any:
""" """
Initializes meta tensor if the meta tensor is DTensor or torch.Tensor. Initializes meta tensor if the meta tensor is DTensor or torch.Tensor.
""" """

View File

@ -1,10 +1,9 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing import List
from torch.distributed.checkpoint.metadata import ChunkStorageMetadata from torch.distributed.checkpoint.metadata import ChunkStorageMetadata
__all__: List[str] = [] __all__: list[str] = []
def _check_shard_metadata_pair_overlap( def _check_shard_metadata_pair_overlap(
@ -27,7 +26,7 @@ def _check_shard_metadata_pair_overlap(
def _shards_get_overlap_region_wrt_saved_tensor( def _shards_get_overlap_region_wrt_saved_tensor(
saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata
) -> List[tuple[int, int, int, int]]: ) -> list[tuple[int, int, int, int]]:
""" """
Return the overlapping region between saved_shard and current_shard. Return the overlapping region between saved_shard and current_shard.

View File

@ -3,21 +3,10 @@ import contextlib
import functools import functools
import gc import gc
import warnings import warnings
from collections.abc import Generator, Iterable
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from itertools import chain from itertools import chain
from typing import ( from typing import Any, Callable, cast, no_type_check, Optional, Union
Any,
Callable,
cast,
Dict,
Generator,
Iterable,
List,
no_type_check,
Optional,
Set,
Union,
)
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -76,17 +65,17 @@ _PG = "param_groups"
_PARAMS = "params" _PARAMS = "params"
_STATE = "state" _STATE = "state"
FQNS_T = Set[str] FQNS_T = set[str]
PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str] PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str]
ValueType = Union[ ValueType = Union[
PrimitiveType, List[PrimitiveType], tuple[PrimitiveType], Dict[str, "ValueType"] PrimitiveType, list[PrimitiveType], tuple[PrimitiveType], dict[str, "ValueType"]
] ]
DictValueType = Dict[str, ValueType] DictValueType = dict[str, ValueType]
ListDictValueType = List[DictValueType] ListDictValueType = list[DictValueType]
OptimizerStateType = Dict[str, Union[DictValueType, ListDictValueType]] OptimizerStateType = dict[str, Union[DictValueType, ListDictValueType]]
_patched_state_dict: Set[Callable] = set() _patched_state_dict: set[Callable] = set()
@contextlib.contextmanager @contextlib.contextmanager
@ -149,20 +138,20 @@ class StateDictOptions:
@dataclass @dataclass
class _StateDictInfo(StateDictOptions): class _StateDictInfo(StateDictOptions):
fqn_param_mapping: Dict[ fqn_param_mapping: dict[
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
] = field(default_factory=dict) ] = field(default_factory=dict)
shared_params_mapping: Dict[ shared_params_mapping: dict[
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
] = field(default_factory=dict) ] = field(default_factory=dict)
submodule_prefixes: Set[str] = field(default_factory=set) submodule_prefixes: set[str] = field(default_factory=set)
handle_model: bool = True handle_model: bool = True
handle_optim: bool = True handle_optim: bool = True
fsdp_context: Callable = contextlib.nullcontext fsdp_context: Callable = contextlib.nullcontext
fsdp_modules: List[nn.Module] = field(default_factory=list) fsdp_modules: list[nn.Module] = field(default_factory=list)
@functools.lru_cache(maxsize=None) @functools.cache
def _get_fqns( def _get_fqns(
model: nn.Module, model: nn.Module,
name: str, name: str,
@ -230,7 +219,7 @@ class _EXTRA_STATE:
def _iterate_valid_model_state(model): def _iterate_valid_model_state(model):
visited_modules: Set[nn.Module] = set() visited_modules: set[nn.Module] = set()
def recurse(module: nn.Module, curr_fqn: str) -> Generator: def recurse(module: nn.Module, curr_fqn: str) -> Generator:
visited_modules.add(module) visited_modules.add(module)
@ -265,7 +254,7 @@ def _verify_options(
optims: tuple[torch.optim.Optimizer, ...], optims: tuple[torch.optim.Optimizer, ...],
optim_only: bool, optim_only: bool,
*, *,
submodules: Optional[Set[nn.Module]] = None, submodules: Optional[set[nn.Module]] = None,
options: Optional[StateDictOptions] = None, options: Optional[StateDictOptions] = None,
) -> _StateDictInfo: ) -> _StateDictInfo:
""" """
@ -285,11 +274,11 @@ def _verify_options(
options = options or StateDictOptions() options = options or StateDictOptions()
fqn_param_mapping: Dict[ fqn_param_mapping: dict[
Union[str, torch.Tensor], Union[Set[str], torch.Tensor] Union[str, torch.Tensor], Union[set[str], torch.Tensor]
] = {} ] = {}
shared_params_mapping: Dict[ shared_params_mapping: dict[
Union[str, torch.Tensor], Union[Set[str], torch.Tensor] Union[str, torch.Tensor], Union[set[str], torch.Tensor]
] = {} ] = {}
for name, param in _iterate_valid_model_state(model): for name, param in _iterate_valid_model_state(model):
if isinstance(param, _EXTRA_STATE): if isinstance(param, _EXTRA_STATE):
@ -298,7 +287,7 @@ def _verify_options(
fqns = _get_fqns(model, name) fqns = _get_fqns(model, name)
fqn = fqn_param_mapping.get(param, None) fqn = fqn_param_mapping.get(param, None)
if fqn is not None: if fqn is not None:
cast(Set[str], fqn_param_mapping[param]).update(fqns) cast(set[str], fqn_param_mapping[param]).update(fqns)
shared_params_mapping[param] = fqn_param_mapping[param] shared_params_mapping[param] = fqn_param_mapping[param]
else: else:
# We need to do copy as _get_fqns is lru_cached # We need to do copy as _get_fqns is lru_cached
@ -311,7 +300,7 @@ def _verify_options(
for fqn in fqns_: for fqn in fqns_:
shared_params_mapping[fqn] = cast(torch.Tensor, param_) shared_params_mapping[fqn] = cast(torch.Tensor, param_)
submodule_prefixes: Set[str] = set() submodule_prefixes: set[str] = set()
if submodules: if submodules:
submodules = set(submodules) submodules = set(submodules)
for name, module in model.named_modules(): for name, module in model.named_modules():
@ -384,14 +373,14 @@ def _verify_options(
shared_params_mapping=shared_params_mapping, shared_params_mapping=shared_params_mapping,
submodule_prefixes=submodule_prefixes, submodule_prefixes=submodule_prefixes,
fsdp_context=fsdp_context, fsdp_context=fsdp_context,
fsdp_modules=cast(List[nn.Module], fsdp_modules), fsdp_modules=cast(list[nn.Module], fsdp_modules),
handle_model=not optim_only, handle_model=not optim_only,
handle_optim=(len(optims) > 0), handle_optim=(len(optims) > 0),
) )
def _verify_state_dict( def _verify_state_dict(
model_state_dict: Dict[str, ValueType], model_state_dict: dict[str, ValueType],
optim_state_dict: OptimizerStateType, optim_state_dict: OptimizerStateType,
info: _StateDictInfo, info: _StateDictInfo,
) -> None: ) -> None:
@ -443,8 +432,8 @@ def _state_dict_fn(obj: Union[nn.Module, torch.optim.Optimizer], api: str) -> Ca
def _maybe_full_or_cpu_state_dict( def _maybe_full_or_cpu_state_dict(
state_dict: Dict[str, Any], info: _StateDictInfo state_dict: dict[str, Any], info: _StateDictInfo
) -> Dict[str, Any]: ) -> dict[str, Any]:
if info.full_state_dict: if info.full_state_dict:
ranks_only = ( ranks_only = (
() ()
@ -463,7 +452,7 @@ def _maybe_full_or_cpu_state_dict(
@torch.no_grad() @torch.no_grad()
def _get_model_state_dict( def _get_model_state_dict(
model: nn.Module, info: _StateDictInfo model: nn.Module, info: _StateDictInfo
) -> Dict[str, ValueType]: ) -> dict[str, ValueType]:
if not info.handle_model: if not info.handle_model:
return {} return {}
@ -500,7 +489,7 @@ def _get_model_state_dict(
state_dict[fqn] = state_dict.pop(key) state_dict[fqn] = state_dict.pop(key)
if info.submodule_prefixes: if info.submodule_prefixes:
new_state_dict: Dict[str, ValueType] = {} new_state_dict: dict[str, ValueType] = {}
# TODO: make this faster. # TODO: make this faster.
for fqn in state_dict.keys(): for fqn in state_dict.keys():
for prefix in info.submodule_prefixes: for prefix in info.submodule_prefixes:
@ -531,7 +520,7 @@ def _get_model_state_dict(
@torch.no_grad() @torch.no_grad()
def _load_model_state_dict( def _load_model_state_dict(
model: nn.Module, model: nn.Module,
state_dict: Dict[str, ValueType], state_dict: dict[str, ValueType],
info: _StateDictInfo, info: _StateDictInfo,
) -> _IncompatibleKeys: ) -> _IncompatibleKeys:
if not info.handle_model or (not state_dict and not info.broadcast_from_rank0): if not info.handle_model or (not state_dict and not info.broadcast_from_rank0):
@ -636,7 +625,7 @@ def _init_optim_state(optim: torch.optim.Optimizer) -> None:
optim.zero_grad(set_to_none=True) optim.zero_grad(set_to_none=True)
def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, ValueType]: def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> dict[str, ValueType]:
""" """
This API flattens the optimizer state_dict to support optimizer resharding for This API flattens the optimizer state_dict to support optimizer resharding for
MPMD, e.g., pipeline parallelism. MPMD, e.g., pipeline parallelism.
@ -686,7 +675,7 @@ def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, Value
f"Type is {type(v)}." f"Type is {type(v)}."
) )
ret: Dict[str, ValueType] = {} ret: dict[str, ValueType] = {}
for fqn, state in cast(DictValueType, state_dict[_STATE]).items(): for fqn, state in cast(DictValueType, state_dict[_STATE]).items():
for k, v in cast(DictValueType, state).items(): for k, v in cast(DictValueType, state).items():
_raise_if_type_not_supported(v) _raise_if_type_not_supported(v)
@ -694,7 +683,7 @@ def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, Value
for param_group in cast(ListDictValueType, state_dict[_PG]): for param_group in cast(ListDictValueType, state_dict[_PG]):
fqns = param_group.pop(_PARAMS) fqns = param_group.pop(_PARAMS)
for fqn in cast(List[str], fqns): for fqn in cast(list[str], fqns):
for k, v in param_group.items(): for k, v in param_group.items():
ret[f"{_PG}.{fqn}.{k}"] = v ret[f"{_PG}.{fqn}.{k}"] = v
return ret return ret
@ -702,7 +691,7 @@ def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, Value
def _unflatten_optim_state_dict( def _unflatten_optim_state_dict(
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
state_dict: Dict[str, ValueType], state_dict: dict[str, ValueType],
info: _StateDictInfo, info: _StateDictInfo,
) -> OptimizerStateType: ) -> OptimizerStateType:
""" """
@ -728,7 +717,7 @@ def _unflatten_optim_state_dict(
f"{_STATE}.{fqn}.{state_name}" f"{_STATE}.{fqn}.{state_name}"
] ]
first_param_fqn = cast(List[str], pg_state[-1][_PARAMS])[0] first_param_fqn = cast(list[str], pg_state[-1][_PARAMS])[0]
for k in param_group.keys(): for k in param_group.keys():
if k == _PARAMS: if k == _PARAMS:
continue continue
@ -833,7 +822,7 @@ def _split_optim_state_dict(
state: DictValueType = {} state: DictValueType = {}
pg_state: ListDictValueType = [] pg_state: ListDictValueType = []
return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state}
pg_mapping: Dict[int, int] = {} pg_mapping: dict[int, int] = {}
if all( if all(
isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys() isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys()
@ -849,7 +838,7 @@ def _split_optim_state_dict(
for loaded_param_group in cast( for loaded_param_group in cast(
ListDictValueType, optim_state_dict[_PG] ListDictValueType, optim_state_dict[_PG]
): ):
if fqn in cast(List[str], loaded_param_group[_PARAMS]): if fqn in cast(list[str], loaded_param_group[_PARAMS]):
in_params = True in_params = True
break break
else: else:
@ -865,7 +854,7 @@ def _split_optim_state_dict(
for loaded_param_group in cast( for loaded_param_group in cast(
ListDictValueType, optim_state_dict[_PG] ListDictValueType, optim_state_dict[_PG]
): ):
if fqn in cast(List[str], loaded_param_group[_PARAMS]): if fqn in cast(list[str], loaded_param_group[_PARAMS]):
pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1
for param_group in cast(ListDictValueType, optim_state_dict[_PG]): for param_group in cast(ListDictValueType, optim_state_dict[_PG]):
@ -900,7 +889,7 @@ def _load_optim_state_dict(
) )
else: else:
optim_state_dict = _unflatten_optim_state_dict( optim_state_dict = _unflatten_optim_state_dict(
optim, cast(Dict[str, ValueType], state_dict), info optim, cast(dict[str, ValueType], state_dict), info
) )
else: else:
optim_state_dict = {} optim_state_dict = {}
@ -919,7 +908,7 @@ def _load_optim_state_dict(
fqn = fqns.pop() fqn = fqns.pop()
fqn_with_compiler = fqns_with_compiler.pop() fqn_with_compiler = fqns_with_compiler.pop()
for g in optim_state_dict[_PG]: for g in optim_state_dict[_PG]:
val = cast(Dict[str, Any], g) val = cast(dict[str, Any], g)
params = [ params = [
key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS] key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS]
] ]
@ -978,9 +967,9 @@ def _load_optim_state_dict(
def get_model_state_dict( def get_model_state_dict(
model: nn.Module, model: nn.Module,
*, *,
submodules: Optional[Set[nn.Module]] = None, submodules: Optional[set[nn.Module]] = None,
options: Optional[StateDictOptions] = None, options: Optional[StateDictOptions] = None,
) -> Dict[str, ValueType]: ) -> dict[str, ValueType]:
""" """
Return the model state_dict of ``model``. Return the model state_dict of ``model``.
@ -1016,7 +1005,7 @@ def get_optimizer_state_dict(
model: nn.Module, model: nn.Module,
optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
*, *,
submodules: Optional[Set[nn.Module]] = None, submodules: Optional[set[nn.Module]] = None,
options: Optional[StateDictOptions] = None, options: Optional[StateDictOptions] = None,
) -> OptimizerStateType: ) -> OptimizerStateType:
""" """
@ -1061,9 +1050,9 @@ def get_state_dict(
model: nn.Module, model: nn.Module,
optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
*, *,
submodules: Optional[Set[nn.Module]] = None, submodules: Optional[set[nn.Module]] = None,
options: Optional[StateDictOptions] = None, options: Optional[StateDictOptions] = None,
) -> tuple[Dict[str, ValueType], OptimizerStateType]: ) -> tuple[dict[str, ValueType], OptimizerStateType]:
""" """
Return the model state_dict and optimizers state_dict. Return the model state_dict and optimizers state_dict.
@ -1148,8 +1137,8 @@ def get_state_dict(
def _unflatten_model_state_dict( def _unflatten_model_state_dict(
model: nn.Module, model: nn.Module,
state_dict: Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]], state_dict: Union[dict[nn.Module, dict[str, ValueType]], dict[str, ValueType]],
) -> Dict[str, ValueType]: ) -> dict[str, ValueType]:
if not state_dict: if not state_dict:
return {} return {}
@ -1161,8 +1150,8 @@ def _unflatten_model_state_dict(
"same functionality.", "same functionality.",
FutureWarning, FutureWarning,
) )
cast_state_dict = cast(Dict[nn.Module, Dict[str, ValueType]], state_dict) cast_state_dict = cast(dict[nn.Module, dict[str, ValueType]], state_dict)
new_state_dict: Dict[str, ValueType] = {} new_state_dict: dict[str, ValueType] = {}
for submodule, sub_state_dict in cast_state_dict.items(): for submodule, sub_state_dict in cast_state_dict.items():
for name, m in model.named_modules(): for name, m in model.named_modules():
if m != submodule: if m != submodule:
@ -1176,12 +1165,12 @@ def _unflatten_model_state_dict(
) )
return new_state_dict return new_state_dict
else: else:
return cast(Dict[str, ValueType], state_dict) return cast(dict[str, ValueType], state_dict)
def set_model_state_dict( def set_model_state_dict(
model: nn.Module, model: nn.Module,
model_state_dict: Dict[str, ValueType], model_state_dict: dict[str, ValueType],
*, *,
options: Optional[StateDictOptions] = None, options: Optional[StateDictOptions] = None,
) -> _IncompatibleKeys: ) -> _IncompatibleKeys:
@ -1208,7 +1197,7 @@ def set_model_state_dict(
:type model_state_dict: typing.Dict[str, ValueType] :type model_state_dict: typing.Dict[str, ValueType]
""" """
model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict( model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict(
model, model_state_dict model, model_state_dict
) )
with _gc_context(): with _gc_context():
@ -1261,7 +1250,7 @@ def set_state_dict(
model: nn.Module, model: nn.Module,
optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
*, *,
model_state_dict: Dict[str, ValueType], model_state_dict: dict[str, ValueType],
optim_state_dict: OptimizerStateType, optim_state_dict: OptimizerStateType,
options: Optional[StateDictOptions] = None, options: Optional[StateDictOptions] = None,
) -> _IncompatibleKeys: ) -> _IncompatibleKeys:
@ -1299,7 +1288,7 @@ def set_state_dict(
:type optim_state_dict: typing.OptimizerStateType :type optim_state_dict: typing.OptimizerStateType
""" """
model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict( model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict(
model, model_state_dict model, model_state_dict
) )
with _gc_context(): with _gc_context():
@ -1363,7 +1352,7 @@ def _patch_model_state_dict(
options=options, options=options,
) )
def load_state_dict_call(state_dict: Dict[str, Any]): def load_state_dict_call(state_dict: dict[str, Any]):
_load_state_dict_call(model_state_dict=state_dict) _load_state_dict_call(model_state_dict=state_dict)
model.load_state_dict = load_state_dict_call model.load_state_dict = load_state_dict_call
@ -1422,7 +1411,7 @@ def _patch_optimizer_state_dict(
options=options, options=options,
) )
def load_state_dict_call(state_dict: Dict[str, Any]): def load_state_dict_call(state_dict: dict[str, Any]):
_load_state_dict_call(optim_state_dict=state_dict) _load_state_dict_call(optim_state_dict=state_dict)
_patched_state_dict.add(state_dict_call) _patched_state_dict.add(state_dict_call)

View File

@ -2,7 +2,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import os import os
import warnings import warnings
from typing import Any, cast, Dict, Optional, Set, Union from typing import Any, cast, Optional, Union
from typing_extensions import deprecated from typing_extensions import deprecated
import torch import torch
@ -27,7 +27,7 @@ __all__ = ["load_state_dict", "load"]
category=FutureWarning, category=FutureWarning,
) )
def load_state_dict( def load_state_dict(
state_dict: Dict[str, Any], state_dict: dict[str, Any],
storage_reader: StorageReader, storage_reader: StorageReader,
process_group: Optional[dist.ProcessGroup] = None, process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0, coordinator_rank: int = 0,
@ -51,7 +51,7 @@ def load_state_dict(
@_dcp_method_logger(log_exceptions=True) @_dcp_method_logger(log_exceptions=True)
@_api_bc_check @_api_bc_check
def load( def load(
state_dict: Dict[str, Any], state_dict: dict[str, Any],
*, *,
checkpoint_id: Union[str, os.PathLike, None] = None, checkpoint_id: Union[str, os.PathLike, None] = None,
storage_reader: Optional[StorageReader] = None, storage_reader: Optional[StorageReader] = None,
@ -190,7 +190,7 @@ def load(
def _load_state_dict( def _load_state_dict(
state_dict: Dict[str, Any], state_dict: dict[str, Any],
storage_reader: StorageReader, storage_reader: StorageReader,
process_group: Optional[dist.ProcessGroup] = None, process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0, coordinator_rank: int = 0,
@ -241,12 +241,12 @@ def _load_state_dict(
def _load_state_dict_from_keys( def _load_state_dict_from_keys(
keys: Optional[Union[Set[str], str]] = None, keys: Optional[Union[set[str], str]] = None,
*, *,
checkpoint_id: Union[str, os.PathLike, None] = None, checkpoint_id: Union[str, os.PathLike, None] = None,
storage_reader: Optional[StorageReader] = None, storage_reader: Optional[StorageReader] = None,
process_group: Optional[dist.ProcessGroup] = None, process_group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Load only the specified keys from the checkpoint, if no keys are specified, the entire Load only the specified keys from the checkpoint, if no keys are specified, the entire
checkpoint will be loaded. Note, this method completely loads the checkpoint into the checkpoint will be loaded. Note, this method completely loads the checkpoint into the
@ -311,7 +311,7 @@ def _load_state_dict_from_keys(
if isinstance(keys, str): if isinstance(keys, str):
keys = {keys} keys = {keys}
sd: Dict[str, Any] = {} sd: dict[str, Any] = {}
_load_state_dict( _load_state_dict(
state_dict=sd, state_dict=sd,
storage_reader=storage_reader, storage_reader=storage_reader,

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, runtime_checkable, TypeVar from typing import Any, runtime_checkable, TypeVar
from typing_extensions import Protocol from typing_extensions import Protocol
@ -11,7 +11,7 @@ class Stateful(Protocol):
Stateful protocol for objects that can be checkpointed and restored. Stateful protocol for objects that can be checkpointed and restored.
""" """
def state_dict(self) -> Dict[str, Any]: def state_dict(self) -> dict[str, Any]:
""" """
Objects should return their state_dict representation as a dictionary. Objects should return their state_dict representation as a dictionary.
The output of this function will be checkpointed, and later restored in The output of this function will be checkpointed, and later restored in
@ -28,7 +28,7 @@ class Stateful(Protocol):
... ...
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def load_state_dict(self, state_dict: dict[str, Any]) -> None:
""" """
Restore the object's state from the provided state_dict. Restore the object's state from the provided state_dict.

View File

@ -1,7 +1,7 @@
import abc import abc
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List, Optional, Union from typing import Any, Optional, Union
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
from torch.distributed.checkpoint.planner import ( from torch.distributed.checkpoint.planner import (
@ -86,7 +86,7 @@ class StorageWriter(abc.ABC):
""" """
@abc.abstractmethod @abc.abstractmethod
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]: def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
""" """
Perform centralized planning of storage. Perform centralized planning of storage.
@ -105,7 +105,7 @@ class StorageWriter(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def write_data( def write_data(
self, plan: SavePlan, planner: SavePlanner self, plan: SavePlan, planner: SavePlanner
) -> Future[List[WriteResult]]: ) -> Future[list[WriteResult]]:
""" """
Write all items from ``plan`` using ``planner`` to resolve the data. Write all items from ``plan`` using ``planner`` to resolve the data.
@ -127,7 +127,7 @@ class StorageWriter(abc.ABC):
""" """
@abc.abstractmethod @abc.abstractmethod
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
""" """
Write the metadata and marks the current checkpoint as successful. Write the metadata and marks the current checkpoint as successful.
@ -236,7 +236,7 @@ class StorageReader(abc.ABC):
""" """
@abc.abstractmethod @abc.abstractmethod
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]: def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]:
""" """
Perform centralized planning of storage loading. Perform centralized planning of storage loading.

View File

@ -5,10 +5,11 @@ import io
import itertools import itertools
import os import os
import warnings import warnings
from collections.abc import Sequence
from contextlib import contextmanager from contextlib import contextmanager
from functools import wraps from functools import wraps
from pstats import Stats from pstats import Stats
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TypeVar, Union from typing import Any, Callable, cast, Optional, TypeVar, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -31,20 +32,20 @@ R = TypeVar("R")
def _get_failure_dict( def _get_failure_dict(
results: List[Union[T, WRAPPED_EXCEPTION]] results: list[Union[T, WRAPPED_EXCEPTION]]
) -> Dict[int, WRAPPED_EXCEPTION]: ) -> dict[int, WRAPPED_EXCEPTION]:
return cast( return cast(
Dict[int, WRAPPED_EXCEPTION], dict[int, WRAPPED_EXCEPTION],
{i: err for i, err in enumerate(results) if _is_wrapped_exception(err)}, {i: err for i, err in enumerate(results) if _is_wrapped_exception(err)},
) )
def _all_gather_keys( def _all_gather_keys(
local_dict: Dict[Any, Any], group: Optional[dist.ProcessGroup] = None local_dict: dict[Any, Any], group: Optional[dist.ProcessGroup] = None
) -> List[Any]: ) -> list[Any]:
"""Gathers all keys, and returns them sorted.""" """Gathers all keys, and returns them sorted."""
keys = list(local_dict.keys()) keys = list(local_dict.keys())
gathered_keys: List[List[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item] gathered_keys: list[list[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item]
dist.all_gather_object(gathered_keys, keys, group=group) dist.all_gather_object(gathered_keys, keys, group=group)
return sorted(set(itertools.chain.from_iterable(gathered_keys))) return sorted(set(itertools.chain.from_iterable(gathered_keys)))
@ -95,11 +96,11 @@ class _DistWrapper:
) )
return cast(T, object_list[0]) return cast(T, object_list[0])
def gather_object(self, object: T) -> Optional[List[T]]: def gather_object(self, object: T) -> Optional[list[T]]:
"""Implement functionality similar to c10d::gather_object but without distributed enabled.""" """Implement functionality similar to c10d::gather_object but without distributed enabled."""
if self.use_dist: if self.use_dist:
gather_objs = ( gather_objs = (
cast(List[T], [None] * dist.get_world_size(self.group)) cast(list[T], [None] * dist.get_world_size(self.group))
if self.is_coordinator if self.is_coordinator
else None else None
) )
@ -115,10 +116,10 @@ class _DistWrapper:
result = [object] result = [object]
return result return result
def all_gather_object(self, object: T) -> List[T]: def all_gather_object(self, object: T) -> list[T]:
"""Implement functionality similar to c10d::all_gather_object but without distributed enabled.""" """Implement functionality similar to c10d::all_gather_object but without distributed enabled."""
if self.use_dist: if self.use_dist:
gather_objs = cast(List[T], [None] * dist.get_world_size(self.group)) gather_objs = cast(list[T], [None] * dist.get_world_size(self.group))
dist.all_gather_object( dist.all_gather_object(
object_list=gather_objs, obj=object, group=self.group object_list=gather_objs, obj=object, group=self.group
@ -127,10 +128,10 @@ class _DistWrapper:
gather_objs = [object] gather_objs = [object]
return gather_objs return gather_objs
def scatter_object(self, object_list: Optional[List[T]]) -> T: def scatter_object(self, object_list: Optional[list[T]]) -> T:
"""Implement functionality similar to c10d::scatter_object but without distributed enabled.""" """Implement functionality similar to c10d::scatter_object but without distributed enabled."""
if self.use_dist: if self.use_dist:
gather_result = cast(List[T], [None]) gather_result = cast(list[T], [None])
dist.scatter_object_list( dist.scatter_object_list(
scatter_object_output_list=gather_result, scatter_object_output_list=gather_result,
scatter_object_input_list=object_list if self.is_coordinator else None, scatter_object_input_list=object_list if self.is_coordinator else None,
@ -148,7 +149,7 @@ class _DistWrapper:
self, self,
step: str, step: str,
map_fun: Callable[[], T], map_fun: Callable[[], T],
reduce_fun: Callable[[List[T]], List[R]], reduce_fun: Callable[[list[T]], list[R]],
) -> R: ) -> R:
""" """
Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter. Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter.
@ -166,7 +167,7 @@ class _DistWrapper:
local_data = _wrap_exception(e) local_data = _wrap_exception(e)
all_data = self.gather_object(local_data) all_data = self.gather_object(local_data)
all_results: Optional[List[Union[R, CheckpointException]]] = None all_results: Optional[list[Union[R, CheckpointException]]] = None
if self.is_coordinator: if self.is_coordinator:
assert all_data is not None assert all_data is not None
node_failures = _get_failure_dict(all_data) node_failures = _get_failure_dict(all_data)
@ -175,8 +176,8 @@ class _DistWrapper:
try: try:
# N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]? # N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]?
all_results = cast( all_results = cast(
List[Union[R, CheckpointException]], list[Union[R, CheckpointException]],
reduce_fun(cast(List[T], all_data)), reduce_fun(cast(list[T], all_data)),
) )
except BaseException as e: except BaseException as e:
node_failures[self.rank] = _wrap_exception(e) node_failures[self.rank] = _wrap_exception(e)
@ -195,7 +196,7 @@ class _DistWrapper:
self, self,
step: str, step: str,
map_fun: Callable[[], T], map_fun: Callable[[], T],
reduce_fun: Callable[[List[T]], R], reduce_fun: Callable[[list[T]], R],
) -> R: ) -> R:
""" """
Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast. Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast.
@ -219,7 +220,7 @@ class _DistWrapper:
node_failures = _get_failure_dict(all_data) node_failures = _get_failure_dict(all_data)
if len(node_failures) == 0: if len(node_failures) == 0:
try: try:
result = reduce_fun(cast(List[T], all_data)) result = reduce_fun(cast(list[T], all_data))
except BaseException as e: except BaseException as e:
node_failures[self.rank] = _wrap_exception(e) node_failures[self.rank] = _wrap_exception(e)
@ -235,7 +236,7 @@ class _DistWrapper:
self, self,
step: str, step: str,
map_fun: Callable[[], T], map_fun: Callable[[], T],
) -> List[T]: ) -> list[T]:
""" """
Compute a value on each rank, then all_gather them. Compute a value on each rank, then all_gather them.
@ -254,7 +255,7 @@ class _DistWrapper:
node_failures = _get_failure_dict(all_results) node_failures = _get_failure_dict(all_results)
if len(node_failures) > 0: if len(node_failures) > 0:
raise CheckpointException(step, node_failures) raise CheckpointException(step, node_failures)
return cast(List[T], all_results) return cast(list[T], all_results)
def broadcast( def broadcast(
self, self,
@ -331,11 +332,11 @@ def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) ->
return obj return obj
def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> List[int]: def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> list[int]:
return [i_a + i_b for i_a, i_b in zip(a, b)] return [i_a + i_b for i_a, i_b in zip(a, b)]
def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]: def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> list[int]:
return [i_a - i_b for i_a, i_b in zip(a, b)] return [i_a - i_b for i_a, i_b in zip(a, b)]

View File

@ -18,7 +18,7 @@ from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Optional, Union
import torch.distributed.elastic.rendezvous as rdzv import torch.distributed.elastic.rendezvous as rdzv
import torch.distributed.elastic.utils.store as store_util import torch.distributed.elastic.utils.store as store_util
@ -80,7 +80,7 @@ class WorkerSpec:
fn: Optional[Callable] = None fn: Optional[Callable] = None
# TODO @kiuk - make entrypoint a required field # TODO @kiuk - make entrypoint a required field
entrypoint: Union[Callable, str, None] = None entrypoint: Union[Callable, str, None] = None
args: Tuple = () args: tuple = ()
max_restarts: int = 3 max_restarts: int = 3
monitor_interval: float = 0.1 monitor_interval: float = 0.1
master_port: Optional[int] = None master_port: Optional[int] = None
@ -320,7 +320,7 @@ class _RoleInstanceInfo:
return -1 return -1
@staticmethod @staticmethod
def find_role_boundaries(roles_infos: List, role: str) -> tuple[int, int]: def find_role_boundaries(roles_infos: list, role: str) -> tuple[int, int]:
start_idx, end_idx = -1, -1 start_idx, end_idx = -1, -1
for idx, role_info in enumerate(roles_infos): for idx, role_info in enumerate(roles_infos):
if role_info.role == role: if role_info.role == role:
@ -357,8 +357,8 @@ class RunResult:
""" """
state: WorkerState state: WorkerState
return_values: Dict[int, Any] = field(default_factory=dict) return_values: dict[int, Any] = field(default_factory=dict)
failures: Dict[int, ProcessFailure] = field(default_factory=dict) failures: dict[int, ProcessFailure] = field(default_factory=dict)
def is_failed(self) -> bool: def is_failed(self) -> bool:
return self.state == WorkerState.FAILED return self.state == WorkerState.FAILED
@ -448,7 +448,7 @@ class SimpleElasticAgent(ElasticAgent):
return self._worker_group return self._worker_group
@abc.abstractmethod @abc.abstractmethod
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]:
r"""Start ``worker_group.spec.local_world_size`` number of workers. r"""Start ``worker_group.spec.local_world_size`` number of workers.
This is according to worker spec for the worker group . This is according to worker spec for the worker group .
@ -554,7 +554,7 @@ class SimpleElasticAgent(ElasticAgent):
@prof @prof
def _assign_worker_ranks( def _assign_worker_ranks(
self, store, group_rank: int, group_world_size: int, spec: WorkerSpec self, store, group_rank: int, group_world_size: int, spec: WorkerSpec
) -> List[Worker]: ) -> list[Worker]:
"""Determine proper ranks for worker processes. """Determine proper ranks for worker processes.
Fast Path: when all workers have the same role and world size. We calculate Fast Path: when all workers have the same role and world size. We calculate

View File

@ -15,7 +15,7 @@ import socket
import time import time
import uuid import uuid
from string import Template from string import Template
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING from typing import Any, Optional, TYPE_CHECKING
import torch.distributed.elastic.timer as timer import torch.distributed.elastic.timer as timer
from torch.distributed.elastic import events from torch.distributed.elastic import events
@ -163,7 +163,7 @@ class LocalElasticAgent(SimpleElasticAgent):
self._logs_specs = logs_specs self._logs_specs = logs_specs
self._health_check_server: Optional[HealthCheckServer] = None self._health_check_server: Optional[HealthCheckServer] = None
def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None: def _setup_local_watchdog(self, envs: dict[int, dict[str, str]]) -> None:
enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
watchdog_enabled = os.getenv(enable_watchdog_env_name) watchdog_enabled = os.getenv(enable_watchdog_env_name)
watchdog_file_env_name = TORCHELASTIC_TIMER_FILE watchdog_file_env_name = TORCHELASTIC_TIMER_FILE
@ -256,7 +256,7 @@ class LocalElasticAgent(SimpleElasticAgent):
md["signal"] = str(request.signal) md["signal"] = str(request.signal)
md_str = json.dumps(md) md_str = json.dumps(md)
state = "RUNNING" state = "RUNNING"
metadata: Dict[str, EventMetadataValue] = { metadata: dict[str, EventMetadataValue] = {
"run_id": spec.rdzv_handler.get_run_id(), "run_id": spec.rdzv_handler.get_run_id(),
"global_rank": None, "global_rank": None,
"group_rank": wg.group_rank, "group_rank": wg.group_rank,
@ -288,7 +288,7 @@ class LocalElasticAgent(SimpleElasticAgent):
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`. # `torch.distributed.elastic.metrics.prof`.
@prof @prof
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]:
spec = worker_group.spec spec = worker_group.spec
store = worker_group.store store = worker_group.store
assert store is not None assert store is not None
@ -297,9 +297,9 @@ class LocalElasticAgent(SimpleElasticAgent):
use_agent_store: bool = spec.rdzv_handler.use_agent_store use_agent_store: bool = spec.rdzv_handler.use_agent_store
logger.info("use_agent_store: %s", use_agent_store) logger.info("use_agent_store: %s", use_agent_store)
args: Dict[int, Tuple] = {} args: dict[int, tuple] = {}
envs: Dict[int, Dict[str, str]] = {} envs: dict[int, dict[str, str]] = {}
log_line_prefixes: Optional[Dict[int, str]] = ( log_line_prefixes: Optional[dict[int, str]] = (
{} if self._log_line_prefix_template else None {} if self._log_line_prefix_template else None
) )
for worker in worker_group.workers: for worker in worker_group.workers:

View File

@ -1,6 +1,6 @@
import os import os
from collections.abc import Generator
from contextlib import contextmanager, ExitStack from contextlib import contextmanager, ExitStack
from typing import Generator
from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.elastic.multiprocessing.errors import record

View File

@ -37,7 +37,7 @@ from .api import ( # noqa: F401
) )
_events_loggers: Dict[str, logging.Logger] = {} _events_loggers: dict[str, logging.Logger] = {}
def _get_or_create_logger(destination: str = "null") -> logging.Logger: def _get_or_create_logger(destination: str = "null") -> logging.Logger:

View File

@ -10,7 +10,7 @@
import json import json
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from enum import Enum from enum import Enum
from typing import Dict, Optional, Union from typing import Optional, Union
__all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"] __all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"]
@ -42,7 +42,7 @@ class Event:
name: str name: str
source: EventSource source: EventSource
timestamp: int = 0 timestamp: int = 0
metadata: Dict[str, EventMetadataValue] = field(default_factory=dict) metadata: dict[str, EventMetadataValue] = field(default_factory=dict)
def __str__(self): def __str__(self):
return self.serialize() return self.serialize()

View File

@ -7,10 +7,9 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging import logging
from typing import Dict
_log_handlers: Dict[str, logging.Handler] = { _log_handlers: dict[str, logging.Handler] = {
"console": logging.StreamHandler(), "console": logging.StreamHandler(),
"dynamic_rendezvous": logging.NullHandler(), "dynamic_rendezvous": logging.NullHandler(),
"null": logging.NullHandler(), "null": logging.NullHandler(),

View File

@ -11,7 +11,7 @@ import abc
import time import time
from collections import namedtuple from collections import namedtuple
from functools import wraps from functools import wraps
from typing import Dict, Optional from typing import Optional
from typing_extensions import deprecated from typing_extensions import deprecated
@ -37,7 +37,7 @@ MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value
class MetricsConfig: class MetricsConfig:
__slots__ = ["params"] __slots__ = ["params"]
def __init__(self, params: Optional[Dict[str, str]] = None): def __init__(self, params: Optional[dict[str, str]] = None):
self.params = params self.params = params
if self.params is None: if self.params is None:
self.params = {} self.params = {}
@ -72,7 +72,7 @@ class MetricStream:
) )
_metrics_map: Dict[str, MetricHandler] = {} _metrics_map: dict[str, MetricHandler] = {}
_default_metrics_handler: MetricHandler = NullMetricHandler() _default_metrics_handler: MetricHandler = NullMetricHandler()

View File

@ -100,10 +100,10 @@ __all__ = [
def start_processes( def start_processes(
name: str, name: str,
entrypoint: Union[Callable, str], entrypoint: Union[Callable, str],
args: Dict[int, Tuple], args: dict[int, tuple],
envs: Dict[int, Dict[str, str]], envs: dict[int, dict[str, str]],
logs_specs: LogsSpecs, logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None, log_line_prefixes: Optional[dict[int, str]] = None,
start_method: str = "spawn", start_method: str = "spawn",
) -> PContext: ) -> PContext:
""" """

View File

@ -24,7 +24,7 @@ from dataclasses import dataclass, field
from enum import IntFlag from enum import IntFlag
from multiprocessing import synchronize from multiprocessing import synchronize
from types import FrameType from types import FrameType
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union from typing import Any, Callable, Optional, Union
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record
@ -100,7 +100,7 @@ def _get_default_signal() -> signal.Signals:
return signal.SIGTERM return signal.SIGTERM
def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str): def _validate_full_rank(d: dict[int, Any], nprocs: int, what: str):
actual_keys = set(d.keys()) actual_keys = set(d.keys())
expected_keys = set(range(nprocs)) expected_keys = set(range(nprocs))
@ -122,7 +122,7 @@ class Std(IntFlag):
ALL = OUT | ERR ALL = OUT | ERR
@classmethod @classmethod
def from_str(cls, vm: str) -> Union["Std", Dict[int, "Std"]]: def from_str(cls, vm: str) -> Union["Std", dict[int, "Std"]]:
""" """
Example: Example:
:: ::
@ -143,7 +143,7 @@ class Std(IntFlag):
if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0) if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0)
return to_std(vm) return to_std(vm)
elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2) elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2)
d: Dict[int, Std] = {} d: dict[int, Std] = {}
for m in vm.split(","): for m in vm.split(","):
i, v = m.split(":") i, v = m.split(":")
d[int(i)] = to_std(v) d[int(i)] = to_std(v)
@ -155,8 +155,8 @@ class Std(IntFlag):
def to_map( def to_map(
val_or_map: Union[Std, Dict[int, Std]], local_world_size: int val_or_map: Union[Std, dict[int, Std]], local_world_size: int
) -> Dict[int, Std]: ) -> dict[int, Std]:
""" """
Certain APIs take redirect settings either as a single value (e.g. apply to all Certain APIs take redirect settings either as a single value (e.g. apply to all
local ranks) or as an explicit user-provided mapping. This method is a convenience local ranks) or as an explicit user-provided mapping. This method is a convenience
@ -184,11 +184,11 @@ class LogsDest:
For each log type, holds mapping of local rank ids to file paths. For each log type, holds mapping of local rank ids to file paths.
""" """
stdouts: Dict[int, str] = field(default_factory=dict) stdouts: dict[int, str] = field(default_factory=dict)
stderrs: Dict[int, str] = field(default_factory=dict) stderrs: dict[int, str] = field(default_factory=dict)
tee_stdouts: Dict[int, str] = field(default_factory=dict) tee_stdouts: dict[int, str] = field(default_factory=dict)
tee_stderrs: Dict[int, str] = field(default_factory=dict) tee_stderrs: dict[int, str] = field(default_factory=dict)
error_files: Dict[int, str] = field(default_factory=dict) error_files: dict[int, str] = field(default_factory=dict)
class LogsSpecs(ABC): class LogsSpecs(ABC):
@ -211,9 +211,9 @@ class LogsSpecs(ABC):
def __init__( def __init__(
self, self,
log_dir: Optional[str] = None, log_dir: Optional[str] = None,
redirects: Union[Std, Dict[int, Std]] = Std.NONE, redirects: Union[Std, dict[int, Std]] = Std.NONE,
tee: Union[Std, Dict[int, Std]] = Std.NONE, tee: Union[Std, dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[Set[int]] = None, local_ranks_filter: Optional[set[int]] = None,
) -> None: ) -> None:
self._root_log_dir = log_dir self._root_log_dir = log_dir
self._redirects = redirects self._redirects = redirects
@ -223,7 +223,7 @@ class LogsSpecs(ABC):
@abstractmethod @abstractmethod
def reify( def reify(
self, self,
envs: Dict[int, Dict[str, str]], envs: dict[int, dict[str, str]],
) -> LogsDest: ) -> LogsDest:
""" """
Given the environment variables, builds destination of log files for each of the local ranks. Given the environment variables, builds destination of log files for each of the local ranks.
@ -249,9 +249,9 @@ class DefaultLogsSpecs(LogsSpecs):
def __init__( def __init__(
self, self,
log_dir: Optional[str] = None, log_dir: Optional[str] = None,
redirects: Union[Std, Dict[int, Std]] = Std.NONE, redirects: Union[Std, dict[int, Std]] = Std.NONE,
tee: Union[Std, Dict[int, Std]] = Std.NONE, tee: Union[Std, dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[Set[int]] = None, local_ranks_filter: Optional[set[int]] = None,
) -> None: ) -> None:
if log_dir != os.devnull: if log_dir != os.devnull:
if not log_dir: if not log_dir:
@ -278,7 +278,7 @@ class DefaultLogsSpecs(LogsSpecs):
def reify( def reify(
self, self,
envs: Dict[int, Dict[str, str]], envs: dict[int, dict[str, str]],
) -> LogsDest: ) -> LogsDest:
""" """
Uses following scheme to build log destination paths: Uses following scheme to build log destination paths:
@ -331,8 +331,8 @@ class DefaultLogsSpecs(LogsSpecs):
SYS_STREAM = "" # special case to indicate to output to console SYS_STREAM = "" # special case to indicate to output to console
stdouts = dict.fromkeys(range(nprocs), SYS_STREAM) stdouts = dict.fromkeys(range(nprocs), SYS_STREAM)
stderrs = dict.fromkeys(range(nprocs), SYS_STREAM) stderrs = dict.fromkeys(range(nprocs), SYS_STREAM)
tee_stdouts: Dict[int, str] = {} tee_stdouts: dict[int, str] = {}
tee_stderrs: Dict[int, str] = {} tee_stderrs: dict[int, str] = {}
error_files = {} error_files = {}
for local_rank in range(nprocs): for local_rank in range(nprocs):
@ -414,10 +414,10 @@ class RunProcsResult:
""" """
return_values: Dict[int, Any] = field(default_factory=dict) return_values: dict[int, Any] = field(default_factory=dict)
failures: Dict[int, ProcessFailure] = field(default_factory=dict) failures: dict[int, ProcessFailure] = field(default_factory=dict)
stdouts: Dict[int, str] = field(default_factory=dict) stdouts: dict[int, str] = field(default_factory=dict)
stderrs: Dict[int, str] = field(default_factory=dict) stderrs: dict[int, str] = field(default_factory=dict)
def is_failed(self) -> bool: def is_failed(self) -> bool:
return len(self.failures) > 0 return len(self.failures) > 0
@ -438,10 +438,10 @@ class PContext(abc.ABC):
self, self,
name: str, name: str,
entrypoint: Union[Callable, str], entrypoint: Union[Callable, str],
args: Dict[int, Tuple], args: dict[int, tuple],
envs: Dict[int, Dict[str, str]], envs: dict[int, dict[str, str]],
logs_specs: LogsSpecs, logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None, log_line_prefixes: Optional[dict[int, str]] = None,
): ):
self.name = name self.name = name
# validate that all mappings have the same number of keys and # validate that all mappings have the same number of keys and
@ -544,7 +544,7 @@ class PContext(abc.ABC):
return None return None
@abc.abstractmethod @abc.abstractmethod
def pids(self) -> Dict[int, int]: def pids(self) -> dict[int, int]:
"""Return pids of processes mapped by their respective local_ranks.""" """Return pids of processes mapped by their respective local_ranks."""
raise NotImplementedError raise NotImplementedError
@ -587,11 +587,11 @@ def get_std_cm(std_rd: str, redirect_fn):
def _wrap( def _wrap(
local_rank: int, local_rank: int,
fn: Callable, fn: Callable,
args: Dict[int, Tuple], args: dict[int, tuple],
envs: Dict[int, Dict[str, str]], envs: dict[int, dict[str, str]],
stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None) stdout_redirects: dict[int, str], # redirect file for stdout (to console if None)
stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None) stderr_redirects: dict[int, str], # redirect file for stderr (to console if None)
ret_vals: Dict[int, mp.SimpleQueue], ret_vals: dict[int, mp.SimpleQueue],
queue_finished_reading_event: synchronize.Event, queue_finished_reading_event: synchronize.Event,
) -> None: ) -> None:
# get the per-rank params up front so we fail fast if no mapping is found # get the per-rank params up front so we fail fast if no mapping is found
@ -621,11 +621,11 @@ class MultiprocessContext(PContext):
self, self,
name: str, name: str,
entrypoint: Callable, entrypoint: Callable,
args: Dict[int, Tuple], args: dict[int, tuple],
envs: Dict[int, Dict[str, str]], envs: dict[int, dict[str, str]],
start_method: str, start_method: str,
logs_specs: LogsSpecs, logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None, log_line_prefixes: Optional[dict[int, str]] = None,
): ):
super().__init__( super().__init__(
name, name,
@ -644,7 +644,7 @@ class MultiprocessContext(PContext):
} }
# see comments in ``join()`` for what this is # see comments in ``join()`` for what this is
self._return_values: Dict[int, Any] = {} self._return_values: dict[int, Any] = {}
self._pc: Optional[mp.ProcessContext] = None self._pc: Optional[mp.ProcessContext] = None
# Note: set method should ONLY be invoked for the use case when all processes finished # Note: set method should ONLY be invoked for the use case when all processes finished
# successfully. If any process died on event.wait() calling set() method will deadlock. # successfully. If any process died on event.wait() calling set() method will deadlock.
@ -755,7 +755,7 @@ class MultiprocessContext(PContext):
stderrs=self.stderrs, stderrs=self.stderrs,
) )
def pids(self) -> Dict[int, int]: def pids(self) -> dict[int, int]:
assert self._pc is not None # assertion for mypy type checking assert self._pc is not None # assertion for mypy type checking
return dict(enumerate(self._pc.pids())) return dict(enumerate(self._pc.pids()))
@ -803,10 +803,10 @@ class SubprocessContext(PContext):
self, self,
name: str, name: str,
entrypoint: str, entrypoint: str,
args: Dict[int, Tuple], args: dict[int, tuple],
envs: Dict[int, Dict[str, str]], envs: dict[int, dict[str, str]],
logs_specs: LogsSpecs, logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None, log_line_prefixes: Optional[dict[int, str]] = None,
): ):
super().__init__( super().__init__(
name, name,
@ -818,9 +818,9 @@ class SubprocessContext(PContext):
) )
# state vector; _vdone[local_rank] -> is local_rank finished or not # state vector; _vdone[local_rank] -> is local_rank finished or not
self._running_local_ranks: Set[int] = set(range(self.nprocs)) self._running_local_ranks: set[int] = set(range(self.nprocs))
self._failures: Dict[int, ProcessFailure] = {} self._failures: dict[int, ProcessFailure] = {}
self.subprocess_handlers: Dict[int, SubprocessHandler] = {} self.subprocess_handlers: dict[int, SubprocessHandler] = {}
def _start(self): def _start(self):
if self.subprocess_handlers: if self.subprocess_handlers:
@ -884,7 +884,7 @@ class SubprocessContext(PContext):
else: # there are no failures and procs still running else: # there are no failures and procs still running
return None return None
def pids(self) -> Dict[int, int]: def pids(self) -> dict[int, int]:
return { return {
local_rank: sh.proc.pid local_rank: sh.proc.pid
for local_rank, sh in self.subprocess_handlers.items() for local_rank, sh in self.subprocess_handlers.items()

View File

@ -78,7 +78,7 @@ __all__ = [
logger = get_logger(__name__) logger = get_logger(__name__)
JSON = Dict JSON = dict
_EMPTY_ERROR_DATA = {"message": "<NONE>"} _EMPTY_ERROR_DATA = {"message": "<NONE>"}
_NOT_AVAILABLE = "<N/A>" _NOT_AVAILABLE = "<N/A>"
@ -143,7 +143,7 @@ class ProcessFailure:
else: else:
self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html" self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html"
def _get_error_data(self, error_file_data: Dict[str, Any]) -> tuple[str, int]: def _get_error_data(self, error_file_data: dict[str, Any]) -> tuple[str, int]:
message = error_file_data["message"] message = error_file_data["message"]
if isinstance(message, str): if isinstance(message, str):
timestamp = int(error_file_data.get("timestamp", 0)) timestamp = int(error_file_data.get("timestamp", 0))
@ -231,7 +231,7 @@ class ChildFailedError(Exception):
of trainer 1's error file to the scheduler's init process. of trainer 1's error file to the scheduler's init process.
""" """
def __init__(self, name: str, failures: Dict[GlobalRank, ProcessFailure]): def __init__(self, name: str, failures: dict[GlobalRank, ProcessFailure]):
self.name = name self.name = name
self.failures = failures self.failures = failures
assert ( assert (
@ -248,7 +248,7 @@ class ChildFailedError(Exception):
root_rank, _root_failure = self.get_first_failure() root_rank, _root_failure = self.get_first_failure()
root_failure_fmt: str = "" root_failure_fmt: str = ""
other_failures_fmt: List[str] = [] other_failures_fmt: list[str] = []
width = len(title) width = len(title)
for idx, (rank, failure) in enumerate(self.failures.items()): for idx, (rank, failure) in enumerate(self.failures.items()):
fmt, w = self._format_failure(idx, rank, failure) fmt, w = self._format_failure(idx, rank, failure)

View File

@ -13,7 +13,7 @@ import os
import time import time
import traceback import traceback
import warnings import warnings
from typing import Any, Dict, Optional from typing import Any, Optional
__all__ = ["ErrorHandler"] __all__ = ["ErrorHandler"]
@ -86,7 +86,7 @@ class ErrorHandler:
def override_error_code_in_rootcause_data( def override_error_code_in_rootcause_data(
self, self,
rootcause_error_file: str, rootcause_error_file: str,
rootcause_error: Dict[str, Any], rootcause_error: dict[str, Any],
error_code: int = 0, error_code: int = 0,
): ):
"""Modify the rootcause_error read from the file, to correctly set the exit code.""" """Modify the rootcause_error read from the file, to correctly set the exit code."""

View File

@ -3,7 +3,6 @@
# #
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Dict, Tuple
from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import ( from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
SubprocessHandler, SubprocessHandler,
@ -15,8 +14,8 @@ __all__ = ["get_subprocess_handler"]
def get_subprocess_handler( def get_subprocess_handler(
entrypoint: str, entrypoint: str,
args: Tuple, args: tuple,
env: Dict[str, str], env: dict[str, str],
stdout: str, stdout: str,
stderr: str, stderr: str,
local_rank_id: int, local_rank_id: int,

View File

@ -9,7 +9,7 @@ import os
import signal import signal
import subprocess import subprocess
import sys import sys
from typing import Any, Dict, Optional, Tuple from typing import Any, Optional
__all__ = ["SubprocessHandler"] __all__ = ["SubprocessHandler"]
@ -34,8 +34,8 @@ class SubprocessHandler:
def __init__( def __init__(
self, self,
entrypoint: str, entrypoint: str,
args: Tuple, args: tuple,
env: Dict[str, str], env: dict[str, str],
stdout: Optional[str], stdout: Optional[str],
stderr: Optional[str], stderr: Optional[str],
local_rank_id: int, local_rank_id: int,
@ -50,8 +50,8 @@ class SubprocessHandler:
self.local_rank_id = local_rank_id self.local_rank_id = local_rank_id
self.proc: subprocess.Popen = self._popen(args_str, env_vars) self.proc: subprocess.Popen = self._popen(args_str, env_vars)
def _popen(self, args: Tuple, env: Dict[str, str]) -> subprocess.Popen: def _popen(self, args: tuple, env: dict[str, str]) -> subprocess.Popen:
kwargs: Dict[str, Any] = {} kwargs: dict[str, Any] = {}
if not IS_WINDOWS: if not IS_WINDOWS:
kwargs["start_new_session"] = True kwargs["start_new_session"] = True
return subprocess.Popen( return subprocess.Popen(

View File

@ -12,7 +12,7 @@ import os
import time import time
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from threading import Event from threading import Event
from typing import Dict, List, Optional, TextIO, TYPE_CHECKING from typing import Optional, TextIO, TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
@ -89,9 +89,9 @@ class TailLog:
def __init__( def __init__(
self, self,
name: str, name: str,
log_files: Dict[int, str], log_files: dict[int, str],
dst: TextIO, dst: TextIO,
log_line_prefixes: Optional[Dict[int, str]] = None, log_line_prefixes: Optional[dict[int, str]] = None,
interval_sec: float = 0.1, interval_sec: float = 0.1,
): ):
n = len(log_files) n = len(log_files)
@ -106,10 +106,10 @@ class TailLog:
self._dst = dst self._dst = dst
self._log_files = log_files self._log_files = log_files
self._log_line_prefixes = log_line_prefixes self._log_line_prefixes = log_line_prefixes
self._finished_events: Dict[int, Event] = { self._finished_events: dict[int, Event] = {
local_rank: Event() for local_rank in log_files.keys() local_rank: Event() for local_rank in log_files.keys()
} }
self._futs: List[Future] = [] self._futs: list[Future] = []
self._interval_sec = interval_sec self._interval_sec = interval_sec
self._stopped = False self._stopped = False

View File

@ -8,7 +8,7 @@
import socket import socket
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Dict, Optional from typing import Any, Callable, ClassVar, Optional
from torch.distributed import Store from torch.distributed import Store
from torch.distributed.elastic.utils.distributed import get_free_port from torch.distributed.elastic.utils.distributed import get_free_port
@ -325,7 +325,7 @@ RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
class RendezvousHandlerRegistry: class RendezvousHandlerRegistry:
"""Represent a registry of :py:class:`RendezvousHandler` backends.""" """Represent a registry of :py:class:`RendezvousHandler` backends."""
_registry: Dict[str, RendezvousHandlerCreator] _registry: dict[str, RendezvousHandlerCreator]
def __init__(self) -> None: def __init__(self) -> None:
self._registry = {} self._registry = {}

View File

@ -17,7 +17,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set from typing import Any, Callable, Optional
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import Store from torch.distributed import Store
@ -298,10 +298,10 @@ class _RendezvousState:
complete: bool complete: bool
deadline: Optional[datetime] deadline: Optional[datetime]
closed: bool closed: bool
participants: Dict[_NodeDesc, int] participants: dict[_NodeDesc, int]
wait_list: Set[_NodeDesc] wait_list: set[_NodeDesc]
redundancy_list: Set[_NodeDesc] redundancy_list: set[_NodeDesc]
last_heartbeats: Dict[_NodeDesc, datetime] last_heartbeats: dict[_NodeDesc, datetime]
def __init__(self) -> None: def __init__(self) -> None:
self.round = 0 self.round = 0
@ -377,7 +377,7 @@ class _BackendRendezvousStateHolder(_RendezvousStateHolder):
_token: Token _token: Token
_dirty: bool _dirty: bool
_last_sync_time: float _last_sync_time: float
_dead_nodes: List[_NodeDesc] _dead_nodes: list[_NodeDesc]
def __init__( def __init__(
self, self,

View File

@ -13,20 +13,20 @@ import time
import weakref import weakref
from datetime import timedelta from datetime import timedelta
from threading import Event, Thread from threading import Event, Thread
from typing import Any, Callable, Dict, Optional, Union from typing import Any, Callable, Optional, Union
__all__ = ["parse_rendezvous_endpoint"] __all__ = ["parse_rendezvous_endpoint"]
def _parse_rendezvous_config(config_str: str) -> Dict[str, str]: def _parse_rendezvous_config(config_str: str) -> dict[str, str]:
"""Extract key-value pairs from a rendezvous configuration string. """Extract key-value pairs from a rendezvous configuration string.
Args: Args:
config_str: config_str:
A string in format <key1>=<value1>,...,<keyN>=<valueN>. A string in format <key1>=<value1>,...,<keyN>=<valueN>.
""" """
config: Dict[str, str] = {} config: dict[str, str] = {}
config_str = config_str.strip() config_str = config_str.strip()
if not config_str: if not config_str:
@ -196,7 +196,7 @@ class _PeriodicTimer:
interval: float interval: float
function: Callable[..., None] function: Callable[..., None]
args: tuple[Any, ...] args: tuple[Any, ...]
kwargs: Dict[str, Any] kwargs: dict[str, Any]
stop_event: Event stop_event: Event
_name: Optional[str] _name: Optional[str]

View File

@ -10,7 +10,7 @@ import threading
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from inspect import getframeinfo, stack from inspect import getframeinfo, stack
from typing import Any, Dict, List, Optional, Set from typing import Any, Optional
__all__ = [ __all__ = [
@ -103,7 +103,7 @@ class RequestQueue(abc.ABC):
""" """
@abc.abstractmethod @abc.abstractmethod
def get(self, size: int, timeout: float) -> List[TimerRequest]: def get(self, size: int, timeout: float) -> list[TimerRequest]:
""" """
Gets up to ``size`` number of timer requests in a blocking fashion Gets up to ``size`` number of timer requests in a blocking fashion
(no more than ``timeout`` seconds). (no more than ``timeout`` seconds).
@ -134,7 +134,7 @@ class TimerServer(abc.ABC):
self._stop_signaled = False self._stop_signaled = False
@abc.abstractmethod @abc.abstractmethod
def register_timers(self, timer_requests: List[TimerRequest]) -> None: def register_timers(self, timer_requests: list[TimerRequest]) -> None:
""" """
Processes the incoming timer requests and registers them with the server. Processes the incoming timer requests and registers them with the server.
The timer request can either be a acquire-timer or release-timer request. The timer request can either be a acquire-timer or release-timer request.
@ -143,13 +143,13 @@ class TimerServer(abc.ABC):
""" """
@abc.abstractmethod @abc.abstractmethod
def clear_timers(self, worker_ids: Set[Any]) -> None: def clear_timers(self, worker_ids: set[Any]) -> None:
""" """
Clears all timers for the given ``worker_ids``. Clears all timers for the given ``worker_ids``.
""" """
@abc.abstractmethod @abc.abstractmethod
def get_expired_timers(self, deadline: float) -> Dict[str, List[TimerRequest]]: def get_expired_timers(self, deadline: float) -> dict[str, list[TimerRequest]]:
""" """
Returns all expired timers for each worker_id. An expired timer Returns all expired timers for each worker_id. An expired timer
is a timer for which the expiration_time is less than or equal to is a timer for which the expiration_time is less than or equal to

View File

@ -7,7 +7,6 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Dict, List
from torch.distributed.elastic.utils.logging import get_logger from torch.distributed.elastic.utils.logging import get_logger
@ -19,7 +18,7 @@ __all__ = ["log_debug_info_for_expired_timers"]
def log_debug_info_for_expired_timers( def log_debug_info_for_expired_timers(
run_id: str, run_id: str,
expired_timers: Dict[int, List[str]], expired_timers: dict[int, list[str]],
): ):
if expired_timers: if expired_timers:
logger.info("Timers expired for run:[%s] [%s].", run_id, expired_timers) logger.info("Timers expired for run:[%s] [%s].", run_id, expired_timers)

View File

@ -13,7 +13,7 @@ import signal
import sys import sys
import threading import threading
import time import time
from typing import Callable, Dict, List, Optional, Set from typing import Callable, Optional
from torch.distributed.elastic.timer.api import TimerClient, TimerRequest from torch.distributed.elastic.timer.api import TimerClient, TimerRequest
from torch.distributed.elastic.timer.debug_info_logging import ( from torch.distributed.elastic.timer.debug_info_logging import (
@ -201,7 +201,7 @@ class FileTimerServer:
self._run_id = run_id self._run_id = run_id
self._max_interval = max_interval self._max_interval = max_interval
self._daemon = daemon self._daemon = daemon
self._timers: Dict[tuple[int, str], FileTimerRequest] = {} self._timers: dict[tuple[int, str], FileTimerRequest] = {}
self._stop_signaled = False self._stop_signaled = False
self._watchdog_thread: Optional[threading.Thread] = None self._watchdog_thread: Optional[threading.Thread] = None
@ -354,12 +354,12 @@ class FileTimerServer:
self.clear_timers(reaped_worker_pids) self.clear_timers(reaped_worker_pids)
def _get_scopes(self, timer_requests: List[FileTimerRequest]) -> List[str]: def _get_scopes(self, timer_requests: list[FileTimerRequest]) -> list[str]:
return [r.scope_id for r in timer_requests] return [r.scope_id for r in timer_requests]
def _get_requests( def _get_requests(
self, fd: io.TextIOWrapper, max_interval: float self, fd: io.TextIOWrapper, max_interval: float
) -> List[FileTimerRequest]: ) -> list[FileTimerRequest]:
start = time.time() start = time.time()
requests = [] requests = []
while not self._stop_signaled or self._run_once: while not self._stop_signaled or self._run_once:
@ -394,7 +394,7 @@ class FileTimerServer:
break break
return requests return requests
def register_timers(self, timer_requests: List[FileTimerRequest]) -> None: def register_timers(self, timer_requests: list[FileTimerRequest]) -> None:
for request in timer_requests: for request in timer_requests:
pid = request.worker_pid pid = request.worker_pid
scope_id = request.scope_id scope_id = request.scope_id
@ -409,14 +409,14 @@ class FileTimerServer:
else: else:
self._timers[key] = request self._timers[key] = request
def clear_timers(self, worker_pids: Set[int]) -> None: def clear_timers(self, worker_pids: set[int]) -> None:
for pid, scope_id in list(self._timers.keys()): for pid, scope_id in list(self._timers.keys()):
if pid in worker_pids or not FileTimerServer.is_process_running(pid): if pid in worker_pids or not FileTimerServer.is_process_running(pid):
del self._timers[(pid, scope_id)] del self._timers[(pid, scope_id)]
def get_expired_timers(self, deadline: float) -> Dict[int, List[FileTimerRequest]]: def get_expired_timers(self, deadline: float) -> dict[int, list[FileTimerRequest]]:
# pid -> [timer_requests...] # pid -> [timer_requests...]
expired_timers: Dict[int, List[FileTimerRequest]] = {} expired_timers: dict[int, list[FileTimerRequest]] = {}
for request in self._timers.values(): for request in self._timers.values():
if request.expiration_time <= deadline: if request.expiration_time <= deadline:
expired_scopes = expired_timers.setdefault(request.worker_pid, []) expired_scopes = expired_timers.setdefault(request.worker_pid, [])

View File

@ -10,7 +10,7 @@ import os
import signal import signal
import time import time
from queue import Empty from queue import Empty
from typing import Any, Dict, List, Set from typing import Any
from .api import RequestQueue, TimerClient, TimerRequest, TimerServer from .api import RequestQueue, TimerClient, TimerRequest, TimerServer
@ -56,7 +56,7 @@ class MultiprocessingRequestQueue(RequestQueue):
def size(self) -> int: def size(self) -> int:
return self._mp_queue.qsize() return self._mp_queue.qsize()
def get(self, size, timeout: float) -> List[TimerRequest]: def get(self, size, timeout: float) -> list[TimerRequest]:
requests = [] requests = []
wait = timeout wait = timeout
for _ in range(0, size): for _ in range(0, size):
@ -88,9 +88,9 @@ class LocalTimerServer(TimerServer):
self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True
): ):
super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon) super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon)
self._timers: Dict[tuple[Any, str], TimerRequest] = {} self._timers: dict[tuple[Any, str], TimerRequest] = {}
def register_timers(self, timer_requests: List[TimerRequest]) -> None: def register_timers(self, timer_requests: list[TimerRequest]) -> None:
for request in timer_requests: for request in timer_requests:
pid = request.worker_id pid = request.worker_id
scope_id = request.scope_id scope_id = request.scope_id
@ -102,14 +102,14 @@ class LocalTimerServer(TimerServer):
else: else:
self._timers[(pid, scope_id)] = request self._timers[(pid, scope_id)] = request
def clear_timers(self, worker_ids: Set[int]) -> None: def clear_timers(self, worker_ids: set[int]) -> None:
for pid, scope_id in list(self._timers.keys()): for pid, scope_id in list(self._timers.keys()):
if pid in worker_ids: if pid in worker_ids:
self._timers.pop((pid, scope_id)) self._timers.pop((pid, scope_id))
def get_expired_timers(self, deadline: float) -> Dict[Any, List[TimerRequest]]: def get_expired_timers(self, deadline: float) -> dict[Any, list[TimerRequest]]:
# pid -> [timer_requests...] # pid -> [timer_requests...]
expired_timers: Dict[Any, List[TimerRequest]] = {} expired_timers: dict[Any, list[TimerRequest]] = {}
for request in self._timers.values(): for request in self._timers.values():
if request.expiration_time <= deadline: if request.expiration_time <= deadline:
expired_scopes = expired_timers.setdefault(request.worker_id, []) expired_scopes = expired_timers.setdefault(request.worker_id, [])

View File

@ -9,7 +9,7 @@
import os import os
import socket import socket
from string import Template from string import Template
from typing import Any, List from typing import Any
def get_env_variable_or_raise(env_name: str) -> str: def get_env_variable_or_raise(env_name: str) -> str:
@ -51,7 +51,7 @@ class macros:
local_rank = "${local_rank}" local_rank = "${local_rank}"
@staticmethod @staticmethod
def substitute(args: List[Any], local_rank: str) -> List[str]: def substitute(args: list[Any], local_rank: str) -> list[str]:
args_sub = [] args_sub = []
for arg in args: for arg in args:
if isinstance(arg, str): if isinstance(arg, str):

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from typing import Callable, Iterator, TypeVar from collections.abc import Iterator
from typing import Callable, TypeVar
from typing_extensions import Self from typing_extensions import Self

View File

@ -7,9 +7,10 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections.abc import Iterable
from contextlib import contextmanager from contextlib import contextmanager
from datetime import timedelta from datetime import timedelta
from typing import Callable, Iterable, List, Optional from typing import Callable, Optional
import torch import torch
@ -85,7 +86,7 @@ def synchronize(
world_size: int, world_size: int,
key_prefix: str, key_prefix: str,
timeout: float = 300, timeout: float = 300,
) -> List[bytes]: ) -> list[bytes]:
""" """
Synchronizes ``world_size`` agents between each other using the underlying c10d store. Synchronizes ``world_size`` agents between each other using the underlying c10d store.
The ``data`` will be available on each of the agents. The ``data`` will be available on each of the agents.