Compare commits

...

1 Commits

Author SHA1 Message Date
59642e6a24 Use Python 3.10 typing
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
2025-11-15 12:37:16 +08:00
97 changed files with 824 additions and 901 deletions

View File

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable, Iterator
from enum import auto, Enum
from functools import partial
from typing import Any, Optional
from typing import Any
import torch
import torch.nn as nn
@ -251,7 +251,7 @@ def apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=checkpoint_wrapper,
check_fn=lambda _: True,
auto_wrap_policy: Optional[Callable[[nn.Module, bool, int], bool]] = None,
auto_wrap_policy: Callable[[nn.Module, bool, int], bool] | None = None,
):
"""
Apply :func:`checkpoint_wrapper` to modules within `model` based on a user-defined configuration.

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import functools
from typing import Optional
import torch
import torch.distributed as dist
@ -136,7 +135,7 @@ def _low_precision_hook(
prec: torch.dtype,
state: LowPrecisionState,
grad: torch.Tensor,
output: Optional[torch.Tensor],
output: torch.Tensor | None,
):
if grad.dtype != prec:
grad.data = grad.data.to(prec)
@ -151,7 +150,7 @@ def _low_precision_hook(
def fp16_compress_hook(
state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None
state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor | None = None
):
r"""
Implement FSDP communication hook for a simple gradient compression approach.
@ -172,7 +171,7 @@ def fp16_compress_hook(
def bf16_compress_hook(
state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None
state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor | None = None
):
r"""
Implement FSDP communication hook for a simple gradient compression approach .

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import weakref
from collections.abc import Callable
from typing import Any, Optional
from typing import Any
import torch
import torch.distributed as dist
@ -47,7 +47,7 @@ def _perform_local_step(
# expects `None` in a list position to indicate that the corresponding
# parameter should not be updated
num_local_optim_params = len(zero.optim.param_groups[0]["params"])
gradients: list[Optional[torch.Tensor]] = [
gradients: list[torch.Tensor | None] = [
_NO_PARAM_UPDATE for _ in range(num_local_optim_params)
]
assert bucket_index in overlap_info.offsets, (

View File

@ -2,7 +2,7 @@
import warnings
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Any, NamedTuple, Optional
from typing import Any, NamedTuple
import torch
import torch.distributed as dist
@ -228,9 +228,9 @@ class Join:
def __exit__(
self,
type: Optional[type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
type: type[BaseException] | None,
value: BaseException | None,
traceback: TracebackType | None,
):
r"""
Repeatedly runs the main hooks until all processes join; then, runs the post-hooks.

View File

@ -2,7 +2,6 @@
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.distributed as dist
@ -23,7 +22,7 @@ class ModelAverager(ABC):
will be used. (default: ``None``)
"""
def __init__(self, process_group: Optional[dist.ProcessGroup] = None):
def __init__(self, process_group: dist.ProcessGroup | None = None):
self.process_group = (
process_group if process_group is not None else _not_none(dist.group.WORLD)
)
@ -88,7 +87,7 @@ class PeriodicModelAverager(ModelAverager):
"""
def __init__(
self, period, warmup_steps=0, process_group: Optional[dist.ProcessGroup] = None
self, period, warmup_steps=0, process_group: dist.ProcessGroup | None = None
):
super().__init__(process_group)
if warmup_steps < 0:
@ -108,9 +107,7 @@ class PeriodicModelAverager(ModelAverager):
def average_parameters(
self,
params: Union[
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
],
params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]],
):
"""
Averages parameters or parameter groups of an optimizer if ``step`` is no less than ``warmup_steps``.

View File

@ -4,7 +4,6 @@ import logging
import warnings
from collections import OrderedDict
from collections.abc import Iterable
from typing import Union
import torch
import torch.distributed as dist
@ -160,9 +159,7 @@ class HierarchicalModelAverager(averagers.ModelAverager):
def average_parameters(
self,
params: Union[
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
],
params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]],
):
"""
Averages parameters or parameter groups of an optimizer.

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import itertools
from collections.abc import Iterable, Iterator
from typing import Union
import torch
import torch.distributed as dist
@ -51,10 +50,7 @@ def average_parameters(
def get_params_to_average(
params: Union[
Iterable[torch.nn.Parameter],
Iterable[dict[str, torch.nn.Parameter]],
],
params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]],
):
"""
Return a list of parameters that need to average.
@ -83,9 +79,7 @@ def get_params_to_average(
def average_parameters_or_parameter_groups(
params: Union[
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
],
params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]],
process_group: ProcessGroup,
):
"""Averages parameters of a model or parameter groups of an optimizer."""

View File

@ -13,7 +13,7 @@ import importlib
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, cast, Generic, Optional, TYPE_CHECKING, TypeVar, Union
from typing import Any, cast, Generic, TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
@ -37,19 +37,19 @@ T = TypeVar("T")
@dataclass
class SyncPayload(Generic[T]):
stage_name: Optional[str]
stage_name: str | None
success: bool
payload: T
exception: Optional[Exception] = None
exception: Exception | None = None
def broadcast(
data_or_fn: Union[T, Callable[[], T]],
data_or_fn: T | Callable[[], T],
*,
success: bool = True,
stage_name: Optional[str] = None,
stage_name: str | None = None,
rank: int = 0,
pg: Optional[dist.ProcessGroup] = None,
pg: dist.ProcessGroup | None = None,
) -> T:
"""
Broadcasts the data payload from rank 0 to all other ranks.
@ -79,8 +79,8 @@ def broadcast(
"Data or Function is expected to be None if not successful"
)
payload: Optional[T] = None
exception: Optional[Exception] = None
payload: T | None = None
exception: Exception | None = None
# if no pg is passed then execute if rank is 0
if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank):
# determine if it is an executable function or data payload only
@ -124,9 +124,9 @@ def broadcast(
def all_gather(
data_or_fn: Union[T, Callable[[], T]],
stage_name: Optional[str] = None,
pg: Optional[dist.ProcessGroup] = None,
data_or_fn: T | Callable[[], T],
stage_name: str | None = None,
pg: dist.ProcessGroup | None = None,
) -> list[T]:
"""
A simple all_gather primitive with basic synchronization guard logic,
@ -144,8 +144,8 @@ def all_gather(
Example usage:
>> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg)
"""
payload: Optional[T] = None
exception: Optional[Exception] = None
payload: T | None = None
exception: Exception | None = None
success = True
# determine if it is an executable function or data payload only
if callable(data_or_fn):
@ -247,7 +247,7 @@ def _summarize_ranks(ranks: Iterable[int]) -> str:
raise AssertionError("ranks should all be positive")
if len(set(ranks)) != len(ranks):
raise AssertionError("ranks should not contain duplicates")
curr: Optional[Union[int, range]] = None
curr: int | range | None = None
ranges = []
while ranks:
x = ranks.pop(0)
@ -345,9 +345,7 @@ def _desync_table_str(tag: str, value_ranks: dict[Any, set[int]]) -> str:
return str(f"{headers}\n{row_str}")
def _check_rng_sync(
generator: torch.Generator, group: dist.ProcessGroup
) -> Optional[str]:
def _check_rng_sync(generator: torch.Generator, group: dist.ProcessGroup) -> str | None:
value_ranks, value_header = _check_rng_sync_internal(generator, group)
log_str = None
if len(value_ranks) > 1:

View File

@ -1,5 +1,4 @@
from datetime import timedelta
from typing import Optional
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
@ -19,7 +18,7 @@ default_pg_timeout: timedelta = _DEFAULT_PG_TIMEOUT
try:
from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT
default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT
default_pg_nccl_timeout: timedelta | None = _DEFAULT_PG_NCCL_TIMEOUT
except ImportError:
# if C++ NCCL support is not compiled, we don't have access to the default nccl value.
# if anyone is actually trying to use nccl in this state, it should error.

View File

@ -65,7 +65,7 @@ else:
"DeviceMesh requires numpy >= 1.21 to be installed for type checking"
)
BackendConfig = tuple[Optional[str], Optional[C10dBackend.Options]]
BackendConfig = tuple[str | None, C10dBackend.Options | None]
torch.serialization.add_safe_globals([_MeshLayout])
class _MeshEnv(threading.local):
@ -175,7 +175,7 @@ else:
_device_type: str
_rank_map: torch.Tensor
_mesh_dim_names: Optional[tuple[str, ...]]
_mesh_dim_names: tuple[str, ...] | None
_layout: _MeshLayout
_root_mesh: Optional["DeviceMesh"] = None
# Record flatten mesh name to its flattened mesh in root mesh.
@ -184,14 +184,14 @@ else:
def __init__(
self,
device_type: str,
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
mesh: Union[torch.Tensor, "ArrayLike"] | None = None,
*,
mesh_dim_names: Optional[tuple[str, ...]] = None,
backend_override: Optional[tuple[BackendConfig, ...]] = None,
mesh_dim_names: tuple[str, ...] | None = None,
backend_override: tuple[BackendConfig, ...] | None = None,
_init_backend: bool = True,
_rank: Optional[int] = None,
_layout: Optional[_MeshLayout] = None,
_rank_map: Optional[torch.Tensor] = None,
_rank: int | None = None,
_layout: _MeshLayout | None = None,
_rank_map: torch.Tensor | None = None,
_root_mesh: Optional["DeviceMesh"] = None,
) -> None:
# no-op in OSS, logs API usage metrics in meta-internal runs
@ -292,7 +292,7 @@ else:
raise AssertionError(
f"rank_coords.size(0) must be 0 or 1, got {rank_coords.size(0)}"
)
self._coordinate_on_dim: Optional[list[int]] = (
self._coordinate_on_dim: list[int] | None = (
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
)
@ -317,7 +317,7 @@ else:
)
@property
def mesh_dim_names(self) -> Optional[tuple[str, ...]]:
def mesh_dim_names(self) -> tuple[str, ...] | None:
"""Returns the names of mesh dimensions."""
return self._mesh_dim_names
@ -378,7 +378,7 @@ else:
rank_map: torch.Tensor,
dim_name: str,
backend_override: BackendConfig,
) -> Optional[str]:
) -> str | None:
# Generate a 2D global mesh tensor for the current dim for PG creation.
pg_ranks_by_dim = sub_layout.nest().remap_to_tensor(rank_map)
backend, pg_options = backend_override
@ -471,7 +471,7 @@ else:
def _init_process_groups(
layout: _MeshLayout,
rank_map: torch.Tensor,
mesh_dim_names: Optional[tuple[str, ...]],
mesh_dim_names: tuple[str, ...] | None,
backend_override: tuple[BackendConfig, ...],
) -> list[str]:
# group_name associated with each mesh dimension, each
@ -543,9 +543,7 @@ else:
and self._thread_id == other._thread_id
)
def __getitem__(
self, mesh_dim_names: Union[str, tuple[str, ...]]
) -> "DeviceMesh":
def __getitem__(self, mesh_dim_names: str | tuple[str, ...]) -> "DeviceMesh":
"""
Slice the current DeviceMesh based on the mesh_dim_names given to create a submesh.
The submesh created consists of the dimensions and the communicators indicated by
@ -613,7 +611,7 @@ else:
submesh = self._create_sub_mesh(sliced_mesh_layout, mesh_dim_names)
return submesh
def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup:
def get_group(self, mesh_dim: int | str | None = None) -> ProcessGroup:
"""
Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the
DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh.
@ -705,7 +703,7 @@ else:
def _create_flatten_mesh(
self,
mesh_dim_name: Optional[str] = None,
mesh_dim_name: str | None = None,
backend_override: BackendConfig = (None, None),
) -> "DeviceMesh":
root_mesh = self._get_root_mesh()
@ -754,7 +752,7 @@ else:
return res_flattened_mesh
def _get_root_mesh_dim(self) -> Optional[int]:
def _get_root_mesh_dim(self) -> int | None:
"""
Returns the index of the mesh dim in the root mesh.
The device_mesh passed in needs to be sliced out from the root mesh
@ -893,11 +891,11 @@ else:
@staticmethod
def from_group(
group: Union[ProcessGroup, list[ProcessGroup]],
group: ProcessGroup | list[ProcessGroup],
device_type: str,
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
mesh: Union[torch.Tensor, "ArrayLike"] | None = None,
*,
mesh_dim_names: Optional[tuple[str, ...]] = None,
mesh_dim_names: tuple[str, ...] | None = None,
) -> "DeviceMesh":
"""
Constructs a :class:`DeviceMesh` with ``device_type`` from an
@ -986,7 +984,7 @@ else:
device_mesh._dim_group_names = [group.group_name for group in groups]
return device_mesh
def size(self, mesh_dim: Optional[int] = None) -> int:
def size(self, mesh_dim: int | None = None) -> int:
if mesh_dim is not None:
return self._layout[mesh_dim].numel()
return self._layout.numel()
@ -1005,7 +1003,7 @@ else:
"""
return get_rank()
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
def get_local_rank(self, mesh_dim: int | str | None = None) -> int:
"""
Returns the local rank of the given mesh_dim of the DeviceMesh.
@ -1049,7 +1047,7 @@ else:
)
return not_none(get_rank(mesh_dim_group))
def get_coordinate(self) -> Optional[list[int]]:
def get_coordinate(self) -> list[int] | None:
"""
Return the relative indices of this rank relative to all
dimensions of the mesh. If this rank is not part of the mesh, return None.
@ -1058,10 +1056,11 @@ else:
def _flatten(
self,
mesh_dim_name: Optional[str] = None,
backend_override: Union[
None, str, C10dBackend.Options, tuple[str, C10dBackend.Options]
] = None,
mesh_dim_name: str | None = None,
backend_override: None
| str
| C10dBackend.Options
| tuple[str, C10dBackend.Options] = None,
) -> "DeviceMesh":
"""
Returns a 1D DeviceMesh by flattening the current DeviceMesh.
@ -1095,7 +1094,7 @@ else:
mesh_sizes: tuple[int, ...],
mesh_dim_names: tuple[str, ...],
backend_override: tuple[
tuple[Optional[str], Optional[C10dBackend.Options]], ...
tuple[str | None, C10dBackend.Options | None], ...
] = ((None, None),),
) -> "DeviceMesh":
inner_layout = _MeshLayout(tuple(mesh_sizes), suffix_product(mesh_sizes))
@ -1140,15 +1139,13 @@ else:
def _unflatten(
self,
dim: Union[int, str],
dim: int | str,
mesh_sizes: tuple[int, ...],
mesh_dim_names: tuple[str, ...],
backend_override: Optional[
dict[
str,
Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]],
]
] = None,
backend_override: dict[
str, str | C10dBackend.Options | tuple[str, C10dBackend.Options]
]
| None = None,
) -> "DeviceMesh":
"""
Returns a DeviceMesh by unflatten the current DeviceMesh.
@ -1239,11 +1236,11 @@ else:
def _normalize_backend_override(
backend_override: dict[
Union[int, str],
Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]],
int | str,
str | C10dBackend.Options | tuple[str, C10dBackend.Options],
],
ndim: int,
mesh_dim_names: Optional[tuple[str, ...]] = None,
mesh_dim_names: tuple[str, ...] | None = None,
) -> Iterator[BackendConfig]:
if mesh_dim_names is None:
mesh_dim_names = ()
@ -1278,13 +1275,11 @@ else:
device_type: str,
mesh_shape: tuple[int, ...],
*,
mesh_dim_names: Optional[tuple[str, ...]] = None,
backend_override: Optional[
dict[
Union[int, str],
Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]],
]
] = None,
mesh_dim_names: tuple[str, ...] | None = None,
backend_override: dict[
int | str, str | C10dBackend.Options | tuple[str, C10dBackend.Options]
]
| None = None,
) -> DeviceMesh:
"""
Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.

View File

@ -17,7 +17,7 @@ import warnings
from collections import namedtuple
from collections.abc import Callable
from datetime import timedelta
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Any, TYPE_CHECKING
from typing_extensions import deprecated
import torch
@ -309,7 +309,7 @@ class Backend(str): # noqa: SLOT000
name,
func,
extended_api=False,
devices: Optional[Union[str, list[str]]] = None,
devices: str | list[str] | None = None,
) -> None:
"""
Register a new backend with the given name and instantiating function.
@ -504,10 +504,10 @@ class P2POp:
self,
op: Callable,
tensor: torch.Tensor,
peer: Optional[int] = None,
group: Optional[ProcessGroup] = None,
peer: int | None = None,
group: ProcessGroup | None = None,
tag: int = 0,
group_peer: Optional[int] = None,
group_peer: int | None = None,
):
"""Init."""
self.op = op
@ -523,10 +523,10 @@ class P2POp:
cls,
op: Callable,
tensor: torch.Tensor,
peer: Optional[int] = None,
group: Optional[ProcessGroup] = None,
peer: int | None = None,
group: ProcessGroup | None = None,
tag: int = 0,
group_peer: Optional[int] = None,
group_peer: int | None = None,
):
"""Create and return a new instance of the class."""
_check_op(op)
@ -566,9 +566,9 @@ class _CollOp:
self,
op: Callable,
tensor: torch.Tensor,
dst_tensor: Optional[torch.Tensor] = None,
redop: Optional[ReduceOp] = None,
root: Optional[int] = None,
dst_tensor: torch.Tensor | None = None,
redop: ReduceOp | None = None,
root: int | None = None,
):
self.op = op
self.tensor = tensor
@ -587,7 +587,7 @@ _pg_backend_config: dict[ProcessGroup, str] = {}
_group_count = 0
_tags_to_pg: dict[str, list[ProcessGroup]] = {}
_pg_to_tag: dict[ProcessGroup, str] = {}
_backend: Optional[str] = None
_backend: str | None = None
class _World:
@ -605,7 +605,7 @@ class _World:
self._pg_coalesce_state: dict[ProcessGroup, list[_CollOp]] = {}
@property
def default_pg(self) -> Optional[ProcessGroup]:
def default_pg(self) -> ProcessGroup | None:
"""
Process group that includes all ranks of the cluster.
@ -730,11 +730,11 @@ class _WorldMeta(type):
# Points to the default PG once initialized.
@property
def WORLD(cls) -> Optional[ProcessGroup]:
def WORLD(cls) -> ProcessGroup | None:
return _world.default_pg
@WORLD.setter
def WORLD(cls, pg: Optional[ProcessGroup]):
def WORLD(cls, pg: ProcessGroup | None):
_world.default_pg = pg
@ -772,12 +772,12 @@ def _check_valid_timeout(timeout: Any) -> None:
# Default process group state
_default_pg_init_method: Optional[str] = None
_default_pg_init_method: str | None = None
STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"
def _get_object_coll_device(group: Optional[ProcessGroup] = None) -> str:
def _get_object_coll_device(group: ProcessGroup | None = None) -> str:
"""
.. note:: This is an internal helper and does not have backward
compatibility, please use with caution.
@ -843,7 +843,7 @@ def _get_object_coll_device(group: Optional[ProcessGroup] = None) -> str:
return devices[0].type
def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device:
def _get_pg_default_device(group: ProcessGroup | None = None) -> torch.device:
"""
.. note:: This method will be deprecated, it only stays for
backward-compatiblity reason. Alternatives:
@ -923,7 +923,7 @@ def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device
return rv
def _device_capability(group: Optional[ProcessGroup] = None) -> list[str]:
def _device_capability(group: ProcessGroup | None = None) -> list[str]:
"""
Return the device type(s) supported by ``group``.
@ -1007,7 +1007,7 @@ def _store_based_barrier(
)
def _rank_not_in_group(group: Optional[ProcessGroup]) -> bool:
def _rank_not_in_group(group: ProcessGroup | None) -> bool:
"""Check if the current process's rank is not in a given group."""
if group is None:
return False
@ -1089,7 +1089,7 @@ def _get_global_rank(group, rank) -> int:
return get_global_rank(group, rank)
def get_process_group_ranks(group: Optional[ProcessGroup]) -> list[int]:
def get_process_group_ranks(group: ProcessGroup | None) -> list[int]:
"""
Get all ranks associated with ``group``.
@ -1148,7 +1148,7 @@ def _check_tensor_list(param, param_name) -> None:
)
def _group_or_default_group(group: Optional[ProcessGroup] = None) -> ProcessGroup:
def _group_or_default_group(group: ProcessGroup | None = None) -> ProcessGroup:
if group is None or group is GroupMember.WORLD:
group = _get_default_group()
return group
@ -1156,8 +1156,8 @@ def _group_or_default_group(group: Optional[ProcessGroup] = None) -> ProcessGrou
def _canonicalize_group_rank(
group: ProcessGroup,
global_rank: Optional[int] = None,
group_rank: Optional[int] = None,
global_rank: int | None = None,
group_rank: int | None = None,
return_global: bool = False,
) -> int:
"""
@ -1361,7 +1361,7 @@ def _update_default_pg(pg) -> None:
torch._C._distributed_c10d._set_global_rank(rank)
def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
def get_backend_config(group: ProcessGroup | None = None) -> str:
"""
Return the backend configuration of the given process group.
@ -1381,7 +1381,7 @@ def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
return str(not_none(backend_config))
def get_backend(group: Optional[ProcessGroup] = None) -> Backend:
def get_backend(group: ProcessGroup | None = None) -> Backend:
"""
Return the backend of the given process group.
@ -1407,7 +1407,7 @@ def get_backend(group: Optional[ProcessGroup] = None) -> Backend:
return Backend(not_none(pg_store)[0])
def get_default_backend_for_device(device: Union[str, torch.device]) -> str:
def get_default_backend_for_device(device: str | torch.device) -> str:
"""
Return the default backend for the given device.
@ -1441,7 +1441,7 @@ def _get_process_group_uid(pg: ProcessGroup) -> int:
return -1
def _get_pg_config(group: Optional[ProcessGroup] = None) -> dict[str, Any]:
def _get_pg_config(group: ProcessGroup | None = None) -> dict[str, Any]:
"""
Return the pg configuration of the given process group.
@ -1473,7 +1473,7 @@ def get_pg_count() -> int:
return _world.group_count
def get_node_local_rank(fallback_rank: Optional[int] = None) -> int:
def get_node_local_rank(fallback_rank: int | None = None) -> int:
"""
Return the local rank of the current process relative to the node.
@ -1526,7 +1526,7 @@ def _add_ephemeral_timeout_for_all_pgs(timeout: timedelta) -> None:
backend._add_ephemeral_timeout(timeout)
def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> None:
def _set_pg_timeout(timeout: timedelta, group: ProcessGroup | None = None) -> None:
"""
Set the timeout for the given process group when users want to use a different timeout instead of
default values.
@ -1575,16 +1575,16 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) ->
@_exception_logger
@_time_logger
def init_process_group(
backend: Optional[str] = None,
init_method: Optional[str] = None,
timeout: Optional[timedelta] = None,
backend: str | None = None,
init_method: str | None = None,
timeout: timedelta | None = None,
world_size: int = -1,
rank: int = -1,
store: Optional[Store] = None,
store: Store | None = None,
group_name: str = "",
pg_options: Optional[Any] = None,
device_id: Optional[Union[torch.device, int]] = None,
_ranks: Optional[list[int]] = None,
pg_options: Any | None = None,
device_id: torch.device | int | None = None,
_ranks: list[int] | None = None,
) -> None:
"""
Initialize the default distributed process group.
@ -2216,7 +2216,7 @@ def _new_process_group_helper(
return pg, prefix_store
def destroy_process_group(group: Optional[ProcessGroup] = None):
def destroy_process_group(group: ProcessGroup | None = None):
"""
Destroy a given process group, and deinitialize the distributed package.
@ -2305,7 +2305,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
_unregister_process_group(pg.group_name)
def _abort_process_group(group: Optional[ProcessGroup] = None):
def _abort_process_group(group: ProcessGroup | None = None):
"""
Abort a given process group. If group.WORLD (i.e. `None`) is given, all
process groups including the default one will be aborted.
@ -2397,7 +2397,7 @@ def _abort_process_group(group: Optional[ProcessGroup] = None):
_unregister_process_group(pg.group_name)
def get_rank(group: Optional[ProcessGroup] = None) -> int:
def get_rank(group: ProcessGroup | None = None) -> int:
"""
Return the rank of the current process in the provided ``group``, default otherwise.
@ -2424,7 +2424,7 @@ def get_rank(group: Optional[ProcessGroup] = None) -> int:
return get_group_rank(group, default_pg.rank())
def get_world_size(group: Optional[ProcessGroup] = None) -> int:
def get_world_size(group: ProcessGroup | None = None) -> int:
"""
Return the number of processes in the current process group.
@ -2445,11 +2445,11 @@ def get_world_size(group: Optional[ProcessGroup] = None) -> int:
def isend(
tensor: torch.Tensor,
dst: Optional[int] = None,
group: Optional[ProcessGroup] = None,
dst: int | None = None,
group: ProcessGroup | None = None,
tag: int = 0,
group_dst: Optional[int] = None,
) -> Optional[Work]:
group_dst: int | None = None,
) -> Work | None:
"""
Send a tensor asynchronously.
@ -2490,11 +2490,11 @@ def isend(
def irecv(
tensor: torch.Tensor,
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
src: int | None = None,
group: ProcessGroup | None = None,
tag: int = 0,
group_src: Optional[int] = None,
) -> Optional[Work]:
group_src: int | None = None,
) -> Work | None:
"""
Receives a tensor asynchronously.
@ -2536,10 +2536,10 @@ def irecv(
@_exception_logger
def send(
tensor: torch.Tensor,
dst: Optional[int] = None,
group: Optional[ProcessGroup] = None,
dst: int | None = None,
group: ProcessGroup | None = None,
tag: int = 0,
group_dst: Optional[int] = None,
group_dst: int | None = None,
) -> None:
"""
Send a tensor synchronously.
@ -2568,10 +2568,10 @@ def send(
@_exception_logger
def recv(
tensor: torch.Tensor,
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
src: int | None = None,
group: ProcessGroup | None = None,
tag: int = 0,
group_src: Optional[int] = None,
group_src: int | None = None,
) -> int:
"""
Receives a tensor synchronously.
@ -2623,7 +2623,7 @@ class _CoalescingManager:
def __init__(self) -> None:
self.works: list[Work] = []
def append(self, work: Optional[Work] = None):
def append(self, work: Work | None = None):
if work:
self.works.append(work)
@ -2634,8 +2634,8 @@ class _CoalescingManager:
@contextlib.contextmanager
def _coalescing_manager(
group: Optional[ProcessGroup] = None,
device: Optional[torch.device] = None,
group: ProcessGroup | None = None,
device: torch.device | None = None,
async_ops: bool = False,
):
"""
@ -2731,13 +2731,13 @@ def _coalescing_manager(
class _TimeEstimator:
def __init__(self) -> None:
self.estimated_time: Optional[float] = None
self.estimated_time: float | None = None
@contextlib.contextmanager
def _time_estimator(
group: Optional[ProcessGroup] = None,
device: Optional[torch.device] = None,
group: ProcessGroup | None = None,
device: torch.device | None = None,
):
"""
Context manager used to estimate time of collectives.
@ -2862,10 +2862,10 @@ def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]:
@_exception_logger
def broadcast(
tensor: torch.Tensor,
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
src: int | None = None,
group: ProcessGroup | None = None,
async_op: bool = False,
group_src: Optional[int] = None,
group_src: int | None = None,
):
"""
Broadcasts the tensor to the whole group.
@ -3084,11 +3084,11 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
@_exception_logger
def reduce(
tensor: torch.Tensor,
dst: Optional[int] = None,
dst: int | None = None,
op=ReduceOp.SUM,
group: Optional[ProcessGroup] = None,
group: ProcessGroup | None = None,
async_op: bool = False,
group_dst: Optional[int] = None,
group_dst: int | None = None,
):
"""
Reduces the tensor data across all machines.
@ -3268,10 +3268,10 @@ def all_gather_object(object_list, obj, group=None):
@_exception_logger
def gather_object(
obj: Any,
object_gather_list: Optional[list[Any]] = None,
dst: Optional[int] = None,
group: Optional[ProcessGroup] = None,
group_dst: Optional[int] = None,
object_gather_list: list[Any] | None = None,
dst: int | None = None,
group: ProcessGroup | None = None,
group_dst: int | None = None,
):
"""
Gathers picklable objects from the whole group in a single process.
@ -3399,10 +3399,10 @@ def gather_object(
@_exception_logger
def send_object_list(
object_list: list[Any],
dst: Optional[int] = None,
group: Optional[ProcessGroup] = None,
device: Optional[torch.device] = None,
group_dst: Optional[int] = None,
dst: int | None = None,
group: ProcessGroup | None = None,
device: torch.device | None = None,
group_dst: int | None = None,
use_batch: bool = False,
):
"""
@ -3517,10 +3517,10 @@ def send_object_list(
@_exception_logger
def recv_object_list(
object_list: list[Any],
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
device: Optional[torch.device] = None,
group_src: Optional[int] = None,
src: int | None = None,
group: ProcessGroup | None = None,
device: torch.device | None = None,
group_src: int | None = None,
use_batch: bool = False,
):
"""
@ -3659,10 +3659,10 @@ def recv_object_list(
@_exception_logger
def broadcast_object_list(
object_list: list[Any],
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
device: Optional[torch.device] = None,
group_src: Optional[int] = None,
src: int | None = None,
group: ProcessGroup | None = None,
device: torch.device | None = None,
group_src: int | None = None,
):
"""
Broadcasts picklable objects in ``object_list`` to the whole group.
@ -3791,10 +3791,10 @@ def broadcast_object_list(
@_exception_logger
def scatter_object_list(
scatter_object_output_list: list[Any],
scatter_object_input_list: Optional[list[Any]] = None,
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
group_src: Optional[int] = None,
scatter_object_input_list: list[Any] | None = None,
src: int | None = None,
group: ProcessGroup | None = None,
group_src: int | None = None,
):
"""
Scatters picklable objects in ``scatter_object_input_list`` to the whole group.
@ -4265,11 +4265,11 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list):
@_exception_logger
def gather(
tensor: torch.Tensor,
gather_list: Optional[list[torch.Tensor]] = None,
dst: Optional[int] = None,
group: Optional[ProcessGroup] = None,
gather_list: list[torch.Tensor] | None = None,
dst: int | None = None,
group: ProcessGroup | None = None,
async_op: bool = False,
group_dst: Optional[int] = None,
group_dst: int | None = None,
):
"""
Gathers a list of tensors in a single process.
@ -4348,11 +4348,11 @@ def gather(
@_exception_logger
def scatter(
tensor: torch.Tensor,
scatter_list: Optional[list[torch.Tensor]] = None,
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
scatter_list: list[torch.Tensor] | None = None,
src: int | None = None,
group: ProcessGroup | None = None,
async_op: bool = False,
group_src: Optional[int] = None,
group_src: int | None = None,
):
"""
Scatters a list of tensors to all processes in a group.
@ -4895,7 +4895,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
@_exception_logger
def barrier(
group: Optional[ProcessGroup] = GroupMember.WORLD, async_op=False, device_ids=None
group: ProcessGroup | None = GroupMember.WORLD, async_op=False, device_ids=None
):
"""
Synchronize all processes.
@ -4967,7 +4967,7 @@ def barrier(
def monitored_barrier(
group: Optional[ProcessGroup] = GroupMember.WORLD,
group: ProcessGroup | None = GroupMember.WORLD,
timeout=None,
wait_all_ranks=False,
):
@ -5104,7 +5104,7 @@ def _process_group_name(ranks, use_hashed_name):
return pg_name
def _get_backend_from_str(backend: Optional[str] = None) -> Backend:
def _get_backend_from_str(backend: str | None = None) -> Backend:
# Default to the same backend as the global process group
# if backend is not specified.
if not backend:
@ -5124,12 +5124,12 @@ def _is_safe_to_split() -> bool:
@_time_logger
def split_group(
parent_pg: Optional[ProcessGroup] = None,
split_ranks: Optional[list] = None,
timeout: Optional[timedelta] = None,
pg_options: Optional[Any] = None,
group_desc: Optional[str] = None,
) -> Optional[ProcessGroup]:
parent_pg: ProcessGroup | None = None,
split_ranks: list | None = None,
timeout: timedelta | None = None,
pg_options: Any | None = None,
group_desc: str | None = None,
) -> ProcessGroup | None:
"""
Create a new process group split from the given parent process group.
@ -5290,7 +5290,7 @@ def new_group(
pg_options=None,
use_local_synchronization=False,
group_desc=None,
device_id: Optional[torch.device] = None,
device_id: torch.device | None = None,
):
"""
Create a new distributed group.
@ -5380,7 +5380,7 @@ def _new_group_with_tag(
pg_tag=None,
use_local_synchronization=False,
group_desc=None,
device_id: Optional[torch.device] = None,
device_id: torch.device | None = None,
):
"""
Variant of ``new_group`` that exposes tag creation.
@ -5693,7 +5693,7 @@ def new_subgroups_by_enumeration(
return cur_subgroup, subgroups
def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGroup]:
def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> ProcessGroup | None:
if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"):
tag = f"user:{tag}"
@ -5765,9 +5765,9 @@ SHRINK_ABORT = 0x01
@_time_logger
def shrink_group(
ranks_to_exclude: list[int],
group: Optional[ProcessGroup] = None,
group: ProcessGroup | None = None,
shrink_flags: int = SHRINK_DEFAULT,
pg_options: Optional[Any] = None,
pg_options: Any | None = None,
) -> ProcessGroup:
"""
Shrinks a process group by excluding specified ranks.
@ -5857,7 +5857,7 @@ def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> N
)
def _prepare_shrink_target_group(group: Optional[ProcessGroup]) -> dict:
def _prepare_shrink_target_group(group: ProcessGroup | None) -> dict:
"""Prepare and validate the target group for shrinking."""
target_pg = group if group is not None else _get_default_group()
@ -6107,7 +6107,7 @@ def _create_shrunk_process_group(
return new_pg
def _destroy_all_other_groups(exclude_group: Optional[ProcessGroup] = None) -> None:
def _destroy_all_other_groups(exclude_group: ProcessGroup | None = None) -> None:
"""
Destroy all process groups except the excluded group and clean up all global state.
@ -6223,9 +6223,9 @@ def _update_process_group_global_state(
store: Store,
group_name: str,
backend_config: str,
rank_mapping: Optional[dict[int, int]] = None,
pg_tag: Optional[str] = None,
user_tag: Optional[str] = None,
rank_mapping: dict[int, int] | None = None,
pg_tag: str | None = None,
user_tag: str | None = None,
) -> None:
"""
Update all global state dictionaries for a process group.

View File

@ -19,7 +19,7 @@ from collections.abc import Callable
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Optional, Union
from typing import Any
import torch.distributed.elastic.rendezvous as rdzv
import torch.distributed.elastic.utils.store as store_util
@ -89,19 +89,19 @@ class WorkerSpec:
role: str
local_world_size: int
rdzv_handler: rdzv.RendezvousHandler
fn: Optional[Callable] = None
fn: Callable | None = None
# TODO @kiuk - make entrypoint a required field
entrypoint: Union[Callable, str, None] = None
entrypoint: Callable | str | None = None
args: tuple = ()
max_restarts: int = 3
monitor_interval: float = 0.1
master_port: Optional[int] = None
master_addr: Optional[str] = None
local_addr: Optional[str] = None
master_port: int | None = None
master_addr: str | None = None
local_addr: str | None = None
event_log_handler: str = "null"
numa_options: Optional[NumaOptions] = None
duplicate_stdout_filters: Optional[list[str]] = None
duplicate_stderr_filters: Optional[list[str]] = None
numa_options: NumaOptions | None = None
duplicate_stdout_filters: list[str] | None = None
duplicate_stderr_filters: list[str] | None = None
virtual_local_rank: bool = False
def __post_init__(self):
@ -807,11 +807,11 @@ class SimpleElasticAgent(ElasticAgent):
self,
state: str,
source: EventSource,
worker: Optional[Worker] = None,
raw_error: Optional[str] = None,
duration_ms: Optional[float] = None,
exit_code: Optional[int] = None,
worker_pid: Optional[int] = None,
worker: Worker | None = None,
raw_error: str | None = None,
duration_ms: float | None = None,
exit_code: int | None = None,
worker_pid: int | None = None,
) -> Event:
wg = self._worker_group
spec = wg.spec

View File

@ -15,7 +15,7 @@ import socket
import time
import uuid
from string import Template
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, TYPE_CHECKING
import torch.distributed.elastic.timer as timer
from torch.distributed.elastic import events
@ -152,16 +152,16 @@ class LocalElasticAgent(SimpleElasticAgent):
logs_specs: LogsSpecs,
start_method="spawn",
exit_barrier_timeout: float = 300,
log_line_prefix_template: Optional[str] = None,
log_line_prefix_template: str | None = None,
):
super().__init__(spec, exit_barrier_timeout)
self._start_method = start_method
self._pcontext: Optional[PContext] = None
self._pcontext: PContext | None = None
self._rdzv_handler = spec.rdzv_handler
self._log_line_prefix_template = log_line_prefix_template
self._worker_watchdog: Optional[timer.FileTimerServer] = None
self._worker_watchdog: timer.FileTimerServer | None = None
self._logs_specs = logs_specs
self._health_check_server: Optional[HealthCheckServer] = None
self._health_check_server: HealthCheckServer | None = None
def _setup_local_watchdog(self, envs: dict[int, dict[str, str]]) -> None:
enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
@ -244,7 +244,7 @@ class LocalElasticAgent(SimpleElasticAgent):
def _log_watchdog_event(
self,
name: str,
request: Optional[timer.FileTimerRequest],
request: timer.FileTimerRequest | None,
) -> None:
wg = self._worker_group
spec = wg.spec
@ -297,7 +297,7 @@ class LocalElasticAgent(SimpleElasticAgent):
args: dict[int, tuple] = {}
envs: dict[int, dict[str, str]] = {}
log_line_prefixes: Optional[dict[int, str]] = (
log_line_prefixes: dict[int, str] | None = (
{} if self._log_line_prefix_template else None
)
for worker in worker_group.workers:

View File

@ -86,10 +86,10 @@ def construct_and_record_rdzv_event(
node_state: NodeState,
name: str = "",
hostname: str = "",
pid: Optional[int] = None,
pid: int | None = None,
master_endpoint: str = "",
local_id: Optional[int] = None,
rank: Optional[int] = None,
local_id: int | None = None,
rank: int | None = None,
) -> None:
"""
Initialize rendezvous event object and record its operations.

View File

@ -10,7 +10,7 @@
import json
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Optional, Union
from typing import Union
__all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"]
@ -95,8 +95,8 @@ class RdzvEvent:
pid: int
node_state: NodeState
master_endpoint: str = ""
rank: Optional[int] = None
local_id: Optional[int] = None
rank: int | None = None
local_id: int | None = None
error_trace: str = ""
def __str__(self):

View File

@ -158,7 +158,7 @@ from .api import ( # noqa: F401
)
def initialize_metrics(cfg: Optional[MetricsConfig] = None):
def initialize_metrics(cfg: MetricsConfig | None = None):
pass

View File

@ -11,7 +11,6 @@ import abc
import time
from collections import namedtuple
from functools import wraps
from typing import Optional
from typing_extensions import deprecated
@ -37,7 +36,7 @@ MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value
class MetricsConfig:
__slots__ = ["params"]
def __init__(self, params: Optional[dict[str, str]] = None):
def __init__(self, params: dict[str, str] | None = None):
self.params = params
if self.params is None:
self.params = {}
@ -77,7 +76,7 @@ _default_metrics_handler: MetricHandler = NullMetricHandler()
# pyre-fixme[9]: group has type `str`; used as `None`.
def configure(handler: MetricHandler, group: Optional[str] = None):
def configure(handler: MetricHandler, group: str | None = None):
if group is None:
global _default_metrics_handler
# pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used

View File

@ -102,15 +102,15 @@ __all__ = [
def start_processes(
name: str,
entrypoint: Union[Callable, str],
entrypoint: Callable | str,
args: dict[int, tuple],
envs: dict[int, dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[dict[int, str]] = None,
log_line_prefixes: dict[int, str] | None = None,
start_method: str = "spawn",
numa_options: Optional[NumaOptions] = None,
duplicate_stdout_filters: Optional[list[str]] = None,
duplicate_stderr_filters: Optional[list[str]] = None,
numa_options: NumaOptions | None = None,
duplicate_stdout_filters: list[str] | None = None,
duplicate_stderr_filters: list[str] | None = None,
) -> PContext:
"""
Start ``n`` copies of ``entrypoint`` processes with the provided options.

View File

@ -25,7 +25,7 @@ from dataclasses import dataclass, field
from enum import IntFlag
from multiprocessing import synchronize
from types import FrameType
from typing import Any, Optional, TextIO, Union
from typing import Any, TextIO, Union
import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record
@ -73,7 +73,7 @@ class SignalException(Exception):
self.sigval = sigval
def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None:
def _terminate_process_handler(signum: int, frame: FrameType | None) -> None:
"""Termination handler that raises exceptions on the main process.
When the process receives death signal(SIGTERM, SIGINT), this termination handler will
@ -156,9 +156,7 @@ class Std(IntFlag):
)
def to_map(
val_or_map: Union[Std, dict[int, Std]], local_world_size: int
) -> dict[int, Std]:
def to_map(val_or_map: Std | dict[int, Std], local_world_size: int) -> dict[int, Std]:
"""
Certain APIs take redirect settings either as a single value (e.g. apply to all
local ranks) or as an explicit user-provided mapping. This method is a convenience
@ -216,10 +214,10 @@ class LogsSpecs(ABC):
def __init__(
self,
log_dir: Optional[str] = None,
redirects: Union[Std, dict[int, Std]] = Std.NONE,
tee: Union[Std, dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[set[int]] = None,
log_dir: str | None = None,
redirects: Std | dict[int, Std] = Std.NONE,
tee: Std | dict[int, Std] = Std.NONE,
local_ranks_filter: set[int] | None = None,
) -> None:
self._root_log_dir = log_dir
self._redirects = redirects
@ -254,10 +252,10 @@ class DefaultLogsSpecs(LogsSpecs):
def __init__(
self,
log_dir: Optional[str] = None,
redirects: Union[Std, dict[int, Std]] = Std.NONE,
tee: Union[Std, dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[set[int]] = None,
log_dir: str | None = None,
redirects: Std | dict[int, Std] = Std.NONE,
tee: Std | dict[int, Std] = Std.NONE,
local_ranks_filter: set[int] | None = None,
) -> None:
if log_dir != os.devnull:
if not log_dir:
@ -275,7 +273,7 @@ class DefaultLogsSpecs(LogsSpecs):
def root_log_dir(self) -> str:
return str(self._root_log_dir)
def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str):
def _make_log_dir(self, log_dir: str | None, rdzv_run_id: str):
base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_")
os.makedirs(base_log_dir, exist_ok=True)
dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir)
@ -465,13 +463,13 @@ class PContext(abc.ABC):
def __init__(
self,
name: str,
entrypoint: Union[Callable, str],
entrypoint: Callable | str,
args: dict[int, tuple],
envs: dict[int, dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[dict[int, str]] = None,
duplicate_stdout_filters: Optional[list[str]] = None,
duplicate_stderr_filters: Optional[list[str]] = None,
log_line_prefixes: dict[int, str] | None = None,
duplicate_stdout_filters: list[str] | None = None,
duplicate_stderr_filters: list[str] | None = None,
):
self.name = name
# validate that all mappings have the same number of keys and
@ -491,8 +489,8 @@ class PContext(abc.ABC):
self.stderrs = logs_dest.stderrs
self.error_files = logs_dest.error_files
self.nprocs = nprocs
self.filtered_stdout: Optional[TextIO] = None
self.filtered_stderr: Optional[TextIO] = None
self.filtered_stdout: TextIO | None = None
self.filtered_stderr: TextIO | None = None
self._tail_logs = [
TailLog(name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes),
@ -582,7 +580,7 @@ class PContext(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
def _poll(self) -> Optional[RunProcsResult]:
def _poll(self) -> RunProcsResult | None:
"""
Poll the run status of the processes running under this context.
This method follows an "all-or-nothing" policy and returns
@ -592,7 +590,7 @@ class PContext(abc.ABC):
"""
raise NotImplementedError
def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]:
def wait(self, timeout: float = -1, period: float = 1) -> RunProcsResult | None:
"""
Wait for the specified ``timeout`` seconds, polling every ``period`` seconds
for the processes to be done. Returns ``None`` if the processes are still running
@ -646,9 +644,7 @@ class PContext(abc.ABC):
"""
raise NotImplementedError
def close(
self, death_sig: Optional[signal.Signals] = None, timeout: int = 30
) -> None:
def close(self, death_sig: signal.Signals | None = None, timeout: int = 30) -> None:
r"""
Terminates all processes managed by this context and cleans up any
meta resources (e.g. redirect, error_file files).
@ -685,7 +681,7 @@ def _wrap(
stderr_redirects: dict[int, str], # redirect file for stderr (to console if None)
ret_vals: dict[int, mp.SimpleQueue],
queue_finished_reading_event: synchronize.Event,
numa_options: Optional[NumaOptions],
numa_options: NumaOptions | None,
) -> None:
# get the per-rank params up front so we fail fast if no mapping is found
args_ = args[local_rank]
@ -721,10 +717,10 @@ class MultiprocessContext(PContext):
envs: dict[int, dict[str, str]],
start_method: str,
logs_specs: LogsSpecs,
log_line_prefixes: Optional[dict[int, str]] = None,
numa_options: Optional[NumaOptions] = None,
duplicate_stdout_filters: Optional[list[str]] = None,
duplicate_stderr_filters: Optional[list[str]] = None,
log_line_prefixes: dict[int, str] | None = None,
numa_options: NumaOptions | None = None,
duplicate_stdout_filters: list[str] | None = None,
duplicate_stderr_filters: list[str] | None = None,
):
super().__init__(
name,
@ -746,12 +742,12 @@ class MultiprocessContext(PContext):
# see comments in ``join()`` for what this is
self._return_values: dict[int, Any] = {}
self._pc: Optional[mp.ProcessContext] = None
self._pc: mp.ProcessContext | None = None
# Note: set method should ONLY be invoked for the use case when all processes finished
# successfully. If any process died on event.wait() calling set() method will deadlock.
self._worker_finished_event = mp.get_context(self.start_method).Event()
self._numa_options: Optional[NumaOptions] = numa_options
self._numa_options: NumaOptions | None = numa_options
def _start(self):
if self._pc:
@ -780,7 +776,7 @@ class MultiprocessContext(PContext):
def _is_done(self) -> bool:
return len(self._return_values) == self.nprocs
def _poll(self) -> Optional[RunProcsResult]:
def _poll(self) -> RunProcsResult | None:
assert self._pc is not None # assertion for mypy type checker
try:
@ -910,10 +906,10 @@ class SubprocessContext(PContext):
args: dict[int, tuple],
envs: dict[int, dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[dict[int, str]] = None,
numa_options: Optional[NumaOptions] = None,
duplicate_stdout_filters: Optional[list[str]] = None,
duplicate_stderr_filters: Optional[list[str]] = None,
log_line_prefixes: dict[int, str] | None = None,
numa_options: NumaOptions | None = None,
duplicate_stdout_filters: list[str] | None = None,
duplicate_stderr_filters: list[str] | None = None,
):
super().__init__(
name,
@ -930,7 +926,7 @@ class SubprocessContext(PContext):
self._running_local_ranks: set[int] = set(range(self.nprocs))
self._failures: dict[int, ProcessFailure] = {}
self.subprocess_handlers: dict[int, SubprocessHandler] = {}
self._numa_options: Optional[NumaOptions] = numa_options
self._numa_options: NumaOptions | None = numa_options
def _start(self):
if self.subprocess_handlers:
@ -965,7 +961,7 @@ class SubprocessContext(PContext):
)
# else: --> succeeded; nothing to do
def _poll(self) -> Optional[RunProcsResult]:
def _poll(self) -> RunProcsResult | None:
done_local_ranks: set[int] = set()
self._capture_process_failures(done_local_ranks)

View File

@ -312,8 +312,8 @@ class ChildFailedError(Exception):
def record(
fn: Callable[_P, _R], error_handler: Optional[ErrorHandler] = None
) -> Callable[_P, Union[_R, None]]:
fn: Callable[_P, _R], error_handler: ErrorHandler | None = None
) -> Callable[_P, _R | None]:
"""
Syntactic sugar to record errors/exceptions that happened in the decorated
function using the provided ``error_handler``.
@ -353,7 +353,7 @@ def record(
if not error_handler:
error_handler = get_error_handler()
def wrap(f: Callable[_P, _R]) -> Callable[_P, Union[_R, None]]:
def wrap(f: Callable[_P, _R]) -> Callable[_P, _R | None]:
@wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs):
assert error_handler is not None # assertion for mypy type checker

View File

@ -13,7 +13,7 @@ import os
import time
import traceback
import warnings
from typing import Any, Optional
from typing import Any
__all__ = ["ErrorHandler"]
@ -33,7 +33,7 @@ class ErrorHandler:
Subclasses should override ``initialize()`` and ``record_exception()``.
"""
def _get_error_file_path(self) -> Optional[str]:
def _get_error_file_path(self) -> str | None:
"""
Return the error file path.

View File

@ -3,7 +3,6 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
SubprocessHandler,
@ -21,7 +20,7 @@ def get_subprocess_handler(
stdout: str,
stderr: str,
local_rank_id: int,
numa_options: Optional[NumaOptions] = None,
numa_options: NumaOptions | None = None,
) -> SubprocessHandler:
return SubprocessHandler(
entrypoint=entrypoint,

View File

@ -9,7 +9,7 @@ import os
import signal
import sys
from subprocess import Popen
from typing import Any, Optional
from typing import Any
from torch.numa.binding import maybe_wrap_command_args_with_numa_binding, NumaOptions
@ -38,10 +38,10 @@ class SubprocessHandler:
entrypoint: str,
args: tuple,
env: dict[str, str],
stdout: Optional[str],
stderr: Optional[str],
stdout: str | None,
stderr: str | None,
local_rank_id: int,
numa_options: Optional[NumaOptions],
numa_options: NumaOptions | None,
):
self._stdout = open(stdout, "w") if stdout else None
self._stderr = open(stderr, "w") if stderr else None
@ -76,7 +76,7 @@ class SubprocessHandler:
**kwargs,
)
def close(self, death_sig: Optional[signal.Signals] = None) -> None:
def close(self, death_sig: signal.Signals | None = None) -> None:
if not death_sig:
death_sig = _get_default_signal()
if IS_WINDOWS:

View File

@ -13,7 +13,7 @@ import time
from collections.abc import Callable
from concurrent.futures.thread import ThreadPoolExecutor
from threading import Event
from typing import Optional, TextIO, TYPE_CHECKING
from typing import TextIO, TYPE_CHECKING
if TYPE_CHECKING:
@ -30,7 +30,7 @@ def tail_logfile(
dst: TextIO,
finished: Event,
interval_sec: float,
log_line_filter: Optional[Callable[[str], bool]] = None,
log_line_filter: Callable[[str], bool] | None = None,
):
while not os.path.exists(file):
if finished.is_set():
@ -98,7 +98,7 @@ class TailLog:
name: str,
log_files: dict[int, str],
dst: TextIO,
log_line_prefixes: Optional[dict[int, str]] = None,
log_line_prefixes: dict[int, str] | None = None,
interval_sec: float = 0.1,
log_line_filter: Callable[[str], bool] = (lambda _: True),
):

View File

@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Optional
from typing import Any
"""
@ -65,11 +65,11 @@ class Client:
raise EtcdStubError
def write(
self, key: str, value: Any, ttl: Optional[int] = None, **kwargs: Any
self, key: str, value: Any, ttl: int | None = None, **kwargs: Any
) -> None:
raise EtcdStubError
def test_and_set(
self, key: str, value: Any, prev_value: Any, ttl: Optional[int] = None
self, key: str, value: Any, prev_value: Any, ttl: int | None = None
) -> None:
raise EtcdStubError

View File

@ -9,7 +9,7 @@ import socket
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, ClassVar, Optional
from typing import Any, ClassVar
from torch.distributed import Store
from torch.distributed.elastic.utils.distributed import get_free_port
@ -72,8 +72,8 @@ class RendezvousStoreInfo:
def build(
rank: int,
store: Store,
local_addr: Optional[str],
server_port: Optional[int] = None,
local_addr: str | None,
server_port: int | None = None,
) -> "RendezvousStoreInfo":
"""Factory method, finds unused new port on rank0 host and addr/port info with all ranks.
@ -137,7 +137,7 @@ class RendezvousInfo:
return self._world_size
@property
def bootstrap_store_info(self) -> Optional[RendezvousStoreInfo]:
def bootstrap_store_info(self) -> RendezvousStoreInfo | None:
"""Store information that can used by trainer code to bootstrap distributed comms."""
return self._bootstrap_store_info
@ -265,7 +265,7 @@ class RendezvousParameters:
run_id: str,
min_nodes: int,
max_nodes: int,
local_addr: Optional[str] = None,
local_addr: str | None = None,
**kwargs,
):
if not backend:
@ -293,7 +293,7 @@ class RendezvousParameters:
"""Return the value for ``key`` if ``key`` exists, else ``default``."""
return self.config.get(key, default)
def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]:
def get_as_bool(self, key: str, default: bool | None = None) -> bool | None:
"""Return the value for ``key`` as a ``bool``."""
value = self.get(key, default)
if value is None or isinstance(value, bool):
@ -312,7 +312,7 @@ class RendezvousParameters:
f"The rendezvous configuration option '{key}' does not represent a valid boolean value."
)
def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]:
def get_as_int(self, key: str, default: int | None = None) -> int | None:
"""Return the value for ``key`` as an ``int``."""
value = self.get(key, default)
if value is None:
@ -350,7 +350,7 @@ class RendezvousHandlerRegistry:
if not backend:
raise ValueError("The rendezvous backend name must be a non-empty string.")
current_creator: Optional[RendezvousHandlerCreator]
current_creator: RendezvousHandlerCreator | None
try:
current_creator = self._registry[backend]
except KeyError:

View File

@ -11,7 +11,7 @@ import os
import tempfile
from base64 import b64decode, b64encode
from datetime import timedelta
from typing import Any, cast, Optional
from typing import Any, cast
from torch.distributed import FileStore, Store, TCPStore
from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState
@ -70,15 +70,15 @@ class C10dRendezvousBackend(RendezvousBackend):
"""See base class."""
return "c10d"
def get_state(self) -> Optional[tuple[bytes, Token]]:
def get_state(self) -> tuple[bytes, Token] | None:
"""See base class."""
base64_state: bytes = self._call_store("get", self._key)
return self._decode_state(base64_state)
def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[tuple[bytes, Token, bool]]:
self, state: bytes, token: Token | None = None
) -> tuple[bytes, Token, bool] | None:
"""See base class."""
base64_state_str: str = b64encode(state).decode()
@ -117,7 +117,7 @@ class C10dRendezvousBackend(RendezvousBackend):
"The connection to the C10d store has failed. See inner exception for details."
) from exc
def _decode_state(self, base64_state: bytes) -> Optional[tuple[bytes, Token]]:
def _decode_state(self, base64_state: bytes) -> tuple[bytes, Token] | None:
if base64_state == self._NULL_SENTINEL.encode():
return None

View File

@ -18,7 +18,7 @@ from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any, Optional
from typing import Any
import torch.distributed as dist
from torch.distributed import Store
@ -68,7 +68,7 @@ class RendezvousBackend(ABC):
"""Get the name of the backend."""
@abstractmethod
def get_state(self) -> Optional[tuple[bytes, Token]]:
def get_state(self) -> tuple[bytes, Token] | None:
"""Get the rendezvous state.
Returns:
@ -84,8 +84,8 @@ class RendezvousBackend(ABC):
@abstractmethod
def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[tuple[bytes, Token, bool]]:
self, state: bytes, token: Token | None = None
) -> tuple[bytes, Token, bool] | None:
"""Set the rendezvous state.
The new rendezvous state is set conditionally:
@ -154,10 +154,10 @@ class RendezvousTimeout:
def __init__(
self,
join: Optional[timedelta] = None,
last_call: Optional[timedelta] = None,
close: Optional[timedelta] = None,
heartbeat: Optional[timedelta] = None,
join: timedelta | None = None,
last_call: timedelta | None = None,
close: timedelta | None = None,
heartbeat: timedelta | None = None,
) -> None:
self._set_timeouts(
join=join, last_call=last_call, close=close, heartbeat=heartbeat
@ -183,7 +183,7 @@ class RendezvousTimeout:
"""Get the keep-alive heartbeat timeout."""
return self._heartbeat
def _set_timeouts(self, **timeouts: Optional[timedelta]):
def _set_timeouts(self, **timeouts: timedelta | None):
for name, timeout in timeouts.items():
if timeout is None:
timeout = self._DEFAULT_TIMEOUTS[name]
@ -258,7 +258,7 @@ class _NodeDescGenerator:
# An integer that is incremented with each call to generate().
self._local_id = 0
def generate(self, local_addr: Optional[str] = None) -> _NodeDesc:
def generate(self, local_addr: str | None = None) -> _NodeDesc:
# This method can be called by multiple threads concurrently; therefore,
# we must increment the integer atomically.
with self._lock:
@ -297,7 +297,7 @@ class _RendezvousState:
round: int
complete: bool
deadline: Optional[datetime]
deadline: datetime | None
closed: bool
participants: dict[_NodeDesc, int]
wait_list: set[_NodeDesc]
@ -345,7 +345,7 @@ class _RendezvousStateHolder(ABC):
"""Get the local state."""
@abstractmethod
def sync(self) -> Optional[bool]:
def sync(self) -> bool | None:
"""Read or writes the latest state.
Returns:
@ -408,13 +408,13 @@ class _BackendRendezvousStateHolder(_RendezvousStateHolder):
"""See base class."""
return self._state
def sync(self) -> Optional[bool]:
def sync(self) -> bool | None:
"""See base class."""
state_bits: Optional[bytes] = None
state_bits: bytes | None = None
token = None
has_set: Optional[bool]
has_set: bool | None
if self._dirty:
has_set = False
@ -574,7 +574,7 @@ class _RendezvousOpExecutor(ABC):
self,
state_handler: Callable[[_RendezvousContext, float], _Action],
deadline: float,
update_deadline: Optional[Callable[[timedelta], float]] = None,
update_deadline: Callable[[timedelta], float] | None = None,
) -> None:
"""Execute a rendezvous operation.
@ -638,7 +638,7 @@ class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor):
self,
state_handler: Callable[[_RendezvousContext, float], _Action],
deadline: float,
update_deadline: Optional[Callable[[timedelta], float]] = None,
update_deadline: Callable[[timedelta], float] | None = None,
) -> None:
"""See base class."""
action = None
@ -1006,7 +1006,7 @@ class DynamicRendezvousHandler(RendezvousHandler):
_state_holder: _RendezvousStateHolder
_op_executor: _RendezvousOpExecutor
_heartbeat_lock: threading.Lock
_keep_alive_timer: Optional[_PeriodicTimer]
_keep_alive_timer: _PeriodicTimer | None
@classmethod
def from_backend(
@ -1016,8 +1016,8 @@ class DynamicRendezvousHandler(RendezvousHandler):
backend: RendezvousBackend,
min_nodes: int,
max_nodes: int,
local_addr: Optional[str] = None,
timeout: Optional[RendezvousTimeout] = None,
local_addr: str | None = None,
timeout: RendezvousTimeout | None = None,
keep_alive_interval: int = 5,
keep_alive_max_attempt: int = 3,
):
@ -1102,15 +1102,15 @@ class DynamicRendezvousHandler(RendezvousHandler):
self._keep_alive_timer = None
# Cached shared store server reference
self._shared_tcp_store_server: Optional[dist.Store] = None
self._shared_tcp_store_server: dist.Store | None = None
self._bootstrap_store_info: Optional[RendezvousStoreInfo] = None
self._bootstrap_store_info: RendezvousStoreInfo | None = None
def _record(
self,
message: str,
node_state: NodeState = NodeState.RUNNING,
rank: Optional[int] = None,
rank: int | None = None,
) -> None:
construct_and_record_rdzv_event(
name=f"{self.__class__.__name__}.{get_method_name()}",
@ -1379,7 +1379,7 @@ class DynamicRendezvousHandler(RendezvousHandler):
return time.monotonic() + timeout.total_seconds()
def _get_timeout(params: RendezvousParameters, key: str) -> Optional[timedelta]:
def _get_timeout(params: RendezvousParameters, key: str) -> timedelta | None:
timeout = params.get_as_int(key + "_timeout")
if timeout is None:
return None

View File

@ -12,7 +12,6 @@ import logging
import sys
import threading
import time
from typing import Optional
try:
@ -153,7 +152,7 @@ class EtcdRendezvousHandler(RendezvousHandler):
+--------------------------------------------+--------------------------+
"""
def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: Optional[str]):
def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: str | None):
"""
Args:
rdzv_impl: the implementation of the rendezvous
@ -542,7 +541,7 @@ class EtcdRendezvous:
# When reaching min workers, or changing state to frozen, we'll set
# the active_version node to be ephemeral.
set_ttl: Optional[int] = None
set_ttl: int | None = None
if len(state["participants"]) == self._num_max_workers:
state["status"] = "frozen"
state["keep_alives"] = []

View File

@ -7,7 +7,7 @@
import binascii
from base64 import b64decode, b64encode
from typing import cast, Optional
from typing import cast
import urllib3.exceptions # type: ignore[import]
@ -49,8 +49,8 @@ class EtcdRendezvousBackend(RendezvousBackend):
self,
client: etcd.Client,
run_id: str,
key_prefix: Optional[str] = None,
ttl: Optional[int] = None,
key_prefix: str | None = None,
ttl: int | None = None,
) -> None:
if not run_id:
raise ValueError("The run id must be a non-empty string.")
@ -72,7 +72,7 @@ class EtcdRendezvousBackend(RendezvousBackend):
"""See base class."""
return "etcd-v2"
def get_state(self) -> Optional[tuple[bytes, Token]]:
def get_state(self) -> tuple[bytes, Token] | None:
"""See base class."""
try:
result = self._client.read(self._key)
@ -86,8 +86,8 @@ class EtcdRendezvousBackend(RendezvousBackend):
return self._decode_state(result)
def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[tuple[bytes, Token, bool]]:
self, state: bytes, token: Token | None = None
) -> tuple[bytes, Token, bool] | None:
"""See base class."""
base64_state = b64encode(state).decode()

View File

@ -15,7 +15,7 @@ import socket
import subprocess
import tempfile
import time
from typing import Optional, TextIO, Union
from typing import TextIO
try:
@ -64,7 +64,7 @@ def find_free_port():
raise RuntimeError("Failed to create a socket")
def stop_etcd(subprocess, data_dir: Optional[str] = None):
def stop_etcd(subprocess, data_dir: str | None = None):
if subprocess and subprocess.poll() is None:
logger.info("stopping etcd server")
subprocess.terminate()
@ -107,7 +107,7 @@ class EtcdServer:
etcd_binary_path: path of etcd server binary (see above for fallback path)
"""
def __init__(self, data_dir: Optional[str] = None):
def __init__(self, data_dir: str | None = None):
self._port = -1
self._host = "localhost"
@ -123,7 +123,7 @@ class EtcdServer:
data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data")
)
self._etcd_cmd = None
self._etcd_proc: Optional[subprocess.Popen] = None
self._etcd_proc: subprocess.Popen | None = None
def _get_etcd_server_process(self) -> subprocess.Popen:
if not self._etcd_proc:
@ -149,7 +149,7 @@ class EtcdServer:
self,
timeout: int = 60,
num_retries: int = 3,
stderr: Union[int, TextIO, None] = None,
stderr: int | TextIO | None = None,
) -> None:
"""
Start the server, and waits for it to be ready. When this function returns the sever is ready to take requests.
@ -185,7 +185,7 @@ class EtcdServer:
atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir)
def _start(
self, data_dir: str, timeout: int = 60, stderr: Union[int, TextIO, None] = None
self, data_dir: str, timeout: int = 60, stderr: int | TextIO | None = None
) -> None:
sock = find_free_port()
sock_peer = find_free_port()

View File

@ -9,7 +9,6 @@ import datetime
import random
import time
from base64 import b64decode, b64encode
from typing import Optional
# pyre-ignore[21]: Could not find name `Store` in `torch.distributed`.
from torch.distributed import Store
@ -40,7 +39,7 @@ class EtcdStore(Store):
etcd_client,
etcd_store_prefix,
# Default timeout same as in c10d/Store.hpp
timeout: Optional[datetime.timedelta] = None,
timeout: datetime.timedelta | None = None,
):
super().__init__() # required for pybind trampoline.
@ -121,7 +120,7 @@ class EtcdStore(Store):
except etcd.EtcdCompareFailed:
cas_delay()
def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None):
def wait(self, keys, override_timeout: datetime.timedelta | None = None):
"""
Wait until all of the keys are published, or until timeout.

View File

@ -9,7 +9,7 @@
import datetime
import logging
from typing import cast, Optional
from typing import cast
from torch.distributed import PrefixStore, Store, TCPStore
from torch.distributed.elastic.rendezvous import (
@ -51,7 +51,7 @@ class StaticTCPRendezvous(RendezvousHandler):
self.world_size = world_size
self.run_id = run_id
self.timeout = datetime.timedelta(seconds=timeout)
self._store: Optional[Store] = None
self._store: Store | None = None
def get_backend(self) -> str:
return "static"

View File

@ -14,7 +14,7 @@ import weakref
from collections.abc import Callable
from datetime import timedelta
from threading import Event, Thread
from typing import Any, Optional, Union
from typing import Any
__all__ = ["parse_rendezvous_endpoint"]
@ -44,7 +44,7 @@ def _parse_rendezvous_config(config_str: str) -> dict[str, str]:
"<key1>=<value1>,...,<keyN>=<valueN>."
)
value: Optional[str]
value: str | None
if values:
value = values[0].strip()
else:
@ -58,7 +58,7 @@ def _parse_rendezvous_config(config_str: str) -> dict[str, str]:
return config
def _try_parse_port(port_str: str) -> Optional[int]:
def _try_parse_port(port_str: str) -> int | None:
"""Try to extract the port number from ``port_str``."""
if port_str and re.match(r"^[0-9]{1,5}$", port_str):
return int(port_str)
@ -66,7 +66,7 @@ def _try_parse_port(port_str: str) -> Optional[int]:
def parse_rendezvous_endpoint(
endpoint: Optional[str], default_port: int
endpoint: str | None, default_port: int
) -> tuple[str, int]:
"""Extract the hostname and the port number from a rendezvous endpoint.
@ -166,7 +166,7 @@ def _matches_machine_hostname(host: str) -> bool:
return False
def _delay(seconds: Union[float, tuple[float, float]]) -> None:
def _delay(seconds: float | tuple[float, float]) -> None:
"""Suspend the current thread for ``seconds``.
Args:
@ -200,9 +200,9 @@ class _PeriodicTimer:
kwargs: dict[str, Any]
stop_event: Event
_name: Optional[str]
_thread: Optional[Thread]
_finalizer: Optional[weakref.finalize]
_name: str | None
_thread: Thread | None
_finalizer: weakref.finalize | None
# The context that is shared between the timer and the background thread.
_ctx: _Context
@ -227,7 +227,7 @@ class _PeriodicTimer:
self._finalizer = None
@property
def name(self) -> Optional[str]:
def name(self) -> str | None:
"""Get the name of the timer."""
return self._name

View File

@ -10,7 +10,7 @@ import threading
import time
from contextlib import contextmanager
from inspect import getframeinfo, stack
from typing import Any, Optional
from typing import Any
__all__ = [
@ -130,7 +130,7 @@ class TimerServer(abc.ABC):
self._request_queue = request_queue
self._max_interval = max_interval
self._daemon = daemon
self._watchdog_thread: Optional[threading.Thread] = None
self._watchdog_thread: threading.Thread | None = None
self._stop_signaled = False
@abc.abstractmethod
@ -234,7 +234,7 @@ class TimerServer(abc.ABC):
logger.info("No watchdog thread running, doing nothing")
_timer_client: Optional[TimerClient] = None
_timer_client: TimerClient | None = None
def configure(timer_client: TimerClient):
@ -247,9 +247,7 @@ def configure(timer_client: TimerClient):
@contextmanager
def expires(
after: float, scope: Optional[str] = None, client: Optional[TimerClient] = None
):
def expires(after: float, scope: str | None = None, client: TimerClient | None = None):
"""
Acquires a countdown timer that expires in ``after`` seconds from now,
unless the code-block that it wraps is finished within the timeframe.

View File

@ -14,7 +14,7 @@ import sys
import threading
import time
from collections.abc import Callable
from typing import Optional, TypeVar
from typing import TypeVar
from typing_extensions import ParamSpec
from torch.distributed.elastic.timer.api import TimerClient, TimerRequest
@ -131,7 +131,7 @@ class FileTimerClient(TimerClient):
self.signal = signal
@_retry(max_retries=10, sleep_time=0.1)
def _open_non_blocking(self) -> Optional[io.TextIOWrapper]:
def _open_non_blocking(self) -> io.TextIOWrapper | None:
# The server may have crashed or may haven't started yet.
# In such case, calling open() in blocking model blocks the client.
# To avoid such issue, open it in non-blocking mode, and an OSError will
@ -200,7 +200,7 @@ class FileTimerServer:
run_id: str,
max_interval: float = 10,
daemon: bool = True,
log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None,
log_event: Callable[[str, FileTimerRequest | None], None] | None = None,
) -> None:
self._file_path = file_path
self._run_id = run_id
@ -208,7 +208,7 @@ class FileTimerServer:
self._daemon = daemon
self._timers: dict[tuple[int, str], FileTimerRequest] = {}
self._stop_signaled = False
self._watchdog_thread: Optional[threading.Thread] = None
self._watchdog_thread: threading.Thread | None = None
self._is_client_started = False
if os.path.exists(self._file_path):

View File

@ -8,7 +8,7 @@
import math
from collections.abc import Iterator, Sized
from typing import cast, Optional, TypeVar
from typing import cast, TypeVar
import torch
from torch.utils.data import Dataset
@ -44,8 +44,8 @@ class ElasticDistributedSampler(DistributedSampler[T]):
def __init__(
self,
dataset: Dataset[T],
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
num_replicas: int | None = None,
rank: int | None = None,
start_index: int = 0,
):
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank)

View File

@ -10,7 +10,6 @@ import datetime
import os
import socket
from contextlib import closing
from typing import Optional
import torch.distributed as dist
from torch.distributed.elastic.utils.logging import get_logger
@ -35,7 +34,7 @@ def create_c10d_store(
timeout: float = (60 * 10), # 10 min
wait_for_workers: bool = True,
retries=3,
use_libuv: Optional[bool] = None,
use_libuv: bool | None = None,
):
if use_libuv is not None:
logger.warning(

View File

@ -10,12 +10,11 @@ import inspect
import logging
import os
import warnings
from typing import Optional
from torch.distributed.elastic.utils.log_level import get_log_level
def get_logger(name: Optional[str] = None) -> logging.Logger:
def get_logger(name: str | None = None) -> logging.Logger:
"""
Util function to set up a simple logger that writes
into stderr. The loglevel is fetched from the LOGLEVEL
@ -32,13 +31,13 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
return _setup_logger(name or _derive_module_name(depth=2))
def _setup_logger(name: Optional[str] = None) -> logging.Logger:
def _setup_logger(name: str | None = None) -> logging.Logger:
logger = logging.getLogger(name)
logger.setLevel(os.environ.get("LOGLEVEL", get_log_level()))
return logger
def _derive_module_name(depth: int = 1) -> Optional[str]:
def _derive_module_name(depth: int = 1) -> str | None:
"""
Derives the name of the caller module from the stack frames.

View File

@ -10,7 +10,6 @@
from collections.abc import Callable, Iterable
from contextlib import contextmanager
from datetime import timedelta
from typing import Optional
import torch
@ -109,7 +108,7 @@ def _try_detecting_missing_ranks(
rank: int,
rank_decoder: Callable[[int], str],
trace_timeout: float,
) -> Optional[Iterable[str]]:
) -> Iterable[str] | None:
store.set(f"{key_prefix}{rank}{_TRACE}", "<val_ignored>")
def _find_missing_ranks():
@ -169,8 +168,8 @@ def barrier(
world_size: int,
key_prefix: str,
barrier_timeout: float = 300,
rank: Optional[int] = None,
rank_tracing_decoder: Optional[Callable[[int], str]] = None,
rank: int | None = None,
rank_tracing_decoder: Callable[[int], str] | None = None,
trace_timeout: float = 10,
) -> None:
"""

View File

@ -11,7 +11,7 @@ import sys
import uuid
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, Optional, Union
from typing import Any
import torch
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
@ -90,7 +90,7 @@ class LaunchConfig:
min_nodes: int
max_nodes: int
nproc_per_node: int
logs_specs: Optional[LogsSpecs] = None
logs_specs: LogsSpecs | None = None
run_id: str = ""
role: str = "default_role"
rdzv_endpoint: str = ""
@ -100,14 +100,14 @@ class LaunchConfig:
max_restarts: int = 3
monitor_interval: float = 0.1
start_method: str = "spawn"
log_line_prefix_template: Optional[str] = None
log_line_prefix_template: str | None = None
metrics_cfg: dict[str, str] = field(default_factory=dict)
local_addr: Optional[str] = None
local_addr: str | None = None
event_log_handler: str = "null"
numa_options: Optional[NumaOptions] = None
numa_options: NumaOptions | None = None
signals_to_handle: str = "SIGTERM,SIGINT,SIGHUP,SIGQUIT"
duplicate_stdout_filters: Optional[list[str]] = None
duplicate_stderr_filters: Optional[list[str]] = None
duplicate_stdout_filters: list[str] | None = None
duplicate_stderr_filters: list[str] | None = None
virtual_local_rank: bool = False
def __post_init__(self):
@ -161,7 +161,7 @@ class elastic_launch:
def __init__(
self,
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
entrypoint: Callable | str | None,
):
self._config = config
self._entrypoint = entrypoint
@ -170,9 +170,7 @@ class elastic_launch:
return launch_agent(self._config, self._entrypoint, list(args))
def _get_entrypoint_name(
entrypoint: Union[Callable, str, None], args: list[Any]
) -> str:
def _get_entrypoint_name(entrypoint: Callable | str | None, args: list[Any]) -> str:
"""Retrieve entrypoint name with the rule:
1. If entrypoint is a function, use ``entrypoint.__qualname__``.
2. If entrypoint is a string, check its value:
@ -194,7 +192,7 @@ def _get_entrypoint_name(
def _get_addr_and_port(
rdzv_parameters: RendezvousParameters,
) -> tuple[Optional[str], Optional[int]]:
) -> tuple[str | None, int | None]:
if rdzv_parameters.backend != "static":
return (None, None)
endpoint = rdzv_parameters.endpoint
@ -213,7 +211,7 @@ def _get_addr_and_port(
def launch_agent(
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
entrypoint: Callable | str | None,
args: list[Any],
) -> dict[int, Any]:
if not config.run_id:

View File

@ -5,7 +5,7 @@ import io
import sys
import types
from collections.abc import Callable, Iterator, Mapping
from typing import Any, Optional, TypeVar, Union
from typing import Any, TypeVar, Union
from typing_extensions import Self
import torch
@ -122,8 +122,8 @@ class _RemoteModule(nn.Module):
self,
remote_device: str,
module_cls: type[nn.Module],
args: Optional[tuple] = None,
kwargs: Optional[dict[str, Any]] = None,
args: tuple | None = None,
kwargs: dict[str, Any] | None = None,
_module_interface_cls: Any = None,
):
"""
@ -310,32 +310,32 @@ class _RemoteModule(nn.Module):
)
def register_buffer(
self, name: str, tensor: Optional[Tensor], persistent: bool = True
self, name: str, tensor: Tensor | None, persistent: bool = True
) -> None:
_raise_not_supported(self.register_buffer.__name__)
def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
def register_parameter(self, name: str, param: Parameter | None) -> None:
_raise_not_supported(self.register_parameter.__name__)
def add_module(self, name: str, module: Optional[Module]) -> None:
def add_module(self, name: str, module: Module | None) -> None:
_raise_not_supported(self.add_module.__name__)
def apply(self, fn: Callable[[Module], None]) -> Self: # type: ignore[return]
_raise_not_supported(self.apply.__name__)
def cuda(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return]
def cuda(self, device: int | device | None = None) -> Self: # type: ignore[return]
_raise_not_supported(self.cuda.__name__)
def ipu(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return]
def ipu(self, device: int | device | None = None) -> Self: # type: ignore[return]
_raise_not_supported(self.ipu.__name__)
def xpu(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return]
def xpu(self, device: int | device | None = None) -> Self: # type: ignore[return]
_raise_not_supported(self.xpu.__name__)
def cpu(self) -> Self: # type: ignore[return]
_raise_not_supported(self.cpu.__name__)
def type(self, dst_type: Union[dtype, str]) -> Self: # type: ignore[return]
def type(self, dst_type: dtype | str) -> Self: # type: ignore[return]
_raise_not_supported(self.type.__name__)
def float(self) -> Self: # type: ignore[return]
@ -355,19 +355,16 @@ class _RemoteModule(nn.Module):
def register_backward_hook( # type: ignore[return]
self,
hook: Callable[[Module, _grad_t, _grad_t], Union[None, _grad_t]],
hook: Callable[[Module, _grad_t, _grad_t], None | _grad_t],
# pyrefly: ignore [bad-return]
) -> RemovableHandle:
_raise_not_supported(self.register_backward_hook.__name__)
def register_forward_pre_hook( # type: ignore[return]
self,
hook: Union[
Callable[[T, tuple[Any, ...]], Optional[Any]],
Callable[
[T, tuple[Any, ...], dict[str, Any]],
Optional[tuple[Any, dict[str, Any]]],
],
hook: Callable[[T, tuple[Any, ...]], Any | None]
| Callable[
[T, tuple[Any, ...], dict[str, Any]], tuple[Any, dict[str, Any]] | None
],
prepend: bool = False,
with_kwargs: bool = False,
@ -377,10 +374,8 @@ class _RemoteModule(nn.Module):
def register_forward_hook( # type: ignore[return, override]
self,
hook: Union[
Callable[[T, tuple[Any, ...], Any], Optional[Any]],
Callable[[T, tuple[Any, ...], dict[str, Any], Any], Optional[Any]],
],
hook: Callable[[T, tuple[Any, ...], Any], Any | None]
| Callable[[T, tuple[Any, ...], dict[str, Any], Any], Any | None],
prepend: bool = False,
with_kwargs: bool = False,
# pyrefly: ignore [bad-return]
@ -435,7 +430,7 @@ class _RemoteModule(nn.Module):
def named_modules(
self,
memo: Optional[set[Module]] = None,
memo: set[Module] | None = None,
prefix: str = "",
remove_duplicate: bool = True,
):
@ -694,8 +689,8 @@ class RemoteModule(_RemoteModule):
self,
remote_device: str,
module_cls: type[nn.Module],
args: Optional[tuple] = None,
kwargs: Optional[dict[str, Any]] = None,
args: tuple | None = None,
kwargs: dict[str, Any] | None = None,
):
super().__init__(remote_device, module_cls, args, kwargs)

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
import torch.optim._functional as F
@ -53,7 +52,7 @@ class _FunctionalAdadelta:
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
def step(self, gradients: list[Optional[Tensor]]):
def step(self, gradients: list[Tensor | None]):
params = self.param_group["params"]
params_with_grad = []
grads = []

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
import torch.optim._functional as F
@ -70,7 +69,7 @@ class _FunctionalAdagrad:
"step": torch.tensor(0.0),
}
def step(self, gradients: list[Optional[Tensor]]):
def step(self, gradients: list[Tensor | None]):
params = self.param_group["params"]
params_with_grad = []
grads = []

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
import torch.optim._functional as F
@ -68,7 +67,7 @@ class _FunctionalAdam:
# param group as it's not a common use case.
self.param_group = {"params": params}
def step_param(self, param: Tensor, grad: Optional[Tensor]):
def step_param(self, param: Tensor, grad: Tensor | None):
"""
Similar to step, but operates on a single parameter and optionally a
gradient tensor.
@ -128,7 +127,7 @@ class _FunctionalAdam:
found_inf=None,
)
def step(self, gradients: list[Optional[Tensor]]):
def step(self, gradients: list[Tensor | None]):
params = self.param_group["params"]
params_with_grad = []
grads = []

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
import torch.optim._functional as F
@ -64,7 +63,7 @@ class _FunctionalAdamax:
# param group as it's not a common use case.
self.param_group = {"params": params}
def step(self, gradients: list[Optional[Tensor]]):
def step(self, gradients: list[Tensor | None]):
params = self.param_group["params"]
params_with_grad = []
grads = []

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
import torch.optim._functional as F
@ -68,7 +67,7 @@ class _FunctionalAdamW:
# param group as it's not a common use case.
self.param_group = {"params": params}
def step_param(self, param: Tensor, grad: Optional[Tensor]):
def step_param(self, param: Tensor, grad: Tensor | None):
params_with_grad = []
grads = []
exp_avgs = []
@ -129,7 +128,7 @@ class _FunctionalAdamW:
has_complex=has_complex,
)
def step(self, gradients: list[Optional[Tensor]]):
def step(self, gradients: list[Tensor | None]):
params = self.param_group["params"]
params_with_grad = []
grads = []

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
import torch.optim._functional as F
@ -57,7 +56,7 @@ class _FunctionalRMSprop:
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
def step(self, gradients: list[Optional[Tensor]]):
def step(self, gradients: list[Tensor | None]):
params = self.param_group["params"]
params_with_grad = []
grads = []

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
import torch.optim._functional as F
@ -51,7 +50,7 @@ class _FunctionalRprop:
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
def step(self, gradients: list[Optional[Tensor]]):
def step(self, gradients: list[Tensor | None]):
params = self.param_group["params"]
params_with_grad = []
grads = []

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
import torch.optim._functional as F
@ -56,7 +55,7 @@ class _FunctionalSGD:
# param group as it's not a common use case.
self.param_group = {"params": params}
def step_param(self, param: Tensor, grad: Optional[Tensor]):
def step_param(self, param: Tensor, grad: Tensor | None):
"""Similar to self.step, but operates on a single parameter and
its gradient.
"""
@ -67,7 +66,7 @@ class _FunctionalSGD:
dampening = self.defaults["dampening"]
lr = self.defaults["lr"]
params = [param]
momentum_buffer_list: list[Optional[Tensor]] = []
momentum_buffer_list: list[Tensor | None] = []
grads = []
has_sparse_grad = False
@ -106,11 +105,11 @@ class _FunctionalSGD:
if momentum_buffer is not None:
state["momentum_buffer"] = momentum_buffer
def step(self, gradients: list[Optional[Tensor]]):
def step(self, gradients: list[Tensor | None]):
params = self.param_group["params"]
params_with_grad = []
grads = []
momentum_buffer_list: list[Optional[Tensor]] = []
momentum_buffer_list: list[Tensor | None] = []
lr = self.defaults["lr"]
weight_decay = self.defaults["weight_decay"]
momentum = self.defaults["momentum"]

View File

@ -2,7 +2,7 @@ import logging
import warnings
from collections.abc import Callable, Collection, Mapping
from copy import deepcopy
from typing import Any, Optional, overload, Union
from typing import Any, overload
import torch
import torch.nn as nn
@ -62,10 +62,10 @@ class _NamedOptimizer(optim.Optimizer):
def __init__(
self,
named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]],
named_parameters: Mapping[str, torch.Tensor | ShardedTensor],
optimizer_class: optim.Optimizer,
param_groups: Optional[Collection[Mapping[str, Any]]] = None,
module: Optional[nn.Module] = None,
param_groups: Collection[Mapping[str, Any]] | None = None,
module: nn.Module | None = None,
*args: tuple[Any, ...],
**kwargs: dict[str, Any],
) -> None:
@ -152,7 +152,7 @@ class _NamedOptimizer(optim.Optimizer):
@overload
def step(self, closure: Callable[[], float]) -> float: ...
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
def step(self, closure: Callable[[], float] | None = None) -> float | None:
"""
Perform a single optimization step.

View File

@ -2,7 +2,6 @@
import logging
from collections import defaultdict
from threading import Lock
from typing import Optional
import torch
import torch.distributed.autograd as dist_autograd
@ -51,7 +50,7 @@ class _ScriptLocalOptimizer(nn.Module):
def step(self, autograd_ctx_id: int):
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
# apply functional optimizer step with a list of gradients
grads: list[Optional[Tensor]] = [
grads: list[Tensor | None] = [
all_local_grads[p] if p in all_local_grads else None # noqa: SIM401
for p in self._local_params
]

View File

@ -13,7 +13,7 @@ import io
import logging
from collections.abc import Callable
from itertools import chain
from typing import Any, Optional, Union
from typing import Any
import torch
import torch.distributed as dist
@ -173,7 +173,7 @@ class _DDPBucketAssignment:
# DDP guarantees all parameters in the bucket have the same device
# pyrefly: ignore [read-only]
self.device: torch.device = self.parameters[0].device
self.tensor: Optional[torch.Tensor] = None
self.tensor: torch.Tensor | None = None
class _OverlapStatus(enum.IntEnum):
@ -252,7 +252,7 @@ class _OverlapInfo:
# Group Ranks
self.assigned_ranks_per_bucket: list[set[int]] = []
self.num_bucket_assignments: int = 0
self.total_size: Optional[int] = None
self.total_size: int | None = None
# Modified per iteration
self.broadcast_handles: list[Any] = []
@ -377,7 +377,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
self,
params,
optimizer_class: type[Optimizer],
process_group: Optional[Any] = None,
process_group: Any | None = None,
parameters_as_bucket_view: bool = False,
overlap_with_ddp: bool = False,
**defaults: Any,
@ -649,7 +649,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
def _partition_parameters(
self,
params_per_rank: Optional[list[list[torch.Tensor]]] = None,
params_per_rank: list[list[torch.Tensor]] | None = None,
) -> list[list[dict]]:
r"""
Partitions parameters across distributed data parallel ranks.
@ -869,7 +869,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
def _get_min_index(
self,
values: list[int],
disallowed_indices: Optional[set[int]] = None,
disallowed_indices: set[int] | None = None,
) -> int:
r"""
Return ``values.index(min(values))``, except only uses one pass.
@ -1036,10 +1036,10 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
def _local_step(
self,
gradients: Optional[list[Optional[torch.Tensor]]] = None,
closure: Optional[Callable[[], float]] = None,
gradients: list[torch.Tensor | None] | None = None,
closure: Callable[[], float] | None = None,
**kwargs: Any,
) -> Optional[float]:
) -> float | None:
r"""
Perform a single optimizer step without syncing parameters across ranks.
@ -1111,9 +1111,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
# pyrefly: ignore [bad-override]
def step(
self,
closure: Optional[Callable[[], float]] = None,
closure: Callable[[], float] | None = None,
**kwargs: Any,
) -> Optional[float]:
) -> float | None:
r"""
Perform a single optimizer step and syncs parameters across all ranks.
@ -1403,7 +1403,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
def _verify_and_init_params(
self,
params: Any,
) -> Union[list[torch.Tensor], list[dict]]:
) -> list[torch.Tensor] | list[dict]:
r"""
Verify the type of ``params`` and initializes ``self._all_params`` as a :class:`list` of all parameters.

View File

@ -8,7 +8,7 @@ from collections.abc import Callable
from enum import Enum
from inspect import Parameter, Signature, signature
from types import MethodType
from typing import Any, Optional, Union
from typing import Any, Union
import torch
import torch.fx as fx
@ -165,7 +165,7 @@ def _insert_stage_symbolic_backward(
# We will only emit backward operations for nodes that can contribute
# to the specified loss value.
live_nodes = {loss_node: None}
val_to_grad: dict[fx.Node, Optional[fx.Node]] = {loss_node: None}
val_to_grad: dict[fx.Node, fx.Node | None] = {loss_node: None}
def assign_or_accumulate_grad(forward_node, grad_value):
if forward_node in val_to_grad and forward_node.op != "placeholder":
@ -186,7 +186,7 @@ def _insert_stage_symbolic_backward(
fx.node.map_arg(node.args, add_to_live_nodes)
fx.node.map_arg(node.kwargs, add_to_live_nodes)
if node.op == "call_module":
output_grads: Union[tuple[Optional[fx.Node], ...], Optional[fx.Node]]
output_grads: tuple[fx.Node | None, ...] | fx.Node | None
if node in tuples:
stage_output = tuples[node]
output_grads = tuple(val_to_grad.get(n) for n in tuples[node])
@ -680,11 +680,10 @@ class Pipe(torch.nn.Module):
def _from_traced(
mod: torch.nn.Module,
exported_program: ExportedProgram,
multi_use_param_spec: Optional[MultiUseParamSpec] = None,
multi_use_param_spec: MultiUseParamSpec | None = None,
output_loss_value_spec=None,
split_policy: Optional[
Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
] = None,
split_policy: Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
| None = None,
):
"""
Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate
@ -1012,7 +1011,7 @@ class Pipe(torch.nn.Module):
def _trace_with_export(
mod: torch.nn.Module,
example_args: tuple[Any, ...],
example_kwargs: Optional[dict[str, Any]] = None,
example_kwargs: dict[str, Any] | None = None,
) -> ExportedProgram:
logger.info("Tracing model ...")
try:
@ -1032,8 +1031,8 @@ class Pipe(torch.nn.Module):
def from_tracing(
mod: torch.nn.Module,
example_args: tuple[Any, ...],
example_kwargs: Optional[dict[str, Any]] = None,
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
example_kwargs: dict[str, Any] | None = None,
split_policy: Callable[[fx.GraphModule], fx.GraphModule] | None = None,
):
# If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across
# stages instead of TRANSMIT'ting it
@ -1120,7 +1119,7 @@ class Pipe(torch.nn.Module):
self,
stage_index: int,
device: torch.device,
group: Optional[ProcessGroup] = None,
group: ProcessGroup | None = None,
) -> _PipelineStage:
"""
Create a `PipelineStage` given a stage index and distributed group.
@ -1209,9 +1208,9 @@ def annotate_split_points(mod: torch.nn.Module, spec: dict[str, SplitPoint]):
def pipeline(
module: torch.nn.Module,
mb_args: tuple[Any, ...],
mb_kwargs: Optional[dict[str, Any]] = None,
split_spec: Optional[dict[str, SplitPoint]] = None,
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
mb_kwargs: dict[str, Any] | None = None,
split_spec: dict[str, SplitPoint] | None = None,
split_policy: Callable[[fx.GraphModule], fx.GraphModule] | None = None,
) -> Pipe:
"""
Split a module based on a specification.

View File

@ -3,7 +3,7 @@
import collections
import logging
from collections.abc import Iterator
from typing import Any, Optional, Union
from typing import Any
import torch
from torch.autograd.graph import GradientEdge, Node
@ -15,7 +15,7 @@ from ._debug import map_debug_info
logger = logging.getLogger(__name__)
def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Union[Node, None]:
def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node | None:
"""
Get the grad function or grad accumulator for a tensor.
@ -142,10 +142,10 @@ def get_param_groups(
def stage_backward_input(
stage_outputs_or_loss: list[torch.Tensor],
output_grads: Optional[list[torch.Tensor]],
output_grads: list[torch.Tensor] | None,
input_values: list[torch.Tensor],
weights: Iterator[Parameter],
) -> tuple[tuple[Optional[torch.Tensor], ...], list[dict[str, Any]]]:
) -> tuple[tuple[torch.Tensor | None, ...], list[dict[str, Any]]]:
"""
Compute the gradients for only the stage inputs with
respect to the stage outputs (if non-last stage) or loss (if last stage)
@ -225,10 +225,10 @@ def stage_backward_input(
def stage_backward_weight(
weights: Iterator[Parameter], param_groups: list[dict[str, Any]], retain_graph=False
) -> tuple[Optional[torch.Tensor], ...]:
) -> tuple[torch.Tensor | None, ...]:
# map weights to param_group_weights
grad_acc_to_weight = {}
weight_grads: list[Optional[torch.Tensor]] = []
weight_grads: list[torch.Tensor | None] = []
for index, weight in enumerate(weights):
grad_acc = _get_grad_fn_or_grad_acc(weight)
grad_acc_to_weight[grad_acc] = weight, index
@ -282,8 +282,8 @@ def stage_backward(
stage_output,
output_grads,
input_values,
outputs_with_grads_idxs: Optional[list[int]] = None, # deprecated, not used
) -> tuple[Optional[torch.Tensor], ...]:
outputs_with_grads_idxs: list[int] | None = None, # deprecated, not used
) -> tuple[torch.Tensor | None, ...]:
"""
This is a helper function to:
1. compute the gradients for the stage inputs, and
@ -303,7 +303,7 @@ def stage_backward(
# stage_output may be a composite datatype like dict. Extract all individual
# tensor values here
stage_output_tensors: list[torch.Tensor] = []
output_grad_tensors: list[Optional[torch.Tensor]] = []
output_grad_tensors: list[torch.Tensor | None] = []
def extract_tensors_with_grads(
output_val,
@ -363,7 +363,7 @@ def stage_backward(
)
# Extract gradients wrt the input values
grad_inputs: list[Optional[torch.Tensor]] = []
grad_inputs: list[torch.Tensor | None] = []
for val in input_values:
if isinstance(val, torch.Tensor):
grad_inputs.append(val.grad)

View File

@ -10,7 +10,7 @@ visualize_schedule(ops, "test.png")
"""
import collections
from typing import NamedTuple, Optional, Union
from typing import NamedTuple
from unittest import mock
from torch.distributed.pipelining.schedules import (
@ -32,13 +32,13 @@ class OpKey(NamedTuple):
def get_schedule_ops(
schedule: Union[str, type[_PipelineSchedule]],
schedule: str | type[_PipelineSchedule],
pp_degree: int,
num_microbatches: int,
num_stages_per_rank: Optional[int] = None,
num_stages_per_rank: int | None = None,
add_spacing: bool = False,
with_comms: bool = False,
) -> list[list[Optional[_Action]]]:
) -> list[list[_Action | None]]:
"""
Get all actions for a given schedule, pp_degree, and num_microbatches. The actions are returned in a list of lists
where each inner list represents a rank and each element in the inner list represents an action.
@ -86,7 +86,7 @@ def get_schedule_ops(
assert schedule_instance.pipeline_order is not None
# Convert to List[List[_Action]]
all_actions: list[list[Optional[_Action]]] = []
all_actions: list[list[_Action | None]] = []
if with_comms:
runtime = _PipelineScheduleRuntime(stages, num_microbatches)
runtime._prepare_schedule_with_comms(schedule_instance.pipeline_order)
@ -136,8 +136,8 @@ action_type_to_color_mapping = {
def add_schedule_op_spacing(
schedule: list[list[Optional[_Action]]],
) -> list[list[Optional[_Action]]]:
schedule: list[list[_Action | None]],
) -> list[list[_Action | None]]:
"""
Add spacing to the schedule based on dependencies between ranks.
@ -169,7 +169,7 @@ def add_schedule_op_spacing(
)
num_ranks = len(schedule)
spaced_schedule: list[list[Optional[_Action]]] = [[] for _ in range(num_ranks)]
spaced_schedule: list[list[_Action | None]] = [[] for _ in range(num_ranks)]
rank_ops = [collections.deque(ops) for ops in schedule]
# Track completion times: (stage_index, action_type, microbatch_index) -> completion_time
@ -331,8 +331,8 @@ def add_schedule_op_spacing(
def visualize_schedule(
schedule: list[list[Optional[_Action]]],
filename: Optional[str] = None,
schedule: list[list[_Action | None]],
filename: str | None = None,
) -> None:
"""
Visualize the schedule using matplotlib.

View File

@ -3,7 +3,6 @@
import logging
from dataclasses import dataclass
from typing import Union
import torch
from torch import fx
@ -76,8 +75,8 @@ def validate_tensor_metadata(desc, expected, given):
def validate_tensors_metadata(
desc,
expected_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
actual_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
expected_tensors: list[torch.Tensor] | tuple[torch.Tensor, ...],
actual_tensors: list[torch.Tensor] | tuple[torch.Tensor, ...],
):
if len(expected_tensors) != len(actual_tensors):
raise PipeliningShapeError(

View File

@ -3,7 +3,7 @@
import logging
import operator
from collections.abc import Sequence
from typing import Any, Optional
from typing import Any
import torch
from torch.fx.node import map_aggregate
@ -307,10 +307,10 @@ def _shard_dict_of_args(
def split_args_kwargs_into_chunks(
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]],
kwargs: dict[str, Any] | None,
chunks: int,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
) -> tuple[list[tuple], list[dict]]:
"""
Given a sequence of args and kwargs, split them into a number of chunks

View File

@ -11,7 +11,7 @@ from collections import Counter, defaultdict
from collections.abc import Callable
from enum import Enum
from functools import lru_cache
from typing import Any, cast, NamedTuple, Optional, Protocol, Union
from typing import Any, cast, NamedTuple, Protocol
import torch
import torch.distributed as dist
@ -131,8 +131,8 @@ _action_regex = re.compile(
class _Action(NamedTuple):
stage_index: int
computation_type: _ComputationType
microbatch_index: Optional[int] = None
sub_actions: Optional[tuple["_Action", ...]] = None
microbatch_index: int | None = None
sub_actions: tuple["_Action", ...] | None = None
def __str__(self):
return self.__repr__()
@ -220,8 +220,8 @@ def _get_profiler_function_name(action: _Action) -> str:
def _format_pipeline_order(
pipeline_order: dict[int, list[Optional[_Action]]],
error_step_number: Optional[int] = None,
pipeline_order: dict[int, list[_Action | None]],
error_step_number: int | None = None,
) -> str:
"""
Formats the pipeline order in a timestep (row) x rank (column) grid of actions
@ -286,10 +286,10 @@ class _PipelineSchedule(ABC):
def __init__(
self,
n_microbatches: int,
loss_fn: Optional[Callable[..., torch.Tensor]] = None,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
loss_fn: Callable[..., torch.Tensor] | None = None,
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
scale_grads: bool = True,
):
# From arguments
@ -360,10 +360,10 @@ class _PipelineSchedule(ABC):
@abstractmethod
def _step_microbatches(
self,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
arg_mbs: list | None = None,
kwarg_mbs: list | None = None,
target_mbs: list | None = None,
losses: list | None = None,
return_outputs: bool = True,
):
"""
@ -382,7 +382,7 @@ class _PipelineSchedule(ABC):
self,
*args,
target=None,
losses: Optional[list] = None,
losses: list | None = None,
return_outputs=True,
**kwargs,
):
@ -399,7 +399,7 @@ class _PipelineSchedule(ABC):
"""
raise NotImplementedError
def eval(self, *args, target=None, losses: Optional[list] = None, **kwargs):
def eval(self, *args, target=None, losses: list | None = None, **kwargs):
"""
Run one iteration of the pipeline schedule with *whole-batch* input.
Will chunk the input into microbatches automatically, and go through the
@ -421,10 +421,10 @@ class _PipelineSchedule(ABC):
def _check_inputs(
self,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
arg_mbs: list | None = None,
kwarg_mbs: list | None = None,
target_mbs: list | None = None,
losses: list | None = None,
) -> tuple[list, list]:
"""
Pre-process/check inputs
@ -463,7 +463,7 @@ class _PipelineSchedule(ABC):
def _split_inputs(
self,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
kwargs: dict[str, Any] | None = None,
):
"""
Splits a full-batch input into chunks (i.e. microbatches) and returns
@ -494,9 +494,7 @@ class _PipelineSchedule(ABC):
)
def _batch_p2p(
p2p_ops: list[dist.P2POp], desc: Optional[str] = None
) -> list[dist.Work]:
def _batch_p2p(p2p_ops: list[dist.P2POp], desc: str | None = None) -> list[dist.Work]:
"""
Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
"""
@ -508,7 +506,7 @@ def _batch_p2p(
def _sorted_batch_p2p(
p2p_ops: list[dist.P2POp], desc: Optional[str] = None
p2p_ops: list[dist.P2POp], desc: str | None = None
) -> dict[int, list[dist.Work]]:
"""
Sorts the list of P2P ops by the peer rank, and then calls
@ -557,10 +555,10 @@ class PipelineScheduleSingle(_PipelineSchedule):
self,
stage: _PipelineStageBase,
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
loss_fn: Callable | None = None,
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
scale_grads: bool = True,
):
# Init parent
@ -584,7 +582,7 @@ class PipelineScheduleSingle(_PipelineSchedule):
or equal to the number of stages ({self._num_stages})."
)
self.pipeline_order: Optional[dict[int, list[Optional[_Action]]]] = (
self.pipeline_order: dict[int, list[_Action | None]] | None = (
self._get_pipeline_order()
)
@ -608,7 +606,7 @@ or equal to the number of stages ({self._num_stages})."
self,
*args,
target=None,
losses: Optional[list] = None,
losses: list | None = None,
return_outputs: bool = True,
**kwargs,
):
@ -656,7 +654,7 @@ or equal to the number of stages ({self._num_stages})."
else:
return None
def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None:
"""
Returns the pipeline execution order as a schedule IR.
@ -683,10 +681,10 @@ class _ScheduleForwardOnly(PipelineScheduleSingle):
def _step_microbatches(
self,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
arg_mbs: list | None = None,
kwarg_mbs: list | None = None,
target_mbs: list | None = None,
losses: list | None = None,
return_outputs: bool = True,
):
"""
@ -734,10 +732,10 @@ class ScheduleGPipe(PipelineScheduleSingle):
def _step_microbatches(
self,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
arg_mbs: list | None = None,
kwarg_mbs: list | None = None,
target_mbs: list | None = None,
losses: list | None = None,
return_outputs: bool = True,
):
"""
@ -812,7 +810,7 @@ class ScheduleGPipe(PipelineScheduleSingle):
self._stage.perform_reduce_grad(self._n_microbatches if self.scale_grads else 1)
def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None:
"""
Returns the pipeline order for GPipe schedule.
@ -822,7 +820,7 @@ class ScheduleGPipe(PipelineScheduleSingle):
pp_group_size = self._num_stages
for rank in range(pp_group_size):
actions: list[Optional[_Action]] = []
actions: list[_Action | None] = []
# 1. Initial delay based on rank position
warmup_delay = rank
@ -853,10 +851,10 @@ class Schedule1F1B(PipelineScheduleSingle):
def _step_microbatches(
self,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
arg_mbs: list | None = None,
kwarg_mbs: list | None = None,
target_mbs: list | None = None,
losses: list | None = None,
return_outputs: bool = True,
):
"""
@ -995,7 +993,7 @@ class Schedule1F1B(PipelineScheduleSingle):
self._stage.perform_reduce_grad(self._n_microbatches if self.scale_grads else 1)
def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None:
"""
Returns the pipeline order for 1F1B schedule.
@ -1005,7 +1003,7 @@ class Schedule1F1B(PipelineScheduleSingle):
pp_group_size = self._num_stages
for rank in range(pp_group_size):
actions: list[Optional[_Action]] = []
actions: list[_Action | None] = []
# 1. Warmup phase: initial delay based on rank
actions.extend([None] * rank)
@ -1069,13 +1067,13 @@ def _requires_reduce_grad(action_type: _ComputationType) -> bool:
def _add_reduce_grad(
actions: list[Optional[_Action]], n_microbatches: int
) -> list[Optional[_Action]]:
actions: list[_Action | None], n_microbatches: int
) -> list[_Action | None]:
"""
REDUCE_GRAD refers to joint across minibatches grad reduction.
reduce_grad frees memory and we want to schedule it just after the last "backward"-like stage.
"""
actions_with_reduce_grad: list[Optional[_Action]] = []
actions_with_reduce_grad: list[_Action | None] = []
cnt: dict[int, int] = defaultdict(int)
def _leaf_action(a, to_schedule):
@ -1102,7 +1100,7 @@ def _add_reduce_grad(
def _add_unshard_reshard(
compute_actions: list[Optional[_Action]],
compute_actions: list[_Action | None],
max_active_stages: int = 3,
) -> list[_Action]:
"""Given a basic schedule involving only compute actions (F,B,W,OVERLAP_F_B), add UNSHARD/RESHARD actions for FSDP.
@ -1117,9 +1115,7 @@ def _add_unshard_reshard(
(to account for having one f and one b active, and something else prefetching?)
"""
def next_stage_indices(
count: int, next_actions: list[Optional[_Action]]
) -> list[int]:
def next_stage_indices(count: int, next_actions: list[_Action | None]) -> list[int]:
"""Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute."""
seen: set[int] = set()
ret: list[int] = []
@ -1187,7 +1183,7 @@ def _add_unshard_reshard(
def _merge_bw(
compute_actions: list[Optional[_Action]],
compute_actions: list[_Action | None],
) -> list[_Action]:
"""Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops.
(note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD)
@ -1259,9 +1255,7 @@ def _add_send_recv(
recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx)
return send, recv
def _ready_to_schedule(
action: Optional[_Action], prev_actions: set[_Action]
) -> bool:
def _ready_to_schedule(action: _Action | None, prev_actions: set[_Action]) -> bool:
"""We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place.
This helps ensure a sane (non-hanging) ordering of sends and recvs.
But it also means we might not be able to schedule our next compute action yet.
@ -1343,7 +1337,7 @@ def _add_send_recv(
def _validate_schedule(
actions: dict[int, list[Optional[_Action]]],
actions: dict[int, list[_Action | None]],
pp_group_size: int,
num_stages: int,
num_microbatches: int,
@ -1479,11 +1473,11 @@ class PipelineScheduleMulti(_PipelineSchedule):
self,
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
use_full_backward: Optional[bool] = None,
loss_fn: Callable | None = None,
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
use_full_backward: bool | None = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
@ -1516,7 +1510,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
self._should_compute_loss = lambda stage: stage.is_last and has_loss
# This will be set during init of derived schedules
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
self.pipeline_order: dict[int, list[_Action | None]] = {}
# When using a custom backward function, we may or may not need autograd to be used
# for the backward pass. This flag is used to determine whether or torch.is_grad_enabled()
@ -1559,7 +1553,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
self._stages_backward_initialized = True
def _validate_and_set_stage_mapping(
self, actions: dict[int, list[Optional[_Action]]]
self, actions: dict[int, list[_Action | None]]
) -> None:
"""
Allocates the stage index to rank mapping which is needed for communication
@ -1600,7 +1594,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
self,
*args,
target=None,
losses: Optional[list] = None,
losses: list | None = None,
return_outputs: bool = True,
**kwargs,
):
@ -1657,10 +1651,10 @@ class PipelineScheduleMulti(_PipelineSchedule):
def _step_microbatches(
self,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
arg_mbs: list | None = None,
kwarg_mbs: list | None = None,
target_mbs: list | None = None,
losses: list | None = None,
return_outputs: bool = True,
):
"""
@ -1851,10 +1845,10 @@ class _PipelineContext:
def __init__(
self,
schedule_ref: _PipelineSchedule,
arg_mbs: Optional[list[tuple]] = None,
kwarg_mbs: Optional[list[dict]] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
arg_mbs: list[tuple] | None = None,
kwarg_mbs: list[dict] | None = None,
target_mbs: list | None = None,
losses: list | None = None,
):
self.schedule_ref = schedule_ref
self.arg_mbs = arg_mbs
@ -1931,7 +1925,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
def _prepare_schedule_with_comms(
self,
actions: dict[int, list[Optional[_Action]]],
actions: dict[int, list[_Action | None]],
format: str = "compute_only",
):
"""
@ -2042,10 +2036,10 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
def _step_microbatches(
self,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
arg_mbs: list | None = None,
kwarg_mbs: list | None = None,
target_mbs: list | None = None,
losses: list | None = None,
return_outputs: bool = True,
):
"""
@ -2306,8 +2300,8 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
self,
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Union[Callable, _Loss]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
loss_fn: Callable | _Loss | None = None,
output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
@ -2323,7 +2317,7 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
# 1. Create the pipeline_order (all ranks do this calculation)
# This will be used to keep track of the current state of the entire pipeline
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
self.pipeline_order: dict[int, list[_Action | None]] = {}
# ========================================================================
for rank in range(self.pp_group_size):
rank_ops = self._calculate_single_rank_operations(rank)
@ -2340,7 +2334,7 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
# Store the list of operations used for that rank
# Pre-padding, rank starts with no-ops based on the warmup.
rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
rank_ops: list[_Action | None] = [None for _ in range(rank)]
for stage_index in stage_indices:
rank_ops.extend(
@ -2380,7 +2374,7 @@ def _get_1f1b_rank_ops(
# Store the list of operations used for that rank
# Pre-padding, rank starts with no-ops based on the warmup.
rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
rank_ops: list[_Action | None] = [None for _ in range(rank)]
# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
# Formula:
@ -2520,10 +2514,10 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
self,
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
loss_fn: Callable | None = None,
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
@ -2551,7 +2545,7 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
# 1. Create the pipeline_order (all ranks do this calculation)
# This will be used to keep track of the current state of the entire pipeline
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
self.pipeline_order: dict[int, list[_Action | None]] = {}
for rank in range(self.pp_group_size):
rank_ops = self._calculate_single_rank_operations(rank)
self.pipeline_order[rank] = rank_ops
@ -2559,7 +2553,7 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
# Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
self._prepare_schedule_with_comms(self.pipeline_order)
def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
def _calculate_single_rank_operations(self, rank) -> list[_Action | None]:
def get_rank_warmup_ops(rank):
# Warms up operations for last stage
warmups_ops_last_stage = (
@ -2634,10 +2628,10 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
self,
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
loss_fn: Callable | None = None,
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
@ -2667,7 +2661,7 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
# 1. Create the pipeline_order (all ranks do this calculation)
# This will be used to keep track of the current state of the entire pipeline
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
self.pipeline_order: dict[int, list[_Action | None]] = {}
for rank in range(self.pp_group_size):
rank_ops = self._calculate_single_rank_operations(rank)
self.pipeline_order[rank] = rank_ops
@ -2682,7 +2676,7 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
# Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
self._prepare_schedule_with_comms(self.pipeline_order)
def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
def _calculate_single_rank_operations(self, rank) -> list[_Action | None]:
def get_rank_warmup_ops(rank):
# Warms up operations for last stage
warmups_ops_last_stage = (
@ -2760,7 +2754,7 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
return False
seen_ops: set[tuple[int, _ComputationType, int]] = set()
result: dict[int, list[Optional[_Action]]] = {}
result: dict[int, list[_Action | None]] = {}
next_pointer: dict[int, int] = {}
bubbles_added: dict[int, int] = {}
total_bubbles_added = 0
@ -2833,10 +2827,10 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
self,
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
loss_fn: Callable | None = None,
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
@ -2872,7 +2866,7 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
# 1. Create the pipeline_order (all ranks do this calculation)
# This will be used to keep track of the current state of the entire pipeline
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
self.pipeline_order: dict[int, list[_Action | None]] = {}
for rank in range(self.pp_group_size):
rank_ops = self._calculate_single_rank_operations(rank)
self.pipeline_order[rank] = rank_ops
@ -2880,11 +2874,11 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
# Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
self._prepare_schedule_with_comms(self.pipeline_order)
def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
def _calculate_single_rank_operations(self, rank) -> list[_Action | None]:
# max(2 * self.pp_group_size - 1, ...) ensure the number of microbatches is at least
# as large of the number of microbatches needed to fully utilize the pipeline
n_micro = max(2 * self.pp_group_size - 1, self._n_microbatches)
rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
rank_ops: list[_Action | None] = [None for _ in range(rank)]
# Forward and backward action counts for stage chunk 0 and chunk 1
f0_cnt, f1_cnt, b0_cnt, b1_cnt = 0, 0, 0, 0
@ -3011,10 +3005,10 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
self,
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
loss_fn: Callable | None = None,
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
@ -3055,7 +3049,7 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
# 1. Create the pipeline_order (all ranks do this calculation)
# This will be used to keep track of the current state of the entire pipeline
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
self.pipeline_order: dict[int, list[_Action | None]] = {}
for rank in range(self.pp_group_size):
rank_ops = self._calculate_single_rank_operations(rank)
self.pipeline_order[rank] = rank_ops
@ -3063,8 +3057,8 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
# Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
self._prepare_schedule_with_comms(self.pipeline_order)
def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
actions: list[Optional[_Action]] = []
def _calculate_single_rank_operations(self, rank) -> list[_Action | None]:
actions: list[_Action | None] = []
counters: dict[
tuple[int, _ComputationType], int
] = {} # (stage_index, computation_type) -> mb_index
@ -3273,12 +3267,12 @@ def _simulate_comms_compute(
_prev_ops_rank: dict[int, set[_Action]] = {rank: set() for rank in _schedule}
def add_to_schedule(rank: int, action: Optional[_Action]):
def add_to_schedule(rank: int, action: _Action | None):
_schedule[rank].append(action)
if action is not None:
_prev_ops_rank[rank].add(action)
def _ready_to_schedule(action: Optional[_Action]) -> bool:
def _ready_to_schedule(action: _Action | None) -> bool:
if action is None:
return True

View File

@ -4,7 +4,7 @@ import logging
import operator
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any, cast, Optional, Union
from typing import Any, cast, Union
import torch
import torch.distributed as dist
@ -99,7 +99,7 @@ InputInfo = Union[_RecvInfo, _RootArgPlaceholder]
def _make_tensor_from_meta(
example: Union[torch.Tensor, FakeTensor],
example: torch.Tensor | FakeTensor,
device: torch.device,
) -> torch.Tensor:
"""
@ -126,8 +126,8 @@ class _PipelineStageBase(ABC):
stage_index: int,
num_stages: int,
device: torch.device,
group: Optional[dist.ProcessGroup] = None,
dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
group: dist.ProcessGroup | None = None,
dw_builder: Callable[[], Callable[..., None]] | None = None,
):
"""
Args:
@ -176,11 +176,11 @@ class _PipelineStageBase(ABC):
)
# Run time states
self._outputs_meta: Optional[tuple[torch.Tensor, ...]] = None
self._outputs_meta: tuple[torch.Tensor, ...] | None = None
# map microbatch ID to list of forward tensor args
self.fwd_cache: dict[int, tuple[Any, list[torch.Tensor]]] = {}
# map microbatch ID to list of backward grad tensor args
self.bwd_cache: dict[int, tuple[Optional[torch.Tensor], ...]] = {}
self.bwd_cache: dict[int, tuple[torch.Tensor | None, ...]] = {}
# Caching chunk outputs for final output merge or reduction
self.output_chunks: list[Any] = []
@ -196,10 +196,10 @@ class _PipelineStageBase(ABC):
# Backward infra will created lazily
self.grad_recv_info: dict = {}
self.grad_send_info: Optional[list] = None
self.grad_send_info: list | None = None
# To be populated later by the Schedule
self.chunks: Optional[int] = None
self.chunks: int | None = None
self.stage_index_to_group_rank: dict[int, int] = {
i: i % self.group_size for i in range(self.num_stages)
}
@ -261,11 +261,11 @@ class _PipelineStageBase(ABC):
def _create_grad_send_info(
self,
args_recv_info: tuple,
) -> list[Optional[int]]:
) -> list[int | None]:
"""
Create a list of stage indices to send gradients to.
"""
grad_send_info: list[Optional[int]] = []
grad_send_info: list[int | None] = []
def map_recv_to_send(a):
# Note: we send gradients back to previous stage as long as in
@ -288,7 +288,7 @@ class _PipelineStageBase(ABC):
self,
num_microbatches: int,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
kwargs: dict[str, Any] | None = None,
) -> tuple[Any, ...]:
raise NotImplementedError
@ -388,7 +388,7 @@ class _PipelineStageBase(ABC):
return self.bwd_cache.pop(mb_index)
def set_local_bwd_input(
self, next_stage_bwd_outputs: tuple[Optional[torch.Tensor], ...], mb_index: int
self, next_stage_bwd_outputs: tuple[torch.Tensor | None, ...], mb_index: int
) -> None:
"""
Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv.
@ -588,7 +588,7 @@ class _PipelineStageBase(ABC):
backward_type,
bwd_kwargs: dict,
last_backward: bool = False,
) -> tuple[tuple[Optional[torch.Tensor], ...], Optional[list[dict[str, Any]]]]:
) -> tuple[tuple[torch.Tensor | None, ...], list[dict[str, Any]] | None]:
"""
Whether using PP with FSDP, DDP, or replicate there are some runtime differences between the last backward step and the
other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but
@ -600,7 +600,7 @@ class _PipelineStageBase(ABC):
backward_type,
) -> Callable[
[],
tuple[tuple[Optional[torch.Tensor], ...], Optional[list[dict[str, Any]]]],
tuple[tuple[torch.Tensor | None, ...], list[dict[str, Any]] | None],
]:
if backward_type == "full":
return lambda: (
@ -663,7 +663,7 @@ class _PipelineStageBase(ABC):
self,
fwd_chunk_id: int,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
kwargs: dict[str, Any] | None = None,
save_forward_output: bool = True,
):
"""
@ -779,7 +779,7 @@ class _PipelineStageBase(ABC):
"input_values": input_values,
}
grads_input: tuple[Optional[torch.Tensor], ...] = ()
grads_input: tuple[torch.Tensor | None, ...] = ()
# Custom backward function
if self.dw_builder:
@ -1019,7 +1019,7 @@ class _PipelineStage(_PipelineStageBase):
stage_index: int,
pipe_info: PipeInfo,
device: torch.device,
group: Optional[dist.ProcessGroup] = None,
group: dist.ProcessGroup | None = None,
):
"""
Create a pipeline stage given a stage_module to be wrapped by this stage
@ -1086,7 +1086,7 @@ class _PipelineStage(_PipelineStageBase):
self,
num_microbatches: int,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
kwargs: dict[str, Any] | None = None,
) -> tuple[Any, ...]:
"""
Create send/recv infrastructures for activations (during forward)
@ -1183,7 +1183,7 @@ class _PipelineStage(_PipelineStageBase):
def find_dst_rank(
self,
user: fx.Node,
) -> Optional[int]:
) -> int | None:
"""
Find the destination rank of a `user` node.
If the `user` is not a submod, `None` may be returned.
@ -1293,7 +1293,7 @@ def build_stage(
stage_index: int,
pipe_info: PipeInfo,
device: torch.device,
group: Optional[dist.ProcessGroup] = None,
group: dist.ProcessGroup | None = None,
) -> _PipelineStage:
"""
Create a pipeline stage given a stage_module to be wrapped by this stage
@ -1347,14 +1347,14 @@ class PipelineStage(_PipelineStageBase):
stage_index: int,
num_stages: int,
device: torch.device,
input_args: Optional[Union[torch.Tensor, tuple[torch.Tensor, ...]]] = None,
output_args: Optional[Union[torch.Tensor, tuple[torch.Tensor, ...]]] = None,
group: Optional[dist.ProcessGroup] = None,
dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
input_args: torch.Tensor | tuple[torch.Tensor, ...] | None = None,
output_args: torch.Tensor | tuple[torch.Tensor, ...] | None = None,
group: dist.ProcessGroup | None = None,
dw_builder: Callable[[], Callable[..., None]] | None = None,
):
super().__init__(submodule, stage_index, num_stages, device, group, dw_builder)
self.inputs: Optional[list[torch.Tensor]] = None
self.inputs_meta: Optional[tuple[torch.Tensor, ...]] = None
self.inputs: list[torch.Tensor] | None = None
self.inputs_meta: tuple[torch.Tensor, ...] | None = None
# Note: inputs and submod should ideally be on meta device. We decided not to assert this (yet) because it
# might be breaking for existing users.
if input_args is None:
@ -1410,7 +1410,7 @@ class PipelineStage(_PipelineStageBase):
def _shape_inference(
self,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
kwargs: dict[str, Any] | None = None,
):
if kwargs is None:
kwargs = {}
@ -1522,7 +1522,7 @@ class PipelineStage(_PipelineStageBase):
self,
num_microbatches: int,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
kwargs: dict[str, Any] | None = None,
) -> tuple[Any, ...]:
# TODO move self.device to an argument from step API (from its input tensors)?
assert num_microbatches is not None, "TODO fix num_microbatches"

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
@ -22,14 +21,14 @@ class _remote_device:
and "cuda:1", just represent local devices.
"""
def __init__(self, remote_device: Union[str, torch.device]):
def __init__(self, remote_device: str | torch.device):
PARSE_ERROR = (
f"Could not parse remote_device: {remote_device}. The valid format is "
"'<workername>/<device>' or 'rank:<rank>/<device>' or '<device>'"
)
self._worker_name = None
self._rank = None
self._device: Optional[Union[str, int, torch.device]] = None
self._device: str | int | torch.device | None = None
if isinstance(remote_device, torch.device):
self._device = remote_device
@ -81,11 +80,11 @@ class _remote_device:
except Exception:
return False
def worker_name(self) -> Optional[str]:
def worker_name(self) -> str | None:
"""Return the name of remote worker representing the remote device and ``None`` if no worker name is available."""
return self._worker_name
def rank(self) -> Optional[int]:
def rank(self) -> int | None:
"""
Returns the rank of remote worker representing the remote device.
Returns ``None`` if no rank is available.

View File

@ -11,7 +11,6 @@ import os
import sys
from collections.abc import Callable, Iterator
from datetime import timedelta
from typing import Optional
from torch.distributed import FileStore, Store, TCPStore
@ -71,7 +70,7 @@ def _get_use_libuv_from_query_dict(query_dict: dict[str, str]) -> bool:
return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "1")) == "1"
def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwargs):
def _rendezvous_helper(url: str, rank: int, world_size_opt: int | None, **kwargs):
result = urlparse(url)
if world_size_opt is None:
world_size = -1

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
from typing import Union
import torch
@ -89,10 +89,10 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,
rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,
init_method: str = rpc_contants.DEFAULT_INIT_METHOD,
device_maps: Optional[dict[str, dict[DeviceType, DeviceType]]] = None,
devices: Optional[list[DeviceType]] = None,
_transports: Optional[list] = None,
_channels: Optional[list] = None,
device_maps: dict[str, dict[DeviceType, DeviceType]] | None = None,
devices: list[DeviceType] | None = None,
_transports: list | None = None,
_channels: list | None = None,
):
full_device_maps = (
{}

View File

@ -375,7 +375,6 @@ import uuid
from argparse import ArgumentParser, REMAINDER
from collections.abc import Callable
from importlib import metadata
from typing import Optional, Union
import torch
from torch.distributed.argparse_util import check_env, env
@ -798,7 +797,7 @@ def get_use_env(args) -> bool:
return args.use_env
def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]:
def _get_logs_specs_class(logs_specs_name: str | None) -> type[LogsSpecs]:
"""
Attempts to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
Provides plugin mechanism to provide custom implementation of LogsSpecs.
@ -827,7 +826,7 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]:
return logs_specs_cls
def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str]]:
def config_from_args(args) -> tuple[LaunchConfig, Callable | str, list[str]]:
# If ``args`` not passed, defaults to ``sys.argv[:1]``
min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
if not (0 < min_nodes <= max_nodes):
@ -871,7 +870,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str
rdzv_endpoint = get_rdzv_endpoint(args)
ranks: Optional[set[int]] = None
ranks: set[int] | None = None
if args.local_ranks_filter:
try:
ranks = set(map(int, args.local_ranks_filter.split(",")))
@ -920,7 +919,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str
)
with_python = not args.no_python
cmd: Union[Callable, str]
cmd: Callable | str
cmd_args = []
use_env = get_use_env(args)
if args.run_path:

View File

@ -5,7 +5,7 @@ import copy
import inspect
import warnings
from collections.abc import Callable, Sequence
from typing import Any, cast, Optional
from typing import Any, cast
from typing_extensions import deprecated
import torch
@ -74,7 +74,7 @@ class _ToTorchTensor(torch.autograd.Function):
def forward( # type: ignore[override]
ctx,
input: "DTensor",
grad_placements: Optional[Sequence[Placement]],
grad_placements: Sequence[Placement] | None,
):
ctx.dtensor_spec = input._spec
ctx.grad_placements = grad_placements
@ -135,8 +135,8 @@ class _FromTorchTensor(torch.autograd.Function):
device_mesh: DeviceMesh,
placements: tuple[Placement, ...],
run_check: bool,
shape: Optional[torch.Size] = None,
stride: Optional[tuple[int, ...]] = None,
shape: torch.Size | None = None,
stride: tuple[int, ...] | None = None,
) -> "DTensor":
ctx.previous_placement = placements
ctx.previous_device_mesh = device_mesh
@ -356,12 +356,12 @@ class DTensor(torch.Tensor):
@staticmethod
def from_local(
local_tensor: torch.Tensor,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
*,
run_check: bool = False,
shape: Optional[torch.Size] = None,
stride: Optional[tuple[int, ...]] = None,
shape: torch.Size | None = None,
stride: tuple[int, ...] | None = None,
) -> "DTensor":
"""
Create a :class:`DTensor` from a local torch.Tensor on each rank
@ -445,7 +445,7 @@ class DTensor(torch.Tensor):
)
def to_local(
self, *, grad_placements: Optional[Sequence[Placement]] = None
self, *, grad_placements: Sequence[Placement] | None = None
) -> torch.Tensor:
"""
Get the local tensor of this DTensor on its current rank. For sharding it returns
@ -483,12 +483,12 @@ class DTensor(torch.Tensor):
def redistribute(
self,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
*,
async_op: bool = False,
forward_dtype: Optional[torch.dtype] = None,
backward_dtype: Optional[torch.dtype] = None,
forward_dtype: torch.dtype | None = None,
backward_dtype: torch.dtype | None = None,
) -> "DTensor":
"""
``redistribute`` performs necessary collective operations that redistribute the current
@ -565,7 +565,7 @@ class DTensor(torch.Tensor):
)
def full_tensor(
self, *, grad_placements: Optional[Sequence[Placement]] = None
self, *, grad_placements: Sequence[Placement] | None = None
) -> torch.Tensor:
"""
Return the full tensor of this DTensor. It will perform necessary collectives
@ -688,10 +688,10 @@ class DTensor(torch.Tensor):
def distribute_tensor(
tensor: torch.Tensor,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
*,
src_data_rank: Optional[int] = 0,
src_data_rank: int | None = 0,
) -> DTensor:
"""
Distribute a leaf ``torch.Tensor`` (i.e. nn.Parameter/buffers) to the ``device_mesh`` according
@ -850,7 +850,7 @@ def distribute_tensor(
def _shard_tensor(
full_tensor: torch.Tensor,
placements: Sequence[Shard],
device_mesh: Optional[DeviceMesh] = None,
device_mesh: DeviceMesh | None = None,
) -> "DTensor":
"""
Locally shards a full tensor based on indicated sharding arrangement, and
@ -886,10 +886,10 @@ def _shard_tensor(
def distribute_module(
module: nn.Module,
device_mesh: Optional[DeviceMesh] = None,
partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None,
input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None,
output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None,
device_mesh: DeviceMesh | None = None,
partition_fn: Callable[[str, nn.Module, DeviceMesh], None] | None = None,
input_fn: Callable[[nn.Module, Any, DeviceMesh], None] | None = None,
output_fn: Callable[[nn.Module, Any, DeviceMesh], None] | None = None,
) -> nn.Module:
"""
This function expose three functions to control the parameters/inputs/outputs of the module:
@ -1042,8 +1042,8 @@ def distribute_module(
def _dtensor_init_helper( # type: ignore[no-untyped-def]
init_op,
size: torch.Size,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
**kwargs,
) -> DTensor:
# if device_mesh is None, use the one from mesh resources
@ -1108,11 +1108,11 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def]
def ones( # type: ignore[no-untyped-def]
*size,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined
@ -1151,11 +1151,11 @@ def ones( # type: ignore[no-untyped-def]
def empty( # type: ignore[no-untyped-def]
*size,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor`
@ -1196,11 +1196,11 @@ def full( # type: ignore[no-untyped-def]
size,
fill_value,
*,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and
@ -1242,10 +1242,10 @@ def full( # type: ignore[no-untyped-def]
def rand( # type: ignore[no-untyped-def]
*size,
requires_grad: bool = False,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with random numbers from a uniform distribution
@ -1286,10 +1286,10 @@ def rand( # type: ignore[no-untyped-def]
def randn( # type: ignore[no-untyped-def]
*size,
requires_grad: bool = False,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with random numbers from a normal distribution
@ -1330,10 +1330,10 @@ def randn( # type: ignore[no-untyped-def]
def zeros( # type: ignore[no-untyped-def]
*size,
requires_grad: bool = False,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with the scalar value 0.

View File

@ -74,7 +74,7 @@ def mesh_scatter(
async_op: bool = False,
*,
group_src: int = 0,
) -> Optional[Work]:
) -> Work | None:
"""
scatter a list of tensors to a device mesh dimension. We by default
use the first rank of the mesh dimension as the source of truth, i.e
@ -135,7 +135,7 @@ def mesh_broadcast(
async_op: bool = False,
*,
group_src: int = 0,
) -> Optional[Work]:
) -> Work | None:
"""
broadcast the tensor to a device mesh dimension. We by default
use the first rank of the mesh dimension as the source of truth, i.e

View File

@ -3,7 +3,7 @@ import contextlib
import logging
import warnings
from collections.abc import Sequence
from typing import cast, Optional
from typing import cast
import torch
import torch.distributed as dist
@ -482,7 +482,7 @@ class OpDispatcher:
kwargs_schema: dict[str, object] = {}
local_args: list[object] = []
local_kwargs: dict[str, object] = {}
compute_mesh: Optional[DeviceMesh] = None
compute_mesh: DeviceMesh | None = None
for arg in args_list:
if isinstance(arg, dtensor.DTensor):

View File

@ -2,7 +2,7 @@ import itertools
import math
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, cast, NamedTuple, Optional
from typing import Any, cast, NamedTuple
import torch
from torch.distributed.device_mesh import DeviceMesh
@ -71,7 +71,7 @@ class DTensorSpec:
placements: tuple[Placement, ...]
# tensor meta will only be set during sharding propagation
tensor_meta: Optional[TensorMeta] = None
tensor_meta: TensorMeta | None = None
# When a tensor dimension is sharded across multiple mesh axes,
# `shard_order` specifies the sequence in which these shardings are applied.
@ -206,7 +206,7 @@ class DTensorSpec:
@staticmethod
def _maybe_convert_StridedShard_to_shard_order(
placements: tuple[Placement, ...], mesh: DeviceMesh
) -> Optional[ShardOrder]:
) -> ShardOrder | None:
"""
Try to convert _StridedShard placements to ShardOrder.
@ -441,7 +441,7 @@ class DTensorSpec:
@staticmethod
def format_shard_order_str(
placements: tuple[Placement, ...],
shard_order: Optional[ShardOrder] = None,
shard_order: ShardOrder | None = None,
) -> str:
"""
Format DTensor sharding information as a human-readable string.
@ -617,7 +617,7 @@ class DTensorSpec:
mesh: DeviceMesh,
dim_map: list[int],
sums: list[int],
tensor_meta: Optional[TensorMeta] = None,
tensor_meta: TensorMeta | None = None,
) -> "DTensorSpec":
"""
Construct a DTensorSpec from dim_map list and pending sum.
@ -669,7 +669,7 @@ class DTensorSpec:
return any(placement.is_shard() for placement in self.placements)
def shallow_copy_with_tensor_meta(
self, tensor_meta: Optional[TensorMeta]
self, tensor_meta: TensorMeta | None
) -> "DTensorSpec":
"""
Shallow copy the DTensorSpec with a new tensor_meta.

View File

@ -26,7 +26,7 @@ These schema definitions enable the DTensor system to:
from collections.abc import Sequence
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Optional, Union
from typing import Any, Optional
from typing_extensions import deprecated
import torch
@ -60,11 +60,11 @@ except ImportError:
ArgsType = tuple[object, ...]
KwargsType = dict[str, object]
PlacementList = list[Optional[Placement]]
PlacementList = list[Placement | None]
# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type should
# be the same set of possibilities.
OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]]
OutputSpecType = Optional[DTensorSpec | Sequence[DTensorSpec | None]]
def _rebuild_tensor_from_dtensor_meta(arg) -> object:
@ -109,8 +109,8 @@ class OpSpec:
# output_specs and input_specs are related: for this op, given these input_specs,
# this is the way the output would look
output_specs: Union[DTensorSpec, tuple[Optional[DTensorSpec], ...]]
input_specs: Optional[Sequence[DTensorSpec]] = None
output_specs: DTensorSpec | tuple[DTensorSpec | None, ...]
input_specs: Sequence[DTensorSpec] | None = None
"""
redistribute_cost tells how expensive it is to redistribute a given input into the
@ -138,7 +138,7 @@ class OpSpec:
K, # cost of redistributing tensor_a from 'Shard(0)'
],
"""
redistribute_cost: Optional[list[list[float]]] = None
redistribute_cost: list[list[float]] | None = None
@cached_property
def output_spec(self) -> DTensorSpec:
@ -301,7 +301,7 @@ class RuntimeSchemaInfo:
# Note that only a few ops need this information, e.g. view, transpose, var.dim, etc.
static_argnum: int = 100
# This static_kwargkey records static kwarg names which would affect sharding prop
static_kwargkey: Optional[list[str]] = None
static_kwargkey: list[str] | None = None
# each op can decide if it wants to use pytree flatten/unflatten during operator
# eager execution, by default we don't need to do flatten/unflatten, only if the
# op indicate it needs to, this is to accelerate eager performance.
@ -331,9 +331,9 @@ class OpSchema:
args_schema: ArgsType
kwargs_schema: KwargsType
schema_info: Optional[RuntimeSchemaInfo] = None
schema_info: RuntimeSchemaInfo | None = None
_comparison_key: Optional[tuple[object, ...]] = None
_comparison_key: tuple[object, ...] | None = None
@property
def args_spec(self) -> tuple[DTensorSpec, ...]:
@ -560,7 +560,7 @@ class OutputSharding:
# specifies the output sharding pattern
output_spec: OutputSpecType
# schema for redistribution if needed
redistribute_schema: Optional[OpSchema] = None
redistribute_schema: OpSchema | None = None
# flag indicating if inputs need redistribution
needs_redistribute: bool = False
# flag to use values from `redistribute_schema`
@ -596,7 +596,7 @@ class OpInfo:
flat_args_schema: list[object]
local_args: Sequence[object]
local_kwargs: dict[str, object]
args_tree_spec: Optional[TreeSpec] = None
args_tree_spec: TreeSpec | None = None
# the output sharding info
output_sharding: Optional[OutputSharding] = None
output_sharding: OutputSharding | None = None

View File

@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import string
from typing import cast, Optional
from typing import cast
import torch
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
@ -44,7 +44,7 @@ def einop_rule(
op_schema: OpSchema,
*,
linearity: bool = False,
enforce_sharding: Optional[dict[str, int]] = None,
enforce_sharding: dict[str, int] | None = None,
) -> OutputSharding:
"""
Propagate the sharding of inputs to output for ops whose data moves according to einsum notation.

View File

@ -1,14 +1,13 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from dataclasses import dataclass
from typing import Optional
import torch
@dataclass
class MaskBuffer:
data: Optional[torch.Tensor] = None
data: torch.Tensor | None = None
# refcount allows shared usage of the MaskBuffer, as long as all users have the same data
refcount: int = 0

View File

@ -4,7 +4,7 @@ import math
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import cast, Optional, Union
from typing import cast, Union
import torch
from torch.distributed.device_mesh import DeviceMesh
@ -47,7 +47,7 @@ class Reduction(Enum):
@dataclass(frozen=True)
class NormReduction:
norm_type: Union[int, float, str]
norm_type: int | float | str
ReductionOpType = Union[NormReduction, str]
@ -71,9 +71,9 @@ class _NormPartial(Partial):
similarly for inf and -inf norm. For 0-norm, the reduction op is sum.
"""
norm_type: Union[int, float, str] = 2
norm_type: int | float | str = 2
def __init__(self, norm_type: Union[int, float, str] = 2):
def __init__(self, norm_type: int | float | str = 2):
reduce_op = None
if norm_type in (float("inf"), "inf"):
reduce_op = "max"
@ -164,7 +164,7 @@ class _NormPartial(Partial):
return 1 + hash(self.norm_type)
def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[list[int]]:
def _infer_reduction_dims(dims_arg: object, ndim: int) -> list[int] | None:
if dims_arg is None:
return None
dims = cast(list[int], as_list(dims_arg))
@ -1081,7 +1081,7 @@ def _common_norm_backward_strategy(
out_tuple_strategy = OpStrategy([])
for idx, input_placement_strategy in enumerate(input_strategy.strategies):
# args for OpSpec
output_specs_list: list[Optional[DTensorSpec]] = []
output_specs_list: list[DTensorSpec | None] = []
input_specs_list: list[DTensorSpec] = []
redistribute_costs = []

View File

@ -2,8 +2,6 @@
# implement matrix related ops for distributed tensor
from typing import Optional
import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
@ -708,7 +706,7 @@ def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrate
) = op_schema.args_schema
return_debug_mask = len(op_schema.args_schema) >= 8 and rest_args[2]
has_attn_bias = attn_bias_strategy is not None
debug_attn_mask_sharding: Optional[Placement] = (
debug_attn_mask_sharding: Placement | None = (
Replicate() if return_debug_mask else None
)
@ -1073,7 +1071,7 @@ def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy:
)
def valid_grouped_mm_strides(
input_specs: list[DTensorSpec], output_specs: tuple[Optional[DTensorSpec], ...]
input_specs: list[DTensorSpec], output_specs: tuple[DTensorSpec | None, ...]
) -> bool:
# 1. compute the local-tensor shape/strides given this sharding proposal
# 2. apply the logic from the groped_mm meta function

View File

@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from collections.abc import Sequence
from typing import cast, Optional
from typing import cast
import torch
from torch.distributed.tensor._dtensor_spec import DTensorSpec
@ -493,7 +493,7 @@ def common_pointwise_strategy(
followed_strategy: OpStrategy,
followed_strategy_index: int,
linearity: int = -1,
scalar_tensor_idx: Optional[int] = None,
scalar_tensor_idx: int | None = None,
) -> OpStrategy:
"""
Common strategy for pointwise operations.
@ -713,11 +713,11 @@ def list_pointwise_strategy(
def args_tuple_strategies(
args_schema: tuple[object, ...],
) -> list[Optional[TupleStrategy]]:
) -> list[TupleStrategy | None]:
first_arg = args_schema[0]
assert isinstance(first_arg, TupleStrategy)
strategy_len = len(first_arg.children)
tuple_strategies: list[Optional[TupleStrategy]] = []
tuple_strategies: list[TupleStrategy | None] = []
for arg_idx, arg in enumerate(args_schema):
if isinstance(arg, TupleStrategy):
# every tuple strategy should have the same length
@ -743,7 +743,7 @@ def list_pointwise_strategy(
for child_idx, child_strtgy in enumerate(follow_strategy.children):
assert isinstance(child_strtgy, OpStrategy)
args_schema: list[Optional[OpStrategy]] = [
args_schema: list[OpStrategy | None] = [
cast(OpStrategy, arg_strategy.children[child_idx]) if arg_strategy else None
for arg_strategy in args_strategies
]

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from collections.abc import Sequence, Sized
from typing import cast, Optional
from typing import cast
import torch
from torch._prims_common import IntLike
@ -721,7 +721,7 @@ def _derive_follow_placements_from_tuple_strategy(
# current replicate, just follow new placement
return new_placement
follow_placements: Optional[list[Placement]] = None
follow_placements: list[Placement] | None = None
mesh = tuple_strategy.child_mesh(0)
for arg_strategy in tuple_strategy.children:
if not isinstance(arg_strategy, OpStrategy):
@ -887,7 +887,7 @@ def prop_index_select(op_schema: OpSchema) -> OutputSharding:
if not isinstance(indices_spec, DTensorSpec):
raise AssertionError(f"Expected DTensorSpec, got {type(indices_spec)}")
all_indices_spec: list[Optional[DTensorSpec]] = [
all_indices_spec: list[DTensorSpec | None] = [
indices_spec if dim == i else None for i in range(values_spec.ndim)
]
@ -934,7 +934,7 @@ def prop_index_put(op_schema: OpSchema) -> StrategyType:
op_strategy = OpStrategy([])
# 1. `indices` should all be replicated first.
indices_redistribute_costs = []
new_indices_spec: list[Optional[DTensorSpec]] = []
new_indices_spec: list[DTensorSpec | None] = []
for indices_spec_child in indices_spec.children:
if not isinstance(indices_spec_child, OpStrategy):
raise AssertionError(f"Expected OpStrategy, got {type(indices_spec_child)}")
@ -1044,7 +1044,7 @@ def prop_index(op_schema: OpSchema) -> OutputSharding:
raise AssertionError(f"Expected DTensorSpec, got {type(values_spec)}")
if not isinstance(multi_indices_spec, list):
raise AssertionError(f"Expected list, got {type(multi_indices_spec)}")
multi_indices_spec = cast(list[Optional[DTensorSpec]], multi_indices_spec)
multi_indices_spec = cast(list[DTensorSpec | None], multi_indices_spec)
valid_indices_spec: list[tuple[int, DTensorSpec]] = [
(i, a) for i, a in enumerate(multi_indices_spec) if a is not None
]

View File

@ -2,7 +2,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from collections.abc import Callable, Iterable, Sequence
from dataclasses import dataclass
from typing import cast, Optional, Union
from typing import cast
import torch
from torch import Tensor
@ -216,7 +216,7 @@ def expand(input_shape: Shape, shape: Shape) -> DimMap:
return tuple(mapping)
def normalize_sizes(sizes: Union[Shape, tuple[Shape]]) -> Shape:
def normalize_sizes(sizes: Shape | tuple[Shape]) -> Shape:
if isinstance(sizes[0], int):
return cast(Shape, sizes)
elif len(sizes) == 1:
@ -428,7 +428,7 @@ def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap:
return tuple(dimmap)
def dim_squeeze(shape: Shape, dim: Optional[int] = None) -> DimMap:
def dim_squeeze(shape: Shape, dim: int | None = None) -> DimMap:
# FIXME: this is wrong when dim=None and one of the dimensions
# equals size of the mesh. For example squeeze(DTensor(tensor(4), Shard[0])) could
# end up as squeeze(tensor(1)) if we have 4 devices; this would lead to
@ -457,7 +457,7 @@ def dim_view_as_real(shape: Shape) -> DimMap:
return tuple(results)
def dim_reduction(ndim: int, dim_or_dims: Optional[DimsType], keepdim: bool) -> DimMap:
def dim_reduction(ndim: int, dim_or_dims: DimsType | None, keepdim: bool) -> DimMap:
"""
General fallback for reduction ops where Partial() does not apply.
@ -542,7 +542,7 @@ def propagate_shape_and_sharding(
def maybe_get_shard_mesh_dim_and_placement(
input_dim: InputDim,
) -> tuple[Optional[int], Optional[Shard]]:
) -> tuple[int | None, Shard | None]:
# if input_dim is sharded, return the mesh_dim and shard placement
for i, placement in enumerate(input_src_placements):
if isinstance(placement, Shard) and placement.dim == input_dim.input_dim:
@ -556,7 +556,7 @@ def propagate_shape_and_sharding(
# 1 and 2 doesn't require the info of whether current input is sharded.
# 3 requires that info, to decide whether we can error out. Maybe we can refactor
# to make this function purely "theoretical".
def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]:
def get_in_dim_to_shard(cmd: DimSpec) -> InputDim | None:
if isinstance(cmd, InputDim):
return cmd
elif isinstance(cmd, Flatten):
@ -692,7 +692,7 @@ def propagate_shape_and_sharding(
def register_op_strategy_map(
aten_op_overload: torch._ops.OpOverload,
local_op_name: Callable[..., torch.Tensor],
schema_info: Optional[RuntimeSchemaInfo] = None,
schema_info: RuntimeSchemaInfo | None = None,
strict_view: bool = False,
) -> None:
"""

View File

@ -4,7 +4,7 @@ import functools
import itertools
import operator
from collections.abc import Callable, Iterable, Sequence
from typing import cast, Optional, TypeVar, Union
from typing import cast, TypeVar
from typing_extensions import ParamSpec
import torch
@ -36,8 +36,8 @@ _P = ParamSpec("_P")
# convenient wrapper to register sharding propagation rules
def register_prop_rule(
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
schema_info: Optional[RuntimeSchemaInfo] = None,
op: torch._ops.OpOverload | list[torch._ops.OpOverload],
schema_info: RuntimeSchemaInfo | None = None,
) -> Callable[
[Callable[[OpSchema], OutputSharding]], Callable[[OpSchema], OutputSharding]
]:
@ -127,9 +127,9 @@ def replicate_op_strategy(op_schema: OpSchema) -> StrategyType:
def as_list(
x: Union[list[object], object],
x: list[object] | object,
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
) -> Union[list[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
) -> list[object] | torch.fx.immutable_collections.immutable_list: # type: ignore[valid-type]
# During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
# which is an object but treated as a list by the tracer. Therefore, keep
# `immutable_list` intact here as well.
@ -295,9 +295,10 @@ def expand_to_full_mesh_op_strategy(
*,
input_index: int = 1,
inplace_op: bool = False,
is_valid_strategy_cb: Optional[
Callable[[list[DTensorSpec], tuple[Optional[DTensorSpec], ...]], bool]
] = None,
is_valid_strategy_cb: Callable[
[list[DTensorSpec], tuple[DTensorSpec | None, ...]], bool
]
| None = None,
) -> OpStrategy:
"""
Convenience function to allow writing a sharding strategy considering only a single mesh dimension,
@ -332,7 +333,7 @@ def expand_to_full_mesh_op_strategy(
all_strategies = []
for strategy_comb in strategy_combs:
spec_list: list[Optional[DTensorSpec]] = []
spec_list: list[DTensorSpec | None] = []
for specs in zip(*strategy_comb):
if specs[0] is not None:
# TODO: we should fill in tensor_meta here. If nothing else, it helps the filter strategy callback
@ -354,7 +355,7 @@ def expand_to_full_mesh_op_strategy(
# input_spec matches the first argument's runtime sharding, otherwise we skip
continue
output_specs: tuple[Optional[DTensorSpec], ...]
output_specs: tuple[DTensorSpec | None, ...]
if input_index > 1:
output_specs = tuple(spec_list[:input_index])
else:

View File

@ -3,7 +3,7 @@
import contextlib
import warnings
from logging import getLogger
from typing import Optional, Union
from typing import Optional
import torch
from torch.distributed.device_mesh import _get_device_handle, DeviceMesh
@ -171,7 +171,7 @@ class _RNGStateTracker:
self._use_distribute_region = value
def _distribute_region(
self, spec: DTensorSpec, generator: Optional[torch.Generator] = None
self, spec: DTensorSpec, generator: torch.Generator | None = None
):
pass
@ -237,7 +237,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
@contextlib.contextmanager
def _distribute_region(
self, spec: DTensorSpec, generator: Optional[torch.Generator] = None
self, spec: DTensorSpec, generator: torch.Generator | None = None
):
if generator is not None:
# This is a little hacky, but for any user-passed generator, we store its state under a unique key,
@ -327,7 +327,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
mesh = spec.mesh
# note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP
# case. Replace the custom logic with dim_map once we support it.
dim_map: list[Union[int, list[int]]] = [-1] * spec.ndim
dim_map: list[int | list[int]] = [-1] * spec.ndim
for i, placement in enumerate(spec.placements):
if isinstance(placement, Shard):
shard_dim = placement.dim

View File

@ -8,7 +8,7 @@ import weakref
from collections import defaultdict
from collections.abc import Sequence
from functools import cache
from typing import cast, NamedTuple, Optional
from typing import cast, NamedTuple
import torch
import torch.distributed._functional_collectives as funcol
@ -88,7 +88,7 @@ class DTensorRedistributePlanner:
class DistState:
placements: tuple[Placement, ...]
tensor_dim_to_mesh_dim: ShardOrder
_hash: Optional[int] = dataclasses.field(
_hash: int | None = dataclasses.field(
default=None, init=False, repr=False, compare=False
)
@ -161,7 +161,7 @@ class DTensorRedistributePlanner:
mesh: DeviceMesh,
transform_infos: Sequence[_TransformInfo],
src_placement: tuple[Placement, ...],
src_shard_order: Optional[ShardOrder] = None,
src_shard_order: ShardOrder | None = None,
) -> str:
"""
Generate a string representation of the sequence of state transitions
@ -646,7 +646,7 @@ class DTensorRedistributePlanner:
def _gen_transform_infos_non_cached(
src_spec: DTensorSpec,
dst_spec: DTensorSpec,
use_graph_based_transform: Optional[bool] = None,
use_graph_based_transform: bool | None = None,
) -> list[_TransformInfo]:
transform_infos: list[_TransformInfo] = []
device_mesh = src_spec.device_mesh
@ -678,7 +678,7 @@ def _gen_transform_infos_non_cached(
def _gen_transform_infos(
src_spec: DTensorSpec,
dst_spec: DTensorSpec,
use_graph_based_transform: Optional[bool] = None,
use_graph_based_transform: bool | None = None,
) -> list[_TransformInfo]:
return _gen_transform_infos_non_cached(
src_spec, dst_spec, use_graph_based_transform
@ -692,7 +692,7 @@ def redistribute_local_tensor(
*,
async_op: bool = False,
is_backward: bool = False,
use_graph_based_transform: Optional[bool] = None,
use_graph_based_transform: bool | None = None,
) -> torch.Tensor:
"""
This redistribute the local tensor (torch.Tensor) from the current DTensorSpec to
@ -846,8 +846,8 @@ class Redistribute(torch.autograd.Function):
device_mesh: DeviceMesh,
placements: tuple[Placement, ...],
async_op: bool = False,
forward_dtype: Optional[torch.dtype] = None,
backward_dtype: Optional[torch.dtype] = None,
forward_dtype: torch.dtype | None = None,
backward_dtype: torch.dtype | None = None,
):
ctx.async_op = async_op
ctx.backward_dtype = backward_dtype

View File

@ -4,7 +4,7 @@ import threading
from collections.abc import Callable, Sequence
from functools import lru_cache
from itertools import chain
from typing import cast, Optional, Union
from typing import cast
import torch
from torch._guards import detect_fake_mode
@ -69,9 +69,7 @@ class ShardingPropagator:
)
# op map to save indices of shape (and stride) args which may need to be
# modified in sharding prop
self.op_to_shape_and_stride_idx: dict[
OpOverload, Union[int, tuple[int, int]]
] = {
self.op_to_shape_and_stride_idx: dict[OpOverload, int | tuple[int, int]] = {
# new factory ops
aten.new_empty.default: 1,
aten.new_full.default: 1,
@ -91,7 +89,7 @@ class ShardingPropagator:
self,
op_overload: OpOverload,
rule_func: Callable[[OpSchema], OutputSharding],
schema_info: Optional[RuntimeSchemaInfo] = None,
schema_info: RuntimeSchemaInfo | None = None,
):
"""
Register a sharding propagation rule for an operator.
@ -104,7 +102,7 @@ class ShardingPropagator:
self,
op_overload: OpOverload,
strategy_func: Callable[[OpSchema], StrategyType],
schema_info: Optional[RuntimeSchemaInfo] = None,
schema_info: RuntimeSchemaInfo | None = None,
):
"""
Register a :class:`OpStrategy` generator for an operator.
@ -157,7 +155,7 @@ class ShardingPropagator:
def _propagate_tensor_meta_non_cached(
self, op_schema: OpSchema
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
) -> None | TensorMeta | Sequence[TensorMeta | None]:
"""
Propagate the tensor metadata, it could either return a TensorMeta
or a list/tuple of TensorMetas
@ -191,7 +189,7 @@ class ShardingPropagator:
)
elif isinstance(fake_out, (tuple, list)):
tensor_meta_list: list[Optional[TensorMeta]] = []
tensor_meta_list: list[TensorMeta | None] = []
for fake_out_item in fake_out:
if isinstance(fake_out_item, torch.Tensor):
tensor_meta_list.append(
@ -215,7 +213,7 @@ class ShardingPropagator:
@lru_cache # noqa: B019
def _propagate_tensor_meta(
self, op_schema: OpSchema
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
) -> None | TensorMeta | Sequence[TensorMeta | None]:
"""
Cached version of _propagate_tensor_meta_non_cached
This is a private API. Use propagate_tensor_meta instead.
@ -224,7 +222,7 @@ class ShardingPropagator:
def propagate_tensor_meta(
self, op_schema: OpSchema
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
) -> None | TensorMeta | Sequence[TensorMeta | None]:
"""
Propagate the tensor metadata, it could either return a TensorMeta
or a list/tuple of TensorMetas. This is a public API that should be
@ -239,7 +237,7 @@ class ShardingPropagator:
self,
op: OpOverload,
output_specs: OutputSpecType,
output_tensor_meta: Union[None, TensorMeta, Sequence[Optional[TensorMeta]]],
output_tensor_meta: None | TensorMeta | Sequence[TensorMeta | None],
) -> OutputSpecType:
"""
Wrap the output_specs with the tensor metadata from the output.
@ -260,7 +258,7 @@ class ShardingPropagator:
)
return output_specs.shallow_copy_with_tensor_meta(output_tensor_meta)
elif isinstance(output_specs, (tuple, list)):
new_specs: list[Optional[DTensorSpec]] = []
new_specs: list[DTensorSpec | None] = []
if not isinstance(output_tensor_meta, (tuple, list)) or len(
output_specs
) != len(output_tensor_meta):
@ -587,7 +585,7 @@ class ShardingPropagator:
)
def _select_strategy(
self, strategy: OpStrategy, op_schema: Optional[OpSchema] = None
self, strategy: OpStrategy, op_schema: OpSchema | None = None
) -> OpSpec:
if len(strategy.strategies) == 1:
# short cut with only one possible OpSpec

View File

@ -1,7 +1,7 @@
import threading
from collections import defaultdict
from collections.abc import Sequence
from typing import cast, Optional
from typing import cast
import torch
import torch.distributed._functional_collectives as funcol
@ -159,7 +159,7 @@ def compute_local_shape_and_global_offset(
def _compute_local_shape_and_global_offset(
global_shape: ShapeType,
mesh_shape: ShapeType,
my_coordinate: Optional[list[int]],
my_coordinate: list[int] | None,
placements: Sequence[Placement],
) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""

View File

@ -5,7 +5,7 @@ torchrun --standalone --nnodes=1 --nproc-per-node=4 comm_mode_features_example.p
import argparse
import os
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
@ -55,7 +55,7 @@ class CommDebugModeExample:
self.device_type = get_device_type()
def _MLP_model_setup(
self, model_type: type, parallelize_plan: Union[None, dict] = None
self, model_type: type, parallelize_plan: None | dict = None
) -> tuple[nn.Module, torch.Tensor]:
"""
Creates MLP or MLPStacked model for examples

View File

@ -5,7 +5,6 @@ torchrun --standalone --nnodes=1 --nproc-per-node=4 flex_attention_cp.py
import os
from functools import lru_cache
from typing import Optional
import torch
import torch.distributed as dist
@ -27,8 +26,8 @@ def get_device_type() -> str:
@lru_cache
def create_block_mask_cached(
score_mod: _mask_mod_signature,
B: Optional[int],
H: Optional[int],
B: int | None,
H: int | None,
M: int,
N: int,
device: str = "cuda",

View File

@ -7,7 +7,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from dataclasses import dataclass
from enum import auto, Enum
from functools import partial
from typing import Any, cast, Optional, Protocol, TypeAlias
from typing import Any, cast, Protocol, TypeAlias
import torch
import torch.distributed as dist
@ -140,8 +140,8 @@ class _SDPAMerger:
def __init__(self, convert_to_f32: bool, seq_dim: int):
self._seq_dim = seq_dim
self._out: Optional[torch.Tensor] = None
self._lse: Optional[torch.Tensor] = None
self._out: torch.Tensor | None = None
self._lse: torch.Tensor | None = None
self._should_lse_squeeze = False
self._convert_to_f32 = convert_to_f32
self._out_dtype = torch.float32
@ -250,7 +250,7 @@ class _AllToAllRotater(_RingRotater):
def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None:
self._pg = pg
self._seq_dim = seq_dim
self._buffer: Optional[torch.Tensor] = None
self._buffer: torch.Tensor | None = None
def exchange_buffers(self, curr_buffer: torch.Tensor) -> None:
curr_buffer = curr_buffer.contiguous()
@ -272,7 +272,7 @@ class _AllGatherRotater(_RingRotater):
def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None:
self._pg = pg
self._seq_dim = seq_dim
self._aggregated_buffer: Optional[torch.Tensor] = None
self._aggregated_buffer: torch.Tensor | None = None
self._idx = 0
def exchange_buffers(self, curr_buffer: torch.Tensor) -> None:
@ -293,7 +293,7 @@ class _AllGatherRotater(_RingRotater):
def _create_rotater(
pg: dist.ProcessGroup, seq_dim: int, method: Optional[_RotateMethod] = None
pg: dist.ProcessGroup, seq_dim: int, method: _RotateMethod | None = None
) -> _RingRotater:
if method is None:
method = _cp_options.rotate_method
@ -655,7 +655,7 @@ def _scaled_dot_product_ring_flash_attention(
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: Optional[float] = None,
scale: float | None = None,
) -> tuple[torch.Tensor, ...]:
if return_debug_mask:
raise NotImplementedError("return_debug_mask is not supported yet")
@ -681,12 +681,12 @@ def _scaled_dot_product_ring_efficient_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
attn_bias: torch.Tensor | None = None,
compute_log_sumexp: bool = True,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: Optional[float] = None,
scale: float | None = None,
) -> tuple[torch.Tensor, ...]:
if attn_bias is not None:
raise NotImplementedError("attn_bias is not supported yet")
@ -718,13 +718,13 @@ def _scaled_dot_product_ring_cudnn_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
attn_bias: torch.Tensor | None = None,
compute_log_sumexp: bool = True,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: Optional[float] = None,
scale: float | None = None,
) -> tuple[torch.Tensor, ...]:
if attn_bias is not None:
raise NotImplementedError("attn_bias is not supported yet")
@ -769,7 +769,7 @@ def _scaled_dot_product_ring_flash_attention_backward(
philox_seed: torch.Tensor,
philox_offset: torch.Tensor,
*,
scale: Optional[float] = None,
scale: float | None = None,
) -> tuple[torch.Tensor, ...]:
# TODO: remove this hardcoding
seq_dim = 2
@ -812,7 +812,7 @@ def _scaled_dot_product_ring_efficient_attention_backward(
grad_input_mask: tuple[bool, ...],
is_causal: bool = False,
*,
scale: Optional[float] = None,
scale: float | None = None,
) -> tuple[torch.Tensor, ...]:
# TODO: remove this hardcoding
seq_dim = 2
@ -856,7 +856,7 @@ def _scaled_dot_product_ring_cudnn_attention_backward(
dropout_p: float,
is_causal: bool,
*,
scale: Optional[float] = None,
scale: float | None = None,
) -> tuple[torch.Tensor, ...]:
# TODO: remove this hardcoding
seq_dim = 2
@ -938,8 +938,8 @@ exitsing_custom_ops = DTensor._op_dispatcher._custom_op_handlers
ArgsType = tuple[Any, ...]
KwargsType = dict[str, Any]
InputFnType = Callable[[Optional[nn.Module], ArgsType, KwargsType, DeviceMesh], Any]
OutputFnType = Callable[[Optional[nn.Module], Any, Any, DeviceMesh], Any]
InputFnType = Callable[[nn.Module | None, ArgsType, KwargsType, DeviceMesh], Any]
OutputFnType = Callable[[nn.Module | None, Any, Any, DeviceMesh], Any]
_replaced_functions: dict[Callable, tuple[str, Callable]] = {}
@ -1039,7 +1039,7 @@ def _context_parallel_buffers(
mesh: DeviceMesh,
buffers: list[torch.Tensor | BlockMask],
buffer_seq_dims: list[int],
load_balancer: Optional[_LoadBalancer] = None,
load_balancer: _LoadBalancer | None = None,
) -> list[torch.Tensor | BlockMask]:
"""
Shard the buffers along the sequence dimensions according to CP rules.
@ -1136,7 +1136,7 @@ def _create_cp_block_mask(
Q_LEN: int,
KV_LEN: int,
device_mesh: DeviceMesh,
load_balancer: Optional[_LoadBalancer] = None,
load_balancer: _LoadBalancer | None = None,
) -> BlockMask:
"""
Creates a specialized BlockMask for Context Parallel FlexAttention.
@ -1197,7 +1197,7 @@ def _create_cp_block_mask(
rank: int,
block_size: int,
local_q_size: int,
qkv_rearrange_indices: Optional[torch.Tensor] = None,
qkv_rearrange_indices: torch.Tensor | None = None,
) -> _mask_mod_signature:
assert qkv_rearrange_indices is None or qkv_rearrange_indices.ndim == 2, (
"load balance index expects shape (1, seq_len) or (B, seq_len) "
@ -1301,7 +1301,7 @@ class _ContextParallel(ParallelStyle):
raise ValueError(f"Unknown attention type: {self.attention_type}")
def flex_input_fn(
self, module: Optional[nn.Module], args: Any, kwargs: Any, mesh: DeviceMesh
self, module: nn.Module | None, args: Any, kwargs: Any, mesh: DeviceMesh
) -> Any:
args_list = list(args)
for idx, name in enumerate(
@ -1329,7 +1329,7 @@ class _ContextParallel(ParallelStyle):
def sdpa_input_fn(
self,
module: Optional[nn.Module],
module: nn.Module | None,
args: tuple[Any, ...],
kwargs: dict[str, Any],
mesh: DeviceMesh,
@ -1351,7 +1351,7 @@ class _ContextParallel(ParallelStyle):
return new_args, new_kwargs
def sdpa_output_fn(
self, module: Optional[nn.Module], inputs: Any, outputs: Any, mesh: DeviceMesh
self, module: nn.Module | None, inputs: Any, outputs: Any, mesh: DeviceMesh
) -> Any:
new_outputs = []
for output in [outputs] if isinstance(outputs, torch.Tensor) else outputs:
@ -1373,7 +1373,7 @@ def _context_parallel_shard(
mesh: DeviceMesh,
buffers: CPBufferContainer,
seq_dims: CPBufferSeqDims,
load_balancer: Optional[_LoadBalancer] = None,
load_balancer: _LoadBalancer | None = None,
) -> list[torch.Tensor | BlockMask]:
"""
Shard the buffers along the specified sequence dimensions (`seq_dims`), so that each
@ -1464,9 +1464,9 @@ def _disable_context_parallel_dispatcher() -> None:
def context_parallel(
mesh: DeviceMesh,
*,
buffers: Optional[list[torch.Tensor]] = None,
buffer_seq_dims: Optional[list[int]] = None,
no_restore_buffers: Optional[set[torch.Tensor]] = None,
buffers: list[torch.Tensor] | None = None,
buffer_seq_dims: list[int] | None = None,
no_restore_buffers: set[torch.Tensor] | None = None,
) -> Generator[None, None, None]:
"""
@ -1554,7 +1554,7 @@ def context_parallel_unshard(
mesh: DeviceMesh,
buffers: list[torch.Tensor],
seq_dims: list[int],
load_balancer: Optional[_LoadBalancer] = None,
load_balancer: _LoadBalancer | None = None,
) -> list[torch.Tensor]:
"""
Unshard the tensors (e.g., output) that are sharded due to context parallelism.

View File

@ -2,7 +2,6 @@
# for different load-balancing strategies in tensor sharding.
import functools
from abc import ABC, abstractmethod
from typing import Optional
import torch
from torch import Tensor
@ -12,7 +11,7 @@ from torch.nn.attention.flex_attention import BlockMask
# make it private since it's still a prototype
class _LoadBalancer(ABC):
@abstractmethod
def _generate_indices(self, restore: bool = False) -> Optional[Tensor]:
def _generate_indices(self, restore: bool = False) -> Tensor | None:
"""
Generate indices for load balancing.
Args:
@ -478,7 +477,7 @@ class _PTRRLoadBalancer(_LoadBalancer):
def _create_default_load_balancer(
seq_length: int, world_size: int, device: str | torch.device
) -> Optional[_LoadBalancer]:
) -> _LoadBalancer | None:
from ._attention import _cp_options
if _cp_options.enable_load_balance:

View File

@ -24,11 +24,11 @@ OutputPlacements = Union[PlacementType, tuple[PlacementType, ...]]
def local_map(
func: Optional[Callable] = None,
func: Callable | None = None,
out_placements: OutputPlacements = None,
in_placements: InputPlacements = None,
in_grad_placements: InputPlacements = None,
device_mesh: Optional[DeviceMesh] = None,
device_mesh: DeviceMesh | None = None,
*,
redistribute_inputs: bool = False,
):
@ -163,7 +163,7 @@ def _local_map_wrapped(
out_placements: OutputPlacements,
in_placements: InputPlacements,
in_grad_placements: InputPlacements,
device_mesh: Optional[DeviceMesh],
device_mesh: DeviceMesh | None,
redistribute_inputs: bool,
*args,
**kwargs,

View File

@ -2,7 +2,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from collections.abc import Callable, Sequence
from functools import partial
from typing import Union
import torch
from torch._ops import OpOverload
@ -21,7 +20,7 @@ from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy
__all__ = ["register_sharding"]
def register_sharding(op: Union[OpOverload, list[OpOverload]]):
def register_sharding(op: OpOverload | list[OpOverload]):
"""
:meth:`register_sharding` is an experimental API that allows users to register sharding
strategies for an operator when the tensor inputs and outputs are DTensor.

View File

@ -2,7 +2,7 @@
import copy
import operator
from collections.abc import Sequence
from typing import Any, cast, Optional
from typing import Any, cast
import torch
from torch._subclasses.fake_tensor import FakeTensor
@ -273,7 +273,7 @@ def _create_placement_strategy(
node: Node,
mesh: DeviceMesh,
placements: tuple[Placement, ...],
input_specs: Optional[Sequence[DTensorSpec]] = None,
input_specs: Sequence[DTensorSpec] | None = None,
) -> OpSpec:
"""
Util function to construct an OpSpec for a given node.

View File

@ -1,5 +1,5 @@
from functools import partial
from typing import no_type_check, Optional
from typing import no_type_check
import torch
from torch.distributed._functional_collectives import AsyncCollectiveTensor
@ -21,7 +21,7 @@ def sync_grad_hook(grad, *, device_handle=None, compute_stream=None):
def _flatten_tensor(
tensor: torch.Tensor,
) -> tuple[torch.Tensor, Optional[DTensorSpec]]:
) -> tuple[torch.Tensor, DTensorSpec | None]:
if isinstance(tensor, DTensor):
tensor._local_tensor.requires_grad_()
return tensor._local_tensor, tensor._spec

View File

@ -1,7 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import warnings
from fnmatch import fnmatch
from typing import Optional, Union
import torch
import torch.nn as nn
@ -14,10 +13,10 @@ __all__ = ["parallelize_module"]
def parallelize_module( # type: ignore[return]
module: nn.Module,
device_mesh: Optional[DeviceMesh] = None,
parallelize_plan: Optional[Union[ParallelStyle, dict[str, ParallelStyle]]] = None,
device_mesh: DeviceMesh | None = None,
parallelize_plan: ParallelStyle | dict[str, ParallelStyle] | None = None,
*,
src_data_rank: Optional[int] = 0,
src_data_rank: int | None = 0,
) -> nn.Module:
"""
Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any, Optional
from typing import Any
import torch.nn as nn
from torch.distributed.tensor.parallel._data_parallel_utils import (
@ -48,7 +48,7 @@ def _reconstruct_dtensor(module: nn.Module, _input: Any):
def _localize_dtensor(
module: nn.Module, *_: Any, ignored_params: Optional[set[nn.Parameter]] = None
module: nn.Module, *_: Any, ignored_params: set[nn.Parameter] | None = None
):
"""
Convert DTensor parameters to local tensors

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import copy
from typing import Any, cast, Optional
from typing import Any, cast
import torch
import torch.distributed as dist
@ -297,7 +297,7 @@ def _pre_load_state_dict(
def _all_gather_dtensor(
tensor: DTensor,
parent_mesh: Optional[DeviceMesh],
parent_mesh: DeviceMesh | None,
) -> torch.Tensor:
"""All gather a DTensor in its FSDP dimension and return the local tensor."""
assert parent_mesh == tensor.device_mesh
@ -336,7 +336,7 @@ class DTensorExtensions(FSDPExtensions):
def pre_flatten_transform(
self,
tensor: torch.Tensor,
) -> tuple[torch.Tensor, Optional[Any]]:
) -> tuple[torch.Tensor, Any | None]:
return _flatten_tensor(tensor)
def post_unflatten_transform(
@ -365,7 +365,7 @@ class DTensorExtensions(FSDPExtensions):
world_size: int,
num_devices_per_node: int,
pg: dist.ProcessGroup,
device: Optional[torch.device] = None,
device: torch.device | None = None,
) -> torch.Tensor:
return _chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg)
@ -386,6 +386,6 @@ class DTensorExtensions(FSDPExtensions):
def all_gather_dtensor(
self,
tensor: DTensor,
parent_mesh: Optional[DeviceMesh],
parent_mesh: DeviceMesh | None,
) -> torch.Tensor:
return _all_gather_dtensor(tensor, parent_mesh)

View File

@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from functools import partial
from typing import Any, Optional
from typing import Any
import torch
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard
@ -14,7 +14,7 @@ __all__ = [
def input_reshard(
module: torch.nn.Module,
tp_device_mesh: DeviceMesh,
input_reshard_dim: Optional[int] = None,
input_reshard_dim: int | None = None,
) -> torch.nn.Module:
"""
Register hooks to an nn.Module for input resharding, enabling sharding and restoration during backward computation.
@ -42,7 +42,7 @@ def input_reshard(
if input_reshard_dim is None:
return module
cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None
cx: torch.autograd.graph.saved_tensors_hooks | None = None
def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: tuple[Any, ...]) -> None:
saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks(

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import contextlib
from typing import cast, Optional
from typing import cast
import torch
import torch._prims_common as utils
@ -201,8 +201,8 @@ def _log_softmax_backward_handler(
def _nll_loss_forward(
x: Tensor,
target: Tensor,
weight: Optional[Tensor],
local_weight: Optional[Tensor],
weight: Tensor | None,
local_weight: Tensor | None,
reduction: int,
ignore_index: int,
input_shape: torch.Size,
@ -356,7 +356,7 @@ def _nll_loss_and_log_softmax_backward(
grad_output: Tensor,
x: Tensor,
target: Tensor,
weight: Optional[Tensor],
weight: Tensor | None,
reduction: int,
ignore_index: int,
total_weight: Tensor,

View File

@ -2,7 +2,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Optional, Union
from typing import Any
import torch
import torch.nn as nn
@ -36,7 +36,7 @@ class ParallelStyle(ABC):
flexibility for different kind of style implementations.
"""
src_data_rank: Optional[int] = 0
src_data_rank: int | None = 0
@abstractmethod
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ...
@ -82,8 +82,8 @@ class ColwiseParallel(ParallelStyle):
def __init__(
self,
*,
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
input_layouts: Placement | None = None,
output_layouts: Placement | None = None,
use_local_output: bool = True,
):
super().__init__()
@ -212,8 +212,8 @@ class RowwiseParallel(ParallelStyle):
def __init__(
self,
*,
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
input_layouts: Placement | None = None,
output_layouts: Placement | None = None,
use_local_output: bool = True,
):
super().__init__()
@ -473,14 +473,10 @@ class PrepareModuleInput(ParallelStyle):
def __init__(
self,
*,
input_layouts: Optional[
Union[Placement, tuple[Optional[Placement], ...]]
] = None,
desired_input_layouts: Optional[
Union[Placement, tuple[Optional[Placement], ...]]
] = None,
input_kwarg_layouts: Optional[dict[str, Placement]] = None,
desired_input_kwarg_layouts: Optional[dict[str, Placement]] = None,
input_layouts: Placement | tuple[Placement | None, ...] | None = None,
desired_input_layouts: Placement | tuple[Placement | None, ...] | None = None,
input_kwarg_layouts: dict[str, Placement] | None = None,
desired_input_kwarg_layouts: dict[str, Placement] | None = None,
use_local_output: bool = False,
):
self.input_layouts = (
@ -513,8 +509,8 @@ class PrepareModuleInput(ParallelStyle):
self,
input: Any,
mesh: DeviceMesh,
input_layout: Optional[Placement],
desired_layout: Optional[Placement],
input_layout: Placement | None,
desired_layout: Placement | None,
):
if input_layout is not None:
if isinstance(input, DTensor):
@ -637,8 +633,8 @@ class PrepareModuleOutput(ParallelStyle):
def __init__(
self,
*,
output_layouts: Union[Placement, tuple[Optional[Placement], ...]],
desired_output_layouts: Union[Placement, tuple[Placement, ...]],
output_layouts: Placement | tuple[Placement | None, ...],
desired_output_layouts: Placement | tuple[Placement, ...],
use_local_output: bool = True,
):
self.output_layouts = (
@ -768,17 +764,13 @@ class PrepareModuleInputOutput(ParallelStyle):
def __init__(
self,
*,
input_layouts: Optional[
Union[Placement, tuple[Optional[Placement], ...]]
] = None,
desired_input_layouts: Optional[
Union[Placement, tuple[Optional[Placement], ...]]
] = None,
input_kwarg_layouts: Optional[dict[str, Placement]] = None,
desired_input_kwarg_layouts: Optional[dict[str, Placement]] = None,
input_layouts: Placement | tuple[Placement | None, ...] | None = None,
desired_input_layouts: Placement | tuple[Placement | None, ...] | None = None,
input_kwarg_layouts: dict[str, Placement] | None = None,
desired_input_kwarg_layouts: dict[str, Placement] | None = None,
use_local_input: bool = False,
output_layouts: Union[Placement, tuple[Optional[Placement], ...]],
desired_output_layouts: Union[Placement, tuple[Placement, ...]],
output_layouts: Placement | tuple[Placement | None, ...],
desired_output_layouts: Placement | tuple[Placement, ...],
use_local_output: bool = True,
):
self.prepare_module_input = PrepareModuleInput(

View File

@ -2,7 +2,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from dataclasses import dataclass, field
from typing import cast, Optional
from typing import cast
import torch
import torch._C
@ -129,7 +129,7 @@ class Shard(torch._C._distributed.Shard):
curr_local_size: int,
num_chunks: int,
rank: int,
) -> tuple[int, Optional[int]]:
) -> tuple[int, int | None]:
return Shard.local_shard_size_and_offset(curr_local_size, num_chunks, rank)
@staticmethod
@ -151,7 +151,7 @@ class Shard(torch._C._distributed.Shard):
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
src_data_rank: Optional[int] = 0,
src_data_rank: int | None = 0,
) -> torch.Tensor:
"""
shard and scatter a tensor on a mesh dimension (use coordinate
@ -203,7 +203,7 @@ class Shard(torch._C._distributed.Shard):
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
src_data_rank: Optional[int] = 0,
src_data_rank: int | None = 0,
) -> torch.Tensor:
shard_placement = cls(dim)
return shard_placement._shard_tensor(tensor, mesh, mesh_dim, src_data_rank)
@ -566,7 +566,7 @@ class _StridedShard(torch._C._distributed.StridedShard, Shard):
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
src_data_rank: Optional[int] = 0,
src_data_rank: int | None = 0,
split_factor: int = 1,
) -> torch.Tensor:
strided_shard_placement = cls(dim=dim, split_factor=split_factor)
@ -689,7 +689,7 @@ class _StridedShard(torch._C._distributed.StridedShard, Shard):
curr_local_size: int,
num_chunks: int,
rank: int,
) -> tuple[int, Optional[int]]:
) -> tuple[int, int | None]:
# indices_tensor is 1D torch.arange(logical_dim_size) unsqueezed
# so that we can reuse self._split_tensor which splits on self.dim
shape = [1] * self.dim + [curr_local_size]
@ -742,7 +742,7 @@ class Replicate(torch._C._distributed.Replicate):
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
src_data_rank: Optional[int] = 0,
src_data_rank: int | None = 0,
) -> torch.Tensor:
"""
Replicate (broadcast) a torch.Tensor on a mesh dimension (use
@ -765,7 +765,7 @@ class Replicate(torch._C._distributed.Replicate):
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
src_data_rank: Optional[int] = 0,
src_data_rank: int | None = 0,
) -> torch.Tensor:
return Replicate._make_replicate_tensor(tensor, mesh, mesh_dim, src_data_rank)
@ -859,7 +859,7 @@ class MaskPartial(Partial):
mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
# required fields for computing the local offset and deriving the mask
offset_shape: Optional[torch.Size] = None
offset_shape: torch.Size | None = None
offset_dim: int = 0
def __init__(

View File

@ -44,7 +44,7 @@ def _pack_kwargs(*args: Any, **kwargs: Any) -> tuple[tuple[Any, ...], tuple[str,
def _cast_forward_inputs(
dtype: Optional[torch.dtype],
dtype: torch.dtype | None,
*args: Any,
**kwargs: Any,
) -> tuple[Any, Any]:
@ -257,7 +257,7 @@ def _apply_to_tensors(fn, container):
def _to_kwargs(
inputs: tuple[Any, ...],
kwargs: Optional[dict[str, Any]],
kwargs: dict[str, Any] | None,
target_device: torch.device,
use_side_stream_for_tensor_copies: bool,
) -> tuple[tuple[Any, ...], tuple[dict[str, Any], ...]]: