mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout (#163213)"
This reverts commit b0985144b59db8fb20964829b5e0a9d2f9a3f0d6. Reverted https://github.com/pytorch/pytorch/pull/163213 on behalf of https://github.com/yangw-dev due to caused internal test failure ([comment](https://github.com/pytorch/pytorch/pull/163213#issuecomment-3363414435))
This commit is contained in:
@ -440,7 +440,6 @@ class DeviceMeshTestNDim(DTensorTestBase):
|
||||
ep_mesh = ep_mesh_1 if self.rank < self.world_size // 2 else ep_mesh_2
|
||||
# ep_mesh is considered different from mesh_2d["TP"]
|
||||
self.assertEqual(mesh_2d["TP"]._flatten_mesh_list, ep_mesh._flatten_mesh_list)
|
||||
self.assertEqual(mesh_2d["TP"]._layout, ep_mesh._layout)
|
||||
self.assertEqual(mesh_2d["TP"].mesh.shape, ep_mesh.mesh.shape)
|
||||
self.assertEqual(mesh_2d["TP"].device_type, ep_mesh.device_type)
|
||||
self.assertNotEqual(mesh_2d["TP"].mesh_dim_names, ep_mesh.mesh_dim_names)
|
||||
@ -455,7 +454,6 @@ class DeviceMeshTestNDim(DTensorTestBase):
|
||||
)
|
||||
# another_mesh is considered the same as ep_mesh
|
||||
self.assertEqual(ep_mesh._flatten_mesh_list, another_mesh._flatten_mesh_list)
|
||||
self.assertEqual(ep_mesh._layout, another_mesh._layout)
|
||||
self.assertEqual(ep_mesh.mesh.shape, another_mesh.mesh.shape)
|
||||
self.assertEqual(ep_mesh.device_type, another_mesh.device_type)
|
||||
self.assertEqual(ep_mesh.mesh_dim_names, another_mesh.mesh_dim_names)
|
||||
@ -541,6 +539,7 @@ class DeviceMeshTestNDim(DTensorTestBase):
|
||||
mesh_dim_names=("dp_replicate", "dp_shard"),
|
||||
)
|
||||
|
||||
# self.assertEqual(ref_mesh._dim_group_names, dp_mesh._dim_group_names)
|
||||
for mesh_dim_group, ref_mesh_dim_group in zip(
|
||||
dp_mesh.get_all_groups(), ref_mesh.get_all_groups()
|
||||
):
|
||||
@ -801,10 +800,6 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
||||
# Test slicing out 1D mesh from a sub-2D mesh.
|
||||
shard_mesh = hsdp_mesh_2["Shard"]
|
||||
self.assertEqual(shard_mesh.mesh.tolist(), shard_group[shard_group_idx])
|
||||
replicate_mesh = hsdp_mesh_2["Replicate"]
|
||||
self.assertEqual(
|
||||
replicate_mesh.mesh.tolist(), replicate_group[replicate_group_idx]
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_cache_and_reuse_submesh_slice_result(self):
|
||||
@ -878,17 +873,12 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
||||
flattened_dp_cp_mesh = dp_cp_mesh._flatten()
|
||||
self.assertEqual(dp_cp_mesh.mesh.flatten(), flattened_dp_cp_mesh.mesh)
|
||||
self.assertEqual(flattened_dp_cp_mesh.mesh_dim_names[0], "dp_cp")
|
||||
self.assertEqual(flattened_dp_cp_mesh.get_group().group_desc, "mesh_dp_cp")
|
||||
root_mesh = _mesh_resources.get_root_mesh(dp_cp_mesh)
|
||||
self.assertEqual(root_mesh, mesh_3d)
|
||||
flatten_mesh_layout = _mesh_resources.root_to_flatten_mapping[root_mesh][
|
||||
flatten_mesh_root_dims = _mesh_resources.flatten_name_to_root_dims[root_mesh][
|
||||
"dp_cp"
|
||||
]._layout
|
||||
self.assertEqual(flatten_mesh_layout, flattened_dp_cp_mesh._layout)
|
||||
self.assertEqual(
|
||||
flattened_dp_cp_mesh._layout.global_ranks(8),
|
||||
[[0, 2, 4, 6], [1, 3, 5, 7]],
|
||||
)
|
||||
]
|
||||
self.assertEqual(flatten_mesh_root_dims, (0, 1))
|
||||
|
||||
ref_pg_count = _world.group_count
|
||||
# Calling flatten again should not create a new pg.
|
||||
@ -903,19 +893,10 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
||||
self.assertEqual(flattened_dp_tp_mesh.mesh_dim_names[0], "dp_tp")
|
||||
root_mesh = _mesh_resources.get_root_mesh(dp_tp_mesh)
|
||||
self.assertEqual(root_mesh, mesh_3d)
|
||||
flatten_mesh_root_layout = _mesh_resources.root_to_flatten_mapping[root_mesh][
|
||||
flatten_mesh_root_dims = _mesh_resources.flatten_name_to_root_dims[root_mesh][
|
||||
"dp_tp"
|
||||
]._layout
|
||||
self.assertEqual(flatten_mesh_root_layout, flattened_dp_tp_mesh._layout)
|
||||
self.assertEqual(
|
||||
flattened_dp_tp_mesh._layout.global_ranks(8),
|
||||
[[0, 1, 4, 5], [2, 3, 6, 7]],
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Currently, this only allows slicing out a contiguous flattened dim",
|
||||
):
|
||||
mesh_3d["dp_tp", "cp"]
|
||||
]
|
||||
self.assertEqual(flatten_mesh_root_dims, (0, 2))
|
||||
|
||||
# Test flatten with a flattened mesh_dim_name
|
||||
cp_tp_mesh = mesh_3d["cp", "tp"]
|
||||
@ -1556,50 +1537,6 @@ class CuTeLayoutTest(TestCase):
|
||||
layout8 = _Layout((3, 2), (2, 3))
|
||||
self.assertTrue(layout8.check_non_overlap())
|
||||
|
||||
def test_remap_to_tensor(self):
|
||||
"""Test the remap_to_tensor method for various scenarios."""
|
||||
# Test 1: Consecutive ranks, full world - should return logical groups directly
|
||||
original_mesh = torch.tensor([[0, 1], [2, 3]], dtype=torch.int)
|
||||
layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2
|
||||
result1 = layout1.remap_to_tensor(original_mesh)
|
||||
expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
|
||||
self.assertEqual(result1, expected1)
|
||||
|
||||
# Test 2: Non-consecutive ranks - should map to actual ranks
|
||||
original_mesh = torch.tensor([[10, 20], [30, 40]], dtype=torch.int)
|
||||
layout2 = _Layout((2, 2), (2, 1))
|
||||
result2 = layout2.remap_to_tensor(original_mesh)
|
||||
expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
|
||||
self.assertEqual(result2, expected2)
|
||||
|
||||
# Test 4: 1D layout with consecutive ranks
|
||||
original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int)
|
||||
layout4 = _Layout((4,), (1,))
|
||||
result4 = layout4.remap_to_tensor(original_mesh)
|
||||
expected4 = torch.tensor([[0, 1, 2, 3]], dtype=torch.int)
|
||||
self.assertEqual(result4, expected4)
|
||||
|
||||
# Test 5: Complex strided layout with non-consecutive ranks
|
||||
original_mesh = torch.tensor([5, 10, 15, 20], dtype=torch.int)
|
||||
layout5 = _Layout((2, 2), (2, 1))
|
||||
result5 = layout5.remap_to_tensor(original_mesh)
|
||||
expected5 = torch.tensor([[[5, 10], [15, 20]]], dtype=torch.int)
|
||||
self.assertEqual(result5, expected5)
|
||||
|
||||
# Test 6: Tensor Cute representation of a 2D mesh
|
||||
original_mesh = torch.tensor([[0, 2], [1, 3]], dtype=torch.int)
|
||||
layout6 = _Layout((2, 2), (1, 2)) # column-major style
|
||||
result6 = layout6.remap_to_tensor(original_mesh)
|
||||
expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
|
||||
self.assertEqual(result6, expected6)
|
||||
|
||||
# Test 7: Layout with different stride pattern
|
||||
original_mesh = torch.tensor([0, 2, 1, 4], dtype=torch.int)
|
||||
layout7 = _Layout((2, 2), (1, 2)) # column-major style
|
||||
result7 = layout7.remap_to_tensor(original_mesh)
|
||||
expected7 = torch.tensor([[[0, 1], [2, 4]]], dtype=torch.int)
|
||||
self.assertEqual(result7, expected7)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -7,7 +7,6 @@ from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from torch.distributed._pycute import (
|
||||
coalesce,
|
||||
complement,
|
||||
@ -244,54 +243,3 @@ class _MeshLayout(Layout):
|
||||
"""
|
||||
ranks = self.all_ranks_from_zero()
|
||||
return len(ranks) == len(set(ranks))
|
||||
|
||||
def remap_to_tensor(
|
||||
self,
|
||||
mesh_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Leverage layout as an index for mesh tensor that re-maps the indexes after layout
|
||||
transformation to actual device ranks.
|
||||
|
||||
With this method, the cute layout serves as the backend of indices bookkeeping for the
|
||||
mesh tensor when it comes to flatten, unflatten and slicing operations. The actual mesh
|
||||
tensor still represents the actual device assignment and ranks. We need this function
|
||||
to specify device allocation and create backend for a mesh. Although any transform of mesh tensors
|
||||
can be treated as a view or subset of mesh tensor, we do need to use the actual view or
|
||||
sub-tensor for DeviceMesh and its backend creation.
|
||||
|
||||
The shape of the `mesh_tensor` can be any size because users can define a device mesh with any
|
||||
shapes. But we can further refactor the code so that internally we can only support 1D mesh tensor
|
||||
and reconstruct the mesh tensor with the shape of the layout when accessed by users.
|
||||
#TODO: Only support 1D mesh tensor stored internally and reconstruct the mesh tensor via layout.
|
||||
|
||||
Examples:
|
||||
|
||||
Case 1 - Consecutive ranks, full world:
|
||||
original_mesh_tensor = [[0,1],[2,3]] # 2x2 mesh, ranks 0-3
|
||||
world_size = 4
|
||||
layout = Layout(2:2)
|
||||
Return: [[0,2],[1,3]]
|
||||
|
||||
Case 2 - Non-consecutive ranks:
|
||||
original_mesh_tensor = [[10,20],[30,40]] # custom rank assignment
|
||||
world_size = 4
|
||||
layout = Layout(2:2)
|
||||
Return: [[[10,30],[20,40]]]
|
||||
|
||||
Args:
|
||||
mesh_tensor: The concrete mesh tensor with actual device ranks
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor representing the actual device allocation from mesh_tensor
|
||||
"""
|
||||
complement_layout = self.complement(mesh_tensor.numel())
|
||||
|
||||
return (
|
||||
mesh_tensor.flatten()
|
||||
.as_strided(
|
||||
flatten(complement_layout.sizes) + flatten(self.sizes),
|
||||
flatten(complement_layout.strides) + flatten(self.strides),
|
||||
)
|
||||
.reshape(-1, *(self[i].numel() for i in range(len(self))))
|
||||
)
|
||||
|
@ -6,13 +6,12 @@ import os
|
||||
import threading
|
||||
import warnings
|
||||
from collections.abc import Iterator
|
||||
from functools import reduce
|
||||
from itertools import zip_longest
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed import is_available
|
||||
from torch.distributed._mesh_layout import _MeshLayout
|
||||
from torch.distributed._pycute import is_int
|
||||
from torch.utils._typing_utils import not_none
|
||||
|
||||
|
||||
@ -67,16 +66,17 @@ else:
|
||||
)
|
||||
|
||||
BackendConfig = tuple[Optional[str], Optional[C10dBackend.Options]]
|
||||
torch.serialization.add_safe_globals([_MeshLayout])
|
||||
|
||||
class _MeshEnv(threading.local):
|
||||
def __init__(self) -> None:
|
||||
self.mesh_stack: list[DeviceMesh] = []
|
||||
# TODO: Move the bookkeeping maps from _MeshEnv to DeviceMesh.
|
||||
self.child_to_root_mapping: dict[DeviceMesh, DeviceMesh] = {}
|
||||
self.mesh_dim_group_options: dict[int, BackendConfig] = {}
|
||||
# Record flatten mesh name to its flattened mesh in root mesh.
|
||||
self.root_to_flatten_mapping: dict[DeviceMesh, dict[str, DeviceMesh]] = {}
|
||||
# Record flatten mesh name to its mesh dim index in root mesh.
|
||||
self.flatten_name_to_root_dims: dict[
|
||||
DeviceMesh, dict[str, tuple[int, ...]]
|
||||
] = {}
|
||||
|
||||
def get_current_mesh(self) -> "DeviceMesh":
|
||||
if len(self.mesh_stack) == 0:
|
||||
@ -86,51 +86,108 @@ else:
|
||||
def create_sub_mesh(
|
||||
self,
|
||||
device_mesh: "DeviceMesh",
|
||||
layout: _MeshLayout,
|
||||
submesh_dim_names: tuple[str, ...],
|
||||
submesh_dims: list[tuple[int, ...]],
|
||||
) -> "DeviceMesh":
|
||||
root_mesh = self.get_root_mesh(device_mesh)
|
||||
slice_dim_group_name = []
|
||||
for name in submesh_dim_names:
|
||||
if name in not_none(device_mesh.mesh_dim_names):
|
||||
slice_dim_group_name.append(
|
||||
device_mesh._dim_group_names[ # type: ignore[has-type]
|
||||
not_none(device_mesh.mesh_dim_names).index(name)
|
||||
# Get the submesh dim size from the submesh_dims.
|
||||
# For example, if we have a 3D mesh with mesh_shape (2, 2, 2) mesh_dim_names ("dp", "cp", "tp") and we want
|
||||
# to slice out mesh["dp_cp"], then submesh_dims = [(0, 1), (2,)] and submesh_dim_size = [2 * 2, 2] = [4, 2].
|
||||
# If we want to slice out mesh["dp", "cp"], then submesh_dims = [(0,), (1,)] and submesh_dim_size = [2, 2].
|
||||
slice_dim_size = [
|
||||
reduce(
|
||||
lambda x, y: x * device_mesh.mesh.size(y),
|
||||
mesh_dim,
|
||||
1,
|
||||
)
|
||||
for mesh_dim in submesh_dims
|
||||
]
|
||||
|
||||
mesh_tensor = device_mesh.mesh
|
||||
# slice_dim_idx could be different from submesh_dims, as we may need to flatten out some dims.
|
||||
slice_dim_idx = []
|
||||
slice_dim_group_name = []
|
||||
# keep track of the number of dims that have been flattened so we can get the correct slice_dim_idx in the
|
||||
# flattened mesh tensor.
|
||||
num_dims_flatten = 0
|
||||
for mesh_dim_indices, mesh_dim_name in zip(submesh_dims, submesh_dim_names):
|
||||
# Currently, this only allows slicing out a contiguous flattened dim.
|
||||
# TODO: we need to handle reconstructing a non-contiguous flattened dim.
|
||||
if len(mesh_dim_indices) > 1:
|
||||
# We need to move the start_dim and end_dim to the left if some dims are already flattened.
|
||||
mesh_tensor = mesh_tensor.flatten(
|
||||
start_dim=mesh_dim_indices[0] - num_dims_flatten,
|
||||
end_dim=mesh_dim_indices[-1] - num_dims_flatten,
|
||||
)
|
||||
# If some dims are already flattened, we need to adjust the slice_dim_idx accordingly.
|
||||
# For example, if the submesh_dims = [(0, 1), (2,), (3, 4)] with 0-1 flattened and 3-4 flattened,
|
||||
# then the final slice_dim_idx should be [0, 1, 2].
|
||||
slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten)
|
||||
num_dims_flatten += len(mesh_dim_indices) - 1
|
||||
slice_dim_group_name.append(
|
||||
self.root_to_flatten_mapping[device_mesh][
|
||||
mesh_dim_name
|
||||
]._dim_group_names[0] # type: ignore[has-type]
|
||||
)
|
||||
else:
|
||||
# If device_mesh is not root_mesh, we already throw error in _get_slice_mesh_layout
|
||||
# Since we will deprecate the slicing of flattened dim_name from root mesh soon,
|
||||
# we don't want to optimize the code furthermore.
|
||||
flatten_mesh = self.root_to_flatten_mapping[device_mesh][name]
|
||||
slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten)
|
||||
slice_dim_group_name.append(
|
||||
flatten_mesh._dim_group_names[ # type: ignore[has-type]
|
||||
not_none(flatten_mesh.mesh_dim_names).index(name)
|
||||
]
|
||||
device_mesh._dim_group_names[mesh_dim_indices[0]] # type: ignore[has-type]
|
||||
)
|
||||
|
||||
# mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now.
|
||||
mesh_dims_remained_idx = list(range(mesh_tensor.ndim))
|
||||
for idx in slice_dim_idx:
|
||||
if idx not in mesh_dims_remained_idx:
|
||||
raise NotImplementedError(
|
||||
"Currently, this only allows slicing out a contiguous flattened dim."
|
||||
)
|
||||
mesh_dims_remained_idx.remove(idx)
|
||||
|
||||
# pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *slice_dim_idx]
|
||||
# This means on each local rank of the outermost slice mesh dim, we have a tensor of submesh size with
|
||||
# the pg ranks of the submesh. From this, we can extract the submesh mesh tensor contains the current rank.
|
||||
pg_ranks_by_dim = mesh_tensor.permute(
|
||||
*mesh_dims_remained_idx, *slice_dim_idx
|
||||
).reshape(-1, *slice_dim_size)
|
||||
|
||||
cur_rank = device_mesh.get_rank()
|
||||
pg_ranks_by_dim = layout.remap_to_tensor(
|
||||
root_mesh.mesh,
|
||||
for mesh_nd in pg_ranks_by_dim:
|
||||
submesh = DeviceMesh(
|
||||
device_mesh.device_type,
|
||||
mesh_nd,
|
||||
mesh_dim_names=submesh_dim_names,
|
||||
_init_backend=False,
|
||||
)
|
||||
if cur_rank in mesh_nd:
|
||||
res_submesh = submesh
|
||||
res_submesh = DeviceMesh._create_mesh_from_ranks(
|
||||
device_mesh.device_type,
|
||||
pg_ranks_by_dim,
|
||||
cur_rank,
|
||||
submesh_dim_names,
|
||||
_init_backend=False,
|
||||
_layout=layout,
|
||||
)
|
||||
res_submesh._dim_group_names = slice_dim_group_name
|
||||
self.child_to_root_mapping[res_submesh] = root_mesh
|
||||
|
||||
res_submesh._dim_group_names = slice_dim_group_name # type: ignore[possibly-undefined, has-type]
|
||||
self.child_to_root_mapping[res_submesh] = device_mesh
|
||||
|
||||
return res_submesh
|
||||
|
||||
def create_flatten_mesh(
|
||||
self,
|
||||
device_mesh: "DeviceMesh",
|
||||
mesh_dim_name: Optional[str] = None,
|
||||
backend_override: BackendConfig = (None, None),
|
||||
backend_override: BackendConfig = (
|
||||
None,
|
||||
None,
|
||||
),
|
||||
) -> "DeviceMesh":
|
||||
root_mesh = self.get_root_mesh(device_mesh)
|
||||
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
|
||||
|
||||
flatten_dims_in_root = [
|
||||
not_none(root_mesh.mesh_dim_names).index(flatten_mesh_dim_name)
|
||||
for flatten_mesh_dim_name in not_none(device_mesh.mesh_dim_names)
|
||||
]
|
||||
|
||||
if not mesh_dim_name:
|
||||
mesh_dim_name = "_".join(not_none(device_mesh.mesh_dim_names))
|
||||
@ -142,6 +199,7 @@ else:
|
||||
return device_mesh
|
||||
|
||||
# Check whether the mesh_dim_name for flattened mesh is valid.
|
||||
self.flatten_name_to_root_dims.setdefault(root_mesh, {})
|
||||
invalid_dim_names = not_none(root_mesh.mesh_dim_names)
|
||||
if mesh_dim_name in invalid_dim_names:
|
||||
raise ValueError(
|
||||
@ -150,43 +208,47 @@ else:
|
||||
f"Please specify another valid mesh_dim_name.",
|
||||
)
|
||||
|
||||
flattened_mesh_layout = device_mesh._layout.coalesce()
|
||||
# Quick return if the flatten mesh has been created before.
|
||||
if (
|
||||
root_mesh in self.root_to_flatten_mapping
|
||||
and mesh_dim_name in self.root_to_flatten_mapping[root_mesh]
|
||||
):
|
||||
if (
|
||||
flattened_mesh_layout
|
||||
== self.root_to_flatten_mapping[root_mesh][mesh_dim_name]._layout
|
||||
tuple(flatten_dims_in_root)
|
||||
== self.flatten_name_to_root_dims[root_mesh][mesh_dim_name]
|
||||
):
|
||||
return self.root_to_flatten_mapping[root_mesh][mesh_dim_name]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Flatten mesh with mesh_dim_name {mesh_dim_name} has been created before, "
|
||||
f"Please specify another valid mesh_dim_name."
|
||||
f"Please specify another valid mesh_dim_name.",
|
||||
)
|
||||
|
||||
flattened_mesh_dim_size = math.prod(device_mesh.mesh.size())
|
||||
|
||||
remained_dims_in_root = list(range(root_mesh.mesh.ndim))
|
||||
for flatten_dim_in_root in flatten_dims_in_root:
|
||||
remained_dims_in_root.remove(flatten_dim_in_root)
|
||||
|
||||
pg_ranks_by_dim = root_mesh.mesh.permute(
|
||||
*remained_dims_in_root, *flatten_dims_in_root
|
||||
).reshape(-1, flattened_mesh_dim_size)
|
||||
|
||||
cur_rank = root_mesh.get_rank()
|
||||
# Due to the limitation of ProcessGroup api, we need to start from root mesh so that all ranks call the
|
||||
# new_group api to avoid potential hang.
|
||||
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(
|
||||
root_mesh.mesh,
|
||||
)
|
||||
res_flattened_mesh = DeviceMesh._create_mesh_from_ranks(
|
||||
root_mesh.device_type,
|
||||
pg_ranks_by_dim.flatten(
|
||||
start_dim=1
|
||||
), # this is needed for flatten non-contiguous mesh dims.
|
||||
pg_ranks_by_dim,
|
||||
cur_rank,
|
||||
(mesh_dim_name,),
|
||||
(backend_override,),
|
||||
_layout=device_mesh._layout.coalesce(),
|
||||
)
|
||||
self.child_to_root_mapping[res_flattened_mesh] = root_mesh
|
||||
self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined]
|
||||
self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = (
|
||||
res_flattened_mesh
|
||||
res_flattened_mesh # type: ignore[possibly-undefined]
|
||||
)
|
||||
self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(
|
||||
flatten_dims_in_root
|
||||
) # type: ignore[possibly-undefined]
|
||||
|
||||
return res_flattened_mesh
|
||||
|
||||
@ -248,35 +310,27 @@ else:
|
||||
) -> None:
|
||||
self.mesh_dim_group_options[dim] = (backend, pg_options)
|
||||
|
||||
def _get_slice_mesh_layout(self, device_mesh, mesh_dim_names) -> _MeshLayout:
|
||||
def _get_slice_mesh_dims(
|
||||
self, device_mesh, mesh_dim_names
|
||||
) -> list[tuple[int, ...]]:
|
||||
"""
|
||||
Validate whether the mesh_dim_names is valid for slicing the given device_mesh.
|
||||
If valid, return dim indexes of the slice mesh in the device mesh.
|
||||
"""
|
||||
slice_from_root = True
|
||||
if device_mesh != self.get_root_mesh(device_mesh):
|
||||
warnings.warn(
|
||||
"You are attempting to slice a submesh from another submesh. While we support this operation, "
|
||||
"it is users' responsibility to ensure that the submesh is consistently sliced across all ranks. "
|
||||
"If not, this may result in some ranks receiving the submesh while others encounter errors."
|
||||
)
|
||||
slice_from_root = False
|
||||
|
||||
# The slice mesh_dim_names should consist either the device_mesh's mesh_dim_names
|
||||
# or its flattened mesh's mesh_dim_names if it's root_mesh.
|
||||
flatten_name_to_root_layout = (
|
||||
{
|
||||
key: mesh._layout
|
||||
for key, mesh in self.root_to_flatten_mapping.setdefault(
|
||||
device_mesh, {}
|
||||
).items()
|
||||
}
|
||||
if slice_from_root
|
||||
else {}
|
||||
)
|
||||
# or its flattened mesh's mesh_dim_names.
|
||||
self.flatten_name_to_root_dims.setdefault(device_mesh, {})
|
||||
flatten_name_to_root_dims = self.flatten_name_to_root_dims[device_mesh]
|
||||
valid_mesh_dim_names = [
|
||||
*device_mesh.mesh_dim_names,
|
||||
*flatten_name_to_root_layout,
|
||||
*flatten_name_to_root_dims,
|
||||
]
|
||||
|
||||
if not all(
|
||||
@ -288,51 +342,30 @@ else:
|
||||
f"Valid mesh_dim_names are {valid_mesh_dim_names}."
|
||||
)
|
||||
|
||||
layout_sliced = []
|
||||
for name in mesh_dim_names:
|
||||
if name in device_mesh.mesh_dim_names:
|
||||
layout_sliced.append(
|
||||
device_mesh._layout[device_mesh.mesh_dim_names.index(name)]
|
||||
)
|
||||
elif name in flatten_name_to_root_layout:
|
||||
layout_sliced.append(flatten_name_to_root_layout[name])
|
||||
|
||||
sliced_sizes = tuple(l.sizes for l in layout_sliced)
|
||||
sliced_strides = tuple(l.strides for l in layout_sliced)
|
||||
|
||||
# The check below is from DeviceMesh's implementation before adopting CuTe layout for internal
|
||||
# bookkeeping and it can be removed but we need to define what is the expected behavior.
|
||||
# TODO: Remove the below check and define the expected behavior.
|
||||
# Validate the order of the slice mesh dim indices.
|
||||
# This needs to be in ascending order.
|
||||
pre_stride = -1
|
||||
for stride in reversed(sliced_strides):
|
||||
# Note that with CuTe layout, we can support slicing flattened non-contiguous mesh dims with no problem.
|
||||
# But this will make this behavior complicated so we decided to not support it for now.
|
||||
if not is_int(stride):
|
||||
raise NotImplementedError(
|
||||
"Currently, this only allows slicing out a contiguous flattened dim."
|
||||
)
|
||||
if stride < pre_stride:
|
||||
curr_idx = -1
|
||||
slice_mesh_dims = []
|
||||
for mesh_dim_name in mesh_dim_names:
|
||||
if mesh_dim_name in flatten_name_to_root_dims:
|
||||
mesh_indices = flatten_name_to_root_dims[mesh_dim_name]
|
||||
# TODO: this doesn't allow non-contiguous slicing with flatten dim yet. next_idx
|
||||
# should be mesh_indices[0] once we support non-contiguous slicing with flatten dim.
|
||||
next_idx = mesh_indices[-1]
|
||||
slice_mesh_dims.append(mesh_indices)
|
||||
else:
|
||||
next_idx = device_mesh.mesh_dim_names.index(mesh_dim_name)
|
||||
slice_mesh_dims.append((next_idx,))
|
||||
if next_idx <= curr_idx:
|
||||
raise KeyError(
|
||||
f"Invalid mesh_dim_names {mesh_dim_names} specified. "
|
||||
f"Found mesh dim indices to slice: {slice_mesh_dims}. "
|
||||
"Mesh dim indices should be in ascending order."
|
||||
)
|
||||
pre_stride = stride
|
||||
curr_idx = next_idx
|
||||
|
||||
# When users sliced dim_names outside from current mesh, we will check whether
|
||||
# there is layout overlap.
|
||||
# TODO: Eventually we will just directly throw error here because
|
||||
# we will deprecate the slicing of flattened dim_name from root mesh.
|
||||
layout_sliced = _MeshLayout(sliced_sizes, sliced_strides)
|
||||
if not layout_sliced.check_non_overlap():
|
||||
raise RuntimeError(
|
||||
f"Slicing overlapping dim_names {mesh_dim_names} is not allowed."
|
||||
)
|
||||
return slice_mesh_dims
|
||||
|
||||
return layout_sliced
|
||||
|
||||
# TODO: to make this use case by other components public API in the future.
|
||||
def _get_all_submeshes(
|
||||
self, device_mesh: "DeviceMesh", mesh_dim_name: str
|
||||
) -> list["DeviceMesh"]:
|
||||
@ -340,10 +373,10 @@ else:
|
||||
Return all the submeshes of a given mesh dimension of the device mesh.
|
||||
"""
|
||||
mesh_dim = self.get_mesh_dim_by_name(device_mesh, mesh_dim_name)
|
||||
layout = device_mesh._layout[mesh_dim]
|
||||
pg_ranks_by_dim = layout.remap_to_tensor(
|
||||
device_mesh.mesh,
|
||||
pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
|
||||
-1, device_mesh.mesh.size(mesh_dim)
|
||||
)
|
||||
|
||||
cur_rank = device_mesh.get_rank()
|
||||
res_submeshes = []
|
||||
for mesh_1d in pg_ranks_by_dim:
|
||||
@ -353,7 +386,7 @@ else:
|
||||
mesh_dim_names=(mesh_dim_name,),
|
||||
_init_backend=False,
|
||||
)
|
||||
submesh._dim_group_names = ( # type: ignore[has-type]
|
||||
submesh._dim_group_names = (
|
||||
[device_mesh._dim_group_names[mesh_dim]] # type: ignore[has-type]
|
||||
if cur_rank in mesh_1d
|
||||
else []
|
||||
@ -421,11 +454,9 @@ else:
|
||||
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
|
||||
"""
|
||||
|
||||
# TODO: to make existing public fields private and add some methods/properties for bc.
|
||||
device_type: str
|
||||
mesh: torch.Tensor
|
||||
mesh_dim_names: Optional[tuple[str, ...]]
|
||||
_layout: _MeshLayout
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -436,7 +467,6 @@ else:
|
||||
backend_override: Optional[tuple[BackendConfig, ...]] = None,
|
||||
_init_backend: bool = True,
|
||||
_rank: Optional[int] = None,
|
||||
_layout: Optional[_MeshLayout] = None,
|
||||
) -> None:
|
||||
self.device_type = device_type
|
||||
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
|
||||
@ -449,20 +479,6 @@ else:
|
||||
self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
|
||||
if backend_override is None:
|
||||
backend_override = ((None, None),) * self.mesh.ndim
|
||||
# Internal bookkeeping for the device mesh.
|
||||
self._layout = (
|
||||
_layout
|
||||
if _layout
|
||||
else _MeshLayout(self.mesh.size(), self.mesh.stride())
|
||||
)
|
||||
assert self._layout.check_non_overlap(), (
|
||||
"Please use a non-overlapping layout when creating a DeviceMesh."
|
||||
)
|
||||
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
|
||||
assert self._layout.numel() == self.mesh.numel(), (
|
||||
"Please use a valid layout when creating a DeviceMesh."
|
||||
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
|
||||
)
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||
@ -772,7 +788,7 @@ else:
|
||||
if mesh_dim_names == self.mesh_dim_names:
|
||||
return self
|
||||
else:
|
||||
sliced_mesh_layout = _mesh_resources._get_slice_mesh_layout(
|
||||
slice_mesh_dims = _mesh_resources._get_slice_mesh_dims(
|
||||
self, mesh_dim_names
|
||||
)
|
||||
# When using FakeTensorMode to trace the model, `create_sub_mesh()` will
|
||||
@ -786,7 +802,7 @@ else:
|
||||
# TODO: compiler + device_mesh slicing.
|
||||
with torch._subclasses.fake_tensor.unset_fake_temporarily():
|
||||
submesh = _mesh_resources.create_sub_mesh(
|
||||
self, sliced_mesh_layout, mesh_dim_names
|
||||
self, mesh_dim_names, slice_mesh_dims
|
||||
)
|
||||
return submesh
|
||||
|
||||
@ -852,7 +868,6 @@ else:
|
||||
mesh_dim_names: tuple[str, ...],
|
||||
backend_override: Optional[tuple[BackendConfig, ...]] = None,
|
||||
_init_backend: bool = True,
|
||||
_layout: Optional[_MeshLayout] = None,
|
||||
) -> "DeviceMesh":
|
||||
"""
|
||||
Helper method to create a DeviceMesh from tensor `pg_ranks_by_dim`. This is due to
|
||||
@ -884,7 +899,6 @@ else:
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
backend_override=backend_override,
|
||||
_init_backend=_init_backend,
|
||||
_layout=_layout,
|
||||
)
|
||||
if cur_rank in mesh_nd:
|
||||
res_mesh = mesh
|
||||
@ -970,12 +984,8 @@ else:
|
||||
raise ValueError(
|
||||
"Must pass mesh_dim_names if passing multiple ProcessGroups"
|
||||
)
|
||||
# When init a DeviceMesh with multiple ProcessGroups directly, we need to make sure
|
||||
# the mesh tensor is contiguous. Otherwise, the layout we inferred from the mesh tensor
|
||||
# will have larger span than the actual tensor. This is just internal implementation detail
|
||||
# and does not affect user facing behavior.
|
||||
mesh = (
|
||||
mesh.detach().to(dtype=torch.int, device="cpu").contiguous()
|
||||
mesh.detach().to(dtype=torch.int, device="cpu")
|
||||
if isinstance(mesh, torch.Tensor)
|
||||
else torch.tensor(mesh, device="cpu", dtype=torch.int)
|
||||
)
|
||||
|
Reference in New Issue
Block a user