diff --git a/docs/source/fsdp.rst b/docs/source/fsdp.rst index c1d2b4e6431a..41883e3c6ed2 100644 --- a/docs/source/fsdp.rst +++ b/docs/source/fsdp.rst @@ -18,29 +18,29 @@ FullyShardedDataParallel .. autoclass:: torch.distributed.fsdp.CPUOffload :members: -.. autoclass:: torch.distributed.fsdp.api.StateDictConfig +.. autoclass:: torch.distributed.fsdp.StateDictConfig :members: -.. autoclass:: torch.distributed.fsdp.api.FullStateDictConfig +.. autoclass:: torch.distributed.fsdp.FullStateDictConfig :members: -.. autoclass:: torch.distributed.fsdp.api.ShardedStateDictConfig +.. autoclass:: torch.distributed.fsdp.ShardedStateDictConfig :members: -.. autoclass:: torch.distributed.fsdp.api.LocalStateDictConfig +.. autoclass:: torch.distributed.fsdp.LocalStateDictConfig :members: -.. autoclass:: torch.distributed.fsdp.api.OptimStateDictConfig +.. autoclass:: torch.distributed.fsdp.OptimStateDictConfig :members: -.. autoclass:: torch.distributed.fsdp.api.FullOptimStateDictConfig +.. autoclass:: torch.distributed.fsdp.FullOptimStateDictConfig :members: -.. autoclass:: torch.distributed.fsdp.api.ShardedOptimStateDictConfig +.. autoclass:: torch.distributed.fsdp.ShardedOptimStateDictConfig :members: -.. autoclass:: torch.distributed.fsdp.api.LocalOptimStateDictConfig +.. autoclass:: torch.distributed.fsdp.LocalOptimStateDictConfig :members: -.. autoclass:: torch.distributed.fsdp.api.StateDictSettings +.. autoclass:: torch.distributed.fsdp.StateDictSettings :members: diff --git a/torch/distributed/fsdp/__init__.py b/torch/distributed/fsdp/__init__.py index b1bffdb25a0e..3bb19d5d9182 100644 --- a/torch/distributed/fsdp/__init__.py +++ b/torch/distributed/fsdp/__init__.py @@ -2,12 +2,18 @@ from .flat_param import FlatParameter from .fully_sharded_data_parallel import ( BackwardPrefetch, CPUOffload, + FullOptimStateDictConfig, FullStateDictConfig, FullyShardedDataParallel, + LocalOptimStateDictConfig, LocalStateDictConfig, MixedPrecision, + OptimStateDictConfig, OptimStateKeyType, + ShardedOptimStateDictConfig, ShardedStateDictConfig, ShardingStrategy, + StateDictConfig, + StateDictSettings, StateDictType, ) diff --git a/torch/distributed/fsdp/api.py b/torch/distributed/fsdp/api.py index 1216c0ccdd1d..3eb106e96c1d 100644 --- a/torch/distributed/fsdp/api.py +++ b/torch/distributed/fsdp/api.py @@ -284,8 +284,8 @@ class StateDictConfig: values to CPU, and if ``False``, then FSDP keeps them on GPU. (Default: ``False``) use_dtensor (bool): If ``True``, then FSDP saves the state dict values - as ``DTensor``, and if ``False``, then FSDP saves them as - ``ShardedTensor``. (Default: ``False``) + as ``DTensor`` if the value is sharded, and if ``False``, then FSDP + saves them as ``ShardedTensor``. (Default: ``False``) """ offload_to_cpu: bool = False @@ -353,8 +353,8 @@ class OptimStateDictConfig: original device (which is GPU unless parameter CPU offloading is enabled). (Default: ``True``) use_dtensor (bool): If ``True``, then FSDP saves the state dict values - as ``DTensor``, and if ``False``, then FSDP saves them as - ``ShardedTensor``. (Default: ``False``) + as ``DTensor`` if the value is sharded, and if ``False``, then FSDP + saves them as ``ShardedTensor``. (Default: ``False``) """ # TODO: actually use this flag in the _optim_utils.py