Revert D19710370: [pytorch][PR] ONNX Update training ops and training amenable export API

Test Plan: revert-hammer

Differential Revision:
D19710370

Original commit changeset: e5e79d385529

fbshipit-source-id: d0114dc561a3415869805d3fbf43b92730bbcf54
This commit is contained in:
Alban Desmaison
2020-03-27 06:47:14 -07:00
committed by Facebook GitHub Bot
parent e5cd17cc9e
commit 45e1be9762
19 changed files with 145 additions and 732 deletions

View File

@ -160,7 +160,6 @@ jobs:
-g"-torch/csrc/jit/export.cpp" \
-g"-torch/csrc/jit/import.cpp" \
-g"-torch/csrc/jit/netdef_converter.cpp" \
-g"-torch/csrc/onnx/init.cpp" \
"$@" > ${GITHUB_WORKSPACE}/clang-tidy-output.txt
cat ${GITHUB_WORKSPACE}/clang-tidy-output.txt

View File

@ -54,7 +54,6 @@ pytest "${args[@]}" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py" \
--ignore "$top_dir/test/onnx/test_custom_ops.py" \
--ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \
--ignore "$top_dir/test/onnx/test_utility_funs.py" \
"${test_paths[@]}"
# onnxruntime only support py3
@ -65,8 +64,7 @@ if [[ "$BUILD_ENVIRONMENT" == *ort1-py3.6* ]]; then
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset8" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime" \
"$top_dir/test/onnx/test_custom_ops.py" \
"$top_dir/test/onnx/test_models_onnxruntime.py" \
"$top_dir/test/onnx/test_utility_funs.py"
"$top_dir/test/onnx/test_models_onnxruntime.py"
fi
if [[ "$BUILD_ENVIRONMENT" == *ort2-py3.6* ]]; then
# Update the loop for new opsets

View File

@ -1,182 +0,0 @@
ir_version: 6
producer_name: "pytorch"
producer_version: "1.5"
graph {
node {
output: "6"
name: "Constant_0"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 9
raw_data: "\001"
}
type: TENSOR
}
}
node {
input: "input"
input: "weight"
input: "bias"
input: "running_mean"
input: "running_var"
input: "6"
output: "7"
output: "8"
output: "9"
output: "batch_norm_dead_output-14"
output: "batch_norm_dead_output-15"
name: "BatchNormalization_1"
op_type: "BatchNormalization"
attribute {
name: "epsilon"
f: 1e-05
type: FLOAT
}
attribute {
name: "momentum"
f: 0.9
type: FLOAT
}
}
name: "torch-jit-export"
initializer {
dims: 2
data_type: 1
name: "bias"
raw_data: "\000\000\000\000\000\000\000\000"
}
initializer {
data_type: 7
name: "num_batches_tracked"
raw_data: "\001\000\000\000\000\000\000\000"
}
initializer {
dims: 2
data_type: 1
name: "running_mean"
raw_data: "\315\314\314=\315\314\314="
}
initializer {
dims: 2
data_type: 1
name: "running_var"
raw_data: "fff?fff?"
}
initializer {
dims: 2
data_type: 1
name: "weight"
raw_data: "\000\000\200?\000\000\200?"
}
input {
name: "input"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "weight"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "bias"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "running_mean"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "running_var"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "num_batches_tracked"
type {
tensor_type {
elem_type: 7
shape {
}
}
}
}
output {
name: "7"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 12
}

View File

@ -1,46 +0,0 @@
ir_version: 6
producer_name: "pytorch"
producer_version: "1.5"
graph {
node {
input: "x"
output: "1"
name: "ReduceMax_0"
op_type: "ReduceMax"
attribute {
name: "keepdims"
i: 0
type: INT
}
}
name: "torch-jit-export"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "1"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 9
}

View File

@ -1,58 +0,0 @@
ir_version: 6
producer_name: "pytorch"
producer_version: "1.5"
graph {
node {
input: "x"
output: "1"
output: "2"
name: "Dropout_0"
op_type: "Dropout"
attribute {
name: "ratio"
f: 0.5
type: FLOAT
}
}
node {
input: "1"
output: "3"
name: "ReduceMax_1"
op_type: "ReduceMax"
attribute {
name: "keepdims"
i: 0
type: INT
}
}
name: "torch-jit-export"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "3"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 9
}

View File

@ -1,67 +0,0 @@
ir_version: 6
producer_name: "pytorch"
producer_version: "1.5"
graph {
node {
output: "1"
name: "Constant_0"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 1
raw_data: "\000\000\000?"
}
type: TENSOR
}
}
node {
input: "x"
input: "1"
output: "2"
output: "3"
name: "Dropout_1"
op_type: "Dropout"
}
node {
input: "2"
output: "4"
name: "ReduceMax_2"
op_type: "ReduceMax"
attribute {
name: "keepdims"
i: 0
type: INT
}
}
name: "torch-jit-export"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "4"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 12
}

View File

@ -37,10 +37,9 @@ BATCH_SIZE = 2
class TestModels(TestCase):
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7):
with torch.onnx.select_model_mode_for_export(model, None):
graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
torch._C._jit_pass_lint(graph)
verify(model, inputs, backend, rtol=rtol, atol=atol)
graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
torch._C._jit_pass_lint(graph)
verify(model, inputs, backend, rtol=rtol, atol=atol)
def test_ops(self):
x = Variable(

View File

@ -45,7 +45,7 @@ def check_onnx_opset_operator(model, ops, opset_version=_export_onnx_opset_versi
assert attributes[j][attribute_field] == getattr(graph.node[i].attribute[j], attribute_field)
def check_onnx_opsets_operator(module, x, ops, opset_versions, training=torch.onnx.TrainingMode.EVAL, example_outputs=None):
def check_onnx_opsets_operator(module, x, ops, opset_versions, training=False, example_outputs=None):
for opset_version in opset_versions:
f = io.BytesIO()
torch.onnx.export(module, x, f,
@ -238,12 +238,12 @@ class TestONNXOpset(TestCase):
# test training mode
ops = [{"op_name" : "Dropout", "attributes" : [{"name" : "ratio", "f" : 0.5, "type" : 1}]}]
ops = {9 : ops, 10 : ops}
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10], training=torch.onnx.TrainingMode.TRAINING)
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10], training=True)
# test eval mode
ops = []
ops = {9 : ops, 10 : ops}
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10], training=torch.onnx.TrainingMode.EVAL)
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10], training=False)
def test_full(self):
class MyModule(Module):

View File

@ -16,6 +16,7 @@ import os
import shutil
import torch.testing._internal.common_utils as common
'''Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
--no-onnx: no onnx python dependence
--produce-onnx-test-data: generate onnx test data
@ -254,12 +255,7 @@ class TestOperators(TestCase):
def test_batchnorm_training(self):
x = torch.ones(2, 2, 2, 2, requires_grad=True)
self.assertONNX(nn.BatchNorm2d(2), x, training=torch.onnx.TrainingMode.TRAINING, keep_initializers_as_inputs=True)
def test_batchnorm_training_opset12(self):
x = torch.ones(2, 2, 2, 2, requires_grad=True)
self.assertONNX(nn.BatchNorm2d(2), x, training=torch.onnx.TrainingMode.TRAINING,
keep_initializers_as_inputs=True, opset_version=12)
self.assertONNX(nn.BatchNorm2d(2), x, training=True, keep_initializers_as_inputs=True)
def test_conv(self):
x = torch.ones(20, 16, 50, 40, requires_grad=True)
@ -676,18 +672,6 @@ class TestOperators(TestCase):
x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.max(functional.dropout(x, training=False)), x)
def test_dropout_default(self):
x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.max(functional.dropout(x,)), x)
def test_dropout_training(self):
x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.max(functional.dropout(x)), x, training=torch.onnx.TrainingMode.TRAINING)
def test_dropout_training_opset12(self):
x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.max(functional.dropout(x)), x, opset_version=12, training=torch.onnx.TrainingMode.TRAINING)
def test_nonzero(self):
x = torch.tensor([[[2., 2.], [1., 0.]], [[0., 0.], [1., 1.]]], requires_grad=True)
self.assertONNX(lambda x: torch.nonzero(x), x)

View File

@ -5,11 +5,8 @@ import torch
import torch.onnx
from torch.onnx import utils, OperatorExportTypes
from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type
from test_pytorch_common import skipIfUnsupportedOpsetVersion
import onnx
import onnxruntime # noqa
import numpy as np
import io
import copy
@ -55,8 +52,6 @@ class TestUtilityFuns(TestCase):
assert "Provided key invalid_name2 for dynamic axes is not a valid input/output name" in messages
assert len(messages) == 2
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_transpose(self):
class TransposeModule(torch.nn.Module):
def forward(self, x):
@ -77,8 +72,6 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_slice(self):
class NarrowModule(torch.nn.Module):
def forward(self, x):
@ -99,8 +92,6 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_slice_index_exceeds_dim(self):
class SliceIndexExceedsDimModule(torch.nn.Module):
def forward(self, x):
@ -122,8 +113,6 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_slice_negative_index(self):
class SliceNegativeIndexModule(torch.nn.Module):
def forward(self, x):
@ -144,8 +133,6 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_unsqueeze(self):
class UnsqueezeModule(torch.nn.Module):
def forward(self, x):
@ -166,8 +153,6 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_concat(self):
class ConcatModule(torch.nn.Module):
def forward(self, x):
@ -205,8 +190,6 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 2
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_lstm(self):
class GruNet(torch.nn.Module):
def __init__(self):
@ -229,8 +212,6 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Unsqueeze"
assert len(list(graph.nodes())) == 3
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_transpose_matmul(self):
class MatMulNet(torch.nn.Module):
def __init__(self):
@ -252,8 +233,6 @@ class TestUtilityFuns(TestCase):
# TODO we need to figure out the root cause and fix the problem
@skip("causing segmentation fault")
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_reshape(self):
class ReshapeModule(torch.nn.Module):
def __init__(self, ):
@ -273,8 +252,6 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Reshape"
assert len(list(graph.nodes())) == 1
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_div(self):
class Module(torch.nn.Module):
def __init__(self, ):
@ -294,8 +271,6 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Div"
assert len(list(graph.nodes())) == 1
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_mul(self):
class Module(torch.nn.Module):
def __init__(self, ):
@ -315,8 +290,6 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Mul"
assert len(list(graph.nodes())) == 1
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_sqrt(self):
class Module(torch.nn.Module):
def __init__(self, ):
@ -369,95 +342,6 @@ class TestUtilityFuns(TestCase):
'unwrap model from torch.nn.DataParallel. Try '):
torch.onnx.export(model, x, f, opset_version=self.opset_version)
def test_export_mode(self):
class MyModule(torch.nn.Module):
def forward(self, x):
y = x + 1
return y
model = MyModule()
x = torch.randn(10, 3, 128, 128)
f = io.BytesIO()
# set mode to in inference mode and export in training mode
model.eval()
old_state = model.training
torch.onnx.export(model, (x,), f,
opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING)
# verify that the model state is preserved
assert model.training == old_state
# set mode to training mode and export in inference mode
model.train()
old_state = model.training
torch.onnx.export(model, (x,), f,
opset_version=self.opset_version, training=torch.onnx.TrainingMode.EVAL)
# verify that the model state is preserved
assert model.training == old_state
# TODO: Enable test when BatchNorm is implemented in ORT for opset 12.
@skipIfUnsupportedOpsetVersion([12])
def test_batchnorm_training(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.bn = torch.nn.BatchNorm2d(3, affine=True)
def forward(self, x):
bn = self.bn(x)
return bn
model = MyModule()
x = torch.randn(10, 3, 128, 128)
model.train()
out = model(x)
# state after 1 train epoch
running_mean = model.bn.running_mean
running_var = model.bn.running_var
saved_mean = x.mean((0, 2, 3))
saved_var = x.var((0, 2, 3))
pytorch_out = [out.detach().numpy(),
running_mean.cpu().numpy(), running_var.cpu().numpy(),
saved_mean.cpu().numpy(), saved_var.cpu().numpy()]
model_export = MyModule()
f = io.BytesIO()
torch.onnx.export(model_export, (x,), f,
opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING)
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = {ort_sess.get_inputs()[0].name : x.cpu().numpy()}
ort_outs = ort_sess.run(None, ort_inputs)
[np.testing.assert_allclose(p_out, ort_out, atol=10e-3, rtol=10e-3) for p_out, ort_out in zip(pytorch_out, ort_outs)]
# TODO: Enable test when Dropout is implemented in ORT for opset 12.
@skipIfUnsupportedOpsetVersion([12])
def test_dropout_training(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.dropout = torch.nn.Dropout(0.4)
def forward(self, x):
dropout = self.dropout(x)
return dropout
model = MyModule()
x = torch.randn(10, 3, 128, 128)
model.train()
f = io.BytesIO()
torch.onnx.export(model, (x,), f,
opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING)
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = {ort_sess.get_inputs()[0].name : x.cpu().numpy()}
ort_outs = ort_sess.run(None, ort_inputs)
assert x != ort_outs[0]
# opset 10 tests
TestUtilityFuns_opset10 = type(str("TestUtilityFuns_opset10"),
@ -470,11 +354,6 @@ TestUtilityFuns_opset11 = type(str("TestUtilityFuns_opset11"),
(TestCase,),
dict(TestUtilityFuns.__dict__, opset_version=11))
# opset 12 tests
TestUtilityFuns_opset12 = type(str("TestUtilityFuns_opset12"),
(TestCase,),
dict(TestUtilityFuns.__dict__, opset_version=12))
# opset 12tests
TestUtilityFuns_opset12 = type(str("TestUtilityFuns_opset12"),

View File

@ -8,6 +8,7 @@ import onnx.helper
import numpy as np
import difflib
import contextlib
import io
@ -225,7 +226,24 @@ class Errors(object):
if exc_type == self.exc_class:
raise RuntimeError("ShortCircuit was raised, but no errors were recorded")
def verify(model, args, backend, verbose=False, training=torch.onnx.TrainingMode.EVAL, rtol=1e-3, atol=1e-7,
@contextlib.contextmanager
def set_training(model, mode):
"""
A context manager to temporarily set the training mode of 'model'
to 'mode', resetting it when we exit the with-block.
"""
old_mode = model.training
if old_mode != mode:
model.train(mode)
try:
yield
finally:
if old_mode != mode:
model.train(old_mode)
def verify(model, args, backend, verbose=False, training=False, rtol=1e-3, atol=1e-7,
test_args=2, do_constant_folding=True, example_outputs=None, opset_version=None,
keep_initializers_as_inputs=True, add_node_names=False):
"""
@ -341,7 +359,7 @@ def verify(model, args, backend, verbose=False, training=torch.onnx.TrainingMode
if isinstance(args, torch.Tensor):
args = (args,)
with torch.onnx.select_model_mode_for_export(model, training):
with set_training(model, training):
proto_bytes = io.BytesIO()
torch_out = torch.onnx._export(model, args, proto_bytes, verbose=verbose,
do_constant_folding=do_constant_folding,

View File

@ -30,11 +30,6 @@ void initONNXBindings(PyObject* module) {
.value("ONNX_ATEN_FALLBACK", OperatorExportTypes::ONNX_ATEN_FALLBACK)
.value("RAW", OperatorExportTypes::RAW);
py::enum_<TrainingMode>(onnx, "TrainingMode")
.value("EVAL", TrainingMode::EVAL)
.value("PRESERVE", TrainingMode::PRESERVE)
.value("TRAINING", TrainingMode::TRAINING);
onnx.attr("IR_VERSION") = IR_VERSION;
onnx.attr("PRODUCER_VERSION") = py::str(PRODUCER_VERSION);

View File

@ -9,12 +9,6 @@ enum class OperatorExportTypes {
RAW, // Raw export (no ONNX)
};
enum class TrainingMode {
EVAL, // Inference mode
PRESERVE, // Preserve model state (eval/training)
TRAINING, // Training mode
};
// we pin IR version to version 6 (12/11/2019) instead of using
// onnx::IR_VERSION. with this change, the test_operators.py will be more
// stable. only bump it when it's necessary

View File

@ -2,7 +2,6 @@ import torch._C as _C
TensorProtoDataType = _C._onnx.TensorProtoDataType
OperatorExportTypes = _C._onnx.OperatorExportTypes
TrainingMode = _C._onnx.TrainingMode
PYTORCH_ONNX_CAFFE2_BUNDLE = _C._onnx.PYTORCH_ONNX_CAFFE2_BUNDLE
ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
@ -29,7 +28,7 @@ def _export(*args, **kwargs):
return result
def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
def export(model, args, f, export_params=True, verbose=False, training=False,
input_names=None, output_names=None, aten=False, export_raw_ir=False,
operator_export_type=None, opset_version=None, _retain_param_name=True,
do_constant_folding=True, example_outputs=None, strip_doc_string=True,
@ -60,11 +59,9 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
as arguments, the ordering as specified by ``model.state_dict().values()``
verbose (bool, default False): if specified, we will print out a debug
description of the trace being exported.
training (enum, default TrainingMode.EVAL):
TrainingMode.EVAL: export the model in inference mode.
TrainingMode.PRESERVE: export the model in inference mode if model.training is
False and to a training friendly mode if model.training is True.
TrainingMode.TRAINING: export the model in a training friendly mode.
training (bool, default False): export the model in training mode. At
the moment, ONNX is oriented towards exporting models for inference
only, so you will generally not need to set this to True.
input_names(list of strings, default empty list): names to assign to the
input nodes of the graph, in order
output_names(list of strings, default empty list): names to assign to the
@ -187,7 +184,7 @@ def _optimize_trace(graph, operator_export_type):
return utils._optimize_graph(graph, operator_export_type)
def select_model_mode_for_export(model, mode):
def set_training(model, mode):
r"""
A context manager to temporarily set the training mode of 'model'
to 'mode', resetting it when we exit the with-block. A no-op if
@ -195,7 +192,7 @@ def select_model_mode_for_export(model, mode):
"""
from torch.onnx import utils
return utils.select_model_mode_for_export(model, mode)
return utils.set_training(model, mode)
def _run_symbolic_function(*args, **kwargs):

View File

@ -413,20 +413,6 @@ def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, na
padding = tuple(tuple_fn(padding))
return padding
def assert_training_mode(op_mode, op_name):
global _training_mode
op_mode = True if op_mode == 1 else False
if op_mode != _training_mode:
op_mode = "training " if op_mode else "inference"
training_mode = "training " if _training_mode else "inference"
# setting the model mode could result in op_mode != _training_mode
# if the model is a FuncModule. In this case we warn the user of
# the state and export depending on training_mode
warnings.warn("ONNX export mode is set to " + training_mode +
" mode, but operator " + op_name + " is set to " +
op_mode + " mode. The model will be exported in " +
training_mode + ", as specified by the export mode.")
# ---------------------------------------------------------------------
# ONNX operator version
# ---------------------------------------------------------------------
@ -475,11 +461,6 @@ def _set_operator_export_type(operator_export_type):
global _operator_export_type
_operator_export_type = operator_export_type
_training_mode = None
def _set_training_mode(training_mode):
global _training_mode
_training_mode = training_mode
# Metaprogram symbolics for each ATen native specialized cast operator.
# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
# ONNX cast node with `to` attribute 'UINT8'

View File

@ -1,6 +1,5 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args
@ -10,62 +9,11 @@ from torch.onnx.symbolic_helper import parse_args
# This file exports ONNX ops for opset 12
black_listed_operators = [
"ArgMin", "ArgMax"
]
@parse_args('s', 'v')
def einsum(g, equation, tensor_list):
tensors = sym_help._unpack_list(tensor_list)
return g.op("Einsum", *tensors, equation_s=equation)
@parse_args('v', 'f', 'i')
def dropout(g, input, p, train):
sym_help.assert_training_mode(train, "dropout")
# in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op
if not sym_help._training_mode:
return input
p = g.op("Constant", value_t=torch.tensor(p))
r, _ = g.op("Dropout", input, p, outputs=2)
return r
@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled):
sym_help.assert_training_mode(training, "batch_norm")
input_sizes = input.type().sizes()
if weight is None or sym_help._is_none(weight):
assert len(input_sizes) > 1
weight_value = torch.tensor([1.] * input_sizes[1]).type(
'torch.' + input.type().scalarType() + 'Tensor')
weight = g.op("Constant", value_t=weight_value)
if bias is None or sym_help._is_none(bias):
assert len(input_sizes) > 1
bias_value = torch.tensor([0.] * input_sizes[1]).type(
'torch.' + input.type().scalarType() + 'Tensor')
bias = g.op("Constant", value_t=bias_value)
if not sym_help._training_mode:
out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var,
epsilon_f=eps,
momentum_f=1 - momentum,
outputs=1)
return out
else:
training_mode = g.op("Constant", value_t=torch.tensor(True))
res, new_running_mean, new_running_var, saved_mean, saved_var = g.op("BatchNormalization",
input,
weight, bias,
running_mean, running_var, training_mode,
epsilon_f=eps,
momentum_f=1 - momentum,
outputs=5)
new_running_mean.setType(running_mean.type())
new_running_var.setType(running_var.type())
saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName())
saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName())
return res
def nll_loss(g, self, target, weight, reduction, ignore_index):
# none reduction : onnx::Constant[value={0}]
# mean reduction : onnx::Constant[value={1}]

View File

@ -1049,7 +1049,6 @@ def conv_transpose3d(g, input, weight, bias, stride, padding, output_padding, gr
@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled):
sym_help.assert_training_mode(training, "dropout")
input_sizes = input.type().sizes()
if weight is None or sym_help._is_none(weight):
@ -1066,8 +1065,8 @@ def batch_norm(g, input, weight, bias, running_mean, running_var, training, mome
out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var,
epsilon_f=eps,
momentum_f=1 - momentum,
outputs=1 if not sym_help._training_mode else 5)
if not sym_help._training_mode:
outputs=1 if not training else 5)
if not training:
return out
else:
res, new_running_mean, new_running_var, saved_mean, saved_var = out
@ -1299,9 +1298,7 @@ def exp(g, self):
@parse_args('v', 'f', 'i')
def dropout(g, input, p, train):
sym_help.assert_training_mode(train, "dropout")
# in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op
if not sym_help._training_mode:
if not train: # in eval mode, dropout is non-op
return input
warnings.warn("Dropout is a training op and should not be exported in inference mode. "
"Make sure to call eval() on the model, and to export it with param training=False.")

View File

@ -17,7 +17,7 @@ import numbers
import warnings
from torch._six import string_classes
from torch.jit import _unique_state_dict
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes, TrainingMode
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes
from torch._C import ListType, _propagate_and_assign_input_shapes, _assign_output_shapes, _check_onnx_proto
@ -31,44 +31,21 @@ def is_in_onnx_export():
@contextlib.contextmanager
def select_model_mode_for_export(model, mode):
if not isinstance(model, torch.jit.ScriptFunction):
is_originally_training = model.training
if mode is None:
mode = TrainingMode.EVAL
# if the model is in training mode but the user did not specify
# to export the model in training mode, export the model in inference
# mode (default) and warn them
if is_originally_training:
warnings.warn("You are exporting the model to ONNX while in training mode with "
"'train' parameter not specified. The model will default to inference mode export. "
"If you wish to export a training amenable ONNX model, specify train=TrainingMode.TRAIN or "
"train=TrainingMode.PRESERVE (to preserve the original model state) in torch.onnx.export().")
# if mode == TrainingMode.EVAL or (mode == TrainingMode.PRESERVE and not is_originally_training) => is_training = False
is_export_training = False
# ONNX opset 12 has better support for training amenable models, with updated
# versions of the dropout and batch_norm operators
if mode == TrainingMode.TRAINING or (mode == TrainingMode.PRESERVE and is_originally_training):
from torch.onnx.symbolic_helper import _export_onnx_opset_version
if _export_onnx_opset_version < 12:
warnings.warn("You are exporting the model in training mode with onnx opset version {}. "
"Opset versions lower than opset 12 will not be able to export nodes such as"
"Dropout and BatchNorm correctly.".format(_export_onnx_opset_version))
is_export_training = True
from torch.onnx.symbolic_helper import _set_training_mode
_set_training_mode(is_export_training)
model.train(is_export_training)
def set_training(model, mode):
if mode is None:
yield
return
old_mode = model.training
if old_mode != mode:
model.train(mode)
try:
yield
finally:
if not isinstance(model, torch.jit.ScriptFunction):
model.train(is_originally_training)
if old_mode != mode:
model.train(old_mode)
def export(model, args, f, export_params=True, verbose=False, training=None,
def export(model, args, f, export_params=True, verbose=False, training=False,
input_names=None, output_names=None, aten=False, export_raw_ir=False,
operator_export_type=None, opset_version=None, _retain_param_name=True,
do_constant_folding=True, example_outputs=None, strip_doc_string=True,
@ -298,15 +275,21 @@ def _trace(func, args, operator_export_type, return_outs=False):
return trace_graph
def _trace_and_get_graph_from_model(model, args):
def _trace_and_get_graph_from_model(model, args, training):
# A basic sanity check: make sure the state_dict keys are the same
# before and after running the model. Fail fast!
orig_state_dict_keys = _unique_state_dict(model).keys()
trace_graph, torch_out, inputs_states = \
torch.jit._get_trace_graph(model, args, _force_outplace=False, _return_inputs_states=True)
warn_on_static_input_change(inputs_states)
# By default, training=False, which is good because running a model in
# training mode could result in internal buffers getting updated, dropout
# getting applied, etc. If you really know what you're doing, you
# can turn training=True (or None, to preserve whatever the original
# training mode was.)
with set_training(model, training):
trace_graph, torch_out, inputs_states = \
torch.jit._get_trace_graph(model, args, _force_outplace=False, _return_inputs_states=True)
warn_on_static_input_change(inputs_states)
if orig_state_dict_keys != _unique_state_dict(model).keys():
raise RuntimeError("state_dict changed after running the tracer; "
@ -315,7 +298,7 @@ def _trace_and_get_graph_from_model(model, args):
return trace_graph, torch_out
def _model_to_graph(model, args, verbose=False,
def _model_to_graph(model, args, verbose=False, training=False,
input_names=None, output_names=None,
operator_export_type=OperatorExportTypes.ONNX,
example_outputs=None, propagate=False,
@ -348,7 +331,7 @@ def _model_to_graph(model, args, verbose=False,
graph = _propagate_and_assign_input_shapes(
model.graph, tuple(in_vars), False, propagate)
else:
graph, torch_out = _trace_and_get_graph_from_model(model, args)
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
state_dict = _unique_state_dict(model)
params = list(state_dict.values())
if _retain_param_name:
@ -404,7 +387,7 @@ def _model_to_graph(model, args, verbose=False,
return graph, params_dict, torch_out
def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=None,
def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False,
input_names=None, output_names=None, aten=False, export_raw_ir=False,
operator_export_type=None, export_type=ExportTypes.PROTOBUF_FILE,
example_outputs=None, propagate=False, google_printer=False,
@ -427,7 +410,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t
custom_opsets=custom_opsets)
def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=None,
def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False,
input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False,
google_printer=False, opset_version=None, _retain_param_name=False,
@ -441,27 +424,27 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False,
custom_opsets = {}
_set_opset_version(opset_version)
_set_operator_export_type(operator_export_type)
with select_model_mode_for_export(model, training):
val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
operator_export_type,
opset_version)
val_add_node_names = _decide_add_node_names(add_node_names, operator_export_type)
val_do_constant_folding = _decide_constant_folding(do_constant_folding, operator_export_type)
graph, params_dict, torch_out = _model_to_graph(model, args, verbose, input_names,
output_names, operator_export_type,
example_outputs, propagate, _retain_param_name,
val_do_constant_folding, fixed_batch_size=fixed_batch_size)
val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
operator_export_type,
opset_version)
val_add_node_names = _decide_add_node_names(add_node_names, operator_export_type)
val_do_constant_folding = _decide_constant_folding(do_constant_folding, operator_export_type)
graph, params_dict, torch_out = _model_to_graph(model, args, verbose,
training, input_names,
output_names, operator_export_type,
example_outputs, propagate, _retain_param_name,
val_do_constant_folding, fixed_batch_size=fixed_batch_size)
return graph._pretty_print_onnx(params_dict, opset_version, False,
operator_export_type, google_printer,
val_keep_init_as_ip, custom_opsets, val_add_node_names)
return graph._pretty_print_onnx(params_dict, opset_version, False,
operator_export_type, google_printer,
val_keep_init_as_ip, custom_opsets, val_add_node_names)
# NOTE: the output `torch_out` will contain the output tensors resulting from
# the trace of a Module. In the case that a torch.nn.ScriptModule is passed in,
# this output will be None, since we are not doing any tracing but rather
# directly extracting the graph.
def _export(model, args, f, export_params=True, verbose=False, training=None,
def _export(model, args, f, export_params=True, verbose=False, training=False,
input_names=None, output_names=None, operator_export_type=None,
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False,
opset_version=None, _retain_param_name=False, do_constant_folding=True,
@ -487,86 +470,80 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
else:
operator_export_type = OperatorExportTypes.ONNX
# By default, training=None, (which defaults to TrainingMode.EVAL),
# which is good because running a model in training mode could result in
# internal buffers getting updated, dropout getting applied, etc.
# If you really know what you're doing, you can turn
# training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE,
# (to preserve whatever the original training mode was.)
with select_model_mode_for_export(model, training):
_set_opset_version(opset_version)
_set_operator_export_type(operator_export_type)
val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
operator_export_type,
opset_version)
val_add_node_names = _decide_add_node_names(add_node_names, operator_export_type)
val_do_constant_folding = _decide_constant_folding(do_constant_folding, operator_export_type)
val_use_external_data_format, model_file_location = _decide_external_data_format(use_external_data_format,
operator_export_type,
f)
graph, params_dict, torch_out = _model_to_graph(model, args, verbose, input_names,
output_names, operator_export_type,
example_outputs, propagate,
_retain_param_name, val_do_constant_folding,
fixed_batch_size=fixed_batch_size)
_set_opset_version(opset_version)
_set_operator_export_type(operator_export_type)
val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
operator_export_type,
opset_version)
val_add_node_names = _decide_add_node_names(add_node_names, operator_export_type)
val_do_constant_folding = _decide_constant_folding(do_constant_folding, operator_export_type)
val_use_external_data_format, model_file_location = _decide_external_data_format(use_external_data_format,
operator_export_type,
f)
graph, params_dict, torch_out = _model_to_graph(model, args, verbose,
training, input_names,
output_names, operator_export_type,
example_outputs, propagate,
_retain_param_name, val_do_constant_folding,
fixed_batch_size=fixed_batch_size)
# TODO: Don't allocate a in-memory string for the protobuf
defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
if dynamic_axes is None:
dynamic_axes = {}
if custom_opsets is None:
custom_opsets = {}
# TODO: Don't allocate a in-memory string for the protobuf
defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
if dynamic_axes is None:
dynamic_axes = {}
if custom_opsets is None:
custom_opsets = {}
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
if export_params:
proto, export_map = graph._export_onnx(
params_dict, opset_version, dynamic_axes, defer_weight_export,
operator_export_type, strip_doc_string, val_keep_init_as_ip, custom_opsets,
val_add_node_names, val_use_external_data_format, model_file_location)
else:
proto, export_map = graph._export_onnx(
{}, opset_version, dynamic_axes, False, operator_export_type,
strip_doc_string, val_keep_init_as_ip, custom_opsets, val_add_node_names,
val_use_external_data_format, model_file_location)
if export_params:
proto, export_map = graph._export_onnx(
params_dict, opset_version, dynamic_axes, defer_weight_export,
operator_export_type, strip_doc_string, val_keep_init_as_ip, custom_opsets,
val_add_node_names, val_use_external_data_format, model_file_location)
else:
proto, export_map = graph._export_onnx(
{}, opset_version, dynamic_axes, False, operator_export_type,
strip_doc_string, val_keep_init_as_ip, custom_opsets, val_add_node_names,
val_use_external_data_format, model_file_location)
if enable_onnx_checker and \
operator_export_type is OperatorExportTypes.ONNX_ATEN_FALLBACK and \
not val_use_external_data_format:
# Only run checker if enabled and we are not using ATEN fallback and
# large model format export in not enabled.
_check_onnx_proto(proto)
if export_type == ExportTypes.PROTOBUF_FILE:
assert(len(export_map) == 0)
with torch.serialization._open_file_like(f, 'wb') as opened_file:
opened_file.write(proto)
elif export_type in [ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE]:
import zipfile
compression = zipfile.ZIP_DEFLATED \
if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
else zipfile.ZIP_STORED
with zipfile.ZipFile(f, 'w', compression=compression) as z:
z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
for k, v in export_map.items():
z.writestr(k, v)
elif export_type == ExportTypes.DIRECTORY:
import os
if os.path.exists(f):
assert(os.path.isdir(f))
else:
os.makedirs(f)
model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME)
with torch.serialization._open_file_like(model_proto_file, 'wb') as opened_file:
opened_file.write(proto)
if enable_onnx_checker and \
operator_export_type is OperatorExportTypes.ONNX and \
not val_use_external_data_format:
# Only run checker if enabled and we are not using ATEN fallback and
# large model format export in not enabled.
_check_onnx_proto(proto)
if export_type == ExportTypes.PROTOBUF_FILE:
assert(len(export_map) == 0)
with torch.serialization._open_file_like(f, 'wb') as opened_file:
opened_file.write(proto)
elif export_type in [ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE]:
import zipfile
compression = zipfile.ZIP_DEFLATED \
if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
else zipfile.ZIP_STORED
with zipfile.ZipFile(f, 'w', compression=compression) as z:
z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
for k, v in export_map.items():
weight_proto_file = os.path.join(f, k)
with torch.serialization._open_file_like(weight_proto_file, 'wb') as opened_file:
opened_file.write(v)
z.writestr(k, v)
elif export_type == ExportTypes.DIRECTORY:
import os
if os.path.exists(f):
assert(os.path.isdir(f))
else:
raise RuntimeError('Unknown export type')
os.makedirs(f)
model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME)
with torch.serialization._open_file_like(model_proto_file, 'wb') as opened_file:
opened_file.write(proto)
for k, v in export_map.items():
weight_proto_file = os.path.join(f, k)
with torch.serialization._open_file_like(weight_proto_file, 'wb') as opened_file:
opened_file.write(v)
else:
raise RuntimeError('Unknown export type')
finally:
assert __IN_ONNX_EXPORT
__IN_ONNX_EXPORT = False

View File

@ -280,7 +280,7 @@ def graph(model, args, verbose=False):
verbose (bool): Whether to print out verbose information while
processing.
"""
with torch.onnx.select_model_mode_for_export(model, torch.onnx.TrainingMode.EVAL): # TODO: move outside of torch.onnx?
with torch.onnx.set_training(model, False): # TODO: move outside of torch.onnx?
try:
trace = torch.jit.trace(model, args)
graph = trace.graph