mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Add repro export from GraphInfo
(#89947)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89947 Approved by: https://github.com/justinchuby
This commit is contained in:
committed by
PyTorch MergeBot
parent
525c33c09f
commit
f258753799
@ -3,7 +3,4 @@
|
||||
==================================== Tree: =====================================
|
||||
1 ✓
|
||||
id:
|
||||
=========================== Mismatch leaf subgraphs: ===========================
|
||||
[]
|
||||
============================= Mismatch node kinds: =============================
|
||||
{}
|
||||
============================== No mismatch found. ==============================
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@ -272,6 +273,17 @@ class TestFindMismatch(pytorch_test_common.ExportTestCase):
|
||||
)
|
||||
self.assertExpected(f.getvalue())
|
||||
|
||||
def test_export_repro_for_mismatch(self):
|
||||
mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
|
||||
self.assertTrue(len(mismatch_leaves) > 0)
|
||||
leaf_info = mismatch_leaves[0]
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
repro_dir = leaf_info.export_repro(temp_dir)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
|
||||
options = verification.VerificationOptions(backend=self.onnx_backend)
|
||||
verification.OnnxTestCaseRepro(repro_dir).validate(options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
@ -1,9 +1,11 @@
|
||||
"""Utilities for manipulating the onnx and onnx-script dependencies and ONNX proto."""
|
||||
|
||||
import glob
|
||||
import io
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
from typing import List, Mapping, Set, Union
|
||||
from typing import Any, List, Mapping, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.jit._trace
|
||||
@ -12,6 +14,150 @@ from torch.onnx import _constants, _exporter_states, errors
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def export_as_test_case(
|
||||
model_bytes: bytes, inputs_data, outputs_data, name: str, dir: str
|
||||
) -> str:
|
||||
"""Export an ONNX model as a self contained ONNX test case.
|
||||
|
||||
The test case contains the model and the inputs/outputs data. The directory structure
|
||||
is as follows:
|
||||
|
||||
dir
|
||||
├── test_<name>
|
||||
│ ├── model.onnx
|
||||
│ └── test_data_set_0
|
||||
│ ├── input_0.pb
|
||||
│ ├── input_1.pb
|
||||
│ ├── output_0.pb
|
||||
│ └── output_1.pb
|
||||
|
||||
Args:
|
||||
model_bytes: The ONNX model in bytes.
|
||||
inputs_data: The inputs data, nested data structure of numpy.ndarray.
|
||||
outputs_data: The outputs data, nested data structure of numpy.ndarray.
|
||||
|
||||
Returns:
|
||||
The path to the test case directory.
|
||||
"""
|
||||
try:
|
||||
import onnx
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Export test case to ONNX format failed: Please install ONNX."
|
||||
)
|
||||
|
||||
test_case_dir = os.path.join(dir, "test_" + name)
|
||||
os.makedirs(test_case_dir, exist_ok=True)
|
||||
_export_file(
|
||||
model_bytes,
|
||||
os.path.join(test_case_dir, "model.onnx"),
|
||||
_exporter_states.ExportTypes.PROTOBUF_FILE,
|
||||
{},
|
||||
)
|
||||
data_set_dir = os.path.join(test_case_dir, "test_data_set_0")
|
||||
if os.path.exists(data_set_dir):
|
||||
shutil.rmtree(data_set_dir)
|
||||
os.makedirs(data_set_dir)
|
||||
|
||||
proto = onnx.load_from_string(model_bytes)
|
||||
|
||||
for i, (input_proto, input) in enumerate(zip(proto.graph.input, inputs_data)):
|
||||
export_data(input, input_proto, os.path.join(data_set_dir, f"input_{i}.pb"))
|
||||
for i, (output_proto, output) in enumerate(zip(proto.graph.output, outputs_data)):
|
||||
export_data(output, output_proto, os.path.join(data_set_dir, f"output_{i}.pb"))
|
||||
|
||||
return test_case_dir
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def load_test_case(dir: str) -> Tuple[bytes, Any, Any]:
|
||||
"""Load a self contained ONNX test case from a directory.
|
||||
|
||||
The test case must contain the model and the inputs/outputs data. The directory structure
|
||||
should be as follows:
|
||||
|
||||
dir
|
||||
├── test_<name>
|
||||
│ ├── model.onnx
|
||||
│ └── test_data_set_0
|
||||
│ ├── input_0.pb
|
||||
│ ├── input_1.pb
|
||||
│ ├── output_0.pb
|
||||
│ └── output_1.pb
|
||||
|
||||
Args:
|
||||
dir: The directory containing the test case.
|
||||
|
||||
Returns:
|
||||
model_bytes: The ONNX model in bytes.
|
||||
inputs: the inputs data, mapping from input name to numpy.ndarray.
|
||||
outputs: the outputs data, mapping from output name to numpy.ndarray.
|
||||
"""
|
||||
try:
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Load test case from ONNX format failed: Please install ONNX."
|
||||
)
|
||||
|
||||
with open(os.path.join(dir, "model.onnx"), "rb") as f:
|
||||
model_bytes = f.read()
|
||||
|
||||
test_data_dir = os.path.join(dir, "test_data_set_0")
|
||||
|
||||
inputs = {}
|
||||
input_files = glob.glob(os.path.join(test_data_dir, "input_*.pb"))
|
||||
for input_file in input_files:
|
||||
tensor = onnx.load_tensor(input_file)
|
||||
inputs[tensor.name] = numpy_helper.to_array(tensor)
|
||||
outputs = {}
|
||||
output_files = glob.glob(os.path.join(test_data_dir, "output_*.pb"))
|
||||
for output_file in output_files:
|
||||
tensor = onnx.load_tensor(output_file)
|
||||
outputs[tensor.name] = numpy_helper.to_array(tensor)
|
||||
|
||||
return model_bytes, inputs, outputs
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def export_data(data, value_info_proto, f: str) -> None:
|
||||
"""Export data to ONNX protobuf format.
|
||||
|
||||
Args:
|
||||
data: The data to export, nested data structure of numpy.ndarray.
|
||||
value_info_proto: The ValueInfoProto of the data. The type of the ValueInfoProto
|
||||
determines how the data is stored.
|
||||
f: The file to write the data to.
|
||||
"""
|
||||
try:
|
||||
from onnx import numpy_helper
|
||||
except ImportError:
|
||||
raise ImportError("Export data to ONNX format failed: Please install ONNX.")
|
||||
|
||||
with open(f, "wb") as opened_file:
|
||||
if value_info_proto.type.HasField("map_type"):
|
||||
opened_file.write(
|
||||
numpy_helper.from_dict(data, value_info_proto.name).SerializeToString()
|
||||
)
|
||||
elif value_info_proto.type.HasField("sequence_type"):
|
||||
opened_file.write(
|
||||
numpy_helper.from_list(data, value_info_proto.name).SerializeToString()
|
||||
)
|
||||
elif value_info_proto.type.HasField("optional_type"):
|
||||
opened_file.write(
|
||||
numpy_helper.from_optional(
|
||||
data, value_info_proto.name
|
||||
).SerializeToString()
|
||||
)
|
||||
else:
|
||||
assert value_info_proto.type.HasField("tensor_type")
|
||||
opened_file.write(
|
||||
numpy_helper.from_array(data, value_info_proto.name).SerializeToString()
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _export_file(
|
||||
model_bytes: bytes,
|
||||
|
@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import datetime
|
||||
import difflib
|
||||
import enum
|
||||
import functools
|
||||
@ -47,6 +48,7 @@ _NumericType = Union[Number, torch.Tensor, np.ndarray]
|
||||
_ModelType = Union[torch.nn.Module, torch.jit.ScriptModule]
|
||||
_InputArgsType = Union[torch.Tensor, Tuple[Any, ...]]
|
||||
_InputKwargsType = Mapping[str, Any]
|
||||
_OutputsType = Union[Sequence[_NumericType], Sequence]
|
||||
|
||||
|
||||
class OnnxBackend(enum.Enum):
|
||||
@ -146,7 +148,7 @@ def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list:
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _run_onnx(onnx_session, inputs):
|
||||
def _run_onnx(onnx_session, inputs) -> _OutputsType:
|
||||
kw_inputs = {}
|
||||
if inputs and isinstance(inputs[-1], dict):
|
||||
kw_inputs = inputs[-1]
|
||||
@ -229,34 +231,14 @@ def _onnx_backend_session(model: Union[str, io.BytesIO], backend: OnnxBackend):
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _compare_onnx_pytorch_outputs(
|
||||
onnx_outs: Union[Sequence[_NumericType], Sequence],
|
||||
pt_outs: Optional[Union[_NumericType, Sequence[_NumericType], Sequence, Dict]],
|
||||
def _compare_onnx_pytorch_outputs_in_np(
|
||||
onnx_outs: _OutputsType,
|
||||
pt_outs: _OutputsType,
|
||||
options: VerificationOptions,
|
||||
):
|
||||
"""
|
||||
Compare ONNX and PyTorch outputs.
|
||||
|
||||
Args:
|
||||
onnx_outs: outputs from ONNX backend.
|
||||
pt_outs: outputs from PyTorch.
|
||||
options: options for verification.
|
||||
|
||||
Raises:
|
||||
AssertionError: if outputs from ONNX model and PyTorch model are not
|
||||
equal up to specified precision.
|
||||
ValueError: if arguments provided are invalid.
|
||||
"""
|
||||
if options.ignore_none:
|
||||
# torch.jit._flatten filters None type
|
||||
pt_outs, _ = torch.jit._flatten(pt_outs)
|
||||
else:
|
||||
pt_outs = _inline_flatten_list([pt_outs], [])
|
||||
pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False)
|
||||
onnx_outs = _inline_flatten_list(onnx_outs, [])
|
||||
assert len(onnx_outs) == len(
|
||||
pt_outs_np
|
||||
), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs_np)})"
|
||||
pt_outs
|
||||
), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})"
|
||||
acceptable_error_percentage = options.acceptable_error_percentage
|
||||
if acceptable_error_percentage and (
|
||||
acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0
|
||||
@ -265,7 +247,7 @@ def _compare_onnx_pytorch_outputs(
|
||||
"If set, acceptable_error_percentage should be between 0.0 and 1.0"
|
||||
)
|
||||
|
||||
for ort_out, pt_out in zip(onnx_outs, pt_outs_np):
|
||||
for ort_out, pt_out in zip(onnx_outs, pt_outs):
|
||||
try:
|
||||
# TODO: Remove `check_shape` option once every shape inconsistent issue is addressed.
|
||||
if not options.check_shape:
|
||||
@ -298,6 +280,35 @@ def _compare_onnx_pytorch_outputs(
|
||||
raise
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _compare_onnx_pytorch_outputs(
|
||||
onnx_outs: _OutputsType,
|
||||
pt_outs: Any,
|
||||
options: VerificationOptions,
|
||||
):
|
||||
"""
|
||||
Compare ONNX and PyTorch outputs.
|
||||
|
||||
Args:
|
||||
onnx_outs: outputs from ONNX backend.
|
||||
pt_outs: outputs from PyTorch.
|
||||
options: options for verification.
|
||||
|
||||
Raises:
|
||||
AssertionError: if outputs from ONNX model and PyTorch model are not
|
||||
equal up to specified precision.
|
||||
ValueError: if arguments provided are invalid.
|
||||
"""
|
||||
if options.ignore_none:
|
||||
# torch.jit._flatten filters None type
|
||||
pt_outs, _ = torch.jit._flatten(pt_outs)
|
||||
else:
|
||||
pt_outs = _inline_flatten_list([pt_outs], [])
|
||||
pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False)
|
||||
onnx_outs = _inline_flatten_list(onnx_outs, [])
|
||||
_compare_onnx_pytorch_outputs_in_np(onnx_outs, pt_outs_np, options)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _prepare_input_for_pytorch(args, kwargs):
|
||||
"""Prepare input for PyTorch model execution.
|
||||
@ -655,6 +666,101 @@ def _onnx_graph_from_model(
|
||||
return onnx_graph
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _onnx_graph_from_aten_graph(
|
||||
graph: torch.Graph,
|
||||
export_options: _experimental.ExportOptions,
|
||||
params_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Graph, Dict[str, Any]]:
|
||||
if params_dict is None:
|
||||
params_dict = {}
|
||||
operator_export_type = export_options.operator_export_type
|
||||
dynamic_axes = export_options.dynamic_axes or {}
|
||||
input_names = export_options.input_names
|
||||
training = export_options.training
|
||||
do_constant_folding = export_options.do_constant_folding
|
||||
opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET
|
||||
|
||||
do_constant_folding = utils._decide_constant_folding(
|
||||
do_constant_folding, operator_export_type, training
|
||||
)
|
||||
|
||||
# TODO: Below is doing aten graph to onnx. It should be abstracted as a
|
||||
# function in torch/onnx/utils.py.
|
||||
graph = graph.copy()
|
||||
graph = utils._optimize_graph(
|
||||
graph,
|
||||
operator_export_type,
|
||||
params_dict=params_dict,
|
||||
dynamic_axes=dynamic_axes,
|
||||
input_names=input_names,
|
||||
)
|
||||
|
||||
if training is None or training == _C_onnx.TrainingMode.EVAL:
|
||||
params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
|
||||
|
||||
if (
|
||||
do_constant_folding
|
||||
and opset_version >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET
|
||||
):
|
||||
params_dict = _C._jit_pass_onnx_constant_fold(graph, params_dict, opset_version)
|
||||
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
||||
|
||||
if GLOBALS.onnx_shape_inference:
|
||||
_C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, opset_version)
|
||||
|
||||
params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
|
||||
|
||||
# For ONNX opset < 9, constants only have three data types: float16, float, double.
|
||||
# In this pass transform constants of other data types to float/double + cast operator.
|
||||
if opset_version < 9:
|
||||
_C._jit_pass_onnx_cast_all_constant_to_floating(graph)
|
||||
|
||||
params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict)
|
||||
_C._jit_decay_packed_param_input_types(graph)
|
||||
|
||||
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
||||
|
||||
if export_options.verbose:
|
||||
print("ONNX graph: ", graph)
|
||||
|
||||
return graph, params_dict
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _onnx_proto_from_onnx_graph(
|
||||
onnx_graph: torch.Graph,
|
||||
export_options: _experimental.ExportOptions,
|
||||
params_dict: Dict[str, Any],
|
||||
) -> Tuple[bytes, Mapping[str, bytes]]:
|
||||
opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET
|
||||
dynamic_axes = export_options.dynamic_axes or {}
|
||||
operator_export_type = export_options.operator_export_type
|
||||
val_keep_init_as_ip = utils._decide_keep_init_as_input(
|
||||
export_options.keep_initializers_as_inputs,
|
||||
operator_export_type,
|
||||
opset_version,
|
||||
)
|
||||
val_add_node_names = utils._decide_add_node_names(True, operator_export_type)
|
||||
custom_opsets = export_options.custom_opsets or {}
|
||||
|
||||
proto, export_map, _, _ = onnx_graph._export_onnx( # type: ignore[attr-defined]
|
||||
params_dict,
|
||||
opset_version,
|
||||
dynamic_axes,
|
||||
False,
|
||||
operator_export_type,
|
||||
not export_options.verbose,
|
||||
val_keep_init_as_ip,
|
||||
custom_opsets,
|
||||
val_add_node_names,
|
||||
"",
|
||||
{},
|
||||
)
|
||||
|
||||
return proto, export_map
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def check_export_model_diff(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
||||
@ -795,41 +901,14 @@ def verify_aten_graph(
|
||||
export_options: _experimental.ExportOptions,
|
||||
params_dict: Optional[Dict[str, Any]] = None,
|
||||
verification_options: Optional[VerificationOptions] = None,
|
||||
) -> Tuple[
|
||||
Optional[AssertionError],
|
||||
torch.Graph,
|
||||
Union[_NumericType, Sequence[_NumericType]],
|
||||
Union[_NumericType, Sequence[_NumericType]],
|
||||
]:
|
||||
) -> Tuple[Optional[AssertionError], torch.Graph, _OutputsType, _OutputsType]:
|
||||
if verification_options is None:
|
||||
verification_options = VerificationOptions()
|
||||
|
||||
original_jit_graph = graph
|
||||
graph = graph.copy()
|
||||
|
||||
operator_export_type = export_options.operator_export_type
|
||||
dynamic_axes = export_options.dynamic_axes
|
||||
if dynamic_axes is None:
|
||||
dynamic_axes = {}
|
||||
input_names = export_options.input_names
|
||||
training = export_options.training
|
||||
do_constant_folding = export_options.do_constant_folding
|
||||
opset_version = export_options.opset_version
|
||||
if params_dict is None:
|
||||
params_dict = {}
|
||||
|
||||
if opset_version is None:
|
||||
opset_version = _constants.ONNX_DEFAULT_OPSET
|
||||
|
||||
val_keep_init_as_ip = utils._decide_keep_init_as_input(
|
||||
export_options.keep_initializers_as_inputs,
|
||||
operator_export_type,
|
||||
opset_version,
|
||||
)
|
||||
val_add_node_names = utils._decide_add_node_names(True, operator_export_type)
|
||||
do_constant_folding = utils._decide_constant_folding(
|
||||
do_constant_folding, operator_export_type, training
|
||||
)
|
||||
original_jit_graph = graph
|
||||
graph = graph.copy()
|
||||
|
||||
# Execute aten graph and get reference torch jit outputs.
|
||||
graph_inputs = list(v for v in graph.inputs())
|
||||
@ -840,69 +919,18 @@ def verify_aten_graph(
|
||||
jit_inputs = copy.deepcopy(jit_inputs)
|
||||
jit_input_and_parameters = jit_inputs + tuple(weights)
|
||||
jit_outs = torch._C._jit_interpret_graph(graph, jit_input_and_parameters) # type: ignore[attr-defined]
|
||||
if not isinstance(jit_outs, (list, tuple)):
|
||||
jit_outs = [jit_outs]
|
||||
|
||||
# Convert aten graph to onnx.
|
||||
graph = utils._optimize_graph(
|
||||
graph,
|
||||
operator_export_type,
|
||||
params_dict=params_dict,
|
||||
dynamic_axes=dynamic_axes,
|
||||
input_names=input_names,
|
||||
# Convert aten graph to onnx graph.
|
||||
graph, onnx_params_dict = _onnx_graph_from_aten_graph(
|
||||
graph, export_options, params_dict
|
||||
)
|
||||
|
||||
# TODO(bowbao): Below is doing aten graph to onnx. It should be abstracted as a
|
||||
# function in torch/onnx/utils.py.
|
||||
if training is None or training == _C_onnx.TrainingMode.EVAL:
|
||||
params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
|
||||
|
||||
if (
|
||||
do_constant_folding
|
||||
and GLOBALS.export_onnx_opset_version
|
||||
>= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET
|
||||
):
|
||||
params_dict = _C._jit_pass_onnx_constant_fold(
|
||||
graph, params_dict, GLOBALS.export_onnx_opset_version
|
||||
)
|
||||
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
||||
|
||||
if GLOBALS.onnx_shape_inference:
|
||||
_C._jit_pass_onnx_graph_shape_type_inference(
|
||||
graph, params_dict, GLOBALS.export_onnx_opset_version
|
||||
)
|
||||
|
||||
params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
|
||||
|
||||
# For ONNX opset < 9, constants only have three data types: float16, float, double.
|
||||
# In this pass transform constants of other data types to float/double + cast operator.
|
||||
if GLOBALS.export_onnx_opset_version < 9:
|
||||
_C._jit_pass_onnx_cast_all_constant_to_floating(graph)
|
||||
|
||||
params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict)
|
||||
_C._jit_decay_packed_param_input_types(graph)
|
||||
|
||||
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
||||
|
||||
if export_options.verbose:
|
||||
print("ONNX graph: ", graph)
|
||||
|
||||
custom_opsets = export_options.custom_opsets
|
||||
if custom_opsets is None:
|
||||
custom_opsets = {}
|
||||
|
||||
proto, export_map = _onnx_proto_from_onnx_graph(
|
||||
graph, export_options, onnx_params_dict
|
||||
)
|
||||
model_f: Union[str, io.BytesIO] = io.BytesIO()
|
||||
proto, export_map, _, _ = graph._export_onnx( # type: ignore[attr-defined]
|
||||
params_dict,
|
||||
opset_version,
|
||||
dynamic_axes,
|
||||
False,
|
||||
operator_export_type,
|
||||
not export_options.verbose,
|
||||
val_keep_init_as_ip,
|
||||
custom_opsets,
|
||||
val_add_node_names,
|
||||
"",
|
||||
{},
|
||||
)
|
||||
export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
|
||||
onnx_proto_utils._export_file(proto, model_f, export_type, export_map)
|
||||
|
||||
@ -933,7 +961,6 @@ def verify_aten_graph(
|
||||
options=verification_options,
|
||||
)
|
||||
except AssertionError as e:
|
||||
# print("Has mismatch: ", e)
|
||||
return e, graph, jit_outs, onnx_outs
|
||||
|
||||
return None, graph, jit_outs, onnx_outs
|
||||
@ -967,6 +994,7 @@ class GraphInfoPrettyPrinter:
|
||||
self.upper_printer = None
|
||||
self.lower_printer = None
|
||||
|
||||
@_beartype.beartype
|
||||
def _total_rows(self) -> int:
|
||||
if self.graph_info is None:
|
||||
return 1
|
||||
@ -976,6 +1004,7 @@ class GraphInfoPrettyPrinter:
|
||||
)
|
||||
return 2 # Two lines: node count + id.
|
||||
|
||||
@_beartype.beartype
|
||||
def _node_count_segment_str(self) -> str:
|
||||
if self.graph_info is None:
|
||||
return "..."
|
||||
@ -989,16 +1018,19 @@ class GraphInfoPrettyPrinter:
|
||||
|
||||
return f"{node_count} {'X' if has_mismatch else '✓'} {error_node_kind}"
|
||||
|
||||
@_beartype.beartype
|
||||
def _graph_id_segment_str(self) -> str:
|
||||
if self.graph_info is None:
|
||||
return ""
|
||||
return f"id: {self.graph_info.id}"
|
||||
|
||||
@_beartype.beartype
|
||||
def _max_segment_columns(self) -> int:
|
||||
return max(
|
||||
map(len, (self._node_count_segment_str(), self._graph_id_segment_str()))
|
||||
)
|
||||
|
||||
@_beartype.beartype
|
||||
def _graph_segment_str_at_line(self, line: int) -> str:
|
||||
"""Get the string representation of the graph segment at the given line."""
|
||||
if line == 0:
|
||||
@ -1013,6 +1045,7 @@ class GraphInfoPrettyPrinter:
|
||||
return " " * self._max_segment_columns()
|
||||
return ""
|
||||
|
||||
@_beartype.beartype
|
||||
def _connector_segment_str_at_line(self, line: int) -> str:
|
||||
"""Get the connector segment string at the given line."""
|
||||
if self.upper_printer is None and self.lower_printer is None:
|
||||
@ -1029,6 +1062,7 @@ class GraphInfoPrettyPrinter:
|
||||
return " "
|
||||
return ""
|
||||
|
||||
@_beartype.beartype
|
||||
def _children_str_at_line(self, line: int) -> str:
|
||||
"""Get the string representation of the children at the given line.
|
||||
|
||||
@ -1050,6 +1084,7 @@ class GraphInfoPrettyPrinter:
|
||||
)
|
||||
return ""
|
||||
|
||||
@_beartype.beartype
|
||||
def _str_at_line(self, line: int) -> str:
|
||||
"""Get the string representation of the graph at the given line."""
|
||||
return (
|
||||
@ -1067,25 +1102,97 @@ class GraphInfoPrettyPrinter:
|
||||
total_rows = self._total_rows()
|
||||
for line in range(total_rows):
|
||||
print(self._str_at_line(line).rstrip())
|
||||
# Summarize leaf subgraphs with mismatch.
|
||||
print(" Mismatch leaf subgraphs: ".center(80, "="))
|
||||
print(
|
||||
[
|
||||
graph_info.id
|
||||
for graph_info in self.graph_info.all_mismatch_leaf_graph_info()
|
||||
]
|
||||
if self.graph_info.has_mismatch():
|
||||
# Summarize leaf subgraphs with mismatch.
|
||||
print(" Mismatch leaf subgraphs: ".center(80, "="))
|
||||
print(
|
||||
[
|
||||
graph_info.id
|
||||
for graph_info in self.graph_info.all_mismatch_leaf_graph_info()
|
||||
]
|
||||
)
|
||||
# Summarize node kinds with mismatch.
|
||||
mismatch_node_kinds: Dict[str, int] = {}
|
||||
for graph_info in self.graph_info.all_mismatch_leaf_graph_info():
|
||||
node_kinds = graph_info.essential_node_kinds()
|
||||
if len(node_kinds) == 1:
|
||||
node_kind = node_kinds.pop()
|
||||
mismatch_node_kinds[node_kind] = (
|
||||
mismatch_node_kinds.get(node_kind, 0) + 1
|
||||
)
|
||||
print(" Mismatch node kinds: ".center(80, "="))
|
||||
print(mismatch_node_kinds)
|
||||
else:
|
||||
print(" No mismatch found. ".center(80, "="))
|
||||
|
||||
|
||||
class OnnxTestCaseRepro:
|
||||
def __init__(self, repro_dir):
|
||||
self.repro_dir = repro_dir
|
||||
self.proto, self.inputs, self.outputs = onnx_proto_utils.load_test_case(
|
||||
repro_dir
|
||||
)
|
||||
# Summarize node kinds with mismatch.
|
||||
mismatch_node_kinds: Dict[str, int] = {}
|
||||
for graph_info in self.graph_info.all_mismatch_leaf_graph_info():
|
||||
node_kinds = graph_info.essential_node_kinds()
|
||||
if len(node_kinds) == 1:
|
||||
node_kind = node_kinds.pop()
|
||||
mismatch_node_kinds[node_kind] = (
|
||||
mismatch_node_kinds.get(node_kind, 0) + 1
|
||||
)
|
||||
print(" Mismatch node kinds: ".center(80, "="))
|
||||
print(mismatch_node_kinds)
|
||||
|
||||
@classmethod
|
||||
@_beartype.beartype
|
||||
def create_test_case_repro(
|
||||
cls, proto: bytes, inputs, outputs, dir: str, name: Optional[str] = None
|
||||
):
|
||||
"""Create a repro under "{dir}/test_{name}" for an ONNX test case.
|
||||
|
||||
The test case contains the model and the inputs/outputs data. The directory
|
||||
structure is as follows:
|
||||
|
||||
dir
|
||||
├── test_<name>
|
||||
│ ├── model.onnx
|
||||
│ └── test_data_set_0
|
||||
│ ├── input_0.pb
|
||||
│ ├── input_1.pb
|
||||
│ ├── output_0.pb
|
||||
│ └── output_1.pb
|
||||
|
||||
Args:
|
||||
proto: ONNX model proto.
|
||||
inputs: Inputs to the model.
|
||||
outputs: Outputs of the model.
|
||||
dir: Directory to save the repro.
|
||||
name: Name of the test case. If not specified, a name based on current time
|
||||
will be generated.
|
||||
Returns:
|
||||
Path to the repro.
|
||||
"""
|
||||
if name is None:
|
||||
name = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
|
||||
return onnx_proto_utils.export_as_test_case(
|
||||
proto,
|
||||
_to_numpy(inputs),
|
||||
_to_numpy(outputs),
|
||||
name,
|
||||
dir,
|
||||
)
|
||||
|
||||
@_beartype.beartype
|
||||
def validate(self, options: VerificationOptions):
|
||||
"""Run the ONNX test case with options.backend, and compare with the expected outputs.
|
||||
|
||||
Args:
|
||||
options: Options for validation.
|
||||
|
||||
Raise:
|
||||
AssertionError: if outputs from options.backend and expected outputs are not
|
||||
equal up to specified precision.
|
||||
"""
|
||||
onnx_session = _onnx_backend_session(io.BytesIO(self.proto), options.backend)
|
||||
run_outputs = onnx_session.run(None, self.inputs)
|
||||
if hasattr(onnx_session, "get_outputs"):
|
||||
output_names = [o.name for o in onnx_session.get_outputs()]
|
||||
elif hasattr(onnx_session, "output_names"):
|
||||
output_names = onnx_session.output_names
|
||||
else:
|
||||
raise ValueError(f"Unknown onnx session type: {type(onnx_session)}")
|
||||
expected_outs = [self.outputs[name] for name in output_names]
|
||||
_compare_onnx_pytorch_outputs_in_np(run_outputs, expected_outs, options)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -1099,7 +1206,7 @@ class GraphInfo:
|
||||
mismatch_error: Optional[AssertionError] = dataclasses.field(
|
||||
default=None, init=False
|
||||
)
|
||||
pt_outs: Optional[Union[_NumericType, Sequence[_NumericType]]] = dataclasses.field(
|
||||
pt_outs: Optional[Sequence[_NumericType]] = dataclasses.field(
|
||||
default=None, init=False
|
||||
)
|
||||
upper_graph_info: Optional[GraphInfo] = dataclasses.field(default=None, init=False)
|
||||
@ -1175,16 +1282,19 @@ class GraphInfo:
|
||||
else:
|
||||
print(" No mismatch ".center(80, "="))
|
||||
|
||||
def has_mismatch(self):
|
||||
@_beartype.beartype
|
||||
def has_mismatch(self) -> bool:
|
||||
"""Return True if the subgraph has output mismatch between torch and ONNX."""
|
||||
return self.mismatch_error is not None
|
||||
|
||||
@_beartype.beartype
|
||||
def essential_node_count(self) -> int:
|
||||
"""Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`."""
|
||||
return sum(
|
||||
1 for n in self.graph.nodes() if n.kind() not in self._EXCLUDED_NODE_KINDS
|
||||
)
|
||||
|
||||
@_beartype.beartype
|
||||
def essential_node_kinds(self) -> Set[str]:
|
||||
"""Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`."""
|
||||
return {
|
||||
@ -1193,7 +1303,8 @@ class GraphInfo:
|
||||
if n.kind() not in self._EXCLUDED_NODE_KINDS
|
||||
}
|
||||
|
||||
def all_mismatch_leaf_graph_info(self) -> List[GraphInfo]:
|
||||
@_beartype.beartype
|
||||
def all_mismatch_leaf_graph_info(self) -> List["GraphInfo"]:
|
||||
"""Return a list of all leaf `GraphInfo` objects that have mismatch."""
|
||||
if not self.has_mismatch():
|
||||
return []
|
||||
@ -1215,7 +1326,8 @@ class GraphInfo:
|
||||
|
||||
return results
|
||||
|
||||
def find_partition(self, id: str) -> Optional[GraphInfo]:
|
||||
@_beartype.beartype
|
||||
def find_partition(self, id: str) -> Optional["GraphInfo"]:
|
||||
"""Find the `GraphInfo` object with the given id."""
|
||||
if id == self.id:
|
||||
return self
|
||||
@ -1227,6 +1339,48 @@ class GraphInfo:
|
||||
return self.lower_graph_info.find_partition(id)
|
||||
return None
|
||||
|
||||
@_beartype.beartype
|
||||
def export_repro(
|
||||
self, repro_dir: Optional[str] = None, name: Optional[str] = None
|
||||
) -> str:
|
||||
"""Export the subgraph to ONNX along with the input/output data for repro.
|
||||
|
||||
The repro directory will contain the following files:
|
||||
|
||||
dir
|
||||
├── test_<name>
|
||||
│ ├── model.onnx
|
||||
│ └── test_data_set_0
|
||||
│ ├── input_0.pb
|
||||
│ ├── input_1.pb
|
||||
│ ├── output_0.pb
|
||||
│ └── output_1.pb
|
||||
|
||||
Args:
|
||||
repro_dir: The directory to export the repro files to. Defaults to current
|
||||
working directory if None.
|
||||
name: An optional name for the test case folder: "test_{name}".
|
||||
|
||||
Returns:
|
||||
The path to the exported repro directory.
|
||||
"""
|
||||
|
||||
if repro_dir is None:
|
||||
repro_dir = os.getcwd()
|
||||
repro_dir = os.path.join(repro_dir, "onnx_debug")
|
||||
|
||||
onnx_graph, onnx_params_dict = _onnx_graph_from_aten_graph(
|
||||
self.graph, self.export_options, self.params_dict
|
||||
)
|
||||
|
||||
proto, _ = _onnx_proto_from_onnx_graph(
|
||||
onnx_graph, self.export_options, onnx_params_dict
|
||||
)
|
||||
return OnnxTestCaseRepro.create_test_case_repro(
|
||||
proto, self.input_args, self.pt_outs, repro_dir, name
|
||||
)
|
||||
|
||||
@_beartype.beartype
|
||||
def _graph_partition_pivot(self) -> int:
|
||||
"""Find the pivot index to partition the graph.
|
||||
|
||||
@ -1249,6 +1403,7 @@ class GraphInfo:
|
||||
return included_node_indices[half_idx] + 1
|
||||
return -1
|
||||
|
||||
@_beartype.beartype
|
||||
def _partition_upper_graph(self) -> torch.Graph:
|
||||
pivot = self._graph_partition_pivot()
|
||||
if pivot == -1:
|
||||
@ -1292,6 +1447,7 @@ class GraphInfo:
|
||||
|
||||
return graph
|
||||
|
||||
@_beartype.beartype
|
||||
def _partition_lower_graph(self) -> torch.Graph:
|
||||
pivot = self._graph_partition_pivot()
|
||||
if pivot == -1:
|
||||
@ -1346,6 +1502,7 @@ class GraphInfo:
|
||||
|
||||
return graph
|
||||
|
||||
@_beartype.beartype
|
||||
def _partition_node(
|
||||
self,
|
||||
node: torch.Node,
|
||||
@ -1384,6 +1541,7 @@ class GraphInfo:
|
||||
):
|
||||
covered_bridge_values.add(process_bridge_value(output))
|
||||
|
||||
@_beartype.beartype
|
||||
def _partition_nodes(
|
||||
self,
|
||||
graph: torch.Graph,
|
||||
@ -1423,18 +1581,17 @@ class GraphInfo:
|
||||
complete_lower_nodes_set,
|
||||
)
|
||||
|
||||
@_beartype.beartype
|
||||
def _bridge_kwargs(self):
|
||||
pt_outs = self.pt_outs
|
||||
if pt_outs is None:
|
||||
raise RuntimeError("pt_outs is not set")
|
||||
if not isinstance(pt_outs, (list, tuple)):
|
||||
pt_outs = [pt_outs] # type: ignore[list-item]
|
||||
graph_outputs = list(self.graph.outputs())
|
||||
assert pt_outs is not None
|
||||
assert len(graph_outputs) == len(
|
||||
pt_outs
|
||||
), f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}"
|
||||
return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)}
|
||||
|
||||
@_beartype.beartype
|
||||
def _args_and_params_for_partition_graph(
|
||||
self,
|
||||
graph: torch.Graph,
|
||||
@ -1451,14 +1608,10 @@ class GraphInfo:
|
||||
), f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}"
|
||||
return args, params
|
||||
|
||||
@_beartype.beartype
|
||||
def verify_export(
|
||||
self, options: VerificationOptions
|
||||
) -> Tuple[
|
||||
Optional[AssertionError],
|
||||
torch.Graph,
|
||||
Union[_NumericType, Sequence[_NumericType]],
|
||||
Union[_NumericType, Sequence[_NumericType]],
|
||||
]:
|
||||
) -> Tuple[Optional[AssertionError], torch.Graph, _OutputsType, _OutputsType]:
|
||||
return verify_aten_graph(
|
||||
self.graph,
|
||||
input_args=self.input_args,
|
||||
@ -1467,6 +1620,7 @@ class GraphInfo:
|
||||
verification_options=options,
|
||||
)
|
||||
|
||||
@_beartype.beartype
|
||||
def find_mismatch(
|
||||
self,
|
||||
options: Optional[VerificationOptions] = None,
|
||||
@ -1536,6 +1690,7 @@ class GraphInfo:
|
||||
self.lower_graph_info.find_mismatch(options)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _all_nodes(nodes: Collection[torch.Node]) -> Set[torch.Node]:
|
||||
all_nodes = set(nodes)
|
||||
for n in nodes:
|
||||
@ -1544,12 +1699,14 @@ def _all_nodes(nodes: Collection[torch.Node]) -> Set[torch.Node]:
|
||||
return all_nodes
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
|
||||
if any(use.user in nodes for use in value.uses()):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool:
|
||||
for output in node.outputs():
|
||||
if _has_uses_by_nodes(output, nodes):
|
||||
@ -1557,6 +1714,7 @@ def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
|
||||
return value.node() in nodes
|
||||
|
||||
|
Reference in New Issue
Block a user