mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 23:45:05 +08:00
[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh (#165492)
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root. Differential Revision: [D85409698](https://our.internmc.facebook.com/intern/diff/D85409698) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165492 Approved by: https://github.com/fegin ghstack dependencies: #163358
This commit is contained in:
@ -10,9 +10,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch._prims_common import make_contiguous_strides_for
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor import DTensor, Replicate, Shard
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor.device_mesh import _mesh_resources
|
||||
from torch.distributed.tensor.placement_types import _StridedShard, Placement
|
||||
|
||||
from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
|
||||
@ -289,22 +289,12 @@ class FSDPParam:
|
||||
if self.is_dtensor:
|
||||
self._tp_spec = cast(DTensor, param)._spec
|
||||
dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh)
|
||||
dp_global_mesh = dp_mesh._get_root_mesh() if dp_mesh is not None else None
|
||||
tp_global_mesh = tp_mesh._get_root_mesh() if tp_mesh is not None else None
|
||||
if dp_global_mesh != tp_global_mesh or (
|
||||
dp_global_mesh is None or tp_global_mesh is None
|
||||
):
|
||||
if dp_mesh is None or tp_mesh is None:
|
||||
raise AssertionError(
|
||||
"FSDP requires the DP and model parallel TP/EP mesh to have the same parent mesh but got: \n"
|
||||
f"DP's global mesh: {dp_global_mesh}\nTP/EP's global mesh: {tp_global_mesh}"
|
||||
"FSDP requires the DP and model parallel TP/EP mesh to be not None but got: \n"
|
||||
f"DP's mesh: {dp_mesh}\nTP/EP's mesh: {tp_mesh}"
|
||||
)
|
||||
name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism"
|
||||
if dp_mesh.mesh_dim_names is None:
|
||||
raise AssertionError(name_dims_error)
|
||||
if tp_mesh.mesh_dim_names is None:
|
||||
raise AssertionError(name_dims_error)
|
||||
submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names
|
||||
self._spmd_mesh = dp_global_mesh[submesh_names]
|
||||
self._spmd_mesh = DeviceMesh._concatenate([dp_mesh, tp_mesh])
|
||||
if len(self._tp_spec.placements) > 2:
|
||||
raise NotImplementedError(
|
||||
f"FSDP only supports 1D TP/EP or 2D EP+TP, not {self._tp_spec.placements}"
|
||||
@ -843,7 +833,7 @@ class FSDPParam:
|
||||
raise AssertionError("Expected mesh_dim_names to not be None")
|
||||
shard_dim_name = mesh.mesh_dim_names[-1]
|
||||
|
||||
root_mesh = _mesh_resources.get_root_mesh(mesh)
|
||||
root_mesh = mesh._get_root_mesh()
|
||||
return root_mesh[shard_dim_name]
|
||||
|
||||
def _assert_in_states(self, *states: ShardedState) -> None:
|
||||
|
||||
Reference in New Issue
Block a user