Files
pytorch/caffe2/core/net.cc
Yangqing Jia 7d5f7ed270 Using c10 namespace across caffe2. (#12714)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12714

This is a short change to enable c10 namespace in caffe2. We did not enable
it before due to gflags global variable confusion, but it should have been
mostly cleaned now. Right now, the plan on record is that namespace caffe2 and
namespace aten will fully be supersets of namespace c10.

Most of the diff is codemod, and only two places of non-codemod is in caffe2/core/common.h, where

```
using namespace c10;
```

is added, and in Flags.h, where instead of creating aliasing variables in c10 namespace, we directly put it in the global namespace to match gflags (and same behavior if gflags is not being built with).

Reviewed By: dzhulgakov

Differential Revision: D10390486

fbshipit-source-id: 5e2df730e28e29a052f513bddc558d9f78a23b9b
2018-10-17 12:57:19 -07:00

218 lines
6.3 KiB
C++

#include "caffe2/core/net.h"
#include "caffe2/core/net_simple.h"
#include <set>
#include <unordered_map>
#include <unordered_set>
#include "caffe2/core/init.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/timer.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/proto_utils.h"
#include "caffe2/utils/string_utils.h"
C10_DEFINE_string(
caffe2_override_executor,
"",
"Comma-separated list of executor overrides");
namespace caffe2 {
C10_DEFINE_REGISTRY(
NetRegistry,
NetBase,
const std::shared_ptr<const NetDef>&,
Workspace*);
NetBase::NetBase(
const std::shared_ptr<const NetDef>& def,
Workspace* /* unused */)
: external_input_(
def->external_input().begin(),
def->external_input().end()),
external_output_(
def->external_output().begin(),
def->external_output().end()),
name_(def->name()),
net_def_(def) {
static GlobalInitIsCalledGuard guard;
// Check that node_name is empty for all ops
for (const OperatorDef& op : def->op()) {
if (op.has_device_option()) {
CAFFE_ENFORCE(
!op.device_option().has_node_name(),
"node_name must be empty for all operators at execution time.");
}
}
// Go through the operators and make sure that blobs are correctly made.
std::set<string> known_blobs(
external_input_.begin(), external_input_.end());
std::set<string> remaining_output(
external_output_.begin(), external_output_.end());
for (const auto& blob : known_blobs) {
remaining_output.erase(blob);
}
for (const OperatorDef& op : def->op()) {
for (const string& in : op.input()) {
if (!known_blobs.count(in)) {
if (external_input_.size()) {
CAFFE_THROW(
"op ",
op.type(),
": Source for input ",
in,
" is unknown for net ",
def->name(),
", operator ",
ProtoDebugString(op));
} else {
// If we are not declaring input and output, we will simply VLOG it
// for debugging purposes.
VLOG(1) << "op " << op.type() << ": input " << in << " is unknown.";
}
}
}
for (const string& out : op.output()) {
known_blobs.insert(out);
remaining_output.erase(out);
}
}
// Finally, check if all declared outputs are being created.
CAFFE_ENFORCE(
remaining_output.size() == 0,
"Some of the blobs are declared as output but never produced by the "
"net ",
def->name(),
", the first one is ",
*remaining_output.begin());
}
bool NetBase::RunAsync() {
for (auto& op : GetOperators()) {
op->ResetEvent();
}
return DoRunAsync();
}
namespace {
const std::string kSimpleNet = "simple";
std::vector<NetObserverCreator>* GetNetObserverCreators() {
static std::vector<NetObserverCreator> creators;
return &creators;
}
const std::unordered_map<std::string, std::string>& defaultOverrides() {
static const std::unordered_map<std::string, std::string> overrides = {
{"dag", "async_scheduling"},
{"prof_dag", "async_scheduling"},
{"async_dag", "async_scheduling"},
{"async_polling", "async_scheduling"},
{"async_simple", "simple"},
};
return overrides;
}
void checkExecutorOverride(std::string& net_type) {
auto executors = caffe2::split(',', FLAGS_caffe2_override_executor);
CAFFE_ENFORCE(
executors.size() % 2 == 0, "Invalid override executors flag value");
std::unordered_map<std::string, std::string> overrides;
for (const auto& kv : defaultOverrides()) {
overrides[kv.first] = kv.second;
}
for (size_t idx = 0; idx < executors.size(); idx += 2) {
overrides[executors[idx]] = executors[idx + 1];
}
if (overrides.count(net_type)) {
VLOG(1) << "Overrode net type '" << net_type << "' with '"
<< overrides[net_type] << "'";
net_type = overrides[net_type];
}
}
} // namespace
void AddGlobalNetObserverCreator(NetObserverCreator creator) {
GetNetObserverCreators()->push_back(creator);
VLOG(1) << "Have set a custom GlobalNetObserverCreator";
}
void ClearGlobalNetObservers() {
GetNetObserverCreators()->clear();
VLOG(1) << "All net observers cleared";
}
unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws) {
std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
return CreateNet(tmp_net_def, ws);
}
unique_ptr<NetBase> CreateNet(
const std::shared_ptr<const NetDef>& net_def,
Workspace* ws) {
std::string net_type;
if (net_def->has_type()) {
net_type = net_def->type();
} else {
// By default, we will return a simple network that just runs all operators
// sequentially.
net_type = kSimpleNet;
}
checkExecutorOverride(net_type);
unique_ptr<NetBase> net = NetRegistry()->Create(net_type, net_def, ws);
VLOG(1) << "Adding a global observer to a net";
if (net) {
auto* observer_creators = GetNetObserverCreators();
for (auto& creator : *observer_creators) {
net->AttachObserver(creator(net.get()));
}
}
return net;
}
TaskThreadPoolBase* ExecutorHelper::GetPool(
const DeviceOption& /* unused */) const {
CAFFE_THROW("Not implemented");
}
std::vector<float> NetBase::TEST_Benchmark(
const int warmup_runs,
const int main_runs,
const bool run_individual) {
LOG(INFO) << "Starting benchmark, running warmup runs";
CAFFE_ENFORCE(
warmup_runs >= 0,
"Number of warm up runs should be non negative, provided ",
warmup_runs);
for (int run_idx = 0; run_idx < warmup_runs; ++run_idx) {
CAFFE_ENFORCE(Run(), "Warmup run ", run_idx, " has failed");
}
LOG(INFO) << "Running main runs";
CAFFE_ENFORCE(
main_runs >= 0,
"Number of main runs should be non negative, provided ",
main_runs);
Timer timer;
for (int run_idx = 0; run_idx < main_runs; ++run_idx) {
CAFFE_ENFORCE(Run(), "Main run ", run_idx, " has failed");
}
auto millis = timer.MilliSeconds();
LOG(INFO) << "Main runs finished. Milliseconds per iter: "
<< millis / main_runs
<< ". Iters per second: " << 1000.0 * main_runs / millis;
if (run_individual) {
LOG(INFO) << "Net does not support per-op benchmark; "
"to run it, switch to a simple net type";
}
return std::vector<float>{millis / main_runs};
}
} // namespace caffe2