mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FSDP2] provide public API to share cuda streams across roots (#165024)
for pipeline parallel, we can have multiple FSDP roots (chunks) ``` model = nn.Sequential([chunk0, chunk1]) fully_shard(model.chunk0) fully_shard(model.chunk1) ``` we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation ``` from torch.distributed.fsdp import share_comm_ctx share_comm_ctx([model.chunk0, model.chunk1]) ``` unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context` Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/165024 Approved by: https://github.com/mori360
This commit is contained in:
committed by
PyTorch MergeBot
parent
9b6be53326
commit
6918f17114
@ -997,6 +997,42 @@ def patch_all_gather(new_all_gather_into_tensor: Callable):
|
||||
dist.all_gather_into_tensor = orig_all_gather
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_foreach_all_gather(new_foreach_all_gather: Callable):
|
||||
orig_foreach_all_gather = (
|
||||
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather
|
||||
)
|
||||
dist.barrier()
|
||||
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
|
||||
new_foreach_all_gather
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
dist.barrier()
|
||||
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
|
||||
orig_foreach_all_gather
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_foreach_reduce(new_foreach_reduce: Callable):
|
||||
orig_foreach_foreach_reduce = (
|
||||
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce
|
||||
)
|
||||
dist.barrier()
|
||||
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
|
||||
new_foreach_reduce
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
dist.barrier()
|
||||
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
|
||||
orig_foreach_foreach_reduce
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_reduce_scatter(new_reduce_scatter_tensor: Callable):
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
Reference in New Issue
Block a user