mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[DeviceMesh][ez] Extract the pg creation as a util function (#163930)
This is just to extract common logic into a util function because we will use it many times for the following stack of Device Mesh refactoring. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163930 Approved by: https://github.com/fegin ghstack dependencies: #163212, #163288, #163928
This commit is contained in:
@ -160,6 +160,13 @@ else:
|
||||
)
|
||||
if cur_rank in mesh_nd:
|
||||
res_submesh = submesh
|
||||
res_submesh = DeviceMesh._create_mesh_from_ranks(
|
||||
device_mesh.device_type,
|
||||
pg_ranks_by_dim,
|
||||
cur_rank,
|
||||
submesh_dim_names,
|
||||
_init_backend=False,
|
||||
)
|
||||
|
||||
res_submesh._dim_group_names = slice_dim_group_name # type: ignore[possibly-undefined, has-type]
|
||||
self.child_to_root_mapping[res_submesh] = device_mesh
|
||||
@ -228,16 +235,13 @@ else:
|
||||
).reshape(-1, flattened_mesh_dim_size)
|
||||
|
||||
cur_rank = root_mesh.get_rank()
|
||||
for mesh_nd in pg_ranks_by_dim:
|
||||
# need to init backend here since the flattened pg doesn't exist in root mesh.
|
||||
flattened_mesh = DeviceMesh(
|
||||
root_mesh.device_type,
|
||||
mesh_nd,
|
||||
mesh_dim_names=(mesh_dim_name,),
|
||||
backend_override=(backend_override,),
|
||||
)
|
||||
if cur_rank in mesh_nd:
|
||||
res_flattened_mesh = flattened_mesh
|
||||
res_flattened_mesh = DeviceMesh._create_mesh_from_ranks(
|
||||
root_mesh.device_type,
|
||||
pg_ranks_by_dim,
|
||||
cur_rank,
|
||||
(mesh_dim_name,),
|
||||
(backend_override,),
|
||||
)
|
||||
self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined]
|
||||
self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = (
|
||||
res_flattened_mesh # type: ignore[possibly-undefined]
|
||||
@ -856,6 +860,55 @@ else:
|
||||
"""
|
||||
return [self.get_group(i) for i in range(self.mesh.ndim)]
|
||||
|
||||
@staticmethod
|
||||
def _create_mesh_from_ranks(
|
||||
device_type: str,
|
||||
pg_ranks_by_dim: torch.Tensor,
|
||||
cur_rank: int,
|
||||
mesh_dim_names: tuple[str, ...],
|
||||
backend_override: Optional[tuple[BackendConfig, ...]] = None,
|
||||
_init_backend: bool = True,
|
||||
) -> "DeviceMesh":
|
||||
"""
|
||||
Helper method to create a DeviceMesh from tensor `pg_ranks_by_dim`. This is due to
|
||||
the constraint of ProcessGroup API that all ranks have to call the PG creation API
|
||||
even if the rank is not in that PG.
|
||||
We will create a potentially very large number of DeviceMesh objects
|
||||
(e.g., on 1024 GPUs with TP=2, this could be up to 512 DeviceMeshes), only to throw
|
||||
them all away except when the mesh contains the current rank.
|
||||
|
||||
#TODO: Further refactor this method once we relax the ProcessGroup API constraint.
|
||||
|
||||
Args:
|
||||
device_type: The device type of the mesh.
|
||||
pg_ranks_by_dim: all ranks within the worlds organized by dimensions.
|
||||
cur_rank: The current global rank in the mesh.
|
||||
mesh_dim_names: Mesh dimension names.
|
||||
backend_override: Optional backend override for the mesh.
|
||||
_init_backend: Whether to initialize the backend of the mesh.
|
||||
_layout: Optional layout for the mesh.
|
||||
|
||||
Returns:
|
||||
The DeviceMesh containing the current rank.
|
||||
"""
|
||||
res_mesh = None
|
||||
for mesh_nd in pg_ranks_by_dim:
|
||||
mesh = DeviceMesh(
|
||||
device_type,
|
||||
mesh_nd,
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
backend_override=backend_override,
|
||||
_init_backend=_init_backend,
|
||||
)
|
||||
if cur_rank in mesh_nd:
|
||||
res_mesh = mesh
|
||||
if res_mesh is None:
|
||||
raise RuntimeError(
|
||||
f"Current rank {cur_rank} not found in any mesh, "
|
||||
f"input {pg_ranks_by_dim} does not contain all ranks in the world"
|
||||
)
|
||||
return res_mesh
|
||||
|
||||
@staticmethod
|
||||
def from_group(
|
||||
group: Union[ProcessGroup, list[ProcessGroup]],
|
||||
|
Reference in New Issue
Block a user