Files
pytorch/docs/source/onnx_ops.md
Justin Chu fbbab794ef [ONNX] Implement Attention-23 (#156431)
Implement Attention-23 using sdpa and flexattention.

- I used copilot for this.
- Also updated the conversion logic to remove trailing None inputs.

@gramalingam @kunal-vaishnavi @titaiwangms
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156431
Approved by: https://github.com/titaiwangms

Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-06-20 23:54:57 +00:00

129 lines
3.5 KiB
Markdown

# torch.onnx.ops
```{eval-rst}
.. automodule:: torch.onnx.ops
```
## Symbolic Operators
Operators that can be used to create any ONNX ops in the FX graph symbolically.
These operators do not do actual computation. It's recommended that you used them
inside an ``if torch.onnx.is_in_onnx_export`` block.
```{eval-rst}
.. autofunction:: torch.onnx.ops.symbolic
.. autofunction:: torch.onnx.ops.symbolic_multi_out
```
## ONNX Operators
The following operators are implemented as native PyTorch ops and can be exported as
ONNX operators. They can be used natively in an ``nn.Module``.
For example, you can define a module:
```py
class Model(torch.nn.Module):
def forward(
self, input_data, cos_cache_data, sin_cache_data, position_ids_data
):
return torch.onnx.ops.rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
position_ids_data,
)
```
and export it to ONNX using:
```py
input_data = torch.rand(2, 3, 4, 8)
position_ids_data = torch.randint(0, 50, (2, 3)).long()
sin_cache_data = torch.rand(50, 4)
cos_cache_data = torch.rand(50, 4)
dynamic_shapes = {
"input_data": {0: torch.export.Dim.DYNAMIC},
"cos_cache_data": None,
"sin_cache_data": None,
"position_ids_data": {0: torch.export.Dim.DYNAMIC},
}
onnx_program = torch.onnx.export(
model,
(input_data, cos_cache_data, sin_cache_data, position_ids_data),
dynamic_shapes=dynamic_shapes,
dynamo=True,
opset_version=23,
)
```
Printing the ONNX program will show the ONNX operators used in the graph:
```
<...>
graph(
name=main_graph,
inputs=(
%"input_data"<FLOAT,[s0,3,4,8]>,
%"cos_cache_data"<FLOAT,[50,4]>,
%"sin_cache_data"<FLOAT,[50,4]>,
%"position_ids_data"<INT64,[s0,3]>
),
outputs=(
%"rotary_embedding"<FLOAT,[s0,3,4,8]>
),
) {
0 | # rotary_embedding
%"rotary_embedding"<FLOAT,[s0,3,4,8]> ⬅️ ::RotaryEmbedding(%"input_data", %"cos_cache_data", %"sin_cache_data", %"position_ids_data")
return %"rotary_embedding"<FLOAT,[s0,3,4,8]>
}
```
with the corresponding ``ExportedProgram``:
ExportedProgram:
```py
class GraphModule(torch.nn.Module):
def forward(self, input_data: "f32[s0, 3, 4, 8]", cos_cache_data: "f32[50, 4]", sin_cache_data: "f32[50, 4]", position_ids_data: "i64[s0, 3]"):
rotary_embedding: "f32[s0, 3, 4, 8]" = torch.ops.onnx.RotaryEmbedding.opset23(input_data, cos_cache_data, sin_cache_data, position_ids_data); input_data = cos_cache_data = sin_cache_data = position_ids_data = None
return (rotary_embedding,)
```
```{eval-rst}
.. autofunction:: torch.onnx.ops.rotary_embedding
.. autofunction:: torch.onnx.ops.attention
```
## ONNX to ATen Decomposition Table
You can use {func}`torch.onnx.ops.aten_decompositions` to obtain a decomposition table
to decompose ONNX operators defined above to ATen operators.
```py
class Model(torch.nn.Module):
def forward(
self, input_data, cos_cache_data, sin_cache_data, position_ids_data
):
return torch.onnx.ops.rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
position_ids_data,
)
model = Model()
ep = torch.export.export(
model,
(input_data, cos_cache_data, sin_cache_data, position_ids_data),
)
# The program can be decomposed into aten ops
ep_decomposed = ep.run_decompositions(torch.onnx.ops.aten_decompositions())
```
```{eval-rst}
.. autofunction:: torch.onnx.ops.aten_decompositions
```