[ONNX] Implement Attention-23 (#156431)

Implement Attention-23 using sdpa and flexattention.

- I used copilot for this.
- Also updated the conversion logic to remove trailing None inputs.

@gramalingam @kunal-vaishnavi @titaiwangms
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156431
Approved by: https://github.com/titaiwangms

Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Justin Chu
2025-06-20 23:54:57 +00:00
committed by PyTorch MergeBot
parent 0ad88a2224
commit fbbab794ef
7 changed files with 1410 additions and 48 deletions

View File

@ -15,7 +15,6 @@ inside an ``if torch.onnx.is_in_onnx_export`` block.
.. autofunction:: torch.onnx.ops.symbolic_multi_out
```
## ONNX Operators
The following operators are implemented as native PyTorch ops and can be exported as
@ -23,7 +22,7 @@ ONNX operators. They can be used natively in an ``nn.Module``.
For example, you can define a module:
```{code-block} python
```py
class Model(torch.nn.Module):
def forward(
self, input_data, cos_cache_data, sin_cache_data, position_ids_data
@ -38,7 +37,7 @@ class Model(torch.nn.Module):
and export it to ONNX using:
```{code-block} python
```py
input_data = torch.rand(2, 3, 4, 8)
position_ids_data = torch.randint(0, 50, (2, 3)).long()
sin_cache_data = torch.rand(50, 4)
@ -84,7 +83,8 @@ graph(
with the corresponding ``ExportedProgram``:
ExportedProgram:
```{code-block} python
```py
class GraphModule(torch.nn.Module):
def forward(self, input_data: "f32[s0, 3, 4, 8]", cos_cache_data: "f32[50, 4]", sin_cache_data: "f32[50, 4]", position_ids_data: "i64[s0, 3]"):
rotary_embedding: "f32[s0, 3, 4, 8]" = torch.ops.onnx.RotaryEmbedding.opset23(input_data, cos_cache_data, sin_cache_data, position_ids_data); input_data = cos_cache_data = sin_cache_data = position_ids_data = None
@ -93,6 +93,7 @@ class GraphModule(torch.nn.Module):
```{eval-rst}
.. autofunction:: torch.onnx.ops.rotary_embedding
.. autofunction:: torch.onnx.ops.attention
```
## ONNX to ATen Decomposition Table
@ -100,7 +101,7 @@ class GraphModule(torch.nn.Module):
You can use {func}`torch.onnx.ops.aten_decompositions` to obtain a decomposition table
to decompose ONNX operators defined above to ATen operators.
```{code-block} python
```py
class Model(torch.nn.Module):
def forward(
self, input_data, cos_cache_data, sin_cache_data, position_ids_data

View File

@ -3,10 +3,11 @@
from __future__ import annotations
import onnx_ir.passes.common as common_passes
from onnxscript import ir
import torch
from torch.onnx.ops import _symbolic_impl
from torch.onnx.ops import _impl, _symbolic_impl
from torch.testing._internal import common_utils
@ -426,6 +427,7 @@ class NativeOnnxOpsTest(common_utils.TestCase):
**options,
)
assert onnx_program is not None
common_passes.CheckerPass()(onnx_program.model)
return onnx_program
def test_onnx_ops_can_be_decomposed_to_aten(self):
@ -469,6 +471,17 @@ class NativeOnnxOpsTest(common_utils.TestCase):
model(input_data, cos_cache_data, sin_cache_data, position_ids_data),
)
def test_rotary_embedding_opcheck(self):
input_data = torch.rand(2, 3, 4, 8)
position_ids_data = torch.randint(0, 50, (2, 3)).long()
sin_cache_data = torch.rand(50, 4)
cos_cache_data = torch.rand(50, 4)
torch.library.opcheck(
_impl.rotary_embedding_23,
(input_data, cos_cache_data, sin_cache_data, position_ids_data),
)
def test_rotary_embedding(self):
input_data = torch.rand(2, 3, 4, 8)
position_ids_data = torch.randint(0, 50, (2, 3)).long()
@ -512,7 +525,928 @@ class NativeOnnxOpsTest(common_utils.TestCase):
)
self.assertEqual(onnx_program.model.opset_imports[""], 23)
self.assertEqual("RotaryEmbedding", onnx_program.model.graph.node(0).op_type)
print(onnx_program)
def test_attention_basic(self):
"""Test basic attention functionality."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test eager mode
torch.library.opcheck(_impl.attention_23, (Q, K, V))
output, present_key, present_value, qk_output = torch.onnx.ops.attention(
Q, K, V
)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
self.assertEqual(present_key.shape, K.shape)
self.assertEqual(present_value.shape, V.shape)
self.assertEqual(
qk_output.shape, (batch_size, q_num_heads, q_seq_len, kv_seq_len)
)
def test_attention_3d_inputs(self):
"""Test attention with 3D inputs (requires num_heads parameters)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size)
K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
torch.library.opcheck(
_impl.attention_23,
(Q, K, V),
dict(q_num_heads=q_num_heads, kv_num_heads=kv_num_heads),
)
output, present_key, present_value, qk_output = torch.onnx.ops.attention(
Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads
)
# Output should be reshaped back to 3D
self.assertEqual(output.shape, (batch_size, q_seq_len, q_num_heads * head_size))
self.assertEqual(
present_key.shape, (batch_size, kv_num_heads, kv_seq_len, head_size)
)
self.assertEqual(
present_value.shape, (batch_size, kv_num_heads, kv_seq_len, head_size)
)
def test_attention_gqa(self):
"""Test Group Query Attention (GQA)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 4 # GQA: q_num_heads % kv_num_heads = 0
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
torch.library.opcheck(_impl.attention_23, (Q, K, V))
output, present_key, present_value, qk_output = torch.onnx.ops.attention(
Q, K, V
)
expected = torch.nn.functional.scaled_dot_product_attention(
Q, K, V, None, enable_gqa=True
)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
self.assertEqual(present_key.shape, K.shape)
self.assertEqual(present_value.shape, V.shape)
torch.testing.assert_close(output, expected)
def test_attention_mqa(self):
"""Test Multi-Query Attention (MQA)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 1 # MQA: kv_num_heads = 1
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
torch.library.opcheck(_impl.attention_23, (Q, K, V))
output, present_key, present_value, qk_output = torch.onnx.ops.attention(
Q, K, V
)
expected = torch.nn.functional.scaled_dot_product_attention(
Q, K, V, None, enable_gqa=True
)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
torch.testing.assert_close(output, expected)
def test_attention_with_2d_mask(self):
"""Test attention with 2D attention mask (q_seq_len, kv_seq_len)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test with boolean mask
bool_mask = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=bool_mask))
output_bool, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=bool_mask)
# Test with float mask
float_mask = torch.randn(q_seq_len, kv_seq_len)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask))
output_float, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask)
self.assertEqual(
output_bool.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
self.assertEqual(
output_float.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
def test_attention_with_4d_mask(self):
"""Test attention with 4D attention mask (batch_size, num_heads, q_seq_len, kv_seq_len)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test with boolean mask
bool_mask = torch.randint(
0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool
)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=bool_mask))
output_bool, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=bool_mask)
# Test with float mask
float_mask = torch.randn(batch_size, q_num_heads, q_seq_len, kv_seq_len)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask))
output_float, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask)
self.assertEqual(
output_bool.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
self.assertEqual(
output_float.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
def test_attention_with_zero_float_mask(self):
"""Test attention with zero float mask."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
zero_mask = torch.zeros(q_seq_len, kv_seq_len)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=zero_mask))
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=zero_mask)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_with_causal_mask_pattern(self):
"""Test attention with lower triangular causal mask pattern."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 4 # Square for causal
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Create a lower triangular causal mask
causal_mask = torch.tril(torch.ones(q_seq_len, kv_seq_len, dtype=torch.bool))
torch.library.opcheck(
_impl.attention_23, (Q, K, V), dict(attn_mask=causal_mask)
)
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=causal_mask)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_with_gqa_and_mask(self):
"""Test attention with GQA and different mask shapes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 4 # GQA
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test 2D mask with GQA
mask_2d = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=mask_2d))
output_2d, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask_2d)
# Test 4D mask with GQA (note: using q_num_heads for mask heads)
mask_4d = torch.randint(
0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool
)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=mask_4d))
output_4d, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask_4d)
self.assertEqual(
output_2d.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
self.assertEqual(
output_4d.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
def test_attention_with_large_negative_float_mask(self):
"""Test attention with large negative values in float mask."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Create mask with large negative values (similar to -inf masking)
float_mask = torch.full((q_seq_len, kv_seq_len), -1e9)
# Allow some positions
float_mask[:, :3] = 0.0
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask))
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_causal(self):
"""Test causal attention."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 4 # Square for causal
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(is_causal=True))
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, is_causal=True)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_with_past_kv(self):
"""Test attention with past key/value caches."""
batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
torch.library.opcheck(
_impl.attention_23,
(Q, K, V),
dict(past_key=past_key, past_value=past_value),
)
output, present_key, present_value, _ = torch.onnx.ops.attention(
Q, K, V, past_key=past_key, past_value=past_value
)
# Present key/value should include past + current
expected_total_seq_len = past_seq_len + kv_seq_len
self.assertEqual(
present_key.shape,
(batch_size, kv_num_heads, expected_total_seq_len, head_size),
)
self.assertEqual(
present_value.shape,
(batch_size, kv_num_heads, expected_total_seq_len, head_size),
)
def test_attention_with_softcap(self):
"""Test attention with softcap."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(softcap=30.0))
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, softcap=30.0)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_qk_output_modes(self):
"""Test different QK matmul output modes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
for mode in [0, 1, 2, 3]:
torch.library.opcheck(
_impl.attention_23,
(Q, K, V),
dict(qk_matmul_output_mode=mode),
)
output, _, _, qk_output = torch.onnx.ops.attention(
Q, K, V, qk_matmul_output_mode=mode
)
self.assertEqual(
output.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
self.assertEqual(
qk_output.shape, (batch_size, q_num_heads, q_seq_len, kv_seq_len)
)
def test_attention_custom_scale(self):
"""Test attention with custom scale factor."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
custom_scale = 0.25
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(scale=custom_scale))
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, scale=custom_scale)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_export(self):
"""Test that attention can be exported to ONNX."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class AttentionModel(torch.nn.Module):
def forward(self, Q, K, V):
output, present_key, present_value, qk_output = (
torch.onnx.ops.attention(Q, K, V)
)
return output
model = AttentionModel()
onnx_program = self.export(
model,
(Q, K, V),
opset_version=23,
)
self.assertEqual(onnx_program.model.opset_imports[""], 23)
self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type)
def test_attention_export_with_dynamic_shapes(self):
"""Test attention export with dynamic shapes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class AttentionModel(torch.nn.Module):
def forward(self, Q, K, V):
output, present_key, present_value, qk_output = (
torch.onnx.ops.attention(Q, K, V)
)
return output
model = AttentionModel()
dynamic_shapes = {
"Q": {0: "batch", 2: "q_seq_len"},
"K": {0: "batch", 2: "kv_seq_len"},
"V": {0: "batch", 2: "kv_seq_len"},
}
onnx_program = self.export(
model,
(Q, K, V),
dynamic_shapes=dynamic_shapes,
opset_version=23,
)
self.assertEqual(onnx_program.model.opset_imports[""], 23)
self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type)
node = onnx_program.model.graph.node(0)
# Verify inputs
self.assertEqual(len(node.inputs), 3) # Q, K, V (no optional inputs)
self.assertEqual(
node.inputs[0].shape, ["batch", q_num_heads, "q_seq_len", head_size]
)
self.assertEqual(
node.inputs[1].shape, ["batch", kv_num_heads, "kv_seq_len", head_size]
)
self.assertEqual(
node.inputs[2].shape, ["batch", kv_num_heads, "kv_seq_len", head_size]
)
# Verify default attributes (should be minimal)
self.assertEqual(len(node.attributes), 0)
def test_attention_3d_export(self):
"""Test attention export with 3D inputs."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size)
K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
class AttentionModel(torch.nn.Module):
def forward(self, Q, K, V):
output, _, _, _ = torch.onnx.ops.attention(
Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads
)
return output
model = AttentionModel()
onnx_program = self.export(
model,
(Q, K, V),
opset_version=23,
)
self.assertEqual(onnx_program.model.opset_imports[""], 23)
self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type)
def test_attention_decomposition(self):
"""Test that attention can be decomposed to aten ops."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class AttentionModel(torch.nn.Module):
def forward(self, Q, K, V):
output, present_key, present_value, qk_output = (
torch.onnx.ops.attention(Q, K, V)
)
return output
model = AttentionModel()
ep = torch.export.export(model, (Q, K, V))
self.assertIn(
"onnx.Attention.opset23",
[str(node.target) for node in ep.graph.nodes],
)
# The program can be decomposed into aten ops
aten_decomped = ep.run_decompositions(torch.onnx.ops.aten_decompositions())
self.assertNotIn(
"onnx.Attention.opset23",
[str(node.target) for node in aten_decomped.graph.nodes],
)
# Results should match
torch.testing.assert_close(
aten_decomped.module()(Q, K, V),
model(Q, K, V),
)
def test_attention_export_with_past_key_value(self):
"""Test export with past_key, past_value to ensure the optional input order is correct."""
batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
class Model(torch.nn.Module):
def forward(self, Q, K, V, past_key, past_value):
output, _, _, _ = torch.onnx.ops.attention(
Q,
K,
V,
past_key=past_key,
attn_mask=None,
# Switched argument order
past_value=past_value,
)
return output
model = Model()
onnx_program = self.export(
model, (Q, K, V, past_key, past_value), opset_version=23
)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify all 6 inputs are present
self.assertEqual(
len(node.inputs), 6
) # Q, K, V, attn_mask, past_key, past_value
self.assertEqual(
node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
)
self.assertEqual(
node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
self.assertEqual(
node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
self.assertIsNone(node.inputs[3])
self.assertEqual(
node.inputs[4].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
)
self.assertEqual(
node.inputs[5].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
)
def test_attention_export_with_all_optional_inputs(self):
"""Test export with all optional inputs: mask, past_key, past_value."""
batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
attn_mask = torch.randint(
0, 2, (1, 1, q_seq_len, kv_seq_len + past_seq_len), dtype=torch.bool
)
past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
class FullAttentionModel(torch.nn.Module):
def forward(self, Q, K, V, attn_mask, past_key, past_value):
output, _, _, _ = torch.onnx.ops.attention(
Q,
K,
V,
attn_mask=attn_mask,
past_key=past_key,
past_value=past_value,
)
return output
model = FullAttentionModel()
onnx_program = self.export(
model, (Q, K, V, attn_mask, past_key, past_value), opset_version=23
)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify all 6 inputs are present
self.assertEqual(
len(node.inputs), 6
) # Q, K, V, attn_mask, past_key, past_value
self.assertEqual(
node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
)
self.assertEqual(
node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
self.assertEqual(
node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
self.assertEqual(
node.inputs[3].shape, [1, 1, q_seq_len, kv_seq_len + past_seq_len]
)
self.assertEqual(
node.inputs[4].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
)
self.assertEqual(
node.inputs[5].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
)
def test_attention_export_3d_with_num_heads_attributes(self):
"""Test export with 3D inputs and explicit num_heads attributes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 4 # GQA
head_size = 64
Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size)
K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
class Attention3DModel(torch.nn.Module):
def forward(self, Q, K, V):
output, _, _, _ = torch.onnx.ops.attention(
Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads
)
return output
model = Attention3DModel()
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify 3D input shapes
self.assertEqual(
node.inputs[0].shape, [batch_size, q_seq_len, q_num_heads * head_size]
)
self.assertEqual(
node.inputs[1].shape, [batch_size, kv_seq_len, kv_num_heads * head_size]
)
self.assertEqual(
node.inputs[2].shape, [batch_size, kv_seq_len, kv_num_heads * head_size]
)
# Verify num_heads attributes are set
attrs = node.attributes
self.assertIn("q_num_heads", attrs)
self.assertIn("kv_num_heads", attrs)
self.assertEqual(attrs["q_num_heads"].value, q_num_heads)
self.assertEqual(attrs["kv_num_heads"].value, kv_num_heads)
def test_attention_export_with_all_attributes(self):
"""Test export with all possible attributes set."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class FullAttributesModel(torch.nn.Module):
def forward(self, Q, K, V):
output, _, _, _ = torch.onnx.ops.attention(
Q,
K,
V,
is_causal=True,
qk_matmul_output_mode=2,
scale=0.25,
softcap=30.0,
softmax_precision=1, # FLOAT
)
return output
model = FullAttributesModel()
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify all attributes are set correctly
attrs = node.attributes
self.assertIn("is_causal", attrs)
self.assertIn("qk_matmul_output_mode", attrs)
self.assertIn("scale", attrs)
self.assertIn("softcap", attrs)
self.assertIn("softmax_precision", attrs)
self.assertEqual(attrs["is_causal"].value, 1) # True as int
self.assertEqual(attrs["qk_matmul_output_mode"].value, 2)
self.assertAlmostEqual(attrs["scale"].value, 0.25, places=6)
self.assertAlmostEqual(attrs["softcap"].value, 30.0, places=6)
self.assertEqual(attrs["softmax_precision"].value, 1)
def test_attention_export_with_different_mask_shapes(self):
"""Test export with different attention mask shapes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test 2D mask
mask_2d = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool)
class Mask2DModel(torch.nn.Module):
def forward(self, Q, K, V, mask):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
return output
model_2d = Mask2DModel()
onnx_program_2d = self.export(model_2d, (Q, K, V, mask_2d), opset_version=23)
node_2d = onnx_program_2d.model.graph.node(0)
self.assertEqual(node_2d.inputs[3].shape, [q_seq_len, kv_seq_len])
# Test 3D mask
mask_3d = torch.randint(
0, 2, (batch_size, 1, q_seq_len, kv_seq_len), dtype=torch.bool
)
class Mask3DModel(torch.nn.Module):
def forward(self, Q, K, V, mask):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
return output
model_3d = Mask3DModel()
onnx_program_3d = self.export(model_3d, (Q, K, V, mask_3d), opset_version=23)
node_3d = onnx_program_3d.model.graph.node(0)
self.assertEqual(
node_3d.inputs[3].shape, [batch_size, 1, q_seq_len, kv_seq_len]
)
# Test 4D mask
mask_4d = torch.randint(
0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool
)
class Mask4DModel(torch.nn.Module):
def forward(self, Q, K, V, mask):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
return output
model_4d = Mask4DModel()
onnx_program_4d = self.export(model_4d, (Q, K, V, mask_4d), opset_version=23)
node_4d = onnx_program_4d.model.graph.node(0)
self.assertEqual(
node_4d.inputs[3].shape, [batch_size, q_num_heads, q_seq_len, kv_seq_len]
)
def test_attention_export_with_float_mask(self):
"""Test export with float attention mask."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
float_mask = torch.randn(q_seq_len, kv_seq_len)
class FloatMaskModel(torch.nn.Module):
def forward(self, Q, K, V, mask):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
return output
model = FloatMaskModel()
onnx_program = self.export(model, (Q, K, V, float_mask), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
self.assertEqual(node.inputs[3].shape, [q_seq_len, kv_seq_len])
# Verify the mask input has float dtype in the ONNX model
self.assertEqual(node.inputs[3].dtype, ir.DataType.FLOAT)
def test_attention_export_qk_output_modes(self):
"""Test export with different QK output modes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
for mode in [0, 1, 2, 3]:
class QKOutputModel(torch.nn.Module):
def __init__(self, qk_mode):
super().__init__()
self.qk_mode = qk_mode
def forward(self, Q, K, V):
output, _, _, qk_output = torch.onnx.ops.attention(
Q, K, V, qk_matmul_output_mode=self.qk_mode
)
return output, qk_output
model = QKOutputModel(mode)
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify qk_matmul_output_mode attribute
attrs = node.attributes
if mode != 0:
self.assertIn("qk_matmul_output_mode", attrs)
self.assertEqual(attrs["qk_matmul_output_mode"].value, mode)
# Verify 4 outputs (output, present_key, present_value, qk_output)
self.assertEqual(len(node.outputs), 4)
def test_attention_export_mqa(self):
"""Test export with Multi-Query Attention (MQA)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 1 # MQA
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class MQAModel(torch.nn.Module):
def forward(self, Q, K, V):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V)
return output
model = MQAModel()
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify MQA tensor shapes
self.assertEqual(
node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
)
self.assertEqual(
node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
) # kv_num_heads = 1
self.assertEqual(
node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
def test_attention_export_with_softmax_precision(self):
"""Test export with different softmax precision values."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test different ONNX precision types
precision_types = [
(1, "FLOAT"),
(10, "FLOAT16"),
(11, "DOUBLE"),
(16, "BFLOAT16"),
]
for precision_val, precision_name in precision_types:
class SoftmaxPrecisionModel(torch.nn.Module):
def __init__(self, precision):
super().__init__()
self.precision = precision
def forward(self, Q, K, V):
output, _, _, _ = torch.onnx.ops.attention(
Q, K, V, softmax_precision=self.precision
)
return output
model = SoftmaxPrecisionModel(precision_val)
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify softmax_precision attribute
attrs = node.attributes
self.assertIn("softmax_precision", attrs)
self.assertEqual(attrs["softmax_precision"].value, precision_val)
def test_attention_export_gqa(self):
"""Test export and verify output tensor shapes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 4 # GQA
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class AttentionOutputsModel(torch.nn.Module):
def forward(self, Q, K, V):
return torch.onnx.ops.attention(Q, K, V)
model = AttentionOutputsModel()
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify all 4 outputs have correct shapes
outputs = node.outputs
self.assertEqual(len(outputs), 4)
# output: (batch_size, q_num_heads, q_seq_len, head_size)
self.assertEqual(
outputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
)
# present_key: (batch_size, kv_num_heads, kv_seq_len, head_size)
self.assertEqual(
outputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
# present_value: (batch_size, kv_num_heads, kv_seq_len, head_size)
self.assertEqual(
outputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
# qk_output: (batch_size, q_num_heads, q_seq_len, kv_seq_len)
self.assertEqual(
outputs[3].shape, [batch_size, q_num_heads, q_seq_len, kv_seq_len]
)
if __name__ == "__main__":

View File

@ -552,6 +552,11 @@ def _handle_call_function_node_with_lowering(
if _is_onnx_op(node.target):
# Handle torch.ops.onnx.* ops. These ops can be directly added to the graph
op_type, opset_version = _parse_onnx_op(node.target) # type: ignore[arg-type]
# If final inputs are None, strip them from the node inputs
for input_ in reversed(onnx_args):
if input_ is not None:
break
onnx_args.pop()
onnx_node = ir.Node(
"",
op_type,

View File

@ -13,6 +13,7 @@ __all__ = [
"symbolic",
"symbolic_multi_out",
"rotary_embedding",
"attention",
]
@ -334,7 +335,7 @@ def rotary_embedding(
Returns:
Tensor with same shape as input.
"""
return _impl.rotary_embedding(
return _impl.rotary_embedding_23(
X,
cos_cache,
sin_cache,
@ -343,3 +344,124 @@ def rotary_embedding(
num_heads=num_heads,
rotary_embedding_dim=rotary_embedding_dim,
)
def attention(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
attn_mask: torch.Tensor | None = None,
past_key: torch.Tensor | None = None,
past_value: torch.Tensor | None = None,
*,
is_causal: bool = False,
kv_num_heads: int = 0,
q_num_heads: int = 0,
qk_matmul_output_mode: int = 0,
scale: float | None = None,
softcap: float = 0.0,
softmax_precision: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Attention op in ONNX.
https://onnx.ai/onnx/operators/onnx__Attention.html
Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed.
This operator covers self and cross variants of the attention operation based on sequence lengths of K, Q and V.
For self attention, ``kv_sequence_length`` equals to ``q_sequence_length``.
For cross attention, query and key might have different lengths.
This operator also covers the 3 following variants based on the number of heads:
1. Multi-headed Attention (MHA): Described in the paper https://arxiv.org/pdf/1706.03762, `q_num_heads = kv_num_heads`.
2. Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`.
3. Multi-query Attention (MQA): Described in the paper https://arxiv.org/pdf/1911.02150, `q_num_heads > kv_num_heads`, `kv_num_heads=1`.
Attention bias to be added is calculated based on ``attn_mask`` input and ``is_causal` `attribute``, only one of which can be provided.
1. If ``is_causal`` is set to `1`, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment.
2. `attn_mask`: A boolean mask where a value of `True` indicates that the element should take part in attention or a float mask of the same type as query, key, value that is added to the attention score.
Both past and present state key/values are optional. They shall be used together, and not allowed to use only one of them.
The following pattern is applied to the Q, K and V inputs after appropriate reshaping of K and V inputs based on sequence lengths and num heads provided::
The following pattern is applied by this operator:
Q K V
| | |
Q*sqrt(scale) K*sqrt(scale) |
| | |
| Transpose |
| | |
---MatMul--- |
| |
at_mask---Add |
| |
softcap (if provided) |
| |
Softmax |
| |
-----MatMul------
|
Y
Args:
Q: Query tensor. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, head_size)` or 3D tensor
with shape `(batch_size, q_sequence_length, q_hidden_size)`. For cases with a 3D input tensor,
`q_hidden_size = q_num_heads * head_size`
K: Key tensor. 4D tensor with shape `(batch_size, kv_num_heads, kv_sequence_length, head_size)` or 3D tensor
with shape `(batch_size, kv_sequence_length, k_hidden_size)`. For cases with a 3D input tensor,
`k_hidden_size = kv_num_heads * head_size`
V: Value tensor. 4D tensor with shape `(batch_size, kv_num_heads, kv_sequence_length, v_head_size)` or 3D tensor
with shape `(batch_size, kv_sequence_length, v_hidden_size)`. For cases with a 3D input tensor,
`v_hidden_size = kv_num_heads * v_head_size`
attn_mask: Attention mask. Shape must be broadcastable to 4D tensor with shape
`(batch_size, q_num_heads, q_sequence_length, total_sequence_length)` where
`total_sequence_length = past_sequence_length + kv_sequence_length`. Two types of masks are supported.
A boolean mask where a value of True indicates that the element should take part in attention.
Also supports a float mask of the same type as query, key, value that is added to the attention score.
past_key: Past state cache for key with shape `(batch_size, kv_num_heads, past_sequence_length, head_size)`
past_value: Past state cache for value with shape `(batch_size, kv_num_heads, past_sequence_length, v_head_size)`
is_causal: If set to True, the attention masking is a lower triangular matrix when the mask is a square matrix.
The attention masking has the form of the upper left causal bias due to the alignment.
kv_num_heads: Number of heads of key and value. Must be used with 3D inputs of Q, K and V.
q_num_heads: Number of heads of query. Must be used with 3D inputs of Q, K and V.
qk_matmul_output_mode: If set to 0, qk_matmul_output is the output of qk matmul. If set to 1,
qk_matmul_output includes the addition of the attention mask to the output of qk matmul.
If set to 2, qk_matmul_output is the output after the softcap operation. If set to 3,
qk_matmul_output is the output after the softmax operation. Default value is 0.
scale: Scaling factor applied to Q*K^T. Default value is 1/sqrt(head_size). To prevent numerical overflow,
scale Q, K by sqrt(scale) before matmul.
softcap: Softcap value for attention weights. Default value is 0.
softmax_precision: The floating-point precision used in softmax computation. If softmax precision is not provided,
the same precision as the input of softmax (Q and K) is used.
Returns:
A tuple containing:
- The output tensor. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, v_head_size)` or 3D tensor
with shape `(batch_size, q_sequence_length, hidden_size)`. For cases with a 3D input tensor,
`hidden_size = q_num_heads * v_head_size`
- Updated key cache with shape `(batch_size, kv_num_heads, total_sequence_length, head_size)` where
`total_sequence_length = past_sequence_length + kv_sequence_length`.
- Updated value cache with shape `(batch_size, kv_num_heads, total_sequence_length, v_head_size)` where
`total_sequence_length = past_sequence_length + kv_sequence_length`.
- The output of QK matmul. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, total_sequence_length)`
where `total_sequence_length = past_sequence_length + kv_sequence_length`.
"""
return _impl.attention_23(
Q,
K,
V,
attn_mask=attn_mask,
past_key=past_key,
past_value=past_value,
is_causal=is_causal,
kv_num_heads=kv_num_heads,
q_num_heads=q_num_heads,
qk_matmul_output_mode=qk_matmul_output_mode,
scale=scale,
softcap=softcap,
softmax_precision=softmax_precision,
)

View File

@ -0,0 +1,27 @@
import torch
ONNX_DTYPE_TO_TORCH_DTYPE: dict[int, torch.dtype] = {
1: torch.float32, # FLOAT
2: torch.uint8, # UINT8
3: torch.int8, # INT8
4: torch.uint16, # UINT16
5: torch.int16, # INT16
6: torch.int32, # INT32
7: torch.int64, # INT64
9: torch.bool, # BOOL
10: torch.float16, # FLOAT16
11: torch.double, # DOUBLE
12: torch.uint32, # UINT32
13: torch.uint64, # UINT64
14: torch.complex64, # COMPLEX64
15: torch.complex128, # COMPLEX128
16: torch.bfloat16, # BFLOAT16
17: torch.float8_e4m3fn, # FLOAT8E4M3FN
18: torch.float8_e4m3fnuz, # FLOAT8E4M3FNUZ
19: torch.float8_e5m2, # FLOAT8E5M2
20: torch.float8_e5m2fnuz, # FLOAT8E5M2FNUZ
21: torch.uint8, # UINT4
22: torch.uint8, # INT4
23: torch.float4_e2m1fn_x2, # FLOAT4E2M1
}

View File

@ -1,13 +1,24 @@
# flake8: noqa: B950
import math
import typing
from typing import Callable, Optional
import torch
from torch.onnx.ops import _dtype_mappings
_T = typing.TypeVar("_T", bound=Callable)
# ONNX to ATen decomp table
ONNX_ATEN_DECOMP_TABLE: dict[torch._ops.OpOverload, Callable] = {}
_ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS = frozenset(
{
1, # FLOAT
10, # FLOAT16
11, # DOUBLE
16, # BFLOAT16
}
)
def _onnx_op(op_type: str, opset_version: int) -> Callable[[_T], _T]:
@ -30,7 +41,7 @@ def _onnx_op(op_type: str, opset_version: int) -> Callable[[_T], _T]:
@_onnx_op("RotaryEmbedding", 23)
def rotary_embedding(
def rotary_embedding_23(
x: torch.Tensor,
cos_cache: torch.Tensor,
sin_cache: torch.Tensor,
@ -46,11 +57,14 @@ def rotary_embedding(
sequence_length = x.shape[1]
if len(x.shape) == 3:
hidden_size = x.shape[2]
assert num_heads != 0
torch._check(
num_heads != 0,
lambda: f"num_heads must be provided for 3D inputs. Received input tensor with shape {x.shape}",
)
head_size = hidden_size // num_heads
new_shape = [batch_size, sequence_length, num_heads, head_size]
x = torch.reshape(x, new_shape)
assert len(x.shape) == 4
torch._check(len(x.shape) == 4, lambda: "x should be a 4D tensor by now")
head_size = x.shape[3]
# Fully or partially perform rotation on x based on rotary_embedding_dim attribute
@ -110,3 +124,273 @@ def rotary_embedding(
if len(x.shape) == 3:
output = torch.reshape(output, x.shape)
return output
def _get_scale_factor(scale: Optional[float], head_size: int) -> float:
"""Get the scale factor for attention computation."""
return scale if scale is not None else (1.0 / math.sqrt(head_size))
def _reshape_3d_to_4d(
tensor: torch.Tensor, batch_size: int, num_heads: int
) -> torch.Tensor:
"""Reshape 3D tensor to 4D for multi-head attention."""
sequence_length, hidden_size = tensor.shape[1], tensor.shape[2]
head_size = hidden_size // num_heads
return (
tensor.view(batch_size, sequence_length, num_heads, head_size)
.transpose(1, 2)
.contiguous()
)
def _get_qk_output_for_aten_spda(
Q: torch.Tensor,
K: torch.Tensor,
current_q_num_heads: int,
current_kv_num_heads: int,
scale: Optional[float],
qk_matmul_output_mode: int,
) -> torch.Tensor:
"""Get QK output tensor based on the specified mode."""
if qk_matmul_output_mode == 0:
return _compute_qk_output_for_mode_0(
Q, K, current_q_num_heads, current_kv_num_heads, scale
)
else:
# For other modes, return a zero tensor with correct shape
return torch.zeros_like(torch.matmul(Q, K.transpose(-2, -1)))
def _validate_gqa_configuration(
current_q_num_heads: int, current_kv_num_heads: int
) -> None:
"""Validate Group Query Attention configuration."""
torch._check(
current_q_num_heads % current_kv_num_heads == 0,
lambda: f"q_num_heads ({current_q_num_heads}) must be divisible by kv_num_heads ({current_kv_num_heads}) for GQA",
)
def _compute_qk_output_for_mode_0(
Q: torch.Tensor,
K: torch.Tensor,
current_q_num_heads: int,
current_kv_num_heads: int,
scale: Optional[float],
) -> torch.Tensor:
"""Helper function to compute QK output for qk_matmul_output_mode == 0."""
# Handle GQA manually for QK output
K_for_qk = K
if current_q_num_heads != current_kv_num_heads:
repeat_factor = current_q_num_heads // current_kv_num_heads
K_for_qk = K.repeat_interleave(repeat_factor, dim=1)
scale_factor = _get_scale_factor(scale, Q.shape[3])
# Scale both Q and K by sqrt(scale_factor) for numerical stability
sqrt_scale = math.sqrt(scale_factor)
Q_scaled = Q * sqrt_scale
K_scaled = K_for_qk * sqrt_scale
return torch.matmul(Q_scaled, K_scaled.transpose(-2, -1))
@_onnx_op("Attention", 23)
def attention_23(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
past_key: Optional[torch.Tensor] = None,
past_value: Optional[torch.Tensor] = None,
*,
is_causal: bool = False,
kv_num_heads: int = 0,
q_num_heads: int = 0,
qk_matmul_output_mode: int = 0,
scale: Optional[float] = None,
softcap: float = 0.0,
softmax_precision: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Attention-23 https://onnx.ai/onnx/operators/onnx__Attention.html#attention-23"""
num_head_dim, sequence_dim, head_dim = 1, 2, 3
# Store original input shape to determine output shape
input_shape_len = len(Q.shape)
batch_size = Q.shape[0]
# Reshape 3D inputs to 4D format
if len(Q.shape) == 3:
torch._check(
q_num_heads != 0 and kv_num_heads != 0,
lambda: "q_num_heads and kv_num_heads must be provided for 3D inputs",
)
q_sequence_length = Q.shape[1]
Q = _reshape_3d_to_4d(Q, batch_size, q_num_heads)
K = _reshape_3d_to_4d(K, batch_size, kv_num_heads)
V = _reshape_3d_to_4d(V, batch_size, kv_num_heads)
torch._check(
len(Q.shape) == 4 and len(K.shape) == 4 and len(V.shape) == 4,
lambda: "Q, K, and V should be 4D tensors by now",
)
# Calculate scale factor if not provided
q_head_size = Q.shape[head_dim]
scale = _get_scale_factor(scale, q_head_size)
# Handle past key/value caches
present_key = (
torch.cat([past_key, K], dim=sequence_dim)
if past_key is not None
else K.clone()
)
present_value = (
torch.cat([past_value, V], dim=sequence_dim)
if past_value is not None
else V.clone()
)
# Update K and V to include past states
K, V = present_key, present_value
# Get current dimensions
current_q_num_heads = Q.shape[num_head_dim]
current_kv_num_heads = K.shape[num_head_dim]
q_sequence_length = Q.shape[sequence_dim]
kv_sequence_length = K.shape[sequence_dim]
# Check if we can use the optimized scaled_dot_product_attention (most optimized)
can_use_sdpa = (
softcap == 0.0 # No softcap
and qk_matmul_output_mode == 0 # Default QK output mode
and softmax_precision is None # No custom softmax precision
and (attn_mask is None or attn_mask.dtype == torch.bool)
)
_validate_gqa_configuration(current_q_num_heads, current_kv_num_heads)
if can_use_sdpa:
# Use PyTorch's optimized scaled_dot_product_attention
# Prepare attention mask for SDPA
sdpa_attn_mask = None
if attn_mask is not None:
# Convert boolean mask: True means participate, SDPA expects True to mask out
sdpa_attn_mask = ~attn_mask if attn_mask.dtype == torch.bool else attn_mask
output = torch.nn.functional.scaled_dot_product_attention(
Q,
K,
V,
attn_mask=sdpa_attn_mask,
dropout_p=0.0,
is_causal=is_causal,
scale=scale,
enable_gqa=bool(
current_q_num_heads != current_kv_num_heads
), # Ensure enable_gqa is not SymBool
)
qk_output = _get_qk_output_for_aten_spda(
Q,
K,
current_q_num_heads,
current_kv_num_heads,
scale,
qk_matmul_output_mode,
)
else:
# Fallback to manual implementation for complex cases
# Handle Group Query Attention (GQA) and Multi-Query Attention (MQA)
if current_q_num_heads != current_kv_num_heads:
repeat_factor = current_q_num_heads // current_kv_num_heads
K = K.repeat_interleave(repeat_factor, dim=num_head_dim)
V = V.repeat_interleave(repeat_factor, dim=num_head_dim)
# Create attention bias
attn_bias = torch.zeros(
q_sequence_length, kv_sequence_length, dtype=Q.dtype, device=Q.device
)
# Apply causal masking
if is_causal:
torch._check(
attn_mask is None, lambda: "Cannot use both is_causal and attn_mask"
)
causal_mask = torch.tril(
torch.ones(
q_sequence_length,
kv_sequence_length,
dtype=torch.bool,
device=Q.device,
)
)
attn_bias = attn_bias.masked_fill(~causal_mask, float("-inf"))
# Apply attention mask
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
# Boolean mask: True means participate in attention
attn_bias = attn_bias.masked_fill(~attn_mask, float("-inf"))
else:
# Float mask: added to attention scores
attn_bias = attn_bias + attn_mask
# Apply scaling factor
scale_factor = _get_scale_factor(scale, Q.shape[3])
# Scale both Q and K by sqrt(scale_factor) for numerical stability
sqrt_scale = math.sqrt(scale_factor)
Q_scaled = Q * sqrt_scale
K_scaled = K * sqrt_scale
# Compute Q @ K^T
qk_matmul_output = torch.matmul(Q_scaled, K_scaled.transpose(-2, -1))
# Initialize QK output based on mode
qk_output = qk_matmul_output # Default case for mode 0
# Add attention bias
qk_with_bias = qk_matmul_output + attn_bias
if qk_matmul_output_mode == 1:
qk_output = qk_with_bias
# Apply softcap if provided
if softcap > 0.0:
qk_with_bias = softcap * torch.tanh(qk_with_bias / softcap)
if qk_matmul_output_mode == 2:
qk_output = qk_with_bias
# Apply softmax with optional precision casting
if softmax_precision is not None:
# Map ONNX data type to torch dtype
if softmax_precision in _ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS:
original_dtype = qk_with_bias.dtype
qk_with_bias = qk_with_bias.to(
_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[softmax_precision]
)
qk_softmax = torch.softmax(qk_with_bias, dim=-1)
qk_softmax = qk_softmax.to(original_dtype)
else:
qk_softmax = torch.softmax(qk_with_bias, dim=-1)
else:
qk_softmax = torch.softmax(qk_with_bias, dim=-1)
if qk_matmul_output_mode == 3:
qk_output = qk_softmax
# Compute attention output
output = torch.matmul(qk_softmax, V)
# Reshape output back to 3D if input was 3D
if input_shape_len == 3:
# output: (batch_size, q_num_heads, q_sequence_length, v_head_size) -> (batch_size, q_sequence_length, hidden_size)
output = (
output.transpose(1, 2).contiguous().view(batch_size, q_sequence_length, -1)
)
return output, present_key, present_value, qk_output

View File

@ -11,38 +11,15 @@ zeros based on the input shape and dtype, and a "fake" implementation that does
or less the same thing but is required by the `torch.library.custom_op` interface.
"""
# flake8: noqa: B950
import dataclasses
from collections.abc import Sequence
from typing import Optional, Union
import torch
from torch.onnx.ops import _dtype_mappings
_ONNX_DTYPE_TO_TORCH_DTYPE: dict[int, torch.dtype] = {
1: torch.float32, # FLOAT
2: torch.uint8, # UINT8
3: torch.int8, # INT8
4: torch.uint16, # UINT16
5: torch.int16, # INT16
6: torch.int32, # INT32
7: torch.int64, # INT64
9: torch.bool, # BOOL
10: torch.float16, # FLOAT16
11: torch.double, # DOUBLE
12: torch.uint32, # UINT32
13: torch.uint64, # UINT64
14: torch.complex64, # COMPLEX64
15: torch.complex128, # COMPLEX128
16: torch.bfloat16, # BFLOAT16
17: torch.float8_e4m3fn, # FLOAT8E4M3FN
18: torch.float8_e4m3fnuz, # FLOAT8E4M3FNUZ
19: torch.float8_e5m2, # FLOAT8E5M2
20: torch.float8_e5m2fnuz, # FLOAT8E5M2FNUZ
21: torch.uint8, # UINT4
22: torch.uint8, # INT4
23: torch.float4_e2m1fn_x2, # FLOAT4E2M1
}
_INT_TYPE = "i"
_FLOAT_TYPE = "f"
_STRING_TYPE = "s"
@ -221,10 +198,12 @@ def _symbolic(
version: Optional[int] = None,
) -> torch.Tensor:
torch._check(
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE,
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
)
return torch.zeros(
shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]
)
return torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype])
@_symbolic.register_fake
@ -246,12 +225,14 @@ def _(
version: Optional[int] = None,
) -> torch.Tensor:
torch._check(
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE,
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
)
# NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured
# out how it can handle empty shapes
return torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype])
return torch.zeros(
shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]
)
@torch.library.custom_op(
@ -289,10 +270,14 @@ def _symbolic_multi_out(
)
for shape, onnx_dtype in zip(shapes, onnx_dtypes):
torch._check(
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE,
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
)
outputs.append(
torch.zeros(
shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]
)
)
outputs.append(torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]))
return outputs
@ -321,10 +306,14 @@ def _(
)
for shape, onnx_dtype in zip(shapes, onnx_dtypes):
torch._check(
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE,
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
)
# NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured
# out how it can handle empty shapes
outputs.append(torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]))
outputs.append(
torch.zeros(
shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]
)
)
return outputs