mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FSDP2] provide public API to share cuda streams across roots (#165024)
for pipeline parallel, we can have multiple FSDP roots (chunks) ``` model = nn.Sequential([chunk0, chunk1]) fully_shard(model.chunk0) fully_shard(model.chunk1) ``` we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation ``` from torch.distributed.fsdp import share_comm_ctx share_comm_ctx([model.chunk0, model.chunk1]) ``` unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context` Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/165024 Approved by: https://github.com/mori360
This commit is contained in:
committed by
PyTorch MergeBot
parent
9b6be53326
commit
6918f17114
@ -123,3 +123,7 @@ The frontend API is `fully_shard` that can be called on a `module`:
|
||||
.. autoclass:: CPUOffloadPolicy
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: share_comm_ctx
|
||||
```
|
||||
|
Reference in New Issue
Block a user