[FSDP2] provide public API to share cuda streams across roots (#165024)

for pipeline parallel, we can have multiple FSDP roots (chunks)
```
model = nn.Sequential([chunk0, chunk1])
fully_shard(model.chunk0)
fully_shard(model.chunk1)
```

we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation
```
from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([model.chunk0, model.chunk1])
```

unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context`

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165024
Approved by: https://github.com/mori360
This commit is contained in:
Wei Feng
2025-10-13 14:03:57 -07:00
committed by PyTorch MergeBot
parent 9b6be53326
commit 6918f17114
6 changed files with 192 additions and 1 deletions

View File

@ -123,3 +123,7 @@ The frontend API is `fully_shard` that can be called on a `module`:
.. autoclass:: CPUOffloadPolicy
:members:
```
```{eval-rst}
.. autofunction:: share_comm_ctx
```