mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert D19710370: [pytorch][PR] ONNX Update training ops and training amenable export API
Test Plan: revert-hammer Differential Revision: D19710370 Original commit changeset: e5e79d385529 fbshipit-source-id: d0114dc561a3415869805d3fbf43b92730bbcf54
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e5cd17cc9e
commit
45e1be9762
1
.github/workflows/lint.yml
vendored
1
.github/workflows/lint.yml
vendored
@ -160,7 +160,6 @@ jobs:
|
||||
-g"-torch/csrc/jit/export.cpp" \
|
||||
-g"-torch/csrc/jit/import.cpp" \
|
||||
-g"-torch/csrc/jit/netdef_converter.cpp" \
|
||||
-g"-torch/csrc/onnx/init.cpp" \
|
||||
"$@" > ${GITHUB_WORKSPACE}/clang-tidy-output.txt
|
||||
|
||||
cat ${GITHUB_WORKSPACE}/clang-tidy-output.txt
|
||||
|
@ -54,7 +54,6 @@ pytest "${args[@]}" \
|
||||
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py" \
|
||||
--ignore "$top_dir/test/onnx/test_custom_ops.py" \
|
||||
--ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \
|
||||
--ignore "$top_dir/test/onnx/test_utility_funs.py" \
|
||||
"${test_paths[@]}"
|
||||
|
||||
# onnxruntime only support py3
|
||||
@ -65,8 +64,7 @@ if [[ "$BUILD_ENVIRONMENT" == *ort1-py3.6* ]]; then
|
||||
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset8" \
|
||||
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime" \
|
||||
"$top_dir/test/onnx/test_custom_ops.py" \
|
||||
"$top_dir/test/onnx/test_models_onnxruntime.py" \
|
||||
"$top_dir/test/onnx/test_utility_funs.py"
|
||||
"$top_dir/test/onnx/test_models_onnxruntime.py"
|
||||
fi
|
||||
if [[ "$BUILD_ENVIRONMENT" == *ort2-py3.6* ]]; then
|
||||
# Update the loop for new opsets
|
||||
|
@ -1,182 +0,0 @@
|
||||
ir_version: 6
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.5"
|
||||
graph {
|
||||
node {
|
||||
output: "6"
|
||||
name: "Constant_0"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
data_type: 9
|
||||
raw_data: "\001"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "input"
|
||||
input: "weight"
|
||||
input: "bias"
|
||||
input: "running_mean"
|
||||
input: "running_var"
|
||||
input: "6"
|
||||
output: "7"
|
||||
output: "8"
|
||||
output: "9"
|
||||
output: "batch_norm_dead_output-14"
|
||||
output: "batch_norm_dead_output-15"
|
||||
name: "BatchNormalization_1"
|
||||
op_type: "BatchNormalization"
|
||||
attribute {
|
||||
name: "epsilon"
|
||||
f: 1e-05
|
||||
type: FLOAT
|
||||
}
|
||||
attribute {
|
||||
name: "momentum"
|
||||
f: 0.9
|
||||
type: FLOAT
|
||||
}
|
||||
}
|
||||
name: "torch-jit-export"
|
||||
initializer {
|
||||
dims: 2
|
||||
data_type: 1
|
||||
name: "bias"
|
||||
raw_data: "\000\000\000\000\000\000\000\000"
|
||||
}
|
||||
initializer {
|
||||
data_type: 7
|
||||
name: "num_batches_tracked"
|
||||
raw_data: "\001\000\000\000\000\000\000\000"
|
||||
}
|
||||
initializer {
|
||||
dims: 2
|
||||
data_type: 1
|
||||
name: "running_mean"
|
||||
raw_data: "\315\314\314=\315\314\314="
|
||||
}
|
||||
initializer {
|
||||
dims: 2
|
||||
data_type: 1
|
||||
name: "running_var"
|
||||
raw_data: "fff?fff?"
|
||||
}
|
||||
initializer {
|
||||
dims: 2
|
||||
data_type: 1
|
||||
name: "weight"
|
||||
raw_data: "\000\000\200?\000\000\200?"
|
||||
}
|
||||
input {
|
||||
name: "input"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "weight"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "bias"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "running_mean"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "running_var"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "num_batches_tracked"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "7"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 12
|
||||
}
|
@ -1,46 +0,0 @@
|
||||
ir_version: 6
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.5"
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
output: "1"
|
||||
name: "ReduceMax_0"
|
||||
op_type: "ReduceMax"
|
||||
attribute {
|
||||
name: "keepdims"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "torch-jit-export"
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "1"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 9
|
||||
}
|
@ -1,58 +0,0 @@
|
||||
ir_version: 6
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.5"
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
output: "1"
|
||||
output: "2"
|
||||
name: "Dropout_0"
|
||||
op_type: "Dropout"
|
||||
attribute {
|
||||
name: "ratio"
|
||||
f: 0.5
|
||||
type: FLOAT
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "1"
|
||||
output: "3"
|
||||
name: "ReduceMax_1"
|
||||
op_type: "ReduceMax"
|
||||
attribute {
|
||||
name: "keepdims"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "torch-jit-export"
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "3"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 9
|
||||
}
|
@ -1,67 +0,0 @@
|
||||
ir_version: 6
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.5"
|
||||
graph {
|
||||
node {
|
||||
output: "1"
|
||||
name: "Constant_0"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
data_type: 1
|
||||
raw_data: "\000\000\000?"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "x"
|
||||
input: "1"
|
||||
output: "2"
|
||||
output: "3"
|
||||
name: "Dropout_1"
|
||||
op_type: "Dropout"
|
||||
}
|
||||
node {
|
||||
input: "2"
|
||||
output: "4"
|
||||
name: "ReduceMax_2"
|
||||
op_type: "ReduceMax"
|
||||
attribute {
|
||||
name: "keepdims"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "torch-jit-export"
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "4"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 12
|
||||
}
|
@ -37,10 +37,9 @@ BATCH_SIZE = 2
|
||||
|
||||
class TestModels(TestCase):
|
||||
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7):
|
||||
with torch.onnx.select_model_mode_for_export(model, None):
|
||||
graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
verify(model, inputs, backend, rtol=rtol, atol=atol)
|
||||
graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
verify(model, inputs, backend, rtol=rtol, atol=atol)
|
||||
|
||||
def test_ops(self):
|
||||
x = Variable(
|
||||
|
@ -45,7 +45,7 @@ def check_onnx_opset_operator(model, ops, opset_version=_export_onnx_opset_versi
|
||||
assert attributes[j][attribute_field] == getattr(graph.node[i].attribute[j], attribute_field)
|
||||
|
||||
|
||||
def check_onnx_opsets_operator(module, x, ops, opset_versions, training=torch.onnx.TrainingMode.EVAL, example_outputs=None):
|
||||
def check_onnx_opsets_operator(module, x, ops, opset_versions, training=False, example_outputs=None):
|
||||
for opset_version in opset_versions:
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(module, x, f,
|
||||
@ -238,12 +238,12 @@ class TestONNXOpset(TestCase):
|
||||
# test training mode
|
||||
ops = [{"op_name" : "Dropout", "attributes" : [{"name" : "ratio", "f" : 0.5, "type" : 1}]}]
|
||||
ops = {9 : ops, 10 : ops}
|
||||
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10], training=torch.onnx.TrainingMode.TRAINING)
|
||||
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10], training=True)
|
||||
|
||||
# test eval mode
|
||||
ops = []
|
||||
ops = {9 : ops, 10 : ops}
|
||||
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10], training=torch.onnx.TrainingMode.EVAL)
|
||||
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10], training=False)
|
||||
|
||||
def test_full(self):
|
||||
class MyModule(Module):
|
||||
|
@ -16,6 +16,7 @@ import os
|
||||
import shutil
|
||||
import torch.testing._internal.common_utils as common
|
||||
|
||||
|
||||
'''Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
|
||||
--no-onnx: no onnx python dependence
|
||||
--produce-onnx-test-data: generate onnx test data
|
||||
@ -254,12 +255,7 @@ class TestOperators(TestCase):
|
||||
|
||||
def test_batchnorm_training(self):
|
||||
x = torch.ones(2, 2, 2, 2, requires_grad=True)
|
||||
self.assertONNX(nn.BatchNorm2d(2), x, training=torch.onnx.TrainingMode.TRAINING, keep_initializers_as_inputs=True)
|
||||
|
||||
def test_batchnorm_training_opset12(self):
|
||||
x = torch.ones(2, 2, 2, 2, requires_grad=True)
|
||||
self.assertONNX(nn.BatchNorm2d(2), x, training=torch.onnx.TrainingMode.TRAINING,
|
||||
keep_initializers_as_inputs=True, opset_version=12)
|
||||
self.assertONNX(nn.BatchNorm2d(2), x, training=True, keep_initializers_as_inputs=True)
|
||||
|
||||
def test_conv(self):
|
||||
x = torch.ones(20, 16, 50, 40, requires_grad=True)
|
||||
@ -676,18 +672,6 @@ class TestOperators(TestCase):
|
||||
x = torch.randn(3, 4, requires_grad=True)
|
||||
self.assertONNX(lambda x: torch.max(functional.dropout(x, training=False)), x)
|
||||
|
||||
def test_dropout_default(self):
|
||||
x = torch.randn(3, 4, requires_grad=True)
|
||||
self.assertONNX(lambda x: torch.max(functional.dropout(x,)), x)
|
||||
|
||||
def test_dropout_training(self):
|
||||
x = torch.randn(3, 4, requires_grad=True)
|
||||
self.assertONNX(lambda x: torch.max(functional.dropout(x)), x, training=torch.onnx.TrainingMode.TRAINING)
|
||||
|
||||
def test_dropout_training_opset12(self):
|
||||
x = torch.randn(3, 4, requires_grad=True)
|
||||
self.assertONNX(lambda x: torch.max(functional.dropout(x)), x, opset_version=12, training=torch.onnx.TrainingMode.TRAINING)
|
||||
|
||||
def test_nonzero(self):
|
||||
x = torch.tensor([[[2., 2.], [1., 0.]], [[0., 0.], [1., 1.]]], requires_grad=True)
|
||||
self.assertONNX(lambda x: torch.nonzero(x), x)
|
||||
|
@ -5,11 +5,8 @@ import torch
|
||||
import torch.onnx
|
||||
from torch.onnx import utils, OperatorExportTypes
|
||||
from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type
|
||||
from test_pytorch_common import skipIfUnsupportedOpsetVersion
|
||||
|
||||
import onnx
|
||||
import onnxruntime # noqa
|
||||
import numpy as np
|
||||
|
||||
import io
|
||||
import copy
|
||||
@ -55,8 +52,6 @@ class TestUtilityFuns(TestCase):
|
||||
assert "Provided key invalid_name2 for dynamic axes is not a valid input/output name" in messages
|
||||
assert len(messages) == 2
|
||||
|
||||
# TODO : enable when constant folding is enabled for opset 12
|
||||
@skipIfUnsupportedOpsetVersion([12])
|
||||
def test_constant_fold_transpose(self):
|
||||
class TransposeModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -77,8 +72,6 @@ class TestUtilityFuns(TestCase):
|
||||
assert node.kind() != "onnx::Constant"
|
||||
assert len(list(graph.nodes())) == 1
|
||||
|
||||
# TODO : enable when constant folding is enabled for opset 12
|
||||
@skipIfUnsupportedOpsetVersion([12])
|
||||
def test_constant_fold_slice(self):
|
||||
class NarrowModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -99,8 +92,6 @@ class TestUtilityFuns(TestCase):
|
||||
assert node.kind() != "onnx::Constant"
|
||||
assert len(list(graph.nodes())) == 1
|
||||
|
||||
# TODO : enable when constant folding is enabled for opset 12
|
||||
@skipIfUnsupportedOpsetVersion([12])
|
||||
def test_constant_fold_slice_index_exceeds_dim(self):
|
||||
class SliceIndexExceedsDimModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -122,8 +113,6 @@ class TestUtilityFuns(TestCase):
|
||||
assert node.kind() != "onnx::Constant"
|
||||
assert len(list(graph.nodes())) == 1
|
||||
|
||||
# TODO : enable when constant folding is enabled for opset 12
|
||||
@skipIfUnsupportedOpsetVersion([12])
|
||||
def test_constant_fold_slice_negative_index(self):
|
||||
class SliceNegativeIndexModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -144,8 +133,6 @@ class TestUtilityFuns(TestCase):
|
||||
assert node.kind() != "onnx::Constant"
|
||||
assert len(list(graph.nodes())) == 1
|
||||
|
||||
# TODO : enable when constant folding is enabled for opset 12
|
||||
@skipIfUnsupportedOpsetVersion([12])
|
||||
def test_constant_fold_unsqueeze(self):
|
||||
class UnsqueezeModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -166,8 +153,6 @@ class TestUtilityFuns(TestCase):
|
||||
assert node.kind() != "onnx::Constant"
|
||||
assert len(list(graph.nodes())) == 1
|
||||
|
||||
# TODO : enable when constant folding is enabled for opset 12
|
||||
@skipIfUnsupportedOpsetVersion([12])
|
||||
def test_constant_fold_concat(self):
|
||||
class ConcatModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -205,8 +190,6 @@ class TestUtilityFuns(TestCase):
|
||||
assert node.kind() != "onnx::Constant"
|
||||
assert len(list(graph.nodes())) == 2
|
||||
|
||||
# TODO : enable when constant folding is enabled for opset 12
|
||||
@skipIfUnsupportedOpsetVersion([12])
|
||||
def test_constant_fold_lstm(self):
|
||||
class GruNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -229,8 +212,6 @@ class TestUtilityFuns(TestCase):
|
||||
assert node.kind() != "onnx::Unsqueeze"
|
||||
assert len(list(graph.nodes())) == 3
|
||||
|
||||
# TODO : enable when constant folding is enabled for opset 12
|
||||
@skipIfUnsupportedOpsetVersion([12])
|
||||
def test_constant_fold_transpose_matmul(self):
|
||||
class MatMulNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -252,8 +233,6 @@ class TestUtilityFuns(TestCase):
|
||||
|
||||
# TODO we need to figure out the root cause and fix the problem
|
||||
@skip("causing segmentation fault")
|
||||
# TODO : enable when constant folding is enabled for opset 12
|
||||
@skipIfUnsupportedOpsetVersion([12])
|
||||
def test_constant_fold_reshape(self):
|
||||
class ReshapeModule(torch.nn.Module):
|
||||
def __init__(self, ):
|
||||
@ -273,8 +252,6 @@ class TestUtilityFuns(TestCase):
|
||||
assert node.kind() != "onnx::Reshape"
|
||||
assert len(list(graph.nodes())) == 1
|
||||
|
||||
# TODO : enable when constant folding is enabled for opset 12
|
||||
@skipIfUnsupportedOpsetVersion([12])
|
||||
def test_constant_fold_div(self):
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self, ):
|
||||
@ -294,8 +271,6 @@ class TestUtilityFuns(TestCase):
|
||||
assert node.kind() != "onnx::Div"
|
||||
assert len(list(graph.nodes())) == 1
|
||||
|
||||
# TODO : enable when constant folding is enabled for opset 12
|
||||
@skipIfUnsupportedOpsetVersion([12])
|
||||
def test_constant_fold_mul(self):
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self, ):
|
||||
@ -315,8 +290,6 @@ class TestUtilityFuns(TestCase):
|
||||
assert node.kind() != "onnx::Mul"
|
||||
assert len(list(graph.nodes())) == 1
|
||||
|
||||
# TODO : enable when constant folding is enabled for opset 12
|
||||
@skipIfUnsupportedOpsetVersion([12])
|
||||
def test_constant_fold_sqrt(self):
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self, ):
|
||||
@ -369,95 +342,6 @@ class TestUtilityFuns(TestCase):
|
||||
'unwrap model from torch.nn.DataParallel. Try '):
|
||||
torch.onnx.export(model, x, f, opset_version=self.opset_version)
|
||||
|
||||
def test_export_mode(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
y = x + 1
|
||||
return y
|
||||
|
||||
model = MyModule()
|
||||
x = torch.randn(10, 3, 128, 128)
|
||||
f = io.BytesIO()
|
||||
|
||||
# set mode to in inference mode and export in training mode
|
||||
model.eval()
|
||||
old_state = model.training
|
||||
torch.onnx.export(model, (x,), f,
|
||||
opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING)
|
||||
# verify that the model state is preserved
|
||||
assert model.training == old_state
|
||||
|
||||
# set mode to training mode and export in inference mode
|
||||
model.train()
|
||||
old_state = model.training
|
||||
torch.onnx.export(model, (x,), f,
|
||||
opset_version=self.opset_version, training=torch.onnx.TrainingMode.EVAL)
|
||||
# verify that the model state is preserved
|
||||
assert model.training == old_state
|
||||
|
||||
# TODO: Enable test when BatchNorm is implemented in ORT for opset 12.
|
||||
@skipIfUnsupportedOpsetVersion([12])
|
||||
def test_batchnorm_training(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
self.bn = torch.nn.BatchNorm2d(3, affine=True)
|
||||
|
||||
def forward(self, x):
|
||||
bn = self.bn(x)
|
||||
return bn
|
||||
|
||||
model = MyModule()
|
||||
x = torch.randn(10, 3, 128, 128)
|
||||
|
||||
model.train()
|
||||
out = model(x)
|
||||
|
||||
# state after 1 train epoch
|
||||
running_mean = model.bn.running_mean
|
||||
running_var = model.bn.running_var
|
||||
saved_mean = x.mean((0, 2, 3))
|
||||
saved_var = x.var((0, 2, 3))
|
||||
|
||||
pytorch_out = [out.detach().numpy(),
|
||||
running_mean.cpu().numpy(), running_var.cpu().numpy(),
|
||||
saved_mean.cpu().numpy(), saved_var.cpu().numpy()]
|
||||
|
||||
model_export = MyModule()
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(model_export, (x,), f,
|
||||
opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING)
|
||||
ort_sess = onnxruntime.InferenceSession(f.getvalue())
|
||||
|
||||
ort_inputs = {ort_sess.get_inputs()[0].name : x.cpu().numpy()}
|
||||
ort_outs = ort_sess.run(None, ort_inputs)
|
||||
[np.testing.assert_allclose(p_out, ort_out, atol=10e-3, rtol=10e-3) for p_out, ort_out in zip(pytorch_out, ort_outs)]
|
||||
|
||||
# TODO: Enable test when Dropout is implemented in ORT for opset 12.
|
||||
@skipIfUnsupportedOpsetVersion([12])
|
||||
def test_dropout_training(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
self.dropout = torch.nn.Dropout(0.4)
|
||||
|
||||
def forward(self, x):
|
||||
dropout = self.dropout(x)
|
||||
return dropout
|
||||
|
||||
model = MyModule()
|
||||
x = torch.randn(10, 3, 128, 128)
|
||||
|
||||
model.train()
|
||||
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(model, (x,), f,
|
||||
opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING)
|
||||
ort_sess = onnxruntime.InferenceSession(f.getvalue())
|
||||
ort_inputs = {ort_sess.get_inputs()[0].name : x.cpu().numpy()}
|
||||
ort_outs = ort_sess.run(None, ort_inputs)
|
||||
assert x != ort_outs[0]
|
||||
|
||||
|
||||
# opset 10 tests
|
||||
TestUtilityFuns_opset10 = type(str("TestUtilityFuns_opset10"),
|
||||
@ -470,11 +354,6 @@ TestUtilityFuns_opset11 = type(str("TestUtilityFuns_opset11"),
|
||||
(TestCase,),
|
||||
dict(TestUtilityFuns.__dict__, opset_version=11))
|
||||
|
||||
# opset 12 tests
|
||||
TestUtilityFuns_opset12 = type(str("TestUtilityFuns_opset12"),
|
||||
(TestCase,),
|
||||
dict(TestUtilityFuns.__dict__, opset_version=12))
|
||||
|
||||
|
||||
# opset 12tests
|
||||
TestUtilityFuns_opset12 = type(str("TestUtilityFuns_opset12"),
|
||||
|
@ -8,6 +8,7 @@ import onnx.helper
|
||||
import numpy as np
|
||||
|
||||
import difflib
|
||||
import contextlib
|
||||
import io
|
||||
|
||||
|
||||
@ -225,7 +226,24 @@ class Errors(object):
|
||||
if exc_type == self.exc_class:
|
||||
raise RuntimeError("ShortCircuit was raised, but no errors were recorded")
|
||||
|
||||
def verify(model, args, backend, verbose=False, training=torch.onnx.TrainingMode.EVAL, rtol=1e-3, atol=1e-7,
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_training(model, mode):
|
||||
"""
|
||||
A context manager to temporarily set the training mode of 'model'
|
||||
to 'mode', resetting it when we exit the with-block.
|
||||
"""
|
||||
old_mode = model.training
|
||||
if old_mode != mode:
|
||||
model.train(mode)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if old_mode != mode:
|
||||
model.train(old_mode)
|
||||
|
||||
|
||||
def verify(model, args, backend, verbose=False, training=False, rtol=1e-3, atol=1e-7,
|
||||
test_args=2, do_constant_folding=True, example_outputs=None, opset_version=None,
|
||||
keep_initializers_as_inputs=True, add_node_names=False):
|
||||
"""
|
||||
@ -341,7 +359,7 @@ def verify(model, args, backend, verbose=False, training=torch.onnx.TrainingMode
|
||||
if isinstance(args, torch.Tensor):
|
||||
args = (args,)
|
||||
|
||||
with torch.onnx.select_model_mode_for_export(model, training):
|
||||
with set_training(model, training):
|
||||
proto_bytes = io.BytesIO()
|
||||
torch_out = torch.onnx._export(model, args, proto_bytes, verbose=verbose,
|
||||
do_constant_folding=do_constant_folding,
|
||||
|
@ -30,11 +30,6 @@ void initONNXBindings(PyObject* module) {
|
||||
.value("ONNX_ATEN_FALLBACK", OperatorExportTypes::ONNX_ATEN_FALLBACK)
|
||||
.value("RAW", OperatorExportTypes::RAW);
|
||||
|
||||
py::enum_<TrainingMode>(onnx, "TrainingMode")
|
||||
.value("EVAL", TrainingMode::EVAL)
|
||||
.value("PRESERVE", TrainingMode::PRESERVE)
|
||||
.value("TRAINING", TrainingMode::TRAINING);
|
||||
|
||||
onnx.attr("IR_VERSION") = IR_VERSION;
|
||||
onnx.attr("PRODUCER_VERSION") = py::str(PRODUCER_VERSION);
|
||||
|
||||
|
@ -9,12 +9,6 @@ enum class OperatorExportTypes {
|
||||
RAW, // Raw export (no ONNX)
|
||||
};
|
||||
|
||||
enum class TrainingMode {
|
||||
EVAL, // Inference mode
|
||||
PRESERVE, // Preserve model state (eval/training)
|
||||
TRAINING, // Training mode
|
||||
};
|
||||
|
||||
// we pin IR version to version 6 (12/11/2019) instead of using
|
||||
// onnx::IR_VERSION. with this change, the test_operators.py will be more
|
||||
// stable. only bump it when it's necessary
|
||||
|
@ -2,7 +2,6 @@ import torch._C as _C
|
||||
|
||||
TensorProtoDataType = _C._onnx.TensorProtoDataType
|
||||
OperatorExportTypes = _C._onnx.OperatorExportTypes
|
||||
TrainingMode = _C._onnx.TrainingMode
|
||||
PYTORCH_ONNX_CAFFE2_BUNDLE = _C._onnx.PYTORCH_ONNX_CAFFE2_BUNDLE
|
||||
|
||||
ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
|
||||
@ -29,7 +28,7 @@ def _export(*args, **kwargs):
|
||||
return result
|
||||
|
||||
|
||||
def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
|
||||
def export(model, args, f, export_params=True, verbose=False, training=False,
|
||||
input_names=None, output_names=None, aten=False, export_raw_ir=False,
|
||||
operator_export_type=None, opset_version=None, _retain_param_name=True,
|
||||
do_constant_folding=True, example_outputs=None, strip_doc_string=True,
|
||||
@ -60,11 +59,9 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
|
||||
as arguments, the ordering as specified by ``model.state_dict().values()``
|
||||
verbose (bool, default False): if specified, we will print out a debug
|
||||
description of the trace being exported.
|
||||
training (enum, default TrainingMode.EVAL):
|
||||
TrainingMode.EVAL: export the model in inference mode.
|
||||
TrainingMode.PRESERVE: export the model in inference mode if model.training is
|
||||
False and to a training friendly mode if model.training is True.
|
||||
TrainingMode.TRAINING: export the model in a training friendly mode.
|
||||
training (bool, default False): export the model in training mode. At
|
||||
the moment, ONNX is oriented towards exporting models for inference
|
||||
only, so you will generally not need to set this to True.
|
||||
input_names(list of strings, default empty list): names to assign to the
|
||||
input nodes of the graph, in order
|
||||
output_names(list of strings, default empty list): names to assign to the
|
||||
@ -187,7 +184,7 @@ def _optimize_trace(graph, operator_export_type):
|
||||
return utils._optimize_graph(graph, operator_export_type)
|
||||
|
||||
|
||||
def select_model_mode_for_export(model, mode):
|
||||
def set_training(model, mode):
|
||||
r"""
|
||||
A context manager to temporarily set the training mode of 'model'
|
||||
to 'mode', resetting it when we exit the with-block. A no-op if
|
||||
@ -195,7 +192,7 @@ def select_model_mode_for_export(model, mode):
|
||||
"""
|
||||
|
||||
from torch.onnx import utils
|
||||
return utils.select_model_mode_for_export(model, mode)
|
||||
return utils.set_training(model, mode)
|
||||
|
||||
|
||||
def _run_symbolic_function(*args, **kwargs):
|
||||
|
@ -413,20 +413,6 @@ def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, na
|
||||
padding = tuple(tuple_fn(padding))
|
||||
return padding
|
||||
|
||||
def assert_training_mode(op_mode, op_name):
|
||||
global _training_mode
|
||||
op_mode = True if op_mode == 1 else False
|
||||
if op_mode != _training_mode:
|
||||
op_mode = "training " if op_mode else "inference"
|
||||
training_mode = "training " if _training_mode else "inference"
|
||||
# setting the model mode could result in op_mode != _training_mode
|
||||
# if the model is a FuncModule. In this case we warn the user of
|
||||
# the state and export depending on training_mode
|
||||
warnings.warn("ONNX export mode is set to " + training_mode +
|
||||
" mode, but operator " + op_name + " is set to " +
|
||||
op_mode + " mode. The model will be exported in " +
|
||||
training_mode + ", as specified by the export mode.")
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# ONNX operator version
|
||||
# ---------------------------------------------------------------------
|
||||
@ -475,11 +461,6 @@ def _set_operator_export_type(operator_export_type):
|
||||
global _operator_export_type
|
||||
_operator_export_type = operator_export_type
|
||||
|
||||
_training_mode = None
|
||||
def _set_training_mode(training_mode):
|
||||
global _training_mode
|
||||
_training_mode = training_mode
|
||||
|
||||
# Metaprogram symbolics for each ATen native specialized cast operator.
|
||||
# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
|
||||
# ONNX cast node with `to` attribute 'UINT8'
|
||||
|
@ -1,6 +1,5 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import torch
|
||||
import torch.onnx.symbolic_helper as sym_help
|
||||
from torch.onnx.symbolic_helper import parse_args
|
||||
|
||||
@ -10,62 +9,11 @@ from torch.onnx.symbolic_helper import parse_args
|
||||
|
||||
# This file exports ONNX ops for opset 12
|
||||
|
||||
black_listed_operators = [
|
||||
"ArgMin", "ArgMax"
|
||||
]
|
||||
|
||||
@parse_args('s', 'v')
|
||||
def einsum(g, equation, tensor_list):
|
||||
tensors = sym_help._unpack_list(tensor_list)
|
||||
return g.op("Einsum", *tensors, equation_s=equation)
|
||||
|
||||
@parse_args('v', 'f', 'i')
|
||||
def dropout(g, input, p, train):
|
||||
sym_help.assert_training_mode(train, "dropout")
|
||||
# in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op
|
||||
if not sym_help._training_mode:
|
||||
return input
|
||||
p = g.op("Constant", value_t=torch.tensor(p))
|
||||
r, _ = g.op("Dropout", input, p, outputs=2)
|
||||
return r
|
||||
|
||||
@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
|
||||
def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled):
|
||||
sym_help.assert_training_mode(training, "batch_norm")
|
||||
input_sizes = input.type().sizes()
|
||||
|
||||
if weight is None or sym_help._is_none(weight):
|
||||
assert len(input_sizes) > 1
|
||||
weight_value = torch.tensor([1.] * input_sizes[1]).type(
|
||||
'torch.' + input.type().scalarType() + 'Tensor')
|
||||
weight = g.op("Constant", value_t=weight_value)
|
||||
if bias is None or sym_help._is_none(bias):
|
||||
assert len(input_sizes) > 1
|
||||
bias_value = torch.tensor([0.] * input_sizes[1]).type(
|
||||
'torch.' + input.type().scalarType() + 'Tensor')
|
||||
bias = g.op("Constant", value_t=bias_value)
|
||||
|
||||
if not sym_help._training_mode:
|
||||
out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var,
|
||||
epsilon_f=eps,
|
||||
momentum_f=1 - momentum,
|
||||
outputs=1)
|
||||
return out
|
||||
else:
|
||||
training_mode = g.op("Constant", value_t=torch.tensor(True))
|
||||
res, new_running_mean, new_running_var, saved_mean, saved_var = g.op("BatchNormalization",
|
||||
input,
|
||||
weight, bias,
|
||||
running_mean, running_var, training_mode,
|
||||
epsilon_f=eps,
|
||||
momentum_f=1 - momentum,
|
||||
outputs=5)
|
||||
new_running_mean.setType(running_mean.type())
|
||||
new_running_var.setType(running_var.type())
|
||||
saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName())
|
||||
saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName())
|
||||
return res
|
||||
|
||||
def nll_loss(g, self, target, weight, reduction, ignore_index):
|
||||
# none reduction : onnx::Constant[value={0}]
|
||||
# mean reduction : onnx::Constant[value={1}]
|
||||
|
@ -1049,7 +1049,6 @@ def conv_transpose3d(g, input, weight, bias, stride, padding, output_padding, gr
|
||||
|
||||
@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
|
||||
def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled):
|
||||
sym_help.assert_training_mode(training, "dropout")
|
||||
input_sizes = input.type().sizes()
|
||||
|
||||
if weight is None or sym_help._is_none(weight):
|
||||
@ -1066,8 +1065,8 @@ def batch_norm(g, input, weight, bias, running_mean, running_var, training, mome
|
||||
out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var,
|
||||
epsilon_f=eps,
|
||||
momentum_f=1 - momentum,
|
||||
outputs=1 if not sym_help._training_mode else 5)
|
||||
if not sym_help._training_mode:
|
||||
outputs=1 if not training else 5)
|
||||
if not training:
|
||||
return out
|
||||
else:
|
||||
res, new_running_mean, new_running_var, saved_mean, saved_var = out
|
||||
@ -1299,9 +1298,7 @@ def exp(g, self):
|
||||
|
||||
@parse_args('v', 'f', 'i')
|
||||
def dropout(g, input, p, train):
|
||||
sym_help.assert_training_mode(train, "dropout")
|
||||
# in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op
|
||||
if not sym_help._training_mode:
|
||||
if not train: # in eval mode, dropout is non-op
|
||||
return input
|
||||
warnings.warn("Dropout is a training op and should not be exported in inference mode. "
|
||||
"Make sure to call eval() on the model, and to export it with param training=False.")
|
||||
|
@ -17,7 +17,7 @@ import numbers
|
||||
import warnings
|
||||
from torch._six import string_classes
|
||||
from torch.jit import _unique_state_dict
|
||||
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes, TrainingMode
|
||||
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes
|
||||
from torch._C import ListType, _propagate_and_assign_input_shapes, _assign_output_shapes, _check_onnx_proto
|
||||
|
||||
|
||||
@ -31,44 +31,21 @@ def is_in_onnx_export():
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def select_model_mode_for_export(model, mode):
|
||||
if not isinstance(model, torch.jit.ScriptFunction):
|
||||
is_originally_training = model.training
|
||||
|
||||
if mode is None:
|
||||
mode = TrainingMode.EVAL
|
||||
# if the model is in training mode but the user did not specify
|
||||
# to export the model in training mode, export the model in inference
|
||||
# mode (default) and warn them
|
||||
if is_originally_training:
|
||||
warnings.warn("You are exporting the model to ONNX while in training mode with "
|
||||
"'train' parameter not specified. The model will default to inference mode export. "
|
||||
"If you wish to export a training amenable ONNX model, specify train=TrainingMode.TRAIN or "
|
||||
"train=TrainingMode.PRESERVE (to preserve the original model state) in torch.onnx.export().")
|
||||
|
||||
# if mode == TrainingMode.EVAL or (mode == TrainingMode.PRESERVE and not is_originally_training) => is_training = False
|
||||
is_export_training = False
|
||||
# ONNX opset 12 has better support for training amenable models, with updated
|
||||
# versions of the dropout and batch_norm operators
|
||||
if mode == TrainingMode.TRAINING or (mode == TrainingMode.PRESERVE and is_originally_training):
|
||||
from torch.onnx.symbolic_helper import _export_onnx_opset_version
|
||||
if _export_onnx_opset_version < 12:
|
||||
warnings.warn("You are exporting the model in training mode with onnx opset version {}. "
|
||||
"Opset versions lower than opset 12 will not be able to export nodes such as"
|
||||
"Dropout and BatchNorm correctly.".format(_export_onnx_opset_version))
|
||||
is_export_training = True
|
||||
|
||||
from torch.onnx.symbolic_helper import _set_training_mode
|
||||
_set_training_mode(is_export_training)
|
||||
model.train(is_export_training)
|
||||
def set_training(model, mode):
|
||||
if mode is None:
|
||||
yield
|
||||
return
|
||||
old_mode = model.training
|
||||
if old_mode != mode:
|
||||
model.train(mode)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if not isinstance(model, torch.jit.ScriptFunction):
|
||||
model.train(is_originally_training)
|
||||
if old_mode != mode:
|
||||
model.train(old_mode)
|
||||
|
||||
|
||||
def export(model, args, f, export_params=True, verbose=False, training=None,
|
||||
def export(model, args, f, export_params=True, verbose=False, training=False,
|
||||
input_names=None, output_names=None, aten=False, export_raw_ir=False,
|
||||
operator_export_type=None, opset_version=None, _retain_param_name=True,
|
||||
do_constant_folding=True, example_outputs=None, strip_doc_string=True,
|
||||
@ -298,15 +275,21 @@ def _trace(func, args, operator_export_type, return_outs=False):
|
||||
return trace_graph
|
||||
|
||||
|
||||
def _trace_and_get_graph_from_model(model, args):
|
||||
def _trace_and_get_graph_from_model(model, args, training):
|
||||
|
||||
# A basic sanity check: make sure the state_dict keys are the same
|
||||
# before and after running the model. Fail fast!
|
||||
orig_state_dict_keys = _unique_state_dict(model).keys()
|
||||
|
||||
trace_graph, torch_out, inputs_states = \
|
||||
torch.jit._get_trace_graph(model, args, _force_outplace=False, _return_inputs_states=True)
|
||||
warn_on_static_input_change(inputs_states)
|
||||
# By default, training=False, which is good because running a model in
|
||||
# training mode could result in internal buffers getting updated, dropout
|
||||
# getting applied, etc. If you really know what you're doing, you
|
||||
# can turn training=True (or None, to preserve whatever the original
|
||||
# training mode was.)
|
||||
with set_training(model, training):
|
||||
trace_graph, torch_out, inputs_states = \
|
||||
torch.jit._get_trace_graph(model, args, _force_outplace=False, _return_inputs_states=True)
|
||||
warn_on_static_input_change(inputs_states)
|
||||
|
||||
if orig_state_dict_keys != _unique_state_dict(model).keys():
|
||||
raise RuntimeError("state_dict changed after running the tracer; "
|
||||
@ -315,7 +298,7 @@ def _trace_and_get_graph_from_model(model, args):
|
||||
return trace_graph, torch_out
|
||||
|
||||
|
||||
def _model_to_graph(model, args, verbose=False,
|
||||
def _model_to_graph(model, args, verbose=False, training=False,
|
||||
input_names=None, output_names=None,
|
||||
operator_export_type=OperatorExportTypes.ONNX,
|
||||
example_outputs=None, propagate=False,
|
||||
@ -348,7 +331,7 @@ def _model_to_graph(model, args, verbose=False,
|
||||
graph = _propagate_and_assign_input_shapes(
|
||||
model.graph, tuple(in_vars), False, propagate)
|
||||
else:
|
||||
graph, torch_out = _trace_and_get_graph_from_model(model, args)
|
||||
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
|
||||
state_dict = _unique_state_dict(model)
|
||||
params = list(state_dict.values())
|
||||
if _retain_param_name:
|
||||
@ -404,7 +387,7 @@ def _model_to_graph(model, args, verbose=False,
|
||||
return graph, params_dict, torch_out
|
||||
|
||||
|
||||
def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=None,
|
||||
def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False,
|
||||
input_names=None, output_names=None, aten=False, export_raw_ir=False,
|
||||
operator_export_type=None, export_type=ExportTypes.PROTOBUF_FILE,
|
||||
example_outputs=None, propagate=False, google_printer=False,
|
||||
@ -427,7 +410,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t
|
||||
custom_opsets=custom_opsets)
|
||||
|
||||
|
||||
def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=None,
|
||||
def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False,
|
||||
input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
|
||||
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False,
|
||||
google_printer=False, opset_version=None, _retain_param_name=False,
|
||||
@ -441,27 +424,27 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False,
|
||||
custom_opsets = {}
|
||||
_set_opset_version(opset_version)
|
||||
_set_operator_export_type(operator_export_type)
|
||||
with select_model_mode_for_export(model, training):
|
||||
val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
|
||||
operator_export_type,
|
||||
opset_version)
|
||||
val_add_node_names = _decide_add_node_names(add_node_names, operator_export_type)
|
||||
val_do_constant_folding = _decide_constant_folding(do_constant_folding, operator_export_type)
|
||||
graph, params_dict, torch_out = _model_to_graph(model, args, verbose, input_names,
|
||||
output_names, operator_export_type,
|
||||
example_outputs, propagate, _retain_param_name,
|
||||
val_do_constant_folding, fixed_batch_size=fixed_batch_size)
|
||||
val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
|
||||
operator_export_type,
|
||||
opset_version)
|
||||
val_add_node_names = _decide_add_node_names(add_node_names, operator_export_type)
|
||||
val_do_constant_folding = _decide_constant_folding(do_constant_folding, operator_export_type)
|
||||
graph, params_dict, torch_out = _model_to_graph(model, args, verbose,
|
||||
training, input_names,
|
||||
output_names, operator_export_type,
|
||||
example_outputs, propagate, _retain_param_name,
|
||||
val_do_constant_folding, fixed_batch_size=fixed_batch_size)
|
||||
|
||||
return graph._pretty_print_onnx(params_dict, opset_version, False,
|
||||
operator_export_type, google_printer,
|
||||
val_keep_init_as_ip, custom_opsets, val_add_node_names)
|
||||
return graph._pretty_print_onnx(params_dict, opset_version, False,
|
||||
operator_export_type, google_printer,
|
||||
val_keep_init_as_ip, custom_opsets, val_add_node_names)
|
||||
|
||||
|
||||
# NOTE: the output `torch_out` will contain the output tensors resulting from
|
||||
# the trace of a Module. In the case that a torch.nn.ScriptModule is passed in,
|
||||
# this output will be None, since we are not doing any tracing but rather
|
||||
# directly extracting the graph.
|
||||
def _export(model, args, f, export_params=True, verbose=False, training=None,
|
||||
def _export(model, args, f, export_params=True, verbose=False, training=False,
|
||||
input_names=None, output_names=None, operator_export_type=None,
|
||||
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False,
|
||||
opset_version=None, _retain_param_name=False, do_constant_folding=True,
|
||||
@ -487,86 +470,80 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
|
||||
else:
|
||||
operator_export_type = OperatorExportTypes.ONNX
|
||||
|
||||
# By default, training=None, (which defaults to TrainingMode.EVAL),
|
||||
# which is good because running a model in training mode could result in
|
||||
# internal buffers getting updated, dropout getting applied, etc.
|
||||
# If you really know what you're doing, you can turn
|
||||
# training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE,
|
||||
# (to preserve whatever the original training mode was.)
|
||||
with select_model_mode_for_export(model, training):
|
||||
_set_opset_version(opset_version)
|
||||
_set_operator_export_type(operator_export_type)
|
||||
val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
|
||||
operator_export_type,
|
||||
opset_version)
|
||||
val_add_node_names = _decide_add_node_names(add_node_names, operator_export_type)
|
||||
val_do_constant_folding = _decide_constant_folding(do_constant_folding, operator_export_type)
|
||||
val_use_external_data_format, model_file_location = _decide_external_data_format(use_external_data_format,
|
||||
operator_export_type,
|
||||
f)
|
||||
graph, params_dict, torch_out = _model_to_graph(model, args, verbose, input_names,
|
||||
output_names, operator_export_type,
|
||||
example_outputs, propagate,
|
||||
_retain_param_name, val_do_constant_folding,
|
||||
fixed_batch_size=fixed_batch_size)
|
||||
_set_opset_version(opset_version)
|
||||
_set_operator_export_type(operator_export_type)
|
||||
val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
|
||||
operator_export_type,
|
||||
opset_version)
|
||||
val_add_node_names = _decide_add_node_names(add_node_names, operator_export_type)
|
||||
val_do_constant_folding = _decide_constant_folding(do_constant_folding, operator_export_type)
|
||||
val_use_external_data_format, model_file_location = _decide_external_data_format(use_external_data_format,
|
||||
operator_export_type,
|
||||
f)
|
||||
graph, params_dict, torch_out = _model_to_graph(model, args, verbose,
|
||||
training, input_names,
|
||||
output_names, operator_export_type,
|
||||
example_outputs, propagate,
|
||||
_retain_param_name, val_do_constant_folding,
|
||||
fixed_batch_size=fixed_batch_size)
|
||||
|
||||
# TODO: Don't allocate a in-memory string for the protobuf
|
||||
defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
|
||||
if dynamic_axes is None:
|
||||
dynamic_axes = {}
|
||||
if custom_opsets is None:
|
||||
custom_opsets = {}
|
||||
# TODO: Don't allocate a in-memory string for the protobuf
|
||||
defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
|
||||
if dynamic_axes is None:
|
||||
dynamic_axes = {}
|
||||
if custom_opsets is None:
|
||||
custom_opsets = {}
|
||||
|
||||
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
|
||||
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
|
||||
|
||||
if export_params:
|
||||
proto, export_map = graph._export_onnx(
|
||||
params_dict, opset_version, dynamic_axes, defer_weight_export,
|
||||
operator_export_type, strip_doc_string, val_keep_init_as_ip, custom_opsets,
|
||||
val_add_node_names, val_use_external_data_format, model_file_location)
|
||||
else:
|
||||
proto, export_map = graph._export_onnx(
|
||||
{}, opset_version, dynamic_axes, False, operator_export_type,
|
||||
strip_doc_string, val_keep_init_as_ip, custom_opsets, val_add_node_names,
|
||||
val_use_external_data_format, model_file_location)
|
||||
if export_params:
|
||||
proto, export_map = graph._export_onnx(
|
||||
params_dict, opset_version, dynamic_axes, defer_weight_export,
|
||||
operator_export_type, strip_doc_string, val_keep_init_as_ip, custom_opsets,
|
||||
val_add_node_names, val_use_external_data_format, model_file_location)
|
||||
else:
|
||||
proto, export_map = graph._export_onnx(
|
||||
{}, opset_version, dynamic_axes, False, operator_export_type,
|
||||
strip_doc_string, val_keep_init_as_ip, custom_opsets, val_add_node_names,
|
||||
val_use_external_data_format, model_file_location)
|
||||
|
||||
if enable_onnx_checker and \
|
||||
operator_export_type is OperatorExportTypes.ONNX_ATEN_FALLBACK and \
|
||||
not val_use_external_data_format:
|
||||
# Only run checker if enabled and we are not using ATEN fallback and
|
||||
# large model format export in not enabled.
|
||||
_check_onnx_proto(proto)
|
||||
|
||||
if export_type == ExportTypes.PROTOBUF_FILE:
|
||||
assert(len(export_map) == 0)
|
||||
with torch.serialization._open_file_like(f, 'wb') as opened_file:
|
||||
opened_file.write(proto)
|
||||
elif export_type in [ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE]:
|
||||
import zipfile
|
||||
compression = zipfile.ZIP_DEFLATED \
|
||||
if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
|
||||
else zipfile.ZIP_STORED
|
||||
with zipfile.ZipFile(f, 'w', compression=compression) as z:
|
||||
z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
|
||||
for k, v in export_map.items():
|
||||
z.writestr(k, v)
|
||||
elif export_type == ExportTypes.DIRECTORY:
|
||||
import os
|
||||
if os.path.exists(f):
|
||||
assert(os.path.isdir(f))
|
||||
else:
|
||||
os.makedirs(f)
|
||||
|
||||
model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME)
|
||||
with torch.serialization._open_file_like(model_proto_file, 'wb') as opened_file:
|
||||
opened_file.write(proto)
|
||||
if enable_onnx_checker and \
|
||||
operator_export_type is OperatorExportTypes.ONNX and \
|
||||
not val_use_external_data_format:
|
||||
# Only run checker if enabled and we are not using ATEN fallback and
|
||||
# large model format export in not enabled.
|
||||
_check_onnx_proto(proto)
|
||||
|
||||
if export_type == ExportTypes.PROTOBUF_FILE:
|
||||
assert(len(export_map) == 0)
|
||||
with torch.serialization._open_file_like(f, 'wb') as opened_file:
|
||||
opened_file.write(proto)
|
||||
elif export_type in [ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE]:
|
||||
import zipfile
|
||||
compression = zipfile.ZIP_DEFLATED \
|
||||
if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
|
||||
else zipfile.ZIP_STORED
|
||||
with zipfile.ZipFile(f, 'w', compression=compression) as z:
|
||||
z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
|
||||
for k, v in export_map.items():
|
||||
weight_proto_file = os.path.join(f, k)
|
||||
with torch.serialization._open_file_like(weight_proto_file, 'wb') as opened_file:
|
||||
opened_file.write(v)
|
||||
z.writestr(k, v)
|
||||
elif export_type == ExportTypes.DIRECTORY:
|
||||
import os
|
||||
if os.path.exists(f):
|
||||
assert(os.path.isdir(f))
|
||||
else:
|
||||
raise RuntimeError('Unknown export type')
|
||||
os.makedirs(f)
|
||||
|
||||
model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME)
|
||||
with torch.serialization._open_file_like(model_proto_file, 'wb') as opened_file:
|
||||
opened_file.write(proto)
|
||||
|
||||
for k, v in export_map.items():
|
||||
weight_proto_file = os.path.join(f, k)
|
||||
with torch.serialization._open_file_like(weight_proto_file, 'wb') as opened_file:
|
||||
opened_file.write(v)
|
||||
else:
|
||||
raise RuntimeError('Unknown export type')
|
||||
finally:
|
||||
assert __IN_ONNX_EXPORT
|
||||
__IN_ONNX_EXPORT = False
|
||||
|
@ -280,7 +280,7 @@ def graph(model, args, verbose=False):
|
||||
verbose (bool): Whether to print out verbose information while
|
||||
processing.
|
||||
"""
|
||||
with torch.onnx.select_model_mode_for_export(model, torch.onnx.TrainingMode.EVAL): # TODO: move outside of torch.onnx?
|
||||
with torch.onnx.set_training(model, False): # TODO: move outside of torch.onnx?
|
||||
try:
|
||||
trace = torch.jit.trace(model, args)
|
||||
graph = trace.graph
|
||||
|
Reference in New Issue
Block a user