mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
https://github.com/pytorch/pytorch/pull/163617 removes the if/else statement to check if the input buffers have the batch dimension. This PR fixes the issue and also adds a test. In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165792 Approved by: https://github.com/XilunWu
811 lines
28 KiB
Python
811 lines
28 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
import itertools
|
|
import random
|
|
import unittest
|
|
from typing import Any, Callable, ClassVar, Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed.distributed_c10d as c10d
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
from torch.distributed.tensor import DeviceMesh
|
|
from torch.distributed.tensor.debug import CommDebugMode
|
|
from torch.distributed.tensor.experimental._attention import (
|
|
_CausalBehavior,
|
|
_context_parallel_shard,
|
|
_ContextParallel,
|
|
_cp_options,
|
|
_disable_context_parallel_dispatcher,
|
|
_enable_context_parallel_dispatcher,
|
|
_is_causal_behavior,
|
|
_RotateMethod,
|
|
context_parallel,
|
|
context_parallel_unshard,
|
|
set_rotate_method,
|
|
)
|
|
from torch.distributed.tensor.experimental._cp_custom_ops import flex_cp_allgather
|
|
from torch.distributed.tensor.experimental._load_balancer import (
|
|
_HeadTailLoadBalancer,
|
|
_LoadBalancer,
|
|
_PerDocumentHeadTailLoadBalancer,
|
|
_PTRRLoadBalancer,
|
|
)
|
|
from torch.distributed.tensor.parallel import parallelize_module
|
|
from torch.nn.attention import sdpa_kernel, SDPBackend
|
|
from torch.nn.attention.flex_attention import (
|
|
_mask_mod_signature,
|
|
AuxOutput,
|
|
AuxRequest,
|
|
BlockMask,
|
|
create_block_mask,
|
|
flex_attention,
|
|
)
|
|
from torch.testing._internal.common_cuda import (
|
|
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
|
|
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
|
PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
|
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
|
)
|
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
|
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
with_comms,
|
|
)
|
|
|
|
|
|
c10d_functional = torch.ops.c10d_functional
|
|
backends = []
|
|
if PLATFORM_SUPPORTS_FLASH_ATTENTION:
|
|
backends.append(SDPBackend.FLASH_ATTENTION)
|
|
if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
|
|
backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
|
if PLATFORM_SUPPORTS_CUDNN_ATTENTION:
|
|
backends.append(SDPBackend.CUDNN_ATTENTION)
|
|
|
|
rotater_enum_to_str = {
|
|
_RotateMethod.ALL_GATHER: "allgather",
|
|
_RotateMethod.ALL_TO_ALL: "alltoall",
|
|
} # mapping from _RotateMethod enum to string
|
|
|
|
|
|
class SDPAWrapper(torch.nn.Module):
|
|
def __init__(self, compiled: bool, backend: SDPBackend) -> None:
|
|
super().__init__()
|
|
if compiled:
|
|
self.sdpa = torch.compile(
|
|
F.scaled_dot_product_attention,
|
|
fullgraph=True,
|
|
backend="aot_eager",
|
|
)
|
|
else:
|
|
self.sdpa = F.scaled_dot_product_attention
|
|
self.backend = backend
|
|
|
|
def forward(self, *args: object, **kwargs: object) -> torch.Tensor:
|
|
with sdpa_kernel(self.backend):
|
|
return self.sdpa(*args, **kwargs)
|
|
|
|
|
|
class RingAttentionTest(DTensorTestBase):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return torch.cuda.device_count()
|
|
|
|
@property
|
|
def destroy_pg_upon_exit(self) -> bool:
|
|
return False
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@skipIfRocm # Missing _c10d_functional_autograd::all_to_all_single
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
|
"Does not support flash nor efficient attention",
|
|
)
|
|
@with_comms
|
|
def test_ring_attention_sdpa(self) -> None:
|
|
self.run_subtests(
|
|
{
|
|
"is_causal": [True, False],
|
|
"compiled": [True, False],
|
|
"backend": backends,
|
|
"load_balance": [True, False],
|
|
"rotater": [_RotateMethod.ALL_TO_ALL, _RotateMethod.ALL_GATHER],
|
|
"test_forward_only": [True, False],
|
|
"use_context": [True, False],
|
|
},
|
|
self._test_ring_attention_sdpa,
|
|
)
|
|
|
|
def _ring_attention_sdpa(
|
|
self,
|
|
cp_q: torch.Tensor,
|
|
cp_k: torch.Tensor,
|
|
cp_v: torch.Tensor,
|
|
*,
|
|
fn_eval: Callable,
|
|
mesh: DeviceMesh,
|
|
seq_dim: int,
|
|
is_causal: bool,
|
|
compiled: bool,
|
|
backend: SDPBackend,
|
|
rotater: _RotateMethod,
|
|
test_forward_only: bool,
|
|
load_balance: bool,
|
|
use_context: bool,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
if not use_context:
|
|
cp_plan = _ContextParallel(
|
|
seq_dim=seq_dim,
|
|
attention_type=_ContextParallel.AttentionType.SDPA,
|
|
)
|
|
attention = SDPAWrapper(compiled=compiled, backend=backend)
|
|
attention = parallelize_module(attention, mesh, cp_plan)
|
|
if load_balance:
|
|
seq_len = cp_q.size(seq_dim)
|
|
load_balancer = _HeadTailLoadBalancer(seq_len, mesh.size(), cp_q.device)
|
|
else:
|
|
load_balancer = None
|
|
cp_q, cp_k, cp_v = _context_parallel_shard(
|
|
mesh, (cp_q, cp_k, cp_v), (seq_dim,) * 3, load_balancer=load_balancer
|
|
)
|
|
_enable_context_parallel_dispatcher()
|
|
else:
|
|
# Theoretically, context_parallel() should not be used to shard
|
|
# parameters because when require_grad is True, resize_ is not
|
|
# allowed. But requires_grad of cp_q, cp_k, and cp_v are False
|
|
# now. So we can just use context_parallel() to shard q, k, v.
|
|
# In reality, context_paralle() should be used to shard the input.
|
|
# In reality, context_parallel() should only be used to shard
|
|
# the model inputs (batch).
|
|
|
|
_cp_options.enable_load_balance = load_balance
|
|
cp_context = context_parallel(
|
|
mesh, buffers=(cp_q, cp_k, cp_v), buffer_seq_dims=(seq_dim,) * 3
|
|
)
|
|
cp_context.__enter__()
|
|
|
|
# NOTE: This demonstrates that monkey patching is not fully reliable.
|
|
# If we use SDPAWrapper directly, the monkey patching dispatch mode
|
|
# does not function correctly. To ensure proper behavior,
|
|
# F.scaled_dot_product_attention must be referenced within the
|
|
# context_parallel() scope.
|
|
attention = F.scaled_dot_product_attention
|
|
if compiled:
|
|
attention = torch.compile(
|
|
attention, fullgraph=True, backend="aot_eager"
|
|
)
|
|
|
|
for target in [cp_q, cp_k, cp_v]:
|
|
target.requires_grad = True
|
|
|
|
with CommDebugMode() as comm_mode:
|
|
with sdpa_kernel(backend):
|
|
cp_out = fn_eval(
|
|
attention,
|
|
cp_q,
|
|
cp_k,
|
|
cp_v,
|
|
is_causal=is_causal,
|
|
)
|
|
|
|
if not compiled and rotater == _RotateMethod.ALL_TO_ALL:
|
|
# Compiler and CommDebugMode do not work well together.
|
|
expect_all2all_count = (
|
|
self.world_size - 1
|
|
if test_forward_only
|
|
else self.world_size * 3 - 2
|
|
)
|
|
self.assertDictEqual(
|
|
comm_mode.get_comm_counts(),
|
|
{c10d_functional.all_to_all_single: expect_all2all_count},
|
|
)
|
|
cp_dq, cp_dk, cp_dv = cp_q.grad, cp_k.grad, cp_v.grad
|
|
for target in [cp_q, cp_k, cp_v]:
|
|
target.requires_grad = False
|
|
|
|
if not use_context:
|
|
_disable_context_parallel_dispatcher()
|
|
else:
|
|
cp_context.__exit__(None, None, None)
|
|
|
|
return cp_out, cp_dq, cp_dk, cp_dv
|
|
|
|
def _test_ring_attention_sdpa(
|
|
self,
|
|
is_causal: bool,
|
|
compiled: bool,
|
|
backend: SDPBackend,
|
|
load_balance: bool,
|
|
rotater: _RotateMethod,
|
|
test_forward_only: bool,
|
|
use_context: bool,
|
|
) -> None:
|
|
def fn_eval(fn, *args, **kwargs):
|
|
if test_forward_only:
|
|
with torch.no_grad():
|
|
return fn(*args, **kwargs)
|
|
else:
|
|
out = fn(*args, **kwargs)
|
|
out.sum().backward()
|
|
return out
|
|
|
|
if load_balance and not is_causal:
|
|
return
|
|
|
|
set_rotate_method(rotater_enum_to_str[rotater])
|
|
self.assertEqual(_cp_options.rotate_method, rotater)
|
|
device_mesh = DeviceMesh(self.device_type, torch.arange(0, self.world_size))
|
|
dtype = torch.bfloat16
|
|
bs = 8
|
|
seq_length = 1024
|
|
seq_dim = 2
|
|
dim = 32
|
|
nheads = 8
|
|
torch.manual_seed(10)
|
|
dtype = (
|
|
torch.bfloat16
|
|
if backend == SDPBackend.FLASH_ATTENTION
|
|
or backend == SDPBackend.CUDNN_ATTENTION
|
|
else torch.float32
|
|
)
|
|
|
|
q, k, v = [
|
|
torch.rand(
|
|
(bs, nheads, seq_length * self.world_size, dim),
|
|
device=self.device_type,
|
|
dtype=dtype,
|
|
requires_grad=True,
|
|
)
|
|
for _ in range(3)
|
|
]
|
|
|
|
# Ensure all ranks have the same initialization data.
|
|
with torch.no_grad():
|
|
dist.broadcast(q, src=0)
|
|
dist.broadcast(k, src=0)
|
|
dist.broadcast(v, src=0)
|
|
|
|
with sdpa_kernel(backend):
|
|
out = fn_eval(F.scaled_dot_product_attention, q, k, v, is_causal=is_causal)
|
|
|
|
cp_q, cp_k, cp_v = [target.detach().clone() for target in [q, k, v]]
|
|
cp_out, cp_dq, cp_dk, cp_dv = self._ring_attention_sdpa(
|
|
cp_q,
|
|
cp_k,
|
|
cp_v,
|
|
fn_eval=fn_eval,
|
|
mesh=device_mesh,
|
|
seq_dim=seq_dim,
|
|
is_causal=is_causal,
|
|
compiled=compiled,
|
|
backend=backend,
|
|
rotater=rotater,
|
|
test_forward_only=test_forward_only,
|
|
load_balance=load_balance,
|
|
use_context=use_context,
|
|
)
|
|
|
|
# Due to numerical error, we need to choose different atol for different
|
|
# attention kernels
|
|
(cp_out,) = context_parallel_unshard(device_mesh, [cp_out], [seq_dim])
|
|
atol = (
|
|
2e-06
|
|
if backend == SDPBackend.EFFICIENT_ATTENTION
|
|
else 8e-3 * self.world_size
|
|
)
|
|
rtol = (
|
|
1e-05
|
|
if backend == SDPBackend.EFFICIENT_ATTENTION
|
|
else 1e-3 * self.world_size
|
|
)
|
|
torch.testing.assert_close(out, cp_out, atol=atol, rtol=rtol)
|
|
|
|
if test_forward_only:
|
|
return
|
|
|
|
cp_dq, cp_dk, cp_dv = context_parallel_unshard(
|
|
device_mesh,
|
|
[cp_dq, cp_dk, cp_dv],
|
|
[seq_dim] * 3,
|
|
)
|
|
torch.testing.assert_close(q.grad, cp_dq, atol=atol, rtol=rtol)
|
|
torch.testing.assert_close(k.grad, cp_dk, atol=atol, rtol=rtol)
|
|
torch.testing.assert_close(v.grad, cp_dv, atol=atol, rtol=rtol)
|
|
|
|
def test_is_causal_behavior(self) -> None:
|
|
_cp_options.enable_load_balance = False
|
|
self.assertEqual(
|
|
_is_causal_behavior(rank=0, world_size=4, i=0, is_causal=False),
|
|
_CausalBehavior.NOT_IS_CAUSAL,
|
|
)
|
|
|
|
ranks = [
|
|
[_CausalBehavior.IS_CAUSAL, _CausalBehavior.SKIP],
|
|
[_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL],
|
|
]
|
|
for rank, iters in enumerate(ranks):
|
|
for i, behavior in enumerate(iters):
|
|
self.assertEqual(
|
|
_is_causal_behavior(rank=rank, world_size=2, i=i, is_causal=True),
|
|
behavior,
|
|
)
|
|
|
|
_cp_options.enable_load_balance = True
|
|
ranks = [
|
|
[_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL],
|
|
[_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL],
|
|
]
|
|
for rank, iters in enumerate(ranks):
|
|
for i, behavior in enumerate(iters):
|
|
self.assertEqual(
|
|
_is_causal_behavior(rank=rank, world_size=2, i=i, is_causal=True),
|
|
behavior,
|
|
)
|
|
|
|
|
|
# Compile the flex_attention function
|
|
compiled_flex_attention = torch.compile(flex_attention, dynamic=False, fullgraph=True)
|
|
compiled_create_block_mask = torch.compile(
|
|
create_block_mask, dynamic=False, fullgraph=True
|
|
)
|
|
|
|
|
|
def causal_mask(b, h, q_idx, kv_idx):
|
|
return q_idx >= kv_idx
|
|
|
|
|
|
# copied from https://github.com/meta-pytorch/attention-gym/blob/main/attn_gym/masks/document_mask.py
|
|
def generate_random_lengths(total_length, num_documents) -> list[int]:
|
|
# Initialize all lengths to 1 to ensure each document has at least one token
|
|
lengths = [1] * num_documents
|
|
remaining_length = total_length - num_documents
|
|
|
|
# Randomly distribute the remaining length
|
|
for _ in range(remaining_length):
|
|
index = random.randint(0, num_documents - 1)
|
|
lengths[index] += 1
|
|
|
|
return lengths
|
|
|
|
|
|
def generate_random_lengths_in_chunks(
|
|
total_length, num_documents, chunk_size
|
|
) -> list[int]:
|
|
# Generate a list of random document lengths so that each document contains
|
|
# some number of chunks of size `chunk_size`. This means each document's length
|
|
# must be a multiple of `chunk_size`. Besides, the lengths of all the documents
|
|
# sum up to `total_length`.
|
|
num_chunks = total_length // chunk_size
|
|
assert total_length % chunk_size == 0 and num_chunks >= num_documents
|
|
|
|
num_chunks_per_document = [1] * num_documents
|
|
remaining_chunks = num_chunks - num_documents
|
|
# Randomly distribute the remaining chunks
|
|
for _ in range(remaining_chunks):
|
|
index = random.randint(0, num_documents - 1) # document_id
|
|
num_chunks_per_document[index] += 1
|
|
|
|
return [num_chunks * chunk_size for num_chunks in num_chunks_per_document]
|
|
|
|
|
|
def length_to_offsets(lengths: list[list[int]], device: str | torch.device) -> Tensor:
|
|
"""Converts a list of lengths to a list of offsets.
|
|
|
|
Args:
|
|
lengths: A list of lengths.
|
|
|
|
"""
|
|
offsets = [[0] + lengths_in_batch for lengths_in_batch in lengths]
|
|
offsets = torch.tensor(offsets, device=device, dtype=torch.int32)
|
|
offsets = torch.cumsum(offsets, dim=-1)
|
|
return offsets
|
|
|
|
|
|
def _offsets_to_doc_ids_tensor(offsets):
|
|
doc_ids = []
|
|
device = offsets.device
|
|
for batch_idx in range(offsets.size(0)):
|
|
counts = offsets[batch_idx][1:] - offsets[batch_idx][:-1]
|
|
doc_id = torch.repeat_interleave(
|
|
torch.arange(len(counts), device=device, dtype=torch.int32), counts
|
|
)
|
|
doc_ids.append(doc_id)
|
|
|
|
return torch.stack(doc_ids)
|
|
|
|
|
|
def generate_doc_mask_mod(
|
|
mask_mod: _mask_mod_signature, offsets: Tensor
|
|
) -> _mask_mod_signature:
|
|
"""Generates mask mods that apply to inputs to flex attention in the sequence stacked
|
|
format.
|
|
|
|
Args:
|
|
mask_mod: The mask mod to apply to the documents
|
|
offsets: This tensor should be of shape(num_documents + 1)
|
|
this should contain the cumulative counts of document tokens.
|
|
e.g. if you have 3 documents of length 2, 4, 3 then
|
|
offsets = [0, 2, 6, 9]
|
|
|
|
Note:
|
|
What is the sequence stacked format? When assembling batches of inputs, we
|
|
take multiple sequences and stack them together to form 1 large sequence. We then
|
|
use masking to ensure that the attention scores are only applied to tokens within
|
|
the same document.
|
|
"""
|
|
document_id = _offsets_to_doc_ids_tensor(offsets)
|
|
|
|
def doc_mask_mod(b, h, q_idx, kv_idx):
|
|
same_doc = document_id[b][q_idx] == document_id[b][kv_idx]
|
|
q_logical = q_idx - offsets[b, document_id[b, q_idx]]
|
|
kv_logical = kv_idx - offsets[b, document_id[b, kv_idx]]
|
|
inner_mask = mask_mod(b, h, q_logical, kv_logical)
|
|
return same_doc & inner_mask
|
|
|
|
return doc_mask_mod
|
|
|
|
|
|
class FlexAttentionWrapper(torch.nn.Module):
|
|
_flex_attn: ClassVar[Callable] = torch.compile(flex_attention)
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self, *args: object, **kwargs: object
|
|
) -> [
|
|
torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
|
tuple[torch.Tensor, AuxOutput],
|
|
]:
|
|
return FlexAttentionWrapper._flex_attn(*args, **kwargs)
|
|
|
|
|
|
class CPFlexAttentionTest(DTensorTestBase):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return 2
|
|
|
|
def _test_cp_flex_attention(
|
|
self,
|
|
*,
|
|
qkv_size: int,
|
|
B: int = 1,
|
|
block_mask,
|
|
lb_type: str,
|
|
document_lengths: Optional[list[list[int]]] = None,
|
|
) -> None:
|
|
torch.use_deterministic_algorithms(True)
|
|
torch.cuda.manual_seed(1234)
|
|
|
|
dtype = torch.float32
|
|
bs = B if B > 1 else 8
|
|
dim = 32
|
|
nheads = 8
|
|
seq_dim = 2
|
|
lb = self._get_load_balancer(
|
|
lb_type,
|
|
{
|
|
"seq_length": qkv_size,
|
|
"document_lengths": document_lengths,
|
|
"block_mask": block_mask,
|
|
},
|
|
)
|
|
|
|
qkv = [
|
|
torch.rand(
|
|
(bs, nheads, qkv_size, dim),
|
|
device=self.device_type,
|
|
dtype=dtype,
|
|
requires_grad=True,
|
|
)
|
|
for _ in range(3)
|
|
]
|
|
|
|
expect_out, expect_aux = compiled_flex_attention(
|
|
*qkv, block_mask=block_mask, return_aux=AuxRequest(lse=True)
|
|
)
|
|
expect_out.sum().backward()
|
|
|
|
# Prepare the required global vars for CP+Flex:
|
|
device_mesh = init_device_mesh(
|
|
device_type=self.device_type,
|
|
mesh_shape=(self.world_size,),
|
|
mesh_dim_names=("cp",),
|
|
)
|
|
|
|
flex_attention_wrapper_module = FlexAttentionWrapper()
|
|
cp_plan = _ContextParallel(
|
|
seq_dim=seq_dim,
|
|
attention_type=_ContextParallel.AttentionType.FLEX,
|
|
)
|
|
parallelize_module(
|
|
flex_attention_wrapper_module,
|
|
device_mesh,
|
|
cp_plan,
|
|
)
|
|
|
|
*cp_qkv, cp_block_mask = _context_parallel_shard(
|
|
device_mesh,
|
|
[t.detach().clone() for t in qkv] + [block_mask],
|
|
[seq_dim] * 4,
|
|
load_balancer=lb,
|
|
)
|
|
for t in cp_qkv:
|
|
t.requires_grad = True
|
|
|
|
cp_out, cp_aux = flex_attention_wrapper_module(
|
|
*cp_qkv,
|
|
block_mask=cp_block_mask,
|
|
return_aux=AuxRequest(lse=True),
|
|
)
|
|
|
|
# backward run
|
|
cp_out.sum().backward()
|
|
|
|
atol = 2e-06
|
|
rtol = 1e-05
|
|
# unshard the output
|
|
cp_out, cp_lse = context_parallel_unshard(
|
|
device_mesh,
|
|
buffers=[cp_out, cp_aux.lse],
|
|
seq_dims=[seq_dim] * 2,
|
|
load_balancer=lb,
|
|
)
|
|
torch.testing.assert_close(cp_out, expect_out, atol=atol, rtol=rtol)
|
|
torch.testing.assert_close(cp_lse, expect_aux.lse, atol=atol, rtol=rtol)
|
|
|
|
# unshard the gradient
|
|
cp_qkv_grad = context_parallel_unshard(
|
|
device_mesh,
|
|
buffers=[t.grad for t in cp_qkv],
|
|
seq_dims=[seq_dim] * 3,
|
|
load_balancer=lb,
|
|
)
|
|
|
|
qkv_grad = [t.grad for t in qkv]
|
|
for grad, cp_grad in zip(qkv_grad, cp_qkv_grad):
|
|
torch.testing.assert_close(grad, cp_grad, atol=atol, rtol=rtol)
|
|
|
|
def _get_load_balancer(
|
|
self, lb_type: str, kwargs: dict[str, Any]
|
|
) -> Optional[_LoadBalancer]:
|
|
seq_length = kwargs["seq_length"]
|
|
document_lengths = kwargs["document_lengths"]
|
|
block_mask = kwargs["block_mask"]
|
|
|
|
# generate load balancer
|
|
if lb_type == "None":
|
|
load_balancer = None # no load-balance
|
|
elif lb_type == "_HeadTailLoadBalancer":
|
|
assert isinstance(seq_length, int)
|
|
load_balancer = _HeadTailLoadBalancer(
|
|
seq_length, self.world_size, torch.device(self.device_type)
|
|
)
|
|
elif lb_type == "_PerDocumentHeadTailLoadBalancer":
|
|
assert isinstance(document_lengths, list)
|
|
load_balancer = _PerDocumentHeadTailLoadBalancer(
|
|
document_lengths, self.world_size, torch.device(self.device_type)
|
|
)
|
|
elif lb_type == "_PTRRLoadBalancer":
|
|
assert isinstance(block_mask, BlockMask)
|
|
load_balancer = _PTRRLoadBalancer(
|
|
block_mask,
|
|
self.world_size,
|
|
)
|
|
else:
|
|
raise ValueError(f"load_balancer type {lb_type} is not supported!")
|
|
|
|
return load_balancer
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@with_comms
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
|
|
)
|
|
def test_cp_flex_attention_causal_mask(self) -> None:
|
|
seq_length_list = [256 * self.world_size, 2048]
|
|
load_balance_type_list = [
|
|
"None",
|
|
"_HeadTailLoadBalancer",
|
|
"_PTRRLoadBalancer",
|
|
]
|
|
|
|
# NOTE: Each (seq_len, load_balance_type) tuple introduces 2
|
|
# create_block_mask compilations: 1 for single-rank flex_attention and 1 for
|
|
# CP flex_attention. In order to avoid the "exceeds_recompile_limit" error,
|
|
# we need to increase the cache_size_limit to 2 * num_of_sub_test_runs which
|
|
# will be the total number of compilations in our test case.
|
|
torch._dynamo.config.cache_size_limit = (len(seq_length_list) + 1) * (
|
|
1 + len(load_balance_type_list)
|
|
)
|
|
|
|
for qkv_size, lb_type in itertools.product(
|
|
seq_length_list, load_balance_type_list
|
|
):
|
|
block_mask = compiled_create_block_mask(
|
|
causal_mask,
|
|
B=1,
|
|
H=1,
|
|
Q_LEN=qkv_size,
|
|
KV_LEN=qkv_size,
|
|
device=self.device_type,
|
|
)
|
|
self._test_cp_flex_attention(
|
|
qkv_size=qkv_size, block_mask=block_mask, lb_type=lb_type
|
|
)
|
|
|
|
# NOTE: Context Parallel should not be used for small attentions (block_size < 128)
|
|
qkv_size = 64 * self.world_size
|
|
block_mask = compiled_create_block_mask(
|
|
causal_mask,
|
|
B=1,
|
|
H=1,
|
|
Q_LEN=qkv_size,
|
|
KV_LEN=qkv_size,
|
|
device=self.device_type,
|
|
)
|
|
|
|
for lb_type in ["None", "_HeadTailLoadBalancer"]:
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
f"Q_LEN {qkv_size} is not divisible",
|
|
):
|
|
self._test_cp_flex_attention(
|
|
qkv_size=qkv_size, block_mask=block_mask, lb_type=lb_type
|
|
)
|
|
|
|
for lb_type in ["_PTRRLoadBalancer"]:
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
"must be divisible by group_size",
|
|
):
|
|
self._test_cp_flex_attention(
|
|
qkv_size=qkv_size, block_mask=block_mask, lb_type=lb_type
|
|
)
|
|
|
|
# TODO: merge with the above test
|
|
@skip_if_lt_x_gpu(2)
|
|
@with_comms
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
|
|
)
|
|
def test_cp_flex_attention_document_mask(self) -> None:
|
|
random.seed(10)
|
|
|
|
# parameters for testing
|
|
doc_count = 28
|
|
batch_size_list = [2, 4, 8]
|
|
max_seq_len_list = [
|
|
256 * self.world_size,
|
|
2048,
|
|
# 128 * self.world_size # NOTE: Mismatched elements: 8 / 131072 (0.0%),
|
|
]
|
|
load_balance_type = [
|
|
"None",
|
|
"_HeadTailLoadBalancer",
|
|
"_PerDocumentHeadTailLoadBalancer",
|
|
"_PTRRLoadBalancer",
|
|
]
|
|
|
|
# NOTE: Each (batch_size, seq_len, load_balance_type) tuple introduces 2
|
|
# create_block_mask compilations: 1 for single-rank flex_attention and 1 for
|
|
# CP flex_attention. In order to avoid the "exceeds_recompile_limit" error,
|
|
# we need to increase the cache_size_limit to 2 * num_of_sub_test_runs which
|
|
# will be the total number of compilations in our test case.
|
|
torch._dynamo.config.cache_size_limit = (
|
|
2 * len(batch_size_list) * len(max_seq_len_list) * len(load_balance_type)
|
|
)
|
|
|
|
# TODO: change this for-loop to run_subtests
|
|
# Use a for-loop instead of run_subtests because we need to intialize the mask
|
|
# for each subtest. This can be baked into self._test_cp_flex_attention as
|
|
# a str argument denoting mask type.
|
|
for batch_size, max_seq_len, lb_type in itertools.product(
|
|
batch_size_list,
|
|
max_seq_len_list,
|
|
load_balance_type,
|
|
):
|
|
# initialize document mask
|
|
lengths = [
|
|
(
|
|
generate_random_lengths_in_chunks(
|
|
max_seq_len, doc_count, chunk_size=2 * self.world_size
|
|
)
|
|
if lb_type == "_PerDocumentHeadTailLoadBalancer"
|
|
else generate_random_lengths(max_seq_len, doc_count)
|
|
)
|
|
for _ in range(batch_size)
|
|
]
|
|
offsets = length_to_offsets(lengths, self.device_type)
|
|
document_causal_mask = generate_doc_mask_mod(causal_mask, offsets)
|
|
block_mask = compiled_create_block_mask(
|
|
document_causal_mask,
|
|
B=batch_size,
|
|
H=1,
|
|
Q_LEN=max_seq_len,
|
|
KV_LEN=max_seq_len,
|
|
device=self.device_type,
|
|
)
|
|
|
|
self._test_cp_flex_attention(
|
|
qkv_size=max_seq_len,
|
|
B=batch_size,
|
|
lb_type=lb_type,
|
|
block_mask=block_mask,
|
|
document_lengths=lengths,
|
|
)
|
|
|
|
|
|
class TestCPCustomOps(DTensorTestBase):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return 2
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@with_comms
|
|
def test_flex_cp_custom_op(self) -> None:
|
|
mesh = init_device_mesh(
|
|
device_type=self.device_type,
|
|
mesh_shape=(self.world_size,),
|
|
mesh_dim_names=("cp",),
|
|
)
|
|
examples_k_v = [
|
|
(
|
|
torch.randn(8, 8, 8, 8, device=self.device_type),
|
|
torch.randn(8, 8, 8, 8, device=self.device_type),
|
|
2,
|
|
c10d._get_process_group_name(mesh.get_group()),
|
|
),
|
|
(
|
|
torch.randn(8, 8, 8, 8, device=self.device_type, requires_grad=True),
|
|
torch.randn(8, 8, 8, 8, device=self.device_type, requires_grad=True),
|
|
2,
|
|
c10d._get_process_group_name(mesh.get_group()),
|
|
),
|
|
]
|
|
for example in examples_k_v:
|
|
torch.library.opcheck(flex_cp_allgather, example)
|
|
|
|
|
|
class TestSharding(DTensorTestBase):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return 2
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@with_comms
|
|
def test_context_parallel_shard(self) -> None:
|
|
B = 4
|
|
seq_len = 32
|
|
|
|
device_mesh = init_device_mesh(
|
|
mesh_shape=(2,), mesh_dim_names=("cp",), device_type=self.device_type
|
|
)
|
|
freqs_cis = torch.arange(0, seq_len, device=self.device_type)
|
|
q = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
|
|
k = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
|
|
v = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
|
|
|
|
load_balancer = _HeadTailLoadBalancer(
|
|
seq_len, self.world_size, torch.device(self.device_type)
|
|
)
|
|
freqs_cis_shard, q_shard, k_shard, v_shard = _context_parallel_shard(
|
|
device_mesh, [freqs_cis, q, k, v], [0, 1, 1, 1], load_balancer=load_balancer
|
|
)
|
|
self.assertEqual(freqs_cis_shard.size(), (seq_len // 2,))
|
|
chunks = freqs_cis.chunk(self.world_size * 2)
|
|
self.assertEqual(
|
|
freqs_cis_shard,
|
|
torch.cat(
|
|
[chunks[self.rank], chunks[self.world_size * 2 - self.rank - 1]], dim=0
|
|
),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|