mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Create onnx_symbolic (#148905)
In the old exporter we allow users to define a symbolic() method to bypass JIT tracing for a block of logic. We can allow users to do similar things by creating symbolic ops at export. This PR implements `torch.onnx.ops.symbolic` and `torch.onnx.ops.symbolic_multi_out` to allow users to create onnx nodes symbolically with pt2 & fx. The custom pytorch ops were designed such that the attributes are encoded to be part of a valid fx op. Users provide shape and dtype for the meta function to produce the currect fake tensor during export. An example is  Pull Request resolved: https://github.com/pytorch/pytorch/pull/148905 Approved by: https://github.com/titaiwangms
This commit is contained in:
committed by
PyTorch MergeBot
parent
d80a70b58a
commit
010963032c
@ -88,6 +88,7 @@ also be interested in reading our `development wiki <https://github.com/pytorch/
|
||||
:hidden:
|
||||
|
||||
onnx_dynamo
|
||||
onnx_ops
|
||||
onnx_verification
|
||||
onnx_dynamo_onnxruntime_backend
|
||||
onnx_torchscript
|
||||
|
11
docs/source/onnx_ops.rst
Normal file
11
docs/source/onnx_ops.rst
Normal file
@ -0,0 +1,11 @@
|
||||
torch.onnx.ops
|
||||
==============
|
||||
|
||||
.. automodule:: torch.onnx.ops
|
||||
|
||||
Operators
|
||||
---------
|
||||
|
||||
.. autofunction:: torch.onnx.ops.symbolic
|
||||
|
||||
.. autofunction:: torch.onnx.ops.symbolic_multi_out
|
418
test/onnx/ops/test_ops.py
Normal file
418
test/onnx/ops/test_ops.py
Normal file
@ -0,0 +1,418 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
"""Test torch.onnx.ops."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from onnxscript import ir
|
||||
|
||||
import torch
|
||||
from torch.onnx.ops import _symbolic_impl
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
class SchemaTest(common_utils.TestCase):
|
||||
def test_symbolic_has_correct_schema(self):
|
||||
torch.library.opcheck(
|
||||
_symbolic_impl._symbolic,
|
||||
([torch.tensor(1)], "CustomOp", 1),
|
||||
dict(
|
||||
shape=[
|
||||
1,
|
||||
],
|
||||
attr_keys=["key"],
|
||||
attr_types=["i"],
|
||||
attr_pos=[(0, 1)],
|
||||
attr_ints=[1],
|
||||
attr_floats=[1.0],
|
||||
attr_strs=["attr"],
|
||||
metadata_props_keys=["meta_key"],
|
||||
metadata_props_values=["meta_value"],
|
||||
domain="custom_domain",
|
||||
version=42,
|
||||
),
|
||||
)
|
||||
|
||||
# Empty inputs
|
||||
torch.library.opcheck(
|
||||
_symbolic_impl._symbolic,
|
||||
([], "CustomOp", 1),
|
||||
dict(
|
||||
shape=[
|
||||
1,
|
||||
],
|
||||
attr_keys=[],
|
||||
attr_types=[],
|
||||
attr_pos=[],
|
||||
attr_ints=[],
|
||||
attr_floats=[],
|
||||
attr_strs=[],
|
||||
metadata_props_keys=[],
|
||||
metadata_props_values=[],
|
||||
),
|
||||
)
|
||||
|
||||
def test_symbolic_multi_out_has_correct_schema(self):
|
||||
torch.library.opcheck(
|
||||
_symbolic_impl._symbolic_multi_out,
|
||||
([torch.tensor(1)], "CustomMultiOutOp", [1, 2, 10]),
|
||||
dict(
|
||||
shapes=[[1, 2], [42], []],
|
||||
attr_keys=["key"],
|
||||
attr_types=["i"],
|
||||
attr_pos=[(0, 1)],
|
||||
attr_ints=[1],
|
||||
attr_floats=[1.0],
|
||||
attr_strs=["attr"],
|
||||
metadata_props_keys=["meta_key"],
|
||||
metadata_props_values=["meta_value"],
|
||||
domain="",
|
||||
version=1,
|
||||
),
|
||||
)
|
||||
|
||||
# Empty inputs
|
||||
torch.library.opcheck(
|
||||
_symbolic_impl._symbolic_multi_out,
|
||||
([], "CustomMultiOutOp", []),
|
||||
dict(
|
||||
shapes=[],
|
||||
attr_keys=[],
|
||||
attr_types=[],
|
||||
attr_pos=[],
|
||||
attr_ints=[],
|
||||
attr_floats=[],
|
||||
attr_strs=[],
|
||||
metadata_props_keys=[],
|
||||
metadata_props_values=[],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SymbolicOpsTest(common_utils.TestCase):
|
||||
def test_symbolic_accepts_valid_inputs(self):
|
||||
output = torch.onnx.ops.symbolic(
|
||||
"custom_domain::CustomOp",
|
||||
(torch.tensor(1),),
|
||||
dict(
|
||||
int_key=1,
|
||||
float_key=1.0,
|
||||
str_key="attr",
|
||||
bool_key=True,
|
||||
list_int_key=[1, 2],
|
||||
list_float_key=[1.0, 2.0],
|
||||
list_str_key=["attr1", "attr2"],
|
||||
list_bool_key=[True, False],
|
||||
),
|
||||
dtype=torch.float32,
|
||||
shape=[1, 2, 3],
|
||||
version=1,
|
||||
metadata_props={"meta_key": "meta_value"},
|
||||
)
|
||||
self.assertEqual(output.shape, torch.Size([1, 2, 3]))
|
||||
self.assertEqual(output.dtype, torch.float32)
|
||||
self.assertEqual(output.device, torch.device("cpu"))
|
||||
|
||||
def test_symbolic_accepts_valid_inputs_empty_shape(self):
|
||||
output = torch.onnx.ops.symbolic(
|
||||
"custom_domain::CustomOp",
|
||||
(torch.tensor(1),),
|
||||
dtype=torch.float32,
|
||||
shape=[],
|
||||
)
|
||||
self.assertEqual(output.shape, torch.Size([]))
|
||||
|
||||
def test_symbolic_accepts_valid_inputs_integer_types(self):
|
||||
output = torch.onnx.ops.symbolic(
|
||||
"custom_domain::CustomOp",
|
||||
(torch.tensor(1),),
|
||||
dtype=1, # 1 is float32 in ONNX
|
||||
shape=[42],
|
||||
)
|
||||
self.assertEqual(output.dtype, torch.float32)
|
||||
|
||||
def test_symbolic_accepts_valid_inputs_int4_type(self):
|
||||
output = torch.onnx.ops.symbolic(
|
||||
"custom_domain::CustomOp",
|
||||
(torch.tensor(1),),
|
||||
dtype=22, # 22 is INT4 in ONNX
|
||||
shape=[42],
|
||||
)
|
||||
# We use torch uint8 for int4
|
||||
self.assertEqual(output.dtype, torch.uint8)
|
||||
|
||||
def test_symbolic_is_exportable(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return torch.onnx.ops.symbolic(
|
||||
"custom_domain::CustomOp",
|
||||
(x,),
|
||||
dict(
|
||||
int_key=1,
|
||||
float_key=1.0,
|
||||
str_key="attr",
|
||||
bool_key=True,
|
||||
list_int_key=[1, 2],
|
||||
list_float_key=[1.0, 2.0],
|
||||
list_str_key=["attr1", "attr2"],
|
||||
list_bool_key=[True, False],
|
||||
),
|
||||
dtype=x.dtype,
|
||||
shape=[1, 2, 3],
|
||||
version=1,
|
||||
metadata_props={"meta_key": "meta_value"},
|
||||
)
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
Model(), (torch.tensor(1),), dynamo=True, verbose=False
|
||||
)
|
||||
assert onnx_program is not None
|
||||
node = onnx_program.model.graph.node(0)
|
||||
self.assertEqual(node.op_type, "CustomOp")
|
||||
self.assertEqual(node.domain, "custom_domain")
|
||||
attributes = node.attributes
|
||||
self.assertEqual(
|
||||
attributes,
|
||||
dict(
|
||||
int_key=ir.AttrInt64("int_key", 1),
|
||||
float_key=ir.AttrFloat32("float_key", 1.0),
|
||||
str_key=ir.AttrString("str_key", "attr"),
|
||||
bool_key=ir.AttrInt64("bool_key", 1),
|
||||
list_int_key=ir.AttrInt64s("list_int_key", [1, 2]),
|
||||
list_float_key=ir.AttrFloat32s("list_float_key", [1.0, 2.0]),
|
||||
list_str_key=ir.AttrStrings("list_str_key", ["attr1", "attr2"]),
|
||||
list_bool_key=ir.AttrInt64s("list_bool_key", [1, 0]),
|
||||
),
|
||||
)
|
||||
self.assertEqual(node.metadata_props["meta_key"], "meta_value")
|
||||
outputs = node.outputs
|
||||
self.assertEqual(list(outputs[0].shape), [1, 2, 3])
|
||||
self.assertEqual(outputs[0].dtype, ir.DataType.INT64)
|
||||
|
||||
def test_symbolic_preserves_dynamic_shapes(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||
return torch.onnx.ops.symbolic(
|
||||
"custom_domain::CustomOp",
|
||||
(x, y),
|
||||
dtype=x.dtype,
|
||||
shape=[*x.shape, *y.shape],
|
||||
version=1,
|
||||
)
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
Model(),
|
||||
(torch.zeros(2, 3), torch.zeros(1, 2)),
|
||||
dynamic_shapes=({0: "batch"}, {1: "something_else"}),
|
||||
dynamo=True,
|
||||
verbose=False,
|
||||
)
|
||||
assert onnx_program is not None
|
||||
node = onnx_program.model.graph.node(0)
|
||||
self.assertEqual(node.op_type, "CustomOp")
|
||||
self.assertEqual(node.domain, "custom_domain")
|
||||
inputs = onnx_program.model.graph.inputs
|
||||
self.assertEqual(str(inputs[0].shape[0]), "batch")
|
||||
self.assertEqual(inputs[0].shape[1], 3)
|
||||
self.assertEqual(inputs[1].shape[0], 1)
|
||||
self.assertEqual(str(inputs[1].shape[1]), "something_else")
|
||||
outputs = node.outputs
|
||||
self.assertEqual(str(outputs[0].shape[0]), "batch")
|
||||
self.assertEqual(outputs[0].shape[1], 3)
|
||||
self.assertEqual(outputs[0].shape[2], 1)
|
||||
self.assertEqual(str(outputs[0].shape[3]), "something_else")
|
||||
self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
|
||||
|
||||
def test_symbolic_multi_out_accepts_valid_inputs(self):
|
||||
outputs = torch.onnx.ops.symbolic_multi_out(
|
||||
"custom_domain::CustomMultiOutOp",
|
||||
(torch.tensor(1),),
|
||||
dict(
|
||||
int_key=1,
|
||||
float_key=1.0,
|
||||
str_key="attr",
|
||||
bool_key=True,
|
||||
list_int_key=[1, 2],
|
||||
list_float_key=[1.0, 2.0],
|
||||
list_str_key=["attr1", "attr2"],
|
||||
list_bool_key=[True, False],
|
||||
),
|
||||
dtypes=(
|
||||
1, # 1 is float32 in ONNX
|
||||
torch.int32,
|
||||
torch.float8_e4m3fn,
|
||||
),
|
||||
shapes=([1, 2], [42], []),
|
||||
version=1,
|
||||
metadata_props={"meta_key": "meta_value"},
|
||||
)
|
||||
self.assertEqual(len(outputs), 3)
|
||||
self.assertEqual(outputs[0].shape, torch.Size([1, 2]))
|
||||
self.assertEqual(outputs[0].dtype, torch.float32)
|
||||
self.assertEqual(outputs[1].shape, torch.Size([42]))
|
||||
self.assertEqual(outputs[1].dtype, torch.int32)
|
||||
self.assertEqual(outputs[2].shape, torch.Size([]))
|
||||
self.assertEqual(outputs[2].dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(outputs[0].device, torch.device("cpu"))
|
||||
self.assertEqual(outputs[1].device, torch.device("cpu"))
|
||||
self.assertEqual(outputs[2].device, torch.device("cpu"))
|
||||
|
||||
def test_symbolic_multi_out_accepts_valid_inputs_empty_shape(self):
|
||||
outputs = torch.onnx.ops.symbolic_multi_out(
|
||||
"custom_domain::CustomOp",
|
||||
(torch.tensor(1),),
|
||||
dtypes=(torch.float32,),
|
||||
shapes=[[]],
|
||||
)
|
||||
self.assertEqual(outputs[0].shape, torch.Size([]))
|
||||
|
||||
def test_symbolic_multi_out_accepts_valid_inputs_integer_types(self):
|
||||
outputs = torch.onnx.ops.symbolic_multi_out(
|
||||
"custom_domain::CustomOp",
|
||||
(torch.tensor(1),),
|
||||
dtypes=(1,), # 1 is float32 in ONNX
|
||||
shapes=[[42]],
|
||||
)
|
||||
self.assertEqual(outputs[0].dtype, torch.float32)
|
||||
|
||||
def test_symbolic_multi_out_accepts_valid_inputs_int4_type(self):
|
||||
outputs = torch.onnx.ops.symbolic_multi_out(
|
||||
"custom_domain::CustomOp",
|
||||
(torch.tensor(1),),
|
||||
dtypes=(22,), # 22 is INT4 in ONNX
|
||||
shapes=[[42]],
|
||||
)
|
||||
# We use torch uint8 for int4
|
||||
self.assertEqual(outputs[0].dtype, torch.uint8)
|
||||
|
||||
def test_symbolic_multi_out_is_exportable(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return torch.onnx.ops.symbolic_multi_out(
|
||||
"custom_domain::CustomOp",
|
||||
(x,),
|
||||
dict(
|
||||
int_key=1,
|
||||
float_key=1.0,
|
||||
str_key="attr",
|
||||
bool_key=True,
|
||||
list_int_key=[1, 2],
|
||||
list_float_key=[1.0, 2.0],
|
||||
list_str_key=["attr1", "attr2"],
|
||||
list_bool_key=[True, False],
|
||||
),
|
||||
dtypes=(torch.float32, torch.int32, torch.float8_e4m3fn),
|
||||
shapes=([1, 2], [42], []),
|
||||
version=1,
|
||||
metadata_props={"meta_key": "meta_value"},
|
||||
)
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
Model(), (torch.tensor(1),), dynamo=True, verbose=False
|
||||
)
|
||||
assert onnx_program is not None
|
||||
node = onnx_program.model.graph.node(0)
|
||||
self.assertEqual(node.op_type, "CustomOp")
|
||||
self.assertEqual(node.domain, "custom_domain")
|
||||
attributes = node.attributes
|
||||
self.assertEqual(
|
||||
attributes,
|
||||
dict(
|
||||
int_key=ir.AttrInt64("int_key", 1),
|
||||
float_key=ir.AttrFloat32("float_key", 1.0),
|
||||
str_key=ir.AttrString("str_key", "attr"),
|
||||
bool_key=ir.AttrInt64("bool_key", 1),
|
||||
list_int_key=ir.AttrInt64s("list_int_key", [1, 2]),
|
||||
list_float_key=ir.AttrFloat32s("list_float_key", [1.0, 2.0]),
|
||||
list_str_key=ir.AttrStrings("list_str_key", ["attr1", "attr2"]),
|
||||
list_bool_key=ir.AttrInt64s("list_bool_key", [1, 0]),
|
||||
),
|
||||
)
|
||||
self.assertEqual(node.metadata_props["meta_key"], "meta_value")
|
||||
outputs = node.outputs
|
||||
self.assertEqual(list(outputs[0].shape), [1, 2])
|
||||
self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
|
||||
self.assertEqual(list(outputs[1].shape), [42])
|
||||
self.assertEqual(outputs[1].dtype, ir.DataType.INT32)
|
||||
self.assertEqual(list(outputs[2].shape), [])
|
||||
self.assertEqual(outputs[2].dtype, ir.DataType.FLOAT8E4M3FN)
|
||||
|
||||
def test_symbolic_multi_out_preserves_dynamic_shapes(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||
return torch.onnx.ops.symbolic_multi_out(
|
||||
"custom_domain::CustomOp",
|
||||
(x, y),
|
||||
dtypes=(x.dtype, 22), # 22 is INT4
|
||||
shapes=[[*x.shape, *y.shape], [42]],
|
||||
version=1,
|
||||
)
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
Model(),
|
||||
(torch.zeros(2, 3), torch.zeros(1, 2)),
|
||||
dynamic_shapes=({0: "batch"}, {1: "something_else"}),
|
||||
dynamo=True,
|
||||
verbose=False,
|
||||
)
|
||||
assert onnx_program is not None
|
||||
node = onnx_program.model.graph.node(0)
|
||||
self.assertEqual(node.op_type, "CustomOp")
|
||||
self.assertEqual(node.domain, "custom_domain")
|
||||
inputs = onnx_program.model.graph.inputs
|
||||
self.assertEqual(str(inputs[0].shape[0]), "batch")
|
||||
self.assertEqual(inputs[0].shape[1], 3)
|
||||
self.assertEqual(inputs[1].shape[0], 1)
|
||||
self.assertEqual(str(inputs[1].shape[1]), "something_else")
|
||||
outputs = node.outputs
|
||||
self.assertEqual(str(outputs[0].shape[0]), "batch")
|
||||
self.assertEqual(outputs[0].shape[1], 3)
|
||||
self.assertEqual(outputs[0].shape[2], 1)
|
||||
self.assertEqual(str(outputs[0].shape[3]), "something_else")
|
||||
self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
|
||||
self.assertEqual(list(outputs[1].shape), [42])
|
||||
self.assertEqual(outputs[1].dtype, ir.DataType.INT4)
|
||||
|
||||
def test_symbolic_multi_out_raises_when_dtypes_and_shapes_differ(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.onnx.ops.symbolic_multi_out(
|
||||
"custom_domain::CustomMultiOutOp",
|
||||
(torch.tensor(1),),
|
||||
dict(
|
||||
int_key=1,
|
||||
float_key=1.0,
|
||||
str_key="attr",
|
||||
bool_key=True,
|
||||
list_int_key=[1, 2],
|
||||
list_float_key=[1.0, 2.0],
|
||||
list_str_key=["attr1", "attr2"],
|
||||
list_bool_key=[True, False],
|
||||
),
|
||||
dtypes=(torch.float32, torch.int32),
|
||||
shapes=([1, 2], [42], []),
|
||||
version=1,
|
||||
metadata_props={"meta_key": "meta_value"},
|
||||
)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.onnx.ops.symbolic_multi_out(
|
||||
"custom_domain::CustomMultiOutOp",
|
||||
(torch.tensor(1),),
|
||||
dict(
|
||||
int_key=1,
|
||||
float_key=1.0,
|
||||
str_key="attr",
|
||||
bool_key=True,
|
||||
list_int_key=[1, 2],
|
||||
list_float_key=[1.0, 2.0],
|
||||
list_str_key=["attr1", "attr2"],
|
||||
list_bool_key=[True, False],
|
||||
),
|
||||
dtypes=(torch.float32,),
|
||||
shapes=([1, 2], [42]),
|
||||
version=1,
|
||||
metadata_props={"meta_key": "meta_value"},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
@ -4,9 +4,10 @@ from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
# Modules
|
||||
"errors",
|
||||
"ops",
|
||||
"symbolic_helper",
|
||||
"utils",
|
||||
"errors",
|
||||
# All opsets
|
||||
"symbolic_caffe2",
|
||||
"symbolic_opset7",
|
||||
@ -52,7 +53,6 @@ from typing import Any, Callable, TYPE_CHECKING
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch._C import _onnx as _C_onnx
|
||||
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
|
||||
|
||||
@ -77,6 +77,7 @@ from .utils import (
|
||||
|
||||
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
|
||||
errors,
|
||||
ops,
|
||||
symbolic_caffe2,
|
||||
symbolic_helper,
|
||||
symbolic_opset7,
|
||||
|
@ -200,27 +200,33 @@ def _set_shape_type(
|
||||
| tuple[torch.Tensor],
|
||||
complex_to_float: bool,
|
||||
) -> None:
|
||||
# TODO: Consider using meta["tensor_meta"] for this? Would it be faster?
|
||||
if isinstance(meta_val, tuple):
|
||||
logger.warning("Setting shape and type of tensors is not supported yet")
|
||||
if isinstance(meta_val, torch.Tensor):
|
||||
# FIXME: Consider shape for complex values
|
||||
dims = []
|
||||
for dim in meta_val.shape:
|
||||
if isinstance(dim, int):
|
||||
dims.append(dim)
|
||||
else:
|
||||
dims.append(str(dim.node))
|
||||
value.dtype = _torch_dtype_to_onnx_dtype(meta_val.dtype)
|
||||
if complex_to_float:
|
||||
if meta_val.dtype == torch.complex64:
|
||||
value.dtype = ir.DataType.FLOAT
|
||||
# Add 2 as the last dimension if the tensor is complex to hold the real/imag parts
|
||||
dims.append(2)
|
||||
elif meta_val.dtype == torch.complex128:
|
||||
value.dtype = ir.DataType.DOUBLE
|
||||
# Add 2 as the last dimension if the tensor is complex to hold the real/imag parts
|
||||
dims.append(2)
|
||||
|
||||
# If the dtype is set already (e.g. by the onnx_symbolic ops),
|
||||
# we don't need to set it again.
|
||||
#
|
||||
# When a user specifies complex in onnx_symbolic, we consider that to
|
||||
# be the intention even though non of the ONNX ops deals with complex values.
|
||||
# In this case, we don't change the dtype or the shape of the tensor.
|
||||
if value.dtype is None:
|
||||
value.dtype = _torch_dtype_to_onnx_dtype(meta_val.dtype)
|
||||
if complex_to_float:
|
||||
if meta_val.dtype == torch.complex64:
|
||||
value.dtype = ir.DataType.FLOAT
|
||||
# Add 2 as the last dimension if the tensor is complex to hold the real/imag parts
|
||||
dims.append(2)
|
||||
elif meta_val.dtype == torch.complex128:
|
||||
value.dtype = ir.DataType.DOUBLE
|
||||
# Add 2 as the last dimension if the tensor is complex to hold the real/imag parts
|
||||
dims.append(2)
|
||||
|
||||
value.shape = ir.Shape(dims)
|
||||
elif isinstance(meta_val, (int, torch.SymInt)):
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
__all__ = ["core", "hop"]
|
||||
__all__ = ["core", "hop", "symbolic"]
|
||||
|
||||
from torch.onnx._internal.exporter._torchlib.ops import core, hop
|
||||
from torch.onnx._internal.exporter._torchlib.ops import core, hop, symbolic
|
||||
|
149
torch/onnx/_internal/exporter/_torchlib/ops/symbolic.py
Normal file
149
torch/onnx/_internal/exporter/_torchlib/ops/symbolic.py
Normal file
@ -0,0 +1,149 @@
|
||||
"""Implementation for higher-order operators."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from onnxscript.ir import convenience as ir_convenience
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal._lazy_import import onnxscript_ir as ir
|
||||
from torch.onnx._internal.exporter import _core
|
||||
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
|
||||
from torch.onnx.ops import _symbolic_impl
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
def _call_symbolic_op(
|
||||
op_type: str,
|
||||
domain: str,
|
||||
args: Sequence[ir.Value | None],
|
||||
kwargs: dict[str, int | float | str | bool | list[int] | list[float] | list[str]],
|
||||
dtypes: Sequence[int],
|
||||
version: int | None,
|
||||
metadata_props: dict[str, str] | None,
|
||||
) -> Sequence[ir.Value]:
|
||||
"""Call an operator with the given arguments and keyword arguments.
|
||||
|
||||
Arguments are always inputs, while keyword arguments are attributes.
|
||||
"""
|
||||
# This is a wrapper around the IR node creation that hooks into the _builder.OpRecorder
|
||||
# tracer so that all nodes created are recorded the same way as if we were to use
|
||||
# onnxscript ops directly.
|
||||
|
||||
assert _core.current_tracer is not None
|
||||
tracer = _core.current_tracer
|
||||
|
||||
inputs = list(args)
|
||||
|
||||
# If final inputs are None, strip them from the node inputs
|
||||
for input in reversed(inputs):
|
||||
if input is not None:
|
||||
break
|
||||
inputs.pop()
|
||||
|
||||
# Construct and filter out None attributes
|
||||
attributes = [
|
||||
attr
|
||||
for attr in ir_convenience.convert_attributes(kwargs) # type: ignore[arg-type]
|
||||
if attr.value is not None # type: ignore[union-attr]
|
||||
]
|
||||
tracer.nodes.append(
|
||||
node := ir.Node(
|
||||
domain,
|
||||
op_type,
|
||||
inputs=inputs,
|
||||
attributes=attributes,
|
||||
num_outputs=len(dtypes),
|
||||
version=version,
|
||||
metadata_props=metadata_props,
|
||||
)
|
||||
)
|
||||
# Set the dtypes for the outputs. We set them here because the graph builder
|
||||
# Uses PyTorch types which are sometimes inaccurate when they are ONNX only
|
||||
# types like float4e2m1.
|
||||
for value, dtype in zip(node.outputs, dtypes):
|
||||
value.dtype = ir.DataType(dtype)
|
||||
# The shape is set by the graph builder. We don't need to set it here.
|
||||
return node.outputs
|
||||
|
||||
|
||||
@onnx_impl(torch.ops.onnx_symbolic._symbolic.default, no_compile=True)
|
||||
def onnx_symbolic_symbolic(
|
||||
inputs: Sequence[ir.Value | None],
|
||||
op_type: str,
|
||||
onnx_dtype: int,
|
||||
*,
|
||||
shape: Sequence[int | ir.Value],
|
||||
attr_keys: Sequence[str],
|
||||
attr_types: Sequence[str],
|
||||
attr_pos: Sequence[tuple[int, int]],
|
||||
attr_ints: Sequence[int],
|
||||
attr_floats: Sequence[float],
|
||||
attr_strs: Sequence[str],
|
||||
metadata_props_keys: Sequence[str] = (),
|
||||
metadata_props_values: Sequence[str] = (),
|
||||
domain: str = "",
|
||||
version: int | None = None,
|
||||
) -> ir.Value:
|
||||
del shape # Unused. The shapes are set by the graph builder
|
||||
encoded = _symbolic_impl.EncodedAttrs(
|
||||
attr_keys=list(attr_keys),
|
||||
attr_types=list(attr_types),
|
||||
attr_pos=list(attr_pos),
|
||||
attr_ints=list(attr_ints),
|
||||
attr_floats=list(attr_floats),
|
||||
attr_strs=list(attr_strs),
|
||||
)
|
||||
attrs = encoded.to_dict()
|
||||
return _call_symbolic_op(
|
||||
op_type,
|
||||
domain,
|
||||
inputs,
|
||||
attrs,
|
||||
dtypes=[onnx_dtype],
|
||||
version=version,
|
||||
metadata_props=dict(zip(metadata_props_keys, metadata_props_values)),
|
||||
)[0]
|
||||
|
||||
|
||||
@onnx_impl(torch.ops.onnx_symbolic._symbolic_multi_out.default, no_compile=True)
|
||||
def onnx_symbolic_symbolic_multi_out(
|
||||
inputs: Sequence[ir.Value | None],
|
||||
op_type: str,
|
||||
onnx_dtypes: Sequence[int],
|
||||
*,
|
||||
shapes: Sequence[Sequence[int | ir.Value]],
|
||||
attr_keys: Sequence[str],
|
||||
attr_types: Sequence[str],
|
||||
attr_pos: Sequence[tuple[int, int]],
|
||||
attr_ints: Sequence[int],
|
||||
attr_floats: Sequence[float],
|
||||
attr_strs: Sequence[str],
|
||||
metadata_props_keys: Sequence[str] = (),
|
||||
metadata_props_values: Sequence[str] = (),
|
||||
domain: str = "",
|
||||
version: int | None = None,
|
||||
) -> Sequence[ir.Value]:
|
||||
del shapes # Unused. The shapes are set by the graph builder
|
||||
encoded = _symbolic_impl.EncodedAttrs(
|
||||
attr_keys=list(attr_keys),
|
||||
attr_types=list(attr_types),
|
||||
attr_pos=list(attr_pos),
|
||||
attr_ints=list(attr_ints),
|
||||
attr_floats=list(attr_floats),
|
||||
attr_strs=list(attr_strs),
|
||||
)
|
||||
attrs = encoded.to_dict()
|
||||
return _call_symbolic_op(
|
||||
op_type,
|
||||
domain,
|
||||
inputs,
|
||||
attrs,
|
||||
dtypes=onnx_dtypes,
|
||||
version=version,
|
||||
metadata_props=dict(zip(metadata_props_keys, metadata_props_values)),
|
||||
)
|
243
torch/onnx/ops/__init__.py
Normal file
243
torch/onnx/ops/__init__.py
Normal file
@ -0,0 +1,243 @@
|
||||
"""ONNX operators as native torch.fx operators.
|
||||
|
||||
This module provides a set of functions to create ONNX operators in the FX graph
|
||||
which are exportable to ONNX.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.onnx.ops import _symbolic_impl
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
# https://github.com/onnx/onnx/blob/f542e1f06699ea7e1db5f62af53355b64338c723/onnx/onnx.proto#L597
|
||||
_TORCH_DTYPE_TO_ONNX_DTYPE = {
|
||||
torch.float32: 1, # FLOAT
|
||||
torch.uint8: 2, # UINT8
|
||||
torch.int8: 3, # INT8
|
||||
torch.uint16: 4, # UINT16
|
||||
torch.int16: 5, # INT16
|
||||
torch.int32: 6, # INT32
|
||||
torch.int64: 7, # INT64
|
||||
str: 8, # STRING
|
||||
torch.bool: 9, # BOOL
|
||||
torch.float16: 10, # FLOAT16
|
||||
torch.double: 11, # DOUBLE
|
||||
torch.uint32: 12, # UINT32
|
||||
torch.uint64: 13, # UINT64
|
||||
torch.complex64: 14, # COMPLEX64
|
||||
torch.complex128: 15, # COMPLEX128
|
||||
torch.bfloat16: 16, # BFLOAT16
|
||||
torch.float8_e4m3fn: 17, # FLOAT8E4M3FN
|
||||
torch.float8_e4m3fnuz: 18, # FLOAT8E4M3FNUZ
|
||||
torch.float8_e5m2: 19, # FLOAT8E5M2
|
||||
torch.float8_e5m2fnuz: 20, # FLOAT8E5M2FNUZ
|
||||
}
|
||||
|
||||
|
||||
def _parse_domain_op_type(domain_op: str) -> tuple[str, str]:
|
||||
splitted = domain_op.split("::", 1)
|
||||
if len(splitted) == 1:
|
||||
domain = ""
|
||||
op_type = splitted[0]
|
||||
else:
|
||||
domain = splitted[0]
|
||||
op_type = splitted[1]
|
||||
return domain, op_type
|
||||
|
||||
|
||||
def symbolic(
|
||||
domain_op: str,
|
||||
/,
|
||||
inputs: Sequence[torch.Tensor],
|
||||
attrs: dict[
|
||||
str,
|
||||
int
|
||||
| float
|
||||
| str
|
||||
| bool
|
||||
| Sequence[int]
|
||||
| Sequence[float]
|
||||
| Sequence[str]
|
||||
| Sequence[bool],
|
||||
]
|
||||
| None = None,
|
||||
*,
|
||||
dtype: torch.dtype | int,
|
||||
shape: Sequence[int | torch.SymInt],
|
||||
version: int | None = None,
|
||||
metadata_props: dict[str, str] | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Create a symbolic FX operator to represent an arbitrary ONNX operator.
|
||||
|
||||
This function is used to create a symbolic operator with a single output.
|
||||
To create an operator with multiple outputs, use :func:`symbolic_multi_out`.
|
||||
|
||||
Example::
|
||||
|
||||
class CustomOp(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return torch.onnx.ops.symbolic(
|
||||
"custom_domain::CustomOp",
|
||||
(x,),
|
||||
dict(attr_key="attr_value"),
|
||||
dtype=x.dtype,
|
||||
shape=x.shape,
|
||||
version=1,
|
||||
)
|
||||
# This will create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain.
|
||||
# The output tensor will have the specified dtype and shape.
|
||||
|
||||
|
||||
# You may then export this model to ONNX using torch.onnx.export.
|
||||
|
||||
Args:
|
||||
domain_op: The domain and operator name, separated by "::". For example,
|
||||
"custom_domain::CustomOp".
|
||||
inputs: The input tensors to the operator.
|
||||
attrs: The attributes of the operator. The keys are attribute names and
|
||||
the values are attribute values. Valid attribute types are int, float,
|
||||
str, bool, and lists of int, float, str, and bool. Tensor attributes
|
||||
are unsupported.
|
||||
dtype: The data type of the output tensor.This can be either a torch.dtype
|
||||
or an integer representing the ONNX data type.
|
||||
shape: The shape of the output tensor. This can be a list of integers or
|
||||
SymInt values.
|
||||
version: The version of the opset used for the operator.
|
||||
metadata_props: Metadata properties for the ONNX node.
|
||||
This is a dictionary of str-str pairs.
|
||||
|
||||
Returns:
|
||||
The output tensor of the operator.
|
||||
"""
|
||||
if not isinstance(dtype, int):
|
||||
torch._check(
|
||||
dtype in _TORCH_DTYPE_TO_ONNX_DTYPE, lambda: f"Unsupported dtype: {dtype}"
|
||||
)
|
||||
dtype = _TORCH_DTYPE_TO_ONNX_DTYPE[dtype]
|
||||
domain, op_type = _parse_domain_op_type(domain_op)
|
||||
if attrs is None:
|
||||
attrs = {}
|
||||
encoded_attrs = _symbolic_impl.EncodedAttrs.from_dict(attrs)
|
||||
# TODO: Parse domain
|
||||
return _symbolic_impl._symbolic(
|
||||
inputs,
|
||||
op_type,
|
||||
dtype,
|
||||
shape=shape,
|
||||
attr_keys=encoded_attrs.attr_keys,
|
||||
attr_types=encoded_attrs.attr_types,
|
||||
attr_pos=encoded_attrs.attr_pos,
|
||||
attr_ints=encoded_attrs.attr_ints,
|
||||
attr_floats=encoded_attrs.attr_floats,
|
||||
attr_strs=encoded_attrs.attr_strs,
|
||||
metadata_props_keys=metadata_props.keys() if metadata_props else [],
|
||||
metadata_props_values=metadata_props.values() if metadata_props else [],
|
||||
domain=domain,
|
||||
version=version,
|
||||
)
|
||||
|
||||
|
||||
def symbolic_multi_out(
|
||||
domain_op: str,
|
||||
/,
|
||||
inputs: Sequence[torch.Tensor],
|
||||
attrs: dict[
|
||||
str,
|
||||
int
|
||||
| float
|
||||
| str
|
||||
| bool
|
||||
| Sequence[int]
|
||||
| Sequence[float]
|
||||
| Sequence[str]
|
||||
| Sequence[bool],
|
||||
]
|
||||
| None = None,
|
||||
*,
|
||||
dtypes: Sequence[torch.dtype | int],
|
||||
shapes: Sequence[Sequence[int | torch.SymInt]],
|
||||
version: int | None = None,
|
||||
metadata_props: dict[str, str] | None = None,
|
||||
) -> Sequence[torch.Tensor]:
|
||||
"""Create a symbolic FX operator to represent an arbitrary ONNX operator with multiple outputs.
|
||||
|
||||
Example::
|
||||
|
||||
class CustomOp(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return torch.onnx.ops.symbolic(
|
||||
"custom_domain::CustomOp",
|
||||
(x,),
|
||||
dict(attr_key="attr_value"),
|
||||
dtypes=(x.dtype, torch.float32),
|
||||
shapes=(x.shape, [1, 2, 3]),
|
||||
version=1,
|
||||
)
|
||||
# This will create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain.
|
||||
# The output tensor will have the specified dtype and shape.
|
||||
|
||||
|
||||
# You may then export this model to ONNX using torch.onnx.export.
|
||||
|
||||
Args:
|
||||
domain_op: The domain and operator name, separated by "::". For example,
|
||||
"custom_domain::CustomOp".
|
||||
inputs: The input tensors to the operator.
|
||||
attrs: The attributes of the operator. The keys are attribute names and
|
||||
the values are attribute values. Valid attribute types are int, float,
|
||||
str, bool, and lists of int, float, str, and bool. Tensor attributes
|
||||
are unsupported.
|
||||
dtypes: The data types of the output tensors. This can be a list of
|
||||
torch.dtype or integers representing the ONNX data types. The length
|
||||
of this list must be the number of outputs.
|
||||
shapes: The shapes of the output tensors. This can be a list of lists of
|
||||
integers or SymInt values. The length of this list must be the number of outputs.
|
||||
version: The version of the opset used for the operator.
|
||||
metadata_props: Metadata properties for the ONNX node.
|
||||
This is a dictionary of str-str pairs.
|
||||
|
||||
Returns:
|
||||
A list of output tensors of the operator.
|
||||
"""
|
||||
torch._check(
|
||||
len(shapes) == len(dtypes),
|
||||
lambda: f"Number of shapes ({len(shapes)}) must match number of dtypes ({len(dtypes)})",
|
||||
)
|
||||
onnx_dtypes = []
|
||||
for dtype in dtypes:
|
||||
if not isinstance(dtype, int):
|
||||
torch._check(
|
||||
dtype in _TORCH_DTYPE_TO_ONNX_DTYPE,
|
||||
lambda: f"Unsupported dtype: {dtype}",
|
||||
)
|
||||
onnx_dtypes.append(_TORCH_DTYPE_TO_ONNX_DTYPE[dtype])
|
||||
else:
|
||||
onnx_dtypes.append(dtype)
|
||||
domain, op_type = _parse_domain_op_type(domain_op)
|
||||
if attrs is None:
|
||||
attrs = {}
|
||||
encoded_attrs = _symbolic_impl.EncodedAttrs.from_dict(attrs)
|
||||
# Use the size of dtypes to determine the number of outputs
|
||||
return _symbolic_impl._symbolic_multi_out(
|
||||
inputs,
|
||||
op_type,
|
||||
onnx_dtypes,
|
||||
shapes=shapes,
|
||||
attr_keys=encoded_attrs.attr_keys,
|
||||
attr_types=encoded_attrs.attr_types,
|
||||
attr_pos=encoded_attrs.attr_pos,
|
||||
attr_ints=encoded_attrs.attr_ints,
|
||||
attr_floats=encoded_attrs.attr_floats,
|
||||
attr_strs=encoded_attrs.attr_strs,
|
||||
metadata_props_keys=metadata_props.keys() if metadata_props else [],
|
||||
metadata_props_values=metadata_props.values() if metadata_props else [],
|
||||
domain=domain,
|
||||
version=version,
|
||||
)
|
330
torch/onnx/ops/_symbolic_impl.py
Normal file
330
torch/onnx/ops/_symbolic_impl.py
Normal file
@ -0,0 +1,330 @@
|
||||
"""Implementation of symbolic FX ops to represent arbitrary ONNX ops.
|
||||
|
||||
This module provides a way to create symbolic FX operators that can represent
|
||||
arbitrary ONNX operators.
|
||||
|
||||
The operators are called "symbolic" because they don't do any actual computation
|
||||
but instead serve as placeholders in the computation graph.
|
||||
|
||||
Each implementation contains two parts: A "real" implementation that produce all
|
||||
zeros based on the input shape and dtype, and a "fake" implementation that does more
|
||||
or less the same thing but is required by the `torch.library.custom_op` interface.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
_ONNX_DTYPE_TO_TORCH_DTYPE: dict[int, torch.dtype] = {
|
||||
1: torch.float32, # FLOAT
|
||||
2: torch.uint8, # UINT8
|
||||
3: torch.int8, # INT8
|
||||
4: torch.uint16, # UINT16
|
||||
5: torch.int16, # INT16
|
||||
6: torch.int32, # INT32
|
||||
7: torch.int64, # INT64
|
||||
9: torch.bool, # BOOL
|
||||
10: torch.float16, # FLOAT16
|
||||
11: torch.double, # DOUBLE
|
||||
12: torch.uint32, # UINT32
|
||||
13: torch.uint64, # UINT64
|
||||
14: torch.complex64, # COMPLEX64
|
||||
15: torch.complex128, # COMPLEX128
|
||||
16: torch.bfloat16, # BFLOAT16
|
||||
17: torch.float8_e4m3fn, # FLOAT8E4M3FN
|
||||
18: torch.float8_e4m3fnuz, # FLOAT8E4M3FNUZ
|
||||
19: torch.float8_e5m2, # FLOAT8E5M2
|
||||
20: torch.float8_e5m2fnuz, # FLOAT8E5M2FNUZ
|
||||
21: torch.uint8, # UINT4
|
||||
22: torch.uint8, # INT4
|
||||
23: torch.uint8, # FLOAT4E2M1
|
||||
}
|
||||
|
||||
_INT_TYPE = "i"
|
||||
_FLOAT_TYPE = "f"
|
||||
_STRING_TYPE = "s"
|
||||
_INT_SEQ_TYPE = "is"
|
||||
_FLOAT_SEQ_TYPE = "fs"
|
||||
_STRING_SEQ_TYPE = "ss"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class EncodedAttrs:
|
||||
"""Class to encode attributes from dictionary into lists of FX compatible attributes.
|
||||
|
||||
Since FX does not support dictionaries, we need to encode the attributes into
|
||||
lists. This class provides a way to encode and decode the attributes.
|
||||
|
||||
Attributes:
|
||||
attr_keys: List of attribute keys.
|
||||
attr_types: List of attribute types. Values can be "i" (int), "f" (float),
|
||||
"s" (string), "is" (int sequence), "fs" (float sequence), or "ss" (string sequence).
|
||||
attr_pos: List of tuples representing the start and end positions of each
|
||||
attribute in the corresponding list.
|
||||
attr_ints: List of integer attributes.
|
||||
attr_floats: List of float attributes.
|
||||
attr_strs: List of string attributes.
|
||||
"""
|
||||
|
||||
attr_keys: list[str]
|
||||
attr_types: list[str]
|
||||
attr_pos: list[tuple[int, int]]
|
||||
attr_ints: list[int]
|
||||
attr_floats: list[float]
|
||||
attr_strs: list[str]
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
attrs: dict[
|
||||
str,
|
||||
Union[
|
||||
int,
|
||||
float,
|
||||
str,
|
||||
bool,
|
||||
Sequence[int],
|
||||
Sequence[float],
|
||||
Sequence[str],
|
||||
Sequence[bool],
|
||||
],
|
||||
],
|
||||
) -> "EncodedAttrs":
|
||||
encoded = cls(
|
||||
attr_keys=[],
|
||||
attr_types=[],
|
||||
attr_pos=[],
|
||||
attr_ints=[],
|
||||
attr_floats=[],
|
||||
attr_strs=[],
|
||||
)
|
||||
for i, (k, v) in enumerate(attrs.items()):
|
||||
encoded.attr_keys.append(k)
|
||||
if isinstance(v, int):
|
||||
start_pos = len(encoded.attr_ints)
|
||||
encoded.attr_ints.append(v)
|
||||
encoded.attr_pos.append((start_pos, start_pos + 1))
|
||||
encoded.attr_types.append(_INT_TYPE)
|
||||
elif isinstance(v, float):
|
||||
start_pos = len(encoded.attr_floats)
|
||||
encoded.attr_floats.append(v)
|
||||
encoded.attr_pos.append((start_pos, start_pos + 1))
|
||||
encoded.attr_types.append(_FLOAT_TYPE)
|
||||
elif isinstance(v, str):
|
||||
start_pos = len(encoded.attr_strs)
|
||||
encoded.attr_strs.append(v)
|
||||
encoded.attr_pos.append((start_pos, start_pos + 1))
|
||||
encoded.attr_types.append(_STRING_TYPE)
|
||||
elif isinstance(v, Sequence):
|
||||
if len(v) == 0:
|
||||
raise ValueError(f"Empty sequence for attribute {k}")
|
||||
if any(isinstance(elem, float) for elem in v):
|
||||
start_pos = len(encoded.attr_floats)
|
||||
encoded.attr_floats.extend([float(elem) for elem in v])
|
||||
encoded.attr_pos.append((start_pos, start_pos + len(v)))
|
||||
encoded.attr_types.append(_FLOAT_SEQ_TYPE)
|
||||
elif isinstance(v[0], int):
|
||||
start_pos = len(encoded.attr_ints)
|
||||
encoded.attr_ints.extend([int(elem) for elem in v])
|
||||
encoded.attr_pos.append((start_pos, start_pos + len(v)))
|
||||
encoded.attr_types.append(_INT_SEQ_TYPE)
|
||||
elif isinstance(v[0], str):
|
||||
start_pos = len(encoded.attr_strs)
|
||||
encoded.attr_strs.extend([str(elem) for elem in v])
|
||||
encoded.attr_pos.append((start_pos, start_pos + len(v)))
|
||||
encoded.attr_types.append(_STRING_SEQ_TYPE)
|
||||
else:
|
||||
raise ValueError(f"Unsupported sequence type for attribute {k}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported attribute type for {k}: {type(v)}")
|
||||
assert len(encoded.attr_keys) == len(encoded.attr_types), (
|
||||
f"Mismatch between number of attribute keys and types: {len(encoded.attr_keys)} != {len(encoded.attr_types)}"
|
||||
)
|
||||
assert len(encoded.attr_keys) == len(encoded.attr_pos), (
|
||||
f"Mismatch between number of attribute keys and positions: {len(encoded.attr_keys)} != {len(encoded.attr_pos)}"
|
||||
)
|
||||
return encoded
|
||||
|
||||
def to_dict(
|
||||
self,
|
||||
) -> dict[
|
||||
str,
|
||||
Union[
|
||||
int,
|
||||
float,
|
||||
str,
|
||||
list[int],
|
||||
list[float],
|
||||
list[str],
|
||||
],
|
||||
]:
|
||||
"""Convert the encoded attributes back to a dictionary for creating an ONNX node."""
|
||||
attrs: dict[
|
||||
str,
|
||||
Union[
|
||||
int,
|
||||
float,
|
||||
str,
|
||||
list[int],
|
||||
list[float],
|
||||
list[str],
|
||||
],
|
||||
] = {}
|
||||
for i, key in enumerate(self.attr_keys):
|
||||
attr_type = self.attr_types[i]
|
||||
if attr_type == _INT_TYPE:
|
||||
attrs[key] = self.attr_ints[self.attr_pos[i][0]]
|
||||
elif attr_type == _FLOAT_TYPE:
|
||||
attrs[key] = self.attr_floats[self.attr_pos[i][0]]
|
||||
elif attr_type == _STRING_TYPE:
|
||||
attrs[key] = self.attr_strs[self.attr_pos[i][0]]
|
||||
elif attr_type == _FLOAT_SEQ_TYPE:
|
||||
attrs[key] = self.attr_floats[self.attr_pos[i][0] : self.attr_pos[i][1]]
|
||||
elif attr_type == _INT_SEQ_TYPE:
|
||||
attrs[key] = self.attr_ints[self.attr_pos[i][0] : self.attr_pos[i][1]]
|
||||
elif attr_type == _STRING_SEQ_TYPE:
|
||||
attrs[key] = self.attr_strs[self.attr_pos[i][0] : self.attr_pos[i][1]]
|
||||
else:
|
||||
raise ValueError(f"Unsupported attribute type: {attr_type}")
|
||||
return attrs
|
||||
|
||||
|
||||
@torch.library.custom_op(
|
||||
"onnx_symbolic::_symbolic",
|
||||
mutates_args=(),
|
||||
schema=(
|
||||
"(Tensor?[] inputs, str op_type, int onnx_dtype, *,"
|
||||
" SymInt[] shape, str[] attr_keys, str[] attr_types, int[][] attr_pos,"
|
||||
" int[] attr_ints, float[] attr_floats, str[] attr_strs, str[] metadata_props_keys,"
|
||||
" str[] metadata_props_values, str domain='', int? version=None"
|
||||
") -> Tensor"
|
||||
),
|
||||
)
|
||||
def _symbolic(
|
||||
inputs: Sequence[Optional[torch.Tensor]],
|
||||
op_type: str,
|
||||
onnx_dtype: int,
|
||||
*,
|
||||
shape: Sequence[Union[int, torch.SymInt]],
|
||||
attr_keys: Sequence[str],
|
||||
attr_types: Sequence[str],
|
||||
attr_pos: Sequence[tuple[int, int]],
|
||||
attr_ints: Sequence[int],
|
||||
attr_floats: Sequence[float],
|
||||
attr_strs: Sequence[str],
|
||||
metadata_props_keys: Sequence[str] = (),
|
||||
metadata_props_values: Sequence[str] = (),
|
||||
domain: str = "",
|
||||
version: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
torch._check(
|
||||
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
|
||||
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
||||
)
|
||||
return torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype])
|
||||
|
||||
|
||||
@_symbolic.register_fake
|
||||
def _(
|
||||
inputs: Sequence[torch.Tensor],
|
||||
op_type: str,
|
||||
onnx_dtype: int,
|
||||
*,
|
||||
shape: Sequence[Union[int, torch.SymInt]],
|
||||
attr_keys: Sequence[str],
|
||||
attr_types: Sequence[str],
|
||||
attr_pos: Sequence[tuple[int, int]],
|
||||
attr_ints: Sequence[int],
|
||||
attr_floats: Sequence[float],
|
||||
attr_strs: Sequence[str],
|
||||
metadata_props_keys: Sequence[str] = (),
|
||||
metadata_props_values: Sequence[str] = (),
|
||||
domain: str = "",
|
||||
version: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
torch._check(
|
||||
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
|
||||
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
||||
)
|
||||
# NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured
|
||||
# out how it can handle empty shapes
|
||||
return torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype])
|
||||
|
||||
|
||||
@torch.library.custom_op(
|
||||
"onnx_symbolic::_symbolic_multi_out",
|
||||
mutates_args=(),
|
||||
schema=(
|
||||
"(Tensor?[] inputs, str op_type, int[] onnx_dtypes, *,"
|
||||
" SymInt[][] shapes, str[] attr_keys, str[] attr_types, int[][] attr_pos,"
|
||||
" int[] attr_ints, float[] attr_floats, str[] attr_strs, str[] metadata_props_keys,"
|
||||
" str[] metadata_props_values, str domain='', int? version=None"
|
||||
") -> Tensor[]"
|
||||
),
|
||||
)
|
||||
def _symbolic_multi_out(
|
||||
inputs: Sequence[Optional[torch.Tensor]],
|
||||
op_type: str,
|
||||
onnx_dtypes: Sequence[int],
|
||||
*,
|
||||
shapes: Sequence[Sequence[Union[int, torch.SymInt]]],
|
||||
attr_keys: Sequence[str],
|
||||
attr_types: Sequence[str],
|
||||
attr_pos: Sequence[tuple[int, int]],
|
||||
attr_ints: Sequence[int],
|
||||
attr_floats: Sequence[float],
|
||||
attr_strs: Sequence[str],
|
||||
metadata_props_keys: Sequence[str] = (),
|
||||
metadata_props_values: Sequence[str] = (),
|
||||
domain: str = "",
|
||||
version: Optional[int] = None,
|
||||
) -> list[torch.Tensor]:
|
||||
outputs = []
|
||||
torch._check(
|
||||
len(shapes) == len(onnx_dtypes),
|
||||
lambda: f"Number of shapes ({len(shapes)}) must match number of ONNX dtypes ({len(onnx_dtypes)})",
|
||||
)
|
||||
for shape, onnx_dtype in zip(shapes, onnx_dtypes):
|
||||
torch._check(
|
||||
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
|
||||
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
||||
)
|
||||
outputs.append(torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]))
|
||||
return outputs
|
||||
|
||||
|
||||
@_symbolic_multi_out.register_fake
|
||||
def _(
|
||||
inputs: Sequence[torch.Tensor],
|
||||
op_type: str,
|
||||
onnx_dtypes: Sequence[int],
|
||||
*,
|
||||
shapes: Sequence[Sequence[Union[int, torch.SymInt]]],
|
||||
attr_keys: Sequence[str],
|
||||
attr_types: Sequence[str],
|
||||
attr_pos: Sequence[tuple[int, int]],
|
||||
attr_ints: Sequence[int],
|
||||
attr_floats: Sequence[float],
|
||||
attr_strs: Sequence[str],
|
||||
metadata_props_keys: Sequence[str] = (),
|
||||
metadata_props_values: Sequence[str] = (),
|
||||
domain: str = "",
|
||||
version: Optional[int] = None,
|
||||
) -> list[torch.Tensor]:
|
||||
outputs = []
|
||||
torch._check(
|
||||
len(shapes) == len(onnx_dtypes),
|
||||
lambda: f"Number of shapes ({len(shapes)}) must match number of ONNX dtypes ({len(onnx_dtypes)})",
|
||||
)
|
||||
for shape, onnx_dtype in zip(shapes, onnx_dtypes):
|
||||
torch._check(
|
||||
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
|
||||
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
||||
)
|
||||
# NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured
|
||||
# out how it can handle empty shapes
|
||||
outputs.append(torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]))
|
||||
return outputs
|
Reference in New Issue
Block a user