[non ghstack] Init threadpool with user defined num_threads before default (#137051)

Very similar to https://github.com/pytorch/pytorch/pull/136793, but adds back `pool->set_thread_count` call as it is still necessary (I am guessing due to the mutex)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137051
Approved by: https://github.com/albanD
This commit is contained in:
Jane Xu
2024-10-02 20:02:28 +00:00
committed by PyTorch MergeBot
parent 59d7cf7342
commit 2d465e4d1d
3 changed files with 11 additions and 7 deletions

View File

@ -61,9 +61,8 @@ void set_num_threads(int nthreads) {
#endif
#ifdef USE_PTHREADPOOL
// because PyTorch uses caffe2::pthreadpool() in QNNPACK
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
caffe2::PThreadPool* const pool = caffe2::pthreadpool(nthreads);
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
pool->set_thread_count(nthreads);
#endif
#if AT_MKLDNN_ENABLED()
at::native::mkldnn::clear_computation_cache();

View File

@ -82,12 +82,9 @@ void PThreadPool::run(
0u);
}
// Forward declaration
size_t getDefaultNumThreads();
PThreadPool* pthreadpool() {
PThreadPool* pthreadpool(size_t thread_count) {
static auto threadpool =
std::make_unique<PThreadPool>(getDefaultNumThreads());
std::make_unique<PThreadPool>(thread_count);
#if !(defined(WIN32))
static std::once_flag flag;
std::call_once(flag, []() {
@ -105,6 +102,13 @@ PThreadPool* pthreadpool() {
return threadpool.get();
}
// Forward declaration
size_t getDefaultNumThreads();
PThreadPool* pthreadpool() {
return pthreadpool(getDefaultNumThreads());
}
pthreadpool_t pthreadpool_() {
if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
return nullptr;

View File

@ -42,6 +42,7 @@ class PThreadPool final {
// Return a singleton instance of PThreadPool for ATen/TH multithreading.
PThreadPool* pthreadpool();
PThreadPool* pthreadpool(size_t thread_count);
// Exposes the underlying implementation of PThreadPool.
// Only for use in external libraries so as to unify threading across