更新 test-sdpa.py
This commit is contained in:
19
test-sdpa.py
19
test-sdpa.py
@ -29,6 +29,8 @@ def sdpa_attention_forward(
|
||||
if attention_mask is not None and causal_mask.ndim == 4:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
causal_mask = causal_mask.to(torch.bool)
|
||||
|
||||
# 确保输入连续
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
@ -68,20 +70,22 @@ num_attention_heads = 4
|
||||
num_key_value_heads = 2 # GQA分组数
|
||||
head_dim = 8
|
||||
scaling_factor = 1 / (head_dim ** 0.5)
|
||||
num_key_value_groups = num_attention_heads // num_key_value_heads
|
||||
|
||||
device = "cuda:0"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# 创建输入张量
|
||||
query = torch.randn(batch_size, num_attention_heads, seq_len, head_dim).to(device)
|
||||
key = torch.randn(batch_size, num_key_value_heads, seq_len, head_dim).to(device)
|
||||
value = torch.randn(batch_size, num_key_value_heads, seq_len, head_dim).to(device)
|
||||
query = torch.randn(batch_size, num_attention_heads, seq_len, head_dim).to(device).to(dtype)
|
||||
key = torch.randn(batch_size, num_key_value_heads, seq_len, head_dim).to(device).to(dtype)
|
||||
value = torch.randn(batch_size, num_key_value_heads, seq_len, head_dim).to(device).to(dtype)
|
||||
|
||||
# 创建因果注意力掩码
|
||||
causal_mask = torch.full(
|
||||
(seq_len, seq_len),
|
||||
fill_value=-torch.finfo(torch.float32).max
|
||||
).to(device)
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1).to(device)
|
||||
).to(device).to(dtype)
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1).to(device).to(dtype)
|
||||
attention_mask = causal_mask[None, None, :, :] # 增加batch和head维度
|
||||
|
||||
# 调用函数 (显式设置is_causal避免自动推断)
|
||||
@ -89,7 +93,7 @@ attn_output, _ = sdpa_attention_forward(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
num_key_value_groups=num_attention_heads // num_key_value_heads, # =2
|
||||
num_key_value_groups=num_key_value_groups, # =2
|
||||
attention_mask=attention_mask,
|
||||
scaling=scaling_factor,
|
||||
dropout=0.0,
|
||||
@ -98,4 +102,5 @@ attn_output, _ = sdpa_attention_forward(
|
||||
|
||||
# 检查输出形状
|
||||
print("SDPA Attention Output Shape:", attn_output.shape) # 应为 [2, 10, 4, 8]
|
||||
print("SDPA Attention Output device:", attn_output.device)
|
||||
print("SDPA Attention Output device:", attn_output.device)
|
||||
print("SDPA Attention Output dtype:", attn_output.dtype)
|
Reference in New Issue
Block a user