[FSDP][state_dict] Expose optimizer state_dict config (#105949)

Optimizer state_dict config are not exposed. This PR exposes the 2 dataclass.

Differential Revision: [D47766024](https://our.internmc.facebook.com/intern/diff/D47766024/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105949
Approved by: https://github.com/rohan-varma
This commit is contained in:
Chien-Chin Huang
2023-08-18 09:45:12 -07:00
committed by PyTorch MergeBot
parent 63e9b5481d
commit 7ba513b6e4
3 changed files with 19 additions and 13 deletions

View File

@ -18,29 +18,29 @@ FullyShardedDataParallel
.. autoclass:: torch.distributed.fsdp.CPUOffload
:members:
.. autoclass:: torch.distributed.fsdp.api.StateDictConfig
.. autoclass:: torch.distributed.fsdp.StateDictConfig
:members:
.. autoclass:: torch.distributed.fsdp.api.FullStateDictConfig
.. autoclass:: torch.distributed.fsdp.FullStateDictConfig
:members:
.. autoclass:: torch.distributed.fsdp.api.ShardedStateDictConfig
.. autoclass:: torch.distributed.fsdp.ShardedStateDictConfig
:members:
.. autoclass:: torch.distributed.fsdp.api.LocalStateDictConfig
.. autoclass:: torch.distributed.fsdp.LocalStateDictConfig
:members:
.. autoclass:: torch.distributed.fsdp.api.OptimStateDictConfig
.. autoclass:: torch.distributed.fsdp.OptimStateDictConfig
:members:
.. autoclass:: torch.distributed.fsdp.api.FullOptimStateDictConfig
.. autoclass:: torch.distributed.fsdp.FullOptimStateDictConfig
:members:
.. autoclass:: torch.distributed.fsdp.api.ShardedOptimStateDictConfig
.. autoclass:: torch.distributed.fsdp.ShardedOptimStateDictConfig
:members:
.. autoclass:: torch.distributed.fsdp.api.LocalOptimStateDictConfig
.. autoclass:: torch.distributed.fsdp.LocalOptimStateDictConfig
:members:
.. autoclass:: torch.distributed.fsdp.api.StateDictSettings
.. autoclass:: torch.distributed.fsdp.StateDictSettings
:members:

View File

@ -2,12 +2,18 @@ from .flat_param import FlatParameter
from .fully_sharded_data_parallel import (
BackwardPrefetch,
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel,
LocalOptimStateDictConfig,
LocalStateDictConfig,
MixedPrecision,
OptimStateDictConfig,
OptimStateKeyType,
ShardedOptimStateDictConfig,
ShardedStateDictConfig,
ShardingStrategy,
StateDictConfig,
StateDictSettings,
StateDictType,
)

View File

@ -284,8 +284,8 @@ class StateDictConfig:
values to CPU, and if ``False``, then FSDP keeps them on GPU.
(Default: ``False``)
use_dtensor (bool): If ``True``, then FSDP saves the state dict values
as ``DTensor``, and if ``False``, then FSDP saves them as
``ShardedTensor``. (Default: ``False``)
as ``DTensor`` if the value is sharded, and if ``False``, then FSDP
saves them as ``ShardedTensor``. (Default: ``False``)
"""
offload_to_cpu: bool = False
@ -353,8 +353,8 @@ class OptimStateDictConfig:
original device (which is GPU unless parameter CPU offloading is
enabled). (Default: ``True``)
use_dtensor (bool): If ``True``, then FSDP saves the state dict values
as ``DTensor``, and if ``False``, then FSDP saves them as
``ShardedTensor``. (Default: ``False``)
as ``DTensor`` if the value is sharded, and if ``False``, then FSDP
saves them as ``ShardedTensor``. (Default: ``False``)
"""
# TODO: actually use this flag in the _optim_utils.py