106 lines
3.4 KiB
Python
106 lines
3.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from typing import Optional, Tuple
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
if n_rep == 1:
|
|
return hidden_states
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
|
|
def sdpa_attention_forward(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
num_key_value_groups: int, # 替换 module.num_key_value_groups
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
dropout: float = 0.0,
|
|
scaling: Optional[float] = None,
|
|
is_causal: Optional[bool] = None,
|
|
**kwargs,
|
|
) -> Tuple[torch.Tensor, None]:
|
|
# 替换 module.num_key_value_groups
|
|
if num_key_value_groups != 1:
|
|
key = repeat_kv(key, num_key_value_groups)
|
|
value = repeat_kv(value, num_key_value_groups)
|
|
|
|
causal_mask = attention_mask
|
|
if attention_mask is not None and causal_mask.ndim == 4:
|
|
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
|
|
|
causal_mask = causal_mask.to(torch.bool)
|
|
|
|
# 确保输入连续
|
|
query = query.contiguous()
|
|
key = key.contiguous()
|
|
value = value.contiguous()
|
|
|
|
# 自动推断因果性
|
|
if is_causal is None:
|
|
is_causal = query.shape[2] > 1 and causal_mask is None
|
|
|
|
# 处理JIT跟踪
|
|
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
|
|
is_causal = is_causal.item()
|
|
|
|
# 计算缩放因子 (如果未提供)
|
|
if scaling is None:
|
|
head_dim = query.size(-1)
|
|
scaling = 1.0 / (head_dim ** 0.5)
|
|
|
|
# 调用PyTorch内置的SDPA
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
query,
|
|
key,
|
|
value,
|
|
attn_mask=causal_mask,
|
|
dropout_p=dropout,
|
|
scale=scaling,
|
|
is_causal=is_causal,
|
|
)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
return attn_output, None
|
|
|
|
# 测试参数配置 (与之前一致)
|
|
batch_size = 2
|
|
seq_len = 1024
|
|
num_attention_heads = 4
|
|
num_key_value_heads = 2 # GQA分组数
|
|
head_dim = 8
|
|
scaling_factor = 1 / (head_dim ** 0.5)
|
|
num_key_value_groups = num_attention_heads // num_key_value_heads
|
|
|
|
device = "cuda:0"
|
|
dtype = torch.bfloat16
|
|
|
|
# 创建输入张量
|
|
query = torch.randn(batch_size, num_attention_heads, seq_len, head_dim).to(device).to(dtype)
|
|
key = torch.randn(batch_size, num_key_value_heads, seq_len, head_dim).to(device).to(dtype)
|
|
value = torch.randn(batch_size, num_key_value_heads, seq_len, head_dim).to(device).to(dtype)
|
|
|
|
# 创建因果注意力掩码
|
|
causal_mask = torch.full(
|
|
(seq_len, seq_len),
|
|
fill_value=-torch.finfo(torch.float32).max
|
|
).to(device).to(dtype)
|
|
causal_mask = torch.triu(causal_mask, diagonal=1).to(device).to(dtype)
|
|
attention_mask = causal_mask[None, None, :, :] # 增加batch和head维度
|
|
|
|
# 调用函数 (显式设置is_causal避免自动推断)
|
|
attn_output, _ = sdpa_attention_forward(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
num_key_value_groups=num_key_value_groups, # =2
|
|
attention_mask=attention_mask,
|
|
scaling=scaling_factor,
|
|
dropout=0.0,
|
|
is_causal=False # 因为提供了显式掩码
|
|
)
|
|
|
|
# 检查输出形状
|
|
print("SDPA Attention Output Shape:", attn_output.shape) # 应为 [2, 10, 4, 8]
|
|
print("SDPA Attention Output device:", attn_output.device)
|
|
print("SDPA Attention Output dtype:", attn_output.dtype) |