[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh (#165492)

With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

Differential Revision: [D85409698](https://our.internmc.facebook.com/intern/diff/D85409698)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165492
Approved by: https://github.com/fegin
ghstack dependencies: #163358
This commit is contained in:
fduwjj
2025-10-26 21:03:43 -07:00
committed by PyTorch MergeBot
parent e214af6ae8
commit f2c81635c8

View File

@ -10,9 +10,9 @@ import torch
import torch.nn as nn
from torch._prims_common import make_contiguous_strides_for
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor.device_mesh import _mesh_resources
from torch.distributed.tensor.placement_types import _StridedShard, Placement
from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
@ -289,22 +289,12 @@ class FSDPParam:
if self.is_dtensor:
self._tp_spec = cast(DTensor, param)._spec
dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh)
dp_global_mesh = dp_mesh._get_root_mesh() if dp_mesh is not None else None
tp_global_mesh = tp_mesh._get_root_mesh() if tp_mesh is not None else None
if dp_global_mesh != tp_global_mesh or (
dp_global_mesh is None or tp_global_mesh is None
):
if dp_mesh is None or tp_mesh is None:
raise AssertionError(
"FSDP requires the DP and model parallel TP/EP mesh to have the same parent mesh but got: \n"
f"DP's global mesh: {dp_global_mesh}\nTP/EP's global mesh: {tp_global_mesh}"
"FSDP requires the DP and model parallel TP/EP mesh to be not None but got: \n"
f"DP's mesh: {dp_mesh}\nTP/EP's mesh: {tp_mesh}"
)
name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism"
if dp_mesh.mesh_dim_names is None:
raise AssertionError(name_dims_error)
if tp_mesh.mesh_dim_names is None:
raise AssertionError(name_dims_error)
submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names
self._spmd_mesh = dp_global_mesh[submesh_names]
self._spmd_mesh = DeviceMesh._concatenate([dp_mesh, tp_mesh])
if len(self._tp_spec.placements) > 2:
raise NotImplementedError(
f"FSDP only supports 1D TP/EP or 2D EP+TP, not {self._tp_spec.placements}"
@ -843,7 +833,7 @@ class FSDPParam:
raise AssertionError("Expected mesh_dim_names to not be None")
shard_dim_name = mesh.mesh_dim_names[-1]
root_mesh = _mesh_resources.get_root_mesh(mesh)
root_mesh = mesh._get_root_mesh()
return root_mesh[shard_dim_name]
def _assert_in_states(self, *states: ShardedState) -> None: