mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add if ops support for onnxifi and ssa-rewrite (#19585)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19585 Originally we will unroll all If op to many different subnets; Now we will not unroll it anymore, but just add all external input of its subnet to the If op, and ssa-rewrite all external input/outputs. That would be enough. Reviewed By: yinghai Differential Revision: D15038139 fbshipit-source-id: 8532216d8749068acd5558ad0d8cb1d98463a063
This commit is contained in:
committed by
Facebook Github Bot
parent
41486306d9
commit
2f73b3d26e
@ -5,6 +5,7 @@
|
||||
#include "caffe2/proto/caffe2_legacy.pb.h"
|
||||
#include "caffe2/utils/map_utils.h"
|
||||
#include "caffe2/utils/proto_utils.h"
|
||||
#include "caffe2/utils/string_utils.h"
|
||||
|
||||
#include <numeric>
|
||||
#include <unordered_set>
|
||||
@ -127,6 +128,123 @@ NodeProto AddShapeNode(const std::string& input, const std::string& output) {
|
||||
#undef CAFFE2_TO_ONNX_TYPE
|
||||
}
|
||||
|
||||
void collectExternalsFromIfOpSubnet(
|
||||
const NetDef* net,
|
||||
std::vector<std::string>* input,
|
||||
std::vector<std::string>* output) {
|
||||
std::set<std::string> in_input, in_output;
|
||||
for (const auto& op : net->op()) {
|
||||
for (const auto& blob : op.input()) {
|
||||
in_input.emplace(blob);
|
||||
}
|
||||
for (const auto& blob : op.output()) {
|
||||
in_output.emplace(blob);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& blob : in_input) {
|
||||
if (!in_output.count(blob)) {
|
||||
input->push_back(blob);
|
||||
}
|
||||
}
|
||||
for (const auto& blob : in_output) {
|
||||
if (!in_input.count(blob)) {
|
||||
output->push_back(blob);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void rewriteSubnet(
|
||||
Argument* arg,
|
||||
std::map<std::string, std::string> oldname_to_newname) {
|
||||
NetDef* net = arg->mutable_n();
|
||||
for (auto& op : *(net->mutable_op())) {
|
||||
for (auto& input : *(op.mutable_input())) {
|
||||
if (oldname_to_newname.find(input) != oldname_to_newname.end()) {
|
||||
input = oldname_to_newname[input];
|
||||
}
|
||||
}
|
||||
for (auto& output : *(op.mutable_output())) {
|
||||
if (oldname_to_newname.find(output) != oldname_to_newname.end()) {
|
||||
output = oldname_to_newname[output];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Argument* getArgumentFromName(OperatorDef* op, const std::string& name) {
|
||||
for (int i = 0; i < op->arg_size(); i++) {
|
||||
if (op->mutable_arg(i)->name() == name) {
|
||||
return op->mutable_arg(i);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void ssaRewriteForIfOp(
|
||||
OperatorDef* op,
|
||||
std::unordered_map<std::string, int>* blob_versions,
|
||||
std::set<std::string>* is_initialized_tensor) {
|
||||
// Get all the "external" inputs and outpus of the subnet
|
||||
// Since then_net and else_net has same external input/output, we only collect
|
||||
// external input/output from one of its subnet And perform the rewrite to
|
||||
// both then_net and else_net
|
||||
std::vector<std::string> if_external_input;
|
||||
std::vector<std::string> if_external_output;
|
||||
ArgumentHelper helper(*op);
|
||||
Argument *then_arg = nullptr, *else_arg = nullptr;
|
||||
NetDef* target_net = nullptr;
|
||||
bool has_then = false, has_else = false;
|
||||
|
||||
if (helper.HasSingleArgumentOfType<NetDef>("then_net")) {
|
||||
then_arg = getArgumentFromName(op, "then_net");
|
||||
target_net = then_arg->mutable_n();
|
||||
has_then = true;
|
||||
}
|
||||
if (helper.HasSingleArgumentOfType<NetDef>("else_net")) {
|
||||
else_arg = getArgumentFromName(op, "else_net");
|
||||
if (!has_then) {
|
||||
target_net = else_arg->mutable_n();
|
||||
}
|
||||
has_else = true;
|
||||
}
|
||||
|
||||
if (has_then || has_else) {
|
||||
collectExternalsFromIfOpSubnet(
|
||||
target_net, &if_external_input, &if_external_output);
|
||||
std::map<string, string> oldname_to_newname;
|
||||
|
||||
// Build oldname_to_newname map
|
||||
for (auto& input : if_external_input) {
|
||||
const auto it = blob_versions->find(input);
|
||||
if (it != blob_versions->end()) {
|
||||
oldname_to_newname[input] = SsaName(input, it->second);
|
||||
}
|
||||
}
|
||||
for (auto& output : if_external_output) {
|
||||
auto it = blob_versions->find(output);
|
||||
if (it != blob_versions->end()) {
|
||||
if (is_initialized_tensor->count(output) == 0) {
|
||||
it->second += 1;
|
||||
} else {
|
||||
is_initialized_tensor->erase(output);
|
||||
}
|
||||
oldname_to_newname[output] = SsaName(output, it->second);
|
||||
} else {
|
||||
blob_versions->emplace(output, 0);
|
||||
oldname_to_newname[output] = SsaName(output, 0);
|
||||
}
|
||||
}
|
||||
|
||||
if (has_then) {
|
||||
rewriteSubnet(then_arg, oldname_to_newname);
|
||||
}
|
||||
if (has_else) {
|
||||
rewriteSubnet(else_arg, oldname_to_newname);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::string> SsaRewrite(
|
||||
caffe2::NetDef* init_net,
|
||||
caffe2::NetDef* pred_net) {
|
||||
@ -147,6 +265,7 @@ std::unordered_map<std::string, std::string> SsaRewrite(
|
||||
blob_versions.clear();
|
||||
}
|
||||
|
||||
std::set<std::string> is_initialized_tensor;
|
||||
if (pred_net) {
|
||||
std::unordered_set<std::string> external_outputs;
|
||||
for (const auto& input : pred_net->external_input()) {
|
||||
@ -168,13 +287,38 @@ std::unordered_map<std::string, std::string> SsaRewrite(
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// Special SSA Rewrite for subnet of If Operator
|
||||
if (op.type() == "If") {
|
||||
ssaRewriteForIfOp(&op, &blob_versions, &is_initialized_tensor);
|
||||
}
|
||||
for (auto& output : *op.mutable_output()) {
|
||||
auto it = blob_versions.find(output);
|
||||
if (it != blob_versions.end()) {
|
||||
it->second += 1;
|
||||
if (op.type() != "If") {
|
||||
if (is_initialized_tensor.count(output) == 0) {
|
||||
it->second += 1;
|
||||
} else {
|
||||
is_initialized_tensor.erase(output);
|
||||
}
|
||||
}
|
||||
output = SsaName(output, it->second);
|
||||
|
||||
} else {
|
||||
blob_versions.emplace(output, 0);
|
||||
// These filling ops are designed for a by-default value for the
|
||||
// tensors generated by ops like If. For example, if an If op's
|
||||
// condition is not satisfied, and it does not have else_net, then it
|
||||
// will not generate any output blob, which may cause some error in
|
||||
// the future. Here we would like to ensure these tensors only been
|
||||
// ssa re-write once but not twice. (One in the filling operator, one
|
||||
// in If op)
|
||||
if ((caffe2::StartsWith(op.type(), "GivenTensor") &&
|
||||
caffe2::EndsWith(op.type(), "Fill")) ||
|
||||
op.type() == "ConstantFill" ||
|
||||
op.type() == "Int8GivenTensorFill" ||
|
||||
op.type() == "Int8GivenIntTensorFill") {
|
||||
is_initialized_tensor.insert(output);
|
||||
}
|
||||
output = SsaName(output, 0);
|
||||
}
|
||||
}
|
||||
|
@ -134,24 +134,44 @@ void getWeightsAndInputs(
|
||||
}
|
||||
}
|
||||
|
||||
void unrollIfOps(NetDef* net) {
|
||||
void collectInputsAndOutputs(
|
||||
const OperatorDef& op,
|
||||
std::set<std::string>* inputs,
|
||||
std::set<std::string>* outputs) {
|
||||
for (const auto& blob : op.input()) {
|
||||
inputs->emplace(blob);
|
||||
}
|
||||
for (const auto& blob : op.output()) {
|
||||
outputs->emplace(blob);
|
||||
}
|
||||
}
|
||||
|
||||
void fetchInputsToIfOpsSubnet(NetDef* net) {
|
||||
NetDef clone(*net);
|
||||
clone.clear_op();
|
||||
for (const auto& op : net->op()) {
|
||||
for (auto& op : net->op()) {
|
||||
if (op.type() == "If") {
|
||||
OperatorDef new_op(op);
|
||||
ArgumentHelper helper(op);
|
||||
std::set<std::string> subnet_inputs, subnet_outputs;
|
||||
if (helper.HasSingleArgumentOfType<NetDef>("then_net")) {
|
||||
auto then_net = helper.GetSingleArgument<NetDef>("then_net", NetDef());
|
||||
for (const auto& nested_op : then_net.op()) {
|
||||
clone.add_op()->CopyFrom(nested_op);
|
||||
collectInputsAndOutputs(nested_op, &subnet_inputs, &subnet_outputs);
|
||||
}
|
||||
}
|
||||
if (helper.HasSingleArgumentOfType<NetDef>("else_net")) {
|
||||
auto else_net = helper.GetSingleArgument<NetDef>("else_net", NetDef());
|
||||
for (const auto& nested_op : else_net.op()) {
|
||||
clone.add_op()->CopyFrom(nested_op);
|
||||
collectInputsAndOutputs(nested_op, &subnet_inputs, &subnet_outputs);
|
||||
}
|
||||
}
|
||||
for (const std::string& blob : subnet_inputs) {
|
||||
if (subnet_outputs.count(blob) == 0) {
|
||||
new_op.add_input(blob);
|
||||
}
|
||||
}
|
||||
clone.add_op()->CopyFrom(new_op);
|
||||
} else {
|
||||
clone.add_op()->CopyFrom(op);
|
||||
}
|
||||
@ -897,7 +917,7 @@ void OnnxifiTransformer::transform(
|
||||
onnxifi_op_id_ = 0;
|
||||
|
||||
// Unroll If ops
|
||||
unrollIfOps(pred_net);
|
||||
fetchInputsToIfOpsSubnet(pred_net);
|
||||
|
||||
std::unordered_set<std::string> weights(
|
||||
weight_names.begin(), weight_names.end());
|
||||
|
Reference in New Issue
Block a user