mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Bump up opset version to 7 in Caffe2 ONNX exporter (#8854)
Summary: Will bump up to opset 8 in another PR to match the current opset version. Already tested through generating the models in current model zoo. Closes https://github.com/pytorch/pytorch/pull/8854 Reviewed By: ezyang Differential Revision: D8666437 Pulled By: houseroad fbshipit-source-id: feffdf704dd3136aa59c0f1ff1830c14d1bd20aa
This commit is contained in:
committed by
Facebook Github Bot
parent
148088a681
commit
63233f98ad
@ -608,15 +608,19 @@ Caffe2Ops Caffe2Backend::CreateGemm(OnnxNode* onnx_node, int opset_version) {
|
||||
caffe2::Argument arg_trans_b;
|
||||
arg_trans_b.set_name("trans_b");
|
||||
arg_trans_b.set_i(trans_b);
|
||||
caffe2::Argument arg_broadcast;
|
||||
arg_broadcast.set_name("broadcast");
|
||||
arg_broadcast.set_i(broadcast);
|
||||
|
||||
auto* c2_op = ret.ops.Add();
|
||||
BuildOperator(
|
||||
c2_op, "MatMul", {input_a, input_b}, {ab}, {arg_trans_a, arg_trans_b});
|
||||
c2_op = ret.ops.Add();
|
||||
BuildOperator(c2_op, "Add", {ab, input_c}, {output}, {arg_broadcast});
|
||||
if (opset_version >= 7) {
|
||||
BuildOperator(c2_op, "Add", {ab, input_c}, {output});
|
||||
} else {
|
||||
caffe2::Argument arg_broadcast;
|
||||
arg_broadcast.set_name("broadcast");
|
||||
arg_broadcast.set_i(broadcast);
|
||||
BuildOperator(c2_op, "Add", {ab, input_c}, {output}, {arg_broadcast});
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
@ -854,7 +858,7 @@ Caffe2Ops Caffe2Backend::CreateBatchNormalization(
|
||||
attributes.remove("consumed_inputs");
|
||||
}
|
||||
|
||||
if (opset_version > 6) {
|
||||
if (opset_version >= 7) {
|
||||
auto& attributes = onnx_node->attributes;
|
||||
auto* attr = attributes.AddRewrittenAttribute("is_test");
|
||||
attr->set_i(1);
|
||||
@ -914,7 +918,7 @@ Caffe2Ops Caffe2Backend::CreateUpsample(OnnxNode* onnx_node, int opset_version)
|
||||
}
|
||||
|
||||
Caffe2Ops Caffe2Backend::CreateDropout(OnnxNode* onnx_node, int opset_version) {
|
||||
if (opset_version > 6) {
|
||||
if (opset_version >= 7) {
|
||||
auto& attributes = onnx_node->attributes;
|
||||
auto* attr = attributes.AddRewrittenAttribute("is_test");
|
||||
attr->set_i(1);
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include "caffe2/onnx/helper.h"
|
||||
#include "caffe2/proto/caffe2_legacy.pb.h"
|
||||
#include "caffe2/utils/map_utils.h"
|
||||
#include "caffe2/utils/proto_utils.h"
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
@ -221,6 +222,17 @@ const std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>&
|
||||
OnnxExporter::get_special_operators() const {
|
||||
const static std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>
|
||||
kSpecialOperators = {
|
||||
{"Add", &OnnxExporter::CreateBinaryElementwiseOpNodes},
|
||||
{"Sub", &OnnxExporter::CreateBinaryElementwiseOpNodes},
|
||||
{"Mul", &OnnxExporter::CreateBinaryElementwiseOpNodes},
|
||||
{"Div", &OnnxExporter::CreateBinaryElementwiseOpNodes},
|
||||
{"Pow", &OnnxExporter::CreateBinaryElementwiseOpNodes},
|
||||
{"And", &OnnxExporter::CreateBinaryElementwiseOpNodes},
|
||||
{"Or", &OnnxExporter::CreateBinaryElementwiseOpNodes},
|
||||
{"Xor", &OnnxExporter::CreateBinaryElementwiseOpNodes},
|
||||
{"Equal", &OnnxExporter::CreateBinaryElementwiseOpNodes},
|
||||
{"Greater", &OnnxExporter::CreateBinaryElementwiseOpNodes},
|
||||
{"Less", &OnnxExporter::CreateBinaryElementwiseOpNodes},
|
||||
{"Cast", &OnnxExporter::CreateCastNodes},
|
||||
{"Conv", &OnnxExporter::CreateConvPoolNodes},
|
||||
{"ConvTranspose", &OnnxExporter::CreateConvPoolNodes},
|
||||
@ -278,7 +290,9 @@ bool OnnxExporter::IsBlackListed(const caffe2::Argument& arg) {
|
||||
kBlackListString = {{"order", {"NCHW"}}};
|
||||
const static std::unordered_map<std::string, std::unordered_set<int64_t>>
|
||||
kBlackListInt = {{"cudnn_exhaustive_search", {0, 1}},
|
||||
{"use_cudnn", {0, 1}}};
|
||||
{"use_cudnn", {0, 1}},
|
||||
{"is_test", {0, 1}},
|
||||
{"broadcast", {0, 1}}};
|
||||
|
||||
if (arg.has_i()) {
|
||||
const auto it = kBlackListInt.find(arg.name());
|
||||
@ -337,6 +351,49 @@ ConvertedResult OnnxExporter::CommonCaffe2OpToOnnxNodes(
|
||||
return result;
|
||||
}
|
||||
|
||||
ConvertedResult OnnxExporter::CreateBinaryElementwiseOpNodes(
|
||||
const caffe2::OperatorDef& def,
|
||||
const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
|
||||
caffe2::OperatorDef mdef(def); // The modified def without broadcast and axis
|
||||
const auto& x = mdef.input(0);
|
||||
const auto& y = def.input(1); // Refer to the old def, later won't change it.
|
||||
const auto& x_shape = shapes.at(x);
|
||||
const auto& y_shape = shapes.at(y);
|
||||
for (int i = 0; i < mdef.arg_size(); ++i) {
|
||||
const auto& arg = mdef.arg(i);
|
||||
if (arg.name() == "broadcast") {
|
||||
ArgumentHelper::RemoveArgument(mdef, i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::vector<int64_t> axes;
|
||||
for (int i = 0; i < mdef.arg_size(); ++i) {
|
||||
const auto& arg = mdef.arg(i);
|
||||
if (arg.name() == "axis") {
|
||||
int64_t axis = arg.i();
|
||||
if (x_shape.dims().size() - axis != y_shape.dims().size()) {
|
||||
// The upper bound (excluded) of expanded y.
|
||||
int64_t end_dim =
|
||||
y_shape.dims().size() - 1 - axis + x_shape.dims().size();
|
||||
axes.resize(end_dim - y_shape.dims().size());
|
||||
std::iota(axes.begin(), axes.end(), y_shape.dims().size());
|
||||
mdef.set_input(1, dummy_->NewDummyName());
|
||||
}
|
||||
ArgumentHelper::RemoveArgument(mdef, i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto result = CommonCaffe2OpToOnnxNodes(mdef);
|
||||
if (axes.size() != 0) {
|
||||
result.first.insert(
|
||||
result.first.begin(),
|
||||
MakeNode(
|
||||
"Unsqueeze", {y}, {mdef.input(1)}, {MakeAttribute("axes", axes)}));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
ConvertedResult OnnxExporter::CreateCastNodes(
|
||||
const caffe2::OperatorDef& def,
|
||||
const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
|
||||
@ -747,7 +804,7 @@ ConvertedResult OnnxExporter::CreateGemmNodes(
|
||||
"Gemm",
|
||||
{x, w, b},
|
||||
{gemm_y_output},
|
||||
{MakeAttribute("transB", 1L), MakeAttribute("broadcast", 1)},
|
||||
{MakeAttribute("transB", 1L)},
|
||||
def.name()));
|
||||
|
||||
if (has_axis) {
|
||||
|
@ -52,6 +52,10 @@ class OnnxExporter {
|
||||
private:
|
||||
ConvertedResult CommonCaffe2OpToOnnxNodes(const caffe2::OperatorDef& def);
|
||||
|
||||
ConvertedResult CreateBinaryElementwiseOpNodes(
|
||||
const caffe2::OperatorDef& def,
|
||||
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
||||
|
||||
ConvertedResult CreateCastNodes(
|
||||
const caffe2::OperatorDef& def,
|
||||
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
|
||||
|
@ -38,7 +38,7 @@ class Caffe2Frontend(object):
|
||||
# ONNX makes a BC breaking change to semantics of operators, having this set
|
||||
# to an accurate number will prevent our models form exporting. However,
|
||||
# we should strive to keep this up-to-date as much as possible.
|
||||
target_opset_version = 6
|
||||
target_opset_version = 7
|
||||
|
||||
_renamed_operators = {
|
||||
'SpatialBN': 'BatchNormalization',
|
||||
@ -136,7 +136,6 @@ class Caffe2Frontend(object):
|
||||
@classmethod
|
||||
def caffe2_op_to_onnx_node(cls, op_def, shapes):
|
||||
if C.support_onnx_export(op_def.type):
|
||||
shape_list = list(shapes.values())
|
||||
node_strs, tensor_strs = C.export_to_onnx(cls._dummy_name, op_def.SerializeToString(), shapes)
|
||||
nodes = []
|
||||
for s in node_strs:
|
||||
|
@ -397,6 +397,11 @@ CAFFE2_MAKE_SINGULAR_ARGUMENT(int64_t, i)
|
||||
CAFFE2_MAKE_SINGULAR_ARGUMENT(string, s)
|
||||
#undef CAFFE2_MAKE_SINGULAR_ARGUMENT
|
||||
|
||||
template <>
|
||||
bool ArgumentHelper::RemoveArgument(OperatorDef& def, int index);
|
||||
template <>
|
||||
bool ArgumentHelper::RemoveArgument(NetDef& def, int index);
|
||||
|
||||
template <>
|
||||
Argument MakeArgument(const string& name, const MessageLite& value) {
|
||||
Argument arg;
|
||||
|
@ -234,6 +234,18 @@ class ArgumentHelper {
|
||||
return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
|
||||
}
|
||||
|
||||
template <typename Def>
|
||||
static bool RemoveArgument(Def& def, int index) {
|
||||
if (index >= def.arg_size()) {
|
||||
return false;
|
||||
}
|
||||
if (index < def.arg_size() - 1) {
|
||||
def.mutable_arg()->SwapElements(index, def.arg_size() - 1);
|
||||
}
|
||||
def.mutable_arg()->RemoveLast();
|
||||
return true;
|
||||
}
|
||||
|
||||
explicit ArgumentHelper(const OperatorDef& def);
|
||||
explicit ArgumentHelper(const NetDef& netdef);
|
||||
bool HasArgument(const string& name) const;
|
||||
|
@ -218,6 +218,7 @@ model_mapping = {
|
||||
'inception_v1': 'inception_v1',
|
||||
'inception_v2': 'inception_v2',
|
||||
'resnet50': 'resnet50',
|
||||
'shufflenet': 'shufflenet',
|
||||
'squeezenet': 'squeezenet_old',
|
||||
#'vgg16': 'vgg16',
|
||||
'vgg19': 'vgg19',
|
||||
|
Reference in New Issue
Block a user