mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: See https://github.com/pytorch/pytorch/issues/42919 Pull Request resolved: https://github.com/pytorch/pytorch/pull/53349 Reviewed By: malfet Differential Revision: D27039089 Pulled By: bugra fbshipit-source-id: 8063dc184248604506a8dbb1bcb73da8ec85bb18
73 lines
3.1 KiB
Python
73 lines
3.1 KiB
Python
import unittest
|
|
import torch
|
|
|
|
import copy
|
|
|
|
import test_pytorch_onnx_onnxruntime
|
|
from test_pytorch_onnx_onnxruntime import TestONNXRuntime
|
|
from torch.onnx import utils, OperatorExportTypes, TrainingMode
|
|
from torch.onnx.utils import _validate_dynamic_axes
|
|
from torch.onnx.symbolic_helper import (_set_opset_version, _set_operator_export_type,
|
|
_set_onnx_shape_inference, _set_training_mode,
|
|
_is_tensor_list, _is_tensor, _is_none)
|
|
|
|
|
|
def verify_inferred_shape(graph):
|
|
# Check every node in graph has type properly assigned.
|
|
for n in graph.nodes():
|
|
for out in n.outputs():
|
|
if not _is_tensor_list(out) and not _is_tensor(out) and not _is_none(out):
|
|
raise RuntimeError("Output of node is neither type Tensor nor type list of Tensor: ", out)
|
|
if _is_tensor(out) and out.type().scalarType() is None:
|
|
raise RuntimeError("Output of node does not have type assigned", out)
|
|
if _is_tensor(out) and out.type().dim() is None:
|
|
raise RuntimeError("Output of node does not have shape assigned", out)
|
|
|
|
|
|
def run_model_test(self, model, batch_size=2, state_dict=None,
|
|
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
|
|
example_outputs=None, do_constant_folding=True,
|
|
dynamic_axes=None, test_with_inputs=None,
|
|
input_names=None, output_names=None,
|
|
fixed_batch_size=False):
|
|
model.eval()
|
|
|
|
if input is None:
|
|
input = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
|
|
|
|
with torch.no_grad():
|
|
if isinstance(input, torch.Tensor):
|
|
input = (input,)
|
|
# In-place operators will update input tensor data as well.
|
|
# Thus inputs are replicated before every forward call.
|
|
input_copy = copy.deepcopy(input)
|
|
output = model(*input_copy)
|
|
if isinstance(output, torch.Tensor):
|
|
output = (output,)
|
|
|
|
_set_opset_version(self.opset_version)
|
|
_set_operator_export_type(OperatorExportTypes.ONNX)
|
|
_set_onnx_shape_inference(True)
|
|
_set_training_mode(False)
|
|
if dynamic_axes is None:
|
|
dynamic_axes = {}
|
|
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
|
|
|
|
input_copy = copy.deepcopy(input)
|
|
graph, _, _ = utils._model_to_graph(model, input_copy,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
operator_export_type=OperatorExportTypes.ONNX,
|
|
example_outputs=output,
|
|
do_constant_folding=do_constant_folding,
|
|
training=TrainingMode.EVAL,
|
|
dynamic_axes=dynamic_axes)
|
|
verify_inferred_shape(graph)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
TestONNXRuntime.opset_version = 12
|
|
test_pytorch_onnx_onnxruntime.run_model_test = run_model_test
|
|
|
|
unittest.main()
|