Move inter-op settings into ATen/Parallel (#20050)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20050
ghimport-source-id: cc102bab8abf3e56c099245976786317ed63ea14

Differential Revision: D15248576

Pulled By: ilia-cher

fbshipit-source-id: 55ddcb7af387ddfc68a42ac7167de07ea648e249
This commit is contained in:
Ilia Cherniavskii
2019-05-17 03:09:11 -07:00
committed by Facebook Github Bot
parent 36d3398aa5
commit 409200df59
11 changed files with 143 additions and 56 deletions

View File

@ -5,6 +5,7 @@
#include <atomic>
#include <sstream>
#include <thread>
#ifdef TH_BLAS_MKL
#include <mkl.h>
@ -13,8 +14,43 @@
namespace at {
namespace {
const int NOT_SET = -1;
const int CONSUMED = -2;
// Number of threads set by the user
std::atomic<int> num_threads(-1);
std::atomic<int> num_threads{NOT_SET};
// Number of inter-op threads set by the user;
// NOT_SET -> positive value -> CONSUMED
// (CONSUMED - thread pool is initialized)
// or
// NOT_SET -> CONSUMED
std::atomic<int> num_interop_threads{NOT_SET};
// thread pool global instance is hidden,
// users should use at::launch and get/set_num_interop_threads interface
TaskThreadPoolBase& get_pool() {
static std::shared_ptr<TaskThreadPoolBase> pool =
ThreadPoolRegistry()->Create(
"C10",
/* device_id */ 0,
/* pool_size */ num_interop_threads.exchange(CONSUMED),
/* create_new */ true);
return *pool;
}
// Factory function for ThreadPoolRegistry
std::shared_ptr<TaskThreadPoolBase> create_c10_threadpool(
int device_id,
int pool_size,
bool create_new) {
// For now, the only accepted device id is 0
AT_CHECK(device_id == 0);
// Create new thread pool
AT_CHECK(create_new);
return std::make_shared<PTThreadPool>(pool_size);
}
}
void init_num_threads() {
@ -32,10 +68,9 @@ void init_num_threads() {
}
}
void set_num_threads(size_t nthreads) {
if (nthreads == 0) {
return;
}
void set_num_threads(int nthreads) {
AT_CHECK(nthreads > 0, "Expected positive number of threads");
num_threads.store(nthreads);
#ifdef _OPENMP
omp_set_num_threads(nthreads);
@ -56,7 +91,7 @@ void set_num_threads(size_t nthreads) {
// region might be different in the new thread;
// Use init_num_threads() during thread initialization to ensure
// consistent size of parallel region in different threads
size_t get_num_threads() {
int get_num_threads() {
#ifdef _OPENMP
return omp_get_max_threads();
#else
@ -100,7 +135,7 @@ std::string get_parallel_info() {
}
PTThreadPool::PTThreadPool(
std::size_t pool_size,
int pool_size,
int numa_node_id)
: c10::ThreadPool(pool_size, numa_node_id) {}
@ -109,26 +144,31 @@ void PTThreadPool::init_thread() {
at::init_num_threads();
}
namespace {
C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool);
std::shared_ptr<TaskThreadPoolBase> createC10ThreadPool(
int device_id,
int pool_size,
bool create_new) {
static std::shared_ptr<TaskThreadPoolBase> pool =
std::make_shared<PTThreadPool>(pool_size);
// For now, the only accepted device id is 0
// for the JIT inter-op pool (CPU),
AT_ASSERT(device_id == 0);
// we use the shared thread pool
AT_ASSERT(!create_new);
// and the size does not change
AT_ASSERT(pool->size() == pool_size);
return pool;
void set_num_interop_threads(int nthreads) {
AT_CHECK(nthreads > 0, "Expected positive number of threads");
int no_value = NOT_SET;
AT_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads),
"Error: cannot set number of interop threads after parallel work "
"has started or set_num_interop_threads called");
}
} // namespace
int get_num_interop_threads() {
int nthreads = num_interop_threads.load();
if (nthreads > 0) {
return nthreads;
} else if (nthreads == NOT_SET) {
// return default value
return TaskThreadPoolBase::defaultNumThreads();
} else {
return get_pool().size();
}
}
C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, createC10ThreadPool);
void launch(const std::function<void()>& func) {
get_pool().run(func);
}
} // namespace at

View File

@ -30,10 +30,10 @@ inline int64_t divup(int64_t x, int64_t y) {
CAFFE2_API void init_num_threads();
// Sets the number of threads to be used in parallel region
CAFFE2_API void set_num_threads(size_t);
CAFFE2_API void set_num_threads(int);
// Returns the number of threads used in parallel region
CAFFE2_API size_t get_num_threads();
CAFFE2_API int get_num_threads();
// Returns the current thread number (starting from 0)
// in the current parallel region, or 0 in the sequential region
@ -153,10 +153,19 @@ CAFFE2_API std::string get_parallel_info();
class CAFFE2_API PTThreadPool : public c10::ThreadPool {
public:
explicit PTThreadPool(
std::size_t pool_size,
int pool_size,
int numa_node_id = -1);
void init_thread() override;
};
// Sets number of threads used for inter-op parallelism
CAFFE2_API void set_num_interop_threads(int);
// Returns the number of threads used for inter-op parallelism
CAFFE2_API int get_num_interop_threads();
// Launches inter-op parallel task
CAFFE2_API void launch(const std::function<void()>& func);
} // namespace at

View File

@ -1,6 +1,6 @@
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/test/test_assert.h>
#include <test/cpp/jit/test_base.h>
#include <thread>
@ -11,8 +11,8 @@
void test(int given_num_threads) {
at::init_num_threads();
auto t = at::ones({1000 * 1000}, at::CPU(at::kFloat));
ASSERT(given_num_threads >= 0);
ASSERT(at::get_num_threads() == given_num_threads);
ASSERT_TRUE(given_num_threads >= 0);
ASSERT_EQ(at::get_num_threads(), given_num_threads);
auto t_sum = t.sum();
for (int i = 0; i < 1000; ++i) {
t_sum = t_sum + t.sum();
@ -38,5 +38,11 @@ int main() {
at::set_num_threads(5);
test(at::get_num_threads());
// test inter-op settings
ASSERT_EQ(at::get_num_interop_threads(), std::thread::hardware_concurrency());
at::set_num_interop_threads(5);
ASSERT_EQ(at::get_num_interop_threads(), 5);
ASSERT_ANY_THROW(at::set_num_interop_threads(6));
return 0;
}

View File

@ -2,8 +2,8 @@
namespace c10 {
ThreadPool::ThreadPool(std::size_t pool_size, int numa_node_id)
: threads_(pool_size),
ThreadPool::ThreadPool(int pool_size, int numa_node_id)
: threads_(pool_size < 0 ? defaultNumThreads() : pool_size),
running_(true),
complete_(true),
available_(threads_.size()),
@ -48,6 +48,9 @@ bool ThreadPool::inThreadPool() const {
}
void ThreadPool::run(const std::function<void()>& func) {
if (threads_.size() == 0) {
throw std::runtime_error("No threads to run a task");
}
std::unique_lock<std::mutex> lock(mutex_);
// Set task and signal condition variable so that a worker thread will
@ -120,20 +123,6 @@ void ThreadPool::main_loop(std::size_t index) {
} // while running_
}
// constexpr initialization guaranteed to be before any static initialization
std::atomic<int> num_threads{1};
void setNumThreads(size_t v) {
if(-1 == num_threads.exchange(v)) {
throw std::runtime_error("Error: cannot set num threads after pool has started");
}
}
TaskThreadPoolBase& global_work_queue() {
static std::shared_ptr<TaskThreadPoolBase> pool =
ThreadPoolRegistry()->Create("C10", 0, num_threads.exchange(-1), false);
return *pool;
}
C10_DEFINE_SHARED_REGISTRY(
ThreadPoolRegistry,
TaskThreadPoolBase,

View File

@ -36,6 +36,10 @@ class C10_API TaskThreadPoolBase {
virtual bool inThreadPool() const = 0;
virtual ~TaskThreadPoolBase() noexcept {}
static size_t defaultNumThreads() {
return std::thread::hardware_concurrency();
}
};
class C10_API ThreadPool : public c10::TaskThreadPoolBase {
@ -66,7 +70,7 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
ThreadPool() = delete;
explicit ThreadPool(
std::size_t pool_size,
int pool_size,
int numa_node_id = -1);
~ThreadPool();
@ -102,10 +106,6 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
void main_loop(std::size_t index);
};
C10_API void setNumThreads(size_t v);
C10_API TaskThreadPoolBase& global_work_queue();
class C10_API TaskThreadPool : public c10::ThreadPool {
public:
explicit TaskThreadPool(

View File

@ -125,6 +125,8 @@ Parallelism
----------------------------------
.. autofunction:: get_num_threads
.. autofunction:: set_num_threads
.. autofunction:: get_num_interop_threads
.. autofunction:: set_num_interop_threads
Locally disabling gradient computation
--------------------------------------

View File

@ -335,6 +335,8 @@ def gen_pyi(declarations_path, out):
'as_tensor': ["def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."],
'get_num_threads': ['def get_num_threads() -> _int: ...'],
'set_num_threads': ['def set_num_threads(num: _int) -> None: ...'],
'get_num_interop_threads': ['def get_num_interop_threads() -> _int: ...'],
'set_num_interop_threads': ['def set_num_interop_threads(num: _int) -> None: ...'],
# These functions are explicitly disabled by
# SKIP_PYTHON_BINDINGS because they are hand bound.
# Correspondingly, we must hand-write their signatures.

View File

@ -2147,7 +2147,15 @@ add_docstr(torch.get_num_threads,
r"""
get_num_threads() -> int
Gets the number of threads used for parallelizing CPU operations
Returns the number of threads used for parallelizing CPU operations
""")
add_docstr(torch.get_num_interop_threads,
r"""
get_num_interop_threads() -> int
Returns the number of threads used for inter-op parallelism on CPU
(e.g. in JIT interpreter)
""")
add_docstr(torch.gt,
@ -4304,6 +4312,16 @@ To ensure that the correct number of threads is used, set_num_threads
must be called before running eager, JIT or autograd code.
""")
add_docstr(torch.set_num_interop_threads,
r"""
set_num_interop_threads(int)
Sets the number of threads used for interop parallelism
(e.g. in JIT interpreter) on CPU.
WARNING: Can only be called once and before any inter-op parallel work
is started (e.g. JIT execution).
""")
add_docstr(torch.sigmoid,
r"""
sigmoid(input, out=None) -> Tensor

View File

@ -158,7 +158,24 @@ static PyObject * THPModule_setNumThreads(PyObject *module, PyObject *arg)
{
THPUtils_assert(THPUtils_checkLong(arg), "set_num_threads expects an int, "
"but got %s", THPUtils_typename(arg));
at::set_num_threads((int)THPUtils_unpackLong(arg));
int nthreads = (int)THPUtils_unpackLong(arg);
THPUtils_assert(nthreads > 0, "set_num_threads expects a positive integer");
at::set_num_threads(nthreads);
Py_RETURN_NONE;
}
static PyObject * THPModule_getNumInteropThreads(PyObject *module)
{
return PyLong_FromLong(at::get_num_interop_threads());
}
static PyObject * THPModule_setNumInteropThreads(PyObject *module, PyObject *arg)
{
THPUtils_assert(THPUtils_checkLong(arg), "set_num_interop_threads expects an int, "
"but got %s", THPUtils_typename(arg));
int nthreads = (int)THPUtils_unpackLong(arg);
THPUtils_assert(nthreads > 0, "set_num_interop_threads expects a positive integer");
at::set_num_interop_threads(nthreads);
Py_RETURN_NONE;
}
@ -458,6 +475,8 @@ static PyMethodDef TorchMethods[] = {
{"_get_backcompat_keepdim_warn", (PyCFunction)THPModule_getBackcompatKeepdimWarn, METH_NOARGS, nullptr},
{"get_num_threads", (PyCFunction)THPModule_getNumThreads, METH_NOARGS, nullptr},
{"set_num_threads", (PyCFunction)THPModule_setNumThreads, METH_O, nullptr},
{"get_num_interop_threads", (PyCFunction)THPModule_getNumInteropThreads, METH_NOARGS, nullptr},
{"set_num_interop_threads", (PyCFunction)THPModule_setNumInteropThreads, METH_O, nullptr},
{"_get_cudnn_enabled", (PyCFunction)THPModule_userEnabledCuDNN, METH_NOARGS, nullptr},
{"_set_cudnn_enabled", (PyCFunction)THPModule_setUserEnabledCuDNN, METH_O, nullptr},
{"_get_cudnn_benchmark", (PyCFunction)THPModule_benchmarkCuDNN, METH_NOARGS, nullptr},

View File

@ -1,6 +1,7 @@
#include <torch/csrc/jit/interpreter.h>
#include <ATen/core/ivalue.h>
#include <ATen/Parallel.h>
#include <c10/core/thread_pool.h>
#include <c10/util/Exception.h>
#include <torch/csrc/autograd/edge.h>
@ -700,8 +701,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
// the current thread will continue running before it suspends.
InterpreterState state(intrusive_from_this());
e.future->addCallback([state]() {
c10::global_work_queue().run(InterpreterContinuation(
state, Stack(), autograd::GradMode::is_enabled()));
at::launch(InterpreterContinuation(state, Stack(),
autograd::GradMode::is_enabled()));
});
return true;

View File

@ -18,6 +18,7 @@
#include <torch/csrc/jit/script/logging.h>
#include <ATen/ExpandUtils.h>
#include <ATen/Parallel.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/Dict.h>
@ -906,7 +907,7 @@ RegisterOperators reg(
push(stack, forked_interprester.getFuture());
c10::global_work_queue().run(std::move(continuation));
at::launch(std::move(continuation));
return 0;
};
}),