Fix init_thread calls in thread pool initialization (#20848)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20848
ghimport-source-id: e542858a198252838c1f3100dbfbe90fd3960f07

Differential Revision: D15466918

Pulled By: ilia-cher

fbshipit-source-id: e75d38f51edd5b508c4ca28a292e4141e90f209f
This commit is contained in:
Ilia Cherniavskii
2019-05-24 01:08:17 -07:00
committed by Facebook Github Bot
parent 1bb728fe14
commit 6b74856747
4 changed files with 20 additions and 23 deletions

View File

@ -137,12 +137,10 @@ std::string get_parallel_info() {
PTThreadPool::PTThreadPool(
int pool_size,
int numa_node_id)
: c10::ThreadPool(pool_size, numa_node_id) {}
void PTThreadPool::init_thread() {
c10::setThreadName("PTThreadPool");
at::init_num_threads();
}
: c10::ThreadPool(pool_size, numa_node_id, [](){
c10::setThreadName("PTThreadPool");
at::init_num_threads();
}) {}
C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool);

View File

@ -155,8 +155,6 @@ class CAFFE2_API PTThreadPool : public c10::ThreadPool {
explicit PTThreadPool(
int pool_size,
int numa_node_id = -1);
void init_thread() override;
};
// Sets number of threads used for inter-op parallelism

View File

@ -2,7 +2,10 @@
namespace c10 {
ThreadPool::ThreadPool(int pool_size, int numa_node_id)
ThreadPool::ThreadPool(
int pool_size,
int numa_node_id,
std::function<void()> init_thread)
: threads_(pool_size < 0 ? defaultNumThreads() : pool_size),
running_(true),
complete_(true),
@ -10,7 +13,12 @@ ThreadPool::ThreadPool(int pool_size, int numa_node_id)
total_(threads_.size()),
numa_node_id_(numa_node_id) {
for (std::size_t i = 0; i < threads_.size(); ++i) {
threads_[i] = std::thread(std::bind(&ThreadPool::main_loop, this, i));
threads_[i] = std::thread([this, i, init_thread](){
if (init_thread) {
init_thread();
}
this->main_loop(i);
});
}
}
@ -68,8 +76,6 @@ void ThreadPool::waitWorkComplete() {
}
void ThreadPool::main_loop(std::size_t index) {
init_thread();
std::unique_lock<std::mutex> lock(mutex_);
while (running_) {
// Wait on condition variable while the task is empty and

View File

@ -71,7 +71,8 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
explicit ThreadPool(
int pool_size,
int numa_node_id = -1);
int numa_node_id = -1,
std::function<void()> init_thread = nullptr);
~ThreadPool();
@ -98,9 +99,6 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
/// @brief Wait for queue to be empty
void waitWorkComplete();
protected:
virtual void init_thread() {}
private:
// @brief Entry point for pool threads.
void main_loop(std::size_t index);
@ -111,13 +109,10 @@ class C10_API TaskThreadPool : public c10::ThreadPool {
explicit TaskThreadPool(
std::size_t pool_size,
int numa_node_id = -1)
: ThreadPool(pool_size, numa_node_id) {}
// TODO move this to ATen/core/thread_pool.h
void init_thread() override {
setThreadName("CaffeTaskThread");
NUMABind(numa_node_id_);
}
: ThreadPool(pool_size, numa_node_id, [numa_node_id](){
setThreadName("CaffeTaskThread");
NUMABind(numa_node_id);
}) {}
};
C10_DECLARE_SHARED_REGISTRY(