mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FSDP][state_dict] Expose optimizer state_dict config (#105949)
Optimizer state_dict config are not exposed. This PR exposes the 2 dataclass. Differential Revision: [D47766024](https://our.internmc.facebook.com/intern/diff/D47766024/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/105949 Approved by: https://github.com/rohan-varma
This commit is contained in:
committed by
PyTorch MergeBot
parent
63e9b5481d
commit
7ba513b6e4
@ -18,29 +18,29 @@ FullyShardedDataParallel
|
||||
.. autoclass:: torch.distributed.fsdp.CPUOffload
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.distributed.fsdp.api.StateDictConfig
|
||||
.. autoclass:: torch.distributed.fsdp.StateDictConfig
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.distributed.fsdp.api.FullStateDictConfig
|
||||
.. autoclass:: torch.distributed.fsdp.FullStateDictConfig
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.distributed.fsdp.api.ShardedStateDictConfig
|
||||
.. autoclass:: torch.distributed.fsdp.ShardedStateDictConfig
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.distributed.fsdp.api.LocalStateDictConfig
|
||||
.. autoclass:: torch.distributed.fsdp.LocalStateDictConfig
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.distributed.fsdp.api.OptimStateDictConfig
|
||||
.. autoclass:: torch.distributed.fsdp.OptimStateDictConfig
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.distributed.fsdp.api.FullOptimStateDictConfig
|
||||
.. autoclass:: torch.distributed.fsdp.FullOptimStateDictConfig
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.distributed.fsdp.api.ShardedOptimStateDictConfig
|
||||
.. autoclass:: torch.distributed.fsdp.ShardedOptimStateDictConfig
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.distributed.fsdp.api.LocalOptimStateDictConfig
|
||||
.. autoclass:: torch.distributed.fsdp.LocalOptimStateDictConfig
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.distributed.fsdp.api.StateDictSettings
|
||||
.. autoclass:: torch.distributed.fsdp.StateDictSettings
|
||||
:members:
|
||||
|
@ -2,12 +2,18 @@ from .flat_param import FlatParameter
|
||||
from .fully_sharded_data_parallel import (
|
||||
BackwardPrefetch,
|
||||
CPUOffload,
|
||||
FullOptimStateDictConfig,
|
||||
FullStateDictConfig,
|
||||
FullyShardedDataParallel,
|
||||
LocalOptimStateDictConfig,
|
||||
LocalStateDictConfig,
|
||||
MixedPrecision,
|
||||
OptimStateDictConfig,
|
||||
OptimStateKeyType,
|
||||
ShardedOptimStateDictConfig,
|
||||
ShardedStateDictConfig,
|
||||
ShardingStrategy,
|
||||
StateDictConfig,
|
||||
StateDictSettings,
|
||||
StateDictType,
|
||||
)
|
||||
|
@ -284,8 +284,8 @@ class StateDictConfig:
|
||||
values to CPU, and if ``False``, then FSDP keeps them on GPU.
|
||||
(Default: ``False``)
|
||||
use_dtensor (bool): If ``True``, then FSDP saves the state dict values
|
||||
as ``DTensor``, and if ``False``, then FSDP saves them as
|
||||
``ShardedTensor``. (Default: ``False``)
|
||||
as ``DTensor`` if the value is sharded, and if ``False``, then FSDP
|
||||
saves them as ``ShardedTensor``. (Default: ``False``)
|
||||
"""
|
||||
|
||||
offload_to_cpu: bool = False
|
||||
@ -353,8 +353,8 @@ class OptimStateDictConfig:
|
||||
original device (which is GPU unless parameter CPU offloading is
|
||||
enabled). (Default: ``True``)
|
||||
use_dtensor (bool): If ``True``, then FSDP saves the state dict values
|
||||
as ``DTensor``, and if ``False``, then FSDP saves them as
|
||||
``ShardedTensor``. (Default: ``False``)
|
||||
as ``DTensor`` if the value is sharded, and if ``False``, then FSDP
|
||||
saves them as ``ShardedTensor``. (Default: ``False``)
|
||||
"""
|
||||
|
||||
# TODO: actually use this flag in the _optim_utils.py
|
||||
|
Reference in New Issue
Block a user