mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DeviceMesh] Prefer using _layout over _mesh for all sorts of things (#165554)
The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165554 Approved by: https://github.com/fduwjj
This commit is contained in:
committed by
PyTorch MergeBot
parent
99b32a6750
commit
d61a9b88cf
@ -1664,14 +1664,14 @@ class CuTeLayoutTest(TestCase):
|
||||
def test_remap_to_tensor(self):
|
||||
"""Test the remap_to_tensor method for various scenarios."""
|
||||
# Test 1: Consecutive ranks, full world - should return logical groups directly
|
||||
original_mesh = torch.tensor([[0, 1], [2, 3]], dtype=torch.int)
|
||||
original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int)
|
||||
layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2
|
||||
result1 = layout1.remap_to_tensor(original_mesh)
|
||||
expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
|
||||
self.assertEqual(result1, expected1)
|
||||
|
||||
# Test 2: Non-consecutive ranks - should map to actual ranks
|
||||
original_mesh = torch.tensor([[10, 20], [30, 40]], dtype=torch.int)
|
||||
original_mesh = torch.tensor([10, 20, 30, 40], dtype=torch.int)
|
||||
layout2 = _Layout((2, 2), (2, 1))
|
||||
result2 = layout2.remap_to_tensor(original_mesh)
|
||||
expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
|
||||
@ -1692,7 +1692,7 @@ class CuTeLayoutTest(TestCase):
|
||||
self.assertEqual(result5, expected5)
|
||||
|
||||
# Test 6: Tensor Cute representation of a 2D mesh
|
||||
original_mesh = torch.tensor([[0, 2], [1, 3]], dtype=torch.int)
|
||||
original_mesh = torch.tensor([0, 2, 1, 3], dtype=torch.int)
|
||||
layout6 = _Layout((2, 2), (1, 2)) # column-major style
|
||||
result6 = layout6.remap_to_tensor(original_mesh)
|
||||
expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
|
||||
|
@ -301,10 +301,7 @@ class _MeshLayout(Layout):
|
||||
ranks = self.all_ranks_from_zero()
|
||||
return len(ranks) == len(set(ranks))
|
||||
|
||||
def remap_to_tensor(
|
||||
self,
|
||||
mesh_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def remap_to_tensor(self, rank_map: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Leverage layout as an index for mesh tensor that re-maps the indexes after layout
|
||||
transformation to actual device ranks.
|
||||
@ -316,10 +313,7 @@ class _MeshLayout(Layout):
|
||||
can be treated as a view or subset of mesh tensor, we do need to use the actual view or
|
||||
sub-tensor for DeviceMesh and its backend creation.
|
||||
|
||||
The shape of the `mesh_tensor` can be any size because users can define a device mesh with any
|
||||
shapes. But we can further refactor the code so that internally we can only support 1D mesh tensor
|
||||
and reconstruct the mesh tensor with the shape of the layout when accessed by users.
|
||||
#TODO: Only support 1D mesh tensor stored internally and reconstruct the mesh tensor via layout.
|
||||
The shape of the `rank_map` must be 1D and contiguous.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -336,18 +330,18 @@ class _MeshLayout(Layout):
|
||||
Return: [[[10,30],[20,40]]]
|
||||
|
||||
Args:
|
||||
mesh_tensor: The concrete mesh tensor with actual device ranks
|
||||
rank_map: The concrete mesh tensor with actual device ranks
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor representing the actual device allocation from mesh_tensor
|
||||
torch.Tensor: A tensor representing the actual device allocation from rank_map
|
||||
"""
|
||||
complement_layout = self.complement(mesh_tensor.numel())
|
||||
assert rank_map.ndim == 1
|
||||
assert rank_map.is_contiguous()
|
||||
assert rank_map.numel() >= self.cosize()
|
||||
|
||||
return (
|
||||
mesh_tensor.flatten()
|
||||
.as_strided(
|
||||
flatten(complement_layout.sizes) + flatten(self.sizes),
|
||||
flatten(complement_layout.strides) + flatten(self.strides),
|
||||
)
|
||||
.reshape(-1, *(self[i].numel() for i in range(len(self))))
|
||||
)
|
||||
complement_layout = self.complement(rank_map.numel())
|
||||
|
||||
return rank_map.as_strided(
|
||||
flatten(complement_layout.sizes) + flatten(self.sizes),
|
||||
flatten(complement_layout.strides) + flatten(self.strides),
|
||||
).reshape(-1, *self.top_level_sizes)
|
||||
|
@ -173,7 +173,7 @@ else:
|
||||
"""
|
||||
|
||||
_device_type: str
|
||||
_mesh: torch.Tensor
|
||||
_rank_map: torch.Tensor
|
||||
_mesh_dim_names: Optional[tuple[str, ...]]
|
||||
_layout: _MeshLayout
|
||||
_root_mesh: Optional["DeviceMesh"] = None
|
||||
@ -190,46 +190,49 @@ else:
|
||||
_init_backend: bool = True,
|
||||
_rank: Optional[int] = None,
|
||||
_layout: Optional[_MeshLayout] = None,
|
||||
_root_mesh: Optional["DeviceMesh"] = None,
|
||||
) -> None:
|
||||
self._device_type = device_type
|
||||
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
|
||||
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
|
||||
self._mesh = (
|
||||
mesh_tensor = (
|
||||
mesh.detach().to(dtype=torch.int).contiguous()
|
||||
if isinstance(mesh, torch.Tensor)
|
||||
else torch.tensor(mesh, device="cpu", dtype=torch.int)
|
||||
)
|
||||
self._rank_map = (
|
||||
_root_mesh._rank_map
|
||||
if _root_mesh is not None
|
||||
else mesh_tensor.flatten()
|
||||
)
|
||||
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
|
||||
if backend_override is None:
|
||||
backend_override = ((None, None),) * self.mesh.ndim
|
||||
elif len(backend_override) != self.mesh.ndim:
|
||||
raise ValueError(
|
||||
f"backend_override should have the same length as the number of mesh dimensions, "
|
||||
f"but got {len(backend_override)} and {self.mesh.ndim}."
|
||||
)
|
||||
# Internal bookkeeping for the device mesh.
|
||||
self._layout = (
|
||||
_layout
|
||||
if _layout
|
||||
else _MeshLayout(self.mesh.size(), self.mesh.stride())
|
||||
else _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
|
||||
)
|
||||
self._root_mesh = _root_mesh
|
||||
assert self._layout.check_non_overlap(), (
|
||||
"Please use a non-overlapping layout when creating a DeviceMesh."
|
||||
)
|
||||
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
|
||||
assert self._layout.top_level_sizes == self.mesh.size(), (
|
||||
assert self._layout.top_level_sizes == mesh_tensor.size(), (
|
||||
"Please use a valid layout when creating a DeviceMesh."
|
||||
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
|
||||
f"The layout {self._layout} is not consistent with the mesh size {mesh_tensor.size()}."
|
||||
)
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||
self._thread_id = None
|
||||
# Initialize instance-specific flatten mapping
|
||||
self._flatten_mapping = {}
|
||||
if backend_override is None:
|
||||
backend_override = ((None, None),) * len(self._layout)
|
||||
elif len(backend_override) != len(self._layout):
|
||||
raise ValueError(
|
||||
f"backend_override should have the same length as the number of mesh dimensions, "
|
||||
f"but got {len(backend_override)} and {len(self._layout)}."
|
||||
)
|
||||
|
||||
# Skip process group initialization if xla device or init backend is False
|
||||
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
|
||||
self._thread_id = None
|
||||
if device_type != "xla":
|
||||
# always try to create default (world) pg, even if it is not initialized
|
||||
# already. The world pg is used for device mesh identity (rank) on each
|
||||
@ -252,6 +255,11 @@ else:
|
||||
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
||||
)
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||
# Initialize instance-specific flatten mapping
|
||||
self._flatten_mapping = {}
|
||||
|
||||
@property
|
||||
def device_type(self) -> str:
|
||||
"""Returns the device type of the mesh."""
|
||||
@ -260,7 +268,17 @@ else:
|
||||
@property
|
||||
def mesh(self) -> torch.Tensor:
|
||||
"""Returns the tensor representing the layout of devices."""
|
||||
return self._mesh
|
||||
full_mesh = self._layout.remap_to_tensor(self._rank_map)
|
||||
if full_mesh.size(0) == 1:
|
||||
return full_mesh[0]
|
||||
my_coords = (full_mesh == get_rank()).nonzero()
|
||||
if my_coords.size(0) > 0:
|
||||
return full_mesh[my_coords[0, 0]]
|
||||
raise RuntimeError(
|
||||
"In order to get the mesh Tensor of a DeviceMesh it needs to "
|
||||
"either have all its original dimensions (e.g., no slicing) "
|
||||
"or it needs to contain the local rank"
|
||||
)
|
||||
|
||||
@property
|
||||
def mesh_dim_names(self) -> Optional[tuple[str, ...]]:
|
||||
@ -275,9 +293,9 @@ else:
|
||||
init_process_group()
|
||||
|
||||
world_size = get_world_size()
|
||||
if self.mesh.numel() > world_size:
|
||||
if self._layout.numel() > world_size:
|
||||
raise RuntimeError(
|
||||
f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!"
|
||||
f"Mesh should not be bigger than default world size {world_size}, but found {self._layout.numel()} ranks!"
|
||||
)
|
||||
|
||||
# ONLY set the device if the current device is not initialized, if user already
|
||||
@ -328,8 +346,8 @@ else:
|
||||
default_group = _get_default_group()
|
||||
|
||||
if (
|
||||
self.mesh.ndim == 1
|
||||
and self.mesh.numel() == get_world_size()
|
||||
len(self._layout) == 1
|
||||
and self._layout.numel() == get_world_size()
|
||||
and backend_override[0] == (None, None)
|
||||
):
|
||||
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
|
||||
@ -348,11 +366,11 @@ else:
|
||||
dim_group_names.append(dim_group.group_name)
|
||||
else:
|
||||
# create sub pgs base on the mesh argument specified
|
||||
for dim in range(self.mesh.ndim):
|
||||
for dim in range(len(self._layout)):
|
||||
# swap the current dim to the last dim
|
||||
# then reshape to flatten out other dims
|
||||
pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
|
||||
-1, self.mesh.size(dim)
|
||||
pg_ranks_by_dim = (
|
||||
self._layout[dim].nest().remap_to_tensor(self._rank_map)
|
||||
)
|
||||
backend, pg_options = backend_override[dim]
|
||||
# We need to explicitly pass in timeout when specified in option, otherwise
|
||||
@ -448,14 +466,14 @@ else:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
device_mesh_repr = (
|
||||
f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._mesh.shape))})"
|
||||
f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._layout.top_level_sizes))})"
|
||||
if self._mesh_dim_names
|
||||
else f"{tuple(self._mesh.shape)}"
|
||||
else f"{self._layout.top_level_sizes}"
|
||||
)
|
||||
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._mesh.stride()}"
|
||||
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._layout.strides}"
|
||||
# We only print the mesh tensor if the debug mode is turned on.
|
||||
if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL":
|
||||
device_mesh_repr += f", Mesh: {self._mesh.tolist()}"
|
||||
device_mesh_repr += f", Mesh: {self.mesh.tolist()}"
|
||||
return f"{device_mesh_repr})"
|
||||
|
||||
def __hash__(self):
|
||||
@ -465,7 +483,7 @@ else:
|
||||
self._hash = hash(
|
||||
(
|
||||
self._flatten_mesh_list,
|
||||
self._mesh.shape,
|
||||
self._layout,
|
||||
self._device_type,
|
||||
self._mesh_dim_names,
|
||||
self._thread_id,
|
||||
@ -481,7 +499,7 @@ else:
|
||||
return False
|
||||
return (
|
||||
self._flatten_mesh_list == other._flatten_mesh_list
|
||||
and self._mesh.shape == other._mesh.shape
|
||||
and self._layout == other._layout
|
||||
and self._device_type == other._device_type
|
||||
and self._mesh_dim_names == other._mesh_dim_names
|
||||
and self._thread_id == other._thread_id
|
||||
@ -573,16 +591,16 @@ else:
|
||||
if not hasattr(self, "_dim_group_names"):
|
||||
raise RuntimeError("DeviceMesh process groups not initialized!")
|
||||
|
||||
if self.mesh.ndim > 1 and mesh_dim is None:
|
||||
if len(self._layout) > 1 and mesh_dim is None:
|
||||
raise RuntimeError(
|
||||
f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
|
||||
f"Found the DeviceMesh have {len(self._layout)} dimensions",
|
||||
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
|
||||
"If you want to get the list of all the ProcessGroups in the DeviceMesh,"
|
||||
"please use `get_all_groups()` instead.",
|
||||
)
|
||||
|
||||
# Quick return if the current device_mesh is a 1D mesh.
|
||||
if self.mesh.ndim == 1 and mesh_dim is None:
|
||||
if len(self._layout) == 1 and mesh_dim is None:
|
||||
return not_none(_resolve_process_group(self._dim_group_names[0]))
|
||||
|
||||
root_mesh = self._get_root_mesh()
|
||||
@ -608,7 +626,7 @@ else:
|
||||
Returns:
|
||||
A list of :class:`ProcessGroup` object.
|
||||
"""
|
||||
return [self.get_group(i) for i in range(self.mesh.ndim)]
|
||||
return [self.get_group(i) for i in range(len(self._layout))]
|
||||
|
||||
def _create_sub_mesh(
|
||||
self,
|
||||
@ -635,9 +653,7 @@ else:
|
||||
]
|
||||
)
|
||||
cur_rank = self.get_rank()
|
||||
pg_ranks_by_dim = layout.remap_to_tensor(
|
||||
root_mesh.mesh,
|
||||
)
|
||||
pg_ranks_by_dim = layout.remap_to_tensor(root_mesh._rank_map)
|
||||
res_submesh = DeviceMesh._create_mesh_from_ranks(
|
||||
self._device_type,
|
||||
pg_ranks_by_dim,
|
||||
@ -692,9 +708,7 @@ else:
|
||||
cur_rank = root_mesh.get_rank()
|
||||
# Due to the limitation of ProcessGroup api, we need to start from root mesh so that all ranks call the
|
||||
# new_group api to avoid potential hang.
|
||||
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(
|
||||
root_mesh.mesh,
|
||||
)
|
||||
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(root_mesh._rank_map)
|
||||
res_flattened_mesh = DeviceMesh._create_mesh_from_ranks(
|
||||
root_mesh._device_type,
|
||||
pg_ranks_by_dim.flatten(
|
||||
@ -833,9 +847,7 @@ else:
|
||||
"""
|
||||
mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name)
|
||||
layout = self._layout[mesh_dim]
|
||||
pg_ranks_by_dim = layout.remap_to_tensor(
|
||||
self.mesh,
|
||||
)
|
||||
pg_ranks_by_dim = layout.remap_to_tensor(self._rank_map)
|
||||
cur_rank = self.get_rank()
|
||||
res_submeshes = []
|
||||
for mesh_1d in pg_ranks_by_dim:
|
||||
@ -896,6 +908,7 @@ else:
|
||||
backend_override=backend_override,
|
||||
_init_backend=_init_backend,
|
||||
_layout=_layout,
|
||||
_root_mesh=_root_mesh,
|
||||
)
|
||||
if cur_rank in mesh_nd:
|
||||
res_mesh = mesh
|
||||
@ -904,8 +917,6 @@ else:
|
||||
f"Current rank {cur_rank} not found in any mesh, "
|
||||
f"input {pg_ranks_by_dim} does not contain all ranks in the world"
|
||||
)
|
||||
if _root_mesh is not None:
|
||||
res_mesh._root_mesh = _root_mesh
|
||||
return res_mesh
|
||||
|
||||
@staticmethod
|
||||
@ -1004,15 +1015,17 @@ else:
|
||||
return device_mesh
|
||||
|
||||
def size(self, mesh_dim: Optional[int] = None) -> int:
|
||||
return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim)
|
||||
if mesh_dim is not None:
|
||||
return self._layout[mesh_dim].numel()
|
||||
return self._layout.numel()
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
return self.mesh.ndim
|
||||
return len(self._layout)
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
return tuple(self.mesh.shape)
|
||||
return self._layout.top_level_sizes
|
||||
|
||||
def get_rank(self) -> int:
|
||||
"""
|
||||
@ -1051,7 +1064,7 @@ else:
|
||||
"""
|
||||
if self.ndim > 1 and mesh_dim is None:
|
||||
raise RuntimeError(
|
||||
f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
|
||||
f"Found the DeviceMesh have {len(self._layout)} dimensions",
|
||||
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
|
||||
)
|
||||
elif mesh_dim is None:
|
||||
@ -1115,9 +1128,7 @@ else:
|
||||
root_mesh = self._get_root_mesh()
|
||||
cur_rank = self.get_rank()
|
||||
unflattened_layout = self._layout.unflatten(dim, mesh_sizes)
|
||||
pg_ranks_by_dim = unflattened_layout.remap_to_tensor(
|
||||
root_mesh.mesh,
|
||||
)
|
||||
pg_ranks_by_dim = unflattened_layout.remap_to_tensor(root_mesh._rank_map)
|
||||
unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names))
|
||||
unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names)
|
||||
res_mesh = DeviceMesh._create_mesh_from_ranks(
|
||||
@ -1141,7 +1152,7 @@ else:
|
||||
tuple(unflattened_layout.strides[dim : dim + unflatten_length]), # type: ignore[index]
|
||||
)
|
||||
unflatten_pg_ranks_by_dim = unflatten_layout.remap_to_tensor(
|
||||
root_mesh.mesh,
|
||||
root_mesh._rank_map
|
||||
)
|
||||
unflatten_submesh = DeviceMesh._create_mesh_from_ranks(
|
||||
self.device_type,
|
||||
|
Reference in New Issue
Block a user