Files
pytorch/caffe2/core/net_async_base.h
Yangqing Jia 7d5f7ed270 Using c10 namespace across caffe2. (#12714)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12714

This is a short change to enable c10 namespace in caffe2. We did not enable
it before due to gflags global variable confusion, but it should have been
mostly cleaned now. Right now, the plan on record is that namespace caffe2 and
namespace aten will fully be supersets of namespace c10.

Most of the diff is codemod, and only two places of non-codemod is in caffe2/core/common.h, where

```
using namespace c10;
```

is added, and in Flags.h, where instead of creating aliasing variables in c10 namespace, we directly put it in the global namespace to match gflags (and same behavior if gflags is not being built with).

Reviewed By: dzhulgakov

Differential Revision: D10390486

fbshipit-source-id: 5e2df730e28e29a052f513bddc558d9f78a23b9b
2018-10-17 12:57:19 -07:00

224 lines
6.5 KiB
C++

#ifndef CAFFE2_CORE_NET_ASYNC_BASE_H_
#define CAFFE2_CORE_NET_ASYNC_BASE_H_
#include "c10/util/Registry.h"
#include "caffe2/core/common.h"
#include "caffe2/core/net.h"
#include "caffe2/core/net_async_base.h"
#include "caffe2/core/net_dag_utils.h"
#include "caffe2/core/prof_dag_counters.h"
#include "caffe2/core/stats.h"
#include "caffe2/core/timer.h"
#include "caffe2/core/workspace.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/proto/prof_dag.pb.h"
#include "caffe2/utils/proto_utils.h"
#include "caffe2/utils/thread_pool.h"
C10_DECLARE_int(caffe2_streams_per_gpu);
C10_DECLARE_bool(caffe2_net_async_finish_chain);
C10_DECLARE_bool(caffe2_net_async_always_schedule_child);
C10_DECLARE_int(caffe2_net_async_max_gpus);
C10_DECLARE_int(caffe2_net_async_max_numa_nodes);
C10_DECLARE_int(caffe2_net_async_cpu_pool_size);
C10_DECLARE_bool(caffe2_net_async_check_stream_status);
C10_DECLARE_bool(caffe2_net_async_use_single_pool);
C10_DECLARE_bool(caffe2_net_async_use_per_net_pools);
namespace caffe2 {
class AsyncNetExecutorHelper;
namespace tracing {
class Tracer;
}
class CAFFE2_API AsyncNetBase : public NetBase {
public:
AsyncNetBase(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
~AsyncNetBase() override;
bool SupportsAsync() override {
return true;
}
vector<OperatorBase*> GetOperators() const override {
return operators_;
}
bool RunAsync() override;
const dag_utils::ExecutionChains& TEST_execution_chains() const {
return execution_chains_;
}
ProfDAGProtos GetOperatorStats() const;
ProfDAGProtos GetPerOperatorCost() const;
protected:
bool canSchedule(
int chain_id,
const std::vector<EventStatus>* status = nullptr,
bool* parent_failed = nullptr);
bool canSchedule(int parent_id, int child_id);
int tasksNum() const;
Event& event(int task_id) const;
EventStatus query(int task_id) const;
const std::vector<int>& children(int task_id) const;
const std::vector<int>& parents(int task_id) const;
int updateParentCount(int child_id);
int getParentCount(int child_id);
bool testAndSetScheduled(int task_id);
int numOps(int task_id) const;
int firstTaskOpId(int task_id) const;
int lastTaskOpId(int task_id) const;
const OperatorBase* firstTaskOp(int task_id) const;
const OperatorBase* lastTaskOp(int task_id) const;
OperatorBase* firstTaskOp(int task_id);
OperatorBase* lastTaskOp(int task_id);
void asyncWait(
int task_id,
int stream_id,
const std::vector<int>& wait_task_ids) const;
bool run(int task_id, int stream_id);
int stream(int task_id);
TaskThreadPoolBase* pool(const DeviceOption& device_option);
void finishTasks(const std::unordered_set<int>& task_ids);
void finalizeEvents();
bool isStreamFree(int task_id, int stream_id) const;
virtual void reset();
bool handleRunError() override;
// Operator/task graph
std::vector<OperatorBase*> operators_;
std::vector<dag_utils::OperatorNode> operator_nodes_;
std::vector<std::vector<int>> chains_;
std::vector<dag_utils::OpGraphNode> chain_nodes_; // chains' parents/children
dag_utils::ExecutionChains execution_chains_; // for testing
// Pools and streams
std::mutex pools_mutex_;
// first int key - device id, second - pool size, one pool per (device, size)
typedef std::unordered_map<
int,
std::unordered_map<int, std::shared_ptr<TaskThreadPoolBase>>>
PoolsMap;
PoolsMap cpu_pools_;
PoolsMap gpu_pools_;
static std::vector<int>& getStreamCounters();
int num_workers_;
// Exception/error handling
void setTaskErrorMessage(int task_id, const std::string& err_msg);
std::atomic<bool> success_;
#ifdef CAFFE2_USE_EXCEPTION_PTR
// Mutex that protects caught_exception_
std::mutex exception_mutex_;
std::exception_ptr caught_exception_;
#endif // CAFFE2_USE_EXCEPTION_PTR
// Tracing
std::shared_ptr<tracing::Tracer> tracer_;
// execution mode flags
void computeExecutionModeFlags();
int streams_per_gpu_;
bool finish_chain_;
bool always_schedule_child_;
bool check_stream_status_;
bool use_single_pool_;
bool use_per_net_pools_;
bool is_blocking_;
bool report_stats_;
ProfDAGCounters counters_;
C10_DISABLE_COPY_AND_ASSIGN(AsyncNetBase);
private:
void storeExceptionPtr();
TaskThreadPoolBase*
poolGetter(PoolsMap& pools, int device_type, int device_id, int pool_size);
std::unique_ptr<AsyncNetExecutorHelper> helper_;
friend class AsyncNetExecutorHelper;
friend class tracing::Tracer;
};
C10_DECLARE_SHARED_REGISTRY(
ThreadPoolRegistry,
TaskThreadPoolBase,
int,
int,
bool);
class AsyncNetExecutorHelper : public ExecutorHelper {
public:
explicit AsyncNetExecutorHelper(AsyncNetBase* net) : net_(net) {}
TaskThreadPoolBase* GetPool(const DeviceOption& option) const override {
return net_->pool(option);
}
private:
AsyncNetBase* net_;
};
template <class TaskThreadPoolImpl>
std::shared_ptr<TaskThreadPoolBase>
GetAsyncNetCPUThreadPool(int numa_node_id, int pool_size, bool create_new) {
// Note: numa_node_id = -1 corresponds to no NUMA used
static std::unordered_map<
int,
std::unordered_map<int, std::weak_ptr<TaskThreadPoolBase>>>
pools;
static std::mutex pool_mutex;
if (pool_size <= 0) {
if (FLAGS_caffe2_net_async_cpu_pool_size > 0) {
pool_size = FLAGS_caffe2_net_async_cpu_pool_size;
LOG(INFO) << "Using default CPU pool size: " << pool_size
<< "; NUMA node id: " << numa_node_id;
} else {
auto num_cores = std::thread::hardware_concurrency();
CAFFE_ENFORCE(num_cores > 0, "Failed to get number of CPU cores");
LOG(INFO) << "Using estimated CPU pool size: " << num_cores
<< "; NUMA node id: " << numa_node_id;
pool_size = num_cores;
}
} else {
LOG(INFO) << "Using specified CPU pool size: " << pool_size
<< "; NUMA node id: " << numa_node_id;
}
if (create_new) {
LOG(INFO) << "Created new CPU pool, size: " << pool_size
<< "; NUMA node id: " << numa_node_id;
return std::make_shared<TaskThreadPoolImpl>(pool_size, numa_node_id);
} else {
std::lock_guard<std::mutex> lock(pool_mutex);
auto shared_pool = pools[numa_node_id][pool_size].lock();
if (!shared_pool) {
LOG(INFO) << "Created shared CPU pool, size: " << pool_size
<< "; NUMA node id: " << numa_node_id;
shared_pool =
std::make_shared<TaskThreadPoolImpl>(pool_size, numa_node_id);
pools[numa_node_id][pool_size] = shared_pool;
}
return shared_pool;
}
}
} // namespace caffe2
#endif // CAFFE2_CORE_NET_ASYNC_BASE_H_