mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-07 01:50:04 +08:00
Summary: Overal context: open-source BlackBoxPredictor as the entry point for inference in Caffe2 (thread safe abstraction for Caffe2 inference). This should be used in ThroughputBenchmark for the purpose of framework comparison This specific diff: There should be no harm in moving transformation code to OSS. On the advantages side we will be able to compare production Caffe2 setup with PyTorch in the most fair way via ThroughputBenchmark. This approach avoid any complicated transformation regirstries. Building those proper would be significant engineering effort as well as production risk. In the past we had SEVs related to transforms being turned off due to various refactors. Given that we don't plan to build any other significant investments into transformation logic except existing ones (like TVM and Glow), and those also relate to open-source technologies, I came up to the conclusion of moving to OSS the whole thing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/23350 ghstack-source-id: 87121538 Pull Request resolved: https://github.com/pytorch/pytorch/pull/24928 Test Plan: waitforsandcastle Differential Revision: D16445133 Pulled By: salexspb fbshipit-source-id: a93106489611dfe427b0f144717bc720d04e47f3
330 lines
12 KiB
C++
330 lines
12 KiB
C++
#include "caffe2/caffe2/fb/predictor/Transforms.h"
|
|
|
|
|
|
namespace caffe2 {
|
|
|
|
namespace {
|
|
string
|
|
NextBlob(const Workspace& ws, const string& prefix, int max_tries = 1000000) {
|
|
for (int i = 0; i < max_tries; ++i) {
|
|
std::stringstream stream;
|
|
stream << prefix;
|
|
if (i) {
|
|
stream << '_' << i;
|
|
}
|
|
if (!ws.HasBlob(stream.str())) {
|
|
return stream.str();
|
|
}
|
|
}
|
|
CAFFE_THROW("Failed to find a new blob name");
|
|
return "";
|
|
}
|
|
|
|
bool HasInput(const string& blob, const OperatorDef& op) {
|
|
for (auto in : op.input()) {
|
|
if (blob == in) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool HasOutput(const string& blob, const OperatorDef& op) {
|
|
for (auto out : op.output()) {
|
|
if (blob == out) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/*
|
|
Tests if it is valid to simply rename "from" to "to" after start (at or after
|
|
start) ignoring in-place ops of operator types in ignoreTypes. There is a set of
|
|
ignoreTypes instead of a single ignoreType for uses cases where multiple
|
|
op-types will be in-placed; in such cases this function will give a more optimal
|
|
answer if it can consider all types at the same time.
|
|
*/
|
|
bool CanRenameForwards(
|
|
const string& from,
|
|
const string& to,
|
|
const std::shared_ptr<NetDef>& net,
|
|
const std::set<string>& netOutputs,
|
|
const std::set<string>& ignoreTypes,
|
|
int start) {
|
|
bool redefined_to = false;
|
|
for (int i = start; i < net->op_size(); i++) {
|
|
auto op = net->op(i);
|
|
bool uses_from = HasInput(from, op);
|
|
|
|
// If there's a use of "from" after a redefine of "to" then we can't rename
|
|
// this "from" to "to" b/c it will use the wrong "to" (this op's "to"
|
|
// instead of the "to" that would be produced by renaming ops[start-1] to
|
|
// output "to" to "from")
|
|
if (redefined_to && uses_from) {
|
|
VLOG(7) << "CanRenameForwards " << to << " is redefined before " << from
|
|
<< " is used. Cannot rename from " << from << " to " << to;
|
|
return false;
|
|
}
|
|
// If "to" is redefined then we have to be careful of in-placing this op or
|
|
// blocking future uses of "from"
|
|
if (HasOutput(to, op)) {
|
|
// If this op also uses "from", then renaming "from" to "to" will make
|
|
// this op inplace
|
|
if (uses_from) {
|
|
// In-placing of this op is not allowed
|
|
if (!ignoreTypes.count(op.type())) {
|
|
VLOG(7) << "CanRenameForwards detected in-placing of op of type "
|
|
<< op.type();
|
|
return false;
|
|
}
|
|
VLOG(7) << "CanRenameForwards will make " << op.DebugString()
|
|
<< " in-place, but this is in the okay-types-to-inplace "
|
|
<< "whitelist";
|
|
}
|
|
// This op won't be inplaced (or it's okay if it is) but we still have to
|
|
// watch out for future uses of "from"
|
|
redefined_to = true;
|
|
VLOG(7) << "CanRenameForwards " << to << " is redefined at " << i;
|
|
}
|
|
// If this op redefines "from" then renaming will stop with the inputs of
|
|
// this op. Since we haven't found any problems with renaming, it's okay to
|
|
// rename. Note that we don't need to check if "from" is a network output,
|
|
// as this op will produce "from"
|
|
if (HasOutput(from, op)) {
|
|
VLOG(7) << "CanRenameForwards found another op making " << from
|
|
<< " so it's fine to rename " << from << " to " << to
|
|
<< " in earlier ops of the net.";
|
|
return true;
|
|
}
|
|
}
|
|
// We reached the end of the ops. There have been no redefinitions of "from",
|
|
// so if "from" is needed in the network outputs then we can't rename it
|
|
return !netOutputs.count(from);
|
|
}
|
|
|
|
bool CanRenameBackwards(
|
|
const string& from,
|
|
const string& to,
|
|
const std::shared_ptr<NetDef>& net,
|
|
const std::set<string>& netInputs,
|
|
const std::set<string>& netOutputs,
|
|
const std::set<string>& ignoreTypes,
|
|
int end) {
|
|
for (int i = end; i >= 0; i--) {
|
|
auto op = net->op(i);
|
|
// If this op defines "to", then all ops after this point will use this op's
|
|
// "to" instead of the producer-of-"from"s, so the producer can't be renamed
|
|
// to produce "to" instead of "from"
|
|
// FUTURE_POSSIBILITY: We might be able to rename this op to not produce to
|
|
if (HasOutput(to, op)) {
|
|
VLOG(7) << "CanRenameBackwards " << to << " is defined after " << from
|
|
<< " is. Cannot rename.";
|
|
return false;
|
|
}
|
|
// Because of the previous question, we know that no op in between this op
|
|
// and end has "to" as an output, so it is impossible to make any of them
|
|
// in-place by adding "to" as an input to any of them
|
|
|
|
// If we find the producer of "from", then we will stop renaming backwards
|
|
// (we won't rename this ops inputs)
|
|
if (HasOutput(from, op)) {
|
|
// If this op has "to" as an input, then renaming "from" to "to" will
|
|
// in-place this op
|
|
if (HasInput(to, op)) {
|
|
if (!ignoreTypes.count(op.type())) {
|
|
// In future, you could maybe check if "to" could be renamed to a
|
|
// brand new unique blob name
|
|
VLOG(7) << "CanRenameBackwards will in-place producer of " << from;
|
|
return false;
|
|
}
|
|
VLOG(7) << "CanRenameBackwards will in-place the producer of " << from
|
|
<< " but this is okay because it's in our in-placeable "
|
|
<< "whitelist";
|
|
}
|
|
// This op won't be in-placed (or it will but it's still okay), but we
|
|
// have to check the forwards logic too
|
|
// TODO why? what's the specific case again?
|
|
VLOG(7) << "CanRenameBackwards found the parent of " << from
|
|
<< ". Recursively testing if we can rename the parent.";
|
|
return CanRenameForwards(from, to, net, netOutputs, ignoreTypes, i + 1);
|
|
}
|
|
// After this point, inputs of this op may be renamed
|
|
|
|
// If this blob uses "to", then renaming the producer of "from" to produce
|
|
// "to" will interfere with this op. Technically it will only interfere if
|
|
// the producer of "from" would overwrite this op's "to", but if that wasn't
|
|
// the case then there's some op that produces "to" after the producer of
|
|
// "from", and this will be caught in the first if (HasOutput(to, op))
|
|
if (HasInput(to, op)) {
|
|
VLOG(7) << "CanRenameBackwards " << to << " is used after " << from
|
|
<< " is defined. Cannot rename.";
|
|
return false;
|
|
}
|
|
}
|
|
// Found no parent, so must be a network output. We cannot rename it
|
|
CAFFE_ENFORCE(netInputs.count(from));
|
|
VLOG(7) << "CanRenameBackwards " << from << " is a network input. Cannot "
|
|
<< "rename.";
|
|
return false;
|
|
}
|
|
|
|
void RenameInputs(const string& from, const string& to, OperatorDef* def) {
|
|
VLOG(6) << "RenameInputs(from=" << from << ", to=" << to << ", "
|
|
<< def->DebugString() << ")";
|
|
for (int i = 0; i < def->input_size(); i++) {
|
|
if (def->input(i) == from) {
|
|
*def->mutable_input(i) = to;
|
|
}
|
|
}
|
|
}
|
|
|
|
void RenameOutputs(const string& from, const string& to, OperatorDef* def) {
|
|
VLOG(6) << "RenameOutputs(from=" << from << ", to=" << to << ", "
|
|
<< def->DebugString() << ")";
|
|
for (string& output : *def->mutable_output()) {
|
|
if (output == from) {
|
|
output = to;
|
|
}
|
|
}
|
|
}
|
|
|
|
void RenameInputsInChildren(
|
|
const string& from,
|
|
const string& to,
|
|
std::shared_ptr<caffe2::NetDef> net,
|
|
int pidx) {
|
|
// This does NOT continue through in-place ops
|
|
VLOG(4) << "RenameInputsInChildren(from=" << from << ", to=" << to;
|
|
for (int j = pidx + 1; j < net->op_size(); j++) {
|
|
if (HasInput(from, net->op(j))) {
|
|
RenameInputs(from, to, net->mutable_op(j));
|
|
}
|
|
// If any child op redefines from, then future ops no longer use this op's
|
|
// (at j) version of from
|
|
if (HasOutput(from, net->op(j))) {
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void InPlaceOps(const InferenceGraph& graph, const std::string& op_type) {
|
|
int num_inplaced = 0;
|
|
auto net = graph.predict_net_def;
|
|
|
|
// Collect blob names that we can never rename
|
|
std::set<string> netInputs(
|
|
graph.parameter_names.begin(), graph.parameter_names.end());
|
|
netInputs.insert(graph.input_names.begin(), graph.input_names.end());
|
|
std::set<string> netOutputs(
|
|
graph.output_names.begin(), graph.output_names.end());
|
|
|
|
// In-place ops greedily in a forward manner
|
|
for (int i = 0; i < net->op_size(); i++) {
|
|
OperatorDef op = net->op(i);
|
|
|
|
// Only inplace the requested
|
|
if (op.type() != op_type) {
|
|
VLOG(2) << "InPlaceObs: Type is " << op_type << ". Not in-placing";
|
|
continue;
|
|
}
|
|
|
|
if (op.input_size() != 1 || op.output_size() != 1) {
|
|
LOG(ERROR) << "InPlaceOps only supports ops with exactly 1 output "
|
|
<< "and exactly 1 input. Skipping op " << op.DebugString();
|
|
continue;
|
|
}
|
|
|
|
const string& in = op.input(0);
|
|
const string& out = op.output(0);
|
|
|
|
// If it's already in place then let's not do any more work
|
|
if (in == out) {
|
|
continue;
|
|
}
|
|
|
|
// Otherwise check if we can rename things
|
|
bool can_rename_forwards =
|
|
CanRenameForwards(out, in, net, netOutputs, {op_type}, i + 1);
|
|
|
|
// If renaming is impossible (or complicated) then skip this op
|
|
if (!can_rename_forwards &&
|
|
!CanRenameBackwards(
|
|
in, out, net, netInputs, netOutputs, {op_type}, i - 1)) {
|
|
VLOG(2) << "InPlaceOps: Complicated or impossible remove for op: "
|
|
<< op.DebugString();
|
|
continue;
|
|
}
|
|
VLOG(2) << "InPlaceOps will inplace " << op.DebugString();
|
|
num_inplaced++;
|
|
|
|
// Handle renaming
|
|
if (can_rename_forwards) {
|
|
// Rename out to in
|
|
VLOG(3) << "InPlaceOps can rename in children from " << out << " to "
|
|
<< in;
|
|
RenameInputsInChildren(out, in, net, i);
|
|
|
|
} else {
|
|
// Since out is an output of the network, we must rename in parents of op
|
|
VLOG(3) << "InPlaceOps must find parent that produced " << in
|
|
<< " and rename in all of its children from " << in << " to "
|
|
<< out;
|
|
for (int pidx = i - 1; pidx >= 0; pidx--) {
|
|
if (HasOutput(in, net->op(pidx))) {
|
|
VLOG(5) << "InPlaceOps found parent is "
|
|
<< net->op(pidx).DebugString();
|
|
RenameOutputs(in, out, net->mutable_op(pidx));
|
|
RenameInputsInChildren(in, out, net, pidx);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
} // For every op
|
|
VLOG(1) << "InPlaceOps(" << op_type << ") renamed " << num_inplaced << " ops";
|
|
}
|
|
|
|
void RemoveOpsByType(const InferenceGraph& graph, const std::string& op_type) {
|
|
int num_removed = 0;
|
|
auto net = graph.predict_net_def;
|
|
|
|
// Rename all the ops we want to delete
|
|
InPlaceOps(graph, op_type);
|
|
|
|
// Now the only ops we can delete are inplaced ones
|
|
for (int i = 0; i < net->op_size(); i++) {
|
|
OperatorDef op = net->op(i);
|
|
|
|
// Only remove ops of the requested type
|
|
if (op.type() != op_type) {
|
|
VLOG(2) << "RemoveOpsByType: Type is " << op_type << ". Not removing";
|
|
continue;
|
|
}
|
|
|
|
if (op.input_size() != 1 || op.output_size() != 1) {
|
|
LOG(ERROR) << "RemoveOpsByType only supports ops with exactly 1 output "
|
|
<< "and exactly 1 input. Skipping op " << op.DebugString();
|
|
continue;
|
|
}
|
|
|
|
const string& in = op.input(0);
|
|
const string& out = op.output(0);
|
|
|
|
// If the op is in-place then we can always delete it
|
|
if (in == out) {
|
|
VLOG(1) << "RemoveOpsByType(" << op_type << ") deleting inplace op";
|
|
net->mutable_op()->erase(net->mutable_op()->begin() + i);
|
|
i--;
|
|
num_removed++;
|
|
} else {
|
|
VLOG(2) << "RemoveOpsByType(" << op_type << ") can't delete.";
|
|
}
|
|
} // For every op
|
|
VLOG(1) << "RemoveOpsByType(" << op_type << ") removed " << num_removed
|
|
<< " ops";
|
|
}
|
|
|
|
} // namespace caffe2
|