Files
pytorch/torch/distributed/_tensor/device_mesh.py
Shintaro Iwasaki 95575f0a5f [DTensor] Fix _get_or_create_default_group() (#96961)
Summary:
This PR fixes `_get_or_create_default_group()` of `DeviceMesh`. When `mesh` of the first created `DeviceMesh` is not `[0, 1, 2, ... WORLD_SIZE - 1]` and `is_initialized() == False`, it wrongly asserts. This PR fixes this issue by removing these assertions.

 ---

More specifically, `_get_or_create_default_group()` has 4 checks:

1. `DeviceMesh must include every process in WORLD`
2. `DeviceMesh cannot have duplicate values`
3. `DeviceMesh ranks must start from 0`
4. `DeviceMesh should have all ranks of WORLD`

1, 3, and 4 are not satisfied when `self.mesh` is not `[0, 1, 2, ... WORLD_SIZE - 1]`.

2 is a valid check, but it is also checked in `__init__()`, so we don't need to check it again in this function.

Test Plan: CI

Reviewed By: wanchaol

Differential Revision: D44098849

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96961
Approved by: https://github.com/wanchaol
2023-03-17 15:52:19 +00:00

528 lines
19 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
import warnings
from typing import List, Optional, Sequence, TypeVar, Union
import torch
from torch.distributed.distributed_c10d import (
_get_default_group,
all_gather,
all_to_all,
broadcast,
get_global_rank,
get_rank,
get_world_size,
GroupMember,
init_process_group,
is_initialized,
new_group,
ProcessGroup,
reduce_scatter,
ReduceOp,
scatter,
Work,
)
import torch.distributed.distributed_c10d as c10d
import torch.distributed._functional_collectives as funcol
_global_device_mesh: Optional["DeviceMesh"] = None
def get_global_device_mesh() -> "DeviceMesh":
global _global_device_mesh
assert _global_device_mesh is not None, "Could not get a default device mesh!"
return _global_device_mesh
def set_global_device_mesh(mesh: Optional["DeviceMesh"]) -> None:
global _global_device_mesh
_global_device_mesh = mesh
# We want a type for "can be passed to torch.as_tensor()";
# this is a recursive sequence type, which isn't fully supported
# yet in python. This construct simulates that up to depth 7.
T = TypeVar("T")
_L = Union[T, Sequence[T]]
NDIntList = _L[_L[_L[_L[_L[_L[_L[int]]]]]]]
MeshExprT = Union[
torch.Tensor,
NDIntList,
]
class DeviceMesh:
"""
DeviceMesh represents a mesh of devices, where layout of devices could be
represented as a n-d dimension array, and each value of the n-d dimensional
array is the global id of the default process group ranks.
DeviceMesh could be used to describe the layout of devices across the cluster,
and serves as a proxy for communication among the device lists within the cluster.
We use the default ProcessGroup in this DeviceMesh class to implement proper
communications. Note that we also add collective wrappers in this class. This is
used to decouple detailed communication backend with the underlying
DTensor implementation.
DeviceMesh can be used as a context manager.
Args:
device_type (str): device type of the mesh. Currently supports: cpu, cuda.
mesh (ndarray): could be a multi-dimension array or an integer tensor that
describes the layout of devices, the ids are global ids of the
default process group.
dim_groups (List[ProcessGroup], optional): The ProcessGroup used per mesh
dimension.
Returns:
A :class:`DeviceMesh` object
Example (2 host with 4 GPUs each):
```
# The following program runs on each process/rank in SPMD manner.
# initialized default world
torch.distributed.init_process_group(backend="nccl", world_size=8)
# initialize device mesh as (2, 4) to represent the topology
# of cross-host(dim 0), and within-host (dim 1)
mesh = DeviceMesh(device_type="cuda",
mesh=[
[0, 1, 2, 3],
[4, 5, 6, 7]
])
```
A reduction over the first dimension of mesh will reduce across
columns (0, 4), .. and (3, 7), a reduction over the second dimension
of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7)
"""
device_type: str
mesh: torch.Tensor
_backend: str
def __init__(
self,
device_type: str,
mesh: MeshExprT,
dim_groups: Optional[List[ProcessGroup]] = None,
) -> None:
self.device_type = device_type
self.mesh = (
mesh.detach()
if isinstance(mesh, torch.Tensor)
else torch.tensor(mesh, dtype=torch.int)
)
default_pg = self._get_or_create_default_group()
self._backend = default_pg._get_backend_name()
# TODO: if user want to pass pg_options, offer a way to do it
# check default pg backend, should support device_type
if device_type == "cpu":
assert (
self._backend == "gloo" or self._backend == "threaded"
), f"ProcessGroup backend: {self._backend} not supporting CPU!"
elif device_type == "cuda":
if self._backend == "gloo":
warnings.warn(
"We recommend using nccl backend for cuda device type, gloo backend might only have partial support!"
)
assert self._backend == "gloo" or self._backend == "nccl" or self._backend == "threaded"
else:
raise RuntimeError(
f"DeviceMesh only support cpu or cuda device type, but got {device_type}"
)
world_size = get_world_size()
if self.mesh.numel() > world_size:
raise RuntimeError(
f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
)
unique_mesh_values = self.mesh.unique(sorted=True)
if unique_mesh_values.numel() != self.mesh.numel():
raise RuntimeError(
f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}"
)
# coordinates of this rank on the mesh
rank_coords = (self.mesh == get_rank()).nonzero()
assert rank_coords.size(0) in (0, 1)
self._coordinate_on_dim: Optional[List[int]] = (
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
)
# groups created by dimension, each dimension should have exact
# one valid process group per rank
self._dim_groups: List[ProcessGroup] = []
if dim_groups is not None:
# if user hand creating dimension based groups
# we just take it and use it for communication
if not isinstance(dim_groups, list):
raise RuntimeError(
"dim_groups expected to be Optional[List[ProcessGroup]]"
)
for group in dim_groups:
if not isinstance(group, ProcessGroup):
raise RuntimeError(
f"found object in dim_groups that is not a ProcessGroup: {group}"
)
if self.get_rank() in self.mesh:
if len(dim_groups) != self.mesh.ndim:
raise RuntimeError(
f"length of dim_groups ({len(dim_groups)}) expected to be equal to mesh.ndim ({self.mesh.ndim})"
)
else:
if len(dim_groups) != 0:
raise RuntimeError(
f"length of dim_groups ({len(dim_groups)}) expected to be equal to 0 on rank {self.get_rank()} "
f"for mesh {self.mesh}"
)
self._dim_groups = dim_groups
return
if self.mesh.ndim == 1 and len(unique_mesh_values) == world_size - 1:
# if the mesh is the same as world_pg, we just append the default
# pg to the first dim goups, as new_group cannot have the exact
# same ranks as world
self._dim_groups.append(default_pg)
else:
# create sub pgs base on the mesh argument specified
# handle multi-dim mesh, create subgroups by
# looping over the pg_ranks_by_dim for each dim
for dim in range(self.mesh.ndim):
# swap the current dim to the last dim
# then reshape to flatten out other dims
pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
-1, self.mesh.size(dim)
)
# multi-dim mesh, create subgroups by
# looping over the pg_ranks for each dim
# and append the groups
for dim_mesh in pg_ranks_by_dim:
subgroup_ranks = dim_mesh.tolist()
# call new_group regardless of the current rank in the
# pg or not, it's required that all ranks participate
# in subgroup construction
new_subgroup = new_group(
ranks=subgroup_ranks, backend=self._backend
)
# only add to dim_groups if the current rank in the subgroup
if self.get_rank() in subgroup_ranks:
if len(self._dim_groups) > dim:
raise RuntimeError(
f"Each device mesh dimension should get only one process group, but got {self.get_rank} "
f"in {subgroup_ranks}!"
)
self._dim_groups.append(new_subgroup)
def _get_or_create_default_group(self):
if not is_initialized():
_backend = "gloo" if self.device_type == "cpu" else "nccl"
init_process_group(backend=_backend)
return _get_default_group()
def __enter__(self) -> "DeviceMesh":
# set global device_mesh to this instance
set_global_device_mesh(self)
return self
# pyre-fixme[2]: Parameter must be annotated.
def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
# unset global device mesh
set_global_device_mesh(None)
def __repr__(self) -> str:
return f"DeviceMesh:({self.mesh.tolist()})"
def __hash__(self):
return hash((self.mesh, id(self)))
def __eq__(self, other: object) -> bool:
if not isinstance(other, DeviceMesh):
return False
if id(self) == id(other):
return True
return self.mesh.equal(other.mesh)
def get_dim_groups(self) -> List[ProcessGroup]:
return self._dim_groups
# pyre-fixme[3]: Return type must be annotated.
def size(self, dim: int = 0):
return self.mesh.size(dim)
@property
def ndim(self) -> int:
return self.mesh.ndim
def backend(self) -> str:
return self._backend
def get_rank(self) -> int:
return get_rank()
def get_coordinate(self) -> Optional[List[int]]:
"""
Return the relative indices of this rank relative to all
dimensions of the mesh. If this rank is not part of the mesh, return None.
"""
return self._coordinate_on_dim if self._coordinate_on_dim else None
def scatter(
self,
output: torch.Tensor,
scatter_list: List[torch.Tensor],
mesh_dim: int = 0,
async_op: bool = False,
) -> Optional[Work]:
"""
scatter a list of tensors to a device mesh dimension. We by default
use the first rank of the mesh dimension as the source of truth, i.e
for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will
scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank
2 to rank 2/3.
Args:
output (torch.Tensor): the tensor to receive the scattered list.
scatter_list (List[torch.Tensor]): the tensor list to be scattered.
mesh_dim (int, optional): indicate which mesh dimension we want
to scatter on, we by default choose the first rank on the
mesh dimension as source of truth.
Returns:
A :class:`Work` object
"""
# TODO: Ideally we should use the meta tensor way
# (to register a meta kernel for the collective op)
# so that it would avoid the communication. Need to
# remove the check below once that is done.
if output.is_meta:
return None
dim_group = self._dim_groups[mesh_dim]
# src need to be global rank
src_for_dim = 0
if dim_group is not GroupMember.WORLD:
src_for_dim = get_global_rank(dim_group, 0)
if src_for_dim == get_rank():
fut = scatter(
output,
scatter_list=scatter_list,
src=src_for_dim,
group=dim_group,
async_op=async_op,
)
else:
fut = scatter(
output,
scatter_list=None,
src=src_for_dim,
group=dim_group,
async_op=async_op,
)
return fut
def broadcast(
self,
tensor: torch.Tensor,
mesh_dim: int = 0,
async_op: bool = False,
) -> Optional[Work]:
"""
broadcast the tensor to a device mesh dimension. We by default
use the first rank of the mesh dimension as the source of truth, i.e
for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will
broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2
to rank 2/3.
Args:
tensor (torch.Tensor): tensor to broadcast.
mesh_dim (int, optional): indicate which mesh dimension we want
to scatter on, we by default choose the first rank on the
mesh dimension as source of truth.
Returns:
A :class:`Work` object
"""
# TODO: Ideally we should use the meta tensor way
# (to register a meta kernel for the collective op)
# so that it would avoid the communication. Need to
# remove the check below once that is done.
if tensor.is_meta:
return None
dim_group = self._dim_groups[mesh_dim]
# src need to be global rank
src_for_dim = 0
if dim_group is not GroupMember.WORLD:
src_for_dim = get_global_rank(dim_group, 0)
return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op)
def all_gather(
self,
tensor_list: List[torch.Tensor],
tensor: torch.Tensor,
mesh_dim: int = 0,
async_op: bool = False,
) -> Optional[Work]:
"""
all_gather the tensor on each rank to the tensor_list on a
device mesh dimension.
Args:
tensor_list (List[torch.Tensor]): The gathered tensor list.
tensor (torch.Tensor): tensor to be gathered on each rank.
mesh_dim (int, optional): indicate which mesh dimension we want
to scatter on, we by default choose the first rank on the
mesh dimension as source of truth.
Returns:
A :class:`Work` object
"""
dim_group = self._dim_groups[mesh_dim]
return all_gather(tensor_list, tensor, group=dim_group, async_op=async_op)
def all_reduce(
self,
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM, # type: ignore[assignment]
mesh_dim: int = 0,
async_op: bool = False,
) -> torch.Tensor:
"""
all_reduce the tensor on each rank on a device mesh dimension, and
return an output tensor on each rank after all_reduce.
Args:
tensor (torch.Tensor): tensor to be all_reduced on each rank.
op (:class:`torch.distributed.distributed_c10d.ReduceOp, optional):
the reduction op of all_reduce (i.e. ReduceOp.SUM)
mesh_dim (int, optional): indicate which mesh dimension we want
to reduce on.
Returns:
A :class:`torch.Tensor` object
"""
dim_group = self._dim_groups[mesh_dim]
op_name: str = op.name # type: ignore[attr-defined]
return funcol.all_reduce(tensor, reduceOp=op_name, group=(self, mesh_dim,))
def reduce_scatter(
self,
output: torch.Tensor,
input_list: List[torch.Tensor],
op: ReduceOp = ReduceOp.SUM, # type: ignore[assignment]
mesh_dim: int = 0,
async_op: bool = False,
) -> Optional[Work]:
"""
reduce the input_list on each rank on a device mesh dimension, and scatter
the results to the output tensor on each rank.
Args:
output (torch.Tensor): tensor to receive the scattered result.
input_list (List[torch.Tensor]): tensor list to be reduced and scattered
and scattered on each rank.
op (:class:`torch.distributed.distributed_c10d.ReduceOp, optional):
the reduction op of reduce_scatter (i.e. ReduceOp.SUM)
mesh_dim (int, optional): indicate which mesh dimension we want
to scatter on.
Returns:
A :class:`Work` object
"""
if self._backend == "nccl":
dim_group = self._dim_groups[mesh_dim]
fut = reduce_scatter(
output, input_list, op=op, group=dim_group, async_op=async_op
)
elif self._backend == "gloo":
# it's gloo, which does not have reduce_scatter
# we have to do all_reduce + scatter
warnings.warn(
"ProcessGroupGloo does not support reduce_scatter, falling back with all reduce!"
)
my_coordinate = self.get_coordinate()
# TODO: what should happen if rank is not in the mesh?
# see issue https://github.com/pytorch/tau/pull/492
assert (
my_coordinate is not None
), "Rank if not part of mesh" # TODO: figure out behavior here
fut = None
flattened_list = []
offset_list = []
offset = 0
for input in input_list:
offset_list.append(offset)
offset += input.numel()
flattened_list.append(input.flatten())
# all reduce since gloo does not support reduce_scatter
flat_tensor = torch.cat(flattened_list).clone(
memory_format=torch.contiguous_format
)
dim_group = self._dim_groups[mesh_dim]
fut = c10d.all_reduce(flat_tensor, op=op, group=dim_group, async_op=async_op)
# scatter the tensor
output_offset = offset_list[my_coordinate[mesh_dim]]
output.copy_(
flat_tensor[output_offset : output_offset + output.numel()].view(
output.shape
)
)
else:
raise RuntimeError(
f"backend {self._backend} does not support reduce_scatter!"
)
return fut
# TODO: test uneven split on GLOO and NCCL
def all_to_all(
self,
output_tensor_list: List[torch.Tensor],
input_tensor_list: List[torch.Tensor],
mesh_dim: int = 0,
async_op: bool = False,
) -> Optional[Work]:
dim_group = self._dim_groups[mesh_dim]
work = None
# no direct dist.all_to_all support on 'gloo' so we manually do scatters
if self.backend() == "gloo":
# TODO: pull the handle of uneven case in #492
dim_group_size = get_world_size(dim_group)
for i in range(dim_group_size):
# src need to be global rank
src_for_dim = i
if dim_group is not GroupMember.WORLD:
src_for_dim = get_global_rank(dim_group, i)
work = scatter(
output_tensor_list[i],
input_tensor_list if self.get_rank() == src_for_dim else [],
group=dim_group,
src=src_for_dim,
async_op=async_op,
)
elif self.backend() == "nccl":
work = all_to_all(
output_tensor_list,
input_tensor_list,
dim_group,
async_op=async_op,
)
else:
raise RuntimeError(
f"DeviceMesh does not support all-to-all collective operations on {self.backend()} backend."
)
return work