PyTorch export to ONNX Opset 7 and 8 - Cont (#22421)

Summary:
This is an extension to the original PR https://github.com/pytorch/pytorch/pull/21765

1. Increase the coverage of different opsets support, comments, and blacklisting.
2. Adding backend tests for both caffe2 and onnxruntime on opset 7 and opset 8.
3. Reusing onnx model tests in caffe2 for onnxruntime.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22421

Reviewed By: zrphercule

Differential Revision: D16225518

Pulled By: houseroad

fbshipit-source-id: 01ae3eed85111a83a0124e9e95512b80109d6aee
This commit is contained in:
BowenBao
2019-07-12 14:45:52 -07:00
committed by Facebook Github Bot
parent 9f8e2c067f
commit b3147bc674
19 changed files with 803 additions and 78 deletions

View File

@ -55,6 +55,7 @@ pytest "${args[@]}" \
'not (TestOperators and test_full_like) and not (TestOperators and test_zeros_like) and not (TestOperators and test_ones_like) and not (TestModels and test_vgg16) and not (TestModels and test_vgg16_bn) and not (TestModels and test_vgg19) and not (TestModels and test_vgg19_bn)' \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py" \
--ignore "$top_dir/test/onnx/test_custom_ops.py" \
--ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \
"${test_paths[@]}"
# onnxruntime only support py3
@ -63,5 +64,6 @@ if [[ "$BUILD_ENVIRONMENT" == *py3* ]]; then
pip install --user onnxruntime
pytest "${args[@]}" "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py"
pytest "${args[@]}" "$top_dir/test/onnx/test_custom_ops.py"
pytest "${args[@]}" "$top_dir/test/onnx/test_models_onnxruntime.py"
fi

View File

@ -0,0 +1,23 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import onnxruntime # noqa
from test_models import TestModels
from test_pytorch_onnx_onnxruntime import run_model_test
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10]
for opset_version in opset_versions:
self.opset_version = opset_version
run_model_test(self, model, False,
input=inputs, rtol=rtol, atol=atol)
if __name__ == '__main__':
TestModels.exportTest = exportTest
unittest.main()

View File

@ -132,6 +132,45 @@ class TestONNXOpset(TestCase):
x = torch.randn(20, 16, 50)
check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
def test_upsample(self):
class MyModule(Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, x):
size = [v * 2 for v in x.size()[2:]]
size = [int(i) for i in size]
return torch.nn.functional.interpolate(x, size=size, mode='nearest')
module = MyModule()
ops8 = [{"op_name" : "Upsample", "attributes" : [{"name": "mode", "s": ("nearest").encode(), "type": 3},
{"name": "scales", "floats": [1.0, 1.0, 2.0, 2.0], "type": 6}]}]
ops9 = [{"op_name" : "Constant"},
{"op_name" : "Upsample", "attributes" : [{"name": "mode", "s": ("nearest").encode(), "type": 3}]}]
ops = {8 : ops8, 9 : ops9}
x = torch.randn(2, 2, 2, 2)
check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
def test_cast_constant(self):
class MyModule(Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, x):
return torch._dim_arange(x, 1)
module = MyModule()
ops_8 = [{"op_name" : "Shape"}, {"op_name" : "Constant"},
{"op_name" : "Cast", "attributes": [{"name": "to", "i": 7, "type": 2}]},
{"op_name" : "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
{"op_name" : "Range"}]
ops_9 = [{"op_name" : "Shape"}, {"op_name" : "Constant"},
{"op_name" : "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
{"op_name" : "Range"}]
ops = {8 : ops_8, 9 : ops_9}
x = torch.ones(5, 6)
check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
def test_slice(self):
class MyModule(Module):
def forward(self, x):

View File

@ -40,7 +40,7 @@ import onnx
import caffe2.python.onnx.backend as c2
from test_pytorch_common import skipIfTravis, skipIfNoLapack, skipIfNoCuda
from test_pytorch_common import skipIfUnsupportedOpsetVersion
from test_pytorch_common import skipIfUnsupportedOpsetVersion, skipIfUnsupportedMinOpsetVersion
import verify
skip = unittest.skip
@ -664,6 +664,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
input = torch.empty(BATCH_SIZE, 10, 10).uniform_(4, 9)
self.run_model_test(MyModel(), train=False, input=input, batch_size=BATCH_SIZE)
@skipIfUnsupportedMinOpsetVersion(9)
def test_erf(self):
class MyModel(torch.nn.Module):
def __init__(self):
@ -807,16 +808,19 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
x = torch.randn(20, 16, 50, 44, 30, requires_grad=True)
self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE)
@skipIfUnsupportedMinOpsetVersion(8)
def test_adaptive_max_pool1D(self):
model = torch.nn.AdaptiveMaxPool1d((5))
x = torch.randn(20, 16, 50, requires_grad=True)
self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE)
@skipIfUnsupportedMinOpsetVersion(8)
def test_adaptive_max_pool2D(self):
model = torch.nn.AdaptiveMaxPool2d((5, 4))
x = torch.randn(20, 16, 50, 32, requires_grad=True)
self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE)
@skipIfUnsupportedMinOpsetVersion(8)
def test_adaptive_max_pool3D(self):
model = torch.nn.AdaptiveMaxPool3d((5, 4, 3))
x = torch.randn(20, 16, 50, 44, 30, requires_grad=True)
@ -993,7 +997,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
self.run_model_test(model, train=False, input=(x),
batch_size=BATCH_SIZE, use_gpu=False)
@skipIfUnsupportedOpsetVersion([10])
@skipIfUnsupportedOpsetVersion([7, 8, 10])
def test_interpolate_upsample_dynamic_sizes(self):
class MyModel(torch.nn.Module):
def __init__(self):
@ -1291,6 +1295,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
self.run_model_test(FullClass(), train=False, input=(x,), batch_size=BATCH_SIZE,
use_gpu=False, example_outputs=FullClass()(x))
@skipIfUnsupportedMinOpsetVersion(9)
def test_where_functional(self):
class WhereFunctional(torch.nn.Module):
def forward(self, x):
@ -1299,6 +1304,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
x = torch.randn(3, 4)
self.run_model_test(WhereFunctional(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False)
@skipIfUnsupportedMinOpsetVersion(9)
def test_where_method(self):
class WhereMethod(torch.nn.Module):
def forward(self, x):
@ -1353,6 +1359,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
self.run_model_test(RsubModel(), train=False, input=(x,),
batch_size=BATCH_SIZE, use_gpu=False)
@skipIfUnsupportedMinOpsetVersion(9)
def test_isnan(self):
class IsNaNModel(torch.nn.Module):
def forward(self, input):
@ -1361,6 +1368,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
x = torch.tensor([1.0, float('nan'), 2.0])
self.run_model_test(IsNaNModel(), train=False, input=x, batch_size=BATCH_SIZE, use_gpu=False)
@skipIfUnsupportedMinOpsetVersion(9)
def test_scatter(self):
class ScatterModel(torch.nn.Module):
def forward(self, input, indices, values):
@ -1396,6 +1404,23 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
x = torch.randn(4, 4, requires_grad=True)
self.run_model_test(MaxModel(), train=False, input=x, batch_size=BATCH_SIZE)
def test_max_keepdim(self):
class MaxModel(torch.nn.Module):
def forward(self, input):
return torch.max(input, dim=1, keepdim=True)
x = torch.randn(4, 4, requires_grad=True)
self.run_model_test(MaxModel(), train=False, input=x, batch_size=BATCH_SIZE)
def test_max_tensors(self):
class MaxModel(torch.nn.Module):
def forward(self, input, other):
return torch.max(input, other)
x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 4, requires_grad=True)
self.run_model_test(MaxModel(), train=False, input=(x, y), batch_size=BATCH_SIZE)
def test_min(self):
class MinModel(torch.nn.Module):
def forward(self, input):
@ -1841,6 +1866,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
x = torch.randn(1, 2, 3)
self.run_model_test(DropoutModel(), train=False, input=x, batch_size=BATCH_SIZE)
@skipIfUnsupportedMinOpsetVersion(9)
def test_while(self):
class WhileModel(torch.jit.ScriptModule):
@torch.jit.script_method
@ -1901,6 +1927,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
@skipIfUnsupportedMinOpsetVersion(9)
def test_nested_loops(self):
class NestedLoopsModel(torch.jit.ScriptModule):
@torch.jit.script_method
@ -2013,6 +2040,24 @@ TestCaffe2BackendEmbed_opset9 = type(str("TestCaffe2BackendEmbed_opset9"),
(unittest.TestCase,),
dict(TestCaffe2Backend_opset9.__dict__, embed_params=True))
# opset 7 tests
TestCaffe2Backend_opset7 = type(str("TestCaffe2Backend_opset7"),
(unittest.TestCase,),
dict(TestCaffe2Backend_opset9.__dict__, opset_version=7))
TestCaffe2BackendEmbed_opset7 = type(str("TestCaffe2BackendEmbed_opset7"),
(unittest.TestCase,),
dict(TestCaffe2Backend_opset9.__dict__,
embed_params=True, opset_version=7))
# opset 8 tests
TestCaffe2Backend_opset8 = type(str("TestCaffe2Backend_opset8"),
(unittest.TestCase,),
dict(TestCaffe2Backend_opset9.__dict__, opset_version=8))
TestCaffe2BackendEmbed_opset8 = type(str("TestCaffe2BackendEmbed_opset8"),
(unittest.TestCase,),
dict(TestCaffe2Backend_opset9.__dict__,
embed_params=True, opset_version=8))
# opset 10 tests
TestCaffe2Backend_opset10 = type(str("TestCaffe2Backend_opset10"),
(unittest.TestCase,),

View File

@ -6,58 +6,97 @@ from __future__ import unicode_literals
import unittest
import onnxruntime # noqa
import torch
import numpy as np
import io
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion
import model_defs.word_language_model as word_language_model
def run_model_test(self, model, train, batch_size=2, state_dict=None,
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
example_outputs=None, do_constant_folding=True):
model.eval()
if input is None:
input = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
with torch.no_grad():
if isinstance(input, torch.Tensor):
input = (input,)
output = model(*input)
if isinstance(output, torch.Tensor):
output = (output,)
# export the model to ONNX
f = io.BytesIO()
torch.onnx.export(model, input, f,
opset_version=self.opset_version,
example_outputs=output)
input, _ = torch.jit._flatten(input)
output, _ = torch.jit._flatten(output)
def to_numpy(tensor):
if tensor.requires_grad:
return tensor.detach().cpu().numpy()
else:
return tensor.cpu().numpy()
inputs = list(map(to_numpy, input))
outputs = list(map(to_numpy, output))
# compute onnxruntime output prediction
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs))
ort_outs = ort_sess.run(None, ort_inputs)
# compare onnxruntime and PyTorch results
assert len(outputs) == len(ort_outs), "number of outputs differ"
# compare onnxruntime and PyTorch results
[np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
class TestONNXRuntime(unittest.TestCase):
from torch.onnx.symbolic_helper import _export_onnx_opset_version
opset_version = _export_onnx_opset_version
def run_test(self, model, inputs, rtol=1e-05, atol=1e-08):
outputs = model(inputs) if isinstance(inputs, torch.Tensor) else model(*inputs)
def run_test(self, model, input, rtol=1e-3, atol=1e-7):
run_model_test(self, model, False, None,
input=input, rtol=rtol, atol=atol)
# export the model to ONNX
f = io.BytesIO()
torch.onnx.export(model, inputs, f,
opset_version=self.opset_version,
example_outputs=outputs)
def run_word_language_model(self, model_name):
ntokens = 50
emsize = 5
nhid = 5
nlayers = 5
dropout = 0.2
tied = False
batchsize = 5
model = word_language_model.RNNModel(model_name, ntokens, emsize,
nhid, nlayers, dropout, tied,
batchsize)
x = torch.arange(0, ntokens).long().view(-1, batchsize)
# Only support CPU version, since tracer is not working in GPU RNN.
self.run_test(model, (x, model.hidden))
def get_numpy_value_at_index(t, i):
return t[i].detach().numpy() if t[i].requires_grad else t[i].numpy()
def test_word_language_model_RNN_TANH(self):
self.run_word_language_model("RNN_TANH")
def get_numpy_value(t):
return t.detach().numpy() if t.requires_grad else t.numpy()
def test_word_language_model_RNN_RELU(self):
self.run_word_language_model("RNN_RELU")
def get_ort_inputs():
ort_inputs = {}
if isinstance(inputs, torch.Tensor):
ort_inputs = {ort_sess.get_inputs()[0].name: get_numpy_value(inputs)}
else:
for i in range(0, len(outputs)):
ort_inputs[ort_sess.get_inputs()[i].name] = get_numpy_value_at_index(inputs, i)
return ort_inputs
def test_word_language_model_LSTM(self):
self.run_word_language_model("LSTM")
# compute onnxruntime output prediction
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = get_ort_inputs()
ort_outs = ort_sess.run(None, ort_inputs)
# compare onnxruntime and PyTorch results
assert (isinstance(outputs, torch.Tensor) and len(ort_outs) == 1) or \
len(outputs) == len(ort_outs), \
"number of outputs differ"
if isinstance(outputs, torch.Tensor):
assert np.allclose(get_numpy_value(outputs), ort_outs[0],
rtol=rtol, atol=atol)
else :
for i in range(0, len(outputs)):
assert np.allclose(get_numpy_value_at_index(outputs, i), ort_outs[i],
rtol=rtol, atol=atol)
def test_word_language_model_GRU(self):
self.run_word_language_model("GRU")
@skipIfUnsupportedMinOpsetVersion(9)
def test_full_trace(self):
class FullModel(torch.nn.Module):
def forward(self, x):
@ -66,6 +105,7 @@ class TestONNXRuntime(unittest.TestCase):
x = torch.tensor(12)
self.run_test(FullModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_full_script(self):
class FullModelScripting(torch.jit.ScriptModule):
@torch.jit.script_method
@ -80,6 +120,12 @@ class TestONNXRuntime(unittest.TestCase):
x = torch.randn(20, 16, 50)
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(8)
def test_maxpool_with_indices(self):
model = torch.nn.MaxPool1d(2, stride=1, return_indices=True)
x = torch.randn(20, 16, 50)
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(10)
def test_maxpool_dilation(self):
model = torch.nn.MaxPool1d(2, stride=1, dilation=2)
@ -117,13 +163,23 @@ class TestONNXRuntime(unittest.TestCase):
x = torch.tensor(np.arange(6.0).reshape(2, 3))
self.run_test(MyModule(), x)
def test_interpolate(self):
@skipIfUnsupportedMinOpsetVersion(9)
def test_interpolate_scale(self):
class MyModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.interpolate(x, mode="nearest", scale_factor=2)
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_test(MyModel(), x)
# NOTE: Supported in onnxruntime master, enable this after 0.5 release.
@skipIfUnsupportedOpsetVersion([10])
def test_interpolate_output_size(self):
class MyModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.interpolate(x, mode="nearest", size=(6, 8))
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_test(MyModel(), x)
@skipIfUnsupportedMinOpsetVersion(10)
def test_interpolate_downsample(self):
class MyModel(torch.nn.Module):
@ -168,10 +224,162 @@ class TestONNXRuntime(unittest.TestCase):
x = torch.randn(4, 4, requires_grad=True)
self.run_test(ReduceLogSumExpModel(), x)
@skipIfUnsupportedMinOpsetVersion(8)
def test_adaptive_max_pool(self):
model = torch.nn.AdaptiveMaxPool1d((5), return_indices=False)
x = torch.randn(20, 16, 50, requires_grad=True)
self.run_test(model, x)
def test_maxpool_2d(self):
model = torch.nn.MaxPool2d(5, padding=(1, 2))
x = torch.randn(1, 20, 16, 50, requires_grad=True)
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(8)
def test_max_tensors(self):
class MaxModel(torch.nn.Module):
def forward(self, input, other):
return torch.max(input, other)
model = MaxModel()
x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 1, requires_grad=True)
self.run_test(model, (x, y))
def test_gt(self):
class GreaterModel(torch.nn.Module):
def forward(self, input, other):
return input > other
x = torch.randn(1, 2, 3, 4, requires_grad=True)
y = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_test(GreaterModel(), (x, y))
x = torch.randint(10, (3, 4), dtype=torch.int32)
y = torch.randint(10, (3, 4), dtype=torch.int32)
self.run_test(GreaterModel(), (x, y))
def test_gt_scalar(self):
class GreaterModel(torch.nn.Module):
def forward(self, input):
return input > 1
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_test(GreaterModel(), x)
x = torch.randint(10, (3, 4), dtype=torch.int32)
self.run_test(GreaterModel(), x)
def test_lt(self):
class LessModel(torch.nn.Module):
def forward(self, input, other):
return input > other
x = torch.randn(1, 2, 3, 4, requires_grad=True)
y = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_test(LessModel(), (x, y))
x = torch.randint(10, (3, 4), dtype=torch.int32)
y = torch.randint(10, (3, 4), dtype=torch.int32)
self.run_test(LessModel(), (x, y))
def test_matmul(self):
class MatmulModel(torch.nn.Module):
def forward(self, input, other):
return torch.matmul(input, other)
x = torch.randn(3, 4, requires_grad=True)
y = torch.randn(4, 5, requires_grad=True)
self.run_test(MatmulModel(), (x, y))
x = torch.randint(10, (3, 4))
y = torch.randint(10, (4, 5))
self.run_test(MatmulModel(), (x, y))
def test_matmul_batch(self):
class MatmulModel(torch.nn.Module):
def forward(self, input, other):
return torch.matmul(input, other)
x = torch.randn(2, 3, 4, requires_grad=True)
y = torch.randn(2, 4, 5, requires_grad=True)
self.run_test(MatmulModel(), (x, y))
x = torch.randint(10, (2, 3, 4))
y = torch.randint(10, (2, 4, 5))
self.run_test(MatmulModel(), (x, y))
def test_view(self):
class ViewModel(torch.nn.Module):
def forward(self, input):
return input.view(4, 24)
x = torch.randint(10, (4, 2, 3, 4), dtype=torch.int32)
self.run_test(ViewModel(), x)
def test_flatten(self):
class FlattenModel(torch.nn.Module):
def forward(self, input):
return torch.flatten(input)
x = torch.randint(10, (1, 2, 3, 4))
self.run_test(FlattenModel(), x)
def test_flatten2d(self):
class FlattenModel(torch.nn.Module):
def forward(self, input):
return torch.flatten(input, 1)
x = torch.randint(10, (1, 2, 3, 4))
self.run_test(FlattenModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_tensor_factories(self):
class TensorFactory(torch.nn.Module):
def forward(self, x):
return torch.zeros(x.size()) + torch.ones(x.size())
x = torch.randn(2, 3, 4)
self.run_test(TensorFactory(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_tensor_factories_script(self):
class TensorFactory(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return torch.zeros(x.shape, dtype=torch.float) + torch.ones(x.shape, dtype=torch.float)
x = torch.randn(2, 3, 4)
self.run_test(TensorFactory(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_tensor_like_factories_script(self):
class TensorFactory(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
zeros = torch.zeros_like(x, dtype=torch.float, layout=torch.strided, device=torch.device('cpu'))
ones = torch.ones_like(x, dtype=torch.float, layout=torch.strided, device=torch.device('cpu'))
return zeros + ones
x = torch.randn(2, 3, 4)
self.run_test(TensorFactory(), x)
# opset 7 tests
TestONNXRuntime_opset7 = type(str("TestONNXRuntime_opset7"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=7))
# opset 8 tests
TestONNXRuntime_opset8 = type(str("TestONNXRuntime_opset8"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=8))
# opset 10 tests
TestONNXRuntime_opset10 = type(str("TestONNXRuntime_opset10"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=10))
if __name__ == '__main__':
unittest.main()

View File

@ -232,6 +232,7 @@ def add_torch_libs():
"torch/csrc/jit/init.cpp",
"torch/csrc/jit/passes/inline_fork_wait.cpp",
"torch/csrc/jit/passes/onnx.cpp",
"torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp",
"torch/csrc/jit/passes/onnx/constant_fold.cpp",
"torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp",
"torch/csrc/jit/passes/onnx/peephole.cpp",

View File

@ -38,11 +38,11 @@ fi
# Run Clang-Tidy
# The negative filters below are to exclude files that include onnx_pb.h or
# caffe2_pb.h, otherwise we'd have to build protos as part of this CI job.
time python tools/clang_tidy.py \
--verbose \
--paths torch/csrc/ \
--diff "$BASE_BRANCH" \
-g"-torch/csrc/jit/export.cpp" \
-g"-torch/csrc/jit/import.cpp" \
-g"-torch/csrc/jit/netdef_converter.cpp" \
time python tools/clang_tidy.py \
--verbose \
--paths torch/csrc/ \
--diff "$BASE_BRANCH" \
-g"-torch/csrc/jit/export.cpp" \
-g"-torch/csrc/jit/import.cpp" \
-g"-torch/csrc/jit/netdef_converter.cpp" \
"$@"

View File

@ -74,6 +74,7 @@ set(TORCH_PYTHON_SRCS
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/fixup_onnx_loop.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/peephole.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/constant_fold.cpp
${TORCH_SRC_DIR}/csrc/jit/python_arg_flatten.cpp
${TORCH_SRC_DIR}/csrc/jit/python_interpreter.cpp

View File

@ -24,6 +24,7 @@
#include <torch/csrc/jit/passes/loop_unrolling.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/passes/onnx.h>
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
#include <torch/csrc/jit/passes/onnx/constant_fold.h>
#include <torch/csrc/jit/passes/onnx/fixup_onnx_loop.h>
#include <torch/csrc/jit/passes/onnx/peephole.h>
@ -112,6 +113,7 @@ void initJITBindings(PyObject* module) {
.def("_jit_pass_onnx", ToONNX)
.def("_jit_pass_lower_all_tuples", LowerAllTuples)
.def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
.def("_jit_pass_onnx_cast_all_constant_to_floating", CastAllConstantToFloating)
.def(
"_jit_pass_onnx_constant_fold",
[](std::shared_ptr<Graph>& graph,

View File

@ -0,0 +1,70 @@
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
namespace torch {
namespace jit {
namespace onnx {
using namespace ::c10::onnx;
}
// For ONNX opset < 9, constant operator supports only three data types:
// float16, float, and double. Constants of other data types are exported as float or double
// and then cast back to their original data type with a cast node.
// The above transformation is done in this pass.
// The motivation behind having it as a post process pass opposed to handling in symbolic,
// is that many constant operators would have already been removed in the export before this step.
// On the other hand if cast is inserted in symbolic, subsequent node conversion will break
// if it depends on certain inputs being constant.
void CastAllConstantToFloating(Block* block) {
auto graph = block->owningGraph();
auto it = block->nodes().begin();
while (it != block->nodes().end()) {
auto node = *it;
++it;
for (auto block : node->blocks()) {
CastAllConstantToFloating(block);
}
if (node->kind() == onnx::Constant) {
auto val = node->t(attr::value);
at::ScalarType dtype = val.scalar_type();
if (dtype != at::ScalarType::Double && dtype != at::ScalarType::Float && dtype != at::ScalarType::Half) {
int to_type;
switch (val.scalar_type()){
case at::ScalarType::Byte:
case at::ScalarType::Char:
case at::ScalarType::Int:
case at::ScalarType::Short:
case at::ScalarType::Bool:
to_type = 6; // ::ONNX_NAMESPACE::TensorProto_DataType_INT32;
val = val.to(at::ScalarType::Float);
break;
case at::ScalarType::Long:
to_type = 7; // ::ONNX_NAMESPACE::TensorProto_DataType_INT64;
val = val.to(at::ScalarType::Double);
break;
default:
throw std::runtime_error("Unsupported types: complex, string");
}
// create a cast node
node->removeAttribute(attr::value);
node->t_(attr::value, val);
Node* cast_node = graph->create(onnx::Cast, 1);
cast_node->i_(attr::to, to_type);
cast_node->insertAfter(node);
// get input from cast node
node->outputs().at(0)->replaceAllUsesWith(cast_node->outputs().at(0));
// add input from constant to cast node
cast_node->addInput(node->outputs().at(0));
}
}
}
}
void CastAllConstantToFloating(const std::shared_ptr<Graph>& graph) {
CastAllConstantToFloating(graph->block());
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,12 @@
#pragma once
#include <torch/csrc/jit/ir.h>
#include <memory>
namespace torch {
namespace jit {
// see .cpp for docs
TORCH_API void CastAllConstantToFloating(const std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch

View File

@ -351,7 +351,7 @@ void hackFixupPadPackedShapes(Block* graph) {
}
}
void fixDefaultRNNState(Graph* graph, Node* n, int input_index) {
void fixDefaultRNNState(Graph* graph, Node* n, int input_index, int opset_version) {
auto initial_state = n->inputs()[input_index];
// The RNN code in pytorch accepts an optional hidden state. When it
@ -420,21 +420,29 @@ void fixDefaultRNNState(Graph* graph, Node* n, int input_index) {
concated_dims->addInput(unsqueezed_batch_size->outputs()[0]);
concated_dims->addInput(hidden_size->outputs()[0]);
Node* constant_of_shape = graph->create(onnx::ConstantOfShape, 1);
constant_of_shape->insertBefore(n);
constant_of_shape->addInput(concated_dims->outputs()[0]);
n->replaceInput(input_index, constant_of_shape->outputs()[0]);
if (opset_version < 9) {
Node* constant_fill = graph->create(onnx::ConstantFill, 1);
constant_fill->insertBefore(n);
constant_fill->i_(attr::input_as_shape, 1);
constant_fill->addInput(concated_dims->outputs()[0]);
n->replaceInput(input_index, constant_fill->outputs()[0]);
} else {
Node* constant_of_shape = graph->create(onnx::ConstantOfShape, 1);
constant_of_shape->insertBefore(n);
constant_of_shape->addInput(concated_dims->outputs()[0]);
n->replaceInput(input_index, constant_of_shape->outputs()[0]);
}
if (initial_state->uses().size() == 0) {
initial_state->node()->destroy();
}
}
void fixDefaultRnnHiddenState(Block* b) {
void fixDefaultRnnHiddenState(Block* b, int opset_version) {
for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
auto* n = *it;
for (auto* child_block : n->blocks()) {
fixDefaultRnnHiddenState(child_block);
fixDefaultRnnHiddenState(child_block, opset_version);
}
if (!isRNN(n)) {
@ -445,15 +453,15 @@ void fixDefaultRnnHiddenState(Block* b) {
if (n->inputs().size() < 6) {
continue;
}
fixDefaultRNNState(b->owningGraph(), n, 5);
fixDefaultRNNState(b->owningGraph(), n, 5, opset_version);
}
}
void fixDefaultLstmCellState(Block* b) {
void fixDefaultLstmCellState(Block* b, int opset_version) {
for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
auto* n = *it;
for (auto* child_block : n->blocks()) {
fixDefaultLstmCellState(child_block);
fixDefaultLstmCellState(child_block, opset_version);
}
if (n->kind() != onnx::LSTM) {
@ -464,7 +472,7 @@ void fixDefaultLstmCellState(Block* b) {
if (n->inputs().size() < 7) {
continue;
}
fixDefaultRNNState(b->owningGraph(), n, 6);
fixDefaultRNNState(b->owningGraph(), n, 6, opset_version);
}
}
@ -625,15 +633,15 @@ void removeMaxPoolUnusedOutput(Block* b) {
// writing your optimization in jit/passes/peephole.cpp rather than
// here, as it will be generally applicable to the JIT as well. The
// optimizations here are ONLY applied on ONNX update
void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph) {
void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph, int opset_version) {
// TODO: decide on fixpoint strategy
// TODO: make it easier not to do O(k) iterations over the graph, where
// k is the number of distinct peephole optimizations
hackFixupPadPackedShapes(graph->block());
pushPackingPastRnn(graph->block());
removeNopPacking(graph->block());
fixDefaultRnnHiddenState(graph->block());
fixDefaultLstmCellState(graph->block());
fixDefaultRnnHiddenState(graph->block(), opset_version);
fixDefaultLstmCellState(graph->block(), opset_version);
fuseBroadcast(graph->block());
fuseConsecutiveTransposes(graph->block());
eliminateNopTranspose(graph->block());

View File

@ -5,7 +5,7 @@
namespace torch {
namespace jit {
void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph);
void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph, int opset_version);
}
} // namespace torch

View File

@ -177,9 +177,9 @@ def _unimplemented(op, msg):
def _black_list_in_opset(name):
def symbolic_fn(*args, **kwargs):
warnings.warn("ONNX export failed on {}, which is not yet implemented for opset 10. "
"Try exporting with a previous opset version."
.format(name))
raise RuntimeError("ONNX export failed on {}, which is not implemented for opset {}. "
"Try exporting with other opset versions."
.format(name, _export_onnx_opset_version))
return symbolic_fn
@ -192,10 +192,10 @@ def _try_get_scalar_type(*args):
return None
def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
if _export_onnx_opset_version == 9:
if _export_onnx_opset_version <= 9:
from torch.onnx.symbolic_opset9 import _slice
return _slice(g, input, axes, starts, ends)
if _export_onnx_opset_version == 10:
else:
from torch.onnx.symbolic_opset10 import _slice
return _slice(g, input, axes, starts, ends, steps, dynamic_slice)
@ -228,7 +228,7 @@ def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False)
_default_onnx_opset_version = 9
_onnx_master_opset = 10
_onnx_stable_opsets = [9, 10]
_onnx_stable_opsets = [7, 8, 9, 10]
_export_onnx_opset_version = _default_onnx_opset_version

View File

@ -65,8 +65,8 @@ def _max_pool(name, tuple_fn, ndims, return_indices):
strides_i=[1 for _ in range(ndims)])
# convert indices to have non-flattened indices values
from torch.onnx.symbolic_opset9 import sub
s = _slice_op(g, flattened_indices, axes=[2 + i for i in range(ndims)],
starts=tuple_fn(0), ends=tuple_fn(1))
s = sym_help._slice_helper(g, flattened_indices, axes=[2 + i for i in range(ndims)],
starts=tuple_fn(0), ends=tuple_fn(1))
indices = sub(g, indices, s)
return r, indices
else:

View File

@ -0,0 +1,47 @@
from torch.onnx.symbolic_helper import _black_list_in_opset
import torch.onnx.symbolic_opset9 as sym_opset9
import warnings
# Note [ONNX operators that are added/updated from opset 7 to opset 8]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# New operators:
# Expand
#
# Updated operators:
# Min, Max, Sum, Mean: supports multidirectional broadcasting.
# MaxPool: added optional indices output.
# Scan
black_listed_operators = [
"scan", "expand", "expand_as",
"adaptive_max_pool1d", "adaptive_max_pool2d", "adaptive_max_pool3d",
"max_pool1d_with_indices", "max_pool2d_with_indices", "max_pool3d_with_indices"
]
# NOTE: max, min, sum, mean: broadcasting is not supported in opset 7.
# torch.max (same for torch.min) actually has two interfaces smashed together:
# torch.max(x, dim, keepdim) and torch.max(x, y)
def max(g, self, dim_or_y=None, keepdim=None):
# torch.max(input, other)
if keepdim is None and dim_or_y is not None:
warnings.warn("Multidirectional broadcasting is not supported in opset 7. "
"This might cause the onnx model to be incorrect, if inputs to max operators "
"have different shapes")
return sym_opset9.max(g, self, dim_or_y, keepdim)
def min(g, self, dim_or_y=None, keepdim=None):
# torch.min(input, other)
if keepdim is None and dim_or_y is not None:
warnings.warn("Multidirectional broadcasting is not supported in opset 7. "
"This might cause the onnx model to be incorrect, if inputs to min operators "
"have different shapes")
return sym_opset9.min(g, self, dim_or_y, keepdim)
for black_listed_op in black_listed_operators:
vars()[black_listed_op] = _black_list_in_opset(black_listed_op)

View File

@ -0,0 +1,240 @@
import torch
import torch.onnx.symbolic_helper as sym_help
import torch.onnx.symbolic_opset9 as sym_opset9
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _black_list_in_opset, _try_get_scalar_type
from torch.onnx.symbolic_opset9 import wrap_logical_op_with_cast_to, _cast_Float
import warnings
# Note [ONNX operators that are added/updated from opset 8 to opset 9]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# New operators:
# Compress
# ConstantOfShape
# EyeLike
# MaxUnpool
# OneHot
# Sinh
# Cosh
# Asinh
# Acosh
# Atanh
# Shrink
# IsNaN
# Sign
# Erf
# Scatter
# Where
# NonZero
# TfIdfVectorizer
# MeanVarianceNormalization
#
# Updated operators:
# BatchNormalization: removed spatial attribute.
# Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported.
# Cast: more data types{string} supported.
# Upsample: moved scales from attribute to input.
# Scan
black_listed_operators = [
"nonzero", "where", "scatter", "scatter_add", "erf", "sign", "isnan", "gather",
]
for black_listed_op in black_listed_operators:
vars()[black_listed_op] = _black_list_in_opset(black_listed_op)
def upsample_nearest2d(g, input, output_size, align_corners=None):
align_corners = sym_help._maybe_get_scalar(align_corners)
if align_corners:
return _unimplemented("upsample_neareset2d", "align_corners == True")
output_size = sym_help._maybe_get_const(output_size, 'is')
if sym_help._is_value(output_size):
return _unimplemented("upsample_nearest2d", "torch._C.Value (output_size) indexing")
else:
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
scales = [1., 1., height_scale, width_scale]
return g.op("Upsample", input, mode_s="nearest",
scales_f=scales)
# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation
# issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which
# is lost after casting.
def _try_cast_integer_to_float(g, *args):
floating_scalar_types = ['Half', 'Float', 'Double']
old_type = None
# Cast the input tensor to Float if its scalarType is known and is not floating number.
# If casting is performed, return the old scalarType, otherwise return None.
if args[0].type().kind() == "DimensionedTensorType" or args[0].type().kind() == "CompleteTensorType":
old_type = args[0].type().scalarType()
if old_type not in floating_scalar_types:
args = tuple(_cast_Float(g, arg, False) for arg in args)
else:
return (None,) + args
else:
warnings.warn("Only floating datatype is supported for these operators: "
"{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause "
"the onnx model to be incorrect, if inputs have integer datatypes.")
return (old_type,) + args
def _cast_to_type(g, input, to_type):
if to_type is None:
return input
return getattr(sym_opset9, '_cast_{}'.format(to_type))(g, input, False)
def _comparison_operator(g, input, other, op_name):
other = sym_help._maybe_get_scalar(other)
other = sym_help._if_scalar_type_as(g, other, input)
_, input, other = _try_cast_integer_to_float(g, input, other)
return g.op(op_name, input, other)
# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten},
# integer input type not supported in opset8. Cast to float if possible.
@wrap_logical_op_with_cast_to('Byte')
def gt(g, input, other):
return _comparison_operator(g, input, other, "Greater")
@wrap_logical_op_with_cast_to('Byte')
def lt(g, input, other):
return _comparison_operator(g, input, other, "Less")
def bmm(g, self, other):
if _try_get_scalar_type(self):
old_type, self, other = _try_cast_integer_to_float(g, self, other)
return _cast_to_type(g, g.op("MatMul", self, other), old_type)
else:
return g.op("MatMul", self, other)
def matmul(g, self, other):
return bmm(g, self, other)
def prelu(g, self, weight):
if self.isCompleteTensor():
self_sizes = self.type().sizes()
if self_sizes and len(self_sizes) > 2:
weight = g.op("Unsqueeze", weight, axes_i=list(range(1, len(self_sizes) - 1)))
if _try_get_scalar_type(self):
old_type, self, weight = _try_cast_integer_to_float(g, self, weight)
return _cast_to_type(g, g.op("PRelu", self, weight), old_type)
else:
return g.op("PRelu", self, weight)
def mm(g, self, other):
# Create a dummy C tensor. Only needed for API purposes, the value is
# since beta = 0
ty = sym_help._try_get_scalar_type(self, other).lower()
C = g.constant(0, [1], ty)
if _try_get_scalar_type(self):
old_type, self, other, C = _try_cast_integer_to_float(g, self, other, C)
return _cast_to_type(g, g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0), old_type)
else:
return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
@parse_args('v', 'v', 'v', 't', 't')
def addmm(g, self, mat1, mat2, beta, alpha):
if _try_get_scalar_type(self):
old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2)
return _cast_to_type(
g, g.op("Gemm", mat1, mat2, self,
beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha)), old_type)
else:
return g.op("Gemm", mat1, mat2, self, beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha))
def view(g, self, size):
size = sym_help._maybe_get_const(size, 'is')
if sym_help._is_value(size):
shape = size
else:
if self.isCompleteTensor():
self_sizes = self.type().sizes()
if self_sizes and len(size) == 2 and self_sizes[0] == size[0]:
old_type, self = _try_cast_integer_to_float(g, self)
return _cast_to_type(g, g.op("Flatten", self, axis_i=1), old_type)
shape = g.op("Constant", value_t=torch.LongTensor(size))
return g.op("Reshape", self, shape)
def flatten(g, input, start_dim, end_dim):
start_dim_i = sym_help._get_const(start_dim, 'i', 'start_dim')
end_dim_i = sym_help._get_const(end_dim, 'i', 'end_dim')
dim = input.type().dim()
if end_dim_i < 0 :
end_dim_i = dim + end_dim_i
# use ONNX's Flatten operator for cases where the output shape is 2D
if start_dim_i == 1 and end_dim_i == dim - 1 :
if _try_get_scalar_type(input):
old_type, input = _try_cast_integer_to_float(g, input)
return _cast_to_type(g, g.op("Flatten", input, axis_i=start_dim_i), old_type)
else:
return g.op("Flatten", input, axis_i=start_dim_i)
if start_dim_i == 0 and end_dim_i == dim - 2 :
if _try_get_scalar_type(input):
old_type, input = _try_cast_integer_to_float(g, input)
return _cast_to_type(g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type)
else:
return g.op("Flatten", input, axis_i=end_dim_i + 1)
return sym_opset9.flatten(g, input, start_dim, end_dim)
def _constant_fill(g, sizes, dtype, const_value):
if not sym_help.scalar_type_to_pytorch_type[dtype].is_floating_point:
result = g.op(
"ConstantFill", sizes, dtype_i=sym_help.cast_pytorch_to_onnx["Float"], input_as_shape_i=1, value_f=const_value)
return sym_help._cast_func_template(sym_help.scalar_type_to_onnx[dtype], g, result, None)
else:
return g.op("ConstantFill", sizes, dtype_i=sym_help.scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=const_value)
@parse_args('v', 'i', 'v', 'v', 'v')
def zeros(g, sizes, dtype, layout, device, pin_memory=False):
# NOTE: no way to set device and layout in ONNX, so we ignore it
return _constant_fill(g, sizes, dtype, 0)
@parse_args('v', 'i', 'v', 'v', 'v')
def zeros_like(g, input, dtype, layout, device, pin_memory=False):
shape = g.op("Shape", input)
return _constant_fill(g, shape, dtype, 0)
@parse_args('v', 'i', 'v', 'v', 'v')
def ones(g, sizes, dtype, layout, device, pin_memory=False):
return _constant_fill(g, sizes, dtype, 1)
@parse_args('v', 'i', 'v', 'v', 'v')
def ones_like(g, input, dtype, layout, device, pin_memory=False):
shape = g.op("Shape", input)
return _constant_fill(g, shape, dtype, 1)
def full(g, sizes, value, dtype, layout, device, pin_memory=False):
const_value = sym_help._maybe_get_const(value, 't')
if sym_help._is_value(const_value):
tmp = zeros(g, sizes, dtype, layout, device)
return sym_opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
else:
dtype = sym_help._get_const(dtype, 'i', 'dtype')
return _constant_fill(g, sizes, dtype, const_value)
@parse_args('v', 'f', 'i', 'v', 'v', 'v')
def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False):
shape = g.op("Shape", input)
return _constant_fill(g, shape, dtype, fill_value)

View File

@ -23,19 +23,40 @@ def register_version(domain, version):
register_ops_in_version(domain, version)
def register_ops_helper(domain, version, iter_version):
version_ops = get_ops_in_version(iter_version)
for op in version_ops:
if isfunction(op[1]) and not is_registered_op(op[0], domain, version):
register_op(op[0], op[1], domain, version)
def register_ops_in_version(domain, version):
# iterates through the symbolic functions of
# the specified opset version, and the previous
# opset versions for operators supported in
# previous versions
# previous versions.
# Opset 9 is the base version. It is selected as the base version because
# 1. It is the first opset version supported by PyTorch export.
# 2. opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
# that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
# we chose to handle them as special cases separately.
# Backward support for opset versions beyond opset 7 is not in our roadmap.
# For opset versions other than 9, by default they will inherit the symbolic functions defined in
# symbolic_opset9.py.
# To extend support for updated operators in different opset versions on top of opset 9,
# simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
# Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
iter_version = version
while iter_version >= 9:
version_ops = get_ops_in_version(iter_version)
for op in version_ops:
if isfunction(op[1]) and \
not is_registered_op(op[0], domain, version):
register_op(op[0], op[1], domain, version)
iter_version = iter_version - 1
while iter_version != 9:
register_ops_helper(domain, version, iter_version)
if iter_version > 9:
iter_version = iter_version - 1
else:
iter_version = iter_version + 1
register_ops_helper(domain, version, 9)
def get_ops_in_version(version):

View File

@ -231,7 +231,8 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
graph = torch._C._jit_pass_onnx(graph, operator_export_type)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_onnx_peephole(graph)
from torch.onnx.symbolic_helper import _export_onnx_opset_version
torch._C._jit_pass_onnx_peephole(graph, _export_onnx_opset_version)
torch._C._jit_pass_lint(graph)
# graph is not a valid jit graph anymore because types have been replaced
@ -354,6 +355,11 @@ def _model_to_graph(model, args, verbose=False, training=False,
_export_onnx_opset_version)
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
# For ONNX opset < 9, constants only have three data types: float16, float, double.
# In this pass transform constants of other data types to float/double + cast operator.
if _export_onnx_opset_version < 9:
torch._C._jit_pass_onnx_cast_all_constant_to_floating(graph)
if verbose:
print(graph)