Revert "[DeviceMesh] Simplify unflatten method (#165556)"

This reverts commit 86fd4fc23e697e275d37c36e3cbe521f156434fd.

Reverted https://github.com/pytorch/pytorch/pull/165556 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see aba8c43594/1 ([comment](https://github.com/pytorch/pytorch/pull/165554#issuecomment-3412765681))
This commit is contained in:
PyTorch MergeBot
2025-10-16 20:41:34 +00:00
parent aead9270f5
commit 431c13cf61
4 changed files with 84 additions and 57 deletions

View File

@ -9,7 +9,6 @@ from itertools import product
import torch
from torch.distributed._pycute import (
as_tuple,
coalesce,
complement,
composition,
@ -18,6 +17,7 @@ from torch.distributed._pycute import (
is_int,
is_tuple,
Layout,
suffix_product,
)
@ -79,11 +79,6 @@ class _MeshLayout(Layout):
# # operator [] (get-i like tuples)
def __getitem__(self, i: int) -> "_MeshLayout":
if i < -len(self) or i >= len(self):
raise IndexError(
f"Dim {i} is out of range for layout with {len(self)} dimensions. "
f"Expected dim to be in range [{-len(self)}, {len(self) - 1}]."
)
layout = super().__getitem__(i)
return _MeshLayout(layout.shape, layout.stride)
@ -161,11 +156,50 @@ class _MeshLayout(Layout):
layout = complement(self, world_size)
return _MeshLayout(layout.shape, layout.stride)
def splice(self, start: int, end: int, layout: "_MeshLayout") -> "_MeshLayout":
sizes = list(as_tuple(self.sizes))
strides = list(as_tuple(self.strides))
sizes[start:end] = list(as_tuple(layout.sizes))
strides[start:end] = list(as_tuple(layout.strides))
def unflatten(self, dim: int, unflatten_sizes: tuple[int, ...]) -> "_MeshLayout":
"""
Unflatten a single dimension in the layout by splitting it into multiple dimensions.
It takes a dimension at position `dim` and splits it into multiple new dimensions
with the specified sizes.
Args:
dim (int): The index of the dimension to unflatten. Must be a valid dimension index.
unflatten_sizes (tuple[int, ...]): The new sizes for the dimensions that will replace
the original dimension at `dim`. The product of these sizes must equal the size
of the original dimension at `dim`.
Returns:
_MeshLayout: A new layout with the specified dimension unflattened.
Example:
Original: sizes=(8,), strides=(1,) # 8 ranks in 1D
Call: unflatten(0, (2, 2, 2)) # Create 3D topology
Result: sizes=(2, 2, 2), strides=(4, 2, 1) # 2*2*2 unflattened topology
"""
# Check that dim is within valid range
if dim < 0 or dim >= len(self):
raise ValueError(
f"dim {dim} is out of range for layout with {len(self)} dimensions. "
f"Expected dim to be in range [0, {len(self) - 1}]."
)
# Check that the product of unflatten_sizes equals the original dimension size
original_size = self[dim].numel()
unflatten_product = math.prod(unflatten_sizes)
if unflatten_product != original_size:
raise ValueError(
f"The product of unflatten_sizes {unflatten_sizes} is {unflatten_product}, "
f"but the original dimension at dim={dim} has size {original_size}. "
f"These must be equal for unflatten to work correctly."
)
sizes = list(self.sizes) # type: ignore[arg-type]
strides = list(self.strides) # type: ignore[arg-type]
unflatten_layout = self[dim].composition(
_MeshLayout(tuple(unflatten_sizes), suffix_product(unflatten_sizes))
)
sizes[dim : dim + 1] = list(unflatten_layout.sizes) # type: ignore[arg-type]
strides[dim : dim + 1] = list(unflatten_layout.strides) # type: ignore[arg-type]
return _MeshLayout(tuple(sizes), tuple(strides))
def all_ranks_from_zero(self) -> list[int]:

View File

@ -31,7 +31,6 @@
#################################################################################################
from .int_tuple import (
as_tuple,
crd2crd,
crd2idx,
elem_scale,

View File

@ -54,12 +54,6 @@ def is_tuple(x: object) -> TypeIs[tuple]:
return isinstance(x, tuple)
def as_tuple(x: IntTuple) -> tuple[IntTuple, ...]:
if is_int(x):
return (x,)
return x
def flatten(t: IntTuple) -> tuple[int, ...]:
if is_tuple(t):
if len(t) == 0:

View File

@ -245,12 +245,7 @@ else:
# process (we need to know if the current global rank is in the mesh or not).
if _init_backend:
self._setup_world_group_and_device()
self._dim_group_names = self._init_process_groups(
self._layout,
self._rank_map,
self._mesh_dim_names,
backend_override,
)
self._init_process_groups(backend_override)
if is_initialized() and get_backend() == "threaded":
# pyrefly: ignore # bad-assignment
@ -346,13 +341,10 @@ else:
return _get_default_group()
@staticmethod
def _init_process_groups(
layout: _MeshLayout,
rank_map: torch.Tensor,
mesh_dim_names: Optional[tuple[str, ...]],
self,
backend_override: tuple[BackendConfig, ...],
) -> list[str]:
):
# group_name associated with each mesh dimension, each
# mesh dimension should have one sub-group per rank
#
@ -360,8 +352,8 @@ else:
default_group = _get_default_group()
if (
len(layout) == 1
and layout.numel() == get_world_size()
len(self._layout) == 1
and self._layout.numel() == get_world_size()
and backend_override[0] == (None, None)
):
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
@ -380,10 +372,12 @@ else:
dim_group_names.append(dim_group.group_name)
else:
# create sub pgs base on the mesh argument specified
for dim in range(len(layout)):
for dim in range(len(self._layout)):
# swap the current dim to the last dim
# then reshape to flatten out other dims
pg_ranks_by_dim = layout[dim].nest().remap_to_tensor(rank_map)
pg_ranks_by_dim = (
self._layout[dim].nest().remap_to_tensor(self._rank_map)
)
backend, pg_options = backend_override[dim]
# We need to explicitly pass in timeout when specified in option, otherwise
# the default timeout will be used to override the timeout set in option.
@ -395,8 +389,8 @@ else:
# If the mesh doesn't not have a mesh_dim_names, then the group description of the
# subgroup would be `mesh_dim_0` and `mesh_dim_1`.
group_desc = (
f"mesh_{mesh_dim_names[dim]}"
if mesh_dim_names
f"mesh_{self._mesh_dim_names[dim]}"
if self._mesh_dim_names
else f"mesh_dim_{dim}"
)
@ -454,14 +448,14 @@ else:
)
# only add to dim_groups if the current rank in the subgroup
if get_rank() in subgroup_ranks:
if self.get_rank() in subgroup_ranks:
if len(dim_group_names) > dim:
raise RuntimeError(
f"Each device mesh dimension should get only one process group, but got {get_rank()} "
f"Each device mesh dimension should get only one process group, but got {self.get_rank()} "
f"in {subgroup_ranks}!"
)
dim_group_names.append(dim_group.group_name) # type: ignore[union-attr]
return dim_group_names
self._dim_group_names = dim_group_names
def _get_root_mesh(self) -> "DeviceMesh":
return self._root_mesh if self._root_mesh else self
@ -1074,21 +1068,10 @@ else:
tuple[Optional[str], Optional[C10dBackend.Options]], ...
] = ((None, None),),
) -> "DeviceMesh":
inner_layout = _MeshLayout(tuple(mesh_sizes), suffix_product(mesh_sizes))
if inner_layout.numel() != self._layout[dim].numel():
raise ValueError(
f"The product of {mesh_sizes=} is {inner_layout.numel()}, "
f"but the original dimension at dim={dim} has size {self._layout[dim].numel()}. "
f"These must be equal for unflatten to work correctly."
)
partial_layout = self._layout[dim].composition(inner_layout)
unflattened_layout = self._layout.splice(dim, dim + 1, partial_layout)
root_mesh = self._get_root_mesh()
unflattened_layout = self._layout.unflatten(dim, mesh_sizes)
unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names))
unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names)
root_mesh = self._get_root_mesh()
res_mesh = DeviceMesh(
self.device_type,
_layout=unflattened_layout,
@ -1103,13 +1086,30 @@ else:
# TODO: To make backend init more efficient with cute layout representation and support
# per dim backend init.
if hasattr(self, "_dim_group_names"):
dim_group_names = self._dim_group_names.copy()
dim_group_names[dim : dim + 1] = self._init_process_groups(
partial_layout,
root_mesh._rank_map,
mesh_dim_names,
backend_override,
unflatten_length = len(mesh_sizes)
unflatten_layout = _MeshLayout(
tuple(unflattened_layout.sizes[dim : dim + unflatten_length]), # type: ignore[index]
tuple(unflattened_layout.strides[dim : dim + unflatten_length]), # type: ignore[index]
)
unflatten_submesh = DeviceMesh(
self.device_type,
_layout=unflatten_layout,
_rank_map=root_mesh._rank_map,
mesh_dim_names=mesh_dim_names,
backend_override=backend_override,
)
dim_group_names = []
for idx in range(0, res_mesh.ndim):
if idx < dim:
dim_group_names.append(self._dim_group_names[idx])
elif idx >= dim + unflatten_length:
dim_group_names.append(
self._dim_group_names[idx - unflatten_length + 1]
)
else:
dim_group_names.append(
unflatten_submesh._dim_group_names[idx - dim]
)
res_mesh._dim_group_names = dim_group_names
return res_mesh