[2D] Enable 2D DTensor state_dict for FSDP + TP (#110846)

This PR adds a `chunk_dtensor()` method to fsdp/_fsdp_extensions.py and the actual implementation of `chunk_dtensor()` in tensor/parallel/fsdp.py. This enables FSDP to return 2D DTensor state_dict when composing FSDP with TP.

cc. @fegin
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110846
Approved by: https://github.com/fegin, https://github.com/wanchaol
ghstack dependencies: #110831
This commit is contained in:
wz337
2023-10-09 00:03:14 -07:00
committed by PyTorch MergeBot
parent 0bd4ce728b
commit 6c136c3302
5 changed files with 152 additions and 15 deletions

View File

@ -18,6 +18,7 @@ from torch.distributed._shard.sharded_tensor import (
ShardedTensor,
)
from torch.distributed._tensor import DTensor, Replicate
from torch.distributed._tensor.device_mesh import mesh_resources
from torch.distributed.distributed_c10d import _get_pg_default_device
from torch.distributed.fsdp._common_utils import (
@ -289,6 +290,14 @@ def _full_pre_state_dict_hook(
is supported in ``nn.Module``, this hook will be registered as a hook in
``nn.Module``.
"""
if getattr(fsdp_state, "_device_mesh", False):
parent_mesh = mesh_resources.get_parent_mesh(fsdp_state._device_mesh)
if parent_mesh:
raise RuntimeError(
f"Found FSDP's device_mesh {fsdp_state._device_mesh} has a parent device_mesh {parent_mesh}.",
"We do not support FULL_STATE_DICT for 2D FSDP + TP. Please use FSDP SHARDED_STATE_DICT instead.",
)
_common_pre_state_dict_hook(module, fsdp_state)
_common_unshard_pre_state_dict_hook(
module,