!237 Add test cases for zigzag ring attn
Merge pull request !237 from lynn/master
This commit is contained in:
205
tests/unit/flow/model/test_zigzag_ring_flash_attn_varlen_func.py
Normal file
205
tests/unit/flow/model/test_zigzag_ring_flash_attn_varlen_func.py
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
#
|
||||||
|
# openMind is licensed under Mulan PSL v2.
|
||||||
|
# You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||||
|
# You may obtain a copy of Mulan PSL v2 at:
|
||||||
|
#
|
||||||
|
# http://license.coscl.org.cn/MulanPSL2
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||||
|
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||||
|
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||||
|
# See the Mulan PSL v2 for more details.
|
||||||
|
|
||||||
|
"""
|
||||||
|
torchrun --nproc_per_node=4 test_zigzag_ring_flash_attn_varlen_func.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_npu # noqa: F401
|
||||||
|
|
||||||
|
is_npu_available = True
|
||||||
|
except ImportError:
|
||||||
|
print("Failed to import torch_npu.")
|
||||||
|
is_npu_available = False
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from openmind.flow.model.context_parallel.zigzag_ring_flash_attn_varlen import (
|
||||||
|
zigzag_ring_flash_attn_varlen_func,
|
||||||
|
flatten_softmax,
|
||||||
|
get_sub_seq_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_softmax_value(softmax_value, cu_seqlens):
|
||||||
|
values = []
|
||||||
|
for i in range(len(cu_seqlens) - 1):
|
||||||
|
start, end = cu_seqlens[i], cu_seqlens[i + 1]
|
||||||
|
value = softmax_value[start:end]
|
||||||
|
values.append(value)
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(rank, seed=42):
|
||||||
|
seed = rank + seed
|
||||||
|
random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if is_npu_available:
|
||||||
|
torch.npu.manual_seed(seed)
|
||||||
|
torch.npu.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def log(msg, a, rank0_only=False):
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
rank = dist.get_rank()
|
||||||
|
if rank0_only:
|
||||||
|
if rank == 0:
|
||||||
|
print(
|
||||||
|
f"{msg}: " f"max {a.abs().max().item():.3g}, " f"mean {a.abs().mean().item():.3g}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
for i in range(world_size):
|
||||||
|
if i == rank:
|
||||||
|
if rank == 0:
|
||||||
|
print(f"{msg}:")
|
||||||
|
print(
|
||||||
|
f"[{rank}] " f"max {a.abs().max().item():.3g}, " f"mean {a.abs().mean().item():.3g}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_local(value, cu_seqlens, rank, world_size):
|
||||||
|
local_values = []
|
||||||
|
for i in range(len(cu_seqlens) - 1):
|
||||||
|
start, end = cu_seqlens[i], cu_seqlens[i + 1]
|
||||||
|
local_value = value[start:end].chunk(2 * world_size, dim=0)
|
||||||
|
local_values.extend(
|
||||||
|
[
|
||||||
|
local_value[rank].detach().clone(),
|
||||||
|
local_value[2 * world_size - 1 - rank].detach().clone(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return torch.cat(local_values, dim=0).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dist.init_process_group("hccl")
|
||||||
|
rank = dist.get_rank()
|
||||||
|
set_seed(rank)
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = torch.device(f"npu:{rank}")
|
||||||
|
|
||||||
|
nheads = 5
|
||||||
|
d = 128
|
||||||
|
dropout_p = 0
|
||||||
|
causal = True
|
||||||
|
|
||||||
|
cu_seqlens = [0, 120, 1248, 4232]
|
||||||
|
cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
|
||||||
|
sub_seq_lens = get_sub_seq_lens(cu_seqlens)
|
||||||
|
total_length = cu_seqlens[-1]
|
||||||
|
|
||||||
|
assert torch.all(cu_seqlens_tensor % world_size == 0)
|
||||||
|
assert d % 8 == 0
|
||||||
|
|
||||||
|
q = torch.randn(total_length, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||||
|
k = torch.randn(total_length, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||||
|
v = torch.randn(total_length, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||||
|
dist.broadcast(q, src=0)
|
||||||
|
dist.broadcast(k, src=0)
|
||||||
|
dist.broadcast(v, src=0)
|
||||||
|
|
||||||
|
dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype)
|
||||||
|
dist.broadcast(dout, src=0)
|
||||||
|
|
||||||
|
local_cu_seqlens_tensor = cu_seqlens_tensor // world_size
|
||||||
|
local_sub_seq_lens = get_sub_seq_lens(local_cu_seqlens_tensor)
|
||||||
|
|
||||||
|
local_q = extract_local(q, cu_seqlens, rank, world_size)
|
||||||
|
local_k = extract_local(k, cu_seqlens, rank, world_size)
|
||||||
|
local_v = extract_local(v, cu_seqlens, rank, world_size)
|
||||||
|
local_q.requires_grad = True
|
||||||
|
local_k.requires_grad = True
|
||||||
|
local_v.requires_grad = True
|
||||||
|
local_dout = extract_local(dout, cu_seqlens, rank, world_size)
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
if rank == 0:
|
||||||
|
print(">>> forward:")
|
||||||
|
|
||||||
|
attn_mask = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
|
||||||
|
out, softmax_max, softmax_sum, _, _, _, _ = torch_npu.npu_fusion_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
head_num=q.shape[1],
|
||||||
|
input_layout="TND",
|
||||||
|
atten_mask=attn_mask,
|
||||||
|
scale=d ** (-0.5),
|
||||||
|
actual_seq_qlen=tuple(cu_seqlens_tensor[1:].cpu().numpy().tolist()),
|
||||||
|
actual_seq_kvlen=tuple(cu_seqlens_tensor[1:].cpu().numpy().tolist()),
|
||||||
|
sparse_mode=3,
|
||||||
|
keep_prob=1.0 - dropout_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
local_out = extract_local(out, cu_seqlens, rank, world_size)
|
||||||
|
|
||||||
|
softmax_max = flatten_softmax(softmax_max, sub_seq_lens)
|
||||||
|
local_softmax_max_list = extract_softmax_value(softmax_max, cu_seqlens)
|
||||||
|
softmax_sum = flatten_softmax(softmax_sum, sub_seq_lens)
|
||||||
|
local_softmax_sum_list = extract_softmax_value(softmax_sum, cu_seqlens)
|
||||||
|
|
||||||
|
ring_out, ring_softmax_max, ring_softmax_sum = zigzag_ring_flash_attn_varlen_func(
|
||||||
|
local_q,
|
||||||
|
local_k,
|
||||||
|
local_v,
|
||||||
|
local_cu_seqlens_tensor,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
ring_softmax_max = flatten_softmax(ring_softmax_max, local_sub_seq_lens)
|
||||||
|
ring_softmax_max_list = extract_softmax_value(ring_softmax_max, local_cu_seqlens_tensor)
|
||||||
|
ring_softmax_sum = flatten_softmax(ring_softmax_sum, local_sub_seq_lens)
|
||||||
|
ring_softmax_sum_list = extract_softmax_value(ring_softmax_sum, local_cu_seqlens_tensor)
|
||||||
|
|
||||||
|
log("out diff", local_out - ring_out)
|
||||||
|
for i, (lsm, ring_lsm) in enumerate(zip(local_softmax_max_list, ring_softmax_max_list)):
|
||||||
|
local_lsm = lsm.chunk(2 * world_size, dim=0)
|
||||||
|
local_lsm = torch.cat([local_lsm[rank], local_lsm[2 * world_size - 1 - rank]], dim=0)
|
||||||
|
log(f"softmax max diff {i}", local_lsm - ring_lsm)
|
||||||
|
for i, (lss, ring_lss) in enumerate(zip(local_softmax_sum_list, ring_softmax_sum_list)):
|
||||||
|
local_lss = lss.chunk(2 * world_size, dim=0)
|
||||||
|
local_lss = torch.cat([local_lss[rank], local_lss[2 * world_size - 1 - rank]], dim=0)
|
||||||
|
log(f"softmax sum diff {i}", local_lss - ring_lss)
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
if rank == 0:
|
||||||
|
print(">>> backward:")
|
||||||
|
|
||||||
|
out.backward(dout)
|
||||||
|
dq = q.grad
|
||||||
|
dk = k.grad
|
||||||
|
dv = v.grad
|
||||||
|
local_dq = extract_local(dq, cu_seqlens, rank, world_size)
|
||||||
|
local_dk = extract_local(dk, cu_seqlens, rank, world_size)
|
||||||
|
local_dv = extract_local(dv, cu_seqlens, rank, world_size)
|
||||||
|
|
||||||
|
ring_out.backward(local_dout)
|
||||||
|
ring_dq = local_q.grad
|
||||||
|
ring_dk = local_k.grad
|
||||||
|
ring_dv = local_v.grad
|
||||||
|
|
||||||
|
log("dq diff", local_dq - ring_dq)
|
||||||
|
log("dk diff", local_dk - ring_dk)
|
||||||
|
log("dv diff", local_dv - ring_dv)
|
||||||
|
|
||||||
|
dist.destroy_process_group()
|
Reference in New Issue
Block a user