mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Back out "Revert D13043261: [caffe2] Task graph and task future abstractions in executor"
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15030 Reviewed By: bddppq Differential Revision: D13408998 fbshipit-source-id: 9eb675e09fbc4829eab34df7aa660a0590816feb
This commit is contained in:
committed by
Facebook Github Bot
parent
83f32eebd9
commit
e9cd781681
107
caffe2/core/net_async_task.cc
Normal file
107
caffe2/core/net_async_task.cc
Normal file
@ -0,0 +1,107 @@
|
||||
#include "caffe2/core/net_async_task.h"
|
||||
|
||||
#include "caffe2/core/net_async_task_graph.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
AsyncTask::AsyncTask(const std::vector<OperatorBase*>& ops) : ops_(ops) {
|
||||
CAFFE_ENFORCE(!ops_.empty());
|
||||
device_option_ = ops_.front()->device_option();
|
||||
for (auto& op : ops_) {
|
||||
CAFFE_ENFORCE(IsSameDevice(device_option_, op->device_option()));
|
||||
}
|
||||
Reset();
|
||||
}
|
||||
|
||||
void AsyncTask::handleChainError(
|
||||
OperatorBase* op,
|
||||
const char* err_str,
|
||||
bool save_exception) {
|
||||
std::string err_msg = err_str;
|
||||
if (op) {
|
||||
err_msg += ", op " + (op->has_debug_def() ? op->type() : " unknown");
|
||||
}
|
||||
LOG(ERROR) << err_msg;
|
||||
|
||||
// save error message and exception in chain's Event
|
||||
auto last_op = ops_.back();
|
||||
if (save_exception) {
|
||||
last_op->event().SetFinishedWithException(err_msg.c_str());
|
||||
} else {
|
||||
last_op->event().SetFinished(err_msg.c_str());
|
||||
}
|
||||
|
||||
// set future as completed with an error
|
||||
// TODO: exceptions in future
|
||||
future_.SetCompleted(err_msg.c_str());
|
||||
}
|
||||
|
||||
bool AsyncTask::Run(const ExecutionOptions& options) {
|
||||
// TODO: insert CUDA's async stream waits; tracing and counters
|
||||
OperatorBase* op = nullptr;
|
||||
try {
|
||||
for (auto op_idx = 0; op_idx < ops_.size(); ++op_idx) {
|
||||
op = ops_[op_idx];
|
||||
int stream_id = 0; // TODO: thread local stream id
|
||||
if (!op->RunAsync(stream_id)) {
|
||||
handleChainError(op, "Failed to execute an op");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (options.finish_chain_) {
|
||||
op = ops_.back();
|
||||
op->Finish();
|
||||
}
|
||||
|
||||
// set the future as successfully completed or, in case of async CPU,
|
||||
// use op's callback
|
||||
if (IsCPUDeviceType(device_option_.device_type()) &&
|
||||
ops_.back()->HasAsyncPart()) {
|
||||
auto& event = ops_.back()->event();
|
||||
event.SetCallback([this, &event]() {
|
||||
CAFFE_ENFORCE(event.IsFinished());
|
||||
if (event.Query() == EventStatus::EVENT_SUCCESS) {
|
||||
future_.SetCompleted();
|
||||
} else {
|
||||
// TODO: support for exceptions
|
||||
future_.SetCompleted(event.ErrorMessage().c_str());
|
||||
}
|
||||
});
|
||||
} else {
|
||||
future_.SetCompleted();
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
handleChainError(op, e.what(), /* save_exception */ true);
|
||||
return false;
|
||||
} catch (...) {
|
||||
handleChainError(
|
||||
op,
|
||||
"Failed to execute task: unknown error",
|
||||
/* save_exception */ true);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void AsyncTask::Reset() {
|
||||
for (auto& op : ops_) {
|
||||
op->ResetEvent();
|
||||
}
|
||||
future_.ResetState();
|
||||
}
|
||||
|
||||
DeviceOption AsyncTask::GetDeviceOption() const {
|
||||
return device_option_;
|
||||
}
|
||||
|
||||
AsyncTaskFuture& AsyncTask::GetFuture() {
|
||||
return future_;
|
||||
}
|
||||
|
||||
const AsyncTaskFuture& AsyncTask::GetFuture() const {
|
||||
return future_;
|
||||
}
|
||||
|
||||
}; // namespace caffe2
|
39
caffe2/core/net_async_task.h
Normal file
39
caffe2/core/net_async_task.h
Normal file
@ -0,0 +1,39 @@
|
||||
#ifndef CAFFE2_NET_ASYNC_TASK_H
|
||||
#define CAFFE2_NET_ASYNC_TASK_H
|
||||
|
||||
#include "caffe2/core/net_async_base.h"
|
||||
#include "caffe2/core/net_async_task_future.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
// AsyncTask represents an asynchronous execution of a chain of ops.
|
||||
class AsyncTask {
|
||||
public:
|
||||
AsyncTask(const std::vector<OperatorBase*>& ops);
|
||||
|
||||
bool Run(const ExecutionOptions& options);
|
||||
|
||||
void Reset();
|
||||
|
||||
DeviceOption GetDeviceOption() const;
|
||||
|
||||
AsyncTaskFuture& GetFuture();
|
||||
const AsyncTaskFuture& GetFuture() const;
|
||||
|
||||
private:
|
||||
void handleChainError(
|
||||
OperatorBase* op,
|
||||
const char* err_msg,
|
||||
bool save_exception = false);
|
||||
|
||||
std::vector<OperatorBase*> ops_;
|
||||
DeviceOption device_option_;
|
||||
AsyncTaskFuture future_;
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_NET_ASYNC_TASK_H
|
110
caffe2/core/net_async_task_future.cc
Normal file
110
caffe2/core/net_async_task_future.cc
Normal file
@ -0,0 +1,110 @@
|
||||
#include "caffe2/core/net_async_task_future.h"
|
||||
|
||||
#include "c10/util/Logging.h"
|
||||
#include "caffe2/core/common.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
AsyncTaskFuture::AsyncTaskFuture() : completed_(false), failed_(false) {}
|
||||
|
||||
AsyncTaskFuture::AsyncTaskFuture(const std::vector<AsyncTaskFuture*>& futures)
|
||||
: completed_(false), failed_(false) {
|
||||
if (futures.size() > 1) {
|
||||
parent_counter_ = caffe2::make_unique<ParentCounter>(futures.size());
|
||||
for (auto future : futures) {
|
||||
future->SetCallback([this](const AsyncTaskFuture* f) {
|
||||
if (f->IsFailed()) {
|
||||
std::unique_lock<std::mutex> lock(parent_counter_->err_mutex);
|
||||
if (parent_counter_->parent_failed) {
|
||||
parent_counter_->err_msg += ", " + f->ErrorMessage();
|
||||
} else {
|
||||
parent_counter_->parent_failed = true;
|
||||
parent_counter_->err_msg = f->ErrorMessage();
|
||||
}
|
||||
}
|
||||
int count = --parent_counter_->parent_count;
|
||||
if (count == 0) {
|
||||
// thread safe to use parent_counter here
|
||||
if (!parent_counter_->parent_failed) {
|
||||
SetCompleted();
|
||||
} else {
|
||||
SetCompleted(parent_counter_->err_msg.c_str());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
} else {
|
||||
CAFFE_ENFORCE_EQ(futures.size(), 1);
|
||||
auto future = futures.back();
|
||||
future->SetCallback([this](const AsyncTaskFuture* f) {
|
||||
if (!f->IsFailed()) {
|
||||
SetCompleted();
|
||||
} else {
|
||||
SetCompleted(f->ErrorMessage().c_str());
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
bool AsyncTaskFuture::IsCompleted() const {
|
||||
return completed_;
|
||||
}
|
||||
|
||||
bool AsyncTaskFuture::IsFailed() const {
|
||||
return failed_;
|
||||
}
|
||||
|
||||
std::string AsyncTaskFuture::ErrorMessage() const {
|
||||
return err_msg_;
|
||||
}
|
||||
|
||||
void AsyncTaskFuture::Wait() const {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (!completed_) {
|
||||
cv_completed_.wait(lock);
|
||||
}
|
||||
}
|
||||
|
||||
void AsyncTaskFuture::SetCallback(
|
||||
std::function<void(const AsyncTaskFuture*)> callback) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
|
||||
callbacks_.push_back(callback);
|
||||
if (completed_) {
|
||||
callback(this);
|
||||
}
|
||||
}
|
||||
|
||||
void AsyncTaskFuture::SetCompleted(const char* err_msg) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
|
||||
CAFFE_ENFORCE(!completed_, "Calling SetCompleted on a completed future");
|
||||
completed_ = true;
|
||||
|
||||
if (err_msg) {
|
||||
failed_ = true;
|
||||
err_msg_ = err_msg;
|
||||
}
|
||||
|
||||
for (auto& callback : callbacks_) {
|
||||
callback(this);
|
||||
}
|
||||
|
||||
cv_completed_.notify_all();
|
||||
}
|
||||
|
||||
// ResetState is called on a completed future,
|
||||
// does not reset callbacks to keep task graph structure
|
||||
void AsyncTaskFuture::ResetState() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (parent_counter_) {
|
||||
parent_counter_->Reset();
|
||||
}
|
||||
completed_ = false;
|
||||
failed_ = false;
|
||||
err_msg_ = "";
|
||||
}
|
||||
|
||||
AsyncTaskFuture::~AsyncTaskFuture() {}
|
||||
|
||||
} // namespace caffe2
|
76
caffe2/core/net_async_task_future.h
Normal file
76
caffe2/core/net_async_task_future.h
Normal file
@ -0,0 +1,76 @@
|
||||
#ifndef CAFFE2_NET_ASYNC_TASK_FUTURE_H
|
||||
#define CAFFE2_NET_ASYNC_TASK_FUTURE_H
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
// Represents the state of AsyncTask execution, that can be queried with
|
||||
// IsCompleted/IsFailed. Callbacks are supported through SetCallback and
|
||||
// are called upon future's completion.
|
||||
|
||||
class AsyncTaskFuture {
|
||||
public:
|
||||
AsyncTaskFuture();
|
||||
// Creates a future completed when all given futures are completed
|
||||
explicit AsyncTaskFuture(const std::vector<AsyncTaskFuture*>& futures);
|
||||
~AsyncTaskFuture();
|
||||
|
||||
AsyncTaskFuture(const AsyncTaskFuture&) = delete;
|
||||
|
||||
AsyncTaskFuture& operator=(const AsyncTaskFuture&) = delete;
|
||||
|
||||
bool IsCompleted() const;
|
||||
|
||||
bool IsFailed() const;
|
||||
|
||||
std::string ErrorMessage() const;
|
||||
|
||||
void Wait() const;
|
||||
|
||||
void SetCallback(std::function<void(const AsyncTaskFuture*)> callback);
|
||||
|
||||
void SetCompleted(const char* err_msg = nullptr);
|
||||
|
||||
void ResetState();
|
||||
|
||||
private:
|
||||
mutable std::mutex mutex_;
|
||||
mutable std::condition_variable cv_completed_;
|
||||
std::atomic<bool> completed_;
|
||||
std::atomic<bool> failed_;
|
||||
std::string err_msg_;
|
||||
std::vector<std::function<void(const AsyncTaskFuture*)>> callbacks_;
|
||||
|
||||
struct ParentCounter {
|
||||
explicit ParentCounter(int init_parent_count)
|
||||
: init_parent_count_(init_parent_count),
|
||||
parent_count(init_parent_count),
|
||||
parent_failed(false) {}
|
||||
|
||||
void Reset() {
|
||||
std::unique_lock<std::mutex> lock(err_mutex);
|
||||
parent_count = init_parent_count_;
|
||||
parent_failed = false;
|
||||
err_msg = "";
|
||||
}
|
||||
|
||||
const int init_parent_count_;
|
||||
std::atomic<int> parent_count;
|
||||
std::mutex err_mutex;
|
||||
std::atomic<bool> parent_failed;
|
||||
std::string err_msg;
|
||||
};
|
||||
|
||||
std::unique_ptr<ParentCounter> parent_counter_;
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_NET_ASYNC_TASK_FUTURE_H
|
139
caffe2/core/net_async_task_graph.cc
Normal file
139
caffe2/core/net_async_task_graph.cc
Normal file
@ -0,0 +1,139 @@
|
||||
#include "caffe2/core/net_async_task_graph.h"
|
||||
|
||||
#include "caffe2/core/net_parallel.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
AsyncTaskGraph::AsyncTaskGraph(
|
||||
ExecutorHelper* helper,
|
||||
const ExecutionOptions& options)
|
||||
: helper_(helper), options_(options), frozen_(false) {}
|
||||
|
||||
bool AsyncTaskGraph::CreateNode(
|
||||
int node_id,
|
||||
const std::vector<OperatorBase*>& ops) {
|
||||
CAFFE_ENFORCE(!frozen_);
|
||||
if (!nodes_.count(node_id)) {
|
||||
nodes_[node_id] = caffe2::make_unique<AsyncTask>(ops);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool AsyncTaskGraph::AddDependency(
|
||||
int child_node_id,
|
||||
const std::vector<int>& parent_node_ids) {
|
||||
CAFFE_ENFORCE(!frozen_);
|
||||
CAFFE_ENFORCE(!parent_node_ids.empty());
|
||||
CAFFE_ENFORCE(nodes_.count(child_node_id));
|
||||
for (auto node_id : parent_node_ids) {
|
||||
CAFFE_ENFORCE(nodes_.count(node_id));
|
||||
}
|
||||
CAFFE_ENFORCE(!parents_.count(child_node_id));
|
||||
|
||||
auto* child_task = nodes_[child_node_id].get();
|
||||
auto child_device = child_task->GetDeviceOption();
|
||||
|
||||
std::vector<AsyncTaskFuture*> parent_futures;
|
||||
for (auto node_id : parent_node_ids) {
|
||||
parents_[child_node_id].insert(node_id);
|
||||
children_[node_id].insert(child_node_id);
|
||||
parent_futures.push_back(&nodes_[node_id]->GetFuture());
|
||||
}
|
||||
|
||||
AsyncTaskFuture* parents_future = nullptr;
|
||||
if (parent_futures.size() > 1) {
|
||||
edge_futures_.push_back(
|
||||
caffe2::make_unique<AsyncTaskFuture>(parent_futures));
|
||||
parents_future = edge_futures_.back().get();
|
||||
} else {
|
||||
CAFFE_ENFORCE_EQ(parent_futures.size(), 1);
|
||||
parents_future = parent_futures.back();
|
||||
}
|
||||
|
||||
// TODO: CUDA polling
|
||||
parents_future->SetCallback(
|
||||
[this, child_task, child_device](const AsyncTaskFuture* f) {
|
||||
CAFFE_ENFORCE(f->IsCompleted());
|
||||
if (!f->IsFailed()) {
|
||||
// if we're in the correct thread pool and DFS scheduling is enabled,
|
||||
// immediately call task inline, otherwise send task into thread pool
|
||||
auto* pool = helper_->GetPool(child_device);
|
||||
if (pool->inThreadPool() && options_.use_dfs_scheduling_) {
|
||||
child_task->Run(options_);
|
||||
} else {
|
||||
pool->run([this, child_task]() { child_task->Run(options_); });
|
||||
}
|
||||
} else {
|
||||
// skip task execution and propagate error further
|
||||
child_task->GetFuture().SetCompleted(f->ErrorMessage().c_str());
|
||||
}
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void AsyncTaskGraph::FreezeGraph() {
|
||||
if (frozen_) {
|
||||
return;
|
||||
}
|
||||
|
||||
CAFFE_ENFORCE(!run_future_);
|
||||
CAFFE_ENFORCE(root_tasks_.empty());
|
||||
|
||||
std::vector<AsyncTaskFuture*> final_futures;
|
||||
for (auto& kv : nodes_) {
|
||||
auto task_id = kv.first;
|
||||
auto* task = kv.second.get();
|
||||
|
||||
if (parents_[task_id].empty()) {
|
||||
root_tasks_.push_back(task);
|
||||
}
|
||||
|
||||
if (children_[task_id].empty()) {
|
||||
auto& future = task->GetFuture();
|
||||
final_futures.push_back(&future);
|
||||
}
|
||||
}
|
||||
|
||||
CAFFE_ENFORCE(!root_tasks_.empty());
|
||||
CAFFE_ENFORCE(!final_futures.empty());
|
||||
|
||||
run_future_ = caffe2::make_unique<AsyncTaskFuture>(final_futures);
|
||||
|
||||
frozen_ = true;
|
||||
}
|
||||
|
||||
AsyncTaskFuture* AsyncTaskGraph::ExecuteGraph() {
|
||||
CAFFE_ENFORCE(frozen_);
|
||||
CAFFE_ENFORCE(run_future_ && !run_future_->IsCompleted());
|
||||
|
||||
// TODO: run root tasks inline in inference mode
|
||||
for (auto* task : root_tasks_) {
|
||||
auto task_device = task->GetDeviceOption();
|
||||
helper_->GetPool(task_device)->run([this, task]() { task->Run(options_); });
|
||||
}
|
||||
|
||||
return run_future_.get();
|
||||
}
|
||||
|
||||
AsyncTaskFuture* AsyncTaskGraph::GetFuture() {
|
||||
CAFFE_ENFORCE(frozen_);
|
||||
return run_future_.get();
|
||||
}
|
||||
|
||||
void AsyncTaskGraph::Reset() {
|
||||
CAFFE_ENFORCE(frozen_);
|
||||
for (auto& kv : nodes_) {
|
||||
kv.second->Reset();
|
||||
}
|
||||
for (auto& future : edge_futures_) {
|
||||
future->ResetState();
|
||||
}
|
||||
if (run_future_) {
|
||||
run_future_->ResetState();
|
||||
}
|
||||
}
|
||||
|
||||
}; // namespace caffe2
|
78
caffe2/core/net_async_task_graph.h
Normal file
78
caffe2/core/net_async_task_graph.h
Normal file
@ -0,0 +1,78 @@
|
||||
#ifndef CAFFE2_NET_ASYNC_TASK_GRAPH_H
|
||||
#define CAFFE2_NET_ASYNC_TASK_GRAPH_H
|
||||
|
||||
#include "caffe2/core/net_async_base.h"
|
||||
#include "caffe2/core/net_async_task.h"
|
||||
#include "caffe2/core/net_async_task_future.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
// AsyncTaskGraph represents an execution of a net, it owns the tasks and
|
||||
// associated futures, sets up future callbacks and propagates errors.
|
||||
// Usage steps:
|
||||
// - Adding graph nodes and edges through CreateNode/AddDependency;
|
||||
// - Freezing the graph (FreezeGraph), after the freezing a future
|
||||
// can be obtained using GetFuture;
|
||||
// - Execution of the graph is scheduled through ExecuteGraph, after each
|
||||
// execution Reset must be called to prepare the graph for the next run
|
||||
|
||||
class AsyncTaskGraphBase {
|
||||
public:
|
||||
virtual bool CreateNode(
|
||||
int node_id,
|
||||
const std::vector<OperatorBase*>& ops) = 0;
|
||||
|
||||
virtual bool AddDependency(
|
||||
int child_node_id,
|
||||
const std::vector<int>& parent_node_ids) = 0;
|
||||
|
||||
virtual void FreezeGraph() = 0;
|
||||
|
||||
virtual AsyncTaskFuture* ExecuteGraph() = 0;
|
||||
|
||||
virtual AsyncTaskFuture* GetFuture() = 0;
|
||||
|
||||
virtual void Reset() = 0;
|
||||
|
||||
virtual ~AsyncTaskGraphBase() noexcept {}
|
||||
};
|
||||
|
||||
class AsyncTaskGraph : public AsyncTaskGraphBase {
|
||||
public:
|
||||
AsyncTaskGraph(ExecutorHelper* helper, const ExecutionOptions& options);
|
||||
|
||||
bool CreateNode(int node_id, const std::vector<OperatorBase*>& ops) override;
|
||||
|
||||
bool AddDependency(int child_node_id, const std::vector<int>& parent_node_ids)
|
||||
override;
|
||||
|
||||
void FreezeGraph() override;
|
||||
|
||||
AsyncTaskFuture* ExecuteGraph() override;
|
||||
|
||||
AsyncTaskFuture* GetFuture() override;
|
||||
|
||||
void Reset() override;
|
||||
|
||||
private:
|
||||
// used to, e.g., get access to executor's thread pools
|
||||
// TODO: pass tracer and counters through ExecutorHelper
|
||||
ExecutorHelper* helper_;
|
||||
ExecutionOptions options_;
|
||||
|
||||
bool frozen_;
|
||||
|
||||
std::unordered_map<int, std::unique_ptr<AsyncTask>> nodes_;
|
||||
std::unordered_map<int, std::unordered_set<int>> parents_;
|
||||
std::unordered_map<int, std::unordered_set<int>> children_;
|
||||
std::vector<std::unique_ptr<AsyncTaskFuture>> edge_futures_;
|
||||
|
||||
std::vector<AsyncTask*> root_tasks_;
|
||||
|
||||
std::unique_ptr<AsyncTaskFuture> run_future_;
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_NET_ASYNC_TASK_GRAPH_H
|
197
caffe2/core/net_parallel.cc
Normal file
197
caffe2/core/net_parallel.cc
Normal file
@ -0,0 +1,197 @@
|
||||
#include "caffe2/core/net_parallel.h"
|
||||
|
||||
#include "caffe2/core/operator.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
C10_DEFINE_string(
|
||||
caffe2_task_graph_engine,
|
||||
"futures",
|
||||
"Task graph engine type used by net executor");
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
ParallelNet::ParallelNet(
|
||||
const std::shared_ptr<const NetDef>& net_def,
|
||||
Workspace* ws)
|
||||
: NetBase(net_def, ws), options_(net_def), run_future_(nullptr) {
|
||||
num_workers_ = net_def->num_workers();
|
||||
CAFFE_ENFORCE_GT(
|
||||
num_workers_, 0, "Expected positive number of worker threads");
|
||||
|
||||
helper_ = caffe2::make_unique<ParallelNetExecutorHelper>(this);
|
||||
task_graph_ = TaskGraphRegistry()->Create(
|
||||
FLAGS_caffe2_task_graph_engine, helper_.get(), options_);
|
||||
|
||||
// initialize operators
|
||||
operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
|
||||
operators_.reserve(operator_nodes_.size());
|
||||
for (const auto& node : operator_nodes_) {
|
||||
auto op = node.operator_.get();
|
||||
op->SetExecutorHelper(helper_.get());
|
||||
operators_.push_back(op);
|
||||
}
|
||||
|
||||
// compute chains
|
||||
// TODO: inference mode for chaining
|
||||
auto execution_chains = dag_utils::computeChains(operator_nodes_);
|
||||
std::vector<std::vector<int>> chains;
|
||||
chains.reserve(execution_chains.size());
|
||||
for (const auto& kv : execution_chains) {
|
||||
chains.push_back(kv.second);
|
||||
}
|
||||
auto chain_nodes = dag_utils::prepareChainGraphNodes(operator_nodes_, chains);
|
||||
CAFFE_ENFORCE_EQ(chains.size(), chain_nodes.size());
|
||||
|
||||
// disable unused events
|
||||
for (const auto& chain : chains) {
|
||||
for (const auto& op_id : chain) {
|
||||
if (op_id == chain.back() || op_id == chain.front()) {
|
||||
continue;
|
||||
}
|
||||
auto op = operators_[op_id];
|
||||
if (IsCPUDeviceType(op->device_option().device_type()) &&
|
||||
op->HasAsyncPart()) {
|
||||
continue;
|
||||
}
|
||||
op->DisableEvent();
|
||||
}
|
||||
}
|
||||
|
||||
// initialize task graph
|
||||
for (auto chain_id = 0; chain_id < chains.size(); ++chain_id) {
|
||||
std::vector<OperatorBase*> ops;
|
||||
ops.reserve(chains[chain_id].size());
|
||||
for (auto op_id : chains[chain_id]) {
|
||||
ops.push_back(operators_[op_id]);
|
||||
}
|
||||
CAFFE_ENFORCE(task_graph_->CreateNode(chain_id, ops));
|
||||
}
|
||||
for (auto chain_id = 0; chain_id < chain_nodes.size(); ++chain_id) {
|
||||
if (!chain_nodes[chain_id].parents_.empty()) {
|
||||
CAFFE_ENFORCE(
|
||||
task_graph_->AddDependency(chain_id, chain_nodes[chain_id].parents_));
|
||||
}
|
||||
}
|
||||
|
||||
// Freeze graph and initialize graph execution future
|
||||
task_graph_->FreezeGraph();
|
||||
run_future_ = task_graph_->GetFuture();
|
||||
run_future_->SetCallback([this](const AsyncTaskFuture* /* unused */) {
|
||||
StopAllObservers();
|
||||
finishRun();
|
||||
});
|
||||
|
||||
LOG(INFO) << "Initialized parallel net: '" << Name()
|
||||
<< "', #ops: " << net_def->op_size()
|
||||
<< ", #chains: " << chains.size() << ", #workers: " << num_workers_
|
||||
<< ", dfs scheduling: " << options_.use_dfs_scheduling_
|
||||
<< ", task graph engine: " << FLAGS_caffe2_task_graph_engine;
|
||||
}
|
||||
|
||||
bool ParallelNet::RunAsync() {
|
||||
reset();
|
||||
StartAllObservers();
|
||||
|
||||
try {
|
||||
task_graph_->ExecuteGraph();
|
||||
} catch (const std::exception&) {
|
||||
StopAllObservers();
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void ParallelNet::Wait() {
|
||||
CAFFE_ENFORCE(run_future_);
|
||||
run_future_->Wait();
|
||||
}
|
||||
|
||||
void ParallelNet::reset() {
|
||||
task_graph_->Reset();
|
||||
}
|
||||
|
||||
bool ParallelNet::handleRunError() {
|
||||
CAFFE_ENFORCE(run_future_ && run_future_->IsCompleted());
|
||||
// TODO: throw saved exceptions
|
||||
if (run_future_->IsFailed()) {
|
||||
LOG(ERROR) << "Failed parallel run (" << Name()
|
||||
<< "): " << run_future_->ErrorMessage();
|
||||
}
|
||||
return !run_future_->IsFailed();
|
||||
}
|
||||
|
||||
TaskThreadPoolBase* ParallelNet::poolGetter(
|
||||
PoolsMap& pools,
|
||||
int device_type,
|
||||
int device_id,
|
||||
int pool_size) {
|
||||
std::unique_lock<std::mutex> pools_lock(pools_mutex_);
|
||||
auto pool = pools[device_id][pool_size];
|
||||
if (!pool) {
|
||||
pool = ThreadPoolRegistry()->Create(
|
||||
DeviceTypeName(device_type),
|
||||
device_id,
|
||||
pool_size,
|
||||
options_.use_per_net_pools_);
|
||||
pools[device_id][pool_size] = pool;
|
||||
}
|
||||
return pool.get();
|
||||
}
|
||||
|
||||
TaskThreadPoolBase* ParallelNet::Pool(const DeviceOption& device_option) {
|
||||
if (options_.use_single_pool_) {
|
||||
return poolGetter(cpu_pools_, PROTO_CPU, -1, num_workers_);
|
||||
}
|
||||
const auto device_type = device_option.device_type();
|
||||
if (IsCPUDeviceType(device_type)) {
|
||||
auto numa_node_id = -1;
|
||||
if (device_option.has_numa_node_id()) {
|
||||
numa_node_id = device_option.numa_node_id();
|
||||
CAFFE_ENFORCE_GE(numa_node_id, 0, "Invalid NUMA node id: ", numa_node_id);
|
||||
}
|
||||
CAFFE_ENFORCE_LT(
|
||||
numa_node_id,
|
||||
FLAGS_caffe2_net_async_max_numa_nodes,
|
||||
"Invalid NUMA node id: ",
|
||||
numa_node_id);
|
||||
return poolGetter(cpu_pools_, device_type, numa_node_id, num_workers_);
|
||||
} else if (IsGPUDeviceType(device_type)) {
|
||||
auto gpu_id = device_option.device_id();
|
||||
CAFFE_ENFORCE(
|
||||
gpu_id >= 0 && gpu_id < FLAGS_caffe2_net_async_max_gpus,
|
||||
"Invalid GPU id: " + caffe2::to_string(gpu_id));
|
||||
return poolGetter(gpu_pools_, device_type, gpu_id, num_workers_);
|
||||
} else {
|
||||
CAFFE_THROW("Unsupported device type " + caffe2::to_string(device_type));
|
||||
}
|
||||
}
|
||||
|
||||
bool ParallelNet::SupportsAsync() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void ParallelNet::finishRun() {}
|
||||
|
||||
std::vector<OperatorBase*> ParallelNet::GetOperators() const {
|
||||
return operators_;
|
||||
}
|
||||
|
||||
std::shared_ptr<AsyncTaskGraphBase> GetAsyncTaskGraph(
|
||||
ExecutorHelper* helper,
|
||||
const ExecutionOptions& options) {
|
||||
return std::make_shared<AsyncTaskGraph>(helper, options);
|
||||
}
|
||||
|
||||
C10_DEFINE_SHARED_REGISTRY(
|
||||
TaskGraphRegistry,
|
||||
AsyncTaskGraphBase,
|
||||
ExecutorHelper*,
|
||||
const ExecutionOptions&);
|
||||
|
||||
C10_REGISTER_CREATOR(TaskGraphRegistry, futures, GetAsyncTaskGraph);
|
||||
|
||||
REGISTER_NET(parallel, ParallelNet);
|
||||
|
||||
} // namespace caffe2
|
77
caffe2/core/net_parallel.h
Normal file
77
caffe2/core/net_parallel.h
Normal file
@ -0,0 +1,77 @@
|
||||
#ifndef CAFFE2_CORE_NET_PARALLEL_H
|
||||
#define CAFFE2_CORE_NET_PARALLEL_H
|
||||
|
||||
#include "caffe2/core/net_async_base.h"
|
||||
#include "caffe2/core/net_async_task_graph.h"
|
||||
|
||||
C10_DECLARE_string(caffe2_task_graph_engine);
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
class ParallelNetExecutorHelper;
|
||||
|
||||
class CAFFE2_API ParallelNet : public NetBase {
|
||||
public:
|
||||
ParallelNet(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
|
||||
|
||||
bool RunAsync() override;
|
||||
void Wait() override;
|
||||
|
||||
bool SupportsAsync() override;
|
||||
std::vector<OperatorBase*> GetOperators() const override;
|
||||
|
||||
TaskThreadPoolBase* Pool(const DeviceOption& device_option);
|
||||
|
||||
protected:
|
||||
bool handleRunError() override;
|
||||
virtual void finishRun();
|
||||
virtual void reset();
|
||||
|
||||
ExecutionOptions options_;
|
||||
int num_workers_;
|
||||
|
||||
std::unique_ptr<ParallelNetExecutorHelper> helper_;
|
||||
std::shared_ptr<AsyncTaskGraphBase> task_graph_;
|
||||
AsyncTaskFuture* run_future_;
|
||||
|
||||
std::vector<dag_utils::OperatorNode> operator_nodes_;
|
||||
std::vector<OperatorBase*> operators_;
|
||||
|
||||
std::mutex pools_mutex_;
|
||||
typedef std::unordered_map<
|
||||
int,
|
||||
std::unordered_map<int, std::shared_ptr<TaskThreadPoolBase>>>
|
||||
PoolsMap;
|
||||
PoolsMap cpu_pools_;
|
||||
PoolsMap gpu_pools_;
|
||||
TaskThreadPoolBase*
|
||||
poolGetter(PoolsMap& pools, int device_type, int device_id, int pool_size);
|
||||
|
||||
friend class ParallelNetExecutorHelper;
|
||||
C10_DISABLE_COPY_AND_ASSIGN(ParallelNet);
|
||||
};
|
||||
|
||||
C10_DECLARE_SHARED_REGISTRY(
|
||||
TaskGraphRegistry,
|
||||
AsyncTaskGraphBase,
|
||||
ExecutorHelper*,
|
||||
const ExecutionOptions&);
|
||||
|
||||
std::shared_ptr<AsyncTaskGraphBase> GetAsyncTaskGraph(
|
||||
ExecutorHelper* helper,
|
||||
const ExecutionOptions& options);
|
||||
|
||||
class ParallelNetExecutorHelper : public ExecutorHelper {
|
||||
public:
|
||||
explicit ParallelNetExecutorHelper(ParallelNet* net) : net_(net) {}
|
||||
TaskThreadPoolBase* GetPool(const DeviceOption& option) const override {
|
||||
return net_->Pool(option);
|
||||
}
|
||||
|
||||
private:
|
||||
ParallelNet* net_;
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_CORE_NET_PARALLEL_H
|
@ -19,7 +19,7 @@ import hypothesis.strategies as st
|
||||
import unittest
|
||||
|
||||
|
||||
EXECUTORS = ["async_scheduling", "dag", "async_dag"]
|
||||
EXECUTORS = ["parallel", "async_scheduling"]
|
||||
ITERATIONS = 1
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user