mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FSDP2] provide public API to share cuda streams across roots (#165024)
for pipeline parallel, we can have multiple FSDP roots (chunks) ``` model = nn.Sequential([chunk0, chunk1]) fully_shard(model.chunk0) fully_shard(model.chunk1) ``` we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation ``` from torch.distributed.fsdp import share_comm_ctx share_comm_ctx([model.chunk0, model.chunk1]) ``` unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context` Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/165024 Approved by: https://github.com/mori360
This commit is contained in:
committed by
PyTorch MergeBot
parent
9b6be53326
commit
6918f17114
@ -123,3 +123,7 @@ The frontend API is `fully_shard` that can be called on a `module`:
|
|||||||
.. autoclass:: CPUOffloadPolicy
|
.. autoclass:: CPUOffloadPolicy
|
||||||
:members:
|
:members:
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. autofunction:: share_comm_ctx
|
||||||
|
```
|
||||||
|
@ -6,7 +6,7 @@ import functools
|
|||||||
import itertools
|
import itertools
|
||||||
import unittest
|
import unittest
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Iterable
|
from collections.abc import Callable, Iterable
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -24,6 +24,11 @@ from torch.distributed.fsdp import (
|
|||||||
fully_shard,
|
fully_shard,
|
||||||
OffloadPolicy,
|
OffloadPolicy,
|
||||||
register_fsdp_forward_method,
|
register_fsdp_forward_method,
|
||||||
|
share_comm_ctx,
|
||||||
|
)
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
|
||||||
|
foreach_all_gather,
|
||||||
|
foreach_reduce,
|
||||||
)
|
)
|
||||||
from torch.distributed.tensor import DTensor, init_device_mesh, Shard
|
from torch.distributed.tensor import DTensor, init_device_mesh, Shard
|
||||||
from torch.distributed.tensor.debug import CommDebugMode
|
from torch.distributed.tensor.debug import CommDebugMode
|
||||||
@ -39,6 +44,8 @@ from torch.testing._internal.common_fsdp import (
|
|||||||
MLP,
|
MLP,
|
||||||
MLPStack,
|
MLPStack,
|
||||||
patch_all_gather,
|
patch_all_gather,
|
||||||
|
patch_foreach_all_gather,
|
||||||
|
patch_foreach_reduce,
|
||||||
patch_reduce_scatter,
|
patch_reduce_scatter,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
@ -1487,6 +1494,116 @@ class TestFullyShardCustomForwardMethod(FSDPTest):
|
|||||||
check_sharded_parity(self, ref_model, model)
|
check_sharded_parity(self, ref_model, model)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFullyShardShareCommContext(FSDPTest):
|
||||||
|
@property
|
||||||
|
def world_size(self) -> int:
|
||||||
|
return min(torch.get_device_module(device_type).device_count(), 2)
|
||||||
|
|
||||||
|
@skip_if_lt_x_gpu(2)
|
||||||
|
def test_share_comm_context(self):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
n_layers = 3
|
||||||
|
lin_dim = 16
|
||||||
|
model = nn.Sequential(
|
||||||
|
*[MLP(lin_dim, torch.device("cpu")) for _ in range(n_layers)]
|
||||||
|
)
|
||||||
|
ref_model = copy.deepcopy(model).to(device_type)
|
||||||
|
for layer in model:
|
||||||
|
fully_shard(layer)
|
||||||
|
layer._get_fsdp_state()._lazy_init()
|
||||||
|
share_comm_ctx(list(model))
|
||||||
|
|
||||||
|
torch.manual_seed(42 + self.rank + 1)
|
||||||
|
inp = torch.randn(4, 3, lin_dim, device=device_type.type)
|
||||||
|
ref_loss = ref_model(inp).sum()
|
||||||
|
|
||||||
|
all_gather_streams = set()
|
||||||
|
reduce_scatter_streams = set()
|
||||||
|
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_api import (
|
||||||
|
AllGather,
|
||||||
|
ReduceScatter,
|
||||||
|
)
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||||
|
|
||||||
|
orig_foreach_all_gather = foreach_all_gather
|
||||||
|
|
||||||
|
def foreach_all_gather_with_assert(
|
||||||
|
fsdp_params: list[FSDPParam],
|
||||||
|
group: dist.ProcessGroup,
|
||||||
|
async_op: bool,
|
||||||
|
all_gather_copy_in_stream: torch.Stream,
|
||||||
|
all_gather_stream: torch.Stream,
|
||||||
|
device: torch.device,
|
||||||
|
all_gather_comm: AllGather,
|
||||||
|
):
|
||||||
|
nonlocal all_gather_streams
|
||||||
|
all_gather_streams.add(all_gather_stream)
|
||||||
|
return orig_foreach_all_gather(
|
||||||
|
fsdp_params,
|
||||||
|
group,
|
||||||
|
async_op,
|
||||||
|
all_gather_copy_in_stream,
|
||||||
|
all_gather_stream,
|
||||||
|
device,
|
||||||
|
all_gather_comm,
|
||||||
|
)
|
||||||
|
|
||||||
|
orig_foreach_reduce = foreach_reduce
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def foreach_reduce_with_assert(
|
||||||
|
fsdp_params: list[FSDPParam],
|
||||||
|
unsharded_grads: list[torch.Tensor],
|
||||||
|
reduce_scatter_group: dist.ProcessGroup,
|
||||||
|
reduce_scatter_stream: torch.Stream,
|
||||||
|
reduce_scatter_comm: ReduceScatter,
|
||||||
|
orig_dtype: Optional[torch.dtype],
|
||||||
|
reduce_dtype: Optional[torch.dtype],
|
||||||
|
device: torch.device,
|
||||||
|
gradient_divide_factor: Optional[float],
|
||||||
|
all_reduce_group: Optional[dist.ProcessGroup], # not `None` iff HSDP
|
||||||
|
all_reduce_stream: torch.Stream,
|
||||||
|
all_reduce_grads: bool,
|
||||||
|
partial_reduce_output: Optional[torch.Tensor], # only used for HSDP
|
||||||
|
all_reduce_hook: Optional[Callable[[torch.Tensor], None]],
|
||||||
|
force_sum_reduction_for_comms: bool = False,
|
||||||
|
):
|
||||||
|
nonlocal reduce_scatter_streams
|
||||||
|
reduce_scatter_streams.add(reduce_scatter_stream)
|
||||||
|
return orig_foreach_reduce(
|
||||||
|
fsdp_params,
|
||||||
|
unsharded_grads,
|
||||||
|
reduce_scatter_group,
|
||||||
|
reduce_scatter_stream,
|
||||||
|
reduce_scatter_comm,
|
||||||
|
orig_dtype,
|
||||||
|
reduce_dtype,
|
||||||
|
device,
|
||||||
|
gradient_divide_factor,
|
||||||
|
all_reduce_group,
|
||||||
|
all_reduce_stream,
|
||||||
|
all_reduce_grads,
|
||||||
|
partial_reduce_output,
|
||||||
|
all_reduce_hook,
|
||||||
|
force_sum_reduction_for_comms,
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_foreach_all_gather(foreach_all_gather_with_assert),
|
||||||
|
patch_foreach_reduce(foreach_reduce_with_assert),
|
||||||
|
):
|
||||||
|
loss = model(inp).sum()
|
||||||
|
self.assertEqual(ref_loss, loss)
|
||||||
|
ref_loss.backward()
|
||||||
|
loss.backward()
|
||||||
|
for param in ref_model.parameters():
|
||||||
|
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
||||||
|
self.assertEqual(len(all_gather_streams), 1)
|
||||||
|
self.assertEqual(len(reduce_scatter_streams), 1)
|
||||||
|
check_sharded_parity(self, ref_model, model)
|
||||||
|
|
||||||
|
|
||||||
class TestFullyShardWorldSize1(FSDPTest):
|
class TestFullyShardWorldSize1(FSDPTest):
|
||||||
@property
|
@property
|
||||||
def world_size(self) -> int:
|
def world_size(self) -> int:
|
||||||
|
@ -6,6 +6,7 @@ from ._fully_shard import (
|
|||||||
MixedPrecisionPolicy,
|
MixedPrecisionPolicy,
|
||||||
OffloadPolicy,
|
OffloadPolicy,
|
||||||
register_fsdp_forward_method,
|
register_fsdp_forward_method,
|
||||||
|
share_comm_ctx,
|
||||||
UnshardHandle,
|
UnshardHandle,
|
||||||
)
|
)
|
||||||
from .fully_sharded_data_parallel import (
|
from .fully_sharded_data_parallel import (
|
||||||
@ -54,6 +55,7 @@ __all__ = [
|
|||||||
"OffloadPolicy",
|
"OffloadPolicy",
|
||||||
"register_fsdp_forward_method",
|
"register_fsdp_forward_method",
|
||||||
"UnshardHandle",
|
"UnshardHandle",
|
||||||
|
"share_comm_ctx",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Set namespace for exposed private names
|
# Set namespace for exposed private names
|
||||||
@ -64,3 +66,4 @@ MixedPrecisionPolicy.__module__ = "torch.distributed.fsdp"
|
|||||||
OffloadPolicy.__module__ = "torch.distributed.fsdp"
|
OffloadPolicy.__module__ = "torch.distributed.fsdp"
|
||||||
register_fsdp_forward_method.__module__ = "torch.distributed.fsdp"
|
register_fsdp_forward_method.__module__ = "torch.distributed.fsdp"
|
||||||
UnshardHandle.__module__ = "torch.distributed.fsdp"
|
UnshardHandle.__module__ = "torch.distributed.fsdp"
|
||||||
|
share_comm_ctx.__module__ = "torch.distributed.fsdp"
|
||||||
|
@ -3,6 +3,7 @@ from ._fully_shard import (
|
|||||||
FSDPModule,
|
FSDPModule,
|
||||||
fully_shard,
|
fully_shard,
|
||||||
register_fsdp_forward_method,
|
register_fsdp_forward_method,
|
||||||
|
share_comm_ctx,
|
||||||
UnshardHandle,
|
UnshardHandle,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -15,4 +16,5 @@ __all__ = [
|
|||||||
"OffloadPolicy",
|
"OffloadPolicy",
|
||||||
"register_fsdp_forward_method",
|
"register_fsdp_forward_method",
|
||||||
"UnshardHandle",
|
"UnshardHandle",
|
||||||
|
"share_comm_ctx",
|
||||||
]
|
]
|
||||||
|
@ -39,6 +39,7 @@ __all__ = [
|
|||||||
"register_fsdp_forward_method",
|
"register_fsdp_forward_method",
|
||||||
"get_cls_to_fsdp_cls",
|
"get_cls_to_fsdp_cls",
|
||||||
"disable_fsdp_module_new_init",
|
"disable_fsdp_module_new_init",
|
||||||
|
"share_comm_ctx",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -711,6 +712,34 @@ def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def share_comm_ctx(modules: list[FSDPModule]) -> None:
|
||||||
|
"""
|
||||||
|
Share cuda streams for multiple FSDPModules
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
from torch.distributed.fsdp import share_comm_ctx
|
||||||
|
share_comm_ctx([fsdp_model_1, fsdp_model_2, ...])
|
||||||
|
|
||||||
|
For Pipeline Parallelism (PP), each model chunk is a FSDP root. We want
|
||||||
|
to share cuda streams for all-gather, reduce-scatter, and all-reduce.
|
||||||
|
This avoids allocating inter-stream memory framgmentation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modules (List[FSDPModule]): modules to share cuda streams
|
||||||
|
"""
|
||||||
|
if len(modules) == 0:
|
||||||
|
return
|
||||||
|
for module in modules:
|
||||||
|
if not isinstance(module, FSDPModule):
|
||||||
|
raise ValueError(f"Expects list of FSDPModules but got {module}")
|
||||||
|
fsdp_states = [module._get_fsdp_state() for module in modules]
|
||||||
|
comm_ctx = fsdp_states[0]._comm_ctx
|
||||||
|
for fsdp_state in fsdp_states[1:]:
|
||||||
|
fsdp_state._comm_ctx = comm_ctx
|
||||||
|
if fsdp_param_group := fsdp_state._fsdp_param_group:
|
||||||
|
fsdp_param_group.comm_ctx = comm_ctx
|
||||||
|
|
||||||
|
|
||||||
def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None:
|
def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None:
|
||||||
for module in modules:
|
for module in modules:
|
||||||
if not isinstance(module, FSDPModule):
|
if not isinstance(module, FSDPModule):
|
||||||
|
@ -997,6 +997,42 @@ def patch_all_gather(new_all_gather_into_tensor: Callable):
|
|||||||
dist.all_gather_into_tensor = orig_all_gather
|
dist.all_gather_into_tensor = orig_all_gather
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def patch_foreach_all_gather(new_foreach_all_gather: Callable):
|
||||||
|
orig_foreach_all_gather = (
|
||||||
|
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather
|
||||||
|
)
|
||||||
|
dist.barrier()
|
||||||
|
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
|
||||||
|
new_foreach_all_gather
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
dist.barrier()
|
||||||
|
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
|
||||||
|
orig_foreach_all_gather
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def patch_foreach_reduce(new_foreach_reduce: Callable):
|
||||||
|
orig_foreach_foreach_reduce = (
|
||||||
|
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce
|
||||||
|
)
|
||||||
|
dist.barrier()
|
||||||
|
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
|
||||||
|
new_foreach_reduce
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
dist.barrier()
|
||||||
|
torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
|
||||||
|
orig_foreach_foreach_reduce
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def patch_reduce_scatter(new_reduce_scatter_tensor: Callable):
|
def patch_reduce_scatter(new_reduce_scatter_tensor: Callable):
|
||||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||||
|
Reference in New Issue
Block a user