Revert "[HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)"

This reverts commit cc220e45a80d7c01a4a58b0f386ca07236a6927a.

Reverted https://github.com/pytorch/pytorch/pull/107533 on behalf of https://github.com/huydhn due to Sorry for reverting this, but it is failing in trunk with the same failure on test_dynamo_distributed cc220e45a8 ([comment](https://github.com/pytorch/pytorch/pull/107533#issuecomment-1701983247))
This commit is contained in:
PyTorch MergeBot
2023-09-01 01:26:28 +00:00
parent 8289ad8e5e
commit ab5b4c4419
11 changed files with 133 additions and 467 deletions

View File

@ -1,5 +1,4 @@
import contextlib
import copy
import logging
import math
import warnings
@ -539,7 +538,7 @@ def _sharded_post_state_dict_hook(
def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str):
param = state_dict[fqn]
if not fsdp_state._state_dict_config._use_dtensor:
if not fsdp_state._state_dict_config.use_dtensor:
sharded_tensor = _ext_chunk_tensor(
tensor=param,
rank=fsdp_state.rank,
@ -602,7 +601,7 @@ def _sharded_pre_load_state_dict_hook(
fqn_from_global_root = f"{prefix}{fqn}"
param = state_dict.pop(fqn_from_global_root)
if not fsdp_state._state_dict_config._use_dtensor:
if not fsdp_state._state_dict_config.use_dtensor:
# All-gather the param (ShardedTensor)
param, shards = _ext_pre_load_state_dict_transform(param)
@ -653,11 +652,9 @@ def _sharded_pre_load_state_dict_hook(
else:
if param.device != fsdp_state._device_mesh.device_type:
param = param.to(fsdp_state._device_mesh.device_type)
placements = list(copy.deepcopy(param.placements))
placements[-1] = Replicate()
param = param.redistribute(
device_mesh=param.device_mesh,
placements=placements,
device_mesh=param.device_mesh, placements=[Replicate()]
)
state_dict[fqn_from_global_root] = param.to_local()