From f2587537992e873f29a7ff68870c37eef9a3fa53 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Fri, 9 Dec 2022 10:18:19 -0800 Subject: [PATCH] [ONNX] Add repro export from `GraphInfo` (#89947) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89947 Approved by: https://github.com/justinchuby --- ...rints_correct_info_when_no_mismatch.expect | 5 +- test/onnx/test_verification.py | 12 + torch/onnx/_internal/onnx_proto_utils.py | 148 +++++- torch/onnx/verification.py | 456 ++++++++++++------ 4 files changed, 467 insertions(+), 154 deletions(-) diff --git a/test/onnx/expect/TestFindMismatch_ONNX_RUNTIME_CPU.test_find_mismatch_prints_correct_info_when_no_mismatch.expect b/test/onnx/expect/TestFindMismatch_ONNX_RUNTIME_CPU.test_find_mismatch_prints_correct_info_when_no_mismatch.expect index a97987522c4b..80cfa609d20a 100644 --- a/test/onnx/expect/TestFindMismatch_ONNX_RUNTIME_CPU.test_find_mismatch_prints_correct_info_when_no_mismatch.expect +++ b/test/onnx/expect/TestFindMismatch_ONNX_RUNTIME_CPU.test_find_mismatch_prints_correct_info_when_no_mismatch.expect @@ -3,7 +3,4 @@ ==================================== Tree: ===================================== 1 ✓ id: -=========================== Mismatch leaf subgraphs: =========================== -[] -============================= Mismatch node kinds: ============================= -{} +============================== No mismatch found. ============================== diff --git a/test/onnx/test_verification.py b/test/onnx/test_verification.py index 8b3dc669afc8..cf62e1f18696 100644 --- a/test/onnx/test_verification.py +++ b/test/onnx/test_verification.py @@ -2,6 +2,7 @@ import contextlib import io +import tempfile import unittest import numpy as np @@ -272,6 +273,17 @@ class TestFindMismatch(pytorch_test_common.ExportTestCase): ) self.assertExpected(f.getvalue()) + def test_export_repro_for_mismatch(self): + mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info() + self.assertTrue(len(mismatch_leaves) > 0) + leaf_info = mismatch_leaves[0] + with tempfile.TemporaryDirectory() as temp_dir: + repro_dir = leaf_info.export_repro(temp_dir) + + with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): + options = verification.VerificationOptions(backend=self.onnx_backend) + verification.OnnxTestCaseRepro(repro_dir).validate(options) + if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/_internal/onnx_proto_utils.py b/torch/onnx/_internal/onnx_proto_utils.py index 6c8b1e420ec3..e8d85d80a0af 100644 --- a/torch/onnx/_internal/onnx_proto_utils.py +++ b/torch/onnx/_internal/onnx_proto_utils.py @@ -1,9 +1,11 @@ """Utilities for manipulating the onnx and onnx-script dependencies and ONNX proto.""" +import glob import io import os +import shutil import zipfile -from typing import List, Mapping, Set, Union +from typing import Any, List, Mapping, Set, Tuple, Union import torch import torch.jit._trace @@ -12,6 +14,150 @@ from torch.onnx import _constants, _exporter_states, errors from torch.onnx._internal import _beartype, jit_utils, registration +@_beartype.beartype +def export_as_test_case( + model_bytes: bytes, inputs_data, outputs_data, name: str, dir: str +) -> str: + """Export an ONNX model as a self contained ONNX test case. + + The test case contains the model and the inputs/outputs data. The directory structure + is as follows: + + dir + ├── test_ + │ ├── model.onnx + │ └── test_data_set_0 + │ ├── input_0.pb + │ ├── input_1.pb + │ ├── output_0.pb + │ └── output_1.pb + + Args: + model_bytes: The ONNX model in bytes. + inputs_data: The inputs data, nested data structure of numpy.ndarray. + outputs_data: The outputs data, nested data structure of numpy.ndarray. + + Returns: + The path to the test case directory. + """ + try: + import onnx + except ImportError: + raise ImportError( + "Export test case to ONNX format failed: Please install ONNX." + ) + + test_case_dir = os.path.join(dir, "test_" + name) + os.makedirs(test_case_dir, exist_ok=True) + _export_file( + model_bytes, + os.path.join(test_case_dir, "model.onnx"), + _exporter_states.ExportTypes.PROTOBUF_FILE, + {}, + ) + data_set_dir = os.path.join(test_case_dir, "test_data_set_0") + if os.path.exists(data_set_dir): + shutil.rmtree(data_set_dir) + os.makedirs(data_set_dir) + + proto = onnx.load_from_string(model_bytes) + + for i, (input_proto, input) in enumerate(zip(proto.graph.input, inputs_data)): + export_data(input, input_proto, os.path.join(data_set_dir, f"input_{i}.pb")) + for i, (output_proto, output) in enumerate(zip(proto.graph.output, outputs_data)): + export_data(output, output_proto, os.path.join(data_set_dir, f"output_{i}.pb")) + + return test_case_dir + + +@_beartype.beartype +def load_test_case(dir: str) -> Tuple[bytes, Any, Any]: + """Load a self contained ONNX test case from a directory. + + The test case must contain the model and the inputs/outputs data. The directory structure + should be as follows: + + dir + ├── test_ + │ ├── model.onnx + │ └── test_data_set_0 + │ ├── input_0.pb + │ ├── input_1.pb + │ ├── output_0.pb + │ └── output_1.pb + + Args: + dir: The directory containing the test case. + + Returns: + model_bytes: The ONNX model in bytes. + inputs: the inputs data, mapping from input name to numpy.ndarray. + outputs: the outputs data, mapping from output name to numpy.ndarray. + """ + try: + import onnx + from onnx import numpy_helper + except ImportError: + raise ImportError( + "Load test case from ONNX format failed: Please install ONNX." + ) + + with open(os.path.join(dir, "model.onnx"), "rb") as f: + model_bytes = f.read() + + test_data_dir = os.path.join(dir, "test_data_set_0") + + inputs = {} + input_files = glob.glob(os.path.join(test_data_dir, "input_*.pb")) + for input_file in input_files: + tensor = onnx.load_tensor(input_file) + inputs[tensor.name] = numpy_helper.to_array(tensor) + outputs = {} + output_files = glob.glob(os.path.join(test_data_dir, "output_*.pb")) + for output_file in output_files: + tensor = onnx.load_tensor(output_file) + outputs[tensor.name] = numpy_helper.to_array(tensor) + + return model_bytes, inputs, outputs + + +@_beartype.beartype +def export_data(data, value_info_proto, f: str) -> None: + """Export data to ONNX protobuf format. + + Args: + data: The data to export, nested data structure of numpy.ndarray. + value_info_proto: The ValueInfoProto of the data. The type of the ValueInfoProto + determines how the data is stored. + f: The file to write the data to. + """ + try: + from onnx import numpy_helper + except ImportError: + raise ImportError("Export data to ONNX format failed: Please install ONNX.") + + with open(f, "wb") as opened_file: + if value_info_proto.type.HasField("map_type"): + opened_file.write( + numpy_helper.from_dict(data, value_info_proto.name).SerializeToString() + ) + elif value_info_proto.type.HasField("sequence_type"): + opened_file.write( + numpy_helper.from_list(data, value_info_proto.name).SerializeToString() + ) + elif value_info_proto.type.HasField("optional_type"): + opened_file.write( + numpy_helper.from_optional( + data, value_info_proto.name + ).SerializeToString() + ) + else: + assert value_info_proto.type.HasField("tensor_type") + opened_file.write( + numpy_helper.from_array(data, value_info_proto.name).SerializeToString() + ) + + @_beartype.beartype def _export_file( model_bytes: bytes, diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py index efa010a3d250..a774250ed187 100644 --- a/torch/onnx/verification.py +++ b/torch/onnx/verification.py @@ -8,6 +8,7 @@ from __future__ import annotations import contextlib import copy import dataclasses +import datetime import difflib import enum import functools @@ -47,6 +48,7 @@ _NumericType = Union[Number, torch.Tensor, np.ndarray] _ModelType = Union[torch.nn.Module, torch.jit.ScriptModule] _InputArgsType = Union[torch.Tensor, Tuple[Any, ...]] _InputKwargsType = Mapping[str, Any] +_OutputsType = Union[Sequence[_NumericType], Sequence] class OnnxBackend(enum.Enum): @@ -146,7 +148,7 @@ def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list: @_beartype.beartype -def _run_onnx(onnx_session, inputs): +def _run_onnx(onnx_session, inputs) -> _OutputsType: kw_inputs = {} if inputs and isinstance(inputs[-1], dict): kw_inputs = inputs[-1] @@ -229,34 +231,14 @@ def _onnx_backend_session(model: Union[str, io.BytesIO], backend: OnnxBackend): @_beartype.beartype -def _compare_onnx_pytorch_outputs( - onnx_outs: Union[Sequence[_NumericType], Sequence], - pt_outs: Optional[Union[_NumericType, Sequence[_NumericType], Sequence, Dict]], +def _compare_onnx_pytorch_outputs_in_np( + onnx_outs: _OutputsType, + pt_outs: _OutputsType, options: VerificationOptions, ): - """ - Compare ONNX and PyTorch outputs. - - Args: - onnx_outs: outputs from ONNX backend. - pt_outs: outputs from PyTorch. - options: options for verification. - - Raises: - AssertionError: if outputs from ONNX model and PyTorch model are not - equal up to specified precision. - ValueError: if arguments provided are invalid. - """ - if options.ignore_none: - # torch.jit._flatten filters None type - pt_outs, _ = torch.jit._flatten(pt_outs) - else: - pt_outs = _inline_flatten_list([pt_outs], []) - pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False) - onnx_outs = _inline_flatten_list(onnx_outs, []) assert len(onnx_outs) == len( - pt_outs_np - ), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs_np)})" + pt_outs + ), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})" acceptable_error_percentage = options.acceptable_error_percentage if acceptable_error_percentage and ( acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0 @@ -265,7 +247,7 @@ def _compare_onnx_pytorch_outputs( "If set, acceptable_error_percentage should be between 0.0 and 1.0" ) - for ort_out, pt_out in zip(onnx_outs, pt_outs_np): + for ort_out, pt_out in zip(onnx_outs, pt_outs): try: # TODO: Remove `check_shape` option once every shape inconsistent issue is addressed. if not options.check_shape: @@ -298,6 +280,35 @@ def _compare_onnx_pytorch_outputs( raise +@_beartype.beartype +def _compare_onnx_pytorch_outputs( + onnx_outs: _OutputsType, + pt_outs: Any, + options: VerificationOptions, +): + """ + Compare ONNX and PyTorch outputs. + + Args: + onnx_outs: outputs from ONNX backend. + pt_outs: outputs from PyTorch. + options: options for verification. + + Raises: + AssertionError: if outputs from ONNX model and PyTorch model are not + equal up to specified precision. + ValueError: if arguments provided are invalid. + """ + if options.ignore_none: + # torch.jit._flatten filters None type + pt_outs, _ = torch.jit._flatten(pt_outs) + else: + pt_outs = _inline_flatten_list([pt_outs], []) + pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False) + onnx_outs = _inline_flatten_list(onnx_outs, []) + _compare_onnx_pytorch_outputs_in_np(onnx_outs, pt_outs_np, options) + + @_beartype.beartype def _prepare_input_for_pytorch(args, kwargs): """Prepare input for PyTorch model execution. @@ -655,6 +666,101 @@ def _onnx_graph_from_model( return onnx_graph +@_beartype.beartype +def _onnx_graph_from_aten_graph( + graph: torch.Graph, + export_options: _experimental.ExportOptions, + params_dict: Optional[Dict[str, Any]] = None, +) -> Tuple[torch.Graph, Dict[str, Any]]: + if params_dict is None: + params_dict = {} + operator_export_type = export_options.operator_export_type + dynamic_axes = export_options.dynamic_axes or {} + input_names = export_options.input_names + training = export_options.training + do_constant_folding = export_options.do_constant_folding + opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET + + do_constant_folding = utils._decide_constant_folding( + do_constant_folding, operator_export_type, training + ) + + # TODO: Below is doing aten graph to onnx. It should be abstracted as a + # function in torch/onnx/utils.py. + graph = graph.copy() + graph = utils._optimize_graph( + graph, + operator_export_type, + params_dict=params_dict, + dynamic_axes=dynamic_axes, + input_names=input_names, + ) + + if training is None or training == _C_onnx.TrainingMode.EVAL: + params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) + + if ( + do_constant_folding + and opset_version >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET + ): + params_dict = _C._jit_pass_onnx_constant_fold(graph, params_dict, opset_version) + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + + if GLOBALS.onnx_shape_inference: + _C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, opset_version) + + params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) + + # For ONNX opset < 9, constants only have three data types: float16, float, double. + # In this pass transform constants of other data types to float/double + cast operator. + if opset_version < 9: + _C._jit_pass_onnx_cast_all_constant_to_floating(graph) + + params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) + _C._jit_decay_packed_param_input_types(graph) + + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + + if export_options.verbose: + print("ONNX graph: ", graph) + + return graph, params_dict + + +@_beartype.beartype +def _onnx_proto_from_onnx_graph( + onnx_graph: torch.Graph, + export_options: _experimental.ExportOptions, + params_dict: Dict[str, Any], +) -> Tuple[bytes, Mapping[str, bytes]]: + opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET + dynamic_axes = export_options.dynamic_axes or {} + operator_export_type = export_options.operator_export_type + val_keep_init_as_ip = utils._decide_keep_init_as_input( + export_options.keep_initializers_as_inputs, + operator_export_type, + opset_version, + ) + val_add_node_names = utils._decide_add_node_names(True, operator_export_type) + custom_opsets = export_options.custom_opsets or {} + + proto, export_map, _, _ = onnx_graph._export_onnx( # type: ignore[attr-defined] + params_dict, + opset_version, + dynamic_axes, + False, + operator_export_type, + not export_options.verbose, + val_keep_init_as_ip, + custom_opsets, + val_add_node_names, + "", + {}, + ) + + return proto, export_map + + @_beartype.beartype def check_export_model_diff( model: Union[torch.nn.Module, torch.jit.ScriptModule], @@ -795,41 +901,14 @@ def verify_aten_graph( export_options: _experimental.ExportOptions, params_dict: Optional[Dict[str, Any]] = None, verification_options: Optional[VerificationOptions] = None, -) -> Tuple[ - Optional[AssertionError], - torch.Graph, - Union[_NumericType, Sequence[_NumericType]], - Union[_NumericType, Sequence[_NumericType]], -]: +) -> Tuple[Optional[AssertionError], torch.Graph, _OutputsType, _OutputsType]: if verification_options is None: verification_options = VerificationOptions() - - original_jit_graph = graph - graph = graph.copy() - - operator_export_type = export_options.operator_export_type - dynamic_axes = export_options.dynamic_axes - if dynamic_axes is None: - dynamic_axes = {} - input_names = export_options.input_names - training = export_options.training - do_constant_folding = export_options.do_constant_folding - opset_version = export_options.opset_version if params_dict is None: params_dict = {} - if opset_version is None: - opset_version = _constants.ONNX_DEFAULT_OPSET - - val_keep_init_as_ip = utils._decide_keep_init_as_input( - export_options.keep_initializers_as_inputs, - operator_export_type, - opset_version, - ) - val_add_node_names = utils._decide_add_node_names(True, operator_export_type) - do_constant_folding = utils._decide_constant_folding( - do_constant_folding, operator_export_type, training - ) + original_jit_graph = graph + graph = graph.copy() # Execute aten graph and get reference torch jit outputs. graph_inputs = list(v for v in graph.inputs()) @@ -840,69 +919,18 @@ def verify_aten_graph( jit_inputs = copy.deepcopy(jit_inputs) jit_input_and_parameters = jit_inputs + tuple(weights) jit_outs = torch._C._jit_interpret_graph(graph, jit_input_and_parameters) # type: ignore[attr-defined] + if not isinstance(jit_outs, (list, tuple)): + jit_outs = [jit_outs] - # Convert aten graph to onnx. - graph = utils._optimize_graph( - graph, - operator_export_type, - params_dict=params_dict, - dynamic_axes=dynamic_axes, - input_names=input_names, + # Convert aten graph to onnx graph. + graph, onnx_params_dict = _onnx_graph_from_aten_graph( + graph, export_options, params_dict ) - # TODO(bowbao): Below is doing aten graph to onnx. It should be abstracted as a - # function in torch/onnx/utils.py. - if training is None or training == _C_onnx.TrainingMode.EVAL: - params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) - - if ( - do_constant_folding - and GLOBALS.export_onnx_opset_version - >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET - ): - params_dict = _C._jit_pass_onnx_constant_fold( - graph, params_dict, GLOBALS.export_onnx_opset_version - ) - _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) - - if GLOBALS.onnx_shape_inference: - _C._jit_pass_onnx_graph_shape_type_inference( - graph, params_dict, GLOBALS.export_onnx_opset_version - ) - - params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) - - # For ONNX opset < 9, constants only have three data types: float16, float, double. - # In this pass transform constants of other data types to float/double + cast operator. - if GLOBALS.export_onnx_opset_version < 9: - _C._jit_pass_onnx_cast_all_constant_to_floating(graph) - - params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) - _C._jit_decay_packed_param_input_types(graph) - - _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) - - if export_options.verbose: - print("ONNX graph: ", graph) - - custom_opsets = export_options.custom_opsets - if custom_opsets is None: - custom_opsets = {} - + proto, export_map = _onnx_proto_from_onnx_graph( + graph, export_options, onnx_params_dict + ) model_f: Union[str, io.BytesIO] = io.BytesIO() - proto, export_map, _, _ = graph._export_onnx( # type: ignore[attr-defined] - params_dict, - opset_version, - dynamic_axes, - False, - operator_export_type, - not export_options.verbose, - val_keep_init_as_ip, - custom_opsets, - val_add_node_names, - "", - {}, - ) export_type = _exporter_states.ExportTypes.PROTOBUF_FILE onnx_proto_utils._export_file(proto, model_f, export_type, export_map) @@ -933,7 +961,6 @@ def verify_aten_graph( options=verification_options, ) except AssertionError as e: - # print("Has mismatch: ", e) return e, graph, jit_outs, onnx_outs return None, graph, jit_outs, onnx_outs @@ -967,6 +994,7 @@ class GraphInfoPrettyPrinter: self.upper_printer = None self.lower_printer = None + @_beartype.beartype def _total_rows(self) -> int: if self.graph_info is None: return 1 @@ -976,6 +1004,7 @@ class GraphInfoPrettyPrinter: ) return 2 # Two lines: node count + id. + @_beartype.beartype def _node_count_segment_str(self) -> str: if self.graph_info is None: return "..." @@ -989,16 +1018,19 @@ class GraphInfoPrettyPrinter: return f"{node_count} {'X' if has_mismatch else '✓'} {error_node_kind}" + @_beartype.beartype def _graph_id_segment_str(self) -> str: if self.graph_info is None: return "" return f"id: {self.graph_info.id}" + @_beartype.beartype def _max_segment_columns(self) -> int: return max( map(len, (self._node_count_segment_str(), self._graph_id_segment_str())) ) + @_beartype.beartype def _graph_segment_str_at_line(self, line: int) -> str: """Get the string representation of the graph segment at the given line.""" if line == 0: @@ -1013,6 +1045,7 @@ class GraphInfoPrettyPrinter: return " " * self._max_segment_columns() return "" + @_beartype.beartype def _connector_segment_str_at_line(self, line: int) -> str: """Get the connector segment string at the given line.""" if self.upper_printer is None and self.lower_printer is None: @@ -1029,6 +1062,7 @@ class GraphInfoPrettyPrinter: return " " return "" + @_beartype.beartype def _children_str_at_line(self, line: int) -> str: """Get the string representation of the children at the given line. @@ -1050,6 +1084,7 @@ class GraphInfoPrettyPrinter: ) return "" + @_beartype.beartype def _str_at_line(self, line: int) -> str: """Get the string representation of the graph at the given line.""" return ( @@ -1067,25 +1102,97 @@ class GraphInfoPrettyPrinter: total_rows = self._total_rows() for line in range(total_rows): print(self._str_at_line(line).rstrip()) - # Summarize leaf subgraphs with mismatch. - print(" Mismatch leaf subgraphs: ".center(80, "=")) - print( - [ - graph_info.id - for graph_info in self.graph_info.all_mismatch_leaf_graph_info() - ] + if self.graph_info.has_mismatch(): + # Summarize leaf subgraphs with mismatch. + print(" Mismatch leaf subgraphs: ".center(80, "=")) + print( + [ + graph_info.id + for graph_info in self.graph_info.all_mismatch_leaf_graph_info() + ] + ) + # Summarize node kinds with mismatch. + mismatch_node_kinds: Dict[str, int] = {} + for graph_info in self.graph_info.all_mismatch_leaf_graph_info(): + node_kinds = graph_info.essential_node_kinds() + if len(node_kinds) == 1: + node_kind = node_kinds.pop() + mismatch_node_kinds[node_kind] = ( + mismatch_node_kinds.get(node_kind, 0) + 1 + ) + print(" Mismatch node kinds: ".center(80, "=")) + print(mismatch_node_kinds) + else: + print(" No mismatch found. ".center(80, "=")) + + +class OnnxTestCaseRepro: + def __init__(self, repro_dir): + self.repro_dir = repro_dir + self.proto, self.inputs, self.outputs = onnx_proto_utils.load_test_case( + repro_dir ) - # Summarize node kinds with mismatch. - mismatch_node_kinds: Dict[str, int] = {} - for graph_info in self.graph_info.all_mismatch_leaf_graph_info(): - node_kinds = graph_info.essential_node_kinds() - if len(node_kinds) == 1: - node_kind = node_kinds.pop() - mismatch_node_kinds[node_kind] = ( - mismatch_node_kinds.get(node_kind, 0) + 1 - ) - print(" Mismatch node kinds: ".center(80, "=")) - print(mismatch_node_kinds) + + @classmethod + @_beartype.beartype + def create_test_case_repro( + cls, proto: bytes, inputs, outputs, dir: str, name: Optional[str] = None + ): + """Create a repro under "{dir}/test_{name}" for an ONNX test case. + + The test case contains the model and the inputs/outputs data. The directory + structure is as follows: + + dir + ├── test_ + │ ├── model.onnx + │ └── test_data_set_0 + │ ├── input_0.pb + │ ├── input_1.pb + │ ├── output_0.pb + │ └── output_1.pb + + Args: + proto: ONNX model proto. + inputs: Inputs to the model. + outputs: Outputs of the model. + dir: Directory to save the repro. + name: Name of the test case. If not specified, a name based on current time + will be generated. + Returns: + Path to the repro. + """ + if name is None: + name = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + return onnx_proto_utils.export_as_test_case( + proto, + _to_numpy(inputs), + _to_numpy(outputs), + name, + dir, + ) + + @_beartype.beartype + def validate(self, options: VerificationOptions): + """Run the ONNX test case with options.backend, and compare with the expected outputs. + + Args: + options: Options for validation. + + Raise: + AssertionError: if outputs from options.backend and expected outputs are not + equal up to specified precision. + """ + onnx_session = _onnx_backend_session(io.BytesIO(self.proto), options.backend) + run_outputs = onnx_session.run(None, self.inputs) + if hasattr(onnx_session, "get_outputs"): + output_names = [o.name for o in onnx_session.get_outputs()] + elif hasattr(onnx_session, "output_names"): + output_names = onnx_session.output_names + else: + raise ValueError(f"Unknown onnx session type: {type(onnx_session)}") + expected_outs = [self.outputs[name] for name in output_names] + _compare_onnx_pytorch_outputs_in_np(run_outputs, expected_outs, options) @dataclasses.dataclass @@ -1099,7 +1206,7 @@ class GraphInfo: mismatch_error: Optional[AssertionError] = dataclasses.field( default=None, init=False ) - pt_outs: Optional[Union[_NumericType, Sequence[_NumericType]]] = dataclasses.field( + pt_outs: Optional[Sequence[_NumericType]] = dataclasses.field( default=None, init=False ) upper_graph_info: Optional[GraphInfo] = dataclasses.field(default=None, init=False) @@ -1175,16 +1282,19 @@ class GraphInfo: else: print(" No mismatch ".center(80, "=")) - def has_mismatch(self): + @_beartype.beartype + def has_mismatch(self) -> bool: """Return True if the subgraph has output mismatch between torch and ONNX.""" return self.mismatch_error is not None + @_beartype.beartype def essential_node_count(self) -> int: """Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" return sum( 1 for n in self.graph.nodes() if n.kind() not in self._EXCLUDED_NODE_KINDS ) + @_beartype.beartype def essential_node_kinds(self) -> Set[str]: """Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" return { @@ -1193,7 +1303,8 @@ class GraphInfo: if n.kind() not in self._EXCLUDED_NODE_KINDS } - def all_mismatch_leaf_graph_info(self) -> List[GraphInfo]: + @_beartype.beartype + def all_mismatch_leaf_graph_info(self) -> List["GraphInfo"]: """Return a list of all leaf `GraphInfo` objects that have mismatch.""" if not self.has_mismatch(): return [] @@ -1215,7 +1326,8 @@ class GraphInfo: return results - def find_partition(self, id: str) -> Optional[GraphInfo]: + @_beartype.beartype + def find_partition(self, id: str) -> Optional["GraphInfo"]: """Find the `GraphInfo` object with the given id.""" if id == self.id: return self @@ -1227,6 +1339,48 @@ class GraphInfo: return self.lower_graph_info.find_partition(id) return None + @_beartype.beartype + def export_repro( + self, repro_dir: Optional[str] = None, name: Optional[str] = None + ) -> str: + """Export the subgraph to ONNX along with the input/output data for repro. + + The repro directory will contain the following files: + + dir + ├── test_ + │ ├── model.onnx + │ └── test_data_set_0 + │ ├── input_0.pb + │ ├── input_1.pb + │ ├── output_0.pb + │ └── output_1.pb + + Args: + repro_dir: The directory to export the repro files to. Defaults to current + working directory if None. + name: An optional name for the test case folder: "test_{name}". + + Returns: + The path to the exported repro directory. + """ + + if repro_dir is None: + repro_dir = os.getcwd() + repro_dir = os.path.join(repro_dir, "onnx_debug") + + onnx_graph, onnx_params_dict = _onnx_graph_from_aten_graph( + self.graph, self.export_options, self.params_dict + ) + + proto, _ = _onnx_proto_from_onnx_graph( + onnx_graph, self.export_options, onnx_params_dict + ) + return OnnxTestCaseRepro.create_test_case_repro( + proto, self.input_args, self.pt_outs, repro_dir, name + ) + + @_beartype.beartype def _graph_partition_pivot(self) -> int: """Find the pivot index to partition the graph. @@ -1249,6 +1403,7 @@ class GraphInfo: return included_node_indices[half_idx] + 1 return -1 + @_beartype.beartype def _partition_upper_graph(self) -> torch.Graph: pivot = self._graph_partition_pivot() if pivot == -1: @@ -1292,6 +1447,7 @@ class GraphInfo: return graph + @_beartype.beartype def _partition_lower_graph(self) -> torch.Graph: pivot = self._graph_partition_pivot() if pivot == -1: @@ -1346,6 +1502,7 @@ class GraphInfo: return graph + @_beartype.beartype def _partition_node( self, node: torch.Node, @@ -1384,6 +1541,7 @@ class GraphInfo: ): covered_bridge_values.add(process_bridge_value(output)) + @_beartype.beartype def _partition_nodes( self, graph: torch.Graph, @@ -1423,18 +1581,17 @@ class GraphInfo: complete_lower_nodes_set, ) + @_beartype.beartype def _bridge_kwargs(self): pt_outs = self.pt_outs - if pt_outs is None: - raise RuntimeError("pt_outs is not set") - if not isinstance(pt_outs, (list, tuple)): - pt_outs = [pt_outs] # type: ignore[list-item] graph_outputs = list(self.graph.outputs()) + assert pt_outs is not None assert len(graph_outputs) == len( pt_outs ), f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}" return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)} + @_beartype.beartype def _args_and_params_for_partition_graph( self, graph: torch.Graph, @@ -1451,14 +1608,10 @@ class GraphInfo: ), f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}" return args, params + @_beartype.beartype def verify_export( self, options: VerificationOptions - ) -> Tuple[ - Optional[AssertionError], - torch.Graph, - Union[_NumericType, Sequence[_NumericType]], - Union[_NumericType, Sequence[_NumericType]], - ]: + ) -> Tuple[Optional[AssertionError], torch.Graph, _OutputsType, _OutputsType]: return verify_aten_graph( self.graph, input_args=self.input_args, @@ -1467,6 +1620,7 @@ class GraphInfo: verification_options=options, ) + @_beartype.beartype def find_mismatch( self, options: Optional[VerificationOptions] = None, @@ -1536,6 +1690,7 @@ class GraphInfo: self.lower_graph_info.find_mismatch(options) +@_beartype.beartype def _all_nodes(nodes: Collection[torch.Node]) -> Set[torch.Node]: all_nodes = set(nodes) for n in nodes: @@ -1544,12 +1699,14 @@ def _all_nodes(nodes: Collection[torch.Node]) -> Set[torch.Node]: return all_nodes +@_beartype.beartype def _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool: if any(use.user in nodes for use in value.uses()): return True return False +@_beartype.beartype def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool: for output in node.outputs(): if _has_uses_by_nodes(output, nodes): @@ -1557,6 +1714,7 @@ def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool: return False +@_beartype.beartype def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool: return value.node() in nodes