mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This PR fixes `_get_or_create_default_group()` of `DeviceMesh`. When `mesh` of the first created `DeviceMesh` is not `[0, 1, 2, ... WORLD_SIZE - 1]` and `is_initialized() == False`, it wrongly asserts. This PR fixes this issue by removing these assertions. --- More specifically, `_get_or_create_default_group()` has 4 checks: 1. `DeviceMesh must include every process in WORLD` 2. `DeviceMesh cannot have duplicate values` 3. `DeviceMesh ranks must start from 0` 4. `DeviceMesh should have all ranks of WORLD` 1, 3, and 4 are not satisfied when `self.mesh` is not `[0, 1, 2, ... WORLD_SIZE - 1]`. 2 is a valid check, but it is also checked in `__init__()`, so we don't need to check it again in this function. Test Plan: CI Reviewed By: wanchaol Differential Revision: D44098849 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96961 Approved by: https://github.com/wanchaol
528 lines
19 KiB
Python
528 lines
19 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
import warnings
|
|
from typing import List, Optional, Sequence, TypeVar, Union
|
|
|
|
import torch
|
|
from torch.distributed.distributed_c10d import (
|
|
_get_default_group,
|
|
all_gather,
|
|
all_to_all,
|
|
broadcast,
|
|
get_global_rank,
|
|
get_rank,
|
|
get_world_size,
|
|
GroupMember,
|
|
init_process_group,
|
|
is_initialized,
|
|
new_group,
|
|
ProcessGroup,
|
|
reduce_scatter,
|
|
ReduceOp,
|
|
scatter,
|
|
Work,
|
|
)
|
|
|
|
import torch.distributed.distributed_c10d as c10d
|
|
import torch.distributed._functional_collectives as funcol
|
|
|
|
_global_device_mesh: Optional["DeviceMesh"] = None
|
|
|
|
|
|
def get_global_device_mesh() -> "DeviceMesh":
|
|
global _global_device_mesh
|
|
assert _global_device_mesh is not None, "Could not get a default device mesh!"
|
|
return _global_device_mesh
|
|
|
|
|
|
def set_global_device_mesh(mesh: Optional["DeviceMesh"]) -> None:
|
|
global _global_device_mesh
|
|
_global_device_mesh = mesh
|
|
|
|
|
|
# We want a type for "can be passed to torch.as_tensor()";
|
|
# this is a recursive sequence type, which isn't fully supported
|
|
# yet in python. This construct simulates that up to depth 7.
|
|
T = TypeVar("T")
|
|
_L = Union[T, Sequence[T]]
|
|
NDIntList = _L[_L[_L[_L[_L[_L[_L[int]]]]]]]
|
|
|
|
MeshExprT = Union[
|
|
torch.Tensor,
|
|
NDIntList,
|
|
]
|
|
|
|
|
|
class DeviceMesh:
|
|
"""
|
|
DeviceMesh represents a mesh of devices, where layout of devices could be
|
|
represented as a n-d dimension array, and each value of the n-d dimensional
|
|
array is the global id of the default process group ranks.
|
|
|
|
DeviceMesh could be used to describe the layout of devices across the cluster,
|
|
and serves as a proxy for communication among the device lists within the cluster.
|
|
|
|
We use the default ProcessGroup in this DeviceMesh class to implement proper
|
|
communications. Note that we also add collective wrappers in this class. This is
|
|
used to decouple detailed communication backend with the underlying
|
|
DTensor implementation.
|
|
|
|
DeviceMesh can be used as a context manager.
|
|
Args:
|
|
device_type (str): device type of the mesh. Currently supports: cpu, cuda.
|
|
mesh (ndarray): could be a multi-dimension array or an integer tensor that
|
|
describes the layout of devices, the ids are global ids of the
|
|
default process group.
|
|
dim_groups (List[ProcessGroup], optional): The ProcessGroup used per mesh
|
|
dimension.
|
|
|
|
Returns:
|
|
A :class:`DeviceMesh` object
|
|
|
|
Example (2 host with 4 GPUs each):
|
|
```
|
|
# The following program runs on each process/rank in SPMD manner.
|
|
# initialized default world
|
|
torch.distributed.init_process_group(backend="nccl", world_size=8)
|
|
# initialize device mesh as (2, 4) to represent the topology
|
|
# of cross-host(dim 0), and within-host (dim 1)
|
|
mesh = DeviceMesh(device_type="cuda",
|
|
mesh=[
|
|
[0, 1, 2, 3],
|
|
[4, 5, 6, 7]
|
|
])
|
|
```
|
|
A reduction over the first dimension of mesh will reduce across
|
|
columns (0, 4), .. and (3, 7), a reduction over the second dimension
|
|
of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7)
|
|
|
|
"""
|
|
|
|
device_type: str
|
|
mesh: torch.Tensor
|
|
_backend: str
|
|
|
|
def __init__(
|
|
self,
|
|
device_type: str,
|
|
mesh: MeshExprT,
|
|
dim_groups: Optional[List[ProcessGroup]] = None,
|
|
) -> None:
|
|
self.device_type = device_type
|
|
self.mesh = (
|
|
mesh.detach()
|
|
if isinstance(mesh, torch.Tensor)
|
|
else torch.tensor(mesh, dtype=torch.int)
|
|
)
|
|
default_pg = self._get_or_create_default_group()
|
|
self._backend = default_pg._get_backend_name()
|
|
# TODO: if user want to pass pg_options, offer a way to do it
|
|
# check default pg backend, should support device_type
|
|
if device_type == "cpu":
|
|
assert (
|
|
self._backend == "gloo" or self._backend == "threaded"
|
|
), f"ProcessGroup backend: {self._backend} not supporting CPU!"
|
|
elif device_type == "cuda":
|
|
if self._backend == "gloo":
|
|
warnings.warn(
|
|
"We recommend using nccl backend for cuda device type, gloo backend might only have partial support!"
|
|
)
|
|
assert self._backend == "gloo" or self._backend == "nccl" or self._backend == "threaded"
|
|
else:
|
|
raise RuntimeError(
|
|
f"DeviceMesh only support cpu or cuda device type, but got {device_type}"
|
|
)
|
|
|
|
world_size = get_world_size()
|
|
if self.mesh.numel() > world_size:
|
|
raise RuntimeError(
|
|
f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
|
|
)
|
|
|
|
unique_mesh_values = self.mesh.unique(sorted=True)
|
|
if unique_mesh_values.numel() != self.mesh.numel():
|
|
raise RuntimeError(
|
|
f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}"
|
|
)
|
|
|
|
# coordinates of this rank on the mesh
|
|
rank_coords = (self.mesh == get_rank()).nonzero()
|
|
assert rank_coords.size(0) in (0, 1)
|
|
self._coordinate_on_dim: Optional[List[int]] = (
|
|
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
|
)
|
|
|
|
# groups created by dimension, each dimension should have exact
|
|
# one valid process group per rank
|
|
self._dim_groups: List[ProcessGroup] = []
|
|
if dim_groups is not None:
|
|
# if user hand creating dimension based groups
|
|
# we just take it and use it for communication
|
|
if not isinstance(dim_groups, list):
|
|
raise RuntimeError(
|
|
"dim_groups expected to be Optional[List[ProcessGroup]]"
|
|
)
|
|
|
|
for group in dim_groups:
|
|
if not isinstance(group, ProcessGroup):
|
|
raise RuntimeError(
|
|
f"found object in dim_groups that is not a ProcessGroup: {group}"
|
|
)
|
|
|
|
if self.get_rank() in self.mesh:
|
|
if len(dim_groups) != self.mesh.ndim:
|
|
raise RuntimeError(
|
|
f"length of dim_groups ({len(dim_groups)}) expected to be equal to mesh.ndim ({self.mesh.ndim})"
|
|
)
|
|
else:
|
|
if len(dim_groups) != 0:
|
|
raise RuntimeError(
|
|
f"length of dim_groups ({len(dim_groups)}) expected to be equal to 0 on rank {self.get_rank()} "
|
|
f"for mesh {self.mesh}"
|
|
)
|
|
|
|
self._dim_groups = dim_groups
|
|
return
|
|
|
|
if self.mesh.ndim == 1 and len(unique_mesh_values) == world_size - 1:
|
|
# if the mesh is the same as world_pg, we just append the default
|
|
# pg to the first dim goups, as new_group cannot have the exact
|
|
# same ranks as world
|
|
self._dim_groups.append(default_pg)
|
|
else:
|
|
# create sub pgs base on the mesh argument specified
|
|
# handle multi-dim mesh, create subgroups by
|
|
# looping over the pg_ranks_by_dim for each dim
|
|
for dim in range(self.mesh.ndim):
|
|
# swap the current dim to the last dim
|
|
# then reshape to flatten out other dims
|
|
pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
|
|
-1, self.mesh.size(dim)
|
|
)
|
|
|
|
# multi-dim mesh, create subgroups by
|
|
# looping over the pg_ranks for each dim
|
|
# and append the groups
|
|
for dim_mesh in pg_ranks_by_dim:
|
|
subgroup_ranks = dim_mesh.tolist()
|
|
# call new_group regardless of the current rank in the
|
|
# pg or not, it's required that all ranks participate
|
|
# in subgroup construction
|
|
new_subgroup = new_group(
|
|
ranks=subgroup_ranks, backend=self._backend
|
|
)
|
|
# only add to dim_groups if the current rank in the subgroup
|
|
if self.get_rank() in subgroup_ranks:
|
|
if len(self._dim_groups) > dim:
|
|
raise RuntimeError(
|
|
f"Each device mesh dimension should get only one process group, but got {self.get_rank} "
|
|
f"in {subgroup_ranks}!"
|
|
)
|
|
self._dim_groups.append(new_subgroup)
|
|
|
|
def _get_or_create_default_group(self):
|
|
if not is_initialized():
|
|
_backend = "gloo" if self.device_type == "cpu" else "nccl"
|
|
init_process_group(backend=_backend)
|
|
return _get_default_group()
|
|
|
|
def __enter__(self) -> "DeviceMesh":
|
|
# set global device_mesh to this instance
|
|
set_global_device_mesh(self)
|
|
return self
|
|
|
|
# pyre-fixme[2]: Parameter must be annotated.
|
|
def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
|
|
# unset global device mesh
|
|
set_global_device_mesh(None)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"DeviceMesh:({self.mesh.tolist()})"
|
|
|
|
def __hash__(self):
|
|
return hash((self.mesh, id(self)))
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
if not isinstance(other, DeviceMesh):
|
|
return False
|
|
if id(self) == id(other):
|
|
return True
|
|
return self.mesh.equal(other.mesh)
|
|
|
|
def get_dim_groups(self) -> List[ProcessGroup]:
|
|
return self._dim_groups
|
|
|
|
# pyre-fixme[3]: Return type must be annotated.
|
|
def size(self, dim: int = 0):
|
|
return self.mesh.size(dim)
|
|
|
|
@property
|
|
def ndim(self) -> int:
|
|
return self.mesh.ndim
|
|
|
|
def backend(self) -> str:
|
|
return self._backend
|
|
|
|
def get_rank(self) -> int:
|
|
return get_rank()
|
|
|
|
def get_coordinate(self) -> Optional[List[int]]:
|
|
"""
|
|
Return the relative indices of this rank relative to all
|
|
dimensions of the mesh. If this rank is not part of the mesh, return None.
|
|
"""
|
|
return self._coordinate_on_dim if self._coordinate_on_dim else None
|
|
|
|
def scatter(
|
|
self,
|
|
output: torch.Tensor,
|
|
scatter_list: List[torch.Tensor],
|
|
mesh_dim: int = 0,
|
|
async_op: bool = False,
|
|
) -> Optional[Work]:
|
|
"""
|
|
scatter a list of tensors to a device mesh dimension. We by default
|
|
use the first rank of the mesh dimension as the source of truth, i.e
|
|
for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will
|
|
scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank
|
|
2 to rank 2/3.
|
|
|
|
Args:
|
|
output (torch.Tensor): the tensor to receive the scattered list.
|
|
scatter_list (List[torch.Tensor]): the tensor list to be scattered.
|
|
mesh_dim (int, optional): indicate which mesh dimension we want
|
|
to scatter on, we by default choose the first rank on the
|
|
mesh dimension as source of truth.
|
|
|
|
Returns:
|
|
A :class:`Work` object
|
|
"""
|
|
# TODO: Ideally we should use the meta tensor way
|
|
# (to register a meta kernel for the collective op)
|
|
# so that it would avoid the communication. Need to
|
|
# remove the check below once that is done.
|
|
if output.is_meta:
|
|
return None
|
|
dim_group = self._dim_groups[mesh_dim]
|
|
# src need to be global rank
|
|
src_for_dim = 0
|
|
if dim_group is not GroupMember.WORLD:
|
|
src_for_dim = get_global_rank(dim_group, 0)
|
|
|
|
if src_for_dim == get_rank():
|
|
fut = scatter(
|
|
output,
|
|
scatter_list=scatter_list,
|
|
src=src_for_dim,
|
|
group=dim_group,
|
|
async_op=async_op,
|
|
)
|
|
else:
|
|
fut = scatter(
|
|
output,
|
|
scatter_list=None,
|
|
src=src_for_dim,
|
|
group=dim_group,
|
|
async_op=async_op,
|
|
)
|
|
|
|
return fut
|
|
|
|
def broadcast(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
mesh_dim: int = 0,
|
|
async_op: bool = False,
|
|
) -> Optional[Work]:
|
|
"""
|
|
broadcast the tensor to a device mesh dimension. We by default
|
|
use the first rank of the mesh dimension as the source of truth, i.e
|
|
for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will
|
|
broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2
|
|
to rank 2/3.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): tensor to broadcast.
|
|
mesh_dim (int, optional): indicate which mesh dimension we want
|
|
to scatter on, we by default choose the first rank on the
|
|
mesh dimension as source of truth.
|
|
|
|
Returns:
|
|
A :class:`Work` object
|
|
"""
|
|
# TODO: Ideally we should use the meta tensor way
|
|
# (to register a meta kernel for the collective op)
|
|
# so that it would avoid the communication. Need to
|
|
# remove the check below once that is done.
|
|
if tensor.is_meta:
|
|
return None
|
|
dim_group = self._dim_groups[mesh_dim]
|
|
# src need to be global rank
|
|
src_for_dim = 0
|
|
if dim_group is not GroupMember.WORLD:
|
|
src_for_dim = get_global_rank(dim_group, 0)
|
|
|
|
return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op)
|
|
|
|
def all_gather(
|
|
self,
|
|
tensor_list: List[torch.Tensor],
|
|
tensor: torch.Tensor,
|
|
mesh_dim: int = 0,
|
|
async_op: bool = False,
|
|
) -> Optional[Work]:
|
|
"""
|
|
all_gather the tensor on each rank to the tensor_list on a
|
|
device mesh dimension.
|
|
|
|
Args:
|
|
tensor_list (List[torch.Tensor]): The gathered tensor list.
|
|
tensor (torch.Tensor): tensor to be gathered on each rank.
|
|
mesh_dim (int, optional): indicate which mesh dimension we want
|
|
to scatter on, we by default choose the first rank on the
|
|
mesh dimension as source of truth.
|
|
|
|
Returns:
|
|
A :class:`Work` object
|
|
"""
|
|
dim_group = self._dim_groups[mesh_dim]
|
|
return all_gather(tensor_list, tensor, group=dim_group, async_op=async_op)
|
|
|
|
def all_reduce(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
op: ReduceOp = ReduceOp.SUM, # type: ignore[assignment]
|
|
mesh_dim: int = 0,
|
|
async_op: bool = False,
|
|
) -> torch.Tensor:
|
|
"""
|
|
all_reduce the tensor on each rank on a device mesh dimension, and
|
|
return an output tensor on each rank after all_reduce.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): tensor to be all_reduced on each rank.
|
|
op (:class:`torch.distributed.distributed_c10d.ReduceOp, optional):
|
|
the reduction op of all_reduce (i.e. ReduceOp.SUM)
|
|
mesh_dim (int, optional): indicate which mesh dimension we want
|
|
to reduce on.
|
|
|
|
Returns:
|
|
A :class:`torch.Tensor` object
|
|
"""
|
|
dim_group = self._dim_groups[mesh_dim]
|
|
op_name: str = op.name # type: ignore[attr-defined]
|
|
return funcol.all_reduce(tensor, reduceOp=op_name, group=(self, mesh_dim,))
|
|
|
|
def reduce_scatter(
|
|
self,
|
|
output: torch.Tensor,
|
|
input_list: List[torch.Tensor],
|
|
op: ReduceOp = ReduceOp.SUM, # type: ignore[assignment]
|
|
mesh_dim: int = 0,
|
|
async_op: bool = False,
|
|
) -> Optional[Work]:
|
|
"""
|
|
reduce the input_list on each rank on a device mesh dimension, and scatter
|
|
the results to the output tensor on each rank.
|
|
|
|
Args:
|
|
output (torch.Tensor): tensor to receive the scattered result.
|
|
input_list (List[torch.Tensor]): tensor list to be reduced and scattered
|
|
and scattered on each rank.
|
|
op (:class:`torch.distributed.distributed_c10d.ReduceOp, optional):
|
|
the reduction op of reduce_scatter (i.e. ReduceOp.SUM)
|
|
mesh_dim (int, optional): indicate which mesh dimension we want
|
|
to scatter on.
|
|
|
|
Returns:
|
|
A :class:`Work` object
|
|
"""
|
|
if self._backend == "nccl":
|
|
dim_group = self._dim_groups[mesh_dim]
|
|
fut = reduce_scatter(
|
|
output, input_list, op=op, group=dim_group, async_op=async_op
|
|
)
|
|
|
|
elif self._backend == "gloo":
|
|
# it's gloo, which does not have reduce_scatter
|
|
# we have to do all_reduce + scatter
|
|
warnings.warn(
|
|
"ProcessGroupGloo does not support reduce_scatter, falling back with all reduce!"
|
|
)
|
|
my_coordinate = self.get_coordinate()
|
|
# TODO: what should happen if rank is not in the mesh?
|
|
# see issue https://github.com/pytorch/tau/pull/492
|
|
assert (
|
|
my_coordinate is not None
|
|
), "Rank if not part of mesh" # TODO: figure out behavior here
|
|
fut = None
|
|
flattened_list = []
|
|
offset_list = []
|
|
|
|
offset = 0
|
|
for input in input_list:
|
|
offset_list.append(offset)
|
|
offset += input.numel()
|
|
flattened_list.append(input.flatten())
|
|
|
|
# all reduce since gloo does not support reduce_scatter
|
|
flat_tensor = torch.cat(flattened_list).clone(
|
|
memory_format=torch.contiguous_format
|
|
)
|
|
dim_group = self._dim_groups[mesh_dim]
|
|
fut = c10d.all_reduce(flat_tensor, op=op, group=dim_group, async_op=async_op)
|
|
|
|
# scatter the tensor
|
|
output_offset = offset_list[my_coordinate[mesh_dim]]
|
|
output.copy_(
|
|
flat_tensor[output_offset : output_offset + output.numel()].view(
|
|
output.shape
|
|
)
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
f"backend {self._backend} does not support reduce_scatter!"
|
|
)
|
|
return fut
|
|
|
|
# TODO: test uneven split on GLOO and NCCL
|
|
def all_to_all(
|
|
self,
|
|
output_tensor_list: List[torch.Tensor],
|
|
input_tensor_list: List[torch.Tensor],
|
|
mesh_dim: int = 0,
|
|
async_op: bool = False,
|
|
) -> Optional[Work]:
|
|
dim_group = self._dim_groups[mesh_dim]
|
|
|
|
work = None
|
|
# no direct dist.all_to_all support on 'gloo' so we manually do scatters
|
|
if self.backend() == "gloo":
|
|
# TODO: pull the handle of uneven case in #492
|
|
dim_group_size = get_world_size(dim_group)
|
|
for i in range(dim_group_size):
|
|
# src need to be global rank
|
|
src_for_dim = i
|
|
if dim_group is not GroupMember.WORLD:
|
|
src_for_dim = get_global_rank(dim_group, i)
|
|
|
|
work = scatter(
|
|
output_tensor_list[i],
|
|
input_tensor_list if self.get_rank() == src_for_dim else [],
|
|
group=dim_group,
|
|
src=src_for_dim,
|
|
async_op=async_op,
|
|
)
|
|
|
|
elif self.backend() == "nccl":
|
|
work = all_to_all(
|
|
output_tensor_list,
|
|
input_tensor_list,
|
|
dim_group,
|
|
async_op=async_op,
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
f"DeviceMesh does not support all-to-all collective operations on {self.backend()} backend."
|
|
)
|
|
return work
|