mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
The implementation of rotary_embedding_23 when input is 3D was incorrect.
## Tested
Locally with
```py
import onnx_ir as ir
import onnx
import torch
import os
import numpy as np
base_path = "/home/justinchu/dev/onnx/onnx/backend/test/data/node"
test_names = [
"test_rotary_embedding",
"test_rotary_embedding_3d_input",
"test_rotary_embedding_interleaved",
"test_rotary_embedding_no_position_ids",
"test_rotary_embedding_no_position_ids_interleaved",
"test_rotary_embedding_no_position_ids_rotary_dim",
"test_rotary_embedding_with_interleaved_rotary_dim",
"test_rotary_embedding_with_rotary_dim",
]
model_paths = [os.path.join(base_path, name) for name in test_names]
for path in model_paths:
print(f"Checking {path} for issues...")
model = onnx.load(os.path.join(path, "model.onnx"))
input0 = ir.from_proto(
onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_0.pb"))
).numpy()
input1 = ir.from_proto(
onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_1.pb"))
).numpy()
input2 = ir.from_proto(
onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_2.pb"))
).numpy()
if os.path.exists(os.path.join(path, "test_data_set_0", "input_3.pb")):
input3 = ir.from_proto(
onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_3.pb"))
).numpy()
else:
input3 = None
output0 = ir.from_proto(
onnx.load_tensor(os.path.join(path, "test_data_set_0", "output_0.pb"))
).numpy()
m = ir.from_proto(model)
node = m.graph[-1]
print(node)
assert node.op_type == "RotaryEmbedding"
interleaved = node.attributes.get_int("interleaved", 0)
num_heads = node.attributes.get_int("num_heads", 0)
rotary_embedding_dim = node.attributes.get_int("rotary_embedding_dim", 0)
torch_out = torch.onnx.ops.rotary_embedding(
torch.tensor(input0),
torch.tensor(input1),
torch.tensor(input2),
position_ids=torch.tensor(input3) if input3 is not None else None,
interleaved=bool(interleaved),
num_heads=num_heads,
rotary_embedding_dim=rotary_embedding_dim,
)
torch_out = torch_out.detach().cpu().numpy()
np.testing.assert_allclose(torch_out, output0)
```
Fix https://github.com/pytorch/pytorch/issues/162848
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162865
Approved by: https://github.com/kunal-vaishnavi, https://github.com/titaiwangms