[HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)

This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
This commit is contained in:
wz337
2023-09-01 00:15:00 +00:00
committed by PyTorch MergeBot
parent a29b9101fa
commit cc220e45a8
11 changed files with 466 additions and 132 deletions

View File

@ -1,4 +1,5 @@
import contextlib
import copy
import logging
import math
import warnings
@ -538,7 +539,7 @@ def _sharded_post_state_dict_hook(
def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str):
param = state_dict[fqn]
if not fsdp_state._state_dict_config.use_dtensor:
if not fsdp_state._state_dict_config._use_dtensor:
sharded_tensor = _ext_chunk_tensor(
tensor=param,
rank=fsdp_state.rank,
@ -601,7 +602,7 @@ def _sharded_pre_load_state_dict_hook(
fqn_from_global_root = f"{prefix}{fqn}"
param = state_dict.pop(fqn_from_global_root)
if not fsdp_state._state_dict_config.use_dtensor:
if not fsdp_state._state_dict_config._use_dtensor:
# All-gather the param (ShardedTensor)
param, shards = _ext_pre_load_state_dict_transform(param)
@ -652,9 +653,11 @@ def _sharded_pre_load_state_dict_hook(
else:
if param.device != fsdp_state._device_mesh.device_type:
param = param.to(fsdp_state._device_mesh.device_type)
placements = list(copy.deepcopy(param.placements))
placements[-1] = Replicate()
param = param.redistribute(
device_mesh=param.device_mesh, placements=[Replicate()]
device_mesh=param.device_mesh,
placements=placements,
)
state_dict[fqn_from_global_root] = param.to_local()