mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Move inter-op settings into ATen/Parallel (#20050)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20050 ghimport-source-id: cc102bab8abf3e56c099245976786317ed63ea14 Differential Revision: D15248576 Pulled By: ilia-cher fbshipit-source-id: 55ddcb7af387ddfc68a42ac7167de07ea648e249
This commit is contained in:
committed by
Facebook Github Bot
parent
36d3398aa5
commit
409200df59
@ -5,6 +5,7 @@
|
||||
|
||||
#include <atomic>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#ifdef TH_BLAS_MKL
|
||||
#include <mkl.h>
|
||||
@ -13,8 +14,43 @@
|
||||
namespace at {
|
||||
|
||||
namespace {
|
||||
const int NOT_SET = -1;
|
||||
const int CONSUMED = -2;
|
||||
|
||||
// Number of threads set by the user
|
||||
std::atomic<int> num_threads(-1);
|
||||
std::atomic<int> num_threads{NOT_SET};
|
||||
|
||||
// Number of inter-op threads set by the user;
|
||||
// NOT_SET -> positive value -> CONSUMED
|
||||
// (CONSUMED - thread pool is initialized)
|
||||
// or
|
||||
// NOT_SET -> CONSUMED
|
||||
std::atomic<int> num_interop_threads{NOT_SET};
|
||||
|
||||
// thread pool global instance is hidden,
|
||||
// users should use at::launch and get/set_num_interop_threads interface
|
||||
TaskThreadPoolBase& get_pool() {
|
||||
static std::shared_ptr<TaskThreadPoolBase> pool =
|
||||
ThreadPoolRegistry()->Create(
|
||||
"C10",
|
||||
/* device_id */ 0,
|
||||
/* pool_size */ num_interop_threads.exchange(CONSUMED),
|
||||
/* create_new */ true);
|
||||
return *pool;
|
||||
}
|
||||
|
||||
// Factory function for ThreadPoolRegistry
|
||||
std::shared_ptr<TaskThreadPoolBase> create_c10_threadpool(
|
||||
int device_id,
|
||||
int pool_size,
|
||||
bool create_new) {
|
||||
// For now, the only accepted device id is 0
|
||||
AT_CHECK(device_id == 0);
|
||||
// Create new thread pool
|
||||
AT_CHECK(create_new);
|
||||
return std::make_shared<PTThreadPool>(pool_size);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void init_num_threads() {
|
||||
@ -32,10 +68,9 @@ void init_num_threads() {
|
||||
}
|
||||
}
|
||||
|
||||
void set_num_threads(size_t nthreads) {
|
||||
if (nthreads == 0) {
|
||||
return;
|
||||
}
|
||||
void set_num_threads(int nthreads) {
|
||||
AT_CHECK(nthreads > 0, "Expected positive number of threads");
|
||||
|
||||
num_threads.store(nthreads);
|
||||
#ifdef _OPENMP
|
||||
omp_set_num_threads(nthreads);
|
||||
@ -56,7 +91,7 @@ void set_num_threads(size_t nthreads) {
|
||||
// region might be different in the new thread;
|
||||
// Use init_num_threads() during thread initialization to ensure
|
||||
// consistent size of parallel region in different threads
|
||||
size_t get_num_threads() {
|
||||
int get_num_threads() {
|
||||
#ifdef _OPENMP
|
||||
return omp_get_max_threads();
|
||||
#else
|
||||
@ -100,7 +135,7 @@ std::string get_parallel_info() {
|
||||
}
|
||||
|
||||
PTThreadPool::PTThreadPool(
|
||||
std::size_t pool_size,
|
||||
int pool_size,
|
||||
int numa_node_id)
|
||||
: c10::ThreadPool(pool_size, numa_node_id) {}
|
||||
|
||||
@ -109,26 +144,31 @@ void PTThreadPool::init_thread() {
|
||||
at::init_num_threads();
|
||||
}
|
||||
|
||||
namespace {
|
||||
C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool);
|
||||
|
||||
std::shared_ptr<TaskThreadPoolBase> createC10ThreadPool(
|
||||
int device_id,
|
||||
int pool_size,
|
||||
bool create_new) {
|
||||
static std::shared_ptr<TaskThreadPoolBase> pool =
|
||||
std::make_shared<PTThreadPool>(pool_size);
|
||||
// For now, the only accepted device id is 0
|
||||
// for the JIT inter-op pool (CPU),
|
||||
AT_ASSERT(device_id == 0);
|
||||
// we use the shared thread pool
|
||||
AT_ASSERT(!create_new);
|
||||
// and the size does not change
|
||||
AT_ASSERT(pool->size() == pool_size);
|
||||
return pool;
|
||||
void set_num_interop_threads(int nthreads) {
|
||||
AT_CHECK(nthreads > 0, "Expected positive number of threads");
|
||||
|
||||
int no_value = NOT_SET;
|
||||
AT_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads),
|
||||
"Error: cannot set number of interop threads after parallel work "
|
||||
"has started or set_num_interop_threads called");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
int get_num_interop_threads() {
|
||||
int nthreads = num_interop_threads.load();
|
||||
if (nthreads > 0) {
|
||||
return nthreads;
|
||||
} else if (nthreads == NOT_SET) {
|
||||
// return default value
|
||||
return TaskThreadPoolBase::defaultNumThreads();
|
||||
} else {
|
||||
return get_pool().size();
|
||||
}
|
||||
}
|
||||
|
||||
C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, createC10ThreadPool);
|
||||
void launch(const std::function<void()>& func) {
|
||||
get_pool().run(func);
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
|
@ -30,10 +30,10 @@ inline int64_t divup(int64_t x, int64_t y) {
|
||||
CAFFE2_API void init_num_threads();
|
||||
|
||||
// Sets the number of threads to be used in parallel region
|
||||
CAFFE2_API void set_num_threads(size_t);
|
||||
CAFFE2_API void set_num_threads(int);
|
||||
|
||||
// Returns the number of threads used in parallel region
|
||||
CAFFE2_API size_t get_num_threads();
|
||||
CAFFE2_API int get_num_threads();
|
||||
|
||||
// Returns the current thread number (starting from 0)
|
||||
// in the current parallel region, or 0 in the sequential region
|
||||
@ -153,10 +153,19 @@ CAFFE2_API std::string get_parallel_info();
|
||||
class CAFFE2_API PTThreadPool : public c10::ThreadPool {
|
||||
public:
|
||||
explicit PTThreadPool(
|
||||
std::size_t pool_size,
|
||||
int pool_size,
|
||||
int numa_node_id = -1);
|
||||
|
||||
void init_thread() override;
|
||||
};
|
||||
|
||||
// Sets number of threads used for inter-op parallelism
|
||||
CAFFE2_API void set_num_interop_threads(int);
|
||||
|
||||
// Returns the number of threads used for inter-op parallelism
|
||||
CAFFE2_API int get_num_interop_threads();
|
||||
|
||||
// Launches inter-op parallel task
|
||||
CAFFE2_API void launch(const std::function<void()>& func);
|
||||
|
||||
} // namespace at
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/test/test_assert.h>
|
||||
#include <test/cpp/jit/test_base.h>
|
||||
#include <thread>
|
||||
|
||||
|
||||
@ -11,8 +11,8 @@
|
||||
void test(int given_num_threads) {
|
||||
at::init_num_threads();
|
||||
auto t = at::ones({1000 * 1000}, at::CPU(at::kFloat));
|
||||
ASSERT(given_num_threads >= 0);
|
||||
ASSERT(at::get_num_threads() == given_num_threads);
|
||||
ASSERT_TRUE(given_num_threads >= 0);
|
||||
ASSERT_EQ(at::get_num_threads(), given_num_threads);
|
||||
auto t_sum = t.sum();
|
||||
for (int i = 0; i < 1000; ++i) {
|
||||
t_sum = t_sum + t.sum();
|
||||
@ -38,5 +38,11 @@ int main() {
|
||||
at::set_num_threads(5);
|
||||
test(at::get_num_threads());
|
||||
|
||||
// test inter-op settings
|
||||
ASSERT_EQ(at::get_num_interop_threads(), std::thread::hardware_concurrency());
|
||||
at::set_num_interop_threads(5);
|
||||
ASSERT_EQ(at::get_num_interop_threads(), 5);
|
||||
ASSERT_ANY_THROW(at::set_num_interop_threads(6));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
ThreadPool::ThreadPool(std::size_t pool_size, int numa_node_id)
|
||||
: threads_(pool_size),
|
||||
ThreadPool::ThreadPool(int pool_size, int numa_node_id)
|
||||
: threads_(pool_size < 0 ? defaultNumThreads() : pool_size),
|
||||
running_(true),
|
||||
complete_(true),
|
||||
available_(threads_.size()),
|
||||
@ -48,6 +48,9 @@ bool ThreadPool::inThreadPool() const {
|
||||
}
|
||||
|
||||
void ThreadPool::run(const std::function<void()>& func) {
|
||||
if (threads_.size() == 0) {
|
||||
throw std::runtime_error("No threads to run a task");
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
|
||||
// Set task and signal condition variable so that a worker thread will
|
||||
@ -120,20 +123,6 @@ void ThreadPool::main_loop(std::size_t index) {
|
||||
} // while running_
|
||||
}
|
||||
|
||||
// constexpr initialization guaranteed to be before any static initialization
|
||||
std::atomic<int> num_threads{1};
|
||||
void setNumThreads(size_t v) {
|
||||
if(-1 == num_threads.exchange(v)) {
|
||||
throw std::runtime_error("Error: cannot set num threads after pool has started");
|
||||
}
|
||||
}
|
||||
|
||||
TaskThreadPoolBase& global_work_queue() {
|
||||
static std::shared_ptr<TaskThreadPoolBase> pool =
|
||||
ThreadPoolRegistry()->Create("C10", 0, num_threads.exchange(-1), false);
|
||||
return *pool;
|
||||
}
|
||||
|
||||
C10_DEFINE_SHARED_REGISTRY(
|
||||
ThreadPoolRegistry,
|
||||
TaskThreadPoolBase,
|
||||
|
@ -36,6 +36,10 @@ class C10_API TaskThreadPoolBase {
|
||||
virtual bool inThreadPool() const = 0;
|
||||
|
||||
virtual ~TaskThreadPoolBase() noexcept {}
|
||||
|
||||
static size_t defaultNumThreads() {
|
||||
return std::thread::hardware_concurrency();
|
||||
}
|
||||
};
|
||||
|
||||
class C10_API ThreadPool : public c10::TaskThreadPoolBase {
|
||||
@ -66,7 +70,7 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
|
||||
ThreadPool() = delete;
|
||||
|
||||
explicit ThreadPool(
|
||||
std::size_t pool_size,
|
||||
int pool_size,
|
||||
int numa_node_id = -1);
|
||||
|
||||
~ThreadPool();
|
||||
@ -102,10 +106,6 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
|
||||
void main_loop(std::size_t index);
|
||||
};
|
||||
|
||||
C10_API void setNumThreads(size_t v);
|
||||
|
||||
C10_API TaskThreadPoolBase& global_work_queue();
|
||||
|
||||
class C10_API TaskThreadPool : public c10::ThreadPool {
|
||||
public:
|
||||
explicit TaskThreadPool(
|
||||
|
@ -125,6 +125,8 @@ Parallelism
|
||||
----------------------------------
|
||||
.. autofunction:: get_num_threads
|
||||
.. autofunction:: set_num_threads
|
||||
.. autofunction:: get_num_interop_threads
|
||||
.. autofunction:: set_num_interop_threads
|
||||
|
||||
Locally disabling gradient computation
|
||||
--------------------------------------
|
||||
|
@ -335,6 +335,8 @@ def gen_pyi(declarations_path, out):
|
||||
'as_tensor': ["def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."],
|
||||
'get_num_threads': ['def get_num_threads() -> _int: ...'],
|
||||
'set_num_threads': ['def set_num_threads(num: _int) -> None: ...'],
|
||||
'get_num_interop_threads': ['def get_num_interop_threads() -> _int: ...'],
|
||||
'set_num_interop_threads': ['def set_num_interop_threads(num: _int) -> None: ...'],
|
||||
# These functions are explicitly disabled by
|
||||
# SKIP_PYTHON_BINDINGS because they are hand bound.
|
||||
# Correspondingly, we must hand-write their signatures.
|
||||
|
@ -2147,7 +2147,15 @@ add_docstr(torch.get_num_threads,
|
||||
r"""
|
||||
get_num_threads() -> int
|
||||
|
||||
Gets the number of threads used for parallelizing CPU operations
|
||||
Returns the number of threads used for parallelizing CPU operations
|
||||
""")
|
||||
|
||||
add_docstr(torch.get_num_interop_threads,
|
||||
r"""
|
||||
get_num_interop_threads() -> int
|
||||
|
||||
Returns the number of threads used for inter-op parallelism on CPU
|
||||
(e.g. in JIT interpreter)
|
||||
""")
|
||||
|
||||
add_docstr(torch.gt,
|
||||
@ -4304,6 +4312,16 @@ To ensure that the correct number of threads is used, set_num_threads
|
||||
must be called before running eager, JIT or autograd code.
|
||||
""")
|
||||
|
||||
add_docstr(torch.set_num_interop_threads,
|
||||
r"""
|
||||
set_num_interop_threads(int)
|
||||
|
||||
Sets the number of threads used for interop parallelism
|
||||
(e.g. in JIT interpreter) on CPU.
|
||||
WARNING: Can only be called once and before any inter-op parallel work
|
||||
is started (e.g. JIT execution).
|
||||
""")
|
||||
|
||||
add_docstr(torch.sigmoid,
|
||||
r"""
|
||||
sigmoid(input, out=None) -> Tensor
|
||||
|
@ -158,7 +158,24 @@ static PyObject * THPModule_setNumThreads(PyObject *module, PyObject *arg)
|
||||
{
|
||||
THPUtils_assert(THPUtils_checkLong(arg), "set_num_threads expects an int, "
|
||||
"but got %s", THPUtils_typename(arg));
|
||||
at::set_num_threads((int)THPUtils_unpackLong(arg));
|
||||
int nthreads = (int)THPUtils_unpackLong(arg);
|
||||
THPUtils_assert(nthreads > 0, "set_num_threads expects a positive integer");
|
||||
at::set_num_threads(nthreads);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject * THPModule_getNumInteropThreads(PyObject *module)
|
||||
{
|
||||
return PyLong_FromLong(at::get_num_interop_threads());
|
||||
}
|
||||
|
||||
static PyObject * THPModule_setNumInteropThreads(PyObject *module, PyObject *arg)
|
||||
{
|
||||
THPUtils_assert(THPUtils_checkLong(arg), "set_num_interop_threads expects an int, "
|
||||
"but got %s", THPUtils_typename(arg));
|
||||
int nthreads = (int)THPUtils_unpackLong(arg);
|
||||
THPUtils_assert(nthreads > 0, "set_num_interop_threads expects a positive integer");
|
||||
at::set_num_interop_threads(nthreads);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
@ -458,6 +475,8 @@ static PyMethodDef TorchMethods[] = {
|
||||
{"_get_backcompat_keepdim_warn", (PyCFunction)THPModule_getBackcompatKeepdimWarn, METH_NOARGS, nullptr},
|
||||
{"get_num_threads", (PyCFunction)THPModule_getNumThreads, METH_NOARGS, nullptr},
|
||||
{"set_num_threads", (PyCFunction)THPModule_setNumThreads, METH_O, nullptr},
|
||||
{"get_num_interop_threads", (PyCFunction)THPModule_getNumInteropThreads, METH_NOARGS, nullptr},
|
||||
{"set_num_interop_threads", (PyCFunction)THPModule_setNumInteropThreads, METH_O, nullptr},
|
||||
{"_get_cudnn_enabled", (PyCFunction)THPModule_userEnabledCuDNN, METH_NOARGS, nullptr},
|
||||
{"_set_cudnn_enabled", (PyCFunction)THPModule_setUserEnabledCuDNN, METH_O, nullptr},
|
||||
{"_get_cudnn_benchmark", (PyCFunction)THPModule_benchmarkCuDNN, METH_NOARGS, nullptr},
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <torch/csrc/jit/interpreter.h>
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <c10/core/thread_pool.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/autograd/edge.h>
|
||||
@ -700,8 +701,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
// the current thread will continue running before it suspends.
|
||||
InterpreterState state(intrusive_from_this());
|
||||
e.future->addCallback([state]() {
|
||||
c10::global_work_queue().run(InterpreterContinuation(
|
||||
state, Stack(), autograd::GradMode::is_enabled()));
|
||||
at::launch(InterpreterContinuation(state, Stack(),
|
||||
autograd::GradMode::is_enabled()));
|
||||
});
|
||||
|
||||
return true;
|
||||
|
@ -18,6 +18,7 @@
|
||||
#include <torch/csrc/jit/script/logging.h>
|
||||
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/Dict.h>
|
||||
@ -906,7 +907,7 @@ RegisterOperators reg(
|
||||
|
||||
push(stack, forked_interprester.getFuture());
|
||||
|
||||
c10::global_work_queue().run(std::move(continuation));
|
||||
at::launch(std::move(continuation));
|
||||
return 0;
|
||||
};
|
||||
}),
|
||||
|
Reference in New Issue
Block a user