Add guard to run on current thread (#52361)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52361

Test Plan:
buck build //xplat/caffe2:aten_test_test_thread_pool_guard
./aten_test_test_thread_pool_guard

Reviewed By: kimishpatel

Differential Revision: D26429540

fbshipit-source-id: 16e4a56d4bf9b73b1ea1ff88d7dc6730e0b1e029
This commit is contained in:
Akshit Khurana
2021-03-03 11:37:36 -08:00
committed by Facebook GitHub Bot
parent 0f81a69a96
commit ecd8e4c1d5
6 changed files with 80 additions and 2 deletions

View File

@ -1355,6 +1355,7 @@ filegroup(
"caffe2/utils/threadpool/ThreadPool.cc",
"caffe2/utils/threadpool/pthreadpool.cc",
"caffe2/utils/threadpool/pthreadpool_impl.cc",
"caffe2/utils/threadpool/thread_pool_guard.cpp",
],
)

View File

@ -0,0 +1,31 @@
#include <gtest/gtest.h>
#include <caffe2/utils/threadpool/thread_pool_guard.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
TEST(TestThreadPoolGuard, TestThreadPoolGuard) {
auto threadpool_ptr = caffe2::pthreadpool_();
ASSERT_NE(threadpool_ptr, nullptr);
{
caffe2::_NoPThreadPoolGuard g1;
auto threadpool_ptr1 = caffe2::pthreadpool_();
ASSERT_EQ(threadpool_ptr1, nullptr);
{
caffe2::_NoPThreadPoolGuard g2;
auto threadpool_ptr2 = caffe2::pthreadpool_();
ASSERT_EQ(threadpool_ptr2, nullptr);
}
// Guard should restore prev value (nullptr)
auto threadpool_ptr3 = caffe2::pthreadpool_();
ASSERT_EQ(threadpool_ptr3, nullptr);
}
// Guard should restore prev value (pthreadpool_)
auto threadpool_ptr4 = caffe2::pthreadpool_();
ASSERT_NE(threadpool_ptr4, nullptr);
ASSERT_EQ(threadpool_ptr4, threadpool_ptr);
}

View File

@ -5,7 +5,10 @@ if((NOT BUILD_CAFFE2) OR (INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE))
)
if(USE_PTHREADPOOL AND NOT USE_INTERNAL_PTHREADPOOL_IMPL)
list(APPEND Caffe2_CPU_SRCS utils/threadpool/pthreadpool-cpp.cc)
list(APPEND Caffe2_CPU_SRCS
utils/threadpool/pthreadpool-cpp.cc
utils/threadpool/thread_pool_guard.cpp
)
endif()
if(NOT BUILD_CAFFE2)
@ -37,7 +40,8 @@ list(APPEND Caffe2_CPU_SRCS
if(USE_PTHREADPOOL)
list(APPEND Caffe2_CPU_SRCS
utils/threadpool/pthreadpool-cpp.cc)
utils/threadpool/pthreadpool-cpp.cc
utils/threadpool/thread_pool_guard.cpp)
if(USE_INTERNAL_PTHREADPOOL_IMPL)
list(APPEND Caffe2_CPU_SRCS
utils/threadpool/pthreadpool.cc

View File

@ -1,4 +1,5 @@
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <caffe2/utils/threadpool/thread_pool_guard.h>
#include <c10/util/Exception.h>
namespace caffe2 {
@ -62,6 +63,9 @@ PThreadPool* pthreadpool() {
}
pthreadpool_t pthreadpool_() {
if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
return nullptr;
}
PThreadPool* const threadpool = pthreadpool();
TORCH_INTERNAL_ASSERT(
threadpool, "Failed to acquire an instance of PThreadPool!");

View File

@ -0,0 +1,15 @@
#include <caffe2/utils/threadpool/thread_pool_guard.h>
namespace caffe2 {
thread_local bool _NoPThreadPoolGuard_enabled = false;
bool _NoPThreadPoolGuard::is_enabled() {
return _NoPThreadPoolGuard_enabled;
}
void _NoPThreadPoolGuard::set_enabled(bool enabled) {
_NoPThreadPoolGuard_enabled = enabled;
}
} // namespace at

View File

@ -0,0 +1,23 @@
#pragma once
#include <c10/macros/Macros.h>
namespace caffe2 {
// A RAII, thread local (!) guard that enables or disables grad mode upon
// construction, and sets it back to the original value upon destruction.
struct TORCH_API _NoPThreadPoolGuard {
static bool is_enabled();
static void set_enabled(bool enabled);
_NoPThreadPoolGuard(): prev_mode_(_NoPThreadPoolGuard::is_enabled()) {
_NoPThreadPoolGuard::set_enabled(true);
}
~_NoPThreadPoolGuard() {
_NoPThreadPoolGuard::set_enabled(prev_mode_);
}
private:
bool prev_mode_;
};
}