Refactor torch.onnx documentation (#108379)

* Distinguish both TorchScript-based exporter (`torch.onnx.export`) and the TorchDynamo-based exporter (`torch.onnx.dynamo_export`) exporters
* Merge ONNX diagnostics page with the exporter page
* Add initial version of a quick overview on the new exporter
* Updates `torch.compiler.html` with the right page for the ONNX Runtime backend for `torch.compile`
* Renamed doc files to clearly identify files belonging to the legacy and newer onnx exporters

Fixes #108274

https://docs-preview.pytorch.org/pytorch/pytorch/108379/index.html
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108379
Approved by: https://github.com/justinchuby, https://github.com/wschin, https://github.com/malfet
This commit is contained in:
Thiago Crepaldi
2023-09-08 18:23:45 +00:00
committed by PyTorch MergeBot
parent e91f66471c
commit aa3355da8a
19 changed files with 1029 additions and 801 deletions

View File

@ -7,6 +7,7 @@
- docs/source/onnx.rst
- docs/source/onnx*
- docs/source/scripts/onnx/**
- docs/source/_static/img/onnx/**
- scripts/onnx/**
- test/onnx/**
- tools/onnx/**

View File

@ -18,8 +18,8 @@ figures:
@$(PYCMD) source/scripts/build_quantization_configs.py
onnx:
@$(PYCMD) source/scripts/onnx/build_onnx_supported_aten_op_csv_table.py
@$(PYCMD) source/scripts/onnx/build_onnx_diagnostics_rules_md.py $(SOURCEDIR)/generated/onnx_diagnostics_rules
@$(PYCMD) source/scripts/onnx/build_onnx_torchscript_supported_aten_op_csv_table.py
@$(PYCMD) source/scripts/onnx/build_onnx_dynamo_diagnostics_rules_md.py $(SOURCEDIR)/generated/onnx_dynamo_diagnostics_rules
opset:
@$(PYCMD) source/scripts/build_opsets.py

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.7 KiB

View File

@ -94,7 +94,6 @@ Features described in this documentation are classified by release status:
profiler
nn.init
onnx
onnx_diagnostics
optim
complex_numbers
ddp_comm_hooks

View File

@ -1,745 +1,64 @@
torch.onnx
==========
.. contents:: :local:
.. automodule:: torch.onnx
Overview
--------
`Open Neural Network eXchange (ONNX) <https://onnx.ai/>`_ is an open standard
format for representing machine learning models. The torch.onnx module can export
PyTorch models to ONNX. The model can then be consumed by any of the many
`runtimes that support ONNX <https://onnx.ai/supported-tools.html#deployModel>`_.
format for representing machine learning models. The ``torch.onnx`` module captures the computation graph from a
native PyTorch :class:`torch.nn.Module` model and converts it into an
`ONNX graph <https://github.com/onnx/onnx/blob/main/docs/IR.md>`_.
Example: AlexNet from PyTorch to ONNX
-------------------------------------
The exported model can be consumed by any of the many
`runtimes that support ONNX <https://onnx.ai/supported-tools.html#deployModel>`_, including
Microsoft's `ONNX Runtime <https://www.onnxruntime.ai>`_.
Here is a simple script which exports a pretrained AlexNet to an ONNX file named ``alexnet.onnx``.
The call to ``torch.onnx.export`` runs the model once to trace its execution and then exports the
traced model to the specified file::
**There are two flavors of ONNX exporter API that you can use, as listed below:**
import torch
import torchvision
TorchDynamo-based ONNX Exporter
-------------------------------
dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
model = torchvision.models.alexnet(pretrained=True).cuda()
*The TorchDynamo-based ONNX exporter is the newest (and Beta) exporter for PyTorch 2.0 and newer*
# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]
TorchDynamo engine is leveraged to hook into Python's frame evaluation API and dynamically rewrite its
bytecode into an FX Graph. The resulting FX Graph is then polished before it is finally translated into an
ONNX graph.
torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)
The main advantage of this approach is that the `FX graph <https://pytorch.org/docs/stable/fx.html>`_ is captured using
bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques.
The resulting ``alexnet.onnx`` file contains a binary `protocol buffer <https://developers.google.com/protocol-buffers/>`_
which contains both the network structure and parameters of the model you exported
(in this case, AlexNet). The argument ``verbose=True`` causes the
exporter to print out a human-readable representation of the model::
:doc:`Learn more about the TorchDynamo-based ONNX Exporter <onnx_dynamo>`
# These are the inputs and parameters to the network, which have taken on
# the names we specified earlier.
graph(%actual_input_1 : Float(10, 3, 224, 224)
%learned_0 : Float(64, 3, 11, 11)
%learned_1 : Float(64)
%learned_2 : Float(192, 64, 5, 5)
%learned_3 : Float(192)
# ---- omitted for brevity ----
%learned_14 : Float(1000, 4096)
%learned_15 : Float(1000)) {
# Every statement consists of some output tensors (and their types),
# the operator to be run (with its attributes, e.g., kernels, strides,
# etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
%17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
%18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
%19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
# ---- omitted for brevity ----
%29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
# Dynamic means that the shape is not known. This may be because of a
# limitation of our implementation (which we would like to fix in a
# future release) or shapes which are truly dynamic.
%30 : Dynamic = onnx::Shape(%29), scope: AlexNet
%31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
%32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
%33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
# ---- omitted for brevity ----
%output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
return (%output1);
}
TorchScript-based ONNX Exporter
-------------------------------
You can also verify the output using the `ONNX <https://github.com/onnx/onnx/>`_ library,
which you can install using ``pip``::
*The TorchScript-based ONNX exporter is available since PyTorch 1.2.0*
pip install onnx
`TorchScript <https://pytorch.org/docs/stable/jit.html>`_ is leveraged to trace (through :func:`torch.jit.trace`)
the model and capture a static computation graph.
Then, you can run::
As a consequence, the resulting graph has a couple limitations:
import onnx
* It does not record any control-flow, like if-statements or loops;
* Does not handle nuances between ``training`` and ``eval`` mode;
* Does not truly handle dynamic inputs
# Load the ONNX model
model = onnx.load("alexnet.onnx")
As an attempt to support the static tracing limitations, the exporter also supports TorchScript scripting
(through :func:`torch.jit.script`), which adds support for data-dependent control-flow, for example. However, TorchScript
itself is a subset of the Python language, so not all features in Python are supported, such as in-place operations.
# Check that the model is well formed
onnx.checker.check_model(model)
:doc:`Learn more about the TorchScript-based ONNX Exporter <onnx_torchscript>`
# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))
You can also run the exported model with one of the many
`runtimes that support ONNX <https://onnx.ai/supported-tools.html#deployModel>`_.
For example after installing `ONNX Runtime <https://www.onnxruntime.ai>`_, you can
load and run the model::
import onnxruntime as ort
import numpy as np
ort_session = ort.InferenceSession("alexnet.onnx")
outputs = ort_session.run(
None,
{"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)
print(outputs[0])
Here is a more involved `tutorial on exporting a model and running it with ONNX Runtime <https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html>`_.
.. _tracing-vs-scripting:
Tracing vs Scripting
--------------------
Internally, :func:`torch.onnx.export()` requires a :class:`torch.jit.ScriptModule` rather than
a :class:`torch.nn.Module`. If the passed-in model is not already a ``ScriptModule``,
``export()`` will use *tracing* to convert it to one:
.. TODO(justinchuby): Add a word on recommending tracing over scripting for most use cases.
* **Tracing**: If ``torch.onnx.export()`` is called with a Module that is not already a
``ScriptModule``, it first does the equivalent of :func:`torch.jit.trace`, which executes the model
once with the given ``args`` and records all operations that happen during that execution. This
means that if your model is dynamic, e.g., changes behavior depending on input data, the exported
model will *not* capture this dynamic behavior.
We recommend examining the exported model and making sure the operators look
reasonable. Tracing will unroll loops and if statements, exporting a static graph that is exactly
the same as the traced run. If you want to export your model with dynamic control flow, you will
need to use *scripting*.
* **Scripting**: Compiling a model via scripting preserves dynamic control flow and is valid for inputs
of different sizes. To use scripting:
* Use :func:`torch.jit.script` to produce a ``ScriptModule``.
* Call ``torch.onnx.export()`` with the ``ScriptModule`` as the model. The ``args`` are still required,
but they will be used internally only to produce example outputs, so that the types and shapes of the
outputs can be captured. No tracing will be performed.
See `Introduction to TorchScript <https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html>`_
and `TorchScript <jit.html>`_ for more details, including how to compose tracing and scripting to suit the
particular requirements of different models.
Avoiding Pitfalls
-----------------
Avoid NumPy and built-in Python types
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
PyTorch models can be written using NumPy or Python types and functions, but
during :ref:`tracing<tracing-vs-scripting>`, any variables of NumPy or Python
types (rather than torch.Tensor) are converted to constants, which will produce
the wrong result if those values should change depending on the inputs.
For example, rather than using numpy functions on numpy.ndarrays: ::
# Bad! Will be replaced with constants during tracing.
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
np.concatenate((x, y), axis=1)
Use torch operators on torch.Tensors: ::
# Good! Tensor operations will be captured during tracing.
x, y = torch.randn(1, 2), torch.randn(1, 2)
torch.cat((x, y), dim=1)
And rather than use :func:`torch.Tensor.item` (which converts a Tensor to a Python
built-in number): ::
# Bad! y.item() will be replaced with a constant during tracing.
def forward(self, x, y):
return x.reshape(y.item(), -1)
Use torch's support for implicit casting of single-element tensors: ::
# Good! y will be preserved as a variable during tracing.
def forward(self, x, y):
return x.reshape(y, -1)
Avoid Tensor.data
^^^^^^^^^^^^^^^^^
Using the Tensor.data field can produce an incorrect trace and therefore an incorrect ONNX graph.
Use :func:`torch.Tensor.detach` instead. (Work is ongoing to
`remove Tensor.data entirely <https://github.com/pytorch/pytorch/issues/30987>`_).
Avoid in-place operations when using tensor.shape in tracing mode
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In tracing mode, shapes obtained from ``tensor.shape`` are traced as tensors,
and share the same memory. This might cause a mismatch the final output values.
As a workaround, avoid the use of inplace operations in these scenarios.
For example, in the model::
class Model(torch.nn.Module):
def forward(self, states):
batch_size, seq_length = states.shape[:2]
real_seq_length = seq_length
real_seq_length += 2
return real_seq_length + seq_length
``real_seq_length`` and ``seq_length`` share the same memory in tracing mode.
This could be avoided by rewriting the inplace operation::
real_seq_length = real_seq_length + 2
Limitations
-----------
Types
^^^^^
* Only :class:`torch.Tensors`, numeric types that can be trivially converted to torch.Tensors (e.g. float, int),
and tuples and lists of those types are supported as model inputs or outputs. Dict and str inputs and
outputs are accepted in :ref:`tracing<tracing-vs-scripting>` mode, but:
* Any computation that depends on the value of a dict or a str input **will be replaced with the
constant value** seen during the one traced execution.
* Any output that is a dict will be silently replaced with a **flattened sequence of its values
(keys will be removed)**. E.g. ``{"foo": 1, "bar": 2}`` becomes ``(1, 2)``.
* Any output that is a str will be silently removed.
* Certain operations involving tuples and lists are not supported in
:ref:`scripting<tracing-vs-scripting>` mode due to limited support in ONNX for nested sequences.
In particular appending a tuple to a list is not supported. In tracing mode, the nested sequences
will be flattened automatically during the tracing.
Differences in Operator Implementations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Due to differences in implementations of operators, running the exported model on different runtimes
may produce different results from each other or from PyTorch. Normally these differences are
numerically small, so this should only be a concern if your application is sensitive to these
small differences.
.. _tensor-indexing:
Unsupported Tensor Indexing Patterns
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Tensor indexing patterns that cannot be exported are listed below.
If you are experiencing issues exporting a model that does not include any of
the unsupported patterns below, please double check that you are exporting with
the latest ``opset_version``.
Reads / Gets
~~~~~~~~~~~~
When indexing into a tensor for reading, the following patterns are not supported: ::
# Tensor indices that includes negative values.
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
# Workarounds: use positive index values.
Writes / Sets
~~~~~~~~~~~~~
When indexing into a Tensor for writing, the following patterns are not supported: ::
# Multiple tensor indices if any has rank >= 2
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
# Workarounds: use single tensor index with rank >= 2,
# or multiple consecutive tensor indices with rank == 1.
# Multiple tensor indices that are not consecutive
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
# Workarounds: transpose `data` such that tensor indices are consecutive.
# Tensor indices that includes negative values.
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
# Workarounds: use positive index values.
# Implicit broadcasting required for new_data.
data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
# Workarounds: expand new_data explicitly.
# Example:
# data shape: [3, 4, 5]
# new_data shape: [5]
# expected new_data shape after broadcasting: [2, 2, 2, 5]
Adding support for operators
----------------------------
When exporting a model that includes unsupported operators, you'll see an error message like:
.. code-block:: text
RuntimeError: ONNX export failed: Couldn't export operator foo
When that happens, there are a few things you can do:
#. Change the model to not use that operator.
#. Create a symbolic function to convert the operator and register it as a custom symbolic function.
#. Contribute to PyTorch to add the same symbolic function to :mod:`torch.onnx` itself.
If you decided to implement a symbolic function (we hope you will contribute it back to PyTorch!), here is how you can get started:
ONNX exporter internals
^^^^^^^^^^^^^^^^^^^^^^^
A "symbolic function" is a function that decomposes a PyTorch operator into a
composition of a series of ONNX operators.
During export, each node (which contains a PyTorch operator) in the TorchScript
graph is visited by the exporter in topological order.
Upon visiting a node, the exporter looks for a registered symbolic functions for
that operator. Symbolic functions are implemented in Python. A symbolic function for
an op named ``foo`` would look something like::
def foo(
g,
input_0: torch._C.Value,
input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
"""
Adds the ONNX operations representing this PyTorch function by updating the
graph g with `g.op()` calls.
Args:
g (Graph): graph to write the ONNX representation into.
input_0 (Value): value representing the variables which contain
the first input for this operator.
input_1 (Value): value representing the variables which contain
the second input for this operator.
Returns:
A Value or List of Values specifying the ONNX nodes that compute something
equivalent to the original PyTorch operator with the given inputs.
None if it cannot be converted to ONNX.
"""
...
The ``torch._C`` types are Python wrappers around the types defined in C++ in
`ir.h <https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/ir/ir.h>`_.
The process for adding a symbolic function depends on the type of operator.
.. _adding-support-aten:
ATen operators
^^^^^^^^^^^^^^
`ATen <https://pytorch.org/cppdocs/#aten>`_ is PyTorch's built-in tensor library.
If the operator is an ATen operator (shows up in the TorchScript graph with the prefix
``aten::``), make sure it is not supported already.
List of supported operators
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Visit the auto generated :doc:`list of supported TorchScript operators <../onnx_supported_aten_ops>`
for details on which operator are supported in each ``opset_version``.
Adding support for an aten or quantized operator
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If the operator is not in the list above:
* Define the symbolic function in ``torch/onnx/symbolic_opset<version>.py``, for example
`torch/onnx/symbolic_opset9.py <https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset9.py>`_.
Make sure the function has the same name as the ATen function, which may be declared in
``torch/_C/_VariableFunctions.pyi`` or ``torch/nn/functional.pyi`` (these files are generated at
build time, so will not appear in your checkout until you build PyTorch).
* By default, the first arg is the ONNX graph.
Other arg names must EXACTLY match the names in the ``.pyi`` file,
because dispatch is done with keyword arguments.
* In the symbolic function, if the operator is in the
`ONNX standard operator set <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
we only need to create a node to represent the ONNX operator in the graph.
If not, we can compose several standard operators that have the
equivalent semantics to the ATen operator.
Here is an example of handling missing symbolic function for the ``ELU`` operator.
If we run the following code::
print(
torch.jit.trace(
torch.nn.ELU(), # module
torch.ones(1) # example input
).graph
)
We see something like::
graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
%4 : float = prim::Constant[value=1.]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=1]()
%7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
return (%7)
Since we see ``aten::elu`` in the graph, we know this is an ATen operator.
We check the `ONNX operator list <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
and confirm that ``Elu`` is standardized in ONNX.
We find a signature for ``elu`` in ``torch/nn/functional.pyi``::
def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...
We add the following lines to ``symbolic_opset9.py``::
def elu(g, input: torch.Value, alpha: torch.Value, inplace: bool = False):
return g.op("Elu", input, alpha_f=alpha)
Now PyTorch is able to export models containing the ``aten::elu`` operator!
See the ``torch/onnx/symbolic_opset*.py`` files for more examples.
torch.autograd.Functions
^^^^^^^^^^^^^^^^^^^^^^^^
If the operator is a sub-class of :class:`torch.autograd.Function`, there are three ways
to export it.
Static Symbolic Method
~~~~~~~~~~~~~~~~~~~~~~
You can add a static method named ``symbolic`` to your function class. It should return
ONNX operators that represent the function's behavior in ONNX. For example::
class MyRelu(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))
.. FIXME(justinchuby): PythonOps are too complicated and the example below
.. uses private methods we do not expose. We are looking to
.. improve the experience. Since SymbolicContext is deprecated, we think
.. defining a symbolic staticmethod is a better way to go for now.
.. PythonOp Symbolic
.. ~~~~~~~~~~~~~~~~~
.. Alternatively, you can register a custom symbolic function.
.. This gives the symbolic function access to more info through the
.. ``torch.onnx.SymbolicContext`` object, which gets passed in as the first
.. argument (before the ``Graph`` object).
.. All autograd ``Function``\ s appear in the TorchScript graph as ``prim::PythonOp`` nodes.
.. In order to differentiate between different ``Function`` subclasses, the
.. symbolic function should use the ``name`` kwarg which gets set to the name of the class.
.. Custom symbolic functions should add type and shape information by calling ``setType(...)``
.. on Value objects before returning them (implemented in C++ by
.. . ``torch::jit::Value::setType``). This is not required, but it can help the exporter's
.. shape and type inference for down-stream nodes. For a non-trivial example of ``setType``, see
.. ``test_aten_embedding_2`` in
.. `test_operators.py <https://github.com/pytorch/pytorch/blob/main/test/onnx/test_operators.py>`_.
.. The example below shows how you can access ``requires_grad`` via the ``Node`` object:
.. class MyClip(torch.autograd.Function):
.. @staticmethod
.. def forward(ctx, input, min):
.. ctx.save_for_backward(input)
.. return input.clamp(min=min)
.. class MyRelu(torch.autograd.Function):
.. @staticmethod
.. def forward(ctx, input):
.. ctx.save_for_backward(input)
.. return input.clamp(min=0)
.. def symbolic_python_op(g: "GraphContext", *args, **kwargs):
.. n = ctx.cur_node
.. print("original node: ", n)
.. for i, out in enumerate(n.outputs()):
.. print("original output {}: {}, requires grad: {}".format(i, out, out.requiresGrad()))
.. import torch.onnx.symbolic_helper as sym_helper
.. for i, arg in enumerate(args):
.. requires_grad = arg.requiresGrad() if sym_helper._is_value(arg) else False
.. print("arg {}: {}, requires grad: {}".format(i, arg, requires_grad))
.. name = kwargs["name"]
.. ret = None
.. if name == "MyClip":
.. ret = g.op("Clip", args[0], args[1])
.. elif name == "MyRelu":
.. ret = g.op("Relu", args[0])
.. else:
.. # Logs a warning and returns None
.. return _unimplemented("prim::PythonOp", "unknown node kind: " + name)
.. # Copy type and shape from original node.
.. ret.setType(n.type())
.. return ret
.. from torch.onnx import register_custom_op_symbolic
.. . register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1)
Inline Autograd Function
~~~~~~~~~~~~~~~~~~~~~~~~
In cases where a static symbolic method is not provided for its subsequent :class:`torch.autograd.Function` or
where a function to register ``prim::PythonOp`` as custom symbolic functions is not provided,
:func:`torch.onnx.export` tries to inline the graph that corresponds to that :class:`torch.autograd.Function` such that
this function is broken down into individual operators that were used within the function.
The export should be successful as long as these individual operators are supported. For example::
class MyLogExp(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input)
h = input.exp()
return h.log().log()
There is no static symbolic method present for this model, yet it is exported as follows::
graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
%1 : float = onnx::Exp[](%input)
%2 : float = onnx::Log[](%1)
%3 : float = onnx::Log[](%2)
return (%3)
If you need to avoid inlining of :class:`torch.autograd.Function`, you should export models with
``operator_export_type`` set to ``ONNX_FALLTHROUGH`` or ``ONNX_ATEN_FALLBACK``.
Custom operators
^^^^^^^^^^^^^^^^
You can export your model with custom operators that includes a combination of many standard ONNX ops,
or are driven by self-defined C++ backend.
ONNX-script functions
~~~~~~~~~~~~~~~~~~~~~
If an operator is not a standard ONNX op, but can be composed of multiple existing ONNX ops, you can utilize
`ONNX-script <https://github.com/microsoft/onnx-script>`_ to create an external ONNX function to support the operator.
You can export it by following this example::
import onnxscript
# There are three opset version needed to be aligned
# This is (1) the opset version in ONNX function
from onnxscript.onnx_opset import opset15 as op
opset_version = 15
x = torch.randn(1, 2, 3, 4, requires_grad=True)
model = torch.nn.SELU()
custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
@onnxscript.script(custom_opset)
def Selu(X):
alpha = 1.67326 # auto wrapped as Constants
gamma = 1.0507
alphaX = op.CastLike(alpha, X)
gammaX = op.CastLike(gamma, X)
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
pos = gammaX * X
zero = op.CastLike(0, X)
return op.Where(X <= zero, neg, pos)
# setType API provides shape/type to ONNX shape/type inference
def custom_selu(g: jit_utils.GraphContext, X):
return g.onnxscript_op(Selu, X).setType(X.type())
# Register custom symbolic function
# There are three opset version needed to be aligned
# This is (2) the opset version in registry
torch.onnx.register_custom_op_symbolic(
symbolic_name="aten::selu",
symbolic_fn=custom_selu,
opset_version=opset_version,
)
# There are three opset version needed to be aligned
# This is (2) the opset version in exporter
torch.onnx.export(
model,
x,
"model.onnx",
opset_version=opset_version,
# only needed if you want to specify an opset version > 1.
custom_opsets={"onnx-script": 2}
)
The example above exports it as a custom operator in the "onnx-script" opset.
When exporting a custom operator, you can specify the custom domain version using the
``custom_opsets`` dictionary at export. If not specified, the custom opset version defaults to 1.
NOTE: Be careful to align the opset version mentioned in the above example, and make sure they are consumed in exporter step.
The example usage of how to write a onnx-script function is a beta version in terms of the active development on onnx-script.
Please follow the latest `ONNX-script <https://github.com/microsoft/onnx-script>`_
C++ Operators
~~~~~~~~~~~~~
If a model uses a custom operator implemented in C++ as described in
`Extending TorchScript with Custom C++ Operators <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_,
you can export it by following this example::
from torch.onnx import symbolic_helper
# Define custom symbolic function
@symbolic_helper.parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)
# Register custom symbolic function
torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)
class FooModel(torch.nn.Module):
def __init__(self, attr1, attr2):
super().__init__()
self.attr1 = attr1
self.attr2 = attr2
def forward(self, input1, input2):
# Calling custom op
return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)
model = FooModel(attr1, attr2)
torch.onnx.export(
model,
(example_input1, example_input1),
"model.onnx",
# only needed if you want to specify an opset version > 1.
custom_opsets={"custom_domain": 2}
)
The example above exports it as a custom operator in the "custom_domain" opset.
When exporting a custom operator, you can specify the custom domain version using the
``custom_opsets`` dictionary at export. If not specified, the custom opset version defaults to 1.
The runtime that consumes the model needs to support the custom op. See
`Caffe2 custom ops <https://caffe2.ai/docs/custom-operators.html>`_,
`ONNX Runtime custom ops <https://onnxruntime.ai/docs/reference/operators/add-custom-op.html>`_,
or your runtime of choice's documentation.
Discovering all unconvertible ATen ops at once
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
When export fails due to an unconvertible ATen op, there may in fact be more
than one such op but the error message only mentions the first. To discover
all of the unconvertible ops in one go you can::
# prepare model, args, opset_version
...
torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
model, args, opset_version=opset_version
)
print(set(unconvertible_ops))
The set is approximated because some ops may be removed during the conversion
process and don't need to be converted. Some other ops may have partial support
that will fail conversion with particular inputs, but this should give you a
general idea of what ops are not supported. Please feel free to open GitHub Issues
for op support requests.
Frequently Asked Questions
--------------------------
Q: I have exported my LSTM model, but its input size seems to be fixed?
The tracer records the shapes of the example inputs. If the model should accept
inputs of dynamic shapes, set ``dynamic_axes`` when calling :func:`torch.onnx.export`.
Q: How to export models containing loops?
See `Tracing vs Scripting`_.
Q: How to export models with primitive type inputs (e.g. int, float)?
Support for primitive numeric type inputs was added in PyTorch 1.9.
However, the exporter does not support models with str inputs.
Q: Does ONNX support implicit scalar datatype casting?
The ONNX standard does not, but the exporter will try to handle that part.
Scalars are exported as constant tensors.
The exporter will figure out the right data type for scalars. In rare cases when it is unable
to do so, you will need to manually specify the datatype with e.g. `dtype=torch.float32`.
If you see any errors, please [create a GitHub issue](https://github.com/pytorch/pytorch/issues).
Q: Are lists of Tensors exportable to ONNX?
Yes, for ``opset_version`` >= 11, since ONNX introduced the Sequence type in opset 11.
Contributing / developing
Contributing / Developing
-------------------------
`Developer docs <https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter>`_.
Functions
---------
.. autofunction:: export
.. autofunction:: export_to_pretty_string
.. autofunction:: register_custom_op_symbolic
.. autofunction:: unregister_custom_op_symbolic
.. autofunction:: select_model_mode_for_export
.. autofunction:: is_in_onnx_export
.. autofunction:: enable_log
.. autofunction:: disable_log
.. autofunction:: torch.onnx.verification.find_mismatch
The ONNX exporter is a community project and we welcome contributions. We follow the
`PyTorch guidelines for contributions <https://github.com/pytorch/pytorch/blob/main/CONTRIBUTING.md>`_, but you might
also be interested in reading our `development wiki <https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter>`_.
Classes
-------
.. toctree::
:hidden:
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
JitScalarType
torch.onnx.verification.GraphInfo
torch.onnx.verification.VerificationOptions
Preview: torch.onnx TorchDynamo Exporter
----------------------------------------
.. warning::
The ONNX exporter for TorchDynamo is under active development and is
subject to rapid change.
.. autofunction:: torch.onnx.dynamo_export
.. autofunction:: torch.onnx.enable_fake_mode
.. autofunction:: torch.onnx.is_onnxrt_backend_supported
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
torch.onnx.DiagnosticOptions
torch.onnx.ExportOptions
torch.onnx.ExportOutput
torch.onnx.ExportOutputSerializer
torch.onnx.OnnxExporterError
torch.onnx.OnnxRegistry
onnx_dynamo
onnx_dynamo_onnxruntime_backend
onnx_torchscript

View File

@ -1,26 +0,0 @@
torch.onnx diagnostics
======================
.. contents:: :local:
.. automodule:: torch.onnx._internal.diagnostics
.. currentmodule:: torch.onnx._internal.diagnostics
Overview
--------
NOTE: This feature is underdevelopment and is subject to change.
The goal is to improve the diagnostics to help users debug and improve their model export to ONNX.
- The diagnostics are emitted in machine parsable `Static Analysis Results Interchange Format (SARIF) <https://docs.oasis-open.org/sarif/sarif/v2.1.0/sarif-v2.1.0.html>`__.
- A new clearer, structured way to add new and keep track of diagnostic rules.
- Serve as foundation for more future improvements consuming the diagnostics.
Diagnostic Rules
----------------
.. toctree::
:glob:
generated/onnx_diagnostics_rules/*

156
docs/source/onnx_dynamo.rst Normal file
View File

@ -0,0 +1,156 @@
TorchDynamo-based ONNX Exporter
===============================
.. automodule:: torch.onnx
:noindex:
.. contents:: :local:
:depth: 3
.. warning::
The ONNX exporter for TorchDynamo is under active development and is subject to rapid change.
Overview
--------
The ONNX exporter leverages TorchDynamo engine to hook into Python's frame evaluation API
and dynamically rewrite its bytecode into an FX Graph.
The resulting FX Graph is then polished before it is finally translated into an ONNX graph.
The main advantage of this approach is that the `FX graph <https://pytorch.org/docs/stable/fx.html>`_ is captured using
bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques.
The exporter is designed to be modular and extensible. It is composed of the following components:
- **ONNX Exporter**: :class:`Exporter` main class that orchestrates the export process.
- **ONNX Export Options**: :class:`ExportOptions` has a set of options that control the export process.
- **ONNX Registry**: :class:`OnnxRegistry` is the registry of ONNX operators and functions.
- **FX Graph Extractor**: :class:`FXGraphExtractor` extracts the FX graph from the PyTorch model.
- **Fake Mode**: :class:`ONNXFakeContext` is a context manager that enables fake mode for large scale models.
- **ONNX Export Output**: :class:`ExportOutput` is the output of the exporter that contains the exported ONNX graph and diagnostics.
- **ONNX Export Output Serializer**: :class:`ExportOutputSerializer` serializes the exported model to a file.
- **ONNX Diagnostic Options**: :class:`DiagnosticOptions` has a set of options that control the diagnostics emitted by the exporter.
Dependencies
------------
The ONNX exporter depends on extra Python packages:
- `ONNX <https://onnx.ai>`_
- `ONNX Script <https://onnxscript.ai>`_
They can be installed through `pip <https://pypi.org/project/pip/>`_:
.. code-block:: bash
pip install --upgrade onnx onnxscript
A simple example
----------------
See below a demonstration of exporter API in action with a simple Multilayer Perceptron (MLP) as example:
.. code-block:: python
import torch
class MLPModel(nn.Module):
def __init__(self):
super().__init__()
self.fc0 = nn.Linear(8, 8, bias=True)
self.fc1 = nn.Linear(8, 4, bias=True)
self.fc2 = nn.Linear(4, 2, bias=True)
self.fc3 = nn.Linear(2, 2, bias=True)
def forward(self, tensor_x: torch.Tensor):
tensor_x = self.fc0(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
tensor_x = self.fc1(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
tensor_x = self.fc2(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
output = self.fc3(tensor_x)
return output
model = MLPModel()
tensor_x = torch.rand((97, 8), dtype=torch.float32)
export_output = torch.onnx.dynamo_export(model, tensor_x)
As the code above shows, all you need is to provide :func:`torch.onnx.dynamo_export` with an instance of the model and its input.
The exporter will then return an instance of :class:`torch.onnx.ExportOutput` that contains the exported ONNX graph along with extra information.
The in-memory model available through ``export_output.model_proto`` is an ``onnx.ModelProto`` object in compliance with the `ONNX IR spec <https://github.com/onnx/onnx/blob/main/docs/IR.md>`_.
The ONNX model be serialized into a `Protobuf file <https://protobuf.dev/>`_ using the :meth:`torch.onnx.ExportOutput.save` API.
.. code-block:: python
export_output.save("mlp.onnx")
Inspecting the ONNX model using GUI
-----------------------------------
You can view the exported model using `Netron <https://netron.app/>`__.
.. image:: _static/img/onnx/onnx_dynamo_mlp_model.png
:width: 40%
:alt: MLP model as viewed using Netron
Note that each layer is represented in a rectangular box with a *f* icon in the top right corner.
.. image:: _static/img/onnx/onnx_dynamo_mlp_model_function_highlight.png
:width: 40%
:alt: ONNX function highlighted on MLP model
By expanding it, the function body is shown.
.. image:: _static/img/onnx/onnx_dynamo_mlp_model_function_body.png
:width: 50%
:alt: ONNX function body
The function body is a sequence of ONNX operators or other functions.
Diagnosing issues with SARIF
----------------------------
ONNX diagnostics goes beyond regular logs through the adoption of
`Static Analysis Results Interchange Format (aka SARIF) <https://docs.oasis-open.org/sarif/sarif/v2.1.0/sarif-v2.1.0.html>`__
to help users debug and improve their model using a GUI, such as
Visual Studio Code's `SARIF Viewer <https://marketplace.visualstudio.com/items?itemName=MS-SarifVSCode.sarif-viewer>`_.
The main advantages are:
- The diagnostics are emitted in machine parseable `Static Analysis Results Interchange Format (SARIF) <https://docs.oasis-open.org/sarif/sarif/v2.1.0/sarif-v2.1.0.html>`__.
- A new clearer, structured way to add new and keep track of diagnostic rules.
- Serve as foundation for more future improvements consuming the diagnostics.
.. toctree::
:maxdepth: 1
:caption: ONNX Diagnostic SARIF Rules
:glob:
generated/onnx_dynamo_diagnostics_rules/*
API Reference
-------------
.. autofunction:: torch.onnx.dynamo_export
.. autoclass:: torch.onnx.ExportOptions
:members:
.. autofunction:: torch.onnx.enable_fake_mode
.. autoclass:: torch.onnx.ExportOutput
:members:
.. autoclass:: torch.onnx.ExportOutputSerializer
:members:
.. autoclass:: torch.onnx.OnnxExporterError
:members:
.. autoclass:: torch.onnx.OnnxRegistry
:members:
.. autoclass:: torch.onnx.DiagnosticOptions
:members:

View File

@ -0,0 +1,10 @@
ONNX Backend for TorchDynamo
============================
For a quick overview of ``torch.compiler``, see :ref:`torch.compiler_overview`.
.. warning::
The ONNX backend for TorchDynamo is under active development and is
subject to rapid change.
.. autofunction:: torch.onnx.is_onnxrt_backend_supported

View File

@ -0,0 +1,716 @@
TorchScript-based ONNX Exporter
===============================
.. contents:: :local:
Example: AlexNet from PyTorch to ONNX
-------------------------------------
Here is a simple script which exports a pretrained AlexNet to an ONNX file named ``alexnet.onnx``.
The call to ``torch.onnx.export`` runs the model once to trace its execution and then exports the
traced model to the specified file::
import torch
import torchvision
dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
model = torchvision.models.alexnet(pretrained=True).cuda()
# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]
torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)
The resulting ``alexnet.onnx`` file contains a binary `protocol buffer <https://developers.google.com/protocol-buffers/>`_
which contains both the network structure and parameters of the model you exported
(in this case, AlexNet). The argument ``verbose=True`` causes the
exporter to print out a human-readable representation of the model::
# These are the inputs and parameters to the network, which have taken on
# the names we specified earlier.
graph(%actual_input_1 : Float(10, 3, 224, 224)
%learned_0 : Float(64, 3, 11, 11)
%learned_1 : Float(64)
%learned_2 : Float(192, 64, 5, 5)
%learned_3 : Float(192)
# ---- omitted for brevity ----
%learned_14 : Float(1000, 4096)
%learned_15 : Float(1000)) {
# Every statement consists of some output tensors (and their types),
# the operator to be run (with its attributes, e.g., kernels, strides,
# etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
%17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
%18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
%19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
# ---- omitted for brevity ----
%29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
# Dynamic means that the shape is not known. This may be because of a
# limitation of our implementation (which we would like to fix in a
# future release) or shapes which are truly dynamic.
%30 : Dynamic = onnx::Shape(%29), scope: AlexNet
%31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
%32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
%33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
# ---- omitted for brevity ----
%output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
return (%output1);
}
You can also verify the output using the `ONNX <https://github.com/onnx/onnx/>`_ library,
which you can install using ``pip``::
pip install onnx
Then, you can run::
import onnx
# Load the ONNX model
model = onnx.load("alexnet.onnx")
# Check that the model is well formed
onnx.checker.check_model(model)
# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))
You can also run the exported model with one of the many
`runtimes that support ONNX <https://onnx.ai/supported-tools.html#deployModel>`_.
For example after installing `ONNX Runtime <https://www.onnxruntime.ai>`_, you can
load and run the model::
import onnxruntime as ort
import numpy as np
ort_session = ort.InferenceSession("alexnet.onnx")
outputs = ort_session.run(
None,
{"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)
print(outputs[0])
Here is a more involved `tutorial on exporting a model and running it with ONNX Runtime <https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html>`_.
.. _tracing-vs-scripting:
Tracing vs Scripting
--------------------
Internally, :func:`torch.onnx.export()` requires a :class:`torch.jit.ScriptModule` rather than
a :class:`torch.nn.Module`. If the passed-in model is not already a ``ScriptModule``,
``export()`` will use *tracing* to convert it to one:
.. TODO(justinchuby): Add a word on recommending tracing over scripting for most use cases.
* **Tracing**: If ``torch.onnx.export()`` is called with a Module that is not already a
``ScriptModule``, it first does the equivalent of :func:`torch.jit.trace`, which executes the model
once with the given ``args`` and records all operations that happen during that execution. This
means that if your model is dynamic, e.g., changes behavior depending on input data, the exported
model will *not* capture this dynamic behavior.
We recommend examining the exported model and making sure the operators look
reasonable. Tracing will unroll loops and if statements, exporting a static graph that is exactly
the same as the traced run. If you want to export your model with dynamic control flow, you will
need to use *scripting*.
* **Scripting**: Compiling a model via scripting preserves dynamic control flow and is valid for inputs
of different sizes. To use scripting:
* Use :func:`torch.jit.script` to produce a ``ScriptModule``.
* Call ``torch.onnx.export()`` with the ``ScriptModule`` as the model. The ``args`` are still required,
but they will be used internally only to produce example outputs, so that the types and shapes of the
outputs can be captured. No tracing will be performed.
See `Introduction to TorchScript <https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html>`_
and `TorchScript <jit.html>`_ for more details, including how to compose tracing and scripting to suit the
particular requirements of different models.
Avoiding Pitfalls
-----------------
Avoid NumPy and built-in Python types
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
PyTorch models can be written using NumPy or Python types and functions, but
during :ref:`tracing<tracing-vs-scripting>`, any variables of NumPy or Python
types (rather than torch.Tensor) are converted to constants, which will produce
the wrong result if those values should change depending on the inputs.
For example, rather than using numpy functions on numpy.ndarrays: ::
# Bad! Will be replaced with constants during tracing.
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
np.concatenate((x, y), axis=1)
Use torch operators on torch.Tensors: ::
# Good! Tensor operations will be captured during tracing.
x, y = torch.randn(1, 2), torch.randn(1, 2)
torch.cat((x, y), dim=1)
And rather than use :func:`torch.Tensor.item` (which converts a Tensor to a Python
built-in number): ::
# Bad! y.item() will be replaced with a constant during tracing.
def forward(self, x, y):
return x.reshape(y.item(), -1)
Use torch's support for implicit casting of single-element tensors: ::
# Good! y will be preserved as a variable during tracing.
def forward(self, x, y):
return x.reshape(y, -1)
Avoid Tensor.data
^^^^^^^^^^^^^^^^^
Using the Tensor.data field can produce an incorrect trace and therefore an incorrect ONNX graph.
Use :func:`torch.Tensor.detach` instead. (Work is ongoing to
`remove Tensor.data entirely <https://github.com/pytorch/pytorch/issues/30987>`_).
Avoid in-place operations when using tensor.shape in tracing mode
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In tracing mode, shapes obtained from ``tensor.shape`` are traced as tensors,
and share the same memory. This might cause a mismatch the final output values.
As a workaround, avoid the use of inplace operations in these scenarios.
For example, in the model::
class Model(torch.nn.Module):
def forward(self, states):
batch_size, seq_length = states.shape[:2]
real_seq_length = seq_length
real_seq_length += 2
return real_seq_length + seq_length
``real_seq_length`` and ``seq_length`` share the same memory in tracing mode.
This could be avoided by rewriting the inplace operation::
real_seq_length = real_seq_length + 2
Limitations
-----------
Types
^^^^^
* Only :class:`torch.Tensors`, numeric types that can be trivially converted to torch.Tensors (e.g. float, int),
and tuples and lists of those types are supported as model inputs or outputs. Dict and str inputs and
outputs are accepted in :ref:`tracing<tracing-vs-scripting>` mode, but:
* Any computation that depends on the value of a dict or a str input **will be replaced with the
constant value** seen during the one traced execution.
* Any output that is a dict will be silently replaced with a **flattened sequence of its values
(keys will be removed)**. E.g. ``{"foo": 1, "bar": 2}`` becomes ``(1, 2)``.
* Any output that is a str will be silently removed.
* Certain operations involving tuples and lists are not supported in
:ref:`scripting<tracing-vs-scripting>` mode due to limited support in ONNX for nested sequences.
In particular appending a tuple to a list is not supported. In tracing mode, the nested sequences
will be flattened automatically during the tracing.
Differences in Operator Implementations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Due to differences in implementations of operators, running the exported model on different runtimes
may produce different results from each other or from PyTorch. Normally these differences are
numerically small, so this should only be a concern if your application is sensitive to these
small differences.
.. _tensor-indexing:
Unsupported Tensor Indexing Patterns
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Tensor indexing patterns that cannot be exported are listed below.
If you are experiencing issues exporting a model that does not include any of
the unsupported patterns below, please double check that you are exporting with
the latest ``opset_version``.
Reads / Gets
~~~~~~~~~~~~
When indexing into a tensor for reading, the following patterns are not supported: ::
# Tensor indices that includes negative values.
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
# Workarounds: use positive index values.
Writes / Sets
~~~~~~~~~~~~~
When indexing into a Tensor for writing, the following patterns are not supported: ::
# Multiple tensor indices if any has rank >= 2
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
# Workarounds: use single tensor index with rank >= 2,
# or multiple consecutive tensor indices with rank == 1.
# Multiple tensor indices that are not consecutive
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
# Workarounds: transpose `data` such that tensor indices are consecutive.
# Tensor indices that includes negative values.
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
# Workarounds: use positive index values.
# Implicit broadcasting required for new_data.
data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
# Workarounds: expand new_data explicitly.
# Example:
# data shape: [3, 4, 5]
# new_data shape: [5]
# expected new_data shape after broadcasting: [2, 2, 2, 5]
Adding support for operators
----------------------------
When exporting a model that includes unsupported operators, you'll see an error message like:
.. code-block:: text
RuntimeError: ONNX export failed: Couldn't export operator foo
When that happens, there are a few things you can do:
#. Change the model to not use that operator.
#. Create a symbolic function to convert the operator and register it as a custom symbolic function.
#. Contribute to PyTorch to add the same symbolic function to :mod:`torch.onnx` itself.
If you decided to implement a symbolic function (we hope you will contribute it back to PyTorch!), here is how you can get started:
ONNX exporter internals
^^^^^^^^^^^^^^^^^^^^^^^
A "symbolic function" is a function that decomposes a PyTorch operator into a
composition of a series of ONNX operators.
During export, each node (which contains a PyTorch operator) in the TorchScript
graph is visited by the exporter in topological order.
Upon visiting a node, the exporter looks for a registered symbolic functions for
that operator. Symbolic functions are implemented in Python. A symbolic function for
an op named ``foo`` would look something like::
def foo(
g,
input_0: torch._C.Value,
input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
"""
Adds the ONNX operations representing this PyTorch function by updating the
graph g with `g.op()` calls.
Args:
g (Graph): graph to write the ONNX representation into.
input_0 (Value): value representing the variables which contain
the first input for this operator.
input_1 (Value): value representing the variables which contain
the second input for this operator.
Returns:
A Value or List of Values specifying the ONNX nodes that compute something
equivalent to the original PyTorch operator with the given inputs.
None if it cannot be converted to ONNX.
"""
...
The ``torch._C`` types are Python wrappers around the types defined in C++ in
`ir.h <https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/ir/ir.h>`_.
The process for adding a symbolic function depends on the type of operator.
.. _adding-support-aten:
ATen operators
^^^^^^^^^^^^^^
`ATen <https://pytorch.org/cppdocs/#aten>`_ is PyTorch's built-in tensor library.
If the operator is an ATen operator (shows up in the TorchScript graph with the prefix
``aten::``), make sure it is not supported already.
List of supported operators
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Visit the auto generated :doc:`list of supported TorchScript operators <../onnx_torchscript_supported_aten_ops>`
for details on which operator are supported in each ``opset_version``.
Adding support for an aten or quantized operator
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If the operator is not in the list above:
* Define the symbolic function in ``torch/onnx/symbolic_opset<version>.py``, for example
`torch/onnx/symbolic_opset9.py <https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset9.py>`_.
Make sure the function has the same name as the ATen function, which may be declared in
``torch/_C/_VariableFunctions.pyi`` or ``torch/nn/functional.pyi`` (these files are generated at
build time, so will not appear in your checkout until you build PyTorch).
* By default, the first arg is the ONNX graph.
Other arg names must EXACTLY match the names in the ``.pyi`` file,
because dispatch is done with keyword arguments.
* In the symbolic function, if the operator is in the
`ONNX standard operator set <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
we only need to create a node to represent the ONNX operator in the graph.
If not, we can compose several standard operators that have the
equivalent semantics to the ATen operator.
Here is an example of handling missing symbolic function for the ``ELU`` operator.
If we run the following code::
print(
torch.jit.trace(
torch.nn.ELU(), # module
torch.ones(1) # example input
).graph
)
We see something like::
graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
%4 : float = prim::Constant[value=1.]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=1]()
%7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
return (%7)
Since we see ``aten::elu`` in the graph, we know this is an ATen operator.
We check the `ONNX operator list <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
and confirm that ``Elu`` is standardized in ONNX.
We find a signature for ``elu`` in ``torch/nn/functional.pyi``::
def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...
We add the following lines to ``symbolic_opset9.py``::
def elu(g, input: torch.Value, alpha: torch.Value, inplace: bool = False):
return g.op("Elu", input, alpha_f=alpha)
Now PyTorch is able to export models containing the ``aten::elu`` operator!
See the ``torch/onnx/symbolic_opset*.py`` files for more examples.
torch.autograd.Functions
^^^^^^^^^^^^^^^^^^^^^^^^
If the operator is a sub-class of :class:`torch.autograd.Function`, there are three ways
to export it.
Static Symbolic Method
~~~~~~~~~~~~~~~~~~~~~~
You can add a static method named ``symbolic`` to your function class. It should return
ONNX operators that represent the function's behavior in ONNX. For example::
class MyRelu(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))
.. FIXME(justinchuby): PythonOps are too complicated and the example below
.. uses private methods we do not expose. We are looking to
.. improve the experience. Since SymbolicContext is deprecated, we think
.. defining a symbolic staticmethod is a better way to go for now.
.. PythonOp Symbolic
.. ~~~~~~~~~~~~~~~~~
.. Alternatively, you can register a custom symbolic function.
.. This gives the symbolic function access to more info through the
.. ``torch.onnx.SymbolicContext`` object, which gets passed in as the first
.. argument (before the ``Graph`` object).
.. All autograd ``Function``\ s appear in the TorchScript graph as ``prim::PythonOp`` nodes.
.. In order to differentiate between different ``Function`` subclasses, the
.. symbolic function should use the ``name`` kwarg which gets set to the name of the class.
.. Custom symbolic functions should add type and shape information by calling ``setType(...)``
.. on Value objects before returning them (implemented in C++ by
.. . ``torch::jit::Value::setType``). This is not required, but it can help the exporter's
.. shape and type inference for down-stream nodes. For a non-trivial example of ``setType``, see
.. ``test_aten_embedding_2`` in
.. `test_operators.py <https://github.com/pytorch/pytorch/blob/main/test/onnx/test_operators.py>`_.
.. The example below shows how you can access ``requires_grad`` via the ``Node`` object:
.. class MyClip(torch.autograd.Function):
.. @staticmethod
.. def forward(ctx, input, min):
.. ctx.save_for_backward(input)
.. return input.clamp(min=min)
.. class MyRelu(torch.autograd.Function):
.. @staticmethod
.. def forward(ctx, input):
.. ctx.save_for_backward(input)
.. return input.clamp(min=0)
.. def symbolic_python_op(g: "GraphContext", *args, **kwargs):
.. n = ctx.cur_node
.. print("original node: ", n)
.. for i, out in enumerate(n.outputs()):
.. print("original output {}: {}, requires grad: {}".format(i, out, out.requiresGrad()))
.. import torch.onnx.symbolic_helper as sym_helper
.. for i, arg in enumerate(args):
.. requires_grad = arg.requiresGrad() if sym_helper._is_value(arg) else False
.. print("arg {}: {}, requires grad: {}".format(i, arg, requires_grad))
.. name = kwargs["name"]
.. ret = None
.. if name == "MyClip":
.. ret = g.op("Clip", args[0], args[1])
.. elif name == "MyRelu":
.. ret = g.op("Relu", args[0])
.. else:
.. # Logs a warning and returns None
.. return _unimplemented("prim::PythonOp", "unknown node kind: " + name)
.. # Copy type and shape from original node.
.. ret.setType(n.type())
.. return ret
.. from torch.onnx import register_custom_op_symbolic
.. . register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1)
Inline Autograd Function
~~~~~~~~~~~~~~~~~~~~~~~~
In cases where a static symbolic method is not provided for its subsequent :class:`torch.autograd.Function` or
where a function to register ``prim::PythonOp`` as custom symbolic functions is not provided,
:func:`torch.onnx.export` tries to inline the graph that corresponds to that :class:`torch.autograd.Function` such that
this function is broken down into individual operators that were used within the function.
The export should be successful as long as these individual operators are supported. For example::
class MyLogExp(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input)
h = input.exp()
return h.log().log()
There is no static symbolic method present for this model, yet it is exported as follows::
graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
%1 : float = onnx::Exp[](%input)
%2 : float = onnx::Log[](%1)
%3 : float = onnx::Log[](%2)
return (%3)
If you need to avoid inlining of :class:`torch.autograd.Function`, you should export models with
``operator_export_type`` set to ``ONNX_FALLTHROUGH`` or ``ONNX_ATEN_FALLBACK``.
Custom operators
^^^^^^^^^^^^^^^^
You can export your model with custom operators that includes a combination of many standard ONNX ops,
or are driven by self-defined C++ backend.
ONNX-script functions
~~~~~~~~~~~~~~~~~~~~~
If an operator is not a standard ONNX op, but can be composed of multiple existing ONNX ops, you can utilize
`ONNX-script <https://github.com/microsoft/onnx-script>`_ to create an external ONNX function to support the operator.
You can export it by following this example::
import onnxscript
# There are three opset version needed to be aligned
# This is (1) the opset version in ONNX function
from onnxscript.onnx_opset import opset15 as op
opset_version = 15
x = torch.randn(1, 2, 3, 4, requires_grad=True)
model = torch.nn.SELU()
custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
@onnxscript.script(custom_opset)
def Selu(X):
alpha = 1.67326 # auto wrapped as Constants
gamma = 1.0507
alphaX = op.CastLike(alpha, X)
gammaX = op.CastLike(gamma, X)
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
pos = gammaX * X
zero = op.CastLike(0, X)
return op.Where(X <= zero, neg, pos)
# setType API provides shape/type to ONNX shape/type inference
def custom_selu(g: jit_utils.GraphContext, X):
return g.onnxscript_op(Selu, X).setType(X.type())
# Register custom symbolic function
# There are three opset version needed to be aligned
# This is (2) the opset version in registry
torch.onnx.register_custom_op_symbolic(
symbolic_name="aten::selu",
symbolic_fn=custom_selu,
opset_version=opset_version,
)
# There are three opset version needed to be aligned
# This is (2) the opset version in exporter
torch.onnx.export(
model,
x,
"model.onnx",
opset_version=opset_version,
# only needed if you want to specify an opset version > 1.
custom_opsets={"onnx-script": 2}
)
The example above exports it as a custom operator in the "onnx-script" opset.
When exporting a custom operator, you can specify the custom domain version using the
``custom_opsets`` dictionary at export. If not specified, the custom opset version defaults to 1.
NOTE: Be careful to align the opset version mentioned in the above example, and make sure they are consumed in exporter step.
The example usage of how to write a onnx-script function is a beta version in terms of the active development on onnx-script.
Please follow the latest `ONNX-script <https://github.com/microsoft/onnx-script>`_
C++ Operators
~~~~~~~~~~~~~
If a model uses a custom operator implemented in C++ as described in
`Extending TorchScript with Custom C++ Operators <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_,
you can export it by following this example::
from torch.onnx import symbolic_helper
# Define custom symbolic function
@symbolic_helper.parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)
# Register custom symbolic function
torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)
class FooModel(torch.nn.Module):
def __init__(self, attr1, attr2):
super().__init__()
self.attr1 = attr1
self.attr2 = attr2
def forward(self, input1, input2):
# Calling custom op
return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)
model = FooModel(attr1, attr2)
torch.onnx.export(
model,
(example_input1, example_input1),
"model.onnx",
# only needed if you want to specify an opset version > 1.
custom_opsets={"custom_domain": 2}
)
The example above exports it as a custom operator in the "custom_domain" opset.
When exporting a custom operator, you can specify the custom domain version using the
``custom_opsets`` dictionary at export. If not specified, the custom opset version defaults to 1.
The runtime that consumes the model needs to support the custom op. See
`Caffe2 custom ops <https://caffe2.ai/docs/custom-operators.html>`_,
`ONNX Runtime custom ops <https://onnxruntime.ai/docs/reference/operators/add-custom-op.html>`_,
or your runtime of choice's documentation.
Discovering all unconvertible ATen ops at once
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
When export fails due to an unconvertible ATen op, there may in fact be more
than one such op but the error message only mentions the first. To discover
all of the unconvertible ops in one go you can::
# prepare model, args, opset_version
...
torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
model, args, opset_version=opset_version
)
print(set(unconvertible_ops))
The set is approximated because some ops may be removed during the conversion
process and don't need to be converted. Some other ops may have partial support
that will fail conversion with particular inputs, but this should give you a
general idea of what ops are not supported. Please feel free to open GitHub Issues
for op support requests.
Frequently Asked Questions
--------------------------
Q: I have exported my LSTM model, but its input size seems to be fixed?
The tracer records the shapes of the example inputs. If the model should accept
inputs of dynamic shapes, set ``dynamic_axes`` when calling :func:`torch.onnx.export`.
Q: How to export models containing loops?
See `Tracing vs Scripting`_.
Q: How to export models with primitive type inputs (e.g. int, float)?
Support for primitive numeric type inputs was added in PyTorch 1.9.
However, the exporter does not support models with str inputs.
Q: Does ONNX support implicit scalar datatype casting?
The ONNX standard does not, but the exporter will try to handle that part.
Scalars are exported as constant tensors.
The exporter will figure out the right data type for scalars. In rare cases when it is unable
to do so, you will need to manually specify the datatype with e.g. `dtype=torch.float32`.
If you see any errors, please [create a GitHub issue](https://github.com/pytorch/pytorch/issues).
Q: Are lists of Tensors exportable to ONNX?
Yes, for ``opset_version`` >= 11, since ONNX introduced the Sequence type in opset 11.
Python API
----------
.. automodule:: torch.onnx
Functions
^^^^^^^^^
.. autofunction:: export
.. autofunction:: export_to_pretty_string
.. autofunction:: register_custom_op_symbolic
.. autofunction:: unregister_custom_op_symbolic
.. autofunction:: select_model_mode_for_export
.. autofunction:: is_in_onnx_export
.. autofunction:: enable_log
.. autofunction:: disable_log
.. autofunction:: torch.onnx.verification.find_mismatch
Classes
^^^^^^^
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
JitScalarType
torch.onnx.verification.GraphInfo
torch.onnx.verification.VerificationOptions

View File

@ -5,7 +5,7 @@ ONNX supported TorchScript operators
.. This file is automatically generated during the documentation build
.. by cross referencing ONNX operator symbolics with TorchScript operators via
.. ``docs/source/scripts/build_onnx_supported_aten_op_csv_table.py``.
.. ``docs/source/scripts/build_onnx_torchscript_supported_aten_op_csv_table.py``.
.. Do not modify directly and instead `rebuild the docs <https://github.com/pytorch/pytorch#building-the-documentation>`_.
This page lists the TorchScript operators that are supported/unsupported by ONNX export.

View File

@ -56,6 +56,8 @@ Some of the most commonly used backends include:
- CUDA graphs with AOT Autograd. `Read more <https://github.com/pytorch/torchdynamo/pull/757>`__
* - ``torch.compile(m, backend="ipex")``
- Uses IPEX on CPU. `Read more <https://github.com/intel/intel-extension-for-pytorch>`__
* - ``torch.compile(m, backend="onnxrt")``
- Uses ONNX Runtime for training on CPU/GPU. :doc:`Read more <onnx_dynamo_onnxruntime_backend>`
**Inference-only backends**
@ -65,10 +67,10 @@ Some of the most commonly used backends include:
* - Backend
- Description
* - ``torch.compile(m, backend="onnxrt")``
- Uses ONNXRT for inference on CPU/GPU. `Read more <https://onnxruntime.ai/>`__
* - ``torch.compile(m, backend="tensorrt")``
- Uses ONNXRT to run TensorRT for inference optimizations. `Read more <https://github.com/onnx/onnx-tensorrt>`__
- Uses ONNX Runtime to run TensorRT for inference optimizations. `Read more <https://github.com/onnx/onnx-tensorrt>`__
* - ``torch.compile(m, backend="ipex")``
- Uses IPEX for inference on CPU. `Read more <https://github.com/intel/intel-extension-for-pytorch>`__
* - ``torch.compile(m, backend="tvm")``
- Uses Apache TVM for inference optimizations. `Read more <https://tvm.apache.org/>`__

View File

@ -1,5 +1,3 @@
"""ONNX exporter."""
from torch import _C
from torch._C import _onnx as _C_onnx
from torch._C._onnx import (

View File

@ -269,13 +269,16 @@ class Invocation:
@dataclasses.dataclass
class DiagnosticOptions:
"""
Options for diagnostic context.
"""Options for diagnostic context.
Attributes:
verbosity_level: Set the amount of information logged for each diagnostics,
equivalent to the 'level' in Python logging module.
warnings_as_errors: When True, warning diagnostics are treated as error diagnostics.
"""
verbosity_level: int = dataclasses.field(default=logging.INFO)
"""Diagnostic context verbosity level, equivalent to the 'level' in Python logging module.
Controls the amount of information logged inside each diagnostics."""
"""Set the amount of information logged for each diagnostics, equivalent to the 'level' in Python logging module."""
warnings_as_errors: bool = dataclasses.field(default=False)
"""If True, warning diagnostics are treated as error diagnostics."""

View File

@ -69,8 +69,8 @@ else:
_DEFAULT_OPSET_VERSION: Final[int] = 18
"""The default ONNX opset version the exporter will use if one is not specified explicitly
through ``ExportOptions``. This should NEVER be accessed outside of this module! Users
should reference ``ExportOptions.opset_version``."""
through :class:`ExportOptions`. This should NEVER be accessed outside of this module! Users
should reference :attr:`ExportOptions.opset_version`."""
_PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues"
"""The URL to the PyTorch GitHub issues page."""
@ -89,15 +89,15 @@ class ONNXFakeContext:
"""A dataclass used to store context for model export using FakeTensor.
This dataclass stores the FakeTensorMode instance used to convert
real tensors and model parameters into fake tensors. This ``fake_mode`` is
reused internally during tracing of a ``torch.nn.Module`` into a FX ``GraphModule``.
real tensors and model parameters into fake tensors. This :attr:`ONNXFakeContext.fake_mode` is
reused internally during tracing of a :class:`torch.nn.Module` into a FX :class:`GraphModule`.
"""
fake_mode: fake_tensor.FakeTensorMode
"""The fake tensor mode used for tracing model using fake tensors and parameters."""
state_dict_paths: Optional[Tuple[Union[str, io.BytesIO]]] = None
"""List of paths of files that contain the model `state_dict`"""
"""List of paths of files that contain the model :meth:`state_dict`"""
class OnnxRegistry:
@ -271,18 +271,29 @@ class OnnxRegistry:
class ExportOptions:
"""Options to influence the TorchDynamo ONNX exporter."""
"""Options to influence the TorchDynamo ONNX exporter.
Attributes:
dynamic_shapes: Shape information hint for input/output tensors.
When ``None``, the exporter determines the most compatible setting.
When ``True``, all input shapes are considered dynamic.
When ``False``, all input shapes are considered static.
op_level_debug: Whether to export the model with op-level debug information
diagnostic_options: The diagnostic options for the exporter.
fake_context: The fake context used for symbolic tracing.
onnx_registry: The ONNX registry used to register ATen operators to ONNX functions.
"""
dynamic_shapes: Optional[bool] = None
"""Shape information hint for input/output tensors.
- ``None``: the exporter determines the most compatible setting.
- ``True``: all input shapes are considered dynamic.
- ``False``: all input shapes are considered static."""
- ``False``: all input shapes are considered static.
"""
op_level_debug: Optional[bool] = None
"""Whether to export the model with op-level debug information by evaluating
ops through ONNX Runtime."""
"""When True export the model with op-level debug running ops through ONNX Runtime."""
diagnostic_options: DiagnosticOptions
"""The diagnostic options for the exporter."""
@ -291,8 +302,7 @@ class ExportOptions:
"""The fake context used for symbolic tracing."""
onnx_registry: Optional[OnnxRegistry] = None
"""The ONNX registry used to register ATen operators to ONNX functions. Defaults to
opset18."""
"""The ONNX registry used to register ATen operators to ONNX functions."""
@_beartype.beartype
def __init__(
@ -312,8 +322,8 @@ class ExportOptions:
class ResolvedExportOptions(ExportOptions):
"""Consolidates `ExportOptions` with default values.
All unspecified options from `ExportOptions` are assigned a default value.
"""Consolidates :class:`ExportOptions` with default values.
All unspecified options from :class:`ExportOptions` are assigned a default value.
This is an internal class and its API may be changed at any time without notice.
"""
@ -412,12 +422,12 @@ class ResolvedExportOptions(ExportOptions):
def enable_fake_mode():
"""Enable fake mode for the duration of the context.
Internally it instantiates a `FakeTensorMode` context manager that converts
user input and model parameters into `FakeTensor`.
Internally it instantiates a :class:`torch._subclasses.fake_tensor.FakeTensorMode` context manager
that converts user input and model parameters into :class:`torch._subclasses.fake_tensor.FakeTensor`.
A [FakeTensor](https://github.com/pytorch/pytorch/blob/main/torch/_subclasses/fake_tensor.py#L870)
is a `torch.Tensor` with the ability to run PyTorch code without having to
actually do computation through tensors allocated on a `meta` device. Because
A :class:`torch._subclasses.fake_tensor.FakeTensor`
is a :class:`torch.Tensor` with the ability to run PyTorch code without having to
actually do computation through tensors allocated on a ``meta`` device. Because
there is no actual data being allocated on the device, this API allows for
exporting large models without the actual memory footprint needed for executing it.
@ -425,8 +435,8 @@ def enable_fake_mode():
are too large to fit into memory.
Returns:
A `ONNXFakeContext` object that must be passed to `torch.onnx.dynamo_export`
through the `ExportOptions.fake_context` argument.
A :class:`ONNXFakeContext` object that must be passed to :func:`dynamo_export`
through the :attr:`ExportOptions.fake_context` argument.
Example::
@ -499,7 +509,7 @@ class ExportOutputSerializer(Protocol):
Example:
A simple serializer that writes the exported ``onnx.ModelProto`` in Protobuf
A simple serializer that writes the exported :py:obj:`onnx.ModelProto` in Protobuf
format to ``destination``:
::
@ -604,7 +614,7 @@ class ExportOutput:
@property
def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined]
"""The exported ONNX model as an ``onnx.ModelProto``."""
"""The exported ONNX model as an :py:obj:`onnx.ModelProto`."""
if self._export_exception is not None:
raise self._export_exception
@ -746,8 +756,8 @@ class ExportOutput:
will be created to store the each initializer of the ONNX model in a separate file. For example, if the
destination is "/path/model.onnx", the initializers will be saved in "/path/model_initializers/" folder.
model_state_dict: The state_dict of the PyTorch model containing all weights on it.
It can be either a dict as returned by `model.state_dict()`, or a string with a file name.
Required when ``enable_fake_mode`` is used but real initializers are needed on the ONNX graph.
It can be either a dict as returned by :meth:`model.state_dict`, or a string with a file name.
Required when :func:`enable_fake_mode` is used but real initializers are needed on the ONNX graph.
It can be either a string with the path to a checkpoint or a dictionary with the actual model state.
serializer: The serializer to use. If not specified, the model will be serialized as Protobuf.
@ -792,7 +802,7 @@ class ExportOutput:
if _model_state_dict_files:
if not isinstance(destination, str):
raise RuntimeError(
"`destination` must be a string with a path when model_state_dict is specified."
"`destination` must be a string with a path when `model_state_dict` is specified."
)
destination_path, destination_filename = os.path.split(destination)
onnx_model_location = destination_filename
@ -845,10 +855,10 @@ class ExportOutput:
diagnostic_context: diagnostics.DiagnosticContext,
) -> Self:
"""
Creates an instance of ``ExportOutput`` when the export process encounters a failure.
Creates an instance of :class:`ExportOutput` when the export process encounters a failure.
In case of a failed export, this method is used to encapsulate the exception
and associated diagnostic context within an ``ExportOutput`` instance for
and associated diagnostic context within an :class:`ExportOutput` instance for
easier handling and debugging.
Args:
@ -856,7 +866,7 @@ class ExportOutput:
diagnostic_context: The context associated with diagnostics during export.
Returns:
An instance of ``ExportOutput`` representing the failed export output.
An instance of :class:`ExportOutput` representing the failed export output.
"""
# Defer `import onnx` out of `import torch` path
# https://github.com/pytorch/pytorch/issues/103764
@ -1034,7 +1044,7 @@ class OnnxExporterError(RuntimeError):
"""Raised when an ONNX exporter error occurs.
This exception is thrown when there's an error during the ONNX export process.
It encapsulates the `ExportOutput` object generated until the failure, allowing
It encapsulates the :class:`ExportOutput` object generated until the failure, allowing
access to the partial export results and associated metadata.
"""
@ -1113,19 +1123,59 @@ def dynamo_export(
Returns:
An in-memory representation of the exported ONNX model.
Example:
**Example 1 - Simplest export**
::
import torch.onnx
torch.onnx.dynamo_export(
my_nn_module,
torch.randn(2, 2, 2), # positional input 1
torch.randn(2, 2, 2), # positional input 2
my_nn_module_attribute="hello", # keyword input
export_options=ExportOptions(
dynamic_shapes=True,
)
).save("my_model.onnx")
class MyModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x, bias=None):
out = self.linear(x)
out = out + bias
return out
model = MyModel()
kwargs = {"bias": 3.}
args = (torch.randn(2, 2, 2),)
export_output = torch.onnx.dynamo_export(
model,
*args,
**kwargs).save("my_simple_model.onnx")
**Example 2 - Exporting with dynamic shapes**
::
# The previous model can be exported with dynamic shapes
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
export_output = torch.onnx.dynamo_export(
model,
*args,
**kwargs,
export_options=export_options)
export_output.save("my_dynamic_model.onnx")
By printing input dynamic dimensions we can see the input shape is no longer (2,2,2)
::
>>> print(export_output.model_proto.graph.input[0])
name: "arg0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "arg0_dim_0"
}
dim {
dim_param: "arg0_dim_1"
}
dim {
dim_param: "arg0_dim_2"
}
}
}
}
"""
resolved_export_options = (

View File

@ -1641,7 +1641,7 @@ class InsertTypePromotion(_pass.Transform):
metadata, specifically the fake tensor stored under node.meta["val"], and ensure it
reflects the latest changes.
See [FXE0015: fx_node_insert_type_promotion](https://pytorch.org/docs/master/generated/onnx_diagnostics_rules/FXE0015%3Afx-node-insert-type-promotion.html) for more details. # noqa: B950
See [FXE0015: fx_node_insert_type_promotion](https://pytorch.org/docs/master/generated/onnx_dynamo_diagnostics_rules/FXE0015%3Afx-node-insert-type-promotion.html) for more details. # noqa: B950
"""
def __init__(