[ONNX] Fix rotary_embedding_23 implementation (#162865)

The implementation of rotary_embedding_23 when input is 3D was incorrect.

## Tested

Locally with

```py
import onnx_ir as ir
import onnx
import torch
import os
import numpy as np

base_path = "/home/justinchu/dev/onnx/onnx/backend/test/data/node"
test_names = [
    "test_rotary_embedding",
    "test_rotary_embedding_3d_input",
    "test_rotary_embedding_interleaved",
    "test_rotary_embedding_no_position_ids",
    "test_rotary_embedding_no_position_ids_interleaved",
    "test_rotary_embedding_no_position_ids_rotary_dim",
    "test_rotary_embedding_with_interleaved_rotary_dim",
    "test_rotary_embedding_with_rotary_dim",
]
model_paths = [os.path.join(base_path, name) for name in test_names]

for path in model_paths:
    print(f"Checking {path} for issues...")

    model = onnx.load(os.path.join(path, "model.onnx"))
    input0 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_0.pb"))
    ).numpy()
    input1 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_1.pb"))
    ).numpy()
    input2 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_2.pb"))
    ).numpy()
    if os.path.exists(os.path.join(path, "test_data_set_0", "input_3.pb")):
        input3 = ir.from_proto(
            onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_3.pb"))
        ).numpy()
    else:
        input3 = None
    output0 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "output_0.pb"))
    ).numpy()

    m = ir.from_proto(model)

    node = m.graph[-1]
    print(node)
    assert node.op_type == "RotaryEmbedding"

    interleaved = node.attributes.get_int("interleaved", 0)
    num_heads = node.attributes.get_int("num_heads", 0)
    rotary_embedding_dim = node.attributes.get_int("rotary_embedding_dim", 0)

    torch_out = torch.onnx.ops.rotary_embedding(
        torch.tensor(input0),
        torch.tensor(input1),
        torch.tensor(input2),
        position_ids=torch.tensor(input3) if input3 is not None else None,
        interleaved=bool(interleaved),
        num_heads=num_heads,
        rotary_embedding_dim=rotary_embedding_dim,
    )
    torch_out = torch_out.detach().cpu().numpy()
    np.testing.assert_allclose(torch_out, output0)
```

Fix https://github.com/pytorch/pytorch/issues/162848

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162865
Approved by: https://github.com/kunal-vaishnavi, https://github.com/titaiwangms
This commit is contained in:
Justin Chu
2025-09-16 03:30:02 +00:00
committed by PyTorch MergeBot
parent 7924b083c1
commit fdf68fa5d7
2 changed files with 119 additions and 19 deletions

View File

@ -4,13 +4,20 @@
from __future__ import annotations
import onnx_ir.passes.common as common_passes
import onnxruntime
from onnxscript import ir
from packaging import version
import torch
from torch.onnx._internal.exporter import _testing as onnx_testing
from torch.onnx.ops import _impl, _symbolic_impl
from torch.testing._internal import common_utils
def has_onnxruntime_opset_23() -> bool:
return version.parse(onnxruntime.__version__) >= version.parse("1.23")
class SchemaTest(common_utils.TestCase):
def test_symbolic_has_correct_schema(self):
torch.library.opcheck(
@ -432,7 +439,7 @@ class NativeOnnxOpsTest(common_utils.TestCase):
def test_onnx_ops_can_be_decomposed_to_aten(self):
input_data = torch.rand(2, 3, 4, 8)
position_ids_data = torch.randint(0, 50, (2, 3)).long()
position_ids_data = torch.randint(0, 50, (2, 4)).long()
sin_cache_data = torch.rand(50, 4)
cos_cache_data = torch.rand(50, 4)
@ -473,7 +480,7 @@ class NativeOnnxOpsTest(common_utils.TestCase):
def test_rotary_embedding_opcheck(self):
input_data = torch.rand(2, 3, 4, 8)
position_ids_data = torch.randint(0, 50, (2, 3)).long()
position_ids_data = torch.randint(0, 50, (2, 4)).long()
sin_cache_data = torch.rand(50, 4)
cos_cache_data = torch.rand(50, 4)
@ -484,7 +491,7 @@ class NativeOnnxOpsTest(common_utils.TestCase):
def test_rotary_embedding(self):
input_data = torch.rand(2, 3, 4, 8)
position_ids_data = torch.randint(0, 50, (2, 3)).long()
position_ids_data = torch.randint(0, 50, (2, 4)).long()
sin_cache_data = torch.rand(50, 4)
cos_cache_data = torch.rand(50, 4)
@ -525,6 +532,49 @@ class NativeOnnxOpsTest(common_utils.TestCase):
)
self.assertEqual(onnx_program.model.opset_imports[""], 23)
self.assertEqual("RotaryEmbedding", onnx_program.model.graph.node(0).op_type)
if has_onnxruntime_opset_23():
onnx_testing.assert_onnx_program(onnx_program)
else:
# Test with reference evaluator because ORT does not support the op as of version 1.22
onnx_testing.assert_onnx_program(onnx_program, backend="reference")
def test_rotary_embedding_3d(self):
num_heads = 2
input_data = torch.rand(2, 3, 8)
sin_cache_data = torch.rand(2, 3, 2)
cos_cache_data = torch.rand(2, 3, 2)
class Model(torch.nn.Module):
def forward(self, input_data, cos_cache_data, sin_cache_data):
return torch.onnx.ops.rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
num_heads=num_heads,
)
model = Model()
# Dynamic shapes are supported
dynamic_shapes = {
"input_data": {0: torch.export.Dim.DYNAMIC},
"cos_cache_data": {0: torch.export.Dim.DYNAMIC},
"sin_cache_data": {0: torch.export.Dim.DYNAMIC},
}
onnx_program = self.export(
model,
(input_data, cos_cache_data, sin_cache_data),
dynamic_shapes=dynamic_shapes,
opset_version=23,
)
self.assertEqual(onnx_program.model.opset_imports[""], 23)
self.assertEqual("RotaryEmbedding", onnx_program.model.graph.node(0).op_type)
if has_onnxruntime_opset_23():
onnx_testing.assert_onnx_program(onnx_program)
else:
# Test with reference evaluator because ORT does not support the op as of version 1.22
onnx_testing.assert_onnx_program(onnx_program, backend="reference")
def test_attention_basic(self):
"""Test basic attention functionality."""

View File

@ -56,18 +56,55 @@ def rotary_embedding_23(
rotary_embedding_dim: int = 0,
) -> torch.Tensor:
"""RotaryEmbedding-23 https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html#rotaryembedding-23"""
# x has shape (batch_size, num_heads, sequence_length, head_size)
# or (batch_size, sequence_length, hidden_size)
input_shape = x.shape
input_rank = len(input_shape)
batch_size = input_shape[0]
sequence_length = input_shape[-2]
# Validate position_ids and caches match x
if position_ids is not None:
torch._check(
position_ids.dim() == 2,
lambda: f"position_ids must be 2D when provided. Received shape {position_ids.shape}",
)
torch._check(
position_ids.shape[0] == batch_size,
lambda: f"position_ids first dim (batch) must match x.shape[0] ({batch_size}). Received {position_ids.shape[0]}",
)
torch._check(
position_ids.shape[1] == sequence_length,
lambda: f"position_ids second dim (sequence) must match x.shape[-2] ({sequence_length}). Received {position_ids.shape[1]}",
)
torch._check(
cos_cache.dim() == 2 and sin_cache.dim() == 2,
lambda: "cos_cache/sin_cache must be 2D when position_ids is provided. "
f"Received cos_cache shape {cos_cache.shape}, sin_cache shape {sin_cache.shape}",
)
else:
torch._check(
cos_cache.dim() == 3 and sin_cache.dim() == 3,
lambda: "cos_cache/sin_cache must be 3D when position_ids is not provided. "
f"Received cos_cache shape {cos_cache.shape}, sin_cache shape {sin_cache.shape}",
)
# First ensure x has shape [batch_size, num_heads, seq_len, head_size]
batch_size = x.shape[0]
sequence_length = x.shape[1]
if len(x.shape) == 3:
hidden_size = x.shape[2]
# So that the rotation logic can be shared with reshaped 3D inputs
if input_rank == 4:
# Reshape from (batch_size, num_heads, seq_len, head_size)
# to [batch_size, seq_len, num_heads, head_size]
x = torch.permute(x, (0, 2, 1, 3))
elif input_rank == 3:
torch._check(
num_heads != 0,
lambda: f"num_heads must be provided for 3D inputs. Received input tensor with shape {x.shape}",
lambda: f"num_heads must be provided for 3D inputs. Received input tensor with shape {input_shape}",
)
hidden_size = input_shape[2]
head_size = hidden_size // num_heads
new_shape = [batch_size, sequence_length, num_heads, head_size]
x = torch.reshape(x, new_shape)
torch._check(len(x.shape) == 4, lambda: "x should be a 4D tensor by now")
head_size = x.shape[3]
@ -88,14 +125,25 @@ def rotary_embedding_23(
position_ids
] # Shape: [batch_size, sequence_length, head_size/2]
else:
cos = cos_cache
sin = sin_cache
cos = cos[
:, :, :rotary_embedding_dim_half
] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
sin = sin[
:, :, :rotary_embedding_dim_half
] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
cos = cos_cache # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
sin = sin_cache # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
torch._check(
cos.shape[0] == batch_size and cos.shape[1] == sequence_length,
lambda: f"cos has shape {cos.shape} but expected (batch={batch_size}, seq={sequence_length}, ...)",
)
torch._check(
sin.shape[0] == batch_size and sin.shape[1] == sequence_length,
lambda: f"sin has shape {sin.shape} but expected (batch={batch_size}, seq={sequence_length}, ...)",
)
torch._check(
cos.shape[-1] == rotary_embedding_dim_half,
lambda: f"Last dimension of cos cache ({cos.shape[-1]}) should match rotary_embedding_dim/2 ({rotary_embedding_dim_half}).",
)
torch._check(
sin.shape[-1] == rotary_embedding_dim_half,
lambda: f"Last dimension of sin cache ({sin.shape[-1]}) should match rotary_embedding_dim/2 ({rotary_embedding_dim_half}).",
)
cos = torch.unsqueeze(
cos, 2
) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
@ -125,9 +173,11 @@ def rotary_embedding_23(
else:
x_rotate = torch.cat((real, imag), dim=-1)
output = torch.cat((x_rotate, x_not_rotate), dim=-1)
if len(x.shape) == 3:
output = torch.reshape(output, x.shape)
return output
if input_rank == 3:
return torch.reshape(output, input_shape)
# Return the dimensions to the original order
return torch.permute(output, (0, 2, 1, 3))
def _get_scale_factor(scale: Optional[float], head_size: int) -> float: