Bump up opset version to 7 in Caffe2 ONNX exporter (#8854)

Summary:
Will bump up to opset 8 in another PR to match the current opset version.

Already tested through generating the models in current model zoo.
Closes https://github.com/pytorch/pytorch/pull/8854

Reviewed By: ezyang

Differential Revision: D8666437

Pulled By: houseroad

fbshipit-source-id: feffdf704dd3136aa59c0f1ff1830c14d1bd20aa
This commit is contained in:
Lu Fang
2018-06-28 07:32:01 -07:00
committed by Facebook Github Bot
parent 148088a681
commit 63233f98ad
7 changed files with 92 additions and 10 deletions

View File

@ -608,15 +608,19 @@ Caffe2Ops Caffe2Backend::CreateGemm(OnnxNode* onnx_node, int opset_version) {
caffe2::Argument arg_trans_b;
arg_trans_b.set_name("trans_b");
arg_trans_b.set_i(trans_b);
caffe2::Argument arg_broadcast;
arg_broadcast.set_name("broadcast");
arg_broadcast.set_i(broadcast);
auto* c2_op = ret.ops.Add();
BuildOperator(
c2_op, "MatMul", {input_a, input_b}, {ab}, {arg_trans_a, arg_trans_b});
c2_op = ret.ops.Add();
BuildOperator(c2_op, "Add", {ab, input_c}, {output}, {arg_broadcast});
if (opset_version >= 7) {
BuildOperator(c2_op, "Add", {ab, input_c}, {output});
} else {
caffe2::Argument arg_broadcast;
arg_broadcast.set_name("broadcast");
arg_broadcast.set_i(broadcast);
BuildOperator(c2_op, "Add", {ab, input_c}, {output}, {arg_broadcast});
}
}
return ret;
@ -854,7 +858,7 @@ Caffe2Ops Caffe2Backend::CreateBatchNormalization(
attributes.remove("consumed_inputs");
}
if (opset_version > 6) {
if (opset_version >= 7) {
auto& attributes = onnx_node->attributes;
auto* attr = attributes.AddRewrittenAttribute("is_test");
attr->set_i(1);
@ -914,7 +918,7 @@ Caffe2Ops Caffe2Backend::CreateUpsample(OnnxNode* onnx_node, int opset_version)
}
Caffe2Ops Caffe2Backend::CreateDropout(OnnxNode* onnx_node, int opset_version) {
if (opset_version > 6) {
if (opset_version >= 7) {
auto& attributes = onnx_node->attributes;
auto* attr = attributes.AddRewrittenAttribute("is_test");
attr->set_i(1);

View File

@ -3,6 +3,7 @@
#include "caffe2/onnx/helper.h"
#include "caffe2/proto/caffe2_legacy.pb.h"
#include "caffe2/utils/map_utils.h"
#include "caffe2/utils/proto_utils.h"
#include <unordered_set>
@ -221,6 +222,17 @@ const std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>&
OnnxExporter::get_special_operators() const {
const static std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>
kSpecialOperators = {
{"Add", &OnnxExporter::CreateBinaryElementwiseOpNodes},
{"Sub", &OnnxExporter::CreateBinaryElementwiseOpNodes},
{"Mul", &OnnxExporter::CreateBinaryElementwiseOpNodes},
{"Div", &OnnxExporter::CreateBinaryElementwiseOpNodes},
{"Pow", &OnnxExporter::CreateBinaryElementwiseOpNodes},
{"And", &OnnxExporter::CreateBinaryElementwiseOpNodes},
{"Or", &OnnxExporter::CreateBinaryElementwiseOpNodes},
{"Xor", &OnnxExporter::CreateBinaryElementwiseOpNodes},
{"Equal", &OnnxExporter::CreateBinaryElementwiseOpNodes},
{"Greater", &OnnxExporter::CreateBinaryElementwiseOpNodes},
{"Less", &OnnxExporter::CreateBinaryElementwiseOpNodes},
{"Cast", &OnnxExporter::CreateCastNodes},
{"Conv", &OnnxExporter::CreateConvPoolNodes},
{"ConvTranspose", &OnnxExporter::CreateConvPoolNodes},
@ -278,7 +290,9 @@ bool OnnxExporter::IsBlackListed(const caffe2::Argument& arg) {
kBlackListString = {{"order", {"NCHW"}}};
const static std::unordered_map<std::string, std::unordered_set<int64_t>>
kBlackListInt = {{"cudnn_exhaustive_search", {0, 1}},
{"use_cudnn", {0, 1}}};
{"use_cudnn", {0, 1}},
{"is_test", {0, 1}},
{"broadcast", {0, 1}}};
if (arg.has_i()) {
const auto it = kBlackListInt.find(arg.name());
@ -337,6 +351,49 @@ ConvertedResult OnnxExporter::CommonCaffe2OpToOnnxNodes(
return result;
}
ConvertedResult OnnxExporter::CreateBinaryElementwiseOpNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
caffe2::OperatorDef mdef(def); // The modified def without broadcast and axis
const auto& x = mdef.input(0);
const auto& y = def.input(1); // Refer to the old def, later won't change it.
const auto& x_shape = shapes.at(x);
const auto& y_shape = shapes.at(y);
for (int i = 0; i < mdef.arg_size(); ++i) {
const auto& arg = mdef.arg(i);
if (arg.name() == "broadcast") {
ArgumentHelper::RemoveArgument(mdef, i);
break;
}
}
std::vector<int64_t> axes;
for (int i = 0; i < mdef.arg_size(); ++i) {
const auto& arg = mdef.arg(i);
if (arg.name() == "axis") {
int64_t axis = arg.i();
if (x_shape.dims().size() - axis != y_shape.dims().size()) {
// The upper bound (excluded) of expanded y.
int64_t end_dim =
y_shape.dims().size() - 1 - axis + x_shape.dims().size();
axes.resize(end_dim - y_shape.dims().size());
std::iota(axes.begin(), axes.end(), y_shape.dims().size());
mdef.set_input(1, dummy_->NewDummyName());
}
ArgumentHelper::RemoveArgument(mdef, i);
break;
}
}
auto result = CommonCaffe2OpToOnnxNodes(mdef);
if (axes.size() != 0) {
result.first.insert(
result.first.begin(),
MakeNode(
"Unsqueeze", {y}, {mdef.input(1)}, {MakeAttribute("axes", axes)}));
}
return result;
}
ConvertedResult OnnxExporter::CreateCastNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
@ -747,7 +804,7 @@ ConvertedResult OnnxExporter::CreateGemmNodes(
"Gemm",
{x, w, b},
{gemm_y_output},
{MakeAttribute("transB", 1L), MakeAttribute("broadcast", 1)},
{MakeAttribute("transB", 1L)},
def.name()));
if (has_axis) {

View File

@ -52,6 +52,10 @@ class OnnxExporter {
private:
ConvertedResult CommonCaffe2OpToOnnxNodes(const caffe2::OperatorDef& def);
ConvertedResult CreateBinaryElementwiseOpNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateCastNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

View File

@ -38,7 +38,7 @@ class Caffe2Frontend(object):
# ONNX makes a BC breaking change to semantics of operators, having this set
# to an accurate number will prevent our models form exporting. However,
# we should strive to keep this up-to-date as much as possible.
target_opset_version = 6
target_opset_version = 7
_renamed_operators = {
'SpatialBN': 'BatchNormalization',
@ -136,7 +136,6 @@ class Caffe2Frontend(object):
@classmethod
def caffe2_op_to_onnx_node(cls, op_def, shapes):
if C.support_onnx_export(op_def.type):
shape_list = list(shapes.values())
node_strs, tensor_strs = C.export_to_onnx(cls._dummy_name, op_def.SerializeToString(), shapes)
nodes = []
for s in node_strs:

View File

@ -397,6 +397,11 @@ CAFFE2_MAKE_SINGULAR_ARGUMENT(int64_t, i)
CAFFE2_MAKE_SINGULAR_ARGUMENT(string, s)
#undef CAFFE2_MAKE_SINGULAR_ARGUMENT
template <>
bool ArgumentHelper::RemoveArgument(OperatorDef& def, int index);
template <>
bool ArgumentHelper::RemoveArgument(NetDef& def, int index);
template <>
Argument MakeArgument(const string& name, const MessageLite& value) {
Argument arg;

View File

@ -234,6 +234,18 @@ class ArgumentHelper {
return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
}
template <typename Def>
static bool RemoveArgument(Def& def, int index) {
if (index >= def.arg_size()) {
return false;
}
if (index < def.arg_size() - 1) {
def.mutable_arg()->SwapElements(index, def.arg_size() - 1);
}
def.mutable_arg()->RemoveLast();
return true;
}
explicit ArgumentHelper(const OperatorDef& def);
explicit ArgumentHelper(const NetDef& netdef);
bool HasArgument(const string& name) const;

View File

@ -218,6 +218,7 @@ model_mapping = {
'inception_v1': 'inception_v1',
'inception_v2': 'inception_v2',
'resnet50': 'resnet50',
'shufflenet': 'shufflenet',
'squeezenet': 'squeezenet_old',
#'vgg16': 'vgg16',
'vgg19': 'vgg19',