mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[dtensor][cp][experiment] add CP experimental API to choose rotate method (#142093)
**Summary** This PR adds a new experimental API `set_rotate_method` for Context Parallel. This API allows user to choose the desired communication method (between all-to-all and all-gather) for shards rotation. **Test** `pytest test/distributed/_tensor/test_attention.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/142093 Approved by: https://github.com/fegin
This commit is contained in:
committed by
PyTorch MergeBot
parent
eb84788fee
commit
bce07deb96
@ -15,6 +15,7 @@ from torch.distributed._tensor.experimental._attention import (
|
||||
_RotateMethod,
|
||||
context_parallel,
|
||||
context_parallel_unshard,
|
||||
set_rotate_method,
|
||||
)
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.parallel import parallelize_module
|
||||
@ -48,6 +49,12 @@ if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
|
||||
backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
||||
|
||||
|
||||
rotater_enum_to_str = {
|
||||
_RotateMethod.ALL_GATHER: "allgather",
|
||||
_RotateMethod.ALL_TO_ALL: "alltoall",
|
||||
} # mapping from _RotateMethod enum to string
|
||||
|
||||
|
||||
class RingAttentionTest(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
@ -76,7 +83,8 @@ class RingAttentionTest(DTensorTestBase):
|
||||
load_balance: bool,
|
||||
rotater: _RotateMethod,
|
||||
) -> None:
|
||||
_cp_options.rotate_method = rotater
|
||||
set_rotate_method(rotater_enum_to_str[rotater])
|
||||
self.assertEqual(_cp_options.rotate_method, rotater)
|
||||
device_mesh = DeviceMesh(self.device_type, torch.arange(0, self.world_size))
|
||||
dtype = torch.bfloat16
|
||||
bs = 8
|
||||
@ -230,7 +238,8 @@ class RingAttentionTest(DTensorTestBase):
|
||||
self, is_causal: bool, rotater: _RotateMethod
|
||||
) -> None:
|
||||
_cp_options.enable_load_balance = is_causal
|
||||
_cp_options.rotate_method = rotater
|
||||
set_rotate_method(rotater_enum_to_str[rotater])
|
||||
self.assertEqual(_cp_options.rotate_method, rotater)
|
||||
device_mesh = DeviceMesh(
|
||||
self.device_type,
|
||||
torch.arange(0, self.world_size),
|
||||
@ -314,7 +323,8 @@ class RingAttentionTest(DTensorTestBase):
|
||||
@sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION])
|
||||
@parametrize("rotater", [_RotateMethod.ALL_GATHER, _RotateMethod.ALL_TO_ALL])
|
||||
def test_ring_attention_custom_transformer(self, rotater: _RotateMethod) -> None:
|
||||
_cp_options.rotate_method = rotater
|
||||
set_rotate_method(rotater_enum_to_str[rotater])
|
||||
self.assertEqual(_cp_options.rotate_method, rotater)
|
||||
device_mesh = DeviceMesh(
|
||||
self.device_type,
|
||||
torch.arange(0, self.world_size),
|
||||
|
||||
@ -31,7 +31,7 @@ from torch.distributed.tensor import distribute_module, DTensor, Replicate, Shar
|
||||
from torch.distributed.tensor.parallel.style import ParallelStyle
|
||||
|
||||
|
||||
__all__ = ["context_parallel"]
|
||||
__all__ = ["context_parallel", "set_rotate_method"]
|
||||
|
||||
|
||||
class _CausalBehavior(Enum):
|
||||
@ -1284,6 +1284,15 @@ def context_parallel_unshard(
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Unshard the tensors (e.g., output) that are sharded due to context parallelism.
|
||||
|
||||
Args:
|
||||
mesh (:class:`DeviceMesh`): the device mesh for the context parallelism.
|
||||
buffers (List[torch.Tensor]): the buffers to be unsharded.
|
||||
seq_dims (List[int]): the sequence dimensions of ``buffers``. This list
|
||||
must have the same length as ``buffers``.
|
||||
|
||||
Returns:
|
||||
List[torch.Tensor]: the unsharded buffers.
|
||||
"""
|
||||
sharder = (
|
||||
_RoundRobinLoadBalancer
|
||||
@ -1291,3 +1300,30 @@ def context_parallel_unshard(
|
||||
else _SequentialSharder
|
||||
)
|
||||
return [sharder.unshard(b, mesh, dim) for b, dim in zip(buffers, seq_dims)]
|
||||
|
||||
|
||||
def set_rotate_method(rotate_method: str) -> None:
|
||||
"""
|
||||
Context Parallel SDPA requires the rotation of kv shards. Users can call this
|
||||
API to specify which rotation method to use. "alltoall" shuffles the kv shards
|
||||
using all-to-all collective. While "allgather" gathers the kv shards using
|
||||
all-gather collective after the first sub-SDPA computation. If this API has not
|
||||
been called, the default rotate method is "allgather".
|
||||
|
||||
Args:
|
||||
rotate_method (str): the rotate method to use. Currently only supports
|
||||
"allgather" and "alltoall". If a different string other than these two
|
||||
is passed in, the function will raise an error.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if rotate_method == "allgather":
|
||||
_cp_options.rotate_method = _RotateMethod.ALL_GATHER
|
||||
elif rotate_method == "alltoall":
|
||||
_cp_options.rotate_method = _RotateMethod.ALL_TO_ALL
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Context Parallel does not support "
|
||||
f"using {rotate_method} for kv shards rotation"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user