Revert "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout (#163213)"

This reverts commit b0985144b59db8fb20964829b5e0a9d2f9a3f0d6.

Reverted https://github.com/pytorch/pytorch/pull/163213 on behalf of https://github.com/yangw-dev due to caused internal test failure ([comment](https://github.com/pytorch/pytorch/pull/163213#issuecomment-3363414435))
This commit is contained in:
PyTorch MergeBot
2025-10-02 22:22:24 +00:00
parent bdc0a421d7
commit 22e219d996
3 changed files with 141 additions and 246 deletions

View File

@ -440,7 +440,6 @@ class DeviceMeshTestNDim(DTensorTestBase):
ep_mesh = ep_mesh_1 if self.rank < self.world_size // 2 else ep_mesh_2
# ep_mesh is considered different from mesh_2d["TP"]
self.assertEqual(mesh_2d["TP"]._flatten_mesh_list, ep_mesh._flatten_mesh_list)
self.assertEqual(mesh_2d["TP"]._layout, ep_mesh._layout)
self.assertEqual(mesh_2d["TP"].mesh.shape, ep_mesh.mesh.shape)
self.assertEqual(mesh_2d["TP"].device_type, ep_mesh.device_type)
self.assertNotEqual(mesh_2d["TP"].mesh_dim_names, ep_mesh.mesh_dim_names)
@ -455,7 +454,6 @@ class DeviceMeshTestNDim(DTensorTestBase):
)
# another_mesh is considered the same as ep_mesh
self.assertEqual(ep_mesh._flatten_mesh_list, another_mesh._flatten_mesh_list)
self.assertEqual(ep_mesh._layout, another_mesh._layout)
self.assertEqual(ep_mesh.mesh.shape, another_mesh.mesh.shape)
self.assertEqual(ep_mesh.device_type, another_mesh.device_type)
self.assertEqual(ep_mesh.mesh_dim_names, another_mesh.mesh_dim_names)
@ -541,6 +539,7 @@ class DeviceMeshTestNDim(DTensorTestBase):
mesh_dim_names=("dp_replicate", "dp_shard"),
)
# self.assertEqual(ref_mesh._dim_group_names, dp_mesh._dim_group_names)
for mesh_dim_group, ref_mesh_dim_group in zip(
dp_mesh.get_all_groups(), ref_mesh.get_all_groups()
):
@ -801,10 +800,6 @@ class TestDeviceMeshGetItem(DTensorTestBase):
# Test slicing out 1D mesh from a sub-2D mesh.
shard_mesh = hsdp_mesh_2["Shard"]
self.assertEqual(shard_mesh.mesh.tolist(), shard_group[shard_group_idx])
replicate_mesh = hsdp_mesh_2["Replicate"]
self.assertEqual(
replicate_mesh.mesh.tolist(), replicate_group[replicate_group_idx]
)
@with_comms
def test_cache_and_reuse_submesh_slice_result(self):
@ -878,17 +873,12 @@ class TestDeviceMeshGetItem(DTensorTestBase):
flattened_dp_cp_mesh = dp_cp_mesh._flatten()
self.assertEqual(dp_cp_mesh.mesh.flatten(), flattened_dp_cp_mesh.mesh)
self.assertEqual(flattened_dp_cp_mesh.mesh_dim_names[0], "dp_cp")
self.assertEqual(flattened_dp_cp_mesh.get_group().group_desc, "mesh_dp_cp")
root_mesh = _mesh_resources.get_root_mesh(dp_cp_mesh)
self.assertEqual(root_mesh, mesh_3d)
flatten_mesh_layout = _mesh_resources.root_to_flatten_mapping[root_mesh][
flatten_mesh_root_dims = _mesh_resources.flatten_name_to_root_dims[root_mesh][
"dp_cp"
]._layout
self.assertEqual(flatten_mesh_layout, flattened_dp_cp_mesh._layout)
self.assertEqual(
flattened_dp_cp_mesh._layout.global_ranks(8),
[[0, 2, 4, 6], [1, 3, 5, 7]],
)
]
self.assertEqual(flatten_mesh_root_dims, (0, 1))
ref_pg_count = _world.group_count
# Calling flatten again should not create a new pg.
@ -903,19 +893,10 @@ class TestDeviceMeshGetItem(DTensorTestBase):
self.assertEqual(flattened_dp_tp_mesh.mesh_dim_names[0], "dp_tp")
root_mesh = _mesh_resources.get_root_mesh(dp_tp_mesh)
self.assertEqual(root_mesh, mesh_3d)
flatten_mesh_root_layout = _mesh_resources.root_to_flatten_mapping[root_mesh][
flatten_mesh_root_dims = _mesh_resources.flatten_name_to_root_dims[root_mesh][
"dp_tp"
]._layout
self.assertEqual(flatten_mesh_root_layout, flattened_dp_tp_mesh._layout)
self.assertEqual(
flattened_dp_tp_mesh._layout.global_ranks(8),
[[0, 1, 4, 5], [2, 3, 6, 7]],
)
with self.assertRaisesRegex(
NotImplementedError,
"Currently, this only allows slicing out a contiguous flattened dim",
):
mesh_3d["dp_tp", "cp"]
]
self.assertEqual(flatten_mesh_root_dims, (0, 2))
# Test flatten with a flattened mesh_dim_name
cp_tp_mesh = mesh_3d["cp", "tp"]
@ -1556,50 +1537,6 @@ class CuTeLayoutTest(TestCase):
layout8 = _Layout((3, 2), (2, 3))
self.assertTrue(layout8.check_non_overlap())
def test_remap_to_tensor(self):
"""Test the remap_to_tensor method for various scenarios."""
# Test 1: Consecutive ranks, full world - should return logical groups directly
original_mesh = torch.tensor([[0, 1], [2, 3]], dtype=torch.int)
layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2
result1 = layout1.remap_to_tensor(original_mesh)
expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
self.assertEqual(result1, expected1)
# Test 2: Non-consecutive ranks - should map to actual ranks
original_mesh = torch.tensor([[10, 20], [30, 40]], dtype=torch.int)
layout2 = _Layout((2, 2), (2, 1))
result2 = layout2.remap_to_tensor(original_mesh)
expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
self.assertEqual(result2, expected2)
# Test 4: 1D layout with consecutive ranks
original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int)
layout4 = _Layout((4,), (1,))
result4 = layout4.remap_to_tensor(original_mesh)
expected4 = torch.tensor([[0, 1, 2, 3]], dtype=torch.int)
self.assertEqual(result4, expected4)
# Test 5: Complex strided layout with non-consecutive ranks
original_mesh = torch.tensor([5, 10, 15, 20], dtype=torch.int)
layout5 = _Layout((2, 2), (2, 1))
result5 = layout5.remap_to_tensor(original_mesh)
expected5 = torch.tensor([[[5, 10], [15, 20]]], dtype=torch.int)
self.assertEqual(result5, expected5)
# Test 6: Tensor Cute representation of a 2D mesh
original_mesh = torch.tensor([[0, 2], [1, 3]], dtype=torch.int)
layout6 = _Layout((2, 2), (1, 2)) # column-major style
result6 = layout6.remap_to_tensor(original_mesh)
expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
self.assertEqual(result6, expected6)
# Test 7: Layout with different stride pattern
original_mesh = torch.tensor([0, 2, 1, 4], dtype=torch.int)
layout7 = _Layout((2, 2), (1, 2)) # column-major style
result7 = layout7.remap_to_tensor(original_mesh)
expected7 = torch.tensor([[[0, 1], [2, 4]]], dtype=torch.int)
self.assertEqual(result7, expected7)
if __name__ == "__main__":
run_tests()

View File

@ -7,7 +7,6 @@ from collections.abc import Iterator
from dataclasses import dataclass
from itertools import product
import torch
from torch.distributed._pycute import (
coalesce,
complement,
@ -244,54 +243,3 @@ class _MeshLayout(Layout):
"""
ranks = self.all_ranks_from_zero()
return len(ranks) == len(set(ranks))
def remap_to_tensor(
self,
mesh_tensor: torch.Tensor,
) -> torch.Tensor:
"""
Leverage layout as an index for mesh tensor that re-maps the indexes after layout
transformation to actual device ranks.
With this method, the cute layout serves as the backend of indices bookkeeping for the
mesh tensor when it comes to flatten, unflatten and slicing operations. The actual mesh
tensor still represents the actual device assignment and ranks. We need this function
to specify device allocation and create backend for a mesh. Although any transform of mesh tensors
can be treated as a view or subset of mesh tensor, we do need to use the actual view or
sub-tensor for DeviceMesh and its backend creation.
The shape of the `mesh_tensor` can be any size because users can define a device mesh with any
shapes. But we can further refactor the code so that internally we can only support 1D mesh tensor
and reconstruct the mesh tensor with the shape of the layout when accessed by users.
#TODO: Only support 1D mesh tensor stored internally and reconstruct the mesh tensor via layout.
Examples:
Case 1 - Consecutive ranks, full world:
original_mesh_tensor = [[0,1],[2,3]] # 2x2 mesh, ranks 0-3
world_size = 4
layout = Layout(2:2)
Return: [[0,2],[1,3]]
Case 2 - Non-consecutive ranks:
original_mesh_tensor = [[10,20],[30,40]] # custom rank assignment
world_size = 4
layout = Layout(2:2)
Return: [[[10,30],[20,40]]]
Args:
mesh_tensor: The concrete mesh tensor with actual device ranks
Returns:
torch.Tensor: A tensor representing the actual device allocation from mesh_tensor
"""
complement_layout = self.complement(mesh_tensor.numel())
return (
mesh_tensor.flatten()
.as_strided(
flatten(complement_layout.sizes) + flatten(self.sizes),
flatten(complement_layout.strides) + flatten(self.strides),
)
.reshape(-1, *(self[i].numel() for i in range(len(self))))
)

View File

@ -6,13 +6,12 @@ import os
import threading
import warnings
from collections.abc import Iterator
from functools import reduce
from itertools import zip_longest
from typing import Optional, TYPE_CHECKING, Union
import torch
from torch.distributed import is_available
from torch.distributed._mesh_layout import _MeshLayout
from torch.distributed._pycute import is_int
from torch.utils._typing_utils import not_none
@ -67,16 +66,17 @@ else:
)
BackendConfig = tuple[Optional[str], Optional[C10dBackend.Options]]
torch.serialization.add_safe_globals([_MeshLayout])
class _MeshEnv(threading.local):
def __init__(self) -> None:
self.mesh_stack: list[DeviceMesh] = []
# TODO: Move the bookkeeping maps from _MeshEnv to DeviceMesh.
self.child_to_root_mapping: dict[DeviceMesh, DeviceMesh] = {}
self.mesh_dim_group_options: dict[int, BackendConfig] = {}
# Record flatten mesh name to its flattened mesh in root mesh.
self.root_to_flatten_mapping: dict[DeviceMesh, dict[str, DeviceMesh]] = {}
# Record flatten mesh name to its mesh dim index in root mesh.
self.flatten_name_to_root_dims: dict[
DeviceMesh, dict[str, tuple[int, ...]]
] = {}
def get_current_mesh(self) -> "DeviceMesh":
if len(self.mesh_stack) == 0:
@ -86,51 +86,108 @@ else:
def create_sub_mesh(
self,
device_mesh: "DeviceMesh",
layout: _MeshLayout,
submesh_dim_names: tuple[str, ...],
submesh_dims: list[tuple[int, ...]],
) -> "DeviceMesh":
root_mesh = self.get_root_mesh(device_mesh)
slice_dim_group_name = []
for name in submesh_dim_names:
if name in not_none(device_mesh.mesh_dim_names):
slice_dim_group_name.append(
device_mesh._dim_group_names[ # type: ignore[has-type]
not_none(device_mesh.mesh_dim_names).index(name)
# Get the submesh dim size from the submesh_dims.
# For example, if we have a 3D mesh with mesh_shape (2, 2, 2) mesh_dim_names ("dp", "cp", "tp") and we want
# to slice out mesh["dp_cp"], then submesh_dims = [(0, 1), (2,)] and submesh_dim_size = [2 * 2, 2] = [4, 2].
# If we want to slice out mesh["dp", "cp"], then submesh_dims = [(0,), (1,)] and submesh_dim_size = [2, 2].
slice_dim_size = [
reduce(
lambda x, y: x * device_mesh.mesh.size(y),
mesh_dim,
1,
)
for mesh_dim in submesh_dims
]
mesh_tensor = device_mesh.mesh
# slice_dim_idx could be different from submesh_dims, as we may need to flatten out some dims.
slice_dim_idx = []
slice_dim_group_name = []
# keep track of the number of dims that have been flattened so we can get the correct slice_dim_idx in the
# flattened mesh tensor.
num_dims_flatten = 0
for mesh_dim_indices, mesh_dim_name in zip(submesh_dims, submesh_dim_names):
# Currently, this only allows slicing out a contiguous flattened dim.
# TODO: we need to handle reconstructing a non-contiguous flattened dim.
if len(mesh_dim_indices) > 1:
# We need to move the start_dim and end_dim to the left if some dims are already flattened.
mesh_tensor = mesh_tensor.flatten(
start_dim=mesh_dim_indices[0] - num_dims_flatten,
end_dim=mesh_dim_indices[-1] - num_dims_flatten,
)
# If some dims are already flattened, we need to adjust the slice_dim_idx accordingly.
# For example, if the submesh_dims = [(0, 1), (2,), (3, 4)] with 0-1 flattened and 3-4 flattened,
# then the final slice_dim_idx should be [0, 1, 2].
slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten)
num_dims_flatten += len(mesh_dim_indices) - 1
slice_dim_group_name.append(
self.root_to_flatten_mapping[device_mesh][
mesh_dim_name
]._dim_group_names[0] # type: ignore[has-type]
)
else:
# If device_mesh is not root_mesh, we already throw error in _get_slice_mesh_layout
# Since we will deprecate the slicing of flattened dim_name from root mesh soon,
# we don't want to optimize the code furthermore.
flatten_mesh = self.root_to_flatten_mapping[device_mesh][name]
slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten)
slice_dim_group_name.append(
flatten_mesh._dim_group_names[ # type: ignore[has-type]
not_none(flatten_mesh.mesh_dim_names).index(name)
]
device_mesh._dim_group_names[mesh_dim_indices[0]] # type: ignore[has-type]
)
# mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now.
mesh_dims_remained_idx = list(range(mesh_tensor.ndim))
for idx in slice_dim_idx:
if idx not in mesh_dims_remained_idx:
raise NotImplementedError(
"Currently, this only allows slicing out a contiguous flattened dim."
)
mesh_dims_remained_idx.remove(idx)
# pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *slice_dim_idx]
# This means on each local rank of the outermost slice mesh dim, we have a tensor of submesh size with
# the pg ranks of the submesh. From this, we can extract the submesh mesh tensor contains the current rank.
pg_ranks_by_dim = mesh_tensor.permute(
*mesh_dims_remained_idx, *slice_dim_idx
).reshape(-1, *slice_dim_size)
cur_rank = device_mesh.get_rank()
pg_ranks_by_dim = layout.remap_to_tensor(
root_mesh.mesh,
for mesh_nd in pg_ranks_by_dim:
submesh = DeviceMesh(
device_mesh.device_type,
mesh_nd,
mesh_dim_names=submesh_dim_names,
_init_backend=False,
)
if cur_rank in mesh_nd:
res_submesh = submesh
res_submesh = DeviceMesh._create_mesh_from_ranks(
device_mesh.device_type,
pg_ranks_by_dim,
cur_rank,
submesh_dim_names,
_init_backend=False,
_layout=layout,
)
res_submesh._dim_group_names = slice_dim_group_name
self.child_to_root_mapping[res_submesh] = root_mesh
res_submesh._dim_group_names = slice_dim_group_name # type: ignore[possibly-undefined, has-type]
self.child_to_root_mapping[res_submesh] = device_mesh
return res_submesh
def create_flatten_mesh(
self,
device_mesh: "DeviceMesh",
mesh_dim_name: Optional[str] = None,
backend_override: BackendConfig = (None, None),
backend_override: BackendConfig = (
None,
None,
),
) -> "DeviceMesh":
root_mesh = self.get_root_mesh(device_mesh)
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
flatten_dims_in_root = [
not_none(root_mesh.mesh_dim_names).index(flatten_mesh_dim_name)
for flatten_mesh_dim_name in not_none(device_mesh.mesh_dim_names)
]
if not mesh_dim_name:
mesh_dim_name = "_".join(not_none(device_mesh.mesh_dim_names))
@ -142,6 +199,7 @@ else:
return device_mesh
# Check whether the mesh_dim_name for flattened mesh is valid.
self.flatten_name_to_root_dims.setdefault(root_mesh, {})
invalid_dim_names = not_none(root_mesh.mesh_dim_names)
if mesh_dim_name in invalid_dim_names:
raise ValueError(
@ -150,43 +208,47 @@ else:
f"Please specify another valid mesh_dim_name.",
)
flattened_mesh_layout = device_mesh._layout.coalesce()
# Quick return if the flatten mesh has been created before.
if (
root_mesh in self.root_to_flatten_mapping
and mesh_dim_name in self.root_to_flatten_mapping[root_mesh]
):
if (
flattened_mesh_layout
== self.root_to_flatten_mapping[root_mesh][mesh_dim_name]._layout
tuple(flatten_dims_in_root)
== self.flatten_name_to_root_dims[root_mesh][mesh_dim_name]
):
return self.root_to_flatten_mapping[root_mesh][mesh_dim_name]
else:
raise ValueError(
f"Flatten mesh with mesh_dim_name {mesh_dim_name} has been created before, "
f"Please specify another valid mesh_dim_name."
f"Please specify another valid mesh_dim_name.",
)
flattened_mesh_dim_size = math.prod(device_mesh.mesh.size())
remained_dims_in_root = list(range(root_mesh.mesh.ndim))
for flatten_dim_in_root in flatten_dims_in_root:
remained_dims_in_root.remove(flatten_dim_in_root)
pg_ranks_by_dim = root_mesh.mesh.permute(
*remained_dims_in_root, *flatten_dims_in_root
).reshape(-1, flattened_mesh_dim_size)
cur_rank = root_mesh.get_rank()
# Due to the limitation of ProcessGroup api, we need to start from root mesh so that all ranks call the
# new_group api to avoid potential hang.
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(
root_mesh.mesh,
)
res_flattened_mesh = DeviceMesh._create_mesh_from_ranks(
root_mesh.device_type,
pg_ranks_by_dim.flatten(
start_dim=1
), # this is needed for flatten non-contiguous mesh dims.
pg_ranks_by_dim,
cur_rank,
(mesh_dim_name,),
(backend_override,),
_layout=device_mesh._layout.coalesce(),
)
self.child_to_root_mapping[res_flattened_mesh] = root_mesh
self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined]
self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = (
res_flattened_mesh
res_flattened_mesh # type: ignore[possibly-undefined]
)
self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(
flatten_dims_in_root
) # type: ignore[possibly-undefined]
return res_flattened_mesh
@ -248,35 +310,27 @@ else:
) -> None:
self.mesh_dim_group_options[dim] = (backend, pg_options)
def _get_slice_mesh_layout(self, device_mesh, mesh_dim_names) -> _MeshLayout:
def _get_slice_mesh_dims(
self, device_mesh, mesh_dim_names
) -> list[tuple[int, ...]]:
"""
Validate whether the mesh_dim_names is valid for slicing the given device_mesh.
If valid, return dim indexes of the slice mesh in the device mesh.
"""
slice_from_root = True
if device_mesh != self.get_root_mesh(device_mesh):
warnings.warn(
"You are attempting to slice a submesh from another submesh. While we support this operation, "
"it is users' responsibility to ensure that the submesh is consistently sliced across all ranks. "
"If not, this may result in some ranks receiving the submesh while others encounter errors."
)
slice_from_root = False
# The slice mesh_dim_names should consist either the device_mesh's mesh_dim_names
# or its flattened mesh's mesh_dim_names if it's root_mesh.
flatten_name_to_root_layout = (
{
key: mesh._layout
for key, mesh in self.root_to_flatten_mapping.setdefault(
device_mesh, {}
).items()
}
if slice_from_root
else {}
)
# or its flattened mesh's mesh_dim_names.
self.flatten_name_to_root_dims.setdefault(device_mesh, {})
flatten_name_to_root_dims = self.flatten_name_to_root_dims[device_mesh]
valid_mesh_dim_names = [
*device_mesh.mesh_dim_names,
*flatten_name_to_root_layout,
*flatten_name_to_root_dims,
]
if not all(
@ -288,51 +342,30 @@ else:
f"Valid mesh_dim_names are {valid_mesh_dim_names}."
)
layout_sliced = []
for name in mesh_dim_names:
if name in device_mesh.mesh_dim_names:
layout_sliced.append(
device_mesh._layout[device_mesh.mesh_dim_names.index(name)]
)
elif name in flatten_name_to_root_layout:
layout_sliced.append(flatten_name_to_root_layout[name])
sliced_sizes = tuple(l.sizes for l in layout_sliced)
sliced_strides = tuple(l.strides for l in layout_sliced)
# The check below is from DeviceMesh's implementation before adopting CuTe layout for internal
# bookkeeping and it can be removed but we need to define what is the expected behavior.
# TODO: Remove the below check and define the expected behavior.
# Validate the order of the slice mesh dim indices.
# This needs to be in ascending order.
pre_stride = -1
for stride in reversed(sliced_strides):
# Note that with CuTe layout, we can support slicing flattened non-contiguous mesh dims with no problem.
# But this will make this behavior complicated so we decided to not support it for now.
if not is_int(stride):
raise NotImplementedError(
"Currently, this only allows slicing out a contiguous flattened dim."
)
if stride < pre_stride:
curr_idx = -1
slice_mesh_dims = []
for mesh_dim_name in mesh_dim_names:
if mesh_dim_name in flatten_name_to_root_dims:
mesh_indices = flatten_name_to_root_dims[mesh_dim_name]
# TODO: this doesn't allow non-contiguous slicing with flatten dim yet. next_idx
# should be mesh_indices[0] once we support non-contiguous slicing with flatten dim.
next_idx = mesh_indices[-1]
slice_mesh_dims.append(mesh_indices)
else:
next_idx = device_mesh.mesh_dim_names.index(mesh_dim_name)
slice_mesh_dims.append((next_idx,))
if next_idx <= curr_idx:
raise KeyError(
f"Invalid mesh_dim_names {mesh_dim_names} specified. "
f"Found mesh dim indices to slice: {slice_mesh_dims}. "
"Mesh dim indices should be in ascending order."
)
pre_stride = stride
curr_idx = next_idx
# When users sliced dim_names outside from current mesh, we will check whether
# there is layout overlap.
# TODO: Eventually we will just directly throw error here because
# we will deprecate the slicing of flattened dim_name from root mesh.
layout_sliced = _MeshLayout(sliced_sizes, sliced_strides)
if not layout_sliced.check_non_overlap():
raise RuntimeError(
f"Slicing overlapping dim_names {mesh_dim_names} is not allowed."
)
return slice_mesh_dims
return layout_sliced
# TODO: to make this use case by other components public API in the future.
def _get_all_submeshes(
self, device_mesh: "DeviceMesh", mesh_dim_name: str
) -> list["DeviceMesh"]:
@ -340,10 +373,10 @@ else:
Return all the submeshes of a given mesh dimension of the device mesh.
"""
mesh_dim = self.get_mesh_dim_by_name(device_mesh, mesh_dim_name)
layout = device_mesh._layout[mesh_dim]
pg_ranks_by_dim = layout.remap_to_tensor(
device_mesh.mesh,
pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
-1, device_mesh.mesh.size(mesh_dim)
)
cur_rank = device_mesh.get_rank()
res_submeshes = []
for mesh_1d in pg_ranks_by_dim:
@ -353,7 +386,7 @@ else:
mesh_dim_names=(mesh_dim_name,),
_init_backend=False,
)
submesh._dim_group_names = ( # type: ignore[has-type]
submesh._dim_group_names = (
[device_mesh._dim_group_names[mesh_dim]] # type: ignore[has-type]
if cur_rank in mesh_1d
else []
@ -421,11 +454,9 @@ else:
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
"""
# TODO: to make existing public fields private and add some methods/properties for bc.
device_type: str
mesh: torch.Tensor
mesh_dim_names: Optional[tuple[str, ...]]
_layout: _MeshLayout
def __init__(
self,
@ -436,7 +467,6 @@ else:
backend_override: Optional[tuple[BackendConfig, ...]] = None,
_init_backend: bool = True,
_rank: Optional[int] = None,
_layout: Optional[_MeshLayout] = None,
) -> None:
self.device_type = device_type
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
@ -449,20 +479,6 @@ else:
self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
if backend_override is None:
backend_override = ((None, None),) * self.mesh.ndim
# Internal bookkeeping for the device mesh.
self._layout = (
_layout
if _layout
else _MeshLayout(self.mesh.size(), self.mesh.stride())
)
assert self._layout.check_non_overlap(), (
"Please use a non-overlapping layout when creating a DeviceMesh."
)
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
assert self._layout.numel() == self.mesh.numel(), (
"Please use a valid layout when creating a DeviceMesh."
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
)
# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
@ -772,7 +788,7 @@ else:
if mesh_dim_names == self.mesh_dim_names:
return self
else:
sliced_mesh_layout = _mesh_resources._get_slice_mesh_layout(
slice_mesh_dims = _mesh_resources._get_slice_mesh_dims(
self, mesh_dim_names
)
# When using FakeTensorMode to trace the model, `create_sub_mesh()` will
@ -786,7 +802,7 @@ else:
# TODO: compiler + device_mesh slicing.
with torch._subclasses.fake_tensor.unset_fake_temporarily():
submesh = _mesh_resources.create_sub_mesh(
self, sliced_mesh_layout, mesh_dim_names
self, mesh_dim_names, slice_mesh_dims
)
return submesh
@ -852,7 +868,6 @@ else:
mesh_dim_names: tuple[str, ...],
backend_override: Optional[tuple[BackendConfig, ...]] = None,
_init_backend: bool = True,
_layout: Optional[_MeshLayout] = None,
) -> "DeviceMesh":
"""
Helper method to create a DeviceMesh from tensor `pg_ranks_by_dim`. This is due to
@ -884,7 +899,6 @@ else:
mesh_dim_names=mesh_dim_names,
backend_override=backend_override,
_init_backend=_init_backend,
_layout=_layout,
)
if cur_rank in mesh_nd:
res_mesh = mesh
@ -970,12 +984,8 @@ else:
raise ValueError(
"Must pass mesh_dim_names if passing multiple ProcessGroups"
)
# When init a DeviceMesh with multiple ProcessGroups directly, we need to make sure
# the mesh tensor is contiguous. Otherwise, the layout we inferred from the mesh tensor
# will have larger span than the actual tensor. This is just internal implementation detail
# and does not affect user facing behavior.
mesh = (
mesh.detach().to(dtype=torch.int, device="cpu").contiguous()
mesh.detach().to(dtype=torch.int, device="cpu")
if isinstance(mesh, torch.Tensor)
else torch.tensor(mesh, device="cpu", dtype=torch.int)
)