[FSDP2] provide public API to share cuda streams across roots (#165024)

for pipeline parallel, we can have multiple FSDP roots (chunks)
```
model = nn.Sequential([chunk0, chunk1])
fully_shard(model.chunk0)
fully_shard(model.chunk1)
```

we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation
```
from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([model.chunk0, model.chunk1])
```

unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context`

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165024
Approved by: https://github.com/mori360
This commit is contained in:
Wei Feng
2025-10-13 14:03:57 -07:00
committed by PyTorch MergeBot
parent 9b6be53326
commit 6918f17114
6 changed files with 192 additions and 1 deletions

View File

@ -123,3 +123,7 @@ The frontend API is `fully_shard` that can be called on a `module`:
.. autoclass:: CPUOffloadPolicy .. autoclass:: CPUOffloadPolicy
:members: :members:
``` ```
```{eval-rst}
.. autofunction:: share_comm_ctx
```

View File

@ -6,7 +6,7 @@ import functools
import itertools import itertools
import unittest import unittest
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Callable, Iterable
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
@ -24,6 +24,11 @@ from torch.distributed.fsdp import (
fully_shard, fully_shard,
OffloadPolicy, OffloadPolicy,
register_fsdp_forward_method, register_fsdp_forward_method,
share_comm_ctx,
)
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
foreach_all_gather,
foreach_reduce,
) )
from torch.distributed.tensor import DTensor, init_device_mesh, Shard from torch.distributed.tensor import DTensor, init_device_mesh, Shard
from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.debug import CommDebugMode
@ -39,6 +44,8 @@ from torch.testing._internal.common_fsdp import (
MLP, MLP,
MLPStack, MLPStack,
patch_all_gather, patch_all_gather,
patch_foreach_all_gather,
patch_foreach_reduce,
patch_reduce_scatter, patch_reduce_scatter,
) )
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
@ -1487,6 +1494,116 @@ class TestFullyShardCustomForwardMethod(FSDPTest):
check_sharded_parity(self, ref_model, model) check_sharded_parity(self, ref_model, model)
class TestFullyShardShareCommContext(FSDPTest):
@property
def world_size(self) -> int:
return min(torch.get_device_module(device_type).device_count(), 2)
@skip_if_lt_x_gpu(2)
def test_share_comm_context(self):
torch.manual_seed(42)
n_layers = 3
lin_dim = 16
model = nn.Sequential(
*[MLP(lin_dim, torch.device("cpu")) for _ in range(n_layers)]
)
ref_model = copy.deepcopy(model).to(device_type)
for layer in model:
fully_shard(layer)
layer._get_fsdp_state()._lazy_init()
share_comm_ctx(list(model))
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn(4, 3, lin_dim, device=device_type.type)
ref_loss = ref_model(inp).sum()
all_gather_streams = set()
reduce_scatter_streams = set()
from torch.distributed.fsdp._fully_shard._fsdp_api import (
AllGather,
ReduceScatter,
)
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
orig_foreach_all_gather = foreach_all_gather
def foreach_all_gather_with_assert(
fsdp_params: list[FSDPParam],
group: dist.ProcessGroup,
async_op: bool,
all_gather_copy_in_stream: torch.Stream,
all_gather_stream: torch.Stream,
device: torch.device,
all_gather_comm: AllGather,
):
nonlocal all_gather_streams
all_gather_streams.add(all_gather_stream)
return orig_foreach_all_gather(
fsdp_params,
group,
async_op,
all_gather_copy_in_stream,
all_gather_stream,
device,
all_gather_comm,
)
orig_foreach_reduce = foreach_reduce
@torch.no_grad()
def foreach_reduce_with_assert(
fsdp_params: list[FSDPParam],
unsharded_grads: list[torch.Tensor],
reduce_scatter_group: dist.ProcessGroup,
reduce_scatter_stream: torch.Stream,
reduce_scatter_comm: ReduceScatter,
orig_dtype: Optional[torch.dtype],
reduce_dtype: Optional[torch.dtype],
device: torch.device,
gradient_divide_factor: Optional[float],
all_reduce_group: Optional[dist.ProcessGroup], # not `None` iff HSDP
all_reduce_stream: torch.Stream,
all_reduce_grads: bool,
partial_reduce_output: Optional[torch.Tensor], # only used for HSDP
all_reduce_hook: Optional[Callable[[torch.Tensor], None]],
force_sum_reduction_for_comms: bool = False,
):
nonlocal reduce_scatter_streams
reduce_scatter_streams.add(reduce_scatter_stream)
return orig_foreach_reduce(
fsdp_params,
unsharded_grads,
reduce_scatter_group,
reduce_scatter_stream,
reduce_scatter_comm,
orig_dtype,
reduce_dtype,
device,
gradient_divide_factor,
all_reduce_group,
all_reduce_stream,
all_reduce_grads,
partial_reduce_output,
all_reduce_hook,
force_sum_reduction_for_comms,
)
with (
patch_foreach_all_gather(foreach_all_gather_with_assert),
patch_foreach_reduce(foreach_reduce_with_assert),
):
loss = model(inp).sum()
self.assertEqual(ref_loss, loss)
ref_loss.backward()
loss.backward()
for param in ref_model.parameters():
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
self.assertEqual(len(all_gather_streams), 1)
self.assertEqual(len(reduce_scatter_streams), 1)
check_sharded_parity(self, ref_model, model)
class TestFullyShardWorldSize1(FSDPTest): class TestFullyShardWorldSize1(FSDPTest):
@property @property
def world_size(self) -> int: def world_size(self) -> int:

View File

@ -6,6 +6,7 @@ from ._fully_shard import (
MixedPrecisionPolicy, MixedPrecisionPolicy,
OffloadPolicy, OffloadPolicy,
register_fsdp_forward_method, register_fsdp_forward_method,
share_comm_ctx,
UnshardHandle, UnshardHandle,
) )
from .fully_sharded_data_parallel import ( from .fully_sharded_data_parallel import (
@ -54,6 +55,7 @@ __all__ = [
"OffloadPolicy", "OffloadPolicy",
"register_fsdp_forward_method", "register_fsdp_forward_method",
"UnshardHandle", "UnshardHandle",
"share_comm_ctx",
] ]
# Set namespace for exposed private names # Set namespace for exposed private names
@ -64,3 +66,4 @@ MixedPrecisionPolicy.__module__ = "torch.distributed.fsdp"
OffloadPolicy.__module__ = "torch.distributed.fsdp" OffloadPolicy.__module__ = "torch.distributed.fsdp"
register_fsdp_forward_method.__module__ = "torch.distributed.fsdp" register_fsdp_forward_method.__module__ = "torch.distributed.fsdp"
UnshardHandle.__module__ = "torch.distributed.fsdp" UnshardHandle.__module__ = "torch.distributed.fsdp"
share_comm_ctx.__module__ = "torch.distributed.fsdp"

View File

@ -3,6 +3,7 @@ from ._fully_shard import (
FSDPModule, FSDPModule,
fully_shard, fully_shard,
register_fsdp_forward_method, register_fsdp_forward_method,
share_comm_ctx,
UnshardHandle, UnshardHandle,
) )
@ -15,4 +16,5 @@ __all__ = [
"OffloadPolicy", "OffloadPolicy",
"register_fsdp_forward_method", "register_fsdp_forward_method",
"UnshardHandle", "UnshardHandle",
"share_comm_ctx",
] ]

View File

@ -39,6 +39,7 @@ __all__ = [
"register_fsdp_forward_method", "register_fsdp_forward_method",
"get_cls_to_fsdp_cls", "get_cls_to_fsdp_cls",
"disable_fsdp_module_new_init", "disable_fsdp_module_new_init",
"share_comm_ctx",
] ]
@ -711,6 +712,34 @@ def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None:
) )
def share_comm_ctx(modules: list[FSDPModule]) -> None:
"""
Share cuda streams for multiple FSDPModules
Example usage:
from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([fsdp_model_1, fsdp_model_2, ...])
For Pipeline Parallelism (PP), each model chunk is a FSDP root. We want
to share cuda streams for all-gather, reduce-scatter, and all-reduce.
This avoids allocating inter-stream memory framgmentation
Args:
modules (List[FSDPModule]): modules to share cuda streams
"""
if len(modules) == 0:
return
for module in modules:
if not isinstance(module, FSDPModule):
raise ValueError(f"Expects list of FSDPModules but got {module}")
fsdp_states = [module._get_fsdp_state() for module in modules]
comm_ctx = fsdp_states[0]._comm_ctx
for fsdp_state in fsdp_states[1:]:
fsdp_state._comm_ctx = comm_ctx
if fsdp_param_group := fsdp_state._fsdp_param_group:
fsdp_param_group.comm_ctx = comm_ctx
def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None: def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None:
for module in modules: for module in modules:
if not isinstance(module, FSDPModule): if not isinstance(module, FSDPModule):

View File

@ -997,6 +997,42 @@ def patch_all_gather(new_all_gather_into_tensor: Callable):
dist.all_gather_into_tensor = orig_all_gather dist.all_gather_into_tensor = orig_all_gather
@contextlib.contextmanager
def patch_foreach_all_gather(new_foreach_all_gather: Callable):
orig_foreach_all_gather = (
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather
)
dist.barrier()
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
new_foreach_all_gather
)
try:
yield
finally:
dist.barrier()
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
orig_foreach_all_gather
)
@contextlib.contextmanager
def patch_foreach_reduce(new_foreach_reduce: Callable):
orig_foreach_foreach_reduce = (
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce
)
dist.barrier()
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
new_foreach_reduce
)
try:
yield
finally:
dist.barrier()
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
orig_foreach_foreach_reduce
)
@contextlib.contextmanager @contextlib.contextmanager
def patch_reduce_scatter(new_reduce_scatter_tensor: Callable): def patch_reduce_scatter(new_reduce_scatter_tensor: Callable):
orig_reduce_scatter = dist.reduce_scatter_tensor orig_reduce_scatter = dist.reduce_scatter_tensor