mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix init_thread calls in thread pool initialization (#20848)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20848 ghimport-source-id: e542858a198252838c1f3100dbfbe90fd3960f07 Differential Revision: D15466918 Pulled By: ilia-cher fbshipit-source-id: e75d38f51edd5b508c4ca28a292e4141e90f209f
This commit is contained in:
committed by
Facebook Github Bot
parent
1bb728fe14
commit
6b74856747
@ -137,12 +137,10 @@ std::string get_parallel_info() {
|
||||
PTThreadPool::PTThreadPool(
|
||||
int pool_size,
|
||||
int numa_node_id)
|
||||
: c10::ThreadPool(pool_size, numa_node_id) {}
|
||||
|
||||
void PTThreadPool::init_thread() {
|
||||
c10::setThreadName("PTThreadPool");
|
||||
at::init_num_threads();
|
||||
}
|
||||
: c10::ThreadPool(pool_size, numa_node_id, [](){
|
||||
c10::setThreadName("PTThreadPool");
|
||||
at::init_num_threads();
|
||||
}) {}
|
||||
|
||||
C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool);
|
||||
|
||||
|
@ -155,8 +155,6 @@ class CAFFE2_API PTThreadPool : public c10::ThreadPool {
|
||||
explicit PTThreadPool(
|
||||
int pool_size,
|
||||
int numa_node_id = -1);
|
||||
|
||||
void init_thread() override;
|
||||
};
|
||||
|
||||
// Sets number of threads used for inter-op parallelism
|
||||
|
@ -2,7 +2,10 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
ThreadPool::ThreadPool(int pool_size, int numa_node_id)
|
||||
ThreadPool::ThreadPool(
|
||||
int pool_size,
|
||||
int numa_node_id,
|
||||
std::function<void()> init_thread)
|
||||
: threads_(pool_size < 0 ? defaultNumThreads() : pool_size),
|
||||
running_(true),
|
||||
complete_(true),
|
||||
@ -10,7 +13,12 @@ ThreadPool::ThreadPool(int pool_size, int numa_node_id)
|
||||
total_(threads_.size()),
|
||||
numa_node_id_(numa_node_id) {
|
||||
for (std::size_t i = 0; i < threads_.size(); ++i) {
|
||||
threads_[i] = std::thread(std::bind(&ThreadPool::main_loop, this, i));
|
||||
threads_[i] = std::thread([this, i, init_thread](){
|
||||
if (init_thread) {
|
||||
init_thread();
|
||||
}
|
||||
this->main_loop(i);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -68,8 +76,6 @@ void ThreadPool::waitWorkComplete() {
|
||||
}
|
||||
|
||||
void ThreadPool::main_loop(std::size_t index) {
|
||||
init_thread();
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (running_) {
|
||||
// Wait on condition variable while the task is empty and
|
||||
|
@ -71,7 +71,8 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
|
||||
|
||||
explicit ThreadPool(
|
||||
int pool_size,
|
||||
int numa_node_id = -1);
|
||||
int numa_node_id = -1,
|
||||
std::function<void()> init_thread = nullptr);
|
||||
|
||||
~ThreadPool();
|
||||
|
||||
@ -98,9 +99,6 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
|
||||
/// @brief Wait for queue to be empty
|
||||
void waitWorkComplete();
|
||||
|
||||
protected:
|
||||
virtual void init_thread() {}
|
||||
|
||||
private:
|
||||
// @brief Entry point for pool threads.
|
||||
void main_loop(std::size_t index);
|
||||
@ -111,13 +109,10 @@ class C10_API TaskThreadPool : public c10::ThreadPool {
|
||||
explicit TaskThreadPool(
|
||||
std::size_t pool_size,
|
||||
int numa_node_id = -1)
|
||||
: ThreadPool(pool_size, numa_node_id) {}
|
||||
|
||||
// TODO move this to ATen/core/thread_pool.h
|
||||
void init_thread() override {
|
||||
setThreadName("CaffeTaskThread");
|
||||
NUMABind(numa_node_id_);
|
||||
}
|
||||
: ThreadPool(pool_size, numa_node_id, [numa_node_id](){
|
||||
setThreadName("CaffeTaskThread");
|
||||
NUMABind(numa_node_id);
|
||||
}) {}
|
||||
};
|
||||
|
||||
C10_DECLARE_SHARED_REGISTRY(
|
||||
|
Reference in New Issue
Block a user