Compare commits

...

2 Commits

Author SHA1 Message Date
1bcdc4f671 a quick fix for not import dist
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
2024-04-22 14:14:22 -07:00
09c15b0c10 [WIP] enforce process group for dp from input
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
2024-04-22 14:04:33 -07:00

View File

@ -31,6 +31,7 @@ def fully_shard(
mesh: Optional[DeviceMesh] = None,
reshard_after_forward: Union[bool, int] = True,
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
process_group: Optional[torch.distributed.ProcessGroup] = None
):
"""
Shard module parameters across data parallel workers.
@ -101,6 +102,8 @@ def fully_shard(
raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}")
elif mesh.ndim == 1:
mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0)
if process_group is not None:
mesh_info.shard_process_group = process_group
else:
mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
device = _get_device_from_mesh(mesh)