From 2d465e4d1d8e09f6f278e2c4bfe5848bc49934fe Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 2 Oct 2024 20:02:28 +0000 Subject: [PATCH] [non ghstack] Init threadpool with user defined num_threads before default (#137051) Very similar to https://github.com/pytorch/pytorch/pull/136793, but adds back `pool->set_thread_count` call as it is still necessary (I am guessing due to the mutex) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/137051 Approved by: https://github.com/albanD --- aten/src/ATen/ParallelOpenMP.cpp | 3 +-- caffe2/utils/threadpool/pthreadpool-cpp.cc | 14 +++++++++----- caffe2/utils/threadpool/pthreadpool-cpp.h | 1 + 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/ParallelOpenMP.cpp b/aten/src/ATen/ParallelOpenMP.cpp index 40257882ea20..1c128bfc3b28 100644 --- a/aten/src/ATen/ParallelOpenMP.cpp +++ b/aten/src/ATen/ParallelOpenMP.cpp @@ -61,9 +61,8 @@ void set_num_threads(int nthreads) { #endif #ifdef USE_PTHREADPOOL // because PyTorch uses caffe2::pthreadpool() in QNNPACK - caffe2::PThreadPool* const pool = caffe2::pthreadpool(); + caffe2::PThreadPool* const pool = caffe2::pthreadpool(nthreads); TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!"); - pool->set_thread_count(nthreads); #endif #if AT_MKLDNN_ENABLED() at::native::mkldnn::clear_computation_cache(); diff --git a/caffe2/utils/threadpool/pthreadpool-cpp.cc b/caffe2/utils/threadpool/pthreadpool-cpp.cc index e281fa2cb40e..6766b13d2b84 100644 --- a/caffe2/utils/threadpool/pthreadpool-cpp.cc +++ b/caffe2/utils/threadpool/pthreadpool-cpp.cc @@ -82,12 +82,9 @@ void PThreadPool::run( 0u); } -// Forward declaration -size_t getDefaultNumThreads(); - -PThreadPool* pthreadpool() { +PThreadPool* pthreadpool(size_t thread_count) { static auto threadpool = - std::make_unique(getDefaultNumThreads()); + std::make_unique(thread_count); #if !(defined(WIN32)) static std::once_flag flag; std::call_once(flag, []() { @@ -105,6 +102,13 @@ PThreadPool* pthreadpool() { return threadpool.get(); } +// Forward declaration +size_t getDefaultNumThreads(); + +PThreadPool* pthreadpool() { + return pthreadpool(getDefaultNumThreads()); +} + pthreadpool_t pthreadpool_() { if (caffe2::_NoPThreadPoolGuard::is_enabled()) { return nullptr; diff --git a/caffe2/utils/threadpool/pthreadpool-cpp.h b/caffe2/utils/threadpool/pthreadpool-cpp.h index 99acff4df027..f6fc5a2d8243 100644 --- a/caffe2/utils/threadpool/pthreadpool-cpp.h +++ b/caffe2/utils/threadpool/pthreadpool-cpp.h @@ -42,6 +42,7 @@ class PThreadPool final { // Return a singleton instance of PThreadPool for ATen/TH multithreading. PThreadPool* pthreadpool(); +PThreadPool* pthreadpool(size_t thread_count); // Exposes the underlying implementation of PThreadPool. // Only for use in external libraries so as to unify threading across