[FSDP2] provide public API to share cuda streams across roots (#165024)

for pipeline parallel, we can have multiple FSDP roots (chunks)
```
model = nn.Sequential([chunk0, chunk1])
fully_shard(model.chunk0)
fully_shard(model.chunk1)
```

we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation
```
from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([model.chunk0, model.chunk1])
```

unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context`

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165024
Approved by: https://github.com/mori360
This commit is contained in:
Wei Feng
2025-10-13 14:03:57 -07:00
committed by PyTorch MergeBot
parent 9b6be53326
commit 6918f17114
6 changed files with 192 additions and 1 deletions

View File

@ -123,3 +123,7 @@ The frontend API is `fully_shard` that can be called on a `module`:
.. autoclass:: CPUOffloadPolicy
:members:
```
```{eval-rst}
.. autofunction:: share_comm_ctx
```

View File

@ -6,7 +6,7 @@ import functools
import itertools
import unittest
from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from typing import Any, Optional, Union
import torch
@ -24,6 +24,11 @@ from torch.distributed.fsdp import (
fully_shard,
OffloadPolicy,
register_fsdp_forward_method,
share_comm_ctx,
)
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
foreach_all_gather,
foreach_reduce,
)
from torch.distributed.tensor import DTensor, init_device_mesh, Shard
from torch.distributed.tensor.debug import CommDebugMode
@ -39,6 +44,8 @@ from torch.testing._internal.common_fsdp import (
MLP,
MLPStack,
patch_all_gather,
patch_foreach_all_gather,
patch_foreach_reduce,
patch_reduce_scatter,
)
from torch.testing._internal.common_utils import (
@ -1487,6 +1494,116 @@ class TestFullyShardCustomForwardMethod(FSDPTest):
check_sharded_parity(self, ref_model, model)
class TestFullyShardShareCommContext(FSDPTest):
@property
def world_size(self) -> int:
return min(torch.get_device_module(device_type).device_count(), 2)
@skip_if_lt_x_gpu(2)
def test_share_comm_context(self):
torch.manual_seed(42)
n_layers = 3
lin_dim = 16
model = nn.Sequential(
*[MLP(lin_dim, torch.device("cpu")) for _ in range(n_layers)]
)
ref_model = copy.deepcopy(model).to(device_type)
for layer in model:
fully_shard(layer)
layer._get_fsdp_state()._lazy_init()
share_comm_ctx(list(model))
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn(4, 3, lin_dim, device=device_type.type)
ref_loss = ref_model(inp).sum()
all_gather_streams = set()
reduce_scatter_streams = set()
from torch.distributed.fsdp._fully_shard._fsdp_api import (
AllGather,
ReduceScatter,
)
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
orig_foreach_all_gather = foreach_all_gather
def foreach_all_gather_with_assert(
fsdp_params: list[FSDPParam],
group: dist.ProcessGroup,
async_op: bool,
all_gather_copy_in_stream: torch.Stream,
all_gather_stream: torch.Stream,
device: torch.device,
all_gather_comm: AllGather,
):
nonlocal all_gather_streams
all_gather_streams.add(all_gather_stream)
return orig_foreach_all_gather(
fsdp_params,
group,
async_op,
all_gather_copy_in_stream,
all_gather_stream,
device,
all_gather_comm,
)
orig_foreach_reduce = foreach_reduce
@torch.no_grad()
def foreach_reduce_with_assert(
fsdp_params: list[FSDPParam],
unsharded_grads: list[torch.Tensor],
reduce_scatter_group: dist.ProcessGroup,
reduce_scatter_stream: torch.Stream,
reduce_scatter_comm: ReduceScatter,
orig_dtype: Optional[torch.dtype],
reduce_dtype: Optional[torch.dtype],
device: torch.device,
gradient_divide_factor: Optional[float],
all_reduce_group: Optional[dist.ProcessGroup], # not `None` iff HSDP
all_reduce_stream: torch.Stream,
all_reduce_grads: bool,
partial_reduce_output: Optional[torch.Tensor], # only used for HSDP
all_reduce_hook: Optional[Callable[[torch.Tensor], None]],
force_sum_reduction_for_comms: bool = False,
):
nonlocal reduce_scatter_streams
reduce_scatter_streams.add(reduce_scatter_stream)
return orig_foreach_reduce(
fsdp_params,
unsharded_grads,
reduce_scatter_group,
reduce_scatter_stream,
reduce_scatter_comm,
orig_dtype,
reduce_dtype,
device,
gradient_divide_factor,
all_reduce_group,
all_reduce_stream,
all_reduce_grads,
partial_reduce_output,
all_reduce_hook,
force_sum_reduction_for_comms,
)
with (
patch_foreach_all_gather(foreach_all_gather_with_assert),
patch_foreach_reduce(foreach_reduce_with_assert),
):
loss = model(inp).sum()
self.assertEqual(ref_loss, loss)
ref_loss.backward()
loss.backward()
for param in ref_model.parameters():
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
self.assertEqual(len(all_gather_streams), 1)
self.assertEqual(len(reduce_scatter_streams), 1)
check_sharded_parity(self, ref_model, model)
class TestFullyShardWorldSize1(FSDPTest):
@property
def world_size(self) -> int:

View File

@ -6,6 +6,7 @@ from ._fully_shard import (
MixedPrecisionPolicy,
OffloadPolicy,
register_fsdp_forward_method,
share_comm_ctx,
UnshardHandle,
)
from .fully_sharded_data_parallel import (
@ -54,6 +55,7 @@ __all__ = [
"OffloadPolicy",
"register_fsdp_forward_method",
"UnshardHandle",
"share_comm_ctx",
]
# Set namespace for exposed private names
@ -64,3 +66,4 @@ MixedPrecisionPolicy.__module__ = "torch.distributed.fsdp"
OffloadPolicy.__module__ = "torch.distributed.fsdp"
register_fsdp_forward_method.__module__ = "torch.distributed.fsdp"
UnshardHandle.__module__ = "torch.distributed.fsdp"
share_comm_ctx.__module__ = "torch.distributed.fsdp"

View File

@ -3,6 +3,7 @@ from ._fully_shard import (
FSDPModule,
fully_shard,
register_fsdp_forward_method,
share_comm_ctx,
UnshardHandle,
)
@ -15,4 +16,5 @@ __all__ = [
"OffloadPolicy",
"register_fsdp_forward_method",
"UnshardHandle",
"share_comm_ctx",
]

View File

@ -39,6 +39,7 @@ __all__ = [
"register_fsdp_forward_method",
"get_cls_to_fsdp_cls",
"disable_fsdp_module_new_init",
"share_comm_ctx",
]
@ -711,6 +712,34 @@ def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None:
)
def share_comm_ctx(modules: list[FSDPModule]) -> None:
"""
Share cuda streams for multiple FSDPModules
Example usage:
from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([fsdp_model_1, fsdp_model_2, ...])
For Pipeline Parallelism (PP), each model chunk is a FSDP root. We want
to share cuda streams for all-gather, reduce-scatter, and all-reduce.
This avoids allocating inter-stream memory framgmentation
Args:
modules (List[FSDPModule]): modules to share cuda streams
"""
if len(modules) == 0:
return
for module in modules:
if not isinstance(module, FSDPModule):
raise ValueError(f"Expects list of FSDPModules but got {module}")
fsdp_states = [module._get_fsdp_state() for module in modules]
comm_ctx = fsdp_states[0]._comm_ctx
for fsdp_state in fsdp_states[1:]:
fsdp_state._comm_ctx = comm_ctx
if fsdp_param_group := fsdp_state._fsdp_param_group:
fsdp_param_group.comm_ctx = comm_ctx
def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None:
for module in modules:
if not isinstance(module, FSDPModule):

View File

@ -997,6 +997,42 @@ def patch_all_gather(new_all_gather_into_tensor: Callable):
dist.all_gather_into_tensor = orig_all_gather
@contextlib.contextmanager
def patch_foreach_all_gather(new_foreach_all_gather: Callable):
orig_foreach_all_gather = (
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather
)
dist.barrier()
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
new_foreach_all_gather
)
try:
yield
finally:
dist.barrier()
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
orig_foreach_all_gather
)
@contextlib.contextmanager
def patch_foreach_reduce(new_foreach_reduce: Callable):
orig_foreach_foreach_reduce = (
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce
)
dist.barrier()
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
new_foreach_reduce
)
try:
yield
finally:
dist.barrier()
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
orig_foreach_foreach_reduce
)
@contextlib.contextmanager
def patch_reduce_scatter(new_reduce_scatter_tensor: Callable):
orig_reduce_scatter = dist.reduce_scatter_tensor