[DeviceMesh] Introduce private constructor instead of _create_mesh_from_ranks (#165555)

The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor.

In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it.

This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`.

With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165555
Approved by: https://github.com/fduwjj, https://github.com/fegin
ghstack dependencies: #165554
This commit is contained in:
Luca Wehrstedt
2025-10-16 15:39:38 +00:00
committed by PyTorch MergeBot
parent a214371008
commit 99097b6d89

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
import math
import os
import threading
import warnings
@ -12,7 +11,7 @@ from typing import Optional, TYPE_CHECKING, Union
import torch
from torch.distributed import is_available
from torch.distributed._mesh_layout import _MeshLayout
from torch.distributed._pycute import is_int
from torch.distributed._pycute import is_int, suffix_product
from torch.utils._typing_utils import not_none
@ -183,45 +182,52 @@ else:
def __init__(
self,
device_type: str,
mesh: Union[torch.Tensor, "ArrayLike"],
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
*,
mesh_dim_names: Optional[tuple[str, ...]] = None,
backend_override: Optional[tuple[BackendConfig, ...]] = None,
_init_backend: bool = True,
_rank: Optional[int] = None,
_layout: Optional[_MeshLayout] = None,
_rank_map: Optional[torch.Tensor] = None,
_root_mesh: Optional["DeviceMesh"] = None,
) -> None:
self._device_type = device_type
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
mesh_tensor = (
mesh.detach().to(dtype=torch.int).contiguous()
if isinstance(mesh, torch.Tensor)
else torch.tensor(mesh, device="cpu", dtype=torch.int)
)
self._rank_map = (
_root_mesh._rank_map
if _root_mesh is not None
else mesh_tensor.flatten()
)
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
# Internal bookkeeping for the device mesh.
self._layout = (
_layout
if _layout
else _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
)
self._root_mesh = _root_mesh
assert self._layout.check_non_overlap(), (
if mesh is not None:
if _layout is not None or _rank_map is not None:
raise TypeError(
"Cannot provide _layout and/or _rank_map if passing explicit mesh"
)
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
mesh_tensor = (
mesh.detach().to(dtype=torch.int).contiguous()
if isinstance(mesh, torch.Tensor)
else torch.tensor(mesh, device="cpu", dtype=torch.int)
)
_layout = _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
_rank_map = mesh_tensor.flatten()
else:
if _layout is None or _rank_map is None:
raise TypeError(
"The mesh argument is required except for PRIVATE USAGE ONLY!"
)
assert _layout.check_non_overlap(), (
"Please use a non-overlapping layout when creating a DeviceMesh."
)
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
assert self._layout.top_level_sizes == mesh_tensor.size(), (
"Please use a valid layout when creating a DeviceMesh."
f"The layout {self._layout} is not consistent with the mesh size {mesh_tensor.size()}."
assert _rank_map.ndim == 1, "The rank map must be 1-dimensional"
assert _rank_map.is_contiguous(), "The rank map must be contiguous"
assert _rank_map.numel() >= _layout.cosize(), (
f"The rank map contains {_rank_map.numel()} element, "
f"which isn't large enough for layout {_layout}"
)
self._device_type = device_type
self._layout = _layout
self._rank_map = _rank_map
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
self._root_mesh = _root_mesh
if backend_override is None:
backend_override = ((None, None),) * len(self._layout)
elif len(backend_override) != len(self._layout):
@ -652,16 +658,13 @@ else:
not_none(flatten_mesh._mesh_dim_names).index(name)
]
)
cur_rank = self.get_rank()
pg_ranks_by_dim = layout.remap_to_tensor(root_mesh._rank_map)
res_submesh = DeviceMesh._create_mesh_from_ranks(
res_submesh = DeviceMesh(
self._device_type,
pg_ranks_by_dim,
cur_rank,
submesh_dim_names,
_init_backend=False,
_layout=layout,
_rank_map=root_mesh._rank_map,
mesh_dim_names=submesh_dim_names,
_root_mesh=root_mesh,
_init_backend=False,
)
res_submesh._dim_group_names = slice_dim_group_name
return res_submesh
@ -705,20 +708,13 @@ else:
f"Please specify another valid mesh_dim_name."
)
cur_rank = root_mesh.get_rank()
# Due to the limitation of ProcessGroup api, we need to start from root mesh so that all ranks call the
# new_group api to avoid potential hang.
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(root_mesh._rank_map)
res_flattened_mesh = DeviceMesh._create_mesh_from_ranks(
res_flattened_mesh = DeviceMesh(
root_mesh._device_type,
pg_ranks_by_dim.flatten(
start_dim=1
), # this is needed for flatten non-contiguous mesh dims.
cur_rank,
(mesh_dim_name,),
(backend_override,),
_layout=flattened_mesh_layout,
_rank_map=root_mesh._rank_map,
mesh_dim_names=(mesh_dim_name,),
_root_mesh=root_mesh,
backend_override=(backend_override,),
)
root_mesh._flatten_mapping[mesh_dim_name] = res_flattened_mesh
@ -866,59 +862,6 @@ else:
return res_submeshes
@staticmethod
def _create_mesh_from_ranks(
device_type: str,
pg_ranks_by_dim: torch.Tensor,
cur_rank: int,
mesh_dim_names: tuple[str, ...],
backend_override: Optional[tuple[BackendConfig, ...]] = None,
_init_backend: bool = True,
_layout: Optional[_MeshLayout] = None,
_root_mesh: Optional["DeviceMesh"] = None,
) -> "DeviceMesh":
"""
Helper method to create a DeviceMesh from tensor `pg_ranks_by_dim`. This is due to
the constraint of ProcessGroup API that all ranks have to call the PG creation API
even if the rank is not in that PG.
We will create a potentially very large number of DeviceMesh objects
(e.g., on 1024 GPUs with TP=2, this could be up to 512 DeviceMeshes), only to throw
them all away except when the mesh contains the current rank.
#TODO: Further refactor this method once we relax the ProcessGroup API constraint.
Args:
device_type: The device type of the mesh.
pg_ranks_by_dim: all ranks within the worlds organized by dimensions.
cur_rank: The current global rank in the mesh.
mesh_dim_names: Mesh dimension names.
backend_override: Optional backend override for the mesh.
_init_backend: Whether to initialize the backend of the mesh.
_layout: Optional layout for the mesh.
Returns:
The DeviceMesh containing the current rank.
"""
res_mesh = None
for mesh_nd in pg_ranks_by_dim:
mesh = DeviceMesh(
device_type,
mesh_nd,
mesh_dim_names=mesh_dim_names,
backend_override=backend_override,
_init_backend=_init_backend,
_layout=_layout,
_root_mesh=_root_mesh,
)
if cur_rank in mesh_nd:
res_mesh = mesh
if res_mesh is None:
raise RuntimeError(
f"Current rank {cur_rank} not found in any mesh, "
f"input {pg_ranks_by_dim} does not contain all ranks in the world"
)
return res_mesh
@staticmethod
def from_group(
group: Union[ProcessGroup, list[ProcessGroup]],
@ -1126,19 +1069,16 @@ else:
] = ((None, None),),
) -> "DeviceMesh":
root_mesh = self._get_root_mesh()
cur_rank = self.get_rank()
unflattened_layout = self._layout.unflatten(dim, mesh_sizes)
pg_ranks_by_dim = unflattened_layout.remap_to_tensor(root_mesh._rank_map)
unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names))
unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names)
res_mesh = DeviceMesh._create_mesh_from_ranks(
res_mesh = DeviceMesh(
self.device_type,
pg_ranks_by_dim,
cur_rank,
tuple(unflattened_mesh_dim_names),
_init_backend=False,
_layout=unflattened_layout,
_rank_map=root_mesh._rank_map,
mesh_dim_names=tuple(unflattened_mesh_dim_names),
_root_mesh=root_mesh,
_init_backend=False,
)
# If original mesh has initiated its backend, we need to initialize the backend
@ -1151,14 +1091,11 @@ else:
tuple(unflattened_layout.sizes[dim : dim + unflatten_length]), # type: ignore[index]
tuple(unflattened_layout.strides[dim : dim + unflatten_length]), # type: ignore[index]
)
unflatten_pg_ranks_by_dim = unflatten_layout.remap_to_tensor(
root_mesh._rank_map
)
unflatten_submesh = DeviceMesh._create_mesh_from_ranks(
unflatten_submesh = DeviceMesh(
self.device_type,
unflatten_pg_ranks_by_dim,
cur_rank,
mesh_dim_names,
_layout=unflatten_layout,
_rank_map=root_mesh._rank_map,
mesh_dim_names=mesh_dim_names,
backend_override=backend_override,
)
dim_group_names = []
@ -1360,13 +1297,15 @@ else:
"If you maintained a 'torch.device' object, it's recommended to pass in 'device.type'.",
)
# Always initialize the mesh's tensor on CPU, regardless of what the
layout = _MeshLayout(tuple(mesh_shape), suffix_product(mesh_shape))
# Always initialize the (identity) rank map on CPU, regardless of what the
# external device type has been set to be (e.g. meta)
with torch.device("cpu"):
mesh = torch.arange(math.prod(mesh_shape), dtype=torch.int).view(mesh_shape)
rank_map = torch.arange(layout.numel(), dtype=torch.int)
device_mesh = DeviceMesh(
device_type=device_type,
mesh=mesh,
_layout=layout,
_rank_map=rank_map,
mesh_dim_names=mesh_dim_names,
backend_override=backend_override_tuple,
)