mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165556 Approved by: https://github.com/fduwjj ghstack dependencies: #165554, #165555
1314 lines
61 KiB
Python
1314 lines
61 KiB
Python
# mypy: allow-untyped-defs
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
import logging
|
|
import os
|
|
import threading
|
|
import warnings
|
|
from collections.abc import Iterator
|
|
from itertools import zip_longest
|
|
from typing import Optional, TYPE_CHECKING, Union
|
|
|
|
import torch
|
|
from torch.distributed import is_available
|
|
from torch.distributed._mesh_layout import _MeshLayout
|
|
from torch.distributed._pycute import is_int, suffix_product
|
|
from torch.utils._typing_utils import not_none
|
|
|
|
|
|
__all__ = ["init_device_mesh", "DeviceMesh"]
|
|
|
|
|
|
if not is_available():
|
|
import sys
|
|
|
|
# We need to create the stubs when distributed is not available.
|
|
# Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```),
|
|
# since it would try to import ``torch.distributed.device_mesh`` or
|
|
# ``torch.distributed.init_device_mesh`` but cannot find them.
|
|
|
|
class _DeviceMeshStub:
|
|
pass
|
|
|
|
def _init_device_mesh_stub():
|
|
pass
|
|
|
|
sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined]
|
|
sys.modules[
|
|
"torch.distributed.device_mesh"
|
|
].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined]
|
|
|
|
|
|
else:
|
|
from torch._C._distributed_c10d import Backend as C10dBackend
|
|
from torch.distributed.distributed_c10d import (
|
|
_get_default_group,
|
|
_resolve_process_group,
|
|
get_backend,
|
|
get_process_group_ranks,
|
|
get_rank,
|
|
get_world_size,
|
|
init_process_group,
|
|
is_initialized,
|
|
new_group,
|
|
ProcessGroup,
|
|
split_group,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# only import numpy typing when type checking
|
|
if TYPE_CHECKING:
|
|
try:
|
|
from numpy.typing import ArrayLike
|
|
except ImportError:
|
|
logger.warning(
|
|
"DeviceMesh requires numpy >= 1.21 to be installed for type checking"
|
|
)
|
|
|
|
BackendConfig = tuple[Optional[str], Optional[C10dBackend.Options]]
|
|
torch.serialization.add_safe_globals([_MeshLayout])
|
|
|
|
class _MeshEnv(threading.local):
|
|
def __init__(self) -> None:
|
|
self.mesh_stack: list[DeviceMesh] = []
|
|
|
|
def get_current_mesh(self) -> "DeviceMesh":
|
|
if len(self.mesh_stack) == 0:
|
|
raise RuntimeError("No device mesh is currently active!")
|
|
return self.mesh_stack[-1]
|
|
|
|
# TODO: to remove it once we move all use cases into new API.
|
|
def get_root_mesh(self, device_mesh: "DeviceMesh") -> "DeviceMesh":
|
|
# If a mesh could not be found in the child_to_root_mapping, it is a root mesh itself.
|
|
# A root mesh is not created through slicing.
|
|
# We considers the root mesh of a root mesh is itself.
|
|
# We keep this function for backward compatibility.
|
|
warnings.warn(
|
|
"This get_root_mesh API will be deprecated soon."
|
|
"Please use `get_root_mesh` inside DeviceMesh instead."
|
|
)
|
|
if not device_mesh:
|
|
return device_mesh
|
|
return device_mesh._get_root_mesh()
|
|
|
|
@staticmethod
|
|
def num_devices_per_host(device_type: str) -> int:
|
|
return _get_device_handle(device_type).device_count()
|
|
|
|
@staticmethod
|
|
def num_hosts(device_type: str) -> int:
|
|
# ProcessGroup can't tell us this info so we have to infer it, assume
|
|
# homogeneous hardware for now
|
|
return get_world_size() // _MeshEnv.num_devices_per_host(device_type)
|
|
|
|
# TODO: to remove it once we move all use cases into new API.
|
|
# We keep this API for backward compatibility.
|
|
def _get_all_submeshes(
|
|
self, device_mesh: "DeviceMesh", mesh_dim_name: str
|
|
) -> list["DeviceMesh"]:
|
|
warnings.warn(
|
|
"This _get_all_submeshes API will be deprecated soon."
|
|
"Please use `_get_all_submeshes` inside DeviceMesh instead."
|
|
)
|
|
return device_mesh._get_all_submeshes(mesh_dim_name)
|
|
|
|
_mesh_resources: _MeshEnv = _MeshEnv()
|
|
|
|
def _get_device_handle(device_type: str = "cuda"):
|
|
"""
|
|
Get the module corresponding to the device_type which is cuda or cuda-like device.
|
|
For example, when the device_type is cuda, the module `torch.cuda` is returned.
|
|
Return None when there is no corresponding module for device_type, otherwise
|
|
return the corresponding module.
|
|
"""
|
|
return getattr(torch, device_type, None)
|
|
|
|
class DeviceMesh:
|
|
"""
|
|
DeviceMesh represents a mesh of devices, where layout of devices could be
|
|
represented as a n-d dimension array, and each value of the n-d dimensional
|
|
array is the global id of the default process group ranks.
|
|
|
|
DeviceMesh could be used to setup the N dimensional device connections across the cluster,
|
|
and manage the ProcessGroups for N dimensional parallelisms. Communications could happen on
|
|
each dimension of the DeviceMesh separately. DeviceMesh respects the device that user selects
|
|
already (i.e. if user call `torch.cuda.set_device` before the DeviceMesh initialization),
|
|
and will select/set the device for the current process if user does not set the device
|
|
beforehand. Note that manual device selection should happen BEFORE the DeviceMesh initialization.
|
|
|
|
DeviceMesh can also be used as a context manager when using together with DTensor APIs.
|
|
|
|
.. note::
|
|
DeviceMesh follows SPMD programming model, which means the same PyTorch Python program
|
|
is running on all processes/ranks in the cluster. Therefore, users need to make sure the
|
|
`mesh` array (which describes the layout of devices) should be identical across all ranks.
|
|
Inconsistent `mesh` will lead to silent hang.
|
|
|
|
Args:
|
|
device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
|
|
mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout
|
|
of devices, where the IDs are global IDs of the default process group.
|
|
_rank (int): (experimental/internal)
|
|
The global rank of the current process. If not provided, it will
|
|
be inferred from the default process group.
|
|
|
|
Returns:
|
|
DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
|
|
|
|
The following program runs on each process/rank in an SPMD manner. In this example, we have 2
|
|
hosts with 4 GPUs each.
|
|
A reduction over the first dimension of mesh will reduce across
|
|
columns (0, 4), .. and (3, 7), a reduction over the second dimension
|
|
of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7).
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +SKIP("no rank")
|
|
>>> from torch.distributed.device_mesh import DeviceMesh
|
|
>>>
|
|
>>> # Initialize device mesh as (2, 4) to represent the topology
|
|
>>> # of cross-host(dim 0), and within-host (dim 1).
|
|
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
|
|
"""
|
|
|
|
_device_type: str
|
|
_rank_map: torch.Tensor
|
|
_mesh_dim_names: Optional[tuple[str, ...]]
|
|
_layout: _MeshLayout
|
|
_root_mesh: Optional["DeviceMesh"] = None
|
|
# Record flatten mesh name to its flattened mesh in root mesh.
|
|
_flatten_mapping: dict[str, "DeviceMesh"]
|
|
|
|
def __init__(
|
|
self,
|
|
device_type: str,
|
|
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
|
|
*,
|
|
mesh_dim_names: Optional[tuple[str, ...]] = None,
|
|
backend_override: Optional[tuple[BackendConfig, ...]] = None,
|
|
_init_backend: bool = True,
|
|
_rank: Optional[int] = None,
|
|
_layout: Optional[_MeshLayout] = None,
|
|
_rank_map: Optional[torch.Tensor] = None,
|
|
_root_mesh: Optional["DeviceMesh"] = None,
|
|
) -> None:
|
|
if mesh is not None:
|
|
if _layout is not None or _rank_map is not None:
|
|
raise TypeError(
|
|
"Cannot provide _layout and/or _rank_map if passing explicit mesh"
|
|
)
|
|
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
|
|
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
|
|
mesh_tensor = (
|
|
mesh.detach().to(dtype=torch.int).contiguous()
|
|
if isinstance(mesh, torch.Tensor)
|
|
else torch.tensor(mesh, device="cpu", dtype=torch.int)
|
|
)
|
|
_layout = _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
|
|
_rank_map = mesh_tensor.flatten()
|
|
else:
|
|
if _layout is None or _rank_map is None:
|
|
raise TypeError(
|
|
"The mesh argument is required except for PRIVATE USAGE ONLY!"
|
|
)
|
|
|
|
assert _layout.check_non_overlap(), (
|
|
"Please use a non-overlapping layout when creating a DeviceMesh."
|
|
)
|
|
assert _rank_map.ndim == 1, "The rank map must be 1-dimensional"
|
|
assert _rank_map.is_contiguous(), "The rank map must be contiguous"
|
|
assert _rank_map.numel() >= _layout.cosize(), (
|
|
f"The rank map contains {_rank_map.numel()} element, "
|
|
f"which isn't large enough for layout {_layout}"
|
|
)
|
|
|
|
self._device_type = device_type
|
|
self._layout = _layout
|
|
self._rank_map = _rank_map
|
|
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
|
|
self._root_mesh = _root_mesh
|
|
|
|
if backend_override is None:
|
|
backend_override = ((None, None),) * len(self._layout)
|
|
elif len(backend_override) != len(self._layout):
|
|
raise ValueError(
|
|
f"backend_override should have the same length as the number of mesh dimensions, "
|
|
f"but got {len(backend_override)} and {len(self._layout)}."
|
|
)
|
|
|
|
# Skip process group initialization if xla device or init backend is False
|
|
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
|
|
self._thread_id = None
|
|
if device_type != "xla":
|
|
# always try to create default (world) pg, even if it is not initialized
|
|
# already. The world pg is used for device mesh identity (rank) on each
|
|
# process (we need to know if the current global rank is in the mesh or not).
|
|
if _init_backend:
|
|
self._setup_world_group_and_device()
|
|
self._dim_group_names = self._init_process_groups(
|
|
self._layout,
|
|
self._rank_map,
|
|
self._mesh_dim_names,
|
|
backend_override,
|
|
)
|
|
|
|
if is_initialized() and get_backend() == "threaded":
|
|
# pyrefly: ignore # bad-assignment
|
|
self._thread_id = threading.get_ident()
|
|
|
|
if _rank is None:
|
|
_rank = get_rank()
|
|
|
|
# calculate the coordinates of the current global rank on the mesh
|
|
rank_coords = (self.mesh == _rank).nonzero()
|
|
assert rank_coords.size(0) in (0, 1)
|
|
self._coordinate_on_dim: Optional[list[int]] = (
|
|
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
|
)
|
|
|
|
# private field to pre-generate DeviceMesh's hash
|
|
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
|
# Initialize instance-specific flatten mapping
|
|
self._flatten_mapping = {}
|
|
|
|
@property
|
|
def device_type(self) -> str:
|
|
"""Returns the device type of the mesh."""
|
|
return self._device_type
|
|
|
|
@property
|
|
def mesh(self) -> torch.Tensor:
|
|
"""Returns the tensor representing the layout of devices."""
|
|
full_mesh = self._layout.remap_to_tensor(self._rank_map)
|
|
if full_mesh.size(0) == 1:
|
|
return full_mesh[0]
|
|
my_coords = (full_mesh == get_rank()).nonzero()
|
|
if my_coords.size(0) > 0:
|
|
return full_mesh[my_coords[0, 0]]
|
|
raise RuntimeError(
|
|
"In order to get the mesh Tensor of a DeviceMesh it needs to "
|
|
"either have all its original dimensions (e.g., no slicing) "
|
|
"or it needs to contain the local rank"
|
|
)
|
|
|
|
@property
|
|
def mesh_dim_names(self) -> Optional[tuple[str, ...]]:
|
|
"""Returns the names of mesh dimensions."""
|
|
return self._mesh_dim_names
|
|
|
|
def _setup_world_group_and_device(self):
|
|
default_initialized = is_initialized()
|
|
# TODO: think about how to allow pg options to be passed to world group
|
|
# or mesh dimension groups
|
|
if not default_initialized:
|
|
init_process_group()
|
|
|
|
world_size = get_world_size()
|
|
if self._layout.numel() > world_size:
|
|
raise RuntimeError(
|
|
f"Mesh should not be bigger than default world size {world_size}, but found {self._layout.numel()} ranks!"
|
|
)
|
|
|
|
# ONLY set the device if the current device is not initialized, if user already
|
|
# set the device before DeviceMesh init, we respect the user's choice.
|
|
device_handle = _get_device_handle(self._device_type)
|
|
if device_handle and not device_handle.is_initialized():
|
|
# auto set the cuda/cuda-like device only if user has not set it, if there's LOCAL_RANK
|
|
# env variable from launchers, we use it to set the device.
|
|
if "LOCAL_RANK" in os.environ:
|
|
local_rank = int(os.environ["LOCAL_RANK"])
|
|
logger.info(
|
|
"Setting default device for the current process based on LOCAL_RANK=%s",
|
|
local_rank,
|
|
)
|
|
device_handle.set_device(local_rank)
|
|
else:
|
|
warnings.warn(
|
|
"It seems like you did not set/select the default device for the current process before the DeviceMesh "
|
|
"initialization or use a launcher (i.e. torchrun) which populates `LOCAL_RANK` environment variable. "
|
|
"It is recommended to set the current device for the process BEFORE the DeviceMesh initialization so that "
|
|
"the underlying communicator (i.e. NCCL) can be initialized properly. "
|
|
"Given that the current process has no default device selected, DeviceMesh will use a heuristic to set the "
|
|
"device_id via `global_rank % num_devices_per_host`, assuming homogeneous hardware cluster. "
|
|
)
|
|
# heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host
|
|
# NOTE: This device selection would only work for homogeneous hardware.
|
|
num_devices_per_host = device_handle.device_count()
|
|
if (
|
|
world_size > num_devices_per_host
|
|
and world_size % num_devices_per_host != 0
|
|
):
|
|
raise RuntimeError(
|
|
f"DeviceMesh only support homogeneous hardware, but found "
|
|
f"{world_size} ranks and {num_devices_per_host} {self._device_type} devices!"
|
|
)
|
|
device_handle.set_device(get_rank() % num_devices_per_host)
|
|
|
|
return _get_default_group()
|
|
|
|
@staticmethod
|
|
def _init_process_groups(
|
|
layout: _MeshLayout,
|
|
rank_map: torch.Tensor,
|
|
mesh_dim_names: Optional[tuple[str, ...]],
|
|
backend_override: tuple[BackendConfig, ...],
|
|
) -> list[str]:
|
|
# group_name associated with each mesh dimension, each
|
|
# mesh dimension should have one sub-group per rank
|
|
#
|
|
dim_group_names: list[str] = []
|
|
default_group = _get_default_group()
|
|
|
|
if (
|
|
len(layout) == 1
|
|
and layout.numel() == get_world_size()
|
|
and backend_override[0] == (None, None)
|
|
):
|
|
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
|
|
# Otherwise, create new pg.
|
|
ranks = list(range(get_world_size()))
|
|
dim_group = (
|
|
new_group(
|
|
backend="cpu:gloo,cuda:nccl",
|
|
ranks=ranks,
|
|
group_desc="mesh_default",
|
|
)
|
|
if torch.cuda.is_available()
|
|
and get_backend(default_group) == "gloo"
|
|
else default_group
|
|
)
|
|
dim_group_names.append(dim_group.group_name)
|
|
else:
|
|
# create sub pgs base on the mesh argument specified
|
|
for dim in range(len(layout)):
|
|
# swap the current dim to the last dim
|
|
# then reshape to flatten out other dims
|
|
pg_ranks_by_dim = layout[dim].nest().remap_to_tensor(rank_map)
|
|
backend, pg_options = backend_override[dim]
|
|
# We need to explicitly pass in timeout when specified in option, otherwise
|
|
# the default timeout will be used to override the timeout set in option.
|
|
# TODO: remove this once we have fixed inside c10d level.
|
|
timeout = pg_options._timeout if pg_options else None
|
|
|
|
# If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description
|
|
# of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`.
|
|
# If the mesh doesn't not have a mesh_dim_names, then the group description of the
|
|
# subgroup would be `mesh_dim_0` and `mesh_dim_1`.
|
|
group_desc = (
|
|
f"mesh_{mesh_dim_names[dim]}"
|
|
if mesh_dim_names
|
|
else f"mesh_dim_{dim}"
|
|
)
|
|
|
|
# If bound_device_id exists, it means the nccl communicator has been eagerly initialized
|
|
# so that we can use `split_group` to create subgroups through `ncclCommSplit`.
|
|
# In this case, we only need to make one API call (`split_group``) for the subgroup creation
|
|
# for each mesh dimension. In a 2 * 4 mesh, we only need to make 2 API calls per ranks to create
|
|
# all the subgroups.
|
|
# Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The
|
|
# numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4
|
|
# mesh, we need to make 2 + 4 = 6 API calls per ranks to create all the subgroups.
|
|
dim_group = None
|
|
has_split_group = False
|
|
if (
|
|
(
|
|
bound_device_id := getattr(
|
|
default_group, "bound_device_id", None
|
|
)
|
|
)
|
|
is not None
|
|
and torch.cuda.is_available()
|
|
and (
|
|
backend is None
|
|
or default_group._get_backend(torch.device("cuda")).name()
|
|
== backend
|
|
)
|
|
):
|
|
dim_group = split_group(
|
|
parent_pg=default_group,
|
|
timeout=timeout,
|
|
pg_options=pg_options,
|
|
split_ranks=pg_ranks_by_dim.tolist(),
|
|
group_desc=group_desc,
|
|
)
|
|
has_split_group = True
|
|
|
|
# If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
|
|
# and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup.
|
|
# Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim`
|
|
# along with appending information to the `dim_group_names` list whenever necessary.
|
|
for dim_mesh in pg_ranks_by_dim:
|
|
subgroup_ranks = dim_mesh.tolist()
|
|
|
|
# We temporarily revert the reuse subgroup, since it breaks two internal tests.
|
|
# Temporarily reverting to resolve test timeout while root-causing.
|
|
# TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.
|
|
# pyrefly: ignore # unbound-name
|
|
if bound_device_id is None or not has_split_group:
|
|
dim_group = new_group(
|
|
ranks=subgroup_ranks,
|
|
timeout=timeout,
|
|
backend=backend,
|
|
pg_options=pg_options,
|
|
group_desc=group_desc,
|
|
)
|
|
|
|
# only add to dim_groups if the current rank in the subgroup
|
|
if get_rank() in subgroup_ranks:
|
|
if len(dim_group_names) > dim:
|
|
raise RuntimeError(
|
|
f"Each device mesh dimension should get only one process group, but got {get_rank()} "
|
|
f"in {subgroup_ranks}!"
|
|
)
|
|
dim_group_names.append(dim_group.group_name) # type: ignore[union-attr]
|
|
return dim_group_names
|
|
|
|
def _get_root_mesh(self) -> "DeviceMesh":
|
|
return self._root_mesh if self._root_mesh else self
|
|
|
|
def __enter__(self) -> "DeviceMesh":
|
|
# set this mesh as the current mesh in mesh env
|
|
_mesh_resources.mesh_stack.append(self)
|
|
return self
|
|
|
|
# pyre-fixme[2]: Parameter must be annotated.
|
|
def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
|
|
# pop this mesh from mesh env
|
|
_mesh_resources.mesh_stack.pop()
|
|
|
|
def __repr__(self) -> str:
|
|
device_mesh_repr = (
|
|
f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._layout.top_level_sizes))})"
|
|
if self._mesh_dim_names
|
|
else f"{self._layout.top_level_sizes}"
|
|
)
|
|
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._layout.strides}"
|
|
# We only print the mesh tensor if the debug mode is turned on.
|
|
if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL":
|
|
device_mesh_repr += f", Mesh: {self.mesh.tolist()}"
|
|
return f"{device_mesh_repr})"
|
|
|
|
def __hash__(self):
|
|
# lazily compute hash
|
|
self._hash = getattr(self, "_hash", None)
|
|
if not self._hash:
|
|
self._hash = hash(
|
|
(
|
|
self._flatten_mesh_list,
|
|
self._layout,
|
|
self._device_type,
|
|
self._mesh_dim_names,
|
|
self._thread_id,
|
|
self._root_mesh,
|
|
)
|
|
)
|
|
return self._hash
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
if self is other:
|
|
return True
|
|
if not isinstance(other, DeviceMesh):
|
|
return False
|
|
return (
|
|
self._flatten_mesh_list == other._flatten_mesh_list
|
|
and self._layout == other._layout
|
|
and self._device_type == other._device_type
|
|
and self._mesh_dim_names == other._mesh_dim_names
|
|
and self._thread_id == other._thread_id
|
|
and self._root_mesh == other._root_mesh
|
|
)
|
|
|
|
def __getitem__(
|
|
self, mesh_dim_names: Union[str, tuple[str, ...]]
|
|
) -> "DeviceMesh":
|
|
"""
|
|
Slice the current DeviceMesh based on the mesh_dim_names given to create a submesh.
|
|
The submesh created consists of the dimensions and the communicators indicated by
|
|
``mesh_dim_names``
|
|
|
|
Args:
|
|
mesh_dim_names (Union[str, Tuple[str]]): the name or the tuple of names of the
|
|
mesh dimension of the DeviceMesh to create the submesh for.
|
|
Returns:
|
|
A :class:`DeviceMesh` object
|
|
|
|
The following program runs on each process/rank in an SPMD manner in a world size of 8.
|
|
In the first example:
|
|
Calling mesh_2d["tp"] on rank 0, 1, 2, 3 returns a 1D submesh of DeviceMesh:([0, 1, 2, 3]).
|
|
Calling mesh_2d["tp"] on rank 4, 5, 6, 7 returns a 1D submesh of DeviceMesh:([4, 5, 6, 7]).
|
|
Calling mesh_2d["dp"] on rank 0, 4 returns a 1D submesh of DeviceMesh:([0, 4]).
|
|
Calling mesh_2d["dp"] on rank 1, 5 returns a 1D submesh of DeviceMesh:([1, 5]).
|
|
Calling mesh_2d["dp"] on rank 2, 6 returns a 1D submesh of DeviceMesh:([2, 6]).
|
|
Calling mesh_2d["dp"] on rank 3, 7 returns a 1D submesh of DeviceMesh:([3, 7]).
|
|
|
|
In the second example:
|
|
Calling mesh_3d["dp", "cp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 1], [4, 5]]).
|
|
Calling mesh_3d["dp", "cp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 3], [6, 7]]).
|
|
Calling mesh_3d["cp", "dp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 4], [1, 5]]).
|
|
Calling mesh_3d["cp", "dp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 6], [3, 7]]).
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +SKIP("no rank")
|
|
>>> from torch.distributed.device_mesh import DeviceMesh
|
|
>>>
|
|
>>> # Initialize a 2D device mesh as (2, 4) to represent the topology
|
|
>>> # of cross-host(dim 0), and within-host (dim 1).
|
|
>>> mesh_2d = init_device_mesh(device_type="cuda", (2,4), mesh_dim_names=("dp", "tp"))
|
|
>>> tp_mesh = mesh_2d["tp"]
|
|
>>> dp_mesh = mesh_2d["dp"]
|
|
>>>
|
|
>>> # Initialize a 3D mesh.
|
|
>>> mesh_3d = init_device_mesh(device_type="cuda", (2,2,2), mesh_dim_names=("dp", "pp", "cp"))
|
|
>>> # The order of the mesh_dim_names provided deteremines the order of dimensions in the submesh.
|
|
>>> dp_cp_mesh = mesh_3d["dp", "cp"]
|
|
>>> cp_dp_mesh = mesh_3d["cp", "dp"]
|
|
"""
|
|
if not self._mesh_dim_names:
|
|
raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!")
|
|
|
|
mesh_dim_names = (
|
|
(mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names
|
|
)
|
|
|
|
if mesh_dim_names == self._mesh_dim_names:
|
|
return self
|
|
else:
|
|
sliced_mesh_layout = self._get_slice_mesh_layout(mesh_dim_names)
|
|
# When using FakeTensorMode to trace the model, `_create_sub_mesh()` will
|
|
# fail as it will require a real tensor to manipulate.
|
|
# `unset_fake_temporarily()` will allow us to materialize the tensors
|
|
# within `_create_sub_mesh`, which should not affect modling.
|
|
#
|
|
# Note that this should be orthogonal to torch.compile(). But whether
|
|
# we can compile device_mesh `slicing` (no graph break) is not verified
|
|
# yet and need a follow-up,
|
|
# TODO: compiler + device_mesh slicing.
|
|
with torch._subclasses.fake_tensor.unset_fake_temporarily():
|
|
submesh = self._create_sub_mesh(sliced_mesh_layout, mesh_dim_names)
|
|
return submesh
|
|
|
|
def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup:
|
|
"""
|
|
Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the
|
|
DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh.
|
|
|
|
Args:
|
|
mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index
|
|
of the mesh dimension. Default is None.
|
|
|
|
Returns:
|
|
A :class:`ProcessGroup` object.
|
|
"""
|
|
if not hasattr(self, "_dim_group_names"):
|
|
raise RuntimeError("DeviceMesh process groups not initialized!")
|
|
|
|
if len(self._layout) > 1 and mesh_dim is None:
|
|
raise RuntimeError(
|
|
f"Found the DeviceMesh have {len(self._layout)} dimensions",
|
|
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
|
|
"If you want to get the list of all the ProcessGroups in the DeviceMesh,"
|
|
"please use `get_all_groups()` instead.",
|
|
)
|
|
|
|
# Quick return if the current device_mesh is a 1D mesh.
|
|
if len(self._layout) == 1 and mesh_dim is None:
|
|
return not_none(_resolve_process_group(self._dim_group_names[0]))
|
|
|
|
root_mesh = self._get_root_mesh()
|
|
root_to_flatten_mapping = root_mesh._flatten_mapping
|
|
if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys():
|
|
dim_group_name = root_to_flatten_mapping[
|
|
mesh_dim # type: ignore[index]
|
|
]._dim_group_names[0]
|
|
return not_none(_resolve_process_group(dim_group_name))
|
|
else:
|
|
mesh_dim = (
|
|
self._get_mesh_dim_by_name(mesh_dim)
|
|
if isinstance(mesh_dim, str)
|
|
else mesh_dim
|
|
)
|
|
assert isinstance(mesh_dim, int)
|
|
return not_none(_resolve_process_group(self._dim_group_names[mesh_dim]))
|
|
|
|
def get_all_groups(self) -> list[ProcessGroup]:
|
|
"""
|
|
Returns a list of ProcessGroups for all mesh dimensions.
|
|
|
|
Returns:
|
|
A list of :class:`ProcessGroup` object.
|
|
"""
|
|
return [self.get_group(i) for i in range(len(self._layout))]
|
|
|
|
def _create_sub_mesh(
|
|
self,
|
|
layout: _MeshLayout,
|
|
submesh_dim_names: tuple[str, ...],
|
|
) -> "DeviceMesh":
|
|
root_mesh = self._get_root_mesh()
|
|
slice_dim_group_name = []
|
|
for name in submesh_dim_names:
|
|
if name in not_none(self._mesh_dim_names):
|
|
slice_dim_group_name.append(
|
|
self._dim_group_names[ # type: ignore[has-type]
|
|
not_none(self._mesh_dim_names).index(name)
|
|
]
|
|
)
|
|
else:
|
|
# If device_mesh is not root_mesh, we already throw error in _get_slice_mesh_layout
|
|
# Since we will deprecate the slicing of flattened dim_name from root mesh soon,
|
|
# we don't want to optimize the code furthermore.
|
|
flatten_mesh = self._flatten_mapping[name]
|
|
slice_dim_group_name.append(
|
|
flatten_mesh._dim_group_names[ # type: ignore[has-type]
|
|
not_none(flatten_mesh._mesh_dim_names).index(name)
|
|
]
|
|
)
|
|
res_submesh = DeviceMesh(
|
|
self._device_type,
|
|
_layout=layout,
|
|
_rank_map=root_mesh._rank_map,
|
|
mesh_dim_names=submesh_dim_names,
|
|
_root_mesh=root_mesh,
|
|
_init_backend=False,
|
|
)
|
|
res_submesh._dim_group_names = slice_dim_group_name
|
|
return res_submesh
|
|
|
|
def _create_flatten_mesh(
|
|
self,
|
|
mesh_dim_name: Optional[str] = None,
|
|
backend_override: BackendConfig = (None, None),
|
|
) -> "DeviceMesh":
|
|
root_mesh = self._get_root_mesh()
|
|
|
|
if not mesh_dim_name:
|
|
mesh_dim_name = "_".join(not_none(self._mesh_dim_names))
|
|
|
|
# Flatten a 1D device mesh into its original mesh_dim_name will return itself.
|
|
if self.ndim == 1 and mesh_dim_name in not_none(self._mesh_dim_names):
|
|
return self
|
|
|
|
# Check whether the mesh_dim_name for flattened mesh is valid.
|
|
invalid_dim_names = not_none(root_mesh._mesh_dim_names)
|
|
if mesh_dim_name in invalid_dim_names:
|
|
raise ValueError(
|
|
f"{mesh_dim_name} already exists for submesh of the {root_mesh}. ",
|
|
f"The mesh_dim_names of submesh and flattened mesh are {invalid_dim_names}. "
|
|
f"Please specify another valid mesh_dim_name.",
|
|
)
|
|
|
|
flattened_mesh_layout = self._layout.coalesce()
|
|
if len(flattened_mesh_layout) > 1:
|
|
flattened_mesh_layout = flattened_mesh_layout.nest()
|
|
# Quick return if the flatten mesh has been created before.
|
|
if mesh_dim_name in root_mesh._flatten_mapping:
|
|
if (
|
|
flattened_mesh_layout
|
|
== root_mesh._flatten_mapping[mesh_dim_name]._layout
|
|
):
|
|
return root_mesh._flatten_mapping[mesh_dim_name]
|
|
else:
|
|
raise ValueError(
|
|
f"Flatten mesh with mesh_dim_name {mesh_dim_name} has been created before, "
|
|
f"Please specify another valid mesh_dim_name."
|
|
)
|
|
|
|
res_flattened_mesh = DeviceMesh(
|
|
root_mesh._device_type,
|
|
_layout=flattened_mesh_layout,
|
|
_rank_map=root_mesh._rank_map,
|
|
mesh_dim_names=(mesh_dim_name,),
|
|
_root_mesh=root_mesh,
|
|
backend_override=(backend_override,),
|
|
)
|
|
root_mesh._flatten_mapping[mesh_dim_name] = res_flattened_mesh
|
|
|
|
return res_flattened_mesh
|
|
|
|
def _get_root_mesh_dim(self) -> Optional[int]:
|
|
"""
|
|
Returns the index of the mesh dim in the root mesh.
|
|
The device_mesh passed in needs to be sliced out from the root mesh
|
|
or submesh of the root mesh.
|
|
"""
|
|
root_mesh = self._get_root_mesh()
|
|
child_mesh_dim_names = self._mesh_dim_names
|
|
if root_mesh and child_mesh_dim_names:
|
|
assert len(child_mesh_dim_names) == 1, (
|
|
"The submesh can only be a 1D mesh."
|
|
)
|
|
child_mesh_dim_name = child_mesh_dim_names[0]
|
|
return root_mesh._get_mesh_dim_by_name(child_mesh_dim_name)
|
|
return None
|
|
|
|
def _get_mesh_dim_by_name(self, mesh_dim_name: str) -> int:
|
|
if self._mesh_dim_names is None or len(self._mesh_dim_names) == 0:
|
|
raise KeyError(
|
|
"No `mesh_dim_names` found.",
|
|
)
|
|
if mesh_dim_name not in self._mesh_dim_names:
|
|
raise KeyError(
|
|
f"Mesh dimension '{mesh_dim_name}' does not exist.",
|
|
f"Available mesh dimensions are: mesh_dim_names={self._mesh_dim_names}",
|
|
)
|
|
return not_none(self._mesh_dim_names.index(mesh_dim_name))
|
|
|
|
def _get_slice_mesh_layout(
|
|
self, mesh_dim_names: tuple[str, ...]
|
|
) -> _MeshLayout:
|
|
"""
|
|
Validate whether the mesh_dim_names is valid for slicing the given device_mesh.
|
|
If valid, return dim indexes of the slice mesh in the device mesh.
|
|
"""
|
|
slice_from_root = True
|
|
if self != self._get_root_mesh():
|
|
warnings.warn(
|
|
"You are attempting to slice a submesh from another submesh. While we support this operation, "
|
|
"it is users' responsibility to ensure that the submesh is consistently sliced across all ranks. "
|
|
"If not, this may result in some ranks receiving the submesh while others encounter errors."
|
|
)
|
|
slice_from_root = False
|
|
|
|
# The slice mesh_dim_names should consist either the current device_mesh's mesh_dim_names
|
|
# or its flattened mesh's mesh_dim_names if it's root_mesh.
|
|
flatten_name_to_root_layout = (
|
|
{
|
|
key: mesh._layout
|
|
for key, mesh in self._get_root_mesh()._flatten_mapping.items()
|
|
}
|
|
if slice_from_root
|
|
else {}
|
|
)
|
|
valid_mesh_dim_names = [
|
|
*not_none(self._mesh_dim_names),
|
|
*flatten_name_to_root_layout,
|
|
]
|
|
|
|
if not all(
|
|
mesh_dim_name in valid_mesh_dim_names
|
|
for mesh_dim_name in mesh_dim_names
|
|
):
|
|
raise KeyError(
|
|
f"Invalid mesh_dim_names {mesh_dim_names} specified. "
|
|
f"Valid mesh_dim_names are {valid_mesh_dim_names}."
|
|
)
|
|
|
|
layout_sliced = []
|
|
for name in mesh_dim_names:
|
|
if name in not_none(self._mesh_dim_names):
|
|
layout_sliced.append(
|
|
self._layout[not_none(self._mesh_dim_names).index(name)]
|
|
)
|
|
elif name in flatten_name_to_root_layout:
|
|
warnings.warn(
|
|
"Slicing a flattened dim from root mesh will be deprecated in PT 2.11. "
|
|
"Users need to bookkeep the flattened mesh directly. "
|
|
)
|
|
layout_sliced.append(flatten_name_to_root_layout[name])
|
|
|
|
sliced_sizes = tuple(l.sizes for l in layout_sliced)
|
|
sliced_strides = tuple(l.strides for l in layout_sliced)
|
|
|
|
# The check below is from DeviceMesh's implementation before adopting CuTe layout for internal
|
|
# bookkeeping and it can be removed but we need to define what is the expected behavior.
|
|
# TODO: Remove the below check and define the expected behavior.
|
|
# Validate the order of the slice mesh dim indices.
|
|
# This needs to be in ascending order.
|
|
pre_stride = -1
|
|
for stride in reversed(sliced_strides):
|
|
# Note that with CuTe layout, we can support slicing flattened non-contiguous mesh dims with no problem.
|
|
# But this will make this behavior complicated so we decided to not support it for now.
|
|
if not is_int(stride):
|
|
raise NotImplementedError(
|
|
"Currently, this only allows slicing out a contiguous flattened dim."
|
|
)
|
|
if stride < pre_stride:
|
|
raise KeyError(
|
|
f"Invalid mesh_dim_names {mesh_dim_names} specified. "
|
|
"Mesh dim indices should be in ascending order."
|
|
)
|
|
pre_stride = stride
|
|
|
|
# When users sliced dim_names outside from current mesh, we will check whether
|
|
# there is layout overlap.
|
|
# TODO: Eventually we will just directly throw error here because
|
|
# we will deprecate the slicing of flattened dim_name from root mesh.
|
|
layout_sliced = _MeshLayout(sliced_sizes, sliced_strides)
|
|
if not layout_sliced.check_non_overlap():
|
|
raise RuntimeError(
|
|
f"Slicing overlapping dim_names {mesh_dim_names} is not allowed."
|
|
)
|
|
|
|
return layout_sliced
|
|
|
|
# TODO: to make this use case by other components public API in the future.
|
|
def _get_all_submeshes(self, mesh_dim_name: str) -> list["DeviceMesh"]:
|
|
"""
|
|
Return all the submeshes of a given mesh dimension of the device mesh.
|
|
"""
|
|
mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name)
|
|
layout = self._layout[mesh_dim]
|
|
pg_ranks_by_dim = layout.remap_to_tensor(self._rank_map)
|
|
cur_rank = self.get_rank()
|
|
res_submeshes = []
|
|
for mesh_1d in pg_ranks_by_dim:
|
|
submesh = DeviceMesh(
|
|
self._device_type,
|
|
mesh_1d,
|
|
mesh_dim_names=(mesh_dim_name,),
|
|
_init_backend=False,
|
|
)
|
|
submesh._dim_group_names = ( # type: ignore[has-type]
|
|
[self._dim_group_names[mesh_dim]] # type: ignore[has-type]
|
|
if cur_rank in mesh_1d
|
|
else []
|
|
)
|
|
res_submeshes.append(submesh)
|
|
|
|
return res_submeshes
|
|
|
|
@staticmethod
|
|
def from_group(
|
|
group: Union[ProcessGroup, list[ProcessGroup]],
|
|
device_type: str,
|
|
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
|
|
*,
|
|
mesh_dim_names: Optional[tuple[str, ...]] = None,
|
|
) -> "DeviceMesh":
|
|
"""
|
|
Constructs a :class:`DeviceMesh` with ``device_type`` from an
|
|
existing :class:`ProcessGroup` or a list of existing :class:`ProcessGroup`.
|
|
|
|
The constructed device mesh has number of dimensions equal to the
|
|
number of groups passed. For example, if a single process group is passed in,
|
|
the resulted DeviceMesh is a 1D mesh. If a list of 2 process groups is passed in,
|
|
the resulted DeviceMesh is a 2D mesh.
|
|
|
|
If more than one group is passed, then the ``mesh`` and ``mesh_dim_names`` arguments
|
|
are required. The order of the process groups passed in determines the topology of
|
|
the mesh. For example, the first process group will be the 0th dimension of the DeviceMesh.
|
|
The `mesh` tensor passed in must have the same number of dimensions as the number of process
|
|
groups passed in, and the order of the dimensions in the `mesh` tensor must match the order
|
|
in the process groups passed in.
|
|
|
|
Args:
|
|
group (ProcessGroup or list[ProcessGroup]): the existing ProcessGroup
|
|
or a list of existing ProcessGroups.
|
|
device_type (str): The device type of the mesh. Currently supports: "cpu",
|
|
"cuda/cuda-like". Passing in a device type with a GPU index, such as "cuda:0",
|
|
is not allowed.
|
|
mesh (torch.Tensor or ArrayLike, optional): A multi-dimensional array or an
|
|
integer tensor describing the layout of devices, where the IDs are global IDs
|
|
of the default process group. Default is None.
|
|
mesh_dim_names (tuple[str], optional): A tuple of mesh dimension names to assign
|
|
to each dimension of the multi-dimensional array describing the layout of devices.
|
|
Its length must match the length of `mesh_shape`. Each string in `mesh_dim_names`
|
|
must be unique. Default is None.
|
|
|
|
Returns:
|
|
DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
|
|
"""
|
|
|
|
# 1D scenario
|
|
if isinstance(group, ProcessGroup):
|
|
group_ranks = get_process_group_ranks(group)
|
|
if (
|
|
isinstance(mesh, torch.Tensor) and mesh.tolist() != group_ranks
|
|
) or (
|
|
mesh is not None
|
|
and not isinstance(mesh, torch.Tensor)
|
|
and mesh != group_ranks
|
|
):
|
|
raise ValueError(
|
|
f"Invalid mesh {str(mesh)} for ProcessGroup with ranks {group_ranks}"
|
|
)
|
|
mesh = torch.tensor(group_ranks, device="cpu", dtype=torch.int)
|
|
device_mesh = DeviceMesh(
|
|
device_type,
|
|
mesh,
|
|
mesh_dim_names=mesh_dim_names,
|
|
_init_backend=False,
|
|
)
|
|
device_mesh._dim_group_names = [group.group_name]
|
|
return device_mesh
|
|
|
|
# nD scenario
|
|
groups = list(group)
|
|
if len(groups) == 0:
|
|
raise ValueError("Expects at least one ProcessGroup to be passed")
|
|
if mesh is None:
|
|
raise ValueError("Must pass mesh if passing multiple ProcessGroups")
|
|
if mesh_dim_names is None:
|
|
raise ValueError(
|
|
"Must pass mesh_dim_names if passing multiple ProcessGroups"
|
|
)
|
|
# When init a DeviceMesh with multiple ProcessGroups directly, we need to make sure
|
|
# the mesh tensor is contiguous. Otherwise, the layout we inferred from the mesh tensor
|
|
# will have larger span than the actual tensor. This is just internal implementation detail
|
|
# and does not affect user facing behavior.
|
|
mesh = (
|
|
mesh.detach().to(dtype=torch.int, device="cpu")
|
|
if isinstance(mesh, torch.Tensor)
|
|
else torch.tensor(mesh, device="cpu", dtype=torch.int)
|
|
)
|
|
if mesh.ndim != len(groups):
|
|
raise ValueError(
|
|
"Expects mesh with ndim equal to number of ProcessGroups but got "
|
|
f"mesh {mesh.tolist()} and {len(groups)} ProcessGroups"
|
|
)
|
|
device_mesh = DeviceMesh(
|
|
device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False
|
|
)
|
|
device_mesh._dim_group_names = [group.group_name for group in groups]
|
|
return device_mesh
|
|
|
|
def size(self, mesh_dim: Optional[int] = None) -> int:
|
|
if mesh_dim is not None:
|
|
return self._layout[mesh_dim].numel()
|
|
return self._layout.numel()
|
|
|
|
@property
|
|
def ndim(self) -> int:
|
|
return len(self._layout)
|
|
|
|
@property
|
|
def shape(self) -> tuple[int, ...]:
|
|
return self._layout.top_level_sizes
|
|
|
|
def get_rank(self) -> int:
|
|
"""
|
|
Returns the current global rank.
|
|
"""
|
|
return get_rank()
|
|
|
|
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
|
|
"""
|
|
Returns the local rank of the given mesh_dim of the DeviceMesh.
|
|
|
|
Args:
|
|
mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index
|
|
of the mesh dimension. Default is None.
|
|
|
|
Returns:
|
|
An integer denotes the local rank.
|
|
|
|
The following program runs on each process/rank in an SPMD manner. In this example, we have 2
|
|
hosts with 4 GPUs each.
|
|
Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0.
|
|
Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1.
|
|
Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0.
|
|
Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1.
|
|
Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2.
|
|
Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3.
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +SKIP("no rank")
|
|
>>> from torch.distributed.device_mesh import DeviceMesh
|
|
>>>
|
|
>>> # Initialize device mesh as (2, 4) to represent the topology
|
|
>>> # of cross-host(dim 0), and within-host (dim 1).
|
|
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
|
|
"""
|
|
if self.ndim > 1 and mesh_dim is None:
|
|
raise RuntimeError(
|
|
f"Found the DeviceMesh have {len(self._layout)} dimensions",
|
|
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
|
|
)
|
|
elif mesh_dim is None:
|
|
mesh_dim = 0
|
|
|
|
mesh_dim_group = not_none(self.get_group(mesh_dim))
|
|
assert isinstance(mesh_dim_group, ProcessGroup), (
|
|
"We expect ProcessGroup before calling `get_rank`!"
|
|
)
|
|
return not_none(get_rank(mesh_dim_group))
|
|
|
|
def get_coordinate(self) -> Optional[list[int]]:
|
|
"""
|
|
Return the relative indices of this rank relative to all
|
|
dimensions of the mesh. If this rank is not part of the mesh, return None.
|
|
"""
|
|
return self._coordinate_on_dim if self._coordinate_on_dim else None
|
|
|
|
def _flatten(
|
|
self,
|
|
mesh_dim_name: Optional[str] = None,
|
|
backend_override: Union[
|
|
None, str, C10dBackend.Options, tuple[str, C10dBackend.Options]
|
|
] = None,
|
|
) -> "DeviceMesh":
|
|
"""
|
|
Returns a 1D DeviceMesh by flattening the current DeviceMesh.
|
|
|
|
If no mesh_dim_name is provided, the default is a string concatenating the mesh_dim_names of the
|
|
given submesh with each mesh_dim_name separated by "_". For example, if we have a 3D mesh
|
|
DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")), calling
|
|
mesh_3d["dp", "cp"]._flatten() will create a 1D submesh DeviceMesh([0, 2, 4, 6], mesh_dim_names=("dp_cp",))
|
|
on rank 0, 2, 4, 6 and a 1D submesh DeviceMesh([1, 3, 5, 7], mesh_dim_names=("dp_cp",)) on rank 1, 3, 5, 7.
|
|
|
|
After the flattened dimension is created, to access the flattened dimension in mesh_3d, one can use the
|
|
existing slicing method to obtain the flattened mesh through calling mesh_3d["dp_cp"].
|
|
"""
|
|
if not self._mesh_dim_names:
|
|
raise RuntimeError(
|
|
"Cannot flatten a DeviceMesh without mesh_dim_names!"
|
|
)
|
|
|
|
if backend_override is not None:
|
|
(backend_override_tuple,) = _normalize_backend_override(
|
|
{0: backend_override}, 1
|
|
)
|
|
else:
|
|
backend_override_tuple = (None, None)
|
|
|
|
return self._create_flatten_mesh(mesh_dim_name, backend_override_tuple)
|
|
|
|
def _create_unflatten_mesh(
|
|
self,
|
|
dim: int,
|
|
mesh_sizes: tuple[int, ...],
|
|
mesh_dim_names: tuple[str, ...],
|
|
backend_override: tuple[
|
|
tuple[Optional[str], Optional[C10dBackend.Options]], ...
|
|
] = ((None, None),),
|
|
) -> "DeviceMesh":
|
|
inner_layout = _MeshLayout(tuple(mesh_sizes), suffix_product(mesh_sizes))
|
|
|
|
if inner_layout.numel() != self._layout[dim].numel():
|
|
raise ValueError(
|
|
f"The product of {mesh_sizes=} is {inner_layout.numel()}, "
|
|
f"but the original dimension at dim={dim} has size {self._layout[dim].numel()}. "
|
|
f"These must be equal for unflatten to work correctly."
|
|
)
|
|
|
|
partial_layout = self._layout[dim].composition(inner_layout)
|
|
unflattened_layout = self._layout.splice(dim, dim + 1, partial_layout)
|
|
unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names))
|
|
unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names)
|
|
|
|
root_mesh = self._get_root_mesh()
|
|
res_mesh = DeviceMesh(
|
|
self.device_type,
|
|
_layout=unflattened_layout,
|
|
_rank_map=root_mesh._rank_map,
|
|
mesh_dim_names=tuple(unflattened_mesh_dim_names),
|
|
_root_mesh=root_mesh,
|
|
_init_backend=False,
|
|
)
|
|
|
|
# If original mesh has initiated its backend, we need to initialize the backend
|
|
# of unflatten dims as well.
|
|
# TODO: To make backend init more efficient with cute layout representation and support
|
|
# per dim backend init.
|
|
if hasattr(self, "_dim_group_names"):
|
|
dim_group_names = self._dim_group_names.copy()
|
|
dim_group_names[dim : dim + 1] = self._init_process_groups(
|
|
partial_layout,
|
|
root_mesh._rank_map,
|
|
mesh_dim_names,
|
|
backend_override,
|
|
)
|
|
res_mesh._dim_group_names = dim_group_names
|
|
|
|
return res_mesh
|
|
|
|
def _unflatten(
|
|
self,
|
|
dim: Union[int, str],
|
|
mesh_sizes: tuple[int, ...],
|
|
mesh_dim_names: tuple[str, ...],
|
|
backend_override: Optional[
|
|
dict[
|
|
str,
|
|
Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]],
|
|
]
|
|
] = None,
|
|
) -> "DeviceMesh":
|
|
"""
|
|
Returns a DeviceMesh by unflatten the current DeviceMesh.
|
|
|
|
This api can be used to unflatten a N-D DeviceMesh into N-1+len(mesh_sizes)-D meshes or submeshes.
|
|
The dim is the dimension to be unflattened which can be either a string or an integer.
|
|
|
|
The mesh_sizes is a tuple which specifies the shape of the mesh unflatten into for the given dim.
|
|
The mesh_dim_names is a list of strings which specifies the names of the dimensions of the mesh unflatten into.
|
|
Its length must match the length of mesh_sizes.
|
|
|
|
For example, if we have a 1D mesh DeviceMesh([0, 1, 2, 3, 4, 5, 6, 7], mesh_dim_names=("world")),
|
|
calling mesh_1d._unflatten(0, (2, 2, 4), ["dp", "pp", "tp"]) will create a 3D mesh
|
|
DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")).
|
|
|
|
Note that after calling the unflatten, there is no access to the unflattened dimension in mesh_1d, one can only
|
|
use the newly unflattened mesh to slice out the unflattened mesh dims.
|
|
"""
|
|
if isinstance(dim, int) and dim >= self.ndim:
|
|
raise ValueError(
|
|
f"dim {dim} specified in `_unflatten` is out of range {self.ndim}"
|
|
)
|
|
elif isinstance(dim, str) and dim in not_none(self.mesh_dim_names):
|
|
raise ValueError(
|
|
f"dim {dim} specified in `_unflatten` is not in {self.mesh_dim_names}"
|
|
)
|
|
|
|
if len(mesh_sizes) != len(mesh_dim_names):
|
|
raise RuntimeError(
|
|
"mesh_dim_names must have same length as mesh_sizes in _unflatten!"
|
|
)
|
|
|
|
if isinstance(dim, str):
|
|
dim = not_none(self.mesh_dim_names).index(dim)
|
|
|
|
if backend_override is not None:
|
|
backend_override_tuple = tuple(
|
|
_normalize_backend_override(
|
|
backend_override, # type: ignore[arg-type]
|
|
len(mesh_sizes),
|
|
mesh_dim_names,
|
|
)
|
|
)
|
|
else:
|
|
backend_override_tuple = ((None, None),) * len(mesh_dim_names)
|
|
|
|
return self._create_unflatten_mesh(
|
|
dim,
|
|
mesh_sizes,
|
|
mesh_dim_names,
|
|
backend_override_tuple,
|
|
)
|
|
|
|
def _normalize_backend_override(
|
|
backend_override: dict[
|
|
Union[int, str],
|
|
Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]],
|
|
],
|
|
ndim: int,
|
|
mesh_dim_names: Optional[tuple[str, ...]] = None,
|
|
) -> Iterator[BackendConfig]:
|
|
if mesh_dim_names is None:
|
|
mesh_dim_names = ()
|
|
for dim_idx, dim_name in zip_longest(range(ndim), mesh_dim_names):
|
|
if dim_name is not None and dim_name in backend_override:
|
|
if dim_idx in backend_override:
|
|
raise RuntimeError(
|
|
f"Found redundant dim index {dim_idx} and "
|
|
f"name {dim_name} in backend_override"
|
|
)
|
|
val = backend_override.pop(dim_name)
|
|
elif dim_idx in backend_override:
|
|
val = backend_override.pop(dim_idx)
|
|
else:
|
|
yield (None, None)
|
|
continue
|
|
|
|
if isinstance(val, str):
|
|
yield (val, None)
|
|
elif isinstance(val, C10dBackend.Options):
|
|
yield (None, val)
|
|
else:
|
|
yield val
|
|
|
|
if backend_override:
|
|
raise RuntimeError(
|
|
f"Found invalid keys in backend_override: got {list(backend_override.keys())}, "
|
|
f"expected integers in range [0, {ndim}) or one of {mesh_dim_names}"
|
|
)
|
|
|
|
def init_device_mesh(
|
|
device_type: str,
|
|
mesh_shape: tuple[int, ...],
|
|
*,
|
|
mesh_dim_names: Optional[tuple[str, ...]] = None,
|
|
backend_override: Optional[
|
|
dict[
|
|
Union[int, str],
|
|
Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]],
|
|
]
|
|
] = None,
|
|
) -> DeviceMesh:
|
|
"""
|
|
Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
|
|
|
|
This creates a DeviceMesh with an n-dimensional array layout, where `n` is the length of `mesh_shape`.
|
|
If `mesh_dim_names` is provided, each dimension is labeled as `mesh_dim_names[i]`.
|
|
|
|
.. note::
|
|
`init_device_mesh` follows SPMD programming model, meaning the same PyTorch Python program
|
|
runs on all processes/ranks in the cluster. Ensure `mesh_shape` (the dimensions of the nD array
|
|
describing device layout) is identical across all ranks. Inconsistent `mesh_shape` may lead to hanging.
|
|
|
|
.. note::
|
|
If no process group is found, init_device_mesh will initialize distributed process group/groups
|
|
required for distributed communications behind the scene.
|
|
|
|
Args:
|
|
device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like", "xpu".
|
|
Passing in a device type with a GPU index, such as "cuda:0", is not allowed.
|
|
mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array
|
|
describing the layout of devices.
|
|
mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension
|
|
of the multi-dimensional array describing the layout of devices. Its length must match the length
|
|
of `mesh_shape`. Each string in `mesh_dim_names` must be unique.
|
|
backend_override (Dict[int | str, tuple[str, Options] | str | Options], optional): Overrides for some or all of
|
|
the ProcessGroups that will be created for each mesh dimension. Each key can be either the index of a
|
|
dimension or its name (if mesh_dim_names is provided). Each value can be a tuple containing the name
|
|
of the backend and its options, or just one of these two components (in which case the other will be
|
|
set to its default value).
|
|
|
|
Returns:
|
|
DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +SKIP("no rank")
|
|
>>> from torch.distributed.device_mesh import init_device_mesh
|
|
>>>
|
|
>>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
|
|
>>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
|
|
|
|
"""
|
|
if mesh_dim_names is not None:
|
|
if len(set(mesh_dim_names)) != len(mesh_dim_names):
|
|
raise RuntimeError(
|
|
"Each mesh_dim_name must be unique.",
|
|
f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}",
|
|
)
|
|
|
|
if len(mesh_shape) != len(mesh_dim_names):
|
|
raise RuntimeError(
|
|
"mesh_shape and mesh_dim_names should have same length!",
|
|
f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
|
|
)
|
|
|
|
if backend_override is not None:
|
|
backend_override_tuple = tuple(
|
|
_normalize_backend_override(
|
|
backend_override, len(mesh_shape), mesh_dim_names
|
|
)
|
|
)
|
|
else:
|
|
backend_override_tuple = None
|
|
|
|
# assume valid device types are all letters
|
|
if device_type and not device_type.isalpha():
|
|
raise RuntimeError(
|
|
f"Device type with index is not supported but got {device_type}. ",
|
|
"If you maintained a 'torch.device' object, it's recommended to pass in 'device.type'.",
|
|
)
|
|
|
|
layout = _MeshLayout(tuple(mesh_shape), suffix_product(tuple(mesh_shape)))
|
|
# Always initialize the (identity) rank map on CPU, regardless of what the
|
|
# external device type has been set to be (e.g. meta)
|
|
with torch.device("cpu"):
|
|
rank_map = torch.arange(layout.numel(), dtype=torch.int)
|
|
device_mesh = DeviceMesh(
|
|
device_type=device_type,
|
|
_layout=layout,
|
|
_rank_map=rank_map,
|
|
mesh_dim_names=mesh_dim_names,
|
|
backend_override=backend_override_tuple,
|
|
)
|
|
|
|
return device_mesh
|