mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Implement Attention-23 (#156431)
Implement Attention-23 using sdpa and flexattention. - I used copilot for this. - Also updated the conversion logic to remove trailing None inputs. @gramalingam @kunal-vaishnavi @titaiwangms Pull Request resolved: https://github.com/pytorch/pytorch/pull/156431 Approved by: https://github.com/titaiwangms Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
0ad88a2224
commit
fbbab794ef
@ -15,7 +15,6 @@ inside an ``if torch.onnx.is_in_onnx_export`` block.
|
||||
.. autofunction:: torch.onnx.ops.symbolic_multi_out
|
||||
```
|
||||
|
||||
|
||||
## ONNX Operators
|
||||
|
||||
The following operators are implemented as native PyTorch ops and can be exported as
|
||||
@ -23,7 +22,7 @@ ONNX operators. They can be used natively in an ``nn.Module``.
|
||||
|
||||
For example, you can define a module:
|
||||
|
||||
```{code-block} python
|
||||
```py
|
||||
class Model(torch.nn.Module):
|
||||
def forward(
|
||||
self, input_data, cos_cache_data, sin_cache_data, position_ids_data
|
||||
@ -38,7 +37,7 @@ class Model(torch.nn.Module):
|
||||
|
||||
and export it to ONNX using:
|
||||
|
||||
```{code-block} python
|
||||
```py
|
||||
input_data = torch.rand(2, 3, 4, 8)
|
||||
position_ids_data = torch.randint(0, 50, (2, 3)).long()
|
||||
sin_cache_data = torch.rand(50, 4)
|
||||
@ -84,7 +83,8 @@ graph(
|
||||
with the corresponding ``ExportedProgram``:
|
||||
|
||||
ExportedProgram:
|
||||
```{code-block} python
|
||||
|
||||
```py
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, input_data: "f32[s0, 3, 4, 8]", cos_cache_data: "f32[50, 4]", sin_cache_data: "f32[50, 4]", position_ids_data: "i64[s0, 3]"):
|
||||
rotary_embedding: "f32[s0, 3, 4, 8]" = torch.ops.onnx.RotaryEmbedding.opset23(input_data, cos_cache_data, sin_cache_data, position_ids_data); input_data = cos_cache_data = sin_cache_data = position_ids_data = None
|
||||
@ -93,6 +93,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: torch.onnx.ops.rotary_embedding
|
||||
.. autofunction:: torch.onnx.ops.attention
|
||||
```
|
||||
|
||||
## ONNX to ATen Decomposition Table
|
||||
@ -100,7 +101,7 @@ class GraphModule(torch.nn.Module):
|
||||
You can use {func}`torch.onnx.ops.aten_decompositions` to obtain a decomposition table
|
||||
to decompose ONNX operators defined above to ATen operators.
|
||||
|
||||
```{code-block} python
|
||||
```py
|
||||
class Model(torch.nn.Module):
|
||||
def forward(
|
||||
self, input_data, cos_cache_data, sin_cache_data, position_ids_data
|
||||
|
@ -3,10 +3,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import onnx_ir.passes.common as common_passes
|
||||
from onnxscript import ir
|
||||
|
||||
import torch
|
||||
from torch.onnx.ops import _symbolic_impl
|
||||
from torch.onnx.ops import _impl, _symbolic_impl
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
@ -426,6 +427,7 @@ class NativeOnnxOpsTest(common_utils.TestCase):
|
||||
**options,
|
||||
)
|
||||
assert onnx_program is not None
|
||||
common_passes.CheckerPass()(onnx_program.model)
|
||||
return onnx_program
|
||||
|
||||
def test_onnx_ops_can_be_decomposed_to_aten(self):
|
||||
@ -469,6 +471,17 @@ class NativeOnnxOpsTest(common_utils.TestCase):
|
||||
model(input_data, cos_cache_data, sin_cache_data, position_ids_data),
|
||||
)
|
||||
|
||||
def test_rotary_embedding_opcheck(self):
|
||||
input_data = torch.rand(2, 3, 4, 8)
|
||||
position_ids_data = torch.randint(0, 50, (2, 3)).long()
|
||||
sin_cache_data = torch.rand(50, 4)
|
||||
cos_cache_data = torch.rand(50, 4)
|
||||
|
||||
torch.library.opcheck(
|
||||
_impl.rotary_embedding_23,
|
||||
(input_data, cos_cache_data, sin_cache_data, position_ids_data),
|
||||
)
|
||||
|
||||
def test_rotary_embedding(self):
|
||||
input_data = torch.rand(2, 3, 4, 8)
|
||||
position_ids_data = torch.randint(0, 50, (2, 3)).long()
|
||||
@ -512,7 +525,928 @@ class NativeOnnxOpsTest(common_utils.TestCase):
|
||||
)
|
||||
self.assertEqual(onnx_program.model.opset_imports[""], 23)
|
||||
self.assertEqual("RotaryEmbedding", onnx_program.model.graph.node(0).op_type)
|
||||
print(onnx_program)
|
||||
|
||||
def test_attention_basic(self):
|
||||
"""Test basic attention functionality."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
# Test eager mode
|
||||
torch.library.opcheck(_impl.attention_23, (Q, K, V))
|
||||
output, present_key, present_value, qk_output = torch.onnx.ops.attention(
|
||||
Q, K, V
|
||||
)
|
||||
|
||||
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
|
||||
self.assertEqual(present_key.shape, K.shape)
|
||||
self.assertEqual(present_value.shape, V.shape)
|
||||
self.assertEqual(
|
||||
qk_output.shape, (batch_size, q_num_heads, q_seq_len, kv_seq_len)
|
||||
)
|
||||
|
||||
def test_attention_3d_inputs(self):
|
||||
"""Test attention with 3D inputs (requires num_heads parameters)."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size)
|
||||
K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
|
||||
V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
|
||||
|
||||
torch.library.opcheck(
|
||||
_impl.attention_23,
|
||||
(Q, K, V),
|
||||
dict(q_num_heads=q_num_heads, kv_num_heads=kv_num_heads),
|
||||
)
|
||||
output, present_key, present_value, qk_output = torch.onnx.ops.attention(
|
||||
Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads
|
||||
)
|
||||
|
||||
# Output should be reshaped back to 3D
|
||||
self.assertEqual(output.shape, (batch_size, q_seq_len, q_num_heads * head_size))
|
||||
self.assertEqual(
|
||||
present_key.shape, (batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
)
|
||||
self.assertEqual(
|
||||
present_value.shape, (batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
)
|
||||
|
||||
def test_attention_gqa(self):
|
||||
"""Test Group Query Attention (GQA)."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 4 # GQA: q_num_heads % kv_num_heads = 0
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
torch.library.opcheck(_impl.attention_23, (Q, K, V))
|
||||
output, present_key, present_value, qk_output = torch.onnx.ops.attention(
|
||||
Q, K, V
|
||||
)
|
||||
expected = torch.nn.functional.scaled_dot_product_attention(
|
||||
Q, K, V, None, enable_gqa=True
|
||||
)
|
||||
|
||||
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
|
||||
self.assertEqual(present_key.shape, K.shape)
|
||||
self.assertEqual(present_value.shape, V.shape)
|
||||
torch.testing.assert_close(output, expected)
|
||||
|
||||
def test_attention_mqa(self):
|
||||
"""Test Multi-Query Attention (MQA)."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 1 # MQA: kv_num_heads = 1
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
torch.library.opcheck(_impl.attention_23, (Q, K, V))
|
||||
output, present_key, present_value, qk_output = torch.onnx.ops.attention(
|
||||
Q, K, V
|
||||
)
|
||||
expected = torch.nn.functional.scaled_dot_product_attention(
|
||||
Q, K, V, None, enable_gqa=True
|
||||
)
|
||||
|
||||
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
|
||||
torch.testing.assert_close(output, expected)
|
||||
|
||||
def test_attention_with_2d_mask(self):
|
||||
"""Test attention with 2D attention mask (q_seq_len, kv_seq_len)."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
# Test with boolean mask
|
||||
bool_mask = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool)
|
||||
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=bool_mask))
|
||||
output_bool, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=bool_mask)
|
||||
|
||||
# Test with float mask
|
||||
float_mask = torch.randn(q_seq_len, kv_seq_len)
|
||||
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask))
|
||||
output_float, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask)
|
||||
|
||||
self.assertEqual(
|
||||
output_bool.shape, (batch_size, q_num_heads, q_seq_len, head_size)
|
||||
)
|
||||
self.assertEqual(
|
||||
output_float.shape, (batch_size, q_num_heads, q_seq_len, head_size)
|
||||
)
|
||||
|
||||
def test_attention_with_4d_mask(self):
|
||||
"""Test attention with 4D attention mask (batch_size, num_heads, q_seq_len, kv_seq_len)."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
# Test with boolean mask
|
||||
bool_mask = torch.randint(
|
||||
0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool
|
||||
)
|
||||
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=bool_mask))
|
||||
output_bool, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=bool_mask)
|
||||
|
||||
# Test with float mask
|
||||
float_mask = torch.randn(batch_size, q_num_heads, q_seq_len, kv_seq_len)
|
||||
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask))
|
||||
output_float, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask)
|
||||
|
||||
self.assertEqual(
|
||||
output_bool.shape, (batch_size, q_num_heads, q_seq_len, head_size)
|
||||
)
|
||||
self.assertEqual(
|
||||
output_float.shape, (batch_size, q_num_heads, q_seq_len, head_size)
|
||||
)
|
||||
|
||||
def test_attention_with_zero_float_mask(self):
|
||||
"""Test attention with zero float mask."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
zero_mask = torch.zeros(q_seq_len, kv_seq_len)
|
||||
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=zero_mask))
|
||||
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=zero_mask)
|
||||
|
||||
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
|
||||
|
||||
def test_attention_with_causal_mask_pattern(self):
|
||||
"""Test attention with lower triangular causal mask pattern."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 4 # Square for causal
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
# Create a lower triangular causal mask
|
||||
causal_mask = torch.tril(torch.ones(q_seq_len, kv_seq_len, dtype=torch.bool))
|
||||
torch.library.opcheck(
|
||||
_impl.attention_23, (Q, K, V), dict(attn_mask=causal_mask)
|
||||
)
|
||||
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=causal_mask)
|
||||
|
||||
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
|
||||
|
||||
def test_attention_with_gqa_and_mask(self):
|
||||
"""Test attention with GQA and different mask shapes."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 4 # GQA
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
# Test 2D mask with GQA
|
||||
mask_2d = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool)
|
||||
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=mask_2d))
|
||||
output_2d, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask_2d)
|
||||
|
||||
# Test 4D mask with GQA (note: using q_num_heads for mask heads)
|
||||
mask_4d = torch.randint(
|
||||
0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool
|
||||
)
|
||||
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=mask_4d))
|
||||
output_4d, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask_4d)
|
||||
|
||||
self.assertEqual(
|
||||
output_2d.shape, (batch_size, q_num_heads, q_seq_len, head_size)
|
||||
)
|
||||
self.assertEqual(
|
||||
output_4d.shape, (batch_size, q_num_heads, q_seq_len, head_size)
|
||||
)
|
||||
|
||||
def test_attention_with_large_negative_float_mask(self):
|
||||
"""Test attention with large negative values in float mask."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
# Create mask with large negative values (similar to -inf masking)
|
||||
float_mask = torch.full((q_seq_len, kv_seq_len), -1e9)
|
||||
# Allow some positions
|
||||
float_mask[:, :3] = 0.0
|
||||
|
||||
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask))
|
||||
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask)
|
||||
|
||||
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
|
||||
|
||||
def test_attention_causal(self):
|
||||
"""Test causal attention."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 4 # Square for causal
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(is_causal=True))
|
||||
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, is_causal=True)
|
||||
|
||||
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
|
||||
|
||||
def test_attention_with_past_kv(self):
|
||||
"""Test attention with past key/value caches."""
|
||||
batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
|
||||
past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
|
||||
|
||||
torch.library.opcheck(
|
||||
_impl.attention_23,
|
||||
(Q, K, V),
|
||||
dict(past_key=past_key, past_value=past_value),
|
||||
)
|
||||
output, present_key, present_value, _ = torch.onnx.ops.attention(
|
||||
Q, K, V, past_key=past_key, past_value=past_value
|
||||
)
|
||||
|
||||
# Present key/value should include past + current
|
||||
expected_total_seq_len = past_seq_len + kv_seq_len
|
||||
self.assertEqual(
|
||||
present_key.shape,
|
||||
(batch_size, kv_num_heads, expected_total_seq_len, head_size),
|
||||
)
|
||||
self.assertEqual(
|
||||
present_value.shape,
|
||||
(batch_size, kv_num_heads, expected_total_seq_len, head_size),
|
||||
)
|
||||
|
||||
def test_attention_with_softcap(self):
|
||||
"""Test attention with softcap."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(softcap=30.0))
|
||||
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, softcap=30.0)
|
||||
|
||||
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
|
||||
|
||||
def test_attention_qk_output_modes(self):
|
||||
"""Test different QK matmul output modes."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
for mode in [0, 1, 2, 3]:
|
||||
torch.library.opcheck(
|
||||
_impl.attention_23,
|
||||
(Q, K, V),
|
||||
dict(qk_matmul_output_mode=mode),
|
||||
)
|
||||
output, _, _, qk_output = torch.onnx.ops.attention(
|
||||
Q, K, V, qk_matmul_output_mode=mode
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
output.shape, (batch_size, q_num_heads, q_seq_len, head_size)
|
||||
)
|
||||
self.assertEqual(
|
||||
qk_output.shape, (batch_size, q_num_heads, q_seq_len, kv_seq_len)
|
||||
)
|
||||
|
||||
def test_attention_custom_scale(self):
|
||||
"""Test attention with custom scale factor."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
custom_scale = 0.25
|
||||
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(scale=custom_scale))
|
||||
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, scale=custom_scale)
|
||||
|
||||
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
|
||||
|
||||
def test_attention_export(self):
|
||||
"""Test that attention can be exported to ONNX."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
class AttentionModel(torch.nn.Module):
|
||||
def forward(self, Q, K, V):
|
||||
output, present_key, present_value, qk_output = (
|
||||
torch.onnx.ops.attention(Q, K, V)
|
||||
)
|
||||
return output
|
||||
|
||||
model = AttentionModel()
|
||||
|
||||
onnx_program = self.export(
|
||||
model,
|
||||
(Q, K, V),
|
||||
opset_version=23,
|
||||
)
|
||||
|
||||
self.assertEqual(onnx_program.model.opset_imports[""], 23)
|
||||
self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type)
|
||||
|
||||
def test_attention_export_with_dynamic_shapes(self):
|
||||
"""Test attention export with dynamic shapes."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
class AttentionModel(torch.nn.Module):
|
||||
def forward(self, Q, K, V):
|
||||
output, present_key, present_value, qk_output = (
|
||||
torch.onnx.ops.attention(Q, K, V)
|
||||
)
|
||||
return output
|
||||
|
||||
model = AttentionModel()
|
||||
|
||||
dynamic_shapes = {
|
||||
"Q": {0: "batch", 2: "q_seq_len"},
|
||||
"K": {0: "batch", 2: "kv_seq_len"},
|
||||
"V": {0: "batch", 2: "kv_seq_len"},
|
||||
}
|
||||
|
||||
onnx_program = self.export(
|
||||
model,
|
||||
(Q, K, V),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
opset_version=23,
|
||||
)
|
||||
|
||||
self.assertEqual(onnx_program.model.opset_imports[""], 23)
|
||||
self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type)
|
||||
node = onnx_program.model.graph.node(0)
|
||||
# Verify inputs
|
||||
self.assertEqual(len(node.inputs), 3) # Q, K, V (no optional inputs)
|
||||
self.assertEqual(
|
||||
node.inputs[0].shape, ["batch", q_num_heads, "q_seq_len", head_size]
|
||||
)
|
||||
self.assertEqual(
|
||||
node.inputs[1].shape, ["batch", kv_num_heads, "kv_seq_len", head_size]
|
||||
)
|
||||
self.assertEqual(
|
||||
node.inputs[2].shape, ["batch", kv_num_heads, "kv_seq_len", head_size]
|
||||
)
|
||||
|
||||
# Verify default attributes (should be minimal)
|
||||
self.assertEqual(len(node.attributes), 0)
|
||||
|
||||
def test_attention_3d_export(self):
|
||||
"""Test attention export with 3D inputs."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size)
|
||||
K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
|
||||
V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
|
||||
|
||||
class AttentionModel(torch.nn.Module):
|
||||
def forward(self, Q, K, V):
|
||||
output, _, _, _ = torch.onnx.ops.attention(
|
||||
Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads
|
||||
)
|
||||
return output
|
||||
|
||||
model = AttentionModel()
|
||||
|
||||
onnx_program = self.export(
|
||||
model,
|
||||
(Q, K, V),
|
||||
opset_version=23,
|
||||
)
|
||||
|
||||
self.assertEqual(onnx_program.model.opset_imports[""], 23)
|
||||
self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type)
|
||||
|
||||
def test_attention_decomposition(self):
|
||||
"""Test that attention can be decomposed to aten ops."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
class AttentionModel(torch.nn.Module):
|
||||
def forward(self, Q, K, V):
|
||||
output, present_key, present_value, qk_output = (
|
||||
torch.onnx.ops.attention(Q, K, V)
|
||||
)
|
||||
return output
|
||||
|
||||
model = AttentionModel()
|
||||
|
||||
ep = torch.export.export(model, (Q, K, V))
|
||||
self.assertIn(
|
||||
"onnx.Attention.opset23",
|
||||
[str(node.target) for node in ep.graph.nodes],
|
||||
)
|
||||
|
||||
# The program can be decomposed into aten ops
|
||||
aten_decomped = ep.run_decompositions(torch.onnx.ops.aten_decompositions())
|
||||
self.assertNotIn(
|
||||
"onnx.Attention.opset23",
|
||||
[str(node.target) for node in aten_decomped.graph.nodes],
|
||||
)
|
||||
|
||||
# Results should match
|
||||
torch.testing.assert_close(
|
||||
aten_decomped.module()(Q, K, V),
|
||||
model(Q, K, V),
|
||||
)
|
||||
|
||||
def test_attention_export_with_past_key_value(self):
|
||||
"""Test export with past_key, past_value to ensure the optional input order is correct."""
|
||||
batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
|
||||
past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, Q, K, V, past_key, past_value):
|
||||
output, _, _, _ = torch.onnx.ops.attention(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
past_key=past_key,
|
||||
attn_mask=None,
|
||||
# Switched argument order
|
||||
past_value=past_value,
|
||||
)
|
||||
return output
|
||||
|
||||
model = Model()
|
||||
onnx_program = self.export(
|
||||
model, (Q, K, V, past_key, past_value), opset_version=23
|
||||
)
|
||||
|
||||
node = onnx_program.model.graph.node(0)
|
||||
self.assertEqual(node.op_type, "Attention")
|
||||
|
||||
# Verify all 6 inputs are present
|
||||
self.assertEqual(
|
||||
len(node.inputs), 6
|
||||
) # Q, K, V, attn_mask, past_key, past_value
|
||||
self.assertEqual(
|
||||
node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
|
||||
)
|
||||
self.assertEqual(
|
||||
node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
|
||||
)
|
||||
self.assertEqual(
|
||||
node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
|
||||
)
|
||||
self.assertIsNone(node.inputs[3])
|
||||
self.assertEqual(
|
||||
node.inputs[4].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
|
||||
)
|
||||
self.assertEqual(
|
||||
node.inputs[5].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
|
||||
)
|
||||
|
||||
def test_attention_export_with_all_optional_inputs(self):
|
||||
"""Test export with all optional inputs: mask, past_key, past_value."""
|
||||
batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
attn_mask = torch.randint(
|
||||
0, 2, (1, 1, q_seq_len, kv_seq_len + past_seq_len), dtype=torch.bool
|
||||
)
|
||||
past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
|
||||
past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
|
||||
|
||||
class FullAttentionModel(torch.nn.Module):
|
||||
def forward(self, Q, K, V, attn_mask, past_key, past_value):
|
||||
output, _, _, _ = torch.onnx.ops.attention(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
attn_mask=attn_mask,
|
||||
past_key=past_key,
|
||||
past_value=past_value,
|
||||
)
|
||||
return output
|
||||
|
||||
model = FullAttentionModel()
|
||||
onnx_program = self.export(
|
||||
model, (Q, K, V, attn_mask, past_key, past_value), opset_version=23
|
||||
)
|
||||
|
||||
node = onnx_program.model.graph.node(0)
|
||||
self.assertEqual(node.op_type, "Attention")
|
||||
|
||||
# Verify all 6 inputs are present
|
||||
self.assertEqual(
|
||||
len(node.inputs), 6
|
||||
) # Q, K, V, attn_mask, past_key, past_value
|
||||
self.assertEqual(
|
||||
node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
|
||||
)
|
||||
self.assertEqual(
|
||||
node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
|
||||
)
|
||||
self.assertEqual(
|
||||
node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
|
||||
)
|
||||
self.assertEqual(
|
||||
node.inputs[3].shape, [1, 1, q_seq_len, kv_seq_len + past_seq_len]
|
||||
)
|
||||
self.assertEqual(
|
||||
node.inputs[4].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
|
||||
)
|
||||
self.assertEqual(
|
||||
node.inputs[5].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
|
||||
)
|
||||
|
||||
def test_attention_export_3d_with_num_heads_attributes(self):
|
||||
"""Test export with 3D inputs and explicit num_heads attributes."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 4 # GQA
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size)
|
||||
K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
|
||||
V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
|
||||
|
||||
class Attention3DModel(torch.nn.Module):
|
||||
def forward(self, Q, K, V):
|
||||
output, _, _, _ = torch.onnx.ops.attention(
|
||||
Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads
|
||||
)
|
||||
return output
|
||||
|
||||
model = Attention3DModel()
|
||||
onnx_program = self.export(model, (Q, K, V), opset_version=23)
|
||||
|
||||
node = onnx_program.model.graph.node(0)
|
||||
self.assertEqual(node.op_type, "Attention")
|
||||
|
||||
# Verify 3D input shapes
|
||||
self.assertEqual(
|
||||
node.inputs[0].shape, [batch_size, q_seq_len, q_num_heads * head_size]
|
||||
)
|
||||
self.assertEqual(
|
||||
node.inputs[1].shape, [batch_size, kv_seq_len, kv_num_heads * head_size]
|
||||
)
|
||||
self.assertEqual(
|
||||
node.inputs[2].shape, [batch_size, kv_seq_len, kv_num_heads * head_size]
|
||||
)
|
||||
|
||||
# Verify num_heads attributes are set
|
||||
attrs = node.attributes
|
||||
self.assertIn("q_num_heads", attrs)
|
||||
self.assertIn("kv_num_heads", attrs)
|
||||
self.assertEqual(attrs["q_num_heads"].value, q_num_heads)
|
||||
self.assertEqual(attrs["kv_num_heads"].value, kv_num_heads)
|
||||
|
||||
def test_attention_export_with_all_attributes(self):
|
||||
"""Test export with all possible attributes set."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
class FullAttributesModel(torch.nn.Module):
|
||||
def forward(self, Q, K, V):
|
||||
output, _, _, _ = torch.onnx.ops.attention(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
is_causal=True,
|
||||
qk_matmul_output_mode=2,
|
||||
scale=0.25,
|
||||
softcap=30.0,
|
||||
softmax_precision=1, # FLOAT
|
||||
)
|
||||
return output
|
||||
|
||||
model = FullAttributesModel()
|
||||
onnx_program = self.export(model, (Q, K, V), opset_version=23)
|
||||
|
||||
node = onnx_program.model.graph.node(0)
|
||||
self.assertEqual(node.op_type, "Attention")
|
||||
|
||||
# Verify all attributes are set correctly
|
||||
attrs = node.attributes
|
||||
self.assertIn("is_causal", attrs)
|
||||
self.assertIn("qk_matmul_output_mode", attrs)
|
||||
self.assertIn("scale", attrs)
|
||||
self.assertIn("softcap", attrs)
|
||||
self.assertIn("softmax_precision", attrs)
|
||||
|
||||
self.assertEqual(attrs["is_causal"].value, 1) # True as int
|
||||
self.assertEqual(attrs["qk_matmul_output_mode"].value, 2)
|
||||
self.assertAlmostEqual(attrs["scale"].value, 0.25, places=6)
|
||||
self.assertAlmostEqual(attrs["softcap"].value, 30.0, places=6)
|
||||
self.assertEqual(attrs["softmax_precision"].value, 1)
|
||||
|
||||
def test_attention_export_with_different_mask_shapes(self):
|
||||
"""Test export with different attention mask shapes."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
# Test 2D mask
|
||||
mask_2d = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool)
|
||||
|
||||
class Mask2DModel(torch.nn.Module):
|
||||
def forward(self, Q, K, V, mask):
|
||||
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
|
||||
return output
|
||||
|
||||
model_2d = Mask2DModel()
|
||||
onnx_program_2d = self.export(model_2d, (Q, K, V, mask_2d), opset_version=23)
|
||||
|
||||
node_2d = onnx_program_2d.model.graph.node(0)
|
||||
self.assertEqual(node_2d.inputs[3].shape, [q_seq_len, kv_seq_len])
|
||||
|
||||
# Test 3D mask
|
||||
mask_3d = torch.randint(
|
||||
0, 2, (batch_size, 1, q_seq_len, kv_seq_len), dtype=torch.bool
|
||||
)
|
||||
|
||||
class Mask3DModel(torch.nn.Module):
|
||||
def forward(self, Q, K, V, mask):
|
||||
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
|
||||
return output
|
||||
|
||||
model_3d = Mask3DModel()
|
||||
onnx_program_3d = self.export(model_3d, (Q, K, V, mask_3d), opset_version=23)
|
||||
|
||||
node_3d = onnx_program_3d.model.graph.node(0)
|
||||
self.assertEqual(
|
||||
node_3d.inputs[3].shape, [batch_size, 1, q_seq_len, kv_seq_len]
|
||||
)
|
||||
|
||||
# Test 4D mask
|
||||
mask_4d = torch.randint(
|
||||
0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool
|
||||
)
|
||||
|
||||
class Mask4DModel(torch.nn.Module):
|
||||
def forward(self, Q, K, V, mask):
|
||||
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
|
||||
return output
|
||||
|
||||
model_4d = Mask4DModel()
|
||||
onnx_program_4d = self.export(model_4d, (Q, K, V, mask_4d), opset_version=23)
|
||||
|
||||
node_4d = onnx_program_4d.model.graph.node(0)
|
||||
self.assertEqual(
|
||||
node_4d.inputs[3].shape, [batch_size, q_num_heads, q_seq_len, kv_seq_len]
|
||||
)
|
||||
|
||||
def test_attention_export_with_float_mask(self):
|
||||
"""Test export with float attention mask."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
float_mask = torch.randn(q_seq_len, kv_seq_len)
|
||||
|
||||
class FloatMaskModel(torch.nn.Module):
|
||||
def forward(self, Q, K, V, mask):
|
||||
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
|
||||
return output
|
||||
|
||||
model = FloatMaskModel()
|
||||
onnx_program = self.export(model, (Q, K, V, float_mask), opset_version=23)
|
||||
|
||||
node = onnx_program.model.graph.node(0)
|
||||
self.assertEqual(node.op_type, "Attention")
|
||||
self.assertEqual(node.inputs[3].shape, [q_seq_len, kv_seq_len])
|
||||
# Verify the mask input has float dtype in the ONNX model
|
||||
self.assertEqual(node.inputs[3].dtype, ir.DataType.FLOAT)
|
||||
|
||||
def test_attention_export_qk_output_modes(self):
|
||||
"""Test export with different QK output modes."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
for mode in [0, 1, 2, 3]:
|
||||
|
||||
class QKOutputModel(torch.nn.Module):
|
||||
def __init__(self, qk_mode):
|
||||
super().__init__()
|
||||
self.qk_mode = qk_mode
|
||||
|
||||
def forward(self, Q, K, V):
|
||||
output, _, _, qk_output = torch.onnx.ops.attention(
|
||||
Q, K, V, qk_matmul_output_mode=self.qk_mode
|
||||
)
|
||||
return output, qk_output
|
||||
|
||||
model = QKOutputModel(mode)
|
||||
onnx_program = self.export(model, (Q, K, V), opset_version=23)
|
||||
|
||||
node = onnx_program.model.graph.node(0)
|
||||
self.assertEqual(node.op_type, "Attention")
|
||||
|
||||
# Verify qk_matmul_output_mode attribute
|
||||
attrs = node.attributes
|
||||
if mode != 0:
|
||||
self.assertIn("qk_matmul_output_mode", attrs)
|
||||
self.assertEqual(attrs["qk_matmul_output_mode"].value, mode)
|
||||
|
||||
# Verify 4 outputs (output, present_key, present_value, qk_output)
|
||||
self.assertEqual(len(node.outputs), 4)
|
||||
|
||||
def test_attention_export_mqa(self):
|
||||
"""Test export with Multi-Query Attention (MQA)."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 1 # MQA
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
class MQAModel(torch.nn.Module):
|
||||
def forward(self, Q, K, V):
|
||||
output, _, _, _ = torch.onnx.ops.attention(Q, K, V)
|
||||
return output
|
||||
|
||||
model = MQAModel()
|
||||
onnx_program = self.export(model, (Q, K, V), opset_version=23)
|
||||
|
||||
node = onnx_program.model.graph.node(0)
|
||||
self.assertEqual(node.op_type, "Attention")
|
||||
|
||||
# Verify MQA tensor shapes
|
||||
self.assertEqual(
|
||||
node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
|
||||
)
|
||||
self.assertEqual(
|
||||
node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
|
||||
) # kv_num_heads = 1
|
||||
self.assertEqual(
|
||||
node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
|
||||
)
|
||||
|
||||
def test_attention_export_with_softmax_precision(self):
|
||||
"""Test export with different softmax precision values."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 8
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
# Test different ONNX precision types
|
||||
precision_types = [
|
||||
(1, "FLOAT"),
|
||||
(10, "FLOAT16"),
|
||||
(11, "DOUBLE"),
|
||||
(16, "BFLOAT16"),
|
||||
]
|
||||
|
||||
for precision_val, precision_name in precision_types:
|
||||
|
||||
class SoftmaxPrecisionModel(torch.nn.Module):
|
||||
def __init__(self, precision):
|
||||
super().__init__()
|
||||
self.precision = precision
|
||||
|
||||
def forward(self, Q, K, V):
|
||||
output, _, _, _ = torch.onnx.ops.attention(
|
||||
Q, K, V, softmax_precision=self.precision
|
||||
)
|
||||
return output
|
||||
|
||||
model = SoftmaxPrecisionModel(precision_val)
|
||||
onnx_program = self.export(model, (Q, K, V), opset_version=23)
|
||||
|
||||
node = onnx_program.model.graph.node(0)
|
||||
self.assertEqual(node.op_type, "Attention")
|
||||
|
||||
# Verify softmax_precision attribute
|
||||
attrs = node.attributes
|
||||
self.assertIn("softmax_precision", attrs)
|
||||
self.assertEqual(attrs["softmax_precision"].value, precision_val)
|
||||
|
||||
def test_attention_export_gqa(self):
|
||||
"""Test export and verify output tensor shapes."""
|
||||
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
|
||||
q_num_heads, kv_num_heads = 8, 4 # GQA
|
||||
head_size = 64
|
||||
|
||||
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
|
||||
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
|
||||
class AttentionOutputsModel(torch.nn.Module):
|
||||
def forward(self, Q, K, V):
|
||||
return torch.onnx.ops.attention(Q, K, V)
|
||||
|
||||
model = AttentionOutputsModel()
|
||||
onnx_program = self.export(model, (Q, K, V), opset_version=23)
|
||||
|
||||
node = onnx_program.model.graph.node(0)
|
||||
self.assertEqual(node.op_type, "Attention")
|
||||
|
||||
# Verify all 4 outputs have correct shapes
|
||||
outputs = node.outputs
|
||||
self.assertEqual(len(outputs), 4)
|
||||
|
||||
# output: (batch_size, q_num_heads, q_seq_len, head_size)
|
||||
self.assertEqual(
|
||||
outputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
|
||||
)
|
||||
|
||||
# present_key: (batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
self.assertEqual(
|
||||
outputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
|
||||
)
|
||||
|
||||
# present_value: (batch_size, kv_num_heads, kv_seq_len, head_size)
|
||||
self.assertEqual(
|
||||
outputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
|
||||
)
|
||||
|
||||
# qk_output: (batch_size, q_num_heads, q_seq_len, kv_seq_len)
|
||||
self.assertEqual(
|
||||
outputs[3].shape, [batch_size, q_num_heads, q_seq_len, kv_seq_len]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -552,6 +552,11 @@ def _handle_call_function_node_with_lowering(
|
||||
if _is_onnx_op(node.target):
|
||||
# Handle torch.ops.onnx.* ops. These ops can be directly added to the graph
|
||||
op_type, opset_version = _parse_onnx_op(node.target) # type: ignore[arg-type]
|
||||
# If final inputs are None, strip them from the node inputs
|
||||
for input_ in reversed(onnx_args):
|
||||
if input_ is not None:
|
||||
break
|
||||
onnx_args.pop()
|
||||
onnx_node = ir.Node(
|
||||
"",
|
||||
op_type,
|
||||
|
@ -13,6 +13,7 @@ __all__ = [
|
||||
"symbolic",
|
||||
"symbolic_multi_out",
|
||||
"rotary_embedding",
|
||||
"attention",
|
||||
]
|
||||
|
||||
|
||||
@ -334,7 +335,7 @@ def rotary_embedding(
|
||||
Returns:
|
||||
Tensor with same shape as input.
|
||||
"""
|
||||
return _impl.rotary_embedding(
|
||||
return _impl.rotary_embedding_23(
|
||||
X,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
@ -343,3 +344,124 @@ def rotary_embedding(
|
||||
num_heads=num_heads,
|
||||
rotary_embedding_dim=rotary_embedding_dim,
|
||||
)
|
||||
|
||||
|
||||
def attention(
|
||||
Q: torch.Tensor,
|
||||
K: torch.Tensor,
|
||||
V: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
past_key: torch.Tensor | None = None,
|
||||
past_value: torch.Tensor | None = None,
|
||||
*,
|
||||
is_causal: bool = False,
|
||||
kv_num_heads: int = 0,
|
||||
q_num_heads: int = 0,
|
||||
qk_matmul_output_mode: int = 0,
|
||||
scale: float | None = None,
|
||||
softcap: float = 0.0,
|
||||
softmax_precision: int | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Attention op in ONNX.
|
||||
|
||||
https://onnx.ai/onnx/operators/onnx__Attention.html
|
||||
|
||||
Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed.
|
||||
|
||||
This operator covers self and cross variants of the attention operation based on sequence lengths of K, Q and V.
|
||||
|
||||
For self attention, ``kv_sequence_length`` equals to ``q_sequence_length``.
|
||||
|
||||
For cross attention, query and key might have different lengths.
|
||||
|
||||
This operator also covers the 3 following variants based on the number of heads:
|
||||
|
||||
1. Multi-headed Attention (MHA): Described in the paper https://arxiv.org/pdf/1706.03762, `q_num_heads = kv_num_heads`.
|
||||
2. Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`.
|
||||
3. Multi-query Attention (MQA): Described in the paper https://arxiv.org/pdf/1911.02150, `q_num_heads > kv_num_heads`, `kv_num_heads=1`.
|
||||
|
||||
Attention bias to be added is calculated based on ``attn_mask`` input and ``is_causal` `attribute``, only one of which can be provided.
|
||||
|
||||
1. If ``is_causal`` is set to `1`, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment.
|
||||
2. `attn_mask`: A boolean mask where a value of `True` indicates that the element should take part in attention or a float mask of the same type as query, key, value that is added to the attention score.
|
||||
|
||||
Both past and present state key/values are optional. They shall be used together, and not allowed to use only one of them.
|
||||
The following pattern is applied to the Q, K and V inputs after appropriate reshaping of K and V inputs based on sequence lengths and num heads provided::
|
||||
|
||||
The following pattern is applied by this operator:
|
||||
Q K V
|
||||
| | |
|
||||
Q*sqrt(scale) K*sqrt(scale) |
|
||||
| | |
|
||||
| Transpose |
|
||||
| | |
|
||||
---MatMul--- |
|
||||
| |
|
||||
at_mask---Add |
|
||||
| |
|
||||
softcap (if provided) |
|
||||
| |
|
||||
Softmax |
|
||||
| |
|
||||
-----MatMul------
|
||||
|
|
||||
Y
|
||||
|
||||
Args:
|
||||
Q: Query tensor. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, head_size)` or 3D tensor
|
||||
with shape `(batch_size, q_sequence_length, q_hidden_size)`. For cases with a 3D input tensor,
|
||||
`q_hidden_size = q_num_heads * head_size`
|
||||
K: Key tensor. 4D tensor with shape `(batch_size, kv_num_heads, kv_sequence_length, head_size)` or 3D tensor
|
||||
with shape `(batch_size, kv_sequence_length, k_hidden_size)`. For cases with a 3D input tensor,
|
||||
`k_hidden_size = kv_num_heads * head_size`
|
||||
V: Value tensor. 4D tensor with shape `(batch_size, kv_num_heads, kv_sequence_length, v_head_size)` or 3D tensor
|
||||
with shape `(batch_size, kv_sequence_length, v_hidden_size)`. For cases with a 3D input tensor,
|
||||
`v_hidden_size = kv_num_heads * v_head_size`
|
||||
attn_mask: Attention mask. Shape must be broadcastable to 4D tensor with shape
|
||||
`(batch_size, q_num_heads, q_sequence_length, total_sequence_length)` where
|
||||
`total_sequence_length = past_sequence_length + kv_sequence_length`. Two types of masks are supported.
|
||||
A boolean mask where a value of True indicates that the element should take part in attention.
|
||||
Also supports a float mask of the same type as query, key, value that is added to the attention score.
|
||||
past_key: Past state cache for key with shape `(batch_size, kv_num_heads, past_sequence_length, head_size)`
|
||||
past_value: Past state cache for value with shape `(batch_size, kv_num_heads, past_sequence_length, v_head_size)`
|
||||
is_causal: If set to True, the attention masking is a lower triangular matrix when the mask is a square matrix.
|
||||
The attention masking has the form of the upper left causal bias due to the alignment.
|
||||
kv_num_heads: Number of heads of key and value. Must be used with 3D inputs of Q, K and V.
|
||||
q_num_heads: Number of heads of query. Must be used with 3D inputs of Q, K and V.
|
||||
qk_matmul_output_mode: If set to 0, qk_matmul_output is the output of qk matmul. If set to 1,
|
||||
qk_matmul_output includes the addition of the attention mask to the output of qk matmul.
|
||||
If set to 2, qk_matmul_output is the output after the softcap operation. If set to 3,
|
||||
qk_matmul_output is the output after the softmax operation. Default value is 0.
|
||||
scale: Scaling factor applied to Q*K^T. Default value is 1/sqrt(head_size). To prevent numerical overflow,
|
||||
scale Q, K by sqrt(scale) before matmul.
|
||||
softcap: Softcap value for attention weights. Default value is 0.
|
||||
softmax_precision: The floating-point precision used in softmax computation. If softmax precision is not provided,
|
||||
the same precision as the input of softmax (Q and K) is used.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The output tensor. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, v_head_size)` or 3D tensor
|
||||
with shape `(batch_size, q_sequence_length, hidden_size)`. For cases with a 3D input tensor,
|
||||
`hidden_size = q_num_heads * v_head_size`
|
||||
- Updated key cache with shape `(batch_size, kv_num_heads, total_sequence_length, head_size)` where
|
||||
`total_sequence_length = past_sequence_length + kv_sequence_length`.
|
||||
- Updated value cache with shape `(batch_size, kv_num_heads, total_sequence_length, v_head_size)` where
|
||||
`total_sequence_length = past_sequence_length + kv_sequence_length`.
|
||||
- The output of QK matmul. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, total_sequence_length)`
|
||||
where `total_sequence_length = past_sequence_length + kv_sequence_length`.
|
||||
"""
|
||||
return _impl.attention_23(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
attn_mask=attn_mask,
|
||||
past_key=past_key,
|
||||
past_value=past_value,
|
||||
is_causal=is_causal,
|
||||
kv_num_heads=kv_num_heads,
|
||||
q_num_heads=q_num_heads,
|
||||
qk_matmul_output_mode=qk_matmul_output_mode,
|
||||
scale=scale,
|
||||
softcap=softcap,
|
||||
softmax_precision=softmax_precision,
|
||||
)
|
||||
|
27
torch/onnx/ops/_dtype_mappings.py
Normal file
27
torch/onnx/ops/_dtype_mappings.py
Normal file
@ -0,0 +1,27 @@
|
||||
import torch
|
||||
|
||||
|
||||
ONNX_DTYPE_TO_TORCH_DTYPE: dict[int, torch.dtype] = {
|
||||
1: torch.float32, # FLOAT
|
||||
2: torch.uint8, # UINT8
|
||||
3: torch.int8, # INT8
|
||||
4: torch.uint16, # UINT16
|
||||
5: torch.int16, # INT16
|
||||
6: torch.int32, # INT32
|
||||
7: torch.int64, # INT64
|
||||
9: torch.bool, # BOOL
|
||||
10: torch.float16, # FLOAT16
|
||||
11: torch.double, # DOUBLE
|
||||
12: torch.uint32, # UINT32
|
||||
13: torch.uint64, # UINT64
|
||||
14: torch.complex64, # COMPLEX64
|
||||
15: torch.complex128, # COMPLEX128
|
||||
16: torch.bfloat16, # BFLOAT16
|
||||
17: torch.float8_e4m3fn, # FLOAT8E4M3FN
|
||||
18: torch.float8_e4m3fnuz, # FLOAT8E4M3FNUZ
|
||||
19: torch.float8_e5m2, # FLOAT8E5M2
|
||||
20: torch.float8_e5m2fnuz, # FLOAT8E5M2FNUZ
|
||||
21: torch.uint8, # UINT4
|
||||
22: torch.uint8, # INT4
|
||||
23: torch.float4_e2m1fn_x2, # FLOAT4E2M1
|
||||
}
|
@ -1,13 +1,24 @@
|
||||
# flake8: noqa: B950
|
||||
import math
|
||||
import typing
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.onnx.ops import _dtype_mappings
|
||||
|
||||
|
||||
_T = typing.TypeVar("_T", bound=Callable)
|
||||
|
||||
# ONNX to ATen decomp table
|
||||
ONNX_ATEN_DECOMP_TABLE: dict[torch._ops.OpOverload, Callable] = {}
|
||||
_ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS = frozenset(
|
||||
{
|
||||
1, # FLOAT
|
||||
10, # FLOAT16
|
||||
11, # DOUBLE
|
||||
16, # BFLOAT16
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _onnx_op(op_type: str, opset_version: int) -> Callable[[_T], _T]:
|
||||
@ -30,7 +41,7 @@ def _onnx_op(op_type: str, opset_version: int) -> Callable[[_T], _T]:
|
||||
|
||||
|
||||
@_onnx_op("RotaryEmbedding", 23)
|
||||
def rotary_embedding(
|
||||
def rotary_embedding_23(
|
||||
x: torch.Tensor,
|
||||
cos_cache: torch.Tensor,
|
||||
sin_cache: torch.Tensor,
|
||||
@ -46,11 +57,14 @@ def rotary_embedding(
|
||||
sequence_length = x.shape[1]
|
||||
if len(x.shape) == 3:
|
||||
hidden_size = x.shape[2]
|
||||
assert num_heads != 0
|
||||
torch._check(
|
||||
num_heads != 0,
|
||||
lambda: f"num_heads must be provided for 3D inputs. Received input tensor with shape {x.shape}",
|
||||
)
|
||||
head_size = hidden_size // num_heads
|
||||
new_shape = [batch_size, sequence_length, num_heads, head_size]
|
||||
x = torch.reshape(x, new_shape)
|
||||
assert len(x.shape) == 4
|
||||
torch._check(len(x.shape) == 4, lambda: "x should be a 4D tensor by now")
|
||||
head_size = x.shape[3]
|
||||
|
||||
# Fully or partially perform rotation on x based on rotary_embedding_dim attribute
|
||||
@ -110,3 +124,273 @@ def rotary_embedding(
|
||||
if len(x.shape) == 3:
|
||||
output = torch.reshape(output, x.shape)
|
||||
return output
|
||||
|
||||
|
||||
def _get_scale_factor(scale: Optional[float], head_size: int) -> float:
|
||||
"""Get the scale factor for attention computation."""
|
||||
return scale if scale is not None else (1.0 / math.sqrt(head_size))
|
||||
|
||||
|
||||
def _reshape_3d_to_4d(
|
||||
tensor: torch.Tensor, batch_size: int, num_heads: int
|
||||
) -> torch.Tensor:
|
||||
"""Reshape 3D tensor to 4D for multi-head attention."""
|
||||
sequence_length, hidden_size = tensor.shape[1], tensor.shape[2]
|
||||
head_size = hidden_size // num_heads
|
||||
return (
|
||||
tensor.view(batch_size, sequence_length, num_heads, head_size)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
|
||||
def _get_qk_output_for_aten_spda(
|
||||
Q: torch.Tensor,
|
||||
K: torch.Tensor,
|
||||
current_q_num_heads: int,
|
||||
current_kv_num_heads: int,
|
||||
scale: Optional[float],
|
||||
qk_matmul_output_mode: int,
|
||||
) -> torch.Tensor:
|
||||
"""Get QK output tensor based on the specified mode."""
|
||||
if qk_matmul_output_mode == 0:
|
||||
return _compute_qk_output_for_mode_0(
|
||||
Q, K, current_q_num_heads, current_kv_num_heads, scale
|
||||
)
|
||||
else:
|
||||
# For other modes, return a zero tensor with correct shape
|
||||
return torch.zeros_like(torch.matmul(Q, K.transpose(-2, -1)))
|
||||
|
||||
|
||||
def _validate_gqa_configuration(
|
||||
current_q_num_heads: int, current_kv_num_heads: int
|
||||
) -> None:
|
||||
"""Validate Group Query Attention configuration."""
|
||||
torch._check(
|
||||
current_q_num_heads % current_kv_num_heads == 0,
|
||||
lambda: f"q_num_heads ({current_q_num_heads}) must be divisible by kv_num_heads ({current_kv_num_heads}) for GQA",
|
||||
)
|
||||
|
||||
|
||||
def _compute_qk_output_for_mode_0(
|
||||
Q: torch.Tensor,
|
||||
K: torch.Tensor,
|
||||
current_q_num_heads: int,
|
||||
current_kv_num_heads: int,
|
||||
scale: Optional[float],
|
||||
) -> torch.Tensor:
|
||||
"""Helper function to compute QK output for qk_matmul_output_mode == 0."""
|
||||
# Handle GQA manually for QK output
|
||||
K_for_qk = K
|
||||
if current_q_num_heads != current_kv_num_heads:
|
||||
repeat_factor = current_q_num_heads // current_kv_num_heads
|
||||
K_for_qk = K.repeat_interleave(repeat_factor, dim=1)
|
||||
|
||||
scale_factor = _get_scale_factor(scale, Q.shape[3])
|
||||
# Scale both Q and K by sqrt(scale_factor) for numerical stability
|
||||
sqrt_scale = math.sqrt(scale_factor)
|
||||
Q_scaled = Q * sqrt_scale
|
||||
K_scaled = K_for_qk * sqrt_scale
|
||||
return torch.matmul(Q_scaled, K_scaled.transpose(-2, -1))
|
||||
|
||||
|
||||
@_onnx_op("Attention", 23)
|
||||
def attention_23(
|
||||
Q: torch.Tensor,
|
||||
K: torch.Tensor,
|
||||
V: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
past_key: Optional[torch.Tensor] = None,
|
||||
past_value: Optional[torch.Tensor] = None,
|
||||
*,
|
||||
is_causal: bool = False,
|
||||
kv_num_heads: int = 0,
|
||||
q_num_heads: int = 0,
|
||||
qk_matmul_output_mode: int = 0,
|
||||
scale: Optional[float] = None,
|
||||
softcap: float = 0.0,
|
||||
softmax_precision: Optional[int] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Attention-23 https://onnx.ai/onnx/operators/onnx__Attention.html#attention-23"""
|
||||
|
||||
num_head_dim, sequence_dim, head_dim = 1, 2, 3
|
||||
|
||||
# Store original input shape to determine output shape
|
||||
input_shape_len = len(Q.shape)
|
||||
batch_size = Q.shape[0]
|
||||
|
||||
# Reshape 3D inputs to 4D format
|
||||
if len(Q.shape) == 3:
|
||||
torch._check(
|
||||
q_num_heads != 0 and kv_num_heads != 0,
|
||||
lambda: "q_num_heads and kv_num_heads must be provided for 3D inputs",
|
||||
)
|
||||
q_sequence_length = Q.shape[1]
|
||||
Q = _reshape_3d_to_4d(Q, batch_size, q_num_heads)
|
||||
K = _reshape_3d_to_4d(K, batch_size, kv_num_heads)
|
||||
V = _reshape_3d_to_4d(V, batch_size, kv_num_heads)
|
||||
|
||||
torch._check(
|
||||
len(Q.shape) == 4 and len(K.shape) == 4 and len(V.shape) == 4,
|
||||
lambda: "Q, K, and V should be 4D tensors by now",
|
||||
)
|
||||
|
||||
# Calculate scale factor if not provided
|
||||
q_head_size = Q.shape[head_dim]
|
||||
scale = _get_scale_factor(scale, q_head_size)
|
||||
|
||||
# Handle past key/value caches
|
||||
present_key = (
|
||||
torch.cat([past_key, K], dim=sequence_dim)
|
||||
if past_key is not None
|
||||
else K.clone()
|
||||
)
|
||||
present_value = (
|
||||
torch.cat([past_value, V], dim=sequence_dim)
|
||||
if past_value is not None
|
||||
else V.clone()
|
||||
)
|
||||
|
||||
# Update K and V to include past states
|
||||
K, V = present_key, present_value
|
||||
|
||||
# Get current dimensions
|
||||
current_q_num_heads = Q.shape[num_head_dim]
|
||||
current_kv_num_heads = K.shape[num_head_dim]
|
||||
q_sequence_length = Q.shape[sequence_dim]
|
||||
kv_sequence_length = K.shape[sequence_dim]
|
||||
|
||||
# Check if we can use the optimized scaled_dot_product_attention (most optimized)
|
||||
can_use_sdpa = (
|
||||
softcap == 0.0 # No softcap
|
||||
and qk_matmul_output_mode == 0 # Default QK output mode
|
||||
and softmax_precision is None # No custom softmax precision
|
||||
and (attn_mask is None or attn_mask.dtype == torch.bool)
|
||||
)
|
||||
|
||||
_validate_gqa_configuration(current_q_num_heads, current_kv_num_heads)
|
||||
|
||||
if can_use_sdpa:
|
||||
# Use PyTorch's optimized scaled_dot_product_attention
|
||||
|
||||
# Prepare attention mask for SDPA
|
||||
sdpa_attn_mask = None
|
||||
if attn_mask is not None:
|
||||
# Convert boolean mask: True means participate, SDPA expects True to mask out
|
||||
sdpa_attn_mask = ~attn_mask if attn_mask.dtype == torch.bool else attn_mask
|
||||
|
||||
output = torch.nn.functional.scaled_dot_product_attention(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
attn_mask=sdpa_attn_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=bool(
|
||||
current_q_num_heads != current_kv_num_heads
|
||||
), # Ensure enable_gqa is not SymBool
|
||||
)
|
||||
|
||||
qk_output = _get_qk_output_for_aten_spda(
|
||||
Q,
|
||||
K,
|
||||
current_q_num_heads,
|
||||
current_kv_num_heads,
|
||||
scale,
|
||||
qk_matmul_output_mode,
|
||||
)
|
||||
else:
|
||||
# Fallback to manual implementation for complex cases
|
||||
|
||||
# Handle Group Query Attention (GQA) and Multi-Query Attention (MQA)
|
||||
if current_q_num_heads != current_kv_num_heads:
|
||||
repeat_factor = current_q_num_heads // current_kv_num_heads
|
||||
K = K.repeat_interleave(repeat_factor, dim=num_head_dim)
|
||||
V = V.repeat_interleave(repeat_factor, dim=num_head_dim)
|
||||
|
||||
# Create attention bias
|
||||
attn_bias = torch.zeros(
|
||||
q_sequence_length, kv_sequence_length, dtype=Q.dtype, device=Q.device
|
||||
)
|
||||
|
||||
# Apply causal masking
|
||||
if is_causal:
|
||||
torch._check(
|
||||
attn_mask is None, lambda: "Cannot use both is_causal and attn_mask"
|
||||
)
|
||||
causal_mask = torch.tril(
|
||||
torch.ones(
|
||||
q_sequence_length,
|
||||
kv_sequence_length,
|
||||
dtype=torch.bool,
|
||||
device=Q.device,
|
||||
)
|
||||
)
|
||||
attn_bias = attn_bias.masked_fill(~causal_mask, float("-inf"))
|
||||
|
||||
# Apply attention mask
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
# Boolean mask: True means participate in attention
|
||||
attn_bias = attn_bias.masked_fill(~attn_mask, float("-inf"))
|
||||
else:
|
||||
# Float mask: added to attention scores
|
||||
attn_bias = attn_bias + attn_mask
|
||||
|
||||
# Apply scaling factor
|
||||
scale_factor = _get_scale_factor(scale, Q.shape[3])
|
||||
|
||||
# Scale both Q and K by sqrt(scale_factor) for numerical stability
|
||||
sqrt_scale = math.sqrt(scale_factor)
|
||||
Q_scaled = Q * sqrt_scale
|
||||
K_scaled = K * sqrt_scale
|
||||
|
||||
# Compute Q @ K^T
|
||||
qk_matmul_output = torch.matmul(Q_scaled, K_scaled.transpose(-2, -1))
|
||||
|
||||
# Initialize QK output based on mode
|
||||
qk_output = qk_matmul_output # Default case for mode 0
|
||||
|
||||
# Add attention bias
|
||||
qk_with_bias = qk_matmul_output + attn_bias
|
||||
|
||||
if qk_matmul_output_mode == 1:
|
||||
qk_output = qk_with_bias
|
||||
|
||||
# Apply softcap if provided
|
||||
if softcap > 0.0:
|
||||
qk_with_bias = softcap * torch.tanh(qk_with_bias / softcap)
|
||||
|
||||
if qk_matmul_output_mode == 2:
|
||||
qk_output = qk_with_bias
|
||||
|
||||
# Apply softmax with optional precision casting
|
||||
if softmax_precision is not None:
|
||||
# Map ONNX data type to torch dtype
|
||||
if softmax_precision in _ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS:
|
||||
original_dtype = qk_with_bias.dtype
|
||||
qk_with_bias = qk_with_bias.to(
|
||||
_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[softmax_precision]
|
||||
)
|
||||
qk_softmax = torch.softmax(qk_with_bias, dim=-1)
|
||||
qk_softmax = qk_softmax.to(original_dtype)
|
||||
else:
|
||||
qk_softmax = torch.softmax(qk_with_bias, dim=-1)
|
||||
else:
|
||||
qk_softmax = torch.softmax(qk_with_bias, dim=-1)
|
||||
|
||||
if qk_matmul_output_mode == 3:
|
||||
qk_output = qk_softmax
|
||||
|
||||
# Compute attention output
|
||||
output = torch.matmul(qk_softmax, V)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if input_shape_len == 3:
|
||||
# output: (batch_size, q_num_heads, q_sequence_length, v_head_size) -> (batch_size, q_sequence_length, hidden_size)
|
||||
output = (
|
||||
output.transpose(1, 2).contiguous().view(batch_size, q_sequence_length, -1)
|
||||
)
|
||||
|
||||
return output, present_key, present_value, qk_output
|
||||
|
@ -11,38 +11,15 @@ zeros based on the input shape and dtype, and a "fake" implementation that does
|
||||
or less the same thing but is required by the `torch.library.custom_op` interface.
|
||||
"""
|
||||
|
||||
# flake8: noqa: B950
|
||||
import dataclasses
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.onnx.ops import _dtype_mappings
|
||||
|
||||
|
||||
_ONNX_DTYPE_TO_TORCH_DTYPE: dict[int, torch.dtype] = {
|
||||
1: torch.float32, # FLOAT
|
||||
2: torch.uint8, # UINT8
|
||||
3: torch.int8, # INT8
|
||||
4: torch.uint16, # UINT16
|
||||
5: torch.int16, # INT16
|
||||
6: torch.int32, # INT32
|
||||
7: torch.int64, # INT64
|
||||
9: torch.bool, # BOOL
|
||||
10: torch.float16, # FLOAT16
|
||||
11: torch.double, # DOUBLE
|
||||
12: torch.uint32, # UINT32
|
||||
13: torch.uint64, # UINT64
|
||||
14: torch.complex64, # COMPLEX64
|
||||
15: torch.complex128, # COMPLEX128
|
||||
16: torch.bfloat16, # BFLOAT16
|
||||
17: torch.float8_e4m3fn, # FLOAT8E4M3FN
|
||||
18: torch.float8_e4m3fnuz, # FLOAT8E4M3FNUZ
|
||||
19: torch.float8_e5m2, # FLOAT8E5M2
|
||||
20: torch.float8_e5m2fnuz, # FLOAT8E5M2FNUZ
|
||||
21: torch.uint8, # UINT4
|
||||
22: torch.uint8, # INT4
|
||||
23: torch.float4_e2m1fn_x2, # FLOAT4E2M1
|
||||
}
|
||||
|
||||
_INT_TYPE = "i"
|
||||
_FLOAT_TYPE = "f"
|
||||
_STRING_TYPE = "s"
|
||||
@ -221,10 +198,12 @@ def _symbolic(
|
||||
version: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
torch._check(
|
||||
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
|
||||
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
||||
onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE,
|
||||
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
||||
)
|
||||
return torch.zeros(
|
||||
shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]
|
||||
)
|
||||
return torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype])
|
||||
|
||||
|
||||
@_symbolic.register_fake
|
||||
@ -246,12 +225,14 @@ def _(
|
||||
version: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
torch._check(
|
||||
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
|
||||
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
||||
onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE,
|
||||
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
||||
)
|
||||
# NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured
|
||||
# out how it can handle empty shapes
|
||||
return torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype])
|
||||
return torch.zeros(
|
||||
shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op(
|
||||
@ -289,10 +270,14 @@ def _symbolic_multi_out(
|
||||
)
|
||||
for shape, onnx_dtype in zip(shapes, onnx_dtypes):
|
||||
torch._check(
|
||||
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
|
||||
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
||||
onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE,
|
||||
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
||||
)
|
||||
outputs.append(
|
||||
torch.zeros(
|
||||
shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]
|
||||
)
|
||||
)
|
||||
outputs.append(torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]))
|
||||
return outputs
|
||||
|
||||
|
||||
@ -321,10 +306,14 @@ def _(
|
||||
)
|
||||
for shape, onnx_dtype in zip(shapes, onnx_dtypes):
|
||||
torch._check(
|
||||
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
|
||||
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
||||
onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE,
|
||||
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
||||
)
|
||||
# NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured
|
||||
# out how it can handle empty shapes
|
||||
outputs.append(torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]))
|
||||
outputs.append(
|
||||
torch.zeros(
|
||||
shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]
|
||||
)
|
||||
)
|
||||
return outputs
|
||||
|
Reference in New Issue
Block a user