[DeviceMesh][ez] Extract the pg creation as a util function (#163930)

This is just to extract common logic into a util function because we will use it many times for the following stack of Device Mesh refactoring.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163930
Approved by: https://github.com/fegin
ghstack dependencies: #163212, #163288, #163928
This commit is contained in:
fduwjj
2025-09-26 10:26:59 -07:00
committed by PyTorch MergeBot
parent c257570e6c
commit a60c6ed99f

View File

@ -160,6 +160,13 @@ else:
)
if cur_rank in mesh_nd:
res_submesh = submesh
res_submesh = DeviceMesh._create_mesh_from_ranks(
device_mesh.device_type,
pg_ranks_by_dim,
cur_rank,
submesh_dim_names,
_init_backend=False,
)
res_submesh._dim_group_names = slice_dim_group_name # type: ignore[possibly-undefined, has-type]
self.child_to_root_mapping[res_submesh] = device_mesh
@ -228,16 +235,13 @@ else:
).reshape(-1, flattened_mesh_dim_size)
cur_rank = root_mesh.get_rank()
for mesh_nd in pg_ranks_by_dim:
# need to init backend here since the flattened pg doesn't exist in root mesh.
flattened_mesh = DeviceMesh(
root_mesh.device_type,
mesh_nd,
mesh_dim_names=(mesh_dim_name,),
backend_override=(backend_override,),
)
if cur_rank in mesh_nd:
res_flattened_mesh = flattened_mesh
res_flattened_mesh = DeviceMesh._create_mesh_from_ranks(
root_mesh.device_type,
pg_ranks_by_dim,
cur_rank,
(mesh_dim_name,),
(backend_override,),
)
self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined]
self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = (
res_flattened_mesh # type: ignore[possibly-undefined]
@ -856,6 +860,55 @@ else:
"""
return [self.get_group(i) for i in range(self.mesh.ndim)]
@staticmethod
def _create_mesh_from_ranks(
device_type: str,
pg_ranks_by_dim: torch.Tensor,
cur_rank: int,
mesh_dim_names: tuple[str, ...],
backend_override: Optional[tuple[BackendConfig, ...]] = None,
_init_backend: bool = True,
) -> "DeviceMesh":
"""
Helper method to create a DeviceMesh from tensor `pg_ranks_by_dim`. This is due to
the constraint of ProcessGroup API that all ranks have to call the PG creation API
even if the rank is not in that PG.
We will create a potentially very large number of DeviceMesh objects
(e.g., on 1024 GPUs with TP=2, this could be up to 512 DeviceMeshes), only to throw
them all away except when the mesh contains the current rank.
#TODO: Further refactor this method once we relax the ProcessGroup API constraint.
Args:
device_type: The device type of the mesh.
pg_ranks_by_dim: all ranks within the worlds organized by dimensions.
cur_rank: The current global rank in the mesh.
mesh_dim_names: Mesh dimension names.
backend_override: Optional backend override for the mesh.
_init_backend: Whether to initialize the backend of the mesh.
_layout: Optional layout for the mesh.
Returns:
The DeviceMesh containing the current rank.
"""
res_mesh = None
for mesh_nd in pg_ranks_by_dim:
mesh = DeviceMesh(
device_type,
mesh_nd,
mesh_dim_names=mesh_dim_names,
backend_override=backend_override,
_init_backend=_init_backend,
)
if cur_rank in mesh_nd:
res_mesh = mesh
if res_mesh is None:
raise RuntimeError(
f"Current rank {cur_rank} not found in any mesh, "
f"input {pg_ranks_by_dim} does not contain all ranks in the world"
)
return res_mesh
@staticmethod
def from_group(
group: Union[ProcessGroup, list[ProcessGroup]],