更新 test-sdpa.py

This commit is contained in:
2025-06-04 17:11:05 +08:00
parent d89d3edc4a
commit d4e6947b71

View File

@ -29,6 +29,8 @@ def sdpa_attention_forward(
if attention_mask is not None and causal_mask.ndim == 4:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
causal_mask = causal_mask.to(torch.bool)
# 确保输入连续
query = query.contiguous()
key = key.contiguous()
@ -68,20 +70,22 @@ num_attention_heads = 4
num_key_value_heads = 2 # GQA分组数
head_dim = 8
scaling_factor = 1 / (head_dim ** 0.5)
num_key_value_groups = num_attention_heads // num_key_value_heads
device = "cuda:0"
dtype = torch.bfloat16
# 创建输入张量
query = torch.randn(batch_size, num_attention_heads, seq_len, head_dim).to(device)
key = torch.randn(batch_size, num_key_value_heads, seq_len, head_dim).to(device)
value = torch.randn(batch_size, num_key_value_heads, seq_len, head_dim).to(device)
query = torch.randn(batch_size, num_attention_heads, seq_len, head_dim).to(device).to(dtype)
key = torch.randn(batch_size, num_key_value_heads, seq_len, head_dim).to(device).to(dtype)
value = torch.randn(batch_size, num_key_value_heads, seq_len, head_dim).to(device).to(dtype)
# 创建因果注意力掩码
causal_mask = torch.full(
(seq_len, seq_len),
fill_value=-torch.finfo(torch.float32).max
).to(device)
causal_mask = torch.triu(causal_mask, diagonal=1).to(device)
).to(device).to(dtype)
causal_mask = torch.triu(causal_mask, diagonal=1).to(device).to(dtype)
attention_mask = causal_mask[None, None, :, :] # 增加batch和head维度
# 调用函数 (显式设置is_causal避免自动推断)
@ -89,7 +93,7 @@ attn_output, _ = sdpa_attention_forward(
query=query,
key=key,
value=value,
num_key_value_groups=num_attention_heads // num_key_value_heads, # =2
num_key_value_groups=num_key_value_groups, # =2
attention_mask=attention_mask,
scaling=scaling_factor,
dropout=0.0,
@ -98,4 +102,5 @@ attn_output, _ = sdpa_attention_forward(
# 检查输出形状
print("SDPA Attention Output Shape:", attn_output.shape) # 应为 [2, 10, 4, 8]
print("SDPA Attention Output device:", attn_output.device)
print("SDPA Attention Output device:", attn_output.device)
print("SDPA Attention Output dtype:", attn_output.dtype)