mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-19 18:14:54 +08:00
Compare commits
1 Commits
dev/joona/
...
ciflow/ind
| Author | SHA1 | Date | |
|---|---|---|---|
| 59642e6a24 |
@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterator
|
||||
from enum import auto, Enum
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -251,7 +251,7 @@ def apply_activation_checkpointing(
|
||||
model,
|
||||
checkpoint_wrapper_fn=checkpoint_wrapper,
|
||||
check_fn=lambda _: True,
|
||||
auto_wrap_policy: Optional[Callable[[nn.Module, bool, int], bool]] = None,
|
||||
auto_wrap_policy: Callable[[nn.Module, bool, int], bool] | None = None,
|
||||
):
|
||||
"""
|
||||
Apply :func:`checkpoint_wrapper` to modules within `model` based on a user-defined configuration.
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -136,7 +135,7 @@ def _low_precision_hook(
|
||||
prec: torch.dtype,
|
||||
state: LowPrecisionState,
|
||||
grad: torch.Tensor,
|
||||
output: Optional[torch.Tensor],
|
||||
output: torch.Tensor | None,
|
||||
):
|
||||
if grad.dtype != prec:
|
||||
grad.data = grad.data.to(prec)
|
||||
@ -151,7 +150,7 @@ def _low_precision_hook(
|
||||
|
||||
|
||||
def fp16_compress_hook(
|
||||
state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None
|
||||
state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor | None = None
|
||||
):
|
||||
r"""
|
||||
Implement FSDP communication hook for a simple gradient compression approach.
|
||||
@ -172,7 +171,7 @@ def fp16_compress_hook(
|
||||
|
||||
|
||||
def bf16_compress_hook(
|
||||
state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None
|
||||
state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor | None = None
|
||||
):
|
||||
r"""
|
||||
Implement FSDP communication hook for a simple gradient compression approach .
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import weakref
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -47,7 +47,7 @@ def _perform_local_step(
|
||||
# expects `None` in a list position to indicate that the corresponding
|
||||
# parameter should not be updated
|
||||
num_local_optim_params = len(zero.optim.param_groups[0]["params"])
|
||||
gradients: list[Optional[torch.Tensor]] = [
|
||||
gradients: list[torch.Tensor | None] = [
|
||||
_NO_PARAM_UPDATE for _ in range(num_local_optim_params)
|
||||
]
|
||||
assert bucket_index in overlap_info.offsets, (
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from types import TracebackType
|
||||
from typing import Any, NamedTuple, Optional
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -228,9 +228,9 @@ class Join:
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
type: Optional[type[BaseException]],
|
||||
value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
type: type[BaseException] | None,
|
||||
value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
):
|
||||
r"""
|
||||
Repeatedly runs the main hooks until all processes join; then, runs the post-hooks.
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -23,7 +22,7 @@ class ModelAverager(ABC):
|
||||
will be used. (default: ``None``)
|
||||
"""
|
||||
|
||||
def __init__(self, process_group: Optional[dist.ProcessGroup] = None):
|
||||
def __init__(self, process_group: dist.ProcessGroup | None = None):
|
||||
self.process_group = (
|
||||
process_group if process_group is not None else _not_none(dist.group.WORLD)
|
||||
)
|
||||
@ -88,7 +87,7 @@ class PeriodicModelAverager(ModelAverager):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, period, warmup_steps=0, process_group: Optional[dist.ProcessGroup] = None
|
||||
self, period, warmup_steps=0, process_group: dist.ProcessGroup | None = None
|
||||
):
|
||||
super().__init__(process_group)
|
||||
if warmup_steps < 0:
|
||||
@ -108,9 +107,7 @@ class PeriodicModelAverager(ModelAverager):
|
||||
|
||||
def average_parameters(
|
||||
self,
|
||||
params: Union[
|
||||
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
|
||||
],
|
||||
params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]],
|
||||
):
|
||||
"""
|
||||
Averages parameters or parameter groups of an optimizer if ``step`` is no less than ``warmup_steps``.
|
||||
|
||||
@ -4,7 +4,6 @@ import logging
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -160,9 +159,7 @@ class HierarchicalModelAverager(averagers.ModelAverager):
|
||||
|
||||
def average_parameters(
|
||||
self,
|
||||
params: Union[
|
||||
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
|
||||
],
|
||||
params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]],
|
||||
):
|
||||
"""
|
||||
Averages parameters or parameter groups of an optimizer.
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import itertools
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -51,10 +50,7 @@ def average_parameters(
|
||||
|
||||
|
||||
def get_params_to_average(
|
||||
params: Union[
|
||||
Iterable[torch.nn.Parameter],
|
||||
Iterable[dict[str, torch.nn.Parameter]],
|
||||
],
|
||||
params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]],
|
||||
):
|
||||
"""
|
||||
Return a list of parameters that need to average.
|
||||
@ -83,9 +79,7 @@ def get_params_to_average(
|
||||
|
||||
|
||||
def average_parameters_or_parameter_groups(
|
||||
params: Union[
|
||||
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
|
||||
],
|
||||
params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]],
|
||||
process_group: ProcessGroup,
|
||||
):
|
||||
"""Averages parameters of a model or parameter groups of an optimizer."""
|
||||
|
||||
@ -13,7 +13,7 @@ import importlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast, Generic, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import Any, cast, Generic, TYPE_CHECKING, TypeVar
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -37,19 +37,19 @@ T = TypeVar("T")
|
||||
|
||||
@dataclass
|
||||
class SyncPayload(Generic[T]):
|
||||
stage_name: Optional[str]
|
||||
stage_name: str | None
|
||||
success: bool
|
||||
payload: T
|
||||
exception: Optional[Exception] = None
|
||||
exception: Exception | None = None
|
||||
|
||||
|
||||
def broadcast(
|
||||
data_or_fn: Union[T, Callable[[], T]],
|
||||
data_or_fn: T | Callable[[], T],
|
||||
*,
|
||||
success: bool = True,
|
||||
stage_name: Optional[str] = None,
|
||||
stage_name: str | None = None,
|
||||
rank: int = 0,
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
pg: dist.ProcessGroup | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Broadcasts the data payload from rank 0 to all other ranks.
|
||||
@ -79,8 +79,8 @@ def broadcast(
|
||||
"Data or Function is expected to be None if not successful"
|
||||
)
|
||||
|
||||
payload: Optional[T] = None
|
||||
exception: Optional[Exception] = None
|
||||
payload: T | None = None
|
||||
exception: Exception | None = None
|
||||
# if no pg is passed then execute if rank is 0
|
||||
if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank):
|
||||
# determine if it is an executable function or data payload only
|
||||
@ -124,9 +124,9 @@ def broadcast(
|
||||
|
||||
|
||||
def all_gather(
|
||||
data_or_fn: Union[T, Callable[[], T]],
|
||||
stage_name: Optional[str] = None,
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
data_or_fn: T | Callable[[], T],
|
||||
stage_name: str | None = None,
|
||||
pg: dist.ProcessGroup | None = None,
|
||||
) -> list[T]:
|
||||
"""
|
||||
A simple all_gather primitive with basic synchronization guard logic,
|
||||
@ -144,8 +144,8 @@ def all_gather(
|
||||
Example usage:
|
||||
>> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg)
|
||||
"""
|
||||
payload: Optional[T] = None
|
||||
exception: Optional[Exception] = None
|
||||
payload: T | None = None
|
||||
exception: Exception | None = None
|
||||
success = True
|
||||
# determine if it is an executable function or data payload only
|
||||
if callable(data_or_fn):
|
||||
@ -247,7 +247,7 @@ def _summarize_ranks(ranks: Iterable[int]) -> str:
|
||||
raise AssertionError("ranks should all be positive")
|
||||
if len(set(ranks)) != len(ranks):
|
||||
raise AssertionError("ranks should not contain duplicates")
|
||||
curr: Optional[Union[int, range]] = None
|
||||
curr: int | range | None = None
|
||||
ranges = []
|
||||
while ranks:
|
||||
x = ranks.pop(0)
|
||||
@ -345,9 +345,7 @@ def _desync_table_str(tag: str, value_ranks: dict[Any, set[int]]) -> str:
|
||||
return str(f"{headers}\n{row_str}")
|
||||
|
||||
|
||||
def _check_rng_sync(
|
||||
generator: torch.Generator, group: dist.ProcessGroup
|
||||
) -> Optional[str]:
|
||||
def _check_rng_sync(generator: torch.Generator, group: dist.ProcessGroup) -> str | None:
|
||||
value_ranks, value_header = _check_rng_sync_internal(generator, group)
|
||||
log_str = None
|
||||
if len(value_ranks) > 1:
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
|
||||
|
||||
@ -19,7 +18,7 @@ default_pg_timeout: timedelta = _DEFAULT_PG_TIMEOUT
|
||||
try:
|
||||
from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT
|
||||
|
||||
default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT
|
||||
default_pg_nccl_timeout: timedelta | None = _DEFAULT_PG_NCCL_TIMEOUT
|
||||
except ImportError:
|
||||
# if C++ NCCL support is not compiled, we don't have access to the default nccl value.
|
||||
# if anyone is actually trying to use nccl in this state, it should error.
|
||||
|
||||
@ -65,7 +65,7 @@ else:
|
||||
"DeviceMesh requires numpy >= 1.21 to be installed for type checking"
|
||||
)
|
||||
|
||||
BackendConfig = tuple[Optional[str], Optional[C10dBackend.Options]]
|
||||
BackendConfig = tuple[str | None, C10dBackend.Options | None]
|
||||
torch.serialization.add_safe_globals([_MeshLayout])
|
||||
|
||||
class _MeshEnv(threading.local):
|
||||
@ -175,7 +175,7 @@ else:
|
||||
|
||||
_device_type: str
|
||||
_rank_map: torch.Tensor
|
||||
_mesh_dim_names: Optional[tuple[str, ...]]
|
||||
_mesh_dim_names: tuple[str, ...] | None
|
||||
_layout: _MeshLayout
|
||||
_root_mesh: Optional["DeviceMesh"] = None
|
||||
# Record flatten mesh name to its flattened mesh in root mesh.
|
||||
@ -184,14 +184,14 @@ else:
|
||||
def __init__(
|
||||
self,
|
||||
device_type: str,
|
||||
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
|
||||
mesh: Union[torch.Tensor, "ArrayLike"] | None = None,
|
||||
*,
|
||||
mesh_dim_names: Optional[tuple[str, ...]] = None,
|
||||
backend_override: Optional[tuple[BackendConfig, ...]] = None,
|
||||
mesh_dim_names: tuple[str, ...] | None = None,
|
||||
backend_override: tuple[BackendConfig, ...] | None = None,
|
||||
_init_backend: bool = True,
|
||||
_rank: Optional[int] = None,
|
||||
_layout: Optional[_MeshLayout] = None,
|
||||
_rank_map: Optional[torch.Tensor] = None,
|
||||
_rank: int | None = None,
|
||||
_layout: _MeshLayout | None = None,
|
||||
_rank_map: torch.Tensor | None = None,
|
||||
_root_mesh: Optional["DeviceMesh"] = None,
|
||||
) -> None:
|
||||
# no-op in OSS, logs API usage metrics in meta-internal runs
|
||||
@ -292,7 +292,7 @@ else:
|
||||
raise AssertionError(
|
||||
f"rank_coords.size(0) must be 0 or 1, got {rank_coords.size(0)}"
|
||||
)
|
||||
self._coordinate_on_dim: Optional[list[int]] = (
|
||||
self._coordinate_on_dim: list[int] | None = (
|
||||
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
||||
)
|
||||
|
||||
@ -317,7 +317,7 @@ else:
|
||||
)
|
||||
|
||||
@property
|
||||
def mesh_dim_names(self) -> Optional[tuple[str, ...]]:
|
||||
def mesh_dim_names(self) -> tuple[str, ...] | None:
|
||||
"""Returns the names of mesh dimensions."""
|
||||
return self._mesh_dim_names
|
||||
|
||||
@ -378,7 +378,7 @@ else:
|
||||
rank_map: torch.Tensor,
|
||||
dim_name: str,
|
||||
backend_override: BackendConfig,
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
# Generate a 2D global mesh tensor for the current dim for PG creation.
|
||||
pg_ranks_by_dim = sub_layout.nest().remap_to_tensor(rank_map)
|
||||
backend, pg_options = backend_override
|
||||
@ -471,7 +471,7 @@ else:
|
||||
def _init_process_groups(
|
||||
layout: _MeshLayout,
|
||||
rank_map: torch.Tensor,
|
||||
mesh_dim_names: Optional[tuple[str, ...]],
|
||||
mesh_dim_names: tuple[str, ...] | None,
|
||||
backend_override: tuple[BackendConfig, ...],
|
||||
) -> list[str]:
|
||||
# group_name associated with each mesh dimension, each
|
||||
@ -543,9 +543,7 @@ else:
|
||||
and self._thread_id == other._thread_id
|
||||
)
|
||||
|
||||
def __getitem__(
|
||||
self, mesh_dim_names: Union[str, tuple[str, ...]]
|
||||
) -> "DeviceMesh":
|
||||
def __getitem__(self, mesh_dim_names: str | tuple[str, ...]) -> "DeviceMesh":
|
||||
"""
|
||||
Slice the current DeviceMesh based on the mesh_dim_names given to create a submesh.
|
||||
The submesh created consists of the dimensions and the communicators indicated by
|
||||
@ -613,7 +611,7 @@ else:
|
||||
submesh = self._create_sub_mesh(sliced_mesh_layout, mesh_dim_names)
|
||||
return submesh
|
||||
|
||||
def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup:
|
||||
def get_group(self, mesh_dim: int | str | None = None) -> ProcessGroup:
|
||||
"""
|
||||
Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the
|
||||
DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh.
|
||||
@ -705,7 +703,7 @@ else:
|
||||
|
||||
def _create_flatten_mesh(
|
||||
self,
|
||||
mesh_dim_name: Optional[str] = None,
|
||||
mesh_dim_name: str | None = None,
|
||||
backend_override: BackendConfig = (None, None),
|
||||
) -> "DeviceMesh":
|
||||
root_mesh = self._get_root_mesh()
|
||||
@ -754,7 +752,7 @@ else:
|
||||
|
||||
return res_flattened_mesh
|
||||
|
||||
def _get_root_mesh_dim(self) -> Optional[int]:
|
||||
def _get_root_mesh_dim(self) -> int | None:
|
||||
"""
|
||||
Returns the index of the mesh dim in the root mesh.
|
||||
The device_mesh passed in needs to be sliced out from the root mesh
|
||||
@ -893,11 +891,11 @@ else:
|
||||
|
||||
@staticmethod
|
||||
def from_group(
|
||||
group: Union[ProcessGroup, list[ProcessGroup]],
|
||||
group: ProcessGroup | list[ProcessGroup],
|
||||
device_type: str,
|
||||
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
|
||||
mesh: Union[torch.Tensor, "ArrayLike"] | None = None,
|
||||
*,
|
||||
mesh_dim_names: Optional[tuple[str, ...]] = None,
|
||||
mesh_dim_names: tuple[str, ...] | None = None,
|
||||
) -> "DeviceMesh":
|
||||
"""
|
||||
Constructs a :class:`DeviceMesh` with ``device_type`` from an
|
||||
@ -986,7 +984,7 @@ else:
|
||||
device_mesh._dim_group_names = [group.group_name for group in groups]
|
||||
return device_mesh
|
||||
|
||||
def size(self, mesh_dim: Optional[int] = None) -> int:
|
||||
def size(self, mesh_dim: int | None = None) -> int:
|
||||
if mesh_dim is not None:
|
||||
return self._layout[mesh_dim].numel()
|
||||
return self._layout.numel()
|
||||
@ -1005,7 +1003,7 @@ else:
|
||||
"""
|
||||
return get_rank()
|
||||
|
||||
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
|
||||
def get_local_rank(self, mesh_dim: int | str | None = None) -> int:
|
||||
"""
|
||||
Returns the local rank of the given mesh_dim of the DeviceMesh.
|
||||
|
||||
@ -1049,7 +1047,7 @@ else:
|
||||
)
|
||||
return not_none(get_rank(mesh_dim_group))
|
||||
|
||||
def get_coordinate(self) -> Optional[list[int]]:
|
||||
def get_coordinate(self) -> list[int] | None:
|
||||
"""
|
||||
Return the relative indices of this rank relative to all
|
||||
dimensions of the mesh. If this rank is not part of the mesh, return None.
|
||||
@ -1058,10 +1056,11 @@ else:
|
||||
|
||||
def _flatten(
|
||||
self,
|
||||
mesh_dim_name: Optional[str] = None,
|
||||
backend_override: Union[
|
||||
None, str, C10dBackend.Options, tuple[str, C10dBackend.Options]
|
||||
] = None,
|
||||
mesh_dim_name: str | None = None,
|
||||
backend_override: None
|
||||
| str
|
||||
| C10dBackend.Options
|
||||
| tuple[str, C10dBackend.Options] = None,
|
||||
) -> "DeviceMesh":
|
||||
"""
|
||||
Returns a 1D DeviceMesh by flattening the current DeviceMesh.
|
||||
@ -1095,7 +1094,7 @@ else:
|
||||
mesh_sizes: tuple[int, ...],
|
||||
mesh_dim_names: tuple[str, ...],
|
||||
backend_override: tuple[
|
||||
tuple[Optional[str], Optional[C10dBackend.Options]], ...
|
||||
tuple[str | None, C10dBackend.Options | None], ...
|
||||
] = ((None, None),),
|
||||
) -> "DeviceMesh":
|
||||
inner_layout = _MeshLayout(tuple(mesh_sizes), suffix_product(mesh_sizes))
|
||||
@ -1140,15 +1139,13 @@ else:
|
||||
|
||||
def _unflatten(
|
||||
self,
|
||||
dim: Union[int, str],
|
||||
dim: int | str,
|
||||
mesh_sizes: tuple[int, ...],
|
||||
mesh_dim_names: tuple[str, ...],
|
||||
backend_override: Optional[
|
||||
dict[
|
||||
str,
|
||||
Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]],
|
||||
]
|
||||
] = None,
|
||||
backend_override: dict[
|
||||
str, str | C10dBackend.Options | tuple[str, C10dBackend.Options]
|
||||
]
|
||||
| None = None,
|
||||
) -> "DeviceMesh":
|
||||
"""
|
||||
Returns a DeviceMesh by unflatten the current DeviceMesh.
|
||||
@ -1239,11 +1236,11 @@ else:
|
||||
|
||||
def _normalize_backend_override(
|
||||
backend_override: dict[
|
||||
Union[int, str],
|
||||
Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]],
|
||||
int | str,
|
||||
str | C10dBackend.Options | tuple[str, C10dBackend.Options],
|
||||
],
|
||||
ndim: int,
|
||||
mesh_dim_names: Optional[tuple[str, ...]] = None,
|
||||
mesh_dim_names: tuple[str, ...] | None = None,
|
||||
) -> Iterator[BackendConfig]:
|
||||
if mesh_dim_names is None:
|
||||
mesh_dim_names = ()
|
||||
@ -1278,13 +1275,11 @@ else:
|
||||
device_type: str,
|
||||
mesh_shape: tuple[int, ...],
|
||||
*,
|
||||
mesh_dim_names: Optional[tuple[str, ...]] = None,
|
||||
backend_override: Optional[
|
||||
dict[
|
||||
Union[int, str],
|
||||
Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]],
|
||||
]
|
||||
] = None,
|
||||
mesh_dim_names: tuple[str, ...] | None = None,
|
||||
backend_override: dict[
|
||||
int | str, str | C10dBackend.Options | tuple[str, C10dBackend.Options]
|
||||
]
|
||||
| None = None,
|
||||
) -> DeviceMesh:
|
||||
"""
|
||||
Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
|
||||
|
||||
@ -17,7 +17,7 @@ import warnings
|
||||
from collections import namedtuple
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
@ -309,7 +309,7 @@ class Backend(str): # noqa: SLOT000
|
||||
name,
|
||||
func,
|
||||
extended_api=False,
|
||||
devices: Optional[Union[str, list[str]]] = None,
|
||||
devices: str | list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Register a new backend with the given name and instantiating function.
|
||||
@ -504,10 +504,10 @@ class P2POp:
|
||||
self,
|
||||
op: Callable,
|
||||
tensor: torch.Tensor,
|
||||
peer: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
peer: int | None = None,
|
||||
group: ProcessGroup | None = None,
|
||||
tag: int = 0,
|
||||
group_peer: Optional[int] = None,
|
||||
group_peer: int | None = None,
|
||||
):
|
||||
"""Init."""
|
||||
self.op = op
|
||||
@ -523,10 +523,10 @@ class P2POp:
|
||||
cls,
|
||||
op: Callable,
|
||||
tensor: torch.Tensor,
|
||||
peer: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
peer: int | None = None,
|
||||
group: ProcessGroup | None = None,
|
||||
tag: int = 0,
|
||||
group_peer: Optional[int] = None,
|
||||
group_peer: int | None = None,
|
||||
):
|
||||
"""Create and return a new instance of the class."""
|
||||
_check_op(op)
|
||||
@ -566,9 +566,9 @@ class _CollOp:
|
||||
self,
|
||||
op: Callable,
|
||||
tensor: torch.Tensor,
|
||||
dst_tensor: Optional[torch.Tensor] = None,
|
||||
redop: Optional[ReduceOp] = None,
|
||||
root: Optional[int] = None,
|
||||
dst_tensor: torch.Tensor | None = None,
|
||||
redop: ReduceOp | None = None,
|
||||
root: int | None = None,
|
||||
):
|
||||
self.op = op
|
||||
self.tensor = tensor
|
||||
@ -587,7 +587,7 @@ _pg_backend_config: dict[ProcessGroup, str] = {}
|
||||
_group_count = 0
|
||||
_tags_to_pg: dict[str, list[ProcessGroup]] = {}
|
||||
_pg_to_tag: dict[ProcessGroup, str] = {}
|
||||
_backend: Optional[str] = None
|
||||
_backend: str | None = None
|
||||
|
||||
|
||||
class _World:
|
||||
@ -605,7 +605,7 @@ class _World:
|
||||
self._pg_coalesce_state: dict[ProcessGroup, list[_CollOp]] = {}
|
||||
|
||||
@property
|
||||
def default_pg(self) -> Optional[ProcessGroup]:
|
||||
def default_pg(self) -> ProcessGroup | None:
|
||||
"""
|
||||
Process group that includes all ranks of the cluster.
|
||||
|
||||
@ -730,11 +730,11 @@ class _WorldMeta(type):
|
||||
|
||||
# Points to the default PG once initialized.
|
||||
@property
|
||||
def WORLD(cls) -> Optional[ProcessGroup]:
|
||||
def WORLD(cls) -> ProcessGroup | None:
|
||||
return _world.default_pg
|
||||
|
||||
@WORLD.setter
|
||||
def WORLD(cls, pg: Optional[ProcessGroup]):
|
||||
def WORLD(cls, pg: ProcessGroup | None):
|
||||
_world.default_pg = pg
|
||||
|
||||
|
||||
@ -772,12 +772,12 @@ def _check_valid_timeout(timeout: Any) -> None:
|
||||
|
||||
|
||||
# Default process group state
|
||||
_default_pg_init_method: Optional[str] = None
|
||||
_default_pg_init_method: str | None = None
|
||||
|
||||
STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"
|
||||
|
||||
|
||||
def _get_object_coll_device(group: Optional[ProcessGroup] = None) -> str:
|
||||
def _get_object_coll_device(group: ProcessGroup | None = None) -> str:
|
||||
"""
|
||||
.. note:: This is an internal helper and does not have backward
|
||||
compatibility, please use with caution.
|
||||
@ -843,7 +843,7 @@ def _get_object_coll_device(group: Optional[ProcessGroup] = None) -> str:
|
||||
return devices[0].type
|
||||
|
||||
|
||||
def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device:
|
||||
def _get_pg_default_device(group: ProcessGroup | None = None) -> torch.device:
|
||||
"""
|
||||
.. note:: This method will be deprecated, it only stays for
|
||||
backward-compatiblity reason. Alternatives:
|
||||
@ -923,7 +923,7 @@ def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device
|
||||
return rv
|
||||
|
||||
|
||||
def _device_capability(group: Optional[ProcessGroup] = None) -> list[str]:
|
||||
def _device_capability(group: ProcessGroup | None = None) -> list[str]:
|
||||
"""
|
||||
Return the device type(s) supported by ``group``.
|
||||
|
||||
@ -1007,7 +1007,7 @@ def _store_based_barrier(
|
||||
)
|
||||
|
||||
|
||||
def _rank_not_in_group(group: Optional[ProcessGroup]) -> bool:
|
||||
def _rank_not_in_group(group: ProcessGroup | None) -> bool:
|
||||
"""Check if the current process's rank is not in a given group."""
|
||||
if group is None:
|
||||
return False
|
||||
@ -1089,7 +1089,7 @@ def _get_global_rank(group, rank) -> int:
|
||||
return get_global_rank(group, rank)
|
||||
|
||||
|
||||
def get_process_group_ranks(group: Optional[ProcessGroup]) -> list[int]:
|
||||
def get_process_group_ranks(group: ProcessGroup | None) -> list[int]:
|
||||
"""
|
||||
Get all ranks associated with ``group``.
|
||||
|
||||
@ -1148,7 +1148,7 @@ def _check_tensor_list(param, param_name) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _group_or_default_group(group: Optional[ProcessGroup] = None) -> ProcessGroup:
|
||||
def _group_or_default_group(group: ProcessGroup | None = None) -> ProcessGroup:
|
||||
if group is None or group is GroupMember.WORLD:
|
||||
group = _get_default_group()
|
||||
return group
|
||||
@ -1156,8 +1156,8 @@ def _group_or_default_group(group: Optional[ProcessGroup] = None) -> ProcessGrou
|
||||
|
||||
def _canonicalize_group_rank(
|
||||
group: ProcessGroup,
|
||||
global_rank: Optional[int] = None,
|
||||
group_rank: Optional[int] = None,
|
||||
global_rank: int | None = None,
|
||||
group_rank: int | None = None,
|
||||
return_global: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
@ -1361,7 +1361,7 @@ def _update_default_pg(pg) -> None:
|
||||
torch._C._distributed_c10d._set_global_rank(rank)
|
||||
|
||||
|
||||
def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
|
||||
def get_backend_config(group: ProcessGroup | None = None) -> str:
|
||||
"""
|
||||
Return the backend configuration of the given process group.
|
||||
|
||||
@ -1381,7 +1381,7 @@ def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
|
||||
return str(not_none(backend_config))
|
||||
|
||||
|
||||
def get_backend(group: Optional[ProcessGroup] = None) -> Backend:
|
||||
def get_backend(group: ProcessGroup | None = None) -> Backend:
|
||||
"""
|
||||
Return the backend of the given process group.
|
||||
|
||||
@ -1407,7 +1407,7 @@ def get_backend(group: Optional[ProcessGroup] = None) -> Backend:
|
||||
return Backend(not_none(pg_store)[0])
|
||||
|
||||
|
||||
def get_default_backend_for_device(device: Union[str, torch.device]) -> str:
|
||||
def get_default_backend_for_device(device: str | torch.device) -> str:
|
||||
"""
|
||||
Return the default backend for the given device.
|
||||
|
||||
@ -1441,7 +1441,7 @@ def _get_process_group_uid(pg: ProcessGroup) -> int:
|
||||
return -1
|
||||
|
||||
|
||||
def _get_pg_config(group: Optional[ProcessGroup] = None) -> dict[str, Any]:
|
||||
def _get_pg_config(group: ProcessGroup | None = None) -> dict[str, Any]:
|
||||
"""
|
||||
Return the pg configuration of the given process group.
|
||||
|
||||
@ -1473,7 +1473,7 @@ def get_pg_count() -> int:
|
||||
return _world.group_count
|
||||
|
||||
|
||||
def get_node_local_rank(fallback_rank: Optional[int] = None) -> int:
|
||||
def get_node_local_rank(fallback_rank: int | None = None) -> int:
|
||||
"""
|
||||
Return the local rank of the current process relative to the node.
|
||||
|
||||
@ -1526,7 +1526,7 @@ def _add_ephemeral_timeout_for_all_pgs(timeout: timedelta) -> None:
|
||||
backend._add_ephemeral_timeout(timeout)
|
||||
|
||||
|
||||
def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> None:
|
||||
def _set_pg_timeout(timeout: timedelta, group: ProcessGroup | None = None) -> None:
|
||||
"""
|
||||
Set the timeout for the given process group when users want to use a different timeout instead of
|
||||
default values.
|
||||
@ -1575,16 +1575,16 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) ->
|
||||
@_exception_logger
|
||||
@_time_logger
|
||||
def init_process_group(
|
||||
backend: Optional[str] = None,
|
||||
init_method: Optional[str] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
backend: str | None = None,
|
||||
init_method: str | None = None,
|
||||
timeout: timedelta | None = None,
|
||||
world_size: int = -1,
|
||||
rank: int = -1,
|
||||
store: Optional[Store] = None,
|
||||
store: Store | None = None,
|
||||
group_name: str = "",
|
||||
pg_options: Optional[Any] = None,
|
||||
device_id: Optional[Union[torch.device, int]] = None,
|
||||
_ranks: Optional[list[int]] = None,
|
||||
pg_options: Any | None = None,
|
||||
device_id: torch.device | int | None = None,
|
||||
_ranks: list[int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the default distributed process group.
|
||||
@ -2216,7 +2216,7 @@ def _new_process_group_helper(
|
||||
return pg, prefix_store
|
||||
|
||||
|
||||
def destroy_process_group(group: Optional[ProcessGroup] = None):
|
||||
def destroy_process_group(group: ProcessGroup | None = None):
|
||||
"""
|
||||
Destroy a given process group, and deinitialize the distributed package.
|
||||
|
||||
@ -2305,7 +2305,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
|
||||
_unregister_process_group(pg.group_name)
|
||||
|
||||
|
||||
def _abort_process_group(group: Optional[ProcessGroup] = None):
|
||||
def _abort_process_group(group: ProcessGroup | None = None):
|
||||
"""
|
||||
Abort a given process group. If group.WORLD (i.e. `None`) is given, all
|
||||
process groups including the default one will be aborted.
|
||||
@ -2397,7 +2397,7 @@ def _abort_process_group(group: Optional[ProcessGroup] = None):
|
||||
_unregister_process_group(pg.group_name)
|
||||
|
||||
|
||||
def get_rank(group: Optional[ProcessGroup] = None) -> int:
|
||||
def get_rank(group: ProcessGroup | None = None) -> int:
|
||||
"""
|
||||
Return the rank of the current process in the provided ``group``, default otherwise.
|
||||
|
||||
@ -2424,7 +2424,7 @@ def get_rank(group: Optional[ProcessGroup] = None) -> int:
|
||||
return get_group_rank(group, default_pg.rank())
|
||||
|
||||
|
||||
def get_world_size(group: Optional[ProcessGroup] = None) -> int:
|
||||
def get_world_size(group: ProcessGroup | None = None) -> int:
|
||||
"""
|
||||
Return the number of processes in the current process group.
|
||||
|
||||
@ -2445,11 +2445,11 @@ def get_world_size(group: Optional[ProcessGroup] = None) -> int:
|
||||
|
||||
def isend(
|
||||
tensor: torch.Tensor,
|
||||
dst: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
dst: int | None = None,
|
||||
group: ProcessGroup | None = None,
|
||||
tag: int = 0,
|
||||
group_dst: Optional[int] = None,
|
||||
) -> Optional[Work]:
|
||||
group_dst: int | None = None,
|
||||
) -> Work | None:
|
||||
"""
|
||||
Send a tensor asynchronously.
|
||||
|
||||
@ -2490,11 +2490,11 @@ def isend(
|
||||
|
||||
def irecv(
|
||||
tensor: torch.Tensor,
|
||||
src: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
src: int | None = None,
|
||||
group: ProcessGroup | None = None,
|
||||
tag: int = 0,
|
||||
group_src: Optional[int] = None,
|
||||
) -> Optional[Work]:
|
||||
group_src: int | None = None,
|
||||
) -> Work | None:
|
||||
"""
|
||||
Receives a tensor asynchronously.
|
||||
|
||||
@ -2536,10 +2536,10 @@ def irecv(
|
||||
@_exception_logger
|
||||
def send(
|
||||
tensor: torch.Tensor,
|
||||
dst: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
dst: int | None = None,
|
||||
group: ProcessGroup | None = None,
|
||||
tag: int = 0,
|
||||
group_dst: Optional[int] = None,
|
||||
group_dst: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send a tensor synchronously.
|
||||
@ -2568,10 +2568,10 @@ def send(
|
||||
@_exception_logger
|
||||
def recv(
|
||||
tensor: torch.Tensor,
|
||||
src: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
src: int | None = None,
|
||||
group: ProcessGroup | None = None,
|
||||
tag: int = 0,
|
||||
group_src: Optional[int] = None,
|
||||
group_src: int | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Receives a tensor synchronously.
|
||||
@ -2623,7 +2623,7 @@ class _CoalescingManager:
|
||||
def __init__(self) -> None:
|
||||
self.works: list[Work] = []
|
||||
|
||||
def append(self, work: Optional[Work] = None):
|
||||
def append(self, work: Work | None = None):
|
||||
if work:
|
||||
self.works.append(work)
|
||||
|
||||
@ -2634,8 +2634,8 @@ class _CoalescingManager:
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _coalescing_manager(
|
||||
group: Optional[ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
group: ProcessGroup | None = None,
|
||||
device: torch.device | None = None,
|
||||
async_ops: bool = False,
|
||||
):
|
||||
"""
|
||||
@ -2731,13 +2731,13 @@ def _coalescing_manager(
|
||||
|
||||
class _TimeEstimator:
|
||||
def __init__(self) -> None:
|
||||
self.estimated_time: Optional[float] = None
|
||||
self.estimated_time: float | None = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _time_estimator(
|
||||
group: Optional[ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
group: ProcessGroup | None = None,
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
"""
|
||||
Context manager used to estimate time of collectives.
|
||||
@ -2862,10 +2862,10 @@ def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]:
|
||||
@_exception_logger
|
||||
def broadcast(
|
||||
tensor: torch.Tensor,
|
||||
src: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
src: int | None = None,
|
||||
group: ProcessGroup | None = None,
|
||||
async_op: bool = False,
|
||||
group_src: Optional[int] = None,
|
||||
group_src: int | None = None,
|
||||
):
|
||||
"""
|
||||
Broadcasts the tensor to the whole group.
|
||||
@ -3084,11 +3084,11 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
|
||||
@_exception_logger
|
||||
def reduce(
|
||||
tensor: torch.Tensor,
|
||||
dst: Optional[int] = None,
|
||||
dst: int | None = None,
|
||||
op=ReduceOp.SUM,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
group: ProcessGroup | None = None,
|
||||
async_op: bool = False,
|
||||
group_dst: Optional[int] = None,
|
||||
group_dst: int | None = None,
|
||||
):
|
||||
"""
|
||||
Reduces the tensor data across all machines.
|
||||
@ -3268,10 +3268,10 @@ def all_gather_object(object_list, obj, group=None):
|
||||
@_exception_logger
|
||||
def gather_object(
|
||||
obj: Any,
|
||||
object_gather_list: Optional[list[Any]] = None,
|
||||
dst: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
group_dst: Optional[int] = None,
|
||||
object_gather_list: list[Any] | None = None,
|
||||
dst: int | None = None,
|
||||
group: ProcessGroup | None = None,
|
||||
group_dst: int | None = None,
|
||||
):
|
||||
"""
|
||||
Gathers picklable objects from the whole group in a single process.
|
||||
@ -3399,10 +3399,10 @@ def gather_object(
|
||||
@_exception_logger
|
||||
def send_object_list(
|
||||
object_list: list[Any],
|
||||
dst: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
group_dst: Optional[int] = None,
|
||||
dst: int | None = None,
|
||||
group: ProcessGroup | None = None,
|
||||
device: torch.device | None = None,
|
||||
group_dst: int | None = None,
|
||||
use_batch: bool = False,
|
||||
):
|
||||
"""
|
||||
@ -3517,10 +3517,10 @@ def send_object_list(
|
||||
@_exception_logger
|
||||
def recv_object_list(
|
||||
object_list: list[Any],
|
||||
src: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
group_src: Optional[int] = None,
|
||||
src: int | None = None,
|
||||
group: ProcessGroup | None = None,
|
||||
device: torch.device | None = None,
|
||||
group_src: int | None = None,
|
||||
use_batch: bool = False,
|
||||
):
|
||||
"""
|
||||
@ -3659,10 +3659,10 @@ def recv_object_list(
|
||||
@_exception_logger
|
||||
def broadcast_object_list(
|
||||
object_list: list[Any],
|
||||
src: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
group_src: Optional[int] = None,
|
||||
src: int | None = None,
|
||||
group: ProcessGroup | None = None,
|
||||
device: torch.device | None = None,
|
||||
group_src: int | None = None,
|
||||
):
|
||||
"""
|
||||
Broadcasts picklable objects in ``object_list`` to the whole group.
|
||||
@ -3791,10 +3791,10 @@ def broadcast_object_list(
|
||||
@_exception_logger
|
||||
def scatter_object_list(
|
||||
scatter_object_output_list: list[Any],
|
||||
scatter_object_input_list: Optional[list[Any]] = None,
|
||||
src: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
group_src: Optional[int] = None,
|
||||
scatter_object_input_list: list[Any] | None = None,
|
||||
src: int | None = None,
|
||||
group: ProcessGroup | None = None,
|
||||
group_src: int | None = None,
|
||||
):
|
||||
"""
|
||||
Scatters picklable objects in ``scatter_object_input_list`` to the whole group.
|
||||
@ -4265,11 +4265,11 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list):
|
||||
@_exception_logger
|
||||
def gather(
|
||||
tensor: torch.Tensor,
|
||||
gather_list: Optional[list[torch.Tensor]] = None,
|
||||
dst: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
gather_list: list[torch.Tensor] | None = None,
|
||||
dst: int | None = None,
|
||||
group: ProcessGroup | None = None,
|
||||
async_op: bool = False,
|
||||
group_dst: Optional[int] = None,
|
||||
group_dst: int | None = None,
|
||||
):
|
||||
"""
|
||||
Gathers a list of tensors in a single process.
|
||||
@ -4348,11 +4348,11 @@ def gather(
|
||||
@_exception_logger
|
||||
def scatter(
|
||||
tensor: torch.Tensor,
|
||||
scatter_list: Optional[list[torch.Tensor]] = None,
|
||||
src: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
scatter_list: list[torch.Tensor] | None = None,
|
||||
src: int | None = None,
|
||||
group: ProcessGroup | None = None,
|
||||
async_op: bool = False,
|
||||
group_src: Optional[int] = None,
|
||||
group_src: int | None = None,
|
||||
):
|
||||
"""
|
||||
Scatters a list of tensors to all processes in a group.
|
||||
@ -4895,7 +4895,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
|
||||
|
||||
@_exception_logger
|
||||
def barrier(
|
||||
group: Optional[ProcessGroup] = GroupMember.WORLD, async_op=False, device_ids=None
|
||||
group: ProcessGroup | None = GroupMember.WORLD, async_op=False, device_ids=None
|
||||
):
|
||||
"""
|
||||
Synchronize all processes.
|
||||
@ -4967,7 +4967,7 @@ def barrier(
|
||||
|
||||
|
||||
def monitored_barrier(
|
||||
group: Optional[ProcessGroup] = GroupMember.WORLD,
|
||||
group: ProcessGroup | None = GroupMember.WORLD,
|
||||
timeout=None,
|
||||
wait_all_ranks=False,
|
||||
):
|
||||
@ -5104,7 +5104,7 @@ def _process_group_name(ranks, use_hashed_name):
|
||||
return pg_name
|
||||
|
||||
|
||||
def _get_backend_from_str(backend: Optional[str] = None) -> Backend:
|
||||
def _get_backend_from_str(backend: str | None = None) -> Backend:
|
||||
# Default to the same backend as the global process group
|
||||
# if backend is not specified.
|
||||
if not backend:
|
||||
@ -5124,12 +5124,12 @@ def _is_safe_to_split() -> bool:
|
||||
|
||||
@_time_logger
|
||||
def split_group(
|
||||
parent_pg: Optional[ProcessGroup] = None,
|
||||
split_ranks: Optional[list] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
pg_options: Optional[Any] = None,
|
||||
group_desc: Optional[str] = None,
|
||||
) -> Optional[ProcessGroup]:
|
||||
parent_pg: ProcessGroup | None = None,
|
||||
split_ranks: list | None = None,
|
||||
timeout: timedelta | None = None,
|
||||
pg_options: Any | None = None,
|
||||
group_desc: str | None = None,
|
||||
) -> ProcessGroup | None:
|
||||
"""
|
||||
Create a new process group split from the given parent process group.
|
||||
|
||||
@ -5290,7 +5290,7 @@ def new_group(
|
||||
pg_options=None,
|
||||
use_local_synchronization=False,
|
||||
group_desc=None,
|
||||
device_id: Optional[torch.device] = None,
|
||||
device_id: torch.device | None = None,
|
||||
):
|
||||
"""
|
||||
Create a new distributed group.
|
||||
@ -5380,7 +5380,7 @@ def _new_group_with_tag(
|
||||
pg_tag=None,
|
||||
use_local_synchronization=False,
|
||||
group_desc=None,
|
||||
device_id: Optional[torch.device] = None,
|
||||
device_id: torch.device | None = None,
|
||||
):
|
||||
"""
|
||||
Variant of ``new_group`` that exposes tag creation.
|
||||
@ -5693,7 +5693,7 @@ def new_subgroups_by_enumeration(
|
||||
return cur_subgroup, subgroups
|
||||
|
||||
|
||||
def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGroup]:
|
||||
def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> ProcessGroup | None:
|
||||
if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"):
|
||||
tag = f"user:{tag}"
|
||||
|
||||
@ -5765,9 +5765,9 @@ SHRINK_ABORT = 0x01
|
||||
@_time_logger
|
||||
def shrink_group(
|
||||
ranks_to_exclude: list[int],
|
||||
group: Optional[ProcessGroup] = None,
|
||||
group: ProcessGroup | None = None,
|
||||
shrink_flags: int = SHRINK_DEFAULT,
|
||||
pg_options: Optional[Any] = None,
|
||||
pg_options: Any | None = None,
|
||||
) -> ProcessGroup:
|
||||
"""
|
||||
Shrinks a process group by excluding specified ranks.
|
||||
@ -5857,7 +5857,7 @@ def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> N
|
||||
)
|
||||
|
||||
|
||||
def _prepare_shrink_target_group(group: Optional[ProcessGroup]) -> dict:
|
||||
def _prepare_shrink_target_group(group: ProcessGroup | None) -> dict:
|
||||
"""Prepare and validate the target group for shrinking."""
|
||||
target_pg = group if group is not None else _get_default_group()
|
||||
|
||||
@ -6107,7 +6107,7 @@ def _create_shrunk_process_group(
|
||||
return new_pg
|
||||
|
||||
|
||||
def _destroy_all_other_groups(exclude_group: Optional[ProcessGroup] = None) -> None:
|
||||
def _destroy_all_other_groups(exclude_group: ProcessGroup | None = None) -> None:
|
||||
"""
|
||||
Destroy all process groups except the excluded group and clean up all global state.
|
||||
|
||||
@ -6223,9 +6223,9 @@ def _update_process_group_global_state(
|
||||
store: Store,
|
||||
group_name: str,
|
||||
backend_config: str,
|
||||
rank_mapping: Optional[dict[int, int]] = None,
|
||||
pg_tag: Optional[str] = None,
|
||||
user_tag: Optional[str] = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
pg_tag: str | None = None,
|
||||
user_tag: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update all global state dictionaries for a process group.
|
||||
|
||||
@ -19,7 +19,7 @@ from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import torch.distributed.elastic.rendezvous as rdzv
|
||||
import torch.distributed.elastic.utils.store as store_util
|
||||
@ -89,19 +89,19 @@ class WorkerSpec:
|
||||
role: str
|
||||
local_world_size: int
|
||||
rdzv_handler: rdzv.RendezvousHandler
|
||||
fn: Optional[Callable] = None
|
||||
fn: Callable | None = None
|
||||
# TODO @kiuk - make entrypoint a required field
|
||||
entrypoint: Union[Callable, str, None] = None
|
||||
entrypoint: Callable | str | None = None
|
||||
args: tuple = ()
|
||||
max_restarts: int = 3
|
||||
monitor_interval: float = 0.1
|
||||
master_port: Optional[int] = None
|
||||
master_addr: Optional[str] = None
|
||||
local_addr: Optional[str] = None
|
||||
master_port: int | None = None
|
||||
master_addr: str | None = None
|
||||
local_addr: str | None = None
|
||||
event_log_handler: str = "null"
|
||||
numa_options: Optional[NumaOptions] = None
|
||||
duplicate_stdout_filters: Optional[list[str]] = None
|
||||
duplicate_stderr_filters: Optional[list[str]] = None
|
||||
numa_options: NumaOptions | None = None
|
||||
duplicate_stdout_filters: list[str] | None = None
|
||||
duplicate_stderr_filters: list[str] | None = None
|
||||
virtual_local_rank: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
@ -807,11 +807,11 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
self,
|
||||
state: str,
|
||||
source: EventSource,
|
||||
worker: Optional[Worker] = None,
|
||||
raw_error: Optional[str] = None,
|
||||
duration_ms: Optional[float] = None,
|
||||
exit_code: Optional[int] = None,
|
||||
worker_pid: Optional[int] = None,
|
||||
worker: Worker | None = None,
|
||||
raw_error: str | None = None,
|
||||
duration_ms: float | None = None,
|
||||
exit_code: int | None = None,
|
||||
worker_pid: int | None = None,
|
||||
) -> Event:
|
||||
wg = self._worker_group
|
||||
spec = wg.spec
|
||||
|
||||
@ -15,7 +15,7 @@ import socket
|
||||
import time
|
||||
import uuid
|
||||
from string import Template
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch.distributed.elastic.timer as timer
|
||||
from torch.distributed.elastic import events
|
||||
@ -152,16 +152,16 @@ class LocalElasticAgent(SimpleElasticAgent):
|
||||
logs_specs: LogsSpecs,
|
||||
start_method="spawn",
|
||||
exit_barrier_timeout: float = 300,
|
||||
log_line_prefix_template: Optional[str] = None,
|
||||
log_line_prefix_template: str | None = None,
|
||||
):
|
||||
super().__init__(spec, exit_barrier_timeout)
|
||||
self._start_method = start_method
|
||||
self._pcontext: Optional[PContext] = None
|
||||
self._pcontext: PContext | None = None
|
||||
self._rdzv_handler = spec.rdzv_handler
|
||||
self._log_line_prefix_template = log_line_prefix_template
|
||||
self._worker_watchdog: Optional[timer.FileTimerServer] = None
|
||||
self._worker_watchdog: timer.FileTimerServer | None = None
|
||||
self._logs_specs = logs_specs
|
||||
self._health_check_server: Optional[HealthCheckServer] = None
|
||||
self._health_check_server: HealthCheckServer | None = None
|
||||
|
||||
def _setup_local_watchdog(self, envs: dict[int, dict[str, str]]) -> None:
|
||||
enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
|
||||
@ -244,7 +244,7 @@ class LocalElasticAgent(SimpleElasticAgent):
|
||||
def _log_watchdog_event(
|
||||
self,
|
||||
name: str,
|
||||
request: Optional[timer.FileTimerRequest],
|
||||
request: timer.FileTimerRequest | None,
|
||||
) -> None:
|
||||
wg = self._worker_group
|
||||
spec = wg.spec
|
||||
@ -297,7 +297,7 @@ class LocalElasticAgent(SimpleElasticAgent):
|
||||
|
||||
args: dict[int, tuple] = {}
|
||||
envs: dict[int, dict[str, str]] = {}
|
||||
log_line_prefixes: Optional[dict[int, str]] = (
|
||||
log_line_prefixes: dict[int, str] | None = (
|
||||
{} if self._log_line_prefix_template else None
|
||||
)
|
||||
for worker in worker_group.workers:
|
||||
|
||||
@ -86,10 +86,10 @@ def construct_and_record_rdzv_event(
|
||||
node_state: NodeState,
|
||||
name: str = "",
|
||||
hostname: str = "",
|
||||
pid: Optional[int] = None,
|
||||
pid: int | None = None,
|
||||
master_endpoint: str = "",
|
||||
local_id: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
local_id: int | None = None,
|
||||
rank: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize rendezvous event object and record its operations.
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
|
||||
__all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"]
|
||||
@ -95,8 +95,8 @@ class RdzvEvent:
|
||||
pid: int
|
||||
node_state: NodeState
|
||||
master_endpoint: str = ""
|
||||
rank: Optional[int] = None
|
||||
local_id: Optional[int] = None
|
||||
rank: int | None = None
|
||||
local_id: int | None = None
|
||||
error_trace: str = ""
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@ -158,7 +158,7 @@ from .api import ( # noqa: F401
|
||||
)
|
||||
|
||||
|
||||
def initialize_metrics(cfg: Optional[MetricsConfig] = None):
|
||||
def initialize_metrics(cfg: MetricsConfig | None = None):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@ -11,7 +11,6 @@ import abc
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
from typing_extensions import deprecated
|
||||
|
||||
|
||||
@ -37,7 +36,7 @@ MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value
|
||||
class MetricsConfig:
|
||||
__slots__ = ["params"]
|
||||
|
||||
def __init__(self, params: Optional[dict[str, str]] = None):
|
||||
def __init__(self, params: dict[str, str] | None = None):
|
||||
self.params = params
|
||||
if self.params is None:
|
||||
self.params = {}
|
||||
@ -77,7 +76,7 @@ _default_metrics_handler: MetricHandler = NullMetricHandler()
|
||||
|
||||
|
||||
# pyre-fixme[9]: group has type `str`; used as `None`.
|
||||
def configure(handler: MetricHandler, group: Optional[str] = None):
|
||||
def configure(handler: MetricHandler, group: str | None = None):
|
||||
if group is None:
|
||||
global _default_metrics_handler
|
||||
# pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used
|
||||
|
||||
@ -102,15 +102,15 @@ __all__ = [
|
||||
|
||||
def start_processes(
|
||||
name: str,
|
||||
entrypoint: Union[Callable, str],
|
||||
entrypoint: Callable | str,
|
||||
args: dict[int, tuple],
|
||||
envs: dict[int, dict[str, str]],
|
||||
logs_specs: LogsSpecs,
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
log_line_prefixes: dict[int, str] | None = None,
|
||||
start_method: str = "spawn",
|
||||
numa_options: Optional[NumaOptions] = None,
|
||||
duplicate_stdout_filters: Optional[list[str]] = None,
|
||||
duplicate_stderr_filters: Optional[list[str]] = None,
|
||||
numa_options: NumaOptions | None = None,
|
||||
duplicate_stdout_filters: list[str] | None = None,
|
||||
duplicate_stderr_filters: list[str] | None = None,
|
||||
) -> PContext:
|
||||
"""
|
||||
Start ``n`` copies of ``entrypoint`` processes with the provided options.
|
||||
|
||||
@ -25,7 +25,7 @@ from dataclasses import dataclass, field
|
||||
from enum import IntFlag
|
||||
from multiprocessing import synchronize
|
||||
from types import FrameType
|
||||
from typing import Any, Optional, TextIO, Union
|
||||
from typing import Any, TextIO, Union
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record
|
||||
@ -73,7 +73,7 @@ class SignalException(Exception):
|
||||
self.sigval = sigval
|
||||
|
||||
|
||||
def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None:
|
||||
def _terminate_process_handler(signum: int, frame: FrameType | None) -> None:
|
||||
"""Termination handler that raises exceptions on the main process.
|
||||
|
||||
When the process receives death signal(SIGTERM, SIGINT), this termination handler will
|
||||
@ -156,9 +156,7 @@ class Std(IntFlag):
|
||||
)
|
||||
|
||||
|
||||
def to_map(
|
||||
val_or_map: Union[Std, dict[int, Std]], local_world_size: int
|
||||
) -> dict[int, Std]:
|
||||
def to_map(val_or_map: Std | dict[int, Std], local_world_size: int) -> dict[int, Std]:
|
||||
"""
|
||||
Certain APIs take redirect settings either as a single value (e.g. apply to all
|
||||
local ranks) or as an explicit user-provided mapping. This method is a convenience
|
||||
@ -216,10 +214,10 @@ class LogsSpecs(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_dir: Optional[str] = None,
|
||||
redirects: Union[Std, dict[int, Std]] = Std.NONE,
|
||||
tee: Union[Std, dict[int, Std]] = Std.NONE,
|
||||
local_ranks_filter: Optional[set[int]] = None,
|
||||
log_dir: str | None = None,
|
||||
redirects: Std | dict[int, Std] = Std.NONE,
|
||||
tee: Std | dict[int, Std] = Std.NONE,
|
||||
local_ranks_filter: set[int] | None = None,
|
||||
) -> None:
|
||||
self._root_log_dir = log_dir
|
||||
self._redirects = redirects
|
||||
@ -254,10 +252,10 @@ class DefaultLogsSpecs(LogsSpecs):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_dir: Optional[str] = None,
|
||||
redirects: Union[Std, dict[int, Std]] = Std.NONE,
|
||||
tee: Union[Std, dict[int, Std]] = Std.NONE,
|
||||
local_ranks_filter: Optional[set[int]] = None,
|
||||
log_dir: str | None = None,
|
||||
redirects: Std | dict[int, Std] = Std.NONE,
|
||||
tee: Std | dict[int, Std] = Std.NONE,
|
||||
local_ranks_filter: set[int] | None = None,
|
||||
) -> None:
|
||||
if log_dir != os.devnull:
|
||||
if not log_dir:
|
||||
@ -275,7 +273,7 @@ class DefaultLogsSpecs(LogsSpecs):
|
||||
def root_log_dir(self) -> str:
|
||||
return str(self._root_log_dir)
|
||||
|
||||
def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str):
|
||||
def _make_log_dir(self, log_dir: str | None, rdzv_run_id: str):
|
||||
base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_")
|
||||
os.makedirs(base_log_dir, exist_ok=True)
|
||||
dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir)
|
||||
@ -465,13 +463,13 @@ class PContext(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
entrypoint: Union[Callable, str],
|
||||
entrypoint: Callable | str,
|
||||
args: dict[int, tuple],
|
||||
envs: dict[int, dict[str, str]],
|
||||
logs_specs: LogsSpecs,
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
duplicate_stdout_filters: Optional[list[str]] = None,
|
||||
duplicate_stderr_filters: Optional[list[str]] = None,
|
||||
log_line_prefixes: dict[int, str] | None = None,
|
||||
duplicate_stdout_filters: list[str] | None = None,
|
||||
duplicate_stderr_filters: list[str] | None = None,
|
||||
):
|
||||
self.name = name
|
||||
# validate that all mappings have the same number of keys and
|
||||
@ -491,8 +489,8 @@ class PContext(abc.ABC):
|
||||
self.stderrs = logs_dest.stderrs
|
||||
self.error_files = logs_dest.error_files
|
||||
self.nprocs = nprocs
|
||||
self.filtered_stdout: Optional[TextIO] = None
|
||||
self.filtered_stderr: Optional[TextIO] = None
|
||||
self.filtered_stdout: TextIO | None = None
|
||||
self.filtered_stderr: TextIO | None = None
|
||||
|
||||
self._tail_logs = [
|
||||
TailLog(name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes),
|
||||
@ -582,7 +580,7 @@ class PContext(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def _poll(self) -> Optional[RunProcsResult]:
|
||||
def _poll(self) -> RunProcsResult | None:
|
||||
"""
|
||||
Poll the run status of the processes running under this context.
|
||||
This method follows an "all-or-nothing" policy and returns
|
||||
@ -592,7 +590,7 @@ class PContext(abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]:
|
||||
def wait(self, timeout: float = -1, period: float = 1) -> RunProcsResult | None:
|
||||
"""
|
||||
Wait for the specified ``timeout`` seconds, polling every ``period`` seconds
|
||||
for the processes to be done. Returns ``None`` if the processes are still running
|
||||
@ -646,9 +644,7 @@ class PContext(abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def close(
|
||||
self, death_sig: Optional[signal.Signals] = None, timeout: int = 30
|
||||
) -> None:
|
||||
def close(self, death_sig: signal.Signals | None = None, timeout: int = 30) -> None:
|
||||
r"""
|
||||
Terminates all processes managed by this context and cleans up any
|
||||
meta resources (e.g. redirect, error_file files).
|
||||
@ -685,7 +681,7 @@ def _wrap(
|
||||
stderr_redirects: dict[int, str], # redirect file for stderr (to console if None)
|
||||
ret_vals: dict[int, mp.SimpleQueue],
|
||||
queue_finished_reading_event: synchronize.Event,
|
||||
numa_options: Optional[NumaOptions],
|
||||
numa_options: NumaOptions | None,
|
||||
) -> None:
|
||||
# get the per-rank params up front so we fail fast if no mapping is found
|
||||
args_ = args[local_rank]
|
||||
@ -721,10 +717,10 @@ class MultiprocessContext(PContext):
|
||||
envs: dict[int, dict[str, str]],
|
||||
start_method: str,
|
||||
logs_specs: LogsSpecs,
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
numa_options: Optional[NumaOptions] = None,
|
||||
duplicate_stdout_filters: Optional[list[str]] = None,
|
||||
duplicate_stderr_filters: Optional[list[str]] = None,
|
||||
log_line_prefixes: dict[int, str] | None = None,
|
||||
numa_options: NumaOptions | None = None,
|
||||
duplicate_stdout_filters: list[str] | None = None,
|
||||
duplicate_stderr_filters: list[str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
@ -746,12 +742,12 @@ class MultiprocessContext(PContext):
|
||||
|
||||
# see comments in ``join()`` for what this is
|
||||
self._return_values: dict[int, Any] = {}
|
||||
self._pc: Optional[mp.ProcessContext] = None
|
||||
self._pc: mp.ProcessContext | None = None
|
||||
# Note: set method should ONLY be invoked for the use case when all processes finished
|
||||
# successfully. If any process died on event.wait() calling set() method will deadlock.
|
||||
self._worker_finished_event = mp.get_context(self.start_method).Event()
|
||||
|
||||
self._numa_options: Optional[NumaOptions] = numa_options
|
||||
self._numa_options: NumaOptions | None = numa_options
|
||||
|
||||
def _start(self):
|
||||
if self._pc:
|
||||
@ -780,7 +776,7 @@ class MultiprocessContext(PContext):
|
||||
def _is_done(self) -> bool:
|
||||
return len(self._return_values) == self.nprocs
|
||||
|
||||
def _poll(self) -> Optional[RunProcsResult]:
|
||||
def _poll(self) -> RunProcsResult | None:
|
||||
assert self._pc is not None # assertion for mypy type checker
|
||||
|
||||
try:
|
||||
@ -910,10 +906,10 @@ class SubprocessContext(PContext):
|
||||
args: dict[int, tuple],
|
||||
envs: dict[int, dict[str, str]],
|
||||
logs_specs: LogsSpecs,
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
numa_options: Optional[NumaOptions] = None,
|
||||
duplicate_stdout_filters: Optional[list[str]] = None,
|
||||
duplicate_stderr_filters: Optional[list[str]] = None,
|
||||
log_line_prefixes: dict[int, str] | None = None,
|
||||
numa_options: NumaOptions | None = None,
|
||||
duplicate_stdout_filters: list[str] | None = None,
|
||||
duplicate_stderr_filters: list[str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
@ -930,7 +926,7 @@ class SubprocessContext(PContext):
|
||||
self._running_local_ranks: set[int] = set(range(self.nprocs))
|
||||
self._failures: dict[int, ProcessFailure] = {}
|
||||
self.subprocess_handlers: dict[int, SubprocessHandler] = {}
|
||||
self._numa_options: Optional[NumaOptions] = numa_options
|
||||
self._numa_options: NumaOptions | None = numa_options
|
||||
|
||||
def _start(self):
|
||||
if self.subprocess_handlers:
|
||||
@ -965,7 +961,7 @@ class SubprocessContext(PContext):
|
||||
)
|
||||
# else: --> succeeded; nothing to do
|
||||
|
||||
def _poll(self) -> Optional[RunProcsResult]:
|
||||
def _poll(self) -> RunProcsResult | None:
|
||||
done_local_ranks: set[int] = set()
|
||||
self._capture_process_failures(done_local_ranks)
|
||||
|
||||
|
||||
@ -312,8 +312,8 @@ class ChildFailedError(Exception):
|
||||
|
||||
|
||||
def record(
|
||||
fn: Callable[_P, _R], error_handler: Optional[ErrorHandler] = None
|
||||
) -> Callable[_P, Union[_R, None]]:
|
||||
fn: Callable[_P, _R], error_handler: ErrorHandler | None = None
|
||||
) -> Callable[_P, _R | None]:
|
||||
"""
|
||||
Syntactic sugar to record errors/exceptions that happened in the decorated
|
||||
function using the provided ``error_handler``.
|
||||
@ -353,7 +353,7 @@ def record(
|
||||
if not error_handler:
|
||||
error_handler = get_error_handler()
|
||||
|
||||
def wrap(f: Callable[_P, _R]) -> Callable[_P, Union[_R, None]]:
|
||||
def wrap(f: Callable[_P, _R]) -> Callable[_P, _R | None]:
|
||||
@wraps(f)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs):
|
||||
assert error_handler is not None # assertion for mypy type checker
|
||||
|
||||
@ -13,7 +13,7 @@ import os
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
__all__ = ["ErrorHandler"]
|
||||
@ -33,7 +33,7 @@ class ErrorHandler:
|
||||
Subclasses should override ``initialize()`` and ``record_exception()``.
|
||||
"""
|
||||
|
||||
def _get_error_file_path(self) -> Optional[str]:
|
||||
def _get_error_file_path(self) -> str | None:
|
||||
"""
|
||||
Return the error file path.
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from typing import Optional
|
||||
|
||||
from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
|
||||
SubprocessHandler,
|
||||
@ -21,7 +20,7 @@ def get_subprocess_handler(
|
||||
stdout: str,
|
||||
stderr: str,
|
||||
local_rank_id: int,
|
||||
numa_options: Optional[NumaOptions] = None,
|
||||
numa_options: NumaOptions | None = None,
|
||||
) -> SubprocessHandler:
|
||||
return SubprocessHandler(
|
||||
entrypoint=entrypoint,
|
||||
|
||||
@ -9,7 +9,7 @@ import os
|
||||
import signal
|
||||
import sys
|
||||
from subprocess import Popen
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from torch.numa.binding import maybe_wrap_command_args_with_numa_binding, NumaOptions
|
||||
|
||||
@ -38,10 +38,10 @@ class SubprocessHandler:
|
||||
entrypoint: str,
|
||||
args: tuple,
|
||||
env: dict[str, str],
|
||||
stdout: Optional[str],
|
||||
stderr: Optional[str],
|
||||
stdout: str | None,
|
||||
stderr: str | None,
|
||||
local_rank_id: int,
|
||||
numa_options: Optional[NumaOptions],
|
||||
numa_options: NumaOptions | None,
|
||||
):
|
||||
self._stdout = open(stdout, "w") if stdout else None
|
||||
self._stderr = open(stderr, "w") if stderr else None
|
||||
@ -76,7 +76,7 @@ class SubprocessHandler:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def close(self, death_sig: Optional[signal.Signals] = None) -> None:
|
||||
def close(self, death_sig: signal.Signals | None = None) -> None:
|
||||
if not death_sig:
|
||||
death_sig = _get_default_signal()
|
||||
if IS_WINDOWS:
|
||||
|
||||
@ -13,7 +13,7 @@ import time
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from threading import Event
|
||||
from typing import Optional, TextIO, TYPE_CHECKING
|
||||
from typing import TextIO, TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -30,7 +30,7 @@ def tail_logfile(
|
||||
dst: TextIO,
|
||||
finished: Event,
|
||||
interval_sec: float,
|
||||
log_line_filter: Optional[Callable[[str], bool]] = None,
|
||||
log_line_filter: Callable[[str], bool] | None = None,
|
||||
):
|
||||
while not os.path.exists(file):
|
||||
if finished.is_set():
|
||||
@ -98,7 +98,7 @@ class TailLog:
|
||||
name: str,
|
||||
log_files: dict[int, str],
|
||||
dst: TextIO,
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
log_line_prefixes: dict[int, str] | None = None,
|
||||
interval_sec: float = 0.1,
|
||||
log_line_filter: Callable[[str], bool] = (lambda _: True),
|
||||
):
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
"""
|
||||
@ -65,11 +65,11 @@ class Client:
|
||||
raise EtcdStubError
|
||||
|
||||
def write(
|
||||
self, key: str, value: Any, ttl: Optional[int] = None, **kwargs: Any
|
||||
self, key: str, value: Any, ttl: int | None = None, **kwargs: Any
|
||||
) -> None:
|
||||
raise EtcdStubError
|
||||
|
||||
def test_and_set(
|
||||
self, key: str, value: Any, prev_value: Any, ttl: Optional[int] = None
|
||||
self, key: str, value: Any, prev_value: Any, ttl: int | None = None
|
||||
) -> None:
|
||||
raise EtcdStubError
|
||||
|
||||
@ -9,7 +9,7 @@ import socket
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Optional
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from torch.distributed import Store
|
||||
from torch.distributed.elastic.utils.distributed import get_free_port
|
||||
@ -72,8 +72,8 @@ class RendezvousStoreInfo:
|
||||
def build(
|
||||
rank: int,
|
||||
store: Store,
|
||||
local_addr: Optional[str],
|
||||
server_port: Optional[int] = None,
|
||||
local_addr: str | None,
|
||||
server_port: int | None = None,
|
||||
) -> "RendezvousStoreInfo":
|
||||
"""Factory method, finds unused new port on rank0 host and addr/port info with all ranks.
|
||||
|
||||
@ -137,7 +137,7 @@ class RendezvousInfo:
|
||||
return self._world_size
|
||||
|
||||
@property
|
||||
def bootstrap_store_info(self) -> Optional[RendezvousStoreInfo]:
|
||||
def bootstrap_store_info(self) -> RendezvousStoreInfo | None:
|
||||
"""Store information that can used by trainer code to bootstrap distributed comms."""
|
||||
return self._bootstrap_store_info
|
||||
|
||||
@ -265,7 +265,7 @@ class RendezvousParameters:
|
||||
run_id: str,
|
||||
min_nodes: int,
|
||||
max_nodes: int,
|
||||
local_addr: Optional[str] = None,
|
||||
local_addr: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not backend:
|
||||
@ -293,7 +293,7 @@ class RendezvousParameters:
|
||||
"""Return the value for ``key`` if ``key`` exists, else ``default``."""
|
||||
return self.config.get(key, default)
|
||||
|
||||
def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]:
|
||||
def get_as_bool(self, key: str, default: bool | None = None) -> bool | None:
|
||||
"""Return the value for ``key`` as a ``bool``."""
|
||||
value = self.get(key, default)
|
||||
if value is None or isinstance(value, bool):
|
||||
@ -312,7 +312,7 @@ class RendezvousParameters:
|
||||
f"The rendezvous configuration option '{key}' does not represent a valid boolean value."
|
||||
)
|
||||
|
||||
def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]:
|
||||
def get_as_int(self, key: str, default: int | None = None) -> int | None:
|
||||
"""Return the value for ``key`` as an ``int``."""
|
||||
value = self.get(key, default)
|
||||
if value is None:
|
||||
@ -350,7 +350,7 @@ class RendezvousHandlerRegistry:
|
||||
if not backend:
|
||||
raise ValueError("The rendezvous backend name must be a non-empty string.")
|
||||
|
||||
current_creator: Optional[RendezvousHandlerCreator]
|
||||
current_creator: RendezvousHandlerCreator | None
|
||||
try:
|
||||
current_creator = self._registry[backend]
|
||||
except KeyError:
|
||||
|
||||
@ -11,7 +11,7 @@ import os
|
||||
import tempfile
|
||||
from base64 import b64decode, b64encode
|
||||
from datetime import timedelta
|
||||
from typing import Any, cast, Optional
|
||||
from typing import Any, cast
|
||||
|
||||
from torch.distributed import FileStore, Store, TCPStore
|
||||
from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState
|
||||
@ -70,15 +70,15 @@ class C10dRendezvousBackend(RendezvousBackend):
|
||||
"""See base class."""
|
||||
return "c10d"
|
||||
|
||||
def get_state(self) -> Optional[tuple[bytes, Token]]:
|
||||
def get_state(self) -> tuple[bytes, Token] | None:
|
||||
"""See base class."""
|
||||
base64_state: bytes = self._call_store("get", self._key)
|
||||
|
||||
return self._decode_state(base64_state)
|
||||
|
||||
def set_state(
|
||||
self, state: bytes, token: Optional[Token] = None
|
||||
) -> Optional[tuple[bytes, Token, bool]]:
|
||||
self, state: bytes, token: Token | None = None
|
||||
) -> tuple[bytes, Token, bool] | None:
|
||||
"""See base class."""
|
||||
base64_state_str: str = b64encode(state).decode()
|
||||
|
||||
@ -117,7 +117,7 @@ class C10dRendezvousBackend(RendezvousBackend):
|
||||
"The connection to the C10d store has failed. See inner exception for details."
|
||||
) from exc
|
||||
|
||||
def _decode_state(self, base64_state: bytes) -> Optional[tuple[bytes, Token]]:
|
||||
def _decode_state(self, base64_state: bytes) -> tuple[bytes, Token] | None:
|
||||
if base64_state == self._NULL_SENTINEL.encode():
|
||||
return None
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import Store
|
||||
@ -68,7 +68,7 @@ class RendezvousBackend(ABC):
|
||||
"""Get the name of the backend."""
|
||||
|
||||
@abstractmethod
|
||||
def get_state(self) -> Optional[tuple[bytes, Token]]:
|
||||
def get_state(self) -> tuple[bytes, Token] | None:
|
||||
"""Get the rendezvous state.
|
||||
|
||||
Returns:
|
||||
@ -84,8 +84,8 @@ class RendezvousBackend(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def set_state(
|
||||
self, state: bytes, token: Optional[Token] = None
|
||||
) -> Optional[tuple[bytes, Token, bool]]:
|
||||
self, state: bytes, token: Token | None = None
|
||||
) -> tuple[bytes, Token, bool] | None:
|
||||
"""Set the rendezvous state.
|
||||
|
||||
The new rendezvous state is set conditionally:
|
||||
@ -154,10 +154,10 @@ class RendezvousTimeout:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
join: Optional[timedelta] = None,
|
||||
last_call: Optional[timedelta] = None,
|
||||
close: Optional[timedelta] = None,
|
||||
heartbeat: Optional[timedelta] = None,
|
||||
join: timedelta | None = None,
|
||||
last_call: timedelta | None = None,
|
||||
close: timedelta | None = None,
|
||||
heartbeat: timedelta | None = None,
|
||||
) -> None:
|
||||
self._set_timeouts(
|
||||
join=join, last_call=last_call, close=close, heartbeat=heartbeat
|
||||
@ -183,7 +183,7 @@ class RendezvousTimeout:
|
||||
"""Get the keep-alive heartbeat timeout."""
|
||||
return self._heartbeat
|
||||
|
||||
def _set_timeouts(self, **timeouts: Optional[timedelta]):
|
||||
def _set_timeouts(self, **timeouts: timedelta | None):
|
||||
for name, timeout in timeouts.items():
|
||||
if timeout is None:
|
||||
timeout = self._DEFAULT_TIMEOUTS[name]
|
||||
@ -258,7 +258,7 @@ class _NodeDescGenerator:
|
||||
# An integer that is incremented with each call to generate().
|
||||
self._local_id = 0
|
||||
|
||||
def generate(self, local_addr: Optional[str] = None) -> _NodeDesc:
|
||||
def generate(self, local_addr: str | None = None) -> _NodeDesc:
|
||||
# This method can be called by multiple threads concurrently; therefore,
|
||||
# we must increment the integer atomically.
|
||||
with self._lock:
|
||||
@ -297,7 +297,7 @@ class _RendezvousState:
|
||||
|
||||
round: int
|
||||
complete: bool
|
||||
deadline: Optional[datetime]
|
||||
deadline: datetime | None
|
||||
closed: bool
|
||||
participants: dict[_NodeDesc, int]
|
||||
wait_list: set[_NodeDesc]
|
||||
@ -345,7 +345,7 @@ class _RendezvousStateHolder(ABC):
|
||||
"""Get the local state."""
|
||||
|
||||
@abstractmethod
|
||||
def sync(self) -> Optional[bool]:
|
||||
def sync(self) -> bool | None:
|
||||
"""Read or writes the latest state.
|
||||
|
||||
Returns:
|
||||
@ -408,13 +408,13 @@ class _BackendRendezvousStateHolder(_RendezvousStateHolder):
|
||||
"""See base class."""
|
||||
return self._state
|
||||
|
||||
def sync(self) -> Optional[bool]:
|
||||
def sync(self) -> bool | None:
|
||||
"""See base class."""
|
||||
state_bits: Optional[bytes] = None
|
||||
state_bits: bytes | None = None
|
||||
|
||||
token = None
|
||||
|
||||
has_set: Optional[bool]
|
||||
has_set: bool | None
|
||||
|
||||
if self._dirty:
|
||||
has_set = False
|
||||
@ -574,7 +574,7 @@ class _RendezvousOpExecutor(ABC):
|
||||
self,
|
||||
state_handler: Callable[[_RendezvousContext, float], _Action],
|
||||
deadline: float,
|
||||
update_deadline: Optional[Callable[[timedelta], float]] = None,
|
||||
update_deadline: Callable[[timedelta], float] | None = None,
|
||||
) -> None:
|
||||
"""Execute a rendezvous operation.
|
||||
|
||||
@ -638,7 +638,7 @@ class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor):
|
||||
self,
|
||||
state_handler: Callable[[_RendezvousContext, float], _Action],
|
||||
deadline: float,
|
||||
update_deadline: Optional[Callable[[timedelta], float]] = None,
|
||||
update_deadline: Callable[[timedelta], float] | None = None,
|
||||
) -> None:
|
||||
"""See base class."""
|
||||
action = None
|
||||
@ -1006,7 +1006,7 @@ class DynamicRendezvousHandler(RendezvousHandler):
|
||||
_state_holder: _RendezvousStateHolder
|
||||
_op_executor: _RendezvousOpExecutor
|
||||
_heartbeat_lock: threading.Lock
|
||||
_keep_alive_timer: Optional[_PeriodicTimer]
|
||||
_keep_alive_timer: _PeriodicTimer | None
|
||||
|
||||
@classmethod
|
||||
def from_backend(
|
||||
@ -1016,8 +1016,8 @@ class DynamicRendezvousHandler(RendezvousHandler):
|
||||
backend: RendezvousBackend,
|
||||
min_nodes: int,
|
||||
max_nodes: int,
|
||||
local_addr: Optional[str] = None,
|
||||
timeout: Optional[RendezvousTimeout] = None,
|
||||
local_addr: str | None = None,
|
||||
timeout: RendezvousTimeout | None = None,
|
||||
keep_alive_interval: int = 5,
|
||||
keep_alive_max_attempt: int = 3,
|
||||
):
|
||||
@ -1102,15 +1102,15 @@ class DynamicRendezvousHandler(RendezvousHandler):
|
||||
self._keep_alive_timer = None
|
||||
|
||||
# Cached shared store server reference
|
||||
self._shared_tcp_store_server: Optional[dist.Store] = None
|
||||
self._shared_tcp_store_server: dist.Store | None = None
|
||||
|
||||
self._bootstrap_store_info: Optional[RendezvousStoreInfo] = None
|
||||
self._bootstrap_store_info: RendezvousStoreInfo | None = None
|
||||
|
||||
def _record(
|
||||
self,
|
||||
message: str,
|
||||
node_state: NodeState = NodeState.RUNNING,
|
||||
rank: Optional[int] = None,
|
||||
rank: int | None = None,
|
||||
) -> None:
|
||||
construct_and_record_rdzv_event(
|
||||
name=f"{self.__class__.__name__}.{get_method_name()}",
|
||||
@ -1379,7 +1379,7 @@ class DynamicRendezvousHandler(RendezvousHandler):
|
||||
return time.monotonic() + timeout.total_seconds()
|
||||
|
||||
|
||||
def _get_timeout(params: RendezvousParameters, key: str) -> Optional[timedelta]:
|
||||
def _get_timeout(params: RendezvousParameters, key: str) -> timedelta | None:
|
||||
timeout = params.get_as_int(key + "_timeout")
|
||||
if timeout is None:
|
||||
return None
|
||||
|
||||
@ -12,7 +12,6 @@ import logging
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
|
||||
try:
|
||||
@ -153,7 +152,7 @@ class EtcdRendezvousHandler(RendezvousHandler):
|
||||
+--------------------------------------------+--------------------------+
|
||||
"""
|
||||
|
||||
def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: Optional[str]):
|
||||
def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: str | None):
|
||||
"""
|
||||
Args:
|
||||
rdzv_impl: the implementation of the rendezvous
|
||||
@ -542,7 +541,7 @@ class EtcdRendezvous:
|
||||
|
||||
# When reaching min workers, or changing state to frozen, we'll set
|
||||
# the active_version node to be ephemeral.
|
||||
set_ttl: Optional[int] = None
|
||||
set_ttl: int | None = None
|
||||
if len(state["participants"]) == self._num_max_workers:
|
||||
state["status"] = "frozen"
|
||||
state["keep_alives"] = []
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
|
||||
import binascii
|
||||
from base64 import b64decode, b64encode
|
||||
from typing import cast, Optional
|
||||
from typing import cast
|
||||
|
||||
import urllib3.exceptions # type: ignore[import]
|
||||
|
||||
@ -49,8 +49,8 @@ class EtcdRendezvousBackend(RendezvousBackend):
|
||||
self,
|
||||
client: etcd.Client,
|
||||
run_id: str,
|
||||
key_prefix: Optional[str] = None,
|
||||
ttl: Optional[int] = None,
|
||||
key_prefix: str | None = None,
|
||||
ttl: int | None = None,
|
||||
) -> None:
|
||||
if not run_id:
|
||||
raise ValueError("The run id must be a non-empty string.")
|
||||
@ -72,7 +72,7 @@ class EtcdRendezvousBackend(RendezvousBackend):
|
||||
"""See base class."""
|
||||
return "etcd-v2"
|
||||
|
||||
def get_state(self) -> Optional[tuple[bytes, Token]]:
|
||||
def get_state(self) -> tuple[bytes, Token] | None:
|
||||
"""See base class."""
|
||||
try:
|
||||
result = self._client.read(self._key)
|
||||
@ -86,8 +86,8 @@ class EtcdRendezvousBackend(RendezvousBackend):
|
||||
return self._decode_state(result)
|
||||
|
||||
def set_state(
|
||||
self, state: bytes, token: Optional[Token] = None
|
||||
) -> Optional[tuple[bytes, Token, bool]]:
|
||||
self, state: bytes, token: Token | None = None
|
||||
) -> tuple[bytes, Token, bool] | None:
|
||||
"""See base class."""
|
||||
base64_state = b64encode(state).decode()
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ import socket
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Optional, TextIO, Union
|
||||
from typing import TextIO
|
||||
|
||||
|
||||
try:
|
||||
@ -64,7 +64,7 @@ def find_free_port():
|
||||
raise RuntimeError("Failed to create a socket")
|
||||
|
||||
|
||||
def stop_etcd(subprocess, data_dir: Optional[str] = None):
|
||||
def stop_etcd(subprocess, data_dir: str | None = None):
|
||||
if subprocess and subprocess.poll() is None:
|
||||
logger.info("stopping etcd server")
|
||||
subprocess.terminate()
|
||||
@ -107,7 +107,7 @@ class EtcdServer:
|
||||
etcd_binary_path: path of etcd server binary (see above for fallback path)
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: Optional[str] = None):
|
||||
def __init__(self, data_dir: str | None = None):
|
||||
self._port = -1
|
||||
self._host = "localhost"
|
||||
|
||||
@ -123,7 +123,7 @@ class EtcdServer:
|
||||
data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data")
|
||||
)
|
||||
self._etcd_cmd = None
|
||||
self._etcd_proc: Optional[subprocess.Popen] = None
|
||||
self._etcd_proc: subprocess.Popen | None = None
|
||||
|
||||
def _get_etcd_server_process(self) -> subprocess.Popen:
|
||||
if not self._etcd_proc:
|
||||
@ -149,7 +149,7 @@ class EtcdServer:
|
||||
self,
|
||||
timeout: int = 60,
|
||||
num_retries: int = 3,
|
||||
stderr: Union[int, TextIO, None] = None,
|
||||
stderr: int | TextIO | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Start the server, and waits for it to be ready. When this function returns the sever is ready to take requests.
|
||||
@ -185,7 +185,7 @@ class EtcdServer:
|
||||
atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir)
|
||||
|
||||
def _start(
|
||||
self, data_dir: str, timeout: int = 60, stderr: Union[int, TextIO, None] = None
|
||||
self, data_dir: str, timeout: int = 60, stderr: int | TextIO | None = None
|
||||
) -> None:
|
||||
sock = find_free_port()
|
||||
sock_peer = find_free_port()
|
||||
|
||||
@ -9,7 +9,6 @@ import datetime
|
||||
import random
|
||||
import time
|
||||
from base64 import b64decode, b64encode
|
||||
from typing import Optional
|
||||
|
||||
# pyre-ignore[21]: Could not find name `Store` in `torch.distributed`.
|
||||
from torch.distributed import Store
|
||||
@ -40,7 +39,7 @@ class EtcdStore(Store):
|
||||
etcd_client,
|
||||
etcd_store_prefix,
|
||||
# Default timeout same as in c10d/Store.hpp
|
||||
timeout: Optional[datetime.timedelta] = None,
|
||||
timeout: datetime.timedelta | None = None,
|
||||
):
|
||||
super().__init__() # required for pybind trampoline.
|
||||
|
||||
@ -121,7 +120,7 @@ class EtcdStore(Store):
|
||||
except etcd.EtcdCompareFailed:
|
||||
cas_delay()
|
||||
|
||||
def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None):
|
||||
def wait(self, keys, override_timeout: datetime.timedelta | None = None):
|
||||
"""
|
||||
Wait until all of the keys are published, or until timeout.
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
from typing import cast, Optional
|
||||
from typing import cast
|
||||
|
||||
from torch.distributed import PrefixStore, Store, TCPStore
|
||||
from torch.distributed.elastic.rendezvous import (
|
||||
@ -51,7 +51,7 @@ class StaticTCPRendezvous(RendezvousHandler):
|
||||
self.world_size = world_size
|
||||
self.run_id = run_id
|
||||
self.timeout = datetime.timedelta(seconds=timeout)
|
||||
self._store: Optional[Store] = None
|
||||
self._store: Store | None = None
|
||||
|
||||
def get_backend(self) -> str:
|
||||
return "static"
|
||||
|
||||
@ -14,7 +14,7 @@ import weakref
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from threading import Event, Thread
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
|
||||
__all__ = ["parse_rendezvous_endpoint"]
|
||||
@ -44,7 +44,7 @@ def _parse_rendezvous_config(config_str: str) -> dict[str, str]:
|
||||
"<key1>=<value1>,...,<keyN>=<valueN>."
|
||||
)
|
||||
|
||||
value: Optional[str]
|
||||
value: str | None
|
||||
if values:
|
||||
value = values[0].strip()
|
||||
else:
|
||||
@ -58,7 +58,7 @@ def _parse_rendezvous_config(config_str: str) -> dict[str, str]:
|
||||
return config
|
||||
|
||||
|
||||
def _try_parse_port(port_str: str) -> Optional[int]:
|
||||
def _try_parse_port(port_str: str) -> int | None:
|
||||
"""Try to extract the port number from ``port_str``."""
|
||||
if port_str and re.match(r"^[0-9]{1,5}$", port_str):
|
||||
return int(port_str)
|
||||
@ -66,7 +66,7 @@ def _try_parse_port(port_str: str) -> Optional[int]:
|
||||
|
||||
|
||||
def parse_rendezvous_endpoint(
|
||||
endpoint: Optional[str], default_port: int
|
||||
endpoint: str | None, default_port: int
|
||||
) -> tuple[str, int]:
|
||||
"""Extract the hostname and the port number from a rendezvous endpoint.
|
||||
|
||||
@ -166,7 +166,7 @@ def _matches_machine_hostname(host: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _delay(seconds: Union[float, tuple[float, float]]) -> None:
|
||||
def _delay(seconds: float | tuple[float, float]) -> None:
|
||||
"""Suspend the current thread for ``seconds``.
|
||||
|
||||
Args:
|
||||
@ -200,9 +200,9 @@ class _PeriodicTimer:
|
||||
kwargs: dict[str, Any]
|
||||
stop_event: Event
|
||||
|
||||
_name: Optional[str]
|
||||
_thread: Optional[Thread]
|
||||
_finalizer: Optional[weakref.finalize]
|
||||
_name: str | None
|
||||
_thread: Thread | None
|
||||
_finalizer: weakref.finalize | None
|
||||
|
||||
# The context that is shared between the timer and the background thread.
|
||||
_ctx: _Context
|
||||
@ -227,7 +227,7 @@ class _PeriodicTimer:
|
||||
self._finalizer = None
|
||||
|
||||
@property
|
||||
def name(self) -> Optional[str]:
|
||||
def name(self) -> str | None:
|
||||
"""Get the name of the timer."""
|
||||
return self._name
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from inspect import getframeinfo, stack
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -130,7 +130,7 @@ class TimerServer(abc.ABC):
|
||||
self._request_queue = request_queue
|
||||
self._max_interval = max_interval
|
||||
self._daemon = daemon
|
||||
self._watchdog_thread: Optional[threading.Thread] = None
|
||||
self._watchdog_thread: threading.Thread | None = None
|
||||
self._stop_signaled = False
|
||||
|
||||
@abc.abstractmethod
|
||||
@ -234,7 +234,7 @@ class TimerServer(abc.ABC):
|
||||
logger.info("No watchdog thread running, doing nothing")
|
||||
|
||||
|
||||
_timer_client: Optional[TimerClient] = None
|
||||
_timer_client: TimerClient | None = None
|
||||
|
||||
|
||||
def configure(timer_client: TimerClient):
|
||||
@ -247,9 +247,7 @@ def configure(timer_client: TimerClient):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def expires(
|
||||
after: float, scope: Optional[str] = None, client: Optional[TimerClient] = None
|
||||
):
|
||||
def expires(after: float, scope: str | None = None, client: TimerClient | None = None):
|
||||
"""
|
||||
Acquires a countdown timer that expires in ``after`` seconds from now,
|
||||
unless the code-block that it wraps is finished within the timeframe.
|
||||
|
||||
@ -14,7 +14,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Optional, TypeVar
|
||||
from typing import TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from torch.distributed.elastic.timer.api import TimerClient, TimerRequest
|
||||
@ -131,7 +131,7 @@ class FileTimerClient(TimerClient):
|
||||
self.signal = signal
|
||||
|
||||
@_retry(max_retries=10, sleep_time=0.1)
|
||||
def _open_non_blocking(self) -> Optional[io.TextIOWrapper]:
|
||||
def _open_non_blocking(self) -> io.TextIOWrapper | None:
|
||||
# The server may have crashed or may haven't started yet.
|
||||
# In such case, calling open() in blocking model blocks the client.
|
||||
# To avoid such issue, open it in non-blocking mode, and an OSError will
|
||||
@ -200,7 +200,7 @@ class FileTimerServer:
|
||||
run_id: str,
|
||||
max_interval: float = 10,
|
||||
daemon: bool = True,
|
||||
log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None,
|
||||
log_event: Callable[[str, FileTimerRequest | None], None] | None = None,
|
||||
) -> None:
|
||||
self._file_path = file_path
|
||||
self._run_id = run_id
|
||||
@ -208,7 +208,7 @@ class FileTimerServer:
|
||||
self._daemon = daemon
|
||||
self._timers: dict[tuple[int, str], FileTimerRequest] = {}
|
||||
self._stop_signaled = False
|
||||
self._watchdog_thread: Optional[threading.Thread] = None
|
||||
self._watchdog_thread: threading.Thread | None = None
|
||||
|
||||
self._is_client_started = False
|
||||
if os.path.exists(self._file_path):
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Iterator, Sized
|
||||
from typing import cast, Optional, TypeVar
|
||||
from typing import cast, TypeVar
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
@ -44,8 +44,8 @@ class ElasticDistributedSampler(DistributedSampler[T]):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset[T],
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
num_replicas: int | None = None,
|
||||
rank: int | None = None,
|
||||
start_index: int = 0,
|
||||
):
|
||||
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank)
|
||||
|
||||
@ -10,7 +10,6 @@ import datetime
|
||||
import os
|
||||
import socket
|
||||
from contextlib import closing
|
||||
from typing import Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.elastic.utils.logging import get_logger
|
||||
@ -35,7 +34,7 @@ def create_c10d_store(
|
||||
timeout: float = (60 * 10), # 10 min
|
||||
wait_for_workers: bool = True,
|
||||
retries=3,
|
||||
use_libuv: Optional[bool] = None,
|
||||
use_libuv: bool | None = None,
|
||||
):
|
||||
if use_libuv is not None:
|
||||
logger.warning(
|
||||
|
||||
@ -10,12 +10,11 @@ import inspect
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
from torch.distributed.elastic.utils.log_level import get_log_level
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
||||
def get_logger(name: str | None = None) -> logging.Logger:
|
||||
"""
|
||||
Util function to set up a simple logger that writes
|
||||
into stderr. The loglevel is fetched from the LOGLEVEL
|
||||
@ -32,13 +31,13 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
|
||||
return _setup_logger(name or _derive_module_name(depth=2))
|
||||
|
||||
|
||||
def _setup_logger(name: Optional[str] = None) -> logging.Logger:
|
||||
def _setup_logger(name: str | None = None) -> logging.Logger:
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(os.environ.get("LOGLEVEL", get_log_level()))
|
||||
return logger
|
||||
|
||||
|
||||
def _derive_module_name(depth: int = 1) -> Optional[str]:
|
||||
def _derive_module_name(depth: int = 1) -> str | None:
|
||||
"""
|
||||
Derives the name of the caller module from the stack frames.
|
||||
|
||||
|
||||
@ -10,7 +10,6 @@
|
||||
from collections.abc import Callable, Iterable
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -109,7 +108,7 @@ def _try_detecting_missing_ranks(
|
||||
rank: int,
|
||||
rank_decoder: Callable[[int], str],
|
||||
trace_timeout: float,
|
||||
) -> Optional[Iterable[str]]:
|
||||
) -> Iterable[str] | None:
|
||||
store.set(f"{key_prefix}{rank}{_TRACE}", "<val_ignored>")
|
||||
|
||||
def _find_missing_ranks():
|
||||
@ -169,8 +168,8 @@ def barrier(
|
||||
world_size: int,
|
||||
key_prefix: str,
|
||||
barrier_timeout: float = 300,
|
||||
rank: Optional[int] = None,
|
||||
rank_tracing_decoder: Optional[Callable[[int], str]] = None,
|
||||
rank: int | None = None,
|
||||
rank_tracing_decoder: Callable[[int], str] | None = None,
|
||||
trace_timeout: float = 10,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
@ -11,7 +11,7 @@ import sys
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
|
||||
@ -90,7 +90,7 @@ class LaunchConfig:
|
||||
min_nodes: int
|
||||
max_nodes: int
|
||||
nproc_per_node: int
|
||||
logs_specs: Optional[LogsSpecs] = None
|
||||
logs_specs: LogsSpecs | None = None
|
||||
run_id: str = ""
|
||||
role: str = "default_role"
|
||||
rdzv_endpoint: str = ""
|
||||
@ -100,14 +100,14 @@ class LaunchConfig:
|
||||
max_restarts: int = 3
|
||||
monitor_interval: float = 0.1
|
||||
start_method: str = "spawn"
|
||||
log_line_prefix_template: Optional[str] = None
|
||||
log_line_prefix_template: str | None = None
|
||||
metrics_cfg: dict[str, str] = field(default_factory=dict)
|
||||
local_addr: Optional[str] = None
|
||||
local_addr: str | None = None
|
||||
event_log_handler: str = "null"
|
||||
numa_options: Optional[NumaOptions] = None
|
||||
numa_options: NumaOptions | None = None
|
||||
signals_to_handle: str = "SIGTERM,SIGINT,SIGHUP,SIGQUIT"
|
||||
duplicate_stdout_filters: Optional[list[str]] = None
|
||||
duplicate_stderr_filters: Optional[list[str]] = None
|
||||
duplicate_stdout_filters: list[str] | None = None
|
||||
duplicate_stderr_filters: list[str] | None = None
|
||||
virtual_local_rank: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
@ -161,7 +161,7 @@ class elastic_launch:
|
||||
def __init__(
|
||||
self,
|
||||
config: LaunchConfig,
|
||||
entrypoint: Union[Callable, str, None],
|
||||
entrypoint: Callable | str | None,
|
||||
):
|
||||
self._config = config
|
||||
self._entrypoint = entrypoint
|
||||
@ -170,9 +170,7 @@ class elastic_launch:
|
||||
return launch_agent(self._config, self._entrypoint, list(args))
|
||||
|
||||
|
||||
def _get_entrypoint_name(
|
||||
entrypoint: Union[Callable, str, None], args: list[Any]
|
||||
) -> str:
|
||||
def _get_entrypoint_name(entrypoint: Callable | str | None, args: list[Any]) -> str:
|
||||
"""Retrieve entrypoint name with the rule:
|
||||
1. If entrypoint is a function, use ``entrypoint.__qualname__``.
|
||||
2. If entrypoint is a string, check its value:
|
||||
@ -194,7 +192,7 @@ def _get_entrypoint_name(
|
||||
|
||||
def _get_addr_and_port(
|
||||
rdzv_parameters: RendezvousParameters,
|
||||
) -> tuple[Optional[str], Optional[int]]:
|
||||
) -> tuple[str | None, int | None]:
|
||||
if rdzv_parameters.backend != "static":
|
||||
return (None, None)
|
||||
endpoint = rdzv_parameters.endpoint
|
||||
@ -213,7 +211,7 @@ def _get_addr_and_port(
|
||||
|
||||
def launch_agent(
|
||||
config: LaunchConfig,
|
||||
entrypoint: Union[Callable, str, None],
|
||||
entrypoint: Callable | str | None,
|
||||
args: list[Any],
|
||||
) -> dict[int, Any]:
|
||||
if not config.run_id:
|
||||
|
||||
@ -5,7 +5,7 @@ import io
|
||||
import sys
|
||||
import types
|
||||
from collections.abc import Callable, Iterator, Mapping
|
||||
from typing import Any, Optional, TypeVar, Union
|
||||
from typing import Any, TypeVar, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
@ -122,8 +122,8 @@ class _RemoteModule(nn.Module):
|
||||
self,
|
||||
remote_device: str,
|
||||
module_cls: type[nn.Module],
|
||||
args: Optional[tuple] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
_module_interface_cls: Any = None,
|
||||
):
|
||||
"""
|
||||
@ -310,32 +310,32 @@ class _RemoteModule(nn.Module):
|
||||
)
|
||||
|
||||
def register_buffer(
|
||||
self, name: str, tensor: Optional[Tensor], persistent: bool = True
|
||||
self, name: str, tensor: Tensor | None, persistent: bool = True
|
||||
) -> None:
|
||||
_raise_not_supported(self.register_buffer.__name__)
|
||||
|
||||
def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
|
||||
def register_parameter(self, name: str, param: Parameter | None) -> None:
|
||||
_raise_not_supported(self.register_parameter.__name__)
|
||||
|
||||
def add_module(self, name: str, module: Optional[Module]) -> None:
|
||||
def add_module(self, name: str, module: Module | None) -> None:
|
||||
_raise_not_supported(self.add_module.__name__)
|
||||
|
||||
def apply(self, fn: Callable[[Module], None]) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.apply.__name__)
|
||||
|
||||
def cuda(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return]
|
||||
def cuda(self, device: int | device | None = None) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.cuda.__name__)
|
||||
|
||||
def ipu(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return]
|
||||
def ipu(self, device: int | device | None = None) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.ipu.__name__)
|
||||
|
||||
def xpu(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return]
|
||||
def xpu(self, device: int | device | None = None) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.xpu.__name__)
|
||||
|
||||
def cpu(self) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.cpu.__name__)
|
||||
|
||||
def type(self, dst_type: Union[dtype, str]) -> Self: # type: ignore[return]
|
||||
def type(self, dst_type: dtype | str) -> Self: # type: ignore[return]
|
||||
_raise_not_supported(self.type.__name__)
|
||||
|
||||
def float(self) -> Self: # type: ignore[return]
|
||||
@ -355,19 +355,16 @@ class _RemoteModule(nn.Module):
|
||||
|
||||
def register_backward_hook( # type: ignore[return]
|
||||
self,
|
||||
hook: Callable[[Module, _grad_t, _grad_t], Union[None, _grad_t]],
|
||||
hook: Callable[[Module, _grad_t, _grad_t], None | _grad_t],
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> RemovableHandle:
|
||||
_raise_not_supported(self.register_backward_hook.__name__)
|
||||
|
||||
def register_forward_pre_hook( # type: ignore[return]
|
||||
self,
|
||||
hook: Union[
|
||||
Callable[[T, tuple[Any, ...]], Optional[Any]],
|
||||
Callable[
|
||||
[T, tuple[Any, ...], dict[str, Any]],
|
||||
Optional[tuple[Any, dict[str, Any]]],
|
||||
],
|
||||
hook: Callable[[T, tuple[Any, ...]], Any | None]
|
||||
| Callable[
|
||||
[T, tuple[Any, ...], dict[str, Any]], tuple[Any, dict[str, Any]] | None
|
||||
],
|
||||
prepend: bool = False,
|
||||
with_kwargs: bool = False,
|
||||
@ -377,10 +374,8 @@ class _RemoteModule(nn.Module):
|
||||
|
||||
def register_forward_hook( # type: ignore[return, override]
|
||||
self,
|
||||
hook: Union[
|
||||
Callable[[T, tuple[Any, ...], Any], Optional[Any]],
|
||||
Callable[[T, tuple[Any, ...], dict[str, Any], Any], Optional[Any]],
|
||||
],
|
||||
hook: Callable[[T, tuple[Any, ...], Any], Any | None]
|
||||
| Callable[[T, tuple[Any, ...], dict[str, Any], Any], Any | None],
|
||||
prepend: bool = False,
|
||||
with_kwargs: bool = False,
|
||||
# pyrefly: ignore [bad-return]
|
||||
@ -435,7 +430,7 @@ class _RemoteModule(nn.Module):
|
||||
|
||||
def named_modules(
|
||||
self,
|
||||
memo: Optional[set[Module]] = None,
|
||||
memo: set[Module] | None = None,
|
||||
prefix: str = "",
|
||||
remove_duplicate: bool = True,
|
||||
):
|
||||
@ -694,8 +689,8 @@ class RemoteModule(_RemoteModule):
|
||||
self,
|
||||
remote_device: str,
|
||||
module_cls: type[nn.Module],
|
||||
args: Optional[tuple] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__(remote_device, module_cls, args, kwargs)
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.optim._functional as F
|
||||
@ -53,7 +52,7 @@ class _FunctionalAdadelta:
|
||||
|
||||
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
|
||||
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Tensor | None]):
|
||||
params = self.param_group["params"]
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.optim._functional as F
|
||||
@ -70,7 +69,7 @@ class _FunctionalAdagrad:
|
||||
"step": torch.tensor(0.0),
|
||||
}
|
||||
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Tensor | None]):
|
||||
params = self.param_group["params"]
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.optim._functional as F
|
||||
@ -68,7 +67,7 @@ class _FunctionalAdam:
|
||||
# param group as it's not a common use case.
|
||||
self.param_group = {"params": params}
|
||||
|
||||
def step_param(self, param: Tensor, grad: Optional[Tensor]):
|
||||
def step_param(self, param: Tensor, grad: Tensor | None):
|
||||
"""
|
||||
Similar to step, but operates on a single parameter and optionally a
|
||||
gradient tensor.
|
||||
@ -128,7 +127,7 @@ class _FunctionalAdam:
|
||||
found_inf=None,
|
||||
)
|
||||
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Tensor | None]):
|
||||
params = self.param_group["params"]
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.optim._functional as F
|
||||
@ -64,7 +63,7 @@ class _FunctionalAdamax:
|
||||
# param group as it's not a common use case.
|
||||
self.param_group = {"params": params}
|
||||
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Tensor | None]):
|
||||
params = self.param_group["params"]
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.optim._functional as F
|
||||
@ -68,7 +67,7 @@ class _FunctionalAdamW:
|
||||
# param group as it's not a common use case.
|
||||
self.param_group = {"params": params}
|
||||
|
||||
def step_param(self, param: Tensor, grad: Optional[Tensor]):
|
||||
def step_param(self, param: Tensor, grad: Tensor | None):
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
exp_avgs = []
|
||||
@ -129,7 +128,7 @@ class _FunctionalAdamW:
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Tensor | None]):
|
||||
params = self.param_group["params"]
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.optim._functional as F
|
||||
@ -57,7 +56,7 @@ class _FunctionalRMSprop:
|
||||
|
||||
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
|
||||
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Tensor | None]):
|
||||
params = self.param_group["params"]
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.optim._functional as F
|
||||
@ -51,7 +50,7 @@ class _FunctionalRprop:
|
||||
|
||||
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
|
||||
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Tensor | None]):
|
||||
params = self.param_group["params"]
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.optim._functional as F
|
||||
@ -56,7 +55,7 @@ class _FunctionalSGD:
|
||||
# param group as it's not a common use case.
|
||||
self.param_group = {"params": params}
|
||||
|
||||
def step_param(self, param: Tensor, grad: Optional[Tensor]):
|
||||
def step_param(self, param: Tensor, grad: Tensor | None):
|
||||
"""Similar to self.step, but operates on a single parameter and
|
||||
its gradient.
|
||||
"""
|
||||
@ -67,7 +66,7 @@ class _FunctionalSGD:
|
||||
dampening = self.defaults["dampening"]
|
||||
lr = self.defaults["lr"]
|
||||
params = [param]
|
||||
momentum_buffer_list: list[Optional[Tensor]] = []
|
||||
momentum_buffer_list: list[Tensor | None] = []
|
||||
grads = []
|
||||
|
||||
has_sparse_grad = False
|
||||
@ -106,11 +105,11 @@ class _FunctionalSGD:
|
||||
if momentum_buffer is not None:
|
||||
state["momentum_buffer"] = momentum_buffer
|
||||
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Tensor | None]):
|
||||
params = self.param_group["params"]
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
momentum_buffer_list: list[Optional[Tensor]] = []
|
||||
momentum_buffer_list: list[Tensor | None] = []
|
||||
lr = self.defaults["lr"]
|
||||
weight_decay = self.defaults["weight_decay"]
|
||||
momentum = self.defaults["momentum"]
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import warnings
|
||||
from collections.abc import Callable, Collection, Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, overload, Union
|
||||
from typing import Any, overload
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -62,10 +62,10 @@ class _NamedOptimizer(optim.Optimizer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]],
|
||||
named_parameters: Mapping[str, torch.Tensor | ShardedTensor],
|
||||
optimizer_class: optim.Optimizer,
|
||||
param_groups: Optional[Collection[Mapping[str, Any]]] = None,
|
||||
module: Optional[nn.Module] = None,
|
||||
param_groups: Collection[Mapping[str, Any]] | None = None,
|
||||
module: nn.Module | None = None,
|
||||
*args: tuple[Any, ...],
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
@ -152,7 +152,7 @@ class _NamedOptimizer(optim.Optimizer):
|
||||
@overload
|
||||
def step(self, closure: Callable[[], float]) -> float: ...
|
||||
|
||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
||||
def step(self, closure: Callable[[], float] | None = None) -> float | None:
|
||||
"""
|
||||
Perform a single optimization step.
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed.autograd as dist_autograd
|
||||
@ -51,7 +50,7 @@ class _ScriptLocalOptimizer(nn.Module):
|
||||
def step(self, autograd_ctx_id: int):
|
||||
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
|
||||
# apply functional optimizer step with a list of gradients
|
||||
grads: list[Optional[Tensor]] = [
|
||||
grads: list[Tensor | None] = [
|
||||
all_local_grads[p] if p in all_local_grads else None # noqa: SIM401
|
||||
for p in self._local_params
|
||||
]
|
||||
|
||||
@ -13,7 +13,7 @@ import io
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from itertools import chain
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -173,7 +173,7 @@ class _DDPBucketAssignment:
|
||||
# DDP guarantees all parameters in the bucket have the same device
|
||||
# pyrefly: ignore [read-only]
|
||||
self.device: torch.device = self.parameters[0].device
|
||||
self.tensor: Optional[torch.Tensor] = None
|
||||
self.tensor: torch.Tensor | None = None
|
||||
|
||||
|
||||
class _OverlapStatus(enum.IntEnum):
|
||||
@ -252,7 +252,7 @@ class _OverlapInfo:
|
||||
# Group Ranks
|
||||
self.assigned_ranks_per_bucket: list[set[int]] = []
|
||||
self.num_bucket_assignments: int = 0
|
||||
self.total_size: Optional[int] = None
|
||||
self.total_size: int | None = None
|
||||
|
||||
# Modified per iteration
|
||||
self.broadcast_handles: list[Any] = []
|
||||
@ -377,7 +377,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
self,
|
||||
params,
|
||||
optimizer_class: type[Optimizer],
|
||||
process_group: Optional[Any] = None,
|
||||
process_group: Any | None = None,
|
||||
parameters_as_bucket_view: bool = False,
|
||||
overlap_with_ddp: bool = False,
|
||||
**defaults: Any,
|
||||
@ -649,7 +649,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
|
||||
def _partition_parameters(
|
||||
self,
|
||||
params_per_rank: Optional[list[list[torch.Tensor]]] = None,
|
||||
params_per_rank: list[list[torch.Tensor]] | None = None,
|
||||
) -> list[list[dict]]:
|
||||
r"""
|
||||
Partitions parameters across distributed data parallel ranks.
|
||||
@ -869,7 +869,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
def _get_min_index(
|
||||
self,
|
||||
values: list[int],
|
||||
disallowed_indices: Optional[set[int]] = None,
|
||||
disallowed_indices: set[int] | None = None,
|
||||
) -> int:
|
||||
r"""
|
||||
Return ``values.index(min(values))``, except only uses one pass.
|
||||
@ -1036,10 +1036,10 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
|
||||
def _local_step(
|
||||
self,
|
||||
gradients: Optional[list[Optional[torch.Tensor]]] = None,
|
||||
closure: Optional[Callable[[], float]] = None,
|
||||
gradients: list[torch.Tensor | None] | None = None,
|
||||
closure: Callable[[], float] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Optional[float]:
|
||||
) -> float | None:
|
||||
r"""
|
||||
Perform a single optimizer step without syncing parameters across ranks.
|
||||
|
||||
@ -1111,9 +1111,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
# pyrefly: ignore [bad-override]
|
||||
def step(
|
||||
self,
|
||||
closure: Optional[Callable[[], float]] = None,
|
||||
closure: Callable[[], float] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Optional[float]:
|
||||
) -> float | None:
|
||||
r"""
|
||||
Perform a single optimizer step and syncs parameters across all ranks.
|
||||
|
||||
@ -1403,7 +1403,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
def _verify_and_init_params(
|
||||
self,
|
||||
params: Any,
|
||||
) -> Union[list[torch.Tensor], list[dict]]:
|
||||
) -> list[torch.Tensor] | list[dict]:
|
||||
r"""
|
||||
Verify the type of ``params`` and initializes ``self._all_params`` as a :class:`list` of all parameters.
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from inspect import Parameter, Signature, signature
|
||||
from types import MethodType
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
@ -165,7 +165,7 @@ def _insert_stage_symbolic_backward(
|
||||
# We will only emit backward operations for nodes that can contribute
|
||||
# to the specified loss value.
|
||||
live_nodes = {loss_node: None}
|
||||
val_to_grad: dict[fx.Node, Optional[fx.Node]] = {loss_node: None}
|
||||
val_to_grad: dict[fx.Node, fx.Node | None] = {loss_node: None}
|
||||
|
||||
def assign_or_accumulate_grad(forward_node, grad_value):
|
||||
if forward_node in val_to_grad and forward_node.op != "placeholder":
|
||||
@ -186,7 +186,7 @@ def _insert_stage_symbolic_backward(
|
||||
fx.node.map_arg(node.args, add_to_live_nodes)
|
||||
fx.node.map_arg(node.kwargs, add_to_live_nodes)
|
||||
if node.op == "call_module":
|
||||
output_grads: Union[tuple[Optional[fx.Node], ...], Optional[fx.Node]]
|
||||
output_grads: tuple[fx.Node | None, ...] | fx.Node | None
|
||||
if node in tuples:
|
||||
stage_output = tuples[node]
|
||||
output_grads = tuple(val_to_grad.get(n) for n in tuples[node])
|
||||
@ -680,11 +680,10 @@ class Pipe(torch.nn.Module):
|
||||
def _from_traced(
|
||||
mod: torch.nn.Module,
|
||||
exported_program: ExportedProgram,
|
||||
multi_use_param_spec: Optional[MultiUseParamSpec] = None,
|
||||
multi_use_param_spec: MultiUseParamSpec | None = None,
|
||||
output_loss_value_spec=None,
|
||||
split_policy: Optional[
|
||||
Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
|
||||
] = None,
|
||||
split_policy: Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
|
||||
| None = None,
|
||||
):
|
||||
"""
|
||||
Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate
|
||||
@ -1012,7 +1011,7 @@ class Pipe(torch.nn.Module):
|
||||
def _trace_with_export(
|
||||
mod: torch.nn.Module,
|
||||
example_args: tuple[Any, ...],
|
||||
example_kwargs: Optional[dict[str, Any]] = None,
|
||||
example_kwargs: dict[str, Any] | None = None,
|
||||
) -> ExportedProgram:
|
||||
logger.info("Tracing model ...")
|
||||
try:
|
||||
@ -1032,8 +1031,8 @@ class Pipe(torch.nn.Module):
|
||||
def from_tracing(
|
||||
mod: torch.nn.Module,
|
||||
example_args: tuple[Any, ...],
|
||||
example_kwargs: Optional[dict[str, Any]] = None,
|
||||
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
|
||||
example_kwargs: dict[str, Any] | None = None,
|
||||
split_policy: Callable[[fx.GraphModule], fx.GraphModule] | None = None,
|
||||
):
|
||||
# If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across
|
||||
# stages instead of TRANSMIT'ting it
|
||||
@ -1120,7 +1119,7 @@ class Pipe(torch.nn.Module):
|
||||
self,
|
||||
stage_index: int,
|
||||
device: torch.device,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
group: ProcessGroup | None = None,
|
||||
) -> _PipelineStage:
|
||||
"""
|
||||
Create a `PipelineStage` given a stage index and distributed group.
|
||||
@ -1209,9 +1208,9 @@ def annotate_split_points(mod: torch.nn.Module, spec: dict[str, SplitPoint]):
|
||||
def pipeline(
|
||||
module: torch.nn.Module,
|
||||
mb_args: tuple[Any, ...],
|
||||
mb_kwargs: Optional[dict[str, Any]] = None,
|
||||
split_spec: Optional[dict[str, SplitPoint]] = None,
|
||||
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
|
||||
mb_kwargs: dict[str, Any] | None = None,
|
||||
split_spec: dict[str, SplitPoint] | None = None,
|
||||
split_policy: Callable[[fx.GraphModule], fx.GraphModule] | None = None,
|
||||
) -> Pipe:
|
||||
"""
|
||||
Split a module based on a specification.
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import collections
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.autograd.graph import GradientEdge, Node
|
||||
@ -15,7 +15,7 @@ from ._debug import map_debug_info
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Union[Node, None]:
|
||||
def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node | None:
|
||||
"""
|
||||
Get the grad function or grad accumulator for a tensor.
|
||||
|
||||
@ -142,10 +142,10 @@ def get_param_groups(
|
||||
|
||||
def stage_backward_input(
|
||||
stage_outputs_or_loss: list[torch.Tensor],
|
||||
output_grads: Optional[list[torch.Tensor]],
|
||||
output_grads: list[torch.Tensor] | None,
|
||||
input_values: list[torch.Tensor],
|
||||
weights: Iterator[Parameter],
|
||||
) -> tuple[tuple[Optional[torch.Tensor], ...], list[dict[str, Any]]]:
|
||||
) -> tuple[tuple[torch.Tensor | None, ...], list[dict[str, Any]]]:
|
||||
"""
|
||||
Compute the gradients for only the stage inputs with
|
||||
respect to the stage outputs (if non-last stage) or loss (if last stage)
|
||||
@ -225,10 +225,10 @@ def stage_backward_input(
|
||||
|
||||
def stage_backward_weight(
|
||||
weights: Iterator[Parameter], param_groups: list[dict[str, Any]], retain_graph=False
|
||||
) -> tuple[Optional[torch.Tensor], ...]:
|
||||
) -> tuple[torch.Tensor | None, ...]:
|
||||
# map weights to param_group_weights
|
||||
grad_acc_to_weight = {}
|
||||
weight_grads: list[Optional[torch.Tensor]] = []
|
||||
weight_grads: list[torch.Tensor | None] = []
|
||||
for index, weight in enumerate(weights):
|
||||
grad_acc = _get_grad_fn_or_grad_acc(weight)
|
||||
grad_acc_to_weight[grad_acc] = weight, index
|
||||
@ -282,8 +282,8 @@ def stage_backward(
|
||||
stage_output,
|
||||
output_grads,
|
||||
input_values,
|
||||
outputs_with_grads_idxs: Optional[list[int]] = None, # deprecated, not used
|
||||
) -> tuple[Optional[torch.Tensor], ...]:
|
||||
outputs_with_grads_idxs: list[int] | None = None, # deprecated, not used
|
||||
) -> tuple[torch.Tensor | None, ...]:
|
||||
"""
|
||||
This is a helper function to:
|
||||
1. compute the gradients for the stage inputs, and
|
||||
@ -303,7 +303,7 @@ def stage_backward(
|
||||
# stage_output may be a composite datatype like dict. Extract all individual
|
||||
# tensor values here
|
||||
stage_output_tensors: list[torch.Tensor] = []
|
||||
output_grad_tensors: list[Optional[torch.Tensor]] = []
|
||||
output_grad_tensors: list[torch.Tensor | None] = []
|
||||
|
||||
def extract_tensors_with_grads(
|
||||
output_val,
|
||||
@ -363,7 +363,7 @@ def stage_backward(
|
||||
)
|
||||
|
||||
# Extract gradients wrt the input values
|
||||
grad_inputs: list[Optional[torch.Tensor]] = []
|
||||
grad_inputs: list[torch.Tensor | None] = []
|
||||
for val in input_values:
|
||||
if isinstance(val, torch.Tensor):
|
||||
grad_inputs.append(val.grad)
|
||||
|
||||
@ -10,7 +10,7 @@ visualize_schedule(ops, "test.png")
|
||||
"""
|
||||
|
||||
import collections
|
||||
from typing import NamedTuple, Optional, Union
|
||||
from typing import NamedTuple
|
||||
from unittest import mock
|
||||
|
||||
from torch.distributed.pipelining.schedules import (
|
||||
@ -32,13 +32,13 @@ class OpKey(NamedTuple):
|
||||
|
||||
|
||||
def get_schedule_ops(
|
||||
schedule: Union[str, type[_PipelineSchedule]],
|
||||
schedule: str | type[_PipelineSchedule],
|
||||
pp_degree: int,
|
||||
num_microbatches: int,
|
||||
num_stages_per_rank: Optional[int] = None,
|
||||
num_stages_per_rank: int | None = None,
|
||||
add_spacing: bool = False,
|
||||
with_comms: bool = False,
|
||||
) -> list[list[Optional[_Action]]]:
|
||||
) -> list[list[_Action | None]]:
|
||||
"""
|
||||
Get all actions for a given schedule, pp_degree, and num_microbatches. The actions are returned in a list of lists
|
||||
where each inner list represents a rank and each element in the inner list represents an action.
|
||||
@ -86,7 +86,7 @@ def get_schedule_ops(
|
||||
assert schedule_instance.pipeline_order is not None
|
||||
|
||||
# Convert to List[List[_Action]]
|
||||
all_actions: list[list[Optional[_Action]]] = []
|
||||
all_actions: list[list[_Action | None]] = []
|
||||
if with_comms:
|
||||
runtime = _PipelineScheduleRuntime(stages, num_microbatches)
|
||||
runtime._prepare_schedule_with_comms(schedule_instance.pipeline_order)
|
||||
@ -136,8 +136,8 @@ action_type_to_color_mapping = {
|
||||
|
||||
|
||||
def add_schedule_op_spacing(
|
||||
schedule: list[list[Optional[_Action]]],
|
||||
) -> list[list[Optional[_Action]]]:
|
||||
schedule: list[list[_Action | None]],
|
||||
) -> list[list[_Action | None]]:
|
||||
"""
|
||||
Add spacing to the schedule based on dependencies between ranks.
|
||||
|
||||
@ -169,7 +169,7 @@ def add_schedule_op_spacing(
|
||||
)
|
||||
|
||||
num_ranks = len(schedule)
|
||||
spaced_schedule: list[list[Optional[_Action]]] = [[] for _ in range(num_ranks)]
|
||||
spaced_schedule: list[list[_Action | None]] = [[] for _ in range(num_ranks)]
|
||||
rank_ops = [collections.deque(ops) for ops in schedule]
|
||||
|
||||
# Track completion times: (stage_index, action_type, microbatch_index) -> completion_time
|
||||
@ -331,8 +331,8 @@ def add_schedule_op_spacing(
|
||||
|
||||
|
||||
def visualize_schedule(
|
||||
schedule: list[list[Optional[_Action]]],
|
||||
filename: Optional[str] = None,
|
||||
schedule: list[list[_Action | None]],
|
||||
filename: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Visualize the schedule using matplotlib.
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
@ -76,8 +75,8 @@ def validate_tensor_metadata(desc, expected, given):
|
||||
|
||||
def validate_tensors_metadata(
|
||||
desc,
|
||||
expected_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
|
||||
actual_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
|
||||
expected_tensors: list[torch.Tensor] | tuple[torch.Tensor, ...],
|
||||
actual_tensors: list[torch.Tensor] | tuple[torch.Tensor, ...],
|
||||
):
|
||||
if len(expected_tensors) != len(actual_tensors):
|
||||
raise PipeliningShapeError(
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import logging
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.fx.node import map_aggregate
|
||||
@ -307,10 +307,10 @@ def _shard_dict_of_args(
|
||||
|
||||
def split_args_kwargs_into_chunks(
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]],
|
||||
kwargs: dict[str, Any] | None,
|
||||
chunks: int,
|
||||
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
|
||||
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
|
||||
) -> tuple[list[tuple], list[dict]]:
|
||||
"""
|
||||
Given a sequence of args and kwargs, split them into a number of chunks
|
||||
|
||||
@ -11,7 +11,7 @@ from collections import Counter, defaultdict
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import Any, cast, NamedTuple, Optional, Protocol, Union
|
||||
from typing import Any, cast, NamedTuple, Protocol
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -131,8 +131,8 @@ _action_regex = re.compile(
|
||||
class _Action(NamedTuple):
|
||||
stage_index: int
|
||||
computation_type: _ComputationType
|
||||
microbatch_index: Optional[int] = None
|
||||
sub_actions: Optional[tuple["_Action", ...]] = None
|
||||
microbatch_index: int | None = None
|
||||
sub_actions: tuple["_Action", ...] | None = None
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
@ -220,8 +220,8 @@ def _get_profiler_function_name(action: _Action) -> str:
|
||||
|
||||
|
||||
def _format_pipeline_order(
|
||||
pipeline_order: dict[int, list[Optional[_Action]]],
|
||||
error_step_number: Optional[int] = None,
|
||||
pipeline_order: dict[int, list[_Action | None]],
|
||||
error_step_number: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Formats the pipeline order in a timestep (row) x rank (column) grid of actions
|
||||
@ -286,10 +286,10 @@ class _PipelineSchedule(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable[..., torch.Tensor]] = None,
|
||||
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
loss_fn: Callable[..., torch.Tensor] | None = None,
|
||||
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
|
||||
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
|
||||
output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
|
||||
scale_grads: bool = True,
|
||||
):
|
||||
# From arguments
|
||||
@ -360,10 +360,10 @@ class _PipelineSchedule(ABC):
|
||||
@abstractmethod
|
||||
def _step_microbatches(
|
||||
self,
|
||||
arg_mbs: Optional[list] = None,
|
||||
kwarg_mbs: Optional[list] = None,
|
||||
target_mbs: Optional[list] = None,
|
||||
losses: Optional[list] = None,
|
||||
arg_mbs: list | None = None,
|
||||
kwarg_mbs: list | None = None,
|
||||
target_mbs: list | None = None,
|
||||
losses: list | None = None,
|
||||
return_outputs: bool = True,
|
||||
):
|
||||
"""
|
||||
@ -382,7 +382,7 @@ class _PipelineSchedule(ABC):
|
||||
self,
|
||||
*args,
|
||||
target=None,
|
||||
losses: Optional[list] = None,
|
||||
losses: list | None = None,
|
||||
return_outputs=True,
|
||||
**kwargs,
|
||||
):
|
||||
@ -399,7 +399,7 @@ class _PipelineSchedule(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def eval(self, *args, target=None, losses: Optional[list] = None, **kwargs):
|
||||
def eval(self, *args, target=None, losses: list | None = None, **kwargs):
|
||||
"""
|
||||
Run one iteration of the pipeline schedule with *whole-batch* input.
|
||||
Will chunk the input into microbatches automatically, and go through the
|
||||
@ -421,10 +421,10 @@ class _PipelineSchedule(ABC):
|
||||
|
||||
def _check_inputs(
|
||||
self,
|
||||
arg_mbs: Optional[list] = None,
|
||||
kwarg_mbs: Optional[list] = None,
|
||||
target_mbs: Optional[list] = None,
|
||||
losses: Optional[list] = None,
|
||||
arg_mbs: list | None = None,
|
||||
kwarg_mbs: list | None = None,
|
||||
target_mbs: list | None = None,
|
||||
losses: list | None = None,
|
||||
) -> tuple[list, list]:
|
||||
"""
|
||||
Pre-process/check inputs
|
||||
@ -463,7 +463,7 @@ class _PipelineSchedule(ABC):
|
||||
def _split_inputs(
|
||||
self,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Splits a full-batch input into chunks (i.e. microbatches) and returns
|
||||
@ -494,9 +494,7 @@ class _PipelineSchedule(ABC):
|
||||
)
|
||||
|
||||
|
||||
def _batch_p2p(
|
||||
p2p_ops: list[dist.P2POp], desc: Optional[str] = None
|
||||
) -> list[dist.Work]:
|
||||
def _batch_p2p(p2p_ops: list[dist.P2POp], desc: str | None = None) -> list[dist.Work]:
|
||||
"""
|
||||
Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
|
||||
"""
|
||||
@ -508,7 +506,7 @@ def _batch_p2p(
|
||||
|
||||
|
||||
def _sorted_batch_p2p(
|
||||
p2p_ops: list[dist.P2POp], desc: Optional[str] = None
|
||||
p2p_ops: list[dist.P2POp], desc: str | None = None
|
||||
) -> dict[int, list[dist.Work]]:
|
||||
"""
|
||||
Sorts the list of P2P ops by the peer rank, and then calls
|
||||
@ -557,10 +555,10 @@ class PipelineScheduleSingle(_PipelineSchedule):
|
||||
self,
|
||||
stage: _PipelineStageBase,
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
loss_fn: Callable | None = None,
|
||||
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
|
||||
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
|
||||
output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
|
||||
scale_grads: bool = True,
|
||||
):
|
||||
# Init parent
|
||||
@ -584,7 +582,7 @@ class PipelineScheduleSingle(_PipelineSchedule):
|
||||
or equal to the number of stages ({self._num_stages})."
|
||||
)
|
||||
|
||||
self.pipeline_order: Optional[dict[int, list[Optional[_Action]]]] = (
|
||||
self.pipeline_order: dict[int, list[_Action | None]] | None = (
|
||||
self._get_pipeline_order()
|
||||
)
|
||||
|
||||
@ -608,7 +606,7 @@ or equal to the number of stages ({self._num_stages})."
|
||||
self,
|
||||
*args,
|
||||
target=None,
|
||||
losses: Optional[list] = None,
|
||||
losses: list | None = None,
|
||||
return_outputs: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
@ -656,7 +654,7 @@ or equal to the number of stages ({self._num_stages})."
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
|
||||
def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None:
|
||||
"""
|
||||
Returns the pipeline execution order as a schedule IR.
|
||||
|
||||
@ -683,10 +681,10 @@ class _ScheduleForwardOnly(PipelineScheduleSingle):
|
||||
|
||||
def _step_microbatches(
|
||||
self,
|
||||
arg_mbs: Optional[list] = None,
|
||||
kwarg_mbs: Optional[list] = None,
|
||||
target_mbs: Optional[list] = None,
|
||||
losses: Optional[list] = None,
|
||||
arg_mbs: list | None = None,
|
||||
kwarg_mbs: list | None = None,
|
||||
target_mbs: list | None = None,
|
||||
losses: list | None = None,
|
||||
return_outputs: bool = True,
|
||||
):
|
||||
"""
|
||||
@ -734,10 +732,10 @@ class ScheduleGPipe(PipelineScheduleSingle):
|
||||
|
||||
def _step_microbatches(
|
||||
self,
|
||||
arg_mbs: Optional[list] = None,
|
||||
kwarg_mbs: Optional[list] = None,
|
||||
target_mbs: Optional[list] = None,
|
||||
losses: Optional[list] = None,
|
||||
arg_mbs: list | None = None,
|
||||
kwarg_mbs: list | None = None,
|
||||
target_mbs: list | None = None,
|
||||
losses: list | None = None,
|
||||
return_outputs: bool = True,
|
||||
):
|
||||
"""
|
||||
@ -812,7 +810,7 @@ class ScheduleGPipe(PipelineScheduleSingle):
|
||||
|
||||
self._stage.perform_reduce_grad(self._n_microbatches if self.scale_grads else 1)
|
||||
|
||||
def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
|
||||
def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None:
|
||||
"""
|
||||
Returns the pipeline order for GPipe schedule.
|
||||
|
||||
@ -822,7 +820,7 @@ class ScheduleGPipe(PipelineScheduleSingle):
|
||||
pp_group_size = self._num_stages
|
||||
|
||||
for rank in range(pp_group_size):
|
||||
actions: list[Optional[_Action]] = []
|
||||
actions: list[_Action | None] = []
|
||||
|
||||
# 1. Initial delay based on rank position
|
||||
warmup_delay = rank
|
||||
@ -853,10 +851,10 @@ class Schedule1F1B(PipelineScheduleSingle):
|
||||
|
||||
def _step_microbatches(
|
||||
self,
|
||||
arg_mbs: Optional[list] = None,
|
||||
kwarg_mbs: Optional[list] = None,
|
||||
target_mbs: Optional[list] = None,
|
||||
losses: Optional[list] = None,
|
||||
arg_mbs: list | None = None,
|
||||
kwarg_mbs: list | None = None,
|
||||
target_mbs: list | None = None,
|
||||
losses: list | None = None,
|
||||
return_outputs: bool = True,
|
||||
):
|
||||
"""
|
||||
@ -995,7 +993,7 @@ class Schedule1F1B(PipelineScheduleSingle):
|
||||
|
||||
self._stage.perform_reduce_grad(self._n_microbatches if self.scale_grads else 1)
|
||||
|
||||
def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
|
||||
def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None:
|
||||
"""
|
||||
Returns the pipeline order for 1F1B schedule.
|
||||
|
||||
@ -1005,7 +1003,7 @@ class Schedule1F1B(PipelineScheduleSingle):
|
||||
pp_group_size = self._num_stages
|
||||
|
||||
for rank in range(pp_group_size):
|
||||
actions: list[Optional[_Action]] = []
|
||||
actions: list[_Action | None] = []
|
||||
|
||||
# 1. Warmup phase: initial delay based on rank
|
||||
actions.extend([None] * rank)
|
||||
@ -1069,13 +1067,13 @@ def _requires_reduce_grad(action_type: _ComputationType) -> bool:
|
||||
|
||||
|
||||
def _add_reduce_grad(
|
||||
actions: list[Optional[_Action]], n_microbatches: int
|
||||
) -> list[Optional[_Action]]:
|
||||
actions: list[_Action | None], n_microbatches: int
|
||||
) -> list[_Action | None]:
|
||||
"""
|
||||
REDUCE_GRAD refers to joint across minibatches grad reduction.
|
||||
reduce_grad frees memory and we want to schedule it just after the last "backward"-like stage.
|
||||
"""
|
||||
actions_with_reduce_grad: list[Optional[_Action]] = []
|
||||
actions_with_reduce_grad: list[_Action | None] = []
|
||||
cnt: dict[int, int] = defaultdict(int)
|
||||
|
||||
def _leaf_action(a, to_schedule):
|
||||
@ -1102,7 +1100,7 @@ def _add_reduce_grad(
|
||||
|
||||
|
||||
def _add_unshard_reshard(
|
||||
compute_actions: list[Optional[_Action]],
|
||||
compute_actions: list[_Action | None],
|
||||
max_active_stages: int = 3,
|
||||
) -> list[_Action]:
|
||||
"""Given a basic schedule involving only compute actions (F,B,W,OVERLAP_F_B), add UNSHARD/RESHARD actions for FSDP.
|
||||
@ -1117,9 +1115,7 @@ def _add_unshard_reshard(
|
||||
(to account for having one f and one b active, and something else prefetching?)
|
||||
"""
|
||||
|
||||
def next_stage_indices(
|
||||
count: int, next_actions: list[Optional[_Action]]
|
||||
) -> list[int]:
|
||||
def next_stage_indices(count: int, next_actions: list[_Action | None]) -> list[int]:
|
||||
"""Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute."""
|
||||
seen: set[int] = set()
|
||||
ret: list[int] = []
|
||||
@ -1187,7 +1183,7 @@ def _add_unshard_reshard(
|
||||
|
||||
|
||||
def _merge_bw(
|
||||
compute_actions: list[Optional[_Action]],
|
||||
compute_actions: list[_Action | None],
|
||||
) -> list[_Action]:
|
||||
"""Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops.
|
||||
(note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD)
|
||||
@ -1259,9 +1255,7 @@ def _add_send_recv(
|
||||
recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx)
|
||||
return send, recv
|
||||
|
||||
def _ready_to_schedule(
|
||||
action: Optional[_Action], prev_actions: set[_Action]
|
||||
) -> bool:
|
||||
def _ready_to_schedule(action: _Action | None, prev_actions: set[_Action]) -> bool:
|
||||
"""We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place.
|
||||
This helps ensure a sane (non-hanging) ordering of sends and recvs.
|
||||
But it also means we might not be able to schedule our next compute action yet.
|
||||
@ -1343,7 +1337,7 @@ def _add_send_recv(
|
||||
|
||||
|
||||
def _validate_schedule(
|
||||
actions: dict[int, list[Optional[_Action]]],
|
||||
actions: dict[int, list[_Action | None]],
|
||||
pp_group_size: int,
|
||||
num_stages: int,
|
||||
num_microbatches: int,
|
||||
@ -1479,11 +1473,11 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
self,
|
||||
stages: list[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
use_full_backward: Optional[bool] = None,
|
||||
loss_fn: Callable | None = None,
|
||||
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
|
||||
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
|
||||
output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
|
||||
use_full_backward: bool | None = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
@ -1516,7 +1510,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
self._should_compute_loss = lambda stage: stage.is_last and has_loss
|
||||
|
||||
# This will be set during init of derived schedules
|
||||
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
|
||||
self.pipeline_order: dict[int, list[_Action | None]] = {}
|
||||
|
||||
# When using a custom backward function, we may or may not need autograd to be used
|
||||
# for the backward pass. This flag is used to determine whether or torch.is_grad_enabled()
|
||||
@ -1559,7 +1553,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
self._stages_backward_initialized = True
|
||||
|
||||
def _validate_and_set_stage_mapping(
|
||||
self, actions: dict[int, list[Optional[_Action]]]
|
||||
self, actions: dict[int, list[_Action | None]]
|
||||
) -> None:
|
||||
"""
|
||||
Allocates the stage index to rank mapping which is needed for communication
|
||||
@ -1600,7 +1594,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
self,
|
||||
*args,
|
||||
target=None,
|
||||
losses: Optional[list] = None,
|
||||
losses: list | None = None,
|
||||
return_outputs: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
@ -1657,10 +1651,10 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
|
||||
def _step_microbatches(
|
||||
self,
|
||||
arg_mbs: Optional[list] = None,
|
||||
kwarg_mbs: Optional[list] = None,
|
||||
target_mbs: Optional[list] = None,
|
||||
losses: Optional[list] = None,
|
||||
arg_mbs: list | None = None,
|
||||
kwarg_mbs: list | None = None,
|
||||
target_mbs: list | None = None,
|
||||
losses: list | None = None,
|
||||
return_outputs: bool = True,
|
||||
):
|
||||
"""
|
||||
@ -1851,10 +1845,10 @@ class _PipelineContext:
|
||||
def __init__(
|
||||
self,
|
||||
schedule_ref: _PipelineSchedule,
|
||||
arg_mbs: Optional[list[tuple]] = None,
|
||||
kwarg_mbs: Optional[list[dict]] = None,
|
||||
target_mbs: Optional[list] = None,
|
||||
losses: Optional[list] = None,
|
||||
arg_mbs: list[tuple] | None = None,
|
||||
kwarg_mbs: list[dict] | None = None,
|
||||
target_mbs: list | None = None,
|
||||
losses: list | None = None,
|
||||
):
|
||||
self.schedule_ref = schedule_ref
|
||||
self.arg_mbs = arg_mbs
|
||||
@ -1931,7 +1925,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
|
||||
def _prepare_schedule_with_comms(
|
||||
self,
|
||||
actions: dict[int, list[Optional[_Action]]],
|
||||
actions: dict[int, list[_Action | None]],
|
||||
format: str = "compute_only",
|
||||
):
|
||||
"""
|
||||
@ -2042,10 +2036,10 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
|
||||
def _step_microbatches(
|
||||
self,
|
||||
arg_mbs: Optional[list] = None,
|
||||
kwarg_mbs: Optional[list] = None,
|
||||
target_mbs: Optional[list] = None,
|
||||
losses: Optional[list] = None,
|
||||
arg_mbs: list | None = None,
|
||||
kwarg_mbs: list | None = None,
|
||||
target_mbs: list | None = None,
|
||||
losses: list | None = None,
|
||||
return_outputs: bool = True,
|
||||
):
|
||||
"""
|
||||
@ -2306,8 +2300,8 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
|
||||
self,
|
||||
stages: list[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Union[Callable, _Loss]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
loss_fn: Callable | _Loss | None = None,
|
||||
output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
@ -2323,7 +2317,7 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
|
||||
# 1. Create the pipeline_order (all ranks do this calculation)
|
||||
# This will be used to keep track of the current state of the entire pipeline
|
||||
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
||||
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
|
||||
self.pipeline_order: dict[int, list[_Action | None]] = {}
|
||||
# ========================================================================
|
||||
for rank in range(self.pp_group_size):
|
||||
rank_ops = self._calculate_single_rank_operations(rank)
|
||||
@ -2340,7 +2334,7 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
|
||||
|
||||
# Store the list of operations used for that rank
|
||||
# Pre-padding, rank starts with no-ops based on the warmup.
|
||||
rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
|
||||
rank_ops: list[_Action | None] = [None for _ in range(rank)]
|
||||
|
||||
for stage_index in stage_indices:
|
||||
rank_ops.extend(
|
||||
@ -2380,7 +2374,7 @@ def _get_1f1b_rank_ops(
|
||||
|
||||
# Store the list of operations used for that rank
|
||||
# Pre-padding, rank starts with no-ops based on the warmup.
|
||||
rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
|
||||
rank_ops: list[_Action | None] = [None for _ in range(rank)]
|
||||
# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
|
||||
# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
|
||||
# Formula:
|
||||
@ -2520,10 +2514,10 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
|
||||
self,
|
||||
stages: list[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
loss_fn: Callable | None = None,
|
||||
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
|
||||
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
|
||||
output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
@ -2551,7 +2545,7 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
|
||||
# 1. Create the pipeline_order (all ranks do this calculation)
|
||||
# This will be used to keep track of the current state of the entire pipeline
|
||||
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
||||
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
|
||||
self.pipeline_order: dict[int, list[_Action | None]] = {}
|
||||
for rank in range(self.pp_group_size):
|
||||
rank_ops = self._calculate_single_rank_operations(rank)
|
||||
self.pipeline_order[rank] = rank_ops
|
||||
@ -2559,7 +2553,7 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
|
||||
# Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
|
||||
self._prepare_schedule_with_comms(self.pipeline_order)
|
||||
|
||||
def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
|
||||
def _calculate_single_rank_operations(self, rank) -> list[_Action | None]:
|
||||
def get_rank_warmup_ops(rank):
|
||||
# Warms up operations for last stage
|
||||
warmups_ops_last_stage = (
|
||||
@ -2634,10 +2628,10 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
|
||||
self,
|
||||
stages: list[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
loss_fn: Callable | None = None,
|
||||
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
|
||||
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
|
||||
output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
@ -2667,7 +2661,7 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
|
||||
# 1. Create the pipeline_order (all ranks do this calculation)
|
||||
# This will be used to keep track of the current state of the entire pipeline
|
||||
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
||||
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
|
||||
self.pipeline_order: dict[int, list[_Action | None]] = {}
|
||||
for rank in range(self.pp_group_size):
|
||||
rank_ops = self._calculate_single_rank_operations(rank)
|
||||
self.pipeline_order[rank] = rank_ops
|
||||
@ -2682,7 +2676,7 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
|
||||
# Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
|
||||
self._prepare_schedule_with_comms(self.pipeline_order)
|
||||
|
||||
def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
|
||||
def _calculate_single_rank_operations(self, rank) -> list[_Action | None]:
|
||||
def get_rank_warmup_ops(rank):
|
||||
# Warms up operations for last stage
|
||||
warmups_ops_last_stage = (
|
||||
@ -2760,7 +2754,7 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
|
||||
return False
|
||||
|
||||
seen_ops: set[tuple[int, _ComputationType, int]] = set()
|
||||
result: dict[int, list[Optional[_Action]]] = {}
|
||||
result: dict[int, list[_Action | None]] = {}
|
||||
next_pointer: dict[int, int] = {}
|
||||
bubbles_added: dict[int, int] = {}
|
||||
total_bubbles_added = 0
|
||||
@ -2833,10 +2827,10 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
|
||||
self,
|
||||
stages: list[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
loss_fn: Callable | None = None,
|
||||
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
|
||||
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
|
||||
output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
@ -2872,7 +2866,7 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
|
||||
# 1. Create the pipeline_order (all ranks do this calculation)
|
||||
# This will be used to keep track of the current state of the entire pipeline
|
||||
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
||||
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
|
||||
self.pipeline_order: dict[int, list[_Action | None]] = {}
|
||||
for rank in range(self.pp_group_size):
|
||||
rank_ops = self._calculate_single_rank_operations(rank)
|
||||
self.pipeline_order[rank] = rank_ops
|
||||
@ -2880,11 +2874,11 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
|
||||
# Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
|
||||
self._prepare_schedule_with_comms(self.pipeline_order)
|
||||
|
||||
def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
|
||||
def _calculate_single_rank_operations(self, rank) -> list[_Action | None]:
|
||||
# max(2 * self.pp_group_size - 1, ...) ensure the number of microbatches is at least
|
||||
# as large of the number of microbatches needed to fully utilize the pipeline
|
||||
n_micro = max(2 * self.pp_group_size - 1, self._n_microbatches)
|
||||
rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
|
||||
rank_ops: list[_Action | None] = [None for _ in range(rank)]
|
||||
|
||||
# Forward and backward action counts for stage chunk 0 and chunk 1
|
||||
f0_cnt, f1_cnt, b0_cnt, b1_cnt = 0, 0, 0, 0
|
||||
@ -3011,10 +3005,10 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
|
||||
self,
|
||||
stages: list[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
loss_fn: Callable | None = None,
|
||||
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
|
||||
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
|
||||
output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
@ -3055,7 +3049,7 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
|
||||
# 1. Create the pipeline_order (all ranks do this calculation)
|
||||
# This will be used to keep track of the current state of the entire pipeline
|
||||
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
||||
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
|
||||
self.pipeline_order: dict[int, list[_Action | None]] = {}
|
||||
for rank in range(self.pp_group_size):
|
||||
rank_ops = self._calculate_single_rank_operations(rank)
|
||||
self.pipeline_order[rank] = rank_ops
|
||||
@ -3063,8 +3057,8 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
|
||||
# Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
|
||||
self._prepare_schedule_with_comms(self.pipeline_order)
|
||||
|
||||
def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
|
||||
actions: list[Optional[_Action]] = []
|
||||
def _calculate_single_rank_operations(self, rank) -> list[_Action | None]:
|
||||
actions: list[_Action | None] = []
|
||||
counters: dict[
|
||||
tuple[int, _ComputationType], int
|
||||
] = {} # (stage_index, computation_type) -> mb_index
|
||||
@ -3273,12 +3267,12 @@ def _simulate_comms_compute(
|
||||
|
||||
_prev_ops_rank: dict[int, set[_Action]] = {rank: set() for rank in _schedule}
|
||||
|
||||
def add_to_schedule(rank: int, action: Optional[_Action]):
|
||||
def add_to_schedule(rank: int, action: _Action | None):
|
||||
_schedule[rank].append(action)
|
||||
if action is not None:
|
||||
_prev_ops_rank[rank].add(action)
|
||||
|
||||
def _ready_to_schedule(action: Optional[_Action]) -> bool:
|
||||
def _ready_to_schedule(action: _Action | None) -> bool:
|
||||
if action is None:
|
||||
return True
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
import operator
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast, Optional, Union
|
||||
from typing import Any, cast, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -99,7 +99,7 @@ InputInfo = Union[_RecvInfo, _RootArgPlaceholder]
|
||||
|
||||
|
||||
def _make_tensor_from_meta(
|
||||
example: Union[torch.Tensor, FakeTensor],
|
||||
example: torch.Tensor | FakeTensor,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -126,8 +126,8 @@ class _PipelineStageBase(ABC):
|
||||
stage_index: int,
|
||||
num_stages: int,
|
||||
device: torch.device,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
|
||||
group: dist.ProcessGroup | None = None,
|
||||
dw_builder: Callable[[], Callable[..., None]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -176,11 +176,11 @@ class _PipelineStageBase(ABC):
|
||||
)
|
||||
|
||||
# Run time states
|
||||
self._outputs_meta: Optional[tuple[torch.Tensor, ...]] = None
|
||||
self._outputs_meta: tuple[torch.Tensor, ...] | None = None
|
||||
# map microbatch ID to list of forward tensor args
|
||||
self.fwd_cache: dict[int, tuple[Any, list[torch.Tensor]]] = {}
|
||||
# map microbatch ID to list of backward grad tensor args
|
||||
self.bwd_cache: dict[int, tuple[Optional[torch.Tensor], ...]] = {}
|
||||
self.bwd_cache: dict[int, tuple[torch.Tensor | None, ...]] = {}
|
||||
# Caching chunk outputs for final output merge or reduction
|
||||
self.output_chunks: list[Any] = []
|
||||
|
||||
@ -196,10 +196,10 @@ class _PipelineStageBase(ABC):
|
||||
|
||||
# Backward infra will created lazily
|
||||
self.grad_recv_info: dict = {}
|
||||
self.grad_send_info: Optional[list] = None
|
||||
self.grad_send_info: list | None = None
|
||||
|
||||
# To be populated later by the Schedule
|
||||
self.chunks: Optional[int] = None
|
||||
self.chunks: int | None = None
|
||||
self.stage_index_to_group_rank: dict[int, int] = {
|
||||
i: i % self.group_size for i in range(self.num_stages)
|
||||
}
|
||||
@ -261,11 +261,11 @@ class _PipelineStageBase(ABC):
|
||||
def _create_grad_send_info(
|
||||
self,
|
||||
args_recv_info: tuple,
|
||||
) -> list[Optional[int]]:
|
||||
) -> list[int | None]:
|
||||
"""
|
||||
Create a list of stage indices to send gradients to.
|
||||
"""
|
||||
grad_send_info: list[Optional[int]] = []
|
||||
grad_send_info: list[int | None] = []
|
||||
|
||||
def map_recv_to_send(a):
|
||||
# Note: we send gradients back to previous stage as long as in
|
||||
@ -288,7 +288,7 @@ class _PipelineStageBase(ABC):
|
||||
self,
|
||||
num_microbatches: int,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[Any, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -388,7 +388,7 @@ class _PipelineStageBase(ABC):
|
||||
return self.bwd_cache.pop(mb_index)
|
||||
|
||||
def set_local_bwd_input(
|
||||
self, next_stage_bwd_outputs: tuple[Optional[torch.Tensor], ...], mb_index: int
|
||||
self, next_stage_bwd_outputs: tuple[torch.Tensor | None, ...], mb_index: int
|
||||
) -> None:
|
||||
"""
|
||||
Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv.
|
||||
@ -588,7 +588,7 @@ class _PipelineStageBase(ABC):
|
||||
backward_type,
|
||||
bwd_kwargs: dict,
|
||||
last_backward: bool = False,
|
||||
) -> tuple[tuple[Optional[torch.Tensor], ...], Optional[list[dict[str, Any]]]]:
|
||||
) -> tuple[tuple[torch.Tensor | None, ...], list[dict[str, Any]] | None]:
|
||||
"""
|
||||
Whether using PP with FSDP, DDP, or replicate there are some runtime differences between the last backward step and the
|
||||
other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but
|
||||
@ -600,7 +600,7 @@ class _PipelineStageBase(ABC):
|
||||
backward_type,
|
||||
) -> Callable[
|
||||
[],
|
||||
tuple[tuple[Optional[torch.Tensor], ...], Optional[list[dict[str, Any]]]],
|
||||
tuple[tuple[torch.Tensor | None, ...], list[dict[str, Any]] | None],
|
||||
]:
|
||||
if backward_type == "full":
|
||||
return lambda: (
|
||||
@ -663,7 +663,7 @@ class _PipelineStageBase(ABC):
|
||||
self,
|
||||
fwd_chunk_id: int,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
save_forward_output: bool = True,
|
||||
):
|
||||
"""
|
||||
@ -779,7 +779,7 @@ class _PipelineStageBase(ABC):
|
||||
"input_values": input_values,
|
||||
}
|
||||
|
||||
grads_input: tuple[Optional[torch.Tensor], ...] = ()
|
||||
grads_input: tuple[torch.Tensor | None, ...] = ()
|
||||
|
||||
# Custom backward function
|
||||
if self.dw_builder:
|
||||
@ -1019,7 +1019,7 @@ class _PipelineStage(_PipelineStageBase):
|
||||
stage_index: int,
|
||||
pipe_info: PipeInfo,
|
||||
device: torch.device,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
group: dist.ProcessGroup | None = None,
|
||||
):
|
||||
"""
|
||||
Create a pipeline stage given a stage_module to be wrapped by this stage
|
||||
@ -1086,7 +1086,7 @@ class _PipelineStage(_PipelineStageBase):
|
||||
self,
|
||||
num_microbatches: int,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[Any, ...]:
|
||||
"""
|
||||
Create send/recv infrastructures for activations (during forward)
|
||||
@ -1183,7 +1183,7 @@ class _PipelineStage(_PipelineStageBase):
|
||||
def find_dst_rank(
|
||||
self,
|
||||
user: fx.Node,
|
||||
) -> Optional[int]:
|
||||
) -> int | None:
|
||||
"""
|
||||
Find the destination rank of a `user` node.
|
||||
If the `user` is not a submod, `None` may be returned.
|
||||
@ -1293,7 +1293,7 @@ def build_stage(
|
||||
stage_index: int,
|
||||
pipe_info: PipeInfo,
|
||||
device: torch.device,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
group: dist.ProcessGroup | None = None,
|
||||
) -> _PipelineStage:
|
||||
"""
|
||||
Create a pipeline stage given a stage_module to be wrapped by this stage
|
||||
@ -1347,14 +1347,14 @@ class PipelineStage(_PipelineStageBase):
|
||||
stage_index: int,
|
||||
num_stages: int,
|
||||
device: torch.device,
|
||||
input_args: Optional[Union[torch.Tensor, tuple[torch.Tensor, ...]]] = None,
|
||||
output_args: Optional[Union[torch.Tensor, tuple[torch.Tensor, ...]]] = None,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
|
||||
input_args: torch.Tensor | tuple[torch.Tensor, ...] | None = None,
|
||||
output_args: torch.Tensor | tuple[torch.Tensor, ...] | None = None,
|
||||
group: dist.ProcessGroup | None = None,
|
||||
dw_builder: Callable[[], Callable[..., None]] | None = None,
|
||||
):
|
||||
super().__init__(submodule, stage_index, num_stages, device, group, dw_builder)
|
||||
self.inputs: Optional[list[torch.Tensor]] = None
|
||||
self.inputs_meta: Optional[tuple[torch.Tensor, ...]] = None
|
||||
self.inputs: list[torch.Tensor] | None = None
|
||||
self.inputs_meta: tuple[torch.Tensor, ...] | None = None
|
||||
# Note: inputs and submod should ideally be on meta device. We decided not to assert this (yet) because it
|
||||
# might be breaking for existing users.
|
||||
if input_args is None:
|
||||
@ -1410,7 +1410,7 @@ class PipelineStage(_PipelineStageBase):
|
||||
def _shape_inference(
|
||||
self,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
@ -1522,7 +1522,7 @@ class PipelineStage(_PipelineStageBase):
|
||||
self,
|
||||
num_microbatches: int,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[Any, ...]:
|
||||
# TODO move self.device to an argument from step API (from its input tensors)?
|
||||
assert num_microbatches is not None, "TODO fix num_microbatches"
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -22,14 +21,14 @@ class _remote_device:
|
||||
and "cuda:1", just represent local devices.
|
||||
"""
|
||||
|
||||
def __init__(self, remote_device: Union[str, torch.device]):
|
||||
def __init__(self, remote_device: str | torch.device):
|
||||
PARSE_ERROR = (
|
||||
f"Could not parse remote_device: {remote_device}. The valid format is "
|
||||
"'<workername>/<device>' or 'rank:<rank>/<device>' or '<device>'"
|
||||
)
|
||||
self._worker_name = None
|
||||
self._rank = None
|
||||
self._device: Optional[Union[str, int, torch.device]] = None
|
||||
self._device: str | int | torch.device | None = None
|
||||
|
||||
if isinstance(remote_device, torch.device):
|
||||
self._device = remote_device
|
||||
@ -81,11 +80,11 @@ class _remote_device:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def worker_name(self) -> Optional[str]:
|
||||
def worker_name(self) -> str | None:
|
||||
"""Return the name of remote worker representing the remote device and ``None`` if no worker name is available."""
|
||||
return self._worker_name
|
||||
|
||||
def rank(self) -> Optional[int]:
|
||||
def rank(self) -> int | None:
|
||||
"""
|
||||
Returns the rank of remote worker representing the remote device.
|
||||
Returns ``None`` if no rank is available.
|
||||
|
||||
@ -11,7 +11,6 @@ import os
|
||||
import sys
|
||||
from collections.abc import Callable, Iterator
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
from torch.distributed import FileStore, Store, TCPStore
|
||||
|
||||
@ -71,7 +70,7 @@ def _get_use_libuv_from_query_dict(query_dict: dict[str, str]) -> bool:
|
||||
return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "1")) == "1"
|
||||
|
||||
|
||||
def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwargs):
|
||||
def _rendezvous_helper(url: str, rank: int, world_size_opt: int | None, **kwargs):
|
||||
result = urlparse(url)
|
||||
if world_size_opt is None:
|
||||
world_size = -1
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -89,10 +89,10 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
|
||||
num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,
|
||||
rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,
|
||||
init_method: str = rpc_contants.DEFAULT_INIT_METHOD,
|
||||
device_maps: Optional[dict[str, dict[DeviceType, DeviceType]]] = None,
|
||||
devices: Optional[list[DeviceType]] = None,
|
||||
_transports: Optional[list] = None,
|
||||
_channels: Optional[list] = None,
|
||||
device_maps: dict[str, dict[DeviceType, DeviceType]] | None = None,
|
||||
devices: list[DeviceType] | None = None,
|
||||
_transports: list | None = None,
|
||||
_channels: list | None = None,
|
||||
):
|
||||
full_device_maps = (
|
||||
{}
|
||||
|
||||
@ -375,7 +375,6 @@ import uuid
|
||||
from argparse import ArgumentParser, REMAINDER
|
||||
from collections.abc import Callable
|
||||
from importlib import metadata
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed.argparse_util import check_env, env
|
||||
@ -798,7 +797,7 @@ def get_use_env(args) -> bool:
|
||||
return args.use_env
|
||||
|
||||
|
||||
def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]:
|
||||
def _get_logs_specs_class(logs_specs_name: str | None) -> type[LogsSpecs]:
|
||||
"""
|
||||
Attempts to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
|
||||
Provides plugin mechanism to provide custom implementation of LogsSpecs.
|
||||
@ -827,7 +826,7 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]:
|
||||
return logs_specs_cls
|
||||
|
||||
|
||||
def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str]]:
|
||||
def config_from_args(args) -> tuple[LaunchConfig, Callable | str, list[str]]:
|
||||
# If ``args`` not passed, defaults to ``sys.argv[:1]``
|
||||
min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
|
||||
if not (0 < min_nodes <= max_nodes):
|
||||
@ -871,7 +870,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str
|
||||
|
||||
rdzv_endpoint = get_rdzv_endpoint(args)
|
||||
|
||||
ranks: Optional[set[int]] = None
|
||||
ranks: set[int] | None = None
|
||||
if args.local_ranks_filter:
|
||||
try:
|
||||
ranks = set(map(int, args.local_ranks_filter.split(",")))
|
||||
@ -920,7 +919,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str
|
||||
)
|
||||
|
||||
with_python = not args.no_python
|
||||
cmd: Union[Callable, str]
|
||||
cmd: Callable | str
|
||||
cmd_args = []
|
||||
use_env = get_use_env(args)
|
||||
if args.run_path:
|
||||
|
||||
@ -5,7 +5,7 @@ import copy
|
||||
import inspect
|
||||
import warnings
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, cast, Optional
|
||||
from typing import Any, cast
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
@ -74,7 +74,7 @@ class _ToTorchTensor(torch.autograd.Function):
|
||||
def forward( # type: ignore[override]
|
||||
ctx,
|
||||
input: "DTensor",
|
||||
grad_placements: Optional[Sequence[Placement]],
|
||||
grad_placements: Sequence[Placement] | None,
|
||||
):
|
||||
ctx.dtensor_spec = input._spec
|
||||
ctx.grad_placements = grad_placements
|
||||
@ -135,8 +135,8 @@ class _FromTorchTensor(torch.autograd.Function):
|
||||
device_mesh: DeviceMesh,
|
||||
placements: tuple[Placement, ...],
|
||||
run_check: bool,
|
||||
shape: Optional[torch.Size] = None,
|
||||
stride: Optional[tuple[int, ...]] = None,
|
||||
shape: torch.Size | None = None,
|
||||
stride: tuple[int, ...] | None = None,
|
||||
) -> "DTensor":
|
||||
ctx.previous_placement = placements
|
||||
ctx.previous_device_mesh = device_mesh
|
||||
@ -356,12 +356,12 @@ class DTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def from_local(
|
||||
local_tensor: torch.Tensor,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
placements: Optional[Sequence[Placement]] = None,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
placements: Sequence[Placement] | None = None,
|
||||
*,
|
||||
run_check: bool = False,
|
||||
shape: Optional[torch.Size] = None,
|
||||
stride: Optional[tuple[int, ...]] = None,
|
||||
shape: torch.Size | None = None,
|
||||
stride: tuple[int, ...] | None = None,
|
||||
) -> "DTensor":
|
||||
"""
|
||||
Create a :class:`DTensor` from a local torch.Tensor on each rank
|
||||
@ -445,7 +445,7 @@ class DTensor(torch.Tensor):
|
||||
)
|
||||
|
||||
def to_local(
|
||||
self, *, grad_placements: Optional[Sequence[Placement]] = None
|
||||
self, *, grad_placements: Sequence[Placement] | None = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Get the local tensor of this DTensor on its current rank. For sharding it returns
|
||||
@ -483,12 +483,12 @@ class DTensor(torch.Tensor):
|
||||
|
||||
def redistribute(
|
||||
self,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
placements: Optional[Sequence[Placement]] = None,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
placements: Sequence[Placement] | None = None,
|
||||
*,
|
||||
async_op: bool = False,
|
||||
forward_dtype: Optional[torch.dtype] = None,
|
||||
backward_dtype: Optional[torch.dtype] = None,
|
||||
forward_dtype: torch.dtype | None = None,
|
||||
backward_dtype: torch.dtype | None = None,
|
||||
) -> "DTensor":
|
||||
"""
|
||||
``redistribute`` performs necessary collective operations that redistribute the current
|
||||
@ -565,7 +565,7 @@ class DTensor(torch.Tensor):
|
||||
)
|
||||
|
||||
def full_tensor(
|
||||
self, *, grad_placements: Optional[Sequence[Placement]] = None
|
||||
self, *, grad_placements: Sequence[Placement] | None = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Return the full tensor of this DTensor. It will perform necessary collectives
|
||||
@ -688,10 +688,10 @@ class DTensor(torch.Tensor):
|
||||
|
||||
def distribute_tensor(
|
||||
tensor: torch.Tensor,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
placements: Optional[Sequence[Placement]] = None,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
placements: Sequence[Placement] | None = None,
|
||||
*,
|
||||
src_data_rank: Optional[int] = 0,
|
||||
src_data_rank: int | None = 0,
|
||||
) -> DTensor:
|
||||
"""
|
||||
Distribute a leaf ``torch.Tensor`` (i.e. nn.Parameter/buffers) to the ``device_mesh`` according
|
||||
@ -850,7 +850,7 @@ def distribute_tensor(
|
||||
def _shard_tensor(
|
||||
full_tensor: torch.Tensor,
|
||||
placements: Sequence[Shard],
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
) -> "DTensor":
|
||||
"""
|
||||
Locally shards a full tensor based on indicated sharding arrangement, and
|
||||
@ -886,10 +886,10 @@ def _shard_tensor(
|
||||
|
||||
def distribute_module(
|
||||
module: nn.Module,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None,
|
||||
input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None,
|
||||
output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
partition_fn: Callable[[str, nn.Module, DeviceMesh], None] | None = None,
|
||||
input_fn: Callable[[nn.Module, Any, DeviceMesh], None] | None = None,
|
||||
output_fn: Callable[[nn.Module, Any, DeviceMesh], None] | None = None,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
This function expose three functions to control the parameters/inputs/outputs of the module:
|
||||
@ -1042,8 +1042,8 @@ def distribute_module(
|
||||
def _dtensor_init_helper( # type: ignore[no-untyped-def]
|
||||
init_op,
|
||||
size: torch.Size,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
placements: Optional[Sequence[Placement]] = None,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
placements: Sequence[Placement] | None = None,
|
||||
**kwargs,
|
||||
) -> DTensor:
|
||||
# if device_mesh is None, use the one from mesh resources
|
||||
@ -1108,11 +1108,11 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def]
|
||||
|
||||
def ones( # type: ignore[no-untyped-def]
|
||||
*size,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
layout: torch.layout = torch.strided,
|
||||
requires_grad: bool = False,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
placements: Optional[Sequence[Placement]] = None,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
placements: Sequence[Placement] | None = None,
|
||||
) -> DTensor:
|
||||
"""
|
||||
Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined
|
||||
@ -1151,11 +1151,11 @@ def ones( # type: ignore[no-untyped-def]
|
||||
|
||||
def empty( # type: ignore[no-untyped-def]
|
||||
*size,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
layout: torch.layout = torch.strided,
|
||||
requires_grad: bool = False,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
placements: Optional[Sequence[Placement]] = None,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
placements: Sequence[Placement] | None = None,
|
||||
) -> DTensor:
|
||||
"""
|
||||
Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor`
|
||||
@ -1196,11 +1196,11 @@ def full( # type: ignore[no-untyped-def]
|
||||
size,
|
||||
fill_value,
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
layout: torch.layout = torch.strided,
|
||||
requires_grad: bool = False,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
placements: Optional[Sequence[Placement]] = None,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
placements: Sequence[Placement] | None = None,
|
||||
) -> DTensor:
|
||||
"""
|
||||
Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and
|
||||
@ -1242,10 +1242,10 @@ def full( # type: ignore[no-untyped-def]
|
||||
def rand( # type: ignore[no-untyped-def]
|
||||
*size,
|
||||
requires_grad: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
layout: torch.layout = torch.strided,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
placements: Optional[Sequence[Placement]] = None,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
placements: Sequence[Placement] | None = None,
|
||||
) -> DTensor:
|
||||
"""
|
||||
Returns a :class:`DTensor` filled with random numbers from a uniform distribution
|
||||
@ -1286,10 +1286,10 @@ def rand( # type: ignore[no-untyped-def]
|
||||
def randn( # type: ignore[no-untyped-def]
|
||||
*size,
|
||||
requires_grad: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
layout: torch.layout = torch.strided,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
placements: Optional[Sequence[Placement]] = None,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
placements: Sequence[Placement] | None = None,
|
||||
) -> DTensor:
|
||||
"""
|
||||
Returns a :class:`DTensor` filled with random numbers from a normal distribution
|
||||
@ -1330,10 +1330,10 @@ def randn( # type: ignore[no-untyped-def]
|
||||
def zeros( # type: ignore[no-untyped-def]
|
||||
*size,
|
||||
requires_grad: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
layout: torch.layout = torch.strided,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
placements: Optional[Sequence[Placement]] = None,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
placements: Sequence[Placement] | None = None,
|
||||
) -> DTensor:
|
||||
"""
|
||||
Returns a :class:`DTensor` filled with the scalar value 0.
|
||||
|
||||
@ -74,7 +74,7 @@ def mesh_scatter(
|
||||
async_op: bool = False,
|
||||
*,
|
||||
group_src: int = 0,
|
||||
) -> Optional[Work]:
|
||||
) -> Work | None:
|
||||
"""
|
||||
scatter a list of tensors to a device mesh dimension. We by default
|
||||
use the first rank of the mesh dimension as the source of truth, i.e
|
||||
@ -135,7 +135,7 @@ def mesh_broadcast(
|
||||
async_op: bool = False,
|
||||
*,
|
||||
group_src: int = 0,
|
||||
) -> Optional[Work]:
|
||||
) -> Work | None:
|
||||
"""
|
||||
broadcast the tensor to a device mesh dimension. We by default
|
||||
use the first rank of the mesh dimension as the source of truth, i.e
|
||||
|
||||
@ -3,7 +3,7 @@ import contextlib
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import cast, Optional
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -482,7 +482,7 @@ class OpDispatcher:
|
||||
kwargs_schema: dict[str, object] = {}
|
||||
local_args: list[object] = []
|
||||
local_kwargs: dict[str, object] = {}
|
||||
compute_mesh: Optional[DeviceMesh] = None
|
||||
compute_mesh: DeviceMesh | None = None
|
||||
|
||||
for arg in args_list:
|
||||
if isinstance(arg, dtensor.DTensor):
|
||||
|
||||
@ -2,7 +2,7 @@ import itertools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast, NamedTuple, Optional
|
||||
from typing import Any, cast, NamedTuple
|
||||
|
||||
import torch
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
@ -71,7 +71,7 @@ class DTensorSpec:
|
||||
placements: tuple[Placement, ...]
|
||||
|
||||
# tensor meta will only be set during sharding propagation
|
||||
tensor_meta: Optional[TensorMeta] = None
|
||||
tensor_meta: TensorMeta | None = None
|
||||
|
||||
# When a tensor dimension is sharded across multiple mesh axes,
|
||||
# `shard_order` specifies the sequence in which these shardings are applied.
|
||||
@ -206,7 +206,7 @@ class DTensorSpec:
|
||||
@staticmethod
|
||||
def _maybe_convert_StridedShard_to_shard_order(
|
||||
placements: tuple[Placement, ...], mesh: DeviceMesh
|
||||
) -> Optional[ShardOrder]:
|
||||
) -> ShardOrder | None:
|
||||
"""
|
||||
Try to convert _StridedShard placements to ShardOrder.
|
||||
|
||||
@ -441,7 +441,7 @@ class DTensorSpec:
|
||||
@staticmethod
|
||||
def format_shard_order_str(
|
||||
placements: tuple[Placement, ...],
|
||||
shard_order: Optional[ShardOrder] = None,
|
||||
shard_order: ShardOrder | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format DTensor sharding information as a human-readable string.
|
||||
@ -617,7 +617,7 @@ class DTensorSpec:
|
||||
mesh: DeviceMesh,
|
||||
dim_map: list[int],
|
||||
sums: list[int],
|
||||
tensor_meta: Optional[TensorMeta] = None,
|
||||
tensor_meta: TensorMeta | None = None,
|
||||
) -> "DTensorSpec":
|
||||
"""
|
||||
Construct a DTensorSpec from dim_map list and pending sum.
|
||||
@ -669,7 +669,7 @@ class DTensorSpec:
|
||||
return any(placement.is_shard() for placement in self.placements)
|
||||
|
||||
def shallow_copy_with_tensor_meta(
|
||||
self, tensor_meta: Optional[TensorMeta]
|
||||
self, tensor_meta: TensorMeta | None
|
||||
) -> "DTensorSpec":
|
||||
"""
|
||||
Shallow copy the DTensorSpec with a new tensor_meta.
|
||||
|
||||
@ -26,7 +26,7 @@ These schema definitions enable the DTensor system to:
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
@ -60,11 +60,11 @@ except ImportError:
|
||||
ArgsType = tuple[object, ...]
|
||||
KwargsType = dict[str, object]
|
||||
|
||||
PlacementList = list[Optional[Placement]]
|
||||
PlacementList = list[Placement | None]
|
||||
|
||||
# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type should
|
||||
# be the same set of possibilities.
|
||||
OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]]
|
||||
OutputSpecType = Optional[DTensorSpec | Sequence[DTensorSpec | None]]
|
||||
|
||||
|
||||
def _rebuild_tensor_from_dtensor_meta(arg) -> object:
|
||||
@ -109,8 +109,8 @@ class OpSpec:
|
||||
|
||||
# output_specs and input_specs are related: for this op, given these input_specs,
|
||||
# this is the way the output would look
|
||||
output_specs: Union[DTensorSpec, tuple[Optional[DTensorSpec], ...]]
|
||||
input_specs: Optional[Sequence[DTensorSpec]] = None
|
||||
output_specs: DTensorSpec | tuple[DTensorSpec | None, ...]
|
||||
input_specs: Sequence[DTensorSpec] | None = None
|
||||
|
||||
"""
|
||||
redistribute_cost tells how expensive it is to redistribute a given input into the
|
||||
@ -138,7 +138,7 @@ class OpSpec:
|
||||
K, # cost of redistributing tensor_a from 'Shard(0)'
|
||||
],
|
||||
"""
|
||||
redistribute_cost: Optional[list[list[float]]] = None
|
||||
redistribute_cost: list[list[float]] | None = None
|
||||
|
||||
@cached_property
|
||||
def output_spec(self) -> DTensorSpec:
|
||||
@ -301,7 +301,7 @@ class RuntimeSchemaInfo:
|
||||
# Note that only a few ops need this information, e.g. view, transpose, var.dim, etc.
|
||||
static_argnum: int = 100
|
||||
# This static_kwargkey records static kwarg names which would affect sharding prop
|
||||
static_kwargkey: Optional[list[str]] = None
|
||||
static_kwargkey: list[str] | None = None
|
||||
# each op can decide if it wants to use pytree flatten/unflatten during operator
|
||||
# eager execution, by default we don't need to do flatten/unflatten, only if the
|
||||
# op indicate it needs to, this is to accelerate eager performance.
|
||||
@ -331,9 +331,9 @@ class OpSchema:
|
||||
args_schema: ArgsType
|
||||
kwargs_schema: KwargsType
|
||||
|
||||
schema_info: Optional[RuntimeSchemaInfo] = None
|
||||
schema_info: RuntimeSchemaInfo | None = None
|
||||
|
||||
_comparison_key: Optional[tuple[object, ...]] = None
|
||||
_comparison_key: tuple[object, ...] | None = None
|
||||
|
||||
@property
|
||||
def args_spec(self) -> tuple[DTensorSpec, ...]:
|
||||
@ -560,7 +560,7 @@ class OutputSharding:
|
||||
# specifies the output sharding pattern
|
||||
output_spec: OutputSpecType
|
||||
# schema for redistribution if needed
|
||||
redistribute_schema: Optional[OpSchema] = None
|
||||
redistribute_schema: OpSchema | None = None
|
||||
# flag indicating if inputs need redistribution
|
||||
needs_redistribute: bool = False
|
||||
# flag to use values from `redistribute_schema`
|
||||
@ -596,7 +596,7 @@ class OpInfo:
|
||||
flat_args_schema: list[object]
|
||||
local_args: Sequence[object]
|
||||
local_kwargs: dict[str, object]
|
||||
args_tree_spec: Optional[TreeSpec] = None
|
||||
args_tree_spec: TreeSpec | None = None
|
||||
|
||||
# the output sharding info
|
||||
output_sharding: Optional[OutputSharding] = None
|
||||
output_sharding: OutputSharding | None = None
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import string
|
||||
from typing import cast, Optional
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
@ -44,7 +44,7 @@ def einop_rule(
|
||||
op_schema: OpSchema,
|
||||
*,
|
||||
linearity: bool = False,
|
||||
enforce_sharding: Optional[dict[str, int]] = None,
|
||||
enforce_sharding: dict[str, int] | None = None,
|
||||
) -> OutputSharding:
|
||||
"""
|
||||
Propagate the sharding of inputs to output for ops whose data moves according to einsum notation.
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskBuffer:
|
||||
data: Optional[torch.Tensor] = None
|
||||
data: torch.Tensor | None = None
|
||||
# refcount allows shared usage of the MaskBuffer, as long as all users have the same data
|
||||
refcount: int = 0
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import math
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import cast, Optional, Union
|
||||
from typing import cast, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
@ -47,7 +47,7 @@ class Reduction(Enum):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormReduction:
|
||||
norm_type: Union[int, float, str]
|
||||
norm_type: int | float | str
|
||||
|
||||
|
||||
ReductionOpType = Union[NormReduction, str]
|
||||
@ -71,9 +71,9 @@ class _NormPartial(Partial):
|
||||
similarly for inf and -inf norm. For 0-norm, the reduction op is sum.
|
||||
"""
|
||||
|
||||
norm_type: Union[int, float, str] = 2
|
||||
norm_type: int | float | str = 2
|
||||
|
||||
def __init__(self, norm_type: Union[int, float, str] = 2):
|
||||
def __init__(self, norm_type: int | float | str = 2):
|
||||
reduce_op = None
|
||||
if norm_type in (float("inf"), "inf"):
|
||||
reduce_op = "max"
|
||||
@ -164,7 +164,7 @@ class _NormPartial(Partial):
|
||||
return 1 + hash(self.norm_type)
|
||||
|
||||
|
||||
def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[list[int]]:
|
||||
def _infer_reduction_dims(dims_arg: object, ndim: int) -> list[int] | None:
|
||||
if dims_arg is None:
|
||||
return None
|
||||
dims = cast(list[int], as_list(dims_arg))
|
||||
@ -1081,7 +1081,7 @@ def _common_norm_backward_strategy(
|
||||
out_tuple_strategy = OpStrategy([])
|
||||
for idx, input_placement_strategy in enumerate(input_strategy.strategies):
|
||||
# args for OpSpec
|
||||
output_specs_list: list[Optional[DTensorSpec]] = []
|
||||
output_specs_list: list[DTensorSpec | None] = []
|
||||
input_specs_list: list[DTensorSpec] = []
|
||||
redistribute_costs = []
|
||||
|
||||
|
||||
@ -2,8 +2,6 @@
|
||||
# implement matrix related ops for distributed tensor
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
@ -708,7 +706,7 @@ def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrate
|
||||
) = op_schema.args_schema
|
||||
return_debug_mask = len(op_schema.args_schema) >= 8 and rest_args[2]
|
||||
has_attn_bias = attn_bias_strategy is not None
|
||||
debug_attn_mask_sharding: Optional[Placement] = (
|
||||
debug_attn_mask_sharding: Placement | None = (
|
||||
Replicate() if return_debug_mask else None
|
||||
)
|
||||
|
||||
@ -1073,7 +1071,7 @@ def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
)
|
||||
|
||||
def valid_grouped_mm_strides(
|
||||
input_specs: list[DTensorSpec], output_specs: tuple[Optional[DTensorSpec], ...]
|
||||
input_specs: list[DTensorSpec], output_specs: tuple[DTensorSpec | None, ...]
|
||||
) -> bool:
|
||||
# 1. compute the local-tensor shape/strides given this sharding proposal
|
||||
# 2. apply the logic from the groped_mm meta function
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from collections.abc import Sequence
|
||||
from typing import cast, Optional
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
||||
@ -493,7 +493,7 @@ def common_pointwise_strategy(
|
||||
followed_strategy: OpStrategy,
|
||||
followed_strategy_index: int,
|
||||
linearity: int = -1,
|
||||
scalar_tensor_idx: Optional[int] = None,
|
||||
scalar_tensor_idx: int | None = None,
|
||||
) -> OpStrategy:
|
||||
"""
|
||||
Common strategy for pointwise operations.
|
||||
@ -713,11 +713,11 @@ def list_pointwise_strategy(
|
||||
|
||||
def args_tuple_strategies(
|
||||
args_schema: tuple[object, ...],
|
||||
) -> list[Optional[TupleStrategy]]:
|
||||
) -> list[TupleStrategy | None]:
|
||||
first_arg = args_schema[0]
|
||||
assert isinstance(first_arg, TupleStrategy)
|
||||
strategy_len = len(first_arg.children)
|
||||
tuple_strategies: list[Optional[TupleStrategy]] = []
|
||||
tuple_strategies: list[TupleStrategy | None] = []
|
||||
for arg_idx, arg in enumerate(args_schema):
|
||||
if isinstance(arg, TupleStrategy):
|
||||
# every tuple strategy should have the same length
|
||||
@ -743,7 +743,7 @@ def list_pointwise_strategy(
|
||||
|
||||
for child_idx, child_strtgy in enumerate(follow_strategy.children):
|
||||
assert isinstance(child_strtgy, OpStrategy)
|
||||
args_schema: list[Optional[OpStrategy]] = [
|
||||
args_schema: list[OpStrategy | None] = [
|
||||
cast(OpStrategy, arg_strategy.children[child_idx]) if arg_strategy else None
|
||||
for arg_strategy in args_strategies
|
||||
]
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from collections.abc import Sequence, Sized
|
||||
from typing import cast, Optional
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch._prims_common import IntLike
|
||||
@ -721,7 +721,7 @@ def _derive_follow_placements_from_tuple_strategy(
|
||||
# current replicate, just follow new placement
|
||||
return new_placement
|
||||
|
||||
follow_placements: Optional[list[Placement]] = None
|
||||
follow_placements: list[Placement] | None = None
|
||||
mesh = tuple_strategy.child_mesh(0)
|
||||
for arg_strategy in tuple_strategy.children:
|
||||
if not isinstance(arg_strategy, OpStrategy):
|
||||
@ -887,7 +887,7 @@ def prop_index_select(op_schema: OpSchema) -> OutputSharding:
|
||||
if not isinstance(indices_spec, DTensorSpec):
|
||||
raise AssertionError(f"Expected DTensorSpec, got {type(indices_spec)}")
|
||||
|
||||
all_indices_spec: list[Optional[DTensorSpec]] = [
|
||||
all_indices_spec: list[DTensorSpec | None] = [
|
||||
indices_spec if dim == i else None for i in range(values_spec.ndim)
|
||||
]
|
||||
|
||||
@ -934,7 +934,7 @@ def prop_index_put(op_schema: OpSchema) -> StrategyType:
|
||||
op_strategy = OpStrategy([])
|
||||
# 1. `indices` should all be replicated first.
|
||||
indices_redistribute_costs = []
|
||||
new_indices_spec: list[Optional[DTensorSpec]] = []
|
||||
new_indices_spec: list[DTensorSpec | None] = []
|
||||
for indices_spec_child in indices_spec.children:
|
||||
if not isinstance(indices_spec_child, OpStrategy):
|
||||
raise AssertionError(f"Expected OpStrategy, got {type(indices_spec_child)}")
|
||||
@ -1044,7 +1044,7 @@ def prop_index(op_schema: OpSchema) -> OutputSharding:
|
||||
raise AssertionError(f"Expected DTensorSpec, got {type(values_spec)}")
|
||||
if not isinstance(multi_indices_spec, list):
|
||||
raise AssertionError(f"Expected list, got {type(multi_indices_spec)}")
|
||||
multi_indices_spec = cast(list[Optional[DTensorSpec]], multi_indices_spec)
|
||||
multi_indices_spec = cast(list[DTensorSpec | None], multi_indices_spec)
|
||||
valid_indices_spec: list[tuple[int, DTensorSpec]] = [
|
||||
(i, a) for i, a in enumerate(multi_indices_spec) if a is not None
|
||||
]
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import cast, Optional, Union
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -216,7 +216,7 @@ def expand(input_shape: Shape, shape: Shape) -> DimMap:
|
||||
return tuple(mapping)
|
||||
|
||||
|
||||
def normalize_sizes(sizes: Union[Shape, tuple[Shape]]) -> Shape:
|
||||
def normalize_sizes(sizes: Shape | tuple[Shape]) -> Shape:
|
||||
if isinstance(sizes[0], int):
|
||||
return cast(Shape, sizes)
|
||||
elif len(sizes) == 1:
|
||||
@ -428,7 +428,7 @@ def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap:
|
||||
return tuple(dimmap)
|
||||
|
||||
|
||||
def dim_squeeze(shape: Shape, dim: Optional[int] = None) -> DimMap:
|
||||
def dim_squeeze(shape: Shape, dim: int | None = None) -> DimMap:
|
||||
# FIXME: this is wrong when dim=None and one of the dimensions
|
||||
# equals size of the mesh. For example squeeze(DTensor(tensor(4), Shard[0])) could
|
||||
# end up as squeeze(tensor(1)) if we have 4 devices; this would lead to
|
||||
@ -457,7 +457,7 @@ def dim_view_as_real(shape: Shape) -> DimMap:
|
||||
return tuple(results)
|
||||
|
||||
|
||||
def dim_reduction(ndim: int, dim_or_dims: Optional[DimsType], keepdim: bool) -> DimMap:
|
||||
def dim_reduction(ndim: int, dim_or_dims: DimsType | None, keepdim: bool) -> DimMap:
|
||||
"""
|
||||
General fallback for reduction ops where Partial() does not apply.
|
||||
|
||||
@ -542,7 +542,7 @@ def propagate_shape_and_sharding(
|
||||
|
||||
def maybe_get_shard_mesh_dim_and_placement(
|
||||
input_dim: InputDim,
|
||||
) -> tuple[Optional[int], Optional[Shard]]:
|
||||
) -> tuple[int | None, Shard | None]:
|
||||
# if input_dim is sharded, return the mesh_dim and shard placement
|
||||
for i, placement in enumerate(input_src_placements):
|
||||
if isinstance(placement, Shard) and placement.dim == input_dim.input_dim:
|
||||
@ -556,7 +556,7 @@ def propagate_shape_and_sharding(
|
||||
# 1 and 2 doesn't require the info of whether current input is sharded.
|
||||
# 3 requires that info, to decide whether we can error out. Maybe we can refactor
|
||||
# to make this function purely "theoretical".
|
||||
def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]:
|
||||
def get_in_dim_to_shard(cmd: DimSpec) -> InputDim | None:
|
||||
if isinstance(cmd, InputDim):
|
||||
return cmd
|
||||
elif isinstance(cmd, Flatten):
|
||||
@ -692,7 +692,7 @@ def propagate_shape_and_sharding(
|
||||
def register_op_strategy_map(
|
||||
aten_op_overload: torch._ops.OpOverload,
|
||||
local_op_name: Callable[..., torch.Tensor],
|
||||
schema_info: Optional[RuntimeSchemaInfo] = None,
|
||||
schema_info: RuntimeSchemaInfo | None = None,
|
||||
strict_view: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
@ -4,7 +4,7 @@ import functools
|
||||
import itertools
|
||||
import operator
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import cast, Optional, TypeVar, Union
|
||||
from typing import cast, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
@ -36,8 +36,8 @@ _P = ParamSpec("_P")
|
||||
|
||||
# convenient wrapper to register sharding propagation rules
|
||||
def register_prop_rule(
|
||||
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
|
||||
schema_info: Optional[RuntimeSchemaInfo] = None,
|
||||
op: torch._ops.OpOverload | list[torch._ops.OpOverload],
|
||||
schema_info: RuntimeSchemaInfo | None = None,
|
||||
) -> Callable[
|
||||
[Callable[[OpSchema], OutputSharding]], Callable[[OpSchema], OutputSharding]
|
||||
]:
|
||||
@ -127,9 +127,9 @@ def replicate_op_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
|
||||
|
||||
def as_list(
|
||||
x: Union[list[object], object],
|
||||
x: list[object] | object,
|
||||
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
|
||||
) -> Union[list[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
|
||||
) -> list[object] | torch.fx.immutable_collections.immutable_list: # type: ignore[valid-type]
|
||||
# During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
|
||||
# which is an object but treated as a list by the tracer. Therefore, keep
|
||||
# `immutable_list` intact here as well.
|
||||
@ -295,9 +295,10 @@ def expand_to_full_mesh_op_strategy(
|
||||
*,
|
||||
input_index: int = 1,
|
||||
inplace_op: bool = False,
|
||||
is_valid_strategy_cb: Optional[
|
||||
Callable[[list[DTensorSpec], tuple[Optional[DTensorSpec], ...]], bool]
|
||||
] = None,
|
||||
is_valid_strategy_cb: Callable[
|
||||
[list[DTensorSpec], tuple[DTensorSpec | None, ...]], bool
|
||||
]
|
||||
| None = None,
|
||||
) -> OpStrategy:
|
||||
"""
|
||||
Convenience function to allow writing a sharding strategy considering only a single mesh dimension,
|
||||
@ -332,7 +333,7 @@ def expand_to_full_mesh_op_strategy(
|
||||
|
||||
all_strategies = []
|
||||
for strategy_comb in strategy_combs:
|
||||
spec_list: list[Optional[DTensorSpec]] = []
|
||||
spec_list: list[DTensorSpec | None] = []
|
||||
for specs in zip(*strategy_comb):
|
||||
if specs[0] is not None:
|
||||
# TODO: we should fill in tensor_meta here. If nothing else, it helps the filter strategy callback
|
||||
@ -354,7 +355,7 @@ def expand_to_full_mesh_op_strategy(
|
||||
# input_spec matches the first argument's runtime sharding, otherwise we skip
|
||||
continue
|
||||
|
||||
output_specs: tuple[Optional[DTensorSpec], ...]
|
||||
output_specs: tuple[DTensorSpec | None, ...]
|
||||
if input_index > 1:
|
||||
output_specs = tuple(spec_list[:input_index])
|
||||
else:
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import contextlib
|
||||
import warnings
|
||||
from logging import getLogger
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed.device_mesh import _get_device_handle, DeviceMesh
|
||||
@ -171,7 +171,7 @@ class _RNGStateTracker:
|
||||
self._use_distribute_region = value
|
||||
|
||||
def _distribute_region(
|
||||
self, spec: DTensorSpec, generator: Optional[torch.Generator] = None
|
||||
self, spec: DTensorSpec, generator: torch.Generator | None = None
|
||||
):
|
||||
pass
|
||||
|
||||
@ -237,7 +237,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _distribute_region(
|
||||
self, spec: DTensorSpec, generator: Optional[torch.Generator] = None
|
||||
self, spec: DTensorSpec, generator: torch.Generator | None = None
|
||||
):
|
||||
if generator is not None:
|
||||
# This is a little hacky, but for any user-passed generator, we store its state under a unique key,
|
||||
@ -327,7 +327,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
||||
mesh = spec.mesh
|
||||
# note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP
|
||||
# case. Replace the custom logic with dim_map once we support it.
|
||||
dim_map: list[Union[int, list[int]]] = [-1] * spec.ndim
|
||||
dim_map: list[int | list[int]] = [-1] * spec.ndim
|
||||
for i, placement in enumerate(spec.placements):
|
||||
if isinstance(placement, Shard):
|
||||
shard_dim = placement.dim
|
||||
|
||||
@ -8,7 +8,7 @@ import weakref
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from functools import cache
|
||||
from typing import cast, NamedTuple, Optional
|
||||
from typing import cast, NamedTuple
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
@ -88,7 +88,7 @@ class DTensorRedistributePlanner:
|
||||
class DistState:
|
||||
placements: tuple[Placement, ...]
|
||||
tensor_dim_to_mesh_dim: ShardOrder
|
||||
_hash: Optional[int] = dataclasses.field(
|
||||
_hash: int | None = dataclasses.field(
|
||||
default=None, init=False, repr=False, compare=False
|
||||
)
|
||||
|
||||
@ -161,7 +161,7 @@ class DTensorRedistributePlanner:
|
||||
mesh: DeviceMesh,
|
||||
transform_infos: Sequence[_TransformInfo],
|
||||
src_placement: tuple[Placement, ...],
|
||||
src_shard_order: Optional[ShardOrder] = None,
|
||||
src_shard_order: ShardOrder | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a string representation of the sequence of state transitions
|
||||
@ -646,7 +646,7 @@ class DTensorRedistributePlanner:
|
||||
def _gen_transform_infos_non_cached(
|
||||
src_spec: DTensorSpec,
|
||||
dst_spec: DTensorSpec,
|
||||
use_graph_based_transform: Optional[bool] = None,
|
||||
use_graph_based_transform: bool | None = None,
|
||||
) -> list[_TransformInfo]:
|
||||
transform_infos: list[_TransformInfo] = []
|
||||
device_mesh = src_spec.device_mesh
|
||||
@ -678,7 +678,7 @@ def _gen_transform_infos_non_cached(
|
||||
def _gen_transform_infos(
|
||||
src_spec: DTensorSpec,
|
||||
dst_spec: DTensorSpec,
|
||||
use_graph_based_transform: Optional[bool] = None,
|
||||
use_graph_based_transform: bool | None = None,
|
||||
) -> list[_TransformInfo]:
|
||||
return _gen_transform_infos_non_cached(
|
||||
src_spec, dst_spec, use_graph_based_transform
|
||||
@ -692,7 +692,7 @@ def redistribute_local_tensor(
|
||||
*,
|
||||
async_op: bool = False,
|
||||
is_backward: bool = False,
|
||||
use_graph_based_transform: Optional[bool] = None,
|
||||
use_graph_based_transform: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This redistribute the local tensor (torch.Tensor) from the current DTensorSpec to
|
||||
@ -846,8 +846,8 @@ class Redistribute(torch.autograd.Function):
|
||||
device_mesh: DeviceMesh,
|
||||
placements: tuple[Placement, ...],
|
||||
async_op: bool = False,
|
||||
forward_dtype: Optional[torch.dtype] = None,
|
||||
backward_dtype: Optional[torch.dtype] = None,
|
||||
forward_dtype: torch.dtype | None = None,
|
||||
backward_dtype: torch.dtype | None = None,
|
||||
):
|
||||
ctx.async_op = async_op
|
||||
ctx.backward_dtype = backward_dtype
|
||||
|
||||
@ -4,7 +4,7 @@ import threading
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
from typing import cast, Optional, Union
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch._guards import detect_fake_mode
|
||||
@ -69,9 +69,7 @@ class ShardingPropagator:
|
||||
)
|
||||
# op map to save indices of shape (and stride) args which may need to be
|
||||
# modified in sharding prop
|
||||
self.op_to_shape_and_stride_idx: dict[
|
||||
OpOverload, Union[int, tuple[int, int]]
|
||||
] = {
|
||||
self.op_to_shape_and_stride_idx: dict[OpOverload, int | tuple[int, int]] = {
|
||||
# new factory ops
|
||||
aten.new_empty.default: 1,
|
||||
aten.new_full.default: 1,
|
||||
@ -91,7 +89,7 @@ class ShardingPropagator:
|
||||
self,
|
||||
op_overload: OpOverload,
|
||||
rule_func: Callable[[OpSchema], OutputSharding],
|
||||
schema_info: Optional[RuntimeSchemaInfo] = None,
|
||||
schema_info: RuntimeSchemaInfo | None = None,
|
||||
):
|
||||
"""
|
||||
Register a sharding propagation rule for an operator.
|
||||
@ -104,7 +102,7 @@ class ShardingPropagator:
|
||||
self,
|
||||
op_overload: OpOverload,
|
||||
strategy_func: Callable[[OpSchema], StrategyType],
|
||||
schema_info: Optional[RuntimeSchemaInfo] = None,
|
||||
schema_info: RuntimeSchemaInfo | None = None,
|
||||
):
|
||||
"""
|
||||
Register a :class:`OpStrategy` generator for an operator.
|
||||
@ -157,7 +155,7 @@ class ShardingPropagator:
|
||||
|
||||
def _propagate_tensor_meta_non_cached(
|
||||
self, op_schema: OpSchema
|
||||
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
|
||||
) -> None | TensorMeta | Sequence[TensorMeta | None]:
|
||||
"""
|
||||
Propagate the tensor metadata, it could either return a TensorMeta
|
||||
or a list/tuple of TensorMetas
|
||||
@ -191,7 +189,7 @@ class ShardingPropagator:
|
||||
)
|
||||
|
||||
elif isinstance(fake_out, (tuple, list)):
|
||||
tensor_meta_list: list[Optional[TensorMeta]] = []
|
||||
tensor_meta_list: list[TensorMeta | None] = []
|
||||
for fake_out_item in fake_out:
|
||||
if isinstance(fake_out_item, torch.Tensor):
|
||||
tensor_meta_list.append(
|
||||
@ -215,7 +213,7 @@ class ShardingPropagator:
|
||||
@lru_cache # noqa: B019
|
||||
def _propagate_tensor_meta(
|
||||
self, op_schema: OpSchema
|
||||
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
|
||||
) -> None | TensorMeta | Sequence[TensorMeta | None]:
|
||||
"""
|
||||
Cached version of _propagate_tensor_meta_non_cached
|
||||
This is a private API. Use propagate_tensor_meta instead.
|
||||
@ -224,7 +222,7 @@ class ShardingPropagator:
|
||||
|
||||
def propagate_tensor_meta(
|
||||
self, op_schema: OpSchema
|
||||
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
|
||||
) -> None | TensorMeta | Sequence[TensorMeta | None]:
|
||||
"""
|
||||
Propagate the tensor metadata, it could either return a TensorMeta
|
||||
or a list/tuple of TensorMetas. This is a public API that should be
|
||||
@ -239,7 +237,7 @@ class ShardingPropagator:
|
||||
self,
|
||||
op: OpOverload,
|
||||
output_specs: OutputSpecType,
|
||||
output_tensor_meta: Union[None, TensorMeta, Sequence[Optional[TensorMeta]]],
|
||||
output_tensor_meta: None | TensorMeta | Sequence[TensorMeta | None],
|
||||
) -> OutputSpecType:
|
||||
"""
|
||||
Wrap the output_specs with the tensor metadata from the output.
|
||||
@ -260,7 +258,7 @@ class ShardingPropagator:
|
||||
)
|
||||
return output_specs.shallow_copy_with_tensor_meta(output_tensor_meta)
|
||||
elif isinstance(output_specs, (tuple, list)):
|
||||
new_specs: list[Optional[DTensorSpec]] = []
|
||||
new_specs: list[DTensorSpec | None] = []
|
||||
if not isinstance(output_tensor_meta, (tuple, list)) or len(
|
||||
output_specs
|
||||
) != len(output_tensor_meta):
|
||||
@ -587,7 +585,7 @@ class ShardingPropagator:
|
||||
)
|
||||
|
||||
def _select_strategy(
|
||||
self, strategy: OpStrategy, op_schema: Optional[OpSchema] = None
|
||||
self, strategy: OpStrategy, op_schema: OpSchema | None = None
|
||||
) -> OpSpec:
|
||||
if len(strategy.strategies) == 1:
|
||||
# short cut with only one possible OpSpec
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from typing import cast, Optional
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
@ -159,7 +159,7 @@ def compute_local_shape_and_global_offset(
|
||||
def _compute_local_shape_and_global_offset(
|
||||
global_shape: ShapeType,
|
||||
mesh_shape: ShapeType,
|
||||
my_coordinate: Optional[list[int]],
|
||||
my_coordinate: list[int] | None,
|
||||
placements: Sequence[Placement],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
"""
|
||||
|
||||
@ -5,7 +5,7 @@ torchrun --standalone --nnodes=1 --nproc-per-node=4 comm_mode_features_example.p
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -55,7 +55,7 @@ class CommDebugModeExample:
|
||||
self.device_type = get_device_type()
|
||||
|
||||
def _MLP_model_setup(
|
||||
self, model_type: type, parallelize_plan: Union[None, dict] = None
|
||||
self, model_type: type, parallelize_plan: None | dict = None
|
||||
) -> tuple[nn.Module, torch.Tensor]:
|
||||
"""
|
||||
Creates MLP or MLPStacked model for examples
|
||||
|
||||
@ -5,7 +5,6 @@ torchrun --standalone --nnodes=1 --nproc-per-node=4 flex_attention_cp.py
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -27,8 +26,8 @@ def get_device_type() -> str:
|
||||
@lru_cache
|
||||
def create_block_mask_cached(
|
||||
score_mod: _mask_mod_signature,
|
||||
B: Optional[int],
|
||||
H: Optional[int],
|
||||
B: int | None,
|
||||
H: int | None,
|
||||
M: int,
|
||||
N: int,
|
||||
device: str = "cuda",
|
||||
|
||||
@ -7,7 +7,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from functools import partial
|
||||
from typing import Any, cast, Optional, Protocol, TypeAlias
|
||||
from typing import Any, cast, Protocol, TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -140,8 +140,8 @@ class _SDPAMerger:
|
||||
|
||||
def __init__(self, convert_to_f32: bool, seq_dim: int):
|
||||
self._seq_dim = seq_dim
|
||||
self._out: Optional[torch.Tensor] = None
|
||||
self._lse: Optional[torch.Tensor] = None
|
||||
self._out: torch.Tensor | None = None
|
||||
self._lse: torch.Tensor | None = None
|
||||
self._should_lse_squeeze = False
|
||||
self._convert_to_f32 = convert_to_f32
|
||||
self._out_dtype = torch.float32
|
||||
@ -250,7 +250,7 @@ class _AllToAllRotater(_RingRotater):
|
||||
def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None:
|
||||
self._pg = pg
|
||||
self._seq_dim = seq_dim
|
||||
self._buffer: Optional[torch.Tensor] = None
|
||||
self._buffer: torch.Tensor | None = None
|
||||
|
||||
def exchange_buffers(self, curr_buffer: torch.Tensor) -> None:
|
||||
curr_buffer = curr_buffer.contiguous()
|
||||
@ -272,7 +272,7 @@ class _AllGatherRotater(_RingRotater):
|
||||
def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None:
|
||||
self._pg = pg
|
||||
self._seq_dim = seq_dim
|
||||
self._aggregated_buffer: Optional[torch.Tensor] = None
|
||||
self._aggregated_buffer: torch.Tensor | None = None
|
||||
self._idx = 0
|
||||
|
||||
def exchange_buffers(self, curr_buffer: torch.Tensor) -> None:
|
||||
@ -293,7 +293,7 @@ class _AllGatherRotater(_RingRotater):
|
||||
|
||||
|
||||
def _create_rotater(
|
||||
pg: dist.ProcessGroup, seq_dim: int, method: Optional[_RotateMethod] = None
|
||||
pg: dist.ProcessGroup, seq_dim: int, method: _RotateMethod | None = None
|
||||
) -> _RingRotater:
|
||||
if method is None:
|
||||
method = _cp_options.rotate_method
|
||||
@ -655,7 +655,7 @@ def _scaled_dot_product_ring_flash_attention(
|
||||
is_causal: bool = False,
|
||||
return_debug_mask: bool = False,
|
||||
*,
|
||||
scale: Optional[float] = None,
|
||||
scale: float | None = None,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if return_debug_mask:
|
||||
raise NotImplementedError("return_debug_mask is not supported yet")
|
||||
@ -681,12 +681,12 @@ def _scaled_dot_product_ring_efficient_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[torch.Tensor] = None,
|
||||
attn_bias: torch.Tensor | None = None,
|
||||
compute_log_sumexp: bool = True,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
*,
|
||||
scale: Optional[float] = None,
|
||||
scale: float | None = None,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if attn_bias is not None:
|
||||
raise NotImplementedError("attn_bias is not supported yet")
|
||||
@ -718,13 +718,13 @@ def _scaled_dot_product_ring_cudnn_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[torch.Tensor] = None,
|
||||
attn_bias: torch.Tensor | None = None,
|
||||
compute_log_sumexp: bool = True,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
return_debug_mask: bool = False,
|
||||
*,
|
||||
scale: Optional[float] = None,
|
||||
scale: float | None = None,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if attn_bias is not None:
|
||||
raise NotImplementedError("attn_bias is not supported yet")
|
||||
@ -769,7 +769,7 @@ def _scaled_dot_product_ring_flash_attention_backward(
|
||||
philox_seed: torch.Tensor,
|
||||
philox_offset: torch.Tensor,
|
||||
*,
|
||||
scale: Optional[float] = None,
|
||||
scale: float | None = None,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
# TODO: remove this hardcoding
|
||||
seq_dim = 2
|
||||
@ -812,7 +812,7 @@ def _scaled_dot_product_ring_efficient_attention_backward(
|
||||
grad_input_mask: tuple[bool, ...],
|
||||
is_causal: bool = False,
|
||||
*,
|
||||
scale: Optional[float] = None,
|
||||
scale: float | None = None,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
# TODO: remove this hardcoding
|
||||
seq_dim = 2
|
||||
@ -856,7 +856,7 @@ def _scaled_dot_product_ring_cudnn_attention_backward(
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
*,
|
||||
scale: Optional[float] = None,
|
||||
scale: float | None = None,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
# TODO: remove this hardcoding
|
||||
seq_dim = 2
|
||||
@ -938,8 +938,8 @@ exitsing_custom_ops = DTensor._op_dispatcher._custom_op_handlers
|
||||
|
||||
ArgsType = tuple[Any, ...]
|
||||
KwargsType = dict[str, Any]
|
||||
InputFnType = Callable[[Optional[nn.Module], ArgsType, KwargsType, DeviceMesh], Any]
|
||||
OutputFnType = Callable[[Optional[nn.Module], Any, Any, DeviceMesh], Any]
|
||||
InputFnType = Callable[[nn.Module | None, ArgsType, KwargsType, DeviceMesh], Any]
|
||||
OutputFnType = Callable[[nn.Module | None, Any, Any, DeviceMesh], Any]
|
||||
|
||||
_replaced_functions: dict[Callable, tuple[str, Callable]] = {}
|
||||
|
||||
@ -1039,7 +1039,7 @@ def _context_parallel_buffers(
|
||||
mesh: DeviceMesh,
|
||||
buffers: list[torch.Tensor | BlockMask],
|
||||
buffer_seq_dims: list[int],
|
||||
load_balancer: Optional[_LoadBalancer] = None,
|
||||
load_balancer: _LoadBalancer | None = None,
|
||||
) -> list[torch.Tensor | BlockMask]:
|
||||
"""
|
||||
Shard the buffers along the sequence dimensions according to CP rules.
|
||||
@ -1136,7 +1136,7 @@ def _create_cp_block_mask(
|
||||
Q_LEN: int,
|
||||
KV_LEN: int,
|
||||
device_mesh: DeviceMesh,
|
||||
load_balancer: Optional[_LoadBalancer] = None,
|
||||
load_balancer: _LoadBalancer | None = None,
|
||||
) -> BlockMask:
|
||||
"""
|
||||
Creates a specialized BlockMask for Context Parallel FlexAttention.
|
||||
@ -1197,7 +1197,7 @@ def _create_cp_block_mask(
|
||||
rank: int,
|
||||
block_size: int,
|
||||
local_q_size: int,
|
||||
qkv_rearrange_indices: Optional[torch.Tensor] = None,
|
||||
qkv_rearrange_indices: torch.Tensor | None = None,
|
||||
) -> _mask_mod_signature:
|
||||
assert qkv_rearrange_indices is None or qkv_rearrange_indices.ndim == 2, (
|
||||
"load balance index expects shape (1, seq_len) or (B, seq_len) "
|
||||
@ -1301,7 +1301,7 @@ class _ContextParallel(ParallelStyle):
|
||||
raise ValueError(f"Unknown attention type: {self.attention_type}")
|
||||
|
||||
def flex_input_fn(
|
||||
self, module: Optional[nn.Module], args: Any, kwargs: Any, mesh: DeviceMesh
|
||||
self, module: nn.Module | None, args: Any, kwargs: Any, mesh: DeviceMesh
|
||||
) -> Any:
|
||||
args_list = list(args)
|
||||
for idx, name in enumerate(
|
||||
@ -1329,7 +1329,7 @@ class _ContextParallel(ParallelStyle):
|
||||
|
||||
def sdpa_input_fn(
|
||||
self,
|
||||
module: Optional[nn.Module],
|
||||
module: nn.Module | None,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
mesh: DeviceMesh,
|
||||
@ -1351,7 +1351,7 @@ class _ContextParallel(ParallelStyle):
|
||||
return new_args, new_kwargs
|
||||
|
||||
def sdpa_output_fn(
|
||||
self, module: Optional[nn.Module], inputs: Any, outputs: Any, mesh: DeviceMesh
|
||||
self, module: nn.Module | None, inputs: Any, outputs: Any, mesh: DeviceMesh
|
||||
) -> Any:
|
||||
new_outputs = []
|
||||
for output in [outputs] if isinstance(outputs, torch.Tensor) else outputs:
|
||||
@ -1373,7 +1373,7 @@ def _context_parallel_shard(
|
||||
mesh: DeviceMesh,
|
||||
buffers: CPBufferContainer,
|
||||
seq_dims: CPBufferSeqDims,
|
||||
load_balancer: Optional[_LoadBalancer] = None,
|
||||
load_balancer: _LoadBalancer | None = None,
|
||||
) -> list[torch.Tensor | BlockMask]:
|
||||
"""
|
||||
Shard the buffers along the specified sequence dimensions (`seq_dims`), so that each
|
||||
@ -1464,9 +1464,9 @@ def _disable_context_parallel_dispatcher() -> None:
|
||||
def context_parallel(
|
||||
mesh: DeviceMesh,
|
||||
*,
|
||||
buffers: Optional[list[torch.Tensor]] = None,
|
||||
buffer_seq_dims: Optional[list[int]] = None,
|
||||
no_restore_buffers: Optional[set[torch.Tensor]] = None,
|
||||
buffers: list[torch.Tensor] | None = None,
|
||||
buffer_seq_dims: list[int] | None = None,
|
||||
no_restore_buffers: set[torch.Tensor] | None = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
|
||||
@ -1554,7 +1554,7 @@ def context_parallel_unshard(
|
||||
mesh: DeviceMesh,
|
||||
buffers: list[torch.Tensor],
|
||||
seq_dims: list[int],
|
||||
load_balancer: Optional[_LoadBalancer] = None,
|
||||
load_balancer: _LoadBalancer | None = None,
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Unshard the tensors (e.g., output) that are sharded due to context parallelism.
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# for different load-balancing strategies in tensor sharding.
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -12,7 +11,7 @@ from torch.nn.attention.flex_attention import BlockMask
|
||||
# make it private since it's still a prototype
|
||||
class _LoadBalancer(ABC):
|
||||
@abstractmethod
|
||||
def _generate_indices(self, restore: bool = False) -> Optional[Tensor]:
|
||||
def _generate_indices(self, restore: bool = False) -> Tensor | None:
|
||||
"""
|
||||
Generate indices for load balancing.
|
||||
Args:
|
||||
@ -478,7 +477,7 @@ class _PTRRLoadBalancer(_LoadBalancer):
|
||||
|
||||
def _create_default_load_balancer(
|
||||
seq_length: int, world_size: int, device: str | torch.device
|
||||
) -> Optional[_LoadBalancer]:
|
||||
) -> _LoadBalancer | None:
|
||||
from ._attention import _cp_options
|
||||
|
||||
if _cp_options.enable_load_balance:
|
||||
|
||||
@ -24,11 +24,11 @@ OutputPlacements = Union[PlacementType, tuple[PlacementType, ...]]
|
||||
|
||||
|
||||
def local_map(
|
||||
func: Optional[Callable] = None,
|
||||
func: Callable | None = None,
|
||||
out_placements: OutputPlacements = None,
|
||||
in_placements: InputPlacements = None,
|
||||
in_grad_placements: InputPlacements = None,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
*,
|
||||
redistribute_inputs: bool = False,
|
||||
):
|
||||
@ -163,7 +163,7 @@ def _local_map_wrapped(
|
||||
out_placements: OutputPlacements,
|
||||
in_placements: InputPlacements,
|
||||
in_grad_placements: InputPlacements,
|
||||
device_mesh: Optional[DeviceMesh],
|
||||
device_mesh: DeviceMesh | None,
|
||||
redistribute_inputs: bool,
|
||||
*args,
|
||||
**kwargs,
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverload
|
||||
@ -21,7 +20,7 @@ from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy
|
||||
__all__ = ["register_sharding"]
|
||||
|
||||
|
||||
def register_sharding(op: Union[OpOverload, list[OpOverload]]):
|
||||
def register_sharding(op: OpOverload | list[OpOverload]):
|
||||
"""
|
||||
:meth:`register_sharding` is an experimental API that allows users to register sharding
|
||||
strategies for an operator when the tensor inputs and outputs are DTensor.
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
import copy
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast, Optional
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
@ -273,7 +273,7 @@ def _create_placement_strategy(
|
||||
node: Node,
|
||||
mesh: DeviceMesh,
|
||||
placements: tuple[Placement, ...],
|
||||
input_specs: Optional[Sequence[DTensorSpec]] = None,
|
||||
input_specs: Sequence[DTensorSpec] | None = None,
|
||||
) -> OpSpec:
|
||||
"""
|
||||
Util function to construct an OpSpec for a given node.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from functools import partial
|
||||
from typing import no_type_check, Optional
|
||||
from typing import no_type_check
|
||||
|
||||
import torch
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
@ -21,7 +21,7 @@ def sync_grad_hook(grad, *, device_handle=None, compute_stream=None):
|
||||
|
||||
def _flatten_tensor(
|
||||
tensor: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, Optional[DTensorSpec]]:
|
||||
) -> tuple[torch.Tensor, DTensorSpec | None]:
|
||||
if isinstance(tensor, DTensor):
|
||||
tensor._local_tensor.requires_grad_()
|
||||
return tensor._local_tensor, tensor._spec
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import warnings
|
||||
from fnmatch import fnmatch
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -14,10 +13,10 @@ __all__ = ["parallelize_module"]
|
||||
|
||||
def parallelize_module( # type: ignore[return]
|
||||
module: nn.Module,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
parallelize_plan: Optional[Union[ParallelStyle, dict[str, ParallelStyle]]] = None,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
parallelize_plan: ParallelStyle | dict[str, ParallelStyle] | None = None,
|
||||
*,
|
||||
src_data_rank: Optional[int] = 0,
|
||||
src_data_rank: int | None = 0,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.distributed.tensor.parallel._data_parallel_utils import (
|
||||
@ -48,7 +48,7 @@ def _reconstruct_dtensor(module: nn.Module, _input: Any):
|
||||
|
||||
|
||||
def _localize_dtensor(
|
||||
module: nn.Module, *_: Any, ignored_params: Optional[set[nn.Parameter]] = None
|
||||
module: nn.Module, *_: Any, ignored_params: set[nn.Parameter] | None = None
|
||||
):
|
||||
"""
|
||||
Convert DTensor parameters to local tensors
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
from typing import Any, cast, Optional
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -297,7 +297,7 @@ def _pre_load_state_dict(
|
||||
|
||||
def _all_gather_dtensor(
|
||||
tensor: DTensor,
|
||||
parent_mesh: Optional[DeviceMesh],
|
||||
parent_mesh: DeviceMesh | None,
|
||||
) -> torch.Tensor:
|
||||
"""All gather a DTensor in its FSDP dimension and return the local tensor."""
|
||||
assert parent_mesh == tensor.device_mesh
|
||||
@ -336,7 +336,7 @@ class DTensorExtensions(FSDPExtensions):
|
||||
def pre_flatten_transform(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, Optional[Any]]:
|
||||
) -> tuple[torch.Tensor, Any | None]:
|
||||
return _flatten_tensor(tensor)
|
||||
|
||||
def post_unflatten_transform(
|
||||
@ -365,7 +365,7 @@ class DTensorExtensions(FSDPExtensions):
|
||||
world_size: int,
|
||||
num_devices_per_node: int,
|
||||
pg: dist.ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device: torch.device | None = None,
|
||||
) -> torch.Tensor:
|
||||
return _chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg)
|
||||
|
||||
@ -386,6 +386,6 @@ class DTensorExtensions(FSDPExtensions):
|
||||
def all_gather_dtensor(
|
||||
self,
|
||||
tensor: DTensor,
|
||||
parent_mesh: Optional[DeviceMesh],
|
||||
parent_mesh: DeviceMesh | None,
|
||||
) -> torch.Tensor:
|
||||
return _all_gather_dtensor(tensor, parent_mesh)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard
|
||||
@ -14,7 +14,7 @@ __all__ = [
|
||||
def input_reshard(
|
||||
module: torch.nn.Module,
|
||||
tp_device_mesh: DeviceMesh,
|
||||
input_reshard_dim: Optional[int] = None,
|
||||
input_reshard_dim: int | None = None,
|
||||
) -> torch.nn.Module:
|
||||
"""
|
||||
Register hooks to an nn.Module for input resharding, enabling sharding and restoration during backward computation.
|
||||
@ -42,7 +42,7 @@ def input_reshard(
|
||||
if input_reshard_dim is None:
|
||||
return module
|
||||
|
||||
cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None
|
||||
cx: torch.autograd.graph.saved_tensors_hooks | None = None
|
||||
|
||||
def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: tuple[Any, ...]) -> None:
|
||||
saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks(
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import contextlib
|
||||
from typing import cast, Optional
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch._prims_common as utils
|
||||
@ -201,8 +201,8 @@ def _log_softmax_backward_handler(
|
||||
def _nll_loss_forward(
|
||||
x: Tensor,
|
||||
target: Tensor,
|
||||
weight: Optional[Tensor],
|
||||
local_weight: Optional[Tensor],
|
||||
weight: Tensor | None,
|
||||
local_weight: Tensor | None,
|
||||
reduction: int,
|
||||
ignore_index: int,
|
||||
input_shape: torch.Size,
|
||||
@ -356,7 +356,7 @@ def _nll_loss_and_log_softmax_backward(
|
||||
grad_output: Tensor,
|
||||
x: Tensor,
|
||||
target: Tensor,
|
||||
weight: Optional[Tensor],
|
||||
weight: Tensor | None,
|
||||
reduction: int,
|
||||
ignore_index: int,
|
||||
total_weight: Tensor,
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -36,7 +36,7 @@ class ParallelStyle(ABC):
|
||||
flexibility for different kind of style implementations.
|
||||
"""
|
||||
|
||||
src_data_rank: Optional[int] = 0
|
||||
src_data_rank: int | None = 0
|
||||
|
||||
@abstractmethod
|
||||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ...
|
||||
@ -82,8 +82,8 @@ class ColwiseParallel(ParallelStyle):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Optional[Placement] = None,
|
||||
output_layouts: Optional[Placement] = None,
|
||||
input_layouts: Placement | None = None,
|
||||
output_layouts: Placement | None = None,
|
||||
use_local_output: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
@ -212,8 +212,8 @@ class RowwiseParallel(ParallelStyle):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Optional[Placement] = None,
|
||||
output_layouts: Optional[Placement] = None,
|
||||
input_layouts: Placement | None = None,
|
||||
output_layouts: Placement | None = None,
|
||||
use_local_output: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
@ -473,14 +473,10 @@ class PrepareModuleInput(ParallelStyle):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Optional[
|
||||
Union[Placement, tuple[Optional[Placement], ...]]
|
||||
] = None,
|
||||
desired_input_layouts: Optional[
|
||||
Union[Placement, tuple[Optional[Placement], ...]]
|
||||
] = None,
|
||||
input_kwarg_layouts: Optional[dict[str, Placement]] = None,
|
||||
desired_input_kwarg_layouts: Optional[dict[str, Placement]] = None,
|
||||
input_layouts: Placement | tuple[Placement | None, ...] | None = None,
|
||||
desired_input_layouts: Placement | tuple[Placement | None, ...] | None = None,
|
||||
input_kwarg_layouts: dict[str, Placement] | None = None,
|
||||
desired_input_kwarg_layouts: dict[str, Placement] | None = None,
|
||||
use_local_output: bool = False,
|
||||
):
|
||||
self.input_layouts = (
|
||||
@ -513,8 +509,8 @@ class PrepareModuleInput(ParallelStyle):
|
||||
self,
|
||||
input: Any,
|
||||
mesh: DeviceMesh,
|
||||
input_layout: Optional[Placement],
|
||||
desired_layout: Optional[Placement],
|
||||
input_layout: Placement | None,
|
||||
desired_layout: Placement | None,
|
||||
):
|
||||
if input_layout is not None:
|
||||
if isinstance(input, DTensor):
|
||||
@ -637,8 +633,8 @@ class PrepareModuleOutput(ParallelStyle):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
output_layouts: Union[Placement, tuple[Optional[Placement], ...]],
|
||||
desired_output_layouts: Union[Placement, tuple[Placement, ...]],
|
||||
output_layouts: Placement | tuple[Placement | None, ...],
|
||||
desired_output_layouts: Placement | tuple[Placement, ...],
|
||||
use_local_output: bool = True,
|
||||
):
|
||||
self.output_layouts = (
|
||||
@ -768,17 +764,13 @@ class PrepareModuleInputOutput(ParallelStyle):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Optional[
|
||||
Union[Placement, tuple[Optional[Placement], ...]]
|
||||
] = None,
|
||||
desired_input_layouts: Optional[
|
||||
Union[Placement, tuple[Optional[Placement], ...]]
|
||||
] = None,
|
||||
input_kwarg_layouts: Optional[dict[str, Placement]] = None,
|
||||
desired_input_kwarg_layouts: Optional[dict[str, Placement]] = None,
|
||||
input_layouts: Placement | tuple[Placement | None, ...] | None = None,
|
||||
desired_input_layouts: Placement | tuple[Placement | None, ...] | None = None,
|
||||
input_kwarg_layouts: dict[str, Placement] | None = None,
|
||||
desired_input_kwarg_layouts: dict[str, Placement] | None = None,
|
||||
use_local_input: bool = False,
|
||||
output_layouts: Union[Placement, tuple[Optional[Placement], ...]],
|
||||
desired_output_layouts: Union[Placement, tuple[Placement, ...]],
|
||||
output_layouts: Placement | tuple[Placement | None, ...],
|
||||
desired_output_layouts: Placement | tuple[Placement, ...],
|
||||
use_local_output: bool = True,
|
||||
):
|
||||
self.prepare_module_input = PrepareModuleInput(
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import cast, Optional
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch._C
|
||||
@ -129,7 +129,7 @@ class Shard(torch._C._distributed.Shard):
|
||||
curr_local_size: int,
|
||||
num_chunks: int,
|
||||
rank: int,
|
||||
) -> tuple[int, Optional[int]]:
|
||||
) -> tuple[int, int | None]:
|
||||
return Shard.local_shard_size_and_offset(curr_local_size, num_chunks, rank)
|
||||
|
||||
@staticmethod
|
||||
@ -151,7 +151,7 @@ class Shard(torch._C._distributed.Shard):
|
||||
tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
src_data_rank: Optional[int] = 0,
|
||||
src_data_rank: int | None = 0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
shard and scatter a tensor on a mesh dimension (use coordinate
|
||||
@ -203,7 +203,7 @@ class Shard(torch._C._distributed.Shard):
|
||||
tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
src_data_rank: Optional[int] = 0,
|
||||
src_data_rank: int | None = 0,
|
||||
) -> torch.Tensor:
|
||||
shard_placement = cls(dim)
|
||||
return shard_placement._shard_tensor(tensor, mesh, mesh_dim, src_data_rank)
|
||||
@ -566,7 +566,7 @@ class _StridedShard(torch._C._distributed.StridedShard, Shard):
|
||||
tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
src_data_rank: Optional[int] = 0,
|
||||
src_data_rank: int | None = 0,
|
||||
split_factor: int = 1,
|
||||
) -> torch.Tensor:
|
||||
strided_shard_placement = cls(dim=dim, split_factor=split_factor)
|
||||
@ -689,7 +689,7 @@ class _StridedShard(torch._C._distributed.StridedShard, Shard):
|
||||
curr_local_size: int,
|
||||
num_chunks: int,
|
||||
rank: int,
|
||||
) -> tuple[int, Optional[int]]:
|
||||
) -> tuple[int, int | None]:
|
||||
# indices_tensor is 1D torch.arange(logical_dim_size) unsqueezed
|
||||
# so that we can reuse self._split_tensor which splits on self.dim
|
||||
shape = [1] * self.dim + [curr_local_size]
|
||||
@ -742,7 +742,7 @@ class Replicate(torch._C._distributed.Replicate):
|
||||
tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
src_data_rank: Optional[int] = 0,
|
||||
src_data_rank: int | None = 0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Replicate (broadcast) a torch.Tensor on a mesh dimension (use
|
||||
@ -765,7 +765,7 @@ class Replicate(torch._C._distributed.Replicate):
|
||||
tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
src_data_rank: Optional[int] = 0,
|
||||
src_data_rank: int | None = 0,
|
||||
) -> torch.Tensor:
|
||||
return Replicate._make_replicate_tensor(tensor, mesh, mesh_dim, src_data_rank)
|
||||
|
||||
@ -859,7 +859,7 @@ class MaskPartial(Partial):
|
||||
mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
|
||||
|
||||
# required fields for computing the local offset and deriving the mask
|
||||
offset_shape: Optional[torch.Size] = None
|
||||
offset_shape: torch.Size | None = None
|
||||
offset_dim: int = 0
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -44,7 +44,7 @@ def _pack_kwargs(*args: Any, **kwargs: Any) -> tuple[tuple[Any, ...], tuple[str,
|
||||
|
||||
|
||||
def _cast_forward_inputs(
|
||||
dtype: Optional[torch.dtype],
|
||||
dtype: torch.dtype | None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> tuple[Any, Any]:
|
||||
@ -257,7 +257,7 @@ def _apply_to_tensors(fn, container):
|
||||
|
||||
def _to_kwargs(
|
||||
inputs: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]],
|
||||
kwargs: dict[str, Any] | None,
|
||||
target_device: torch.device,
|
||||
use_side_stream_for_tensor_copies: bool,
|
||||
) -> tuple[tuple[Any, ...], tuple[dict[str, Any], ...]]:
|
||||
|
||||
Reference in New Issue
Block a user