mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[non ghstack] Init threadpool with user defined num_threads before default (#137051)
Very similar to https://github.com/pytorch/pytorch/pull/136793, but adds back `pool->set_thread_count` call as it is still necessary (I am guessing due to the mutex) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/137051 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
59d7cf7342
commit
2d465e4d1d
@ -61,9 +61,8 @@ void set_num_threads(int nthreads) {
|
||||
#endif
|
||||
#ifdef USE_PTHREADPOOL
|
||||
// because PyTorch uses caffe2::pthreadpool() in QNNPACK
|
||||
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
|
||||
caffe2::PThreadPool* const pool = caffe2::pthreadpool(nthreads);
|
||||
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
|
||||
pool->set_thread_count(nthreads);
|
||||
#endif
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
at::native::mkldnn::clear_computation_cache();
|
||||
|
@ -82,12 +82,9 @@ void PThreadPool::run(
|
||||
0u);
|
||||
}
|
||||
|
||||
// Forward declaration
|
||||
size_t getDefaultNumThreads();
|
||||
|
||||
PThreadPool* pthreadpool() {
|
||||
PThreadPool* pthreadpool(size_t thread_count) {
|
||||
static auto threadpool =
|
||||
std::make_unique<PThreadPool>(getDefaultNumThreads());
|
||||
std::make_unique<PThreadPool>(thread_count);
|
||||
#if !(defined(WIN32))
|
||||
static std::once_flag flag;
|
||||
std::call_once(flag, []() {
|
||||
@ -105,6 +102,13 @@ PThreadPool* pthreadpool() {
|
||||
return threadpool.get();
|
||||
}
|
||||
|
||||
// Forward declaration
|
||||
size_t getDefaultNumThreads();
|
||||
|
||||
PThreadPool* pthreadpool() {
|
||||
return pthreadpool(getDefaultNumThreads());
|
||||
}
|
||||
|
||||
pthreadpool_t pthreadpool_() {
|
||||
if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
|
||||
return nullptr;
|
||||
|
@ -42,6 +42,7 @@ class PThreadPool final {
|
||||
|
||||
// Return a singleton instance of PThreadPool for ATen/TH multithreading.
|
||||
PThreadPool* pthreadpool();
|
||||
PThreadPool* pthreadpool(size_t thread_count);
|
||||
|
||||
// Exposes the underlying implementation of PThreadPool.
|
||||
// Only for use in external libraries so as to unify threading across
|
||||
|
Reference in New Issue
Block a user