mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor torch.onnx documentation (#108379)
* Distinguish both TorchScript-based exporter (`torch.onnx.export`) and the TorchDynamo-based exporter (`torch.onnx.dynamo_export`) exporters * Merge ONNX diagnostics page with the exporter page * Add initial version of a quick overview on the new exporter * Updates `torch.compiler.html` with the right page for the ONNX Runtime backend for `torch.compile` * Renamed doc files to clearly identify files belonging to the legacy and newer onnx exporters Fixes #108274 https://docs-preview.pytorch.org/pytorch/pytorch/108379/index.html Pull Request resolved: https://github.com/pytorch/pytorch/pull/108379 Approved by: https://github.com/justinchuby, https://github.com/wschin, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
e91f66471c
commit
aa3355da8a
1
.github/merge_rules.yaml
vendored
1
.github/merge_rules.yaml
vendored
@ -7,6 +7,7 @@
|
||||
- docs/source/onnx.rst
|
||||
- docs/source/onnx*
|
||||
- docs/source/scripts/onnx/**
|
||||
- docs/source/_static/img/onnx/**
|
||||
- scripts/onnx/**
|
||||
- test/onnx/**
|
||||
- tools/onnx/**
|
||||
|
@ -18,8 +18,8 @@ figures:
|
||||
@$(PYCMD) source/scripts/build_quantization_configs.py
|
||||
|
||||
onnx:
|
||||
@$(PYCMD) source/scripts/onnx/build_onnx_supported_aten_op_csv_table.py
|
||||
@$(PYCMD) source/scripts/onnx/build_onnx_diagnostics_rules_md.py $(SOURCEDIR)/generated/onnx_diagnostics_rules
|
||||
@$(PYCMD) source/scripts/onnx/build_onnx_torchscript_supported_aten_op_csv_table.py
|
||||
@$(PYCMD) source/scripts/onnx/build_onnx_dynamo_diagnostics_rules_md.py $(SOURCEDIR)/generated/onnx_dynamo_diagnostics_rules
|
||||
|
||||
opset:
|
||||
@$(PYCMD) source/scripts/build_opsets.py
|
||||
|
BIN
docs/source/_static/img/onnx/onnx_dynamo_mlp_model.png
Normal file
BIN
docs/source/_static/img/onnx/onnx_dynamo_mlp_model.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 36 KiB |
Binary file not shown.
After Width: | Height: | Size: 11 KiB |
Binary file not shown.
After Width: | Height: | Size: 5.7 KiB |
@ -94,7 +94,6 @@ Features described in this documentation are classified by release status:
|
||||
profiler
|
||||
nn.init
|
||||
onnx
|
||||
onnx_diagnostics
|
||||
optim
|
||||
complex_numbers
|
||||
ddp_comm_hooks
|
||||
|
@ -1,745 +1,64 @@
|
||||
torch.onnx
|
||||
==========
|
||||
|
||||
.. contents:: :local:
|
||||
|
||||
.. automodule:: torch.onnx
|
||||
Overview
|
||||
--------
|
||||
|
||||
`Open Neural Network eXchange (ONNX) <https://onnx.ai/>`_ is an open standard
|
||||
format for representing machine learning models. The torch.onnx module can export
|
||||
PyTorch models to ONNX. The model can then be consumed by any of the many
|
||||
`runtimes that support ONNX <https://onnx.ai/supported-tools.html#deployModel>`_.
|
||||
format for representing machine learning models. The ``torch.onnx`` module captures the computation graph from a
|
||||
native PyTorch :class:`torch.nn.Module` model and converts it into an
|
||||
`ONNX graph <https://github.com/onnx/onnx/blob/main/docs/IR.md>`_.
|
||||
|
||||
Example: AlexNet from PyTorch to ONNX
|
||||
-------------------------------------
|
||||
The exported model can be consumed by any of the many
|
||||
`runtimes that support ONNX <https://onnx.ai/supported-tools.html#deployModel>`_, including
|
||||
Microsoft's `ONNX Runtime <https://www.onnxruntime.ai>`_.
|
||||
|
||||
Here is a simple script which exports a pretrained AlexNet to an ONNX file named ``alexnet.onnx``.
|
||||
The call to ``torch.onnx.export`` runs the model once to trace its execution and then exports the
|
||||
traced model to the specified file::
|
||||
**There are two flavors of ONNX exporter API that you can use, as listed below:**
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
TorchDynamo-based ONNX Exporter
|
||||
-------------------------------
|
||||
|
||||
dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
|
||||
model = torchvision.models.alexnet(pretrained=True).cuda()
|
||||
*The TorchDynamo-based ONNX exporter is the newest (and Beta) exporter for PyTorch 2.0 and newer*
|
||||
|
||||
# Providing input and output names sets the display names for values
|
||||
# within the model's graph. Setting these does not change the semantics
|
||||
# of the graph; it is only for readability.
|
||||
#
|
||||
# The inputs to the network consist of the flat list of inputs (i.e.
|
||||
# the values you would pass to the forward() method) followed by the
|
||||
# flat list of parameters. You can partially specify names, i.e. provide
|
||||
# a list here shorter than the number of inputs to the model, and we will
|
||||
# only set that subset of names, starting from the beginning.
|
||||
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
|
||||
output_names = [ "output1" ]
|
||||
TorchDynamo engine is leveraged to hook into Python's frame evaluation API and dynamically rewrite its
|
||||
bytecode into an FX Graph. The resulting FX Graph is then polished before it is finally translated into an
|
||||
ONNX graph.
|
||||
|
||||
torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)
|
||||
The main advantage of this approach is that the `FX graph <https://pytorch.org/docs/stable/fx.html>`_ is captured using
|
||||
bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques.
|
||||
|
||||
The resulting ``alexnet.onnx`` file contains a binary `protocol buffer <https://developers.google.com/protocol-buffers/>`_
|
||||
which contains both the network structure and parameters of the model you exported
|
||||
(in this case, AlexNet). The argument ``verbose=True`` causes the
|
||||
exporter to print out a human-readable representation of the model::
|
||||
:doc:`Learn more about the TorchDynamo-based ONNX Exporter <onnx_dynamo>`
|
||||
|
||||
# These are the inputs and parameters to the network, which have taken on
|
||||
# the names we specified earlier.
|
||||
graph(%actual_input_1 : Float(10, 3, 224, 224)
|
||||
%learned_0 : Float(64, 3, 11, 11)
|
||||
%learned_1 : Float(64)
|
||||
%learned_2 : Float(192, 64, 5, 5)
|
||||
%learned_3 : Float(192)
|
||||
# ---- omitted for brevity ----
|
||||
%learned_14 : Float(1000, 4096)
|
||||
%learned_15 : Float(1000)) {
|
||||
# Every statement consists of some output tensors (and their types),
|
||||
# the operator to be run (with its attributes, e.g., kernels, strides,
|
||||
# etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
|
||||
%17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
|
||||
%18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
|
||||
%19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
|
||||
# ---- omitted for brevity ----
|
||||
%29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
|
||||
# Dynamic means that the shape is not known. This may be because of a
|
||||
# limitation of our implementation (which we would like to fix in a
|
||||
# future release) or shapes which are truly dynamic.
|
||||
%30 : Dynamic = onnx::Shape(%29), scope: AlexNet
|
||||
%31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
|
||||
%32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
|
||||
%33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
|
||||
# ---- omitted for brevity ----
|
||||
%output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
|
||||
return (%output1);
|
||||
}
|
||||
TorchScript-based ONNX Exporter
|
||||
-------------------------------
|
||||
|
||||
You can also verify the output using the `ONNX <https://github.com/onnx/onnx/>`_ library,
|
||||
which you can install using ``pip``::
|
||||
*The TorchScript-based ONNX exporter is available since PyTorch 1.2.0*
|
||||
|
||||
pip install onnx
|
||||
`TorchScript <https://pytorch.org/docs/stable/jit.html>`_ is leveraged to trace (through :func:`torch.jit.trace`)
|
||||
the model and capture a static computation graph.
|
||||
|
||||
Then, you can run::
|
||||
As a consequence, the resulting graph has a couple limitations:
|
||||
|
||||
import onnx
|
||||
* It does not record any control-flow, like if-statements or loops;
|
||||
* Does not handle nuances between ``training`` and ``eval`` mode;
|
||||
* Does not truly handle dynamic inputs
|
||||
|
||||
# Load the ONNX model
|
||||
model = onnx.load("alexnet.onnx")
|
||||
As an attempt to support the static tracing limitations, the exporter also supports TorchScript scripting
|
||||
(through :func:`torch.jit.script`), which adds support for data-dependent control-flow, for example. However, TorchScript
|
||||
itself is a subset of the Python language, so not all features in Python are supported, such as in-place operations.
|
||||
|
||||
# Check that the model is well formed
|
||||
onnx.checker.check_model(model)
|
||||
:doc:`Learn more about the TorchScript-based ONNX Exporter <onnx_torchscript>`
|
||||
|
||||
# Print a human readable representation of the graph
|
||||
print(onnx.helper.printable_graph(model.graph))
|
||||
|
||||
You can also run the exported model with one of the many
|
||||
`runtimes that support ONNX <https://onnx.ai/supported-tools.html#deployModel>`_.
|
||||
For example after installing `ONNX Runtime <https://www.onnxruntime.ai>`_, you can
|
||||
load and run the model::
|
||||
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
|
||||
ort_session = ort.InferenceSession("alexnet.onnx")
|
||||
|
||||
outputs = ort_session.run(
|
||||
None,
|
||||
{"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
|
||||
)
|
||||
print(outputs[0])
|
||||
|
||||
Here is a more involved `tutorial on exporting a model and running it with ONNX Runtime <https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html>`_.
|
||||
|
||||
.. _tracing-vs-scripting:
|
||||
|
||||
Tracing vs Scripting
|
||||
--------------------
|
||||
|
||||
Internally, :func:`torch.onnx.export()` requires a :class:`torch.jit.ScriptModule` rather than
|
||||
a :class:`torch.nn.Module`. If the passed-in model is not already a ``ScriptModule``,
|
||||
``export()`` will use *tracing* to convert it to one:
|
||||
|
||||
.. TODO(justinchuby): Add a word on recommending tracing over scripting for most use cases.
|
||||
|
||||
* **Tracing**: If ``torch.onnx.export()`` is called with a Module that is not already a
|
||||
``ScriptModule``, it first does the equivalent of :func:`torch.jit.trace`, which executes the model
|
||||
once with the given ``args`` and records all operations that happen during that execution. This
|
||||
means that if your model is dynamic, e.g., changes behavior depending on input data, the exported
|
||||
model will *not* capture this dynamic behavior.
|
||||
We recommend examining the exported model and making sure the operators look
|
||||
reasonable. Tracing will unroll loops and if statements, exporting a static graph that is exactly
|
||||
the same as the traced run. If you want to export your model with dynamic control flow, you will
|
||||
need to use *scripting*.
|
||||
|
||||
* **Scripting**: Compiling a model via scripting preserves dynamic control flow and is valid for inputs
|
||||
of different sizes. To use scripting:
|
||||
|
||||
* Use :func:`torch.jit.script` to produce a ``ScriptModule``.
|
||||
* Call ``torch.onnx.export()`` with the ``ScriptModule`` as the model. The ``args`` are still required,
|
||||
but they will be used internally only to produce example outputs, so that the types and shapes of the
|
||||
outputs can be captured. No tracing will be performed.
|
||||
|
||||
See `Introduction to TorchScript <https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html>`_
|
||||
and `TorchScript <jit.html>`_ for more details, including how to compose tracing and scripting to suit the
|
||||
particular requirements of different models.
|
||||
|
||||
|
||||
Avoiding Pitfalls
|
||||
-----------------
|
||||
|
||||
Avoid NumPy and built-in Python types
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
PyTorch models can be written using NumPy or Python types and functions, but
|
||||
during :ref:`tracing<tracing-vs-scripting>`, any variables of NumPy or Python
|
||||
types (rather than torch.Tensor) are converted to constants, which will produce
|
||||
the wrong result if those values should change depending on the inputs.
|
||||
|
||||
For example, rather than using numpy functions on numpy.ndarrays: ::
|
||||
|
||||
# Bad! Will be replaced with constants during tracing.
|
||||
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
|
||||
np.concatenate((x, y), axis=1)
|
||||
|
||||
Use torch operators on torch.Tensors: ::
|
||||
|
||||
# Good! Tensor operations will be captured during tracing.
|
||||
x, y = torch.randn(1, 2), torch.randn(1, 2)
|
||||
torch.cat((x, y), dim=1)
|
||||
|
||||
|
||||
And rather than use :func:`torch.Tensor.item` (which converts a Tensor to a Python
|
||||
built-in number): ::
|
||||
|
||||
# Bad! y.item() will be replaced with a constant during tracing.
|
||||
def forward(self, x, y):
|
||||
return x.reshape(y.item(), -1)
|
||||
|
||||
Use torch's support for implicit casting of single-element tensors: ::
|
||||
|
||||
# Good! y will be preserved as a variable during tracing.
|
||||
def forward(self, x, y):
|
||||
return x.reshape(y, -1)
|
||||
|
||||
Avoid Tensor.data
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
Using the Tensor.data field can produce an incorrect trace and therefore an incorrect ONNX graph.
|
||||
Use :func:`torch.Tensor.detach` instead. (Work is ongoing to
|
||||
`remove Tensor.data entirely <https://github.com/pytorch/pytorch/issues/30987>`_).
|
||||
|
||||
Avoid in-place operations when using tensor.shape in tracing mode
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
In tracing mode, shapes obtained from ``tensor.shape`` are traced as tensors,
|
||||
and share the same memory. This might cause a mismatch the final output values.
|
||||
As a workaround, avoid the use of inplace operations in these scenarios.
|
||||
For example, in the model::
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, states):
|
||||
batch_size, seq_length = states.shape[:2]
|
||||
real_seq_length = seq_length
|
||||
real_seq_length += 2
|
||||
return real_seq_length + seq_length
|
||||
|
||||
``real_seq_length`` and ``seq_length`` share the same memory in tracing mode.
|
||||
This could be avoided by rewriting the inplace operation::
|
||||
|
||||
real_seq_length = real_seq_length + 2
|
||||
|
||||
Limitations
|
||||
-----------
|
||||
|
||||
Types
|
||||
^^^^^
|
||||
|
||||
* Only :class:`torch.Tensors`, numeric types that can be trivially converted to torch.Tensors (e.g. float, int),
|
||||
and tuples and lists of those types are supported as model inputs or outputs. Dict and str inputs and
|
||||
outputs are accepted in :ref:`tracing<tracing-vs-scripting>` mode, but:
|
||||
|
||||
* Any computation that depends on the value of a dict or a str input **will be replaced with the
|
||||
constant value** seen during the one traced execution.
|
||||
* Any output that is a dict will be silently replaced with a **flattened sequence of its values
|
||||
(keys will be removed)**. E.g. ``{"foo": 1, "bar": 2}`` becomes ``(1, 2)``.
|
||||
* Any output that is a str will be silently removed.
|
||||
|
||||
* Certain operations involving tuples and lists are not supported in
|
||||
:ref:`scripting<tracing-vs-scripting>` mode due to limited support in ONNX for nested sequences.
|
||||
In particular appending a tuple to a list is not supported. In tracing mode, the nested sequences
|
||||
will be flattened automatically during the tracing.
|
||||
|
||||
Differences in Operator Implementations
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Due to differences in implementations of operators, running the exported model on different runtimes
|
||||
may produce different results from each other or from PyTorch. Normally these differences are
|
||||
numerically small, so this should only be a concern if your application is sensitive to these
|
||||
small differences.
|
||||
|
||||
.. _tensor-indexing:
|
||||
|
||||
Unsupported Tensor Indexing Patterns
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Tensor indexing patterns that cannot be exported are listed below.
|
||||
If you are experiencing issues exporting a model that does not include any of
|
||||
the unsupported patterns below, please double check that you are exporting with
|
||||
the latest ``opset_version``.
|
||||
|
||||
Reads / Gets
|
||||
~~~~~~~~~~~~
|
||||
|
||||
When indexing into a tensor for reading, the following patterns are not supported: ::
|
||||
|
||||
# Tensor indices that includes negative values.
|
||||
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
|
||||
# Workarounds: use positive index values.
|
||||
|
||||
Writes / Sets
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
When indexing into a Tensor for writing, the following patterns are not supported: ::
|
||||
|
||||
# Multiple tensor indices if any has rank >= 2
|
||||
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
|
||||
# Workarounds: use single tensor index with rank >= 2,
|
||||
# or multiple consecutive tensor indices with rank == 1.
|
||||
|
||||
# Multiple tensor indices that are not consecutive
|
||||
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
|
||||
# Workarounds: transpose `data` such that tensor indices are consecutive.
|
||||
|
||||
# Tensor indices that includes negative values.
|
||||
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
|
||||
# Workarounds: use positive index values.
|
||||
|
||||
# Implicit broadcasting required for new_data.
|
||||
data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
|
||||
# Workarounds: expand new_data explicitly.
|
||||
# Example:
|
||||
# data shape: [3, 4, 5]
|
||||
# new_data shape: [5]
|
||||
# expected new_data shape after broadcasting: [2, 2, 2, 5]
|
||||
|
||||
Adding support for operators
|
||||
----------------------------
|
||||
|
||||
When exporting a model that includes unsupported operators, you'll see an error message like:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
RuntimeError: ONNX export failed: Couldn't export operator foo
|
||||
|
||||
When that happens, there are a few things you can do:
|
||||
|
||||
#. Change the model to not use that operator.
|
||||
#. Create a symbolic function to convert the operator and register it as a custom symbolic function.
|
||||
#. Contribute to PyTorch to add the same symbolic function to :mod:`torch.onnx` itself.
|
||||
|
||||
If you decided to implement a symbolic function (we hope you will contribute it back to PyTorch!), here is how you can get started:
|
||||
|
||||
ONNX exporter internals
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
A "symbolic function" is a function that decomposes a PyTorch operator into a
|
||||
composition of a series of ONNX operators.
|
||||
|
||||
During export, each node (which contains a PyTorch operator) in the TorchScript
|
||||
graph is visited by the exporter in topological order.
|
||||
Upon visiting a node, the exporter looks for a registered symbolic functions for
|
||||
that operator. Symbolic functions are implemented in Python. A symbolic function for
|
||||
an op named ``foo`` would look something like::
|
||||
|
||||
|
||||
def foo(
|
||||
g,
|
||||
input_0: torch._C.Value,
|
||||
input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
|
||||
"""
|
||||
Adds the ONNX operations representing this PyTorch function by updating the
|
||||
graph g with `g.op()` calls.
|
||||
|
||||
Args:
|
||||
g (Graph): graph to write the ONNX representation into.
|
||||
input_0 (Value): value representing the variables which contain
|
||||
the first input for this operator.
|
||||
input_1 (Value): value representing the variables which contain
|
||||
the second input for this operator.
|
||||
|
||||
Returns:
|
||||
A Value or List of Values specifying the ONNX nodes that compute something
|
||||
equivalent to the original PyTorch operator with the given inputs.
|
||||
|
||||
None if it cannot be converted to ONNX.
|
||||
"""
|
||||
...
|
||||
|
||||
The ``torch._C`` types are Python wrappers around the types defined in C++ in
|
||||
`ir.h <https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/ir/ir.h>`_.
|
||||
|
||||
The process for adding a symbolic function depends on the type of operator.
|
||||
|
||||
.. _adding-support-aten:
|
||||
|
||||
ATen operators
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
`ATen <https://pytorch.org/cppdocs/#aten>`_ is PyTorch's built-in tensor library.
|
||||
If the operator is an ATen operator (shows up in the TorchScript graph with the prefix
|
||||
``aten::``), make sure it is not supported already.
|
||||
|
||||
List of supported operators
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Visit the auto generated :doc:`list of supported TorchScript operators <../onnx_supported_aten_ops>`
|
||||
for details on which operator are supported in each ``opset_version``.
|
||||
|
||||
Adding support for an aten or quantized operator
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
If the operator is not in the list above:
|
||||
|
||||
* Define the symbolic function in ``torch/onnx/symbolic_opset<version>.py``, for example
|
||||
`torch/onnx/symbolic_opset9.py <https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset9.py>`_.
|
||||
Make sure the function has the same name as the ATen function, which may be declared in
|
||||
``torch/_C/_VariableFunctions.pyi`` or ``torch/nn/functional.pyi`` (these files are generated at
|
||||
build time, so will not appear in your checkout until you build PyTorch).
|
||||
* By default, the first arg is the ONNX graph.
|
||||
Other arg names must EXACTLY match the names in the ``.pyi`` file,
|
||||
because dispatch is done with keyword arguments.
|
||||
* In the symbolic function, if the operator is in the
|
||||
`ONNX standard operator set <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
|
||||
we only need to create a node to represent the ONNX operator in the graph.
|
||||
If not, we can compose several standard operators that have the
|
||||
equivalent semantics to the ATen operator.
|
||||
|
||||
Here is an example of handling missing symbolic function for the ``ELU`` operator.
|
||||
|
||||
If we run the following code::
|
||||
|
||||
print(
|
||||
torch.jit.trace(
|
||||
torch.nn.ELU(), # module
|
||||
torch.ones(1) # example input
|
||||
).graph
|
||||
)
|
||||
|
||||
We see something like::
|
||||
|
||||
graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
|
||||
%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
|
||||
%4 : float = prim::Constant[value=1.]()
|
||||
%5 : int = prim::Constant[value=1]()
|
||||
%6 : int = prim::Constant[value=1]()
|
||||
%7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
|
||||
return (%7)
|
||||
|
||||
Since we see ``aten::elu`` in the graph, we know this is an ATen operator.
|
||||
|
||||
We check the `ONNX operator list <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
|
||||
and confirm that ``Elu`` is standardized in ONNX.
|
||||
|
||||
We find a signature for ``elu`` in ``torch/nn/functional.pyi``::
|
||||
|
||||
def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...
|
||||
|
||||
We add the following lines to ``symbolic_opset9.py``::
|
||||
|
||||
def elu(g, input: torch.Value, alpha: torch.Value, inplace: bool = False):
|
||||
return g.op("Elu", input, alpha_f=alpha)
|
||||
|
||||
Now PyTorch is able to export models containing the ``aten::elu`` operator!
|
||||
|
||||
See the ``torch/onnx/symbolic_opset*.py`` files for more examples.
|
||||
|
||||
|
||||
torch.autograd.Functions
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
If the operator is a sub-class of :class:`torch.autograd.Function`, there are three ways
|
||||
to export it.
|
||||
|
||||
Static Symbolic Method
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
You can add a static method named ``symbolic`` to your function class. It should return
|
||||
ONNX operators that represent the function's behavior in ONNX. For example::
|
||||
|
||||
class MyRelu(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
|
||||
ctx.save_for_backward(input)
|
||||
return input.clamp(min=0)
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
|
||||
return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))
|
||||
|
||||
.. FIXME(justinchuby): PythonOps are too complicated and the example below
|
||||
.. uses private methods we do not expose. We are looking to
|
||||
.. improve the experience. Since SymbolicContext is deprecated, we think
|
||||
.. defining a symbolic staticmethod is a better way to go for now.
|
||||
|
||||
.. PythonOp Symbolic
|
||||
.. ~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. Alternatively, you can register a custom symbolic function.
|
||||
.. This gives the symbolic function access to more info through the
|
||||
.. ``torch.onnx.SymbolicContext`` object, which gets passed in as the first
|
||||
.. argument (before the ``Graph`` object).
|
||||
|
||||
.. All autograd ``Function``\ s appear in the TorchScript graph as ``prim::PythonOp`` nodes.
|
||||
.. In order to differentiate between different ``Function`` subclasses, the
|
||||
.. symbolic function should use the ``name`` kwarg which gets set to the name of the class.
|
||||
|
||||
.. Custom symbolic functions should add type and shape information by calling ``setType(...)``
|
||||
.. on Value objects before returning them (implemented in C++ by
|
||||
.. . ``torch::jit::Value::setType``). This is not required, but it can help the exporter's
|
||||
.. shape and type inference for down-stream nodes. For a non-trivial example of ``setType``, see
|
||||
.. ``test_aten_embedding_2`` in
|
||||
.. `test_operators.py <https://github.com/pytorch/pytorch/blob/main/test/onnx/test_operators.py>`_.
|
||||
|
||||
.. The example below shows how you can access ``requires_grad`` via the ``Node`` object:
|
||||
|
||||
.. class MyClip(torch.autograd.Function):
|
||||
.. @staticmethod
|
||||
.. def forward(ctx, input, min):
|
||||
.. ctx.save_for_backward(input)
|
||||
.. return input.clamp(min=min)
|
||||
|
||||
.. class MyRelu(torch.autograd.Function):
|
||||
.. @staticmethod
|
||||
.. def forward(ctx, input):
|
||||
.. ctx.save_for_backward(input)
|
||||
.. return input.clamp(min=0)
|
||||
|
||||
.. def symbolic_python_op(g: "GraphContext", *args, **kwargs):
|
||||
.. n = ctx.cur_node
|
||||
.. print("original node: ", n)
|
||||
.. for i, out in enumerate(n.outputs()):
|
||||
.. print("original output {}: {}, requires grad: {}".format(i, out, out.requiresGrad()))
|
||||
.. import torch.onnx.symbolic_helper as sym_helper
|
||||
.. for i, arg in enumerate(args):
|
||||
.. requires_grad = arg.requiresGrad() if sym_helper._is_value(arg) else False
|
||||
.. print("arg {}: {}, requires grad: {}".format(i, arg, requires_grad))
|
||||
|
||||
.. name = kwargs["name"]
|
||||
.. ret = None
|
||||
.. if name == "MyClip":
|
||||
.. ret = g.op("Clip", args[0], args[1])
|
||||
.. elif name == "MyRelu":
|
||||
.. ret = g.op("Relu", args[0])
|
||||
.. else:
|
||||
.. # Logs a warning and returns None
|
||||
.. return _unimplemented("prim::PythonOp", "unknown node kind: " + name)
|
||||
.. # Copy type and shape from original node.
|
||||
.. ret.setType(n.type())
|
||||
.. return ret
|
||||
|
||||
.. from torch.onnx import register_custom_op_symbolic
|
||||
.. . register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1)
|
||||
|
||||
Inline Autograd Function
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
In cases where a static symbolic method is not provided for its subsequent :class:`torch.autograd.Function` or
|
||||
where a function to register ``prim::PythonOp`` as custom symbolic functions is not provided,
|
||||
:func:`torch.onnx.export` tries to inline the graph that corresponds to that :class:`torch.autograd.Function` such that
|
||||
this function is broken down into individual operators that were used within the function.
|
||||
The export should be successful as long as these individual operators are supported. For example::
|
||||
|
||||
class MyLogExp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
|
||||
ctx.save_for_backward(input)
|
||||
h = input.exp()
|
||||
return h.log().log()
|
||||
|
||||
There is no static symbolic method present for this model, yet it is exported as follows::
|
||||
|
||||
graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
|
||||
%1 : float = onnx::Exp[](%input)
|
||||
%2 : float = onnx::Log[](%1)
|
||||
%3 : float = onnx::Log[](%2)
|
||||
return (%3)
|
||||
|
||||
If you need to avoid inlining of :class:`torch.autograd.Function`, you should export models with
|
||||
``operator_export_type`` set to ``ONNX_FALLTHROUGH`` or ``ONNX_ATEN_FALLBACK``.
|
||||
|
||||
Custom operators
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
You can export your model with custom operators that includes a combination of many standard ONNX ops,
|
||||
or are driven by self-defined C++ backend.
|
||||
|
||||
ONNX-script functions
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
If an operator is not a standard ONNX op, but can be composed of multiple existing ONNX ops, you can utilize
|
||||
`ONNX-script <https://github.com/microsoft/onnx-script>`_ to create an external ONNX function to support the operator.
|
||||
You can export it by following this example::
|
||||
|
||||
import onnxscript
|
||||
# There are three opset version needed to be aligned
|
||||
# This is (1) the opset version in ONNX function
|
||||
from onnxscript.onnx_opset import opset15 as op
|
||||
opset_version = 15
|
||||
|
||||
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
||||
model = torch.nn.SELU()
|
||||
|
||||
custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
|
||||
|
||||
@onnxscript.script(custom_opset)
|
||||
def Selu(X):
|
||||
alpha = 1.67326 # auto wrapped as Constants
|
||||
gamma = 1.0507
|
||||
alphaX = op.CastLike(alpha, X)
|
||||
gammaX = op.CastLike(gamma, X)
|
||||
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
|
||||
pos = gammaX * X
|
||||
zero = op.CastLike(0, X)
|
||||
return op.Where(X <= zero, neg, pos)
|
||||
|
||||
# setType API provides shape/type to ONNX shape/type inference
|
||||
def custom_selu(g: jit_utils.GraphContext, X):
|
||||
return g.onnxscript_op(Selu, X).setType(X.type())
|
||||
|
||||
# Register custom symbolic function
|
||||
# There are three opset version needed to be aligned
|
||||
# This is (2) the opset version in registry
|
||||
torch.onnx.register_custom_op_symbolic(
|
||||
symbolic_name="aten::selu",
|
||||
symbolic_fn=custom_selu,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
# There are three opset version needed to be aligned
|
||||
# This is (2) the opset version in exporter
|
||||
torch.onnx.export(
|
||||
model,
|
||||
x,
|
||||
"model.onnx",
|
||||
opset_version=opset_version,
|
||||
# only needed if you want to specify an opset version > 1.
|
||||
custom_opsets={"onnx-script": 2}
|
||||
)
|
||||
|
||||
The example above exports it as a custom operator in the "onnx-script" opset.
|
||||
When exporting a custom operator, you can specify the custom domain version using the
|
||||
``custom_opsets`` dictionary at export. If not specified, the custom opset version defaults to 1.
|
||||
|
||||
NOTE: Be careful to align the opset version mentioned in the above example, and make sure they are consumed in exporter step.
|
||||
The example usage of how to write a onnx-script function is a beta version in terms of the active development on onnx-script.
|
||||
Please follow the latest `ONNX-script <https://github.com/microsoft/onnx-script>`_
|
||||
|
||||
C++ Operators
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
If a model uses a custom operator implemented in C++ as described in
|
||||
`Extending TorchScript with Custom C++ Operators <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_,
|
||||
you can export it by following this example::
|
||||
|
||||
from torch.onnx import symbolic_helper
|
||||
|
||||
|
||||
# Define custom symbolic function
|
||||
@symbolic_helper.parse_args("v", "v", "f", "i")
|
||||
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
|
||||
return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)
|
||||
|
||||
|
||||
# Register custom symbolic function
|
||||
torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)
|
||||
|
||||
|
||||
class FooModel(torch.nn.Module):
|
||||
def __init__(self, attr1, attr2):
|
||||
super().__init__()
|
||||
self.attr1 = attr1
|
||||
self.attr2 = attr2
|
||||
|
||||
def forward(self, input1, input2):
|
||||
# Calling custom op
|
||||
return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)
|
||||
|
||||
|
||||
model = FooModel(attr1, attr2)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(example_input1, example_input1),
|
||||
"model.onnx",
|
||||
# only needed if you want to specify an opset version > 1.
|
||||
custom_opsets={"custom_domain": 2}
|
||||
)
|
||||
|
||||
The example above exports it as a custom operator in the "custom_domain" opset.
|
||||
When exporting a custom operator, you can specify the custom domain version using the
|
||||
``custom_opsets`` dictionary at export. If not specified, the custom opset version defaults to 1.
|
||||
|
||||
The runtime that consumes the model needs to support the custom op. See
|
||||
`Caffe2 custom ops <https://caffe2.ai/docs/custom-operators.html>`_,
|
||||
`ONNX Runtime custom ops <https://onnxruntime.ai/docs/reference/operators/add-custom-op.html>`_,
|
||||
or your runtime of choice's documentation.
|
||||
|
||||
|
||||
Discovering all unconvertible ATen ops at once
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
When export fails due to an unconvertible ATen op, there may in fact be more
|
||||
than one such op but the error message only mentions the first. To discover
|
||||
all of the unconvertible ops in one go you can::
|
||||
|
||||
# prepare model, args, opset_version
|
||||
...
|
||||
|
||||
torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
|
||||
model, args, opset_version=opset_version
|
||||
)
|
||||
|
||||
print(set(unconvertible_ops))
|
||||
|
||||
The set is approximated because some ops may be removed during the conversion
|
||||
process and don't need to be converted. Some other ops may have partial support
|
||||
that will fail conversion with particular inputs, but this should give you a
|
||||
general idea of what ops are not supported. Please feel free to open GitHub Issues
|
||||
for op support requests.
|
||||
|
||||
Frequently Asked Questions
|
||||
--------------------------
|
||||
Q: I have exported my LSTM model, but its input size seems to be fixed?
|
||||
|
||||
The tracer records the shapes of the example inputs. If the model should accept
|
||||
inputs of dynamic shapes, set ``dynamic_axes`` when calling :func:`torch.onnx.export`.
|
||||
|
||||
Q: How to export models containing loops?
|
||||
|
||||
See `Tracing vs Scripting`_.
|
||||
|
||||
Q: How to export models with primitive type inputs (e.g. int, float)?
|
||||
|
||||
Support for primitive numeric type inputs was added in PyTorch 1.9.
|
||||
However, the exporter does not support models with str inputs.
|
||||
|
||||
Q: Does ONNX support implicit scalar datatype casting?
|
||||
|
||||
The ONNX standard does not, but the exporter will try to handle that part.
|
||||
Scalars are exported as constant tensors.
|
||||
The exporter will figure out the right data type for scalars. In rare cases when it is unable
|
||||
to do so, you will need to manually specify the datatype with e.g. `dtype=torch.float32`.
|
||||
If you see any errors, please [create a GitHub issue](https://github.com/pytorch/pytorch/issues).
|
||||
|
||||
Q: Are lists of Tensors exportable to ONNX?
|
||||
|
||||
Yes, for ``opset_version`` >= 11, since ONNX introduced the Sequence type in opset 11.
|
||||
|
||||
|
||||
Contributing / developing
|
||||
Contributing / Developing
|
||||
-------------------------
|
||||
`Developer docs <https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter>`_.
|
||||
|
||||
Functions
|
||||
---------
|
||||
.. autofunction:: export
|
||||
.. autofunction:: export_to_pretty_string
|
||||
.. autofunction:: register_custom_op_symbolic
|
||||
.. autofunction:: unregister_custom_op_symbolic
|
||||
.. autofunction:: select_model_mode_for_export
|
||||
.. autofunction:: is_in_onnx_export
|
||||
.. autofunction:: enable_log
|
||||
.. autofunction:: disable_log
|
||||
.. autofunction:: torch.onnx.verification.find_mismatch
|
||||
The ONNX exporter is a community project and we welcome contributions. We follow the
|
||||
`PyTorch guidelines for contributions <https://github.com/pytorch/pytorch/blob/main/CONTRIBUTING.md>`_, but you might
|
||||
also be interested in reading our `development wiki <https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter>`_.
|
||||
|
||||
Classes
|
||||
-------
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
JitScalarType
|
||||
torch.onnx.verification.GraphInfo
|
||||
torch.onnx.verification.VerificationOptions
|
||||
|
||||
Preview: torch.onnx TorchDynamo Exporter
|
||||
----------------------------------------
|
||||
|
||||
.. warning::
|
||||
The ONNX exporter for TorchDynamo is under active development and is
|
||||
subject to rapid change.
|
||||
|
||||
.. autofunction:: torch.onnx.dynamo_export
|
||||
.. autofunction:: torch.onnx.enable_fake_mode
|
||||
.. autofunction:: torch.onnx.is_onnxrt_backend_supported
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
torch.onnx.DiagnosticOptions
|
||||
torch.onnx.ExportOptions
|
||||
torch.onnx.ExportOutput
|
||||
torch.onnx.ExportOutputSerializer
|
||||
torch.onnx.OnnxExporterError
|
||||
torch.onnx.OnnxRegistry
|
||||
onnx_dynamo
|
||||
onnx_dynamo_onnxruntime_backend
|
||||
onnx_torchscript
|
||||
|
@ -1,26 +0,0 @@
|
||||
torch.onnx diagnostics
|
||||
======================
|
||||
|
||||
.. contents:: :local:
|
||||
.. automodule:: torch.onnx._internal.diagnostics
|
||||
.. currentmodule:: torch.onnx._internal.diagnostics
|
||||
|
||||
Overview
|
||||
--------
|
||||
|
||||
NOTE: This feature is underdevelopment and is subject to change.
|
||||
|
||||
The goal is to improve the diagnostics to help users debug and improve their model export to ONNX.
|
||||
|
||||
- The diagnostics are emitted in machine parsable `Static Analysis Results Interchange Format (SARIF) <https://docs.oasis-open.org/sarif/sarif/v2.1.0/sarif-v2.1.0.html>`__.
|
||||
- A new clearer, structured way to add new and keep track of diagnostic rules.
|
||||
- Serve as foundation for more future improvements consuming the diagnostics.
|
||||
|
||||
|
||||
Diagnostic Rules
|
||||
----------------
|
||||
|
||||
.. toctree::
|
||||
:glob:
|
||||
|
||||
generated/onnx_diagnostics_rules/*
|
156
docs/source/onnx_dynamo.rst
Normal file
156
docs/source/onnx_dynamo.rst
Normal file
@ -0,0 +1,156 @@
|
||||
TorchDynamo-based ONNX Exporter
|
||||
===============================
|
||||
|
||||
.. automodule:: torch.onnx
|
||||
:noindex:
|
||||
|
||||
.. contents:: :local:
|
||||
:depth: 3
|
||||
|
||||
.. warning::
|
||||
The ONNX exporter for TorchDynamo is under active development and is subject to rapid change.
|
||||
|
||||
Overview
|
||||
--------
|
||||
|
||||
The ONNX exporter leverages TorchDynamo engine to hook into Python's frame evaluation API
|
||||
and dynamically rewrite its bytecode into an FX Graph.
|
||||
The resulting FX Graph is then polished before it is finally translated into an ONNX graph.
|
||||
|
||||
The main advantage of this approach is that the `FX graph <https://pytorch.org/docs/stable/fx.html>`_ is captured using
|
||||
bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques.
|
||||
|
||||
The exporter is designed to be modular and extensible. It is composed of the following components:
|
||||
|
||||
- **ONNX Exporter**: :class:`Exporter` main class that orchestrates the export process.
|
||||
- **ONNX Export Options**: :class:`ExportOptions` has a set of options that control the export process.
|
||||
- **ONNX Registry**: :class:`OnnxRegistry` is the registry of ONNX operators and functions.
|
||||
- **FX Graph Extractor**: :class:`FXGraphExtractor` extracts the FX graph from the PyTorch model.
|
||||
- **Fake Mode**: :class:`ONNXFakeContext` is a context manager that enables fake mode for large scale models.
|
||||
- **ONNX Export Output**: :class:`ExportOutput` is the output of the exporter that contains the exported ONNX graph and diagnostics.
|
||||
- **ONNX Export Output Serializer**: :class:`ExportOutputSerializer` serializes the exported model to a file.
|
||||
- **ONNX Diagnostic Options**: :class:`DiagnosticOptions` has a set of options that control the diagnostics emitted by the exporter.
|
||||
|
||||
Dependencies
|
||||
------------
|
||||
|
||||
The ONNX exporter depends on extra Python packages:
|
||||
|
||||
- `ONNX <https://onnx.ai>`_
|
||||
- `ONNX Script <https://onnxscript.ai>`_
|
||||
|
||||
They can be installed through `pip <https://pypi.org/project/pip/>`_:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install --upgrade onnx onnxscript
|
||||
|
||||
A simple example
|
||||
----------------
|
||||
|
||||
See below a demonstration of exporter API in action with a simple Multilayer Perceptron (MLP) as example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
|
||||
class MLPModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc0 = nn.Linear(8, 8, bias=True)
|
||||
self.fc1 = nn.Linear(8, 4, bias=True)
|
||||
self.fc2 = nn.Linear(4, 2, bias=True)
|
||||
self.fc3 = nn.Linear(2, 2, bias=True)
|
||||
|
||||
def forward(self, tensor_x: torch.Tensor):
|
||||
tensor_x = self.fc0(tensor_x)
|
||||
tensor_x = torch.sigmoid(tensor_x)
|
||||
tensor_x = self.fc1(tensor_x)
|
||||
tensor_x = torch.sigmoid(tensor_x)
|
||||
tensor_x = self.fc2(tensor_x)
|
||||
tensor_x = torch.sigmoid(tensor_x)
|
||||
output = self.fc3(tensor_x)
|
||||
return output
|
||||
|
||||
model = MLPModel()
|
||||
tensor_x = torch.rand((97, 8), dtype=torch.float32)
|
||||
export_output = torch.onnx.dynamo_export(model, tensor_x)
|
||||
|
||||
As the code above shows, all you need is to provide :func:`torch.onnx.dynamo_export` with an instance of the model and its input.
|
||||
The exporter will then return an instance of :class:`torch.onnx.ExportOutput` that contains the exported ONNX graph along with extra information.
|
||||
|
||||
The in-memory model available through ``export_output.model_proto`` is an ``onnx.ModelProto`` object in compliance with the `ONNX IR spec <https://github.com/onnx/onnx/blob/main/docs/IR.md>`_.
|
||||
The ONNX model be serialized into a `Protobuf file <https://protobuf.dev/>`_ using the :meth:`torch.onnx.ExportOutput.save` API.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
export_output.save("mlp.onnx")
|
||||
|
||||
Inspecting the ONNX model using GUI
|
||||
-----------------------------------
|
||||
|
||||
You can view the exported model using `Netron <https://netron.app/>`__.
|
||||
|
||||
.. image:: _static/img/onnx/onnx_dynamo_mlp_model.png
|
||||
:width: 40%
|
||||
:alt: MLP model as viewed using Netron
|
||||
|
||||
Note that each layer is represented in a rectangular box with a *f* icon in the top right corner.
|
||||
|
||||
.. image:: _static/img/onnx/onnx_dynamo_mlp_model_function_highlight.png
|
||||
:width: 40%
|
||||
:alt: ONNX function highlighted on MLP model
|
||||
|
||||
By expanding it, the function body is shown.
|
||||
|
||||
.. image:: _static/img/onnx/onnx_dynamo_mlp_model_function_body.png
|
||||
:width: 50%
|
||||
:alt: ONNX function body
|
||||
|
||||
The function body is a sequence of ONNX operators or other functions.
|
||||
|
||||
Diagnosing issues with SARIF
|
||||
----------------------------
|
||||
|
||||
ONNX diagnostics goes beyond regular logs through the adoption of
|
||||
`Static Analysis Results Interchange Format (aka SARIF) <https://docs.oasis-open.org/sarif/sarif/v2.1.0/sarif-v2.1.0.html>`__
|
||||
to help users debug and improve their model using a GUI, such as
|
||||
Visual Studio Code's `SARIF Viewer <https://marketplace.visualstudio.com/items?itemName=MS-SarifVSCode.sarif-viewer>`_.
|
||||
|
||||
The main advantages are:
|
||||
|
||||
- The diagnostics are emitted in machine parseable `Static Analysis Results Interchange Format (SARIF) <https://docs.oasis-open.org/sarif/sarif/v2.1.0/sarif-v2.1.0.html>`__.
|
||||
- A new clearer, structured way to add new and keep track of diagnostic rules.
|
||||
- Serve as foundation for more future improvements consuming the diagnostics.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: ONNX Diagnostic SARIF Rules
|
||||
:glob:
|
||||
|
||||
generated/onnx_dynamo_diagnostics_rules/*
|
||||
|
||||
API Reference
|
||||
-------------
|
||||
|
||||
.. autofunction:: torch.onnx.dynamo_export
|
||||
|
||||
.. autoclass:: torch.onnx.ExportOptions
|
||||
:members:
|
||||
|
||||
.. autofunction:: torch.onnx.enable_fake_mode
|
||||
|
||||
.. autoclass:: torch.onnx.ExportOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.onnx.ExportOutputSerializer
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.onnx.OnnxExporterError
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.onnx.OnnxRegistry
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.onnx.DiagnosticOptions
|
||||
:members:
|
10
docs/source/onnx_dynamo_onnxruntime_backend.rst
Normal file
10
docs/source/onnx_dynamo_onnxruntime_backend.rst
Normal file
@ -0,0 +1,10 @@
|
||||
ONNX Backend for TorchDynamo
|
||||
============================
|
||||
|
||||
For a quick overview of ``torch.compiler``, see :ref:`torch.compiler_overview`.
|
||||
|
||||
.. warning::
|
||||
The ONNX backend for TorchDynamo is under active development and is
|
||||
subject to rapid change.
|
||||
|
||||
.. autofunction:: torch.onnx.is_onnxrt_backend_supported
|
716
docs/source/onnx_torchscript.rst
Normal file
716
docs/source/onnx_torchscript.rst
Normal file
@ -0,0 +1,716 @@
|
||||
TorchScript-based ONNX Exporter
|
||||
===============================
|
||||
|
||||
.. contents:: :local:
|
||||
|
||||
Example: AlexNet from PyTorch to ONNX
|
||||
-------------------------------------
|
||||
|
||||
Here is a simple script which exports a pretrained AlexNet to an ONNX file named ``alexnet.onnx``.
|
||||
The call to ``torch.onnx.export`` runs the model once to trace its execution and then exports the
|
||||
traced model to the specified file::
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
|
||||
model = torchvision.models.alexnet(pretrained=True).cuda()
|
||||
|
||||
# Providing input and output names sets the display names for values
|
||||
# within the model's graph. Setting these does not change the semantics
|
||||
# of the graph; it is only for readability.
|
||||
#
|
||||
# The inputs to the network consist of the flat list of inputs (i.e.
|
||||
# the values you would pass to the forward() method) followed by the
|
||||
# flat list of parameters. You can partially specify names, i.e. provide
|
||||
# a list here shorter than the number of inputs to the model, and we will
|
||||
# only set that subset of names, starting from the beginning.
|
||||
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
|
||||
output_names = [ "output1" ]
|
||||
|
||||
torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)
|
||||
|
||||
The resulting ``alexnet.onnx`` file contains a binary `protocol buffer <https://developers.google.com/protocol-buffers/>`_
|
||||
which contains both the network structure and parameters of the model you exported
|
||||
(in this case, AlexNet). The argument ``verbose=True`` causes the
|
||||
exporter to print out a human-readable representation of the model::
|
||||
|
||||
# These are the inputs and parameters to the network, which have taken on
|
||||
# the names we specified earlier.
|
||||
graph(%actual_input_1 : Float(10, 3, 224, 224)
|
||||
%learned_0 : Float(64, 3, 11, 11)
|
||||
%learned_1 : Float(64)
|
||||
%learned_2 : Float(192, 64, 5, 5)
|
||||
%learned_3 : Float(192)
|
||||
# ---- omitted for brevity ----
|
||||
%learned_14 : Float(1000, 4096)
|
||||
%learned_15 : Float(1000)) {
|
||||
# Every statement consists of some output tensors (and their types),
|
||||
# the operator to be run (with its attributes, e.g., kernels, strides,
|
||||
# etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
|
||||
%17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
|
||||
%18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
|
||||
%19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
|
||||
# ---- omitted for brevity ----
|
||||
%29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
|
||||
# Dynamic means that the shape is not known. This may be because of a
|
||||
# limitation of our implementation (which we would like to fix in a
|
||||
# future release) or shapes which are truly dynamic.
|
||||
%30 : Dynamic = onnx::Shape(%29), scope: AlexNet
|
||||
%31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
|
||||
%32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
|
||||
%33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
|
||||
# ---- omitted for brevity ----
|
||||
%output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
|
||||
return (%output1);
|
||||
}
|
||||
|
||||
You can also verify the output using the `ONNX <https://github.com/onnx/onnx/>`_ library,
|
||||
which you can install using ``pip``::
|
||||
|
||||
pip install onnx
|
||||
|
||||
Then, you can run::
|
||||
|
||||
import onnx
|
||||
|
||||
# Load the ONNX model
|
||||
model = onnx.load("alexnet.onnx")
|
||||
|
||||
# Check that the model is well formed
|
||||
onnx.checker.check_model(model)
|
||||
|
||||
# Print a human readable representation of the graph
|
||||
print(onnx.helper.printable_graph(model.graph))
|
||||
|
||||
You can also run the exported model with one of the many
|
||||
`runtimes that support ONNX <https://onnx.ai/supported-tools.html#deployModel>`_.
|
||||
For example after installing `ONNX Runtime <https://www.onnxruntime.ai>`_, you can
|
||||
load and run the model::
|
||||
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
|
||||
ort_session = ort.InferenceSession("alexnet.onnx")
|
||||
|
||||
outputs = ort_session.run(
|
||||
None,
|
||||
{"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
|
||||
)
|
||||
print(outputs[0])
|
||||
|
||||
Here is a more involved `tutorial on exporting a model and running it with ONNX Runtime <https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html>`_.
|
||||
|
||||
.. _tracing-vs-scripting:
|
||||
|
||||
Tracing vs Scripting
|
||||
--------------------
|
||||
|
||||
Internally, :func:`torch.onnx.export()` requires a :class:`torch.jit.ScriptModule` rather than
|
||||
a :class:`torch.nn.Module`. If the passed-in model is not already a ``ScriptModule``,
|
||||
``export()`` will use *tracing* to convert it to one:
|
||||
|
||||
.. TODO(justinchuby): Add a word on recommending tracing over scripting for most use cases.
|
||||
|
||||
* **Tracing**: If ``torch.onnx.export()`` is called with a Module that is not already a
|
||||
``ScriptModule``, it first does the equivalent of :func:`torch.jit.trace`, which executes the model
|
||||
once with the given ``args`` and records all operations that happen during that execution. This
|
||||
means that if your model is dynamic, e.g., changes behavior depending on input data, the exported
|
||||
model will *not* capture this dynamic behavior.
|
||||
We recommend examining the exported model and making sure the operators look
|
||||
reasonable. Tracing will unroll loops and if statements, exporting a static graph that is exactly
|
||||
the same as the traced run. If you want to export your model with dynamic control flow, you will
|
||||
need to use *scripting*.
|
||||
|
||||
* **Scripting**: Compiling a model via scripting preserves dynamic control flow and is valid for inputs
|
||||
of different sizes. To use scripting:
|
||||
|
||||
* Use :func:`torch.jit.script` to produce a ``ScriptModule``.
|
||||
* Call ``torch.onnx.export()`` with the ``ScriptModule`` as the model. The ``args`` are still required,
|
||||
but they will be used internally only to produce example outputs, so that the types and shapes of the
|
||||
outputs can be captured. No tracing will be performed.
|
||||
|
||||
See `Introduction to TorchScript <https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html>`_
|
||||
and `TorchScript <jit.html>`_ for more details, including how to compose tracing and scripting to suit the
|
||||
particular requirements of different models.
|
||||
|
||||
|
||||
Avoiding Pitfalls
|
||||
-----------------
|
||||
|
||||
Avoid NumPy and built-in Python types
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
PyTorch models can be written using NumPy or Python types and functions, but
|
||||
during :ref:`tracing<tracing-vs-scripting>`, any variables of NumPy or Python
|
||||
types (rather than torch.Tensor) are converted to constants, which will produce
|
||||
the wrong result if those values should change depending on the inputs.
|
||||
|
||||
For example, rather than using numpy functions on numpy.ndarrays: ::
|
||||
|
||||
# Bad! Will be replaced with constants during tracing.
|
||||
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
|
||||
np.concatenate((x, y), axis=1)
|
||||
|
||||
Use torch operators on torch.Tensors: ::
|
||||
|
||||
# Good! Tensor operations will be captured during tracing.
|
||||
x, y = torch.randn(1, 2), torch.randn(1, 2)
|
||||
torch.cat((x, y), dim=1)
|
||||
|
||||
|
||||
And rather than use :func:`torch.Tensor.item` (which converts a Tensor to a Python
|
||||
built-in number): ::
|
||||
|
||||
# Bad! y.item() will be replaced with a constant during tracing.
|
||||
def forward(self, x, y):
|
||||
return x.reshape(y.item(), -1)
|
||||
|
||||
Use torch's support for implicit casting of single-element tensors: ::
|
||||
|
||||
# Good! y will be preserved as a variable during tracing.
|
||||
def forward(self, x, y):
|
||||
return x.reshape(y, -1)
|
||||
|
||||
Avoid Tensor.data
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
Using the Tensor.data field can produce an incorrect trace and therefore an incorrect ONNX graph.
|
||||
Use :func:`torch.Tensor.detach` instead. (Work is ongoing to
|
||||
`remove Tensor.data entirely <https://github.com/pytorch/pytorch/issues/30987>`_).
|
||||
|
||||
Avoid in-place operations when using tensor.shape in tracing mode
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
In tracing mode, shapes obtained from ``tensor.shape`` are traced as tensors,
|
||||
and share the same memory. This might cause a mismatch the final output values.
|
||||
As a workaround, avoid the use of inplace operations in these scenarios.
|
||||
For example, in the model::
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, states):
|
||||
batch_size, seq_length = states.shape[:2]
|
||||
real_seq_length = seq_length
|
||||
real_seq_length += 2
|
||||
return real_seq_length + seq_length
|
||||
|
||||
``real_seq_length`` and ``seq_length`` share the same memory in tracing mode.
|
||||
This could be avoided by rewriting the inplace operation::
|
||||
|
||||
real_seq_length = real_seq_length + 2
|
||||
|
||||
Limitations
|
||||
-----------
|
||||
|
||||
Types
|
||||
^^^^^
|
||||
|
||||
* Only :class:`torch.Tensors`, numeric types that can be trivially converted to torch.Tensors (e.g. float, int),
|
||||
and tuples and lists of those types are supported as model inputs or outputs. Dict and str inputs and
|
||||
outputs are accepted in :ref:`tracing<tracing-vs-scripting>` mode, but:
|
||||
|
||||
* Any computation that depends on the value of a dict or a str input **will be replaced with the
|
||||
constant value** seen during the one traced execution.
|
||||
* Any output that is a dict will be silently replaced with a **flattened sequence of its values
|
||||
(keys will be removed)**. E.g. ``{"foo": 1, "bar": 2}`` becomes ``(1, 2)``.
|
||||
* Any output that is a str will be silently removed.
|
||||
|
||||
* Certain operations involving tuples and lists are not supported in
|
||||
:ref:`scripting<tracing-vs-scripting>` mode due to limited support in ONNX for nested sequences.
|
||||
In particular appending a tuple to a list is not supported. In tracing mode, the nested sequences
|
||||
will be flattened automatically during the tracing.
|
||||
|
||||
Differences in Operator Implementations
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Due to differences in implementations of operators, running the exported model on different runtimes
|
||||
may produce different results from each other or from PyTorch. Normally these differences are
|
||||
numerically small, so this should only be a concern if your application is sensitive to these
|
||||
small differences.
|
||||
|
||||
.. _tensor-indexing:
|
||||
|
||||
Unsupported Tensor Indexing Patterns
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Tensor indexing patterns that cannot be exported are listed below.
|
||||
If you are experiencing issues exporting a model that does not include any of
|
||||
the unsupported patterns below, please double check that you are exporting with
|
||||
the latest ``opset_version``.
|
||||
|
||||
Reads / Gets
|
||||
~~~~~~~~~~~~
|
||||
|
||||
When indexing into a tensor for reading, the following patterns are not supported: ::
|
||||
|
||||
# Tensor indices that includes negative values.
|
||||
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
|
||||
# Workarounds: use positive index values.
|
||||
|
||||
Writes / Sets
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
When indexing into a Tensor for writing, the following patterns are not supported: ::
|
||||
|
||||
# Multiple tensor indices if any has rank >= 2
|
||||
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
|
||||
# Workarounds: use single tensor index with rank >= 2,
|
||||
# or multiple consecutive tensor indices with rank == 1.
|
||||
|
||||
# Multiple tensor indices that are not consecutive
|
||||
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
|
||||
# Workarounds: transpose `data` such that tensor indices are consecutive.
|
||||
|
||||
# Tensor indices that includes negative values.
|
||||
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
|
||||
# Workarounds: use positive index values.
|
||||
|
||||
# Implicit broadcasting required for new_data.
|
||||
data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
|
||||
# Workarounds: expand new_data explicitly.
|
||||
# Example:
|
||||
# data shape: [3, 4, 5]
|
||||
# new_data shape: [5]
|
||||
# expected new_data shape after broadcasting: [2, 2, 2, 5]
|
||||
|
||||
Adding support for operators
|
||||
----------------------------
|
||||
|
||||
When exporting a model that includes unsupported operators, you'll see an error message like:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
RuntimeError: ONNX export failed: Couldn't export operator foo
|
||||
|
||||
When that happens, there are a few things you can do:
|
||||
|
||||
#. Change the model to not use that operator.
|
||||
#. Create a symbolic function to convert the operator and register it as a custom symbolic function.
|
||||
#. Contribute to PyTorch to add the same symbolic function to :mod:`torch.onnx` itself.
|
||||
|
||||
If you decided to implement a symbolic function (we hope you will contribute it back to PyTorch!), here is how you can get started:
|
||||
|
||||
ONNX exporter internals
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
A "symbolic function" is a function that decomposes a PyTorch operator into a
|
||||
composition of a series of ONNX operators.
|
||||
|
||||
During export, each node (which contains a PyTorch operator) in the TorchScript
|
||||
graph is visited by the exporter in topological order.
|
||||
Upon visiting a node, the exporter looks for a registered symbolic functions for
|
||||
that operator. Symbolic functions are implemented in Python. A symbolic function for
|
||||
an op named ``foo`` would look something like::
|
||||
|
||||
|
||||
def foo(
|
||||
g,
|
||||
input_0: torch._C.Value,
|
||||
input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
|
||||
"""
|
||||
Adds the ONNX operations representing this PyTorch function by updating the
|
||||
graph g with `g.op()` calls.
|
||||
|
||||
Args:
|
||||
g (Graph): graph to write the ONNX representation into.
|
||||
input_0 (Value): value representing the variables which contain
|
||||
the first input for this operator.
|
||||
input_1 (Value): value representing the variables which contain
|
||||
the second input for this operator.
|
||||
|
||||
Returns:
|
||||
A Value or List of Values specifying the ONNX nodes that compute something
|
||||
equivalent to the original PyTorch operator with the given inputs.
|
||||
|
||||
None if it cannot be converted to ONNX.
|
||||
"""
|
||||
...
|
||||
|
||||
The ``torch._C`` types are Python wrappers around the types defined in C++ in
|
||||
`ir.h <https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/ir/ir.h>`_.
|
||||
|
||||
The process for adding a symbolic function depends on the type of operator.
|
||||
|
||||
.. _adding-support-aten:
|
||||
|
||||
ATen operators
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
`ATen <https://pytorch.org/cppdocs/#aten>`_ is PyTorch's built-in tensor library.
|
||||
If the operator is an ATen operator (shows up in the TorchScript graph with the prefix
|
||||
``aten::``), make sure it is not supported already.
|
||||
|
||||
List of supported operators
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Visit the auto generated :doc:`list of supported TorchScript operators <../onnx_torchscript_supported_aten_ops>`
|
||||
for details on which operator are supported in each ``opset_version``.
|
||||
|
||||
Adding support for an aten or quantized operator
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
If the operator is not in the list above:
|
||||
|
||||
* Define the symbolic function in ``torch/onnx/symbolic_opset<version>.py``, for example
|
||||
`torch/onnx/symbolic_opset9.py <https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset9.py>`_.
|
||||
Make sure the function has the same name as the ATen function, which may be declared in
|
||||
``torch/_C/_VariableFunctions.pyi`` or ``torch/nn/functional.pyi`` (these files are generated at
|
||||
build time, so will not appear in your checkout until you build PyTorch).
|
||||
* By default, the first arg is the ONNX graph.
|
||||
Other arg names must EXACTLY match the names in the ``.pyi`` file,
|
||||
because dispatch is done with keyword arguments.
|
||||
* In the symbolic function, if the operator is in the
|
||||
`ONNX standard operator set <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
|
||||
we only need to create a node to represent the ONNX operator in the graph.
|
||||
If not, we can compose several standard operators that have the
|
||||
equivalent semantics to the ATen operator.
|
||||
|
||||
Here is an example of handling missing symbolic function for the ``ELU`` operator.
|
||||
|
||||
If we run the following code::
|
||||
|
||||
print(
|
||||
torch.jit.trace(
|
||||
torch.nn.ELU(), # module
|
||||
torch.ones(1) # example input
|
||||
).graph
|
||||
)
|
||||
|
||||
We see something like::
|
||||
|
||||
graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
|
||||
%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
|
||||
%4 : float = prim::Constant[value=1.]()
|
||||
%5 : int = prim::Constant[value=1]()
|
||||
%6 : int = prim::Constant[value=1]()
|
||||
%7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
|
||||
return (%7)
|
||||
|
||||
Since we see ``aten::elu`` in the graph, we know this is an ATen operator.
|
||||
|
||||
We check the `ONNX operator list <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
|
||||
and confirm that ``Elu`` is standardized in ONNX.
|
||||
|
||||
We find a signature for ``elu`` in ``torch/nn/functional.pyi``::
|
||||
|
||||
def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...
|
||||
|
||||
We add the following lines to ``symbolic_opset9.py``::
|
||||
|
||||
def elu(g, input: torch.Value, alpha: torch.Value, inplace: bool = False):
|
||||
return g.op("Elu", input, alpha_f=alpha)
|
||||
|
||||
Now PyTorch is able to export models containing the ``aten::elu`` operator!
|
||||
|
||||
See the ``torch/onnx/symbolic_opset*.py`` files for more examples.
|
||||
|
||||
|
||||
torch.autograd.Functions
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
If the operator is a sub-class of :class:`torch.autograd.Function`, there are three ways
|
||||
to export it.
|
||||
|
||||
Static Symbolic Method
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
You can add a static method named ``symbolic`` to your function class. It should return
|
||||
ONNX operators that represent the function's behavior in ONNX. For example::
|
||||
|
||||
class MyRelu(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
|
||||
ctx.save_for_backward(input)
|
||||
return input.clamp(min=0)
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
|
||||
return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))
|
||||
|
||||
.. FIXME(justinchuby): PythonOps are too complicated and the example below
|
||||
.. uses private methods we do not expose. We are looking to
|
||||
.. improve the experience. Since SymbolicContext is deprecated, we think
|
||||
.. defining a symbolic staticmethod is a better way to go for now.
|
||||
|
||||
.. PythonOp Symbolic
|
||||
.. ~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. Alternatively, you can register a custom symbolic function.
|
||||
.. This gives the symbolic function access to more info through the
|
||||
.. ``torch.onnx.SymbolicContext`` object, which gets passed in as the first
|
||||
.. argument (before the ``Graph`` object).
|
||||
|
||||
.. All autograd ``Function``\ s appear in the TorchScript graph as ``prim::PythonOp`` nodes.
|
||||
.. In order to differentiate between different ``Function`` subclasses, the
|
||||
.. symbolic function should use the ``name`` kwarg which gets set to the name of the class.
|
||||
|
||||
.. Custom symbolic functions should add type and shape information by calling ``setType(...)``
|
||||
.. on Value objects before returning them (implemented in C++ by
|
||||
.. . ``torch::jit::Value::setType``). This is not required, but it can help the exporter's
|
||||
.. shape and type inference for down-stream nodes. For a non-trivial example of ``setType``, see
|
||||
.. ``test_aten_embedding_2`` in
|
||||
.. `test_operators.py <https://github.com/pytorch/pytorch/blob/main/test/onnx/test_operators.py>`_.
|
||||
|
||||
.. The example below shows how you can access ``requires_grad`` via the ``Node`` object:
|
||||
|
||||
.. class MyClip(torch.autograd.Function):
|
||||
.. @staticmethod
|
||||
.. def forward(ctx, input, min):
|
||||
.. ctx.save_for_backward(input)
|
||||
.. return input.clamp(min=min)
|
||||
|
||||
.. class MyRelu(torch.autograd.Function):
|
||||
.. @staticmethod
|
||||
.. def forward(ctx, input):
|
||||
.. ctx.save_for_backward(input)
|
||||
.. return input.clamp(min=0)
|
||||
|
||||
.. def symbolic_python_op(g: "GraphContext", *args, **kwargs):
|
||||
.. n = ctx.cur_node
|
||||
.. print("original node: ", n)
|
||||
.. for i, out in enumerate(n.outputs()):
|
||||
.. print("original output {}: {}, requires grad: {}".format(i, out, out.requiresGrad()))
|
||||
.. import torch.onnx.symbolic_helper as sym_helper
|
||||
.. for i, arg in enumerate(args):
|
||||
.. requires_grad = arg.requiresGrad() if sym_helper._is_value(arg) else False
|
||||
.. print("arg {}: {}, requires grad: {}".format(i, arg, requires_grad))
|
||||
|
||||
.. name = kwargs["name"]
|
||||
.. ret = None
|
||||
.. if name == "MyClip":
|
||||
.. ret = g.op("Clip", args[0], args[1])
|
||||
.. elif name == "MyRelu":
|
||||
.. ret = g.op("Relu", args[0])
|
||||
.. else:
|
||||
.. # Logs a warning and returns None
|
||||
.. return _unimplemented("prim::PythonOp", "unknown node kind: " + name)
|
||||
.. # Copy type and shape from original node.
|
||||
.. ret.setType(n.type())
|
||||
.. return ret
|
||||
|
||||
.. from torch.onnx import register_custom_op_symbolic
|
||||
.. . register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1)
|
||||
|
||||
Inline Autograd Function
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
In cases where a static symbolic method is not provided for its subsequent :class:`torch.autograd.Function` or
|
||||
where a function to register ``prim::PythonOp`` as custom symbolic functions is not provided,
|
||||
:func:`torch.onnx.export` tries to inline the graph that corresponds to that :class:`torch.autograd.Function` such that
|
||||
this function is broken down into individual operators that were used within the function.
|
||||
The export should be successful as long as these individual operators are supported. For example::
|
||||
|
||||
class MyLogExp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
|
||||
ctx.save_for_backward(input)
|
||||
h = input.exp()
|
||||
return h.log().log()
|
||||
|
||||
There is no static symbolic method present for this model, yet it is exported as follows::
|
||||
|
||||
graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
|
||||
%1 : float = onnx::Exp[](%input)
|
||||
%2 : float = onnx::Log[](%1)
|
||||
%3 : float = onnx::Log[](%2)
|
||||
return (%3)
|
||||
|
||||
If you need to avoid inlining of :class:`torch.autograd.Function`, you should export models with
|
||||
``operator_export_type`` set to ``ONNX_FALLTHROUGH`` or ``ONNX_ATEN_FALLBACK``.
|
||||
|
||||
Custom operators
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
You can export your model with custom operators that includes a combination of many standard ONNX ops,
|
||||
or are driven by self-defined C++ backend.
|
||||
|
||||
ONNX-script functions
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
If an operator is not a standard ONNX op, but can be composed of multiple existing ONNX ops, you can utilize
|
||||
`ONNX-script <https://github.com/microsoft/onnx-script>`_ to create an external ONNX function to support the operator.
|
||||
You can export it by following this example::
|
||||
|
||||
import onnxscript
|
||||
# There are three opset version needed to be aligned
|
||||
# This is (1) the opset version in ONNX function
|
||||
from onnxscript.onnx_opset import opset15 as op
|
||||
opset_version = 15
|
||||
|
||||
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
||||
model = torch.nn.SELU()
|
||||
|
||||
custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
|
||||
|
||||
@onnxscript.script(custom_opset)
|
||||
def Selu(X):
|
||||
alpha = 1.67326 # auto wrapped as Constants
|
||||
gamma = 1.0507
|
||||
alphaX = op.CastLike(alpha, X)
|
||||
gammaX = op.CastLike(gamma, X)
|
||||
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
|
||||
pos = gammaX * X
|
||||
zero = op.CastLike(0, X)
|
||||
return op.Where(X <= zero, neg, pos)
|
||||
|
||||
# setType API provides shape/type to ONNX shape/type inference
|
||||
def custom_selu(g: jit_utils.GraphContext, X):
|
||||
return g.onnxscript_op(Selu, X).setType(X.type())
|
||||
|
||||
# Register custom symbolic function
|
||||
# There are three opset version needed to be aligned
|
||||
# This is (2) the opset version in registry
|
||||
torch.onnx.register_custom_op_symbolic(
|
||||
symbolic_name="aten::selu",
|
||||
symbolic_fn=custom_selu,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
# There are three opset version needed to be aligned
|
||||
# This is (2) the opset version in exporter
|
||||
torch.onnx.export(
|
||||
model,
|
||||
x,
|
||||
"model.onnx",
|
||||
opset_version=opset_version,
|
||||
# only needed if you want to specify an opset version > 1.
|
||||
custom_opsets={"onnx-script": 2}
|
||||
)
|
||||
|
||||
The example above exports it as a custom operator in the "onnx-script" opset.
|
||||
When exporting a custom operator, you can specify the custom domain version using the
|
||||
``custom_opsets`` dictionary at export. If not specified, the custom opset version defaults to 1.
|
||||
|
||||
NOTE: Be careful to align the opset version mentioned in the above example, and make sure they are consumed in exporter step.
|
||||
The example usage of how to write a onnx-script function is a beta version in terms of the active development on onnx-script.
|
||||
Please follow the latest `ONNX-script <https://github.com/microsoft/onnx-script>`_
|
||||
|
||||
C++ Operators
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
If a model uses a custom operator implemented in C++ as described in
|
||||
`Extending TorchScript with Custom C++ Operators <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_,
|
||||
you can export it by following this example::
|
||||
|
||||
from torch.onnx import symbolic_helper
|
||||
|
||||
|
||||
# Define custom symbolic function
|
||||
@symbolic_helper.parse_args("v", "v", "f", "i")
|
||||
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
|
||||
return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)
|
||||
|
||||
|
||||
# Register custom symbolic function
|
||||
torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)
|
||||
|
||||
|
||||
class FooModel(torch.nn.Module):
|
||||
def __init__(self, attr1, attr2):
|
||||
super().__init__()
|
||||
self.attr1 = attr1
|
||||
self.attr2 = attr2
|
||||
|
||||
def forward(self, input1, input2):
|
||||
# Calling custom op
|
||||
return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)
|
||||
|
||||
|
||||
model = FooModel(attr1, attr2)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(example_input1, example_input1),
|
||||
"model.onnx",
|
||||
# only needed if you want to specify an opset version > 1.
|
||||
custom_opsets={"custom_domain": 2}
|
||||
)
|
||||
|
||||
The example above exports it as a custom operator in the "custom_domain" opset.
|
||||
When exporting a custom operator, you can specify the custom domain version using the
|
||||
``custom_opsets`` dictionary at export. If not specified, the custom opset version defaults to 1.
|
||||
|
||||
The runtime that consumes the model needs to support the custom op. See
|
||||
`Caffe2 custom ops <https://caffe2.ai/docs/custom-operators.html>`_,
|
||||
`ONNX Runtime custom ops <https://onnxruntime.ai/docs/reference/operators/add-custom-op.html>`_,
|
||||
or your runtime of choice's documentation.
|
||||
|
||||
|
||||
Discovering all unconvertible ATen ops at once
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
When export fails due to an unconvertible ATen op, there may in fact be more
|
||||
than one such op but the error message only mentions the first. To discover
|
||||
all of the unconvertible ops in one go you can::
|
||||
|
||||
# prepare model, args, opset_version
|
||||
...
|
||||
|
||||
torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
|
||||
model, args, opset_version=opset_version
|
||||
)
|
||||
|
||||
print(set(unconvertible_ops))
|
||||
|
||||
The set is approximated because some ops may be removed during the conversion
|
||||
process and don't need to be converted. Some other ops may have partial support
|
||||
that will fail conversion with particular inputs, but this should give you a
|
||||
general idea of what ops are not supported. Please feel free to open GitHub Issues
|
||||
for op support requests.
|
||||
|
||||
Frequently Asked Questions
|
||||
--------------------------
|
||||
Q: I have exported my LSTM model, but its input size seems to be fixed?
|
||||
|
||||
The tracer records the shapes of the example inputs. If the model should accept
|
||||
inputs of dynamic shapes, set ``dynamic_axes`` when calling :func:`torch.onnx.export`.
|
||||
|
||||
Q: How to export models containing loops?
|
||||
|
||||
See `Tracing vs Scripting`_.
|
||||
|
||||
Q: How to export models with primitive type inputs (e.g. int, float)?
|
||||
|
||||
Support for primitive numeric type inputs was added in PyTorch 1.9.
|
||||
However, the exporter does not support models with str inputs.
|
||||
|
||||
Q: Does ONNX support implicit scalar datatype casting?
|
||||
|
||||
The ONNX standard does not, but the exporter will try to handle that part.
|
||||
Scalars are exported as constant tensors.
|
||||
The exporter will figure out the right data type for scalars. In rare cases when it is unable
|
||||
to do so, you will need to manually specify the datatype with e.g. `dtype=torch.float32`.
|
||||
If you see any errors, please [create a GitHub issue](https://github.com/pytorch/pytorch/issues).
|
||||
|
||||
Q: Are lists of Tensors exportable to ONNX?
|
||||
|
||||
Yes, for ``opset_version`` >= 11, since ONNX introduced the Sequence type in opset 11.
|
||||
|
||||
Python API
|
||||
----------
|
||||
|
||||
.. automodule:: torch.onnx
|
||||
|
||||
Functions
|
||||
^^^^^^^^^
|
||||
|
||||
.. autofunction:: export
|
||||
.. autofunction:: export_to_pretty_string
|
||||
.. autofunction:: register_custom_op_symbolic
|
||||
.. autofunction:: unregister_custom_op_symbolic
|
||||
.. autofunction:: select_model_mode_for_export
|
||||
.. autofunction:: is_in_onnx_export
|
||||
.. autofunction:: enable_log
|
||||
.. autofunction:: disable_log
|
||||
.. autofunction:: torch.onnx.verification.find_mismatch
|
||||
|
||||
Classes
|
||||
^^^^^^^
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
JitScalarType
|
||||
torch.onnx.verification.GraphInfo
|
||||
torch.onnx.verification.VerificationOptions
|
@ -5,7 +5,7 @@ ONNX supported TorchScript operators
|
||||
|
||||
.. This file is automatically generated during the documentation build
|
||||
.. by cross referencing ONNX operator symbolics with TorchScript operators via
|
||||
.. ``docs/source/scripts/build_onnx_supported_aten_op_csv_table.py``.
|
||||
.. ``docs/source/scripts/build_onnx_torchscript_supported_aten_op_csv_table.py``.
|
||||
.. Do not modify directly and instead `rebuild the docs <https://github.com/pytorch/pytorch#building-the-documentation>`_.
|
||||
|
||||
This page lists the TorchScript operators that are supported/unsupported by ONNX export.
|
@ -56,6 +56,8 @@ Some of the most commonly used backends include:
|
||||
- CUDA graphs with AOT Autograd. `Read more <https://github.com/pytorch/torchdynamo/pull/757>`__
|
||||
* - ``torch.compile(m, backend="ipex")``
|
||||
- Uses IPEX on CPU. `Read more <https://github.com/intel/intel-extension-for-pytorch>`__
|
||||
* - ``torch.compile(m, backend="onnxrt")``
|
||||
- Uses ONNX Runtime for training on CPU/GPU. :doc:`Read more <onnx_dynamo_onnxruntime_backend>`
|
||||
|
||||
**Inference-only backends**
|
||||
|
||||
@ -65,10 +67,10 @@ Some of the most commonly used backends include:
|
||||
|
||||
* - Backend
|
||||
- Description
|
||||
* - ``torch.compile(m, backend="onnxrt")``
|
||||
- Uses ONNXRT for inference on CPU/GPU. `Read more <https://onnxruntime.ai/>`__
|
||||
* - ``torch.compile(m, backend="tensorrt")``
|
||||
- Uses ONNXRT to run TensorRT for inference optimizations. `Read more <https://github.com/onnx/onnx-tensorrt>`__
|
||||
- Uses ONNX Runtime to run TensorRT for inference optimizations. `Read more <https://github.com/onnx/onnx-tensorrt>`__
|
||||
* - ``torch.compile(m, backend="ipex")``
|
||||
- Uses IPEX for inference on CPU. `Read more <https://github.com/intel/intel-extension-for-pytorch>`__
|
||||
* - ``torch.compile(m, backend="tvm")``
|
||||
- Uses Apache TVM for inference optimizations. `Read more <https://tvm.apache.org/>`__
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
"""ONNX exporter."""
|
||||
|
||||
from torch import _C
|
||||
from torch._C import _onnx as _C_onnx
|
||||
from torch._C._onnx import (
|
||||
|
@ -269,13 +269,16 @@ class Invocation:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DiagnosticOptions:
|
||||
"""
|
||||
Options for diagnostic context.
|
||||
"""Options for diagnostic context.
|
||||
|
||||
Attributes:
|
||||
verbosity_level: Set the amount of information logged for each diagnostics,
|
||||
equivalent to the 'level' in Python logging module.
|
||||
warnings_as_errors: When True, warning diagnostics are treated as error diagnostics.
|
||||
"""
|
||||
|
||||
verbosity_level: int = dataclasses.field(default=logging.INFO)
|
||||
"""Diagnostic context verbosity level, equivalent to the 'level' in Python logging module.
|
||||
Controls the amount of information logged inside each diagnostics."""
|
||||
"""Set the amount of information logged for each diagnostics, equivalent to the 'level' in Python logging module."""
|
||||
|
||||
warnings_as_errors: bool = dataclasses.field(default=False)
|
||||
"""If True, warning diagnostics are treated as error diagnostics."""
|
||||
|
@ -69,8 +69,8 @@ else:
|
||||
|
||||
_DEFAULT_OPSET_VERSION: Final[int] = 18
|
||||
"""The default ONNX opset version the exporter will use if one is not specified explicitly
|
||||
through ``ExportOptions``. This should NEVER be accessed outside of this module! Users
|
||||
should reference ``ExportOptions.opset_version``."""
|
||||
through :class:`ExportOptions`. This should NEVER be accessed outside of this module! Users
|
||||
should reference :attr:`ExportOptions.opset_version`."""
|
||||
|
||||
_PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues"
|
||||
"""The URL to the PyTorch GitHub issues page."""
|
||||
@ -89,15 +89,15 @@ class ONNXFakeContext:
|
||||
"""A dataclass used to store context for model export using FakeTensor.
|
||||
|
||||
This dataclass stores the FakeTensorMode instance used to convert
|
||||
real tensors and model parameters into fake tensors. This ``fake_mode`` is
|
||||
reused internally during tracing of a ``torch.nn.Module`` into a FX ``GraphModule``.
|
||||
real tensors and model parameters into fake tensors. This :attr:`ONNXFakeContext.fake_mode` is
|
||||
reused internally during tracing of a :class:`torch.nn.Module` into a FX :class:`GraphModule`.
|
||||
"""
|
||||
|
||||
fake_mode: fake_tensor.FakeTensorMode
|
||||
"""The fake tensor mode used for tracing model using fake tensors and parameters."""
|
||||
|
||||
state_dict_paths: Optional[Tuple[Union[str, io.BytesIO]]] = None
|
||||
"""List of paths of files that contain the model `state_dict`"""
|
||||
"""List of paths of files that contain the model :meth:`state_dict`"""
|
||||
|
||||
|
||||
class OnnxRegistry:
|
||||
@ -271,18 +271,29 @@ class OnnxRegistry:
|
||||
|
||||
|
||||
class ExportOptions:
|
||||
"""Options to influence the TorchDynamo ONNX exporter."""
|
||||
"""Options to influence the TorchDynamo ONNX exporter.
|
||||
|
||||
Attributes:
|
||||
dynamic_shapes: Shape information hint for input/output tensors.
|
||||
When ``None``, the exporter determines the most compatible setting.
|
||||
When ``True``, all input shapes are considered dynamic.
|
||||
When ``False``, all input shapes are considered static.
|
||||
op_level_debug: Whether to export the model with op-level debug information
|
||||
diagnostic_options: The diagnostic options for the exporter.
|
||||
fake_context: The fake context used for symbolic tracing.
|
||||
onnx_registry: The ONNX registry used to register ATen operators to ONNX functions.
|
||||
"""
|
||||
|
||||
dynamic_shapes: Optional[bool] = None
|
||||
"""Shape information hint for input/output tensors.
|
||||
|
||||
- ``None``: the exporter determines the most compatible setting.
|
||||
- ``True``: all input shapes are considered dynamic.
|
||||
- ``False``: all input shapes are considered static."""
|
||||
- ``False``: all input shapes are considered static.
|
||||
"""
|
||||
|
||||
op_level_debug: Optional[bool] = None
|
||||
"""Whether to export the model with op-level debug information by evaluating
|
||||
ops through ONNX Runtime."""
|
||||
"""When True export the model with op-level debug running ops through ONNX Runtime."""
|
||||
|
||||
diagnostic_options: DiagnosticOptions
|
||||
"""The diagnostic options for the exporter."""
|
||||
@ -291,8 +302,7 @@ class ExportOptions:
|
||||
"""The fake context used for symbolic tracing."""
|
||||
|
||||
onnx_registry: Optional[OnnxRegistry] = None
|
||||
"""The ONNX registry used to register ATen operators to ONNX functions. Defaults to
|
||||
opset18."""
|
||||
"""The ONNX registry used to register ATen operators to ONNX functions."""
|
||||
|
||||
@_beartype.beartype
|
||||
def __init__(
|
||||
@ -312,8 +322,8 @@ class ExportOptions:
|
||||
|
||||
|
||||
class ResolvedExportOptions(ExportOptions):
|
||||
"""Consolidates `ExportOptions` with default values.
|
||||
All unspecified options from `ExportOptions` are assigned a default value.
|
||||
"""Consolidates :class:`ExportOptions` with default values.
|
||||
All unspecified options from :class:`ExportOptions` are assigned a default value.
|
||||
This is an internal class and its API may be changed at any time without notice.
|
||||
"""
|
||||
|
||||
@ -412,12 +422,12 @@ class ResolvedExportOptions(ExportOptions):
|
||||
def enable_fake_mode():
|
||||
"""Enable fake mode for the duration of the context.
|
||||
|
||||
Internally it instantiates a `FakeTensorMode` context manager that converts
|
||||
user input and model parameters into `FakeTensor`.
|
||||
Internally it instantiates a :class:`torch._subclasses.fake_tensor.FakeTensorMode` context manager
|
||||
that converts user input and model parameters into :class:`torch._subclasses.fake_tensor.FakeTensor`.
|
||||
|
||||
A [FakeTensor](https://github.com/pytorch/pytorch/blob/main/torch/_subclasses/fake_tensor.py#L870)
|
||||
is a `torch.Tensor` with the ability to run PyTorch code without having to
|
||||
actually do computation through tensors allocated on a `meta` device. Because
|
||||
A :class:`torch._subclasses.fake_tensor.FakeTensor`
|
||||
is a :class:`torch.Tensor` with the ability to run PyTorch code without having to
|
||||
actually do computation through tensors allocated on a ``meta`` device. Because
|
||||
there is no actual data being allocated on the device, this API allows for
|
||||
exporting large models without the actual memory footprint needed for executing it.
|
||||
|
||||
@ -425,8 +435,8 @@ def enable_fake_mode():
|
||||
are too large to fit into memory.
|
||||
|
||||
Returns:
|
||||
A `ONNXFakeContext` object that must be passed to `torch.onnx.dynamo_export`
|
||||
through the `ExportOptions.fake_context` argument.
|
||||
A :class:`ONNXFakeContext` object that must be passed to :func:`dynamo_export`
|
||||
through the :attr:`ExportOptions.fake_context` argument.
|
||||
|
||||
Example::
|
||||
|
||||
@ -499,7 +509,7 @@ class ExportOutputSerializer(Protocol):
|
||||
|
||||
Example:
|
||||
|
||||
A simple serializer that writes the exported ``onnx.ModelProto`` in Protobuf
|
||||
A simple serializer that writes the exported :py:obj:`onnx.ModelProto` in Protobuf
|
||||
format to ``destination``:
|
||||
|
||||
::
|
||||
@ -604,7 +614,7 @@ class ExportOutput:
|
||||
|
||||
@property
|
||||
def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined]
|
||||
"""The exported ONNX model as an ``onnx.ModelProto``."""
|
||||
"""The exported ONNX model as an :py:obj:`onnx.ModelProto`."""
|
||||
|
||||
if self._export_exception is not None:
|
||||
raise self._export_exception
|
||||
@ -746,8 +756,8 @@ class ExportOutput:
|
||||
will be created to store the each initializer of the ONNX model in a separate file. For example, if the
|
||||
destination is "/path/model.onnx", the initializers will be saved in "/path/model_initializers/" folder.
|
||||
model_state_dict: The state_dict of the PyTorch model containing all weights on it.
|
||||
It can be either a dict as returned by `model.state_dict()`, or a string with a file name.
|
||||
Required when ``enable_fake_mode`` is used but real initializers are needed on the ONNX graph.
|
||||
It can be either a dict as returned by :meth:`model.state_dict`, or a string with a file name.
|
||||
Required when :func:`enable_fake_mode` is used but real initializers are needed on the ONNX graph.
|
||||
It can be either a string with the path to a checkpoint or a dictionary with the actual model state.
|
||||
|
||||
serializer: The serializer to use. If not specified, the model will be serialized as Protobuf.
|
||||
@ -792,7 +802,7 @@ class ExportOutput:
|
||||
if _model_state_dict_files:
|
||||
if not isinstance(destination, str):
|
||||
raise RuntimeError(
|
||||
"`destination` must be a string with a path when model_state_dict is specified."
|
||||
"`destination` must be a string with a path when `model_state_dict` is specified."
|
||||
)
|
||||
destination_path, destination_filename = os.path.split(destination)
|
||||
onnx_model_location = destination_filename
|
||||
@ -845,10 +855,10 @@ class ExportOutput:
|
||||
diagnostic_context: diagnostics.DiagnosticContext,
|
||||
) -> Self:
|
||||
"""
|
||||
Creates an instance of ``ExportOutput`` when the export process encounters a failure.
|
||||
Creates an instance of :class:`ExportOutput` when the export process encounters a failure.
|
||||
|
||||
In case of a failed export, this method is used to encapsulate the exception
|
||||
and associated diagnostic context within an ``ExportOutput`` instance for
|
||||
and associated diagnostic context within an :class:`ExportOutput` instance for
|
||||
easier handling and debugging.
|
||||
|
||||
Args:
|
||||
@ -856,7 +866,7 @@ class ExportOutput:
|
||||
diagnostic_context: The context associated with diagnostics during export.
|
||||
|
||||
Returns:
|
||||
An instance of ``ExportOutput`` representing the failed export output.
|
||||
An instance of :class:`ExportOutput` representing the failed export output.
|
||||
"""
|
||||
# Defer `import onnx` out of `import torch` path
|
||||
# https://github.com/pytorch/pytorch/issues/103764
|
||||
@ -1034,7 +1044,7 @@ class OnnxExporterError(RuntimeError):
|
||||
"""Raised when an ONNX exporter error occurs.
|
||||
|
||||
This exception is thrown when there's an error during the ONNX export process.
|
||||
It encapsulates the `ExportOutput` object generated until the failure, allowing
|
||||
It encapsulates the :class:`ExportOutput` object generated until the failure, allowing
|
||||
access to the partial export results and associated metadata.
|
||||
"""
|
||||
|
||||
@ -1113,19 +1123,59 @@ def dynamo_export(
|
||||
Returns:
|
||||
An in-memory representation of the exported ONNX model.
|
||||
|
||||
Example:
|
||||
**Example 1 - Simplest export**
|
||||
::
|
||||
|
||||
import torch.onnx
|
||||
torch.onnx.dynamo_export(
|
||||
my_nn_module,
|
||||
torch.randn(2, 2, 2), # positional input 1
|
||||
torch.randn(2, 2, 2), # positional input 2
|
||||
my_nn_module_attribute="hello", # keyword input
|
||||
export_options=ExportOptions(
|
||||
dynamic_shapes=True,
|
||||
)
|
||||
).save("my_model.onnx")
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
def forward(self, x, bias=None):
|
||||
out = self.linear(x)
|
||||
out = out + bias
|
||||
return out
|
||||
model = MyModel()
|
||||
kwargs = {"bias": 3.}
|
||||
args = (torch.randn(2, 2, 2),)
|
||||
export_output = torch.onnx.dynamo_export(
|
||||
model,
|
||||
*args,
|
||||
**kwargs).save("my_simple_model.onnx")
|
||||
|
||||
**Example 2 - Exporting with dynamic shapes**
|
||||
::
|
||||
|
||||
# The previous model can be exported with dynamic shapes
|
||||
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
||||
export_output = torch.onnx.dynamo_export(
|
||||
model,
|
||||
*args,
|
||||
**kwargs,
|
||||
export_options=export_options)
|
||||
export_output.save("my_dynamic_model.onnx")
|
||||
|
||||
|
||||
By printing input dynamic dimensions we can see the input shape is no longer (2,2,2)
|
||||
::
|
||||
|
||||
>>> print(export_output.model_proto.graph.input[0])
|
||||
name: "arg0"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_param: "arg0_dim_0"
|
||||
}
|
||||
dim {
|
||||
dim_param: "arg0_dim_1"
|
||||
}
|
||||
dim {
|
||||
dim_param: "arg0_dim_2"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
resolved_export_options = (
|
||||
|
@ -1641,7 +1641,7 @@ class InsertTypePromotion(_pass.Transform):
|
||||
metadata, specifically the fake tensor stored under node.meta["val"], and ensure it
|
||||
reflects the latest changes.
|
||||
|
||||
See [FXE0015: fx_node_insert_type_promotion](https://pytorch.org/docs/master/generated/onnx_diagnostics_rules/FXE0015%3Afx-node-insert-type-promotion.html) for more details. # noqa: B950
|
||||
See [FXE0015: fx_node_insert_type_promotion](https://pytorch.org/docs/master/generated/onnx_dynamo_diagnostics_rules/FXE0015%3Afx-node-insert-type-promotion.html) for more details. # noqa: B950
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
Reference in New Issue
Block a user