PEP585 update - torch/distributed/elastic torch/distributed/checkpoint (#145163)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145163
Approved by: https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein
2025-01-18 14:58:05 -08:00
committed by PyTorch MergeBot
parent c64e657632
commit 316808e4e9
47 changed files with 311 additions and 344 deletions

View File

@ -1,5 +1,5 @@
from concurrent.futures import Future
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch.distributed as dist
import torch.distributed.checkpoint.state_dict_loader as loader
@ -13,7 +13,7 @@ from torch.distributed.checkpoint.storage import (
)
__all__: List[str] = []
__all__: list[str] = []
class _Checkpointer:
@ -90,7 +90,7 @@ class _Checkpointer:
planner=self.save_planner,
)
def load(self, state_dict: Dict[str, Any]) -> None:
def load(self, state_dict: dict[str, Any]) -> None:
"""Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization."""
loader.load(
state_dict,

View File

@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import dataclasses
from collections import defaultdict
from typing import Dict, List, Set, TYPE_CHECKING
from typing import TYPE_CHECKING
from torch.distributed.checkpoint.planner import SavePlan, WriteItem
@ -13,16 +13,16 @@ __all__ = ["dedup_save_plans"]
def dedup_save_plans(
all_plans: List[SavePlan],
all_plans: list[SavePlan],
save_to_lowest_rank: bool = False,
) -> List[SavePlan]:
) -> list[SavePlan]:
"""
Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across
a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry.
"""
write_item_to_plan_indices: Dict[MetadataIndex, Set[int]] = defaultdict(set)
write_item_idx_to_write_item: Dict[MetadataIndex, WriteItem] = {}
write_item_to_plan_indices: dict[MetadataIndex, set[int]] = defaultdict(set)
write_item_idx_to_write_item: dict[MetadataIndex, WriteItem] = {}
for plan_idx, plan in enumerate(all_plans):
for write_item in plan.items:
# map each write item to its plan
@ -30,7 +30,7 @@ def dedup_save_plans(
write_item_idx_to_write_item[write_item.index] = write_item
# put item in the plan with the smallest size and remove it from the other plan_indices
to_remove: List[Set] = [set() for _ in range(len(all_plans))]
to_remove: list[set] = [set() for _ in range(len(all_plans))]
plan_to_size = [0] * len(all_plans)
for write_item_idx, plan_indices in write_item_to_plan_indices.items():
if save_to_lowest_rank:

View File

@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import dataclasses
import logging
from typing import Dict, List, TYPE_CHECKING
from typing import TYPE_CHECKING
from torch.distributed.checkpoint.planner import SavePlan
@ -31,9 +31,9 @@ logger = init_logger()
# TODO add docstring for dedup_tensors
def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
def dedup_tensors(all_plans: list[SavePlan]) -> list[SavePlan]:
all_plans = list(all_plans)
key_to_plan: Dict[MetadataIndex, List[int]] = {}
key_to_plan: dict[MetadataIndex, list[int]] = {}
for plan_idx, plan in enumerate(all_plans):
for write_item in plan.items:
key_to_plan.setdefault(write_item.index, []).append(plan_idx)
@ -42,7 +42,7 @@ def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
# Remove duplicates by always keeping the first entry.
# Compute the per-rank remove set.
plan_to_keys: Dict[int, List[MetadataIndex]] = {}
plan_to_keys: dict[int, list[MetadataIndex]] = {}
for key, plans in replicated_items.items():
for plan_idx in plans[1:]:
plan_to_keys.setdefault(plan_idx, []).append(key)

View File

@ -3,10 +3,10 @@
import io
import os
from collections.abc import Sequence
from collections.abc import Generator, Sequence
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Optional, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING, Union
from fsspec.core import url_to_fs

View File

@ -1,5 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Dict
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
@ -21,7 +20,7 @@ Change set_element to recreate the right type for tuple, OrderedDict, and NamedT
"""
FLATTEN_MAPPING = Dict[str, OBJ_PATH]
FLATTEN_MAPPING = dict[str, OBJ_PATH]
# TODO: Update Docstring for nested_dict.py

View File

@ -1,5 +1,5 @@
import os
from typing import List, Type, Union
from typing import Union
from .filesystem import FileSystemReader, FileSystemWriter
from .storage import StorageReader, StorageWriter
@ -21,7 +21,7 @@ def _storage_setup(
"storage_reader/storage_writer is None."
)
targets: List[Type[Union[StorageReader, StorageWriter]]] = []
targets: list[type[Union[StorageReader, StorageWriter]]] = []
if reader:
targets = [
FileSystemReader,

View File

@ -1,15 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import (
Callable,
cast,
Collection,
List,
Mapping,
MutableMapping,
Optional,
TypeVar,
Union,
)
from collections.abc import Collection, Mapping, MutableMapping
from typing import Callable, cast, Optional, TypeVar, Union
import torch
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
@ -123,7 +114,7 @@ def set_element(
"""Set ``value`` in ``root_dict`` along the ``path`` object path."""
cur_container = cast(CONTAINER_TYPE, root_dict)
def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None:
def extend_list(lst: list[STATE_DICT_ITEM], idx: int) -> None:
while len(lst) <= idx:
lst.append(None)
@ -144,7 +135,7 @@ def set_element(
key = path[-1]
if type(key) == int:
extend_list(cast(List[STATE_DICT_ITEM], cur_container), key)
extend_list(cast(list[STATE_DICT_ITEM], cur_container), key)
cur_container[key] = value

View File

@ -1,5 +1,5 @@
import traceback as tb
from typing import Any, Dict
from typing import Any
WRAPPED_EXCEPTION = tuple[BaseException, tb.StackSummary]
@ -22,12 +22,12 @@ def _is_wrapped_exception(obj: Any) -> bool:
class CheckpointException(BaseException):
"""Exception raised if failure was detected as part of a checkpoint load or save."""
def __init__(self, msg: str, failures: Dict[int, WRAPPED_EXCEPTION]):
def __init__(self, msg: str, failures: dict[int, WRAPPED_EXCEPTION]):
super().__init__(msg, failures)
self._failures = failures
@property
def failures(self) -> Dict[int, WRAPPED_EXCEPTION]:
def failures(self) -> dict[int, WRAPPED_EXCEPTION]:
"""Return a dictionary mapping node ranks to their associated exceptions in case of failure."""
return self._failures

View File

@ -7,7 +7,7 @@ import logging
import operator
from collections import ChainMap
from functools import reduce
from typing import Any, cast, Dict, List, Optional, Union
from typing import Any, cast, Optional, Union
import torch
from torch.distributed._shard._utils import narrow_tensor_by_index
@ -106,8 +106,8 @@ class DefaultSavePlanner(SavePlanner):
return self.plan
def create_global_plan(
self, all_plans: List[SavePlan]
) -> tuple[List[SavePlan], Metadata]:
self, all_plans: list[SavePlan]
) -> tuple[list[SavePlan], Metadata]:
all_plans = dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank)
global_plan, metadata = create_default_global_save_plan(all_plans)
@ -234,7 +234,7 @@ class DefaultLoadPlanner(LoadPlanner):
self.state_dict, self.metadata, not self.allow_partial_load
)
def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
def create_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]:
return create_default_global_load_plan(global_plan)
def finish_plan(self, new_plan: LoadPlan) -> LoadPlan:
@ -293,7 +293,7 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
if key in self.keys:
True
unflattened_keys: List[str] = []
unflattened_keys: list[str] = []
planner_data = metadata.planner_data.get(key)
for unflattened_key in planner_data:
if unflattened_keys:
@ -334,7 +334,7 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
def create_default_local_load_plan(
state_dict: Dict[str, Any], metadata: Metadata, strict: bool = True
state_dict: dict[str, Any], metadata: Metadata, strict: bool = True
) -> LoadPlan:
requests = []
"""
@ -376,8 +376,8 @@ def create_default_local_load_plan(
def create_default_global_load_plan(
all_plans: List[LoadPlan],
) -> List[LoadPlan]:
all_plans: list[LoadPlan],
) -> list[LoadPlan]:
"""
Create global load plan used by DefaultLoadPlanner.
@ -388,7 +388,7 @@ def create_default_global_load_plan(
def create_default_local_save_plan(
state_dict: Dict[str, Any], is_coordinator: bool
state_dict: dict[str, Any], is_coordinator: bool
) -> SavePlan:
"""
Create the ``SavePlan`` used by DefaultSavePlanner.
@ -415,9 +415,9 @@ def create_default_local_save_plan(
def create_default_global_save_plan(
all_plans: List[SavePlan],
all_plans: list[SavePlan],
rewrite_index_hints: bool = True,
) -> tuple[List[SavePlan], Metadata]:
) -> tuple[list[SavePlan], Metadata]:
"""
Create the global plan and metadata used by DefaultSavePlanner.
@ -426,7 +426,7 @@ def create_default_global_save_plan(
The only global planning change is to update index hints in all ``MetadataIndex`` objects if
``rewrite_index_hints`` is True.
"""
md: Dict[str, STORAGE_TYPES] = {}
md: dict[str, STORAGE_TYPES] = {}
new_plans = []
for plan in all_plans:
new_items = []
@ -506,7 +506,7 @@ def _check_box_bounds(
return True
def _validate_global_plan(global_plan: List[SavePlan], metadata: Metadata) -> bool:
def _validate_global_plan(global_plan: list[SavePlan], metadata: Metadata) -> bool:
all_good = True
for key, value in metadata.state_dict_metadata.items():
if isinstance(value, BytesStorageMetadata):

View File

@ -10,24 +10,12 @@ import threading
import uuid
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Generator, Iterable, Iterator, Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from io import UnsupportedOperation
from pathlib import Path
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
IO,
Iterable,
Iterator,
List,
Optional,
Union,
)
from typing import Any, Callable, cast, IO, Optional, Union
# introduced as collections.abc.Buffer in Python 3.12
from typing_extensions import Buffer
@ -113,7 +101,7 @@ class _TensorLoader(ABC):
class _SerialCpuLoader(_TensorLoader):
def __init__(self, resolve_fun: Callable) -> None:
self.resolve_fun = resolve_fun
self.items: List[tuple[int, object]] = []
self.items: list[tuple[int, object]] = []
def add(self, size: int, obj: object) -> None:
self.items.append((size, obj))
@ -141,7 +129,7 @@ class _OverlappingCpuLoader(_TensorLoader):
inflight_threshhold: int = 1_000_000,
) -> None:
self.resolve_fun = resolve_fun
self.items: List[tuple[int, object]] = []
self.items: list[tuple[int, object]] = []
self.inflight_threshhold = inflight_threshhold
self.in_flight_data = 0
self.current_items: collections.deque = collections.deque()
@ -161,7 +149,7 @@ class _OverlappingCpuLoader(_TensorLoader):
def _done(self) -> bool:
return self.idx >= len(self.items)
def _drain(self) -> List[tuple[torch.Tensor, object]]:
def _drain(self) -> list[tuple[torch.Tensor, object]]:
drained = []
if self.in_flight_data >= self.inflight_threshhold:
self.stream.synchronize()
@ -243,7 +231,7 @@ class _StorageWriterTransforms:
def transform_save_stream(
self, write_item: WriteItem, raw_stream: io.IOBase
) -> tuple[IO[bytes], List[str]]:
) -> tuple[IO[bytes], list[str]]:
# In order to avoid leaking fds, transformers' close must
# cascade to wrapped streams, but since this function can
# append to the raw stream, we can't close the actual stream.
@ -285,14 +273,14 @@ def _item_size(item: WriteItem) -> int:
return size * torch._utils._element_size(dtype)
def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
def _split_by_size_and_type(bins: int, items: list[WriteItem]) -> list[list[WriteItem]]:
if bins == 1:
return [items]
bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO]
tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO]
buckets: List[List[WriteItem]] = [[] for _ in range(bins)]
buckets: list[list[WriteItem]] = [[] for _ in range(bins)]
bucket_sizes = [0 for _ in range(bins)]
tensor_w.sort(key=_item_size, reverse=True)
@ -584,7 +572,7 @@ class _FileSystemWriter(StorageWriter):
return plan
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
new_plans = [
dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
for i, plan in enumerate(plans)
@ -595,7 +583,7 @@ class _FileSystemWriter(StorageWriter):
self,
plan: SavePlan,
planner: SavePlanner,
) -> Future[List[WriteResult]]:
) -> Future[list[WriteResult]]:
storage_plan: _StoragePrefix = plan.storage_data
file_count = 0
@ -656,11 +644,11 @@ class _FileSystemWriter(StorageWriter):
while True:
res += result_queue.get_nowait()
except queue.Empty:
fut: Future[List[WriteResult]] = Future()
fut: Future[list[WriteResult]] = Future()
fut.set_result(res)
return fut
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
storage_md = {}
for wr_list in results:
storage_md.update({wr.index: wr.storage_data for wr in wr_list})
@ -737,7 +725,7 @@ class FileSystemReader(StorageReader):
super().__init__()
self.fs = FileSystem()
self.path = self.fs.init_path(path)
self.storage_data: Dict[MetadataIndex, _StorageInfo] = {}
self.storage_data: dict[MetadataIndex, _StorageInfo] = {}
self.load_id = _generate_uuid()
self.transforms = _StorageReaderTransforms(_extension_registry)
@ -752,7 +740,7 @@ class FileSystemReader(StorageReader):
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
# group requests by file
per_file: Dict[str, List[ReadItem]] = {}
per_file: dict[str, list[ReadItem]] = {}
for read_item in plan.items:
item_md = self.storage_data[read_item.storage_index]
path = item_md.relative_path
@ -828,7 +816,7 @@ class FileSystemReader(StorageReader):
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
return plan
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]:
return plans
@property

View File

@ -2,7 +2,7 @@
import argparse
import os
from enum import Enum
from typing import cast, Dict, List, Optional, Union
from typing import cast, Optional, Union
import torch
import torch.distributed as dist
@ -133,7 +133,7 @@ class BroadcastingTorchSaveReader(StorageReader):
"""Implementation of the StorageReader method"""
return plan
def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
def prepare_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]:
"""Implementation of the StorageReader method"""
return global_plan
@ -177,7 +177,7 @@ class DynamicMetaLoadPlanner(DefaultLoadPlanner):
"""Setups of the planner, extnding default behavior by creating the Metadata object from the state dict"""
super().set_up_planner(state_dict, metadata, is_coordinator)
state_dict_metadata: Dict[str, STORAGE_TYPES] = {}
state_dict_metadata: dict[str, STORAGE_TYPES] = {}
for key, tensor in self.state_dict.items():
if not torch.is_tensor(tensor):
raise RuntimeError(

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import functools
import time
from typing import Any, Callable, Dict, List, TypeVar
from typing import Any, Callable, TypeVar
from typing_extensions import ParamSpec
from uuid import uuid4
@ -9,7 +9,7 @@ import torch.distributed.c10d_logger as c10d_logger
from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME
__all__: List[str] = []
__all__: list[str] = []
global _dcp_logger
_dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME)
@ -18,7 +18,7 @@ _T = TypeVar("_T")
_P = ParamSpec("_P")
def _msg_dict_from_dcp_method_args(*args, **kwargs) -> Dict[str, Any]:
def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]:
"""
Extracts log data from dcp method args
"""
@ -52,7 +52,7 @@ def _msg_dict_from_dcp_method_args(*args, **kwargs) -> Dict[str, Any]:
return msg_dict
def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
def _get_msg_dict(func_name, *args, **kwargs) -> dict[str, Any]:
msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs)
msg_dict.update(c10d_logger._get_msg_dict(func_name, *args, **kwargs))

View File

@ -1,10 +1,9 @@
import logging
from typing import List
from torch.distributed.logging_handlers import _log_handlers
__all__: List[str] = []
__all__: list[str] = []
DCP_LOGGER_NAME = "dcp_logger"

View File

@ -1,8 +1,9 @@
# mypy: allow-untyped-defs
import os
from collections.abc import Sequence
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Optional, Union
import torch
from torch.distributed.checkpoint.stateful import StatefulT
@ -113,7 +114,7 @@ class TensorProperties:
class TensorStorageMetadata:
properties: TensorProperties
size: torch.Size
chunks: List[ChunkStorageMetadata]
chunks: list[ChunkStorageMetadata]
@dataclass
@ -122,7 +123,7 @@ class BytesStorageMetadata:
STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata]
STATE_DICT_TYPE = Dict[str, Union[StatefulT, Any]]
STATE_DICT_TYPE = dict[str, Union[StatefulT, Any]]
@dataclass
@ -137,7 +138,7 @@ class Metadata:
"""This class represents the metadata of the checkpoint."""
# Keys are the same from the `state_dict` used.
state_dict_metadata: Dict[str, STORAGE_TYPES]
state_dict_metadata: dict[str, STORAGE_TYPES]
# It is the responsibility of the planner and storage plugins to ensure
# backward compatibility of the planner_data and storage_data. DCP will
# also ensure the backward compatibility of the metadata in this file and

View File

@ -1,7 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import dataclasses
from typing import cast, Dict, List, Optional, Sequence, Union
from collections.abc import Sequence
from typing import cast, Optional, Union
import torch
import torch.distributed as dist
@ -41,7 +42,7 @@ from torch.distributed.remote_device import _remote_device
from torch.distributed.tensor import DTensor
STATE_DICT_2D_LAYOUT = Dict[str, tuple[Optional[Sequence[int]], Sequence[int]]]
STATE_DICT_2D_LAYOUT = dict[str, tuple[Optional[Sequence[int]], Sequence[int]]]
# TODO: Update docstrings for optimizer.py
@ -77,7 +78,7 @@ def _create_colwise_spec(
]
return ChunkShardingSpec(
dim=0,
placements=cast(List[Union[_remote_device, str]], placements),
placements=cast(list[Union[_remote_device, str]], placements),
)
@ -154,11 +155,11 @@ def _get_state_dict_2d_layout(
class _ReaderWithOffset(DefaultLoadPlanner):
translation: Dict[MetadataIndex, MetadataIndex]
translation: dict[MetadataIndex, MetadataIndex]
state_dict: STATE_DICT_TYPE
metadata: Metadata
def __init__(self, fqn_to_offset: Dict[str, Sequence[int]]) -> None:
def __init__(self, fqn_to_offset: dict[str, Sequence[int]]) -> None:
super().__init__()
self.fqn_to_offset = fqn_to_offset
self.metadata = Metadata({})
@ -284,7 +285,7 @@ def load_sharded_optimizer_state_dict(
# Create a state_dict for optimizer state
state_dict: STATE_DICT_TYPE = {}
fqn_to_offset: Dict[str, Sequence[int]] = {}
fqn_to_offset: dict[str, Sequence[int]] = {}
for key, value in metadata.state_dict_metadata.items():
key_path = metadata.planner_data[key]
if key_path[0] != optimizer_key:

View File

@ -4,7 +4,7 @@ import operator
from dataclasses import dataclass
from enum import auto, Enum
from functools import reduce
from typing import Any, List, Optional, Union
from typing import Any, Optional, Union
import torch
from torch.distributed.checkpoint.metadata import (
@ -94,14 +94,14 @@ class ReadItem:
@dataclass(frozen=True)
class SavePlan:
items: List[WriteItem]
items: list[WriteItem]
storage_data: Any = None
planner_data: Any = None
@dataclass
class LoadPlan:
items: List[ReadItem]
items: list[ReadItem]
storage_data: Any = None
planner_data: Any = None
@ -231,8 +231,8 @@ class SavePlanner(abc.ABC):
@abc.abstractmethod
def create_global_plan(
self, all_plans: List[SavePlan]
) -> tuple[List[SavePlan], Metadata]:
self, all_plans: list[SavePlan]
) -> tuple[list[SavePlan], Metadata]:
"""
Compute the global checkpoint plan and return the local plan of each rank.
@ -364,7 +364,7 @@ class LoadPlanner:
"""
@abc.abstractmethod
def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
def create_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]:
"""
Compute the global load plan and return plans for each rank.

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import io
from typing import Any, Callable, cast, Dict, List
from typing import Any, Callable, cast
import torch
import torch.distributed as dist
@ -33,7 +33,7 @@ from .resharding import (
)
__all__: List[str] = ["create_read_items_for_chunk_list"]
__all__: list[str] = ["create_read_items_for_chunk_list"]
def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata:
@ -149,8 +149,8 @@ def _create_read_item_for_tensor(
def create_read_items_for_chunk_list(
fqn: str,
checkpoint_md: TensorStorageMetadata,
local_chunks: List[ChunkStorageMetadata],
) -> List[ReadItem]:
local_chunks: list[ChunkStorageMetadata],
) -> list[ReadItem]:
"""
Create a list of ``ReadItem`` based on the checkpoint and local chunks.
@ -218,7 +218,7 @@ def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
return SavePlan(requests)
def _create_write_items(fqn: str, object: Any) -> List[WriteItem]:
def _create_write_items(fqn: str, object: Any) -> list[WriteItem]:
if hasattr(object, "__create_write_items__"):
# DTensor implements _Checkpointable
return object.__create_write_items__(fqn, object)
@ -244,7 +244,7 @@ def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata:
)
def _create_chunk_list(tensor: torch.Tensor) -> List[ChunkStorageMetadata]:
def _create_chunk_list(tensor: torch.Tensor) -> list[ChunkStorageMetadata]:
if hasattr(tensor, "__create_chunk_list__"):
# DTensor implements _Checkpointable
local_chunks = tensor.__create_chunk_list__() # type: ignore[attr-defined]
@ -263,7 +263,7 @@ def _create_chunk_list(tensor: torch.Tensor) -> List[ChunkStorageMetadata]:
return local_chunks
def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]:
def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> list[ReadItem]:
if not isinstance(md, BytesStorageMetadata):
try:
local_chunks = _create_chunk_list(obj)
@ -286,7 +286,7 @@ def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]:
]
def _init_state_dict(state_dict: Dict[str, Any]) -> Any:
def _init_state_dict(state_dict: dict[str, Any]) -> Any:
"""
Initializes meta tensor if the meta tensor is DTensor or torch.Tensor.
"""

View File

@ -1,10 +1,9 @@
# mypy: allow-untyped-defs
from typing import List
from torch.distributed.checkpoint.metadata import ChunkStorageMetadata
__all__: List[str] = []
__all__: list[str] = []
def _check_shard_metadata_pair_overlap(
@ -27,7 +26,7 @@ def _check_shard_metadata_pair_overlap(
def _shards_get_overlap_region_wrt_saved_tensor(
saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata
) -> List[tuple[int, int, int, int]]:
) -> list[tuple[int, int, int, int]]:
"""
Return the overlapping region between saved_shard and current_shard.

View File

@ -3,21 +3,10 @@ import contextlib
import functools
import gc
import warnings
from collections.abc import Generator, Iterable
from dataclasses import asdict, dataclass, field
from itertools import chain
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
Iterable,
List,
no_type_check,
Optional,
Set,
Union,
)
from typing import Any, Callable, cast, no_type_check, Optional, Union
import torch
import torch.distributed as dist
@ -76,17 +65,17 @@ _PG = "param_groups"
_PARAMS = "params"
_STATE = "state"
FQNS_T = Set[str]
FQNS_T = set[str]
PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str]
ValueType = Union[
PrimitiveType, List[PrimitiveType], tuple[PrimitiveType], Dict[str, "ValueType"]
PrimitiveType, list[PrimitiveType], tuple[PrimitiveType], dict[str, "ValueType"]
]
DictValueType = Dict[str, ValueType]
ListDictValueType = List[DictValueType]
OptimizerStateType = Dict[str, Union[DictValueType, ListDictValueType]]
DictValueType = dict[str, ValueType]
ListDictValueType = list[DictValueType]
OptimizerStateType = dict[str, Union[DictValueType, ListDictValueType]]
_patched_state_dict: Set[Callable] = set()
_patched_state_dict: set[Callable] = set()
@contextlib.contextmanager
@ -149,20 +138,20 @@ class StateDictOptions:
@dataclass
class _StateDictInfo(StateDictOptions):
fqn_param_mapping: Dict[
fqn_param_mapping: dict[
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
] = field(default_factory=dict)
shared_params_mapping: Dict[
shared_params_mapping: dict[
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
] = field(default_factory=dict)
submodule_prefixes: Set[str] = field(default_factory=set)
submodule_prefixes: set[str] = field(default_factory=set)
handle_model: bool = True
handle_optim: bool = True
fsdp_context: Callable = contextlib.nullcontext
fsdp_modules: List[nn.Module] = field(default_factory=list)
fsdp_modules: list[nn.Module] = field(default_factory=list)
@functools.lru_cache(maxsize=None)
@functools.cache
def _get_fqns(
model: nn.Module,
name: str,
@ -230,7 +219,7 @@ class _EXTRA_STATE:
def _iterate_valid_model_state(model):
visited_modules: Set[nn.Module] = set()
visited_modules: set[nn.Module] = set()
def recurse(module: nn.Module, curr_fqn: str) -> Generator:
visited_modules.add(module)
@ -265,7 +254,7 @@ def _verify_options(
optims: tuple[torch.optim.Optimizer, ...],
optim_only: bool,
*,
submodules: Optional[Set[nn.Module]] = None,
submodules: Optional[set[nn.Module]] = None,
options: Optional[StateDictOptions] = None,
) -> _StateDictInfo:
"""
@ -285,11 +274,11 @@ def _verify_options(
options = options or StateDictOptions()
fqn_param_mapping: Dict[
Union[str, torch.Tensor], Union[Set[str], torch.Tensor]
fqn_param_mapping: dict[
Union[str, torch.Tensor], Union[set[str], torch.Tensor]
] = {}
shared_params_mapping: Dict[
Union[str, torch.Tensor], Union[Set[str], torch.Tensor]
shared_params_mapping: dict[
Union[str, torch.Tensor], Union[set[str], torch.Tensor]
] = {}
for name, param in _iterate_valid_model_state(model):
if isinstance(param, _EXTRA_STATE):
@ -298,7 +287,7 @@ def _verify_options(
fqns = _get_fqns(model, name)
fqn = fqn_param_mapping.get(param, None)
if fqn is not None:
cast(Set[str], fqn_param_mapping[param]).update(fqns)
cast(set[str], fqn_param_mapping[param]).update(fqns)
shared_params_mapping[param] = fqn_param_mapping[param]
else:
# We need to do copy as _get_fqns is lru_cached
@ -311,7 +300,7 @@ def _verify_options(
for fqn in fqns_:
shared_params_mapping[fqn] = cast(torch.Tensor, param_)
submodule_prefixes: Set[str] = set()
submodule_prefixes: set[str] = set()
if submodules:
submodules = set(submodules)
for name, module in model.named_modules():
@ -384,14 +373,14 @@ def _verify_options(
shared_params_mapping=shared_params_mapping,
submodule_prefixes=submodule_prefixes,
fsdp_context=fsdp_context,
fsdp_modules=cast(List[nn.Module], fsdp_modules),
fsdp_modules=cast(list[nn.Module], fsdp_modules),
handle_model=not optim_only,
handle_optim=(len(optims) > 0),
)
def _verify_state_dict(
model_state_dict: Dict[str, ValueType],
model_state_dict: dict[str, ValueType],
optim_state_dict: OptimizerStateType,
info: _StateDictInfo,
) -> None:
@ -443,8 +432,8 @@ def _state_dict_fn(obj: Union[nn.Module, torch.optim.Optimizer], api: str) -> Ca
def _maybe_full_or_cpu_state_dict(
state_dict: Dict[str, Any], info: _StateDictInfo
) -> Dict[str, Any]:
state_dict: dict[str, Any], info: _StateDictInfo
) -> dict[str, Any]:
if info.full_state_dict:
ranks_only = (
()
@ -463,7 +452,7 @@ def _maybe_full_or_cpu_state_dict(
@torch.no_grad()
def _get_model_state_dict(
model: nn.Module, info: _StateDictInfo
) -> Dict[str, ValueType]:
) -> dict[str, ValueType]:
if not info.handle_model:
return {}
@ -500,7 +489,7 @@ def _get_model_state_dict(
state_dict[fqn] = state_dict.pop(key)
if info.submodule_prefixes:
new_state_dict: Dict[str, ValueType] = {}
new_state_dict: dict[str, ValueType] = {}
# TODO: make this faster.
for fqn in state_dict.keys():
for prefix in info.submodule_prefixes:
@ -531,7 +520,7 @@ def _get_model_state_dict(
@torch.no_grad()
def _load_model_state_dict(
model: nn.Module,
state_dict: Dict[str, ValueType],
state_dict: dict[str, ValueType],
info: _StateDictInfo,
) -> _IncompatibleKeys:
if not info.handle_model or (not state_dict and not info.broadcast_from_rank0):
@ -636,7 +625,7 @@ def _init_optim_state(optim: torch.optim.Optimizer) -> None:
optim.zero_grad(set_to_none=True)
def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, ValueType]:
def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> dict[str, ValueType]:
"""
This API flattens the optimizer state_dict to support optimizer resharding for
MPMD, e.g., pipeline parallelism.
@ -686,7 +675,7 @@ def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, Value
f"Type is {type(v)}."
)
ret: Dict[str, ValueType] = {}
ret: dict[str, ValueType] = {}
for fqn, state in cast(DictValueType, state_dict[_STATE]).items():
for k, v in cast(DictValueType, state).items():
_raise_if_type_not_supported(v)
@ -694,7 +683,7 @@ def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, Value
for param_group in cast(ListDictValueType, state_dict[_PG]):
fqns = param_group.pop(_PARAMS)
for fqn in cast(List[str], fqns):
for fqn in cast(list[str], fqns):
for k, v in param_group.items():
ret[f"{_PG}.{fqn}.{k}"] = v
return ret
@ -702,7 +691,7 @@ def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, Value
def _unflatten_optim_state_dict(
optim: torch.optim.Optimizer,
state_dict: Dict[str, ValueType],
state_dict: dict[str, ValueType],
info: _StateDictInfo,
) -> OptimizerStateType:
"""
@ -728,7 +717,7 @@ def _unflatten_optim_state_dict(
f"{_STATE}.{fqn}.{state_name}"
]
first_param_fqn = cast(List[str], pg_state[-1][_PARAMS])[0]
first_param_fqn = cast(list[str], pg_state[-1][_PARAMS])[0]
for k in param_group.keys():
if k == _PARAMS:
continue
@ -833,7 +822,7 @@ def _split_optim_state_dict(
state: DictValueType = {}
pg_state: ListDictValueType = []
return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state}
pg_mapping: Dict[int, int] = {}
pg_mapping: dict[int, int] = {}
if all(
isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys()
@ -849,7 +838,7 @@ def _split_optim_state_dict(
for loaded_param_group in cast(
ListDictValueType, optim_state_dict[_PG]
):
if fqn in cast(List[str], loaded_param_group[_PARAMS]):
if fqn in cast(list[str], loaded_param_group[_PARAMS]):
in_params = True
break
else:
@ -865,7 +854,7 @@ def _split_optim_state_dict(
for loaded_param_group in cast(
ListDictValueType, optim_state_dict[_PG]
):
if fqn in cast(List[str], loaded_param_group[_PARAMS]):
if fqn in cast(list[str], loaded_param_group[_PARAMS]):
pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1
for param_group in cast(ListDictValueType, optim_state_dict[_PG]):
@ -900,7 +889,7 @@ def _load_optim_state_dict(
)
else:
optim_state_dict = _unflatten_optim_state_dict(
optim, cast(Dict[str, ValueType], state_dict), info
optim, cast(dict[str, ValueType], state_dict), info
)
else:
optim_state_dict = {}
@ -919,7 +908,7 @@ def _load_optim_state_dict(
fqn = fqns.pop()
fqn_with_compiler = fqns_with_compiler.pop()
for g in optim_state_dict[_PG]:
val = cast(Dict[str, Any], g)
val = cast(dict[str, Any], g)
params = [
key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS]
]
@ -978,9 +967,9 @@ def _load_optim_state_dict(
def get_model_state_dict(
model: nn.Module,
*,
submodules: Optional[Set[nn.Module]] = None,
submodules: Optional[set[nn.Module]] = None,
options: Optional[StateDictOptions] = None,
) -> Dict[str, ValueType]:
) -> dict[str, ValueType]:
"""
Return the model state_dict of ``model``.
@ -1016,7 +1005,7 @@ def get_optimizer_state_dict(
model: nn.Module,
optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
*,
submodules: Optional[Set[nn.Module]] = None,
submodules: Optional[set[nn.Module]] = None,
options: Optional[StateDictOptions] = None,
) -> OptimizerStateType:
"""
@ -1061,9 +1050,9 @@ def get_state_dict(
model: nn.Module,
optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
*,
submodules: Optional[Set[nn.Module]] = None,
submodules: Optional[set[nn.Module]] = None,
options: Optional[StateDictOptions] = None,
) -> tuple[Dict[str, ValueType], OptimizerStateType]:
) -> tuple[dict[str, ValueType], OptimizerStateType]:
"""
Return the model state_dict and optimizers state_dict.
@ -1148,8 +1137,8 @@ def get_state_dict(
def _unflatten_model_state_dict(
model: nn.Module,
state_dict: Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]],
) -> Dict[str, ValueType]:
state_dict: Union[dict[nn.Module, dict[str, ValueType]], dict[str, ValueType]],
) -> dict[str, ValueType]:
if not state_dict:
return {}
@ -1161,8 +1150,8 @@ def _unflatten_model_state_dict(
"same functionality.",
FutureWarning,
)
cast_state_dict = cast(Dict[nn.Module, Dict[str, ValueType]], state_dict)
new_state_dict: Dict[str, ValueType] = {}
cast_state_dict = cast(dict[nn.Module, dict[str, ValueType]], state_dict)
new_state_dict: dict[str, ValueType] = {}
for submodule, sub_state_dict in cast_state_dict.items():
for name, m in model.named_modules():
if m != submodule:
@ -1176,12 +1165,12 @@ def _unflatten_model_state_dict(
)
return new_state_dict
else:
return cast(Dict[str, ValueType], state_dict)
return cast(dict[str, ValueType], state_dict)
def set_model_state_dict(
model: nn.Module,
model_state_dict: Dict[str, ValueType],
model_state_dict: dict[str, ValueType],
*,
options: Optional[StateDictOptions] = None,
) -> _IncompatibleKeys:
@ -1208,7 +1197,7 @@ def set_model_state_dict(
:type model_state_dict: typing.Dict[str, ValueType]
"""
model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict(
model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict(
model, model_state_dict
)
with _gc_context():
@ -1261,7 +1250,7 @@ def set_state_dict(
model: nn.Module,
optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
*,
model_state_dict: Dict[str, ValueType],
model_state_dict: dict[str, ValueType],
optim_state_dict: OptimizerStateType,
options: Optional[StateDictOptions] = None,
) -> _IncompatibleKeys:
@ -1299,7 +1288,7 @@ def set_state_dict(
:type optim_state_dict: typing.OptimizerStateType
"""
model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict(
model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict(
model, model_state_dict
)
with _gc_context():
@ -1363,7 +1352,7 @@ def _patch_model_state_dict(
options=options,
)
def load_state_dict_call(state_dict: Dict[str, Any]):
def load_state_dict_call(state_dict: dict[str, Any]):
_load_state_dict_call(model_state_dict=state_dict)
model.load_state_dict = load_state_dict_call
@ -1422,7 +1411,7 @@ def _patch_optimizer_state_dict(
options=options,
)
def load_state_dict_call(state_dict: Dict[str, Any]):
def load_state_dict_call(state_dict: dict[str, Any]):
_load_state_dict_call(optim_state_dict=state_dict)
_patched_state_dict.add(state_dict_call)

View File

@ -2,7 +2,7 @@
# mypy: allow-untyped-defs
import os
import warnings
from typing import Any, cast, Dict, Optional, Set, Union
from typing import Any, cast, Optional, Union
from typing_extensions import deprecated
import torch
@ -27,7 +27,7 @@ __all__ = ["load_state_dict", "load"]
category=FutureWarning,
)
def load_state_dict(
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
storage_reader: StorageReader,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
@ -51,7 +51,7 @@ def load_state_dict(
@_dcp_method_logger(log_exceptions=True)
@_api_bc_check
def load(
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
*,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_reader: Optional[StorageReader] = None,
@ -190,7 +190,7 @@ def load(
def _load_state_dict(
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
storage_reader: StorageReader,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
@ -241,12 +241,12 @@ def _load_state_dict(
def _load_state_dict_from_keys(
keys: Optional[Union[Set[str], str]] = None,
keys: Optional[Union[set[str], str]] = None,
*,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_reader: Optional[StorageReader] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Load only the specified keys from the checkpoint, if no keys are specified, the entire
checkpoint will be loaded. Note, this method completely loads the checkpoint into the
@ -311,7 +311,7 @@ def _load_state_dict_from_keys(
if isinstance(keys, str):
keys = {keys}
sd: Dict[str, Any] = {}
sd: dict[str, Any] = {}
_load_state_dict(
state_dict=sd,
storage_reader=storage_reader,

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, runtime_checkable, TypeVar
from typing import Any, runtime_checkable, TypeVar
from typing_extensions import Protocol
@ -11,7 +11,7 @@ class Stateful(Protocol):
Stateful protocol for objects that can be checkpointed and restored.
"""
def state_dict(self) -> Dict[str, Any]:
def state_dict(self) -> dict[str, Any]:
"""
Objects should return their state_dict representation as a dictionary.
The output of this function will be checkpointed, and later restored in
@ -28,7 +28,7 @@ class Stateful(Protocol):
...
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""
Restore the object's state from the provided state_dict.

View File

@ -1,7 +1,7 @@
import abc
import os
from dataclasses import dataclass
from typing import Any, List, Optional, Union
from typing import Any, Optional, Union
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
from torch.distributed.checkpoint.planner import (
@ -86,7 +86,7 @@ class StorageWriter(abc.ABC):
"""
@abc.abstractmethod
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
"""
Perform centralized planning of storage.
@ -105,7 +105,7 @@ class StorageWriter(abc.ABC):
@abc.abstractmethod
def write_data(
self, plan: SavePlan, planner: SavePlanner
) -> Future[List[WriteResult]]:
) -> Future[list[WriteResult]]:
"""
Write all items from ``plan`` using ``planner`` to resolve the data.
@ -127,7 +127,7 @@ class StorageWriter(abc.ABC):
"""
@abc.abstractmethod
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
"""
Write the metadata and marks the current checkpoint as successful.
@ -236,7 +236,7 @@ class StorageReader(abc.ABC):
"""
@abc.abstractmethod
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]:
"""
Perform centralized planning of storage loading.

View File

@ -5,10 +5,11 @@ import io
import itertools
import os
import warnings
from collections.abc import Sequence
from contextlib import contextmanager
from functools import wraps
from pstats import Stats
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TypeVar, Union
from typing import Any, Callable, cast, Optional, TypeVar, Union
import torch
import torch.distributed as dist
@ -31,20 +32,20 @@ R = TypeVar("R")
def _get_failure_dict(
results: List[Union[T, WRAPPED_EXCEPTION]]
) -> Dict[int, WRAPPED_EXCEPTION]:
results: list[Union[T, WRAPPED_EXCEPTION]]
) -> dict[int, WRAPPED_EXCEPTION]:
return cast(
Dict[int, WRAPPED_EXCEPTION],
dict[int, WRAPPED_EXCEPTION],
{i: err for i, err in enumerate(results) if _is_wrapped_exception(err)},
)
def _all_gather_keys(
local_dict: Dict[Any, Any], group: Optional[dist.ProcessGroup] = None
) -> List[Any]:
local_dict: dict[Any, Any], group: Optional[dist.ProcessGroup] = None
) -> list[Any]:
"""Gathers all keys, and returns them sorted."""
keys = list(local_dict.keys())
gathered_keys: List[List[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item]
gathered_keys: list[list[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item]
dist.all_gather_object(gathered_keys, keys, group=group)
return sorted(set(itertools.chain.from_iterable(gathered_keys)))
@ -95,11 +96,11 @@ class _DistWrapper:
)
return cast(T, object_list[0])
def gather_object(self, object: T) -> Optional[List[T]]:
def gather_object(self, object: T) -> Optional[list[T]]:
"""Implement functionality similar to c10d::gather_object but without distributed enabled."""
if self.use_dist:
gather_objs = (
cast(List[T], [None] * dist.get_world_size(self.group))
cast(list[T], [None] * dist.get_world_size(self.group))
if self.is_coordinator
else None
)
@ -115,10 +116,10 @@ class _DistWrapper:
result = [object]
return result
def all_gather_object(self, object: T) -> List[T]:
def all_gather_object(self, object: T) -> list[T]:
"""Implement functionality similar to c10d::all_gather_object but without distributed enabled."""
if self.use_dist:
gather_objs = cast(List[T], [None] * dist.get_world_size(self.group))
gather_objs = cast(list[T], [None] * dist.get_world_size(self.group))
dist.all_gather_object(
object_list=gather_objs, obj=object, group=self.group
@ -127,10 +128,10 @@ class _DistWrapper:
gather_objs = [object]
return gather_objs
def scatter_object(self, object_list: Optional[List[T]]) -> T:
def scatter_object(self, object_list: Optional[list[T]]) -> T:
"""Implement functionality similar to c10d::scatter_object but without distributed enabled."""
if self.use_dist:
gather_result = cast(List[T], [None])
gather_result = cast(list[T], [None])
dist.scatter_object_list(
scatter_object_output_list=gather_result,
scatter_object_input_list=object_list if self.is_coordinator else None,
@ -148,7 +149,7 @@ class _DistWrapper:
self,
step: str,
map_fun: Callable[[], T],
reduce_fun: Callable[[List[T]], List[R]],
reduce_fun: Callable[[list[T]], list[R]],
) -> R:
"""
Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter.
@ -166,7 +167,7 @@ class _DistWrapper:
local_data = _wrap_exception(e)
all_data = self.gather_object(local_data)
all_results: Optional[List[Union[R, CheckpointException]]] = None
all_results: Optional[list[Union[R, CheckpointException]]] = None
if self.is_coordinator:
assert all_data is not None
node_failures = _get_failure_dict(all_data)
@ -175,8 +176,8 @@ class _DistWrapper:
try:
# N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]?
all_results = cast(
List[Union[R, CheckpointException]],
reduce_fun(cast(List[T], all_data)),
list[Union[R, CheckpointException]],
reduce_fun(cast(list[T], all_data)),
)
except BaseException as e:
node_failures[self.rank] = _wrap_exception(e)
@ -195,7 +196,7 @@ class _DistWrapper:
self,
step: str,
map_fun: Callable[[], T],
reduce_fun: Callable[[List[T]], R],
reduce_fun: Callable[[list[T]], R],
) -> R:
"""
Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast.
@ -219,7 +220,7 @@ class _DistWrapper:
node_failures = _get_failure_dict(all_data)
if len(node_failures) == 0:
try:
result = reduce_fun(cast(List[T], all_data))
result = reduce_fun(cast(list[T], all_data))
except BaseException as e:
node_failures[self.rank] = _wrap_exception(e)
@ -235,7 +236,7 @@ class _DistWrapper:
self,
step: str,
map_fun: Callable[[], T],
) -> List[T]:
) -> list[T]:
"""
Compute a value on each rank, then all_gather them.
@ -254,7 +255,7 @@ class _DistWrapper:
node_failures = _get_failure_dict(all_results)
if len(node_failures) > 0:
raise CheckpointException(step, node_failures)
return cast(List[T], all_results)
return cast(list[T], all_results)
def broadcast(
self,
@ -331,11 +332,11 @@ def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) ->
return obj
def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> List[int]:
def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> list[int]:
return [i_a + i_b for i_a, i_b in zip(a, b)]
def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]:
def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> list[int]:
return [i_a - i_b for i_a, i_b in zip(a, b)]

View File

@ -18,7 +18,7 @@ from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch.distributed.elastic.rendezvous as rdzv
import torch.distributed.elastic.utils.store as store_util
@ -80,7 +80,7 @@ class WorkerSpec:
fn: Optional[Callable] = None
# TODO @kiuk - make entrypoint a required field
entrypoint: Union[Callable, str, None] = None
args: Tuple = ()
args: tuple = ()
max_restarts: int = 3
monitor_interval: float = 0.1
master_port: Optional[int] = None
@ -320,7 +320,7 @@ class _RoleInstanceInfo:
return -1
@staticmethod
def find_role_boundaries(roles_infos: List, role: str) -> tuple[int, int]:
def find_role_boundaries(roles_infos: list, role: str) -> tuple[int, int]:
start_idx, end_idx = -1, -1
for idx, role_info in enumerate(roles_infos):
if role_info.role == role:
@ -357,8 +357,8 @@ class RunResult:
"""
state: WorkerState
return_values: Dict[int, Any] = field(default_factory=dict)
failures: Dict[int, ProcessFailure] = field(default_factory=dict)
return_values: dict[int, Any] = field(default_factory=dict)
failures: dict[int, ProcessFailure] = field(default_factory=dict)
def is_failed(self) -> bool:
return self.state == WorkerState.FAILED
@ -448,7 +448,7 @@ class SimpleElasticAgent(ElasticAgent):
return self._worker_group
@abc.abstractmethod
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]:
r"""Start ``worker_group.spec.local_world_size`` number of workers.
This is according to worker spec for the worker group .
@ -554,7 +554,7 @@ class SimpleElasticAgent(ElasticAgent):
@prof
def _assign_worker_ranks(
self, store, group_rank: int, group_world_size: int, spec: WorkerSpec
) -> List[Worker]:
) -> list[Worker]:
"""Determine proper ranks for worker processes.
Fast Path: when all workers have the same role and world size. We calculate

View File

@ -15,7 +15,7 @@ import socket
import time
import uuid
from string import Template
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
import torch.distributed.elastic.timer as timer
from torch.distributed.elastic import events
@ -163,7 +163,7 @@ class LocalElasticAgent(SimpleElasticAgent):
self._logs_specs = logs_specs
self._health_check_server: Optional[HealthCheckServer] = None
def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None:
def _setup_local_watchdog(self, envs: dict[int, dict[str, str]]) -> None:
enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
watchdog_enabled = os.getenv(enable_watchdog_env_name)
watchdog_file_env_name = TORCHELASTIC_TIMER_FILE
@ -256,7 +256,7 @@ class LocalElasticAgent(SimpleElasticAgent):
md["signal"] = str(request.signal)
md_str = json.dumps(md)
state = "RUNNING"
metadata: Dict[str, EventMetadataValue] = {
metadata: dict[str, EventMetadataValue] = {
"run_id": spec.rdzv_handler.get_run_id(),
"global_rank": None,
"group_rank": wg.group_rank,
@ -288,7 +288,7 @@ class LocalElasticAgent(SimpleElasticAgent):
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]:
spec = worker_group.spec
store = worker_group.store
assert store is not None
@ -297,9 +297,9 @@ class LocalElasticAgent(SimpleElasticAgent):
use_agent_store: bool = spec.rdzv_handler.use_agent_store
logger.info("use_agent_store: %s", use_agent_store)
args: Dict[int, Tuple] = {}
envs: Dict[int, Dict[str, str]] = {}
log_line_prefixes: Optional[Dict[int, str]] = (
args: dict[int, tuple] = {}
envs: dict[int, dict[str, str]] = {}
log_line_prefixes: Optional[dict[int, str]] = (
{} if self._log_line_prefix_template else None
)
for worker in worker_group.workers:

View File

@ -1,6 +1,6 @@
import os
from collections.abc import Generator
from contextlib import contextmanager, ExitStack
from typing import Generator
from torch.distributed.elastic.multiprocessing.errors import record

View File

@ -37,7 +37,7 @@ from .api import ( # noqa: F401
)
_events_loggers: Dict[str, logging.Logger] = {}
_events_loggers: dict[str, logging.Logger] = {}
def _get_or_create_logger(destination: str = "null") -> logging.Logger:

View File

@ -10,7 +10,7 @@
import json
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Dict, Optional, Union
from typing import Optional, Union
__all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"]
@ -42,7 +42,7 @@ class Event:
name: str
source: EventSource
timestamp: int = 0
metadata: Dict[str, EventMetadataValue] = field(default_factory=dict)
metadata: dict[str, EventMetadataValue] = field(default_factory=dict)
def __str__(self):
return self.serialize()

View File

@ -7,10 +7,9 @@
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict
_log_handlers: Dict[str, logging.Handler] = {
_log_handlers: dict[str, logging.Handler] = {
"console": logging.StreamHandler(),
"dynamic_rendezvous": logging.NullHandler(),
"null": logging.NullHandler(),

View File

@ -11,7 +11,7 @@ import abc
import time
from collections import namedtuple
from functools import wraps
from typing import Dict, Optional
from typing import Optional
from typing_extensions import deprecated
@ -37,7 +37,7 @@ MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value
class MetricsConfig:
__slots__ = ["params"]
def __init__(self, params: Optional[Dict[str, str]] = None):
def __init__(self, params: Optional[dict[str, str]] = None):
self.params = params
if self.params is None:
self.params = {}
@ -72,7 +72,7 @@ class MetricStream:
)
_metrics_map: Dict[str, MetricHandler] = {}
_metrics_map: dict[str, MetricHandler] = {}
_default_metrics_handler: MetricHandler = NullMetricHandler()

View File

@ -100,10 +100,10 @@ __all__ = [
def start_processes(
name: str,
entrypoint: Union[Callable, str],
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
args: dict[int, tuple],
envs: dict[int, dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
log_line_prefixes: Optional[dict[int, str]] = None,
start_method: str = "spawn",
) -> PContext:
"""

View File

@ -24,7 +24,7 @@ from dataclasses import dataclass, field
from enum import IntFlag
from multiprocessing import synchronize
from types import FrameType
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record
@ -100,7 +100,7 @@ def _get_default_signal() -> signal.Signals:
return signal.SIGTERM
def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str):
def _validate_full_rank(d: dict[int, Any], nprocs: int, what: str):
actual_keys = set(d.keys())
expected_keys = set(range(nprocs))
@ -122,7 +122,7 @@ class Std(IntFlag):
ALL = OUT | ERR
@classmethod
def from_str(cls, vm: str) -> Union["Std", Dict[int, "Std"]]:
def from_str(cls, vm: str) -> Union["Std", dict[int, "Std"]]:
"""
Example:
::
@ -143,7 +143,7 @@ class Std(IntFlag):
if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0)
return to_std(vm)
elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2)
d: Dict[int, Std] = {}
d: dict[int, Std] = {}
for m in vm.split(","):
i, v = m.split(":")
d[int(i)] = to_std(v)
@ -155,8 +155,8 @@ class Std(IntFlag):
def to_map(
val_or_map: Union[Std, Dict[int, Std]], local_world_size: int
) -> Dict[int, Std]:
val_or_map: Union[Std, dict[int, Std]], local_world_size: int
) -> dict[int, Std]:
"""
Certain APIs take redirect settings either as a single value (e.g. apply to all
local ranks) or as an explicit user-provided mapping. This method is a convenience
@ -184,11 +184,11 @@ class LogsDest:
For each log type, holds mapping of local rank ids to file paths.
"""
stdouts: Dict[int, str] = field(default_factory=dict)
stderrs: Dict[int, str] = field(default_factory=dict)
tee_stdouts: Dict[int, str] = field(default_factory=dict)
tee_stderrs: Dict[int, str] = field(default_factory=dict)
error_files: Dict[int, str] = field(default_factory=dict)
stdouts: dict[int, str] = field(default_factory=dict)
stderrs: dict[int, str] = field(default_factory=dict)
tee_stdouts: dict[int, str] = field(default_factory=dict)
tee_stderrs: dict[int, str] = field(default_factory=dict)
error_files: dict[int, str] = field(default_factory=dict)
class LogsSpecs(ABC):
@ -211,9 +211,9 @@ class LogsSpecs(ABC):
def __init__(
self,
log_dir: Optional[str] = None,
redirects: Union[Std, Dict[int, Std]] = Std.NONE,
tee: Union[Std, Dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[Set[int]] = None,
redirects: Union[Std, dict[int, Std]] = Std.NONE,
tee: Union[Std, dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[set[int]] = None,
) -> None:
self._root_log_dir = log_dir
self._redirects = redirects
@ -223,7 +223,7 @@ class LogsSpecs(ABC):
@abstractmethod
def reify(
self,
envs: Dict[int, Dict[str, str]],
envs: dict[int, dict[str, str]],
) -> LogsDest:
"""
Given the environment variables, builds destination of log files for each of the local ranks.
@ -249,9 +249,9 @@ class DefaultLogsSpecs(LogsSpecs):
def __init__(
self,
log_dir: Optional[str] = None,
redirects: Union[Std, Dict[int, Std]] = Std.NONE,
tee: Union[Std, Dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[Set[int]] = None,
redirects: Union[Std, dict[int, Std]] = Std.NONE,
tee: Union[Std, dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[set[int]] = None,
) -> None:
if log_dir != os.devnull:
if not log_dir:
@ -278,7 +278,7 @@ class DefaultLogsSpecs(LogsSpecs):
def reify(
self,
envs: Dict[int, Dict[str, str]],
envs: dict[int, dict[str, str]],
) -> LogsDest:
"""
Uses following scheme to build log destination paths:
@ -331,8 +331,8 @@ class DefaultLogsSpecs(LogsSpecs):
SYS_STREAM = "" # special case to indicate to output to console
stdouts = dict.fromkeys(range(nprocs), SYS_STREAM)
stderrs = dict.fromkeys(range(nprocs), SYS_STREAM)
tee_stdouts: Dict[int, str] = {}
tee_stderrs: Dict[int, str] = {}
tee_stdouts: dict[int, str] = {}
tee_stderrs: dict[int, str] = {}
error_files = {}
for local_rank in range(nprocs):
@ -414,10 +414,10 @@ class RunProcsResult:
"""
return_values: Dict[int, Any] = field(default_factory=dict)
failures: Dict[int, ProcessFailure] = field(default_factory=dict)
stdouts: Dict[int, str] = field(default_factory=dict)
stderrs: Dict[int, str] = field(default_factory=dict)
return_values: dict[int, Any] = field(default_factory=dict)
failures: dict[int, ProcessFailure] = field(default_factory=dict)
stdouts: dict[int, str] = field(default_factory=dict)
stderrs: dict[int, str] = field(default_factory=dict)
def is_failed(self) -> bool:
return len(self.failures) > 0
@ -438,10 +438,10 @@ class PContext(abc.ABC):
self,
name: str,
entrypoint: Union[Callable, str],
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
args: dict[int, tuple],
envs: dict[int, dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
log_line_prefixes: Optional[dict[int, str]] = None,
):
self.name = name
# validate that all mappings have the same number of keys and
@ -544,7 +544,7 @@ class PContext(abc.ABC):
return None
@abc.abstractmethod
def pids(self) -> Dict[int, int]:
def pids(self) -> dict[int, int]:
"""Return pids of processes mapped by their respective local_ranks."""
raise NotImplementedError
@ -587,11 +587,11 @@ def get_std_cm(std_rd: str, redirect_fn):
def _wrap(
local_rank: int,
fn: Callable,
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None)
stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None)
ret_vals: Dict[int, mp.SimpleQueue],
args: dict[int, tuple],
envs: dict[int, dict[str, str]],
stdout_redirects: dict[int, str], # redirect file for stdout (to console if None)
stderr_redirects: dict[int, str], # redirect file for stderr (to console if None)
ret_vals: dict[int, mp.SimpleQueue],
queue_finished_reading_event: synchronize.Event,
) -> None:
# get the per-rank params up front so we fail fast if no mapping is found
@ -621,11 +621,11 @@ class MultiprocessContext(PContext):
self,
name: str,
entrypoint: Callable,
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
args: dict[int, tuple],
envs: dict[int, dict[str, str]],
start_method: str,
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
log_line_prefixes: Optional[dict[int, str]] = None,
):
super().__init__(
name,
@ -644,7 +644,7 @@ class MultiprocessContext(PContext):
}
# see comments in ``join()`` for what this is
self._return_values: Dict[int, Any] = {}
self._return_values: dict[int, Any] = {}
self._pc: Optional[mp.ProcessContext] = None
# Note: set method should ONLY be invoked for the use case when all processes finished
# successfully. If any process died on event.wait() calling set() method will deadlock.
@ -755,7 +755,7 @@ class MultiprocessContext(PContext):
stderrs=self.stderrs,
)
def pids(self) -> Dict[int, int]:
def pids(self) -> dict[int, int]:
assert self._pc is not None # assertion for mypy type checking
return dict(enumerate(self._pc.pids()))
@ -803,10 +803,10 @@ class SubprocessContext(PContext):
self,
name: str,
entrypoint: str,
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
args: dict[int, tuple],
envs: dict[int, dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
log_line_prefixes: Optional[dict[int, str]] = None,
):
super().__init__(
name,
@ -818,9 +818,9 @@ class SubprocessContext(PContext):
)
# state vector; _vdone[local_rank] -> is local_rank finished or not
self._running_local_ranks: Set[int] = set(range(self.nprocs))
self._failures: Dict[int, ProcessFailure] = {}
self.subprocess_handlers: Dict[int, SubprocessHandler] = {}
self._running_local_ranks: set[int] = set(range(self.nprocs))
self._failures: dict[int, ProcessFailure] = {}
self.subprocess_handlers: dict[int, SubprocessHandler] = {}
def _start(self):
if self.subprocess_handlers:
@ -884,7 +884,7 @@ class SubprocessContext(PContext):
else: # there are no failures and procs still running
return None
def pids(self) -> Dict[int, int]:
def pids(self) -> dict[int, int]:
return {
local_rank: sh.proc.pid
for local_rank, sh in self.subprocess_handlers.items()

View File

@ -78,7 +78,7 @@ __all__ = [
logger = get_logger(__name__)
JSON = Dict
JSON = dict
_EMPTY_ERROR_DATA = {"message": "<NONE>"}
_NOT_AVAILABLE = "<N/A>"
@ -143,7 +143,7 @@ class ProcessFailure:
else:
self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html"
def _get_error_data(self, error_file_data: Dict[str, Any]) -> tuple[str, int]:
def _get_error_data(self, error_file_data: dict[str, Any]) -> tuple[str, int]:
message = error_file_data["message"]
if isinstance(message, str):
timestamp = int(error_file_data.get("timestamp", 0))
@ -231,7 +231,7 @@ class ChildFailedError(Exception):
of trainer 1's error file to the scheduler's init process.
"""
def __init__(self, name: str, failures: Dict[GlobalRank, ProcessFailure]):
def __init__(self, name: str, failures: dict[GlobalRank, ProcessFailure]):
self.name = name
self.failures = failures
assert (
@ -248,7 +248,7 @@ class ChildFailedError(Exception):
root_rank, _root_failure = self.get_first_failure()
root_failure_fmt: str = ""
other_failures_fmt: List[str] = []
other_failures_fmt: list[str] = []
width = len(title)
for idx, (rank, failure) in enumerate(self.failures.items()):
fmt, w = self._format_failure(idx, rank, failure)

View File

@ -13,7 +13,7 @@ import os
import time
import traceback
import warnings
from typing import Any, Dict, Optional
from typing import Any, Optional
__all__ = ["ErrorHandler"]
@ -86,7 +86,7 @@ class ErrorHandler:
def override_error_code_in_rootcause_data(
self,
rootcause_error_file: str,
rootcause_error: Dict[str, Any],
rootcause_error: dict[str, Any],
error_code: int = 0,
):
"""Modify the rootcause_error read from the file, to correctly set the exit code."""

View File

@ -3,7 +3,6 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Tuple
from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
SubprocessHandler,
@ -15,8 +14,8 @@ __all__ = ["get_subprocess_handler"]
def get_subprocess_handler(
entrypoint: str,
args: Tuple,
env: Dict[str, str],
args: tuple,
env: dict[str, str],
stdout: str,
stderr: str,
local_rank_id: int,

View File

@ -9,7 +9,7 @@ import os
import signal
import subprocess
import sys
from typing import Any, Dict, Optional, Tuple
from typing import Any, Optional
__all__ = ["SubprocessHandler"]
@ -34,8 +34,8 @@ class SubprocessHandler:
def __init__(
self,
entrypoint: str,
args: Tuple,
env: Dict[str, str],
args: tuple,
env: dict[str, str],
stdout: Optional[str],
stderr: Optional[str],
local_rank_id: int,
@ -50,8 +50,8 @@ class SubprocessHandler:
self.local_rank_id = local_rank_id
self.proc: subprocess.Popen = self._popen(args_str, env_vars)
def _popen(self, args: Tuple, env: Dict[str, str]) -> subprocess.Popen:
kwargs: Dict[str, Any] = {}
def _popen(self, args: tuple, env: dict[str, str]) -> subprocess.Popen:
kwargs: dict[str, Any] = {}
if not IS_WINDOWS:
kwargs["start_new_session"] = True
return subprocess.Popen(

View File

@ -12,7 +12,7 @@ import os
import time
from concurrent.futures.thread import ThreadPoolExecutor
from threading import Event
from typing import Dict, List, Optional, TextIO, TYPE_CHECKING
from typing import Optional, TextIO, TYPE_CHECKING
if TYPE_CHECKING:
@ -89,9 +89,9 @@ class TailLog:
def __init__(
self,
name: str,
log_files: Dict[int, str],
log_files: dict[int, str],
dst: TextIO,
log_line_prefixes: Optional[Dict[int, str]] = None,
log_line_prefixes: Optional[dict[int, str]] = None,
interval_sec: float = 0.1,
):
n = len(log_files)
@ -106,10 +106,10 @@ class TailLog:
self._dst = dst
self._log_files = log_files
self._log_line_prefixes = log_line_prefixes
self._finished_events: Dict[int, Event] = {
self._finished_events: dict[int, Event] = {
local_rank: Event() for local_rank in log_files.keys()
}
self._futs: List[Future] = []
self._futs: list[Future] = []
self._interval_sec = interval_sec
self._stopped = False

View File

@ -8,7 +8,7 @@
import socket
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Dict, Optional
from typing import Any, Callable, ClassVar, Optional
from torch.distributed import Store
from torch.distributed.elastic.utils.distributed import get_free_port
@ -325,7 +325,7 @@ RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
class RendezvousHandlerRegistry:
"""Represent a registry of :py:class:`RendezvousHandler` backends."""
_registry: Dict[str, RendezvousHandlerCreator]
_registry: dict[str, RendezvousHandlerCreator]
def __init__(self) -> None:
self._registry = {}

View File

@ -17,7 +17,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set
from typing import Any, Callable, Optional
import torch.distributed as dist
from torch.distributed import Store
@ -298,10 +298,10 @@ class _RendezvousState:
complete: bool
deadline: Optional[datetime]
closed: bool
participants: Dict[_NodeDesc, int]
wait_list: Set[_NodeDesc]
redundancy_list: Set[_NodeDesc]
last_heartbeats: Dict[_NodeDesc, datetime]
participants: dict[_NodeDesc, int]
wait_list: set[_NodeDesc]
redundancy_list: set[_NodeDesc]
last_heartbeats: dict[_NodeDesc, datetime]
def __init__(self) -> None:
self.round = 0
@ -377,7 +377,7 @@ class _BackendRendezvousStateHolder(_RendezvousStateHolder):
_token: Token
_dirty: bool
_last_sync_time: float
_dead_nodes: List[_NodeDesc]
_dead_nodes: list[_NodeDesc]
def __init__(
self,

View File

@ -13,20 +13,20 @@ import time
import weakref
from datetime import timedelta
from threading import Event, Thread
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Optional, Union
__all__ = ["parse_rendezvous_endpoint"]
def _parse_rendezvous_config(config_str: str) -> Dict[str, str]:
def _parse_rendezvous_config(config_str: str) -> dict[str, str]:
"""Extract key-value pairs from a rendezvous configuration string.
Args:
config_str:
A string in format <key1>=<value1>,...,<keyN>=<valueN>.
"""
config: Dict[str, str] = {}
config: dict[str, str] = {}
config_str = config_str.strip()
if not config_str:
@ -196,7 +196,7 @@ class _PeriodicTimer:
interval: float
function: Callable[..., None]
args: tuple[Any, ...]
kwargs: Dict[str, Any]
kwargs: dict[str, Any]
stop_event: Event
_name: Optional[str]

View File

@ -10,7 +10,7 @@ import threading
import time
from contextlib import contextmanager
from inspect import getframeinfo, stack
from typing import Any, Dict, List, Optional, Set
from typing import Any, Optional
__all__ = [
@ -103,7 +103,7 @@ class RequestQueue(abc.ABC):
"""
@abc.abstractmethod
def get(self, size: int, timeout: float) -> List[TimerRequest]:
def get(self, size: int, timeout: float) -> list[TimerRequest]:
"""
Gets up to ``size`` number of timer requests in a blocking fashion
(no more than ``timeout`` seconds).
@ -134,7 +134,7 @@ class TimerServer(abc.ABC):
self._stop_signaled = False
@abc.abstractmethod
def register_timers(self, timer_requests: List[TimerRequest]) -> None:
def register_timers(self, timer_requests: list[TimerRequest]) -> None:
"""
Processes the incoming timer requests and registers them with the server.
The timer request can either be a acquire-timer or release-timer request.
@ -143,13 +143,13 @@ class TimerServer(abc.ABC):
"""
@abc.abstractmethod
def clear_timers(self, worker_ids: Set[Any]) -> None:
def clear_timers(self, worker_ids: set[Any]) -> None:
"""
Clears all timers for the given ``worker_ids``.
"""
@abc.abstractmethod
def get_expired_timers(self, deadline: float) -> Dict[str, List[TimerRequest]]:
def get_expired_timers(self, deadline: float) -> dict[str, list[TimerRequest]]:
"""
Returns all expired timers for each worker_id. An expired timer
is a timer for which the expiration_time is less than or equal to

View File

@ -7,7 +7,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List
from torch.distributed.elastic.utils.logging import get_logger
@ -19,7 +18,7 @@ __all__ = ["log_debug_info_for_expired_timers"]
def log_debug_info_for_expired_timers(
run_id: str,
expired_timers: Dict[int, List[str]],
expired_timers: dict[int, list[str]],
):
if expired_timers:
logger.info("Timers expired for run:[%s] [%s].", run_id, expired_timers)

View File

@ -13,7 +13,7 @@ import signal
import sys
import threading
import time
from typing import Callable, Dict, List, Optional, Set
from typing import Callable, Optional
from torch.distributed.elastic.timer.api import TimerClient, TimerRequest
from torch.distributed.elastic.timer.debug_info_logging import (
@ -201,7 +201,7 @@ class FileTimerServer:
self._run_id = run_id
self._max_interval = max_interval
self._daemon = daemon
self._timers: Dict[tuple[int, str], FileTimerRequest] = {}
self._timers: dict[tuple[int, str], FileTimerRequest] = {}
self._stop_signaled = False
self._watchdog_thread: Optional[threading.Thread] = None
@ -354,12 +354,12 @@ class FileTimerServer:
self.clear_timers(reaped_worker_pids)
def _get_scopes(self, timer_requests: List[FileTimerRequest]) -> List[str]:
def _get_scopes(self, timer_requests: list[FileTimerRequest]) -> list[str]:
return [r.scope_id for r in timer_requests]
def _get_requests(
self, fd: io.TextIOWrapper, max_interval: float
) -> List[FileTimerRequest]:
) -> list[FileTimerRequest]:
start = time.time()
requests = []
while not self._stop_signaled or self._run_once:
@ -394,7 +394,7 @@ class FileTimerServer:
break
return requests
def register_timers(self, timer_requests: List[FileTimerRequest]) -> None:
def register_timers(self, timer_requests: list[FileTimerRequest]) -> None:
for request in timer_requests:
pid = request.worker_pid
scope_id = request.scope_id
@ -409,14 +409,14 @@ class FileTimerServer:
else:
self._timers[key] = request
def clear_timers(self, worker_pids: Set[int]) -> None:
def clear_timers(self, worker_pids: set[int]) -> None:
for pid, scope_id in list(self._timers.keys()):
if pid in worker_pids or not FileTimerServer.is_process_running(pid):
del self._timers[(pid, scope_id)]
def get_expired_timers(self, deadline: float) -> Dict[int, List[FileTimerRequest]]:
def get_expired_timers(self, deadline: float) -> dict[int, list[FileTimerRequest]]:
# pid -> [timer_requests...]
expired_timers: Dict[int, List[FileTimerRequest]] = {}
expired_timers: dict[int, list[FileTimerRequest]] = {}
for request in self._timers.values():
if request.expiration_time <= deadline:
expired_scopes = expired_timers.setdefault(request.worker_pid, [])

View File

@ -10,7 +10,7 @@ import os
import signal
import time
from queue import Empty
from typing import Any, Dict, List, Set
from typing import Any
from .api import RequestQueue, TimerClient, TimerRequest, TimerServer
@ -56,7 +56,7 @@ class MultiprocessingRequestQueue(RequestQueue):
def size(self) -> int:
return self._mp_queue.qsize()
def get(self, size, timeout: float) -> List[TimerRequest]:
def get(self, size, timeout: float) -> list[TimerRequest]:
requests = []
wait = timeout
for _ in range(0, size):
@ -88,9 +88,9 @@ class LocalTimerServer(TimerServer):
self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True
):
super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon)
self._timers: Dict[tuple[Any, str], TimerRequest] = {}
self._timers: dict[tuple[Any, str], TimerRequest] = {}
def register_timers(self, timer_requests: List[TimerRequest]) -> None:
def register_timers(self, timer_requests: list[TimerRequest]) -> None:
for request in timer_requests:
pid = request.worker_id
scope_id = request.scope_id
@ -102,14 +102,14 @@ class LocalTimerServer(TimerServer):
else:
self._timers[(pid, scope_id)] = request
def clear_timers(self, worker_ids: Set[int]) -> None:
def clear_timers(self, worker_ids: set[int]) -> None:
for pid, scope_id in list(self._timers.keys()):
if pid in worker_ids:
self._timers.pop((pid, scope_id))
def get_expired_timers(self, deadline: float) -> Dict[Any, List[TimerRequest]]:
def get_expired_timers(self, deadline: float) -> dict[Any, list[TimerRequest]]:
# pid -> [timer_requests...]
expired_timers: Dict[Any, List[TimerRequest]] = {}
expired_timers: dict[Any, list[TimerRequest]] = {}
for request in self._timers.values():
if request.expiration_time <= deadline:
expired_scopes = expired_timers.setdefault(request.worker_id, [])

View File

@ -9,7 +9,7 @@
import os
import socket
from string import Template
from typing import Any, List
from typing import Any
def get_env_variable_or_raise(env_name: str) -> str:
@ -51,7 +51,7 @@ class macros:
local_rank = "${local_rank}"
@staticmethod
def substitute(args: List[Any], local_rank: str) -> List[str]:
def substitute(args: list[Any], local_rank: str) -> list[str]:
args_sub = []
for arg in args:
if isinstance(arg, str):

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
from typing import Callable, Iterator, TypeVar
from collections.abc import Iterator
from typing import Callable, TypeVar
from typing_extensions import Self

View File

@ -7,9 +7,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from collections.abc import Iterable
from contextlib import contextmanager
from datetime import timedelta
from typing import Callable, Iterable, List, Optional
from typing import Callable, Optional
import torch
@ -85,7 +86,7 @@ def synchronize(
world_size: int,
key_prefix: str,
timeout: float = 300,
) -> List[bytes]:
) -> list[bytes]:
"""
Synchronizes ``world_size`` agents between each other using the underlying c10d store.
The ``data`` will be available on each of the agents.