mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)
This PR: 1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg. 2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict. 3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict. Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533 Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
@ -538,7 +539,7 @@ def _sharded_post_state_dict_hook(
|
||||
|
||||
def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str):
|
||||
param = state_dict[fqn]
|
||||
if not fsdp_state._state_dict_config.use_dtensor:
|
||||
if not fsdp_state._state_dict_config._use_dtensor:
|
||||
sharded_tensor = _ext_chunk_tensor(
|
||||
tensor=param,
|
||||
rank=fsdp_state.rank,
|
||||
@ -601,7 +602,7 @@ def _sharded_pre_load_state_dict_hook(
|
||||
fqn_from_global_root = f"{prefix}{fqn}"
|
||||
param = state_dict.pop(fqn_from_global_root)
|
||||
|
||||
if not fsdp_state._state_dict_config.use_dtensor:
|
||||
if not fsdp_state._state_dict_config._use_dtensor:
|
||||
# All-gather the param (ShardedTensor)
|
||||
param, shards = _ext_pre_load_state_dict_transform(param)
|
||||
|
||||
@ -652,9 +653,11 @@ def _sharded_pre_load_state_dict_hook(
|
||||
else:
|
||||
if param.device != fsdp_state._device_mesh.device_type:
|
||||
param = param.to(fsdp_state._device_mesh.device_type)
|
||||
|
||||
placements = list(copy.deepcopy(param.placements))
|
||||
placements[-1] = Replicate()
|
||||
param = param.redistribute(
|
||||
device_mesh=param.device_mesh, placements=[Replicate()]
|
||||
device_mesh=param.device_mesh,
|
||||
placements=placements,
|
||||
)
|
||||
state_dict[fqn_from_global_root] = param.to_local()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user