mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add guard to run on current thread (#52361)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52361 Test Plan: buck build //xplat/caffe2:aten_test_test_thread_pool_guard ./aten_test_test_thread_pool_guard Reviewed By: kimishpatel Differential Revision: D26429540 fbshipit-source-id: 16e4a56d4bf9b73b1ea1ff88d7dc6730e0b1e029
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0f81a69a96
commit
ecd8e4c1d5
@ -1355,6 +1355,7 @@ filegroup(
|
||||
"caffe2/utils/threadpool/ThreadPool.cc",
|
||||
"caffe2/utils/threadpool/pthreadpool.cc",
|
||||
"caffe2/utils/threadpool/pthreadpool_impl.cc",
|
||||
"caffe2/utils/threadpool/thread_pool_guard.cpp",
|
||||
],
|
||||
)
|
||||
|
||||
|
31
aten/src/ATen/test/test_thread_pool_guard.cpp
Normal file
31
aten/src/ATen/test/test_thread_pool_guard.cpp
Normal file
@ -0,0 +1,31 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <caffe2/utils/threadpool/thread_pool_guard.h>
|
||||
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
||||
|
||||
|
||||
TEST(TestThreadPoolGuard, TestThreadPoolGuard) {
|
||||
auto threadpool_ptr = caffe2::pthreadpool_();
|
||||
|
||||
ASSERT_NE(threadpool_ptr, nullptr);
|
||||
{
|
||||
caffe2::_NoPThreadPoolGuard g1;
|
||||
auto threadpool_ptr1 = caffe2::pthreadpool_();
|
||||
ASSERT_EQ(threadpool_ptr1, nullptr);
|
||||
|
||||
{
|
||||
caffe2::_NoPThreadPoolGuard g2;
|
||||
auto threadpool_ptr2 = caffe2::pthreadpool_();
|
||||
ASSERT_EQ(threadpool_ptr2, nullptr);
|
||||
}
|
||||
|
||||
// Guard should restore prev value (nullptr)
|
||||
auto threadpool_ptr3 = caffe2::pthreadpool_();
|
||||
ASSERT_EQ(threadpool_ptr3, nullptr);
|
||||
}
|
||||
|
||||
// Guard should restore prev value (pthreadpool_)
|
||||
auto threadpool_ptr4 = caffe2::pthreadpool_();
|
||||
ASSERT_NE(threadpool_ptr4, nullptr);
|
||||
ASSERT_EQ(threadpool_ptr4, threadpool_ptr);
|
||||
}
|
@ -5,7 +5,10 @@ if((NOT BUILD_CAFFE2) OR (INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE))
|
||||
)
|
||||
|
||||
if(USE_PTHREADPOOL AND NOT USE_INTERNAL_PTHREADPOOL_IMPL)
|
||||
list(APPEND Caffe2_CPU_SRCS utils/threadpool/pthreadpool-cpp.cc)
|
||||
list(APPEND Caffe2_CPU_SRCS
|
||||
utils/threadpool/pthreadpool-cpp.cc
|
||||
utils/threadpool/thread_pool_guard.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
if(NOT BUILD_CAFFE2)
|
||||
@ -37,7 +40,8 @@ list(APPEND Caffe2_CPU_SRCS
|
||||
|
||||
if(USE_PTHREADPOOL)
|
||||
list(APPEND Caffe2_CPU_SRCS
|
||||
utils/threadpool/pthreadpool-cpp.cc)
|
||||
utils/threadpool/pthreadpool-cpp.cc
|
||||
utils/threadpool/thread_pool_guard.cpp)
|
||||
if(USE_INTERNAL_PTHREADPOOL_IMPL)
|
||||
list(APPEND Caffe2_CPU_SRCS
|
||||
utils/threadpool/pthreadpool.cc
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
||||
#include <caffe2/utils/threadpool/thread_pool_guard.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace caffe2 {
|
||||
@ -62,6 +63,9 @@ PThreadPool* pthreadpool() {
|
||||
}
|
||||
|
||||
pthreadpool_t pthreadpool_() {
|
||||
if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
|
||||
return nullptr;
|
||||
}
|
||||
PThreadPool* const threadpool = pthreadpool();
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
threadpool, "Failed to acquire an instance of PThreadPool!");
|
||||
|
15
caffe2/utils/threadpool/thread_pool_guard.cpp
Normal file
15
caffe2/utils/threadpool/thread_pool_guard.cpp
Normal file
@ -0,0 +1,15 @@
|
||||
#include <caffe2/utils/threadpool/thread_pool_guard.h>
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
thread_local bool _NoPThreadPoolGuard_enabled = false;
|
||||
|
||||
bool _NoPThreadPoolGuard::is_enabled() {
|
||||
return _NoPThreadPoolGuard_enabled;
|
||||
}
|
||||
|
||||
void _NoPThreadPoolGuard::set_enabled(bool enabled) {
|
||||
_NoPThreadPoolGuard_enabled = enabled;
|
||||
}
|
||||
|
||||
} // namespace at
|
23
caffe2/utils/threadpool/thread_pool_guard.h
Normal file
23
caffe2/utils/threadpool/thread_pool_guard.h
Normal file
@ -0,0 +1,23 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
// A RAII, thread local (!) guard that enables or disables grad mode upon
|
||||
// construction, and sets it back to the original value upon destruction.
|
||||
struct TORCH_API _NoPThreadPoolGuard {
|
||||
static bool is_enabled();
|
||||
static void set_enabled(bool enabled);
|
||||
|
||||
_NoPThreadPoolGuard(): prev_mode_(_NoPThreadPoolGuard::is_enabled()) {
|
||||
_NoPThreadPoolGuard::set_enabled(true);
|
||||
}
|
||||
~_NoPThreadPoolGuard() {
|
||||
_NoPThreadPoolGuard::set_enabled(prev_mode_);
|
||||
}
|
||||
private:
|
||||
bool prev_mode_;
|
||||
};
|
||||
|
||||
}
|
Reference in New Issue
Block a user