mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[2D] Enable 2D DTensor state_dict for FSDP + TP (#110846)
This PR adds a `chunk_dtensor()` method to fsdp/_fsdp_extensions.py and the actual implementation of `chunk_dtensor()` in tensor/parallel/fsdp.py. This enables FSDP to return 2D DTensor state_dict when composing FSDP with TP. cc. @fegin Pull Request resolved: https://github.com/pytorch/pytorch/pull/110846 Approved by: https://github.com/fegin, https://github.com/wanchaol ghstack dependencies: #110831
This commit is contained in:
@ -18,6 +18,7 @@ from torch.distributed._shard.sharded_tensor import (
|
||||
ShardedTensor,
|
||||
)
|
||||
from torch.distributed._tensor import DTensor, Replicate
|
||||
from torch.distributed._tensor.device_mesh import mesh_resources
|
||||
|
||||
from torch.distributed.distributed_c10d import _get_pg_default_device
|
||||
from torch.distributed.fsdp._common_utils import (
|
||||
@ -289,6 +290,14 @@ def _full_pre_state_dict_hook(
|
||||
is supported in ``nn.Module``, this hook will be registered as a hook in
|
||||
``nn.Module``.
|
||||
"""
|
||||
if getattr(fsdp_state, "_device_mesh", False):
|
||||
parent_mesh = mesh_resources.get_parent_mesh(fsdp_state._device_mesh)
|
||||
if parent_mesh:
|
||||
raise RuntimeError(
|
||||
f"Found FSDP's device_mesh {fsdp_state._device_mesh} has a parent device_mesh {parent_mesh}.",
|
||||
"We do not support FULL_STATE_DICT for 2D FSDP + TP. Please use FSDP SHARDED_STATE_DICT instead.",
|
||||
)
|
||||
|
||||
_common_pre_state_dict_hook(module, fsdp_state)
|
||||
_common_unshard_pre_state_dict_hook(
|
||||
module,
|
||||
|
||||
Reference in New Issue
Block a user