mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Resend "Split ATen/Parallel into interface and backend" (#20825)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20825 ghimport-source-id: 0371fbd37cb37635647d473d5ac9f2859e787061 Differential Revision: D15458073 Pulled By: ilia-cher fbshipit-source-id: cd27d0da1691f6be1183cd152348ac0d93a53996
This commit is contained in:
committed by
Facebook Github Bot
parent
6b74856747
commit
c3d05e86cc
19
aten/src/ATen/PTThreadPool.h
Normal file
19
aten/src/ATen/PTThreadPool.h
Normal file
@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Parallel.h>
|
||||
#include <c10/core/thread_pool.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
class CAFFE2_API PTThreadPool : public c10::ThreadPool {
|
||||
public:
|
||||
explicit PTThreadPool(
|
||||
int pool_size,
|
||||
int numa_node_id = -1)
|
||||
: c10::ThreadPool(pool_size, numa_node_id, [](){
|
||||
c10::setThreadName("PTThreadPool");
|
||||
at::init_num_threads();
|
||||
}) {}
|
||||
};
|
||||
|
||||
} // namespace at
|
@ -1,172 +0,0 @@
|
||||
#include <ATen/Parallel.h>
|
||||
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/Version.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#ifdef TH_BLAS_MKL
|
||||
#include <mkl.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace {
|
||||
const int NOT_SET = -1;
|
||||
const int CONSUMED = -2;
|
||||
|
||||
// Number of threads set by the user
|
||||
std::atomic<int> num_threads{NOT_SET};
|
||||
|
||||
// Number of inter-op threads set by the user;
|
||||
// NOT_SET -> positive value -> CONSUMED
|
||||
// (CONSUMED - thread pool is initialized)
|
||||
// or
|
||||
// NOT_SET -> CONSUMED
|
||||
std::atomic<int> num_interop_threads{NOT_SET};
|
||||
|
||||
// thread pool global instance is hidden,
|
||||
// users should use at::launch and get/set_num_interop_threads interface
|
||||
TaskThreadPoolBase& get_pool() {
|
||||
static std::shared_ptr<TaskThreadPoolBase> pool =
|
||||
ThreadPoolRegistry()->Create(
|
||||
"C10",
|
||||
/* device_id */ 0,
|
||||
/* pool_size */ num_interop_threads.exchange(CONSUMED),
|
||||
/* create_new */ true);
|
||||
return *pool;
|
||||
}
|
||||
|
||||
// Factory function for ThreadPoolRegistry
|
||||
std::shared_ptr<TaskThreadPoolBase> create_c10_threadpool(
|
||||
int device_id,
|
||||
int pool_size,
|
||||
bool create_new) {
|
||||
// For now, the only accepted device id is 0
|
||||
AT_CHECK(device_id == 0);
|
||||
// Create new thread pool
|
||||
AT_CHECK(create_new);
|
||||
return std::make_shared<PTThreadPool>(pool_size);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void init_num_threads() {
|
||||
auto nthreads = num_threads.load();
|
||||
if (nthreads > 0) {
|
||||
set_num_threads(nthreads);
|
||||
} else {
|
||||
#if defined(_OPENMP) && defined(TH_BLAS_MKL)
|
||||
// If we are using MKL an OpenMP make sure the number of threads match.
|
||||
// Otherwise, MKL and our OpenMP-enabled functions will keep changing the
|
||||
// size of the OpenMP thread pool, resulting in worse performance (and memory
|
||||
// leaks in GCC 5.4)
|
||||
omp_set_num_threads(mkl_get_max_threads());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void set_num_threads(int nthreads) {
|
||||
AT_CHECK(nthreads > 0, "Expected positive number of threads");
|
||||
|
||||
num_threads.store(nthreads);
|
||||
#ifdef _OPENMP
|
||||
omp_set_num_threads(nthreads);
|
||||
#endif
|
||||
#ifdef TH_BLAS_MKL
|
||||
mkl_set_num_threads(nthreads);
|
||||
|
||||
// because PyTorch uses OpenMP outside of MKL invocations
|
||||
// as well, we want this flag to be false, so that
|
||||
// threads aren't destroyed and recreated across every
|
||||
// MKL / non-MKL boundary of OpenMP usage
|
||||
// See https://github.com/pytorch/pytorch/issues/13757
|
||||
mkl_set_dynamic(false);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Explicitly calling omp_get_max_threads() as the size of the parallel
|
||||
// region might be different in the new thread;
|
||||
// Use init_num_threads() during thread initialization to ensure
|
||||
// consistent size of parallel region in different threads
|
||||
int get_num_threads() {
|
||||
#ifdef _OPENMP
|
||||
return omp_get_max_threads();
|
||||
#else
|
||||
return 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
namespace {
|
||||
const char* get_env_var(const char* var_name) {
|
||||
const char* value = std::getenv(var_name);
|
||||
return value ? value : "[not set]";
|
||||
}
|
||||
}
|
||||
|
||||
std::string get_parallel_info() {
|
||||
std::ostringstream ss;
|
||||
|
||||
ss << "ATen/Parallel:\n\tat::get_num_threads() : "
|
||||
<< at::get_num_threads() << std::endl;
|
||||
|
||||
ss << at::get_openmp_version() << std::endl;
|
||||
#ifdef _OPENMP
|
||||
ss << "\tomp_get_max_threads() : " << omp_get_max_threads() << std::endl;
|
||||
#endif
|
||||
|
||||
ss << at::get_mkl_version() << std::endl;
|
||||
#ifdef TH_BLAS_MKL
|
||||
ss << "\tmkl_get_max_threads() : " << mkl_get_max_threads() << std::endl;
|
||||
#endif
|
||||
|
||||
ss << at::get_mkldnn_version() << std::endl;
|
||||
|
||||
ss << "std::thread::hardware_concurrency() : "
|
||||
<< std::thread::hardware_concurrency() << std::endl;
|
||||
|
||||
ss << "Environment variables:" << std::endl;
|
||||
ss << "\tOMP_NUM_THREADS : " << get_env_var("OMP_NUM_THREADS") << std::endl;
|
||||
ss << "\tMKL_NUM_THREADS : " << get_env_var("MKL_NUM_THREADS") << std::endl;
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
PTThreadPool::PTThreadPool(
|
||||
int pool_size,
|
||||
int numa_node_id)
|
||||
: c10::ThreadPool(pool_size, numa_node_id, [](){
|
||||
c10::setThreadName("PTThreadPool");
|
||||
at::init_num_threads();
|
||||
}) {}
|
||||
|
||||
C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool);
|
||||
|
||||
void set_num_interop_threads(int nthreads) {
|
||||
AT_CHECK(nthreads > 0, "Expected positive number of threads");
|
||||
|
||||
int no_value = NOT_SET;
|
||||
AT_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads),
|
||||
"Error: cannot set number of interop threads after parallel work "
|
||||
"has started or set_num_interop_threads called");
|
||||
}
|
||||
|
||||
int get_num_interop_threads() {
|
||||
int nthreads = num_interop_threads.load();
|
||||
if (nthreads > 0) {
|
||||
return nthreads;
|
||||
} else if (nthreads == NOT_SET) {
|
||||
// return default value
|
||||
return TaskThreadPoolBase::defaultNumThreads();
|
||||
} else {
|
||||
return get_pool().size();
|
||||
}
|
||||
}
|
||||
|
||||
void launch(const std::function<void()>& func) {
|
||||
get_pool().run(func);
|
||||
}
|
||||
|
||||
} // namespace at
|
@ -1,16 +1,5 @@
|
||||
#pragma once
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/core/thread_pool.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cstddef>
|
||||
#include <exception>
|
||||
|
||||
#ifdef _OPENMP
|
||||
#define INTRA_OP_PARALLEL
|
||||
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace internal {
|
||||
@ -37,56 +26,17 @@ CAFFE2_API int get_num_threads();
|
||||
|
||||
// Returns the current thread number (starting from 0)
|
||||
// in the current parallel region, or 0 in the sequential region
|
||||
inline int get_thread_num() {
|
||||
#ifdef _OPENMP
|
||||
return omp_get_thread_num();
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
CAFFE2_API int get_thread_num();
|
||||
|
||||
inline bool in_parallel_region() {
|
||||
#ifdef _OPENMP
|
||||
return omp_in_parallel();
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
// Checks whether the code runs in parallel region
|
||||
CAFFE2_API bool in_parallel_region();
|
||||
|
||||
template <class F>
|
||||
inline void parallel_for(
|
||||
const int64_t begin,
|
||||
const int64_t end,
|
||||
const int64_t grain_size,
|
||||
const F& f) {
|
||||
#ifdef _OPENMP
|
||||
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
|
||||
std::exception_ptr eptr;
|
||||
#pragma omp parallel if (!omp_in_parallel() && ((end - begin) >= grain_size))
|
||||
{
|
||||
int64_t num_threads = omp_get_num_threads();
|
||||
int64_t tid = omp_get_thread_num();
|
||||
int64_t chunk_size = divup((end - begin), num_threads);
|
||||
int64_t begin_tid = begin + tid * chunk_size;
|
||||
if (begin_tid < end) {
|
||||
try {
|
||||
f(begin_tid, std::min(end, chunk_size + begin_tid));
|
||||
} catch (...) {
|
||||
if (!err_flag.test_and_set()) {
|
||||
eptr = std::current_exception();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (eptr) {
|
||||
std::rethrow_exception(eptr);
|
||||
}
|
||||
#else
|
||||
if (begin < end) {
|
||||
f(begin, end);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
const F& f);
|
||||
|
||||
/*
|
||||
parallel_reduce
|
||||
@ -127,36 +77,12 @@ inline scalar_t parallel_reduce(
|
||||
const int64_t end,
|
||||
const int64_t grain_size,
|
||||
const scalar_t ident,
|
||||
const F f,
|
||||
const SF sf) {
|
||||
if (in_parallel_region() || get_num_threads() == 1) {
|
||||
return f(begin, end, ident);
|
||||
} else {
|
||||
const int64_t num_results = divup((end - begin), grain_size);
|
||||
std::vector<scalar_t> results(num_results);
|
||||
scalar_t* results_data = results.data();
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for if ((end - begin) >= grain_size)
|
||||
#endif
|
||||
for (int64_t id = 0; id < num_results; id++) {
|
||||
int64_t i = begin + id * grain_size;
|
||||
results_data[id] = f(i, i + std::min(end - i, grain_size), ident);
|
||||
}
|
||||
return std::accumulate(
|
||||
results_data, results_data + results.size(), ident, sf);
|
||||
}
|
||||
}
|
||||
const F& f,
|
||||
const SF& sf);
|
||||
|
||||
// Returns a detailed string describing parallelization settings
|
||||
CAFFE2_API std::string get_parallel_info();
|
||||
|
||||
class CAFFE2_API PTThreadPool : public c10::ThreadPool {
|
||||
public:
|
||||
explicit PTThreadPool(
|
||||
int pool_size,
|
||||
int numa_node_id = -1);
|
||||
};
|
||||
|
||||
// Sets number of threads used for inter-op parallelism
|
||||
CAFFE2_API void set_num_interop_threads(int);
|
||||
|
||||
@ -167,3 +93,7 @@ CAFFE2_API int get_num_interop_threads();
|
||||
CAFFE2_API void launch(const std::function<void()>& func);
|
||||
|
||||
} // namespace at
|
||||
|
||||
#if AT_PARALLEL_OPENMP
|
||||
#include <ATen/ParallelOpenMP.h>
|
||||
#endif
|
||||
|
56
aten/src/ATen/ParallelCommon.cpp
Normal file
56
aten/src/ATen/ParallelCommon.cpp
Normal file
@ -0,0 +1,56 @@
|
||||
#include <ATen/Parallel.h>
|
||||
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/Version.h>
|
||||
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#ifdef TH_BLAS_MKL
|
||||
#include <mkl.h>
|
||||
#endif
|
||||
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace {
|
||||
|
||||
const char* get_env_var(const char* var_name) {
|
||||
const char* value = std::getenv(var_name);
|
||||
return value ? value : "[not set]";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::string get_parallel_info() {
|
||||
std::ostringstream ss;
|
||||
|
||||
ss << "ATen/Parallel:\n\tat::get_num_threads() : "
|
||||
<< at::get_num_threads() << std::endl;
|
||||
|
||||
ss << at::get_openmp_version() << std::endl;
|
||||
#ifdef _OPENMP
|
||||
ss << "\tomp_get_max_threads() : " << omp_get_max_threads() << std::endl;
|
||||
#endif
|
||||
|
||||
ss << at::get_mkl_version() << std::endl;
|
||||
#ifdef TH_BLAS_MKL
|
||||
ss << "\tmkl_get_max_threads() : " << mkl_get_max_threads() << std::endl;
|
||||
#endif
|
||||
|
||||
ss << at::get_mkldnn_version() << std::endl;
|
||||
|
||||
ss << "std::thread::hardware_concurrency() : "
|
||||
<< std::thread::hardware_concurrency() << std::endl;
|
||||
|
||||
ss << "Environment variables:" << std::endl;
|
||||
ss << "\tOMP_NUM_THREADS : " << get_env_var("OMP_NUM_THREADS") << std::endl;
|
||||
ss << "\tMKL_NUM_THREADS : " << get_env_var("MKL_NUM_THREADS") << std::endl;
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace at
|
80
aten/src/ATen/ParallelOpenMP.cpp
Normal file
80
aten/src/ATen/ParallelOpenMP.cpp
Normal file
@ -0,0 +1,80 @@
|
||||
#ifdef AT_PARALLEL_OPENMP
|
||||
#include <ATen/Parallel.h>
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#ifdef TH_BLAS_MKL
|
||||
#include <mkl.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace {
|
||||
// Number of threads set by the user
|
||||
std::atomic<int> num_threads{-1};
|
||||
|
||||
} // namespace
|
||||
|
||||
void init_num_threads() {
|
||||
auto nthreads = num_threads.load();
|
||||
if (nthreads > 0) {
|
||||
set_num_threads(nthreads);
|
||||
} else {
|
||||
#if defined(_OPENMP) && defined(TH_BLAS_MKL)
|
||||
// If we are using MKL an OpenMP make sure the number of threads match.
|
||||
// Otherwise, MKL and our OpenMP-enabled functions will keep changing the
|
||||
// size of the OpenMP thread pool, resulting in worse performance (and memory
|
||||
// leaks in GCC 5.4)
|
||||
omp_set_num_threads(mkl_get_max_threads());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void set_num_threads(int nthreads) {
|
||||
AT_CHECK(nthreads > 0, "Expected positive number of threads");
|
||||
num_threads.store(nthreads);
|
||||
#ifdef _OPENMP
|
||||
omp_set_num_threads(nthreads);
|
||||
#endif
|
||||
#ifdef TH_BLAS_MKL
|
||||
mkl_set_num_threads(nthreads);
|
||||
|
||||
// because PyTorch uses OpenMP outside of MKL invocations
|
||||
// as well, we want this flag to be false, so that
|
||||
// threads aren't destroyed and recreated across every
|
||||
// MKL / non-MKL boundary of OpenMP usage
|
||||
// See https://github.com/pytorch/pytorch/issues/13757
|
||||
mkl_set_dynamic(false);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Explicitly calling omp_get_max_threads() as the size of the parallel
|
||||
// region might be different in the new thread;
|
||||
// Use init_num_threads() during thread initialization to ensure
|
||||
// consistent size of parallel region in different threads
|
||||
int get_num_threads() {
|
||||
#ifdef _OPENMP
|
||||
return omp_get_max_threads();
|
||||
#else
|
||||
return 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
int get_thread_num() {
|
||||
#ifdef _OPENMP
|
||||
return omp_get_thread_num();
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool in_parallel_region() {
|
||||
#ifdef _OPENMP
|
||||
return omp_in_parallel();
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
#endif
|
91
aten/src/ATen/ParallelOpenMP.h
Normal file
91
aten/src/ATen/ParallelOpenMP.h
Normal file
@ -0,0 +1,91 @@
|
||||
#pragma once
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <exception>
|
||||
|
||||
#ifdef _OPENMP
|
||||
#define INTRA_OP_PARALLEL
|
||||
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
|
||||
template <class F>
|
||||
inline void parallel_for(
|
||||
const int64_t begin,
|
||||
const int64_t end,
|
||||
const int64_t grain_size,
|
||||
const F& f) {
|
||||
if (begin >= end) {
|
||||
return;
|
||||
}
|
||||
#ifdef _OPENMP
|
||||
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
|
||||
std::exception_ptr eptr;
|
||||
#pragma omp parallel if (!omp_in_parallel() && ((end - begin) >= grain_size))
|
||||
{
|
||||
int64_t num_threads = omp_get_num_threads();
|
||||
int64_t tid = omp_get_thread_num();
|
||||
int64_t chunk_size = divup((end - begin), num_threads);
|
||||
int64_t begin_tid = begin + tid * chunk_size;
|
||||
if (begin_tid < end) {
|
||||
try {
|
||||
f(begin_tid, std::min(end, chunk_size + begin_tid));
|
||||
} catch (...) {
|
||||
if (!err_flag.test_and_set()) {
|
||||
eptr = std::current_exception();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (eptr) {
|
||||
std::rethrow_exception(eptr);
|
||||
}
|
||||
#else
|
||||
f(begin, end);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class scalar_t, class F, class SF>
|
||||
inline scalar_t parallel_reduce(
|
||||
const int64_t begin,
|
||||
const int64_t end,
|
||||
const int64_t grain_size,
|
||||
const scalar_t ident,
|
||||
const F& f,
|
||||
const SF& sf) {
|
||||
if (begin >= end) {
|
||||
return ident;
|
||||
} else if (in_parallel_region() || get_num_threads() == 1) {
|
||||
return f(begin, end, ident);
|
||||
} else {
|
||||
const int64_t num_results = divup((end - begin), grain_size);
|
||||
std::vector<scalar_t> results(num_results);
|
||||
scalar_t* results_data = results.data();
|
||||
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
|
||||
std::exception_ptr eptr;
|
||||
#pragma omp parallel for if ((end - begin) >= grain_size)
|
||||
for (int64_t id = 0; id < num_results; id++) {
|
||||
int64_t i = begin + id * grain_size;
|
||||
try {
|
||||
results_data[id] = f(i, i + std::min(end - i, grain_size), ident);
|
||||
} catch (...) {
|
||||
if (!err_flag.test_and_set()) {
|
||||
eptr = std::current_exception();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (eptr) {
|
||||
std::rethrow_exception(eptr);
|
||||
}
|
||||
scalar_t result = ident;
|
||||
for (auto partial_result : results) {
|
||||
result = sf(result, partial_result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at
|
74
aten/src/ATen/ParallelThreadPoolNative.cpp
Normal file
74
aten/src/ATen/ParallelThreadPoolNative.cpp
Normal file
@ -0,0 +1,74 @@
|
||||
#if AT_PARALLEL_OPENMP
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/PTThreadPool.h>
|
||||
|
||||
#include <atomic>
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace {
|
||||
const int NOT_SET = -1;
|
||||
const int CONSUMED = -2;
|
||||
|
||||
// Number of inter-op threads set by the user;
|
||||
// NOT_SET -> positive value -> CONSUMED
|
||||
// (CONSUMED - thread pool is initialized)
|
||||
// or
|
||||
// NOT_SET -> CONSUMED
|
||||
std::atomic<int> num_interop_threads{NOT_SET};
|
||||
|
||||
// thread pool global instance is hidden,
|
||||
// users should use at::launch and get/set_num_interop_threads interface
|
||||
TaskThreadPoolBase& get_pool() {
|
||||
static std::shared_ptr<TaskThreadPoolBase> pool =
|
||||
ThreadPoolRegistry()->Create(
|
||||
"C10",
|
||||
/* device_id */ 0,
|
||||
/* pool_size */ num_interop_threads.exchange(CONSUMED),
|
||||
/* create_new */ true);
|
||||
return *pool;
|
||||
}
|
||||
|
||||
// Factory function for ThreadPoolRegistry
|
||||
std::shared_ptr<TaskThreadPoolBase> create_c10_threadpool(
|
||||
int device_id,
|
||||
int pool_size,
|
||||
bool create_new) {
|
||||
// For now, the only accepted device id is 0
|
||||
AT_CHECK(device_id == 0);
|
||||
// Create new thread pool
|
||||
AT_CHECK(create_new);
|
||||
return std::make_shared<PTThreadPool>(pool_size);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool);
|
||||
|
||||
void set_num_interop_threads(int nthreads) {
|
||||
AT_CHECK(nthreads > 0, "Expected positive number of threads");
|
||||
|
||||
int no_value = NOT_SET;
|
||||
AT_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads),
|
||||
"Error: cannot set number of interop threads after parallel work "
|
||||
"has started or set_num_interop_threads called");
|
||||
}
|
||||
|
||||
int get_num_interop_threads() {
|
||||
int nthreads = num_interop_threads.load();
|
||||
if (nthreads > 0) {
|
||||
return nthreads;
|
||||
} else if (nthreads == NOT_SET) {
|
||||
// return default value
|
||||
return TaskThreadPoolBase::defaultNumThreads();
|
||||
} else {
|
||||
return get_pool().size();
|
||||
}
|
||||
}
|
||||
|
||||
void launch(const std::function<void()>& func) {
|
||||
get_pool().run(func);
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
#endif
|
@ -862,6 +862,16 @@ torch_set_target_props(caffe2)
|
||||
#endif()
|
||||
|
||||
target_compile_options(caffe2 PRIVATE "-DCAFFE2_BUILD_MAIN_LIB")
|
||||
|
||||
# Parallelism settings
|
||||
# OPENMP - OpenMP for intra-op, native thread pool for inter-op parallelism
|
||||
set(PARALLEL_BACKEND "OPENMP" CACHE STRING "ATen parallel backend")
|
||||
if ("${PARALLEL_BACKEND}" STREQUAL "OPENMP")
|
||||
target_compile_definitions(caffe2 PUBLIC "-DAT_PARALLEL_OPENMP=1")
|
||||
else()
|
||||
message(FATAL_ERROR "Unknown parallel backend: ${PARALLEL_BACKEND}")
|
||||
endif()
|
||||
|
||||
if (MSVC AND NOT BUILD_SHARED_LIBS)
|
||||
# Note [Supporting both static and dynamic libraries on Windows]
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
5
setup.py
5
setup.py
@ -149,6 +149,11 @@
|
||||
# LIBRARY_PATH
|
||||
# LD_LIBRARY_PATH
|
||||
# we will search for libraries in these paths
|
||||
#
|
||||
# PARALLEL_BACKEND
|
||||
# parallel backend to use for intra- and inter-op parallelism
|
||||
# possible values:
|
||||
# OPENMP - use OpenMP for intra-op and native backend for inter-op tasks
|
||||
|
||||
from __future__ import print_function
|
||||
from setuptools import setup, Extension, distutils, find_packages
|
||||
|
@ -226,6 +226,10 @@ def run_cmake(version,
|
||||
if mkldnn_threading:
|
||||
cmake_defines(cmake_args, MKLDNN_THREADING=mkldnn_threading)
|
||||
|
||||
parallel_backend = os.getenv('PARALLEL_BACKEND')
|
||||
if parallel_backend:
|
||||
cmake_defines(cmake_args, PARALLEL_BACKEND=parallel_backend)
|
||||
|
||||
if USE_GLOO_IBVERBS:
|
||||
cmake_defines(cmake_args, USE_IBVERBS="1", USE_GLOO_IBVERBS="1")
|
||||
|
||||
|
Reference in New Issue
Block a user