mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FSDP2] provide public API to share cuda streams across roots (#165024)
for pipeline parallel, we can have multiple FSDP roots (chunks) ``` model = nn.Sequential([chunk0, chunk1]) fully_shard(model.chunk0) fully_shard(model.chunk1) ``` we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation ``` from torch.distributed.fsdp import share_comm_ctx share_comm_ctx([model.chunk0, model.chunk1]) ``` unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context` Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/165024 Approved by: https://github.com/mori360
This commit is contained in:
committed by
PyTorch MergeBot
parent
9b6be53326
commit
6918f17114
@ -123,3 +123,7 @@ The frontend API is `fully_shard` that can be called on a `module`:
|
||||
.. autoclass:: CPUOffloadPolicy
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: share_comm_ctx
|
||||
```
|
||||
|
@ -6,7 +6,7 @@ import functools
|
||||
import itertools
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -24,6 +24,11 @@ from torch.distributed.fsdp import (
|
||||
fully_shard,
|
||||
OffloadPolicy,
|
||||
register_fsdp_forward_method,
|
||||
share_comm_ctx,
|
||||
)
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
|
||||
foreach_all_gather,
|
||||
foreach_reduce,
|
||||
)
|
||||
from torch.distributed.tensor import DTensor, init_device_mesh, Shard
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
@ -39,6 +44,8 @@ from torch.testing._internal.common_fsdp import (
|
||||
MLP,
|
||||
MLPStack,
|
||||
patch_all_gather,
|
||||
patch_foreach_all_gather,
|
||||
patch_foreach_reduce,
|
||||
patch_reduce_scatter,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -1487,6 +1494,116 @@ class TestFullyShardCustomForwardMethod(FSDPTest):
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
|
||||
class TestFullyShardShareCommContext(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(torch.get_device_module(device_type).device_count(), 2)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_share_comm_context(self):
|
||||
torch.manual_seed(42)
|
||||
n_layers = 3
|
||||
lin_dim = 16
|
||||
model = nn.Sequential(
|
||||
*[MLP(lin_dim, torch.device("cpu")) for _ in range(n_layers)]
|
||||
)
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
for layer in model:
|
||||
fully_shard(layer)
|
||||
layer._get_fsdp_state()._lazy_init()
|
||||
share_comm_ctx(list(model))
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn(4, 3, lin_dim, device=device_type.type)
|
||||
ref_loss = ref_model(inp).sum()
|
||||
|
||||
all_gather_streams = set()
|
||||
reduce_scatter_streams = set()
|
||||
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_api import (
|
||||
AllGather,
|
||||
ReduceScatter,
|
||||
)
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||
|
||||
orig_foreach_all_gather = foreach_all_gather
|
||||
|
||||
def foreach_all_gather_with_assert(
|
||||
fsdp_params: list[FSDPParam],
|
||||
group: dist.ProcessGroup,
|
||||
async_op: bool,
|
||||
all_gather_copy_in_stream: torch.Stream,
|
||||
all_gather_stream: torch.Stream,
|
||||
device: torch.device,
|
||||
all_gather_comm: AllGather,
|
||||
):
|
||||
nonlocal all_gather_streams
|
||||
all_gather_streams.add(all_gather_stream)
|
||||
return orig_foreach_all_gather(
|
||||
fsdp_params,
|
||||
group,
|
||||
async_op,
|
||||
all_gather_copy_in_stream,
|
||||
all_gather_stream,
|
||||
device,
|
||||
all_gather_comm,
|
||||
)
|
||||
|
||||
orig_foreach_reduce = foreach_reduce
|
||||
|
||||
@torch.no_grad()
|
||||
def foreach_reduce_with_assert(
|
||||
fsdp_params: list[FSDPParam],
|
||||
unsharded_grads: list[torch.Tensor],
|
||||
reduce_scatter_group: dist.ProcessGroup,
|
||||
reduce_scatter_stream: torch.Stream,
|
||||
reduce_scatter_comm: ReduceScatter,
|
||||
orig_dtype: Optional[torch.dtype],
|
||||
reduce_dtype: Optional[torch.dtype],
|
||||
device: torch.device,
|
||||
gradient_divide_factor: Optional[float],
|
||||
all_reduce_group: Optional[dist.ProcessGroup], # not `None` iff HSDP
|
||||
all_reduce_stream: torch.Stream,
|
||||
all_reduce_grads: bool,
|
||||
partial_reduce_output: Optional[torch.Tensor], # only used for HSDP
|
||||
all_reduce_hook: Optional[Callable[[torch.Tensor], None]],
|
||||
force_sum_reduction_for_comms: bool = False,
|
||||
):
|
||||
nonlocal reduce_scatter_streams
|
||||
reduce_scatter_streams.add(reduce_scatter_stream)
|
||||
return orig_foreach_reduce(
|
||||
fsdp_params,
|
||||
unsharded_grads,
|
||||
reduce_scatter_group,
|
||||
reduce_scatter_stream,
|
||||
reduce_scatter_comm,
|
||||
orig_dtype,
|
||||
reduce_dtype,
|
||||
device,
|
||||
gradient_divide_factor,
|
||||
all_reduce_group,
|
||||
all_reduce_stream,
|
||||
all_reduce_grads,
|
||||
partial_reduce_output,
|
||||
all_reduce_hook,
|
||||
force_sum_reduction_for_comms,
|
||||
)
|
||||
|
||||
with (
|
||||
patch_foreach_all_gather(foreach_all_gather_with_assert),
|
||||
patch_foreach_reduce(foreach_reduce_with_assert),
|
||||
):
|
||||
loss = model(inp).sum()
|
||||
self.assertEqual(ref_loss, loss)
|
||||
ref_loss.backward()
|
||||
loss.backward()
|
||||
for param in ref_model.parameters():
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
||||
self.assertEqual(len(all_gather_streams), 1)
|
||||
self.assertEqual(len(reduce_scatter_streams), 1)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
|
||||
class TestFullyShardWorldSize1(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
|
@ -6,6 +6,7 @@ from ._fully_shard import (
|
||||
MixedPrecisionPolicy,
|
||||
OffloadPolicy,
|
||||
register_fsdp_forward_method,
|
||||
share_comm_ctx,
|
||||
UnshardHandle,
|
||||
)
|
||||
from .fully_sharded_data_parallel import (
|
||||
@ -54,6 +55,7 @@ __all__ = [
|
||||
"OffloadPolicy",
|
||||
"register_fsdp_forward_method",
|
||||
"UnshardHandle",
|
||||
"share_comm_ctx",
|
||||
]
|
||||
|
||||
# Set namespace for exposed private names
|
||||
@ -64,3 +66,4 @@ MixedPrecisionPolicy.__module__ = "torch.distributed.fsdp"
|
||||
OffloadPolicy.__module__ = "torch.distributed.fsdp"
|
||||
register_fsdp_forward_method.__module__ = "torch.distributed.fsdp"
|
||||
UnshardHandle.__module__ = "torch.distributed.fsdp"
|
||||
share_comm_ctx.__module__ = "torch.distributed.fsdp"
|
||||
|
@ -3,6 +3,7 @@ from ._fully_shard import (
|
||||
FSDPModule,
|
||||
fully_shard,
|
||||
register_fsdp_forward_method,
|
||||
share_comm_ctx,
|
||||
UnshardHandle,
|
||||
)
|
||||
|
||||
@ -15,4 +16,5 @@ __all__ = [
|
||||
"OffloadPolicy",
|
||||
"register_fsdp_forward_method",
|
||||
"UnshardHandle",
|
||||
"share_comm_ctx",
|
||||
]
|
||||
|
@ -39,6 +39,7 @@ __all__ = [
|
||||
"register_fsdp_forward_method",
|
||||
"get_cls_to_fsdp_cls",
|
||||
"disable_fsdp_module_new_init",
|
||||
"share_comm_ctx",
|
||||
]
|
||||
|
||||
|
||||
@ -711,6 +712,34 @@ def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def share_comm_ctx(modules: list[FSDPModule]) -> None:
|
||||
"""
|
||||
Share cuda streams for multiple FSDPModules
|
||||
|
||||
Example usage:
|
||||
from torch.distributed.fsdp import share_comm_ctx
|
||||
share_comm_ctx([fsdp_model_1, fsdp_model_2, ...])
|
||||
|
||||
For Pipeline Parallelism (PP), each model chunk is a FSDP root. We want
|
||||
to share cuda streams for all-gather, reduce-scatter, and all-reduce.
|
||||
This avoids allocating inter-stream memory framgmentation
|
||||
|
||||
Args:
|
||||
modules (List[FSDPModule]): modules to share cuda streams
|
||||
"""
|
||||
if len(modules) == 0:
|
||||
return
|
||||
for module in modules:
|
||||
if not isinstance(module, FSDPModule):
|
||||
raise ValueError(f"Expects list of FSDPModules but got {module}")
|
||||
fsdp_states = [module._get_fsdp_state() for module in modules]
|
||||
comm_ctx = fsdp_states[0]._comm_ctx
|
||||
for fsdp_state in fsdp_states[1:]:
|
||||
fsdp_state._comm_ctx = comm_ctx
|
||||
if fsdp_param_group := fsdp_state._fsdp_param_group:
|
||||
fsdp_param_group.comm_ctx = comm_ctx
|
||||
|
||||
|
||||
def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None:
|
||||
for module in modules:
|
||||
if not isinstance(module, FSDPModule):
|
||||
|
@ -997,6 +997,42 @@ def patch_all_gather(new_all_gather_into_tensor: Callable):
|
||||
dist.all_gather_into_tensor = orig_all_gather
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_foreach_all_gather(new_foreach_all_gather: Callable):
|
||||
orig_foreach_all_gather = (
|
||||
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather
|
||||
)
|
||||
dist.barrier()
|
||||
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
|
||||
new_foreach_all_gather
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
dist.barrier()
|
||||
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
|
||||
orig_foreach_all_gather
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_foreach_reduce(new_foreach_reduce: Callable):
|
||||
orig_foreach_foreach_reduce = (
|
||||
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce
|
||||
)
|
||||
dist.barrier()
|
||||
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
|
||||
new_foreach_reduce
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
dist.barrier()
|
||||
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
|
||||
orig_foreach_foreach_reduce
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_reduce_scatter(new_reduce_scatter_tensor: Callable):
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
Reference in New Issue
Block a user