mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
<img width="838" alt="Screenshot 2025-05-08 at 10 51 05 AM" src="https://github.com/user-attachments/assets/4cf43a16-3801-424b-a74f-ede1d41ff052" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/153079 Approved by: https://github.com/mori360
86 lines
3.7 KiB
ReStructuredText
86 lines
3.7 KiB
ReStructuredText
torch.distributed.fsdp.fully_shard
|
|
==================================
|
|
|
|
PyTorch FSDP2 (``fully_shard``)
|
|
-------------------------------
|
|
|
|
PyTorch FSDP2 provides a fully sharded data parallelism (FSDP) implementation
|
|
targeting performant eager-mode while using per-parameter sharding for improved
|
|
usability.
|
|
|
|
- If you are new to FSDP, we recommend that you start with FSDP2 due to improved
|
|
usability. See `TorchTitan <https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md>`_ for code examples.
|
|
- If you are currently using FSDP1, consider evaluating the following
|
|
differences to see if you should switch to FSDP2:
|
|
|
|
Compared to PyTorch FSDP1 (``FullyShardedDataParallel``):
|
|
|
|
- FSDP2 uses ``DTensor``-based dim-0 per-parameter sharding for a simpler
|
|
sharding representation compared to FSDP1's flat-parameter sharding, while
|
|
preserving similar throughput performance. More specifically, FSDP2 chunks
|
|
each parameter on dim-0 across the data parallel workers (using
|
|
``torch.chunk(dim=0)``), whereas FSDP1 flattens, concatenates, and chunks a
|
|
group of tensors together, making reasoning about what data is present on
|
|
each worker and resharding to different parallelisms complex. Per-parameter
|
|
sharding provides a more intuitive user experience, relaxes constraints
|
|
around frozen parameters, and allows for communication-free (sharded) state
|
|
dicts, which otherwise require all-gathers in FSDP1.
|
|
- FSDP2 implements a different memory management approach to handle the
|
|
multi-stream usages that avoids ``torch.Tensor.record_stream``. This ensures
|
|
deterministic and expected memory usage and does not require blocking the CPU
|
|
like in FSDP1's ``limit_all_gathers=True``.
|
|
- FSDP2 exposes APIs for manual control over prefetching and collective
|
|
scheduling, allowing power users more customization. See the methods on
|
|
``FSDPModule`` below for details.
|
|
- FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly
|
|
support full state dicts. Instead, users can reshard the sharded state dicts
|
|
containing ``DTensor`` s to full state dicts themselves using ``DTensor``
|
|
APIs like ``DTensor.full_tensor()`` or by using higher-level APIs like
|
|
`PyTorch Distributed Checkpoint <https://pytorch.org/docs/stable/distributed.checkpoint.html>`_ 's
|
|
distributed state dict APIs. Also, some other args have been removed; see
|
|
`here <https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md>`_ for
|
|
details.
|
|
|
|
If you are onboarding FSDP for the first time or if any of the above appeals to
|
|
your use case, we recommend that you consider using FSDP2.
|
|
|
|
See `this RFC <https://github.com/pytorch/pytorch/issues/114299>`_ for details
|
|
on system design and implementation.
|
|
|
|
.. note::
|
|
``torch.distributed.fsdp.fully_shard`` is currently in prototype state and
|
|
under development. The core API will likely not change, but we may make some
|
|
API changes if necessary.
|
|
|
|
.. currentmodule:: torch.distributed.fsdp
|
|
|
|
The frontend API is ``fully_shard`` that can be called on a ``module``:
|
|
|
|
.. autofunction:: fully_shard
|
|
|
|
Calling ``fully_shard(module)`` dynamically constructs a new class that
|
|
subclasses ``type(module)`` and an FSDP class ``FSDPModule``. For example, if
|
|
we call ``fully_shard(linear)`` on a module ``linear: nn.Linear``, then FSDP
|
|
constructs a new class ``FSDPLinear`` and changes ``linear`` 's type to this.
|
|
Otherwise, ``fully_shard`` does not change the module structure and parameter
|
|
fully-qualified names. The class ``FSDPModule`` allows providing some
|
|
FSDP-specific methods on the module.
|
|
|
|
.. autoclass:: FSDPModule
|
|
:members:
|
|
:member-order: bysource
|
|
|
|
.. autoclass:: UnshardHandle
|
|
:members:
|
|
|
|
.. autofunction:: register_fsdp_forward_method
|
|
|
|
.. autoclass:: MixedPrecisionPolicy
|
|
:members:
|
|
|
|
.. autoclass:: OffloadPolicy
|
|
:members:
|
|
|
|
.. autoclass:: CPUOffloadPolicy
|
|
:members:
|