[ONNX] Add repro export from GraphInfo (#89947)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89947
Approved by: https://github.com/justinchuby
This commit is contained in:
BowenBao
2022-12-09 10:18:19 -08:00
committed by PyTorch MergeBot
parent 525c33c09f
commit f258753799
4 changed files with 467 additions and 154 deletions

View File

@ -3,7 +3,4 @@
==================================== Tree: =====================================
1 ✓
id:
=========================== Mismatch leaf subgraphs: ===========================
[]
============================= Mismatch node kinds: =============================
{}
============================== No mismatch found. ==============================

View File

@ -2,6 +2,7 @@
import contextlib
import io
import tempfile
import unittest
import numpy as np
@ -272,6 +273,17 @@ class TestFindMismatch(pytorch_test_common.ExportTestCase):
)
self.assertExpected(f.getvalue())
def test_export_repro_for_mismatch(self):
mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
self.assertTrue(len(mismatch_leaves) > 0)
leaf_info = mismatch_leaves[0]
with tempfile.TemporaryDirectory() as temp_dir:
repro_dir = leaf_info.export_repro(temp_dir)
with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
options = verification.VerificationOptions(backend=self.onnx_backend)
verification.OnnxTestCaseRepro(repro_dir).validate(options)
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -1,9 +1,11 @@
"""Utilities for manipulating the onnx and onnx-script dependencies and ONNX proto."""
import glob
import io
import os
import shutil
import zipfile
from typing import List, Mapping, Set, Union
from typing import Any, List, Mapping, Set, Tuple, Union
import torch
import torch.jit._trace
@ -12,6 +14,150 @@ from torch.onnx import _constants, _exporter_states, errors
from torch.onnx._internal import _beartype, jit_utils, registration
@_beartype.beartype
def export_as_test_case(
model_bytes: bytes, inputs_data, outputs_data, name: str, dir: str
) -> str:
"""Export an ONNX model as a self contained ONNX test case.
The test case contains the model and the inputs/outputs data. The directory structure
is as follows:
dir
├── test_<name>
│ ├── model.onnx
│ └── test_data_set_0
│ ├── input_0.pb
│ ├── input_1.pb
│ ├── output_0.pb
│ └── output_1.pb
Args:
model_bytes: The ONNX model in bytes.
inputs_data: The inputs data, nested data structure of numpy.ndarray.
outputs_data: The outputs data, nested data structure of numpy.ndarray.
Returns:
The path to the test case directory.
"""
try:
import onnx
except ImportError:
raise ImportError(
"Export test case to ONNX format failed: Please install ONNX."
)
test_case_dir = os.path.join(dir, "test_" + name)
os.makedirs(test_case_dir, exist_ok=True)
_export_file(
model_bytes,
os.path.join(test_case_dir, "model.onnx"),
_exporter_states.ExportTypes.PROTOBUF_FILE,
{},
)
data_set_dir = os.path.join(test_case_dir, "test_data_set_0")
if os.path.exists(data_set_dir):
shutil.rmtree(data_set_dir)
os.makedirs(data_set_dir)
proto = onnx.load_from_string(model_bytes)
for i, (input_proto, input) in enumerate(zip(proto.graph.input, inputs_data)):
export_data(input, input_proto, os.path.join(data_set_dir, f"input_{i}.pb"))
for i, (output_proto, output) in enumerate(zip(proto.graph.output, outputs_data)):
export_data(output, output_proto, os.path.join(data_set_dir, f"output_{i}.pb"))
return test_case_dir
@_beartype.beartype
def load_test_case(dir: str) -> Tuple[bytes, Any, Any]:
"""Load a self contained ONNX test case from a directory.
The test case must contain the model and the inputs/outputs data. The directory structure
should be as follows:
dir
├── test_<name>
│ ├── model.onnx
│ └── test_data_set_0
│ ├── input_0.pb
│ ├── input_1.pb
│ ├── output_0.pb
│ └── output_1.pb
Args:
dir: The directory containing the test case.
Returns:
model_bytes: The ONNX model in bytes.
inputs: the inputs data, mapping from input name to numpy.ndarray.
outputs: the outputs data, mapping from output name to numpy.ndarray.
"""
try:
import onnx
from onnx import numpy_helper
except ImportError:
raise ImportError(
"Load test case from ONNX format failed: Please install ONNX."
)
with open(os.path.join(dir, "model.onnx"), "rb") as f:
model_bytes = f.read()
test_data_dir = os.path.join(dir, "test_data_set_0")
inputs = {}
input_files = glob.glob(os.path.join(test_data_dir, "input_*.pb"))
for input_file in input_files:
tensor = onnx.load_tensor(input_file)
inputs[tensor.name] = numpy_helper.to_array(tensor)
outputs = {}
output_files = glob.glob(os.path.join(test_data_dir, "output_*.pb"))
for output_file in output_files:
tensor = onnx.load_tensor(output_file)
outputs[tensor.name] = numpy_helper.to_array(tensor)
return model_bytes, inputs, outputs
@_beartype.beartype
def export_data(data, value_info_proto, f: str) -> None:
"""Export data to ONNX protobuf format.
Args:
data: The data to export, nested data structure of numpy.ndarray.
value_info_proto: The ValueInfoProto of the data. The type of the ValueInfoProto
determines how the data is stored.
f: The file to write the data to.
"""
try:
from onnx import numpy_helper
except ImportError:
raise ImportError("Export data to ONNX format failed: Please install ONNX.")
with open(f, "wb") as opened_file:
if value_info_proto.type.HasField("map_type"):
opened_file.write(
numpy_helper.from_dict(data, value_info_proto.name).SerializeToString()
)
elif value_info_proto.type.HasField("sequence_type"):
opened_file.write(
numpy_helper.from_list(data, value_info_proto.name).SerializeToString()
)
elif value_info_proto.type.HasField("optional_type"):
opened_file.write(
numpy_helper.from_optional(
data, value_info_proto.name
).SerializeToString()
)
else:
assert value_info_proto.type.HasField("tensor_type")
opened_file.write(
numpy_helper.from_array(data, value_info_proto.name).SerializeToString()
)
@_beartype.beartype
def _export_file(
model_bytes: bytes,

View File

@ -8,6 +8,7 @@ from __future__ import annotations
import contextlib
import copy
import dataclasses
import datetime
import difflib
import enum
import functools
@ -47,6 +48,7 @@ _NumericType = Union[Number, torch.Tensor, np.ndarray]
_ModelType = Union[torch.nn.Module, torch.jit.ScriptModule]
_InputArgsType = Union[torch.Tensor, Tuple[Any, ...]]
_InputKwargsType = Mapping[str, Any]
_OutputsType = Union[Sequence[_NumericType], Sequence]
class OnnxBackend(enum.Enum):
@ -146,7 +148,7 @@ def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list:
@_beartype.beartype
def _run_onnx(onnx_session, inputs):
def _run_onnx(onnx_session, inputs) -> _OutputsType:
kw_inputs = {}
if inputs and isinstance(inputs[-1], dict):
kw_inputs = inputs[-1]
@ -229,34 +231,14 @@ def _onnx_backend_session(model: Union[str, io.BytesIO], backend: OnnxBackend):
@_beartype.beartype
def _compare_onnx_pytorch_outputs(
onnx_outs: Union[Sequence[_NumericType], Sequence],
pt_outs: Optional[Union[_NumericType, Sequence[_NumericType], Sequence, Dict]],
def _compare_onnx_pytorch_outputs_in_np(
onnx_outs: _OutputsType,
pt_outs: _OutputsType,
options: VerificationOptions,
):
"""
Compare ONNX and PyTorch outputs.
Args:
onnx_outs: outputs from ONNX backend.
pt_outs: outputs from PyTorch.
options: options for verification.
Raises:
AssertionError: if outputs from ONNX model and PyTorch model are not
equal up to specified precision.
ValueError: if arguments provided are invalid.
"""
if options.ignore_none:
# torch.jit._flatten filters None type
pt_outs, _ = torch.jit._flatten(pt_outs)
else:
pt_outs = _inline_flatten_list([pt_outs], [])
pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False)
onnx_outs = _inline_flatten_list(onnx_outs, [])
assert len(onnx_outs) == len(
pt_outs_np
), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs_np)})"
pt_outs
), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})"
acceptable_error_percentage = options.acceptable_error_percentage
if acceptable_error_percentage and (
acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0
@ -265,7 +247,7 @@ def _compare_onnx_pytorch_outputs(
"If set, acceptable_error_percentage should be between 0.0 and 1.0"
)
for ort_out, pt_out in zip(onnx_outs, pt_outs_np):
for ort_out, pt_out in zip(onnx_outs, pt_outs):
try:
# TODO: Remove `check_shape` option once every shape inconsistent issue is addressed.
if not options.check_shape:
@ -298,6 +280,35 @@ def _compare_onnx_pytorch_outputs(
raise
@_beartype.beartype
def _compare_onnx_pytorch_outputs(
onnx_outs: _OutputsType,
pt_outs: Any,
options: VerificationOptions,
):
"""
Compare ONNX and PyTorch outputs.
Args:
onnx_outs: outputs from ONNX backend.
pt_outs: outputs from PyTorch.
options: options for verification.
Raises:
AssertionError: if outputs from ONNX model and PyTorch model are not
equal up to specified precision.
ValueError: if arguments provided are invalid.
"""
if options.ignore_none:
# torch.jit._flatten filters None type
pt_outs, _ = torch.jit._flatten(pt_outs)
else:
pt_outs = _inline_flatten_list([pt_outs], [])
pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False)
onnx_outs = _inline_flatten_list(onnx_outs, [])
_compare_onnx_pytorch_outputs_in_np(onnx_outs, pt_outs_np, options)
@_beartype.beartype
def _prepare_input_for_pytorch(args, kwargs):
"""Prepare input for PyTorch model execution.
@ -655,6 +666,101 @@ def _onnx_graph_from_model(
return onnx_graph
@_beartype.beartype
def _onnx_graph_from_aten_graph(
graph: torch.Graph,
export_options: _experimental.ExportOptions,
params_dict: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Graph, Dict[str, Any]]:
if params_dict is None:
params_dict = {}
operator_export_type = export_options.operator_export_type
dynamic_axes = export_options.dynamic_axes or {}
input_names = export_options.input_names
training = export_options.training
do_constant_folding = export_options.do_constant_folding
opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET
do_constant_folding = utils._decide_constant_folding(
do_constant_folding, operator_export_type, training
)
# TODO: Below is doing aten graph to onnx. It should be abstracted as a
# function in torch/onnx/utils.py.
graph = graph.copy()
graph = utils._optimize_graph(
graph,
operator_export_type,
params_dict=params_dict,
dynamic_axes=dynamic_axes,
input_names=input_names,
)
if training is None or training == _C_onnx.TrainingMode.EVAL:
params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
if (
do_constant_folding
and opset_version >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET
):
params_dict = _C._jit_pass_onnx_constant_fold(graph, params_dict, opset_version)
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
if GLOBALS.onnx_shape_inference:
_C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, opset_version)
params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
# For ONNX opset < 9, constants only have three data types: float16, float, double.
# In this pass transform constants of other data types to float/double + cast operator.
if opset_version < 9:
_C._jit_pass_onnx_cast_all_constant_to_floating(graph)
params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict)
_C._jit_decay_packed_param_input_types(graph)
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
if export_options.verbose:
print("ONNX graph: ", graph)
return graph, params_dict
@_beartype.beartype
def _onnx_proto_from_onnx_graph(
onnx_graph: torch.Graph,
export_options: _experimental.ExportOptions,
params_dict: Dict[str, Any],
) -> Tuple[bytes, Mapping[str, bytes]]:
opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET
dynamic_axes = export_options.dynamic_axes or {}
operator_export_type = export_options.operator_export_type
val_keep_init_as_ip = utils._decide_keep_init_as_input(
export_options.keep_initializers_as_inputs,
operator_export_type,
opset_version,
)
val_add_node_names = utils._decide_add_node_names(True, operator_export_type)
custom_opsets = export_options.custom_opsets or {}
proto, export_map, _, _ = onnx_graph._export_onnx( # type: ignore[attr-defined]
params_dict,
opset_version,
dynamic_axes,
False,
operator_export_type,
not export_options.verbose,
val_keep_init_as_ip,
custom_opsets,
val_add_node_names,
"",
{},
)
return proto, export_map
@_beartype.beartype
def check_export_model_diff(
model: Union[torch.nn.Module, torch.jit.ScriptModule],
@ -795,41 +901,14 @@ def verify_aten_graph(
export_options: _experimental.ExportOptions,
params_dict: Optional[Dict[str, Any]] = None,
verification_options: Optional[VerificationOptions] = None,
) -> Tuple[
Optional[AssertionError],
torch.Graph,
Union[_NumericType, Sequence[_NumericType]],
Union[_NumericType, Sequence[_NumericType]],
]:
) -> Tuple[Optional[AssertionError], torch.Graph, _OutputsType, _OutputsType]:
if verification_options is None:
verification_options = VerificationOptions()
original_jit_graph = graph
graph = graph.copy()
operator_export_type = export_options.operator_export_type
dynamic_axes = export_options.dynamic_axes
if dynamic_axes is None:
dynamic_axes = {}
input_names = export_options.input_names
training = export_options.training
do_constant_folding = export_options.do_constant_folding
opset_version = export_options.opset_version
if params_dict is None:
params_dict = {}
if opset_version is None:
opset_version = _constants.ONNX_DEFAULT_OPSET
val_keep_init_as_ip = utils._decide_keep_init_as_input(
export_options.keep_initializers_as_inputs,
operator_export_type,
opset_version,
)
val_add_node_names = utils._decide_add_node_names(True, operator_export_type)
do_constant_folding = utils._decide_constant_folding(
do_constant_folding, operator_export_type, training
)
original_jit_graph = graph
graph = graph.copy()
# Execute aten graph and get reference torch jit outputs.
graph_inputs = list(v for v in graph.inputs())
@ -840,69 +919,18 @@ def verify_aten_graph(
jit_inputs = copy.deepcopy(jit_inputs)
jit_input_and_parameters = jit_inputs + tuple(weights)
jit_outs = torch._C._jit_interpret_graph(graph, jit_input_and_parameters) # type: ignore[attr-defined]
if not isinstance(jit_outs, (list, tuple)):
jit_outs = [jit_outs]
# Convert aten graph to onnx.
graph = utils._optimize_graph(
graph,
operator_export_type,
params_dict=params_dict,
dynamic_axes=dynamic_axes,
input_names=input_names,
# Convert aten graph to onnx graph.
graph, onnx_params_dict = _onnx_graph_from_aten_graph(
graph, export_options, params_dict
)
# TODO(bowbao): Below is doing aten graph to onnx. It should be abstracted as a
# function in torch/onnx/utils.py.
if training is None or training == _C_onnx.TrainingMode.EVAL:
params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
if (
do_constant_folding
and GLOBALS.export_onnx_opset_version
>= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET
):
params_dict = _C._jit_pass_onnx_constant_fold(
graph, params_dict, GLOBALS.export_onnx_opset_version
)
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
if GLOBALS.onnx_shape_inference:
_C._jit_pass_onnx_graph_shape_type_inference(
graph, params_dict, GLOBALS.export_onnx_opset_version
)
params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
# For ONNX opset < 9, constants only have three data types: float16, float, double.
# In this pass transform constants of other data types to float/double + cast operator.
if GLOBALS.export_onnx_opset_version < 9:
_C._jit_pass_onnx_cast_all_constant_to_floating(graph)
params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict)
_C._jit_decay_packed_param_input_types(graph)
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
if export_options.verbose:
print("ONNX graph: ", graph)
custom_opsets = export_options.custom_opsets
if custom_opsets is None:
custom_opsets = {}
proto, export_map = _onnx_proto_from_onnx_graph(
graph, export_options, onnx_params_dict
)
model_f: Union[str, io.BytesIO] = io.BytesIO()
proto, export_map, _, _ = graph._export_onnx( # type: ignore[attr-defined]
params_dict,
opset_version,
dynamic_axes,
False,
operator_export_type,
not export_options.verbose,
val_keep_init_as_ip,
custom_opsets,
val_add_node_names,
"",
{},
)
export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
onnx_proto_utils._export_file(proto, model_f, export_type, export_map)
@ -933,7 +961,6 @@ def verify_aten_graph(
options=verification_options,
)
except AssertionError as e:
# print("Has mismatch: ", e)
return e, graph, jit_outs, onnx_outs
return None, graph, jit_outs, onnx_outs
@ -967,6 +994,7 @@ class GraphInfoPrettyPrinter:
self.upper_printer = None
self.lower_printer = None
@_beartype.beartype
def _total_rows(self) -> int:
if self.graph_info is None:
return 1
@ -976,6 +1004,7 @@ class GraphInfoPrettyPrinter:
)
return 2 # Two lines: node count + id.
@_beartype.beartype
def _node_count_segment_str(self) -> str:
if self.graph_info is None:
return "..."
@ -989,16 +1018,19 @@ class GraphInfoPrettyPrinter:
return f"{node_count} {'X' if has_mismatch else ''} {error_node_kind}"
@_beartype.beartype
def _graph_id_segment_str(self) -> str:
if self.graph_info is None:
return ""
return f"id: {self.graph_info.id}"
@_beartype.beartype
def _max_segment_columns(self) -> int:
return max(
map(len, (self._node_count_segment_str(), self._graph_id_segment_str()))
)
@_beartype.beartype
def _graph_segment_str_at_line(self, line: int) -> str:
"""Get the string representation of the graph segment at the given line."""
if line == 0:
@ -1013,6 +1045,7 @@ class GraphInfoPrettyPrinter:
return " " * self._max_segment_columns()
return ""
@_beartype.beartype
def _connector_segment_str_at_line(self, line: int) -> str:
"""Get the connector segment string at the given line."""
if self.upper_printer is None and self.lower_printer is None:
@ -1029,6 +1062,7 @@ class GraphInfoPrettyPrinter:
return " "
return ""
@_beartype.beartype
def _children_str_at_line(self, line: int) -> str:
"""Get the string representation of the children at the given line.
@ -1050,6 +1084,7 @@ class GraphInfoPrettyPrinter:
)
return ""
@_beartype.beartype
def _str_at_line(self, line: int) -> str:
"""Get the string representation of the graph at the given line."""
return (
@ -1067,25 +1102,97 @@ class GraphInfoPrettyPrinter:
total_rows = self._total_rows()
for line in range(total_rows):
print(self._str_at_line(line).rstrip())
# Summarize leaf subgraphs with mismatch.
print(" Mismatch leaf subgraphs: ".center(80, "="))
print(
[
graph_info.id
for graph_info in self.graph_info.all_mismatch_leaf_graph_info()
]
if self.graph_info.has_mismatch():
# Summarize leaf subgraphs with mismatch.
print(" Mismatch leaf subgraphs: ".center(80, "="))
print(
[
graph_info.id
for graph_info in self.graph_info.all_mismatch_leaf_graph_info()
]
)
# Summarize node kinds with mismatch.
mismatch_node_kinds: Dict[str, int] = {}
for graph_info in self.graph_info.all_mismatch_leaf_graph_info():
node_kinds = graph_info.essential_node_kinds()
if len(node_kinds) == 1:
node_kind = node_kinds.pop()
mismatch_node_kinds[node_kind] = (
mismatch_node_kinds.get(node_kind, 0) + 1
)
print(" Mismatch node kinds: ".center(80, "="))
print(mismatch_node_kinds)
else:
print(" No mismatch found. ".center(80, "="))
class OnnxTestCaseRepro:
def __init__(self, repro_dir):
self.repro_dir = repro_dir
self.proto, self.inputs, self.outputs = onnx_proto_utils.load_test_case(
repro_dir
)
# Summarize node kinds with mismatch.
mismatch_node_kinds: Dict[str, int] = {}
for graph_info in self.graph_info.all_mismatch_leaf_graph_info():
node_kinds = graph_info.essential_node_kinds()
if len(node_kinds) == 1:
node_kind = node_kinds.pop()
mismatch_node_kinds[node_kind] = (
mismatch_node_kinds.get(node_kind, 0) + 1
)
print(" Mismatch node kinds: ".center(80, "="))
print(mismatch_node_kinds)
@classmethod
@_beartype.beartype
def create_test_case_repro(
cls, proto: bytes, inputs, outputs, dir: str, name: Optional[str] = None
):
"""Create a repro under "{dir}/test_{name}" for an ONNX test case.
The test case contains the model and the inputs/outputs data. The directory
structure is as follows:
dir
├── test_<name>
│ ├── model.onnx
│ └── test_data_set_0
│ ├── input_0.pb
│ ├── input_1.pb
│ ├── output_0.pb
│ └── output_1.pb
Args:
proto: ONNX model proto.
inputs: Inputs to the model.
outputs: Outputs of the model.
dir: Directory to save the repro.
name: Name of the test case. If not specified, a name based on current time
will be generated.
Returns:
Path to the repro.
"""
if name is None:
name = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
return onnx_proto_utils.export_as_test_case(
proto,
_to_numpy(inputs),
_to_numpy(outputs),
name,
dir,
)
@_beartype.beartype
def validate(self, options: VerificationOptions):
"""Run the ONNX test case with options.backend, and compare with the expected outputs.
Args:
options: Options for validation.
Raise:
AssertionError: if outputs from options.backend and expected outputs are not
equal up to specified precision.
"""
onnx_session = _onnx_backend_session(io.BytesIO(self.proto), options.backend)
run_outputs = onnx_session.run(None, self.inputs)
if hasattr(onnx_session, "get_outputs"):
output_names = [o.name for o in onnx_session.get_outputs()]
elif hasattr(onnx_session, "output_names"):
output_names = onnx_session.output_names
else:
raise ValueError(f"Unknown onnx session type: {type(onnx_session)}")
expected_outs = [self.outputs[name] for name in output_names]
_compare_onnx_pytorch_outputs_in_np(run_outputs, expected_outs, options)
@dataclasses.dataclass
@ -1099,7 +1206,7 @@ class GraphInfo:
mismatch_error: Optional[AssertionError] = dataclasses.field(
default=None, init=False
)
pt_outs: Optional[Union[_NumericType, Sequence[_NumericType]]] = dataclasses.field(
pt_outs: Optional[Sequence[_NumericType]] = dataclasses.field(
default=None, init=False
)
upper_graph_info: Optional[GraphInfo] = dataclasses.field(default=None, init=False)
@ -1175,16 +1282,19 @@ class GraphInfo:
else:
print(" No mismatch ".center(80, "="))
def has_mismatch(self):
@_beartype.beartype
def has_mismatch(self) -> bool:
"""Return True if the subgraph has output mismatch between torch and ONNX."""
return self.mismatch_error is not None
@_beartype.beartype
def essential_node_count(self) -> int:
"""Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`."""
return sum(
1 for n in self.graph.nodes() if n.kind() not in self._EXCLUDED_NODE_KINDS
)
@_beartype.beartype
def essential_node_kinds(self) -> Set[str]:
"""Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`."""
return {
@ -1193,7 +1303,8 @@ class GraphInfo:
if n.kind() not in self._EXCLUDED_NODE_KINDS
}
def all_mismatch_leaf_graph_info(self) -> List[GraphInfo]:
@_beartype.beartype
def all_mismatch_leaf_graph_info(self) -> List["GraphInfo"]:
"""Return a list of all leaf `GraphInfo` objects that have mismatch."""
if not self.has_mismatch():
return []
@ -1215,7 +1326,8 @@ class GraphInfo:
return results
def find_partition(self, id: str) -> Optional[GraphInfo]:
@_beartype.beartype
def find_partition(self, id: str) -> Optional["GraphInfo"]:
"""Find the `GraphInfo` object with the given id."""
if id == self.id:
return self
@ -1227,6 +1339,48 @@ class GraphInfo:
return self.lower_graph_info.find_partition(id)
return None
@_beartype.beartype
def export_repro(
self, repro_dir: Optional[str] = None, name: Optional[str] = None
) -> str:
"""Export the subgraph to ONNX along with the input/output data for repro.
The repro directory will contain the following files:
dir
├── test_<name>
│ ├── model.onnx
│ └── test_data_set_0
│ ├── input_0.pb
│ ├── input_1.pb
│ ├── output_0.pb
│ └── output_1.pb
Args:
repro_dir: The directory to export the repro files to. Defaults to current
working directory if None.
name: An optional name for the test case folder: "test_{name}".
Returns:
The path to the exported repro directory.
"""
if repro_dir is None:
repro_dir = os.getcwd()
repro_dir = os.path.join(repro_dir, "onnx_debug")
onnx_graph, onnx_params_dict = _onnx_graph_from_aten_graph(
self.graph, self.export_options, self.params_dict
)
proto, _ = _onnx_proto_from_onnx_graph(
onnx_graph, self.export_options, onnx_params_dict
)
return OnnxTestCaseRepro.create_test_case_repro(
proto, self.input_args, self.pt_outs, repro_dir, name
)
@_beartype.beartype
def _graph_partition_pivot(self) -> int:
"""Find the pivot index to partition the graph.
@ -1249,6 +1403,7 @@ class GraphInfo:
return included_node_indices[half_idx] + 1
return -1
@_beartype.beartype
def _partition_upper_graph(self) -> torch.Graph:
pivot = self._graph_partition_pivot()
if pivot == -1:
@ -1292,6 +1447,7 @@ class GraphInfo:
return graph
@_beartype.beartype
def _partition_lower_graph(self) -> torch.Graph:
pivot = self._graph_partition_pivot()
if pivot == -1:
@ -1346,6 +1502,7 @@ class GraphInfo:
return graph
@_beartype.beartype
def _partition_node(
self,
node: torch.Node,
@ -1384,6 +1541,7 @@ class GraphInfo:
):
covered_bridge_values.add(process_bridge_value(output))
@_beartype.beartype
def _partition_nodes(
self,
graph: torch.Graph,
@ -1423,18 +1581,17 @@ class GraphInfo:
complete_lower_nodes_set,
)
@_beartype.beartype
def _bridge_kwargs(self):
pt_outs = self.pt_outs
if pt_outs is None:
raise RuntimeError("pt_outs is not set")
if not isinstance(pt_outs, (list, tuple)):
pt_outs = [pt_outs] # type: ignore[list-item]
graph_outputs = list(self.graph.outputs())
assert pt_outs is not None
assert len(graph_outputs) == len(
pt_outs
), f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}"
return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)}
@_beartype.beartype
def _args_and_params_for_partition_graph(
self,
graph: torch.Graph,
@ -1451,14 +1608,10 @@ class GraphInfo:
), f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}"
return args, params
@_beartype.beartype
def verify_export(
self, options: VerificationOptions
) -> Tuple[
Optional[AssertionError],
torch.Graph,
Union[_NumericType, Sequence[_NumericType]],
Union[_NumericType, Sequence[_NumericType]],
]:
) -> Tuple[Optional[AssertionError], torch.Graph, _OutputsType, _OutputsType]:
return verify_aten_graph(
self.graph,
input_args=self.input_args,
@ -1467,6 +1620,7 @@ class GraphInfo:
verification_options=options,
)
@_beartype.beartype
def find_mismatch(
self,
options: Optional[VerificationOptions] = None,
@ -1536,6 +1690,7 @@ class GraphInfo:
self.lower_graph_info.find_mismatch(options)
@_beartype.beartype
def _all_nodes(nodes: Collection[torch.Node]) -> Set[torch.Node]:
all_nodes = set(nodes)
for n in nodes:
@ -1544,12 +1699,14 @@ def _all_nodes(nodes: Collection[torch.Node]) -> Set[torch.Node]:
return all_nodes
@_beartype.beartype
def _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
if any(use.user in nodes for use in value.uses()):
return True
return False
@_beartype.beartype
def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool:
for output in node.outputs():
if _has_uses_by_nodes(output, nodes):
@ -1557,6 +1714,7 @@ def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool:
return False
@_beartype.beartype
def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
return value.node() in nodes