Files
test/test-sdpa.py
2025-06-04 17:11:05 +08:00

106 lines
3.4 KiB
Python

import torch
import torch.nn as nn
from typing import Optional, Tuple
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def sdpa_attention_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_key_value_groups: int, # 替换 module.num_key_value_groups
attention_mask: Optional[torch.Tensor] = None,
dropout: float = 0.0,
scaling: Optional[float] = None,
is_causal: Optional[bool] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
# 替换 module.num_key_value_groups
if num_key_value_groups != 1:
key = repeat_kv(key, num_key_value_groups)
value = repeat_kv(value, num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None and causal_mask.ndim == 4:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
causal_mask = causal_mask.to(torch.bool)
# 确保输入连续
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# 自动推断因果性
if is_causal is None:
is_causal = query.shape[2] > 1 and causal_mask is None
# 处理JIT跟踪
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
is_causal = is_causal.item()
# 计算缩放因子 (如果未提供)
if scaling is None:
head_dim = query.size(-1)
scaling = 1.0 / (head_dim ** 0.5)
# 调用PyTorch内置的SDPA
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
# 测试参数配置 (与之前一致)
batch_size = 2
seq_len = 1024
num_attention_heads = 4
num_key_value_heads = 2 # GQA分组数
head_dim = 8
scaling_factor = 1 / (head_dim ** 0.5)
num_key_value_groups = num_attention_heads // num_key_value_heads
device = "cuda:0"
dtype = torch.bfloat16
# 创建输入张量
query = torch.randn(batch_size, num_attention_heads, seq_len, head_dim).to(device).to(dtype)
key = torch.randn(batch_size, num_key_value_heads, seq_len, head_dim).to(device).to(dtype)
value = torch.randn(batch_size, num_key_value_heads, seq_len, head_dim).to(device).to(dtype)
# 创建因果注意力掩码
causal_mask = torch.full(
(seq_len, seq_len),
fill_value=-torch.finfo(torch.float32).max
).to(device).to(dtype)
causal_mask = torch.triu(causal_mask, diagonal=1).to(device).to(dtype)
attention_mask = causal_mask[None, None, :, :] # 增加batch和head维度
# 调用函数 (显式设置is_causal避免自动推断)
attn_output, _ = sdpa_attention_forward(
query=query,
key=key,
value=value,
num_key_value_groups=num_key_value_groups, # =2
attention_mask=attention_mask,
scaling=scaling_factor,
dropout=0.0,
is_causal=False # 因为提供了显式掩码
)
# 检查输出形状
print("SDPA Attention Output Shape:", attn_output.shape) # 应为 [2, 10, 4, 8]
print("SDPA Attention Output device:", attn_output.device)
print("SDPA Attention Output dtype:", attn_output.dtype)