[dtensor][cp][experiment] add CP experimental API to choose rotate method (#142093)

**Summary**
This PR adds a new experimental API `set_rotate_method` for Context Parallel. This API allows user to choose the desired communication method (between all-to-all and all-gather) for shards rotation.

**Test**
`pytest test/distributed/_tensor/test_attention.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142093
Approved by: https://github.com/fegin
This commit is contained in:
Xilun Wu
2024-12-09 17:37:48 -08:00
committed by PyTorch MergeBot
parent eb84788fee
commit bce07deb96
2 changed files with 50 additions and 4 deletions

View File

@ -15,6 +15,7 @@ from torch.distributed._tensor.experimental._attention import (
_RotateMethod,
context_parallel,
context_parallel_unshard,
set_rotate_method,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import parallelize_module
@ -48,6 +49,12 @@ if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
backends.append(SDPBackend.EFFICIENT_ATTENTION)
rotater_enum_to_str = {
_RotateMethod.ALL_GATHER: "allgather",
_RotateMethod.ALL_TO_ALL: "alltoall",
} # mapping from _RotateMethod enum to string
class RingAttentionTest(DTensorTestBase):
@property
def world_size(self) -> int:
@ -76,7 +83,8 @@ class RingAttentionTest(DTensorTestBase):
load_balance: bool,
rotater: _RotateMethod,
) -> None:
_cp_options.rotate_method = rotater
set_rotate_method(rotater_enum_to_str[rotater])
self.assertEqual(_cp_options.rotate_method, rotater)
device_mesh = DeviceMesh(self.device_type, torch.arange(0, self.world_size))
dtype = torch.bfloat16
bs = 8
@ -230,7 +238,8 @@ class RingAttentionTest(DTensorTestBase):
self, is_causal: bool, rotater: _RotateMethod
) -> None:
_cp_options.enable_load_balance = is_causal
_cp_options.rotate_method = rotater
set_rotate_method(rotater_enum_to_str[rotater])
self.assertEqual(_cp_options.rotate_method, rotater)
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, self.world_size),
@ -314,7 +323,8 @@ class RingAttentionTest(DTensorTestBase):
@sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION])
@parametrize("rotater", [_RotateMethod.ALL_GATHER, _RotateMethod.ALL_TO_ALL])
def test_ring_attention_custom_transformer(self, rotater: _RotateMethod) -> None:
_cp_options.rotate_method = rotater
set_rotate_method(rotater_enum_to_str[rotater])
self.assertEqual(_cp_options.rotate_method, rotater)
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, self.world_size),

View File

@ -31,7 +31,7 @@ from torch.distributed.tensor import distribute_module, DTensor, Replicate, Shar
from torch.distributed.tensor.parallel.style import ParallelStyle
__all__ = ["context_parallel"]
__all__ = ["context_parallel", "set_rotate_method"]
class _CausalBehavior(Enum):
@ -1284,6 +1284,15 @@ def context_parallel_unshard(
) -> List[torch.Tensor]:
"""
Unshard the tensors (e.g., output) that are sharded due to context parallelism.
Args:
mesh (:class:`DeviceMesh`): the device mesh for the context parallelism.
buffers (List[torch.Tensor]): the buffers to be unsharded.
seq_dims (List[int]): the sequence dimensions of ``buffers``. This list
must have the same length as ``buffers``.
Returns:
List[torch.Tensor]: the unsharded buffers.
"""
sharder = (
_RoundRobinLoadBalancer
@ -1291,3 +1300,30 @@ def context_parallel_unshard(
else _SequentialSharder
)
return [sharder.unshard(b, mesh, dim) for b, dim in zip(buffers, seq_dims)]
def set_rotate_method(rotate_method: str) -> None:
"""
Context Parallel SDPA requires the rotation of kv shards. Users can call this
API to specify which rotation method to use. "alltoall" shuffles the kv shards
using all-to-all collective. While "allgather" gathers the kv shards using
all-gather collective after the first sub-SDPA computation. If this API has not
been called, the default rotate method is "allgather".
Args:
rotate_method (str): the rotate method to use. Currently only supports
"allgather" and "alltoall". If a different string other than these two
is passed in, the function will raise an error.
Returns:
None
"""
if rotate_method == "allgather":
_cp_options.rotate_method = _RotateMethod.ALL_GATHER
elif rotate_method == "alltoall":
_cp_options.rotate_method = _RotateMethod.ALL_TO_ALL
else:
raise NotImplementedError(
"Context Parallel does not support "
f"using {rotate_method} for kv shards rotation"
)