Files
pytorch/test/onnx/ops/test_ops.py
Justin Chu fdf68fa5d7 [ONNX] Fix rotary_embedding_23 implementation (#162865)
The implementation of rotary_embedding_23 when input is 3D was incorrect.

## Tested

Locally with

```py
import onnx_ir as ir
import onnx
import torch
import os
import numpy as np

base_path = "/home/justinchu/dev/onnx/onnx/backend/test/data/node"
test_names = [
    "test_rotary_embedding",
    "test_rotary_embedding_3d_input",
    "test_rotary_embedding_interleaved",
    "test_rotary_embedding_no_position_ids",
    "test_rotary_embedding_no_position_ids_interleaved",
    "test_rotary_embedding_no_position_ids_rotary_dim",
    "test_rotary_embedding_with_interleaved_rotary_dim",
    "test_rotary_embedding_with_rotary_dim",
]
model_paths = [os.path.join(base_path, name) for name in test_names]

for path in model_paths:
    print(f"Checking {path} for issues...")

    model = onnx.load(os.path.join(path, "model.onnx"))
    input0 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_0.pb"))
    ).numpy()
    input1 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_1.pb"))
    ).numpy()
    input2 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_2.pb"))
    ).numpy()
    if os.path.exists(os.path.join(path, "test_data_set_0", "input_3.pb")):
        input3 = ir.from_proto(
            onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_3.pb"))
        ).numpy()
    else:
        input3 = None
    output0 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "output_0.pb"))
    ).numpy()

    m = ir.from_proto(model)

    node = m.graph[-1]
    print(node)
    assert node.op_type == "RotaryEmbedding"

    interleaved = node.attributes.get_int("interleaved", 0)
    num_heads = node.attributes.get_int("num_heads", 0)
    rotary_embedding_dim = node.attributes.get_int("rotary_embedding_dim", 0)

    torch_out = torch.onnx.ops.rotary_embedding(
        torch.tensor(input0),
        torch.tensor(input1),
        torch.tensor(input2),
        position_ids=torch.tensor(input3) if input3 is not None else None,
        interleaved=bool(interleaved),
        num_heads=num_heads,
        rotary_embedding_dim=rotary_embedding_dim,
    )
    torch_out = torch_out.detach().cpu().numpy()
    np.testing.assert_allclose(torch_out, output0)
```

Fix https://github.com/pytorch/pytorch/issues/162848

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162865
Approved by: https://github.com/kunal-vaishnavi, https://github.com/titaiwangms
2025-09-16 03:30:05 +00:00

1504 lines
58 KiB
Python

# Owner(s): ["module: onnx"]
"""Test torch.onnx.ops."""
from __future__ import annotations
import onnx_ir.passes.common as common_passes
import onnxruntime
from onnxscript import ir
from packaging import version
import torch
from torch.onnx._internal.exporter import _testing as onnx_testing
from torch.onnx.ops import _impl, _symbolic_impl
from torch.testing._internal import common_utils
def has_onnxruntime_opset_23() -> bool:
return version.parse(onnxruntime.__version__) >= version.parse("1.23")
class SchemaTest(common_utils.TestCase):
def test_symbolic_has_correct_schema(self):
torch.library.opcheck(
_symbolic_impl._symbolic,
([torch.tensor(1)], "CustomOp", 1),
dict(
shape=[
1,
],
attr_keys=["key"],
attr_types=["i"],
attr_pos=[(0, 1)],
attr_ints=[1],
attr_floats=[1.0],
attr_strs=["attr"],
metadata_props_keys=["meta_key"],
metadata_props_values=["meta_value"],
domain="custom_domain",
version=42,
),
)
# Empty inputs
torch.library.opcheck(
_symbolic_impl._symbolic,
([], "CustomOp", 1),
dict(
shape=[
1,
],
attr_keys=[],
attr_types=[],
attr_pos=[],
attr_ints=[],
attr_floats=[],
attr_strs=[],
metadata_props_keys=[],
metadata_props_values=[],
),
)
def test_symbolic_multi_out_has_correct_schema(self):
torch.library.opcheck(
_symbolic_impl._symbolic_multi_out,
([torch.tensor(1)], "CustomMultiOutOp", [1, 2, 10]),
dict(
shapes=[[1, 2], [42], []],
attr_keys=["key"],
attr_types=["i"],
attr_pos=[(0, 1)],
attr_ints=[1],
attr_floats=[1.0],
attr_strs=["attr"],
metadata_props_keys=["meta_key"],
metadata_props_values=["meta_value"],
domain="",
version=1,
),
)
# Empty inputs
torch.library.opcheck(
_symbolic_impl._symbolic_multi_out,
([], "CustomMultiOutOp", []),
dict(
shapes=[],
attr_keys=[],
attr_types=[],
attr_pos=[],
attr_ints=[],
attr_floats=[],
attr_strs=[],
metadata_props_keys=[],
metadata_props_values=[],
),
)
class SymbolicOpsTest(common_utils.TestCase):
def test_symbolic_accepts_valid_inputs(self):
output = torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(torch.tensor(1),),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtype=torch.float32,
shape=[1, 2, 3],
version=1,
metadata_props={"meta_key": "meta_value"},
)
self.assertEqual(output.shape, torch.Size([1, 2, 3]))
self.assertEqual(output.dtype, torch.float32)
self.assertEqual(output.device, torch.device("cpu"))
def test_symbolic_accepts_valid_inputs_empty_shape(self):
output = torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtype=torch.float32,
shape=[],
)
self.assertEqual(output.shape, torch.Size([]))
def test_symbolic_accepts_valid_inputs_integer_types(self):
output = torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtype=1, # 1 is float32 in ONNX
shape=[42],
)
self.assertEqual(output.dtype, torch.float32)
def test_symbolic_accepts_valid_inputs_int4_type(self):
output = torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtype=22, # 22 is INT4 in ONNX
shape=[42],
)
# We use torch uint8 for int4
self.assertEqual(output.dtype, torch.uint8)
def test_symbolic_is_exportable(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor):
return torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(x, None),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtype=x.dtype,
shape=[1, 2, 3],
version=1,
metadata_props={"meta_key": "meta_value"},
)
onnx_program = torch.onnx.export(
Model(), (torch.tensor(1),), dynamo=True, verbose=False
)
assert onnx_program is not None
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "CustomOp")
self.assertEqual(node.domain, "custom_domain")
attributes = node.attributes
self.assertEqual(
attributes,
dict(
int_key=ir.AttrInt64("int_key", 1),
float_key=ir.AttrFloat32("float_key", 1.0),
str_key=ir.AttrString("str_key", "attr"),
bool_key=ir.AttrInt64("bool_key", 1),
list_int_key=ir.AttrInt64s("list_int_key", [1, 2]),
list_float_key=ir.AttrFloat32s("list_float_key", [1.0, 2.0]),
list_str_key=ir.AttrStrings("list_str_key", ["attr1", "attr2"]),
list_bool_key=ir.AttrInt64s("list_bool_key", [1, 0]),
),
)
self.assertEqual(node.metadata_props["meta_key"], "meta_value")
outputs = node.outputs
self.assertEqual(list(outputs[0].shape), [1, 2, 3])
self.assertEqual(outputs[0].dtype, ir.DataType.INT64)
def test_symbolic_preserves_dynamic_shapes(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(x, y),
dtype=x.dtype,
shape=[*x.shape, *y.shape],
version=1,
)
onnx_program = torch.onnx.export(
Model(),
(torch.zeros(2, 3), torch.zeros(1, 2)),
dynamic_shapes=({0: "batch"}, {1: "something_else"}),
dynamo=True,
verbose=False,
)
assert onnx_program is not None
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "CustomOp")
self.assertEqual(node.domain, "custom_domain")
inputs = onnx_program.model.graph.inputs
self.assertEqual(str(inputs[0].shape[0]), "batch")
self.assertEqual(inputs[0].shape[1], 3)
self.assertEqual(inputs[1].shape[0], 1)
self.assertEqual(str(inputs[1].shape[1]), "something_else")
outputs = node.outputs
self.assertEqual(str(outputs[0].shape[0]), "batch")
self.assertEqual(outputs[0].shape[1], 3)
self.assertEqual(outputs[0].shape[2], 1)
self.assertEqual(str(outputs[0].shape[3]), "something_else")
self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
def test_symbolic_multi_out_accepts_valid_inputs(self):
outputs = torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomMultiOutOp",
(torch.tensor(1),),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtypes=(
1, # 1 is float32 in ONNX
torch.int32,
torch.float8_e4m3fn,
),
shapes=([1, 2], [42], []),
version=1,
metadata_props={"meta_key": "meta_value"},
)
self.assertEqual(len(outputs), 3)
self.assertEqual(outputs[0].shape, torch.Size([1, 2]))
self.assertEqual(outputs[0].dtype, torch.float32)
self.assertEqual(outputs[1].shape, torch.Size([42]))
self.assertEqual(outputs[1].dtype, torch.int32)
self.assertEqual(outputs[2].shape, torch.Size([]))
self.assertEqual(outputs[2].dtype, torch.float8_e4m3fn)
self.assertEqual(outputs[0].device, torch.device("cpu"))
self.assertEqual(outputs[1].device, torch.device("cpu"))
self.assertEqual(outputs[2].device, torch.device("cpu"))
def test_symbolic_multi_out_accepts_valid_inputs_empty_shape(self):
outputs = torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtypes=(torch.float32,),
shapes=[[]],
)
self.assertEqual(outputs[0].shape, torch.Size([]))
def test_symbolic_multi_out_accepts_valid_inputs_integer_types(self):
outputs = torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtypes=(1,), # 1 is float32 in ONNX
shapes=[[42]],
)
self.assertEqual(outputs[0].dtype, torch.float32)
def test_symbolic_multi_out_accepts_valid_inputs_int4_type(self):
outputs = torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtypes=(22,), # 22 is INT4 in ONNX
shapes=[[42]],
)
# We use torch uint8 for int4
self.assertEqual(outputs[0].dtype, torch.uint8)
def test_symbolic_multi_out_is_exportable(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor):
return torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomOp",
(x, None),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtypes=(torch.float32, torch.int32, torch.float8_e4m3fn),
shapes=([1, 2], [42], []),
version=1,
metadata_props={"meta_key": "meta_value"},
)
onnx_program = torch.onnx.export(
Model(), (torch.tensor(1),), dynamo=True, verbose=False
)
assert onnx_program is not None
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "CustomOp")
self.assertEqual(node.domain, "custom_domain")
attributes = node.attributes
self.assertEqual(
attributes,
dict(
int_key=ir.AttrInt64("int_key", 1),
float_key=ir.AttrFloat32("float_key", 1.0),
str_key=ir.AttrString("str_key", "attr"),
bool_key=ir.AttrInt64("bool_key", 1),
list_int_key=ir.AttrInt64s("list_int_key", [1, 2]),
list_float_key=ir.AttrFloat32s("list_float_key", [1.0, 2.0]),
list_str_key=ir.AttrStrings("list_str_key", ["attr1", "attr2"]),
list_bool_key=ir.AttrInt64s("list_bool_key", [1, 0]),
),
)
self.assertEqual(node.metadata_props["meta_key"], "meta_value")
outputs = node.outputs
self.assertEqual(list(outputs[0].shape), [1, 2])
self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
self.assertEqual(list(outputs[1].shape), [42])
self.assertEqual(outputs[1].dtype, ir.DataType.INT32)
self.assertEqual(list(outputs[2].shape), [])
self.assertEqual(outputs[2].dtype, ir.DataType.FLOAT8E4M3FN)
def test_symbolic_multi_out_preserves_dynamic_shapes(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomOp",
(x, y),
dtypes=(x.dtype, 22), # 22 is INT4
shapes=[[*x.shape, *y.shape], [42]],
version=1,
)
onnx_program = torch.onnx.export(
Model(),
(torch.zeros(2, 3), torch.zeros(1, 2)),
dynamic_shapes=({0: "batch"}, {1: "something_else"}),
dynamo=True,
verbose=False,
)
assert onnx_program is not None
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "CustomOp")
self.assertEqual(node.domain, "custom_domain")
inputs = onnx_program.model.graph.inputs
self.assertEqual(str(inputs[0].shape[0]), "batch")
self.assertEqual(inputs[0].shape[1], 3)
self.assertEqual(inputs[1].shape[0], 1)
self.assertEqual(str(inputs[1].shape[1]), "something_else")
outputs = node.outputs
self.assertEqual(str(outputs[0].shape[0]), "batch")
self.assertEqual(outputs[0].shape[1], 3)
self.assertEqual(outputs[0].shape[2], 1)
self.assertEqual(str(outputs[0].shape[3]), "something_else")
self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
self.assertEqual(list(outputs[1].shape), [42])
self.assertEqual(outputs[1].dtype, ir.DataType.INT4)
def test_symbolic_multi_out_raises_when_dtypes_and_shapes_differ(self):
with self.assertRaises(RuntimeError):
torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomMultiOutOp",
(torch.tensor(1),),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtypes=(torch.float32, torch.int32),
shapes=([1, 2], [42], []),
version=1,
metadata_props={"meta_key": "meta_value"},
)
with self.assertRaises(RuntimeError):
torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomMultiOutOp",
(torch.tensor(1),),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtypes=(torch.float32,),
shapes=([1, 2], [42]),
version=1,
metadata_props={"meta_key": "meta_value"},
)
class NativeOnnxOpsTest(common_utils.TestCase):
def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram:
onnx_program = torch.onnx.export(
model,
args,
kwargs=kwargs,
dynamo=True,
fallback=False,
verbose=False,
**options,
)
assert onnx_program is not None
common_passes.CheckerPass()(onnx_program.model)
return onnx_program
def test_onnx_ops_can_be_decomposed_to_aten(self):
input_data = torch.rand(2, 3, 4, 8)
position_ids_data = torch.randint(0, 50, (2, 4)).long()
sin_cache_data = torch.rand(50, 4)
cos_cache_data = torch.rand(50, 4)
class Model(torch.nn.Module):
def forward(
self, input_data, cos_cache_data, sin_cache_data, position_ids_data
):
return torch.onnx.ops.rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
position_ids_data,
interleaved=True,
)
model = Model()
ep = torch.export.export(
model,
(input_data, cos_cache_data, sin_cache_data, position_ids_data),
)
self.assertIn(
"onnx.RotaryEmbedding.opset23",
[str(node.target) for node in ep.graph.nodes],
)
# The program can be decomposed into aten ops so it is fully compatible with the PyTorch ecosystem
aten_decomped = ep.run_decompositions(torch.onnx.ops.aten_decompositions())
self.assertNotIn(
"onnx.RotaryEmbedding.opset23",
[str(node.target) for node in aten_decomped.graph.nodes],
)
torch.testing.assert_close(
aten_decomped.module()(
input_data, cos_cache_data, sin_cache_data, position_ids_data
),
model(input_data, cos_cache_data, sin_cache_data, position_ids_data),
)
def test_rotary_embedding_opcheck(self):
input_data = torch.rand(2, 3, 4, 8)
position_ids_data = torch.randint(0, 50, (2, 4)).long()
sin_cache_data = torch.rand(50, 4)
cos_cache_data = torch.rand(50, 4)
torch.library.opcheck(
_impl.rotary_embedding_23,
(input_data, cos_cache_data, sin_cache_data, position_ids_data),
)
def test_rotary_embedding(self):
input_data = torch.rand(2, 3, 4, 8)
position_ids_data = torch.randint(0, 50, (2, 4)).long()
sin_cache_data = torch.rand(50, 4)
cos_cache_data = torch.rand(50, 4)
# Eager mode is supported. Autograd is also supported so users can choose to use the op
# in development and production
result = torch.onnx.ops.rotary_embedding(
input_data, cos_cache_data, sin_cache_data, position_ids_data
)
self.assertEqual(result.shape, input_data.shape)
class Model(torch.nn.Module):
def forward(
self, input_data, cos_cache_data, sin_cache_data, position_ids_data
):
return torch.onnx.ops.rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
position_ids_data,
interleaved=True,
)
model = Model()
# Dynamic shapes are supported
dynamic_shapes = {
"input_data": {0: torch.export.Dim.DYNAMIC},
"cos_cache_data": None,
"sin_cache_data": None,
"position_ids_data": {0: torch.export.Dim.DYNAMIC},
}
onnx_program = self.export(
model,
(input_data, cos_cache_data, sin_cache_data, position_ids_data),
dynamic_shapes=dynamic_shapes,
opset_version=23,
)
self.assertEqual(onnx_program.model.opset_imports[""], 23)
self.assertEqual("RotaryEmbedding", onnx_program.model.graph.node(0).op_type)
if has_onnxruntime_opset_23():
onnx_testing.assert_onnx_program(onnx_program)
else:
# Test with reference evaluator because ORT does not support the op as of version 1.22
onnx_testing.assert_onnx_program(onnx_program, backend="reference")
def test_rotary_embedding_3d(self):
num_heads = 2
input_data = torch.rand(2, 3, 8)
sin_cache_data = torch.rand(2, 3, 2)
cos_cache_data = torch.rand(2, 3, 2)
class Model(torch.nn.Module):
def forward(self, input_data, cos_cache_data, sin_cache_data):
return torch.onnx.ops.rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
num_heads=num_heads,
)
model = Model()
# Dynamic shapes are supported
dynamic_shapes = {
"input_data": {0: torch.export.Dim.DYNAMIC},
"cos_cache_data": {0: torch.export.Dim.DYNAMIC},
"sin_cache_data": {0: torch.export.Dim.DYNAMIC},
}
onnx_program = self.export(
model,
(input_data, cos_cache_data, sin_cache_data),
dynamic_shapes=dynamic_shapes,
opset_version=23,
)
self.assertEqual(onnx_program.model.opset_imports[""], 23)
self.assertEqual("RotaryEmbedding", onnx_program.model.graph.node(0).op_type)
if has_onnxruntime_opset_23():
onnx_testing.assert_onnx_program(onnx_program)
else:
# Test with reference evaluator because ORT does not support the op as of version 1.22
onnx_testing.assert_onnx_program(onnx_program, backend="reference")
def test_attention_basic(self):
"""Test basic attention functionality."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test eager mode
torch.library.opcheck(_impl.attention_23, (Q, K, V))
output, present_key, present_value, qk_output = torch.onnx.ops.attention(
Q, K, V
)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
self.assertEqual(present_key.shape, K.shape)
self.assertEqual(present_value.shape, V.shape)
self.assertEqual(
qk_output.shape, (batch_size, q_num_heads, q_seq_len, kv_seq_len)
)
def test_attention_3d_inputs(self):
"""Test attention with 3D inputs (requires num_heads parameters)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size)
K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
torch.library.opcheck(
_impl.attention_23,
(Q, K, V),
dict(q_num_heads=q_num_heads, kv_num_heads=kv_num_heads),
)
output, present_key, present_value, qk_output = torch.onnx.ops.attention(
Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads
)
# Output should be reshaped back to 3D
self.assertEqual(output.shape, (batch_size, q_seq_len, q_num_heads * head_size))
self.assertEqual(
present_key.shape, (batch_size, kv_num_heads, kv_seq_len, head_size)
)
self.assertEqual(
present_value.shape, (batch_size, kv_num_heads, kv_seq_len, head_size)
)
def test_attention_gqa(self):
"""Test Group Query Attention (GQA)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 4 # GQA: q_num_heads % kv_num_heads = 0
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
torch.library.opcheck(_impl.attention_23, (Q, K, V))
output, present_key, present_value, qk_output = torch.onnx.ops.attention(
Q, K, V
)
expected = torch.nn.functional.scaled_dot_product_attention(
Q, K, V, None, enable_gqa=True
)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
self.assertEqual(present_key.shape, K.shape)
self.assertEqual(present_value.shape, V.shape)
torch.testing.assert_close(output, expected)
def test_attention_mqa(self):
"""Test Multi-Query Attention (MQA)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 1 # MQA: kv_num_heads = 1
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
torch.library.opcheck(_impl.attention_23, (Q, K, V))
output, present_key, present_value, qk_output = torch.onnx.ops.attention(
Q, K, V
)
expected = torch.nn.functional.scaled_dot_product_attention(
Q, K, V, None, enable_gqa=True
)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
torch.testing.assert_close(output, expected)
def test_attention_with_2d_mask(self):
"""Test attention with 2D attention mask (q_seq_len, kv_seq_len)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test with boolean mask
bool_mask = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=bool_mask))
output_bool, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=bool_mask)
# Test with float mask
float_mask = torch.randn(q_seq_len, kv_seq_len)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask))
output_float, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask)
self.assertEqual(
output_bool.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
self.assertEqual(
output_float.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
def test_attention_with_4d_mask(self):
"""Test attention with 4D attention mask (batch_size, num_heads, q_seq_len, kv_seq_len)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test with boolean mask
bool_mask = torch.randint(
0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool
)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=bool_mask))
output_bool, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=bool_mask)
# Test with float mask
float_mask = torch.randn(batch_size, q_num_heads, q_seq_len, kv_seq_len)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask))
output_float, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask)
self.assertEqual(
output_bool.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
self.assertEqual(
output_float.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
def test_attention_with_zero_float_mask(self):
"""Test attention with zero float mask."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
zero_mask = torch.zeros(q_seq_len, kv_seq_len)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=zero_mask))
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=zero_mask)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_with_causal_mask_pattern(self):
"""Test attention with lower triangular causal mask pattern."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 4 # Square for causal
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Create a lower triangular causal mask
causal_mask = torch.tril(torch.ones(q_seq_len, kv_seq_len, dtype=torch.bool))
torch.library.opcheck(
_impl.attention_23, (Q, K, V), dict(attn_mask=causal_mask)
)
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=causal_mask)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_with_gqa_and_mask(self):
"""Test attention with GQA and different mask shapes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 4 # GQA
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test 2D mask with GQA
mask_2d = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=mask_2d))
output_2d, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask_2d)
# Test 4D mask with GQA (note: using q_num_heads for mask heads)
mask_4d = torch.randint(
0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool
)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=mask_4d))
output_4d, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask_4d)
self.assertEqual(
output_2d.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
self.assertEqual(
output_4d.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
def test_attention_with_large_negative_float_mask(self):
"""Test attention with large negative values in float mask."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Create mask with large negative values (similar to -inf masking)
float_mask = torch.full((q_seq_len, kv_seq_len), -1e9)
# Allow some positions
float_mask[:, :3] = 0.0
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask))
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_causal(self):
"""Test causal attention."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 4 # Square for causal
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(is_causal=True))
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, is_causal=True)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_with_past_kv(self):
"""Test attention with past key/value caches."""
batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
torch.library.opcheck(
_impl.attention_23,
(Q, K, V),
dict(past_key=past_key, past_value=past_value),
)
output, present_key, present_value, _ = torch.onnx.ops.attention(
Q, K, V, past_key=past_key, past_value=past_value
)
# Present key/value should include past + current
expected_total_seq_len = past_seq_len + kv_seq_len
self.assertEqual(
present_key.shape,
(batch_size, kv_num_heads, expected_total_seq_len, head_size),
)
self.assertEqual(
present_value.shape,
(batch_size, kv_num_heads, expected_total_seq_len, head_size),
)
def test_attention_with_softcap(self):
"""Test attention with softcap."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(softcap=30.0))
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, softcap=30.0)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_qk_output_modes(self):
"""Test different QK matmul output modes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
for mode in [0, 1, 2, 3]:
torch.library.opcheck(
_impl.attention_23,
(Q, K, V),
dict(qk_matmul_output_mode=mode),
)
output, _, _, qk_output = torch.onnx.ops.attention(
Q, K, V, qk_matmul_output_mode=mode
)
self.assertEqual(
output.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
self.assertEqual(
qk_output.shape, (batch_size, q_num_heads, q_seq_len, kv_seq_len)
)
def test_attention_custom_scale(self):
"""Test attention with custom scale factor."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
custom_scale = 0.25
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(scale=custom_scale))
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, scale=custom_scale)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_export(self):
"""Test that attention can be exported to ONNX."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class AttentionModel(torch.nn.Module):
def forward(self, Q, K, V):
output, present_key, present_value, qk_output = (
torch.onnx.ops.attention(Q, K, V)
)
return output
model = AttentionModel()
onnx_program = self.export(
model,
(Q, K, V),
opset_version=23,
)
self.assertEqual(onnx_program.model.opset_imports[""], 23)
self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type)
def test_attention_export_with_dynamic_shapes(self):
"""Test attention export with dynamic shapes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class AttentionModel(torch.nn.Module):
def forward(self, Q, K, V):
output, present_key, present_value, qk_output = (
torch.onnx.ops.attention(Q, K, V)
)
return output
model = AttentionModel()
dynamic_shapes = {
"Q": {0: "batch", 2: "q_seq_len"},
"K": {0: "batch", 2: "kv_seq_len"},
"V": {0: "batch", 2: "kv_seq_len"},
}
onnx_program = self.export(
model,
(Q, K, V),
dynamic_shapes=dynamic_shapes,
opset_version=23,
)
self.assertEqual(onnx_program.model.opset_imports[""], 23)
self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type)
node = onnx_program.model.graph.node(0)
# Verify inputs
self.assertEqual(len(node.inputs), 3) # Q, K, V (no optional inputs)
self.assertEqual(
node.inputs[0].shape, ["batch", q_num_heads, "q_seq_len", head_size]
)
self.assertEqual(
node.inputs[1].shape, ["batch", kv_num_heads, "kv_seq_len", head_size]
)
self.assertEqual(
node.inputs[2].shape, ["batch", kv_num_heads, "kv_seq_len", head_size]
)
# Verify default attributes (should be minimal)
self.assertEqual(len(node.attributes), 0)
def test_attention_3d_export(self):
"""Test attention export with 3D inputs."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size)
K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
class AttentionModel(torch.nn.Module):
def forward(self, Q, K, V):
output, _, _, _ = torch.onnx.ops.attention(
Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads
)
return output
model = AttentionModel()
onnx_program = self.export(
model,
(Q, K, V),
opset_version=23,
)
self.assertEqual(onnx_program.model.opset_imports[""], 23)
self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type)
def test_attention_decomposition(self):
"""Test that attention can be decomposed to aten ops."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class AttentionModel(torch.nn.Module):
def forward(self, Q, K, V):
output, present_key, present_value, qk_output = (
torch.onnx.ops.attention(Q, K, V)
)
return output
model = AttentionModel()
ep = torch.export.export(model, (Q, K, V))
self.assertIn(
"onnx.Attention.opset23",
[str(node.target) for node in ep.graph.nodes],
)
# The program can be decomposed into aten ops
aten_decomped = ep.run_decompositions(torch.onnx.ops.aten_decompositions())
self.assertNotIn(
"onnx.Attention.opset23",
[str(node.target) for node in aten_decomped.graph.nodes],
)
# Results should match
torch.testing.assert_close(
aten_decomped.module()(Q, K, V),
model(Q, K, V),
)
def test_attention_export_with_past_key_value(self):
"""Test export with past_key, past_value to ensure the optional input order is correct."""
batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
class Model(torch.nn.Module):
def forward(self, Q, K, V, past_key, past_value):
output, _, _, _ = torch.onnx.ops.attention(
Q,
K,
V,
past_key=past_key,
attn_mask=None,
# Switched argument order
past_value=past_value,
)
return output
model = Model()
onnx_program = self.export(
model, (Q, K, V, past_key, past_value), opset_version=23
)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify all 6 inputs are present
self.assertEqual(
len(node.inputs), 6
) # Q, K, V, attn_mask, past_key, past_value
self.assertEqual(
node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
)
self.assertEqual(
node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
self.assertEqual(
node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
self.assertIsNone(node.inputs[3])
self.assertEqual(
node.inputs[4].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
)
self.assertEqual(
node.inputs[5].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
)
def test_attention_export_with_all_optional_inputs(self):
"""Test export with all optional inputs: mask, past_key, past_value."""
batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
attn_mask = torch.randint(
0, 2, (1, 1, q_seq_len, kv_seq_len + past_seq_len), dtype=torch.bool
)
past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
class FullAttentionModel(torch.nn.Module):
def forward(self, Q, K, V, attn_mask, past_key, past_value):
output, _, _, _ = torch.onnx.ops.attention(
Q,
K,
V,
attn_mask=attn_mask,
past_key=past_key,
past_value=past_value,
)
return output
model = FullAttentionModel()
onnx_program = self.export(
model, (Q, K, V, attn_mask, past_key, past_value), opset_version=23
)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify all 6 inputs are present
self.assertEqual(
len(node.inputs), 6
) # Q, K, V, attn_mask, past_key, past_value
self.assertEqual(
node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
)
self.assertEqual(
node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
self.assertEqual(
node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
self.assertEqual(
node.inputs[3].shape, [1, 1, q_seq_len, kv_seq_len + past_seq_len]
)
self.assertEqual(
node.inputs[4].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
)
self.assertEqual(
node.inputs[5].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
)
def test_attention_export_3d_with_num_heads_attributes(self):
"""Test export with 3D inputs and explicit num_heads attributes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 4 # GQA
head_size = 64
Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size)
K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
class Attention3DModel(torch.nn.Module):
def forward(self, Q, K, V):
output, _, _, _ = torch.onnx.ops.attention(
Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads
)
return output
model = Attention3DModel()
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify 3D input shapes
self.assertEqual(
node.inputs[0].shape, [batch_size, q_seq_len, q_num_heads * head_size]
)
self.assertEqual(
node.inputs[1].shape, [batch_size, kv_seq_len, kv_num_heads * head_size]
)
self.assertEqual(
node.inputs[2].shape, [batch_size, kv_seq_len, kv_num_heads * head_size]
)
# Verify num_heads attributes are set
attrs = node.attributes
self.assertIn("q_num_heads", attrs)
self.assertIn("kv_num_heads", attrs)
self.assertEqual(attrs["q_num_heads"].value, q_num_heads)
self.assertEqual(attrs["kv_num_heads"].value, kv_num_heads)
def test_attention_export_with_all_attributes(self):
"""Test export with all possible attributes set."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class FullAttributesModel(torch.nn.Module):
def forward(self, Q, K, V):
output, _, _, _ = torch.onnx.ops.attention(
Q,
K,
V,
is_causal=True,
qk_matmul_output_mode=2,
scale=0.25,
softcap=30.0,
softmax_precision=1, # FLOAT
)
return output
model = FullAttributesModel()
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify all attributes are set correctly
attrs = node.attributes
self.assertIn("is_causal", attrs)
self.assertIn("qk_matmul_output_mode", attrs)
self.assertIn("scale", attrs)
self.assertIn("softcap", attrs)
self.assertIn("softmax_precision", attrs)
self.assertEqual(attrs["is_causal"].value, 1) # True as int
self.assertEqual(attrs["qk_matmul_output_mode"].value, 2)
self.assertAlmostEqual(attrs["scale"].value, 0.25, places=6)
self.assertAlmostEqual(attrs["softcap"].value, 30.0, places=6)
self.assertEqual(attrs["softmax_precision"].value, 1)
def test_attention_export_with_different_mask_shapes(self):
"""Test export with different attention mask shapes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test 2D mask
mask_2d = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool)
class Mask2DModel(torch.nn.Module):
def forward(self, Q, K, V, mask):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
return output
model_2d = Mask2DModel()
onnx_program_2d = self.export(model_2d, (Q, K, V, mask_2d), opset_version=23)
node_2d = onnx_program_2d.model.graph.node(0)
self.assertEqual(node_2d.inputs[3].shape, [q_seq_len, kv_seq_len])
# Test 3D mask
mask_3d = torch.randint(
0, 2, (batch_size, 1, q_seq_len, kv_seq_len), dtype=torch.bool
)
class Mask3DModel(torch.nn.Module):
def forward(self, Q, K, V, mask):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
return output
model_3d = Mask3DModel()
onnx_program_3d = self.export(model_3d, (Q, K, V, mask_3d), opset_version=23)
node_3d = onnx_program_3d.model.graph.node(0)
self.assertEqual(
node_3d.inputs[3].shape, [batch_size, 1, q_seq_len, kv_seq_len]
)
# Test 4D mask
mask_4d = torch.randint(
0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool
)
class Mask4DModel(torch.nn.Module):
def forward(self, Q, K, V, mask):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
return output
model_4d = Mask4DModel()
onnx_program_4d = self.export(model_4d, (Q, K, V, mask_4d), opset_version=23)
node_4d = onnx_program_4d.model.graph.node(0)
self.assertEqual(
node_4d.inputs[3].shape, [batch_size, q_num_heads, q_seq_len, kv_seq_len]
)
def test_attention_export_with_float_mask(self):
"""Test export with float attention mask."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
float_mask = torch.randn(q_seq_len, kv_seq_len)
class FloatMaskModel(torch.nn.Module):
def forward(self, Q, K, V, mask):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
return output
model = FloatMaskModel()
onnx_program = self.export(model, (Q, K, V, float_mask), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
self.assertEqual(node.inputs[3].shape, [q_seq_len, kv_seq_len])
# Verify the mask input has float dtype in the ONNX model
self.assertEqual(node.inputs[3].dtype, ir.DataType.FLOAT)
def test_attention_export_qk_output_modes(self):
"""Test export with different QK output modes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
for mode in [0, 1, 2, 3]:
class QKOutputModel(torch.nn.Module):
def __init__(self, qk_mode):
super().__init__()
self.qk_mode = qk_mode
def forward(self, Q, K, V):
output, _, _, qk_output = torch.onnx.ops.attention(
Q, K, V, qk_matmul_output_mode=self.qk_mode
)
return output, qk_output
model = QKOutputModel(mode)
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify qk_matmul_output_mode attribute
attrs = node.attributes
if mode != 0:
self.assertIn("qk_matmul_output_mode", attrs)
self.assertEqual(attrs["qk_matmul_output_mode"].value, mode)
# Verify 4 outputs (output, present_key, present_value, qk_output)
self.assertEqual(len(node.outputs), 4)
def test_attention_export_mqa(self):
"""Test export with Multi-Query Attention (MQA)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 1 # MQA
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class MQAModel(torch.nn.Module):
def forward(self, Q, K, V):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V)
return output
model = MQAModel()
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify MQA tensor shapes
self.assertEqual(
node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
)
self.assertEqual(
node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
) # kv_num_heads = 1
self.assertEqual(
node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
def test_attention_export_with_softmax_precision(self):
"""Test export with different softmax precision values."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test different ONNX precision types
precision_types = [
(1, "FLOAT"),
(10, "FLOAT16"),
(11, "DOUBLE"),
(16, "BFLOAT16"),
]
for precision_val, precision_name in precision_types:
class SoftmaxPrecisionModel(torch.nn.Module):
def __init__(self, precision):
super().__init__()
self.precision = precision
def forward(self, Q, K, V):
output, _, _, _ = torch.onnx.ops.attention(
Q, K, V, softmax_precision=self.precision
)
return output
model = SoftmaxPrecisionModel(precision_val)
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify softmax_precision attribute
attrs = node.attributes
self.assertIn("softmax_precision", attrs)
self.assertEqual(attrs["softmax_precision"].value, precision_val)
def test_attention_export_gqa(self):
"""Test export and verify output tensor shapes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 4 # GQA
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class AttentionOutputsModel(torch.nn.Module):
def forward(self, Q, K, V):
return torch.onnx.ops.attention(Q, K, V)
model = AttentionOutputsModel()
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify all 4 outputs have correct shapes
outputs = node.outputs
self.assertEqual(len(outputs), 4)
# output: (batch_size, q_num_heads, q_seq_len, head_size)
self.assertEqual(
outputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
)
# present_key: (batch_size, kv_num_heads, kv_seq_len, head_size)
self.assertEqual(
outputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
# present_value: (batch_size, kv_num_heads, kv_seq_len, head_size)
self.assertEqual(
outputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
# qk_output: (batch_size, q_num_heads, q_seq_len, kv_seq_len)
self.assertEqual(
outputs[3].shape, [batch_size, q_num_heads, q_seq_len, kv_seq_len]
)
if __name__ == "__main__":
common_utils.run_tests()