mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DeviceMesh] Introduce private constructor instead of _create_mesh_from_ranks (#165555)
The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor. In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it. This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`. With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165555 Approved by: https://github.com/fduwjj, https://github.com/fegin ghstack dependencies: #165554
This commit is contained in:
committed by
PyTorch MergeBot
parent
a214371008
commit
99097b6d89
@ -1,7 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import threading
|
||||
import warnings
|
||||
@ -12,7 +11,7 @@ from typing import Optional, TYPE_CHECKING, Union
|
||||
import torch
|
||||
from torch.distributed import is_available
|
||||
from torch.distributed._mesh_layout import _MeshLayout
|
||||
from torch.distributed._pycute import is_int
|
||||
from torch.distributed._pycute import is_int, suffix_product
|
||||
from torch.utils._typing_utils import not_none
|
||||
|
||||
|
||||
@ -183,45 +182,52 @@ else:
|
||||
def __init__(
|
||||
self,
|
||||
device_type: str,
|
||||
mesh: Union[torch.Tensor, "ArrayLike"],
|
||||
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
|
||||
*,
|
||||
mesh_dim_names: Optional[tuple[str, ...]] = None,
|
||||
backend_override: Optional[tuple[BackendConfig, ...]] = None,
|
||||
_init_backend: bool = True,
|
||||
_rank: Optional[int] = None,
|
||||
_layout: Optional[_MeshLayout] = None,
|
||||
_rank_map: Optional[torch.Tensor] = None,
|
||||
_root_mesh: Optional["DeviceMesh"] = None,
|
||||
) -> None:
|
||||
self._device_type = device_type
|
||||
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
|
||||
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
|
||||
mesh_tensor = (
|
||||
mesh.detach().to(dtype=torch.int).contiguous()
|
||||
if isinstance(mesh, torch.Tensor)
|
||||
else torch.tensor(mesh, device="cpu", dtype=torch.int)
|
||||
)
|
||||
self._rank_map = (
|
||||
_root_mesh._rank_map
|
||||
if _root_mesh is not None
|
||||
else mesh_tensor.flatten()
|
||||
)
|
||||
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
|
||||
# Internal bookkeeping for the device mesh.
|
||||
self._layout = (
|
||||
_layout
|
||||
if _layout
|
||||
else _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
|
||||
)
|
||||
self._root_mesh = _root_mesh
|
||||
assert self._layout.check_non_overlap(), (
|
||||
if mesh is not None:
|
||||
if _layout is not None or _rank_map is not None:
|
||||
raise TypeError(
|
||||
"Cannot provide _layout and/or _rank_map if passing explicit mesh"
|
||||
)
|
||||
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
|
||||
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
|
||||
mesh_tensor = (
|
||||
mesh.detach().to(dtype=torch.int).contiguous()
|
||||
if isinstance(mesh, torch.Tensor)
|
||||
else torch.tensor(mesh, device="cpu", dtype=torch.int)
|
||||
)
|
||||
_layout = _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
|
||||
_rank_map = mesh_tensor.flatten()
|
||||
else:
|
||||
if _layout is None or _rank_map is None:
|
||||
raise TypeError(
|
||||
"The mesh argument is required except for PRIVATE USAGE ONLY!"
|
||||
)
|
||||
|
||||
assert _layout.check_non_overlap(), (
|
||||
"Please use a non-overlapping layout when creating a DeviceMesh."
|
||||
)
|
||||
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
|
||||
assert self._layout.top_level_sizes == mesh_tensor.size(), (
|
||||
"Please use a valid layout when creating a DeviceMesh."
|
||||
f"The layout {self._layout} is not consistent with the mesh size {mesh_tensor.size()}."
|
||||
assert _rank_map.ndim == 1, "The rank map must be 1-dimensional"
|
||||
assert _rank_map.is_contiguous(), "The rank map must be contiguous"
|
||||
assert _rank_map.numel() >= _layout.cosize(), (
|
||||
f"The rank map contains {_rank_map.numel()} element, "
|
||||
f"which isn't large enough for layout {_layout}"
|
||||
)
|
||||
|
||||
self._device_type = device_type
|
||||
self._layout = _layout
|
||||
self._rank_map = _rank_map
|
||||
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
|
||||
self._root_mesh = _root_mesh
|
||||
|
||||
if backend_override is None:
|
||||
backend_override = ((None, None),) * len(self._layout)
|
||||
elif len(backend_override) != len(self._layout):
|
||||
@ -652,16 +658,13 @@ else:
|
||||
not_none(flatten_mesh._mesh_dim_names).index(name)
|
||||
]
|
||||
)
|
||||
cur_rank = self.get_rank()
|
||||
pg_ranks_by_dim = layout.remap_to_tensor(root_mesh._rank_map)
|
||||
res_submesh = DeviceMesh._create_mesh_from_ranks(
|
||||
res_submesh = DeviceMesh(
|
||||
self._device_type,
|
||||
pg_ranks_by_dim,
|
||||
cur_rank,
|
||||
submesh_dim_names,
|
||||
_init_backend=False,
|
||||
_layout=layout,
|
||||
_rank_map=root_mesh._rank_map,
|
||||
mesh_dim_names=submesh_dim_names,
|
||||
_root_mesh=root_mesh,
|
||||
_init_backend=False,
|
||||
)
|
||||
res_submesh._dim_group_names = slice_dim_group_name
|
||||
return res_submesh
|
||||
@ -705,20 +708,13 @@ else:
|
||||
f"Please specify another valid mesh_dim_name."
|
||||
)
|
||||
|
||||
cur_rank = root_mesh.get_rank()
|
||||
# Due to the limitation of ProcessGroup api, we need to start from root mesh so that all ranks call the
|
||||
# new_group api to avoid potential hang.
|
||||
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(root_mesh._rank_map)
|
||||
res_flattened_mesh = DeviceMesh._create_mesh_from_ranks(
|
||||
res_flattened_mesh = DeviceMesh(
|
||||
root_mesh._device_type,
|
||||
pg_ranks_by_dim.flatten(
|
||||
start_dim=1
|
||||
), # this is needed for flatten non-contiguous mesh dims.
|
||||
cur_rank,
|
||||
(mesh_dim_name,),
|
||||
(backend_override,),
|
||||
_layout=flattened_mesh_layout,
|
||||
_rank_map=root_mesh._rank_map,
|
||||
mesh_dim_names=(mesh_dim_name,),
|
||||
_root_mesh=root_mesh,
|
||||
backend_override=(backend_override,),
|
||||
)
|
||||
root_mesh._flatten_mapping[mesh_dim_name] = res_flattened_mesh
|
||||
|
||||
@ -866,59 +862,6 @@ else:
|
||||
|
||||
return res_submeshes
|
||||
|
||||
@staticmethod
|
||||
def _create_mesh_from_ranks(
|
||||
device_type: str,
|
||||
pg_ranks_by_dim: torch.Tensor,
|
||||
cur_rank: int,
|
||||
mesh_dim_names: tuple[str, ...],
|
||||
backend_override: Optional[tuple[BackendConfig, ...]] = None,
|
||||
_init_backend: bool = True,
|
||||
_layout: Optional[_MeshLayout] = None,
|
||||
_root_mesh: Optional["DeviceMesh"] = None,
|
||||
) -> "DeviceMesh":
|
||||
"""
|
||||
Helper method to create a DeviceMesh from tensor `pg_ranks_by_dim`. This is due to
|
||||
the constraint of ProcessGroup API that all ranks have to call the PG creation API
|
||||
even if the rank is not in that PG.
|
||||
We will create a potentially very large number of DeviceMesh objects
|
||||
(e.g., on 1024 GPUs with TP=2, this could be up to 512 DeviceMeshes), only to throw
|
||||
them all away except when the mesh contains the current rank.
|
||||
|
||||
#TODO: Further refactor this method once we relax the ProcessGroup API constraint.
|
||||
|
||||
Args:
|
||||
device_type: The device type of the mesh.
|
||||
pg_ranks_by_dim: all ranks within the worlds organized by dimensions.
|
||||
cur_rank: The current global rank in the mesh.
|
||||
mesh_dim_names: Mesh dimension names.
|
||||
backend_override: Optional backend override for the mesh.
|
||||
_init_backend: Whether to initialize the backend of the mesh.
|
||||
_layout: Optional layout for the mesh.
|
||||
|
||||
Returns:
|
||||
The DeviceMesh containing the current rank.
|
||||
"""
|
||||
res_mesh = None
|
||||
for mesh_nd in pg_ranks_by_dim:
|
||||
mesh = DeviceMesh(
|
||||
device_type,
|
||||
mesh_nd,
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
backend_override=backend_override,
|
||||
_init_backend=_init_backend,
|
||||
_layout=_layout,
|
||||
_root_mesh=_root_mesh,
|
||||
)
|
||||
if cur_rank in mesh_nd:
|
||||
res_mesh = mesh
|
||||
if res_mesh is None:
|
||||
raise RuntimeError(
|
||||
f"Current rank {cur_rank} not found in any mesh, "
|
||||
f"input {pg_ranks_by_dim} does not contain all ranks in the world"
|
||||
)
|
||||
return res_mesh
|
||||
|
||||
@staticmethod
|
||||
def from_group(
|
||||
group: Union[ProcessGroup, list[ProcessGroup]],
|
||||
@ -1126,19 +1069,16 @@ else:
|
||||
] = ((None, None),),
|
||||
) -> "DeviceMesh":
|
||||
root_mesh = self._get_root_mesh()
|
||||
cur_rank = self.get_rank()
|
||||
unflattened_layout = self._layout.unflatten(dim, mesh_sizes)
|
||||
pg_ranks_by_dim = unflattened_layout.remap_to_tensor(root_mesh._rank_map)
|
||||
unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names))
|
||||
unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names)
|
||||
res_mesh = DeviceMesh._create_mesh_from_ranks(
|
||||
res_mesh = DeviceMesh(
|
||||
self.device_type,
|
||||
pg_ranks_by_dim,
|
||||
cur_rank,
|
||||
tuple(unflattened_mesh_dim_names),
|
||||
_init_backend=False,
|
||||
_layout=unflattened_layout,
|
||||
_rank_map=root_mesh._rank_map,
|
||||
mesh_dim_names=tuple(unflattened_mesh_dim_names),
|
||||
_root_mesh=root_mesh,
|
||||
_init_backend=False,
|
||||
)
|
||||
|
||||
# If original mesh has initiated its backend, we need to initialize the backend
|
||||
@ -1151,14 +1091,11 @@ else:
|
||||
tuple(unflattened_layout.sizes[dim : dim + unflatten_length]), # type: ignore[index]
|
||||
tuple(unflattened_layout.strides[dim : dim + unflatten_length]), # type: ignore[index]
|
||||
)
|
||||
unflatten_pg_ranks_by_dim = unflatten_layout.remap_to_tensor(
|
||||
root_mesh._rank_map
|
||||
)
|
||||
unflatten_submesh = DeviceMesh._create_mesh_from_ranks(
|
||||
unflatten_submesh = DeviceMesh(
|
||||
self.device_type,
|
||||
unflatten_pg_ranks_by_dim,
|
||||
cur_rank,
|
||||
mesh_dim_names,
|
||||
_layout=unflatten_layout,
|
||||
_rank_map=root_mesh._rank_map,
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
backend_override=backend_override,
|
||||
)
|
||||
dim_group_names = []
|
||||
@ -1360,13 +1297,15 @@ else:
|
||||
"If you maintained a 'torch.device' object, it's recommended to pass in 'device.type'.",
|
||||
)
|
||||
|
||||
# Always initialize the mesh's tensor on CPU, regardless of what the
|
||||
layout = _MeshLayout(tuple(mesh_shape), suffix_product(mesh_shape))
|
||||
# Always initialize the (identity) rank map on CPU, regardless of what the
|
||||
# external device type has been set to be (e.g. meta)
|
||||
with torch.device("cpu"):
|
||||
mesh = torch.arange(math.prod(mesh_shape), dtype=torch.int).view(mesh_shape)
|
||||
rank_map = torch.arange(layout.numel(), dtype=torch.int)
|
||||
device_mesh = DeviceMesh(
|
||||
device_type=device_type,
|
||||
mesh=mesh,
|
||||
_layout=layout,
|
||||
_rank_map=rank_map,
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
backend_override=backend_override_tuple,
|
||||
)
|
||||
|
Reference in New Issue
Block a user