FSDP originally uses `_init_from_local_shards_and_global_metadata()` to create a ShardedTensor for sharded_state_dict(). We have seen some non-trivial overhead if the number of tensors is large. Using `_init_from_local_shards_and_global_metadata ` can significantly reduce the overhead. For a model with ~250 tensors in the state_dict trained with 16 GPUs, the original `sharded_state_dict` takes ~1.7 seconds and this PR reduces the overhead to ~0.6 seconds.
Differential Revision: [D38452170](https://our.internmc.facebook.com/intern/diff/D38452170/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82911
Approved by: https://github.com/awgu
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77356
Implement ShardedTensor compatible sharded_state_dict() and load_sharded_state_dict().
Algorithm overview:
sharded_state_dict():
1. Call summon_full_parameters().
2. For each unflattened, non-sharded parameter.
2.1 Call chunk() to get the local shard of the parameter.
2.2 Create a ShardedTensor.
3. Replace the tensor in the state_dict with the newly created ShardedTensor.
load_sharded_state_dict():
1. For each unflattened, sharded parameter (ShardedTensor) in the given state_dict:
1.1 Pop out from the state_dict.
1.2 Do allgather to reconstruct the unflattened, non-sharded parameter.
2. Create a FlatParameter with the unflattened, non-sharded parameters.
3. Shard the newly created FlatParameter.
4. Insert the new FlatParameter into the state_dict.
Differential Revision: [D36284983](https://our.internmc.facebook.com/intern/diff/D36284983/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36284983/)!
Approved by: https://github.com/zhaojuanmao