mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Fix rotary_embedding_23 implementation (#162865)
The implementation of rotary_embedding_23 when input is 3D was incorrect. ## Tested Locally with ```py import onnx_ir as ir import onnx import torch import os import numpy as np base_path = "/home/justinchu/dev/onnx/onnx/backend/test/data/node" test_names = [ "test_rotary_embedding", "test_rotary_embedding_3d_input", "test_rotary_embedding_interleaved", "test_rotary_embedding_no_position_ids", "test_rotary_embedding_no_position_ids_interleaved", "test_rotary_embedding_no_position_ids_rotary_dim", "test_rotary_embedding_with_interleaved_rotary_dim", "test_rotary_embedding_with_rotary_dim", ] model_paths = [os.path.join(base_path, name) for name in test_names] for path in model_paths: print(f"Checking {path} for issues...") model = onnx.load(os.path.join(path, "model.onnx")) input0 = ir.from_proto( onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_0.pb")) ).numpy() input1 = ir.from_proto( onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_1.pb")) ).numpy() input2 = ir.from_proto( onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_2.pb")) ).numpy() if os.path.exists(os.path.join(path, "test_data_set_0", "input_3.pb")): input3 = ir.from_proto( onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_3.pb")) ).numpy() else: input3 = None output0 = ir.from_proto( onnx.load_tensor(os.path.join(path, "test_data_set_0", "output_0.pb")) ).numpy() m = ir.from_proto(model) node = m.graph[-1] print(node) assert node.op_type == "RotaryEmbedding" interleaved = node.attributes.get_int("interleaved", 0) num_heads = node.attributes.get_int("num_heads", 0) rotary_embedding_dim = node.attributes.get_int("rotary_embedding_dim", 0) torch_out = torch.onnx.ops.rotary_embedding( torch.tensor(input0), torch.tensor(input1), torch.tensor(input2), position_ids=torch.tensor(input3) if input3 is not None else None, interleaved=bool(interleaved), num_heads=num_heads, rotary_embedding_dim=rotary_embedding_dim, ) torch_out = torch_out.detach().cpu().numpy() np.testing.assert_allclose(torch_out, output0) ``` Fix https://github.com/pytorch/pytorch/issues/162848 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162865 Approved by: https://github.com/kunal-vaishnavi, https://github.com/titaiwangms
This commit is contained in:
committed by
PyTorch MergeBot
parent
7924b083c1
commit
fdf68fa5d7
@ -4,13 +4,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import onnx_ir.passes.common as common_passes
|
||||
import onnxruntime
|
||||
from onnxscript import ir
|
||||
from packaging import version
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal.exporter import _testing as onnx_testing
|
||||
from torch.onnx.ops import _impl, _symbolic_impl
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
def has_onnxruntime_opset_23() -> bool:
|
||||
return version.parse(onnxruntime.__version__) >= version.parse("1.23")
|
||||
|
||||
|
||||
class SchemaTest(common_utils.TestCase):
|
||||
def test_symbolic_has_correct_schema(self):
|
||||
torch.library.opcheck(
|
||||
@ -432,7 +439,7 @@ class NativeOnnxOpsTest(common_utils.TestCase):
|
||||
|
||||
def test_onnx_ops_can_be_decomposed_to_aten(self):
|
||||
input_data = torch.rand(2, 3, 4, 8)
|
||||
position_ids_data = torch.randint(0, 50, (2, 3)).long()
|
||||
position_ids_data = torch.randint(0, 50, (2, 4)).long()
|
||||
sin_cache_data = torch.rand(50, 4)
|
||||
cos_cache_data = torch.rand(50, 4)
|
||||
|
||||
@ -473,7 +480,7 @@ class NativeOnnxOpsTest(common_utils.TestCase):
|
||||
|
||||
def test_rotary_embedding_opcheck(self):
|
||||
input_data = torch.rand(2, 3, 4, 8)
|
||||
position_ids_data = torch.randint(0, 50, (2, 3)).long()
|
||||
position_ids_data = torch.randint(0, 50, (2, 4)).long()
|
||||
sin_cache_data = torch.rand(50, 4)
|
||||
cos_cache_data = torch.rand(50, 4)
|
||||
|
||||
@ -484,7 +491,7 @@ class NativeOnnxOpsTest(common_utils.TestCase):
|
||||
|
||||
def test_rotary_embedding(self):
|
||||
input_data = torch.rand(2, 3, 4, 8)
|
||||
position_ids_data = torch.randint(0, 50, (2, 3)).long()
|
||||
position_ids_data = torch.randint(0, 50, (2, 4)).long()
|
||||
sin_cache_data = torch.rand(50, 4)
|
||||
cos_cache_data = torch.rand(50, 4)
|
||||
|
||||
@ -525,6 +532,49 @@ class NativeOnnxOpsTest(common_utils.TestCase):
|
||||
)
|
||||
self.assertEqual(onnx_program.model.opset_imports[""], 23)
|
||||
self.assertEqual("RotaryEmbedding", onnx_program.model.graph.node(0).op_type)
|
||||
if has_onnxruntime_opset_23():
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
else:
|
||||
# Test with reference evaluator because ORT does not support the op as of version 1.22
|
||||
onnx_testing.assert_onnx_program(onnx_program, backend="reference")
|
||||
|
||||
def test_rotary_embedding_3d(self):
|
||||
num_heads = 2
|
||||
input_data = torch.rand(2, 3, 8)
|
||||
sin_cache_data = torch.rand(2, 3, 2)
|
||||
cos_cache_data = torch.rand(2, 3, 2)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, input_data, cos_cache_data, sin_cache_data):
|
||||
return torch.onnx.ops.rotary_embedding(
|
||||
input_data,
|
||||
cos_cache_data,
|
||||
sin_cache_data,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
|
||||
model = Model()
|
||||
|
||||
# Dynamic shapes are supported
|
||||
dynamic_shapes = {
|
||||
"input_data": {0: torch.export.Dim.DYNAMIC},
|
||||
"cos_cache_data": {0: torch.export.Dim.DYNAMIC},
|
||||
"sin_cache_data": {0: torch.export.Dim.DYNAMIC},
|
||||
}
|
||||
|
||||
onnx_program = self.export(
|
||||
model,
|
||||
(input_data, cos_cache_data, sin_cache_data),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
opset_version=23,
|
||||
)
|
||||
self.assertEqual(onnx_program.model.opset_imports[""], 23)
|
||||
self.assertEqual("RotaryEmbedding", onnx_program.model.graph.node(0).op_type)
|
||||
if has_onnxruntime_opset_23():
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
else:
|
||||
# Test with reference evaluator because ORT does not support the op as of version 1.22
|
||||
onnx_testing.assert_onnx_program(onnx_program, backend="reference")
|
||||
|
||||
def test_attention_basic(self):
|
||||
"""Test basic attention functionality."""
|
||||
|
@ -56,18 +56,55 @@ def rotary_embedding_23(
|
||||
rotary_embedding_dim: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""RotaryEmbedding-23 https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html#rotaryembedding-23"""
|
||||
# x has shape (batch_size, num_heads, sequence_length, head_size)
|
||||
# or (batch_size, sequence_length, hidden_size)
|
||||
input_shape = x.shape
|
||||
input_rank = len(input_shape)
|
||||
batch_size = input_shape[0]
|
||||
sequence_length = input_shape[-2]
|
||||
|
||||
# Validate position_ids and caches match x
|
||||
if position_ids is not None:
|
||||
torch._check(
|
||||
position_ids.dim() == 2,
|
||||
lambda: f"position_ids must be 2D when provided. Received shape {position_ids.shape}",
|
||||
)
|
||||
torch._check(
|
||||
position_ids.shape[0] == batch_size,
|
||||
lambda: f"position_ids first dim (batch) must match x.shape[0] ({batch_size}). Received {position_ids.shape[0]}",
|
||||
)
|
||||
torch._check(
|
||||
position_ids.shape[1] == sequence_length,
|
||||
lambda: f"position_ids second dim (sequence) must match x.shape[-2] ({sequence_length}). Received {position_ids.shape[1]}",
|
||||
)
|
||||
torch._check(
|
||||
cos_cache.dim() == 2 and sin_cache.dim() == 2,
|
||||
lambda: "cos_cache/sin_cache must be 2D when position_ids is provided. "
|
||||
f"Received cos_cache shape {cos_cache.shape}, sin_cache shape {sin_cache.shape}",
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
cos_cache.dim() == 3 and sin_cache.dim() == 3,
|
||||
lambda: "cos_cache/sin_cache must be 3D when position_ids is not provided. "
|
||||
f"Received cos_cache shape {cos_cache.shape}, sin_cache shape {sin_cache.shape}",
|
||||
)
|
||||
|
||||
# First ensure x has shape [batch_size, num_heads, seq_len, head_size]
|
||||
batch_size = x.shape[0]
|
||||
sequence_length = x.shape[1]
|
||||
if len(x.shape) == 3:
|
||||
hidden_size = x.shape[2]
|
||||
# So that the rotation logic can be shared with reshaped 3D inputs
|
||||
if input_rank == 4:
|
||||
# Reshape from (batch_size, num_heads, seq_len, head_size)
|
||||
# to [batch_size, seq_len, num_heads, head_size]
|
||||
x = torch.permute(x, (0, 2, 1, 3))
|
||||
elif input_rank == 3:
|
||||
torch._check(
|
||||
num_heads != 0,
|
||||
lambda: f"num_heads must be provided for 3D inputs. Received input tensor with shape {x.shape}",
|
||||
lambda: f"num_heads must be provided for 3D inputs. Received input tensor with shape {input_shape}",
|
||||
)
|
||||
hidden_size = input_shape[2]
|
||||
head_size = hidden_size // num_heads
|
||||
new_shape = [batch_size, sequence_length, num_heads, head_size]
|
||||
x = torch.reshape(x, new_shape)
|
||||
|
||||
torch._check(len(x.shape) == 4, lambda: "x should be a 4D tensor by now")
|
||||
head_size = x.shape[3]
|
||||
|
||||
@ -88,14 +125,25 @@ def rotary_embedding_23(
|
||||
position_ids
|
||||
] # Shape: [batch_size, sequence_length, head_size/2]
|
||||
else:
|
||||
cos = cos_cache
|
||||
sin = sin_cache
|
||||
cos = cos[
|
||||
:, :, :rotary_embedding_dim_half
|
||||
] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
|
||||
sin = sin[
|
||||
:, :, :rotary_embedding_dim_half
|
||||
] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
|
||||
cos = cos_cache # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
|
||||
sin = sin_cache # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
|
||||
|
||||
torch._check(
|
||||
cos.shape[0] == batch_size and cos.shape[1] == sequence_length,
|
||||
lambda: f"cos has shape {cos.shape} but expected (batch={batch_size}, seq={sequence_length}, ...)",
|
||||
)
|
||||
torch._check(
|
||||
sin.shape[0] == batch_size and sin.shape[1] == sequence_length,
|
||||
lambda: f"sin has shape {sin.shape} but expected (batch={batch_size}, seq={sequence_length}, ...)",
|
||||
)
|
||||
torch._check(
|
||||
cos.shape[-1] == rotary_embedding_dim_half,
|
||||
lambda: f"Last dimension of cos cache ({cos.shape[-1]}) should match rotary_embedding_dim/2 ({rotary_embedding_dim_half}).",
|
||||
)
|
||||
torch._check(
|
||||
sin.shape[-1] == rotary_embedding_dim_half,
|
||||
lambda: f"Last dimension of sin cache ({sin.shape[-1]}) should match rotary_embedding_dim/2 ({rotary_embedding_dim_half}).",
|
||||
)
|
||||
cos = torch.unsqueeze(
|
||||
cos, 2
|
||||
) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
|
||||
@ -125,9 +173,11 @@ def rotary_embedding_23(
|
||||
else:
|
||||
x_rotate = torch.cat((real, imag), dim=-1)
|
||||
output = torch.cat((x_rotate, x_not_rotate), dim=-1)
|
||||
if len(x.shape) == 3:
|
||||
output = torch.reshape(output, x.shape)
|
||||
return output
|
||||
if input_rank == 3:
|
||||
return torch.reshape(output, input_shape)
|
||||
|
||||
# Return the dimensions to the original order
|
||||
return torch.permute(output, (0, 2, 1, 3))
|
||||
|
||||
|
||||
def _get_scale_factor(scale: Optional[float], head_size: int) -> float:
|
||||
|
Reference in New Issue
Block a user