diff --git a/aten/src/ATen/PTThreadPool.h b/aten/src/ATen/PTThreadPool.h new file mode 100644 index 000000000000..f5e8a1a18256 --- /dev/null +++ b/aten/src/ATen/PTThreadPool.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +namespace at { + +class CAFFE2_API PTThreadPool : public c10::ThreadPool { +public: + explicit PTThreadPool( + int pool_size, + int numa_node_id = -1) + : c10::ThreadPool(pool_size, numa_node_id, [](){ + c10::setThreadName("PTThreadPool"); + at::init_num_threads(); + }) {} +}; + +} // namespace at diff --git a/aten/src/ATen/Parallel.cpp b/aten/src/ATen/Parallel.cpp deleted file mode 100644 index 4af4776f0b97..000000000000 --- a/aten/src/ATen/Parallel.cpp +++ /dev/null @@ -1,172 +0,0 @@ -#include - -#include -#include - -#include -#include -#include - -#ifdef TH_BLAS_MKL -#include -#endif - -namespace at { - -namespace { -const int NOT_SET = -1; -const int CONSUMED = -2; - -// Number of threads set by the user -std::atomic num_threads{NOT_SET}; - -// Number of inter-op threads set by the user; -// NOT_SET -> positive value -> CONSUMED -// (CONSUMED - thread pool is initialized) -// or -// NOT_SET -> CONSUMED -std::atomic num_interop_threads{NOT_SET}; - -// thread pool global instance is hidden, -// users should use at::launch and get/set_num_interop_threads interface -TaskThreadPoolBase& get_pool() { - static std::shared_ptr pool = - ThreadPoolRegistry()->Create( - "C10", - /* device_id */ 0, - /* pool_size */ num_interop_threads.exchange(CONSUMED), - /* create_new */ true); - return *pool; -} - - // Factory function for ThreadPoolRegistry -std::shared_ptr create_c10_threadpool( - int device_id, - int pool_size, - bool create_new) { - // For now, the only accepted device id is 0 - AT_CHECK(device_id == 0); - // Create new thread pool - AT_CHECK(create_new); - return std::make_shared(pool_size); -} - -} - -void init_num_threads() { - auto nthreads = num_threads.load(); - if (nthreads > 0) { - set_num_threads(nthreads); - } else { -#if defined(_OPENMP) && defined(TH_BLAS_MKL) - // If we are using MKL an OpenMP make sure the number of threads match. - // Otherwise, MKL and our OpenMP-enabled functions will keep changing the - // size of the OpenMP thread pool, resulting in worse performance (and memory - // leaks in GCC 5.4) - omp_set_num_threads(mkl_get_max_threads()); -#endif - } -} - -void set_num_threads(int nthreads) { - AT_CHECK(nthreads > 0, "Expected positive number of threads"); - - num_threads.store(nthreads); -#ifdef _OPENMP - omp_set_num_threads(nthreads); -#endif -#ifdef TH_BLAS_MKL - mkl_set_num_threads(nthreads); - - // because PyTorch uses OpenMP outside of MKL invocations - // as well, we want this flag to be false, so that - // threads aren't destroyed and recreated across every - // MKL / non-MKL boundary of OpenMP usage - // See https://github.com/pytorch/pytorch/issues/13757 - mkl_set_dynamic(false); -#endif -} - -// Explicitly calling omp_get_max_threads() as the size of the parallel -// region might be different in the new thread; -// Use init_num_threads() during thread initialization to ensure -// consistent size of parallel region in different threads -int get_num_threads() { -#ifdef _OPENMP - return omp_get_max_threads(); -#else - return 1; -#endif -} - -namespace { -const char* get_env_var(const char* var_name) { - const char* value = std::getenv(var_name); - return value ? value : "[not set]"; -} -} - -std::string get_parallel_info() { - std::ostringstream ss; - - ss << "ATen/Parallel:\n\tat::get_num_threads() : " - << at::get_num_threads() << std::endl; - - ss << at::get_openmp_version() << std::endl; -#ifdef _OPENMP - ss << "\tomp_get_max_threads() : " << omp_get_max_threads() << std::endl; -#endif - - ss << at::get_mkl_version() << std::endl; -#ifdef TH_BLAS_MKL - ss << "\tmkl_get_max_threads() : " << mkl_get_max_threads() << std::endl; -#endif - - ss << at::get_mkldnn_version() << std::endl; - - ss << "std::thread::hardware_concurrency() : " - << std::thread::hardware_concurrency() << std::endl; - - ss << "Environment variables:" << std::endl; - ss << "\tOMP_NUM_THREADS : " << get_env_var("OMP_NUM_THREADS") << std::endl; - ss << "\tMKL_NUM_THREADS : " << get_env_var("MKL_NUM_THREADS") << std::endl; - - return ss.str(); -} - -PTThreadPool::PTThreadPool( - int pool_size, - int numa_node_id) - : c10::ThreadPool(pool_size, numa_node_id, [](){ - c10::setThreadName("PTThreadPool"); - at::init_num_threads(); - }) {} - -C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool); - -void set_num_interop_threads(int nthreads) { - AT_CHECK(nthreads > 0, "Expected positive number of threads"); - - int no_value = NOT_SET; - AT_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads), - "Error: cannot set number of interop threads after parallel work " - "has started or set_num_interop_threads called"); -} - -int get_num_interop_threads() { - int nthreads = num_interop_threads.load(); - if (nthreads > 0) { - return nthreads; - } else if (nthreads == NOT_SET) { - // return default value - return TaskThreadPoolBase::defaultNumThreads(); - } else { - return get_pool().size(); - } -} - -void launch(const std::function& func) { - get_pool().run(func); -} - -} // namespace at diff --git a/aten/src/ATen/Parallel.h b/aten/src/ATen/Parallel.h index 452ae15128ea..4a3d5a006178 100644 --- a/aten/src/ATen/Parallel.h +++ b/aten/src/ATen/Parallel.h @@ -1,16 +1,5 @@ #pragma once #include -#include - -#include -#include -#include - -#ifdef _OPENMP -#define INTRA_OP_PARALLEL - -#include -#endif namespace at { namespace internal { @@ -37,56 +26,17 @@ CAFFE2_API int get_num_threads(); // Returns the current thread number (starting from 0) // in the current parallel region, or 0 in the sequential region -inline int get_thread_num() { -#ifdef _OPENMP - return omp_get_thread_num(); -#else - return 0; -#endif -} +CAFFE2_API int get_thread_num(); -inline bool in_parallel_region() { -#ifdef _OPENMP - return omp_in_parallel(); -#else - return false; -#endif -} +// Checks whether the code runs in parallel region +CAFFE2_API bool in_parallel_region(); template inline void parallel_for( const int64_t begin, const int64_t end, const int64_t grain_size, - const F& f) { -#ifdef _OPENMP - std::atomic_flag err_flag = ATOMIC_FLAG_INIT; - std::exception_ptr eptr; -#pragma omp parallel if (!omp_in_parallel() && ((end - begin) >= grain_size)) - { - int64_t num_threads = omp_get_num_threads(); - int64_t tid = omp_get_thread_num(); - int64_t chunk_size = divup((end - begin), num_threads); - int64_t begin_tid = begin + tid * chunk_size; - if (begin_tid < end) { - try { - f(begin_tid, std::min(end, chunk_size + begin_tid)); - } catch (...) { - if (!err_flag.test_and_set()) { - eptr = std::current_exception(); - } - } - } - } - if (eptr) { - std::rethrow_exception(eptr); - } -#else - if (begin < end) { - f(begin, end); - } -#endif -} + const F& f); /* parallel_reduce @@ -127,36 +77,12 @@ inline scalar_t parallel_reduce( const int64_t end, const int64_t grain_size, const scalar_t ident, - const F f, - const SF sf) { - if (in_parallel_region() || get_num_threads() == 1) { - return f(begin, end, ident); - } else { - const int64_t num_results = divup((end - begin), grain_size); - std::vector results(num_results); - scalar_t* results_data = results.data(); -#ifdef _OPENMP -#pragma omp parallel for if ((end - begin) >= grain_size) -#endif - for (int64_t id = 0; id < num_results; id++) { - int64_t i = begin + id * grain_size; - results_data[id] = f(i, i + std::min(end - i, grain_size), ident); - } - return std::accumulate( - results_data, results_data + results.size(), ident, sf); - } -} + const F& f, + const SF& sf); // Returns a detailed string describing parallelization settings CAFFE2_API std::string get_parallel_info(); -class CAFFE2_API PTThreadPool : public c10::ThreadPool { - public: - explicit PTThreadPool( - int pool_size, - int numa_node_id = -1); -}; - // Sets number of threads used for inter-op parallelism CAFFE2_API void set_num_interop_threads(int); @@ -167,3 +93,7 @@ CAFFE2_API int get_num_interop_threads(); CAFFE2_API void launch(const std::function& func); } // namespace at + +#if AT_PARALLEL_OPENMP +#include +#endif diff --git a/aten/src/ATen/ParallelCommon.cpp b/aten/src/ATen/ParallelCommon.cpp new file mode 100644 index 000000000000..fd2f9ddeb5cc --- /dev/null +++ b/aten/src/ATen/ParallelCommon.cpp @@ -0,0 +1,56 @@ +#include + +#include +#include + +#include +#include + +#ifdef TH_BLAS_MKL +#include +#endif + +#ifdef _OPENMP +#include +#endif + +namespace at { + +namespace { + +const char* get_env_var(const char* var_name) { + const char* value = std::getenv(var_name); + return value ? value : "[not set]"; +} + +} // namespace + +std::string get_parallel_info() { + std::ostringstream ss; + + ss << "ATen/Parallel:\n\tat::get_num_threads() : " + << at::get_num_threads() << std::endl; + + ss << at::get_openmp_version() << std::endl; +#ifdef _OPENMP + ss << "\tomp_get_max_threads() : " << omp_get_max_threads() << std::endl; +#endif + + ss << at::get_mkl_version() << std::endl; +#ifdef TH_BLAS_MKL + ss << "\tmkl_get_max_threads() : " << mkl_get_max_threads() << std::endl; +#endif + + ss << at::get_mkldnn_version() << std::endl; + + ss << "std::thread::hardware_concurrency() : " + << std::thread::hardware_concurrency() << std::endl; + + ss << "Environment variables:" << std::endl; + ss << "\tOMP_NUM_THREADS : " << get_env_var("OMP_NUM_THREADS") << std::endl; + ss << "\tMKL_NUM_THREADS : " << get_env_var("MKL_NUM_THREADS") << std::endl; + + return ss.str(); +} + +} // namespace at diff --git a/aten/src/ATen/ParallelOpenMP.cpp b/aten/src/ATen/ParallelOpenMP.cpp new file mode 100644 index 000000000000..b0255b27d493 --- /dev/null +++ b/aten/src/ATen/ParallelOpenMP.cpp @@ -0,0 +1,80 @@ +#ifdef AT_PARALLEL_OPENMP +#include + +#include + +#ifdef TH_BLAS_MKL +#include +#endif + +namespace at { + +namespace { +// Number of threads set by the user +std::atomic num_threads{-1}; + +} // namespace + +void init_num_threads() { + auto nthreads = num_threads.load(); + if (nthreads > 0) { + set_num_threads(nthreads); + } else { +#if defined(_OPENMP) && defined(TH_BLAS_MKL) + // If we are using MKL an OpenMP make sure the number of threads match. + // Otherwise, MKL and our OpenMP-enabled functions will keep changing the + // size of the OpenMP thread pool, resulting in worse performance (and memory + // leaks in GCC 5.4) + omp_set_num_threads(mkl_get_max_threads()); +#endif + } +} + +void set_num_threads(int nthreads) { + AT_CHECK(nthreads > 0, "Expected positive number of threads"); + num_threads.store(nthreads); +#ifdef _OPENMP + omp_set_num_threads(nthreads); +#endif +#ifdef TH_BLAS_MKL + mkl_set_num_threads(nthreads); + + // because PyTorch uses OpenMP outside of MKL invocations + // as well, we want this flag to be false, so that + // threads aren't destroyed and recreated across every + // MKL / non-MKL boundary of OpenMP usage + // See https://github.com/pytorch/pytorch/issues/13757 + mkl_set_dynamic(false); +#endif +} + +// Explicitly calling omp_get_max_threads() as the size of the parallel +// region might be different in the new thread; +// Use init_num_threads() during thread initialization to ensure +// consistent size of parallel region in different threads +int get_num_threads() { +#ifdef _OPENMP + return omp_get_max_threads(); +#else + return 1; +#endif +} + +int get_thread_num() { +#ifdef _OPENMP + return omp_get_thread_num(); +#else + return 0; +#endif +} + +bool in_parallel_region() { +#ifdef _OPENMP + return omp_in_parallel(); +#else + return false; +#endif +} + +} // namespace at +#endif diff --git a/aten/src/ATen/ParallelOpenMP.h b/aten/src/ATen/ParallelOpenMP.h new file mode 100644 index 000000000000..1d9918ed6c2b --- /dev/null +++ b/aten/src/ATen/ParallelOpenMP.h @@ -0,0 +1,91 @@ +#pragma once +#include + +#include +#include + +#ifdef _OPENMP +#define INTRA_OP_PARALLEL + +#include +#endif + +namespace at { + +template +inline void parallel_for( + const int64_t begin, + const int64_t end, + const int64_t grain_size, + const F& f) { + if (begin >= end) { + return; + } +#ifdef _OPENMP + std::atomic_flag err_flag = ATOMIC_FLAG_INIT; + std::exception_ptr eptr; +#pragma omp parallel if (!omp_in_parallel() && ((end - begin) >= grain_size)) + { + int64_t num_threads = omp_get_num_threads(); + int64_t tid = omp_get_thread_num(); + int64_t chunk_size = divup((end - begin), num_threads); + int64_t begin_tid = begin + tid * chunk_size; + if (begin_tid < end) { + try { + f(begin_tid, std::min(end, chunk_size + begin_tid)); + } catch (...) { + if (!err_flag.test_and_set()) { + eptr = std::current_exception(); + } + } + } + } + if (eptr) { + std::rethrow_exception(eptr); + } +#else + f(begin, end); +#endif +} + +template +inline scalar_t parallel_reduce( + const int64_t begin, + const int64_t end, + const int64_t grain_size, + const scalar_t ident, + const F& f, + const SF& sf) { + if (begin >= end) { + return ident; + } else if (in_parallel_region() || get_num_threads() == 1) { + return f(begin, end, ident); + } else { + const int64_t num_results = divup((end - begin), grain_size); + std::vector results(num_results); + scalar_t* results_data = results.data(); + std::atomic_flag err_flag = ATOMIC_FLAG_INIT; + std::exception_ptr eptr; +#pragma omp parallel for if ((end - begin) >= grain_size) + for (int64_t id = 0; id < num_results; id++) { + int64_t i = begin + id * grain_size; + try { + results_data[id] = f(i, i + std::min(end - i, grain_size), ident); + } catch (...) { + if (!err_flag.test_and_set()) { + eptr = std::current_exception(); + } + } + } + if (eptr) { + std::rethrow_exception(eptr); + } + scalar_t result = ident; + for (auto partial_result : results) { + result = sf(result, partial_result); + } + return result; + } +} + +} // namespace at diff --git a/aten/src/ATen/ParallelThreadPoolNative.cpp b/aten/src/ATen/ParallelThreadPoolNative.cpp new file mode 100644 index 000000000000..6387e49c2ea7 --- /dev/null +++ b/aten/src/ATen/ParallelThreadPoolNative.cpp @@ -0,0 +1,74 @@ +#if AT_PARALLEL_OPENMP +#include +#include + +#include + +namespace at { + +namespace { +const int NOT_SET = -1; +const int CONSUMED = -2; + +// Number of inter-op threads set by the user; +// NOT_SET -> positive value -> CONSUMED +// (CONSUMED - thread pool is initialized) +// or +// NOT_SET -> CONSUMED +std::atomic num_interop_threads{NOT_SET}; + +// thread pool global instance is hidden, +// users should use at::launch and get/set_num_interop_threads interface +TaskThreadPoolBase& get_pool() { + static std::shared_ptr pool = + ThreadPoolRegistry()->Create( + "C10", + /* device_id */ 0, + /* pool_size */ num_interop_threads.exchange(CONSUMED), + /* create_new */ true); + return *pool; +} + +// Factory function for ThreadPoolRegistry +std::shared_ptr create_c10_threadpool( + int device_id, + int pool_size, + bool create_new) { + // For now, the only accepted device id is 0 + AT_CHECK(device_id == 0); + // Create new thread pool + AT_CHECK(create_new); + return std::make_shared(pool_size); +} + +} // namespace + +C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool); + +void set_num_interop_threads(int nthreads) { + AT_CHECK(nthreads > 0, "Expected positive number of threads"); + + int no_value = NOT_SET; + AT_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads), + "Error: cannot set number of interop threads after parallel work " + "has started or set_num_interop_threads called"); +} + +int get_num_interop_threads() { + int nthreads = num_interop_threads.load(); + if (nthreads > 0) { + return nthreads; + } else if (nthreads == NOT_SET) { + // return default value + return TaskThreadPoolBase::defaultNumThreads(); + } else { + return get_pool().size(); + } +} + +void launch(const std::function& func) { + get_pool().run(func); +} + +} // namespace at +#endif diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 96c2ed92b81a..4774d9abb0c4 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -862,6 +862,16 @@ torch_set_target_props(caffe2) #endif() target_compile_options(caffe2 PRIVATE "-DCAFFE2_BUILD_MAIN_LIB") + +# Parallelism settings +# OPENMP - OpenMP for intra-op, native thread pool for inter-op parallelism +set(PARALLEL_BACKEND "OPENMP" CACHE STRING "ATen parallel backend") +if ("${PARALLEL_BACKEND}" STREQUAL "OPENMP") + target_compile_definitions(caffe2 PUBLIC "-DAT_PARALLEL_OPENMP=1") +else() + message(FATAL_ERROR "Unknown parallel backend: ${PARALLEL_BACKEND}") +endif() + if (MSVC AND NOT BUILD_SHARED_LIBS) # Note [Supporting both static and dynamic libraries on Windows] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/setup.py b/setup.py index 8c0a5cb56858..e82e1e3f1c23 100644 --- a/setup.py +++ b/setup.py @@ -149,6 +149,11 @@ # LIBRARY_PATH # LD_LIBRARY_PATH # we will search for libraries in these paths +# +# PARALLEL_BACKEND +# parallel backend to use for intra- and inter-op parallelism +# possible values: +# OPENMP - use OpenMP for intra-op and native backend for inter-op tasks from __future__ import print_function from setuptools import setup, Extension, distutils, find_packages diff --git a/tools/build_pytorch_libs.py b/tools/build_pytorch_libs.py index 4d21610bb41e..a53fd0649ac8 100644 --- a/tools/build_pytorch_libs.py +++ b/tools/build_pytorch_libs.py @@ -226,6 +226,10 @@ def run_cmake(version, if mkldnn_threading: cmake_defines(cmake_args, MKLDNN_THREADING=mkldnn_threading) + parallel_backend = os.getenv('PARALLEL_BACKEND') + if parallel_backend: + cmake_defines(cmake_args, PARALLEL_BACKEND=parallel_backend) + if USE_GLOO_IBVERBS: cmake_defines(cmake_args, USE_IBVERBS="1", USE_GLOO_IBVERBS="1")