mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 14:15:01 +08:00
- Move definitions in `__init__` to internal classes and expose them by importing to init (prevent circular dependencies): https://github.com/pytorch/pytorch/wiki/torch.onnx-Namespacing - Context classes and enums are moved to `_exporter_states.py` - Exceptions are moved to `errors.py` - Define `__all__` for torch.onnx. https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation - Moved `utils.__IN_ONNX_EXPORT` to `GLOBALS.in_onnx_export` - Deprecated `torch.onnx._export` Precedes #78231 Using this as an aid for finding public functions: ```python list(filter(lambda x: not x.startswith("_"), torch.onnx.utils.__dict__.keys())) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/78446 Approved by: https://github.com/BowenBao
1553 lines
58 KiB
Python
1553 lines
58 KiB
Python
"""Functions to export models into the ONNX IR format.
|
|
|
|
These models can be loaded with the ONNX library and then
|
|
converted to models which run on other deep learning frameworks.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import copy
|
|
import inspect
|
|
import itertools
|
|
import os
|
|
import re
|
|
import textwrap
|
|
import typing
|
|
import warnings
|
|
import zipfile
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch._C._onnx as _C_onnx
|
|
import torch.jit._trace
|
|
import torch.serialization
|
|
from torch import _C
|
|
from torch.onnx import ( # noqa: F401
|
|
_constants,
|
|
_exporter_states,
|
|
_patch_torch,
|
|
errors,
|
|
symbolic_caffe2,
|
|
symbolic_helper,
|
|
symbolic_registry,
|
|
)
|
|
from torch.onnx._globals import GLOBALS
|
|
|
|
__all__ = [
|
|
"is_in_onnx_export",
|
|
"select_model_mode_for_export",
|
|
"disable_apex_o2_state_dict_hook",
|
|
"setup_onnx_logging",
|
|
"exporter_context",
|
|
"export",
|
|
"warn_on_static_input_change",
|
|
"unpack_quantized_tensor",
|
|
"export_to_pretty_string",
|
|
"unconvertible_ops",
|
|
"get_ns_op_name_from_custom_op",
|
|
"register_custom_op_symbolic",
|
|
"unregister_custom_op_symbolic",
|
|
]
|
|
|
|
|
|
def is_in_onnx_export() -> bool:
|
|
"""Returns whether it is in the middle of ONNX export."""
|
|
return GLOBALS.in_onnx_export
|
|
|
|
|
|
# TODO(justinchuby): Remove dependency to this global variable from constant_fold.cpp
|
|
# Skip check due to cannot import IValue from torch._C
|
|
_params_dict = {} # type: ignore[var-annotated]
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def select_model_mode_for_export(model, mode):
|
|
if not isinstance(model, torch.jit.ScriptFunction):
|
|
is_originally_training = model.training
|
|
|
|
if mode is None:
|
|
mode = _C_onnx.TrainingMode.EVAL
|
|
# if the model is in training mode but the user did not specify
|
|
# to export the model in training mode, export the model in inference
|
|
# mode (default) and warn them
|
|
if is_originally_training:
|
|
warnings.warn(
|
|
"You are exporting the model to ONNX while in training mode with "
|
|
"'train' parameter not specified. The model will default to inference mode export. "
|
|
"If you wish to export a training amenable ONNX model, specify training=TrainingMode.TRAINING or "
|
|
"training=TrainingMode.PRESERVE (to preserve the original model state) in torch.onnx.export()."
|
|
)
|
|
|
|
# if mode == TrainingMode.EVAL or (mode == TrainingMode.PRESERVE and not is_originally_training) => is_training = False
|
|
is_export_training = False
|
|
# ONNX opset 12 has better support for training amenable models, with updated
|
|
# versions of the dropout and batch_norm operators
|
|
if mode == _C_onnx.TrainingMode.TRAINING or (
|
|
mode == _C_onnx.TrainingMode.PRESERVE and is_originally_training
|
|
):
|
|
|
|
if GLOBALS.export_onnx_opset_version < 12:
|
|
warnings.warn(
|
|
"You are exporting the model in training mode with onnx opset "
|
|
f"version {GLOBALS.export_onnx_opset_version}. "
|
|
"Opset versions lower than opset 12 will not be able to export "
|
|
"nodes such as Dropout and BatchNorm correctly."
|
|
)
|
|
is_export_training = True
|
|
|
|
symbolic_helper._set_training_mode(is_export_training)
|
|
model.train(is_export_training)
|
|
try:
|
|
yield
|
|
finally:
|
|
if not isinstance(model, torch.jit.ScriptFunction):
|
|
# FIXME(justinchuby): is_originally_training is possibly unbound
|
|
model.train(is_originally_training)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def disable_apex_o2_state_dict_hook(model):
|
|
# Apex O2 hook state_dict to return fp16 weights as fp32.
|
|
# Exporter cannot identify them as same tensors.
|
|
# Since this hook is only used by optimizer, it is safe to
|
|
# remove this hook while exporting.
|
|
if not isinstance(model, torch.jit.ScriptFunction):
|
|
tmp_map = {} # type: ignore[var-annotated]
|
|
for module in model.modules():
|
|
for k, v in module._state_dict_hooks.items():
|
|
if type(v).__name__ == "O2StateDictHook":
|
|
if module not in tmp_map:
|
|
tmp_map[module] = {}
|
|
tmp_map[module][k] = v
|
|
if module in tmp_map:
|
|
for k in tmp_map[module].keys():
|
|
module._state_dict_hooks.pop(k)
|
|
try:
|
|
yield
|
|
finally:
|
|
if not isinstance(model, torch.jit.ScriptFunction):
|
|
# FIXME(justinchuby): tmp_map is possibly unbound
|
|
for module, m_map in tmp_map.items():
|
|
for k, v in m_map.items():
|
|
module._state_dict_hooks[k] = v
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def setup_onnx_logging(verbose):
|
|
is_originally_enabled = torch.onnx.is_onnx_log_enabled()
|
|
if is_originally_enabled or verbose:
|
|
torch.onnx.enable_log()
|
|
try:
|
|
yield
|
|
finally:
|
|
if not is_originally_enabled:
|
|
torch.onnx.disable_log()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def exporter_context(model, mode, verbose):
|
|
with select_model_mode_for_export(
|
|
model, mode
|
|
) as mode_ctx, disable_apex_o2_state_dict_hook(
|
|
model
|
|
) as apex_ctx, setup_onnx_logging(
|
|
verbose
|
|
) as log_ctx:
|
|
yield (mode_ctx, apex_ctx, log_ctx)
|
|
|
|
|
|
def export(
|
|
model,
|
|
args,
|
|
f,
|
|
export_params=True,
|
|
verbose=False,
|
|
training=None,
|
|
input_names=None,
|
|
output_names=None,
|
|
operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
|
|
opset_version=None,
|
|
do_constant_folding=True,
|
|
dynamic_axes=None,
|
|
keep_initializers_as_inputs=None,
|
|
custom_opsets=None,
|
|
export_modules_as_functions=False,
|
|
):
|
|
|
|
_export(
|
|
model,
|
|
args,
|
|
f,
|
|
export_params,
|
|
verbose,
|
|
training,
|
|
input_names,
|
|
output_names,
|
|
operator_export_type=operator_export_type,
|
|
opset_version=opset_version,
|
|
do_constant_folding=do_constant_folding,
|
|
dynamic_axes=dynamic_axes,
|
|
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
|
custom_opsets=custom_opsets,
|
|
export_modules_as_functions=export_modules_as_functions,
|
|
)
|
|
|
|
|
|
def _is_constant_tensor_list(node):
|
|
if node.kind() != "prim::Constant":
|
|
return False
|
|
output_type = node.output().type()
|
|
if output_type.isSubtypeOf(_C.ListType.ofTensors()):
|
|
return True
|
|
if output_type.isSubtypeOf(_C.ListType(_C.OptionalType.ofTensor())):
|
|
return True
|
|
|
|
|
|
# ONNX can't handle constants that are lists of tensors, which can
|
|
# get generated in constant prop. So we split them back into prim::ListConstructs
|
|
|
|
|
|
def _split_tensor_list_constants(g, block):
|
|
for node in block.nodes():
|
|
for subblock in node.blocks():
|
|
_split_tensor_list_constants(g, subblock)
|
|
if _is_constant_tensor_list(node):
|
|
inputs = []
|
|
for val in node.output().toIValue():
|
|
input = g.insertConstant(val)
|
|
input.node().moveBefore(node)
|
|
input.node().copyMetadata(node)
|
|
inputs.append(input)
|
|
|
|
lc = (
|
|
g.create("prim::ListConstruct", inputs)
|
|
.insertBefore(node)
|
|
.output()
|
|
.setType(_C.ListType.ofTensors())
|
|
)
|
|
lc.node().copyMetadata(node)
|
|
node.output().replaceAllUsesWith(lc)
|
|
|
|
|
|
def _optimize_graph(
|
|
graph: _C.Graph,
|
|
operator_export_type: _C_onnx.OperatorExportTypes,
|
|
_disable_torch_constant_prop: bool = False,
|
|
fixed_batch_size: bool = False,
|
|
params_dict=None,
|
|
dynamic_axes=None,
|
|
input_names=None,
|
|
module=None,
|
|
):
|
|
# Inline everything
|
|
_C._jit_pass_inline(graph)
|
|
|
|
# Remove fork/wait nodes
|
|
_C._jit_pass_inline_fork_wait(graph)
|
|
_C._jit_pass_lint(graph)
|
|
_C._jit_pass_lower_all_tuples(graph)
|
|
|
|
# we now record some ops like ones/zeros
|
|
# into a trace where we previously recorded constants.
|
|
# use constant prop to maintain our current level of onnx support
|
|
# without implementing symbolics for all of them
|
|
if _disable_torch_constant_prop is False:
|
|
_C._jit_pass_constant_propagation(graph)
|
|
|
|
_split_tensor_list_constants(graph, graph)
|
|
# run dce to eliminate dead parts of the graph that might have been
|
|
# left behind by things like symbolic_override
|
|
_C._jit_pass_dce(graph)
|
|
_C._jit_pass_lint(graph)
|
|
|
|
_C._jit_pass_canonicalize_graph_fuser_ops(graph)
|
|
_C._jit_pass_lint(graph)
|
|
_C._jit_pass_peephole(graph, True)
|
|
_C._jit_pass_fuse_addmm(graph)
|
|
_C._jit_pass_lint(graph)
|
|
|
|
_C._jit_pass_peephole(graph, True)
|
|
_C._jit_pass_lower_all_tuples(graph)
|
|
# in _jit_pass_onnx, symbolic functions are called for each node for conversion.
|
|
# However, there are nodes that cannot be converted without additional context.
|
|
# For example, the number of outputs from split (and whether it is static or dynamic) is unknown
|
|
# until the point where it is unpacked by listUnpack node.
|
|
# This pass does a preprocess, and prepares the nodes such that enough context can be received
|
|
# by the symbolic function.
|
|
_C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module)
|
|
_C._jit_pass_onnx_preprocess(graph)
|
|
|
|
# onnx does not support tuples, so try to remove them
|
|
_C._jit_pass_lint(graph)
|
|
|
|
# onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
|
|
_C._jit_pass_prepare_division_for_onnx(graph)
|
|
|
|
_C._jit_pass_onnx_remove_print(graph)
|
|
_C._jit_pass_onnx_preprocess_caffe2(graph)
|
|
|
|
symbolic_helper._quantized_ops.clear()
|
|
# Unpack quantized weights for conv and linear ops and insert into graph.
|
|
_C._jit_pass_onnx_unpack_quantized_weights(
|
|
graph, params_dict, symbolic_helper.is_caffe2_aten_fallback()
|
|
)
|
|
if symbolic_helper.is_caffe2_aten_fallback():
|
|
# Insert permutes before and after each conv op to ensure correct order.
|
|
_C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict)
|
|
|
|
# Find consecutive permutes that are no-ops and remove them.
|
|
_C._jit_pass_custom_pattern_based_rewrite_graph(
|
|
textwrap.dedent(
|
|
"""\
|
|
graph(%Pi):
|
|
%Pq = quantized::nhwc2nchw(%Pi)
|
|
%Pr = quantized::nchw2nhwc(%Pq)
|
|
return (%Pr)"""
|
|
),
|
|
textwrap.dedent(
|
|
"""\
|
|
graph(%Ri):
|
|
return (%Ri)"""
|
|
),
|
|
graph,
|
|
)
|
|
|
|
# onnx only supports tensors, so we turn all out number types into tensors
|
|
_C._jit_pass_erase_number_types(graph)
|
|
if GLOBALS.onnx_shape_inference:
|
|
input_names = [] if input_names is None else input_names
|
|
dynamic_axes = {} if dynamic_axes is None else dynamic_axes
|
|
_C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
|
|
_C._jit_pass_onnx_lint(graph)
|
|
graph = _C._jit_pass_onnx(graph, operator_export_type)
|
|
_C._jit_pass_onnx_lint(graph)
|
|
_C._jit_pass_lint(graph)
|
|
|
|
_C._jit_pass_onnx_scalar_type_analysis(
|
|
graph, True, GLOBALS.export_onnx_opset_version
|
|
)
|
|
_C._jit_pass_lint(graph)
|
|
|
|
_C._jit_pass_onnx_peephole(
|
|
graph, GLOBALS.export_onnx_opset_version, fixed_batch_size
|
|
)
|
|
_C._jit_pass_lint(graph)
|
|
|
|
# graph is not a valid jit graph anymore because types have been replaced
|
|
# (e.g. int with Tensor), so it now contains operators that don't actually
|
|
# exist. We can't run normal dead code elimination because it'd fail trying
|
|
# to look up if an operator has side effects, but we can run a dead code
|
|
# elimination variant that doesn't need to look up if an op has side effects.
|
|
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
|
_C._jit_pass_lint(graph)
|
|
graph = _C._jit_pass_canonicalize(graph)
|
|
_C._jit_pass_lint(graph)
|
|
if GLOBALS.onnx_shape_inference:
|
|
_C._jit_pass_onnx_graph_shape_type_inference(
|
|
graph, params_dict, GLOBALS.export_onnx_opset_version
|
|
)
|
|
return graph
|
|
|
|
|
|
def warn_on_static_input_change(input_states):
|
|
"""Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph.
|
|
|
|
We accept dictionaries and strings as ONNX inputs, but they should be only for
|
|
configuration use. we detect here if these inputs are modified, and if so we warn
|
|
the user that the changes won't take effect in the traced ONNX graph.
|
|
"""
|
|
for input, traced_input in zip(input_states[0], input_states[1]):
|
|
if isinstance(input, dict):
|
|
if list(input.keys()) != list(traced_input.keys()):
|
|
warning = (
|
|
"We detected that you are modifying a dictionary that is an input to your "
|
|
"model. "
|
|
"Note that dictionaries are allowed as inputs in ONNX but they should be "
|
|
"handled with care. "
|
|
"Usages of dictionaries is not recommended, and should not be used except "
|
|
"for configuration use. "
|
|
"Also note that the order and values of the keys must remain the same. "
|
|
)
|
|
warnings.warn(warning)
|
|
elif isinstance(input, str):
|
|
if input != traced_input:
|
|
warning = (
|
|
"The model seems to have string inputs/outputs. "
|
|
"Note that strings will not appear as inputs/outputs of the ONNX graph. "
|
|
)
|
|
warnings.warn(warning)
|
|
|
|
|
|
def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
|
|
"""Resolves the arguments that are ignored when export_type != operator_export_type.ONNX."""
|
|
if (
|
|
operator_export_type is not operator_export_type.ONNX
|
|
and _C_onnx._CAFFE2_ATEN_FALLBACK
|
|
):
|
|
if arg_value is True:
|
|
warnings.warn(
|
|
f"'{arg_name}' can be set to True only when 'operator_export_type' is "
|
|
"`ONNX`. Since 'operator_export_type' is not set to 'ONNX', "
|
|
f"'{arg_name}' argument will be ignored."
|
|
)
|
|
arg_value = False
|
|
return arg_value
|
|
|
|
|
|
def _decide_keep_init_as_input(
|
|
keep_initializers_as_inputs: Optional[bool],
|
|
operator_export_type: _C_onnx.OperatorExportTypes,
|
|
opset_version: int,
|
|
):
|
|
"""Decides whether the initializers in the graph should be listed as ONNX graph inputs.
|
|
|
|
This method encapsulates the logic to decide whether the initializers in the graph
|
|
should be listed as ONNX graph inputs (i.e., whether to choose ONNX IR v3 or v4).
|
|
If keep_initializers_as_inputs is not specified (None), then we decide whether to keep
|
|
initializers as graph inputs (val_keep_init_as_ip) based on export type. If export type
|
|
is ONNX, then do not keep initializers as input (val_keep_init_as_ip=False). For all other
|
|
export types keep initializers as input (val_keep_init_as_ip=True).
|
|
If keep_initializers_as_inputs is specified, then respect it. Unless opset version <= 8,
|
|
in which case it must be ignored because for opset version <= 8, all initializers MUST be
|
|
part of graph input (only ONNX IR v3 is allowed), i.e. val_keep_init_as_ip=True.
|
|
|
|
Special handling is needed for opset version 8 or lower, because irrespective
|
|
of user input for keep_initializers_as_inputs, the graph must follow ONNX IR v3
|
|
semantics, i.e. all initializers must be listed as ONNX graph input.
|
|
"""
|
|
|
|
if opset_version < 9:
|
|
if keep_initializers_as_inputs is False:
|
|
warnings.warn(
|
|
"Setting 'keep_initializers_as_inputs=False' for opset version"
|
|
"8 or lower would lead to an invalid ONNX graph. Therefore, "
|
|
"'keep_initializers_as_inputs=False' is ignored during export."
|
|
"Exported model will have initializers as graph inputs (compliant "
|
|
" to ONNX IR v3)."
|
|
)
|
|
return True # i.e. True == initializers are part of graph input (ONNX IR v3)
|
|
val_keep_init_as_ip = (
|
|
True if keep_initializers_as_inputs is None else keep_initializers_as_inputs
|
|
)
|
|
if (
|
|
keep_initializers_as_inputs is None
|
|
and operator_export_type is _C_onnx.OperatorExportTypes.ONNX
|
|
):
|
|
val_keep_init_as_ip = False
|
|
return val_keep_init_as_ip
|
|
|
|
|
|
def _decide_add_node_names(add_node_names, operator_export_type):
|
|
return _resolve_args_by_export_type(
|
|
"add_node_names", add_node_names, operator_export_type
|
|
)
|
|
|
|
|
|
def _decide_constant_folding(do_constant_folding, operator_export_type, training):
|
|
do_constant_folding = _resolve_args_by_export_type(
|
|
"do_constant_folding", do_constant_folding, operator_export_type
|
|
)
|
|
if do_constant_folding and (
|
|
training is not None and training is not _C_onnx.TrainingMode.EVAL
|
|
):
|
|
warnings.warn(
|
|
"It is recommended that constant folding be turned off ('do_constant_folding=False') "
|
|
"when exporting the model in training-amenable mode, i.e. with 'training=TrainingMode.TRAIN' "
|
|
"or 'training=TrainingMode.PRESERVE' (when model is in training mode). Otherwise, some "
|
|
"learnable model parameters may not translate correctly in the exported ONNX model "
|
|
"because constant folding mutates model parameters. Please consider "
|
|
"turning off constant folding or setting the training=TrainingMode.EVAL."
|
|
)
|
|
return do_constant_folding
|
|
|
|
|
|
def _signature(model) -> inspect.Signature:
|
|
should_be_callable = getattr(model, "forward", model)
|
|
if callable(should_be_callable):
|
|
return inspect.signature(should_be_callable)
|
|
raise ValueError("model has no forward method and is not callable")
|
|
|
|
|
|
def _decide_input_format(model, args):
|
|
try:
|
|
sig = _signature(model)
|
|
except ValueError as e:
|
|
warnings.warn(f"{e}, skipping _decide_input_format")
|
|
return args
|
|
try:
|
|
ordered_list_keys = list(sig.parameters.keys())
|
|
if ordered_list_keys[0] == "self":
|
|
ordered_list_keys = ordered_list_keys[1:]
|
|
args_dict: Dict = {}
|
|
if isinstance(args, list):
|
|
args_list = args
|
|
elif isinstance(args, tuple):
|
|
args_list = list(args)
|
|
else:
|
|
args_list = [args]
|
|
if isinstance(args_list[-1], dict):
|
|
args_dict = args_list[-1]
|
|
args_list = args_list[:-1]
|
|
n_nonkeyword = len(args_list)
|
|
for optional_arg in ordered_list_keys[n_nonkeyword:]:
|
|
if optional_arg in args_dict:
|
|
args_list.append(args_dict[optional_arg])
|
|
# Check if this arg has a default value
|
|
else:
|
|
param = sig.parameters[optional_arg]
|
|
if param.default != param.empty:
|
|
args_list.append(param.default)
|
|
args = args_list if isinstance(args, list) else tuple(args_list)
|
|
# Cases of models with no input args
|
|
except IndexError:
|
|
warnings.warn("No input args, skipping _decide_input_format")
|
|
except Exception as e:
|
|
warnings.warn(f"Skipping _decide_input_format\n {e.args[0]}")
|
|
|
|
return args
|
|
|
|
|
|
def _trace(func, args, operator_export_type, return_outs=False):
|
|
# Special case for common case of passing a single Tensor
|
|
if isinstance(args, torch.Tensor):
|
|
args = (args,)
|
|
|
|
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
|
|
func, args, strict=False, _force_outplace=False, _return_inputs_states=True
|
|
)
|
|
warn_on_static_input_change(inputs_states)
|
|
|
|
trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={})
|
|
if return_outs:
|
|
return trace_graph, torch_out
|
|
return trace_graph
|
|
|
|
|
|
def _trace_and_get_graph_from_model(model, args):
|
|
# A basic sanity check: make sure the state_dict keys are the same
|
|
# before and after running the model. Fail fast!
|
|
orig_state_dict_keys = torch.jit._unique_state_dict(model).keys()
|
|
|
|
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
|
|
model, args, strict=False, _force_outplace=False, _return_inputs_states=True
|
|
)
|
|
warn_on_static_input_change(inputs_states)
|
|
|
|
if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys():
|
|
raise RuntimeError(
|
|
"state_dict changed after running the tracer; "
|
|
"something weird is happening in your model!"
|
|
)
|
|
|
|
return trace_graph, torch_out
|
|
|
|
|
|
def _get_param_count_list(method_graph, args_params):
|
|
param_count_list = []
|
|
for input_, arg_params_ in zip(method_graph.inputs(), args_params):
|
|
if "PackedParams" in str(input_.type()):
|
|
in_vars, _ = torch.jit._flatten(arg_params_)
|
|
param_count_list.append(len(in_vars))
|
|
else:
|
|
param_count_list.append(arg_params_ is not None)
|
|
|
|
return param_count_list
|
|
|
|
|
|
def _check_flatten_did_not_remove(original, jit_flattened):
|
|
"""torch.jit._flatten removes None. Check if it did so in this case."""
|
|
|
|
def flatten(x):
|
|
if isinstance(x, (list, tuple)):
|
|
for inner in x:
|
|
yield from flatten(inner)
|
|
elif isinstance(x, dict):
|
|
for inner in x.values():
|
|
yield from flatten(inner)
|
|
else:
|
|
yield x
|
|
|
|
flattened_with_none = list(flatten(original))
|
|
num_none = len(flattened_with_none) - len(jit_flattened)
|
|
assert num_none >= 0
|
|
if num_none:
|
|
raise ValueError(
|
|
f"args contained {num_none} None's after flattening. "
|
|
"When exporting a ScriptModule or ScriptFunction, no args may "
|
|
"be None because that breaks type propagation."
|
|
)
|
|
|
|
|
|
def _create_jit_graph(model, args):
|
|
torch_out = None
|
|
params: Union[List, Tuple]
|
|
if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)):
|
|
flattened_args = tuple(torch.jit._flatten(tuple(args))[0])
|
|
_check_flatten_did_not_remove(args, flattened_args)
|
|
if isinstance(model, torch.jit.ScriptModule):
|
|
try:
|
|
graph = model.forward.graph
|
|
except AttributeError as e:
|
|
raise RuntimeError("'forward' method must be a script method") from e
|
|
_C._jit_pass_onnx_function_substitution(graph)
|
|
freezed_m = _C._freeze_module(model._c, preserveParameters=True)
|
|
module, params = _C._jit_onnx_list_model_parameters(freezed_m)
|
|
method_graph = module._get_method("forward").graph
|
|
args_params = tuple(args) + tuple(params)
|
|
param_count_list = _get_param_count_list(method_graph, args_params)
|
|
in_vars, _ = torch.jit._flatten(args_params)
|
|
graph = _C._propagate_and_assign_input_shapes(
|
|
method_graph, tuple(in_vars), param_count_list, False, False
|
|
)
|
|
return graph, params, torch_out, module
|
|
elif isinstance(model, torch.jit.ScriptFunction):
|
|
params = ()
|
|
graph = model.graph
|
|
_C._jit_pass_onnx_function_substitution(graph)
|
|
param_count_list = _get_param_count_list(graph, args)
|
|
# FIXME(justinchuby): flattened_args is possibly unbound
|
|
graph = _C._propagate_and_assign_input_shapes(
|
|
graph, flattened_args, param_count_list, False, False
|
|
)
|
|
return graph, params, torch_out, None
|
|
else:
|
|
graph, torch_out = _trace_and_get_graph_from_model(model, args)
|
|
_C._jit_pass_onnx_lint(graph)
|
|
state_dict = torch.jit._unique_state_dict(model)
|
|
params = list(state_dict.values())
|
|
graph_inputs = list(graph.inputs())
|
|
user_input_num = len(graph_inputs) - len(state_dict)
|
|
param_names = list(state_dict.keys())
|
|
for i, inp in enumerate(graph_inputs):
|
|
if i >= user_input_num:
|
|
inp.setDebugName(param_names[i - user_input_num])
|
|
_C._jit_pass_onnx_function_substitution(graph)
|
|
return graph, params, torch_out, None
|
|
|
|
|
|
def _get_named_param_dict(graph, params):
|
|
input_and_param_names = [val.debugName() for val in graph.inputs()]
|
|
param_names = input_and_param_names[len(input_and_param_names) - len(params) :]
|
|
_params_dict = dict(zip(param_names, params))
|
|
return _params_dict
|
|
|
|
|
|
def _get_example_outputs(model, args):
|
|
input_args = copy.deepcopy(args)
|
|
input_kwargs = {}
|
|
if input_args and isinstance(input_args[-1], dict):
|
|
input_kwargs = input_args[-1]
|
|
input_args = input_args[:-1]
|
|
|
|
example_outputs = model(*input_args, **input_kwargs)
|
|
if isinstance(example_outputs, list):
|
|
example_outputs = [example_outputs]
|
|
elif not isinstance(example_outputs, tuple):
|
|
example_outputs = (example_outputs,)
|
|
|
|
return example_outputs
|
|
|
|
|
|
_qtype_vtype_map = {
|
|
torch.quint8: torch.uint8,
|
|
torch.qint8: torch.int8,
|
|
torch.qint32: torch.int32,
|
|
torch.quint4x2: torch.int8,
|
|
}
|
|
|
|
|
|
def unpack_quantized_tensor(value):
|
|
if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map:
|
|
q_value_dequantize = value.dequantize()
|
|
q_scale = torch.tensor(value.q_scale(), dtype=torch.double)
|
|
q_zero_point = torch.tensor(value.q_zero_point(), dtype=torch.int64)
|
|
q_value = q_value_dequantize / q_scale + q_zero_point
|
|
q_value = q_value.to(dtype=_qtype_vtype_map[value.dtype])
|
|
return q_value, q_scale, q_zero_point
|
|
else:
|
|
return (value,)
|
|
|
|
|
|
def _pre_trace_quant_model(model, args):
|
|
r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return
|
|
original model.
|
|
|
|
This is due to https://github.com/pytorch/pytorch/issues/75761.
|
|
"""
|
|
if any(
|
|
hasattr(m, "_packed_params") for m in getattr(model, "modules", lambda: [])()
|
|
) or any(getattr(arg, "is_quantized", False) for arg in args):
|
|
return torch.jit.trace(model, args)
|
|
return model
|
|
|
|
|
|
def _assign_onnx_node_name(graph, node_names):
|
|
"""Takes in ONNX graph, and mapping from _C.Node to node name in exported ONNX ModelProto.
|
|
|
|
Returns:
|
|
graph (_C.Graph): A TorchScript IR Graph with ONNX nodes, where each _C.Node gets its name
|
|
in exported ONNX ModelProto assigned as attribute ``onnx_name``.
|
|
"""
|
|
|
|
def n_fn(n, b_fn, node_names):
|
|
for b in n.blocks():
|
|
b_fn(b, node_names)
|
|
if n in node_names:
|
|
n.s_("onnx_name", node_names[n])
|
|
|
|
def b_fn(b, node_names):
|
|
for n in b.nodes():
|
|
n_fn(n, b_fn, node_names)
|
|
|
|
b_fn(graph, node_names)
|
|
return graph
|
|
|
|
|
|
def _model_to_graph(
|
|
model,
|
|
args,
|
|
verbose=False,
|
|
input_names=None,
|
|
output_names=None,
|
|
operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
|
|
do_constant_folding=True,
|
|
_disable_torch_constant_prop=False,
|
|
fixed_batch_size=False,
|
|
training=None,
|
|
dynamic_axes=None,
|
|
) -> Tuple[
|
|
_C.Graph,
|
|
Dict[str, torch.Tensor],
|
|
Optional[Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]]],
|
|
]:
|
|
"""Converts model into an ONNX graph.
|
|
|
|
Returns:
|
|
graph: A TorchScript IR Graph with ONNX nodes.
|
|
params_dict: Dict from input param name to param value.
|
|
torch_out: The output tensors resulting from the trace of ``model``.
|
|
If ``model`` is a :class:`torch.jit.ScriptModule` or :class:`torch.jit.ScriptFunction`,
|
|
this will be None, since we are not doing any tracing.
|
|
"""
|
|
# TODO: can we simplify this to always return a tuple of Tensor or None?
|
|
|
|
# Special case for common case of passing a single Tensor
|
|
if isinstance(args, (torch.Tensor, int, float, bool)):
|
|
args = (args,)
|
|
|
|
model = _pre_trace_quant_model(model, args)
|
|
graph, params, torch_out, module = _create_jit_graph(model, args)
|
|
params_dict = _get_named_param_dict(graph, params)
|
|
|
|
try:
|
|
graph = _optimize_graph(
|
|
graph,
|
|
operator_export_type,
|
|
_disable_torch_constant_prop=_disable_torch_constant_prop,
|
|
fixed_batch_size=fixed_batch_size,
|
|
params_dict=params_dict,
|
|
dynamic_axes=dynamic_axes,
|
|
input_names=input_names,
|
|
module=module,
|
|
)
|
|
except Exception as e:
|
|
torch.onnx.log("Torch IR graph at exception: ", graph)
|
|
raise
|
|
|
|
is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule))
|
|
if is_script:
|
|
example_outputs = _get_example_outputs(model, args)
|
|
example_outputs_final = ()
|
|
for example_output in example_outputs:
|
|
example_outputs_final += unpack_quantized_tensor(example_output)
|
|
out_vars, desc = torch.jit._flatten(example_outputs_final)
|
|
_C._jit_pass_onnx_assign_output_shape(
|
|
graph, out_vars, desc, GLOBALS.onnx_shape_inference, is_script
|
|
)
|
|
|
|
# NB: ONNX requires complete information about output types, which might be
|
|
# erased by some optimizations, so we need to set it explicitly again.
|
|
else:
|
|
if not isinstance(torch_out, (list, tuple)):
|
|
output_wrapped = [torch_out]
|
|
else:
|
|
output_wrapped = torch_out # type: ignore[assignment]
|
|
|
|
output_tensors, out_desc = _C._jit_flatten(tuple(output_wrapped))
|
|
# assign_output_shape pass is not compatible with quantized outputs.
|
|
# Quantized outputs are flattened to 3 values in ONNX, while packed as
|
|
# single value in PyTorch.
|
|
if not any(getattr(out, "is_quantized", False) for out in output_tensors):
|
|
_C._jit_pass_onnx_assign_output_shape(
|
|
graph,
|
|
output_tensors,
|
|
out_desc,
|
|
GLOBALS.onnx_shape_inference,
|
|
is_script,
|
|
)
|
|
|
|
_set_input_and_output_names(graph, input_names, output_names)
|
|
params_dict = _get_named_param_dict(graph, params)
|
|
|
|
if training is None or training == _C_onnx.TrainingMode.EVAL:
|
|
params_dict = _C._jit_pass_onnx_eval_peephole(graph, params_dict)
|
|
|
|
if (
|
|
do_constant_folding
|
|
and GLOBALS.export_onnx_opset_version in _constants.onnx_constant_folding_opsets
|
|
):
|
|
params_dict = _C._jit_pass_onnx_constant_fold(
|
|
graph, params_dict, GLOBALS.export_onnx_opset_version
|
|
)
|
|
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
|
|
|
if GLOBALS.onnx_shape_inference:
|
|
_C._jit_pass_onnx_graph_shape_type_inference(
|
|
graph, params_dict, GLOBALS.export_onnx_opset_version
|
|
)
|
|
|
|
params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
|
|
|
|
# For ONNX opset < 9, constants only have three data types: float16, float, double.
|
|
# In this pass transform constants of other data types to float/double + cast operator.
|
|
if GLOBALS.export_onnx_opset_version < 9:
|
|
_C._jit_pass_onnx_cast_all_constant_to_floating(graph)
|
|
|
|
params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict)
|
|
_C._jit_decay_packed_param_input_types(graph)
|
|
|
|
# If output names lack a proper name and are identified only by their unique
|
|
# give them a legible name for debugging purposes
|
|
_apply_friendly_debug_names(graph, params_dict)
|
|
|
|
return graph, params_dict, torch_out
|
|
|
|
|
|
def export_to_pretty_string(
|
|
model,
|
|
args,
|
|
export_params=True,
|
|
verbose=False,
|
|
training=None,
|
|
input_names=None,
|
|
output_names=None,
|
|
operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
|
|
export_type=None,
|
|
google_printer=False,
|
|
opset_version=None,
|
|
keep_initializers_as_inputs=None,
|
|
custom_opsets=None,
|
|
add_node_names=True,
|
|
do_constant_folding=True,
|
|
dynamic_axes=None,
|
|
):
|
|
|
|
if opset_version is None:
|
|
opset_version = _constants.onnx_default_opset
|
|
if custom_opsets is None:
|
|
custom_opsets = {}
|
|
symbolic_helper._set_opset_version(opset_version)
|
|
symbolic_helper._set_operator_export_type(operator_export_type)
|
|
|
|
symbolic_helper._set_onnx_shape_inference(True)
|
|
with exporter_context(model, training, verbose):
|
|
val_keep_init_as_ip = _decide_keep_init_as_input(
|
|
keep_initializers_as_inputs, operator_export_type, opset_version
|
|
)
|
|
val_add_node_names = _decide_add_node_names(
|
|
add_node_names, operator_export_type
|
|
)
|
|
val_do_constant_folding = _decide_constant_folding(
|
|
do_constant_folding, operator_export_type, training
|
|
)
|
|
args = _decide_input_format(model, args)
|
|
graph, params_dict, torch_out = _model_to_graph(
|
|
model,
|
|
args,
|
|
verbose,
|
|
input_names,
|
|
output_names,
|
|
operator_export_type,
|
|
val_do_constant_folding,
|
|
training=training,
|
|
dynamic_axes=dynamic_axes,
|
|
)
|
|
|
|
return graph._pretty_print_onnx( # type: ignore[attr-defined]
|
|
params_dict,
|
|
opset_version,
|
|
False,
|
|
operator_export_type,
|
|
google_printer,
|
|
val_keep_init_as_ip,
|
|
custom_opsets,
|
|
val_add_node_names,
|
|
)
|
|
|
|
|
|
def unconvertible_ops(
|
|
model, args, training=_C_onnx.TrainingMode.EVAL, opset_version=None
|
|
):
|
|
r"""
|
|
Converts the model with operator_export_type set to
|
|
torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH once in order to get a list of
|
|
all the ops that are not supported/implemented by the exporter.
|
|
|
|
Args:
|
|
model: Same as corresponding arg to torch.onnx.export.
|
|
args: Same as corresponding arg to torch.onnx.export.
|
|
training: Same as corresponding arg to torch.onnx.export.
|
|
opset_version: Same as corresponding arg to torch.onnx.export.
|
|
|
|
Returns:
|
|
Tuple[torch._C.Graph, List[str]], where the list includes the names
|
|
of the unconvertible ops.
|
|
"""
|
|
|
|
opset_version = opset_version or _constants.onnx_default_opset
|
|
symbolic_helper._set_opset_version(opset_version)
|
|
# operator_export_type is set to ONNX_FALLTHROUGH by default so that if an op is not supported
|
|
# in ONNX, fall through will occur and export the operator as is, as a custom ONNX op.
|
|
with exporter_context(model, training, False):
|
|
args = _decide_input_format(model, args)
|
|
graph, params_dict, torch_out = _model_to_graph(
|
|
model,
|
|
args,
|
|
# So that if an op connot be converted to ONNX, it will be kept
|
|
# as-is rather than cause a failure.
|
|
operator_export_type=_C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
|
|
)
|
|
unsupported_ops = list()
|
|
supported_namespaces = ("onnx", "prim", "quantized")
|
|
for node in graph.nodes():
|
|
if node.kind().split(":")[0] not in supported_namespaces:
|
|
unsupported_ops.append(node.kind())
|
|
return graph, unsupported_ops
|
|
|
|
|
|
def _setup_trace_module_map(model, export_modules_as_functions):
|
|
def __setup_trace_module_map():
|
|
trace_module_map = {_m: torch.typename(type(_m)) for _m in model.modules()}
|
|
torch.jit._trace._trace_module_map = trace_module_map
|
|
return trace_module_map
|
|
|
|
def __register_attribute_hook():
|
|
attr_name = "_onnx_attrs"
|
|
|
|
def _track_module_attributes_forward_pre_hook(module, input):
|
|
setattr(module, attr_name, _get_module_attributes(module))
|
|
|
|
def _track_module_attributes_forward_hook(module, input, output):
|
|
tracing_state = _C._get_tracing_state()
|
|
if not tracing_state:
|
|
return
|
|
|
|
graph = tracing_state.graph()
|
|
onnx_attrs = {}
|
|
if hasattr(module, attr_name):
|
|
onnx_attrs = getattr(module, attr_name)
|
|
delattr(module, attr_name)
|
|
|
|
_C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
|
|
|
|
for m in model.modules():
|
|
m.register_forward_hook(_track_module_attributes_forward_hook)
|
|
m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook)
|
|
|
|
if isinstance(export_modules_as_functions, bool) and export_modules_as_functions:
|
|
trace_module_map = __setup_trace_module_map()
|
|
export_modules_as_functions = {v for k, v in trace_module_map.items()}
|
|
elif (
|
|
isinstance(export_modules_as_functions, set)
|
|
and len(export_modules_as_functions) > 0
|
|
):
|
|
|
|
def _find_typename(v):
|
|
if isinstance(v, type):
|
|
return torch.typename(v)
|
|
else:
|
|
raise RuntimeError(
|
|
"Only type of the `nn.Module` should be "
|
|
"passed in the set for argument `export_modules_as_functions`. "
|
|
"Got `%s`." % (type(v).__name__)
|
|
)
|
|
|
|
trace_module_map = __setup_trace_module_map()
|
|
module_typenames = {_find_typename(v) for v in export_modules_as_functions}
|
|
export_modules_as_functions = module_typenames
|
|
else:
|
|
export_modules_as_functions = None
|
|
|
|
if export_modules_as_functions:
|
|
__register_attribute_hook()
|
|
|
|
return export_modules_as_functions
|
|
|
|
|
|
def _reset_trace_module_map():
|
|
torch.jit._trace._trace_module_map = None
|
|
_C._jit_pass_onnx_clear_scope_records()
|
|
|
|
|
|
def _get_module_attributes(module):
|
|
|
|
annotations = typing.get_type_hints(type(module))
|
|
base_m_annotations = typing.get_type_hints(torch.nn.Module)
|
|
[annotations.pop(k, None) for k in base_m_annotations]
|
|
return {k: getattr(module, k) for k in annotations}
|
|
|
|
|
|
def _export(
|
|
model,
|
|
args,
|
|
f,
|
|
export_params=True,
|
|
verbose=False,
|
|
training=None,
|
|
input_names=None,
|
|
output_names=None,
|
|
operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
|
|
export_type=None,
|
|
opset_version=None,
|
|
do_constant_folding=True,
|
|
dynamic_axes=None,
|
|
keep_initializers_as_inputs=None,
|
|
fixed_batch_size=False,
|
|
custom_opsets=None,
|
|
add_node_names=True,
|
|
onnx_shape_inference=True,
|
|
export_modules_as_functions=False,
|
|
):
|
|
if export_type is None:
|
|
export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
|
|
|
|
if isinstance(model, torch.nn.DataParallel):
|
|
raise ValueError(
|
|
"torch.nn.DataParallel is not supported by ONNX "
|
|
"exporter, please use 'attribute' module to "
|
|
"unwrap model from torch.nn.DataParallel. Try "
|
|
"torch.onnx.export(model.module, ...)"
|
|
)
|
|
assert GLOBALS.in_onnx_export is False
|
|
GLOBALS.in_onnx_export = True
|
|
try:
|
|
|
|
symbolic_helper._set_onnx_shape_inference(onnx_shape_inference)
|
|
|
|
if opset_version is None:
|
|
opset_version = _constants.onnx_default_opset
|
|
|
|
if export_modules_as_functions and opset_version < 15:
|
|
raise ValueError(
|
|
"`export_modules_as_functions` is not supported for `opset_version` < 15."
|
|
"This is because `opset_version` < 15 implies IR version < 8, which means "
|
|
"no local function support. "
|
|
)
|
|
export_modules_as_functions = _setup_trace_module_map(
|
|
model, export_modules_as_functions
|
|
)
|
|
|
|
if not operator_export_type:
|
|
if _C_onnx._CAFFE2_ATEN_FALLBACK:
|
|
operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
|
else:
|
|
operator_export_type = _C_onnx.OperatorExportTypes.ONNX
|
|
|
|
# By default, training=None, (which defaults to TrainingMode.EVAL),
|
|
# which is good because running a model in training mode could result in
|
|
# internal buffers getting updated, dropout getting applied, etc.
|
|
# If you really know what you're doing, you can turn
|
|
# training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE,
|
|
# (to preserve whatever the original training mode was.)
|
|
symbolic_helper._set_opset_version(opset_version)
|
|
symbolic_helper._set_operator_export_type(operator_export_type)
|
|
with exporter_context(model, training, verbose):
|
|
val_keep_init_as_ip = _decide_keep_init_as_input(
|
|
keep_initializers_as_inputs, operator_export_type, opset_version
|
|
)
|
|
val_add_node_names = _decide_add_node_names(
|
|
add_node_names, operator_export_type
|
|
)
|
|
val_do_constant_folding = _decide_constant_folding(
|
|
do_constant_folding, operator_export_type, training
|
|
)
|
|
# Normally f can be a file-like object, but for large models, the external data format requires a
|
|
# valid `model_file_location`. Code in export.cpp will enforce this.
|
|
if isinstance(f, str):
|
|
model_file_location = f
|
|
else:
|
|
model_file_location = ""
|
|
args = _decide_input_format(model, args)
|
|
if dynamic_axes is None:
|
|
dynamic_axes = {}
|
|
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
|
|
|
|
graph, params_dict, torch_out = _model_to_graph(
|
|
model,
|
|
args,
|
|
verbose,
|
|
input_names,
|
|
output_names,
|
|
operator_export_type,
|
|
val_do_constant_folding,
|
|
fixed_batch_size=fixed_batch_size,
|
|
training=training,
|
|
dynamic_axes=dynamic_axes,
|
|
)
|
|
|
|
# TODO: Don't allocate a in-memory string for the protobuf
|
|
defer_weight_export = (
|
|
export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
|
|
)
|
|
if custom_opsets is None:
|
|
custom_opsets = {}
|
|
|
|
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
|
node_attr_to_name = {} # type: ignore[var-annotated]
|
|
if export_modules_as_functions:
|
|
# NOTE: cannot call DCE after this pass. DCE will remove function definition nodes.
|
|
node_attr_to_name = _C._jit_pass_onnx_function_extraction(
|
|
graph, export_modules_as_functions, list(params_dict.keys())
|
|
)
|
|
params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment]
|
|
graph, params_dict, getattr(model, "training", False) # type: ignore[arg-type]
|
|
)
|
|
if export_params:
|
|
(
|
|
proto,
|
|
export_map,
|
|
val_use_external_data_format,
|
|
node_names,
|
|
) = graph._export_onnx( # type: ignore[attr-defined]
|
|
params_dict,
|
|
opset_version,
|
|
dynamic_axes,
|
|
defer_weight_export,
|
|
operator_export_type,
|
|
not verbose,
|
|
val_keep_init_as_ip,
|
|
custom_opsets,
|
|
val_add_node_names,
|
|
model_file_location,
|
|
node_attr_to_name,
|
|
)
|
|
else:
|
|
(
|
|
proto,
|
|
export_map,
|
|
val_use_external_data_format,
|
|
node_names,
|
|
) = graph._export_onnx( # type: ignore[attr-defined]
|
|
{},
|
|
opset_version,
|
|
dynamic_axes,
|
|
False,
|
|
operator_export_type,
|
|
not verbose,
|
|
val_keep_init_as_ip,
|
|
custom_opsets,
|
|
val_add_node_names,
|
|
model_file_location,
|
|
node_attr_to_name,
|
|
)
|
|
if verbose:
|
|
torch.onnx.log(
|
|
"Exported graph: ", _assign_onnx_node_name(graph, node_names)
|
|
)
|
|
if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE:
|
|
assert len(export_map) == 0
|
|
with torch.serialization._open_file_like(f, "wb") as opened_file:
|
|
opened_file.write(proto)
|
|
elif export_type in [
|
|
_exporter_states.ExportTypes.ZIP_ARCHIVE,
|
|
_exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE,
|
|
]:
|
|
compression = (
|
|
zipfile.ZIP_DEFLATED
|
|
if export_type
|
|
== _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE
|
|
else zipfile.ZIP_STORED
|
|
)
|
|
with zipfile.ZipFile(f, "w", compression=compression) as z:
|
|
z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
|
|
for k, v in export_map.items():
|
|
z.writestr(k, v)
|
|
elif export_type == _exporter_states.ExportTypes.DIRECTORY:
|
|
if os.path.exists(f):
|
|
assert os.path.isdir(f)
|
|
else:
|
|
os.makedirs(f)
|
|
|
|
model_proto_file = os.path.join(
|
|
f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME
|
|
)
|
|
with torch.serialization._open_file_like(
|
|
model_proto_file, "wb"
|
|
) as opened_file:
|
|
opened_file.write(proto)
|
|
|
|
for k, v in export_map.items():
|
|
weight_proto_file = os.path.join(f, k)
|
|
with torch.serialization._open_file_like(
|
|
weight_proto_file, "wb"
|
|
) as opened_file:
|
|
opened_file.write(v)
|
|
else:
|
|
raise RuntimeError("Unknown export type")
|
|
|
|
# The ONNX checker only works for ONNX graph. So if the operator_export_type is not ONNX,
|
|
# we can skip this check.
|
|
# If large model format export is enabled, proto will only contain data location instead of
|
|
# raw data and _check_onnx_proto() will fail because it can only handle the raw ONNX proto
|
|
# string in memory.
|
|
if (operator_export_type is _C_onnx.OperatorExportTypes.ONNX) and (
|
|
not val_use_external_data_format
|
|
):
|
|
try:
|
|
_C._check_onnx_proto(proto, full_check=True)
|
|
except RuntimeError as e:
|
|
raise errors.CheckerError(e)
|
|
finally:
|
|
assert GLOBALS.in_onnx_export
|
|
GLOBALS.in_onnx_export = False
|
|
_reset_trace_module_map()
|
|
|
|
return torch_out
|
|
|
|
|
|
def _apply_friendly_debug_names(graph, params):
|
|
for n in graph.nodes():
|
|
for v in n.inputs():
|
|
old_name = v.debugName()
|
|
if old_name != str(v.unique()):
|
|
continue
|
|
new_name = f"{n.kind()}_{v.unique()}"
|
|
v.setDebugName(new_name)
|
|
if old_name in params:
|
|
params[new_name] = params.pop(old_name)
|
|
|
|
|
|
def _set_input_and_output_names(graph, input_names, output_names):
|
|
def set_names(node_list, name_list, descriptor):
|
|
if name_list is None:
|
|
return
|
|
if len(name_list) > len(node_list):
|
|
raise RuntimeError(
|
|
"number of %s names provided (%d) exceeded number of %ss (%d)"
|
|
% (descriptor, len(name_list), descriptor, len(node_list))
|
|
)
|
|
|
|
# Mark if the output node DebugName is set before.
|
|
output_node_set = set()
|
|
for i, (name, node) in enumerate(zip(name_list, node_list)):
|
|
# Duplicated output node, insert onnx::Identity to avoid setting the same DebugName after setDebugName().
|
|
if descriptor == "output":
|
|
if node in output_node_set:
|
|
identity_node = graph.create("onnx::Identity")
|
|
identity_node.insertAfter(node.node())
|
|
identity_node.addInput(node)
|
|
identity_node.output().setType(node.type())
|
|
graph.return_node().replaceInput(i, identity_node.output())
|
|
node = identity_node.output()
|
|
output_node_set.add(node)
|
|
|
|
if node.debugName() != name:
|
|
node.setDebugName(name)
|
|
|
|
set_names(list(graph.inputs()), input_names, "input")
|
|
set_names(list(graph.outputs()), output_names, "output")
|
|
|
|
|
|
def _run_symbolic_method(g, op_name, symbolic_fn, args):
|
|
r"""
|
|
This trampoline function gets invoked for every symbolic method
|
|
call from C++.
|
|
"""
|
|
try:
|
|
return symbolic_fn(g, *args)
|
|
except TypeError as e:
|
|
# Handle the specific case where we didn't successfully dispatch
|
|
# to symbolic_fn. Otherwise, the backtrace will have the clues
|
|
# you need.
|
|
e.args = (f"{e.args[0]} (occurred when translating {op_name})",)
|
|
raise
|
|
|
|
|
|
def _add_block(node: _C.Node):
|
|
return node.addBlock() # type: ignore[attr-defined]
|
|
|
|
|
|
def _add_input_to_block(block: _C.Block):
|
|
return block.addInputToBlock() # type: ignore[attr-defined]
|
|
|
|
|
|
def _add_output_to_block(block: _C.Block, value: _C.Value):
|
|
new_output = block.registerOutput(value) # type: ignore[attr-defined]
|
|
return new_output
|
|
|
|
|
|
# Note [Export inplace]
|
|
# ~~~~~~~~~~~~~~~~~~~~~
|
|
# In abstract, it would be better for us to export inplace annotations,
|
|
# than to not export them, since it is useful information that can
|
|
# help the target of an ONNX export export more efficiently. However,
|
|
# ONNX doesn't currently formalize inplace. Fortunately, it's sound to drop
|
|
# inplace annotations, but we are losing information this way.
|
|
|
|
|
|
def _find_symbolic_in_registry(
|
|
domain: str,
|
|
op_name: str,
|
|
opset_version: int,
|
|
operator_export_type: _C_onnx.OperatorExportTypes,
|
|
) -> Optional[Callable]:
|
|
"""Looks up for the symbolic function in the registry.
|
|
|
|
Args:
|
|
domain: The domain of the symbolic function.
|
|
op_name: The name of the op.
|
|
opset_version: Currect opset used.
|
|
operator_export_type: An enum in _C_onnx.OperatorExportTypes.
|
|
|
|
Returns:
|
|
The symbolic function if found, None otherwise.
|
|
"""
|
|
|
|
if not symbolic_registry.is_registered_op(op_name, domain, opset_version):
|
|
if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
|
|
# Use the original node directly
|
|
return None
|
|
return symbolic_registry.get_registered_op(op_name, domain, opset_version)
|
|
|
|
|
|
def _should_aten_fallback(ns, op_name, opset_version, operator_export_type):
|
|
|
|
is_exportable_aten_op = symbolic_registry.is_registered_op(
|
|
op_name, "", opset_version
|
|
)
|
|
is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN
|
|
is_aten_fallback_export = (
|
|
operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
|
)
|
|
return is_onnx_aten_export or (
|
|
not is_exportable_aten_op and is_aten_fallback_export
|
|
)
|
|
|
|
|
|
def _need_symbolic_context(symbolic_fn) -> bool:
|
|
"""Checks if the first argument to symbolic_fn is annotated as type `torch.onnx.SymbolicContext`."""
|
|
params = tuple(inspect.signature(symbolic_fn).parameters.values())
|
|
# When the annotation is postpone-evaluated, the annotation is a string
|
|
# and not a type. We need to use get_type_hints to get the real type.
|
|
if not params:
|
|
return False
|
|
first_param_name = params[0].name
|
|
type_hints = typing.get_type_hints(symbolic_fn)
|
|
if first_param_name not in type_hints:
|
|
return False
|
|
param_type = type_hints[first_param_name]
|
|
return issubclass(param_type, _exporter_states.SymbolicContext)
|
|
|
|
|
|
def _get_aten_op_overload_name(n: _C.Node) -> str:
|
|
|
|
# Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
|
|
schema = n.schema()
|
|
if not schema.startswith("aten::") or symbolic_helper.is_caffe2_aten_fallback():
|
|
return ""
|
|
return _C.parse_schema(schema).overload_name
|
|
|
|
|
|
def _run_symbolic_function(
|
|
g: _C.Graph,
|
|
block: _C.Block,
|
|
n: _C.Node,
|
|
inputs: Any,
|
|
env: Dict[_C.Value, _C.Value],
|
|
operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
|
|
) -> Optional[Union[_C.Value, Tuple[_C.Value, ...]]]:
|
|
"""Runs a symbolic function.
|
|
|
|
The function is used in C++ to export the node to ONNX.
|
|
|
|
Returns:
|
|
A single or a tuple of Values.
|
|
None when the node gets cloned as is into the new graph.
|
|
"""
|
|
|
|
opset_version = GLOBALS.export_onnx_opset_version
|
|
symbolic_helper.is_caffe2_aten_fallback = symbolic_helper.is_caffe2_aten_fallback
|
|
|
|
# See Note [Export inplace]
|
|
# TODO(ezyang): I think this is not necessary anymore
|
|
if n.kind().endswith("_"):
|
|
ns_op_name = n.kind()[:-1]
|
|
else:
|
|
ns_op_name = n.kind()
|
|
ns, op_name = ns_op_name.split("::")
|
|
|
|
try:
|
|
symbolic_registry.register_version("", opset_version)
|
|
|
|
# Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
|
|
if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9:
|
|
|
|
symbolic_caffe2.register_quantized_ops("caffe2", opset_version)
|
|
|
|
if ns == "aten":
|
|
domain = ""
|
|
elif ns == "quantized" and symbolic_helper.is_caffe2_aten_fallback():
|
|
domain = "caffe2"
|
|
else:
|
|
domain = ns
|
|
|
|
if symbolic_registry.is_registered_op(op_name, domain, opset_version):
|
|
symbolic_fn = _find_symbolic_in_registry(
|
|
domain, op_name, opset_version, operator_export_type
|
|
)
|
|
assert symbolic_fn is not None
|
|
|
|
attrs = {k: n[k] for k in n.attributeNames()} # type: ignore[attr-defined]
|
|
if _need_symbolic_context(symbolic_fn):
|
|
ctx = _exporter_states.SymbolicContext(_params_dict, env, n, block)
|
|
return symbolic_fn(ctx, g, *inputs, **attrs)
|
|
# PythonOp symbolic need access to the node to resolve the name conflict,
|
|
# this is inconsistent with regular op symbolic.
|
|
if op_name == "PythonOp":
|
|
inputs = (n, *inputs)
|
|
return symbolic_fn(g, *inputs, **attrs)
|
|
elif ns == "onnx":
|
|
# Clone node to trigger ONNX shape inference
|
|
attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()} # type: ignore[attr-defined]
|
|
return g.op(op_name, *inputs, **attrs, outputs=n.outputsSize()) # type: ignore[attr-defined]
|
|
elif _should_aten_fallback(ns, op_name, opset_version, operator_export_type):
|
|
# Direct ATen export requested
|
|
attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()} # type: ignore[attr-defined]
|
|
outputs = n.outputsSize()
|
|
attrs["outputs"] = outputs
|
|
# `overload_name` is set for non-Caffe2 builds only
|
|
return g.at( # type: ignore[attr-defined]
|
|
op_name, *inputs, overload_name=_get_aten_op_overload_name(n), **attrs
|
|
)
|
|
else:
|
|
raise errors.UnsupportedOperatorError(
|
|
domain,
|
|
op_name,
|
|
opset_version,
|
|
symbolic_registry.get_op_supported_version(
|
|
op_name, domain, opset_version
|
|
),
|
|
)
|
|
except RuntimeError:
|
|
if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
|
|
return None
|
|
elif (
|
|
operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
|
and not symbolic_helper.is_caffe2_aten_fallback()
|
|
):
|
|
# Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK`
|
|
attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()} # type: ignore[attr-defined]
|
|
return g.at( # type: ignore[attr-defined]
|
|
op_name, *inputs, overload_name=_get_aten_op_overload_name(n), **attrs
|
|
)
|
|
raise
|
|
except TypeError as e:
|
|
# Handle the specific case where we didn't successfully dispatch.
|
|
# Otherwise, the backtrace will have the clues you need.
|
|
e.args = (f"{e.args[0]} \n(Occurred when translating {op_name}).",)
|
|
raise
|
|
|
|
|
|
def get_ns_op_name_from_custom_op(symbolic_name):
|
|
if not bool(
|
|
re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name)
|
|
):
|
|
raise ValueError(
|
|
f"Failed to register operator {symbolic_name}."
|
|
"The symbolic name must match the format Domain::Name, "
|
|
"and should start with a letter and contain only "
|
|
"alphanumerical characters"
|
|
)
|
|
|
|
ns, op_name = symbolic_name.split("::")
|
|
if ns == "onnx":
|
|
raise ValueError(
|
|
f"Failed to register operator {symbolic_name}. {ns} domain cannot be modified."
|
|
)
|
|
|
|
if ns == "aten":
|
|
ns = ""
|
|
|
|
return ns, op_name
|
|
|
|
|
|
def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
|
|
"""Registers a symbolic function for a custom operator.
|
|
|
|
When the user registers symbolic for custom/contrib ops,
|
|
it is highly recommended to add shape inference for that operator via setType API,
|
|
otherwise the exported graph may have incorrect shape inference in some extreme cases.
|
|
An example of setType is `test_aten_embedding_2` in `test_operators.py`.
|
|
"""
|
|
ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
|
|
|
|
for version in itertools.chain(
|
|
_constants.onnx_stable_opsets, [_constants.onnx_main_opset]
|
|
):
|
|
if version >= opset_version:
|
|
symbolic_registry.register_op(op_name, symbolic_fn, ns, version)
|
|
|
|
|
|
def unregister_custom_op_symbolic(symbolic_name, opset_version):
|
|
ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
|
|
|
|
for version in itertools.chain(
|
|
_constants.onnx_stable_opsets, [_constants.onnx_main_opset]
|
|
):
|
|
if version >= opset_version:
|
|
symbolic_registry.unregister_op(op_name, ns, version)
|
|
|
|
|
|
def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
|
|
"""Ensures dynamic axes argument is follows the expected format."""
|
|
if len(dynamic_axes) == 0:
|
|
return
|
|
|
|
if hasattr(model, "graph"):
|
|
# Extracting set of valid input/output names that shall be used for dynamic_axes
|
|
if (input_names is None) or len(input_names) == 0:
|
|
input_names = [x.debugName() for x in model.graph.inputs()]
|
|
if (output_names is None) or len(output_names) == 0:
|
|
output_names = [y.debugName() for y in model.graph.outputs()]
|
|
|
|
valid_names = set((input_names or []) + (output_names or []))
|
|
|
|
# If dynamic axes are provided as a list rather than dictionary, they should
|
|
# first get converted to a dictionary in expected format. If desired axes names
|
|
# are not provided for dynamic axes, automatic names shall be generated for
|
|
# provided dynamic axes of specified input/output
|
|
for key, value in dynamic_axes.items():
|
|
if key not in valid_names:
|
|
warnings.warn(
|
|
f"Provided key {key} for dynamic axes is not a valid input/output name"
|
|
)
|
|
if isinstance(value, list):
|
|
warnings.warn(
|
|
"No names were found for specified dynamic axes of provided input."
|
|
f"Automatically generated names will be applied to each dynamic axes of input {key}"
|
|
)
|
|
|
|
value_dict = {}
|
|
for i, x in enumerate(value):
|
|
if not isinstance(x, int):
|
|
raise ValueError(
|
|
"The type of axis index is expected to be an integer"
|
|
)
|
|
if x in value_dict:
|
|
warnings.warn(
|
|
f"Duplicate dynamic axis index {x} was provided for input {key}."
|
|
)
|
|
else:
|
|
value_dict[x] = str(key) + "_dynamic_axes_" + str(i + 1)
|
|
dynamic_axes[key] = value_dict
|