Files
pytorch/test/onnx/test_utility_funs.py
Lara Haidar f4d9bfaa4d Support Exports to Multiple ONNX Opset (#19294)
Summary:
Support exporting multiple ONNX opsets (more specifically opset 10 for now), following the proposal in https://gist.github.com/spandantiwari/99700e60919c43bd167838038d20f353.
And add support for custom ops (merge with https://github.com/pytorch/pytorch/pull/18297).

This PR will be followed by another PR containing the changes related to testing the ops for different opsets.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19294

Reviewed By: zrphercule

Differential Revision: D15043951

Pulled By: houseroad

fbshipit-source-id: d336fc35b8827145639137bc348ae07e3c14bb1c
2019-05-10 18:37:12 -07:00

166 lines
6.1 KiB
Python

from __future__ import absolute_import, division, print_function, unicode_literals
from test_pytorch_common import TestCase, run_tests
import torch
import torch.onnx
from torch.onnx import utils
from torch.onnx.symbolic_helper import _set_opset_version
import onnx
import io
import copy
class TestUtilityFuns(TestCase):
def test_is_in_onnx_export(self):
test_self = self
class MyModule(torch.nn.Module):
def forward(self, x):
test_self.assertTrue(torch.onnx.is_in_onnx_export())
raise ValueError
return x + 1
x = torch.randn(3, 4)
f = io.BytesIO()
try:
torch.onnx.export(MyModule(), x, f)
except ValueError:
self.assertFalse(torch.onnx.is_in_onnx_export())
def test_constant_fold_transpose(self):
class TransposeModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
b = torch.transpose(a, 1, 0)
return b + x
_set_opset_version(9)
x = torch.ones(3, 2)
graph, _, __ = utils._model_to_graph(TransposeModule(), (x, ),
do_constant_folding=True,
_disable_torch_constant_prop=True)
for node in graph.nodes():
assert node.kind() != "onnx::Transpose"
assert node.kind() != "onnx::Cast"
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
def test_constant_fold_slice(self):
class SliceModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
b = torch.narrow(a, 0, 0, 1)
return b + x
_set_opset_version(9)
x = torch.ones(1, 3)
graph, _, __ = utils._model_to_graph(SliceModule(), (x, ),
do_constant_folding=True,
_disable_torch_constant_prop=True)
for node in graph.nodes():
assert node.kind() != "onnx::Slice"
assert node.kind() != "onnx::Cast"
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
def test_constant_fold_unsqueeze(self):
class UnsqueezeModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
b = torch.unsqueeze(a, 0)
return b + x
_set_opset_version(9)
x = torch.ones(1, 2, 3)
graph, _, __ = utils._model_to_graph(UnsqueezeModule(), (x, ),
do_constant_folding=True,
_disable_torch_constant_prop=True)
for node in graph.nodes():
assert node.kind() != "onnx::Unsqueeeze"
assert node.kind() != "onnx::Cast"
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
def test_constant_fold_concat(self):
class ConcatModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1., 2., 3.]])
b = torch.tensor([[4., 5., 6.]])
c = torch.cat((a, b), 0)
return b + c
_set_opset_version(9)
x = torch.ones(2, 3)
graph, _, __ = utils._model_to_graph(ConcatModule(), (x, ),
do_constant_folding=True,
_disable_torch_constant_prop=True)
for node in graph.nodes():
assert node.kind() != "onnx::Concat"
assert node.kind() != "onnx::Cast"
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
def test_constant_fold_lstm(self):
class GruNet(torch.nn.Module):
def __init__(self):
super(GruNet, self).__init__()
self.mygru = torch.nn.GRU(7, 3, 1, bidirectional=False)
def forward(self, input, initial_state):
return self.mygru(input, initial_state)
_set_opset_version(9)
input = torch.randn(5, 3, 7)
h0 = torch.randn(1, 3, 3)
graph, _, __ = utils._model_to_graph(GruNet(), (input, h0),
do_constant_folding=True)
for node in graph.nodes():
assert node.kind() != "onnx::Slice"
assert node.kind() != "onnx::Concat"
assert node.kind() != "onnx::Unsqueeze"
assert len(list(graph.nodes())) == 3
def test_constant_fold_transpose_matmul(self):
class MatMulNet(torch.nn.Module):
def __init__(self):
super(MatMulNet, self).__init__()
self.B = torch.nn.Parameter(torch.ones(5, 3))
def forward(self, A):
return torch.matmul(A, torch.transpose(self.B, -1, -2))
_set_opset_version(9)
A = torch.randn(2, 3)
graph, _, __ = utils._model_to_graph(MatMulNet(), (A),
do_constant_folding=True)
for node in graph.nodes():
assert node.kind() != "onnx::Transpose"
assert len(list(graph.nodes())) == 1
def test_strip_doc_string(self):
class MyModule(torch.nn.Module):
def forward(self, input):
return torch.exp(input)
x = torch.randn(3, 4)
def is_model_stripped(f, strip_doc_string=None):
if strip_doc_string is None:
torch.onnx.export(MyModule(), x, f)
else:
torch.onnx.export(MyModule(), x, f, strip_doc_string=strip_doc_string)
model = onnx.load(io.BytesIO(f.getvalue()))
model_strip = copy.copy(model)
onnx.helper.strip_doc_string(model_strip)
return model == model_strip
# test strip_doc_string=True (default)
self.assertTrue(is_model_stripped(io.BytesIO()))
# test strip_doc_string=False
self.assertFalse(is_model_stripped(io.BytesIO(), False))
if __name__ == '__main__':
run_tests()