!237 Add test cases for zigzag ring attn
Merge pull request !237 from lynn/master
This commit is contained in:
205
tests/unit/flow/model/test_zigzag_ring_flash_attn_varlen_func.py
Normal file
205
tests/unit/flow/model/test_zigzag_ring_flash_attn_varlen_func.py
Normal file
@ -0,0 +1,205 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
#
|
||||
# openMind is licensed under Mulan PSL v2.
|
||||
# You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
# You may obtain a copy of Mulan PSL v2 at:
|
||||
#
|
||||
# http://license.coscl.org.cn/MulanPSL2
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
# See the Mulan PSL v2 for more details.
|
||||
|
||||
"""
|
||||
torchrun --nproc_per_node=4 test_zigzag_ring_flash_attn_varlen_func.py
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
is_npu_available = True
|
||||
except ImportError:
|
||||
print("Failed to import torch_npu.")
|
||||
is_npu_available = False
|
||||
import torch.distributed as dist
|
||||
|
||||
from openmind.flow.model.context_parallel.zigzag_ring_flash_attn_varlen import (
|
||||
zigzag_ring_flash_attn_varlen_func,
|
||||
flatten_softmax,
|
||||
get_sub_seq_lens,
|
||||
)
|
||||
|
||||
|
||||
def extract_softmax_value(softmax_value, cu_seqlens):
|
||||
values = []
|
||||
for i in range(len(cu_seqlens) - 1):
|
||||
start, end = cu_seqlens[i], cu_seqlens[i + 1]
|
||||
value = softmax_value[start:end]
|
||||
values.append(value)
|
||||
return values
|
||||
|
||||
|
||||
def set_seed(rank, seed=42):
|
||||
seed = rank + seed
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if is_npu_available:
|
||||
torch.npu.manual_seed(seed)
|
||||
torch.npu.manual_seed_all(seed)
|
||||
|
||||
|
||||
def log(msg, a, rank0_only=False):
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
if rank0_only:
|
||||
if rank == 0:
|
||||
print(
|
||||
f"{msg}: " f"max {a.abs().max().item():.3g}, " f"mean {a.abs().mean().item():.3g}",
|
||||
flush=True,
|
||||
)
|
||||
return
|
||||
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
if rank == 0:
|
||||
print(f"{msg}:")
|
||||
print(
|
||||
f"[{rank}] " f"max {a.abs().max().item():.3g}, " f"mean {a.abs().mean().item():.3g}",
|
||||
flush=True,
|
||||
)
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def extract_local(value, cu_seqlens, rank, world_size):
|
||||
local_values = []
|
||||
for i in range(len(cu_seqlens) - 1):
|
||||
start, end = cu_seqlens[i], cu_seqlens[i + 1]
|
||||
local_value = value[start:end].chunk(2 * world_size, dim=0)
|
||||
local_values.extend(
|
||||
[
|
||||
local_value[rank].detach().clone(),
|
||||
local_value[2 * world_size - 1 - rank].detach().clone(),
|
||||
]
|
||||
)
|
||||
return torch.cat(local_values, dim=0).contiguous()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dist.init_process_group("hccl")
|
||||
rank = dist.get_rank()
|
||||
set_seed(rank)
|
||||
world_size = dist.get_world_size()
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device(f"npu:{rank}")
|
||||
|
||||
nheads = 5
|
||||
d = 128
|
||||
dropout_p = 0
|
||||
causal = True
|
||||
|
||||
cu_seqlens = [0, 120, 1248, 4232]
|
||||
cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
|
||||
sub_seq_lens = get_sub_seq_lens(cu_seqlens)
|
||||
total_length = cu_seqlens[-1]
|
||||
|
||||
assert torch.all(cu_seqlens_tensor % world_size == 0)
|
||||
assert d % 8 == 0
|
||||
|
||||
q = torch.randn(total_length, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
k = torch.randn(total_length, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
v = torch.randn(total_length, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
dist.broadcast(q, src=0)
|
||||
dist.broadcast(k, src=0)
|
||||
dist.broadcast(v, src=0)
|
||||
|
||||
dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype)
|
||||
dist.broadcast(dout, src=0)
|
||||
|
||||
local_cu_seqlens_tensor = cu_seqlens_tensor // world_size
|
||||
local_sub_seq_lens = get_sub_seq_lens(local_cu_seqlens_tensor)
|
||||
|
||||
local_q = extract_local(q, cu_seqlens, rank, world_size)
|
||||
local_k = extract_local(k, cu_seqlens, rank, world_size)
|
||||
local_v = extract_local(v, cu_seqlens, rank, world_size)
|
||||
local_q.requires_grad = True
|
||||
local_k.requires_grad = True
|
||||
local_v.requires_grad = True
|
||||
local_dout = extract_local(dout, cu_seqlens, rank, world_size)
|
||||
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
print(">>> forward:")
|
||||
|
||||
attn_mask = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
|
||||
out, softmax_max, softmax_sum, _, _, _, _ = torch_npu.npu_fusion_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
head_num=q.shape[1],
|
||||
input_layout="TND",
|
||||
atten_mask=attn_mask,
|
||||
scale=d ** (-0.5),
|
||||
actual_seq_qlen=tuple(cu_seqlens_tensor[1:].cpu().numpy().tolist()),
|
||||
actual_seq_kvlen=tuple(cu_seqlens_tensor[1:].cpu().numpy().tolist()),
|
||||
sparse_mode=3,
|
||||
keep_prob=1.0 - dropout_p,
|
||||
)
|
||||
|
||||
local_out = extract_local(out, cu_seqlens, rank, world_size)
|
||||
|
||||
softmax_max = flatten_softmax(softmax_max, sub_seq_lens)
|
||||
local_softmax_max_list = extract_softmax_value(softmax_max, cu_seqlens)
|
||||
softmax_sum = flatten_softmax(softmax_sum, sub_seq_lens)
|
||||
local_softmax_sum_list = extract_softmax_value(softmax_sum, cu_seqlens)
|
||||
|
||||
ring_out, ring_softmax_max, ring_softmax_sum = zigzag_ring_flash_attn_varlen_func(
|
||||
local_q,
|
||||
local_k,
|
||||
local_v,
|
||||
local_cu_seqlens_tensor,
|
||||
dropout_p=dropout_p,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
ring_softmax_max = flatten_softmax(ring_softmax_max, local_sub_seq_lens)
|
||||
ring_softmax_max_list = extract_softmax_value(ring_softmax_max, local_cu_seqlens_tensor)
|
||||
ring_softmax_sum = flatten_softmax(ring_softmax_sum, local_sub_seq_lens)
|
||||
ring_softmax_sum_list = extract_softmax_value(ring_softmax_sum, local_cu_seqlens_tensor)
|
||||
|
||||
log("out diff", local_out - ring_out)
|
||||
for i, (lsm, ring_lsm) in enumerate(zip(local_softmax_max_list, ring_softmax_max_list)):
|
||||
local_lsm = lsm.chunk(2 * world_size, dim=0)
|
||||
local_lsm = torch.cat([local_lsm[rank], local_lsm[2 * world_size - 1 - rank]], dim=0)
|
||||
log(f"softmax max diff {i}", local_lsm - ring_lsm)
|
||||
for i, (lss, ring_lss) in enumerate(zip(local_softmax_sum_list, ring_softmax_sum_list)):
|
||||
local_lss = lss.chunk(2 * world_size, dim=0)
|
||||
local_lss = torch.cat([local_lss[rank], local_lss[2 * world_size - 1 - rank]], dim=0)
|
||||
log(f"softmax sum diff {i}", local_lss - ring_lss)
|
||||
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
print(">>> backward:")
|
||||
|
||||
out.backward(dout)
|
||||
dq = q.grad
|
||||
dk = k.grad
|
||||
dv = v.grad
|
||||
local_dq = extract_local(dq, cu_seqlens, rank, world_size)
|
||||
local_dk = extract_local(dk, cu_seqlens, rank, world_size)
|
||||
local_dv = extract_local(dv, cu_seqlens, rank, world_size)
|
||||
|
||||
ring_out.backward(local_dout)
|
||||
ring_dq = local_q.grad
|
||||
ring_dk = local_k.grad
|
||||
ring_dv = local_v.grad
|
||||
|
||||
log("dq diff", local_dq - ring_dq)
|
||||
log("dk diff", local_dk - ring_dk)
|
||||
log("dv diff", local_dv - ring_dv)
|
||||
|
||||
dist.destroy_process_group()
|
Reference in New Issue
Block a user