From 409200df59dcaabed986b50bdcc4078811a5f139 Mon Sep 17 00:00:00 2001 From: Ilia Cherniavskii Date: Fri, 17 May 2019 03:09:11 -0700 Subject: [PATCH] Move inter-op settings into ATen/Parallel (#20050) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20050 ghimport-source-id: cc102bab8abf3e56c099245976786317ed63ea14 Differential Revision: D15248576 Pulled By: ilia-cher fbshipit-source-id: 55ddcb7af387ddfc68a42ac7167de07ea648e249 --- aten/src/ATen/Parallel.cpp | 88 ++++++++++++++++++------- aten/src/ATen/Parallel.h | 15 ++++- aten/src/ATen/test/thread_init_test.cpp | 12 +++- c10/core/thread_pool.cpp | 21 ++---- c10/core/thread_pool.h | 10 +-- docs/source/torch.rst | 2 + tools/pyi/gen_pyi.py | 2 + torch/_torch_docs.py | 20 +++++- torch/csrc/Module.cpp | 21 +++++- torch/csrc/jit/interpreter.cpp | 5 +- torch/csrc/jit/register_prim_ops.cpp | 3 +- 11 files changed, 143 insertions(+), 56 deletions(-) diff --git a/aten/src/ATen/Parallel.cpp b/aten/src/ATen/Parallel.cpp index 0a965f7dba06..0d81dd443699 100644 --- a/aten/src/ATen/Parallel.cpp +++ b/aten/src/ATen/Parallel.cpp @@ -5,6 +5,7 @@ #include #include +#include #ifdef TH_BLAS_MKL #include @@ -13,8 +14,43 @@ namespace at { namespace { +const int NOT_SET = -1; +const int CONSUMED = -2; + // Number of threads set by the user -std::atomic num_threads(-1); +std::atomic num_threads{NOT_SET}; + +// Number of inter-op threads set by the user; +// NOT_SET -> positive value -> CONSUMED +// (CONSUMED - thread pool is initialized) +// or +// NOT_SET -> CONSUMED +std::atomic num_interop_threads{NOT_SET}; + +// thread pool global instance is hidden, +// users should use at::launch and get/set_num_interop_threads interface +TaskThreadPoolBase& get_pool() { + static std::shared_ptr pool = + ThreadPoolRegistry()->Create( + "C10", + /* device_id */ 0, + /* pool_size */ num_interop_threads.exchange(CONSUMED), + /* create_new */ true); + return *pool; +} + + // Factory function for ThreadPoolRegistry +std::shared_ptr create_c10_threadpool( + int device_id, + int pool_size, + bool create_new) { + // For now, the only accepted device id is 0 + AT_CHECK(device_id == 0); + // Create new thread pool + AT_CHECK(create_new); + return std::make_shared(pool_size); +} + } void init_num_threads() { @@ -32,10 +68,9 @@ void init_num_threads() { } } -void set_num_threads(size_t nthreads) { - if (nthreads == 0) { - return; - } +void set_num_threads(int nthreads) { + AT_CHECK(nthreads > 0, "Expected positive number of threads"); + num_threads.store(nthreads); #ifdef _OPENMP omp_set_num_threads(nthreads); @@ -56,7 +91,7 @@ void set_num_threads(size_t nthreads) { // region might be different in the new thread; // Use init_num_threads() during thread initialization to ensure // consistent size of parallel region in different threads -size_t get_num_threads() { +int get_num_threads() { #ifdef _OPENMP return omp_get_max_threads(); #else @@ -100,7 +135,7 @@ std::string get_parallel_info() { } PTThreadPool::PTThreadPool( - std::size_t pool_size, + int pool_size, int numa_node_id) : c10::ThreadPool(pool_size, numa_node_id) {} @@ -109,26 +144,31 @@ void PTThreadPool::init_thread() { at::init_num_threads(); } -namespace { +C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool); -std::shared_ptr createC10ThreadPool( - int device_id, - int pool_size, - bool create_new) { - static std::shared_ptr pool = - std::make_shared(pool_size); - // For now, the only accepted device id is 0 - // for the JIT inter-op pool (CPU), - AT_ASSERT(device_id == 0); - // we use the shared thread pool - AT_ASSERT(!create_new); - // and the size does not change - AT_ASSERT(pool->size() == pool_size); - return pool; +void set_num_interop_threads(int nthreads) { + AT_CHECK(nthreads > 0, "Expected positive number of threads"); + + int no_value = NOT_SET; + AT_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads), + "Error: cannot set number of interop threads after parallel work " + "has started or set_num_interop_threads called"); } -} // namespace +int get_num_interop_threads() { + int nthreads = num_interop_threads.load(); + if (nthreads > 0) { + return nthreads; + } else if (nthreads == NOT_SET) { + // return default value + return TaskThreadPoolBase::defaultNumThreads(); + } else { + return get_pool().size(); + } +} -C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, createC10ThreadPool); +void launch(const std::function& func) { + get_pool().run(func); +} } // namespace at diff --git a/aten/src/ATen/Parallel.h b/aten/src/ATen/Parallel.h index 2668619436c2..fe7530793589 100644 --- a/aten/src/ATen/Parallel.h +++ b/aten/src/ATen/Parallel.h @@ -30,10 +30,10 @@ inline int64_t divup(int64_t x, int64_t y) { CAFFE2_API void init_num_threads(); // Sets the number of threads to be used in parallel region -CAFFE2_API void set_num_threads(size_t); +CAFFE2_API void set_num_threads(int); // Returns the number of threads used in parallel region -CAFFE2_API size_t get_num_threads(); +CAFFE2_API int get_num_threads(); // Returns the current thread number (starting from 0) // in the current parallel region, or 0 in the sequential region @@ -153,10 +153,19 @@ CAFFE2_API std::string get_parallel_info(); class CAFFE2_API PTThreadPool : public c10::ThreadPool { public: explicit PTThreadPool( - std::size_t pool_size, + int pool_size, int numa_node_id = -1); void init_thread() override; }; +// Sets number of threads used for inter-op parallelism +CAFFE2_API void set_num_interop_threads(int); + +// Returns the number of threads used for inter-op parallelism +CAFFE2_API int get_num_interop_threads(); + +// Launches inter-op parallel task +CAFFE2_API void launch(const std::function& func); + } // namespace at diff --git a/aten/src/ATen/test/thread_init_test.cpp b/aten/src/ATen/test/thread_init_test.cpp index 1c0d8576e32d..6f9ae19485e4 100644 --- a/aten/src/ATen/test/thread_init_test.cpp +++ b/aten/src/ATen/test/thread_init_test.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include @@ -11,8 +11,8 @@ void test(int given_num_threads) { at::init_num_threads(); auto t = at::ones({1000 * 1000}, at::CPU(at::kFloat)); - ASSERT(given_num_threads >= 0); - ASSERT(at::get_num_threads() == given_num_threads); + ASSERT_TRUE(given_num_threads >= 0); + ASSERT_EQ(at::get_num_threads(), given_num_threads); auto t_sum = t.sum(); for (int i = 0; i < 1000; ++i) { t_sum = t_sum + t.sum(); @@ -38,5 +38,11 @@ int main() { at::set_num_threads(5); test(at::get_num_threads()); + // test inter-op settings + ASSERT_EQ(at::get_num_interop_threads(), std::thread::hardware_concurrency()); + at::set_num_interop_threads(5); + ASSERT_EQ(at::get_num_interop_threads(), 5); + ASSERT_ANY_THROW(at::set_num_interop_threads(6)); + return 0; } diff --git a/c10/core/thread_pool.cpp b/c10/core/thread_pool.cpp index cc13566e29ee..1529f74bd4a9 100644 --- a/c10/core/thread_pool.cpp +++ b/c10/core/thread_pool.cpp @@ -2,8 +2,8 @@ namespace c10 { -ThreadPool::ThreadPool(std::size_t pool_size, int numa_node_id) - : threads_(pool_size), +ThreadPool::ThreadPool(int pool_size, int numa_node_id) + : threads_(pool_size < 0 ? defaultNumThreads() : pool_size), running_(true), complete_(true), available_(threads_.size()), @@ -48,6 +48,9 @@ bool ThreadPool::inThreadPool() const { } void ThreadPool::run(const std::function& func) { + if (threads_.size() == 0) { + throw std::runtime_error("No threads to run a task"); + } std::unique_lock lock(mutex_); // Set task and signal condition variable so that a worker thread will @@ -120,20 +123,6 @@ void ThreadPool::main_loop(std::size_t index) { } // while running_ } -// constexpr initialization guaranteed to be before any static initialization -std::atomic num_threads{1}; -void setNumThreads(size_t v) { - if(-1 == num_threads.exchange(v)) { - throw std::runtime_error("Error: cannot set num threads after pool has started"); - } -} - -TaskThreadPoolBase& global_work_queue() { - static std::shared_ptr pool = - ThreadPoolRegistry()->Create("C10", 0, num_threads.exchange(-1), false); - return *pool; -} - C10_DEFINE_SHARED_REGISTRY( ThreadPoolRegistry, TaskThreadPoolBase, diff --git a/c10/core/thread_pool.h b/c10/core/thread_pool.h index b4a716ac5b6a..5fe8b416c6f9 100644 --- a/c10/core/thread_pool.h +++ b/c10/core/thread_pool.h @@ -36,6 +36,10 @@ class C10_API TaskThreadPoolBase { virtual bool inThreadPool() const = 0; virtual ~TaskThreadPoolBase() noexcept {} + + static size_t defaultNumThreads() { + return std::thread::hardware_concurrency(); + } }; class C10_API ThreadPool : public c10::TaskThreadPoolBase { @@ -66,7 +70,7 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase { ThreadPool() = delete; explicit ThreadPool( - std::size_t pool_size, + int pool_size, int numa_node_id = -1); ~ThreadPool(); @@ -102,10 +106,6 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase { void main_loop(std::size_t index); }; -C10_API void setNumThreads(size_t v); - -C10_API TaskThreadPoolBase& global_work_queue(); - class C10_API TaskThreadPool : public c10::ThreadPool { public: explicit TaskThreadPool( diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 6a6370c8a65c..cd4d6975101a 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -125,6 +125,8 @@ Parallelism ---------------------------------- .. autofunction:: get_num_threads .. autofunction:: set_num_threads +.. autofunction:: get_num_interop_threads +.. autofunction:: set_num_interop_threads Locally disabling gradient computation -------------------------------------- diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index f3ddc100e346..fd4abf591783 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -335,6 +335,8 @@ def gen_pyi(declarations_path, out): 'as_tensor': ["def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."], 'get_num_threads': ['def get_num_threads() -> _int: ...'], 'set_num_threads': ['def set_num_threads(num: _int) -> None: ...'], + 'get_num_interop_threads': ['def get_num_interop_threads() -> _int: ...'], + 'set_num_interop_threads': ['def set_num_interop_threads(num: _int) -> None: ...'], # These functions are explicitly disabled by # SKIP_PYTHON_BINDINGS because they are hand bound. # Correspondingly, we must hand-write their signatures. diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index caccb9322212..8683f3434552 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2147,7 +2147,15 @@ add_docstr(torch.get_num_threads, r""" get_num_threads() -> int -Gets the number of threads used for parallelizing CPU operations +Returns the number of threads used for parallelizing CPU operations +""") + +add_docstr(torch.get_num_interop_threads, + r""" +get_num_interop_threads() -> int + +Returns the number of threads used for inter-op parallelism on CPU +(e.g. in JIT interpreter) """) add_docstr(torch.gt, @@ -4304,6 +4312,16 @@ To ensure that the correct number of threads is used, set_num_threads must be called before running eager, JIT or autograd code. """) +add_docstr(torch.set_num_interop_threads, + r""" +set_num_interop_threads(int) + +Sets the number of threads used for interop parallelism +(e.g. in JIT interpreter) on CPU. +WARNING: Can only be called once and before any inter-op parallel work +is started (e.g. JIT execution). +""") + add_docstr(torch.sigmoid, r""" sigmoid(input, out=None) -> Tensor diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 4693407ff107..3675a146b189 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -158,7 +158,24 @@ static PyObject * THPModule_setNumThreads(PyObject *module, PyObject *arg) { THPUtils_assert(THPUtils_checkLong(arg), "set_num_threads expects an int, " "but got %s", THPUtils_typename(arg)); - at::set_num_threads((int)THPUtils_unpackLong(arg)); + int nthreads = (int)THPUtils_unpackLong(arg); + THPUtils_assert(nthreads > 0, "set_num_threads expects a positive integer"); + at::set_num_threads(nthreads); + Py_RETURN_NONE; +} + +static PyObject * THPModule_getNumInteropThreads(PyObject *module) +{ + return PyLong_FromLong(at::get_num_interop_threads()); +} + +static PyObject * THPModule_setNumInteropThreads(PyObject *module, PyObject *arg) +{ + THPUtils_assert(THPUtils_checkLong(arg), "set_num_interop_threads expects an int, " + "but got %s", THPUtils_typename(arg)); + int nthreads = (int)THPUtils_unpackLong(arg); + THPUtils_assert(nthreads > 0, "set_num_interop_threads expects a positive integer"); + at::set_num_interop_threads(nthreads); Py_RETURN_NONE; } @@ -458,6 +475,8 @@ static PyMethodDef TorchMethods[] = { {"_get_backcompat_keepdim_warn", (PyCFunction)THPModule_getBackcompatKeepdimWarn, METH_NOARGS, nullptr}, {"get_num_threads", (PyCFunction)THPModule_getNumThreads, METH_NOARGS, nullptr}, {"set_num_threads", (PyCFunction)THPModule_setNumThreads, METH_O, nullptr}, + {"get_num_interop_threads", (PyCFunction)THPModule_getNumInteropThreads, METH_NOARGS, nullptr}, + {"set_num_interop_threads", (PyCFunction)THPModule_setNumInteropThreads, METH_O, nullptr}, {"_get_cudnn_enabled", (PyCFunction)THPModule_userEnabledCuDNN, METH_NOARGS, nullptr}, {"_set_cudnn_enabled", (PyCFunction)THPModule_setUserEnabledCuDNN, METH_O, nullptr}, {"_get_cudnn_benchmark", (PyCFunction)THPModule_benchmarkCuDNN, METH_NOARGS, nullptr}, diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index 18a4c7532a66..808412883d0a 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -700,8 +701,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { // the current thread will continue running before it suspends. InterpreterState state(intrusive_from_this()); e.future->addCallback([state]() { - c10::global_work_queue().run(InterpreterContinuation( - state, Stack(), autograd::GradMode::is_enabled())); + at::launch(InterpreterContinuation(state, Stack(), + autograd::GradMode::is_enabled())); }); return true; diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 189ce0b7ca79..bfcca87cc8f6 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -906,7 +907,7 @@ RegisterOperators reg( push(stack, forked_interprester.getFuture()); - c10::global_work_queue().run(std::move(continuation)); + at::launch(std::move(continuation)); return 0; }; }),