mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
The implementation of rotary_embedding_23 when input is 3D was incorrect. ## Tested Locally with ```py import onnx_ir as ir import onnx import torch import os import numpy as np base_path = "/home/justinchu/dev/onnx/onnx/backend/test/data/node" test_names = [ "test_rotary_embedding", "test_rotary_embedding_3d_input", "test_rotary_embedding_interleaved", "test_rotary_embedding_no_position_ids", "test_rotary_embedding_no_position_ids_interleaved", "test_rotary_embedding_no_position_ids_rotary_dim", "test_rotary_embedding_with_interleaved_rotary_dim", "test_rotary_embedding_with_rotary_dim", ] model_paths = [os.path.join(base_path, name) for name in test_names] for path in model_paths: print(f"Checking {path} for issues...") model = onnx.load(os.path.join(path, "model.onnx")) input0 = ir.from_proto( onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_0.pb")) ).numpy() input1 = ir.from_proto( onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_1.pb")) ).numpy() input2 = ir.from_proto( onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_2.pb")) ).numpy() if os.path.exists(os.path.join(path, "test_data_set_0", "input_3.pb")): input3 = ir.from_proto( onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_3.pb")) ).numpy() else: input3 = None output0 = ir.from_proto( onnx.load_tensor(os.path.join(path, "test_data_set_0", "output_0.pb")) ).numpy() m = ir.from_proto(model) node = m.graph[-1] print(node) assert node.op_type == "RotaryEmbedding" interleaved = node.attributes.get_int("interleaved", 0) num_heads = node.attributes.get_int("num_heads", 0) rotary_embedding_dim = node.attributes.get_int("rotary_embedding_dim", 0) torch_out = torch.onnx.ops.rotary_embedding( torch.tensor(input0), torch.tensor(input1), torch.tensor(input2), position_ids=torch.tensor(input3) if input3 is not None else None, interleaved=bool(interleaved), num_heads=num_heads, rotary_embedding_dim=rotary_embedding_dim, ) torch_out = torch_out.detach().cpu().numpy() np.testing.assert_allclose(torch_out, output0) ``` Fix https://github.com/pytorch/pytorch/issues/162848 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162865 Approved by: https://github.com/kunal-vaishnavi, https://github.com/titaiwangms