mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PyTorch export to ONNX Opset 7 and 8 - Cont (#22421)
Summary: This is an extension to the original PR https://github.com/pytorch/pytorch/pull/21765 1. Increase the coverage of different opsets support, comments, and blacklisting. 2. Adding backend tests for both caffe2 and onnxruntime on opset 7 and opset 8. 3. Reusing onnx model tests in caffe2 for onnxruntime. Pull Request resolved: https://github.com/pytorch/pytorch/pull/22421 Reviewed By: zrphercule Differential Revision: D16225518 Pulled By: houseroad fbshipit-source-id: 01ae3eed85111a83a0124e9e95512b80109d6aee
This commit is contained in:
committed by
Facebook Github Bot
parent
9f8e2c067f
commit
b3147bc674
@ -55,6 +55,7 @@ pytest "${args[@]}" \
|
||||
'not (TestOperators and test_full_like) and not (TestOperators and test_zeros_like) and not (TestOperators and test_ones_like) and not (TestModels and test_vgg16) and not (TestModels and test_vgg16_bn) and not (TestModels and test_vgg19) and not (TestModels and test_vgg19_bn)' \
|
||||
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py" \
|
||||
--ignore "$top_dir/test/onnx/test_custom_ops.py" \
|
||||
--ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \
|
||||
"${test_paths[@]}"
|
||||
|
||||
# onnxruntime only support py3
|
||||
@ -63,5 +64,6 @@ if [[ "$BUILD_ENVIRONMENT" == *py3* ]]; then
|
||||
pip install --user onnxruntime
|
||||
pytest "${args[@]}" "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py"
|
||||
pytest "${args[@]}" "$top_dir/test/onnx/test_custom_ops.py"
|
||||
pytest "${args[@]}" "$top_dir/test/onnx/test_models_onnxruntime.py"
|
||||
fi
|
||||
|
||||
|
||||
23
test/onnx/test_models_onnxruntime.py
Normal file
23
test/onnx/test_models_onnxruntime.py
Normal file
@ -0,0 +1,23 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
import onnxruntime # noqa
|
||||
|
||||
from test_models import TestModels
|
||||
from test_pytorch_onnx_onnxruntime import run_model_test
|
||||
|
||||
|
||||
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
|
||||
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10]
|
||||
for opset_version in opset_versions:
|
||||
self.opset_version = opset_version
|
||||
run_model_test(self, model, False,
|
||||
input=inputs, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
TestModels.exportTest = exportTest
|
||||
unittest.main()
|
||||
@ -132,6 +132,45 @@ class TestONNXOpset(TestCase):
|
||||
x = torch.randn(20, 16, 50)
|
||||
check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
|
||||
|
||||
def test_upsample(self):
|
||||
class MyModule(Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
size = [v * 2 for v in x.size()[2:]]
|
||||
size = [int(i) for i in size]
|
||||
return torch.nn.functional.interpolate(x, size=size, mode='nearest')
|
||||
|
||||
module = MyModule()
|
||||
ops8 = [{"op_name" : "Upsample", "attributes" : [{"name": "mode", "s": ("nearest").encode(), "type": 3},
|
||||
{"name": "scales", "floats": [1.0, 1.0, 2.0, 2.0], "type": 6}]}]
|
||||
ops9 = [{"op_name" : "Constant"},
|
||||
{"op_name" : "Upsample", "attributes" : [{"name": "mode", "s": ("nearest").encode(), "type": 3}]}]
|
||||
ops = {8 : ops8, 9 : ops9}
|
||||
x = torch.randn(2, 2, 2, 2)
|
||||
check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
|
||||
|
||||
def test_cast_constant(self):
|
||||
class MyModule(Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch._dim_arange(x, 1)
|
||||
|
||||
module = MyModule()
|
||||
ops_8 = [{"op_name" : "Shape"}, {"op_name" : "Constant"},
|
||||
{"op_name" : "Cast", "attributes": [{"name": "to", "i": 7, "type": 2}]},
|
||||
{"op_name" : "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
|
||||
{"op_name" : "Range"}]
|
||||
ops_9 = [{"op_name" : "Shape"}, {"op_name" : "Constant"},
|
||||
{"op_name" : "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
|
||||
{"op_name" : "Range"}]
|
||||
ops = {8 : ops_8, 9 : ops_9}
|
||||
x = torch.ones(5, 6)
|
||||
check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
|
||||
|
||||
def test_slice(self):
|
||||
class MyModule(Module):
|
||||
def forward(self, x):
|
||||
|
||||
@ -40,7 +40,7 @@ import onnx
|
||||
import caffe2.python.onnx.backend as c2
|
||||
|
||||
from test_pytorch_common import skipIfTravis, skipIfNoLapack, skipIfNoCuda
|
||||
from test_pytorch_common import skipIfUnsupportedOpsetVersion
|
||||
from test_pytorch_common import skipIfUnsupportedOpsetVersion, skipIfUnsupportedMinOpsetVersion
|
||||
import verify
|
||||
|
||||
skip = unittest.skip
|
||||
@ -664,6 +664,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
input = torch.empty(BATCH_SIZE, 10, 10).uniform_(4, 9)
|
||||
self.run_model_test(MyModel(), train=False, input=input, batch_size=BATCH_SIZE)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_erf(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -807,16 +808,19 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
x = torch.randn(20, 16, 50, 44, 30, requires_grad=True)
|
||||
self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(8)
|
||||
def test_adaptive_max_pool1D(self):
|
||||
model = torch.nn.AdaptiveMaxPool1d((5))
|
||||
x = torch.randn(20, 16, 50, requires_grad=True)
|
||||
self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(8)
|
||||
def test_adaptive_max_pool2D(self):
|
||||
model = torch.nn.AdaptiveMaxPool2d((5, 4))
|
||||
x = torch.randn(20, 16, 50, 32, requires_grad=True)
|
||||
self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(8)
|
||||
def test_adaptive_max_pool3D(self):
|
||||
model = torch.nn.AdaptiveMaxPool3d((5, 4, 3))
|
||||
x = torch.randn(20, 16, 50, 44, 30, requires_grad=True)
|
||||
@ -993,7 +997,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
self.run_model_test(model, train=False, input=(x),
|
||||
batch_size=BATCH_SIZE, use_gpu=False)
|
||||
|
||||
@skipIfUnsupportedOpsetVersion([10])
|
||||
@skipIfUnsupportedOpsetVersion([7, 8, 10])
|
||||
def test_interpolate_upsample_dynamic_sizes(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -1291,6 +1295,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
self.run_model_test(FullClass(), train=False, input=(x,), batch_size=BATCH_SIZE,
|
||||
use_gpu=False, example_outputs=FullClass()(x))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_where_functional(self):
|
||||
class WhereFunctional(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -1299,6 +1304,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
x = torch.randn(3, 4)
|
||||
self.run_model_test(WhereFunctional(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_where_method(self):
|
||||
class WhereMethod(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -1353,6 +1359,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
self.run_model_test(RsubModel(), train=False, input=(x,),
|
||||
batch_size=BATCH_SIZE, use_gpu=False)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_isnan(self):
|
||||
class IsNaNModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
@ -1361,6 +1368,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
x = torch.tensor([1.0, float('nan'), 2.0])
|
||||
self.run_model_test(IsNaNModel(), train=False, input=x, batch_size=BATCH_SIZE, use_gpu=False)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_scatter(self):
|
||||
class ScatterModel(torch.nn.Module):
|
||||
def forward(self, input, indices, values):
|
||||
@ -1396,6 +1404,23 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
self.run_model_test(MaxModel(), train=False, input=x, batch_size=BATCH_SIZE)
|
||||
|
||||
def test_max_keepdim(self):
|
||||
class MaxModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return torch.max(input, dim=1, keepdim=True)
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
self.run_model_test(MaxModel(), train=False, input=x, batch_size=BATCH_SIZE)
|
||||
|
||||
def test_max_tensors(self):
|
||||
class MaxModel(torch.nn.Module):
|
||||
def forward(self, input, other):
|
||||
return torch.max(input, other)
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
y = torch.randn(4, 4, requires_grad=True)
|
||||
self.run_model_test(MaxModel(), train=False, input=(x, y), batch_size=BATCH_SIZE)
|
||||
|
||||
def test_min(self):
|
||||
class MinModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
@ -1841,6 +1866,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
x = torch.randn(1, 2, 3)
|
||||
self.run_model_test(DropoutModel(), train=False, input=x, batch_size=BATCH_SIZE)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_while(self):
|
||||
class WhileModel(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
@ -1901,6 +1927,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_nested_loops(self):
|
||||
class NestedLoopsModel(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
@ -2013,6 +2040,24 @@ TestCaffe2BackendEmbed_opset9 = type(str("TestCaffe2BackendEmbed_opset9"),
|
||||
(unittest.TestCase,),
|
||||
dict(TestCaffe2Backend_opset9.__dict__, embed_params=True))
|
||||
|
||||
# opset 7 tests
|
||||
TestCaffe2Backend_opset7 = type(str("TestCaffe2Backend_opset7"),
|
||||
(unittest.TestCase,),
|
||||
dict(TestCaffe2Backend_opset9.__dict__, opset_version=7))
|
||||
TestCaffe2BackendEmbed_opset7 = type(str("TestCaffe2BackendEmbed_opset7"),
|
||||
(unittest.TestCase,),
|
||||
dict(TestCaffe2Backend_opset9.__dict__,
|
||||
embed_params=True, opset_version=7))
|
||||
|
||||
# opset 8 tests
|
||||
TestCaffe2Backend_opset8 = type(str("TestCaffe2Backend_opset8"),
|
||||
(unittest.TestCase,),
|
||||
dict(TestCaffe2Backend_opset9.__dict__, opset_version=8))
|
||||
TestCaffe2BackendEmbed_opset8 = type(str("TestCaffe2BackendEmbed_opset8"),
|
||||
(unittest.TestCase,),
|
||||
dict(TestCaffe2Backend_opset9.__dict__,
|
||||
embed_params=True, opset_version=8))
|
||||
|
||||
# opset 10 tests
|
||||
TestCaffe2Backend_opset10 = type(str("TestCaffe2Backend_opset10"),
|
||||
(unittest.TestCase,),
|
||||
|
||||
@ -6,58 +6,97 @@ from __future__ import unicode_literals
|
||||
import unittest
|
||||
import onnxruntime # noqa
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import io
|
||||
|
||||
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion
|
||||
|
||||
|
||||
import model_defs.word_language_model as word_language_model
|
||||
|
||||
|
||||
def run_model_test(self, model, train, batch_size=2, state_dict=None,
|
||||
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
|
||||
example_outputs=None, do_constant_folding=True):
|
||||
model.eval()
|
||||
|
||||
if input is None:
|
||||
input = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
|
||||
|
||||
with torch.no_grad():
|
||||
if isinstance(input, torch.Tensor):
|
||||
input = (input,)
|
||||
output = model(*input)
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = (output,)
|
||||
|
||||
# export the model to ONNX
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(model, input, f,
|
||||
opset_version=self.opset_version,
|
||||
example_outputs=output)
|
||||
|
||||
input, _ = torch.jit._flatten(input)
|
||||
output, _ = torch.jit._flatten(output)
|
||||
|
||||
def to_numpy(tensor):
|
||||
if tensor.requires_grad:
|
||||
return tensor.detach().cpu().numpy()
|
||||
else:
|
||||
return tensor.cpu().numpy()
|
||||
|
||||
inputs = list(map(to_numpy, input))
|
||||
outputs = list(map(to_numpy, output))
|
||||
|
||||
# compute onnxruntime output prediction
|
||||
ort_sess = onnxruntime.InferenceSession(f.getvalue())
|
||||
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs))
|
||||
ort_outs = ort_sess.run(None, ort_inputs)
|
||||
|
||||
# compare onnxruntime and PyTorch results
|
||||
assert len(outputs) == len(ort_outs), "number of outputs differ"
|
||||
|
||||
# compare onnxruntime and PyTorch results
|
||||
[np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
|
||||
|
||||
|
||||
class TestONNXRuntime(unittest.TestCase):
|
||||
from torch.onnx.symbolic_helper import _export_onnx_opset_version
|
||||
opset_version = _export_onnx_opset_version
|
||||
|
||||
def run_test(self, model, inputs, rtol=1e-05, atol=1e-08):
|
||||
outputs = model(inputs) if isinstance(inputs, torch.Tensor) else model(*inputs)
|
||||
def run_test(self, model, input, rtol=1e-3, atol=1e-7):
|
||||
run_model_test(self, model, False, None,
|
||||
input=input, rtol=rtol, atol=atol)
|
||||
|
||||
# export the model to ONNX
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(model, inputs, f,
|
||||
opset_version=self.opset_version,
|
||||
example_outputs=outputs)
|
||||
def run_word_language_model(self, model_name):
|
||||
ntokens = 50
|
||||
emsize = 5
|
||||
nhid = 5
|
||||
nlayers = 5
|
||||
dropout = 0.2
|
||||
tied = False
|
||||
batchsize = 5
|
||||
model = word_language_model.RNNModel(model_name, ntokens, emsize,
|
||||
nhid, nlayers, dropout, tied,
|
||||
batchsize)
|
||||
x = torch.arange(0, ntokens).long().view(-1, batchsize)
|
||||
# Only support CPU version, since tracer is not working in GPU RNN.
|
||||
self.run_test(model, (x, model.hidden))
|
||||
|
||||
def get_numpy_value_at_index(t, i):
|
||||
return t[i].detach().numpy() if t[i].requires_grad else t[i].numpy()
|
||||
def test_word_language_model_RNN_TANH(self):
|
||||
self.run_word_language_model("RNN_TANH")
|
||||
|
||||
def get_numpy_value(t):
|
||||
return t.detach().numpy() if t.requires_grad else t.numpy()
|
||||
def test_word_language_model_RNN_RELU(self):
|
||||
self.run_word_language_model("RNN_RELU")
|
||||
|
||||
def get_ort_inputs():
|
||||
ort_inputs = {}
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
ort_inputs = {ort_sess.get_inputs()[0].name: get_numpy_value(inputs)}
|
||||
else:
|
||||
for i in range(0, len(outputs)):
|
||||
ort_inputs[ort_sess.get_inputs()[i].name] = get_numpy_value_at_index(inputs, i)
|
||||
return ort_inputs
|
||||
def test_word_language_model_LSTM(self):
|
||||
self.run_word_language_model("LSTM")
|
||||
|
||||
# compute onnxruntime output prediction
|
||||
ort_sess = onnxruntime.InferenceSession(f.getvalue())
|
||||
ort_inputs = get_ort_inputs()
|
||||
ort_outs = ort_sess.run(None, ort_inputs)
|
||||
|
||||
# compare onnxruntime and PyTorch results
|
||||
assert (isinstance(outputs, torch.Tensor) and len(ort_outs) == 1) or \
|
||||
len(outputs) == len(ort_outs), \
|
||||
"number of outputs differ"
|
||||
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
assert np.allclose(get_numpy_value(outputs), ort_outs[0],
|
||||
rtol=rtol, atol=atol)
|
||||
else :
|
||||
for i in range(0, len(outputs)):
|
||||
assert np.allclose(get_numpy_value_at_index(outputs, i), ort_outs[i],
|
||||
rtol=rtol, atol=atol)
|
||||
def test_word_language_model_GRU(self):
|
||||
self.run_word_language_model("GRU")
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_full_trace(self):
|
||||
class FullModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -66,6 +105,7 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
x = torch.tensor(12)
|
||||
self.run_test(FullModel(), x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_full_script(self):
|
||||
class FullModelScripting(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
@ -80,6 +120,12 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
x = torch.randn(20, 16, 50)
|
||||
self.run_test(model, x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(8)
|
||||
def test_maxpool_with_indices(self):
|
||||
model = torch.nn.MaxPool1d(2, stride=1, return_indices=True)
|
||||
x = torch.randn(20, 16, 50)
|
||||
self.run_test(model, x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_maxpool_dilation(self):
|
||||
model = torch.nn.MaxPool1d(2, stride=1, dilation=2)
|
||||
@ -117,13 +163,23 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
x = torch.tensor(np.arange(6.0).reshape(2, 3))
|
||||
self.run_test(MyModule(), x)
|
||||
|
||||
def test_interpolate(self):
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_interpolate_scale(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.interpolate(x, mode="nearest", scale_factor=2)
|
||||
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
||||
self.run_test(MyModel(), x)
|
||||
|
||||
# NOTE: Supported in onnxruntime master, enable this after 0.5 release.
|
||||
@skipIfUnsupportedOpsetVersion([10])
|
||||
def test_interpolate_output_size(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.interpolate(x, mode="nearest", size=(6, 8))
|
||||
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
||||
self.run_test(MyModel(), x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_interpolate_downsample(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
@ -168,10 +224,162 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
self.run_test(ReduceLogSumExpModel(), x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(8)
|
||||
def test_adaptive_max_pool(self):
|
||||
model = torch.nn.AdaptiveMaxPool1d((5), return_indices=False)
|
||||
x = torch.randn(20, 16, 50, requires_grad=True)
|
||||
self.run_test(model, x)
|
||||
|
||||
def test_maxpool_2d(self):
|
||||
model = torch.nn.MaxPool2d(5, padding=(1, 2))
|
||||
x = torch.randn(1, 20, 16, 50, requires_grad=True)
|
||||
self.run_test(model, x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(8)
|
||||
def test_max_tensors(self):
|
||||
class MaxModel(torch.nn.Module):
|
||||
def forward(self, input, other):
|
||||
return torch.max(input, other)
|
||||
|
||||
model = MaxModel()
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
y = torch.randn(4, 1, requires_grad=True)
|
||||
self.run_test(model, (x, y))
|
||||
|
||||
def test_gt(self):
|
||||
class GreaterModel(torch.nn.Module):
|
||||
def forward(self, input, other):
|
||||
return input > other
|
||||
|
||||
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
||||
y = torch.randn(1, 2, 3, 4, requires_grad=True)
|
||||
self.run_test(GreaterModel(), (x, y))
|
||||
|
||||
x = torch.randint(10, (3, 4), dtype=torch.int32)
|
||||
y = torch.randint(10, (3, 4), dtype=torch.int32)
|
||||
self.run_test(GreaterModel(), (x, y))
|
||||
|
||||
def test_gt_scalar(self):
|
||||
class GreaterModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return input > 1
|
||||
|
||||
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
||||
self.run_test(GreaterModel(), x)
|
||||
|
||||
x = torch.randint(10, (3, 4), dtype=torch.int32)
|
||||
self.run_test(GreaterModel(), x)
|
||||
|
||||
def test_lt(self):
|
||||
class LessModel(torch.nn.Module):
|
||||
def forward(self, input, other):
|
||||
return input > other
|
||||
|
||||
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
||||
y = torch.randn(1, 2, 3, 4, requires_grad=True)
|
||||
self.run_test(LessModel(), (x, y))
|
||||
|
||||
x = torch.randint(10, (3, 4), dtype=torch.int32)
|
||||
y = torch.randint(10, (3, 4), dtype=torch.int32)
|
||||
self.run_test(LessModel(), (x, y))
|
||||
|
||||
def test_matmul(self):
|
||||
class MatmulModel(torch.nn.Module):
|
||||
def forward(self, input, other):
|
||||
return torch.matmul(input, other)
|
||||
|
||||
x = torch.randn(3, 4, requires_grad=True)
|
||||
y = torch.randn(4, 5, requires_grad=True)
|
||||
self.run_test(MatmulModel(), (x, y))
|
||||
|
||||
x = torch.randint(10, (3, 4))
|
||||
y = torch.randint(10, (4, 5))
|
||||
self.run_test(MatmulModel(), (x, y))
|
||||
|
||||
def test_matmul_batch(self):
|
||||
class MatmulModel(torch.nn.Module):
|
||||
def forward(self, input, other):
|
||||
return torch.matmul(input, other)
|
||||
|
||||
x = torch.randn(2, 3, 4, requires_grad=True)
|
||||
y = torch.randn(2, 4, 5, requires_grad=True)
|
||||
self.run_test(MatmulModel(), (x, y))
|
||||
|
||||
x = torch.randint(10, (2, 3, 4))
|
||||
y = torch.randint(10, (2, 4, 5))
|
||||
self.run_test(MatmulModel(), (x, y))
|
||||
|
||||
def test_view(self):
|
||||
class ViewModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return input.view(4, 24)
|
||||
|
||||
x = torch.randint(10, (4, 2, 3, 4), dtype=torch.int32)
|
||||
self.run_test(ViewModel(), x)
|
||||
|
||||
def test_flatten(self):
|
||||
class FlattenModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return torch.flatten(input)
|
||||
|
||||
x = torch.randint(10, (1, 2, 3, 4))
|
||||
self.run_test(FlattenModel(), x)
|
||||
|
||||
def test_flatten2d(self):
|
||||
class FlattenModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return torch.flatten(input, 1)
|
||||
|
||||
x = torch.randint(10, (1, 2, 3, 4))
|
||||
self.run_test(FlattenModel(), x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_tensor_factories(self):
|
||||
class TensorFactory(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.zeros(x.size()) + torch.ones(x.size())
|
||||
|
||||
x = torch.randn(2, 3, 4)
|
||||
self.run_test(TensorFactory(), x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_tensor_factories_script(self):
|
||||
class TensorFactory(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
return torch.zeros(x.shape, dtype=torch.float) + torch.ones(x.shape, dtype=torch.float)
|
||||
|
||||
x = torch.randn(2, 3, 4)
|
||||
self.run_test(TensorFactory(), x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_tensor_like_factories_script(self):
|
||||
class TensorFactory(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
zeros = torch.zeros_like(x, dtype=torch.float, layout=torch.strided, device=torch.device('cpu'))
|
||||
ones = torch.ones_like(x, dtype=torch.float, layout=torch.strided, device=torch.device('cpu'))
|
||||
return zeros + ones
|
||||
|
||||
x = torch.randn(2, 3, 4)
|
||||
self.run_test(TensorFactory(), x)
|
||||
|
||||
|
||||
# opset 7 tests
|
||||
TestONNXRuntime_opset7 = type(str("TestONNXRuntime_opset7"),
|
||||
(unittest.TestCase,),
|
||||
dict(TestONNXRuntime.__dict__, opset_version=7))
|
||||
|
||||
# opset 8 tests
|
||||
TestONNXRuntime_opset8 = type(str("TestONNXRuntime_opset8"),
|
||||
(unittest.TestCase,),
|
||||
dict(TestONNXRuntime.__dict__, opset_version=8))
|
||||
|
||||
# opset 10 tests
|
||||
TestONNXRuntime_opset10 = type(str("TestONNXRuntime_opset10"),
|
||||
(unittest.TestCase,),
|
||||
dict(TestONNXRuntime.__dict__, opset_version=10))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@ -232,6 +232,7 @@ def add_torch_libs():
|
||||
"torch/csrc/jit/init.cpp",
|
||||
"torch/csrc/jit/passes/inline_fork_wait.cpp",
|
||||
"torch/csrc/jit/passes/onnx.cpp",
|
||||
"torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp",
|
||||
"torch/csrc/jit/passes/onnx/constant_fold.cpp",
|
||||
"torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp",
|
||||
"torch/csrc/jit/passes/onnx/peephole.cpp",
|
||||
|
||||
@ -38,11 +38,11 @@ fi
|
||||
# Run Clang-Tidy
|
||||
# The negative filters below are to exclude files that include onnx_pb.h or
|
||||
# caffe2_pb.h, otherwise we'd have to build protos as part of this CI job.
|
||||
time python tools/clang_tidy.py \
|
||||
--verbose \
|
||||
--paths torch/csrc/ \
|
||||
--diff "$BASE_BRANCH" \
|
||||
-g"-torch/csrc/jit/export.cpp" \
|
||||
-g"-torch/csrc/jit/import.cpp" \
|
||||
-g"-torch/csrc/jit/netdef_converter.cpp" \
|
||||
time python tools/clang_tidy.py \
|
||||
--verbose \
|
||||
--paths torch/csrc/ \
|
||||
--diff "$BASE_BRANCH" \
|
||||
-g"-torch/csrc/jit/export.cpp" \
|
||||
-g"-torch/csrc/jit/import.cpp" \
|
||||
-g"-torch/csrc/jit/netdef_converter.cpp" \
|
||||
"$@"
|
||||
|
||||
@ -74,6 +74,7 @@ set(TORCH_PYTHON_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/fixup_onnx_loop.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/peephole.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/constant_fold.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/python_arg_flatten.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/python_interpreter.cpp
|
||||
|
||||
@ -24,6 +24,7 @@
|
||||
#include <torch/csrc/jit/passes/loop_unrolling.h>
|
||||
#include <torch/csrc/jit/passes/lower_tuples.h>
|
||||
#include <torch/csrc/jit/passes/onnx.h>
|
||||
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
|
||||
#include <torch/csrc/jit/passes/onnx/constant_fold.h>
|
||||
#include <torch/csrc/jit/passes/onnx/fixup_onnx_loop.h>
|
||||
#include <torch/csrc/jit/passes/onnx/peephole.h>
|
||||
@ -112,6 +113,7 @@ void initJITBindings(PyObject* module) {
|
||||
.def("_jit_pass_onnx", ToONNX)
|
||||
.def("_jit_pass_lower_all_tuples", LowerAllTuples)
|
||||
.def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
|
||||
.def("_jit_pass_onnx_cast_all_constant_to_floating", CastAllConstantToFloating)
|
||||
.def(
|
||||
"_jit_pass_onnx_constant_fold",
|
||||
[](std::shared_ptr<Graph>& graph,
|
||||
|
||||
70
torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp
Normal file
70
torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp
Normal file
@ -0,0 +1,70 @@
|
||||
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace onnx {
|
||||
using namespace ::c10::onnx;
|
||||
}
|
||||
|
||||
// For ONNX opset < 9, constant operator supports only three data types:
|
||||
// float16, float, and double. Constants of other data types are exported as float or double
|
||||
// and then cast back to their original data type with a cast node.
|
||||
// The above transformation is done in this pass.
|
||||
// The motivation behind having it as a post process pass opposed to handling in symbolic,
|
||||
// is that many constant operators would have already been removed in the export before this step.
|
||||
// On the other hand if cast is inserted in symbolic, subsequent node conversion will break
|
||||
// if it depends on certain inputs being constant.
|
||||
void CastAllConstantToFloating(Block* block) {
|
||||
auto graph = block->owningGraph();
|
||||
auto it = block->nodes().begin();
|
||||
while (it != block->nodes().end()) {
|
||||
auto node = *it;
|
||||
++it;
|
||||
for (auto block : node->blocks()) {
|
||||
CastAllConstantToFloating(block);
|
||||
}
|
||||
|
||||
if (node->kind() == onnx::Constant) {
|
||||
auto val = node->t(attr::value);
|
||||
at::ScalarType dtype = val.scalar_type();
|
||||
if (dtype != at::ScalarType::Double && dtype != at::ScalarType::Float && dtype != at::ScalarType::Half) {
|
||||
int to_type;
|
||||
switch (val.scalar_type()){
|
||||
case at::ScalarType::Byte:
|
||||
case at::ScalarType::Char:
|
||||
case at::ScalarType::Int:
|
||||
case at::ScalarType::Short:
|
||||
case at::ScalarType::Bool:
|
||||
to_type = 6; // ::ONNX_NAMESPACE::TensorProto_DataType_INT32;
|
||||
val = val.to(at::ScalarType::Float);
|
||||
break;
|
||||
|
||||
case at::ScalarType::Long:
|
||||
to_type = 7; // ::ONNX_NAMESPACE::TensorProto_DataType_INT64;
|
||||
val = val.to(at::ScalarType::Double);
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unsupported types: complex, string");
|
||||
}
|
||||
// create a cast node
|
||||
node->removeAttribute(attr::value);
|
||||
node->t_(attr::value, val);
|
||||
Node* cast_node = graph->create(onnx::Cast, 1);
|
||||
cast_node->i_(attr::to, to_type);
|
||||
cast_node->insertAfter(node);
|
||||
// get input from cast node
|
||||
node->outputs().at(0)->replaceAllUsesWith(cast_node->outputs().at(0));
|
||||
// add input from constant to cast node
|
||||
cast_node->addInput(node->outputs().at(0));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CastAllConstantToFloating(const std::shared_ptr<Graph>& graph) {
|
||||
CastAllConstantToFloating(graph->block());
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
12
torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h
Normal file
12
torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h
Normal file
@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
// see .cpp for docs
|
||||
TORCH_API void CastAllConstantToFloating(const std::shared_ptr<Graph>& graph);
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
@ -351,7 +351,7 @@ void hackFixupPadPackedShapes(Block* graph) {
|
||||
}
|
||||
}
|
||||
|
||||
void fixDefaultRNNState(Graph* graph, Node* n, int input_index) {
|
||||
void fixDefaultRNNState(Graph* graph, Node* n, int input_index, int opset_version) {
|
||||
auto initial_state = n->inputs()[input_index];
|
||||
|
||||
// The RNN code in pytorch accepts an optional hidden state. When it
|
||||
@ -420,21 +420,29 @@ void fixDefaultRNNState(Graph* graph, Node* n, int input_index) {
|
||||
concated_dims->addInput(unsqueezed_batch_size->outputs()[0]);
|
||||
concated_dims->addInput(hidden_size->outputs()[0]);
|
||||
|
||||
Node* constant_of_shape = graph->create(onnx::ConstantOfShape, 1);
|
||||
constant_of_shape->insertBefore(n);
|
||||
constant_of_shape->addInput(concated_dims->outputs()[0]);
|
||||
n->replaceInput(input_index, constant_of_shape->outputs()[0]);
|
||||
if (opset_version < 9) {
|
||||
Node* constant_fill = graph->create(onnx::ConstantFill, 1);
|
||||
constant_fill->insertBefore(n);
|
||||
constant_fill->i_(attr::input_as_shape, 1);
|
||||
constant_fill->addInput(concated_dims->outputs()[0]);
|
||||
n->replaceInput(input_index, constant_fill->outputs()[0]);
|
||||
} else {
|
||||
Node* constant_of_shape = graph->create(onnx::ConstantOfShape, 1);
|
||||
constant_of_shape->insertBefore(n);
|
||||
constant_of_shape->addInput(concated_dims->outputs()[0]);
|
||||
n->replaceInput(input_index, constant_of_shape->outputs()[0]);
|
||||
}
|
||||
|
||||
if (initial_state->uses().size() == 0) {
|
||||
initial_state->node()->destroy();
|
||||
}
|
||||
}
|
||||
|
||||
void fixDefaultRnnHiddenState(Block* b) {
|
||||
void fixDefaultRnnHiddenState(Block* b, int opset_version) {
|
||||
for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
|
||||
auto* n = *it;
|
||||
for (auto* child_block : n->blocks()) {
|
||||
fixDefaultRnnHiddenState(child_block);
|
||||
fixDefaultRnnHiddenState(child_block, opset_version);
|
||||
}
|
||||
|
||||
if (!isRNN(n)) {
|
||||
@ -445,15 +453,15 @@ void fixDefaultRnnHiddenState(Block* b) {
|
||||
if (n->inputs().size() < 6) {
|
||||
continue;
|
||||
}
|
||||
fixDefaultRNNState(b->owningGraph(), n, 5);
|
||||
fixDefaultRNNState(b->owningGraph(), n, 5, opset_version);
|
||||
}
|
||||
}
|
||||
|
||||
void fixDefaultLstmCellState(Block* b) {
|
||||
void fixDefaultLstmCellState(Block* b, int opset_version) {
|
||||
for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
|
||||
auto* n = *it;
|
||||
for (auto* child_block : n->blocks()) {
|
||||
fixDefaultLstmCellState(child_block);
|
||||
fixDefaultLstmCellState(child_block, opset_version);
|
||||
}
|
||||
|
||||
if (n->kind() != onnx::LSTM) {
|
||||
@ -464,7 +472,7 @@ void fixDefaultLstmCellState(Block* b) {
|
||||
if (n->inputs().size() < 7) {
|
||||
continue;
|
||||
}
|
||||
fixDefaultRNNState(b->owningGraph(), n, 6);
|
||||
fixDefaultRNNState(b->owningGraph(), n, 6, opset_version);
|
||||
}
|
||||
}
|
||||
|
||||
@ -625,15 +633,15 @@ void removeMaxPoolUnusedOutput(Block* b) {
|
||||
// writing your optimization in jit/passes/peephole.cpp rather than
|
||||
// here, as it will be generally applicable to the JIT as well. The
|
||||
// optimizations here are ONLY applied on ONNX update
|
||||
void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph) {
|
||||
void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph, int opset_version) {
|
||||
// TODO: decide on fixpoint strategy
|
||||
// TODO: make it easier not to do O(k) iterations over the graph, where
|
||||
// k is the number of distinct peephole optimizations
|
||||
hackFixupPadPackedShapes(graph->block());
|
||||
pushPackingPastRnn(graph->block());
|
||||
removeNopPacking(graph->block());
|
||||
fixDefaultRnnHiddenState(graph->block());
|
||||
fixDefaultLstmCellState(graph->block());
|
||||
fixDefaultRnnHiddenState(graph->block(), opset_version);
|
||||
fixDefaultLstmCellState(graph->block(), opset_version);
|
||||
fuseBroadcast(graph->block());
|
||||
fuseConsecutiveTransposes(graph->block());
|
||||
eliminateNopTranspose(graph->block());
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph);
|
||||
void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph, int opset_version);
|
||||
|
||||
}
|
||||
} // namespace torch
|
||||
|
||||
@ -177,9 +177,9 @@ def _unimplemented(op, msg):
|
||||
|
||||
def _black_list_in_opset(name):
|
||||
def symbolic_fn(*args, **kwargs):
|
||||
warnings.warn("ONNX export failed on {}, which is not yet implemented for opset 10. "
|
||||
"Try exporting with a previous opset version."
|
||||
.format(name))
|
||||
raise RuntimeError("ONNX export failed on {}, which is not implemented for opset {}. "
|
||||
"Try exporting with other opset versions."
|
||||
.format(name, _export_onnx_opset_version))
|
||||
return symbolic_fn
|
||||
|
||||
|
||||
@ -192,10 +192,10 @@ def _try_get_scalar_type(*args):
|
||||
return None
|
||||
|
||||
def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
|
||||
if _export_onnx_opset_version == 9:
|
||||
if _export_onnx_opset_version <= 9:
|
||||
from torch.onnx.symbolic_opset9 import _slice
|
||||
return _slice(g, input, axes, starts, ends)
|
||||
if _export_onnx_opset_version == 10:
|
||||
else:
|
||||
from torch.onnx.symbolic_opset10 import _slice
|
||||
return _slice(g, input, axes, starts, ends, steps, dynamic_slice)
|
||||
|
||||
@ -228,7 +228,7 @@ def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False)
|
||||
|
||||
_default_onnx_opset_version = 9
|
||||
_onnx_master_opset = 10
|
||||
_onnx_stable_opsets = [9, 10]
|
||||
_onnx_stable_opsets = [7, 8, 9, 10]
|
||||
_export_onnx_opset_version = _default_onnx_opset_version
|
||||
|
||||
|
||||
|
||||
@ -65,8 +65,8 @@ def _max_pool(name, tuple_fn, ndims, return_indices):
|
||||
strides_i=[1 for _ in range(ndims)])
|
||||
# convert indices to have non-flattened indices values
|
||||
from torch.onnx.symbolic_opset9 import sub
|
||||
s = _slice_op(g, flattened_indices, axes=[2 + i for i in range(ndims)],
|
||||
starts=tuple_fn(0), ends=tuple_fn(1))
|
||||
s = sym_help._slice_helper(g, flattened_indices, axes=[2 + i for i in range(ndims)],
|
||||
starts=tuple_fn(0), ends=tuple_fn(1))
|
||||
indices = sub(g, indices, s)
|
||||
return r, indices
|
||||
else:
|
||||
|
||||
47
torch/onnx/symbolic_opset7.py
Normal file
47
torch/onnx/symbolic_opset7.py
Normal file
@ -0,0 +1,47 @@
|
||||
from torch.onnx.symbolic_helper import _black_list_in_opset
|
||||
|
||||
import torch.onnx.symbolic_opset9 as sym_opset9
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
# Note [ONNX operators that are added/updated from opset 7 to opset 8]
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# New operators:
|
||||
# Expand
|
||||
#
|
||||
# Updated operators:
|
||||
# Min, Max, Sum, Mean: supports multidirectional broadcasting.
|
||||
# MaxPool: added optional indices output.
|
||||
# Scan
|
||||
|
||||
black_listed_operators = [
|
||||
"scan", "expand", "expand_as",
|
||||
"adaptive_max_pool1d", "adaptive_max_pool2d", "adaptive_max_pool3d",
|
||||
"max_pool1d_with_indices", "max_pool2d_with_indices", "max_pool3d_with_indices"
|
||||
]
|
||||
|
||||
|
||||
# NOTE: max, min, sum, mean: broadcasting is not supported in opset 7.
|
||||
# torch.max (same for torch.min) actually has two interfaces smashed together:
|
||||
# torch.max(x, dim, keepdim) and torch.max(x, y)
|
||||
def max(g, self, dim_or_y=None, keepdim=None):
|
||||
# torch.max(input, other)
|
||||
if keepdim is None and dim_or_y is not None:
|
||||
warnings.warn("Multidirectional broadcasting is not supported in opset 7. "
|
||||
"This might cause the onnx model to be incorrect, if inputs to max operators "
|
||||
"have different shapes")
|
||||
return sym_opset9.max(g, self, dim_or_y, keepdim)
|
||||
|
||||
|
||||
def min(g, self, dim_or_y=None, keepdim=None):
|
||||
# torch.min(input, other)
|
||||
if keepdim is None and dim_or_y is not None:
|
||||
warnings.warn("Multidirectional broadcasting is not supported in opset 7. "
|
||||
"This might cause the onnx model to be incorrect, if inputs to min operators "
|
||||
"have different shapes")
|
||||
return sym_opset9.min(g, self, dim_or_y, keepdim)
|
||||
|
||||
|
||||
for black_listed_op in black_listed_operators:
|
||||
vars()[black_listed_op] = _black_list_in_opset(black_listed_op)
|
||||
240
torch/onnx/symbolic_opset8.py
Normal file
240
torch/onnx/symbolic_opset8.py
Normal file
@ -0,0 +1,240 @@
|
||||
import torch
|
||||
import torch.onnx.symbolic_helper as sym_help
|
||||
import torch.onnx.symbolic_opset9 as sym_opset9
|
||||
|
||||
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _black_list_in_opset, _try_get_scalar_type
|
||||
from torch.onnx.symbolic_opset9 import wrap_logical_op_with_cast_to, _cast_Float
|
||||
|
||||
import warnings
|
||||
|
||||
# Note [ONNX operators that are added/updated from opset 8 to opset 9]
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# New operators:
|
||||
# Compress
|
||||
# ConstantOfShape
|
||||
# EyeLike
|
||||
# MaxUnpool
|
||||
# OneHot
|
||||
# Sinh
|
||||
# Cosh
|
||||
# Asinh
|
||||
# Acosh
|
||||
# Atanh
|
||||
# Shrink
|
||||
# IsNaN
|
||||
# Sign
|
||||
# Erf
|
||||
# Scatter
|
||||
# Where
|
||||
# NonZero
|
||||
# TfIdfVectorizer
|
||||
# MeanVarianceNormalization
|
||||
#
|
||||
# Updated operators:
|
||||
# BatchNormalization: removed spatial attribute.
|
||||
# Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported.
|
||||
# Cast: more data types{string} supported.
|
||||
# Upsample: moved scales from attribute to input.
|
||||
# Scan
|
||||
|
||||
black_listed_operators = [
|
||||
"nonzero", "where", "scatter", "scatter_add", "erf", "sign", "isnan", "gather",
|
||||
]
|
||||
|
||||
for black_listed_op in black_listed_operators:
|
||||
vars()[black_listed_op] = _black_list_in_opset(black_listed_op)
|
||||
|
||||
|
||||
def upsample_nearest2d(g, input, output_size, align_corners=None):
|
||||
align_corners = sym_help._maybe_get_scalar(align_corners)
|
||||
if align_corners:
|
||||
return _unimplemented("upsample_neareset2d", "align_corners == True")
|
||||
|
||||
output_size = sym_help._maybe_get_const(output_size, 'is')
|
||||
if sym_help._is_value(output_size):
|
||||
return _unimplemented("upsample_nearest2d", "torch._C.Value (output_size) indexing")
|
||||
else:
|
||||
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
|
||||
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
|
||||
scales = [1., 1., height_scale, width_scale]
|
||||
return g.op("Upsample", input, mode_s="nearest",
|
||||
scales_f=scales)
|
||||
|
||||
|
||||
# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation
|
||||
# issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which
|
||||
# is lost after casting.
|
||||
def _try_cast_integer_to_float(g, *args):
|
||||
floating_scalar_types = ['Half', 'Float', 'Double']
|
||||
old_type = None
|
||||
# Cast the input tensor to Float if its scalarType is known and is not floating number.
|
||||
# If casting is performed, return the old scalarType, otherwise return None.
|
||||
if args[0].type().kind() == "DimensionedTensorType" or args[0].type().kind() == "CompleteTensorType":
|
||||
old_type = args[0].type().scalarType()
|
||||
if old_type not in floating_scalar_types:
|
||||
args = tuple(_cast_Float(g, arg, False) for arg in args)
|
||||
else:
|
||||
return (None,) + args
|
||||
else:
|
||||
warnings.warn("Only floating datatype is supported for these operators: "
|
||||
"{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause "
|
||||
"the onnx model to be incorrect, if inputs have integer datatypes.")
|
||||
return (old_type,) + args
|
||||
|
||||
|
||||
def _cast_to_type(g, input, to_type):
|
||||
if to_type is None:
|
||||
return input
|
||||
return getattr(sym_opset9, '_cast_{}'.format(to_type))(g, input, False)
|
||||
|
||||
|
||||
def _comparison_operator(g, input, other, op_name):
|
||||
other = sym_help._maybe_get_scalar(other)
|
||||
other = sym_help._if_scalar_type_as(g, other, input)
|
||||
_, input, other = _try_cast_integer_to_float(g, input, other)
|
||||
return g.op(op_name, input, other)
|
||||
|
||||
|
||||
# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten},
|
||||
# integer input type not supported in opset8. Cast to float if possible.
|
||||
@wrap_logical_op_with_cast_to('Byte')
|
||||
def gt(g, input, other):
|
||||
return _comparison_operator(g, input, other, "Greater")
|
||||
|
||||
|
||||
@wrap_logical_op_with_cast_to('Byte')
|
||||
def lt(g, input, other):
|
||||
return _comparison_operator(g, input, other, "Less")
|
||||
|
||||
|
||||
def bmm(g, self, other):
|
||||
if _try_get_scalar_type(self):
|
||||
old_type, self, other = _try_cast_integer_to_float(g, self, other)
|
||||
return _cast_to_type(g, g.op("MatMul", self, other), old_type)
|
||||
else:
|
||||
return g.op("MatMul", self, other)
|
||||
|
||||
|
||||
def matmul(g, self, other):
|
||||
return bmm(g, self, other)
|
||||
|
||||
|
||||
def prelu(g, self, weight):
|
||||
if self.isCompleteTensor():
|
||||
self_sizes = self.type().sizes()
|
||||
if self_sizes and len(self_sizes) > 2:
|
||||
weight = g.op("Unsqueeze", weight, axes_i=list(range(1, len(self_sizes) - 1)))
|
||||
if _try_get_scalar_type(self):
|
||||
old_type, self, weight = _try_cast_integer_to_float(g, self, weight)
|
||||
return _cast_to_type(g, g.op("PRelu", self, weight), old_type)
|
||||
else:
|
||||
return g.op("PRelu", self, weight)
|
||||
|
||||
|
||||
def mm(g, self, other):
|
||||
# Create a dummy C tensor. Only needed for API purposes, the value is
|
||||
# since beta = 0
|
||||
ty = sym_help._try_get_scalar_type(self, other).lower()
|
||||
C = g.constant(0, [1], ty)
|
||||
if _try_get_scalar_type(self):
|
||||
old_type, self, other, C = _try_cast_integer_to_float(g, self, other, C)
|
||||
return _cast_to_type(g, g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0), old_type)
|
||||
else:
|
||||
return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
|
||||
|
||||
|
||||
@parse_args('v', 'v', 'v', 't', 't')
|
||||
def addmm(g, self, mat1, mat2, beta, alpha):
|
||||
if _try_get_scalar_type(self):
|
||||
old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2)
|
||||
return _cast_to_type(
|
||||
g, g.op("Gemm", mat1, mat2, self,
|
||||
beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha)), old_type)
|
||||
else:
|
||||
return g.op("Gemm", mat1, mat2, self, beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha))
|
||||
|
||||
|
||||
def view(g, self, size):
|
||||
size = sym_help._maybe_get_const(size, 'is')
|
||||
if sym_help._is_value(size):
|
||||
shape = size
|
||||
else:
|
||||
if self.isCompleteTensor():
|
||||
self_sizes = self.type().sizes()
|
||||
if self_sizes and len(size) == 2 and self_sizes[0] == size[0]:
|
||||
old_type, self = _try_cast_integer_to_float(g, self)
|
||||
return _cast_to_type(g, g.op("Flatten", self, axis_i=1), old_type)
|
||||
shape = g.op("Constant", value_t=torch.LongTensor(size))
|
||||
return g.op("Reshape", self, shape)
|
||||
|
||||
|
||||
def flatten(g, input, start_dim, end_dim):
|
||||
start_dim_i = sym_help._get_const(start_dim, 'i', 'start_dim')
|
||||
end_dim_i = sym_help._get_const(end_dim, 'i', 'end_dim')
|
||||
|
||||
dim = input.type().dim()
|
||||
if end_dim_i < 0 :
|
||||
end_dim_i = dim + end_dim_i
|
||||
# use ONNX's Flatten operator for cases where the output shape is 2D
|
||||
if start_dim_i == 1 and end_dim_i == dim - 1 :
|
||||
if _try_get_scalar_type(input):
|
||||
old_type, input = _try_cast_integer_to_float(g, input)
|
||||
return _cast_to_type(g, g.op("Flatten", input, axis_i=start_dim_i), old_type)
|
||||
else:
|
||||
return g.op("Flatten", input, axis_i=start_dim_i)
|
||||
if start_dim_i == 0 and end_dim_i == dim - 2 :
|
||||
if _try_get_scalar_type(input):
|
||||
old_type, input = _try_cast_integer_to_float(g, input)
|
||||
return _cast_to_type(g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type)
|
||||
else:
|
||||
return g.op("Flatten", input, axis_i=end_dim_i + 1)
|
||||
|
||||
return sym_opset9.flatten(g, input, start_dim, end_dim)
|
||||
|
||||
|
||||
def _constant_fill(g, sizes, dtype, const_value):
|
||||
if not sym_help.scalar_type_to_pytorch_type[dtype].is_floating_point:
|
||||
result = g.op(
|
||||
"ConstantFill", sizes, dtype_i=sym_help.cast_pytorch_to_onnx["Float"], input_as_shape_i=1, value_f=const_value)
|
||||
return sym_help._cast_func_template(sym_help.scalar_type_to_onnx[dtype], g, result, None)
|
||||
else:
|
||||
return g.op("ConstantFill", sizes, dtype_i=sym_help.scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=const_value)
|
||||
|
||||
|
||||
@parse_args('v', 'i', 'v', 'v', 'v')
|
||||
def zeros(g, sizes, dtype, layout, device, pin_memory=False):
|
||||
# NOTE: no way to set device and layout in ONNX, so we ignore it
|
||||
return _constant_fill(g, sizes, dtype, 0)
|
||||
|
||||
|
||||
@parse_args('v', 'i', 'v', 'v', 'v')
|
||||
def zeros_like(g, input, dtype, layout, device, pin_memory=False):
|
||||
shape = g.op("Shape", input)
|
||||
return _constant_fill(g, shape, dtype, 0)
|
||||
|
||||
|
||||
@parse_args('v', 'i', 'v', 'v', 'v')
|
||||
def ones(g, sizes, dtype, layout, device, pin_memory=False):
|
||||
return _constant_fill(g, sizes, dtype, 1)
|
||||
|
||||
|
||||
@parse_args('v', 'i', 'v', 'v', 'v')
|
||||
def ones_like(g, input, dtype, layout, device, pin_memory=False):
|
||||
shape = g.op("Shape", input)
|
||||
return _constant_fill(g, shape, dtype, 1)
|
||||
|
||||
|
||||
def full(g, sizes, value, dtype, layout, device, pin_memory=False):
|
||||
const_value = sym_help._maybe_get_const(value, 't')
|
||||
if sym_help._is_value(const_value):
|
||||
tmp = zeros(g, sizes, dtype, layout, device)
|
||||
return sym_opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
|
||||
else:
|
||||
dtype = sym_help._get_const(dtype, 'i', 'dtype')
|
||||
return _constant_fill(g, sizes, dtype, const_value)
|
||||
|
||||
|
||||
@parse_args('v', 'f', 'i', 'v', 'v', 'v')
|
||||
def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False):
|
||||
shape = g.op("Shape", input)
|
||||
return _constant_fill(g, shape, dtype, fill_value)
|
||||
@ -23,19 +23,40 @@ def register_version(domain, version):
|
||||
register_ops_in_version(domain, version)
|
||||
|
||||
|
||||
def register_ops_helper(domain, version, iter_version):
|
||||
version_ops = get_ops_in_version(iter_version)
|
||||
for op in version_ops:
|
||||
if isfunction(op[1]) and not is_registered_op(op[0], domain, version):
|
||||
register_op(op[0], op[1], domain, version)
|
||||
|
||||
|
||||
def register_ops_in_version(domain, version):
|
||||
# iterates through the symbolic functions of
|
||||
# the specified opset version, and the previous
|
||||
# opset versions for operators supported in
|
||||
# previous versions
|
||||
# previous versions.
|
||||
|
||||
# Opset 9 is the base version. It is selected as the base version because
|
||||
# 1. It is the first opset version supported by PyTorch export.
|
||||
# 2. opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
|
||||
# that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
|
||||
# we chose to handle them as special cases separately.
|
||||
# Backward support for opset versions beyond opset 7 is not in our roadmap.
|
||||
|
||||
# For opset versions other than 9, by default they will inherit the symbolic functions defined in
|
||||
# symbolic_opset9.py.
|
||||
# To extend support for updated operators in different opset versions on top of opset 9,
|
||||
# simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
|
||||
# Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
|
||||
iter_version = version
|
||||
while iter_version >= 9:
|
||||
version_ops = get_ops_in_version(iter_version)
|
||||
for op in version_ops:
|
||||
if isfunction(op[1]) and \
|
||||
not is_registered_op(op[0], domain, version):
|
||||
register_op(op[0], op[1], domain, version)
|
||||
iter_version = iter_version - 1
|
||||
while iter_version != 9:
|
||||
register_ops_helper(domain, version, iter_version)
|
||||
if iter_version > 9:
|
||||
iter_version = iter_version - 1
|
||||
else:
|
||||
iter_version = iter_version + 1
|
||||
|
||||
register_ops_helper(domain, version, 9)
|
||||
|
||||
|
||||
def get_ops_in_version(version):
|
||||
|
||||
@ -231,7 +231,8 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
|
||||
|
||||
graph = torch._C._jit_pass_onnx(graph, operator_export_type)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
torch._C._jit_pass_onnx_peephole(graph)
|
||||
from torch.onnx.symbolic_helper import _export_onnx_opset_version
|
||||
torch._C._jit_pass_onnx_peephole(graph, _export_onnx_opset_version)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
|
||||
# graph is not a valid jit graph anymore because types have been replaced
|
||||
@ -354,6 +355,11 @@ def _model_to_graph(model, args, verbose=False, training=False,
|
||||
_export_onnx_opset_version)
|
||||
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
||||
|
||||
# For ONNX opset < 9, constants only have three data types: float16, float, double.
|
||||
# In this pass transform constants of other data types to float/double + cast operator.
|
||||
if _export_onnx_opset_version < 9:
|
||||
torch._C._jit_pass_onnx_cast_all_constant_to_floating(graph)
|
||||
|
||||
if verbose:
|
||||
print(graph)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user