Files
pytorch/torch/distributed/_mesh_layout.py
Luca Wehrstedt d61a9b88cf [DeviceMesh] Prefer using _layout over _mesh for all sorts of things (#165554)
The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165554
Approved by: https://github.com/fduwjj
2025-10-16 17:01:44 +00:00

348 lines
14 KiB
Python

"""
Definition of CuTe inspired Layouts for DeviceMesh internal bookkeeping and functions to manipulate them
"""
import math
from collections.abc import Iterator
from dataclasses import dataclass
from itertools import product
import torch
from torch.distributed._pycute import (
coalesce,
complement,
composition,
flatten,
IntTuple,
is_int,
is_tuple,
Layout,
suffix_product,
)
@dataclass(frozen=True, init=True)
class _MeshLayout(Layout):
"""
Utility class for representing an integer layout by borrowing ideas from CuTe Layout Algebra.
See https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html for more details.
Each layout is represented as a list of sizes and strides. We use it as a way for mechanical bookkeeping
of the integers such as ranks in a SPMD mesh, and the transformation on top of it.
Lots of methods of layout like coalesce, composition, complement, etc. are borrowed from pycute.
https://github.com/NVIDIA/cutlass/blob/6dd13d42784ee5bfa232d2441e6b9a021c5c6290/python/pycute/layout.py#L137,L257
Note this is a CuTe-inspired layout, because CuTe uses co-lexicographic way in linearization while PyTorch
is using lexicographic. So even though the CuTe documentation can still be referenced, the implementation will be
different from that of PyCute's.
"""
# pyrefly: ignore # bad-override
shape: IntTuple
# pyrefly: ignore # bad-override
stride: IntTuple
def __post_init__(self) -> None:
if not is_tuple(self.shape) and not is_int(self.shape):
raise TypeError(f"shape must be a tuple or int, got {type(self.shape)}")
if not is_tuple(self.stride) and not is_int(self.stride):
raise TypeError(f"stride must be a tuple or int, got {type(self.stride)}")
if (
is_tuple(self.shape)
and is_tuple(self.stride)
and len(flatten(self.shape)) != len(flatten(self.stride))
):
raise ValueError(
f"sizes {len(flatten(self.shape))} and "
f"strides {len(flatten(self.stride))} must have the same length"
)
@property
def sizes(self) -> IntTuple:
return self.shape
@property
def strides(self) -> IntTuple:
return self.stride
@property
def sizes_and_strides(self) -> Iterator[tuple[int, int]]:
return zip(flatten(self.shape), flatten(self.stride))
@property
def top_level_sizes(self) -> tuple[int, ...]:
return tuple(self[i].numel() for i in range(len(self)))
def numel(self) -> int:
return math.prod(flatten(self.shape))
# # operator [] (get-i like tuples)
def __getitem__(self, i: int) -> "_MeshLayout":
layout = super().__getitem__(i)
return _MeshLayout(layout.shape, layout.stride)
def nest(self) -> "_MeshLayout":
return _MeshLayout((self.shape,), (self.stride,))
def coalesce(self) -> "_MeshLayout":
"""
A layout is represented by (sizes):(strides), e.g. (3,2):(4,2).
Two consecutive dimensions can be "merged" into one if their
strides are contiguous/multiplicative (i.e., the inner stride * inner size
equals the next stride), we perform this kind of merge inside coalesce.
Example 1 (simple): (3,2):(2,1)
- inner dimension: has stride=1, size=2
- outer dimension: stride = inner_stride * inner_size = 2
→ coalesced = (6:1) # acts like a flat 1D array of length 6
Example 2 (non-coalescible): (3,2):(4,1)
- inner dimension: stride=1, size=2 → 2*1 = 2
- outer dimension: stride=4, mismatch (≠ 2)
→ cannot merge; result stays (3,2):(4,1)
"""
layout = coalesce(self)
return _MeshLayout(layout.shape, layout.stride)
def composition(self, layout: "_MeshLayout") -> "_MeshLayout":
"""
By-dimension composition allows one layout to "select from" or "filter through" another layout.
Think of it as function composition: (self ∘ layout)(input) = self(layout(input))
between two layouts. This function is a wrapper of pycute's composition.
Mental model about how to understand the composition logic:
- The LEFT layout (self) defines the "output space" - what indices are possible
- The RIGHT layout (layout parameter) acts as a "selector" - which specific indices to pick
- The composition only generates indices that the left layout could originally produce,
but the right layout determines which indices to be picked.
- The stride of the composition layout will not be smaller than the stride of the right layout,
because when picking the indices the composition will at least follow the the right layout's stride
to move forward.
Example:
self = (6,2):(2,1) # sizes=(6,2), strides=(2,1)
layout = (3:2) # sizes=(3,), stride=(2,)
self o layout = (3:2)
Returns:
Layout being composed.
"""
result = composition(self, layout)
return _MeshLayout(result.shape, result.stride)
def complement(self, world_size: int) -> "_MeshLayout":
"""
Compute the "complement layout" relative to a given world_size.
A complement layout fills in the "missing" factor so that: self repeat a layout of complement(self, world_size)
will get a complete world_size. We use ⊗ to denote the repeat operation.
Example:
self = (4:1) # size=4, stride=1
world_size = 8
Then:
complete needed factor = 8 / 4 = 2
complement(self, 8) = (2:1)
Together they form:
(4:1) ⊗ (2:1) = (4,2):(2,1)
which has world_size = 4 * 2 = 8, as required.
In distributed terms, complement() is often used to derive the "other"
rank grouping when splitting processes into 2D meshes.
For a visualized explanation, see https://x.com/ezyang/status/1962364978393981433/
"""
layout = complement(self, world_size)
return _MeshLayout(layout.shape, layout.stride)
def unflatten(self, dim: int, unflatten_sizes: tuple[int, ...]) -> "_MeshLayout":
"""
Unflatten a single dimension in the layout by splitting it into multiple dimensions.
It takes a dimension at position `dim` and splits it into multiple new dimensions
with the specified sizes.
Args:
dim (int): The index of the dimension to unflatten. Must be a valid dimension index.
unflatten_sizes (tuple[int, ...]): The new sizes for the dimensions that will replace
the original dimension at `dim`. The product of these sizes must equal the size
of the original dimension at `dim`.
Returns:
_MeshLayout: A new layout with the specified dimension unflattened.
Example:
Original: sizes=(8,), strides=(1,) # 8 ranks in 1D
Call: unflatten(0, (2, 2, 2)) # Create 3D topology
Result: sizes=(2, 2, 2), strides=(4, 2, 1) # 2*2*2 unflattened topology
"""
# Check that dim is within valid range
if dim < 0 or dim >= len(self):
raise ValueError(
f"dim {dim} is out of range for layout with {len(self)} dimensions. "
f"Expected dim to be in range [0, {len(self) - 1}]."
)
# Check that the product of unflatten_sizes equals the original dimension size
original_size = self[dim].numel()
unflatten_product = math.prod(unflatten_sizes)
if unflatten_product != original_size:
raise ValueError(
f"The product of unflatten_sizes {unflatten_sizes} is {unflatten_product}, "
f"but the original dimension at dim={dim} has size {original_size}. "
f"These must be equal for unflatten to work correctly."
)
sizes = list(self.sizes) # type: ignore[arg-type]
strides = list(self.strides) # type: ignore[arg-type]
unflatten_layout = self[dim].composition(
_MeshLayout(tuple(unflatten_sizes), suffix_product(unflatten_sizes))
)
sizes[dim : dim + 1] = list(unflatten_layout.sizes) # type: ignore[arg-type]
strides[dim : dim + 1] = list(unflatten_layout.strides) # type: ignore[arg-type]
return _MeshLayout(tuple(sizes), tuple(strides))
def all_ranks_from_zero(self) -> list[int]:
"""
This function computes the all ranks specified by the layout staring from zero.
How it works:
1. we enumerates every possible coordinate (like a nested for-loop).
If sizes = (2, 3), we get the following coordinates:
(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)
2. For each coordinate, we compute a linear rank index as:
all_ranks_from_zero = sum(coord[i] * strides[i] for i in range(ndim))
Example A:
sizes = (2, 3) # 2 rows, 3 cols
strides = (3, 1) # row-major layout
coords = (0,0) -> 0*3 + 0*1 = 0
(0,1) -> 0*3 + 1*1 = 1
(0,2) -> 0*3 + 2*1 = 2
(1,0) -> 1*3 + 0*1 = 3
(1,1) -> 1*3 + 1*1 = 4
(1,2) -> 1*3 + 2*1 = 5
result = [0, 1, 2, 3, 4, 5]
Example B:
sizes = (2, 3)
strides = (1, 2) # non-standard / strided layout
coords = (0,0) -> 0*1 + 0*2 = 0
(0,1) -> 0*1 + 1*2 = 2
(0,2) -> 0*1 + 2*2 = 4
(1,0) -> 1*1 + 0*2 = 1
(1,1) -> 1*1 + 1*2 = 3
(1,2) -> 1*1 + 2*2 = 5
result = [0, 2, 4, 1, 3, 5]
"""
return [
sum(c * s for c, s in zip(coord, flatten(self.strides)))
for coord in product(*(range(s) for s in flatten(self.sizes)))
]
def global_ranks(self, world_size: int) -> list[list[int]]:
"""
Build global ranks specified by the layout via two-level ranks composition.
The nested list forms the Cartesian product of all ranks for one layout and offset
regarding filling up the world_size with the layout.
The final global ranks are the addition of these two. The result is a
list of lists: one sublist per layout. This rank list will be used to build
the communicator underlying the layout and the given `world_size`.
Example:
world_size = 16
self.size = 4
self.stride = 1
ranks = [0, 1, 2, 3]
offsets = [0, 4, 8, 12]
result = [
[0+0, 0+1, 0+2, 0+3], # → [0, 1, 2, 3]
[4+0, 4+1, 4+2, 4+3], # → [4, 5, 6, 7]
[8+0, 8+1, 8+2, 8+3], # → [8, 9, 10,11]
[12+0, 12+1, 12+2, 12+3], # → [12,13,14,15]
]
"""
return [
[offset + rank for rank in self.all_ranks_from_zero()]
for offset in self.complement(world_size).all_ranks_from_zero()
]
def check_non_overlap(self) -> bool:
"""
Check if the layout has any overlap between the ranks it generates. If there is overlap,
we return False, otherwise True.
The layout is supposed to be injective i.e, aside from indice 0, indices from each
dim of the layout must be non-overlapping.
Example 1 - Valid (no overlap):
Layout: sizes=(2,3), strides=(6,1)
- Dim 1: stride=1, span=3*1=3, covers indices [0,1,2]
- Dim 0: stride=6, span=2*6=12, covers indices [0,6]
→ No overlap since 6 > 3
Example 2 - Invalid (overlap):
Layout: sizes=(2,3), strides=(2,1)
- Dim 1: stride=1, span=3*1=3, covers indices [0,1,2]
- Dim 0: stride=2, span=2*2=4, covers indices [0,2]
→ Overlap! stride=2 < span=3, so indices [0,2] are duplicated
Example 3 - Invalid (overlap):
Layout: sizes=(4,2), strides=(1,1)
- Dim 1: stride=1, span=4, covers indices [0,1,2,3]
- Dim 0: stride=1, span=2, covers indices [0,1]
→ Overlap! stride is same for two dims, so indices [0,2] are duplicated
Returns:
bool: True if no overlap, False if overlap detected
"""
ranks = self.all_ranks_from_zero()
return len(ranks) == len(set(ranks))
def remap_to_tensor(self, rank_map: torch.Tensor) -> torch.Tensor:
"""
Leverage layout as an index for mesh tensor that re-maps the indexes after layout
transformation to actual device ranks.
With this method, the cute layout serves as the backend of indices bookkeeping for the
mesh tensor when it comes to flatten, unflatten and slicing operations. The actual mesh
tensor still represents the actual device assignment and ranks. We need this function
to specify device allocation and create backend for a mesh. Although any transform of mesh tensors
can be treated as a view or subset of mesh tensor, we do need to use the actual view or
sub-tensor for DeviceMesh and its backend creation.
The shape of the `rank_map` must be 1D and contiguous.
Examples:
Case 1 - Consecutive ranks, full world:
original_mesh_tensor = [[0,1],[2,3]] # 2x2 mesh, ranks 0-3
world_size = 4
layout = Layout(2:2)
Return: [[0,2],[1,3]]
Case 2 - Non-consecutive ranks:
original_mesh_tensor = [[10,20],[30,40]] # custom rank assignment
world_size = 4
layout = Layout(2:2)
Return: [[[10,30],[20,40]]]
Args:
rank_map: The concrete mesh tensor with actual device ranks
Returns:
torch.Tensor: A tensor representing the actual device allocation from rank_map
"""
assert rank_map.ndim == 1
assert rank_map.is_contiguous()
assert rank_map.numel() >= self.cosize()
complement_layout = self.complement(rank_map.numel())
return rank_map.as_strided(
flatten(complement_layout.sizes) + flatten(self.sizes),
flatten(complement_layout.strides) + flatten(self.strides),
).reshape(-1, *self.top_level_sizes)