[DeviceMesh] Prefer using _layout over _mesh for all sorts of things (#165554)

The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165554
Approved by: https://github.com/fduwjj
This commit is contained in:
Luca Wehrstedt
2025-10-16 13:54:18 +00:00
committed by PyTorch MergeBot
parent 99b32a6750
commit d61a9b88cf
3 changed files with 81 additions and 76 deletions

View File

@ -1664,14 +1664,14 @@ class CuTeLayoutTest(TestCase):
def test_remap_to_tensor(self):
"""Test the remap_to_tensor method for various scenarios."""
# Test 1: Consecutive ranks, full world - should return logical groups directly
original_mesh = torch.tensor([[0, 1], [2, 3]], dtype=torch.int)
original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int)
layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2
result1 = layout1.remap_to_tensor(original_mesh)
expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
self.assertEqual(result1, expected1)
# Test 2: Non-consecutive ranks - should map to actual ranks
original_mesh = torch.tensor([[10, 20], [30, 40]], dtype=torch.int)
original_mesh = torch.tensor([10, 20, 30, 40], dtype=torch.int)
layout2 = _Layout((2, 2), (2, 1))
result2 = layout2.remap_to_tensor(original_mesh)
expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
@ -1692,7 +1692,7 @@ class CuTeLayoutTest(TestCase):
self.assertEqual(result5, expected5)
# Test 6: Tensor Cute representation of a 2D mesh
original_mesh = torch.tensor([[0, 2], [1, 3]], dtype=torch.int)
original_mesh = torch.tensor([0, 2, 1, 3], dtype=torch.int)
layout6 = _Layout((2, 2), (1, 2)) # column-major style
result6 = layout6.remap_to_tensor(original_mesh)
expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)

View File

@ -301,10 +301,7 @@ class _MeshLayout(Layout):
ranks = self.all_ranks_from_zero()
return len(ranks) == len(set(ranks))
def remap_to_tensor(
self,
mesh_tensor: torch.Tensor,
) -> torch.Tensor:
def remap_to_tensor(self, rank_map: torch.Tensor) -> torch.Tensor:
"""
Leverage layout as an index for mesh tensor that re-maps the indexes after layout
transformation to actual device ranks.
@ -316,10 +313,7 @@ class _MeshLayout(Layout):
can be treated as a view or subset of mesh tensor, we do need to use the actual view or
sub-tensor for DeviceMesh and its backend creation.
The shape of the `mesh_tensor` can be any size because users can define a device mesh with any
shapes. But we can further refactor the code so that internally we can only support 1D mesh tensor
and reconstruct the mesh tensor with the shape of the layout when accessed by users.
#TODO: Only support 1D mesh tensor stored internally and reconstruct the mesh tensor via layout.
The shape of the `rank_map` must be 1D and contiguous.
Examples:
@ -336,18 +330,18 @@ class _MeshLayout(Layout):
Return: [[[10,30],[20,40]]]
Args:
mesh_tensor: The concrete mesh tensor with actual device ranks
rank_map: The concrete mesh tensor with actual device ranks
Returns:
torch.Tensor: A tensor representing the actual device allocation from mesh_tensor
torch.Tensor: A tensor representing the actual device allocation from rank_map
"""
complement_layout = self.complement(mesh_tensor.numel())
assert rank_map.ndim == 1
assert rank_map.is_contiguous()
assert rank_map.numel() >= self.cosize()
return (
mesh_tensor.flatten()
.as_strided(
complement_layout = self.complement(rank_map.numel())
return rank_map.as_strided(
flatten(complement_layout.sizes) + flatten(self.sizes),
flatten(complement_layout.strides) + flatten(self.strides),
)
.reshape(-1, *(self[i].numel() for i in range(len(self))))
)
).reshape(-1, *self.top_level_sizes)

View File

@ -173,7 +173,7 @@ else:
"""
_device_type: str
_mesh: torch.Tensor
_rank_map: torch.Tensor
_mesh_dim_names: Optional[tuple[str, ...]]
_layout: _MeshLayout
_root_mesh: Optional["DeviceMesh"] = None
@ -190,46 +190,49 @@ else:
_init_backend: bool = True,
_rank: Optional[int] = None,
_layout: Optional[_MeshLayout] = None,
_root_mesh: Optional["DeviceMesh"] = None,
) -> None:
self._device_type = device_type
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
self._mesh = (
mesh_tensor = (
mesh.detach().to(dtype=torch.int).contiguous()
if isinstance(mesh, torch.Tensor)
else torch.tensor(mesh, device="cpu", dtype=torch.int)
)
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
if backend_override is None:
backend_override = ((None, None),) * self.mesh.ndim
elif len(backend_override) != self.mesh.ndim:
raise ValueError(
f"backend_override should have the same length as the number of mesh dimensions, "
f"but got {len(backend_override)} and {self.mesh.ndim}."
self._rank_map = (
_root_mesh._rank_map
if _root_mesh is not None
else mesh_tensor.flatten()
)
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
# Internal bookkeeping for the device mesh.
self._layout = (
_layout
if _layout
else _MeshLayout(self.mesh.size(), self.mesh.stride())
else _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
)
self._root_mesh = _root_mesh
assert self._layout.check_non_overlap(), (
"Please use a non-overlapping layout when creating a DeviceMesh."
)
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
assert self._layout.top_level_sizes == self.mesh.size(), (
assert self._layout.top_level_sizes == mesh_tensor.size(), (
"Please use a valid layout when creating a DeviceMesh."
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
f"The layout {self._layout} is not consistent with the mesh size {mesh_tensor.size()}."
)
# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
self._thread_id = None
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
if backend_override is None:
backend_override = ((None, None),) * len(self._layout)
elif len(backend_override) != len(self._layout):
raise ValueError(
f"backend_override should have the same length as the number of mesh dimensions, "
f"but got {len(backend_override)} and {len(self._layout)}."
)
# Skip process group initialization if xla device or init backend is False
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
self._thread_id = None
if device_type != "xla":
# always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each
@ -252,6 +255,11 @@ else:
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
)
# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
@property
def device_type(self) -> str:
"""Returns the device type of the mesh."""
@ -260,7 +268,17 @@ else:
@property
def mesh(self) -> torch.Tensor:
"""Returns the tensor representing the layout of devices."""
return self._mesh
full_mesh = self._layout.remap_to_tensor(self._rank_map)
if full_mesh.size(0) == 1:
return full_mesh[0]
my_coords = (full_mesh == get_rank()).nonzero()
if my_coords.size(0) > 0:
return full_mesh[my_coords[0, 0]]
raise RuntimeError(
"In order to get the mesh Tensor of a DeviceMesh it needs to "
"either have all its original dimensions (e.g., no slicing) "
"or it needs to contain the local rank"
)
@property
def mesh_dim_names(self) -> Optional[tuple[str, ...]]:
@ -275,9 +293,9 @@ else:
init_process_group()
world_size = get_world_size()
if self.mesh.numel() > world_size:
if self._layout.numel() > world_size:
raise RuntimeError(
f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!"
f"Mesh should not be bigger than default world size {world_size}, but found {self._layout.numel()} ranks!"
)
# ONLY set the device if the current device is not initialized, if user already
@ -328,8 +346,8 @@ else:
default_group = _get_default_group()
if (
self.mesh.ndim == 1
and self.mesh.numel() == get_world_size()
len(self._layout) == 1
and self._layout.numel() == get_world_size()
and backend_override[0] == (None, None)
):
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
@ -348,11 +366,11 @@ else:
dim_group_names.append(dim_group.group_name)
else:
# create sub pgs base on the mesh argument specified
for dim in range(self.mesh.ndim):
for dim in range(len(self._layout)):
# swap the current dim to the last dim
# then reshape to flatten out other dims
pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
-1, self.mesh.size(dim)
pg_ranks_by_dim = (
self._layout[dim].nest().remap_to_tensor(self._rank_map)
)
backend, pg_options = backend_override[dim]
# We need to explicitly pass in timeout when specified in option, otherwise
@ -448,14 +466,14 @@ else:
def __repr__(self) -> str:
device_mesh_repr = (
f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._mesh.shape))})"
f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._layout.top_level_sizes))})"
if self._mesh_dim_names
else f"{tuple(self._mesh.shape)}"
else f"{self._layout.top_level_sizes}"
)
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._mesh.stride()}"
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._layout.strides}"
# We only print the mesh tensor if the debug mode is turned on.
if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL":
device_mesh_repr += f", Mesh: {self._mesh.tolist()}"
device_mesh_repr += f", Mesh: {self.mesh.tolist()}"
return f"{device_mesh_repr})"
def __hash__(self):
@ -465,7 +483,7 @@ else:
self._hash = hash(
(
self._flatten_mesh_list,
self._mesh.shape,
self._layout,
self._device_type,
self._mesh_dim_names,
self._thread_id,
@ -481,7 +499,7 @@ else:
return False
return (
self._flatten_mesh_list == other._flatten_mesh_list
and self._mesh.shape == other._mesh.shape
and self._layout == other._layout
and self._device_type == other._device_type
and self._mesh_dim_names == other._mesh_dim_names
and self._thread_id == other._thread_id
@ -573,16 +591,16 @@ else:
if not hasattr(self, "_dim_group_names"):
raise RuntimeError("DeviceMesh process groups not initialized!")
if self.mesh.ndim > 1 and mesh_dim is None:
if len(self._layout) > 1 and mesh_dim is None:
raise RuntimeError(
f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
f"Found the DeviceMesh have {len(self._layout)} dimensions",
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
"If you want to get the list of all the ProcessGroups in the DeviceMesh,"
"please use `get_all_groups()` instead.",
)
# Quick return if the current device_mesh is a 1D mesh.
if self.mesh.ndim == 1 and mesh_dim is None:
if len(self._layout) == 1 and mesh_dim is None:
return not_none(_resolve_process_group(self._dim_group_names[0]))
root_mesh = self._get_root_mesh()
@ -608,7 +626,7 @@ else:
Returns:
A list of :class:`ProcessGroup` object.
"""
return [self.get_group(i) for i in range(self.mesh.ndim)]
return [self.get_group(i) for i in range(len(self._layout))]
def _create_sub_mesh(
self,
@ -635,9 +653,7 @@ else:
]
)
cur_rank = self.get_rank()
pg_ranks_by_dim = layout.remap_to_tensor(
root_mesh.mesh,
)
pg_ranks_by_dim = layout.remap_to_tensor(root_mesh._rank_map)
res_submesh = DeviceMesh._create_mesh_from_ranks(
self._device_type,
pg_ranks_by_dim,
@ -692,9 +708,7 @@ else:
cur_rank = root_mesh.get_rank()
# Due to the limitation of ProcessGroup api, we need to start from root mesh so that all ranks call the
# new_group api to avoid potential hang.
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(
root_mesh.mesh,
)
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(root_mesh._rank_map)
res_flattened_mesh = DeviceMesh._create_mesh_from_ranks(
root_mesh._device_type,
pg_ranks_by_dim.flatten(
@ -833,9 +847,7 @@ else:
"""
mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name)
layout = self._layout[mesh_dim]
pg_ranks_by_dim = layout.remap_to_tensor(
self.mesh,
)
pg_ranks_by_dim = layout.remap_to_tensor(self._rank_map)
cur_rank = self.get_rank()
res_submeshes = []
for mesh_1d in pg_ranks_by_dim:
@ -896,6 +908,7 @@ else:
backend_override=backend_override,
_init_backend=_init_backend,
_layout=_layout,
_root_mesh=_root_mesh,
)
if cur_rank in mesh_nd:
res_mesh = mesh
@ -904,8 +917,6 @@ else:
f"Current rank {cur_rank} not found in any mesh, "
f"input {pg_ranks_by_dim} does not contain all ranks in the world"
)
if _root_mesh is not None:
res_mesh._root_mesh = _root_mesh
return res_mesh
@staticmethod
@ -1004,15 +1015,17 @@ else:
return device_mesh
def size(self, mesh_dim: Optional[int] = None) -> int:
return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim)
if mesh_dim is not None:
return self._layout[mesh_dim].numel()
return self._layout.numel()
@property
def ndim(self) -> int:
return self.mesh.ndim
return len(self._layout)
@property
def shape(self) -> tuple[int, ...]:
return tuple(self.mesh.shape)
return self._layout.top_level_sizes
def get_rank(self) -> int:
"""
@ -1051,7 +1064,7 @@ else:
"""
if self.ndim > 1 and mesh_dim is None:
raise RuntimeError(
f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
f"Found the DeviceMesh have {len(self._layout)} dimensions",
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
)
elif mesh_dim is None:
@ -1115,9 +1128,7 @@ else:
root_mesh = self._get_root_mesh()
cur_rank = self.get_rank()
unflattened_layout = self._layout.unflatten(dim, mesh_sizes)
pg_ranks_by_dim = unflattened_layout.remap_to_tensor(
root_mesh.mesh,
)
pg_ranks_by_dim = unflattened_layout.remap_to_tensor(root_mesh._rank_map)
unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names))
unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names)
res_mesh = DeviceMesh._create_mesh_from_ranks(
@ -1141,7 +1152,7 @@ else:
tuple(unflattened_layout.strides[dim : dim + unflatten_length]), # type: ignore[index]
)
unflatten_pg_ranks_by_dim = unflatten_layout.remap_to_tensor(
root_mesh.mesh,
root_mesh._rank_map
)
unflatten_submesh = DeviceMesh._create_mesh_from_ranks(
self.device_type,