[FSDP2] provide public API to share cuda streams across roots (#165024)

for pipeline parallel, we can have multiple FSDP roots (chunks)
```
model = nn.Sequential([chunk0, chunk1])
fully_shard(model.chunk0)
fully_shard(model.chunk1)
```

we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation
```
from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([model.chunk0, model.chunk1])
```

unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context`

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165024
Approved by: https://github.com/mori360
This commit is contained in:
Wei Feng
2025-10-13 14:03:57 -07:00
committed by PyTorch MergeBot
parent 9b6be53326
commit 6918f17114
6 changed files with 192 additions and 1 deletions

View File

@ -997,6 +997,42 @@ def patch_all_gather(new_all_gather_into_tensor: Callable):
dist.all_gather_into_tensor = orig_all_gather
@contextlib.contextmanager
def patch_foreach_all_gather(new_foreach_all_gather: Callable):
orig_foreach_all_gather = (
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather
)
dist.barrier()
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
new_foreach_all_gather
)
try:
yield
finally:
dist.barrier()
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
orig_foreach_all_gather
)
@contextlib.contextmanager
def patch_foreach_reduce(new_foreach_reduce: Callable):
orig_foreach_foreach_reduce = (
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce
)
dist.barrier()
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
new_foreach_reduce
)
try:
yield
finally:
dist.barrier()
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
orig_foreach_foreach_reduce
)
@contextlib.contextmanager
def patch_reduce_scatter(new_reduce_scatter_tensor: Callable):
orig_reduce_scatter = dist.reduce_scatter_tensor