mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "[DeviceMesh] Simplify unflatten method (#165556)"
This reverts commit 86fd4fc23e697e275d37c36e3cbe521f156434fd.
Reverted https://github.com/pytorch/pytorch/pull/165556 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see aba8c43594/1
([comment](https://github.com/pytorch/pytorch/pull/165554#issuecomment-3412765681))
This commit is contained in:
@ -9,7 +9,6 @@ from itertools import product
|
||||
|
||||
import torch
|
||||
from torch.distributed._pycute import (
|
||||
as_tuple,
|
||||
coalesce,
|
||||
complement,
|
||||
composition,
|
||||
@ -18,6 +17,7 @@ from torch.distributed._pycute import (
|
||||
is_int,
|
||||
is_tuple,
|
||||
Layout,
|
||||
suffix_product,
|
||||
)
|
||||
|
||||
|
||||
@ -79,11 +79,6 @@ class _MeshLayout(Layout):
|
||||
|
||||
# # operator [] (get-i like tuples)
|
||||
def __getitem__(self, i: int) -> "_MeshLayout":
|
||||
if i < -len(self) or i >= len(self):
|
||||
raise IndexError(
|
||||
f"Dim {i} is out of range for layout with {len(self)} dimensions. "
|
||||
f"Expected dim to be in range [{-len(self)}, {len(self) - 1}]."
|
||||
)
|
||||
layout = super().__getitem__(i)
|
||||
return _MeshLayout(layout.shape, layout.stride)
|
||||
|
||||
@ -161,11 +156,50 @@ class _MeshLayout(Layout):
|
||||
layout = complement(self, world_size)
|
||||
return _MeshLayout(layout.shape, layout.stride)
|
||||
|
||||
def splice(self, start: int, end: int, layout: "_MeshLayout") -> "_MeshLayout":
|
||||
sizes = list(as_tuple(self.sizes))
|
||||
strides = list(as_tuple(self.strides))
|
||||
sizes[start:end] = list(as_tuple(layout.sizes))
|
||||
strides[start:end] = list(as_tuple(layout.strides))
|
||||
def unflatten(self, dim: int, unflatten_sizes: tuple[int, ...]) -> "_MeshLayout":
|
||||
"""
|
||||
Unflatten a single dimension in the layout by splitting it into multiple dimensions.
|
||||
It takes a dimension at position `dim` and splits it into multiple new dimensions
|
||||
with the specified sizes.
|
||||
|
||||
Args:
|
||||
dim (int): The index of the dimension to unflatten. Must be a valid dimension index.
|
||||
unflatten_sizes (tuple[int, ...]): The new sizes for the dimensions that will replace
|
||||
the original dimension at `dim`. The product of these sizes must equal the size
|
||||
of the original dimension at `dim`.
|
||||
|
||||
Returns:
|
||||
_MeshLayout: A new layout with the specified dimension unflattened.
|
||||
|
||||
Example:
|
||||
Original: sizes=(8,), strides=(1,) # 8 ranks in 1D
|
||||
Call: unflatten(0, (2, 2, 2)) # Create 3D topology
|
||||
Result: sizes=(2, 2, 2), strides=(4, 2, 1) # 2*2*2 unflattened topology
|
||||
"""
|
||||
# Check that dim is within valid range
|
||||
if dim < 0 or dim >= len(self):
|
||||
raise ValueError(
|
||||
f"dim {dim} is out of range for layout with {len(self)} dimensions. "
|
||||
f"Expected dim to be in range [0, {len(self) - 1}]."
|
||||
)
|
||||
|
||||
# Check that the product of unflatten_sizes equals the original dimension size
|
||||
original_size = self[dim].numel()
|
||||
unflatten_product = math.prod(unflatten_sizes)
|
||||
if unflatten_product != original_size:
|
||||
raise ValueError(
|
||||
f"The product of unflatten_sizes {unflatten_sizes} is {unflatten_product}, "
|
||||
f"but the original dimension at dim={dim} has size {original_size}. "
|
||||
f"These must be equal for unflatten to work correctly."
|
||||
)
|
||||
|
||||
sizes = list(self.sizes) # type: ignore[arg-type]
|
||||
strides = list(self.strides) # type: ignore[arg-type]
|
||||
unflatten_layout = self[dim].composition(
|
||||
_MeshLayout(tuple(unflatten_sizes), suffix_product(unflatten_sizes))
|
||||
)
|
||||
sizes[dim : dim + 1] = list(unflatten_layout.sizes) # type: ignore[arg-type]
|
||||
strides[dim : dim + 1] = list(unflatten_layout.strides) # type: ignore[arg-type]
|
||||
return _MeshLayout(tuple(sizes), tuple(strides))
|
||||
|
||||
def all_ranks_from_zero(self) -> list[int]:
|
||||
|
@ -31,7 +31,6 @@
|
||||
#################################################################################################
|
||||
|
||||
from .int_tuple import (
|
||||
as_tuple,
|
||||
crd2crd,
|
||||
crd2idx,
|
||||
elem_scale,
|
||||
|
@ -54,12 +54,6 @@ def is_tuple(x: object) -> TypeIs[tuple]:
|
||||
return isinstance(x, tuple)
|
||||
|
||||
|
||||
def as_tuple(x: IntTuple) -> tuple[IntTuple, ...]:
|
||||
if is_int(x):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
|
||||
def flatten(t: IntTuple) -> tuple[int, ...]:
|
||||
if is_tuple(t):
|
||||
if len(t) == 0:
|
||||
|
@ -245,12 +245,7 @@ else:
|
||||
# process (we need to know if the current global rank is in the mesh or not).
|
||||
if _init_backend:
|
||||
self._setup_world_group_and_device()
|
||||
self._dim_group_names = self._init_process_groups(
|
||||
self._layout,
|
||||
self._rank_map,
|
||||
self._mesh_dim_names,
|
||||
backend_override,
|
||||
)
|
||||
self._init_process_groups(backend_override)
|
||||
|
||||
if is_initialized() and get_backend() == "threaded":
|
||||
# pyrefly: ignore # bad-assignment
|
||||
@ -346,13 +341,10 @@ else:
|
||||
|
||||
return _get_default_group()
|
||||
|
||||
@staticmethod
|
||||
def _init_process_groups(
|
||||
layout: _MeshLayout,
|
||||
rank_map: torch.Tensor,
|
||||
mesh_dim_names: Optional[tuple[str, ...]],
|
||||
self,
|
||||
backend_override: tuple[BackendConfig, ...],
|
||||
) -> list[str]:
|
||||
):
|
||||
# group_name associated with each mesh dimension, each
|
||||
# mesh dimension should have one sub-group per rank
|
||||
#
|
||||
@ -360,8 +352,8 @@ else:
|
||||
default_group = _get_default_group()
|
||||
|
||||
if (
|
||||
len(layout) == 1
|
||||
and layout.numel() == get_world_size()
|
||||
len(self._layout) == 1
|
||||
and self._layout.numel() == get_world_size()
|
||||
and backend_override[0] == (None, None)
|
||||
):
|
||||
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
|
||||
@ -380,10 +372,12 @@ else:
|
||||
dim_group_names.append(dim_group.group_name)
|
||||
else:
|
||||
# create sub pgs base on the mesh argument specified
|
||||
for dim in range(len(layout)):
|
||||
for dim in range(len(self._layout)):
|
||||
# swap the current dim to the last dim
|
||||
# then reshape to flatten out other dims
|
||||
pg_ranks_by_dim = layout[dim].nest().remap_to_tensor(rank_map)
|
||||
pg_ranks_by_dim = (
|
||||
self._layout[dim].nest().remap_to_tensor(self._rank_map)
|
||||
)
|
||||
backend, pg_options = backend_override[dim]
|
||||
# We need to explicitly pass in timeout when specified in option, otherwise
|
||||
# the default timeout will be used to override the timeout set in option.
|
||||
@ -395,8 +389,8 @@ else:
|
||||
# If the mesh doesn't not have a mesh_dim_names, then the group description of the
|
||||
# subgroup would be `mesh_dim_0` and `mesh_dim_1`.
|
||||
group_desc = (
|
||||
f"mesh_{mesh_dim_names[dim]}"
|
||||
if mesh_dim_names
|
||||
f"mesh_{self._mesh_dim_names[dim]}"
|
||||
if self._mesh_dim_names
|
||||
else f"mesh_dim_{dim}"
|
||||
)
|
||||
|
||||
@ -454,14 +448,14 @@ else:
|
||||
)
|
||||
|
||||
# only add to dim_groups if the current rank in the subgroup
|
||||
if get_rank() in subgroup_ranks:
|
||||
if self.get_rank() in subgroup_ranks:
|
||||
if len(dim_group_names) > dim:
|
||||
raise RuntimeError(
|
||||
f"Each device mesh dimension should get only one process group, but got {get_rank()} "
|
||||
f"Each device mesh dimension should get only one process group, but got {self.get_rank()} "
|
||||
f"in {subgroup_ranks}!"
|
||||
)
|
||||
dim_group_names.append(dim_group.group_name) # type: ignore[union-attr]
|
||||
return dim_group_names
|
||||
self._dim_group_names = dim_group_names
|
||||
|
||||
def _get_root_mesh(self) -> "DeviceMesh":
|
||||
return self._root_mesh if self._root_mesh else self
|
||||
@ -1074,21 +1068,10 @@ else:
|
||||
tuple[Optional[str], Optional[C10dBackend.Options]], ...
|
||||
] = ((None, None),),
|
||||
) -> "DeviceMesh":
|
||||
inner_layout = _MeshLayout(tuple(mesh_sizes), suffix_product(mesh_sizes))
|
||||
|
||||
if inner_layout.numel() != self._layout[dim].numel():
|
||||
raise ValueError(
|
||||
f"The product of {mesh_sizes=} is {inner_layout.numel()}, "
|
||||
f"but the original dimension at dim={dim} has size {self._layout[dim].numel()}. "
|
||||
f"These must be equal for unflatten to work correctly."
|
||||
)
|
||||
|
||||
partial_layout = self._layout[dim].composition(inner_layout)
|
||||
unflattened_layout = self._layout.splice(dim, dim + 1, partial_layout)
|
||||
root_mesh = self._get_root_mesh()
|
||||
unflattened_layout = self._layout.unflatten(dim, mesh_sizes)
|
||||
unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names))
|
||||
unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names)
|
||||
|
||||
root_mesh = self._get_root_mesh()
|
||||
res_mesh = DeviceMesh(
|
||||
self.device_type,
|
||||
_layout=unflattened_layout,
|
||||
@ -1103,13 +1086,30 @@ else:
|
||||
# TODO: To make backend init more efficient with cute layout representation and support
|
||||
# per dim backend init.
|
||||
if hasattr(self, "_dim_group_names"):
|
||||
dim_group_names = self._dim_group_names.copy()
|
||||
dim_group_names[dim : dim + 1] = self._init_process_groups(
|
||||
partial_layout,
|
||||
root_mesh._rank_map,
|
||||
mesh_dim_names,
|
||||
backend_override,
|
||||
unflatten_length = len(mesh_sizes)
|
||||
unflatten_layout = _MeshLayout(
|
||||
tuple(unflattened_layout.sizes[dim : dim + unflatten_length]), # type: ignore[index]
|
||||
tuple(unflattened_layout.strides[dim : dim + unflatten_length]), # type: ignore[index]
|
||||
)
|
||||
unflatten_submesh = DeviceMesh(
|
||||
self.device_type,
|
||||
_layout=unflatten_layout,
|
||||
_rank_map=root_mesh._rank_map,
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
backend_override=backend_override,
|
||||
)
|
||||
dim_group_names = []
|
||||
for idx in range(0, res_mesh.ndim):
|
||||
if idx < dim:
|
||||
dim_group_names.append(self._dim_group_names[idx])
|
||||
elif idx >= dim + unflatten_length:
|
||||
dim_group_names.append(
|
||||
self._dim_group_names[idx - unflatten_length + 1]
|
||||
)
|
||||
else:
|
||||
dim_group_names.append(
|
||||
unflatten_submesh._dim_group_names[idx - dim]
|
||||
)
|
||||
res_mesh._dim_group_names = dim_group_names
|
||||
|
||||
return res_mesh
|
||||
|
Reference in New Issue
Block a user