mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
A graph is exported for each set of inputs. The exported graphs are then compared to each other, and discrepancies are reported. This function first checks the jit graph, and then the onnx graph. Unless otherwise specified, the jit/ONNX graph is expected to be the same, regardless of the inputs it used for exporting. A discrepancy would imply the graph exported is not accurate when running with other set of inputs, which will typically results in runtime error or output mismatches. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78323 Approved by: https://github.com/justinchuby, https://github.com/garymm
29 lines
1.0 KiB
Python
29 lines
1.0 KiB
Python
"""Experimental classes and functions used by ONNX export."""
|
|
|
|
import dataclasses
|
|
from typing import Mapping, Optional, Sequence, Set, Type, Union
|
|
|
|
import torch
|
|
import torch._C._onnx as _C_onnx
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ExportOptions:
|
|
"""Arguments used by :func:`torch.onnx.export`.
|
|
|
|
TODO: Adopt this in `torch.onnx.export` api to replace keyword arguments.
|
|
"""
|
|
|
|
export_params: bool = True
|
|
verbose: bool = False
|
|
training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL
|
|
input_names: Optional[Sequence[str]] = None
|
|
output_names: Optional[Sequence[str]] = None
|
|
operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX
|
|
opset_version: Optional[int] = None
|
|
do_constant_folding: bool = True
|
|
dynamic_axes: Optional[Mapping[str, Union[Mapping[int, str], Sequence[int]]]] = None
|
|
keep_initializers_as_inputs: Optional[bool] = None
|
|
custom_opsets: Optional[Mapping[str, int]] = None
|
|
export_modules_as_functions: Union[bool, Set[Type[torch.nn.Module]]] = False
|