mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10946 ``` codemod -d . --extensions cc,cpp,cu,cuh,h caffe2/proto/caffe2.pb.h caffe2/proto/caffe2_pb.h ``` Reviewed By: houseroad Differential Revision: D9539945 fbshipit-source-id: 497d04720e8e7e61c05ffe1b23733d0cb774de7e
346 lines
11 KiB
C++
346 lines
11 KiB
C++
#include "caffe2/core/net_dag.h"
|
|
|
|
#include <iostream>
|
|
#include <set>
|
|
#include <stack>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/core/static_tracepoint.h"
|
|
#include "caffe2/core/timer.h"
|
|
#include "caffe2/proto/caffe2_pb.h"
|
|
#include "caffe2/utils/proto_utils.h"
|
|
#include "caffe2/utils/thread_name.h"
|
|
|
|
CAFFE2_DEFINE_bool(
|
|
caffe2_disable_chaining,
|
|
false,
|
|
"Disable chaining logic (some latent multi-device issues).");
|
|
|
|
CAFFE2_DEFINE_bool(
|
|
caffe2_dag_net_collect_stats,
|
|
false,
|
|
"Collect time stats in DAG net");
|
|
|
|
namespace caffe2 {
|
|
|
|
DAGNetBase::DAGNetBase(
|
|
const std::shared_ptr<const NetDef>& net_def,
|
|
Workspace* ws)
|
|
: NetBase(net_def, ws), caught_exception_yet_(false), iter_(0) {
|
|
// Blob creator allows us to track which operator created which blob.
|
|
VLOG(1) << "Constructing DAGNet " << net_def->name();
|
|
|
|
operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
|
|
|
|
execution_chains_ =
|
|
(FLAGS_caffe2_disable_chaining
|
|
? dag_utils::singleChains(operator_nodes_)
|
|
: dag_utils::computeChains(operator_nodes_));
|
|
|
|
operators_.reserve(operator_nodes_.size());
|
|
for (const auto& node : operator_nodes_) {
|
|
operators_.push_back(node.operator_.get());
|
|
}
|
|
|
|
LOG(INFO) << "Number of parallel execution chains "
|
|
<< execution_chains_.size()
|
|
<< " Number of operators = " << net_def->op_size();
|
|
// TODO: do we want to make sure that there are no loops in the
|
|
// dependency graph?
|
|
|
|
// Figure out the initial frontier - this is the one we will feed into the job
|
|
// queue to start a run.
|
|
for (size_t idx = 0; idx < operator_nodes_.size(); ++idx) {
|
|
if (operator_nodes_[idx].parents_.size() == 0) {
|
|
initial_frontier_.push_back(idx);
|
|
}
|
|
}
|
|
// Finally, start the workers.
|
|
int num_workers = net_def->has_num_workers() ? net_def->num_workers() : 1;
|
|
CAFFE_ENFORCE(num_workers > 0, "Must have a positive number of workers.");
|
|
if (num_workers == 1) {
|
|
LOG(WARNING) << "Number of workers is 1: this means that all operators "
|
|
<< "will be executed sequentially. Did you forget to set "
|
|
<< "num_workers in the NetDef?";
|
|
}
|
|
num_workers_ = num_workers;
|
|
|
|
for (size_t idx = 0; idx < operator_nodes_.size(); ++idx) {
|
|
if (operator_nodes_[idx].is_chain_start_) {
|
|
task_timers_[idx] = caffe2::make_unique<Timer>();
|
|
}
|
|
}
|
|
stats_.reserve(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
|
|
for (auto device_idx = 0;
|
|
device_idx < DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
|
|
++device_idx) {
|
|
stats_.emplace_back(
|
|
"dag_net/stats/" + net_def->name() + "/" +
|
|
caffe2::DeviceTypeName(device_idx));
|
|
}
|
|
|
|
tracer_ = tracing::create(this, net_def->name());
|
|
if (tracer_) {
|
|
LOG(INFO) << "Tracing net: " << net_def->name();
|
|
}
|
|
}
|
|
|
|
DAGNetBase::~DAGNetBase() {
|
|
if (job_queue_) {
|
|
job_queue_->NoMoreJobs();
|
|
VLOG(1) << "Joining workers.";
|
|
for (auto& worker : workers_) {
|
|
worker.join();
|
|
}
|
|
}
|
|
}
|
|
|
|
bool DAGNetBase::DoRunAsync() {
|
|
StartAllObservers();
|
|
|
|
tracing::startIter(tracer_);
|
|
|
|
// Lock run_in_progress_ to prevent concurrent Run()s.
|
|
std::unique_lock<std::mutex> run_lock(run_in_progress_);
|
|
VLOG(1) << "Running parallel net.";
|
|
// First, set up job queue.
|
|
remaining_ops_ = operator_nodes_.size();
|
|
success_ = true;
|
|
iter_++;
|
|
if (!job_queue_) {
|
|
job_queue_ = caffe2::make_unique<SimpleQueue<int>>();
|
|
}
|
|
// Figure out number of workers to start.
|
|
size_t num_workers_to_start = num_workers_ - workers_.size();
|
|
|
|
// Ensure the number of workers matches the defined in case
|
|
// any of the previously started threads terminated.
|
|
for (size_t i = 0; i < num_workers_to_start; i++) {
|
|
VLOG(1) << "Start worker #" << workers_.size();
|
|
workers_.push_back(std::thread(&DAGNetBase::WorkerFunction, this));
|
|
}
|
|
// Initialize the runtime parent count.
|
|
for (auto& node : operator_nodes_) {
|
|
node.runtime_parent_count_ = node.parents_.size();
|
|
}
|
|
// Kickstart the job queue.
|
|
for (auto& value : initial_frontier_) {
|
|
if (FLAGS_caffe2_dag_net_collect_stats) {
|
|
task_timers_[value]->Start();
|
|
}
|
|
job_queue_->Push(value);
|
|
}
|
|
// Wait for failure or completed execution.
|
|
{
|
|
std::unique_lock<std::mutex> mutex_lock(remaining_ops_mutex_);
|
|
for (;;) {
|
|
if (remaining_ops_ == 0 || !success_) {
|
|
break;
|
|
}
|
|
cv_.wait(mutex_lock);
|
|
}
|
|
}
|
|
// Wait for all workers to terminate after failure.
|
|
// If there is a failure, it is unlikely that the net is executed
|
|
// again without modifications. Therefore it's easier to let the
|
|
// workers terminate here, versus adding a drain state to make the
|
|
// sure the job queue is cleared.
|
|
if (!success_) {
|
|
for (auto& worker : workers_) {
|
|
worker.join();
|
|
}
|
|
workers_.clear();
|
|
job_queue_.reset(nullptr);
|
|
#ifdef CAFFE2_USE_EXCEPTION_PTR
|
|
if (caught_exception_) {
|
|
// Reset flag here in case Net gets run again
|
|
caught_exception_yet_ = false;
|
|
std::rethrow_exception(caught_exception_);
|
|
}
|
|
#endif // CAFFE2_USE_EXCEPTION_PTR
|
|
return success_;
|
|
}
|
|
VLOG(2) << "All ops finished running.";
|
|
for (const auto& op : operator_nodes_) {
|
|
CAFFE_ENFORCE(
|
|
op.runtime_parent_count_ == 0,
|
|
"Operator ",
|
|
op.operator_->debug_def().name(),
|
|
"(",
|
|
op.operator_->debug_def().type(),
|
|
") has some runtime parents left.");
|
|
}
|
|
|
|
StopAllObservers();
|
|
// If the above while loop finished, we know that the current run finished.
|
|
return success_;
|
|
}
|
|
|
|
void DAGNetBase::HandleException(
|
|
int operator_idx,
|
|
const std::string& exception_str) {
|
|
const std::string& operator_name =
|
|
operator_nodes_[operator_idx].operator_->debug_def().name();
|
|
const std::string& operator_type =
|
|
operator_nodes_[operator_idx].operator_->debug_def().type();
|
|
const char* prefix = "Exception from operator chain starting at '";
|
|
#ifdef CAFFE2_USE_EXCEPTION_PTR
|
|
if (!caught_exception_yet_.exchange(true)) {
|
|
caught_exception_ = std::current_exception();
|
|
} else {
|
|
prefix = "Secondary exception from operator chain starting at '";
|
|
}
|
|
#endif // CAFFE2_USE_EXCEPTION_PTR
|
|
LOG(ERROR) << prefix << operator_name << "' (type '" << operator_type
|
|
<< "'): " << exception_str << "\n";
|
|
#ifndef CAFFE2_USE_EXCEPTION_PTR
|
|
throw; // Can't capture for dispatch to other thread, re-throw here
|
|
#endif // CAFFE2_USE_EXCEPTION_PTR
|
|
}
|
|
|
|
void DAGNetBase::WorkerFunction() {
|
|
setThreadName("CaffeDAGNet");
|
|
|
|
// WorkerFunctions() is an infinite loop until there are no more jobs to run.
|
|
while (true) {
|
|
int idx = 0;
|
|
|
|
// Return if there are no more operators to run (e.g. the
|
|
// DAGNetBase is destructing, or there was an error on another
|
|
// worker and we're cleaning up).
|
|
if (!job_queue_->Pop(&idx)) {
|
|
return;
|
|
}
|
|
if (FLAGS_caffe2_dag_net_collect_stats) {
|
|
auto device_option =
|
|
operator_nodes_[idx].operator_->event().GetDeviceOption();
|
|
CAFFE_EVENT(
|
|
stats_[device_option.device_type()],
|
|
task_pool_wait_time_us,
|
|
task_timers_[idx]->MicroSeconds());
|
|
}
|
|
|
|
VLOG(1) << "Running chain starting at operator #" << idx << " "
|
|
<< operator_nodes_[idx].operator_->debug_def().name() << "("
|
|
<< operator_nodes_[idx].operator_->debug_def().type() << ").";
|
|
CAFFE_ENFORCE(
|
|
execution_chains_.find(idx) != execution_chains_.end(),
|
|
"Can't find chain ",
|
|
idx,
|
|
".");
|
|
bool this_success = false;
|
|
try {
|
|
this_success = RunAt(idx, execution_chains_[idx]);
|
|
|
|
if (!this_success) {
|
|
// If an exception was thrown, the operator def will get printed
|
|
// by Operator::Run[Async], but if no exception occurs we print it here.
|
|
LOG(ERROR) << "Operator chain failed starting at: "
|
|
<< ProtoDebugString(
|
|
operator_nodes_[idx].operator_->debug_def());
|
|
}
|
|
} catch (std::exception& e) {
|
|
std::string exception_str = GetExceptionString(e);
|
|
HandleException(idx, exception_str);
|
|
} catch (...) {
|
|
std::string exception_str = "Unknown exception";
|
|
HandleException(idx, exception_str);
|
|
}
|
|
|
|
// Do book-keeping
|
|
std::vector<int> chains_to_queue;
|
|
const auto& chain = execution_chains_[idx];
|
|
for (const auto idx : chain) {
|
|
for (const auto child : operator_nodes_[idx].children_) {
|
|
const int count = --operator_nodes_[child].runtime_parent_count_;
|
|
CAFFE_ENFORCE(
|
|
count >= 0,
|
|
"Found runtime parent count smaller than zero for ",
|
|
"operator node ",
|
|
operator_nodes_[child].operator_->debug_def().name(),
|
|
"(",
|
|
operator_nodes_[child].operator_->debug_def().type(),
|
|
").");
|
|
|
|
if (count != 0) {
|
|
continue;
|
|
}
|
|
|
|
if (operator_nodes_[child].is_chain_start_) {
|
|
VLOG(2) << "Pushing chain #" << child << " to queue.";
|
|
chains_to_queue.push_back(child);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Notify the caller of Run
|
|
{
|
|
std::unique_lock<std::mutex> mutex_lock(remaining_ops_mutex_);
|
|
remaining_ops_ -= chain.size();
|
|
CAFFE_ENFORCE(remaining_ops_ >= 0);
|
|
success_ &= this_success;
|
|
if (remaining_ops_ == 0 || !success_) {
|
|
cv_.notify_one();
|
|
}
|
|
|
|
// Terminate thread if this or any other operator chain failed.
|
|
if (!success_) {
|
|
job_queue_->NoMoreJobs();
|
|
return;
|
|
}
|
|
|
|
// Queue follow up operator chains.
|
|
// Can't do this inline because it can race with another thread
|
|
// calling NoMoreJobs(). So the lock needs to be held on push.
|
|
for (const auto idx : chains_to_queue) {
|
|
if (FLAGS_caffe2_dag_net_collect_stats) {
|
|
task_timers_[idx]->Start();
|
|
}
|
|
job_queue_->Push(idx);
|
|
}
|
|
}
|
|
|
|
VLOG(2) << "Finished executing operator #" << idx;
|
|
}
|
|
}
|
|
|
|
bool DAGNet::RunAt(int chain_id, const std::vector<int>& chain) {
|
|
for (const auto i : chain) {
|
|
#ifdef CAFFE2_ENABLE_SDT
|
|
const auto& op_name =
|
|
operator_nodes_[i].operator_->debug_def().name().c_str();
|
|
const auto& op_type =
|
|
operator_nodes_[i].operator_->debug_def().type().c_str();
|
|
auto* op_ptr = operator_nodes_[i].operator_.get();
|
|
const auto& net_name = name_.c_str();
|
|
CAFFE_SDT(operator_start, net_name, op_name, op_type, op_ptr);
|
|
#endif
|
|
bool success = false;
|
|
{
|
|
TRACE_EVENT(tracing::TRACE_OP, i, tracing::TRACE_TASK, chain_id);
|
|
success = operator_nodes_[i].operator_->Run();
|
|
}
|
|
#ifdef CAFFE2_ENABLE_SDT
|
|
CAFFE_SDT(operator_done, net_name, op_name, op_type, op_ptr);
|
|
#endif
|
|
if (!success) {
|
|
return false;
|
|
}
|
|
}
|
|
if (FLAGS_caffe2_dag_net_collect_stats) {
|
|
auto device_option =
|
|
operator_nodes_[chain_id].operator_->event().GetDeviceOption();
|
|
CAFFE_EVENT(
|
|
stats_[device_option.device_type()],
|
|
task_time_to_succeeded_ms,
|
|
task_timers_[chain_id]->MilliSeconds());
|
|
}
|
|
return true;
|
|
}
|
|
|
|
REGISTER_NET(dag, DAGNet);
|
|
|
|
} // namespace caffe2
|