[DeviceMesh] Prefer using _layout over _mesh for all sorts of things (#165554)

The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165554
Approved by: https://github.com/fduwjj
This commit is contained in:
Luca Wehrstedt
2025-10-16 13:54:18 +00:00
committed by PyTorch MergeBot
parent 99b32a6750
commit d61a9b88cf
3 changed files with 81 additions and 76 deletions

View File

@ -1664,14 +1664,14 @@ class CuTeLayoutTest(TestCase):
def test_remap_to_tensor(self): def test_remap_to_tensor(self):
"""Test the remap_to_tensor method for various scenarios.""" """Test the remap_to_tensor method for various scenarios."""
# Test 1: Consecutive ranks, full world - should return logical groups directly # Test 1: Consecutive ranks, full world - should return logical groups directly
original_mesh = torch.tensor([[0, 1], [2, 3]], dtype=torch.int) original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int)
layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2 layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2
result1 = layout1.remap_to_tensor(original_mesh) result1 = layout1.remap_to_tensor(original_mesh)
expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int) expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
self.assertEqual(result1, expected1) self.assertEqual(result1, expected1)
# Test 2: Non-consecutive ranks - should map to actual ranks # Test 2: Non-consecutive ranks - should map to actual ranks
original_mesh = torch.tensor([[10, 20], [30, 40]], dtype=torch.int) original_mesh = torch.tensor([10, 20, 30, 40], dtype=torch.int)
layout2 = _Layout((2, 2), (2, 1)) layout2 = _Layout((2, 2), (2, 1))
result2 = layout2.remap_to_tensor(original_mesh) result2 = layout2.remap_to_tensor(original_mesh)
expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int) expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
@ -1692,7 +1692,7 @@ class CuTeLayoutTest(TestCase):
self.assertEqual(result5, expected5) self.assertEqual(result5, expected5)
# Test 6: Tensor Cute representation of a 2D mesh # Test 6: Tensor Cute representation of a 2D mesh
original_mesh = torch.tensor([[0, 2], [1, 3]], dtype=torch.int) original_mesh = torch.tensor([0, 2, 1, 3], dtype=torch.int)
layout6 = _Layout((2, 2), (1, 2)) # column-major style layout6 = _Layout((2, 2), (1, 2)) # column-major style
result6 = layout6.remap_to_tensor(original_mesh) result6 = layout6.remap_to_tensor(original_mesh)
expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int) expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)

View File

@ -301,10 +301,7 @@ class _MeshLayout(Layout):
ranks = self.all_ranks_from_zero() ranks = self.all_ranks_from_zero()
return len(ranks) == len(set(ranks)) return len(ranks) == len(set(ranks))
def remap_to_tensor( def remap_to_tensor(self, rank_map: torch.Tensor) -> torch.Tensor:
self,
mesh_tensor: torch.Tensor,
) -> torch.Tensor:
""" """
Leverage layout as an index for mesh tensor that re-maps the indexes after layout Leverage layout as an index for mesh tensor that re-maps the indexes after layout
transformation to actual device ranks. transformation to actual device ranks.
@ -316,10 +313,7 @@ class _MeshLayout(Layout):
can be treated as a view or subset of mesh tensor, we do need to use the actual view or can be treated as a view or subset of mesh tensor, we do need to use the actual view or
sub-tensor for DeviceMesh and its backend creation. sub-tensor for DeviceMesh and its backend creation.
The shape of the `mesh_tensor` can be any size because users can define a device mesh with any The shape of the `rank_map` must be 1D and contiguous.
shapes. But we can further refactor the code so that internally we can only support 1D mesh tensor
and reconstruct the mesh tensor with the shape of the layout when accessed by users.
#TODO: Only support 1D mesh tensor stored internally and reconstruct the mesh tensor via layout.
Examples: Examples:
@ -336,18 +330,18 @@ class _MeshLayout(Layout):
Return: [[[10,30],[20,40]]] Return: [[[10,30],[20,40]]]
Args: Args:
mesh_tensor: The concrete mesh tensor with actual device ranks rank_map: The concrete mesh tensor with actual device ranks
Returns: Returns:
torch.Tensor: A tensor representing the actual device allocation from mesh_tensor torch.Tensor: A tensor representing the actual device allocation from rank_map
""" """
complement_layout = self.complement(mesh_tensor.numel()) assert rank_map.ndim == 1
assert rank_map.is_contiguous()
assert rank_map.numel() >= self.cosize()
return ( complement_layout = self.complement(rank_map.numel())
mesh_tensor.flatten()
.as_strided( return rank_map.as_strided(
flatten(complement_layout.sizes) + flatten(self.sizes), flatten(complement_layout.sizes) + flatten(self.sizes),
flatten(complement_layout.strides) + flatten(self.strides), flatten(complement_layout.strides) + flatten(self.strides),
) ).reshape(-1, *self.top_level_sizes)
.reshape(-1, *(self[i].numel() for i in range(len(self))))
)

View File

@ -173,7 +173,7 @@ else:
""" """
_device_type: str _device_type: str
_mesh: torch.Tensor _rank_map: torch.Tensor
_mesh_dim_names: Optional[tuple[str, ...]] _mesh_dim_names: Optional[tuple[str, ...]]
_layout: _MeshLayout _layout: _MeshLayout
_root_mesh: Optional["DeviceMesh"] = None _root_mesh: Optional["DeviceMesh"] = None
@ -190,46 +190,49 @@ else:
_init_backend: bool = True, _init_backend: bool = True,
_rank: Optional[int] = None, _rank: Optional[int] = None,
_layout: Optional[_MeshLayout] = None, _layout: Optional[_MeshLayout] = None,
_root_mesh: Optional["DeviceMesh"] = None,
) -> None: ) -> None:
self._device_type = device_type self._device_type = device_type
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu": if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}") raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
self._mesh = ( mesh_tensor = (
mesh.detach().to(dtype=torch.int).contiguous() mesh.detach().to(dtype=torch.int).contiguous()
if isinstance(mesh, torch.Tensor) if isinstance(mesh, torch.Tensor)
else torch.tensor(mesh, device="cpu", dtype=torch.int) else torch.tensor(mesh, device="cpu", dtype=torch.int)
) )
self._rank_map = (
_root_mesh._rank_map
if _root_mesh is not None
else mesh_tensor.flatten()
)
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
if backend_override is None:
backend_override = ((None, None),) * self.mesh.ndim
elif len(backend_override) != self.mesh.ndim:
raise ValueError(
f"backend_override should have the same length as the number of mesh dimensions, "
f"but got {len(backend_override)} and {self.mesh.ndim}."
)
# Internal bookkeeping for the device mesh. # Internal bookkeeping for the device mesh.
self._layout = ( self._layout = (
_layout _layout
if _layout if _layout
else _MeshLayout(self.mesh.size(), self.mesh.stride()) else _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
) )
self._root_mesh = _root_mesh
assert self._layout.check_non_overlap(), ( assert self._layout.check_non_overlap(), (
"Please use a non-overlapping layout when creating a DeviceMesh." "Please use a non-overlapping layout when creating a DeviceMesh."
) )
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here. # Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
assert self._layout.top_level_sizes == self.mesh.size(), ( assert self._layout.top_level_sizes == mesh_tensor.size(), (
"Please use a valid layout when creating a DeviceMesh." "Please use a valid layout when creating a DeviceMesh."
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}." f"The layout {self._layout} is not consistent with the mesh size {mesh_tensor.size()}."
) )
# private field to pre-generate DeviceMesh's hash if backend_override is None:
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) backend_override = ((None, None),) * len(self._layout)
self._thread_id = None elif len(backend_override) != len(self._layout):
# Initialize instance-specific flatten mapping raise ValueError(
self._flatten_mapping = {} f"backend_override should have the same length as the number of mesh dimensions, "
f"but got {len(backend_override)} and {len(self._layout)}."
)
# Skip process group initialization if xla device or init backend is False # Skip process group initialization if xla device or init backend is False
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend. # TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
self._thread_id = None
if device_type != "xla": if device_type != "xla":
# always try to create default (world) pg, even if it is not initialized # always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each # already. The world pg is used for device mesh identity (rank) on each
@ -252,6 +255,11 @@ else:
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
) )
# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
@property @property
def device_type(self) -> str: def device_type(self) -> str:
"""Returns the device type of the mesh.""" """Returns the device type of the mesh."""
@ -260,7 +268,17 @@ else:
@property @property
def mesh(self) -> torch.Tensor: def mesh(self) -> torch.Tensor:
"""Returns the tensor representing the layout of devices.""" """Returns the tensor representing the layout of devices."""
return self._mesh full_mesh = self._layout.remap_to_tensor(self._rank_map)
if full_mesh.size(0) == 1:
return full_mesh[0]
my_coords = (full_mesh == get_rank()).nonzero()
if my_coords.size(0) > 0:
return full_mesh[my_coords[0, 0]]
raise RuntimeError(
"In order to get the mesh Tensor of a DeviceMesh it needs to "
"either have all its original dimensions (e.g., no slicing) "
"or it needs to contain the local rank"
)
@property @property
def mesh_dim_names(self) -> Optional[tuple[str, ...]]: def mesh_dim_names(self) -> Optional[tuple[str, ...]]:
@ -275,9 +293,9 @@ else:
init_process_group() init_process_group()
world_size = get_world_size() world_size = get_world_size()
if self.mesh.numel() > world_size: if self._layout.numel() > world_size:
raise RuntimeError( raise RuntimeError(
f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!" f"Mesh should not be bigger than default world size {world_size}, but found {self._layout.numel()} ranks!"
) )
# ONLY set the device if the current device is not initialized, if user already # ONLY set the device if the current device is not initialized, if user already
@ -328,8 +346,8 @@ else:
default_group = _get_default_group() default_group = _get_default_group()
if ( if (
self.mesh.ndim == 1 len(self._layout) == 1
and self.mesh.numel() == get_world_size() and self._layout.numel() == get_world_size()
and backend_override[0] == (None, None) and backend_override[0] == (None, None)
): ):
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`. # Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
@ -348,11 +366,11 @@ else:
dim_group_names.append(dim_group.group_name) dim_group_names.append(dim_group.group_name)
else: else:
# create sub pgs base on the mesh argument specified # create sub pgs base on the mesh argument specified
for dim in range(self.mesh.ndim): for dim in range(len(self._layout)):
# swap the current dim to the last dim # swap the current dim to the last dim
# then reshape to flatten out other dims # then reshape to flatten out other dims
pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape( pg_ranks_by_dim = (
-1, self.mesh.size(dim) self._layout[dim].nest().remap_to_tensor(self._rank_map)
) )
backend, pg_options = backend_override[dim] backend, pg_options = backend_override[dim]
# We need to explicitly pass in timeout when specified in option, otherwise # We need to explicitly pass in timeout when specified in option, otherwise
@ -448,14 +466,14 @@ else:
def __repr__(self) -> str: def __repr__(self) -> str:
device_mesh_repr = ( device_mesh_repr = (
f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._mesh.shape))})" f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._layout.top_level_sizes))})"
if self._mesh_dim_names if self._mesh_dim_names
else f"{tuple(self._mesh.shape)}" else f"{self._layout.top_level_sizes}"
) )
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._mesh.stride()}" device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._layout.strides}"
# We only print the mesh tensor if the debug mode is turned on. # We only print the mesh tensor if the debug mode is turned on.
if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL": if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL":
device_mesh_repr += f", Mesh: {self._mesh.tolist()}" device_mesh_repr += f", Mesh: {self.mesh.tolist()}"
return f"{device_mesh_repr})" return f"{device_mesh_repr})"
def __hash__(self): def __hash__(self):
@ -465,7 +483,7 @@ else:
self._hash = hash( self._hash = hash(
( (
self._flatten_mesh_list, self._flatten_mesh_list,
self._mesh.shape, self._layout,
self._device_type, self._device_type,
self._mesh_dim_names, self._mesh_dim_names,
self._thread_id, self._thread_id,
@ -481,7 +499,7 @@ else:
return False return False
return ( return (
self._flatten_mesh_list == other._flatten_mesh_list self._flatten_mesh_list == other._flatten_mesh_list
and self._mesh.shape == other._mesh.shape and self._layout == other._layout
and self._device_type == other._device_type and self._device_type == other._device_type
and self._mesh_dim_names == other._mesh_dim_names and self._mesh_dim_names == other._mesh_dim_names
and self._thread_id == other._thread_id and self._thread_id == other._thread_id
@ -573,16 +591,16 @@ else:
if not hasattr(self, "_dim_group_names"): if not hasattr(self, "_dim_group_names"):
raise RuntimeError("DeviceMesh process groups not initialized!") raise RuntimeError("DeviceMesh process groups not initialized!")
if self.mesh.ndim > 1 and mesh_dim is None: if len(self._layout) > 1 and mesh_dim is None:
raise RuntimeError( raise RuntimeError(
f"Found the DeviceMesh have {self.mesh.ndim} dimensions", f"Found the DeviceMesh have {len(self._layout)} dimensions",
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
"If you want to get the list of all the ProcessGroups in the DeviceMesh," "If you want to get the list of all the ProcessGroups in the DeviceMesh,"
"please use `get_all_groups()` instead.", "please use `get_all_groups()` instead.",
) )
# Quick return if the current device_mesh is a 1D mesh. # Quick return if the current device_mesh is a 1D mesh.
if self.mesh.ndim == 1 and mesh_dim is None: if len(self._layout) == 1 and mesh_dim is None:
return not_none(_resolve_process_group(self._dim_group_names[0])) return not_none(_resolve_process_group(self._dim_group_names[0]))
root_mesh = self._get_root_mesh() root_mesh = self._get_root_mesh()
@ -608,7 +626,7 @@ else:
Returns: Returns:
A list of :class:`ProcessGroup` object. A list of :class:`ProcessGroup` object.
""" """
return [self.get_group(i) for i in range(self.mesh.ndim)] return [self.get_group(i) for i in range(len(self._layout))]
def _create_sub_mesh( def _create_sub_mesh(
self, self,
@ -635,9 +653,7 @@ else:
] ]
) )
cur_rank = self.get_rank() cur_rank = self.get_rank()
pg_ranks_by_dim = layout.remap_to_tensor( pg_ranks_by_dim = layout.remap_to_tensor(root_mesh._rank_map)
root_mesh.mesh,
)
res_submesh = DeviceMesh._create_mesh_from_ranks( res_submesh = DeviceMesh._create_mesh_from_ranks(
self._device_type, self._device_type,
pg_ranks_by_dim, pg_ranks_by_dim,
@ -692,9 +708,7 @@ else:
cur_rank = root_mesh.get_rank() cur_rank = root_mesh.get_rank()
# Due to the limitation of ProcessGroup api, we need to start from root mesh so that all ranks call the # Due to the limitation of ProcessGroup api, we need to start from root mesh so that all ranks call the
# new_group api to avoid potential hang. # new_group api to avoid potential hang.
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor( pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(root_mesh._rank_map)
root_mesh.mesh,
)
res_flattened_mesh = DeviceMesh._create_mesh_from_ranks( res_flattened_mesh = DeviceMesh._create_mesh_from_ranks(
root_mesh._device_type, root_mesh._device_type,
pg_ranks_by_dim.flatten( pg_ranks_by_dim.flatten(
@ -833,9 +847,7 @@ else:
""" """
mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name) mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name)
layout = self._layout[mesh_dim] layout = self._layout[mesh_dim]
pg_ranks_by_dim = layout.remap_to_tensor( pg_ranks_by_dim = layout.remap_to_tensor(self._rank_map)
self.mesh,
)
cur_rank = self.get_rank() cur_rank = self.get_rank()
res_submeshes = [] res_submeshes = []
for mesh_1d in pg_ranks_by_dim: for mesh_1d in pg_ranks_by_dim:
@ -896,6 +908,7 @@ else:
backend_override=backend_override, backend_override=backend_override,
_init_backend=_init_backend, _init_backend=_init_backend,
_layout=_layout, _layout=_layout,
_root_mesh=_root_mesh,
) )
if cur_rank in mesh_nd: if cur_rank in mesh_nd:
res_mesh = mesh res_mesh = mesh
@ -904,8 +917,6 @@ else:
f"Current rank {cur_rank} not found in any mesh, " f"Current rank {cur_rank} not found in any mesh, "
f"input {pg_ranks_by_dim} does not contain all ranks in the world" f"input {pg_ranks_by_dim} does not contain all ranks in the world"
) )
if _root_mesh is not None:
res_mesh._root_mesh = _root_mesh
return res_mesh return res_mesh
@staticmethod @staticmethod
@ -1004,15 +1015,17 @@ else:
return device_mesh return device_mesh
def size(self, mesh_dim: Optional[int] = None) -> int: def size(self, mesh_dim: Optional[int] = None) -> int:
return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim) if mesh_dim is not None:
return self._layout[mesh_dim].numel()
return self._layout.numel()
@property @property
def ndim(self) -> int: def ndim(self) -> int:
return self.mesh.ndim return len(self._layout)
@property @property
def shape(self) -> tuple[int, ...]: def shape(self) -> tuple[int, ...]:
return tuple(self.mesh.shape) return self._layout.top_level_sizes
def get_rank(self) -> int: def get_rank(self) -> int:
""" """
@ -1051,7 +1064,7 @@ else:
""" """
if self.ndim > 1 and mesh_dim is None: if self.ndim > 1 and mesh_dim is None:
raise RuntimeError( raise RuntimeError(
f"Found the DeviceMesh have {self.mesh.ndim} dimensions", f"Found the DeviceMesh have {len(self._layout)} dimensions",
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
) )
elif mesh_dim is None: elif mesh_dim is None:
@ -1115,9 +1128,7 @@ else:
root_mesh = self._get_root_mesh() root_mesh = self._get_root_mesh()
cur_rank = self.get_rank() cur_rank = self.get_rank()
unflattened_layout = self._layout.unflatten(dim, mesh_sizes) unflattened_layout = self._layout.unflatten(dim, mesh_sizes)
pg_ranks_by_dim = unflattened_layout.remap_to_tensor( pg_ranks_by_dim = unflattened_layout.remap_to_tensor(root_mesh._rank_map)
root_mesh.mesh,
)
unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names)) unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names))
unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names) unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names)
res_mesh = DeviceMesh._create_mesh_from_ranks( res_mesh = DeviceMesh._create_mesh_from_ranks(
@ -1141,7 +1152,7 @@ else:
tuple(unflattened_layout.strides[dim : dim + unflatten_length]), # type: ignore[index] tuple(unflattened_layout.strides[dim : dim + unflatten_length]), # type: ignore[index]
) )
unflatten_pg_ranks_by_dim = unflatten_layout.remap_to_tensor( unflatten_pg_ranks_by_dim = unflatten_layout.remap_to_tensor(
root_mesh.mesh, root_mesh._rank_map
) )
unflatten_submesh = DeviceMesh._create_mesh_from_ranks( unflatten_submesh = DeviceMesh._create_mesh_from_ranks(
self.device_type, self.device_type,