mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Follows #164104 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164206 Approved by: https://github.com/albanD
452 lines
16 KiB
Python
452 lines
16 KiB
Python
# flake8: noqa: B950
|
|
import math
|
|
from collections.abc import Callable
|
|
from typing import Optional, TypeVar
|
|
from typing_extensions import ParamSpec
|
|
|
|
import torch
|
|
from torch.onnx.ops import _dtype_mappings
|
|
|
|
|
|
# Use ParamSpec for better type preservation instead of bound Callable TypeVar
|
|
_P = ParamSpec("_P")
|
|
_R = TypeVar("_R")
|
|
|
|
# ONNX to ATen decomp table
|
|
ONNX_ATEN_DECOMP_TABLE: dict[torch._ops.OpOverload, Callable] = {}
|
|
_ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS = frozenset(
|
|
{
|
|
1, # FLOAT
|
|
10, # FLOAT16
|
|
11, # DOUBLE
|
|
16, # BFLOAT16
|
|
}
|
|
)
|
|
|
|
|
|
def _onnx_op(
|
|
op_type: str, opset_version: int
|
|
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
|
"""Decorator to register an ONNX operator with a custom implementation."""
|
|
|
|
def decorator(func: Callable[_P, _R]) -> Callable[_P, _R]:
|
|
overload = f"opset{opset_version}"
|
|
torch_op = torch.library.custom_op(
|
|
f"onnx::{op_type}.{overload}", mutates_args=()
|
|
)(func)
|
|
ONNX_ATEN_DECOMP_TABLE[getattr(getattr(torch.ops.onnx, op_type), overload)] = (
|
|
func # type: ignore[assignment]
|
|
)
|
|
# Use the same implementation for the fake implementation
|
|
# This is possible because we use pure aten ops to implement ONNX ops
|
|
torch_op.register_fake(func)
|
|
return torch_op # type: ignore[return-value]
|
|
|
|
return decorator
|
|
|
|
|
|
@_onnx_op("RotaryEmbedding", 23)
|
|
def rotary_embedding_23(
|
|
x: torch.Tensor,
|
|
cos_cache: torch.Tensor,
|
|
sin_cache: torch.Tensor,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
*,
|
|
interleaved: bool = False,
|
|
num_heads: int = 0,
|
|
rotary_embedding_dim: int = 0,
|
|
) -> torch.Tensor:
|
|
"""RotaryEmbedding-23 https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html#rotaryembedding-23"""
|
|
# x has shape (batch_size, num_heads, sequence_length, head_size)
|
|
# or (batch_size, sequence_length, hidden_size)
|
|
input_shape = x.shape
|
|
input_rank = len(input_shape)
|
|
batch_size = input_shape[0]
|
|
sequence_length = input_shape[-2]
|
|
|
|
# Validate position_ids and caches match x
|
|
if position_ids is not None:
|
|
torch._check(
|
|
position_ids.dim() == 2,
|
|
lambda: f"position_ids must be 2D when provided. Received shape {position_ids.shape}",
|
|
)
|
|
torch._check(
|
|
position_ids.shape[0] == batch_size,
|
|
lambda: f"position_ids first dim (batch) must match x.shape[0] ({batch_size}). Received {position_ids.shape[0]}",
|
|
)
|
|
torch._check(
|
|
position_ids.shape[1] == sequence_length,
|
|
lambda: f"position_ids second dim (sequence) must match x.shape[-2] ({sequence_length}). Received {position_ids.shape[1]}",
|
|
)
|
|
torch._check(
|
|
cos_cache.dim() == 2 and sin_cache.dim() == 2,
|
|
lambda: "cos_cache/sin_cache must be 2D when position_ids is provided. "
|
|
f"Received cos_cache shape {cos_cache.shape}, sin_cache shape {sin_cache.shape}",
|
|
)
|
|
else:
|
|
torch._check(
|
|
cos_cache.dim() == 3 and sin_cache.dim() == 3,
|
|
lambda: "cos_cache/sin_cache must be 3D when position_ids is not provided. "
|
|
f"Received cos_cache shape {cos_cache.shape}, sin_cache shape {sin_cache.shape}",
|
|
)
|
|
|
|
# First ensure x has shape [batch_size, num_heads, seq_len, head_size]
|
|
# So that the rotation logic can be shared with reshaped 3D inputs
|
|
if input_rank == 4:
|
|
# Reshape from (batch_size, num_heads, seq_len, head_size)
|
|
# to [batch_size, seq_len, num_heads, head_size]
|
|
x = torch.permute(x, (0, 2, 1, 3))
|
|
elif input_rank == 3:
|
|
torch._check(
|
|
num_heads != 0,
|
|
lambda: f"num_heads must be provided for 3D inputs. Received input tensor with shape {input_shape}",
|
|
)
|
|
hidden_size = input_shape[2]
|
|
head_size = hidden_size // num_heads
|
|
new_shape = [batch_size, sequence_length, num_heads, head_size]
|
|
x = torch.reshape(x, new_shape)
|
|
|
|
torch._check(len(x.shape) == 4, lambda: "x should be a 4D tensor by now")
|
|
head_size = x.shape[3]
|
|
|
|
# Fully or partially perform rotation on x based on rotary_embedding_dim attribute
|
|
if rotary_embedding_dim == 0:
|
|
# If rotary_embedding_dim not provided, perform full rotation by using head_size
|
|
rotary_embedding_dim = head_size
|
|
x_rotate = x[:, :, :, :rotary_embedding_dim]
|
|
x_not_rotate = x[:, :, :, rotary_embedding_dim:]
|
|
rotary_embedding_dim_half = rotary_embedding_dim // 2
|
|
|
|
# Retrieve sin and cos caches using position ids
|
|
if position_ids is not None:
|
|
cos = cos_cache[
|
|
position_ids
|
|
] # Shape: [batch_size, sequence_length, head_size/2]
|
|
sin = sin_cache[
|
|
position_ids
|
|
] # Shape: [batch_size, sequence_length, head_size/2]
|
|
else:
|
|
cos = cos_cache # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
|
|
sin = sin_cache # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
|
|
|
|
torch._check(
|
|
cos.shape[0] == batch_size and cos.shape[1] == sequence_length,
|
|
lambda: f"cos has shape {cos.shape} but expected (batch={batch_size}, seq={sequence_length}, ...)",
|
|
)
|
|
torch._check(
|
|
sin.shape[0] == batch_size and sin.shape[1] == sequence_length,
|
|
lambda: f"sin has shape {sin.shape} but expected (batch={batch_size}, seq={sequence_length}, ...)",
|
|
)
|
|
torch._check(
|
|
cos.shape[-1] == rotary_embedding_dim_half,
|
|
lambda: f"Last dimension of cos cache ({cos.shape[-1]}) should match rotary_embedding_dim/2 ({rotary_embedding_dim_half}).",
|
|
)
|
|
torch._check(
|
|
sin.shape[-1] == rotary_embedding_dim_half,
|
|
lambda: f"Last dimension of sin cache ({sin.shape[-1]}) should match rotary_embedding_dim/2 ({rotary_embedding_dim_half}).",
|
|
)
|
|
cos = torch.unsqueeze(
|
|
cos, 2
|
|
) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
|
|
sin = torch.unsqueeze(
|
|
sin, 2
|
|
) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
|
|
|
|
# Either divide the x in halves or interleave (based on interleaved attribute)
|
|
if interleaved:
|
|
x1 = x_rotate[:, :, :, 0::2]
|
|
x2 = x_rotate[:, :, :, 1::2]
|
|
else:
|
|
x1, x2 = torch.chunk(x_rotate, 2, dim=-1)
|
|
|
|
# Calculate real and imaginary values
|
|
real = cos * x1 - sin * x2
|
|
imag = sin * x1 + cos * x2
|
|
|
|
# Inserted rotated embeddings back to the original x
|
|
if interleaved:
|
|
# x_rotate[:, :, :, 0::2] = real
|
|
# x_rotate[:, :, :, 1::2] = imag
|
|
real = torch.unsqueeze(real, -1)
|
|
imag = torch.unsqueeze(imag, -1)
|
|
x_rotate_concat = torch.cat((real, imag), dim=-1)
|
|
x_rotate = torch.reshape(x_rotate_concat, x_rotate.shape)
|
|
else:
|
|
x_rotate = torch.cat((real, imag), dim=-1)
|
|
output = torch.cat((x_rotate, x_not_rotate), dim=-1)
|
|
if input_rank == 3:
|
|
return torch.reshape(output, input_shape)
|
|
|
|
# Return the dimensions to the original order
|
|
return torch.permute(output, (0, 2, 1, 3))
|
|
|
|
|
|
def _get_scale_factor(scale: Optional[float], head_size: int) -> float:
|
|
"""Get the scale factor for attention computation."""
|
|
return scale if scale is not None else (1.0 / math.sqrt(head_size))
|
|
|
|
|
|
def _reshape_3d_to_4d(
|
|
tensor: torch.Tensor, batch_size: int, num_heads: int
|
|
) -> torch.Tensor:
|
|
"""Reshape 3D tensor to 4D for multi-head attention."""
|
|
sequence_length, hidden_size = tensor.shape[1], tensor.shape[2]
|
|
head_size = hidden_size // num_heads
|
|
return (
|
|
tensor.view(batch_size, sequence_length, num_heads, head_size)
|
|
.transpose(1, 2)
|
|
.contiguous()
|
|
)
|
|
|
|
|
|
def _get_qk_output_for_aten_spda(
|
|
Q: torch.Tensor,
|
|
K: torch.Tensor,
|
|
current_q_num_heads: int,
|
|
current_kv_num_heads: int,
|
|
scale: Optional[float],
|
|
qk_matmul_output_mode: int,
|
|
) -> torch.Tensor:
|
|
"""Get QK output tensor based on the specified mode."""
|
|
if qk_matmul_output_mode == 0:
|
|
return _compute_qk_output_for_mode_0(
|
|
Q, K, current_q_num_heads, current_kv_num_heads, scale
|
|
)
|
|
else:
|
|
# For other modes, return a zero tensor with correct shape
|
|
return torch.zeros_like(torch.matmul(Q, K.transpose(-2, -1)))
|
|
|
|
|
|
def _validate_gqa_configuration(
|
|
current_q_num_heads: int, current_kv_num_heads: int
|
|
) -> None:
|
|
"""Validate Group Query Attention configuration."""
|
|
torch._check(
|
|
current_q_num_heads % current_kv_num_heads == 0,
|
|
lambda: f"q_num_heads ({current_q_num_heads}) must be divisible by kv_num_heads ({current_kv_num_heads}) for GQA",
|
|
)
|
|
|
|
|
|
def _compute_qk_output_for_mode_0(
|
|
Q: torch.Tensor,
|
|
K: torch.Tensor,
|
|
current_q_num_heads: int,
|
|
current_kv_num_heads: int,
|
|
scale: Optional[float],
|
|
) -> torch.Tensor:
|
|
"""Helper function to compute QK output for qk_matmul_output_mode == 0."""
|
|
# Handle GQA manually for QK output
|
|
K_for_qk = K
|
|
if current_q_num_heads != current_kv_num_heads:
|
|
repeat_factor = current_q_num_heads // current_kv_num_heads
|
|
K_for_qk = K.repeat_interleave(repeat_factor, dim=1)
|
|
|
|
scale_factor = _get_scale_factor(scale, Q.shape[3])
|
|
# Scale both Q and K by sqrt(scale_factor) for numerical stability
|
|
sqrt_scale = math.sqrt(scale_factor)
|
|
Q_scaled = Q * sqrt_scale
|
|
K_scaled = K_for_qk * sqrt_scale
|
|
return torch.matmul(Q_scaled, K_scaled.transpose(-2, -1))
|
|
|
|
|
|
@_onnx_op("Attention", 23)
|
|
def attention_23(
|
|
Q: torch.Tensor,
|
|
K: torch.Tensor,
|
|
V: torch.Tensor,
|
|
attn_mask: Optional[torch.Tensor] = None,
|
|
past_key: Optional[torch.Tensor] = None,
|
|
past_value: Optional[torch.Tensor] = None,
|
|
*,
|
|
is_causal: bool = False,
|
|
kv_num_heads: int = 0,
|
|
q_num_heads: int = 0,
|
|
qk_matmul_output_mode: int = 0,
|
|
scale: Optional[float] = None,
|
|
softcap: float = 0.0,
|
|
softmax_precision: Optional[int] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Attention-23 https://onnx.ai/onnx/operators/onnx__Attention.html#attention-23"""
|
|
|
|
num_head_dim, sequence_dim, head_dim = 1, 2, 3
|
|
|
|
# Store original input shape to determine output shape
|
|
input_shape_len = len(Q.shape)
|
|
batch_size = Q.shape[0]
|
|
|
|
# Reshape 3D inputs to 4D format
|
|
if len(Q.shape) == 3:
|
|
torch._check(
|
|
q_num_heads != 0 and kv_num_heads != 0,
|
|
lambda: "q_num_heads and kv_num_heads must be provided for 3D inputs",
|
|
)
|
|
q_sequence_length = Q.shape[1]
|
|
Q = _reshape_3d_to_4d(Q, batch_size, q_num_heads)
|
|
K = _reshape_3d_to_4d(K, batch_size, kv_num_heads)
|
|
V = _reshape_3d_to_4d(V, batch_size, kv_num_heads)
|
|
|
|
torch._check(
|
|
len(Q.shape) == 4 and len(K.shape) == 4 and len(V.shape) == 4,
|
|
lambda: "Q, K, and V should be 4D tensors by now",
|
|
)
|
|
|
|
# Calculate scale factor if not provided
|
|
q_head_size = Q.shape[head_dim]
|
|
scale = _get_scale_factor(scale, q_head_size)
|
|
|
|
# Handle past key/value caches
|
|
present_key = (
|
|
torch.cat([past_key, K], dim=sequence_dim)
|
|
if past_key is not None
|
|
else K.clone()
|
|
)
|
|
present_value = (
|
|
torch.cat([past_value, V], dim=sequence_dim)
|
|
if past_value is not None
|
|
else V.clone()
|
|
)
|
|
|
|
# Update K and V to include past states
|
|
K, V = present_key, present_value
|
|
|
|
# Get current dimensions
|
|
current_q_num_heads = Q.shape[num_head_dim]
|
|
current_kv_num_heads = K.shape[num_head_dim]
|
|
q_sequence_length = Q.shape[sequence_dim]
|
|
kv_sequence_length = K.shape[sequence_dim]
|
|
|
|
# Check if we can use the optimized scaled_dot_product_attention (most optimized)
|
|
can_use_sdpa = (
|
|
softcap == 0.0 # No softcap
|
|
and qk_matmul_output_mode == 0 # Default QK output mode
|
|
and softmax_precision is None # No custom softmax precision
|
|
and (attn_mask is None or attn_mask.dtype == torch.bool)
|
|
)
|
|
|
|
_validate_gqa_configuration(current_q_num_heads, current_kv_num_heads)
|
|
|
|
if can_use_sdpa:
|
|
# Use PyTorch's optimized scaled_dot_product_attention
|
|
|
|
# Prepare attention mask for SDPA
|
|
sdpa_attn_mask = None
|
|
if attn_mask is not None:
|
|
# Convert boolean mask: True means participate, SDPA expects True to mask out
|
|
sdpa_attn_mask = ~attn_mask if attn_mask.dtype == torch.bool else attn_mask
|
|
|
|
output = torch.nn.functional.scaled_dot_product_attention(
|
|
Q,
|
|
K,
|
|
V,
|
|
attn_mask=sdpa_attn_mask,
|
|
dropout_p=0.0,
|
|
is_causal=is_causal,
|
|
scale=scale,
|
|
enable_gqa=bool(
|
|
current_q_num_heads != current_kv_num_heads
|
|
), # Ensure enable_gqa is not SymBool
|
|
)
|
|
|
|
qk_output = _get_qk_output_for_aten_spda(
|
|
Q,
|
|
K,
|
|
current_q_num_heads,
|
|
current_kv_num_heads,
|
|
scale,
|
|
qk_matmul_output_mode,
|
|
)
|
|
else:
|
|
# Fallback to manual implementation for complex cases
|
|
|
|
# Handle Group Query Attention (GQA) and Multi-Query Attention (MQA)
|
|
if current_q_num_heads != current_kv_num_heads:
|
|
repeat_factor = current_q_num_heads // current_kv_num_heads
|
|
K = K.repeat_interleave(repeat_factor, dim=num_head_dim)
|
|
V = V.repeat_interleave(repeat_factor, dim=num_head_dim)
|
|
|
|
# Create attention bias
|
|
attn_bias = torch.zeros(
|
|
q_sequence_length, kv_sequence_length, dtype=Q.dtype, device=Q.device
|
|
)
|
|
|
|
# Apply causal masking
|
|
if is_causal:
|
|
torch._check(
|
|
attn_mask is None, lambda: "Cannot use both is_causal and attn_mask"
|
|
)
|
|
causal_mask = torch.tril(
|
|
torch.ones(
|
|
q_sequence_length,
|
|
kv_sequence_length,
|
|
dtype=torch.bool,
|
|
device=Q.device,
|
|
)
|
|
)
|
|
attn_bias = attn_bias.masked_fill(~causal_mask, float("-inf"))
|
|
|
|
# Apply attention mask
|
|
if attn_mask is not None:
|
|
if attn_mask.dtype == torch.bool:
|
|
# Boolean mask: True means participate in attention
|
|
attn_bias = attn_bias.masked_fill(~attn_mask, float("-inf"))
|
|
else:
|
|
# Float mask: added to attention scores
|
|
attn_bias = attn_bias + attn_mask
|
|
|
|
# Apply scaling factor
|
|
scale_factor = _get_scale_factor(scale, Q.shape[3])
|
|
|
|
# Scale both Q and K by sqrt(scale_factor) for numerical stability
|
|
sqrt_scale = math.sqrt(scale_factor)
|
|
Q_scaled = Q * sqrt_scale
|
|
K_scaled = K * sqrt_scale
|
|
|
|
# Compute Q @ K^T
|
|
qk_matmul_output = torch.matmul(Q_scaled, K_scaled.transpose(-2, -1))
|
|
|
|
# Initialize QK output based on mode
|
|
qk_output = qk_matmul_output # Default case for mode 0
|
|
|
|
# Add attention bias
|
|
qk_with_bias = qk_matmul_output + attn_bias
|
|
|
|
if qk_matmul_output_mode == 1:
|
|
qk_output = qk_with_bias
|
|
|
|
# Apply softcap if provided
|
|
if softcap > 0.0:
|
|
qk_with_bias = softcap * torch.tanh(qk_with_bias / softcap)
|
|
|
|
if qk_matmul_output_mode == 2:
|
|
qk_output = qk_with_bias
|
|
|
|
# Apply softmax with optional precision casting
|
|
if softmax_precision is not None:
|
|
# Map ONNX data type to torch dtype
|
|
if softmax_precision in _ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS:
|
|
original_dtype = qk_with_bias.dtype
|
|
qk_with_bias = qk_with_bias.to(
|
|
_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[softmax_precision]
|
|
)
|
|
qk_softmax = torch.softmax(qk_with_bias, dim=-1)
|
|
qk_softmax = qk_softmax.to(original_dtype)
|
|
else:
|
|
qk_softmax = torch.softmax(qk_with_bias, dim=-1)
|
|
else:
|
|
qk_softmax = torch.softmax(qk_with_bias, dim=-1)
|
|
|
|
if qk_matmul_output_mode == 3:
|
|
qk_output = qk_softmax
|
|
|
|
# Compute attention output
|
|
output = torch.matmul(qk_softmax, V)
|
|
|
|
# Reshape output back to 3D if input was 3D
|
|
if input_shape_len == 3:
|
|
# output: (batch_size, q_num_heads, q_sequence_length, v_head_size) -> (batch_size, q_sequence_length, hidden_size)
|
|
output = (
|
|
output.transpose(1, 2).contiguous().view(batch_size, q_sequence_length, -1)
|
|
)
|
|
|
|
return output, present_key, present_value, qk_output
|