5 Commits

Author SHA1 Message Date
fdf68fa5d7 [ONNX] Fix rotary_embedding_23 implementation (#162865)
The implementation of rotary_embedding_23 when input is 3D was incorrect.

## Tested

Locally with

```py
import onnx_ir as ir
import onnx
import torch
import os
import numpy as np

base_path = "/home/justinchu/dev/onnx/onnx/backend/test/data/node"
test_names = [
    "test_rotary_embedding",
    "test_rotary_embedding_3d_input",
    "test_rotary_embedding_interleaved",
    "test_rotary_embedding_no_position_ids",
    "test_rotary_embedding_no_position_ids_interleaved",
    "test_rotary_embedding_no_position_ids_rotary_dim",
    "test_rotary_embedding_with_interleaved_rotary_dim",
    "test_rotary_embedding_with_rotary_dim",
]
model_paths = [os.path.join(base_path, name) for name in test_names]

for path in model_paths:
    print(f"Checking {path} for issues...")

    model = onnx.load(os.path.join(path, "model.onnx"))
    input0 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_0.pb"))
    ).numpy()
    input1 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_1.pb"))
    ).numpy()
    input2 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_2.pb"))
    ).numpy()
    if os.path.exists(os.path.join(path, "test_data_set_0", "input_3.pb")):
        input3 = ir.from_proto(
            onnx.load_tensor(os.path.join(path, "test_data_set_0", "input_3.pb"))
        ).numpy()
    else:
        input3 = None
    output0 = ir.from_proto(
        onnx.load_tensor(os.path.join(path, "test_data_set_0", "output_0.pb"))
    ).numpy()

    m = ir.from_proto(model)

    node = m.graph[-1]
    print(node)
    assert node.op_type == "RotaryEmbedding"

    interleaved = node.attributes.get_int("interleaved", 0)
    num_heads = node.attributes.get_int("num_heads", 0)
    rotary_embedding_dim = node.attributes.get_int("rotary_embedding_dim", 0)

    torch_out = torch.onnx.ops.rotary_embedding(
        torch.tensor(input0),
        torch.tensor(input1),
        torch.tensor(input2),
        position_ids=torch.tensor(input3) if input3 is not None else None,
        interleaved=bool(interleaved),
        num_heads=num_heads,
        rotary_embedding_dim=rotary_embedding_dim,
    )
    torch_out = torch_out.detach().cpu().numpy()
    np.testing.assert_allclose(torch_out, output0)
```

Fix https://github.com/pytorch/pytorch/issues/162848

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162865
Approved by: https://github.com/kunal-vaishnavi, https://github.com/titaiwangms
2025-09-16 03:30:05 +00:00
fbbab794ef [ONNX] Implement Attention-23 (#156431)
Implement Attention-23 using sdpa and flexattention.

- I used copilot for this.
- Also updated the conversion logic to remove trailing None inputs.

@gramalingam @kunal-vaishnavi @titaiwangms
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156431
Approved by: https://github.com/titaiwangms

Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-06-20 23:54:57 +00:00
3e57de1251 [ONNX] Create support for rotary embeddings (#154745)
This PR registers the RotaryEmbedding op in the `torch.ops.onnx` name spaces and allows the exporter to recognize and export onnx operators.

## Design

ONNX operators of their respective opset version is implemented in torch/onnx/ops/_impl.py, and are registered in the torch.ops.onnx namespace following the following rule:

`OpType-version => torch.ops.onnx.OpType.opset{version}`

For example, `RotaryEmbedding-23` becomes `torch.ops.onnx.RotaryEmbedding.opset23`

This name is parsed by the exporter to create an onnx node in the graph without having to go through translation.

When users use the ops in the model, we provide more convenient, unversioned functions under `torch.onnx.ops` that will dispatch to the implementations based on user input (type and provided attributes). For example, users can directly call `torch.onnx.ops.rotary_embedding()` to use the op natively in their pytorch models. I chose snake case naming to make the functions more pythonic and aligned with other torch apis.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154745
Approved by: https://github.com/titaiwangms
2025-06-04 03:07:43 +00:00
3efa211e48 [ONNX] Annotate None inputs in symbolic ops (#150038)
Add `None` to type annotations of `torch.onnx.ops.symbolic*` ops and improve tests to test support for optional inputs. Previously it was omitted mistakenly even though the implementation supports it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150038
Approved by: https://github.com/titaiwangms
2025-03-27 00:01:09 +00:00
010963032c [ONNX] Create onnx_symbolic (#148905)
In the old exporter we allow users to define a symbolic() method to bypass JIT tracing for a block of logic. We can allow users to do similar things by creating symbolic ops at export.

This PR implements `torch.onnx.ops.symbolic` and `torch.onnx.ops.symbolic_multi_out` to allow users to create onnx nodes symbolically with pt2 & fx. The custom pytorch ops were designed such that the attributes are encoded to be part of a valid fx op. Users provide shape and dtype for the meta function to produce the currect fake tensor during export.

An example is

![image](https://github.com/user-attachments/assets/c62f5f21-e038-456e-a71d-b9a5d0a7cd9d)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148905
Approved by: https://github.com/titaiwangms
2025-03-18 21:32:06 +00:00