diff --git a/.circleci/scripts/binary_ios_upload.sh b/.circleci/scripts/binary_ios_upload.sh index 5429cadd52e7..b530521f7f2d 100644 --- a/.circleci/scripts/binary_ios_upload.sh +++ b/.circleci/scripts/binary_ios_upload.sh @@ -14,7 +14,7 @@ mkdir -p ${ZIP_DIR}/src cp -R ${ARTIFACTS_DIR}/arm64/include ${ZIP_DIR}/install/ # build a FAT bianry cd ${ZIP_DIR}/install/lib -target_libs=(libc10.a libclog.a libcpuinfo.a libeigen_blas.a libpytorch_qnnpack.a libtorch_cpu.a libtorch.a libXNNPACK.a) +target_libs=(libc10.a libclog.a libcpuinfo.a libeigen_blas.a libpthreadpool.a libpytorch_qnnpack.a libtorch_cpu.a libtorch.a libXNNPACK.a) for lib in ${target_libs[*]} do if [ -f "${ARTIFACTS_DIR}/x86_64/lib/${lib}" ] && [ -f "${ARTIFACTS_DIR}/arm64/lib/${lib}" ]; then diff --git a/BUILD.bazel b/BUILD.bazel index 5606719b175d..4095bcbec332 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1350,7 +1350,6 @@ filegroup( "caffe2/utils/smart_tensor_printer.cc", "caffe2/utils/string_utils.cc", "caffe2/utils/threadpool/ThreadPool.cc", - "caffe2/utils/threadpool/ThreadPoolMobile.cc", "caffe2/utils/threadpool/pthreadpool.cc", "caffe2/utils/threadpool/pthreadpool_impl.cc", ], diff --git a/CMakeLists.txt b/CMakeLists.txt index 140a192339e0..50f765744054 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -481,7 +481,7 @@ if(USE_PYTORCH_QNNPACK) endif() if(USE_XNNPACK) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_XNNPACK -DUSE_INTERNAL_THREADPOOL_IMPL") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_XNNPACK") endif() if(USE_VULKAN) diff --git a/android/pytorch_android/CMakeLists.txt b/android/pytorch_android/CMakeLists.txt index 8705fe54420e..f81b1bf05527 100644 --- a/android/pytorch_android/CMakeLists.txt +++ b/android/pytorch_android/CMakeLists.txt @@ -99,6 +99,7 @@ if(ANDROID_ABI) import_static_lib(libnnpack) import_static_lib(libXNNPACK) import_static_lib(libpytorch_qnnpack) + import_static_lib(libpthreadpool) import_static_lib(libeigen_blas) import_static_lib(libcpuinfo) import_static_lib(libclog) @@ -115,6 +116,7 @@ if(ANDROID_ABI) libnnpack libXNNPACK libpytorch_qnnpack + libpthreadpool libeigen_blas libcpuinfo libclog @@ -129,6 +131,7 @@ else() nnpack XNNPACK pytorch_qnnpack + pthreadpool cpuinfo clog ) diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp index 6451921d0239..f297aad8f5f5 100644 --- a/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp +++ b/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp @@ -8,8 +8,10 @@ #include "pytorch_jni_common.h" #if defined(__ANDROID__) -#include -#include +#ifndef USE_PTHREADPOOL +#define USE_PTHREADPOOL +#endif /* USE_PTHREADPOOL */ +#include #endif namespace pytorch_jni { @@ -605,7 +607,7 @@ class PyTorchAndroidJni : public facebook::jni::JavaClass { } static void setNumThreads(facebook::jni::alias_ref, jint numThreads) { - caffe2::mobile_threadpool()->setNumThreads(numThreads); + caffe2::pthreadpool()->set_thread_count(numThreads); } }; #endif diff --git a/aten/src/ATen/ParallelNative.cpp b/aten/src/ATen/ParallelNative.cpp index d5ea017e1496..1f26bbf15d3a 100644 --- a/aten/src/ATen/ParallelNative.cpp +++ b/aten/src/ATen/ParallelNative.cpp @@ -6,8 +6,7 @@ #ifndef C10_MOBILE #include #else -#include -#include +#include #endif // C10_MOBILE #include @@ -88,15 +87,15 @@ void _run_with_pool(const std::function& fn, size_t range) { // Run the first task on the current thread directly. fn(0, 0); #else - caffe2::ThreadPool* pool = caffe2::mobile_threadpool(); - if (pool) { - // caffe2::ThreadPool can utilize the current thread. - pool->run(fn, range); - } else { - for (size_t i = 0; i < range; ++i) { - fn(0, i); - } - } + caffe2::PThreadPool* const pool = caffe2::pthreadpool(); + TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!"); + + pool->run( + // PThreadPool::run() is blocking. A std::function [const] reference to + // this lambda cannot go out of scope before PThreadPool::run() returns. + [&fn](const size_t task_id) { + fn(0 /* unused */, task_id); + }, range); #endif // C10_MOBILE } @@ -184,7 +183,7 @@ void init_num_threads() { #endif #ifdef C10_MOBILE - caffe2::mobile_threadpool(); + caffe2::pthreadpool(); #endif } @@ -208,7 +207,9 @@ void set_num_threads(int nthreads) { } } #else - TORCH_CHECK(false, "set_num_threads is not supported for mobile."); + caffe2::PThreadPool* const pool = caffe2::pthreadpool(); + TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!"); + pool->set_thread_count(nthreads); #endif // C10_MOBILE } @@ -226,9 +227,9 @@ int get_num_threads() { return _get_intraop_pool().size() + 1; } #else - caffe2::ThreadPool* pool = caffe2::mobile_threadpool(); - // caffe2::ThreadPool::getNumThreads() counts the current thread. - return !pool || in_parallel_region() ? 1 /* current thread */ : pool->getNumThreads(); + caffe2::PThreadPool* const pool = caffe2::pthreadpool(); + TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!") + return in_parallel_region() ? 1 /* current thread */ : pool->get_thread_count(); #endif // C10_MOBILE } @@ -257,8 +258,8 @@ void intraop_launch(std::function func) { func(); } #else - // TODO: caffe2::ThreadPool doesn't support submitting tasks separately and - // running in parallel. Should fix it when this API becomes popular. + // TODO: caffe2::PThreadPool only provides a data-parallel API. + // Task parallelism is not currently supported. func(); #endif // C10_MOBILE } @@ -280,8 +281,8 @@ std::shared_ptr intraop_launch_future( } return future; #else - // TODO: caffe2::ThreadPool doesn't support submitting tasks separately and - // running in parallel. Should fix it when this API becomes popular. + // TODO: caffe2::PThreadPool only provides a data-parallel API. + // Task parallelism is not currently supported. auto future = std::make_shared(NoneType::get()); func(); future->markCompleted(); diff --git a/aten/src/ATen/native/NNPACK.cpp b/aten/src/ATen/native/NNPACK.cpp index 31453e81aeb4..b9f86a285289 100644 --- a/aten/src/ATen/native/NNPACK.cpp +++ b/aten/src/ATen/native/NNPACK.cpp @@ -58,7 +58,7 @@ bool _nnpack_available() { #include -#include +#include #include namespace at { @@ -87,35 +87,7 @@ static bool init_nnpack() { } static pthreadpool_t nnpack_threadpool() { - // Try initializing a threadpool for NNPACK's use. If we fail to - // successfully initialize an implementation, return nullptr which will - // instruct NNPACK to run single threaded. - -#ifdef C10_MOBILE - // If building for mobile, use Caffe 2's mobile-friendly threadpool. - return caffe2::mobile_pthreadpool(); -#else - // Otherwise, try using pthreadpool if we manage to initialize it successfully. - static pthreadpool_t nnpack_threadpool_ = nullptr; - static bool called_nnpack_threadpool_ = false; - - if (!called_nnpack_threadpool_) { - called_nnpack_threadpool_ = true; - -#ifdef INTRA_OP_PARALLEL - const uint32_t threads = at::get_num_threads(); -#else - const uint32_t threads = std::thread::hardware_concurrency(); -#endif - - nnpack_threadpool_ = pthreadpool_create(threads); - if (!nnpack_threadpool_) { - LOG(WARNING) << "Failed to initialize pthreadpool! Running NNPACK in single-threaded mode."; - } - } - - return nnpack_threadpool_; -#endif + return caffe2::pthreadpool_(); } bool _nnpack_available() { diff --git a/aten/src/ATen/native/quantized/cpu/q_avgpool.cpp b/aten/src/ATen/native/quantized/cpu/q_avgpool.cpp index d1b0f0c2eab3..2aede06b71eb 100644 --- a/aten/src/ATen/native/quantized/cpu/q_avgpool.cpp +++ b/aten/src/ATen/native/quantized/cpu/q_avgpool.cpp @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include #include @@ -375,7 +375,7 @@ Tensor qnnpack_avg_pool2d( CAFFE_ENFORCE( setupStatus == pytorch_qnnp_status_success, "failed to setup QNNPACK Average Pooling operator"); - pthreadpool_t threadpool = caffe2::mobile_pthreadpool(); + pthreadpool_t threadpool = caffe2::pthreadpool_(); const pytorch_qnnp_status runStatus = pytorch_qnnp_run_operator(qnnpack_operator, threadpool); TORCH_INTERNAL_ASSERT( diff --git a/aten/src/ATen/native/quantized/cpu/q_avgpool3d.cpp b/aten/src/ATen/native/quantized/cpu/q_avgpool3d.cpp index 062b2feab723..8c7a34b633d5 100644 --- a/aten/src/ATen/native/quantized/cpu/q_avgpool3d.cpp +++ b/aten/src/ATen/native/quantized/cpu/q_avgpool3d.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qadd.cpp b/aten/src/ATen/native/quantized/cpu/qadd.cpp index 85a861d87353..0248e2852a77 100644 --- a/aten/src/ATen/native/quantized/cpu/qadd.cpp +++ b/aten/src/ATen/native/quantized/cpu/qadd.cpp @@ -7,7 +7,7 @@ #include #include #include -#include +#include #include @@ -194,7 +194,7 @@ Tensor qnnpack_add(Tensor qa, Tensor qb, double scale, int64_t zero_point) { setupStatus == pytorch_qnnp_status_success, "failed to setup QNNPACK Add operator"); - pthreadpool_t threadpool = caffe2::mobile_pthreadpool(); + pthreadpool_t threadpool = caffe2::pthreadpool_(); const pytorch_qnnp_status runStatus = pytorch_qnnp_run_operator(qnnpack_operator, threadpool); diff --git a/aten/src/ATen/native/quantized/cpu/qchannel_shuffle.cpp b/aten/src/ATen/native/quantized/cpu/qchannel_shuffle.cpp index 6ae0b1e5da37..7ec75e3593a7 100644 --- a/aten/src/ATen/native/quantized/cpu/qchannel_shuffle.cpp +++ b/aten/src/ATen/native/quantized/cpu/qchannel_shuffle.cpp @@ -8,7 +8,7 @@ #include #include #include -#include +#include #include @@ -82,7 +82,7 @@ Tensor quantized_channel_shuffle_impl( setupStatus == pytorch_qnnp_status_success, "failed to setup QNNPACK ChannelShuffle operator"); - pthreadpool_t threadpool = caffe2::mobile_pthreadpool(); + pthreadpool_t threadpool = caffe2::pthreadpool_(); const pytorch_qnnp_status runStatus = pytorch_qnnp_run_operator(qnnpack_operator, threadpool); TORCH_INTERNAL_ASSERT( diff --git a/aten/src/ATen/native/quantized/cpu/qclamp.cpp b/aten/src/ATen/native/quantized/cpu/qclamp.cpp index 481af9d56104..9b1a45257f37 100644 --- a/aten/src/ATen/native/quantized/cpu/qclamp.cpp +++ b/aten/src/ATen/native/quantized/cpu/qclamp.cpp @@ -7,7 +7,7 @@ #include #include #include -#include +#include #include @@ -64,7 +64,7 @@ Tensor qnnpack_clamp(Tensor input, Scalar min, Scalar max) { TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success, "failed to setup QNNPACK Clamp operator"); - pthreadpool_t threadpool = caffe2::mobile_pthreadpool(); + pthreadpool_t threadpool = caffe2::pthreadpool_(); const pytorch_qnnp_status runStatus = pytorch_qnnp_run_operator(clamp_op, threadpool); diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 0aaace2731c2..157f10f32a33 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include template bool ConvDimChecks( @@ -603,7 +603,7 @@ at::Tensor PackedConvWeightsQnnp::apply_impl( output_min, output_max, reinterpret_cast(output.template data_ptr()), - caffe2::mobile_pthreadpool()); + caffe2::pthreadpool_()); TORCH_INTERNAL_ASSERT( run_status == pytorch_qnnp_status_success, diff --git a/aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp b/aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp index 7cbf5acbe80d..525d32c6bcf8 100644 --- a/aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp +++ b/aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include @@ -57,7 +57,7 @@ Tensor qnnpack_hardsigmoid(Tensor input) { TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success, "failed to setup QNNPACK Hardsigmoid operator"); - pthreadpool_t threadpool = caffe2::mobile_pthreadpool(); + pthreadpool_t threadpool = caffe2::pthreadpool_(); const pytorch_qnnp_status runStatus = pytorch_qnnp_run_operator(hardsigmoid_op, threadpool); diff --git a/aten/src/ATen/native/quantized/cpu/qhardswish.cpp b/aten/src/ATen/native/quantized/cpu/qhardswish.cpp index f7cecb817bec..f0dbd644b2be 100644 --- a/aten/src/ATen/native/quantized/cpu/qhardswish.cpp +++ b/aten/src/ATen/native/quantized/cpu/qhardswish.cpp @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include @@ -51,7 +51,7 @@ Tensor qnnpack_hardswish(const Tensor& qx, Tensor& qy) { TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success, "failed to setup QNNPACK Hardswish operator"); - pthreadpool_t threadpool = caffe2::mobile_pthreadpool(); + pthreadpool_t threadpool = caffe2::pthreadpool_(); const pytorch_qnnp_status runStatus = pytorch_qnnp_run_operator(hardswish_op, threadpool); diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index bf630e747722..8ad2776009d0 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include @@ -341,7 +341,9 @@ at::Tensor PackedLinearWeightsQnnp::apply_impl( packB->getPackedWeights(), (uint8_t*)output.data_ptr(), rows_w /* output_stride */, - caffe2::mobile_pthreadpool() /* threadpool */); + // TODO (Ashkan): Disabling temporarily. + // Throws a floating point exception with OSS pthreadpool. + nullptr); TORCH_INTERNAL_ASSERT( runStatus == pytorch_qnnp_status_success, diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index 98ac72122511..d9bd095d6066 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include #include @@ -327,7 +327,7 @@ at::Tensor PackedLinearWeightsQnnp::apply_dynamic_impl(at::Tensor input) { bias_ptr, output.data_ptr(), rows_w /* output_stride */, - caffe2::mobile_pthreadpool() /* threadpool */); + caffe2::pthreadpool_() /* threadpool */); TORCH_INTERNAL_ASSERT( runStatus == pytorch_qnnp_status_success, diff --git a/aten/src/ATen/native/quantized/cpu/qpool.cpp b/aten/src/ATen/native/quantized/cpu/qpool.cpp index 6c0bf0010376..f986ab4934b9 100644 --- a/aten/src/ATen/native/quantized/cpu/qpool.cpp +++ b/aten/src/ATen/native/quantized/cpu/qpool.cpp @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include #include @@ -346,7 +346,7 @@ void check_maxpool2d_params( setupStatus == pytorch_qnnp_status_success, "failed to setup QNNPACK MaxPool operator"); - pthreadpool_t threadpool = caffe2::mobile_pthreadpool(); + pthreadpool_t threadpool = caffe2::pthreadpool_(); const pytorch_qnnp_status runStatus = pytorch_qnnp_run_operator(qnnpack_operator, threadpool); TORCH_INTERNAL_ASSERT( diff --git a/aten/src/ATen/native/quantized/cpu/qreduction.cpp b/aten/src/ATen/native/quantized/cpu/qreduction.cpp index 2a82499143fe..f1f262d07be4 100644 --- a/aten/src/ATen/native/quantized/cpu/qreduction.cpp +++ b/aten/src/ATen/native/quantized/cpu/qreduction.cpp @@ -3,7 +3,7 @@ #include #include #include -#include +#include namespace at { namespace native { @@ -66,7 +66,7 @@ Tensor qnnpack_mean(const Tensor& input, IntArrayRef dim) { CAFFE_ENFORCE( setupStatus == pytorch_qnnp_status_success, "failed to setup QNNPACK Global Average Pooling operator"); - pthreadpool_t threadpool = caffe2::mobile_pthreadpool(); + pthreadpool_t threadpool = caffe2::pthreadpool_(); const pytorch_qnnp_status runStatus = pytorch_qnnp_run_operator(qnnpack_operator, threadpool); TORCH_INTERNAL_ASSERT( diff --git a/aten/src/ATen/native/quantized/cpu/qrelu.cpp b/aten/src/ATen/native/quantized/cpu/qrelu.cpp index d43b59973c2e..cbd934b0b23b 100644 --- a/aten/src/ATen/native/quantized/cpu/qrelu.cpp +++ b/aten/src/ATen/native/quantized/cpu/qrelu.cpp @@ -6,7 +6,7 @@ #include #include #include -#include +#include #include #include @@ -69,7 +69,7 @@ Tensor qnnpack_relu(Tensor input) { setupStatus == pytorch_qnnp_status_success, "failed to setup QNNPACK Relu operator"); - pthreadpool_t threadpool = caffe2::mobile_pthreadpool(); + pthreadpool_t threadpool = caffe2::pthreadpool_(); const pytorch_qnnp_status runStatus = pytorch_qnnp_run_operator(qnnpack_operator, threadpool); diff --git a/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp b/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp index 5147086cc81f..5d869c31665e 100644 --- a/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp +++ b/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp @@ -7,7 +7,7 @@ #include #include #include -#include +#include #include @@ -66,7 +66,7 @@ Tensor qnnpack_sigmoid(Tensor input) { TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success, "failed to setup QNNPACK sigmoid operator"); - pthreadpool_t threadpool = caffe2::mobile_pthreadpool(); + pthreadpool_t threadpool = caffe2::pthreadpool_(); const pytorch_qnnp_status runStatus = pytorch_qnnp_run_operator(sigmoid_op, threadpool); diff --git a/aten/src/ATen/native/quantized/cpu/qtanh.cpp b/aten/src/ATen/native/quantized/cpu/qtanh.cpp index 43c5261d4def..d2ccc143aa34 100644 --- a/aten/src/ATen/native/quantized/cpu/qtanh.cpp +++ b/aten/src/ATen/native/quantized/cpu/qtanh.cpp @@ -7,7 +7,7 @@ #include #include #include -#include +#include #include @@ -64,7 +64,7 @@ Tensor qnnpack_tanh(Tensor input) { TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success, "failed to setup QNNPACK TanH operator"); - pthreadpool_t threadpool = caffe2::mobile_pthreadpool(); + pthreadpool_t threadpool = caffe2::pthreadpool_(); const pytorch_qnnp_status runStatus = pytorch_qnnp_run_operator(tanh_op, threadpool); diff --git a/aten/src/ATen/native/xnnpack/Common.h b/aten/src/ATen/native/xnnpack/Common.h index 658b609a75d9..9ce410922e4f 100644 --- a/aten/src/ATen/native/xnnpack/Common.h +++ b/aten/src/ATen/native/xnnpack/Common.h @@ -5,7 +5,7 @@ #ifdef USE_XNNPACK #include -#include +#include namespace at { namespace native { diff --git a/aten/src/ATen/native/xnnpack/Convolution.cpp b/aten/src/ATen/native/xnnpack/Convolution.cpp index f8c3a51ac3b4..61a7894e639a 100644 --- a/aten/src/ATen/native/xnnpack/Convolution.cpp +++ b/aten/src/ATen/native/xnnpack/Convolution.cpp @@ -208,15 +208,15 @@ Tensor run( padded_input_nhwc.size(Layout::Activation4D::width), // input_width padded_input_nhwc.data_ptr(), // input output.data_ptr(), // output - caffe2::xnnpack_threadpool()); // threadpool + caffe2::pthreadpool_()); // threadpool TORCH_CHECK( xnn_status_success == setup_status, "xnn_setup_convolution2d_nhwc_f32 failed!"); const xnn_status run_status = xnn_run_operator( - context.op.get(), // operator - caffe2::xnnpack_threadpool()); // threadpool + context.op.get(), // operator + caffe2::pthreadpool_()); // threadpool TORCH_INTERNAL_ASSERT( xnn_status_success == run_status, diff --git a/aten/src/ATen/native/xnnpack/Linear.cpp b/aten/src/ATen/native/xnnpack/Linear.cpp index 89d3c4e85642..08f8efe40be4 100644 --- a/aten/src/ATen/native/xnnpack/Linear.cpp +++ b/aten/src/ATen/native/xnnpack/Linear.cpp @@ -137,15 +137,15 @@ Tensor run( Layout::ActivationND::batch(padded_input.sizes()), // Batch, padded_input.data_ptr(), // input output.data_ptr(), // output - caffe2::xnnpack_threadpool()); // threadpool + caffe2::pthreadpool_()); // threadpool TORCH_CHECK( xnn_status_success == setup_status, "xnn_setup_fully_connected_nc_f32 failed!"); const xnn_status run_status = xnn_run_operator( - context.op.get(), // operator - caffe2::xnnpack_threadpool()); // threadpool + context.op.get(), // operator + caffe2::pthreadpool_()); // threadpool TORCH_INTERNAL_ASSERT( xnn_status_success == run_status, diff --git a/aten/src/ATen/native/xnnpack/MaxPooling.cpp b/aten/src/ATen/native/xnnpack/MaxPooling.cpp index 088e90078509..693cd2c4f111 100644 --- a/aten/src/ATen/native/xnnpack/MaxPooling.cpp +++ b/aten/src/ATen/native/xnnpack/MaxPooling.cpp @@ -219,15 +219,15 @@ Tensor max_pool2d( input_padded_contig_nhwc.size(Layout::Activation4D::width), // input_width input_padded_contig_nhwc.data_ptr(), // input output_padded_contig_nhwc.data_ptr(), // output - caffe2::xnnpack_threadpool()); // threadpool + caffe2::pthreadpool_()); // threadpool TORCH_CHECK( xnn_status_success == setup_status, "xnn_setup_max_pooling2d_nhwc_f32 failed!"); const xnn_status run_status = xnn_run_operator( - max_pool_op, // operator - caffe2::xnnpack_threadpool()); // threadpool + max_pool_op, // operator + caffe2::pthreadpool_()); // threadpool TORCH_INTERNAL_ASSERT( xnn_status_success == run_status, diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index fa685376f7af..584d56e6a40e 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -87,7 +87,6 @@ endif() # Note: the folders that are being commented out have not been properly # addressed yet. -# For pthreadpool_new_if_impl. TODO: Remove when threadpools are unitied. if(NOT MSVC AND USE_XNNPACK) if(NOT TARGET fxdiv) set(FXDIV_BUILD_TESTS OFF CACHE BOOL "") @@ -96,10 +95,6 @@ if(NOT MSVC AND USE_XNNPACK) "${FXDIV_SOURCE_DIR}" "${CMAKE_BINARY_DIR}/FXdiv") endif() - if(NOT (INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE)) - set_source_files_properties( - utils/threadpool/pthreadpool_new_if_impl.c PROPERTIES COMPILE_FLAGS -fno-openmp) - endif() endif() add_subdirectory(core) diff --git a/caffe2/utils/CMakeLists.txt b/caffe2/utils/CMakeLists.txt index 2925f4145b0e..19b0b430e725 100644 --- a/caffe2/utils/CMakeLists.txt +++ b/caffe2/utils/CMakeLists.txt @@ -1,15 +1,8 @@ -# TODO: Add ThreadPoolXNNPACK.cc when XNNPACK integration is updated -# to pass the actual threadpool ptr instead of nullptr. if(INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE) - add_definitions(-DUSE_INTERNAL_THREADPOOL_IMPL) list(APPEND Caffe2_CPU_SRCS utils/string_utils.cc - utils/threadpool/pthreadpool.cc - utils/threadpool/pthreadpool_impl.cc - utils/threadpool/pthreadpool_new_if_impl.c + utils/threadpool/pthreadpool-cpp.cc utils/threadpool/ThreadPool.cc - utils/threadpool/ThreadPoolMobile.cc - utils/threadpool/ThreadPoolXNNPACK.cc ) set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE) return() @@ -28,23 +21,19 @@ list(APPEND Caffe2_CPU_SRCS utils/proto_convert.cc utils/proto_utils.cc utils/proto_wrap.cc + utils/threadpool/ThreadPool.cc utils/signal_handler.cc utils/smart_tensor_printer.cc - utils/string_utils.cc - utils/threadpool/ThreadPool.cc) + utils/string_utils.cc) -# ---[ threadpool/pthreadpool* is a local modification of the NNPACK -# pthreadpool with a very similar interface. Neither NNPACK, nor this -# thread pool supports Windows. -if(NOT MSVC AND USE_XNNPACK) - add_definitions(-DUSE_INTERNAL_THREADPOOL_IMPL) - set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} - utils/threadpool/pthreadpool.cc - utils/threadpool/pthreadpool_impl.cc - utils/threadpool/pthreadpool_new_if_impl.c - utils/threadpool/ThreadPoolMobile.cc - utils/threadpool/ThreadPoolXNNPACK.cc - ) +if(USE_PTHREADPOOL) + list(APPEND Caffe2_CPU_SRCS + utils/threadpool/pthreadpool-cpp.cc) + if(USE_INTERNAL_PTHREADPOOL_IMPL) + list(APPEND Caffe2_CPU_SRCS + utils/threadpool/pthreadpool.cc + utils/threadpool/pthreadpool_impl.cc) + endif() endif() set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} diff --git a/caffe2/utils/threadpool/ThreadPoolMobile.cc b/caffe2/utils/threadpool/ThreadPoolMobile.cc deleted file mode 100644 index 76b312bdac83..000000000000 --- a/caffe2/utils/threadpool/ThreadPoolMobile.cc +++ /dev/null @@ -1,21 +0,0 @@ -#include -#include -#include - -namespace caffe2 { - -caffe2::ThreadPool* mobile_threadpool() { -#ifdef C10_MOBILE - static std::unique_ptr thread_pool = - caffe2::ThreadPool::defaultThreadPool(); - return thread_pool.get(); -#else - return nullptr; -#endif -} - -pthreadpool_t mobile_pthreadpool() { - return reinterpret_cast(mobile_threadpool()); -} - -} // namespace caffe2 diff --git a/caffe2/utils/threadpool/ThreadPoolMobile.h b/caffe2/utils/threadpool/ThreadPoolMobile.h deleted file mode 100644 index 12b46067ebde..000000000000 --- a/caffe2/utils/threadpool/ThreadPoolMobile.h +++ /dev/null @@ -1,24 +0,0 @@ -#pragma once -#include - -// TODO Implement a parallel_for version for Mobile here, add to Aten/Parallel.h - -namespace caffe2 { - -class ThreadPool; - -// Return a singleton instance of caffe2::ThreadPool for ATen/TH multithreading. -ThreadPool* mobile_threadpool(); - -// NOTE: This interface is temporary and should not be used. -// Please use Aten/Parallel.h for parallel primitives in pytorch. -// This implementation will be used by pytorch mobile, specifically -// NNPACK/QNNPACK. For mobile we need to use caffe2::ThreadPool instead of the -// 3rd party pthreadpool. Future work (TODO) Implement a mobile version of -// "at::parallel_for" using caffe2::ThreadPool so all ATen/TH multithreading -// usage is mobile friendly; Refactor QNNPACK or pthreadpool to explicitly using -// "at::parallel_for" primitive to replace pthreadpool_compute_1d for Pytorch; -pthreadpool_t mobile_pthreadpool(); - -size_t getDefaultNumThreads(); -} // namespace caffe2 diff --git a/caffe2/utils/threadpool/ThreadPoolXNNPACK.cc b/caffe2/utils/threadpool/ThreadPoolXNNPACK.cc deleted file mode 100644 index 6194165849a6..000000000000 --- a/caffe2/utils/threadpool/ThreadPoolXNNPACK.cc +++ /dev/null @@ -1,22 +0,0 @@ -#include -#include -#include -#include - -namespace caffe2 { - -// Will be unified. -pthreadpool_t xnnpack_threadpool() { -// Depending on internal implemenation vs. OSS we will link against pthreadpool_create_xnnpack -// or pthreadpool_create. This is only temporary. It will be unified soon. -#ifdef USE_INTERNAL_THREADPOOL_IMPL - static std::unique_ptr - threadpool(pthreadpool_create_xnnpack(getDefaultNumThreads()), pthreadpool_destroy_xnnpack); -#else - static std::unique_ptr - threadpool(pthreadpool_create(getDefaultNumThreads()), pthreadpool_destroy); -#endif - return threadpool.get(); -} - -} // namespace caffe2 diff --git a/caffe2/utils/threadpool/ThreadPoolXNNPACK.h b/caffe2/utils/threadpool/ThreadPoolXNNPACK.h deleted file mode 100644 index e6dc9495a5de..000000000000 --- a/caffe2/utils/threadpool/ThreadPoolXNNPACK.h +++ /dev/null @@ -1,7 +0,0 @@ -#pragma once -// Creating a separate .h/.cc file for creating threadpool for XNNPACK -// to avoid touching existing internal builds. -// When we unify threadpools this should all go away. -namespace caffe2 { -pthreadpool_t xnnpack_threadpool(); -} // namespace caffe2 diff --git a/caffe2/utils/threadpool/pthreadpool-cpp.cc b/caffe2/utils/threadpool/pthreadpool-cpp.cc new file mode 100644 index 000000000000..55a6ee462311 --- /dev/null +++ b/caffe2/utils/threadpool/pthreadpool-cpp.cc @@ -0,0 +1,71 @@ +#include +#include + +namespace caffe2 { + +PThreadPool::PThreadPool(const size_t thread_count) + : threadpool_(pthreadpool_create(thread_count), pthreadpool_destroy) {} + +size_t PThreadPool::get_thread_count() const { + std::lock_guard lock{mutex_}; + + TORCH_INTERNAL_ASSERT(threadpool_.get(), "Invalid threadpool!"); + return pthreadpool_get_threads_count(threadpool_.get()); +} + +void PThreadPool::set_thread_count(const size_t thread_count) { + std::lock_guard lock{mutex_}; + + // As it stands, pthreadpool is an entirely data parallel framework with no + // support for task parallelism. Hence, all functions are blocking, and no + // user-provided tasks can be in flight when the control is returned to the + // user of the API, which means re-initializing the library, without the + // need to wait on any pending tasks, is all one needs to do to re-adjust + // the thread count. + threadpool_.reset(pthreadpool_create(thread_count)); +} + +void PThreadPool::run( + const std::function& fn, + const size_t range) { + std::lock_guard lock{mutex_}; + + TORCH_INTERNAL_ASSERT(threadpool_.get(), "Invalid threadpool!"); + + struct Context final { + const std::function& fn; + } context{ + fn, + }; + + pthreadpool_parallelize_1d( + threadpool_.get(), + // Note: pthreadpool_parallelize_1d() is a blocking function. The + // function pointer to this lambda passed on to + // pthreadpool_parallelize_1d() cannot go out of scope until + // pthreadpool_parallelize_1d() returns. + [](void* const context, const size_t item) { + reinterpret_cast(context)->fn(item); + }, + &context, + range, + 0u); +} + +// Forward declaration +size_t getDefaultNumThreads(); + +PThreadPool* pthreadpool() { + static std::unique_ptr threadpool = + std::make_unique(getDefaultNumThreads()); + return threadpool.get(); +} + +pthreadpool_t pthreadpool_() { + PThreadPool* const threadpool = pthreadpool(); + TORCH_INTERNAL_ASSERT( + threadpool, "Failed to acquire an instance of PThreadPool!"); + return threadpool->threadpool_.get(); +} + +} // namespace caffe2 diff --git a/caffe2/utils/threadpool/pthreadpool-cpp.h b/caffe2/utils/threadpool/pthreadpool-cpp.h new file mode 100644 index 000000000000..99acff4df027 --- /dev/null +++ b/caffe2/utils/threadpool/pthreadpool-cpp.h @@ -0,0 +1,54 @@ +#pragma once + +#ifdef USE_PTHREADPOOL + +#ifdef USE_INTERNAL_PTHREADPOOL_IMPL +#include +#else +#include +#endif + +#include +#include +#include + +namespace caffe2 { + +class PThreadPool final { + public: + explicit PThreadPool(size_t thread_count); + ~PThreadPool() = default; + + PThreadPool(const PThreadPool&) = delete; + PThreadPool& operator=(const PThreadPool&) = delete; + + PThreadPool(PThreadPool&&) = delete; + PThreadPool& operator=(PThreadPool&&) = delete; + + size_t get_thread_count() const; + void set_thread_count(size_t thread_count); + + // Run, in parallel, function fn(task_id) over task_id in range [0, range). + // This function is blocking. All input is processed by the time it returns. + void run(const std::function& fn, size_t range); + + private: + friend pthreadpool_t pthreadpool_(); + + private: + mutable std::mutex mutex_; + std::unique_ptr threadpool_; +}; + +// Return a singleton instance of PThreadPool for ATen/TH multithreading. +PThreadPool* pthreadpool(); + +// Exposes the underlying implementation of PThreadPool. +// Only for use in external libraries so as to unify threading across +// internal (i.e. ATen, etc.) and external (e.g. NNPACK, QNNPACK, XNNPACK) +// use cases. +pthreadpool_t pthreadpool_(); + +} // namespace caffe2 + +#endif /* USE_PTHREADPOOL */ diff --git a/caffe2/utils/threadpool/pthreadpool.cc b/caffe2/utils/threadpool/pthreadpool.cc index d9d6c1583616..ac633d271a36 100644 --- a/caffe2/utils/threadpool/pthreadpool.cc +++ b/caffe2/utils/threadpool/pthreadpool.cc @@ -32,7 +32,7 @@ static inline size_t min(size_t a, size_t b) { } struct compute_1d_tiled_context { - pthreadpool_function_1d_tiled_t function; + legacy_pthreadpool_function_1d_tiled_t function; void* argument; size_t range; size_t tile; @@ -46,9 +46,9 @@ static void compute_1d_tiled(void* context_, size_t linear_index) { context->function(context->argument, index, tile); } -void pthreadpool_compute_1d_tiled( - pthreadpool_t threadpool, - pthreadpool_function_1d_tiled_t function, +void legacy_pthreadpool_compute_1d_tiled( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_1d_tiled_t function, void* argument, size_t range, size_t tile) @@ -65,12 +65,12 @@ void pthreadpool_compute_1d_tiled( /*.argument = */ argument, /*.range = */ range, /*.tile = */ tile}; - pthreadpool_compute_1d(threadpool, (pthreadpool_function_1d_t) compute_1d_tiled, &context, tile_range); + legacy_pthreadpool_compute_1d(threadpool, (legacy_pthreadpool_function_1d_t) compute_1d_tiled, &context, tile_range); } } struct compute_2d_context { - pthreadpool_function_2d_t function; + legacy_pthreadpool_function_2d_t function; void* argument; caffe2::FixedDivisor range_j; }; @@ -85,9 +85,9 @@ static void compute_2d(void* context_, size_t linear_index) { context->function(context->argument, q, r); } -void pthreadpool_compute_2d( - struct pthreadpool* threadpool, - pthreadpool_function_2d_t function, +void legacy_pthreadpool_compute_2d( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_2d_t function, void* argument, size_t range_i, size_t range_j) @@ -106,12 +106,12 @@ void pthreadpool_compute_2d( /*.function = */ function, /*.argument = */ argument, /*.range_j = */ caffe2::FixedDivisor(range_j)}; - pthreadpool_compute_1d(threadpool, (pthreadpool_function_1d_t) compute_2d, &context, range_i * range_j); + legacy_pthreadpool_compute_1d(threadpool, (legacy_pthreadpool_function_1d_t) compute_2d, &context, range_i * range_j); } } struct compute_2d_tiled_context { - pthreadpool_function_2d_tiled_t function; + legacy_pthreadpool_function_2d_tiled_t function; void* argument; caffe2::FixedDivisor tile_range_j; size_t range_i; @@ -135,9 +135,9 @@ static void compute_2d_tiled(void* context_, size_t linear_index) { context->function(context->argument, index_i, index_j, tile_i, tile_j); } -void pthreadpool_compute_2d_tiled( - pthreadpool_t threadpool, - pthreadpool_function_2d_tiled_t function, +void legacy_pthreadpool_compute_2d_tiled( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_2d_tiled_t function, void* argument, size_t range_i, size_t range_j, @@ -166,12 +166,12 @@ void pthreadpool_compute_2d_tiled( /*.range_j = */ range_j, /*.tile_i = */ tile_i, /*.tile_j = */ tile_j}; - pthreadpool_compute_1d(threadpool, (pthreadpool_function_1d_t) compute_2d_tiled, &context, tile_range_i * tile_range_j); + legacy_pthreadpool_compute_1d(threadpool, (legacy_pthreadpool_function_1d_t) compute_2d_tiled, &context, tile_range_i * tile_range_j); } } struct compute_3d_tiled_context { - pthreadpool_function_3d_tiled_t function; + legacy_pthreadpool_function_3d_tiled_t function; void* argument; caffe2::FixedDivisor tile_range_j; caffe2::FixedDivisor tile_range_k; @@ -205,9 +205,9 @@ static void compute_3d_tiled( context->argument, index_i, index_j, index_k, tile_i, tile_j, tile_k); } -void pthreadpool_compute_3d_tiled( - pthreadpool_t threadpool, - pthreadpool_function_3d_tiled_t function, +void legacy_pthreadpool_compute_3d_tiled( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_3d_tiled_t function, void* argument, size_t range_i, size_t range_j, @@ -251,16 +251,16 @@ void pthreadpool_compute_3d_tiled( /*.tile_i = */ tile_i, /*.tile_j = */ tile_j, /*.tile_k = */ tile_k}; - pthreadpool_compute_1d( + legacy_pthreadpool_compute_1d( threadpool, - (pthreadpool_function_1d_t)compute_3d_tiled, + (legacy_pthreadpool_function_1d_t)compute_3d_tiled, &context, tile_range_i * tile_range_j * tile_range_k); } } struct compute_4d_tiled_context { - pthreadpool_function_4d_tiled_t function; + legacy_pthreadpool_function_4d_tiled_t function; void* argument; caffe2::FixedDivisor tile_range_kl; caffe2::FixedDivisor tile_range_j; @@ -310,9 +310,9 @@ static void compute_4d_tiled( tile_l); } -void pthreadpool_compute_4d_tiled( - pthreadpool_t threadpool, - pthreadpool_function_4d_tiled_t function, +void legacy_pthreadpool_compute_4d_tiled( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_4d_tiled_t function, void* argument, size_t range_i, size_t range_j, @@ -367,9 +367,9 @@ void pthreadpool_compute_4d_tiled( /*.tile_j = */ tile_j, /*.tile_k = */ tile_k, /*.tile_l = */ tile_l}; - pthreadpool_compute_1d( + legacy_pthreadpool_compute_1d( threadpool, - (pthreadpool_function_1d_t)compute_4d_tiled, + (legacy_pthreadpool_function_1d_t)compute_4d_tiled, &context, tile_range_i * tile_range_j * tile_range_k * tile_range_l); } diff --git a/caffe2/utils/threadpool/pthreadpool.h b/caffe2/utils/threadpool/pthreadpool.h index 8fd2d01126ef..27935febe45e 100644 --- a/caffe2/utils/threadpool/pthreadpool.h +++ b/caffe2/utils/threadpool/pthreadpool.h @@ -5,49 +5,16 @@ #include "ThreadPoolCommon.h" - #include // for size_t - -typedef struct pthreadpool* pthreadpool_t; - -typedef void (*pthreadpool_function_1d_t)(void*, size_t); -typedef void (*pthreadpool_function_1d_tiled_t)(void*, size_t, size_t); -typedef void (*pthreadpool_function_2d_t)(void*, size_t, size_t); -typedef void (*pthreadpool_function_2d_tiled_t)(void*, size_t, size_t, size_t, size_t); -typedef void (*pthreadpool_function_3d_tiled_t)( - void*, - size_t, - size_t, - size_t, - size_t, - size_t, - size_t); -typedef void (*pthreadpool_function_4d_tiled_t)( - void*, - size_t, - size_t, - size_t, - size_t, - size_t, - size_t, - size_t, - size_t); - #include // for uint32_t -typedef void (*pthreadpool_task_1d_t)(void*, size_t); -typedef void (*pthreadpool_task_1d_tile_1d_t)(void*, size_t, size_t); -typedef void (*pthreadpool_task_2d_t)(void*, size_t, size_t); -typedef void (*pthreadpool_task_2d_tile_1d_t)(void*, size_t, size_t, size_t); -typedef void (*pthreadpool_task_2d_tile_2d_t)(void*, size_t, size_t, size_t, size_t); -typedef void (*pthreadpool_task_3d_tile_2d_t)( - void*, - size_t, - size_t, - size_t, - size_t, - size_t); -typedef void (*pthreadpool_task_4d_tile_2d_t)( +typedef struct pthreadpool* legacy_pthreadpool_t; + +typedef void (*legacy_pthreadpool_function_1d_t)(void*, size_t); +typedef void (*legacy_pthreadpool_function_1d_tiled_t)(void*, size_t, size_t); +typedef void (*legacy_pthreadpool_function_2d_t)(void*, size_t, size_t); +typedef void (*legacy_pthreadpool_function_2d_tiled_t)(void*, size_t, size_t, size_t, size_t); +typedef void (*legacy_pthreadpool_function_3d_tiled_t)( void*, size_t, size_t, @@ -55,16 +22,7 @@ typedef void (*pthreadpool_task_4d_tile_2d_t)( size_t, size_t, size_t); -typedef void (*pthreadpool_task_5d_tile_2d_t)( - void*, - size_t, - size_t, - size_t, - size_t, - size_t, - size_t, - size_t); -typedef void (*pthreadpool_task_6d_tile_2d_t)( +typedef void (*legacy_pthreadpool_function_4d_tiled_t)( void*, size_t, size_t, @@ -90,8 +48,8 @@ extern "C" { * On error the function returns NULL and sets errno accordingly. */ -//Returns internal threadpool impl. -pthreadpool_t pthreadpool_create(size_t threads_count); +// Returns internal threadpool impl. +legacy_pthreadpool_t legacy_pthreadpool_create(size_t threads_count); /** * Queries the number of threads in a thread pool. @@ -100,7 +58,7 @@ pthreadpool_t pthreadpool_create(size_t threads_count); * * @returns The number of threads in the thread pool. */ -size_t pthreadpool_get_threads_count(pthreadpool_t threadpool); +size_t legacy_pthreadpool_get_threads_count(legacy_pthreadpool_t threadpool); /** * Processes items in parallel using threads from a thread pool. @@ -117,38 +75,45 @@ size_t pthreadpool_get_threads_count(pthreadpool_t threadpool); * @param[in] items The number of items to process. The @a function * will be called once for each item. */ -void pthreadpool_compute_1d( - pthreadpool_t threadpool, - pthreadpool_function_1d_t function, +void legacy_pthreadpool_compute_1d( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_1d_t function, void* argument, size_t range); -void pthreadpool_compute_1d_tiled( - pthreadpool_t threadpool, - pthreadpool_function_1d_tiled_t function, +void legacy_pthreadpool_parallelize_1d( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_1d_t function, + void* argument, + size_t range, + uint32_t flags); + +void legacy_pthreadpool_compute_1d_tiled( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_1d_tiled_t function, void* argument, size_t range, size_t tile); -void pthreadpool_compute_2d( - pthreadpool_t threadpool, - pthreadpool_function_2d_t function, +void legacy_pthreadpool_compute_2d( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_2d_t function, void* argument, size_t range_i, size_t range_j); -void pthreadpool_compute_2d_tiled( - pthreadpool_t threadpool, - pthreadpool_function_2d_tiled_t function, +void legacy_pthreadpool_compute_2d_tiled( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_2d_tiled_t function, void* argument, size_t range_i, size_t range_j, size_t tile_i, size_t tile_j); -void pthreadpool_compute_3d_tiled( - pthreadpool_t threadpool, - pthreadpool_function_3d_tiled_t function, +void legacy_pthreadpool_compute_3d_tiled( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_3d_tiled_t function, void* argument, size_t range_i, size_t range_j, @@ -157,9 +122,9 @@ void pthreadpool_compute_3d_tiled( size_t tile_j, size_t tile_k); -void pthreadpool_compute_4d_tiled( - pthreadpool_t threadpool, - pthreadpool_function_4d_tiled_t function, +void legacy_pthreadpool_compute_4d_tiled( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_4d_tiled_t function, void* argument, size_t range_i, size_t range_j, @@ -178,129 +143,29 @@ void pthreadpool_compute_4d_tiled( * * @param[in,out] threadpool The thread pool to destroy. */ -void pthreadpool_destroy(pthreadpool_t threadpool); +void legacy_pthreadpool_destroy(legacy_pthreadpool_t threadpool); -// New interface copy/pasted from pthreadpool. -// We will merge the internal and third-party/pthreadpool eventually. -// For now copy-paste to get past build issues. +#ifdef USE_INTERNAL_PTHREADPOOL_IMPL -#define PTHREADPOOL_FLAG_DISABLE_DENORMALS 0x00000001 +#define pthreadpool_t legacy_pthreadpool_t +#define pthreadpool_function_1d_t legacy_pthreadpool_function_1d_t +#define pthreadpool_function_1d_tiled_t legacy_pthreadpool_function_1d_tiled_t +#define pthreadpool_function_2d_t legacy_pthreadpool_function_2d_t +#define pthreadpool_function_2d_tiled_t legacy_pthreadpool_function_2d_tiled_t +#define pthreadpool_function_3d_tiled_t legacy_pthreadpool_function_3d_tiled_t +#define pthreadpool_function_4d_tiled_t legacy_pthreadpool_function_4d_tiled_t +#define pthreadpool_create legacy_pthreadpool_create +#define pthreadpool_destroy legacy_pthreadpool_destroy +#define pthreadpool_get_threads_count legacy_pthreadpool_get_threads_count +#define pthreadpool_compute_1d legacy_pthreadpool_compute_1d +#define pthreadpool_parallelize_1d legacy_pthreadpool_parallelize_1d +#define pthreadpool_compute_1d_tiled legacy_pthreadpool_compute_1d_tiled +#define pthreadpool_compute_2d legacy_pthreadpool_compute_2d +#define pthreadpool_compute_2d_tiled legacy_pthreadpool_compute_2d_tiled +#define pthreadpool_compute_3d_tiled legacy_pthreadpool_compute_3d_tiled +#define pthreadpool_compute_4d_tiled legacy_pthreadpool_compute_4d_tiled -// Returns the copied threadpool impl of third-party/pthreadpool -pthreadpool_t pthreadpool_create_xnnpack(size_t threads_count); - -// Copied third-party impl. -size_t pthreadpool_get_threads_count_xnnpack(pthreadpool_t threadpool); - -// Copied third-party impl. -void pthreadpool_destroy_xnnpack(pthreadpool_t threadpool); - -/** - * Processes items in parallel using threads from a thread pool. - * - * When the call returns, all items have been processed and the thread pool is - * ready for a new task. - * - * @note If multiple threads call this function with the same thread pool, the - * calls are serialized. - * - * @param[in] threadpool The thread pool to use for parallelisation. - * @param[in] function The function to call for each item. - * @param[in] argument The first argument passed to the @a function. - * @param[in] items The number of items to process. The @a function - * will be called once for each item. - */ -void pthreadpool_parallelize_1d( - pthreadpool_t threadpool, - pthreadpool_task_1d_t function, - void* argument, - size_t range, - uint32_t flags); - -void pthreadpool_parallelize_1d_tile_1d( - pthreadpool_t threadpool, - pthreadpool_task_1d_tile_1d_t function, - void* argument, - size_t range, - size_t tile, - uint32_t flags); - -void pthreadpool_parallelize_2d( - pthreadpool_t threadpool, - pthreadpool_task_2d_t function, - void* argument, - size_t range_i, - size_t range_j, - uint32_t flags); - -void pthreadpool_parallelize_2d_tile_1d( - pthreadpool_t threadpool, - pthreadpool_task_2d_tile_1d_t function, - void* argument, - size_t range_i, - size_t range_j, - size_t tile_j, - uint32_t flags); - -void pthreadpool_parallelize_2d_tile_2d( - pthreadpool_t threadpool, - pthreadpool_task_2d_tile_2d_t function, - void* argument, - size_t range_i, - size_t range_j, - size_t tile_i, - size_t tile_j, - uint32_t flags); - -void pthreadpool_parallelize_3d_tile_2d( - pthreadpool_t threadpool, - pthreadpool_task_3d_tile_2d_t function, - void* argument, - size_t range_i, - size_t range_j, - size_t range_k, - size_t tile_j, - size_t tile_k, - uint32_t flags); - -void pthreadpool_parallelize_4d_tile_2d( - pthreadpool_t threadpool, - pthreadpool_task_4d_tile_2d_t function, - void* argument, - size_t range_i, - size_t range_j, - size_t range_k, - size_t range_l, - size_t tile_k, - size_t tile_l, - uint32_t flags); - -void pthreadpool_parallelize_5d_tile_2d( - pthreadpool_t threadpool, - pthreadpool_task_5d_tile_2d_t function, - void* argument, - size_t range_i, - size_t range_j, - size_t range_k, - size_t range_l, - size_t range_m, - size_t tile_l, - size_t tile_m, - uint32_t flags); - -void pthreadpool_parallelize_6d_tile_2d( - pthreadpool_t threadpool, - pthreadpool_task_6d_tile_2d_t function, - void* argument, - size_t range_i, - size_t range_j, - size_t range_k, - size_t range_l, - size_t range_m, - size_t range_n, - size_t tile_m, - size_t tile_n, - uint32_t flags); +#endif /* USE_INTERNAL_PTHREADPOOL_IMPL */ #ifdef __cplusplus } /* extern "C" */ diff --git a/caffe2/utils/threadpool/pthreadpool_impl.cc b/caffe2/utils/threadpool/pthreadpool_impl.cc index 3b284e50b9cd..66326eef7a7b 100644 --- a/caffe2/utils/threadpool/pthreadpool_impl.cc +++ b/caffe2/utils/threadpool/pthreadpool_impl.cc @@ -6,9 +6,9 @@ // External API // -void pthreadpool_compute_1d( - pthreadpool_t threadpool, - pthreadpool_function_1d_t function, +void legacy_pthreadpool_compute_1d( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_1d_t function, void* argument, size_t range) { if (threadpool == nullptr) { @@ -27,30 +27,31 @@ void pthreadpool_compute_1d( range); } -size_t pthreadpool_get_threads_count(pthreadpool_t threadpool) { - // The current fix only useful when XNNPACK calls pthreadpool_get_threads_count with nullptr. +void legacy_pthreadpool_parallelize_1d( + const legacy_pthreadpool_t threadpool, + const legacy_pthreadpool_function_1d_t function, + void* const argument, + const size_t range, + uint32_t) { + legacy_pthreadpool_compute_1d(threadpool, function, argument, range); +} + +size_t legacy_pthreadpool_get_threads_count(legacy_pthreadpool_t threadpool) { + // The current fix only useful when XNNPACK calls legacy_pthreadpool_get_threads_count with nullptr. if (threadpool == nullptr) { return 1; } return reinterpret_cast(threadpool)->getNumThreads(); - // TODO: Future fix: If we keep maintaining two different threadpools. - // Old C2 and new one for XNNPACK, then the we have two different pthreadpool pointer - // types. One is caffe2::Thredpool*, the other is pthreadpool* (pthreadpool_new_if_impl.c) - // XNNPACK calls pthreadpool_get_threads_count during op setup using pthreadpool*, and - // uses _parallelize_ interface for for actual work. - // While NNPACK uses caffe2::Threadpool*. - // Thus if pthreadpool_get_threads_count is getting called from XNNPACK we cannot - // reinterpret_cast it to ThreadPool. It will seg fault or worse will have unedfined behavior. } -pthreadpool_t pthreadpool_create(size_t threads_count) { +legacy_pthreadpool_t legacy_pthreadpool_create(size_t threads_count) { std::mutex thread_pool_creation_mutex_; std::lock_guard guard(thread_pool_creation_mutex_); - return reinterpret_cast(new caffe2::ThreadPool(threads_count)); + return reinterpret_cast(new caffe2::ThreadPool(threads_count)); } -void pthreadpool_destroy(pthreadpool_t pthreadpool) { +void legacy_pthreadpool_destroy(legacy_pthreadpool_t pthreadpool) { if (pthreadpool) { caffe2::ThreadPool* threadpool = reinterpret_cast(pthreadpool); diff --git a/caffe2/utils/threadpool/pthreadpool_new_if_impl.c b/caffe2/utils/threadpool/pthreadpool_new_if_impl.c deleted file mode 100644 index 6b2bcf14b394..000000000000 --- a/caffe2/utils/threadpool/pthreadpool_new_if_impl.c +++ /dev/null @@ -1,1209 +0,0 @@ -/* Standard C headers */ -#include -#include -#include -#include -#include - -/* POSIX headers */ -#include -#include - -/* Futex-specific headers */ -#ifndef PTHREADPOOL_USE_FUTEX - #if defined(__linux__) - #define PTHREADPOOL_USE_FUTEX 1 - #include - #include - - /* Old Android NDKs do not define SYS_futex and FUTEX_PRIVATE_FLAG */ - #ifndef SYS_futex - #define SYS_futex __NR_futex - #endif - #ifndef FUTEX_PRIVATE_FLAG - #define FUTEX_PRIVATE_FLAG 128 - #endif - #elif defined(__native_client__) - #define PTHREADPOOL_USE_FUTEX 1 - #include - #else - #define PTHREADPOOL_USE_FUTEX 0 - #endif -#endif - -/* Dependencies */ -#include - -/* Library header */ -#include "caffe2/utils/threadpool/pthreadpool.h" - -/* Internal headers */ -#include "caffe2/utils/threadpool/pthreadpool_utils_new_if.h" - -/* Number of iterations in spin-wait loop before going into futex/mutex wait */ -#define PTHREADPOOL_SPIN_WAIT_ITERATIONS 1000000 - -#define PTHREADPOOL_CACHELINE_SIZE 64 -#define PTHREADPOOL_CACHELINE_ALIGNED __attribute__((__aligned__(PTHREADPOOL_CACHELINE_SIZE))) - -#if defined(__clang__) - #if __has_extension(c_static_assert) || __has_feature(c_static_assert) - #define PTHREADPOOL_STATIC_ASSERT(predicate, message) _Static_assert((predicate), message) - #else - #define PTHREADPOOL_STATIC_ASSERT(predicate, message) - #endif -#elif defined(__GNUC__) && ((__GNUC__ > 4) || (__GNUC__ == 4) && (__GNUC_MINOR__ >= 6)) - /* Static assert is supported by gcc >= 4.6 */ - #define PTHREADPOOL_STATIC_ASSERT(predicate, message) _Static_assert((predicate), message) -#else - #define PTHREADPOOL_STATIC_ASSERT(predicate, message) -#endif - -static inline size_t multiply_divide(size_t a, size_t b, size_t d) { - #if defined(__SIZEOF_SIZE_T__) && (__SIZEOF_SIZE_T__ == 4) - return (size_t) (((uint64_t) a) * ((uint64_t) b)) / ((uint64_t) d); - #elif defined(__SIZEOF_SIZE_T__) && (__SIZEOF_SIZE_T__ == 8) - return (size_t) (((__uint128_t) a) * ((__uint128_t) b)) / ((__uint128_t) d); - #else - #error "Unsupported platform" - #endif -} - -static inline size_t divide_round_up(size_t dividend, size_t divisor) { - if (dividend % divisor == 0) { - return dividend / divisor; - } else { - return dividend / divisor + 1; - } -} - -static inline size_t min(size_t a, size_t b) { - return a < b ? a : b; -} - -#if PTHREADPOOL_USE_FUTEX - #if defined(__linux__) - static int futex_wait(_Atomic uint32_t* address, uint32_t value) { - return syscall(SYS_futex, address, FUTEX_WAIT | FUTEX_PRIVATE_FLAG, value, NULL); - } - - static int futex_wake_all(_Atomic uint32_t* address) { - return syscall(SYS_futex, address, FUTEX_WAKE | FUTEX_PRIVATE_FLAG, INT_MAX); - } - #elif defined(__native_client__) - static struct nacl_irt_futex nacl_irt_futex = { 0 }; - static pthread_once_t nacl_init_guard = PTHREAD_ONCE_INIT; - static void nacl_init(void) { - nacl_interface_query(NACL_IRT_FUTEX_v0_1, &nacl_irt_futex, sizeof(nacl_irt_futex)); - } - - static int futex_wait(_Atomic uint32_t* address, uint32_t value) { - return nacl_irt_futex.futex_wait_abs((_Atomic int*) address, (int) value, NULL); - } - - static int futex_wake_all(_Atomic uint32_t* address) { - int count; - return nacl_irt_futex.futex_wake((_Atomic int*) address, INT_MAX, &count); - } - #else - #error "Platform-specific implementation of futex_wait and futex_wake_all required" - #endif -#endif - -#define THREADPOOL_COMMAND_MASK UINT32_C(0x7FFFFFFF) - -enum threadpool_command { - threadpool_command_init, - threadpool_command_compute_1d, - threadpool_command_shutdown, -}; - -struct PTHREADPOOL_CACHELINE_ALIGNED thread_info { - /** - * Index of the first element in the work range. - * Before processing a new element the owning worker thread increments this value. - */ - atomic_size_t range_start; - /** - * Index of the element after the last element of the work range. - * Before processing a new element the stealing worker thread decrements this value. - */ - atomic_size_t range_end; - /** - * The number of elements in the work range. - * Due to race conditions range_length <= range_end - range_start. - * The owning worker thread must decrement this value before incrementing @a range_start. - * The stealing worker thread must decrement this value before decrementing @a range_end. - */ - atomic_size_t range_length; - /** - * Thread number in the 0..threads_count-1 range. - */ - size_t thread_number; - /** - * The pthread object corresponding to the thread. - */ - pthread_t thread_object; - /** - * Condition variable used to wake up the thread. - * When the thread is idle, it waits on this condition variable. - */ - pthread_cond_t wakeup_condvar; -}; - -PTHREADPOOL_STATIC_ASSERT(sizeof(struct thread_info) % PTHREADPOOL_CACHELINE_SIZE == 0, "thread_info structure must occupy an integer number of cache lines (64 bytes)"); - -struct PTHREADPOOL_CACHELINE_ALIGNED pthreadpool { - /** - * The number of threads that are processing an operation. - */ - atomic_size_t active_threads; -#if PTHREADPOOL_USE_FUTEX - /** - * Indicates if there are active threads. - * Only two values are possible: - * - has_active_threads == 0 if active_threads == 0 - * - has_active_threads == 1 if active_threads != 0 - */ - _Atomic uint32_t has_active_threads; -#endif - /** - * The last command submitted to the thread pool. - */ - _Atomic uint32_t command; - /** - * The function to call for each item. - */ - void *_Atomic task; - /** - * The first argument to the item processing function. - */ - void *_Atomic argument; - /** - * Copy of the flags passed to parallelization function. - */ - _Atomic uint32_t flags; - /** - * Serializes concurrent calls to @a pthreadpool_parallelize_* from different threads. - */ - pthread_mutex_t execution_mutex; -#if !PTHREADPOOL_USE_FUTEX - /** - * Guards access to the @a active_threads variable. - */ - pthread_mutex_t completion_mutex; - /** - * Condition variable to wait until all threads complete an operation (until @a active_threads is zero). - */ - pthread_cond_t completion_condvar; - /** - * Guards access to the @a command variable. - */ - pthread_mutex_t command_mutex; - /** - * Condition variable to wait for change of the @a command variable. - */ - pthread_cond_t command_condvar; -#endif - /** - * The number of threads in the thread pool. Never changes after initialization. - */ - size_t threads_count; - /** - * Thread information structures that immediately follow this structure. - */ - struct thread_info threads[]; -}; - -PTHREADPOOL_STATIC_ASSERT(sizeof(struct pthreadpool) % PTHREADPOOL_CACHELINE_SIZE == 0, "pthreadpool structure must occupy an integer number of cache lines (64 bytes)"); - -static void checkin_worker_thread(struct pthreadpool* threadpool) { - #if PTHREADPOOL_USE_FUTEX - if (atomic_fetch_sub_explicit(&threadpool->active_threads, 1, memory_order_relaxed) == 1) { - atomic_store_explicit(&threadpool->has_active_threads, 0, memory_order_release); - futex_wake_all(&threadpool->has_active_threads); - } - #else - pthread_mutex_lock(&threadpool->completion_mutex); - if (atomic_fetch_sub_explicit(&threadpool->active_threads, 1, memory_order_relaxed) == 1) { - pthread_cond_signal(&threadpool->completion_condvar); - } - pthread_mutex_unlock(&threadpool->completion_mutex); - #endif -} - -static void wait_worker_threads(struct pthreadpool* threadpool) { - /* Initial check */ - #if PTHREADPOOL_USE_FUTEX - uint32_t has_active_threads = atomic_load_explicit(&threadpool->has_active_threads, memory_order_relaxed); - if (has_active_threads == 0) { - return; - } - #else - size_t active_threads = atomic_load_explicit(&threadpool->active_threads, memory_order_relaxed); - if (active_threads == 0) { - return; - } - #endif - - /* Spin-wait */ - for (uint32_t i = PTHREADPOOL_SPIN_WAIT_ITERATIONS; i != 0; i--) { - /* This fence serves as a sleep instruction */ - atomic_thread_fence(memory_order_acquire); - - #if PTHREADPOOL_USE_FUTEX - has_active_threads = atomic_load_explicit(&threadpool->has_active_threads, memory_order_relaxed); - if (has_active_threads == 0) { - return; - } - #else - active_threads = atomic_load_explicit(&threadpool->active_threads, memory_order_relaxed); - if (active_threads == 0) { - return; - } - #endif - } - - /* Fall-back to mutex/futex wait */ - #if PTHREADPOOL_USE_FUTEX - while ((has_active_threads = atomic_load(&threadpool->has_active_threads)) != 0) { - futex_wait(&threadpool->has_active_threads, 1); - } - #else - pthread_mutex_lock(&threadpool->completion_mutex); - while (atomic_load_explicit(&threadpool->active_threads, memory_order_relaxed) != 0) { - pthread_cond_wait(&threadpool->completion_condvar, &threadpool->completion_mutex); - }; - pthread_mutex_unlock(&threadpool->completion_mutex); - #endif -} - -inline static bool atomic_decrement(atomic_size_t* value) { - size_t actual_value = atomic_load_explicit(value, memory_order_relaxed); - if (actual_value == 0) { - return false; - } - while (!atomic_compare_exchange_weak_explicit( - value, &actual_value, actual_value - 1, memory_order_relaxed, memory_order_relaxed)) - { - if (actual_value == 0) { - return false; - } - } - return true; -} - -inline static size_t modulo_decrement(uint32_t i, uint32_t n) { - /* Wrap modulo n, if needed */ - if (i == 0) { - i = n; - } - /* Decrement input variable */ - return i - 1; -} - -static void thread_parallelize_1d(struct pthreadpool* threadpool, struct thread_info* thread) { - const pthreadpool_task_1d_t task = (pthreadpool_task_1d_t) atomic_load_explicit(&threadpool->task, memory_order_relaxed); - void *const argument = atomic_load_explicit(&threadpool->argument, memory_order_relaxed); - /* Process thread's own range of items */ - size_t range_start = atomic_load_explicit(&thread->range_start, memory_order_relaxed); - while (atomic_decrement(&thread->range_length)) { - task(argument, range_start++); - } - - /* There still may be other threads with work */ - const size_t thread_number = thread->thread_number; - const size_t threads_count = threadpool->threads_count; - for (size_t tid = modulo_decrement(thread_number, threads_count); - tid != thread_number; - tid = modulo_decrement(tid, threads_count)) - { - struct thread_info* other_thread = &threadpool->threads[tid]; - while (atomic_decrement(&other_thread->range_length)) { - const size_t item_id = atomic_fetch_sub_explicit(&other_thread->range_end, 1, memory_order_relaxed) - 1; - task(argument, item_id); - } - } - atomic_thread_fence(memory_order_release); -} - -static uint32_t wait_for_new_command( - struct pthreadpool* threadpool, - uint32_t last_command) -{ - uint32_t command = atomic_load_explicit(&threadpool->command, memory_order_relaxed); - if (command != last_command) { - atomic_thread_fence(memory_order_acquire); - return command; - } - - /* Spin-wait loop */ - for (uint32_t i = PTHREADPOOL_SPIN_WAIT_ITERATIONS; i != 0; i--) { - /* This fence serves as a sleep instruction */ - atomic_thread_fence(memory_order_acquire); - - command = atomic_load_explicit(&threadpool->command, memory_order_relaxed); - if (command != last_command) { - atomic_thread_fence(memory_order_acquire); - return command; - } - } - - /* Spin-wait timed out, fall back to mutex/futex wait */ - #if PTHREADPOOL_USE_FUTEX - do { - futex_wait(&threadpool->command, last_command); - command = atomic_load_explicit(&threadpool->command, memory_order_relaxed); - } while (command == last_command); - #else - /* Lock the command mutex */ - pthread_mutex_lock(&threadpool->command_mutex); - /* Read the command */ - while ((command = atomic_load_explicit(&threadpool->command, memory_order_relaxed)) == last_command) { - /* Wait for new command */ - pthread_cond_wait(&threadpool->command_condvar, &threadpool->command_mutex); - } - /* Read a new command */ - pthread_mutex_unlock(&threadpool->command_mutex); - #endif - atomic_thread_fence(memory_order_acquire); - return command; -} - -static void* thread_main(void* arg) { - struct thread_info* thread = (struct thread_info*) arg; - struct pthreadpool* threadpool = ((struct pthreadpool*) (thread - thread->thread_number)) - 1; - uint32_t last_command = threadpool_command_init; - struct fpu_state saved_fpu_state = { 0 }; - - /* Check in */ - checkin_worker_thread(threadpool); - - /* Monitor new commands and act accordingly */ - for (;;) { - uint32_t command = wait_for_new_command(threadpool, last_command); - const uint32_t flags = atomic_load_explicit(&threadpool->flags, memory_order_relaxed); - - /* Process command */ - switch (command & THREADPOOL_COMMAND_MASK) { - case threadpool_command_compute_1d: - { - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - saved_fpu_state = get_fpu_state(); - disable_fpu_denormals(); - } - thread_parallelize_1d(threadpool, thread); - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - set_fpu_state(saved_fpu_state); - } - break; - } - case threadpool_command_shutdown: - /* Exit immediately: the master thread is waiting on pthread_join */ - return NULL; - case threadpool_command_init: - /* To inhibit compiler warning */ - break; - } - /* Notify the master thread that we finished processing */ - checkin_worker_thread(threadpool); - /* Update last command */ - last_command = command; - }; -} - -static struct pthreadpool* pthreadpool_allocate(size_t threads_count) { - const size_t threadpool_size = sizeof(struct pthreadpool) + threads_count * sizeof(struct thread_info); - struct pthreadpool* threadpool = NULL; - #if defined(__ANDROID__) - /* - * Android didn't get posix_memalign until API level 17 (Android 4.2). - * Use (otherwise obsolete) memalign function on Android platform. - */ - threadpool = memalign(PTHREADPOOL_CACHELINE_SIZE, threadpool_size); - if (threadpool == NULL) { - return NULL; - } - #else - if (posix_memalign((void**) &threadpool, PTHREADPOOL_CACHELINE_SIZE, threadpool_size) != 0) { - return NULL; - } - #endif - memset(threadpool, 0, threadpool_size); - return threadpool; -} - -struct pthreadpool* pthreadpool_create_xnnpack(size_t threads_count) { -#if defined(__native_client__) - pthread_once(&nacl_init_guard, nacl_init); -#endif - - if (threads_count == 0) { - threads_count = (size_t) sysconf(_SC_NPROCESSORS_ONLN); - } - struct pthreadpool* threadpool = pthreadpool_allocate(threads_count); - if (threadpool == NULL) { - return NULL; - } - threadpool->threads_count = threads_count; - for (size_t tid = 0; tid < threads_count; tid++) { - threadpool->threads[tid].thread_number = tid; - } - - /* Thread pool with a single thread computes everything on the caller thread. */ - if (threads_count > 1) { - pthread_mutex_init(&threadpool->execution_mutex, NULL); - #if !PTHREADPOOL_USE_FUTEX - pthread_mutex_init(&threadpool->completion_mutex, NULL); - pthread_cond_init(&threadpool->completion_condvar, NULL); - pthread_mutex_init(&threadpool->command_mutex, NULL); - pthread_cond_init(&threadpool->command_condvar, NULL); - #endif - - #if PTHREADPOOL_USE_FUTEX - atomic_store_explicit(&threadpool->has_active_threads, 1, memory_order_relaxed); - #endif - atomic_store_explicit( - &threadpool->active_threads, threadpool->threads_count - 1 /* caller thread */, memory_order_release); - - /* Caller thread serves as worker #0. Thus, we create system threads starting with worker #1. */ - for (size_t tid = 1; tid < threads_count; tid++) { - pthread_create(&threadpool->threads[tid].thread_object, NULL, &thread_main, &threadpool->threads[tid]); - } - - /* Wait until all threads initialize */ - wait_worker_threads(threadpool); - } - return threadpool; -} - -size_t pthreadpool_get_threads_count_xnnpack(struct pthreadpool* threadpool) { - if (threadpool == NULL) { - return 1; - } else { - return threadpool->threads_count; - } -} - -void pthreadpool_parallelize_1d( - struct pthreadpool* threadpool, - pthreadpool_task_1d_t task, - void* argument, - size_t range, - uint32_t flags) -{ - if (threadpool == NULL || threadpool->threads_count <= 1) { - /* No thread pool used: execute task sequentially on the calling thread */ - struct fpu_state saved_fpu_state = { 0 }; - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - saved_fpu_state = get_fpu_state(); - disable_fpu_denormals(); - } - for (size_t i = 0; i < range; i++) { - task(argument, i); - } - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - set_fpu_state(saved_fpu_state); - } - } else { - /* Protect the global threadpool structures */ - pthread_mutex_lock(&threadpool->execution_mutex); - - #if !PTHREADPOOL_USE_FUTEX - /* Lock the command variables to ensure that threads don't start processing before they observe complete command with all arguments */ - pthread_mutex_lock(&threadpool->command_mutex); - #endif - - /* Setup global arguments */ - atomic_store_explicit(&threadpool->task, task, memory_order_relaxed); - atomic_store_explicit(&threadpool->argument, argument, memory_order_relaxed); - atomic_store_explicit(&threadpool->flags, flags, memory_order_relaxed); - - /* Locking of completion_mutex not needed: readers are sleeping on command_condvar */ - atomic_store_explicit( - &threadpool->active_threads, threadpool->threads_count - 1 /* caller thread */, memory_order_relaxed); - #if PTHREADPOOL_USE_FUTEX - atomic_store_explicit(&threadpool->has_active_threads, 1, memory_order_relaxed); - #endif - - /* Spread the work between threads */ - for (size_t tid = 0; tid < threadpool->threads_count; tid++) { - struct thread_info* thread = &threadpool->threads[tid]; - const size_t range_start = multiply_divide(range, tid, threadpool->threads_count); - const size_t range_end = multiply_divide(range, tid + 1, threadpool->threads_count); - atomic_store_explicit(&thread->range_start, range_start, memory_order_relaxed); - atomic_store_explicit(&thread->range_end, range_end, memory_order_relaxed); - atomic_store_explicit(&thread->range_length, range_end - range_start, memory_order_relaxed); - } - - #if PTHREADPOOL_USE_FUTEX - /* - * Make new command parameters globally visible. Having this fence before updating the command is imporatnt: it - * guarantees that if a worker thread observes new command value, it also observes the updated command parameters. - */ - atomic_thread_fence(memory_order_release); - #endif - - /* - * Update the threadpool command. - * Imporantly, do it after initializing command parameters (range, task, argument) - * ~(threadpool->command | THREADPOOL_COMMAND_MASK) flips the bits not in command mask - * to ensure the unmasked command is different then the last command, because worker threads - * monitor for change in the unmasked command. - */ - const uint32_t old_command = atomic_load_explicit(&threadpool->command, memory_order_relaxed); - const uint32_t new_command = ~(old_command | THREADPOOL_COMMAND_MASK) | threadpool_command_compute_1d; - - #if PTHREADPOOL_USE_FUTEX - atomic_store_explicit(&threadpool->command, new_command, memory_order_release); - - /* Wake up the threads */ - futex_wake_all(&threadpool->command); - #else - atomic_store_explicit(&threadpool->command, new_command, memory_order_relaxed); - - /* Unlock the command variables before waking up the threads for better performance */ - pthread_mutex_unlock(&threadpool->command_mutex); - - /* Wake up the threads */ - pthread_cond_broadcast(&threadpool->command_condvar); - #endif - - /* Save and modify FPU denormals control, if needed */ - struct fpu_state saved_fpu_state = { 0 }; - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - saved_fpu_state = get_fpu_state(); - disable_fpu_denormals(); - } - - /* Do computations as worker #0 */ - thread_parallelize_1d(threadpool, &threadpool->threads[0]); - - /* Restore FPU denormals control, if needed */ - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - set_fpu_state(saved_fpu_state); - } - - /* Wait until the threads finish computation */ - wait_worker_threads(threadpool); - - /* Make changes by other threads visible to this thread */ - atomic_thread_fence(memory_order_acquire); - - /* Unprotect the global threadpool structures */ - pthread_mutex_unlock(&threadpool->execution_mutex); - } -} - -struct compute_1d_tile_1d_context { - pthreadpool_task_1d_tile_1d_t task; - void* argument; - size_t range; - size_t tile; -}; - -static void compute_1d_tile_1d(const struct compute_1d_tile_1d_context* context, size_t linear_index) { - const size_t tile_index = linear_index; - const size_t index = tile_index * context->tile; - const size_t tile = min(context->tile, context->range - index); - context->task(context->argument, index, tile); -} - -void pthreadpool_parallelize_1d_tile_1d( - pthreadpool_t threadpool, - pthreadpool_task_1d_tile_1d_t task, - void* argument, - size_t range, - size_t tile, - uint32_t flags) -{ - if (threadpool == NULL || threadpool->threads_count <= 1) { - /* No thread pool used: execute task sequentially on the calling thread */ - struct fpu_state saved_fpu_state = { 0 }; - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - saved_fpu_state = get_fpu_state(); - disable_fpu_denormals(); - } - for (size_t i = 0; i < range; i += tile) { - task(argument, i, min(range - i, tile)); - } - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - set_fpu_state(saved_fpu_state); - } - } else { - /* Execute in parallel on the thread pool using linearized index */ - const size_t tile_range = divide_round_up(range, tile); - struct compute_1d_tile_1d_context context = { - .task = task, - .argument = argument, - .range = range, - .tile = tile - }; - pthreadpool_parallelize_1d(threadpool, (pthreadpool_task_1d_t) compute_1d_tile_1d, &context, tile_range, flags); - } -} - -struct compute_2d_context { - pthreadpool_task_2d_t task; - void* argument; - struct fxdiv_divisor_size_t range_j; -}; - -static void compute_2d(const struct compute_2d_context* context, size_t linear_index) { - const struct fxdiv_divisor_size_t range_j = context->range_j; - const struct fxdiv_result_size_t index = fxdiv_divide_size_t(linear_index, range_j); - context->task(context->argument, index.quotient, index.remainder); -} - -void pthreadpool_parallelize_2d( - struct pthreadpool* threadpool, - pthreadpool_task_2d_t task, - void* argument, - size_t range_i, - size_t range_j, - uint32_t flags) -{ - if (threadpool == NULL || threadpool->threads_count <= 1) { - /* No thread pool used: execute task sequentially on the calling thread */ - struct fpu_state saved_fpu_state = { 0 }; - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - saved_fpu_state = get_fpu_state(); - disable_fpu_denormals(); - } - for (size_t i = 0; i < range_i; i++) { - for (size_t j = 0; j < range_j; j++) { - task(argument, i, j); - } - } - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - set_fpu_state(saved_fpu_state); - } - } else { - /* Execute in parallel on the thread pool using linearized index */ - struct compute_2d_context context = { - .task = task, - .argument = argument, - .range_j = fxdiv_init_size_t(range_j) - }; - pthreadpool_parallelize_1d(threadpool, (pthreadpool_task_1d_t) compute_2d, &context, range_i * range_j, flags); - } -} - -struct compute_2d_tile_1d_context { - pthreadpool_task_2d_tile_1d_t task; - void* argument; - struct fxdiv_divisor_size_t tile_range_j; - size_t range_i; - size_t range_j; - size_t tile_j; -}; - -static void compute_2d_tile_1d(const struct compute_2d_tile_1d_context* context, size_t linear_index) { - const struct fxdiv_divisor_size_t tile_range_j = context->tile_range_j; - const struct fxdiv_result_size_t tile_index = fxdiv_divide_size_t(linear_index, tile_range_j); - const size_t max_tile_j = context->tile_j; - const size_t index_i = tile_index.quotient; - const size_t index_j = tile_index.remainder * max_tile_j; - const size_t tile_j = min(max_tile_j, context->range_j - index_j); - context->task(context->argument, index_i, index_j, tile_j); -} - -void pthreadpool_parallelize_2d_tile_1d( - pthreadpool_t threadpool, - pthreadpool_task_2d_tile_1d_t task, - void* argument, - size_t range_i, - size_t range_j, - size_t tile_j, - uint32_t flags) -{ - if (threadpool == NULL || threadpool->threads_count <= 1) { - /* No thread pool used: execute task sequentially on the calling thread */ - struct fpu_state saved_fpu_state = { 0 }; - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - saved_fpu_state = get_fpu_state(); - disable_fpu_denormals(); - } - for (size_t i = 0; i < range_i; i++) { - for (size_t j = 0; j < range_j; j += tile_j) { - task(argument, i, j, min(range_j - j, tile_j)); - } - } - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - set_fpu_state(saved_fpu_state); - } - } else { - /* Execute in parallel on the thread pool using linearized index */ - const size_t tile_range_j = divide_round_up(range_j, tile_j); - struct compute_2d_tile_1d_context context = { - .task = task, - .argument = argument, - .tile_range_j = fxdiv_init_size_t(tile_range_j), - .range_i = range_i, - .range_j = range_j, - .tile_j = tile_j - }; - pthreadpool_parallelize_1d(threadpool, (pthreadpool_task_1d_t) compute_2d_tile_1d, &context, range_i * tile_range_j, flags); - } -} - -struct compute_2d_tile_2d_context { - pthreadpool_task_2d_tile_2d_t task; - void* argument; - struct fxdiv_divisor_size_t tile_range_j; - size_t range_i; - size_t range_j; - size_t tile_i; - size_t tile_j; -}; - -static void compute_2d_tile_2d(const struct compute_2d_tile_2d_context* context, size_t linear_index) { - const struct fxdiv_divisor_size_t tile_range_j = context->tile_range_j; - const struct fxdiv_result_size_t tile_index = fxdiv_divide_size_t(linear_index, tile_range_j); - const size_t max_tile_i = context->tile_i; - const size_t max_tile_j = context->tile_j; - const size_t index_i = tile_index.quotient * max_tile_i; - const size_t index_j = tile_index.remainder * max_tile_j; - const size_t tile_i = min(max_tile_i, context->range_i - index_i); - const size_t tile_j = min(max_tile_j, context->range_j - index_j); - context->task(context->argument, index_i, index_j, tile_i, tile_j); -} - -void pthreadpool_parallelize_2d_tile_2d( - pthreadpool_t threadpool, - pthreadpool_task_2d_tile_2d_t task, - void* argument, - size_t range_i, - size_t range_j, - size_t tile_i, - size_t tile_j, - uint32_t flags) -{ - if (threadpool == NULL || threadpool->threads_count <= 1) { - /* No thread pool used: execute task sequentially on the calling thread */ - struct fpu_state saved_fpu_state = { 0 }; - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - saved_fpu_state = get_fpu_state(); - disable_fpu_denormals(); - } - for (size_t i = 0; i < range_i; i += tile_i) { - for (size_t j = 0; j < range_j; j += tile_j) { - task(argument, i, j, min(range_i - i, tile_i), min(range_j - j, tile_j)); - } - } - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - set_fpu_state(saved_fpu_state); - } - } else { - /* Execute in parallel on the thread pool using linearized index */ - const size_t tile_range_i = divide_round_up(range_i, tile_i); - const size_t tile_range_j = divide_round_up(range_j, tile_j); - struct compute_2d_tile_2d_context context = { - .task = task, - .argument = argument, - .tile_range_j = fxdiv_init_size_t(tile_range_j), - .range_i = range_i, - .range_j = range_j, - .tile_i = tile_i, - .tile_j = tile_j - }; - pthreadpool_parallelize_1d(threadpool, (pthreadpool_task_1d_t) compute_2d_tile_2d, &context, tile_range_i * tile_range_j, flags); - } -} - -struct compute_3d_tile_2d_context { - pthreadpool_task_3d_tile_2d_t task; - void* argument; - struct fxdiv_divisor_size_t tile_range_j; - struct fxdiv_divisor_size_t tile_range_k; - size_t range_j; - size_t range_k; - size_t tile_j; - size_t tile_k; -}; - -static void compute_3d_tile_2d(const struct compute_3d_tile_2d_context* context, size_t linear_index) { - const struct fxdiv_divisor_size_t tile_range_k = context->tile_range_k; - const struct fxdiv_result_size_t tile_index_ij_k = fxdiv_divide_size_t(linear_index, tile_range_k); - const struct fxdiv_divisor_size_t tile_range_j = context->tile_range_j; - const struct fxdiv_result_size_t tile_index_i_j = fxdiv_divide_size_t(tile_index_ij_k.quotient, tile_range_j); - const size_t max_tile_j = context->tile_j; - const size_t max_tile_k = context->tile_k; - const size_t index_i = tile_index_i_j.quotient; - const size_t index_j = tile_index_i_j.remainder * max_tile_j; - const size_t index_k = tile_index_ij_k.remainder * max_tile_k; - const size_t tile_j = min(max_tile_j, context->range_j - index_j); - const size_t tile_k = min(max_tile_k, context->range_k - index_k); - context->task(context->argument, index_i, index_j, index_k, tile_j, tile_k); -} - -void pthreadpool_parallelize_3d_tile_2d( - pthreadpool_t threadpool, - pthreadpool_task_3d_tile_2d_t task, - void* argument, - size_t range_i, - size_t range_j, - size_t range_k, - size_t tile_j, - size_t tile_k, - uint32_t flags) -{ - if (threadpool == NULL || threadpool->threads_count <= 1) { - /* No thread pool used: execute task sequentially on the calling thread */ - struct fpu_state saved_fpu_state = { 0 }; - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - saved_fpu_state = get_fpu_state(); - disable_fpu_denormals(); - } - for (size_t i = 0; i < range_i; i++) { - for (size_t j = 0; j < range_j; j += tile_j) { - for (size_t k = 0; k < range_k; k += tile_k) { - task(argument, i, j, k, min(range_j - j, tile_j), min(range_k - k, tile_k)); - } - } - } - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - set_fpu_state(saved_fpu_state); - } - } else { - /* Execute in parallel on the thread pool using linearized index */ - const size_t tile_range_j = divide_round_up(range_j, tile_j); - const size_t tile_range_k = divide_round_up(range_k, tile_k); - struct compute_3d_tile_2d_context context = { - .task = task, - .argument = argument, - .tile_range_j = fxdiv_init_size_t(tile_range_j), - .tile_range_k = fxdiv_init_size_t(tile_range_k), - .range_j = range_j, - .range_k = range_k, - .tile_j = tile_j, - .tile_k = tile_k - }; - pthreadpool_parallelize_1d(threadpool, - (pthreadpool_task_1d_t) compute_3d_tile_2d, &context, - range_i * tile_range_j * tile_range_k, flags); - } -} - -struct compute_4d_tile_2d_context { - pthreadpool_task_4d_tile_2d_t task; - void* argument; - struct fxdiv_divisor_size_t tile_range_kl; - struct fxdiv_divisor_size_t range_j; - struct fxdiv_divisor_size_t tile_range_l; - size_t range_k; - size_t range_l; - size_t tile_k; - size_t tile_l; -}; - -static void compute_4d_tile_2d(const struct compute_4d_tile_2d_context* context, size_t linear_index) { - const struct fxdiv_divisor_size_t tile_range_kl = context->tile_range_kl; - const struct fxdiv_result_size_t tile_index_ij_kl = fxdiv_divide_size_t(linear_index, tile_range_kl); - const struct fxdiv_divisor_size_t range_j = context->range_j; - const struct fxdiv_result_size_t tile_index_i_j = fxdiv_divide_size_t(tile_index_ij_kl.quotient, range_j); - const struct fxdiv_divisor_size_t tile_range_l = context->tile_range_l; - const struct fxdiv_result_size_t tile_index_k_l = fxdiv_divide_size_t(tile_index_ij_kl.remainder, tile_range_l); - const size_t max_tile_k = context->tile_k; - const size_t max_tile_l = context->tile_l; - const size_t index_i = tile_index_i_j.quotient; - const size_t index_j = tile_index_i_j.remainder; - const size_t index_k = tile_index_k_l.quotient * max_tile_k; - const size_t index_l = tile_index_k_l.remainder * max_tile_l; - const size_t tile_k = min(max_tile_k, context->range_k - index_k); - const size_t tile_l = min(max_tile_l, context->range_l - index_l); - context->task(context->argument, index_i, index_j, index_k, index_l, tile_k, tile_l); -} - -void pthreadpool_parallelize_4d_tile_2d( - pthreadpool_t threadpool, - pthreadpool_task_4d_tile_2d_t task, - void* argument, - size_t range_i, - size_t range_j, - size_t range_k, - size_t range_l, - size_t tile_k, - size_t tile_l, - uint32_t flags) -{ - if (threadpool == NULL || threadpool->threads_count <= 1) { - /* No thread pool used: execute task sequentially on the calling thread */ - struct fpu_state saved_fpu_state = { 0 }; - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - saved_fpu_state = get_fpu_state(); - disable_fpu_denormals(); - } - for (size_t i = 0; i < range_i; i++) { - for (size_t j = 0; j < range_j; j++) { - for (size_t k = 0; k < range_k; k += tile_k) { - for (size_t l = 0; l < range_l; l += tile_l) { - task(argument, i, j, k, l, - min(range_k - k, tile_k), min(range_l - l, tile_l)); - } - } - } - } - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - set_fpu_state(saved_fpu_state); - } - } else { - /* Execute in parallel on the thread pool using linearized index */ - const size_t tile_range_k = divide_round_up(range_k, tile_k); - const size_t tile_range_l = divide_round_up(range_l, tile_l); - struct compute_4d_tile_2d_context context = { - .task = task, - .argument = argument, - .tile_range_kl = fxdiv_init_size_t(tile_range_k * tile_range_l), - .range_j = fxdiv_init_size_t(range_j), - .tile_range_l = fxdiv_init_size_t(tile_range_l), - .range_k = range_k, - .range_l = range_l, - .tile_k = tile_k, - .tile_l = tile_l - }; - pthreadpool_parallelize_1d(threadpool, - (pthreadpool_task_1d_t) compute_4d_tile_2d, &context, - range_i * range_j * tile_range_k * tile_range_l, flags); - } -} - -struct compute_5d_tile_2d_context { - pthreadpool_task_5d_tile_2d_t task; - void* argument; - struct fxdiv_divisor_size_t tile_range_lm; - struct fxdiv_divisor_size_t range_k; - struct fxdiv_divisor_size_t tile_range_m; - struct fxdiv_divisor_size_t range_j; - size_t range_l; - size_t range_m; - size_t tile_l; - size_t tile_m; -}; - -static void compute_5d_tile_2d(const struct compute_5d_tile_2d_context* context, size_t linear_index) { - const struct fxdiv_divisor_size_t tile_range_lm = context->tile_range_lm; - const struct fxdiv_result_size_t tile_index_ijk_lm = fxdiv_divide_size_t(linear_index, tile_range_lm); - const struct fxdiv_divisor_size_t range_k = context->range_k; - const struct fxdiv_result_size_t tile_index_ij_k = fxdiv_divide_size_t(tile_index_ijk_lm.quotient, range_k); - const struct fxdiv_divisor_size_t tile_range_m = context->tile_range_m; - const struct fxdiv_result_size_t tile_index_l_m = fxdiv_divide_size_t(tile_index_ijk_lm.remainder, tile_range_m); - const struct fxdiv_divisor_size_t range_j = context->range_j; - const struct fxdiv_result_size_t tile_index_i_j = fxdiv_divide_size_t(tile_index_ij_k.quotient, range_j); - - const size_t max_tile_l = context->tile_l; - const size_t max_tile_m = context->tile_m; - const size_t index_i = tile_index_i_j.quotient; - const size_t index_j = tile_index_i_j.remainder; - const size_t index_k = tile_index_ij_k.remainder; - const size_t index_l = tile_index_l_m.quotient * max_tile_l; - const size_t index_m = tile_index_l_m.remainder * max_tile_m; - const size_t tile_l = min(max_tile_l, context->range_l - index_l); - const size_t tile_m = min(max_tile_m, context->range_m - index_m); - context->task(context->argument, index_i, index_j, index_k, index_l, index_m, tile_l, tile_m); -} - -void pthreadpool_parallelize_5d_tile_2d( - pthreadpool_t threadpool, - pthreadpool_task_5d_tile_2d_t task, - void* argument, - size_t range_i, - size_t range_j, - size_t range_k, - size_t range_l, - size_t range_m, - size_t tile_l, - size_t tile_m, - uint32_t flags) -{ - if (threadpool == NULL || threadpool->threads_count <= 1) { - /* No thread pool used: execute task sequentially on the calling thread */ - struct fpu_state saved_fpu_state = { 0 }; - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - saved_fpu_state = get_fpu_state(); - disable_fpu_denormals(); - } - for (size_t i = 0; i < range_i; i++) { - for (size_t j = 0; j < range_j; j++) { - for (size_t k = 0; k < range_k; k++) { - for (size_t l = 0; l < range_l; l += tile_l) { - for (size_t m = 0; m < range_m; m += tile_m) { - task(argument, i, j, k, l, m, - min(range_l - l, tile_l), min(range_m - m, tile_m)); - } - } - } - } - } - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - set_fpu_state(saved_fpu_state); - } - } else { - /* Execute in parallel on the thread pool using linearized index */ - const size_t tile_range_l = divide_round_up(range_l, tile_l); - const size_t tile_range_m = divide_round_up(range_m, tile_m); - struct compute_5d_tile_2d_context context = { - .task = task, - .argument = argument, - .tile_range_lm = fxdiv_init_size_t(tile_range_l * tile_range_m), - .range_k = fxdiv_init_size_t(range_k), - .tile_range_m = fxdiv_init_size_t(tile_range_m), - .range_j = fxdiv_init_size_t(range_j), - .range_l = range_l, - .range_m = range_m, - .tile_l = tile_l, - .tile_m = tile_m, - }; - pthreadpool_parallelize_1d(threadpool, - (pthreadpool_task_1d_t) compute_5d_tile_2d, &context, - range_i * range_j * range_k * tile_range_l * tile_range_m, flags); - } -} - -struct compute_6d_tile_2d_context { - pthreadpool_task_6d_tile_2d_t task; - void* argument; - struct fxdiv_divisor_size_t tile_range_lmn; - struct fxdiv_divisor_size_t range_k; - struct fxdiv_divisor_size_t tile_range_n; - struct fxdiv_divisor_size_t range_j; - struct fxdiv_divisor_size_t tile_range_m; - size_t range_m; - size_t range_n; - size_t tile_m; - size_t tile_n; -}; - -static void compute_6d_tile_2d(const struct compute_6d_tile_2d_context* context, size_t linear_index) { - const struct fxdiv_divisor_size_t tile_range_lmn = context->tile_range_lmn; - const struct fxdiv_result_size_t tile_index_ijk_lmn = fxdiv_divide_size_t(linear_index, tile_range_lmn); - const struct fxdiv_divisor_size_t range_k = context->range_k; - const struct fxdiv_result_size_t tile_index_ij_k = fxdiv_divide_size_t(tile_index_ijk_lmn.quotient, range_k); - const struct fxdiv_divisor_size_t tile_range_n = context->tile_range_n; - const struct fxdiv_result_size_t tile_index_lm_n = fxdiv_divide_size_t(tile_index_ijk_lmn.remainder, tile_range_n); - const struct fxdiv_divisor_size_t range_j = context->range_j; - const struct fxdiv_result_size_t tile_index_i_j = fxdiv_divide_size_t(tile_index_ij_k.quotient, range_j); - const struct fxdiv_divisor_size_t tile_range_m = context->tile_range_m; - const struct fxdiv_result_size_t tile_index_l_m = fxdiv_divide_size_t(tile_index_lm_n.quotient, tile_range_m); - - const size_t max_tile_m = context->tile_m; - const size_t max_tile_n = context->tile_n; - const size_t index_i = tile_index_i_j.quotient; - const size_t index_j = tile_index_i_j.remainder; - const size_t index_k = tile_index_ij_k.remainder; - const size_t index_l = tile_index_l_m.quotient; - const size_t index_m = tile_index_l_m.remainder * max_tile_m; - const size_t index_n = tile_index_lm_n.remainder * max_tile_n; - const size_t tile_m = min(max_tile_m, context->range_m - index_m); - const size_t tile_n = min(max_tile_n, context->range_n - index_n); - context->task(context->argument, index_i, index_j, index_k, index_l, index_m, index_n, tile_m, tile_n); -} - -void pthreadpool_parallelize_6d_tile_2d( - pthreadpool_t threadpool, - pthreadpool_task_6d_tile_2d_t task, - void* argument, - size_t range_i, - size_t range_j, - size_t range_k, - size_t range_l, - size_t range_m, - size_t range_n, - size_t tile_m, - size_t tile_n, - uint32_t flags) -{ - if (threadpool == NULL || threadpool->threads_count <= 1) { - /* No thread pool used: execute task sequentially on the calling thread */ - struct fpu_state saved_fpu_state = { 0 }; - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - saved_fpu_state = get_fpu_state(); - disable_fpu_denormals(); - } - for (size_t i = 0; i < range_i; i++) { - for (size_t j = 0; j < range_j; j++) { - for (size_t k = 0; k < range_k; k++) { - for (size_t l = 0; l < range_l; l++) { - for (size_t m = 0; m < range_m; m += tile_m) { - for (size_t n = 0; n < range_n; n += tile_n) { - task(argument, i, j, k, l, m, n, - min(range_m - m, tile_m), min(range_n - n, tile_n)); - } - } - } - } - } - } - if (flags & PTHREADPOOL_FLAG_DISABLE_DENORMALS) { - set_fpu_state(saved_fpu_state); - } - } else { - /* Execute in parallel on the thread pool using linearized index */ - const size_t tile_range_m = divide_round_up(range_m, tile_m); - const size_t tile_range_n = divide_round_up(range_n, tile_n); - struct compute_6d_tile_2d_context context = { - .task = task, - .argument = argument, - .tile_range_lmn = fxdiv_init_size_t(range_l * tile_range_m * tile_range_n), - .range_k = fxdiv_init_size_t(range_k), - .tile_range_n = fxdiv_init_size_t(tile_range_n), - .range_j = fxdiv_init_size_t(range_j), - .tile_range_m = fxdiv_init_size_t(tile_range_m), - .range_m = range_m, - .range_n = range_n, - .tile_m = tile_m, - .tile_n = tile_n, - }; - pthreadpool_parallelize_1d(threadpool, - (pthreadpool_task_1d_t) compute_6d_tile_2d, &context, - range_i * range_j * range_k * range_l * tile_range_m * tile_range_n, flags); - } -} - -void pthreadpool_destroy_xnnpack(struct pthreadpool* threadpool) { - if (threadpool != NULL) { - if (threadpool->threads_count > 1) { - #if PTHREADPOOL_USE_FUTEX - atomic_store_explicit( - &threadpool->active_threads, threadpool->threads_count - 1 /* caller thread */, memory_order_relaxed); - atomic_store_explicit(&threadpool->has_active_threads, 1, memory_order_release); - - atomic_store_explicit(&threadpool->command, threadpool_command_shutdown, memory_order_release); - - /* Wake up worker threads */ - futex_wake_all(&threadpool->command); - #else - /* Lock the command variable to ensure that threads don't shutdown until both command and active_threads are updated */ - pthread_mutex_lock(&threadpool->command_mutex); - - /* Locking of completion_mutex not needed: readers are sleeping on command_condvar */ - atomic_store_explicit( - &threadpool->active_threads, threadpool->threads_count - 1 /* caller thread */, memory_order_release); - - /* Update the threadpool command. */ - atomic_store_explicit(&threadpool->command, threadpool_command_shutdown, memory_order_release); - - /* Wake up worker threads */ - pthread_cond_broadcast(&threadpool->command_condvar); - - /* Commit the state changes and let workers start processing */ - pthread_mutex_unlock(&threadpool->command_mutex); - #endif - - /* Wait until all threads return */ - for (size_t thread = 1; thread < threadpool->threads_count; thread++) { - pthread_join(threadpool->threads[thread].thread_object, NULL); - } - - /* Release resources */ - pthread_mutex_destroy(&threadpool->execution_mutex); - #if !PTHREADPOOL_USE_FUTEX - pthread_mutex_destroy(&threadpool->completion_mutex); - pthread_cond_destroy(&threadpool->completion_condvar); - pthread_mutex_destroy(&threadpool->command_mutex); - pthread_cond_destroy(&threadpool->command_condvar); - #endif - } - free(threadpool); - } -} diff --git a/caffe2/utils/threadpool/pthreadpool_utils_new_if.h b/caffe2/utils/threadpool/pthreadpool_utils_new_if.h deleted file mode 100644 index 940f53ed6a6a..000000000000 --- a/caffe2/utils/threadpool/pthreadpool_utils_new_if.h +++ /dev/null @@ -1,62 +0,0 @@ -#pragma once - -#include - -#if defined(__SSE__) || defined(__x86_64__) -#include -#endif - -struct fpu_state { -#if defined(__SSE__) || defined(__x86_64__) - uint32_t mxcsr; -#elif defined(__arm__) && defined(__ARM_FP) && (__ARM_FP != 0) - uint32_t fpscr; -#elif defined(__aarch64__) - uint64_t fpcr; -#else - char unused; -#endif -}; - -static inline struct fpu_state get_fpu_state() { - struct fpu_state state = { 0 }; -#if defined(__SSE__) || defined(__x86_64__) - state.mxcsr = (uint32_t) _mm_getcsr(); -#elif defined(__arm__) && defined(__ARM_FP) && (__ARM_FP != 0) - __asm__ __volatile__("VMRS %[fpscr], fpscr" : [fpscr] "=r" (state.fpscr)); -#elif defined(__aarch64__) - __asm__ __volatile__("MRS %[fpcr], fpcr" : [fpcr] "=r" (state.fpcr)); -#endif - return state; -} - -static inline void set_fpu_state(const struct fpu_state state) { -#if defined(__SSE__) || defined(__x86_64__) - _mm_setcsr((unsigned int) state.mxcsr); -#elif defined(__arm__) && defined(__ARM_FP) && (__ARM_FP != 0) - __asm__ __volatile__("VMSR fpscr, %[fpscr]" : : [fpscr] "r" (state.fpscr)); -#elif defined(__aarch64__) - __asm__ __volatile__("MSR fpcr, %[fpcr]" : : [fpcr] "r" (state.fpcr)); -#endif -} - -static inline void disable_fpu_denormals() { -#if defined(__SSE__) || defined(__x86_64__) - _mm_setcsr(_mm_getcsr() | 0x8040); -#elif defined(__arm__) && defined(__ARM_FP) && (__ARM_FP != 0) - uint32_t fpscr; - __asm__ __volatile__( - "VMRS %[fpscr], fpscr\n" - "ORR %[fpscr], #0x1000000\n" - "VMSR fpscr, %[fpscr]\n" - : [fpscr] "=r" (fpscr)); -#elif defined(__aarch64__) - uint64_t fpcr; - __asm__ __volatile__( - "MRS %[fpcr], fpcr\n" - "ORR %w[fpcr], %w[fpcr], 0x1000000\n" - "ORR %w[fpcr], %w[fpcr], 0x80000\n" - "MSR fpcr, %[fpcr]\n" - : [fpcr] "=r" (fpcr)); -#endif -} diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 3bcb5cae680a..cfc6b2c96883 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -266,6 +266,8 @@ if(USE_NNPACK OR USE_QNNPACK OR USE_PYTORCH_QNNPACK OR USE_XNNPACK) set(CPUINFO_LOG_LEVEL "error" CACHE STRING "") set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "") endif() +else() + set(DISABLE_NNPACK_AND_FAMILY ON) endif() set(CONFU_DEPENDENCIES_SOURCE_DIR ${PROJECT_BINARY_DIR}/confu-srcs @@ -281,45 +283,45 @@ if(INTERN_BUILD_MOBILE AND INTERN_USE_EIGEN_BLAS) endif() # ---[ pthreadpool -# QNNPACK and NNPACK both depend on pthreadpool, but when building with libtorch -# they should use the pthreadpool implementation under caffe2/utils/threadpool -# instead of the default implementation. To avoid confusion, add pthreadpool -# subdirectory explicitly with EXCLUDE_FROM_ALL property prior to QNNPACK/NNPACK -# does so, which will prevent it from installing the default pthreadpool library. -if(INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE AND (USE_QNNPACK OR USE_NNPACK OR USE_XNNPACK)) - if(NOT DEFINED PTHREADPOOL_SOURCE_DIR) - set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party") - set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory") +if(NOT USE_SYSTEM_PTHREADPOOL AND (INTERN_BUILD_MOBILE OR NOT DISABLE_NNPACK_AND_FAMILY)) + set(USE_PTHREADPOOL ON CACHE BOOL "" FORCE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_PTHREADPOOL") + + # Opt for custom Caffe2 implementation on MSVC. Windows support seems to have + # been added to pthreadpool recently but the current third party revision we are + # using right now does not suppor it. Should unify later after updating pthreadpool. + if(MSVC) + set(USE_INTERNAL_PTHREADPOOL_IMPL ON CACHE BOOL "" FORCE) + # XNNPACK cannot link against a custom implementation of pthreadpool + caffe2_update_option(USE_XNNPACK OFF) + else() + # We would like to maintain the ability to build against the internal C2 + # pthreadpool implementation for now, hence this flag. This flag is not + # exposed as a build option to the user and is purly internal. + set(USE_INTERNAL_PTHREADPOOL_IMPL OFF CACHE BOOL "" FORCE) + + if(NOT DEFINED PTHREADPOOL_SOURCE_DIR) + set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party") + set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory") + endif() + + if(NOT TARGET pthreadpool) + set(PTHREADPOOL_BUILD_TESTS OFF CACHE BOOL "") + set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE BOOL "") + add_subdirectory( + "${PTHREADPOOL_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/pthreadpool") + set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON) + endif() + + list(APPEND Caffe2_DEPENDENCY_LIBS pthreadpool) endif() - if(NOT TARGET pthreadpool) - set(PTHREADPOOL_BUILD_TESTS OFF CACHE BOOL "") - set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE BOOL "") - add_subdirectory( - "${PTHREADPOOL_SOURCE_DIR}" - "${CONFU_DEPENDENCIES_BINARY_DIR}/pthreadpool" - EXCLUDE_FROM_ALL) - endif() -endif() - -# XNNPACK has not option of like QNNPACK_CUSTOM_THREADPOOL -# that allows us to hijack pthreadpool interface. -# Thus not doing this ends up building pthreadpool as well as -# the internal implemenation of pthreadpool which results in symbol conflicts. -if(USE_XNNPACK AND NOT USE_SYSTEM_XNNPACK) - if(NOT DEFINED PTHREADPOOL_SOURCE_DIR) - set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party") - set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory") - endif() - - if(NOT TARGET pthreadpool) - set(PTHREADPOOL_BUILD_TESTS OFF CACHE BOOL "") - set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE BOOL "") - add_subdirectory( - "${PTHREADPOOL_SOURCE_DIR}" - "${CONFU_DEPENDENCIES_BINARY_DIR}/pthreadpool" - EXCLUDE_FROM_ALL) + if(USE_INTERNAL_PTHREADPOOL_IMPL) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_INTERNAL_PTHREADPOOL_IMPL") endif() +else() + set(USE_PTHREADPOOL OFF CACHE BOOL "" FORCE) endif() # ---[ Caffe2 uses cpuinfo library in the thread pool @@ -369,9 +371,12 @@ if(USE_QNNPACK) endif() if(NOT TARGET qnnpack) + if(NOT USE_SYSTEM_PTHREADPOOL AND USE_INTERNAL_PTHREADPOOL_IMPL) + set(QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "") + endif() + set(QNNPACK_BUILD_TESTS OFF CACHE BOOL "") set(QNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "") - set(QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "") set(QNNPACK_LIBRARY_TYPE "static" CACHE STRING "") add_subdirectory( "${QNNPACK_SOURCE_DIR}" @@ -379,7 +384,6 @@ if(USE_QNNPACK) # We build static versions of QNNPACK and pthreadpool but link # them into a shared library for Caffe2, so they need PIC. set_property(TARGET qnnpack PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON) endif() @@ -400,9 +404,12 @@ if(USE_PYTORCH_QNNPACK) endif() if(NOT TARGET pytorch_qnnpack) + if(NOT USE_SYSTEM_PTHREADPOOL AND USE_INTERNAL_PTHREADPOOL_IMPL) + set(PYTORCH_QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "") + endif() + set(PYTORCH_QNNPACK_BUILD_TESTS OFF CACHE BOOL "") set(PYTORCH_QNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "") - set(PYTORCH_QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "") set(PYTORCH_QNNPACK_LIBRARY_TYPE "static" CACHE STRING "") add_subdirectory( "${PYTORCH_QNNPACK_SOURCE_DIR}" @@ -410,9 +417,6 @@ if(USE_PYTORCH_QNNPACK) # We build static versions of QNNPACK and pthreadpool but link # them into a shared library for Caffe2, so they need PIC. set_property(TARGET pytorch_qnnpack PROPERTY POSITION_INDEPENDENT_CODE ON) - if(NOT USE_SYSTEM_PTHREADPOOL) - set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON) - endif() set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON) endif() @@ -447,7 +451,6 @@ if(USE_XNNPACK AND NOT USE_SYSTEM_XNNPACK) endif() if(NOT TARGET XNNPACK) - set(XNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "") set(XNNPACK_LIBRARY_TYPE "static" CACHE STRING "") set(XNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "") set(XNNPACK_BUILD_TESTS OFF CACHE BOOL "") @@ -457,15 +460,6 @@ if(USE_XNNPACK AND NOT USE_SYSTEM_XNNPACK) "${CONFU_DEPENDENCIES_BINARY_DIR}/XNNPACK") set_property(TARGET XNNPACK PROPERTY POSITION_INDEPENDENT_CODE ON) - # Context: pthreadpool_get_threads_count implementation that is built in pytorch, uses - # implementation defined in caffe2/utils/threadpool/pthreadpool_impl.cc. This implementation - # assumes the the pthreadpool* passed is of type caffe2::ThradPool and thus does reinterpret cast. - # This is not valid when we create pthreadpool via caffe2::xnnpack_threadpool, which is of type - # compatible with new pthreadpool interface and is used in PT's XNNPACK integration. - # Thus all the calls for pthreadpool_get_threads_count originating from XNNPACK must be routed - # appropriately to pthreadpool_get_threads_count_xnnpack, which does not do the aforementioned - # casting to caffe2::ThradPool. Once the threadpools are unified, we will not need this. - target_compile_definitions(XNNPACK PRIVATE -Dpthreadpool_get_threads_count=pthreadpool_get_threads_count_xnnpack) endif() include_directories(SYSTEM ${XNNPACK_INCLUDE_DIR}) diff --git a/cmake/External/nnpack.cmake b/cmake/External/nnpack.cmake index d13048483091..24f54627c012 100644 --- a/cmake/External/nnpack.cmake +++ b/cmake/External/nnpack.cmake @@ -59,9 +59,12 @@ if(ANDROID OR IOS OR ${CMAKE_SYSTEM_NAME} STREQUAL "Linux" OR ${CMAKE_SYSTEM_NAM set(GOOGLETEST_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/googletest" CACHE STRING "Google Test source directory") if(NOT TARGET nnpack) + if(NOT USE_SYSTEM_PTHREADPOOL AND USE_INTERNAL_PTHREADPOOL_IMPL) + set(NNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "") + endif() + set(NNPACK_BUILD_TESTS OFF CACHE BOOL "") set(NNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "") - set(NNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "") set(NNPACK_LIBRARY_TYPE "static" CACHE STRING "") set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "") set(CPUINFO_LIBRARY_TYPE "static" CACHE STRING "") diff --git a/cmake/TorchConfig.cmake.in b/cmake/TorchConfig.cmake.in index 66c4a43787b8..2c3b75941aa2 100644 --- a/cmake/TorchConfig.cmake.in +++ b/cmake/TorchConfig.cmake.in @@ -69,6 +69,11 @@ if(NOT @BUILD_SHARED_LIBS@) list(APPEND TORCH_LIBRARIES ${XNNPACK_LIBRARY}) endif() + if(NOT @USE_INTERNAL_PTHREADPOOL_IMPL@) + find_library(PTHREADPOOL_LIBRARY pthreadpool PATHS "${TORCH_INSTALL_PREFIX}/lib") + list(APPEND TORCH_LIBRARIES ${PTHREADPOOL_LIBRARY}) + endif() + if(@INTERN_USE_EIGEN_BLAS@) find_library(EIGEN_BLAS_LIBRARY eigen_blas PATHS "${TORCH_INSTALL_PREFIX}/lib") list(APPEND TORCH_LIBRARIES ${EIGEN_BLAS_LIBRARY}) diff --git a/ios/TestApp/benchmark/setup.rb b/ios/TestApp/benchmark/setup.rb index 14a64203d7d4..408145412eef 100644 --- a/ios/TestApp/benchmark/setup.rb +++ b/ios/TestApp/benchmark/setup.rb @@ -63,7 +63,7 @@ targets.each do |target| target.resources_build_phase.add_file_reference(config_file_ref, true) end puts "Linking static libraries..." -libs = ['libc10.a', 'libclog.a', 'libXNNPACK.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a'] +libs = ['libc10.a', 'libclog.a', 'libpthreadpool.a', 'libXNNPACK.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a'] targets.each do |target| target.frameworks_build_phases.clear for lib in libs do diff --git a/scripts/xcode_build.rb b/scripts/xcode_build.rb index 214c80c99645..801ad34a64fd 100644 --- a/scripts/xcode_build.rb +++ b/scripts/xcode_build.rb @@ -51,7 +51,7 @@ end # link static libraries target.frameworks_build_phases.clear -libs = ['libc10.a', 'libclog.a', 'libXNNPACK.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a'] +libs = ['libc10.a', 'libclog.a', 'libpthreadpool.a', 'libXNNPACK.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a'] for lib in libs do path = "#{install_path}/lib/#{lib}" if File.exist?(path)