[ONNX] Create onnx_symbolic (#148905)

In the old exporter we allow users to define a symbolic() method to bypass JIT tracing for a block of logic. We can allow users to do similar things by creating symbolic ops at export.

This PR implements `torch.onnx.ops.symbolic` and `torch.onnx.ops.symbolic_multi_out` to allow users to create onnx nodes symbolically with pt2 & fx. The custom pytorch ops were designed such that the attributes are encoded to be part of a valid fx op. Users provide shape and dtype for the meta function to produce the currect fake tensor during export.

An example is

![image](https://github.com/user-attachments/assets/c62f5f21-e038-456e-a71d-b9a5d0a7cd9d)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148905
Approved by: https://github.com/titaiwangms
This commit is contained in:
Justin Chu
2025-03-18 21:32:06 +00:00
committed by PyTorch MergeBot
parent d80a70b58a
commit 010963032c
9 changed files with 1175 additions and 16 deletions

View File

@ -88,6 +88,7 @@ also be interested in reading our `development wiki <https://github.com/pytorch/
:hidden:
onnx_dynamo
onnx_ops
onnx_verification
onnx_dynamo_onnxruntime_backend
onnx_torchscript

11
docs/source/onnx_ops.rst Normal file
View File

@ -0,0 +1,11 @@
torch.onnx.ops
==============
.. automodule:: torch.onnx.ops
Operators
---------
.. autofunction:: torch.onnx.ops.symbolic
.. autofunction:: torch.onnx.ops.symbolic_multi_out

418
test/onnx/ops/test_ops.py Normal file
View File

@ -0,0 +1,418 @@
# Owner(s): ["module: onnx"]
"""Test torch.onnx.ops."""
from __future__ import annotations
from onnxscript import ir
import torch
from torch.onnx.ops import _symbolic_impl
from torch.testing._internal import common_utils
class SchemaTest(common_utils.TestCase):
def test_symbolic_has_correct_schema(self):
torch.library.opcheck(
_symbolic_impl._symbolic,
([torch.tensor(1)], "CustomOp", 1),
dict(
shape=[
1,
],
attr_keys=["key"],
attr_types=["i"],
attr_pos=[(0, 1)],
attr_ints=[1],
attr_floats=[1.0],
attr_strs=["attr"],
metadata_props_keys=["meta_key"],
metadata_props_values=["meta_value"],
domain="custom_domain",
version=42,
),
)
# Empty inputs
torch.library.opcheck(
_symbolic_impl._symbolic,
([], "CustomOp", 1),
dict(
shape=[
1,
],
attr_keys=[],
attr_types=[],
attr_pos=[],
attr_ints=[],
attr_floats=[],
attr_strs=[],
metadata_props_keys=[],
metadata_props_values=[],
),
)
def test_symbolic_multi_out_has_correct_schema(self):
torch.library.opcheck(
_symbolic_impl._symbolic_multi_out,
([torch.tensor(1)], "CustomMultiOutOp", [1, 2, 10]),
dict(
shapes=[[1, 2], [42], []],
attr_keys=["key"],
attr_types=["i"],
attr_pos=[(0, 1)],
attr_ints=[1],
attr_floats=[1.0],
attr_strs=["attr"],
metadata_props_keys=["meta_key"],
metadata_props_values=["meta_value"],
domain="",
version=1,
),
)
# Empty inputs
torch.library.opcheck(
_symbolic_impl._symbolic_multi_out,
([], "CustomMultiOutOp", []),
dict(
shapes=[],
attr_keys=[],
attr_types=[],
attr_pos=[],
attr_ints=[],
attr_floats=[],
attr_strs=[],
metadata_props_keys=[],
metadata_props_values=[],
),
)
class SymbolicOpsTest(common_utils.TestCase):
def test_symbolic_accepts_valid_inputs(self):
output = torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(torch.tensor(1),),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtype=torch.float32,
shape=[1, 2, 3],
version=1,
metadata_props={"meta_key": "meta_value"},
)
self.assertEqual(output.shape, torch.Size([1, 2, 3]))
self.assertEqual(output.dtype, torch.float32)
self.assertEqual(output.device, torch.device("cpu"))
def test_symbolic_accepts_valid_inputs_empty_shape(self):
output = torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtype=torch.float32,
shape=[],
)
self.assertEqual(output.shape, torch.Size([]))
def test_symbolic_accepts_valid_inputs_integer_types(self):
output = torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtype=1, # 1 is float32 in ONNX
shape=[42],
)
self.assertEqual(output.dtype, torch.float32)
def test_symbolic_accepts_valid_inputs_int4_type(self):
output = torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtype=22, # 22 is INT4 in ONNX
shape=[42],
)
# We use torch uint8 for int4
self.assertEqual(output.dtype, torch.uint8)
def test_symbolic_is_exportable(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor):
return torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(x,),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtype=x.dtype,
shape=[1, 2, 3],
version=1,
metadata_props={"meta_key": "meta_value"},
)
onnx_program = torch.onnx.export(
Model(), (torch.tensor(1),), dynamo=True, verbose=False
)
assert onnx_program is not None
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "CustomOp")
self.assertEqual(node.domain, "custom_domain")
attributes = node.attributes
self.assertEqual(
attributes,
dict(
int_key=ir.AttrInt64("int_key", 1),
float_key=ir.AttrFloat32("float_key", 1.0),
str_key=ir.AttrString("str_key", "attr"),
bool_key=ir.AttrInt64("bool_key", 1),
list_int_key=ir.AttrInt64s("list_int_key", [1, 2]),
list_float_key=ir.AttrFloat32s("list_float_key", [1.0, 2.0]),
list_str_key=ir.AttrStrings("list_str_key", ["attr1", "attr2"]),
list_bool_key=ir.AttrInt64s("list_bool_key", [1, 0]),
),
)
self.assertEqual(node.metadata_props["meta_key"], "meta_value")
outputs = node.outputs
self.assertEqual(list(outputs[0].shape), [1, 2, 3])
self.assertEqual(outputs[0].dtype, ir.DataType.INT64)
def test_symbolic_preserves_dynamic_shapes(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(x, y),
dtype=x.dtype,
shape=[*x.shape, *y.shape],
version=1,
)
onnx_program = torch.onnx.export(
Model(),
(torch.zeros(2, 3), torch.zeros(1, 2)),
dynamic_shapes=({0: "batch"}, {1: "something_else"}),
dynamo=True,
verbose=False,
)
assert onnx_program is not None
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "CustomOp")
self.assertEqual(node.domain, "custom_domain")
inputs = onnx_program.model.graph.inputs
self.assertEqual(str(inputs[0].shape[0]), "batch")
self.assertEqual(inputs[0].shape[1], 3)
self.assertEqual(inputs[1].shape[0], 1)
self.assertEqual(str(inputs[1].shape[1]), "something_else")
outputs = node.outputs
self.assertEqual(str(outputs[0].shape[0]), "batch")
self.assertEqual(outputs[0].shape[1], 3)
self.assertEqual(outputs[0].shape[2], 1)
self.assertEqual(str(outputs[0].shape[3]), "something_else")
self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
def test_symbolic_multi_out_accepts_valid_inputs(self):
outputs = torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomMultiOutOp",
(torch.tensor(1),),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtypes=(
1, # 1 is float32 in ONNX
torch.int32,
torch.float8_e4m3fn,
),
shapes=([1, 2], [42], []),
version=1,
metadata_props={"meta_key": "meta_value"},
)
self.assertEqual(len(outputs), 3)
self.assertEqual(outputs[0].shape, torch.Size([1, 2]))
self.assertEqual(outputs[0].dtype, torch.float32)
self.assertEqual(outputs[1].shape, torch.Size([42]))
self.assertEqual(outputs[1].dtype, torch.int32)
self.assertEqual(outputs[2].shape, torch.Size([]))
self.assertEqual(outputs[2].dtype, torch.float8_e4m3fn)
self.assertEqual(outputs[0].device, torch.device("cpu"))
self.assertEqual(outputs[1].device, torch.device("cpu"))
self.assertEqual(outputs[2].device, torch.device("cpu"))
def test_symbolic_multi_out_accepts_valid_inputs_empty_shape(self):
outputs = torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtypes=(torch.float32,),
shapes=[[]],
)
self.assertEqual(outputs[0].shape, torch.Size([]))
def test_symbolic_multi_out_accepts_valid_inputs_integer_types(self):
outputs = torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtypes=(1,), # 1 is float32 in ONNX
shapes=[[42]],
)
self.assertEqual(outputs[0].dtype, torch.float32)
def test_symbolic_multi_out_accepts_valid_inputs_int4_type(self):
outputs = torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtypes=(22,), # 22 is INT4 in ONNX
shapes=[[42]],
)
# We use torch uint8 for int4
self.assertEqual(outputs[0].dtype, torch.uint8)
def test_symbolic_multi_out_is_exportable(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor):
return torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomOp",
(x,),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtypes=(torch.float32, torch.int32, torch.float8_e4m3fn),
shapes=([1, 2], [42], []),
version=1,
metadata_props={"meta_key": "meta_value"},
)
onnx_program = torch.onnx.export(
Model(), (torch.tensor(1),), dynamo=True, verbose=False
)
assert onnx_program is not None
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "CustomOp")
self.assertEqual(node.domain, "custom_domain")
attributes = node.attributes
self.assertEqual(
attributes,
dict(
int_key=ir.AttrInt64("int_key", 1),
float_key=ir.AttrFloat32("float_key", 1.0),
str_key=ir.AttrString("str_key", "attr"),
bool_key=ir.AttrInt64("bool_key", 1),
list_int_key=ir.AttrInt64s("list_int_key", [1, 2]),
list_float_key=ir.AttrFloat32s("list_float_key", [1.0, 2.0]),
list_str_key=ir.AttrStrings("list_str_key", ["attr1", "attr2"]),
list_bool_key=ir.AttrInt64s("list_bool_key", [1, 0]),
),
)
self.assertEqual(node.metadata_props["meta_key"], "meta_value")
outputs = node.outputs
self.assertEqual(list(outputs[0].shape), [1, 2])
self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
self.assertEqual(list(outputs[1].shape), [42])
self.assertEqual(outputs[1].dtype, ir.DataType.INT32)
self.assertEqual(list(outputs[2].shape), [])
self.assertEqual(outputs[2].dtype, ir.DataType.FLOAT8E4M3FN)
def test_symbolic_multi_out_preserves_dynamic_shapes(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomOp",
(x, y),
dtypes=(x.dtype, 22), # 22 is INT4
shapes=[[*x.shape, *y.shape], [42]],
version=1,
)
onnx_program = torch.onnx.export(
Model(),
(torch.zeros(2, 3), torch.zeros(1, 2)),
dynamic_shapes=({0: "batch"}, {1: "something_else"}),
dynamo=True,
verbose=False,
)
assert onnx_program is not None
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "CustomOp")
self.assertEqual(node.domain, "custom_domain")
inputs = onnx_program.model.graph.inputs
self.assertEqual(str(inputs[0].shape[0]), "batch")
self.assertEqual(inputs[0].shape[1], 3)
self.assertEqual(inputs[1].shape[0], 1)
self.assertEqual(str(inputs[1].shape[1]), "something_else")
outputs = node.outputs
self.assertEqual(str(outputs[0].shape[0]), "batch")
self.assertEqual(outputs[0].shape[1], 3)
self.assertEqual(outputs[0].shape[2], 1)
self.assertEqual(str(outputs[0].shape[3]), "something_else")
self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
self.assertEqual(list(outputs[1].shape), [42])
self.assertEqual(outputs[1].dtype, ir.DataType.INT4)
def test_symbolic_multi_out_raises_when_dtypes_and_shapes_differ(self):
with self.assertRaises(RuntimeError):
torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomMultiOutOp",
(torch.tensor(1),),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtypes=(torch.float32, torch.int32),
shapes=([1, 2], [42], []),
version=1,
metadata_props={"meta_key": "meta_value"},
)
with self.assertRaises(RuntimeError):
torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomMultiOutOp",
(torch.tensor(1),),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtypes=(torch.float32,),
shapes=([1, 2], [42]),
version=1,
metadata_props={"meta_key": "meta_value"},
)
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -4,9 +4,10 @@ from __future__ import annotations
__all__ = [
# Modules
"errors",
"ops",
"symbolic_helper",
"utils",
"errors",
# All opsets
"symbolic_caffe2",
"symbolic_opset7",
@ -52,7 +53,6 @@ from typing import Any, Callable, TYPE_CHECKING
from typing_extensions import deprecated
import torch
from torch import _C
from torch._C import _onnx as _C_onnx
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
@ -77,6 +77,7 @@ from .utils import (
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
errors,
ops,
symbolic_caffe2,
symbolic_helper,
symbolic_opset7,

View File

@ -200,27 +200,33 @@ def _set_shape_type(
| tuple[torch.Tensor],
complex_to_float: bool,
) -> None:
# TODO: Consider using meta["tensor_meta"] for this? Would it be faster?
if isinstance(meta_val, tuple):
logger.warning("Setting shape and type of tensors is not supported yet")
if isinstance(meta_val, torch.Tensor):
# FIXME: Consider shape for complex values
dims = []
for dim in meta_val.shape:
if isinstance(dim, int):
dims.append(dim)
else:
dims.append(str(dim.node))
value.dtype = _torch_dtype_to_onnx_dtype(meta_val.dtype)
if complex_to_float:
if meta_val.dtype == torch.complex64:
value.dtype = ir.DataType.FLOAT
# Add 2 as the last dimension if the tensor is complex to hold the real/imag parts
dims.append(2)
elif meta_val.dtype == torch.complex128:
value.dtype = ir.DataType.DOUBLE
# Add 2 as the last dimension if the tensor is complex to hold the real/imag parts
dims.append(2)
# If the dtype is set already (e.g. by the onnx_symbolic ops),
# we don't need to set it again.
#
# When a user specifies complex in onnx_symbolic, we consider that to
# be the intention even though non of the ONNX ops deals with complex values.
# In this case, we don't change the dtype or the shape of the tensor.
if value.dtype is None:
value.dtype = _torch_dtype_to_onnx_dtype(meta_val.dtype)
if complex_to_float:
if meta_val.dtype == torch.complex64:
value.dtype = ir.DataType.FLOAT
# Add 2 as the last dimension if the tensor is complex to hold the real/imag parts
dims.append(2)
elif meta_val.dtype == torch.complex128:
value.dtype = ir.DataType.DOUBLE
# Add 2 as the last dimension if the tensor is complex to hold the real/imag parts
dims.append(2)
value.shape = ir.Shape(dims)
elif isinstance(meta_val, (int, torch.SymInt)):

View File

@ -1,6 +1,6 @@
from __future__ import annotations
__all__ = ["core", "hop"]
__all__ = ["core", "hop", "symbolic"]
from torch.onnx._internal.exporter._torchlib.ops import core, hop
from torch.onnx._internal.exporter._torchlib.ops import core, hop, symbolic

View File

@ -0,0 +1,149 @@
"""Implementation for higher-order operators."""
from __future__ import annotations
from typing import TYPE_CHECKING
from onnxscript.ir import convenience as ir_convenience
import torch
from torch.onnx._internal._lazy_import import onnxscript_ir as ir
from torch.onnx._internal.exporter import _core
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
from torch.onnx.ops import _symbolic_impl
if TYPE_CHECKING:
from collections.abc import Sequence
def _call_symbolic_op(
op_type: str,
domain: str,
args: Sequence[ir.Value | None],
kwargs: dict[str, int | float | str | bool | list[int] | list[float] | list[str]],
dtypes: Sequence[int],
version: int | None,
metadata_props: dict[str, str] | None,
) -> Sequence[ir.Value]:
"""Call an operator with the given arguments and keyword arguments.
Arguments are always inputs, while keyword arguments are attributes.
"""
# This is a wrapper around the IR node creation that hooks into the _builder.OpRecorder
# tracer so that all nodes created are recorded the same way as if we were to use
# onnxscript ops directly.
assert _core.current_tracer is not None
tracer = _core.current_tracer
inputs = list(args)
# If final inputs are None, strip them from the node inputs
for input in reversed(inputs):
if input is not None:
break
inputs.pop()
# Construct and filter out None attributes
attributes = [
attr
for attr in ir_convenience.convert_attributes(kwargs) # type: ignore[arg-type]
if attr.value is not None # type: ignore[union-attr]
]
tracer.nodes.append(
node := ir.Node(
domain,
op_type,
inputs=inputs,
attributes=attributes,
num_outputs=len(dtypes),
version=version,
metadata_props=metadata_props,
)
)
# Set the dtypes for the outputs. We set them here because the graph builder
# Uses PyTorch types which are sometimes inaccurate when they are ONNX only
# types like float4e2m1.
for value, dtype in zip(node.outputs, dtypes):
value.dtype = ir.DataType(dtype)
# The shape is set by the graph builder. We don't need to set it here.
return node.outputs
@onnx_impl(torch.ops.onnx_symbolic._symbolic.default, no_compile=True)
def onnx_symbolic_symbolic(
inputs: Sequence[ir.Value | None],
op_type: str,
onnx_dtype: int,
*,
shape: Sequence[int | ir.Value],
attr_keys: Sequence[str],
attr_types: Sequence[str],
attr_pos: Sequence[tuple[int, int]],
attr_ints: Sequence[int],
attr_floats: Sequence[float],
attr_strs: Sequence[str],
metadata_props_keys: Sequence[str] = (),
metadata_props_values: Sequence[str] = (),
domain: str = "",
version: int | None = None,
) -> ir.Value:
del shape # Unused. The shapes are set by the graph builder
encoded = _symbolic_impl.EncodedAttrs(
attr_keys=list(attr_keys),
attr_types=list(attr_types),
attr_pos=list(attr_pos),
attr_ints=list(attr_ints),
attr_floats=list(attr_floats),
attr_strs=list(attr_strs),
)
attrs = encoded.to_dict()
return _call_symbolic_op(
op_type,
domain,
inputs,
attrs,
dtypes=[onnx_dtype],
version=version,
metadata_props=dict(zip(metadata_props_keys, metadata_props_values)),
)[0]
@onnx_impl(torch.ops.onnx_symbolic._symbolic_multi_out.default, no_compile=True)
def onnx_symbolic_symbolic_multi_out(
inputs: Sequence[ir.Value | None],
op_type: str,
onnx_dtypes: Sequence[int],
*,
shapes: Sequence[Sequence[int | ir.Value]],
attr_keys: Sequence[str],
attr_types: Sequence[str],
attr_pos: Sequence[tuple[int, int]],
attr_ints: Sequence[int],
attr_floats: Sequence[float],
attr_strs: Sequence[str],
metadata_props_keys: Sequence[str] = (),
metadata_props_values: Sequence[str] = (),
domain: str = "",
version: int | None = None,
) -> Sequence[ir.Value]:
del shapes # Unused. The shapes are set by the graph builder
encoded = _symbolic_impl.EncodedAttrs(
attr_keys=list(attr_keys),
attr_types=list(attr_types),
attr_pos=list(attr_pos),
attr_ints=list(attr_ints),
attr_floats=list(attr_floats),
attr_strs=list(attr_strs),
)
attrs = encoded.to_dict()
return _call_symbolic_op(
op_type,
domain,
inputs,
attrs,
dtypes=onnx_dtypes,
version=version,
metadata_props=dict(zip(metadata_props_keys, metadata_props_values)),
)

243
torch/onnx/ops/__init__.py Normal file
View File

@ -0,0 +1,243 @@
"""ONNX operators as native torch.fx operators.
This module provides a set of functions to create ONNX operators in the FX graph
which are exportable to ONNX.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from torch.onnx.ops import _symbolic_impl
if TYPE_CHECKING:
from collections.abc import Sequence
# https://github.com/onnx/onnx/blob/f542e1f06699ea7e1db5f62af53355b64338c723/onnx/onnx.proto#L597
_TORCH_DTYPE_TO_ONNX_DTYPE = {
torch.float32: 1, # FLOAT
torch.uint8: 2, # UINT8
torch.int8: 3, # INT8
torch.uint16: 4, # UINT16
torch.int16: 5, # INT16
torch.int32: 6, # INT32
torch.int64: 7, # INT64
str: 8, # STRING
torch.bool: 9, # BOOL
torch.float16: 10, # FLOAT16
torch.double: 11, # DOUBLE
torch.uint32: 12, # UINT32
torch.uint64: 13, # UINT64
torch.complex64: 14, # COMPLEX64
torch.complex128: 15, # COMPLEX128
torch.bfloat16: 16, # BFLOAT16
torch.float8_e4m3fn: 17, # FLOAT8E4M3FN
torch.float8_e4m3fnuz: 18, # FLOAT8E4M3FNUZ
torch.float8_e5m2: 19, # FLOAT8E5M2
torch.float8_e5m2fnuz: 20, # FLOAT8E5M2FNUZ
}
def _parse_domain_op_type(domain_op: str) -> tuple[str, str]:
splitted = domain_op.split("::", 1)
if len(splitted) == 1:
domain = ""
op_type = splitted[0]
else:
domain = splitted[0]
op_type = splitted[1]
return domain, op_type
def symbolic(
domain_op: str,
/,
inputs: Sequence[torch.Tensor],
attrs: dict[
str,
int
| float
| str
| bool
| Sequence[int]
| Sequence[float]
| Sequence[str]
| Sequence[bool],
]
| None = None,
*,
dtype: torch.dtype | int,
shape: Sequence[int | torch.SymInt],
version: int | None = None,
metadata_props: dict[str, str] | None = None,
) -> torch.Tensor:
"""Create a symbolic FX operator to represent an arbitrary ONNX operator.
This function is used to create a symbolic operator with a single output.
To create an operator with multiple outputs, use :func:`symbolic_multi_out`.
Example::
class CustomOp(torch.nn.Module):
def forward(self, x: torch.Tensor):
return torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(x,),
dict(attr_key="attr_value"),
dtype=x.dtype,
shape=x.shape,
version=1,
)
# This will create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain.
# The output tensor will have the specified dtype and shape.
# You may then export this model to ONNX using torch.onnx.export.
Args:
domain_op: The domain and operator name, separated by "::". For example,
"custom_domain::CustomOp".
inputs: The input tensors to the operator.
attrs: The attributes of the operator. The keys are attribute names and
the values are attribute values. Valid attribute types are int, float,
str, bool, and lists of int, float, str, and bool. Tensor attributes
are unsupported.
dtype: The data type of the output tensor.This can be either a torch.dtype
or an integer representing the ONNX data type.
shape: The shape of the output tensor. This can be a list of integers or
SymInt values.
version: The version of the opset used for the operator.
metadata_props: Metadata properties for the ONNX node.
This is a dictionary of str-str pairs.
Returns:
The output tensor of the operator.
"""
if not isinstance(dtype, int):
torch._check(
dtype in _TORCH_DTYPE_TO_ONNX_DTYPE, lambda: f"Unsupported dtype: {dtype}"
)
dtype = _TORCH_DTYPE_TO_ONNX_DTYPE[dtype]
domain, op_type = _parse_domain_op_type(domain_op)
if attrs is None:
attrs = {}
encoded_attrs = _symbolic_impl.EncodedAttrs.from_dict(attrs)
# TODO: Parse domain
return _symbolic_impl._symbolic(
inputs,
op_type,
dtype,
shape=shape,
attr_keys=encoded_attrs.attr_keys,
attr_types=encoded_attrs.attr_types,
attr_pos=encoded_attrs.attr_pos,
attr_ints=encoded_attrs.attr_ints,
attr_floats=encoded_attrs.attr_floats,
attr_strs=encoded_attrs.attr_strs,
metadata_props_keys=metadata_props.keys() if metadata_props else [],
metadata_props_values=metadata_props.values() if metadata_props else [],
domain=domain,
version=version,
)
def symbolic_multi_out(
domain_op: str,
/,
inputs: Sequence[torch.Tensor],
attrs: dict[
str,
int
| float
| str
| bool
| Sequence[int]
| Sequence[float]
| Sequence[str]
| Sequence[bool],
]
| None = None,
*,
dtypes: Sequence[torch.dtype | int],
shapes: Sequence[Sequence[int | torch.SymInt]],
version: int | None = None,
metadata_props: dict[str, str] | None = None,
) -> Sequence[torch.Tensor]:
"""Create a symbolic FX operator to represent an arbitrary ONNX operator with multiple outputs.
Example::
class CustomOp(torch.nn.Module):
def forward(self, x: torch.Tensor):
return torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(x,),
dict(attr_key="attr_value"),
dtypes=(x.dtype, torch.float32),
shapes=(x.shape, [1, 2, 3]),
version=1,
)
# This will create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain.
# The output tensor will have the specified dtype and shape.
# You may then export this model to ONNX using torch.onnx.export.
Args:
domain_op: The domain and operator name, separated by "::". For example,
"custom_domain::CustomOp".
inputs: The input tensors to the operator.
attrs: The attributes of the operator. The keys are attribute names and
the values are attribute values. Valid attribute types are int, float,
str, bool, and lists of int, float, str, and bool. Tensor attributes
are unsupported.
dtypes: The data types of the output tensors. This can be a list of
torch.dtype or integers representing the ONNX data types. The length
of this list must be the number of outputs.
shapes: The shapes of the output tensors. This can be a list of lists of
integers or SymInt values. The length of this list must be the number of outputs.
version: The version of the opset used for the operator.
metadata_props: Metadata properties for the ONNX node.
This is a dictionary of str-str pairs.
Returns:
A list of output tensors of the operator.
"""
torch._check(
len(shapes) == len(dtypes),
lambda: f"Number of shapes ({len(shapes)}) must match number of dtypes ({len(dtypes)})",
)
onnx_dtypes = []
for dtype in dtypes:
if not isinstance(dtype, int):
torch._check(
dtype in _TORCH_DTYPE_TO_ONNX_DTYPE,
lambda: f"Unsupported dtype: {dtype}",
)
onnx_dtypes.append(_TORCH_DTYPE_TO_ONNX_DTYPE[dtype])
else:
onnx_dtypes.append(dtype)
domain, op_type = _parse_domain_op_type(domain_op)
if attrs is None:
attrs = {}
encoded_attrs = _symbolic_impl.EncodedAttrs.from_dict(attrs)
# Use the size of dtypes to determine the number of outputs
return _symbolic_impl._symbolic_multi_out(
inputs,
op_type,
onnx_dtypes,
shapes=shapes,
attr_keys=encoded_attrs.attr_keys,
attr_types=encoded_attrs.attr_types,
attr_pos=encoded_attrs.attr_pos,
attr_ints=encoded_attrs.attr_ints,
attr_floats=encoded_attrs.attr_floats,
attr_strs=encoded_attrs.attr_strs,
metadata_props_keys=metadata_props.keys() if metadata_props else [],
metadata_props_values=metadata_props.values() if metadata_props else [],
domain=domain,
version=version,
)

View File

@ -0,0 +1,330 @@
"""Implementation of symbolic FX ops to represent arbitrary ONNX ops.
This module provides a way to create symbolic FX operators that can represent
arbitrary ONNX operators.
The operators are called "symbolic" because they don't do any actual computation
but instead serve as placeholders in the computation graph.
Each implementation contains two parts: A "real" implementation that produce all
zeros based on the input shape and dtype, and a "fake" implementation that does more
or less the same thing but is required by the `torch.library.custom_op` interface.
"""
import dataclasses
from collections.abc import Sequence
from typing import Optional, Union
import torch
_ONNX_DTYPE_TO_TORCH_DTYPE: dict[int, torch.dtype] = {
1: torch.float32, # FLOAT
2: torch.uint8, # UINT8
3: torch.int8, # INT8
4: torch.uint16, # UINT16
5: torch.int16, # INT16
6: torch.int32, # INT32
7: torch.int64, # INT64
9: torch.bool, # BOOL
10: torch.float16, # FLOAT16
11: torch.double, # DOUBLE
12: torch.uint32, # UINT32
13: torch.uint64, # UINT64
14: torch.complex64, # COMPLEX64
15: torch.complex128, # COMPLEX128
16: torch.bfloat16, # BFLOAT16
17: torch.float8_e4m3fn, # FLOAT8E4M3FN
18: torch.float8_e4m3fnuz, # FLOAT8E4M3FNUZ
19: torch.float8_e5m2, # FLOAT8E5M2
20: torch.float8_e5m2fnuz, # FLOAT8E5M2FNUZ
21: torch.uint8, # UINT4
22: torch.uint8, # INT4
23: torch.uint8, # FLOAT4E2M1
}
_INT_TYPE = "i"
_FLOAT_TYPE = "f"
_STRING_TYPE = "s"
_INT_SEQ_TYPE = "is"
_FLOAT_SEQ_TYPE = "fs"
_STRING_SEQ_TYPE = "ss"
@dataclasses.dataclass
class EncodedAttrs:
"""Class to encode attributes from dictionary into lists of FX compatible attributes.
Since FX does not support dictionaries, we need to encode the attributes into
lists. This class provides a way to encode and decode the attributes.
Attributes:
attr_keys: List of attribute keys.
attr_types: List of attribute types. Values can be "i" (int), "f" (float),
"s" (string), "is" (int sequence), "fs" (float sequence), or "ss" (string sequence).
attr_pos: List of tuples representing the start and end positions of each
attribute in the corresponding list.
attr_ints: List of integer attributes.
attr_floats: List of float attributes.
attr_strs: List of string attributes.
"""
attr_keys: list[str]
attr_types: list[str]
attr_pos: list[tuple[int, int]]
attr_ints: list[int]
attr_floats: list[float]
attr_strs: list[str]
@classmethod
def from_dict(
cls,
attrs: dict[
str,
Union[
int,
float,
str,
bool,
Sequence[int],
Sequence[float],
Sequence[str],
Sequence[bool],
],
],
) -> "EncodedAttrs":
encoded = cls(
attr_keys=[],
attr_types=[],
attr_pos=[],
attr_ints=[],
attr_floats=[],
attr_strs=[],
)
for i, (k, v) in enumerate(attrs.items()):
encoded.attr_keys.append(k)
if isinstance(v, int):
start_pos = len(encoded.attr_ints)
encoded.attr_ints.append(v)
encoded.attr_pos.append((start_pos, start_pos + 1))
encoded.attr_types.append(_INT_TYPE)
elif isinstance(v, float):
start_pos = len(encoded.attr_floats)
encoded.attr_floats.append(v)
encoded.attr_pos.append((start_pos, start_pos + 1))
encoded.attr_types.append(_FLOAT_TYPE)
elif isinstance(v, str):
start_pos = len(encoded.attr_strs)
encoded.attr_strs.append(v)
encoded.attr_pos.append((start_pos, start_pos + 1))
encoded.attr_types.append(_STRING_TYPE)
elif isinstance(v, Sequence):
if len(v) == 0:
raise ValueError(f"Empty sequence for attribute {k}")
if any(isinstance(elem, float) for elem in v):
start_pos = len(encoded.attr_floats)
encoded.attr_floats.extend([float(elem) for elem in v])
encoded.attr_pos.append((start_pos, start_pos + len(v)))
encoded.attr_types.append(_FLOAT_SEQ_TYPE)
elif isinstance(v[0], int):
start_pos = len(encoded.attr_ints)
encoded.attr_ints.extend([int(elem) for elem in v])
encoded.attr_pos.append((start_pos, start_pos + len(v)))
encoded.attr_types.append(_INT_SEQ_TYPE)
elif isinstance(v[0], str):
start_pos = len(encoded.attr_strs)
encoded.attr_strs.extend([str(elem) for elem in v])
encoded.attr_pos.append((start_pos, start_pos + len(v)))
encoded.attr_types.append(_STRING_SEQ_TYPE)
else:
raise ValueError(f"Unsupported sequence type for attribute {k}")
else:
raise ValueError(f"Unsupported attribute type for {k}: {type(v)}")
assert len(encoded.attr_keys) == len(encoded.attr_types), (
f"Mismatch between number of attribute keys and types: {len(encoded.attr_keys)} != {len(encoded.attr_types)}"
)
assert len(encoded.attr_keys) == len(encoded.attr_pos), (
f"Mismatch between number of attribute keys and positions: {len(encoded.attr_keys)} != {len(encoded.attr_pos)}"
)
return encoded
def to_dict(
self,
) -> dict[
str,
Union[
int,
float,
str,
list[int],
list[float],
list[str],
],
]:
"""Convert the encoded attributes back to a dictionary for creating an ONNX node."""
attrs: dict[
str,
Union[
int,
float,
str,
list[int],
list[float],
list[str],
],
] = {}
for i, key in enumerate(self.attr_keys):
attr_type = self.attr_types[i]
if attr_type == _INT_TYPE:
attrs[key] = self.attr_ints[self.attr_pos[i][0]]
elif attr_type == _FLOAT_TYPE:
attrs[key] = self.attr_floats[self.attr_pos[i][0]]
elif attr_type == _STRING_TYPE:
attrs[key] = self.attr_strs[self.attr_pos[i][0]]
elif attr_type == _FLOAT_SEQ_TYPE:
attrs[key] = self.attr_floats[self.attr_pos[i][0] : self.attr_pos[i][1]]
elif attr_type == _INT_SEQ_TYPE:
attrs[key] = self.attr_ints[self.attr_pos[i][0] : self.attr_pos[i][1]]
elif attr_type == _STRING_SEQ_TYPE:
attrs[key] = self.attr_strs[self.attr_pos[i][0] : self.attr_pos[i][1]]
else:
raise ValueError(f"Unsupported attribute type: {attr_type}")
return attrs
@torch.library.custom_op(
"onnx_symbolic::_symbolic",
mutates_args=(),
schema=(
"(Tensor?[] inputs, str op_type, int onnx_dtype, *,"
" SymInt[] shape, str[] attr_keys, str[] attr_types, int[][] attr_pos,"
" int[] attr_ints, float[] attr_floats, str[] attr_strs, str[] metadata_props_keys,"
" str[] metadata_props_values, str domain='', int? version=None"
") -> Tensor"
),
)
def _symbolic(
inputs: Sequence[Optional[torch.Tensor]],
op_type: str,
onnx_dtype: int,
*,
shape: Sequence[Union[int, torch.SymInt]],
attr_keys: Sequence[str],
attr_types: Sequence[str],
attr_pos: Sequence[tuple[int, int]],
attr_ints: Sequence[int],
attr_floats: Sequence[float],
attr_strs: Sequence[str],
metadata_props_keys: Sequence[str] = (),
metadata_props_values: Sequence[str] = (),
domain: str = "",
version: Optional[int] = None,
) -> torch.Tensor:
torch._check(
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
)
return torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype])
@_symbolic.register_fake
def _(
inputs: Sequence[torch.Tensor],
op_type: str,
onnx_dtype: int,
*,
shape: Sequence[Union[int, torch.SymInt]],
attr_keys: Sequence[str],
attr_types: Sequence[str],
attr_pos: Sequence[tuple[int, int]],
attr_ints: Sequence[int],
attr_floats: Sequence[float],
attr_strs: Sequence[str],
metadata_props_keys: Sequence[str] = (),
metadata_props_values: Sequence[str] = (),
domain: str = "",
version: Optional[int] = None,
) -> torch.Tensor:
torch._check(
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
)
# NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured
# out how it can handle empty shapes
return torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype])
@torch.library.custom_op(
"onnx_symbolic::_symbolic_multi_out",
mutates_args=(),
schema=(
"(Tensor?[] inputs, str op_type, int[] onnx_dtypes, *,"
" SymInt[][] shapes, str[] attr_keys, str[] attr_types, int[][] attr_pos,"
" int[] attr_ints, float[] attr_floats, str[] attr_strs, str[] metadata_props_keys,"
" str[] metadata_props_values, str domain='', int? version=None"
") -> Tensor[]"
),
)
def _symbolic_multi_out(
inputs: Sequence[Optional[torch.Tensor]],
op_type: str,
onnx_dtypes: Sequence[int],
*,
shapes: Sequence[Sequence[Union[int, torch.SymInt]]],
attr_keys: Sequence[str],
attr_types: Sequence[str],
attr_pos: Sequence[tuple[int, int]],
attr_ints: Sequence[int],
attr_floats: Sequence[float],
attr_strs: Sequence[str],
metadata_props_keys: Sequence[str] = (),
metadata_props_values: Sequence[str] = (),
domain: str = "",
version: Optional[int] = None,
) -> list[torch.Tensor]:
outputs = []
torch._check(
len(shapes) == len(onnx_dtypes),
lambda: f"Number of shapes ({len(shapes)}) must match number of ONNX dtypes ({len(onnx_dtypes)})",
)
for shape, onnx_dtype in zip(shapes, onnx_dtypes):
torch._check(
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
)
outputs.append(torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]))
return outputs
@_symbolic_multi_out.register_fake
def _(
inputs: Sequence[torch.Tensor],
op_type: str,
onnx_dtypes: Sequence[int],
*,
shapes: Sequence[Sequence[Union[int, torch.SymInt]]],
attr_keys: Sequence[str],
attr_types: Sequence[str],
attr_pos: Sequence[tuple[int, int]],
attr_ints: Sequence[int],
attr_floats: Sequence[float],
attr_strs: Sequence[str],
metadata_props_keys: Sequence[str] = (),
metadata_props_values: Sequence[str] = (),
domain: str = "",
version: Optional[int] = None,
) -> list[torch.Tensor]:
outputs = []
torch._check(
len(shapes) == len(onnx_dtypes),
lambda: f"Number of shapes ({len(shapes)}) must match number of ONNX dtypes ({len(onnx_dtypes)})",
)
for shape, onnx_dtype in zip(shapes, onnx_dtypes):
torch._check(
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
)
# NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured
# out how it can handle empty shapes
outputs.append(torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]))
return outputs