mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - torch/distributed/elastic torch/distributed/checkpoint (#145163)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145163 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
c64e657632
commit
316808e4e9
@ -1,5 +1,5 @@
|
||||
from concurrent.futures import Future
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint.state_dict_loader as loader
|
||||
@ -13,7 +13,7 @@ from torch.distributed.checkpoint.storage import (
|
||||
)
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
class _Checkpointer:
|
||||
@ -90,7 +90,7 @@ class _Checkpointer:
|
||||
planner=self.save_planner,
|
||||
)
|
||||
|
||||
def load(self, state_dict: Dict[str, Any]) -> None:
|
||||
def load(self, state_dict: dict[str, Any]) -> None:
|
||||
"""Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization."""
|
||||
loader.load(
|
||||
state_dict,
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import dataclasses
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Set, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch.distributed.checkpoint.planner import SavePlan, WriteItem
|
||||
|
||||
@ -13,16 +13,16 @@ __all__ = ["dedup_save_plans"]
|
||||
|
||||
|
||||
def dedup_save_plans(
|
||||
all_plans: List[SavePlan],
|
||||
all_plans: list[SavePlan],
|
||||
save_to_lowest_rank: bool = False,
|
||||
) -> List[SavePlan]:
|
||||
) -> list[SavePlan]:
|
||||
"""
|
||||
Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across
|
||||
a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry.
|
||||
"""
|
||||
|
||||
write_item_to_plan_indices: Dict[MetadataIndex, Set[int]] = defaultdict(set)
|
||||
write_item_idx_to_write_item: Dict[MetadataIndex, WriteItem] = {}
|
||||
write_item_to_plan_indices: dict[MetadataIndex, set[int]] = defaultdict(set)
|
||||
write_item_idx_to_write_item: dict[MetadataIndex, WriteItem] = {}
|
||||
for plan_idx, plan in enumerate(all_plans):
|
||||
for write_item in plan.items:
|
||||
# map each write item to its plan
|
||||
@ -30,7 +30,7 @@ def dedup_save_plans(
|
||||
write_item_idx_to_write_item[write_item.index] = write_item
|
||||
|
||||
# put item in the plan with the smallest size and remove it from the other plan_indices
|
||||
to_remove: List[Set] = [set() for _ in range(len(all_plans))]
|
||||
to_remove: list[set] = [set() for _ in range(len(all_plans))]
|
||||
plan_to_size = [0] * len(all_plans)
|
||||
for write_item_idx, plan_indices in write_item_to_plan_indices.items():
|
||||
if save_to_lowest_rank:
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch.distributed.checkpoint.planner import SavePlan
|
||||
|
||||
@ -31,9 +31,9 @@ logger = init_logger()
|
||||
|
||||
|
||||
# TODO add docstring for dedup_tensors
|
||||
def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
|
||||
def dedup_tensors(all_plans: list[SavePlan]) -> list[SavePlan]:
|
||||
all_plans = list(all_plans)
|
||||
key_to_plan: Dict[MetadataIndex, List[int]] = {}
|
||||
key_to_plan: dict[MetadataIndex, list[int]] = {}
|
||||
for plan_idx, plan in enumerate(all_plans):
|
||||
for write_item in plan.items:
|
||||
key_to_plan.setdefault(write_item.index, []).append(plan_idx)
|
||||
@ -42,7 +42,7 @@ def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
|
||||
|
||||
# Remove duplicates by always keeping the first entry.
|
||||
# Compute the per-rank remove set.
|
||||
plan_to_keys: Dict[int, List[MetadataIndex]] = {}
|
||||
plan_to_keys: dict[int, list[MetadataIndex]] = {}
|
||||
for key, plans in replicated_items.items():
|
||||
for plan_idx in plans[1:]:
|
||||
plan_to_keys.setdefault(plan_idx, []).append(key)
|
||||
|
@ -3,10 +3,10 @@
|
||||
|
||||
import io
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Generator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator, Optional, TYPE_CHECKING, Union
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
from fsspec.core import url_to_fs
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from typing import Dict
|
||||
|
||||
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
|
||||
|
||||
@ -21,7 +20,7 @@ Change set_element to recreate the right type for tuple, OrderedDict, and NamedT
|
||||
"""
|
||||
|
||||
|
||||
FLATTEN_MAPPING = Dict[str, OBJ_PATH]
|
||||
FLATTEN_MAPPING = dict[str, OBJ_PATH]
|
||||
|
||||
|
||||
# TODO: Update Docstring for nested_dict.py
|
||||
|
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import List, Type, Union
|
||||
from typing import Union
|
||||
|
||||
from .filesystem import FileSystemReader, FileSystemWriter
|
||||
from .storage import StorageReader, StorageWriter
|
||||
@ -21,7 +21,7 @@ def _storage_setup(
|
||||
"storage_reader/storage_writer is None."
|
||||
)
|
||||
|
||||
targets: List[Type[Union[StorageReader, StorageWriter]]] = []
|
||||
targets: list[type[Union[StorageReader, StorageWriter]]] = []
|
||||
if reader:
|
||||
targets = [
|
||||
FileSystemReader,
|
||||
|
@ -1,15 +1,6 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from typing import (
|
||||
Callable,
|
||||
cast,
|
||||
Collection,
|
||||
List,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Collection, Mapping, MutableMapping
|
||||
from typing import Callable, cast, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
||||
@ -123,7 +114,7 @@ def set_element(
|
||||
"""Set ``value`` in ``root_dict`` along the ``path`` object path."""
|
||||
cur_container = cast(CONTAINER_TYPE, root_dict)
|
||||
|
||||
def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None:
|
||||
def extend_list(lst: list[STATE_DICT_ITEM], idx: int) -> None:
|
||||
while len(lst) <= idx:
|
||||
lst.append(None)
|
||||
|
||||
@ -144,7 +135,7 @@ def set_element(
|
||||
|
||||
key = path[-1]
|
||||
if type(key) == int:
|
||||
extend_list(cast(List[STATE_DICT_ITEM], cur_container), key)
|
||||
extend_list(cast(list[STATE_DICT_ITEM], cur_container), key)
|
||||
|
||||
cur_container[key] = value
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import traceback as tb
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
|
||||
WRAPPED_EXCEPTION = tuple[BaseException, tb.StackSummary]
|
||||
@ -22,12 +22,12 @@ def _is_wrapped_exception(obj: Any) -> bool:
|
||||
class CheckpointException(BaseException):
|
||||
"""Exception raised if failure was detected as part of a checkpoint load or save."""
|
||||
|
||||
def __init__(self, msg: str, failures: Dict[int, WRAPPED_EXCEPTION]):
|
||||
def __init__(self, msg: str, failures: dict[int, WRAPPED_EXCEPTION]):
|
||||
super().__init__(msg, failures)
|
||||
self._failures = failures
|
||||
|
||||
@property
|
||||
def failures(self) -> Dict[int, WRAPPED_EXCEPTION]:
|
||||
def failures(self) -> dict[int, WRAPPED_EXCEPTION]:
|
||||
"""Return a dictionary mapping node ranks to their associated exceptions in case of failure."""
|
||||
return self._failures
|
||||
|
||||
|
@ -7,7 +7,7 @@ import logging
|
||||
import operator
|
||||
from collections import ChainMap
|
||||
from functools import reduce
|
||||
from typing import Any, cast, Dict, List, Optional, Union
|
||||
from typing import Any, cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard._utils import narrow_tensor_by_index
|
||||
@ -106,8 +106,8 @@ class DefaultSavePlanner(SavePlanner):
|
||||
return self.plan
|
||||
|
||||
def create_global_plan(
|
||||
self, all_plans: List[SavePlan]
|
||||
) -> tuple[List[SavePlan], Metadata]:
|
||||
self, all_plans: list[SavePlan]
|
||||
) -> tuple[list[SavePlan], Metadata]:
|
||||
all_plans = dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank)
|
||||
|
||||
global_plan, metadata = create_default_global_save_plan(all_plans)
|
||||
@ -234,7 +234,7 @@ class DefaultLoadPlanner(LoadPlanner):
|
||||
self.state_dict, self.metadata, not self.allow_partial_load
|
||||
)
|
||||
|
||||
def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
|
||||
def create_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]:
|
||||
return create_default_global_load_plan(global_plan)
|
||||
|
||||
def finish_plan(self, new_plan: LoadPlan) -> LoadPlan:
|
||||
@ -293,7 +293,7 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
|
||||
if key in self.keys:
|
||||
True
|
||||
|
||||
unflattened_keys: List[str] = []
|
||||
unflattened_keys: list[str] = []
|
||||
planner_data = metadata.planner_data.get(key)
|
||||
for unflattened_key in planner_data:
|
||||
if unflattened_keys:
|
||||
@ -334,7 +334,7 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
|
||||
|
||||
|
||||
def create_default_local_load_plan(
|
||||
state_dict: Dict[str, Any], metadata: Metadata, strict: bool = True
|
||||
state_dict: dict[str, Any], metadata: Metadata, strict: bool = True
|
||||
) -> LoadPlan:
|
||||
requests = []
|
||||
"""
|
||||
@ -376,8 +376,8 @@ def create_default_local_load_plan(
|
||||
|
||||
|
||||
def create_default_global_load_plan(
|
||||
all_plans: List[LoadPlan],
|
||||
) -> List[LoadPlan]:
|
||||
all_plans: list[LoadPlan],
|
||||
) -> list[LoadPlan]:
|
||||
"""
|
||||
Create global load plan used by DefaultLoadPlanner.
|
||||
|
||||
@ -388,7 +388,7 @@ def create_default_global_load_plan(
|
||||
|
||||
|
||||
def create_default_local_save_plan(
|
||||
state_dict: Dict[str, Any], is_coordinator: bool
|
||||
state_dict: dict[str, Any], is_coordinator: bool
|
||||
) -> SavePlan:
|
||||
"""
|
||||
Create the ``SavePlan`` used by DefaultSavePlanner.
|
||||
@ -415,9 +415,9 @@ def create_default_local_save_plan(
|
||||
|
||||
|
||||
def create_default_global_save_plan(
|
||||
all_plans: List[SavePlan],
|
||||
all_plans: list[SavePlan],
|
||||
rewrite_index_hints: bool = True,
|
||||
) -> tuple[List[SavePlan], Metadata]:
|
||||
) -> tuple[list[SavePlan], Metadata]:
|
||||
"""
|
||||
Create the global plan and metadata used by DefaultSavePlanner.
|
||||
|
||||
@ -426,7 +426,7 @@ def create_default_global_save_plan(
|
||||
The only global planning change is to update index hints in all ``MetadataIndex`` objects if
|
||||
``rewrite_index_hints`` is True.
|
||||
"""
|
||||
md: Dict[str, STORAGE_TYPES] = {}
|
||||
md: dict[str, STORAGE_TYPES] = {}
|
||||
new_plans = []
|
||||
for plan in all_plans:
|
||||
new_items = []
|
||||
@ -506,7 +506,7 @@ def _check_box_bounds(
|
||||
return True
|
||||
|
||||
|
||||
def _validate_global_plan(global_plan: List[SavePlan], metadata: Metadata) -> bool:
|
||||
def _validate_global_plan(global_plan: list[SavePlan], metadata: Metadata) -> bool:
|
||||
all_good = True
|
||||
for key, value in metadata.state_dict_metadata.items():
|
||||
if isinstance(value, BytesStorageMetadata):
|
||||
|
@ -10,24 +10,12 @@ import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Generator, Iterable, Iterator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from io import UnsupportedOperation
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Generator,
|
||||
IO,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, cast, IO, Optional, Union
|
||||
|
||||
# introduced as collections.abc.Buffer in Python 3.12
|
||||
from typing_extensions import Buffer
|
||||
@ -113,7 +101,7 @@ class _TensorLoader(ABC):
|
||||
class _SerialCpuLoader(_TensorLoader):
|
||||
def __init__(self, resolve_fun: Callable) -> None:
|
||||
self.resolve_fun = resolve_fun
|
||||
self.items: List[tuple[int, object]] = []
|
||||
self.items: list[tuple[int, object]] = []
|
||||
|
||||
def add(self, size: int, obj: object) -> None:
|
||||
self.items.append((size, obj))
|
||||
@ -141,7 +129,7 @@ class _OverlappingCpuLoader(_TensorLoader):
|
||||
inflight_threshhold: int = 1_000_000,
|
||||
) -> None:
|
||||
self.resolve_fun = resolve_fun
|
||||
self.items: List[tuple[int, object]] = []
|
||||
self.items: list[tuple[int, object]] = []
|
||||
self.inflight_threshhold = inflight_threshhold
|
||||
self.in_flight_data = 0
|
||||
self.current_items: collections.deque = collections.deque()
|
||||
@ -161,7 +149,7 @@ class _OverlappingCpuLoader(_TensorLoader):
|
||||
def _done(self) -> bool:
|
||||
return self.idx >= len(self.items)
|
||||
|
||||
def _drain(self) -> List[tuple[torch.Tensor, object]]:
|
||||
def _drain(self) -> list[tuple[torch.Tensor, object]]:
|
||||
drained = []
|
||||
if self.in_flight_data >= self.inflight_threshhold:
|
||||
self.stream.synchronize()
|
||||
@ -243,7 +231,7 @@ class _StorageWriterTransforms:
|
||||
|
||||
def transform_save_stream(
|
||||
self, write_item: WriteItem, raw_stream: io.IOBase
|
||||
) -> tuple[IO[bytes], List[str]]:
|
||||
) -> tuple[IO[bytes], list[str]]:
|
||||
# In order to avoid leaking fds, transformers' close must
|
||||
# cascade to wrapped streams, but since this function can
|
||||
# append to the raw stream, we can't close the actual stream.
|
||||
@ -285,14 +273,14 @@ def _item_size(item: WriteItem) -> int:
|
||||
return size * torch._utils._element_size(dtype)
|
||||
|
||||
|
||||
def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
|
||||
def _split_by_size_and_type(bins: int, items: list[WriteItem]) -> list[list[WriteItem]]:
|
||||
if bins == 1:
|
||||
return [items]
|
||||
|
||||
bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO]
|
||||
tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO]
|
||||
|
||||
buckets: List[List[WriteItem]] = [[] for _ in range(bins)]
|
||||
buckets: list[list[WriteItem]] = [[] for _ in range(bins)]
|
||||
bucket_sizes = [0 for _ in range(bins)]
|
||||
|
||||
tensor_w.sort(key=_item_size, reverse=True)
|
||||
@ -584,7 +572,7 @@ class _FileSystemWriter(StorageWriter):
|
||||
|
||||
return plan
|
||||
|
||||
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
|
||||
def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
|
||||
new_plans = [
|
||||
dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
|
||||
for i, plan in enumerate(plans)
|
||||
@ -595,7 +583,7 @@ class _FileSystemWriter(StorageWriter):
|
||||
self,
|
||||
plan: SavePlan,
|
||||
planner: SavePlanner,
|
||||
) -> Future[List[WriteResult]]:
|
||||
) -> Future[list[WriteResult]]:
|
||||
storage_plan: _StoragePrefix = plan.storage_data
|
||||
file_count = 0
|
||||
|
||||
@ -656,11 +644,11 @@ class _FileSystemWriter(StorageWriter):
|
||||
while True:
|
||||
res += result_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
fut: Future[List[WriteResult]] = Future()
|
||||
fut: Future[list[WriteResult]] = Future()
|
||||
fut.set_result(res)
|
||||
return fut
|
||||
|
||||
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
|
||||
def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
|
||||
storage_md = {}
|
||||
for wr_list in results:
|
||||
storage_md.update({wr.index: wr.storage_data for wr in wr_list})
|
||||
@ -737,7 +725,7 @@ class FileSystemReader(StorageReader):
|
||||
super().__init__()
|
||||
self.fs = FileSystem()
|
||||
self.path = self.fs.init_path(path)
|
||||
self.storage_data: Dict[MetadataIndex, _StorageInfo] = {}
|
||||
self.storage_data: dict[MetadataIndex, _StorageInfo] = {}
|
||||
self.load_id = _generate_uuid()
|
||||
self.transforms = _StorageReaderTransforms(_extension_registry)
|
||||
|
||||
@ -752,7 +740,7 @@ class FileSystemReader(StorageReader):
|
||||
|
||||
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
|
||||
# group requests by file
|
||||
per_file: Dict[str, List[ReadItem]] = {}
|
||||
per_file: dict[str, list[ReadItem]] = {}
|
||||
for read_item in plan.items:
|
||||
item_md = self.storage_data[read_item.storage_index]
|
||||
path = item_md.relative_path
|
||||
@ -828,7 +816,7 @@ class FileSystemReader(StorageReader):
|
||||
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
|
||||
return plan
|
||||
|
||||
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
|
||||
def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]:
|
||||
return plans
|
||||
|
||||
@property
|
||||
|
@ -2,7 +2,7 @@
|
||||
import argparse
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import cast, Dict, List, Optional, Union
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -133,7 +133,7 @@ class BroadcastingTorchSaveReader(StorageReader):
|
||||
"""Implementation of the StorageReader method"""
|
||||
return plan
|
||||
|
||||
def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
|
||||
def prepare_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]:
|
||||
"""Implementation of the StorageReader method"""
|
||||
return global_plan
|
||||
|
||||
@ -177,7 +177,7 @@ class DynamicMetaLoadPlanner(DefaultLoadPlanner):
|
||||
"""Setups of the planner, extnding default behavior by creating the Metadata object from the state dict"""
|
||||
super().set_up_planner(state_dict, metadata, is_coordinator)
|
||||
|
||||
state_dict_metadata: Dict[str, STORAGE_TYPES] = {}
|
||||
state_dict_metadata: dict[str, STORAGE_TYPES] = {}
|
||||
for key, tensor in self.state_dict.items():
|
||||
if not torch.is_tensor(tensor):
|
||||
raise RuntimeError(
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, TypeVar
|
||||
from typing import Any, Callable, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
from uuid import uuid4
|
||||
|
||||
@ -9,7 +9,7 @@ import torch.distributed.c10d_logger as c10d_logger
|
||||
from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
global _dcp_logger
|
||||
_dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME)
|
||||
@ -18,7 +18,7 @@ _T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
def _msg_dict_from_dcp_method_args(*args, **kwargs) -> Dict[str, Any]:
|
||||
def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]:
|
||||
"""
|
||||
Extracts log data from dcp method args
|
||||
"""
|
||||
@ -52,7 +52,7 @@ def _msg_dict_from_dcp_method_args(*args, **kwargs) -> Dict[str, Any]:
|
||||
return msg_dict
|
||||
|
||||
|
||||
def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
|
||||
def _get_msg_dict(func_name, *args, **kwargs) -> dict[str, Any]:
|
||||
msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs)
|
||||
msg_dict.update(c10d_logger._get_msg_dict(func_name, *args, **kwargs))
|
||||
|
||||
|
@ -1,10 +1,9 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from torch.distributed.logging_handlers import _log_handlers
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
DCP_LOGGER_NAME = "dcp_logger"
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint.stateful import StatefulT
|
||||
@ -113,7 +114,7 @@ class TensorProperties:
|
||||
class TensorStorageMetadata:
|
||||
properties: TensorProperties
|
||||
size: torch.Size
|
||||
chunks: List[ChunkStorageMetadata]
|
||||
chunks: list[ChunkStorageMetadata]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -122,7 +123,7 @@ class BytesStorageMetadata:
|
||||
|
||||
|
||||
STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata]
|
||||
STATE_DICT_TYPE = Dict[str, Union[StatefulT, Any]]
|
||||
STATE_DICT_TYPE = dict[str, Union[StatefulT, Any]]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -137,7 +138,7 @@ class Metadata:
|
||||
"""This class represents the metadata of the checkpoint."""
|
||||
|
||||
# Keys are the same from the `state_dict` used.
|
||||
state_dict_metadata: Dict[str, STORAGE_TYPES]
|
||||
state_dict_metadata: dict[str, STORAGE_TYPES]
|
||||
# It is the responsibility of the planner and storage plugins to ensure
|
||||
# backward compatibility of the planner_data and storage_data. DCP will
|
||||
# also ensure the backward compatibility of the metadata in this file and
|
||||
|
@ -1,7 +1,8 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
import dataclasses
|
||||
from typing import cast, Dict, List, Optional, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -41,7 +42,7 @@ from torch.distributed.remote_device import _remote_device
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
|
||||
STATE_DICT_2D_LAYOUT = Dict[str, tuple[Optional[Sequence[int]], Sequence[int]]]
|
||||
STATE_DICT_2D_LAYOUT = dict[str, tuple[Optional[Sequence[int]], Sequence[int]]]
|
||||
|
||||
|
||||
# TODO: Update docstrings for optimizer.py
|
||||
@ -77,7 +78,7 @@ def _create_colwise_spec(
|
||||
]
|
||||
return ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=cast(List[Union[_remote_device, str]], placements),
|
||||
placements=cast(list[Union[_remote_device, str]], placements),
|
||||
)
|
||||
|
||||
|
||||
@ -154,11 +155,11 @@ def _get_state_dict_2d_layout(
|
||||
|
||||
|
||||
class _ReaderWithOffset(DefaultLoadPlanner):
|
||||
translation: Dict[MetadataIndex, MetadataIndex]
|
||||
translation: dict[MetadataIndex, MetadataIndex]
|
||||
state_dict: STATE_DICT_TYPE
|
||||
metadata: Metadata
|
||||
|
||||
def __init__(self, fqn_to_offset: Dict[str, Sequence[int]]) -> None:
|
||||
def __init__(self, fqn_to_offset: dict[str, Sequence[int]]) -> None:
|
||||
super().__init__()
|
||||
self.fqn_to_offset = fqn_to_offset
|
||||
self.metadata = Metadata({})
|
||||
@ -284,7 +285,7 @@ def load_sharded_optimizer_state_dict(
|
||||
# Create a state_dict for optimizer state
|
||||
state_dict: STATE_DICT_TYPE = {}
|
||||
|
||||
fqn_to_offset: Dict[str, Sequence[int]] = {}
|
||||
fqn_to_offset: dict[str, Sequence[int]] = {}
|
||||
for key, value in metadata.state_dict_metadata.items():
|
||||
key_path = metadata.planner_data[key]
|
||||
if key_path[0] != optimizer_key:
|
||||
|
@ -4,7 +4,7 @@ import operator
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from functools import reduce
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
@ -94,14 +94,14 @@ class ReadItem:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SavePlan:
|
||||
items: List[WriteItem]
|
||||
items: list[WriteItem]
|
||||
storage_data: Any = None
|
||||
planner_data: Any = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadPlan:
|
||||
items: List[ReadItem]
|
||||
items: list[ReadItem]
|
||||
storage_data: Any = None
|
||||
planner_data: Any = None
|
||||
|
||||
@ -231,8 +231,8 @@ class SavePlanner(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_global_plan(
|
||||
self, all_plans: List[SavePlan]
|
||||
) -> tuple[List[SavePlan], Metadata]:
|
||||
self, all_plans: list[SavePlan]
|
||||
) -> tuple[list[SavePlan], Metadata]:
|
||||
"""
|
||||
Compute the global checkpoint plan and return the local plan of each rank.
|
||||
|
||||
@ -364,7 +364,7 @@ class LoadPlanner:
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
|
||||
def create_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]:
|
||||
"""
|
||||
Compute the global load plan and return plans for each rank.
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import io
|
||||
from typing import Any, Callable, cast, Dict, List
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -33,7 +33,7 @@ from .resharding import (
|
||||
)
|
||||
|
||||
|
||||
__all__: List[str] = ["create_read_items_for_chunk_list"]
|
||||
__all__: list[str] = ["create_read_items_for_chunk_list"]
|
||||
|
||||
|
||||
def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata:
|
||||
@ -149,8 +149,8 @@ def _create_read_item_for_tensor(
|
||||
def create_read_items_for_chunk_list(
|
||||
fqn: str,
|
||||
checkpoint_md: TensorStorageMetadata,
|
||||
local_chunks: List[ChunkStorageMetadata],
|
||||
) -> List[ReadItem]:
|
||||
local_chunks: list[ChunkStorageMetadata],
|
||||
) -> list[ReadItem]:
|
||||
"""
|
||||
Create a list of ``ReadItem`` based on the checkpoint and local chunks.
|
||||
|
||||
@ -218,7 +218,7 @@ def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
|
||||
return SavePlan(requests)
|
||||
|
||||
|
||||
def _create_write_items(fqn: str, object: Any) -> List[WriteItem]:
|
||||
def _create_write_items(fqn: str, object: Any) -> list[WriteItem]:
|
||||
if hasattr(object, "__create_write_items__"):
|
||||
# DTensor implements _Checkpointable
|
||||
return object.__create_write_items__(fqn, object)
|
||||
@ -244,7 +244,7 @@ def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata:
|
||||
)
|
||||
|
||||
|
||||
def _create_chunk_list(tensor: torch.Tensor) -> List[ChunkStorageMetadata]:
|
||||
def _create_chunk_list(tensor: torch.Tensor) -> list[ChunkStorageMetadata]:
|
||||
if hasattr(tensor, "__create_chunk_list__"):
|
||||
# DTensor implements _Checkpointable
|
||||
local_chunks = tensor.__create_chunk_list__() # type: ignore[attr-defined]
|
||||
@ -263,7 +263,7 @@ def _create_chunk_list(tensor: torch.Tensor) -> List[ChunkStorageMetadata]:
|
||||
return local_chunks
|
||||
|
||||
|
||||
def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]:
|
||||
def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> list[ReadItem]:
|
||||
if not isinstance(md, BytesStorageMetadata):
|
||||
try:
|
||||
local_chunks = _create_chunk_list(obj)
|
||||
@ -286,7 +286,7 @@ def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]:
|
||||
]
|
||||
|
||||
|
||||
def _init_state_dict(state_dict: Dict[str, Any]) -> Any:
|
||||
def _init_state_dict(state_dict: dict[str, Any]) -> Any:
|
||||
"""
|
||||
Initializes meta tensor if the meta tensor is DTensor or torch.Tensor.
|
||||
"""
|
||||
|
@ -1,10 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import List
|
||||
|
||||
from torch.distributed.checkpoint.metadata import ChunkStorageMetadata
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
def _check_shard_metadata_pair_overlap(
|
||||
@ -27,7 +26,7 @@ def _check_shard_metadata_pair_overlap(
|
||||
|
||||
def _shards_get_overlap_region_wrt_saved_tensor(
|
||||
saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata
|
||||
) -> List[tuple[int, int, int, int]]:
|
||||
) -> list[tuple[int, int, int, int]]:
|
||||
"""
|
||||
Return the overlapping region between saved_shard and current_shard.
|
||||
|
||||
|
@ -3,21 +3,10 @@ import contextlib
|
||||
import functools
|
||||
import gc
|
||||
import warnings
|
||||
from collections.abc import Generator, Iterable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
no_type_check,
|
||||
Optional,
|
||||
Set,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, cast, no_type_check, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -76,17 +65,17 @@ _PG = "param_groups"
|
||||
_PARAMS = "params"
|
||||
_STATE = "state"
|
||||
|
||||
FQNS_T = Set[str]
|
||||
FQNS_T = set[str]
|
||||
PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str]
|
||||
ValueType = Union[
|
||||
PrimitiveType, List[PrimitiveType], tuple[PrimitiveType], Dict[str, "ValueType"]
|
||||
PrimitiveType, list[PrimitiveType], tuple[PrimitiveType], dict[str, "ValueType"]
|
||||
]
|
||||
DictValueType = Dict[str, ValueType]
|
||||
ListDictValueType = List[DictValueType]
|
||||
OptimizerStateType = Dict[str, Union[DictValueType, ListDictValueType]]
|
||||
DictValueType = dict[str, ValueType]
|
||||
ListDictValueType = list[DictValueType]
|
||||
OptimizerStateType = dict[str, Union[DictValueType, ListDictValueType]]
|
||||
|
||||
|
||||
_patched_state_dict: Set[Callable] = set()
|
||||
_patched_state_dict: set[Callable] = set()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -149,20 +138,20 @@ class StateDictOptions:
|
||||
|
||||
@dataclass
|
||||
class _StateDictInfo(StateDictOptions):
|
||||
fqn_param_mapping: Dict[
|
||||
fqn_param_mapping: dict[
|
||||
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
|
||||
] = field(default_factory=dict)
|
||||
shared_params_mapping: Dict[
|
||||
shared_params_mapping: dict[
|
||||
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
|
||||
] = field(default_factory=dict)
|
||||
submodule_prefixes: Set[str] = field(default_factory=set)
|
||||
submodule_prefixes: set[str] = field(default_factory=set)
|
||||
handle_model: bool = True
|
||||
handle_optim: bool = True
|
||||
fsdp_context: Callable = contextlib.nullcontext
|
||||
fsdp_modules: List[nn.Module] = field(default_factory=list)
|
||||
fsdp_modules: list[nn.Module] = field(default_factory=list)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def _get_fqns(
|
||||
model: nn.Module,
|
||||
name: str,
|
||||
@ -230,7 +219,7 @@ class _EXTRA_STATE:
|
||||
|
||||
|
||||
def _iterate_valid_model_state(model):
|
||||
visited_modules: Set[nn.Module] = set()
|
||||
visited_modules: set[nn.Module] = set()
|
||||
|
||||
def recurse(module: nn.Module, curr_fqn: str) -> Generator:
|
||||
visited_modules.add(module)
|
||||
@ -265,7 +254,7 @@ def _verify_options(
|
||||
optims: tuple[torch.optim.Optimizer, ...],
|
||||
optim_only: bool,
|
||||
*,
|
||||
submodules: Optional[Set[nn.Module]] = None,
|
||||
submodules: Optional[set[nn.Module]] = None,
|
||||
options: Optional[StateDictOptions] = None,
|
||||
) -> _StateDictInfo:
|
||||
"""
|
||||
@ -285,11 +274,11 @@ def _verify_options(
|
||||
|
||||
options = options or StateDictOptions()
|
||||
|
||||
fqn_param_mapping: Dict[
|
||||
Union[str, torch.Tensor], Union[Set[str], torch.Tensor]
|
||||
fqn_param_mapping: dict[
|
||||
Union[str, torch.Tensor], Union[set[str], torch.Tensor]
|
||||
] = {}
|
||||
shared_params_mapping: Dict[
|
||||
Union[str, torch.Tensor], Union[Set[str], torch.Tensor]
|
||||
shared_params_mapping: dict[
|
||||
Union[str, torch.Tensor], Union[set[str], torch.Tensor]
|
||||
] = {}
|
||||
for name, param in _iterate_valid_model_state(model):
|
||||
if isinstance(param, _EXTRA_STATE):
|
||||
@ -298,7 +287,7 @@ def _verify_options(
|
||||
fqns = _get_fqns(model, name)
|
||||
fqn = fqn_param_mapping.get(param, None)
|
||||
if fqn is not None:
|
||||
cast(Set[str], fqn_param_mapping[param]).update(fqns)
|
||||
cast(set[str], fqn_param_mapping[param]).update(fqns)
|
||||
shared_params_mapping[param] = fqn_param_mapping[param]
|
||||
else:
|
||||
# We need to do copy as _get_fqns is lru_cached
|
||||
@ -311,7 +300,7 @@ def _verify_options(
|
||||
for fqn in fqns_:
|
||||
shared_params_mapping[fqn] = cast(torch.Tensor, param_)
|
||||
|
||||
submodule_prefixes: Set[str] = set()
|
||||
submodule_prefixes: set[str] = set()
|
||||
if submodules:
|
||||
submodules = set(submodules)
|
||||
for name, module in model.named_modules():
|
||||
@ -384,14 +373,14 @@ def _verify_options(
|
||||
shared_params_mapping=shared_params_mapping,
|
||||
submodule_prefixes=submodule_prefixes,
|
||||
fsdp_context=fsdp_context,
|
||||
fsdp_modules=cast(List[nn.Module], fsdp_modules),
|
||||
fsdp_modules=cast(list[nn.Module], fsdp_modules),
|
||||
handle_model=not optim_only,
|
||||
handle_optim=(len(optims) > 0),
|
||||
)
|
||||
|
||||
|
||||
def _verify_state_dict(
|
||||
model_state_dict: Dict[str, ValueType],
|
||||
model_state_dict: dict[str, ValueType],
|
||||
optim_state_dict: OptimizerStateType,
|
||||
info: _StateDictInfo,
|
||||
) -> None:
|
||||
@ -443,8 +432,8 @@ def _state_dict_fn(obj: Union[nn.Module, torch.optim.Optimizer], api: str) -> Ca
|
||||
|
||||
|
||||
def _maybe_full_or_cpu_state_dict(
|
||||
state_dict: Dict[str, Any], info: _StateDictInfo
|
||||
) -> Dict[str, Any]:
|
||||
state_dict: dict[str, Any], info: _StateDictInfo
|
||||
) -> dict[str, Any]:
|
||||
if info.full_state_dict:
|
||||
ranks_only = (
|
||||
()
|
||||
@ -463,7 +452,7 @@ def _maybe_full_or_cpu_state_dict(
|
||||
@torch.no_grad()
|
||||
def _get_model_state_dict(
|
||||
model: nn.Module, info: _StateDictInfo
|
||||
) -> Dict[str, ValueType]:
|
||||
) -> dict[str, ValueType]:
|
||||
if not info.handle_model:
|
||||
return {}
|
||||
|
||||
@ -500,7 +489,7 @@ def _get_model_state_dict(
|
||||
state_dict[fqn] = state_dict.pop(key)
|
||||
|
||||
if info.submodule_prefixes:
|
||||
new_state_dict: Dict[str, ValueType] = {}
|
||||
new_state_dict: dict[str, ValueType] = {}
|
||||
# TODO: make this faster.
|
||||
for fqn in state_dict.keys():
|
||||
for prefix in info.submodule_prefixes:
|
||||
@ -531,7 +520,7 @@ def _get_model_state_dict(
|
||||
@torch.no_grad()
|
||||
def _load_model_state_dict(
|
||||
model: nn.Module,
|
||||
state_dict: Dict[str, ValueType],
|
||||
state_dict: dict[str, ValueType],
|
||||
info: _StateDictInfo,
|
||||
) -> _IncompatibleKeys:
|
||||
if not info.handle_model or (not state_dict and not info.broadcast_from_rank0):
|
||||
@ -636,7 +625,7 @@ def _init_optim_state(optim: torch.optim.Optimizer) -> None:
|
||||
optim.zero_grad(set_to_none=True)
|
||||
|
||||
|
||||
def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, ValueType]:
|
||||
def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> dict[str, ValueType]:
|
||||
"""
|
||||
This API flattens the optimizer state_dict to support optimizer resharding for
|
||||
MPMD, e.g., pipeline parallelism.
|
||||
@ -686,7 +675,7 @@ def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, Value
|
||||
f"Type is {type(v)}."
|
||||
)
|
||||
|
||||
ret: Dict[str, ValueType] = {}
|
||||
ret: dict[str, ValueType] = {}
|
||||
for fqn, state in cast(DictValueType, state_dict[_STATE]).items():
|
||||
for k, v in cast(DictValueType, state).items():
|
||||
_raise_if_type_not_supported(v)
|
||||
@ -694,7 +683,7 @@ def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, Value
|
||||
|
||||
for param_group in cast(ListDictValueType, state_dict[_PG]):
|
||||
fqns = param_group.pop(_PARAMS)
|
||||
for fqn in cast(List[str], fqns):
|
||||
for fqn in cast(list[str], fqns):
|
||||
for k, v in param_group.items():
|
||||
ret[f"{_PG}.{fqn}.{k}"] = v
|
||||
return ret
|
||||
@ -702,7 +691,7 @@ def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, Value
|
||||
|
||||
def _unflatten_optim_state_dict(
|
||||
optim: torch.optim.Optimizer,
|
||||
state_dict: Dict[str, ValueType],
|
||||
state_dict: dict[str, ValueType],
|
||||
info: _StateDictInfo,
|
||||
) -> OptimizerStateType:
|
||||
"""
|
||||
@ -728,7 +717,7 @@ def _unflatten_optim_state_dict(
|
||||
f"{_STATE}.{fqn}.{state_name}"
|
||||
]
|
||||
|
||||
first_param_fqn = cast(List[str], pg_state[-1][_PARAMS])[0]
|
||||
first_param_fqn = cast(list[str], pg_state[-1][_PARAMS])[0]
|
||||
for k in param_group.keys():
|
||||
if k == _PARAMS:
|
||||
continue
|
||||
@ -833,7 +822,7 @@ def _split_optim_state_dict(
|
||||
state: DictValueType = {}
|
||||
pg_state: ListDictValueType = []
|
||||
return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state}
|
||||
pg_mapping: Dict[int, int] = {}
|
||||
pg_mapping: dict[int, int] = {}
|
||||
|
||||
if all(
|
||||
isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys()
|
||||
@ -849,7 +838,7 @@ def _split_optim_state_dict(
|
||||
for loaded_param_group in cast(
|
||||
ListDictValueType, optim_state_dict[_PG]
|
||||
):
|
||||
if fqn in cast(List[str], loaded_param_group[_PARAMS]):
|
||||
if fqn in cast(list[str], loaded_param_group[_PARAMS]):
|
||||
in_params = True
|
||||
break
|
||||
else:
|
||||
@ -865,7 +854,7 @@ def _split_optim_state_dict(
|
||||
for loaded_param_group in cast(
|
||||
ListDictValueType, optim_state_dict[_PG]
|
||||
):
|
||||
if fqn in cast(List[str], loaded_param_group[_PARAMS]):
|
||||
if fqn in cast(list[str], loaded_param_group[_PARAMS]):
|
||||
pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1
|
||||
|
||||
for param_group in cast(ListDictValueType, optim_state_dict[_PG]):
|
||||
@ -900,7 +889,7 @@ def _load_optim_state_dict(
|
||||
)
|
||||
else:
|
||||
optim_state_dict = _unflatten_optim_state_dict(
|
||||
optim, cast(Dict[str, ValueType], state_dict), info
|
||||
optim, cast(dict[str, ValueType], state_dict), info
|
||||
)
|
||||
else:
|
||||
optim_state_dict = {}
|
||||
@ -919,7 +908,7 @@ def _load_optim_state_dict(
|
||||
fqn = fqns.pop()
|
||||
fqn_with_compiler = fqns_with_compiler.pop()
|
||||
for g in optim_state_dict[_PG]:
|
||||
val = cast(Dict[str, Any], g)
|
||||
val = cast(dict[str, Any], g)
|
||||
params = [
|
||||
key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS]
|
||||
]
|
||||
@ -978,9 +967,9 @@ def _load_optim_state_dict(
|
||||
def get_model_state_dict(
|
||||
model: nn.Module,
|
||||
*,
|
||||
submodules: Optional[Set[nn.Module]] = None,
|
||||
submodules: Optional[set[nn.Module]] = None,
|
||||
options: Optional[StateDictOptions] = None,
|
||||
) -> Dict[str, ValueType]:
|
||||
) -> dict[str, ValueType]:
|
||||
"""
|
||||
Return the model state_dict of ``model``.
|
||||
|
||||
@ -1016,7 +1005,7 @@ def get_optimizer_state_dict(
|
||||
model: nn.Module,
|
||||
optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
|
||||
*,
|
||||
submodules: Optional[Set[nn.Module]] = None,
|
||||
submodules: Optional[set[nn.Module]] = None,
|
||||
options: Optional[StateDictOptions] = None,
|
||||
) -> OptimizerStateType:
|
||||
"""
|
||||
@ -1061,9 +1050,9 @@ def get_state_dict(
|
||||
model: nn.Module,
|
||||
optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
|
||||
*,
|
||||
submodules: Optional[Set[nn.Module]] = None,
|
||||
submodules: Optional[set[nn.Module]] = None,
|
||||
options: Optional[StateDictOptions] = None,
|
||||
) -> tuple[Dict[str, ValueType], OptimizerStateType]:
|
||||
) -> tuple[dict[str, ValueType], OptimizerStateType]:
|
||||
"""
|
||||
Return the model state_dict and optimizers state_dict.
|
||||
|
||||
@ -1148,8 +1137,8 @@ def get_state_dict(
|
||||
|
||||
def _unflatten_model_state_dict(
|
||||
model: nn.Module,
|
||||
state_dict: Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]],
|
||||
) -> Dict[str, ValueType]:
|
||||
state_dict: Union[dict[nn.Module, dict[str, ValueType]], dict[str, ValueType]],
|
||||
) -> dict[str, ValueType]:
|
||||
if not state_dict:
|
||||
return {}
|
||||
|
||||
@ -1161,8 +1150,8 @@ def _unflatten_model_state_dict(
|
||||
"same functionality.",
|
||||
FutureWarning,
|
||||
)
|
||||
cast_state_dict = cast(Dict[nn.Module, Dict[str, ValueType]], state_dict)
|
||||
new_state_dict: Dict[str, ValueType] = {}
|
||||
cast_state_dict = cast(dict[nn.Module, dict[str, ValueType]], state_dict)
|
||||
new_state_dict: dict[str, ValueType] = {}
|
||||
for submodule, sub_state_dict in cast_state_dict.items():
|
||||
for name, m in model.named_modules():
|
||||
if m != submodule:
|
||||
@ -1176,12 +1165,12 @@ def _unflatten_model_state_dict(
|
||||
)
|
||||
return new_state_dict
|
||||
else:
|
||||
return cast(Dict[str, ValueType], state_dict)
|
||||
return cast(dict[str, ValueType], state_dict)
|
||||
|
||||
|
||||
def set_model_state_dict(
|
||||
model: nn.Module,
|
||||
model_state_dict: Dict[str, ValueType],
|
||||
model_state_dict: dict[str, ValueType],
|
||||
*,
|
||||
options: Optional[StateDictOptions] = None,
|
||||
) -> _IncompatibleKeys:
|
||||
@ -1208,7 +1197,7 @@ def set_model_state_dict(
|
||||
|
||||
:type model_state_dict: typing.Dict[str, ValueType]
|
||||
"""
|
||||
model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict(
|
||||
model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict(
|
||||
model, model_state_dict
|
||||
)
|
||||
with _gc_context():
|
||||
@ -1261,7 +1250,7 @@ def set_state_dict(
|
||||
model: nn.Module,
|
||||
optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
|
||||
*,
|
||||
model_state_dict: Dict[str, ValueType],
|
||||
model_state_dict: dict[str, ValueType],
|
||||
optim_state_dict: OptimizerStateType,
|
||||
options: Optional[StateDictOptions] = None,
|
||||
) -> _IncompatibleKeys:
|
||||
@ -1299,7 +1288,7 @@ def set_state_dict(
|
||||
:type optim_state_dict: typing.OptimizerStateType
|
||||
"""
|
||||
|
||||
model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict(
|
||||
model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict(
|
||||
model, model_state_dict
|
||||
)
|
||||
with _gc_context():
|
||||
@ -1363,7 +1352,7 @@ def _patch_model_state_dict(
|
||||
options=options,
|
||||
)
|
||||
|
||||
def load_state_dict_call(state_dict: Dict[str, Any]):
|
||||
def load_state_dict_call(state_dict: dict[str, Any]):
|
||||
_load_state_dict_call(model_state_dict=state_dict)
|
||||
|
||||
model.load_state_dict = load_state_dict_call
|
||||
@ -1422,7 +1411,7 @@ def _patch_optimizer_state_dict(
|
||||
options=options,
|
||||
)
|
||||
|
||||
def load_state_dict_call(state_dict: Dict[str, Any]):
|
||||
def load_state_dict_call(state_dict: dict[str, Any]):
|
||||
_load_state_dict_call(optim_state_dict=state_dict)
|
||||
|
||||
_patched_state_dict.add(state_dict_call)
|
||||
|
@ -2,7 +2,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, cast, Dict, Optional, Set, Union
|
||||
from typing import Any, cast, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
@ -27,7 +27,7 @@ __all__ = ["load_state_dict", "load"]
|
||||
category=FutureWarning,
|
||||
)
|
||||
def load_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
storage_reader: StorageReader,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
coordinator_rank: int = 0,
|
||||
@ -51,7 +51,7 @@ def load_state_dict(
|
||||
@_dcp_method_logger(log_exceptions=True)
|
||||
@_api_bc_check
|
||||
def load(
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
*,
|
||||
checkpoint_id: Union[str, os.PathLike, None] = None,
|
||||
storage_reader: Optional[StorageReader] = None,
|
||||
@ -190,7 +190,7 @@ def load(
|
||||
|
||||
|
||||
def _load_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
storage_reader: StorageReader,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
coordinator_rank: int = 0,
|
||||
@ -241,12 +241,12 @@ def _load_state_dict(
|
||||
|
||||
|
||||
def _load_state_dict_from_keys(
|
||||
keys: Optional[Union[Set[str], str]] = None,
|
||||
keys: Optional[Union[set[str], str]] = None,
|
||||
*,
|
||||
checkpoint_id: Union[str, os.PathLike, None] = None,
|
||||
storage_reader: Optional[StorageReader] = None,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Load only the specified keys from the checkpoint, if no keys are specified, the entire
|
||||
checkpoint will be loaded. Note, this method completely loads the checkpoint into the
|
||||
@ -311,7 +311,7 @@ def _load_state_dict_from_keys(
|
||||
if isinstance(keys, str):
|
||||
keys = {keys}
|
||||
|
||||
sd: Dict[str, Any] = {}
|
||||
sd: dict[str, Any] = {}
|
||||
_load_state_dict(
|
||||
state_dict=sd,
|
||||
storage_reader=storage_reader,
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, runtime_checkable, TypeVar
|
||||
from typing import Any, runtime_checkable, TypeVar
|
||||
from typing_extensions import Protocol
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ class Stateful(Protocol):
|
||||
Stateful protocol for objects that can be checkpointed and restored.
|
||||
"""
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Objects should return their state_dict representation as a dictionary.
|
||||
The output of this function will be checkpointed, and later restored in
|
||||
@ -28,7 +28,7 @@ class Stateful(Protocol):
|
||||
|
||||
...
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
||||
"""
|
||||
Restore the object's state from the provided state_dict.
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import abc
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
|
||||
from torch.distributed.checkpoint.planner import (
|
||||
@ -86,7 +86,7 @@ class StorageWriter(abc.ABC):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
|
||||
def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
|
||||
"""
|
||||
Perform centralized planning of storage.
|
||||
|
||||
@ -105,7 +105,7 @@ class StorageWriter(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def write_data(
|
||||
self, plan: SavePlan, planner: SavePlanner
|
||||
) -> Future[List[WriteResult]]:
|
||||
) -> Future[list[WriteResult]]:
|
||||
"""
|
||||
Write all items from ``plan`` using ``planner`` to resolve the data.
|
||||
|
||||
@ -127,7 +127,7 @@ class StorageWriter(abc.ABC):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
|
||||
def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
|
||||
"""
|
||||
Write the metadata and marks the current checkpoint as successful.
|
||||
|
||||
@ -236,7 +236,7 @@ class StorageReader(abc.ABC):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
|
||||
def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]:
|
||||
"""
|
||||
Perform centralized planning of storage loading.
|
||||
|
||||
|
@ -5,10 +5,11 @@ import io
|
||||
import itertools
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from pstats import Stats
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TypeVar, Union
|
||||
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -31,20 +32,20 @@ R = TypeVar("R")
|
||||
|
||||
|
||||
def _get_failure_dict(
|
||||
results: List[Union[T, WRAPPED_EXCEPTION]]
|
||||
) -> Dict[int, WRAPPED_EXCEPTION]:
|
||||
results: list[Union[T, WRAPPED_EXCEPTION]]
|
||||
) -> dict[int, WRAPPED_EXCEPTION]:
|
||||
return cast(
|
||||
Dict[int, WRAPPED_EXCEPTION],
|
||||
dict[int, WRAPPED_EXCEPTION],
|
||||
{i: err for i, err in enumerate(results) if _is_wrapped_exception(err)},
|
||||
)
|
||||
|
||||
|
||||
def _all_gather_keys(
|
||||
local_dict: Dict[Any, Any], group: Optional[dist.ProcessGroup] = None
|
||||
) -> List[Any]:
|
||||
local_dict: dict[Any, Any], group: Optional[dist.ProcessGroup] = None
|
||||
) -> list[Any]:
|
||||
"""Gathers all keys, and returns them sorted."""
|
||||
keys = list(local_dict.keys())
|
||||
gathered_keys: List[List[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item]
|
||||
gathered_keys: list[list[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item]
|
||||
|
||||
dist.all_gather_object(gathered_keys, keys, group=group)
|
||||
return sorted(set(itertools.chain.from_iterable(gathered_keys)))
|
||||
@ -95,11 +96,11 @@ class _DistWrapper:
|
||||
)
|
||||
return cast(T, object_list[0])
|
||||
|
||||
def gather_object(self, object: T) -> Optional[List[T]]:
|
||||
def gather_object(self, object: T) -> Optional[list[T]]:
|
||||
"""Implement functionality similar to c10d::gather_object but without distributed enabled."""
|
||||
if self.use_dist:
|
||||
gather_objs = (
|
||||
cast(List[T], [None] * dist.get_world_size(self.group))
|
||||
cast(list[T], [None] * dist.get_world_size(self.group))
|
||||
if self.is_coordinator
|
||||
else None
|
||||
)
|
||||
@ -115,10 +116,10 @@ class _DistWrapper:
|
||||
result = [object]
|
||||
return result
|
||||
|
||||
def all_gather_object(self, object: T) -> List[T]:
|
||||
def all_gather_object(self, object: T) -> list[T]:
|
||||
"""Implement functionality similar to c10d::all_gather_object but without distributed enabled."""
|
||||
if self.use_dist:
|
||||
gather_objs = cast(List[T], [None] * dist.get_world_size(self.group))
|
||||
gather_objs = cast(list[T], [None] * dist.get_world_size(self.group))
|
||||
|
||||
dist.all_gather_object(
|
||||
object_list=gather_objs, obj=object, group=self.group
|
||||
@ -127,10 +128,10 @@ class _DistWrapper:
|
||||
gather_objs = [object]
|
||||
return gather_objs
|
||||
|
||||
def scatter_object(self, object_list: Optional[List[T]]) -> T:
|
||||
def scatter_object(self, object_list: Optional[list[T]]) -> T:
|
||||
"""Implement functionality similar to c10d::scatter_object but without distributed enabled."""
|
||||
if self.use_dist:
|
||||
gather_result = cast(List[T], [None])
|
||||
gather_result = cast(list[T], [None])
|
||||
dist.scatter_object_list(
|
||||
scatter_object_output_list=gather_result,
|
||||
scatter_object_input_list=object_list if self.is_coordinator else None,
|
||||
@ -148,7 +149,7 @@ class _DistWrapper:
|
||||
self,
|
||||
step: str,
|
||||
map_fun: Callable[[], T],
|
||||
reduce_fun: Callable[[List[T]], List[R]],
|
||||
reduce_fun: Callable[[list[T]], list[R]],
|
||||
) -> R:
|
||||
"""
|
||||
Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter.
|
||||
@ -166,7 +167,7 @@ class _DistWrapper:
|
||||
local_data = _wrap_exception(e)
|
||||
|
||||
all_data = self.gather_object(local_data)
|
||||
all_results: Optional[List[Union[R, CheckpointException]]] = None
|
||||
all_results: Optional[list[Union[R, CheckpointException]]] = None
|
||||
if self.is_coordinator:
|
||||
assert all_data is not None
|
||||
node_failures = _get_failure_dict(all_data)
|
||||
@ -175,8 +176,8 @@ class _DistWrapper:
|
||||
try:
|
||||
# N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]?
|
||||
all_results = cast(
|
||||
List[Union[R, CheckpointException]],
|
||||
reduce_fun(cast(List[T], all_data)),
|
||||
list[Union[R, CheckpointException]],
|
||||
reduce_fun(cast(list[T], all_data)),
|
||||
)
|
||||
except BaseException as e:
|
||||
node_failures[self.rank] = _wrap_exception(e)
|
||||
@ -195,7 +196,7 @@ class _DistWrapper:
|
||||
self,
|
||||
step: str,
|
||||
map_fun: Callable[[], T],
|
||||
reduce_fun: Callable[[List[T]], R],
|
||||
reduce_fun: Callable[[list[T]], R],
|
||||
) -> R:
|
||||
"""
|
||||
Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast.
|
||||
@ -219,7 +220,7 @@ class _DistWrapper:
|
||||
node_failures = _get_failure_dict(all_data)
|
||||
if len(node_failures) == 0:
|
||||
try:
|
||||
result = reduce_fun(cast(List[T], all_data))
|
||||
result = reduce_fun(cast(list[T], all_data))
|
||||
except BaseException as e:
|
||||
node_failures[self.rank] = _wrap_exception(e)
|
||||
|
||||
@ -235,7 +236,7 @@ class _DistWrapper:
|
||||
self,
|
||||
step: str,
|
||||
map_fun: Callable[[], T],
|
||||
) -> List[T]:
|
||||
) -> list[T]:
|
||||
"""
|
||||
Compute a value on each rank, then all_gather them.
|
||||
|
||||
@ -254,7 +255,7 @@ class _DistWrapper:
|
||||
node_failures = _get_failure_dict(all_results)
|
||||
if len(node_failures) > 0:
|
||||
raise CheckpointException(step, node_failures)
|
||||
return cast(List[T], all_results)
|
||||
return cast(list[T], all_results)
|
||||
|
||||
def broadcast(
|
||||
self,
|
||||
@ -331,11 +332,11 @@ def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) ->
|
||||
return obj
|
||||
|
||||
|
||||
def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> List[int]:
|
||||
def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> list[int]:
|
||||
return [i_a + i_b for i_a, i_b in zip(a, b)]
|
||||
|
||||
|
||||
def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]:
|
||||
def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> list[int]:
|
||||
return [i_a - i_b for i_a, i_b in zip(a, b)]
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@ from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch.distributed.elastic.rendezvous as rdzv
|
||||
import torch.distributed.elastic.utils.store as store_util
|
||||
@ -80,7 +80,7 @@ class WorkerSpec:
|
||||
fn: Optional[Callable] = None
|
||||
# TODO @kiuk - make entrypoint a required field
|
||||
entrypoint: Union[Callable, str, None] = None
|
||||
args: Tuple = ()
|
||||
args: tuple = ()
|
||||
max_restarts: int = 3
|
||||
monitor_interval: float = 0.1
|
||||
master_port: Optional[int] = None
|
||||
@ -320,7 +320,7 @@ class _RoleInstanceInfo:
|
||||
return -1
|
||||
|
||||
@staticmethod
|
||||
def find_role_boundaries(roles_infos: List, role: str) -> tuple[int, int]:
|
||||
def find_role_boundaries(roles_infos: list, role: str) -> tuple[int, int]:
|
||||
start_idx, end_idx = -1, -1
|
||||
for idx, role_info in enumerate(roles_infos):
|
||||
if role_info.role == role:
|
||||
@ -357,8 +357,8 @@ class RunResult:
|
||||
"""
|
||||
|
||||
state: WorkerState
|
||||
return_values: Dict[int, Any] = field(default_factory=dict)
|
||||
failures: Dict[int, ProcessFailure] = field(default_factory=dict)
|
||||
return_values: dict[int, Any] = field(default_factory=dict)
|
||||
failures: dict[int, ProcessFailure] = field(default_factory=dict)
|
||||
|
||||
def is_failed(self) -> bool:
|
||||
return self.state == WorkerState.FAILED
|
||||
@ -448,7 +448,7 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
return self._worker_group
|
||||
|
||||
@abc.abstractmethod
|
||||
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
|
||||
def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]:
|
||||
r"""Start ``worker_group.spec.local_world_size`` number of workers.
|
||||
|
||||
This is according to worker spec for the worker group .
|
||||
@ -554,7 +554,7 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
@prof
|
||||
def _assign_worker_ranks(
|
||||
self, store, group_rank: int, group_world_size: int, spec: WorkerSpec
|
||||
) -> List[Worker]:
|
||||
) -> list[Worker]:
|
||||
"""Determine proper ranks for worker processes.
|
||||
|
||||
Fast Path: when all workers have the same role and world size. We calculate
|
||||
|
@ -15,7 +15,7 @@ import socket
|
||||
import time
|
||||
import uuid
|
||||
from string import Template
|
||||
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import torch.distributed.elastic.timer as timer
|
||||
from torch.distributed.elastic import events
|
||||
@ -163,7 +163,7 @@ class LocalElasticAgent(SimpleElasticAgent):
|
||||
self._logs_specs = logs_specs
|
||||
self._health_check_server: Optional[HealthCheckServer] = None
|
||||
|
||||
def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None:
|
||||
def _setup_local_watchdog(self, envs: dict[int, dict[str, str]]) -> None:
|
||||
enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
|
||||
watchdog_enabled = os.getenv(enable_watchdog_env_name)
|
||||
watchdog_file_env_name = TORCHELASTIC_TIMER_FILE
|
||||
@ -256,7 +256,7 @@ class LocalElasticAgent(SimpleElasticAgent):
|
||||
md["signal"] = str(request.signal)
|
||||
md_str = json.dumps(md)
|
||||
state = "RUNNING"
|
||||
metadata: Dict[str, EventMetadataValue] = {
|
||||
metadata: dict[str, EventMetadataValue] = {
|
||||
"run_id": spec.rdzv_handler.get_run_id(),
|
||||
"global_rank": None,
|
||||
"group_rank": wg.group_rank,
|
||||
@ -288,7 +288,7 @@ class LocalElasticAgent(SimpleElasticAgent):
|
||||
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
|
||||
# `torch.distributed.elastic.metrics.prof`.
|
||||
@prof
|
||||
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
|
||||
def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]:
|
||||
spec = worker_group.spec
|
||||
store = worker_group.store
|
||||
assert store is not None
|
||||
@ -297,9 +297,9 @@ class LocalElasticAgent(SimpleElasticAgent):
|
||||
use_agent_store: bool = spec.rdzv_handler.use_agent_store
|
||||
logger.info("use_agent_store: %s", use_agent_store)
|
||||
|
||||
args: Dict[int, Tuple] = {}
|
||||
envs: Dict[int, Dict[str, str]] = {}
|
||||
log_line_prefixes: Optional[Dict[int, str]] = (
|
||||
args: dict[int, tuple] = {}
|
||||
envs: dict[int, dict[str, str]] = {}
|
||||
log_line_prefixes: Optional[dict[int, str]] = (
|
||||
{} if self._log_line_prefix_template else None
|
||||
)
|
||||
for worker in worker_group.workers:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager, ExitStack
|
||||
from typing import Generator
|
||||
|
||||
from torch.distributed.elastic.multiprocessing.errors import record
|
||||
|
||||
|
@ -37,7 +37,7 @@ from .api import ( # noqa: F401
|
||||
)
|
||||
|
||||
|
||||
_events_loggers: Dict[str, logging.Logger] = {}
|
||||
_events_loggers: dict[str, logging.Logger] = {}
|
||||
|
||||
|
||||
def _get_or_create_logger(destination: str = "null") -> logging.Logger:
|
||||
|
@ -10,7 +10,7 @@
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
__all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"]
|
||||
@ -42,7 +42,7 @@ class Event:
|
||||
name: str
|
||||
source: EventSource
|
||||
timestamp: int = 0
|
||||
metadata: Dict[str, EventMetadataValue] = field(default_factory=dict)
|
||||
metadata: dict[str, EventMetadataValue] = field(default_factory=dict)
|
||||
|
||||
def __str__(self):
|
||||
return self.serialize()
|
||||
|
@ -7,10 +7,9 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
|
||||
_log_handlers: Dict[str, logging.Handler] = {
|
||||
_log_handlers: dict[str, logging.Handler] = {
|
||||
"console": logging.StreamHandler(),
|
||||
"dynamic_rendezvous": logging.NullHandler(),
|
||||
"null": logging.NullHandler(),
|
||||
|
@ -11,7 +11,7 @@ import abc
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from functools import wraps
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
from typing_extensions import deprecated
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value
|
||||
class MetricsConfig:
|
||||
__slots__ = ["params"]
|
||||
|
||||
def __init__(self, params: Optional[Dict[str, str]] = None):
|
||||
def __init__(self, params: Optional[dict[str, str]] = None):
|
||||
self.params = params
|
||||
if self.params is None:
|
||||
self.params = {}
|
||||
@ -72,7 +72,7 @@ class MetricStream:
|
||||
)
|
||||
|
||||
|
||||
_metrics_map: Dict[str, MetricHandler] = {}
|
||||
_metrics_map: dict[str, MetricHandler] = {}
|
||||
_default_metrics_handler: MetricHandler = NullMetricHandler()
|
||||
|
||||
|
||||
|
@ -100,10 +100,10 @@ __all__ = [
|
||||
def start_processes(
|
||||
name: str,
|
||||
entrypoint: Union[Callable, str],
|
||||
args: Dict[int, Tuple],
|
||||
envs: Dict[int, Dict[str, str]],
|
||||
args: dict[int, tuple],
|
||||
envs: dict[int, dict[str, str]],
|
||||
logs_specs: LogsSpecs,
|
||||
log_line_prefixes: Optional[Dict[int, str]] = None,
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
start_method: str = "spawn",
|
||||
) -> PContext:
|
||||
"""
|
||||
|
@ -24,7 +24,7 @@ from dataclasses import dataclass, field
|
||||
from enum import IntFlag
|
||||
from multiprocessing import synchronize
|
||||
from types import FrameType
|
||||
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record
|
||||
@ -100,7 +100,7 @@ def _get_default_signal() -> signal.Signals:
|
||||
return signal.SIGTERM
|
||||
|
||||
|
||||
def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str):
|
||||
def _validate_full_rank(d: dict[int, Any], nprocs: int, what: str):
|
||||
actual_keys = set(d.keys())
|
||||
expected_keys = set(range(nprocs))
|
||||
|
||||
@ -122,7 +122,7 @@ class Std(IntFlag):
|
||||
ALL = OUT | ERR
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, vm: str) -> Union["Std", Dict[int, "Std"]]:
|
||||
def from_str(cls, vm: str) -> Union["Std", dict[int, "Std"]]:
|
||||
"""
|
||||
Example:
|
||||
::
|
||||
@ -143,7 +143,7 @@ class Std(IntFlag):
|
||||
if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0)
|
||||
return to_std(vm)
|
||||
elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2)
|
||||
d: Dict[int, Std] = {}
|
||||
d: dict[int, Std] = {}
|
||||
for m in vm.split(","):
|
||||
i, v = m.split(":")
|
||||
d[int(i)] = to_std(v)
|
||||
@ -155,8 +155,8 @@ class Std(IntFlag):
|
||||
|
||||
|
||||
def to_map(
|
||||
val_or_map: Union[Std, Dict[int, Std]], local_world_size: int
|
||||
) -> Dict[int, Std]:
|
||||
val_or_map: Union[Std, dict[int, Std]], local_world_size: int
|
||||
) -> dict[int, Std]:
|
||||
"""
|
||||
Certain APIs take redirect settings either as a single value (e.g. apply to all
|
||||
local ranks) or as an explicit user-provided mapping. This method is a convenience
|
||||
@ -184,11 +184,11 @@ class LogsDest:
|
||||
For each log type, holds mapping of local rank ids to file paths.
|
||||
"""
|
||||
|
||||
stdouts: Dict[int, str] = field(default_factory=dict)
|
||||
stderrs: Dict[int, str] = field(default_factory=dict)
|
||||
tee_stdouts: Dict[int, str] = field(default_factory=dict)
|
||||
tee_stderrs: Dict[int, str] = field(default_factory=dict)
|
||||
error_files: Dict[int, str] = field(default_factory=dict)
|
||||
stdouts: dict[int, str] = field(default_factory=dict)
|
||||
stderrs: dict[int, str] = field(default_factory=dict)
|
||||
tee_stdouts: dict[int, str] = field(default_factory=dict)
|
||||
tee_stderrs: dict[int, str] = field(default_factory=dict)
|
||||
error_files: dict[int, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
class LogsSpecs(ABC):
|
||||
@ -211,9 +211,9 @@ class LogsSpecs(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir: Optional[str] = None,
|
||||
redirects: Union[Std, Dict[int, Std]] = Std.NONE,
|
||||
tee: Union[Std, Dict[int, Std]] = Std.NONE,
|
||||
local_ranks_filter: Optional[Set[int]] = None,
|
||||
redirects: Union[Std, dict[int, Std]] = Std.NONE,
|
||||
tee: Union[Std, dict[int, Std]] = Std.NONE,
|
||||
local_ranks_filter: Optional[set[int]] = None,
|
||||
) -> None:
|
||||
self._root_log_dir = log_dir
|
||||
self._redirects = redirects
|
||||
@ -223,7 +223,7 @@ class LogsSpecs(ABC):
|
||||
@abstractmethod
|
||||
def reify(
|
||||
self,
|
||||
envs: Dict[int, Dict[str, str]],
|
||||
envs: dict[int, dict[str, str]],
|
||||
) -> LogsDest:
|
||||
"""
|
||||
Given the environment variables, builds destination of log files for each of the local ranks.
|
||||
@ -249,9 +249,9 @@ class DefaultLogsSpecs(LogsSpecs):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir: Optional[str] = None,
|
||||
redirects: Union[Std, Dict[int, Std]] = Std.NONE,
|
||||
tee: Union[Std, Dict[int, Std]] = Std.NONE,
|
||||
local_ranks_filter: Optional[Set[int]] = None,
|
||||
redirects: Union[Std, dict[int, Std]] = Std.NONE,
|
||||
tee: Union[Std, dict[int, Std]] = Std.NONE,
|
||||
local_ranks_filter: Optional[set[int]] = None,
|
||||
) -> None:
|
||||
if log_dir != os.devnull:
|
||||
if not log_dir:
|
||||
@ -278,7 +278,7 @@ class DefaultLogsSpecs(LogsSpecs):
|
||||
|
||||
def reify(
|
||||
self,
|
||||
envs: Dict[int, Dict[str, str]],
|
||||
envs: dict[int, dict[str, str]],
|
||||
) -> LogsDest:
|
||||
"""
|
||||
Uses following scheme to build log destination paths:
|
||||
@ -331,8 +331,8 @@ class DefaultLogsSpecs(LogsSpecs):
|
||||
SYS_STREAM = "" # special case to indicate to output to console
|
||||
stdouts = dict.fromkeys(range(nprocs), SYS_STREAM)
|
||||
stderrs = dict.fromkeys(range(nprocs), SYS_STREAM)
|
||||
tee_stdouts: Dict[int, str] = {}
|
||||
tee_stderrs: Dict[int, str] = {}
|
||||
tee_stdouts: dict[int, str] = {}
|
||||
tee_stderrs: dict[int, str] = {}
|
||||
error_files = {}
|
||||
|
||||
for local_rank in range(nprocs):
|
||||
@ -414,10 +414,10 @@ class RunProcsResult:
|
||||
|
||||
"""
|
||||
|
||||
return_values: Dict[int, Any] = field(default_factory=dict)
|
||||
failures: Dict[int, ProcessFailure] = field(default_factory=dict)
|
||||
stdouts: Dict[int, str] = field(default_factory=dict)
|
||||
stderrs: Dict[int, str] = field(default_factory=dict)
|
||||
return_values: dict[int, Any] = field(default_factory=dict)
|
||||
failures: dict[int, ProcessFailure] = field(default_factory=dict)
|
||||
stdouts: dict[int, str] = field(default_factory=dict)
|
||||
stderrs: dict[int, str] = field(default_factory=dict)
|
||||
|
||||
def is_failed(self) -> bool:
|
||||
return len(self.failures) > 0
|
||||
@ -438,10 +438,10 @@ class PContext(abc.ABC):
|
||||
self,
|
||||
name: str,
|
||||
entrypoint: Union[Callable, str],
|
||||
args: Dict[int, Tuple],
|
||||
envs: Dict[int, Dict[str, str]],
|
||||
args: dict[int, tuple],
|
||||
envs: dict[int, dict[str, str]],
|
||||
logs_specs: LogsSpecs,
|
||||
log_line_prefixes: Optional[Dict[int, str]] = None,
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
):
|
||||
self.name = name
|
||||
# validate that all mappings have the same number of keys and
|
||||
@ -544,7 +544,7 @@ class PContext(abc.ABC):
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
def pids(self) -> Dict[int, int]:
|
||||
def pids(self) -> dict[int, int]:
|
||||
"""Return pids of processes mapped by their respective local_ranks."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -587,11 +587,11 @@ def get_std_cm(std_rd: str, redirect_fn):
|
||||
def _wrap(
|
||||
local_rank: int,
|
||||
fn: Callable,
|
||||
args: Dict[int, Tuple],
|
||||
envs: Dict[int, Dict[str, str]],
|
||||
stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None)
|
||||
stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None)
|
||||
ret_vals: Dict[int, mp.SimpleQueue],
|
||||
args: dict[int, tuple],
|
||||
envs: dict[int, dict[str, str]],
|
||||
stdout_redirects: dict[int, str], # redirect file for stdout (to console if None)
|
||||
stderr_redirects: dict[int, str], # redirect file for stderr (to console if None)
|
||||
ret_vals: dict[int, mp.SimpleQueue],
|
||||
queue_finished_reading_event: synchronize.Event,
|
||||
) -> None:
|
||||
# get the per-rank params up front so we fail fast if no mapping is found
|
||||
@ -621,11 +621,11 @@ class MultiprocessContext(PContext):
|
||||
self,
|
||||
name: str,
|
||||
entrypoint: Callable,
|
||||
args: Dict[int, Tuple],
|
||||
envs: Dict[int, Dict[str, str]],
|
||||
args: dict[int, tuple],
|
||||
envs: dict[int, dict[str, str]],
|
||||
start_method: str,
|
||||
logs_specs: LogsSpecs,
|
||||
log_line_prefixes: Optional[Dict[int, str]] = None,
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
@ -644,7 +644,7 @@ class MultiprocessContext(PContext):
|
||||
}
|
||||
|
||||
# see comments in ``join()`` for what this is
|
||||
self._return_values: Dict[int, Any] = {}
|
||||
self._return_values: dict[int, Any] = {}
|
||||
self._pc: Optional[mp.ProcessContext] = None
|
||||
# Note: set method should ONLY be invoked for the use case when all processes finished
|
||||
# successfully. If any process died on event.wait() calling set() method will deadlock.
|
||||
@ -755,7 +755,7 @@ class MultiprocessContext(PContext):
|
||||
stderrs=self.stderrs,
|
||||
)
|
||||
|
||||
def pids(self) -> Dict[int, int]:
|
||||
def pids(self) -> dict[int, int]:
|
||||
assert self._pc is not None # assertion for mypy type checking
|
||||
return dict(enumerate(self._pc.pids()))
|
||||
|
||||
@ -803,10 +803,10 @@ class SubprocessContext(PContext):
|
||||
self,
|
||||
name: str,
|
||||
entrypoint: str,
|
||||
args: Dict[int, Tuple],
|
||||
envs: Dict[int, Dict[str, str]],
|
||||
args: dict[int, tuple],
|
||||
envs: dict[int, dict[str, str]],
|
||||
logs_specs: LogsSpecs,
|
||||
log_line_prefixes: Optional[Dict[int, str]] = None,
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
@ -818,9 +818,9 @@ class SubprocessContext(PContext):
|
||||
)
|
||||
|
||||
# state vector; _vdone[local_rank] -> is local_rank finished or not
|
||||
self._running_local_ranks: Set[int] = set(range(self.nprocs))
|
||||
self._failures: Dict[int, ProcessFailure] = {}
|
||||
self.subprocess_handlers: Dict[int, SubprocessHandler] = {}
|
||||
self._running_local_ranks: set[int] = set(range(self.nprocs))
|
||||
self._failures: dict[int, ProcessFailure] = {}
|
||||
self.subprocess_handlers: dict[int, SubprocessHandler] = {}
|
||||
|
||||
def _start(self):
|
||||
if self.subprocess_handlers:
|
||||
@ -884,7 +884,7 @@ class SubprocessContext(PContext):
|
||||
else: # there are no failures and procs still running
|
||||
return None
|
||||
|
||||
def pids(self) -> Dict[int, int]:
|
||||
def pids(self) -> dict[int, int]:
|
||||
return {
|
||||
local_rank: sh.proc.pid
|
||||
for local_rank, sh in self.subprocess_handlers.items()
|
||||
|
@ -78,7 +78,7 @@ __all__ = [
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
JSON = Dict
|
||||
JSON = dict
|
||||
|
||||
_EMPTY_ERROR_DATA = {"message": "<NONE>"}
|
||||
_NOT_AVAILABLE = "<N/A>"
|
||||
@ -143,7 +143,7 @@ class ProcessFailure:
|
||||
else:
|
||||
self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html"
|
||||
|
||||
def _get_error_data(self, error_file_data: Dict[str, Any]) -> tuple[str, int]:
|
||||
def _get_error_data(self, error_file_data: dict[str, Any]) -> tuple[str, int]:
|
||||
message = error_file_data["message"]
|
||||
if isinstance(message, str):
|
||||
timestamp = int(error_file_data.get("timestamp", 0))
|
||||
@ -231,7 +231,7 @@ class ChildFailedError(Exception):
|
||||
of trainer 1's error file to the scheduler's init process.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, failures: Dict[GlobalRank, ProcessFailure]):
|
||||
def __init__(self, name: str, failures: dict[GlobalRank, ProcessFailure]):
|
||||
self.name = name
|
||||
self.failures = failures
|
||||
assert (
|
||||
@ -248,7 +248,7 @@ class ChildFailedError(Exception):
|
||||
root_rank, _root_failure = self.get_first_failure()
|
||||
|
||||
root_failure_fmt: str = ""
|
||||
other_failures_fmt: List[str] = []
|
||||
other_failures_fmt: list[str] = []
|
||||
width = len(title)
|
||||
for idx, (rank, failure) in enumerate(self.failures.items()):
|
||||
fmt, w = self._format_failure(idx, rank, failure)
|
||||
|
@ -13,7 +13,7 @@ import os
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
__all__ = ["ErrorHandler"]
|
||||
@ -86,7 +86,7 @@ class ErrorHandler:
|
||||
def override_error_code_in_rootcause_data(
|
||||
self,
|
||||
rootcause_error_file: str,
|
||||
rootcause_error: Dict[str, Any],
|
||||
rootcause_error: dict[str, Any],
|
||||
error_code: int = 0,
|
||||
):
|
||||
"""Modify the rootcause_error read from the file, to correctly set the exit code."""
|
||||
|
@ -3,7 +3,6 @@
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
|
||||
SubprocessHandler,
|
||||
@ -15,8 +14,8 @@ __all__ = ["get_subprocess_handler"]
|
||||
|
||||
def get_subprocess_handler(
|
||||
entrypoint: str,
|
||||
args: Tuple,
|
||||
env: Dict[str, str],
|
||||
args: tuple,
|
||||
env: dict[str, str],
|
||||
stdout: str,
|
||||
stderr: str,
|
||||
local_rank_id: int,
|
||||
|
@ -9,7 +9,7 @@ import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
__all__ = ["SubprocessHandler"]
|
||||
@ -34,8 +34,8 @@ class SubprocessHandler:
|
||||
def __init__(
|
||||
self,
|
||||
entrypoint: str,
|
||||
args: Tuple,
|
||||
env: Dict[str, str],
|
||||
args: tuple,
|
||||
env: dict[str, str],
|
||||
stdout: Optional[str],
|
||||
stderr: Optional[str],
|
||||
local_rank_id: int,
|
||||
@ -50,8 +50,8 @@ class SubprocessHandler:
|
||||
self.local_rank_id = local_rank_id
|
||||
self.proc: subprocess.Popen = self._popen(args_str, env_vars)
|
||||
|
||||
def _popen(self, args: Tuple, env: Dict[str, str]) -> subprocess.Popen:
|
||||
kwargs: Dict[str, Any] = {}
|
||||
def _popen(self, args: tuple, env: dict[str, str]) -> subprocess.Popen:
|
||||
kwargs: dict[str, Any] = {}
|
||||
if not IS_WINDOWS:
|
||||
kwargs["start_new_session"] = True
|
||||
return subprocess.Popen(
|
||||
|
@ -12,7 +12,7 @@ import os
|
||||
import time
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from threading import Event
|
||||
from typing import Dict, List, Optional, TextIO, TYPE_CHECKING
|
||||
from typing import Optional, TextIO, TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -89,9 +89,9 @@ class TailLog:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
log_files: Dict[int, str],
|
||||
log_files: dict[int, str],
|
||||
dst: TextIO,
|
||||
log_line_prefixes: Optional[Dict[int, str]] = None,
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
interval_sec: float = 0.1,
|
||||
):
|
||||
n = len(log_files)
|
||||
@ -106,10 +106,10 @@ class TailLog:
|
||||
self._dst = dst
|
||||
self._log_files = log_files
|
||||
self._log_line_prefixes = log_line_prefixes
|
||||
self._finished_events: Dict[int, Event] = {
|
||||
self._finished_events: dict[int, Event] = {
|
||||
local_rank: Event() for local_rank in log_files.keys()
|
||||
}
|
||||
self._futs: List[Future] = []
|
||||
self._futs: list[Future] = []
|
||||
self._interval_sec = interval_sec
|
||||
self._stopped = False
|
||||
|
||||
|
@ -8,7 +8,7 @@
|
||||
import socket
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, ClassVar, Dict, Optional
|
||||
from typing import Any, Callable, ClassVar, Optional
|
||||
|
||||
from torch.distributed import Store
|
||||
from torch.distributed.elastic.utils.distributed import get_free_port
|
||||
@ -325,7 +325,7 @@ RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
|
||||
class RendezvousHandlerRegistry:
|
||||
"""Represent a registry of :py:class:`RendezvousHandler` backends."""
|
||||
|
||||
_registry: Dict[str, RendezvousHandlerCreator]
|
||||
_registry: dict[str, RendezvousHandlerCreator]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._registry = {}
|
||||
|
@ -17,7 +17,7 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import Store
|
||||
@ -298,10 +298,10 @@ class _RendezvousState:
|
||||
complete: bool
|
||||
deadline: Optional[datetime]
|
||||
closed: bool
|
||||
participants: Dict[_NodeDesc, int]
|
||||
wait_list: Set[_NodeDesc]
|
||||
redundancy_list: Set[_NodeDesc]
|
||||
last_heartbeats: Dict[_NodeDesc, datetime]
|
||||
participants: dict[_NodeDesc, int]
|
||||
wait_list: set[_NodeDesc]
|
||||
redundancy_list: set[_NodeDesc]
|
||||
last_heartbeats: dict[_NodeDesc, datetime]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.round = 0
|
||||
@ -377,7 +377,7 @@ class _BackendRendezvousStateHolder(_RendezvousStateHolder):
|
||||
_token: Token
|
||||
_dirty: bool
|
||||
_last_sync_time: float
|
||||
_dead_nodes: List[_NodeDesc]
|
||||
_dead_nodes: list[_NodeDesc]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -13,20 +13,20 @@ import time
|
||||
import weakref
|
||||
from datetime import timedelta
|
||||
from threading import Event, Thread
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
|
||||
__all__ = ["parse_rendezvous_endpoint"]
|
||||
|
||||
|
||||
def _parse_rendezvous_config(config_str: str) -> Dict[str, str]:
|
||||
def _parse_rendezvous_config(config_str: str) -> dict[str, str]:
|
||||
"""Extract key-value pairs from a rendezvous configuration string.
|
||||
|
||||
Args:
|
||||
config_str:
|
||||
A string in format <key1>=<value1>,...,<keyN>=<valueN>.
|
||||
"""
|
||||
config: Dict[str, str] = {}
|
||||
config: dict[str, str] = {}
|
||||
|
||||
config_str = config_str.strip()
|
||||
if not config_str:
|
||||
@ -196,7 +196,7 @@ class _PeriodicTimer:
|
||||
interval: float
|
||||
function: Callable[..., None]
|
||||
args: tuple[Any, ...]
|
||||
kwargs: Dict[str, Any]
|
||||
kwargs: dict[str, Any]
|
||||
stop_event: Event
|
||||
|
||||
_name: Optional[str]
|
||||
|
@ -10,7 +10,7 @@ import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from inspect import getframeinfo, stack
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -103,7 +103,7 @@ class RequestQueue(abc.ABC):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get(self, size: int, timeout: float) -> List[TimerRequest]:
|
||||
def get(self, size: int, timeout: float) -> list[TimerRequest]:
|
||||
"""
|
||||
Gets up to ``size`` number of timer requests in a blocking fashion
|
||||
(no more than ``timeout`` seconds).
|
||||
@ -134,7 +134,7 @@ class TimerServer(abc.ABC):
|
||||
self._stop_signaled = False
|
||||
|
||||
@abc.abstractmethod
|
||||
def register_timers(self, timer_requests: List[TimerRequest]) -> None:
|
||||
def register_timers(self, timer_requests: list[TimerRequest]) -> None:
|
||||
"""
|
||||
Processes the incoming timer requests and registers them with the server.
|
||||
The timer request can either be a acquire-timer or release-timer request.
|
||||
@ -143,13 +143,13 @@ class TimerServer(abc.ABC):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def clear_timers(self, worker_ids: Set[Any]) -> None:
|
||||
def clear_timers(self, worker_ids: set[Any]) -> None:
|
||||
"""
|
||||
Clears all timers for the given ``worker_ids``.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_expired_timers(self, deadline: float) -> Dict[str, List[TimerRequest]]:
|
||||
def get_expired_timers(self, deadline: float) -> dict[str, list[TimerRequest]]:
|
||||
"""
|
||||
Returns all expired timers for each worker_id. An expired timer
|
||||
is a timer for which the expiration_time is less than or equal to
|
||||
|
@ -7,7 +7,6 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from torch.distributed.elastic.utils.logging import get_logger
|
||||
|
||||
@ -19,7 +18,7 @@ __all__ = ["log_debug_info_for_expired_timers"]
|
||||
|
||||
def log_debug_info_for_expired_timers(
|
||||
run_id: str,
|
||||
expired_timers: Dict[int, List[str]],
|
||||
expired_timers: dict[int, list[str]],
|
||||
):
|
||||
if expired_timers:
|
||||
logger.info("Timers expired for run:[%s] [%s].", run_id, expired_timers)
|
||||
|
@ -13,7 +13,7 @@ import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional, Set
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch.distributed.elastic.timer.api import TimerClient, TimerRequest
|
||||
from torch.distributed.elastic.timer.debug_info_logging import (
|
||||
@ -201,7 +201,7 @@ class FileTimerServer:
|
||||
self._run_id = run_id
|
||||
self._max_interval = max_interval
|
||||
self._daemon = daemon
|
||||
self._timers: Dict[tuple[int, str], FileTimerRequest] = {}
|
||||
self._timers: dict[tuple[int, str], FileTimerRequest] = {}
|
||||
self._stop_signaled = False
|
||||
self._watchdog_thread: Optional[threading.Thread] = None
|
||||
|
||||
@ -354,12 +354,12 @@ class FileTimerServer:
|
||||
|
||||
self.clear_timers(reaped_worker_pids)
|
||||
|
||||
def _get_scopes(self, timer_requests: List[FileTimerRequest]) -> List[str]:
|
||||
def _get_scopes(self, timer_requests: list[FileTimerRequest]) -> list[str]:
|
||||
return [r.scope_id for r in timer_requests]
|
||||
|
||||
def _get_requests(
|
||||
self, fd: io.TextIOWrapper, max_interval: float
|
||||
) -> List[FileTimerRequest]:
|
||||
) -> list[FileTimerRequest]:
|
||||
start = time.time()
|
||||
requests = []
|
||||
while not self._stop_signaled or self._run_once:
|
||||
@ -394,7 +394,7 @@ class FileTimerServer:
|
||||
break
|
||||
return requests
|
||||
|
||||
def register_timers(self, timer_requests: List[FileTimerRequest]) -> None:
|
||||
def register_timers(self, timer_requests: list[FileTimerRequest]) -> None:
|
||||
for request in timer_requests:
|
||||
pid = request.worker_pid
|
||||
scope_id = request.scope_id
|
||||
@ -409,14 +409,14 @@ class FileTimerServer:
|
||||
else:
|
||||
self._timers[key] = request
|
||||
|
||||
def clear_timers(self, worker_pids: Set[int]) -> None:
|
||||
def clear_timers(self, worker_pids: set[int]) -> None:
|
||||
for pid, scope_id in list(self._timers.keys()):
|
||||
if pid in worker_pids or not FileTimerServer.is_process_running(pid):
|
||||
del self._timers[(pid, scope_id)]
|
||||
|
||||
def get_expired_timers(self, deadline: float) -> Dict[int, List[FileTimerRequest]]:
|
||||
def get_expired_timers(self, deadline: float) -> dict[int, list[FileTimerRequest]]:
|
||||
# pid -> [timer_requests...]
|
||||
expired_timers: Dict[int, List[FileTimerRequest]] = {}
|
||||
expired_timers: dict[int, list[FileTimerRequest]] = {}
|
||||
for request in self._timers.values():
|
||||
if request.expiration_time <= deadline:
|
||||
expired_scopes = expired_timers.setdefault(request.worker_pid, [])
|
||||
|
@ -10,7 +10,7 @@ import os
|
||||
import signal
|
||||
import time
|
||||
from queue import Empty
|
||||
from typing import Any, Dict, List, Set
|
||||
from typing import Any
|
||||
|
||||
from .api import RequestQueue, TimerClient, TimerRequest, TimerServer
|
||||
|
||||
@ -56,7 +56,7 @@ class MultiprocessingRequestQueue(RequestQueue):
|
||||
def size(self) -> int:
|
||||
return self._mp_queue.qsize()
|
||||
|
||||
def get(self, size, timeout: float) -> List[TimerRequest]:
|
||||
def get(self, size, timeout: float) -> list[TimerRequest]:
|
||||
requests = []
|
||||
wait = timeout
|
||||
for _ in range(0, size):
|
||||
@ -88,9 +88,9 @@ class LocalTimerServer(TimerServer):
|
||||
self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True
|
||||
):
|
||||
super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon)
|
||||
self._timers: Dict[tuple[Any, str], TimerRequest] = {}
|
||||
self._timers: dict[tuple[Any, str], TimerRequest] = {}
|
||||
|
||||
def register_timers(self, timer_requests: List[TimerRequest]) -> None:
|
||||
def register_timers(self, timer_requests: list[TimerRequest]) -> None:
|
||||
for request in timer_requests:
|
||||
pid = request.worker_id
|
||||
scope_id = request.scope_id
|
||||
@ -102,14 +102,14 @@ class LocalTimerServer(TimerServer):
|
||||
else:
|
||||
self._timers[(pid, scope_id)] = request
|
||||
|
||||
def clear_timers(self, worker_ids: Set[int]) -> None:
|
||||
def clear_timers(self, worker_ids: set[int]) -> None:
|
||||
for pid, scope_id in list(self._timers.keys()):
|
||||
if pid in worker_ids:
|
||||
self._timers.pop((pid, scope_id))
|
||||
|
||||
def get_expired_timers(self, deadline: float) -> Dict[Any, List[TimerRequest]]:
|
||||
def get_expired_timers(self, deadline: float) -> dict[Any, list[TimerRequest]]:
|
||||
# pid -> [timer_requests...]
|
||||
expired_timers: Dict[Any, List[TimerRequest]] = {}
|
||||
expired_timers: dict[Any, list[TimerRequest]] = {}
|
||||
for request in self._timers.values():
|
||||
if request.expiration_time <= deadline:
|
||||
expired_scopes = expired_timers.setdefault(request.worker_id, [])
|
||||
|
@ -9,7 +9,7 @@
|
||||
import os
|
||||
import socket
|
||||
from string import Template
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_env_variable_or_raise(env_name: str) -> str:
|
||||
@ -51,7 +51,7 @@ class macros:
|
||||
local_rank = "${local_rank}"
|
||||
|
||||
@staticmethod
|
||||
def substitute(args: List[Any], local_rank: str) -> List[str]:
|
||||
def substitute(args: list[Any], local_rank: str) -> list[str]:
|
||||
args_sub = []
|
||||
for arg in args:
|
||||
if isinstance(arg, str):
|
||||
|
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from typing import Callable, Iterator, TypeVar
|
||||
from collections.abc import Iterator
|
||||
from typing import Callable, TypeVar
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
|
@ -7,9 +7,10 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections.abc import Iterable
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from typing import Callable, Iterable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -85,7 +86,7 @@ def synchronize(
|
||||
world_size: int,
|
||||
key_prefix: str,
|
||||
timeout: float = 300,
|
||||
) -> List[bytes]:
|
||||
) -> list[bytes]:
|
||||
"""
|
||||
Synchronizes ``world_size`` agents between each other using the underlying c10d store.
|
||||
The ``data`` will be available on each of the agents.
|
||||
|
Reference in New Issue
Block a user