Add if ops support for onnxifi and ssa-rewrite (#19585)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19585

Originally we will unroll all If op to many different subnets;
Now we will not unroll it anymore, but just add all external input of its subnet to the If op, and ssa-rewrite all external input/outputs. That would be enough.

Reviewed By: yinghai

Differential Revision: D15038139

fbshipit-source-id: 8532216d8749068acd5558ad0d8cb1d98463a063
This commit is contained in:
Rui Zhu
2019-04-24 10:53:59 -07:00
committed by Facebook Github Bot
parent 41486306d9
commit 2f73b3d26e
2 changed files with 170 additions and 6 deletions

View File

@ -5,6 +5,7 @@
#include "caffe2/proto/caffe2_legacy.pb.h"
#include "caffe2/utils/map_utils.h"
#include "caffe2/utils/proto_utils.h"
#include "caffe2/utils/string_utils.h"
#include <numeric>
#include <unordered_set>
@ -127,6 +128,123 @@ NodeProto AddShapeNode(const std::string& input, const std::string& output) {
#undef CAFFE2_TO_ONNX_TYPE
}
void collectExternalsFromIfOpSubnet(
const NetDef* net,
std::vector<std::string>* input,
std::vector<std::string>* output) {
std::set<std::string> in_input, in_output;
for (const auto& op : net->op()) {
for (const auto& blob : op.input()) {
in_input.emplace(blob);
}
for (const auto& blob : op.output()) {
in_output.emplace(blob);
}
}
for (const auto& blob : in_input) {
if (!in_output.count(blob)) {
input->push_back(blob);
}
}
for (const auto& blob : in_output) {
if (!in_input.count(blob)) {
output->push_back(blob);
}
}
}
void rewriteSubnet(
Argument* arg,
std::map<std::string, std::string> oldname_to_newname) {
NetDef* net = arg->mutable_n();
for (auto& op : *(net->mutable_op())) {
for (auto& input : *(op.mutable_input())) {
if (oldname_to_newname.find(input) != oldname_to_newname.end()) {
input = oldname_to_newname[input];
}
}
for (auto& output : *(op.mutable_output())) {
if (oldname_to_newname.find(output) != oldname_to_newname.end()) {
output = oldname_to_newname[output];
}
}
}
}
Argument* getArgumentFromName(OperatorDef* op, const std::string& name) {
for (int i = 0; i < op->arg_size(); i++) {
if (op->mutable_arg(i)->name() == name) {
return op->mutable_arg(i);
}
}
return nullptr;
}
void ssaRewriteForIfOp(
OperatorDef* op,
std::unordered_map<std::string, int>* blob_versions,
std::set<std::string>* is_initialized_tensor) {
// Get all the "external" inputs and outpus of the subnet
// Since then_net and else_net has same external input/output, we only collect
// external input/output from one of its subnet And perform the rewrite to
// both then_net and else_net
std::vector<std::string> if_external_input;
std::vector<std::string> if_external_output;
ArgumentHelper helper(*op);
Argument *then_arg = nullptr, *else_arg = nullptr;
NetDef* target_net = nullptr;
bool has_then = false, has_else = false;
if (helper.HasSingleArgumentOfType<NetDef>("then_net")) {
then_arg = getArgumentFromName(op, "then_net");
target_net = then_arg->mutable_n();
has_then = true;
}
if (helper.HasSingleArgumentOfType<NetDef>("else_net")) {
else_arg = getArgumentFromName(op, "else_net");
if (!has_then) {
target_net = else_arg->mutable_n();
}
has_else = true;
}
if (has_then || has_else) {
collectExternalsFromIfOpSubnet(
target_net, &if_external_input, &if_external_output);
std::map<string, string> oldname_to_newname;
// Build oldname_to_newname map
for (auto& input : if_external_input) {
const auto it = blob_versions->find(input);
if (it != blob_versions->end()) {
oldname_to_newname[input] = SsaName(input, it->second);
}
}
for (auto& output : if_external_output) {
auto it = blob_versions->find(output);
if (it != blob_versions->end()) {
if (is_initialized_tensor->count(output) == 0) {
it->second += 1;
} else {
is_initialized_tensor->erase(output);
}
oldname_to_newname[output] = SsaName(output, it->second);
} else {
blob_versions->emplace(output, 0);
oldname_to_newname[output] = SsaName(output, 0);
}
}
if (has_then) {
rewriteSubnet(then_arg, oldname_to_newname);
}
if (has_else) {
rewriteSubnet(else_arg, oldname_to_newname);
}
}
}
std::unordered_map<std::string, std::string> SsaRewrite(
caffe2::NetDef* init_net,
caffe2::NetDef* pred_net) {
@ -147,6 +265,7 @@ std::unordered_map<std::string, std::string> SsaRewrite(
blob_versions.clear();
}
std::set<std::string> is_initialized_tensor;
if (pred_net) {
std::unordered_set<std::string> external_outputs;
for (const auto& input : pred_net->external_input()) {
@ -168,13 +287,38 @@ std::unordered_map<std::string, std::string> SsaRewrite(
continue;
}
}
// Special SSA Rewrite for subnet of If Operator
if (op.type() == "If") {
ssaRewriteForIfOp(&op, &blob_versions, &is_initialized_tensor);
}
for (auto& output : *op.mutable_output()) {
auto it = blob_versions.find(output);
if (it != blob_versions.end()) {
it->second += 1;
if (op.type() != "If") {
if (is_initialized_tensor.count(output) == 0) {
it->second += 1;
} else {
is_initialized_tensor.erase(output);
}
}
output = SsaName(output, it->second);
} else {
blob_versions.emplace(output, 0);
// These filling ops are designed for a by-default value for the
// tensors generated by ops like If. For example, if an If op's
// condition is not satisfied, and it does not have else_net, then it
// will not generate any output blob, which may cause some error in
// the future. Here we would like to ensure these tensors only been
// ssa re-write once but not twice. (One in the filling operator, one
// in If op)
if ((caffe2::StartsWith(op.type(), "GivenTensor") &&
caffe2::EndsWith(op.type(), "Fill")) ||
op.type() == "ConstantFill" ||
op.type() == "Int8GivenTensorFill" ||
op.type() == "Int8GivenIntTensorFill") {
is_initialized_tensor.insert(output);
}
output = SsaName(output, 0);
}
}

View File

@ -134,24 +134,44 @@ void getWeightsAndInputs(
}
}
void unrollIfOps(NetDef* net) {
void collectInputsAndOutputs(
const OperatorDef& op,
std::set<std::string>* inputs,
std::set<std::string>* outputs) {
for (const auto& blob : op.input()) {
inputs->emplace(blob);
}
for (const auto& blob : op.output()) {
outputs->emplace(blob);
}
}
void fetchInputsToIfOpsSubnet(NetDef* net) {
NetDef clone(*net);
clone.clear_op();
for (const auto& op : net->op()) {
for (auto& op : net->op()) {
if (op.type() == "If") {
OperatorDef new_op(op);
ArgumentHelper helper(op);
std::set<std::string> subnet_inputs, subnet_outputs;
if (helper.HasSingleArgumentOfType<NetDef>("then_net")) {
auto then_net = helper.GetSingleArgument<NetDef>("then_net", NetDef());
for (const auto& nested_op : then_net.op()) {
clone.add_op()->CopyFrom(nested_op);
collectInputsAndOutputs(nested_op, &subnet_inputs, &subnet_outputs);
}
}
if (helper.HasSingleArgumentOfType<NetDef>("else_net")) {
auto else_net = helper.GetSingleArgument<NetDef>("else_net", NetDef());
for (const auto& nested_op : else_net.op()) {
clone.add_op()->CopyFrom(nested_op);
collectInputsAndOutputs(nested_op, &subnet_inputs, &subnet_outputs);
}
}
for (const std::string& blob : subnet_inputs) {
if (subnet_outputs.count(blob) == 0) {
new_op.add_input(blob);
}
}
clone.add_op()->CopyFrom(new_op);
} else {
clone.add_op()->CopyFrom(op);
}
@ -897,7 +917,7 @@ void OnnxifiTransformer::transform(
onnxifi_op_id_ = 0;
// Unroll If ops
unrollIfOps(pred_net);
fetchInputsToIfOpsSubnet(pred_net);
std::unordered_set<std::string> weights(
weight_names.begin(), weight_names.end());