Unify PyTorch mobile's threadpool usage. (#37243)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37243

*** Why ***

As it stands, we have two thread pool solutions concurrently in use in PyTorch mobile: (1) the open source pthreadpool library under third_party, and (2) Caffe2's implementation of pthreadpool under caffe2/utils/threadpool.  Since the primary use-case of the latter has been to act as a drop-in replacement for the third party version so as to enable integration and usage from within NNPACK and QNNPACK, Caffe2's implementation is intentionally written to the exact same interface as the third party version.

The original argument in favor of C2's implementation has been improved performance as a result of using spin locks, as opposed to relinquishing the thread's time slot and putting it to sleep - a less expensive operation up to a point.  That seems to have given C2's implementation the upper hand in performance, hence justifying the added maintenance complexity, until the third party version improved in parallel surpassing the efficiency of C2's implementation as I have verified in benchmarks.  With that advantage gone, there is no reason to continue using C2's implementation in PyTorch mobile either from the perspective of performance or code hygiene.  As a matter of fact, there is considerable performance benefit to be had as a result of using the third party version as it currently stands.

This is a tricky change though, mainly because in order to avoid potential performance regressions, of which I have witnessed none but just in abundance of caution, we have decided to continue using the internal C2's implementation whenever building for Caffe2.  Again, this is mainly to avoid potential performance regressions in production C2 use cases even if doing so results in reduced performance as far as I can tell.

So to summarize, today, and as it currently stands, we are using C2's implementation for (1) NNPACK, (2) PyTorch QNNPACK, and (3) ATen parallel_for on mobile builds, while using the third party version of pthreadpool for XNNPACK as XNNPACK does not provide any build options to link against an external implementation unlike NNPACK and QNNPACK do.

The goal of this PR then, is to unify all usage on mobile to the third party implementation both for improved performance and better code hygiene.  This applies to PyTorch's use of NNPACK, QNNPACK, XNNPACK, and mobile's implementation of ATen parallel_for, all getting routed to the
exact same third party implementation in this PR.

Considering that NNPACK, QNNPACK, and XNNPACK are not mobile specific, these benefits carry over to non-mobile builds of PyTorch (but not Caffe2) as well.  The implementation of ATen parallel_for on non-mobile builds remains unchanged.

*** How ***

This is where things get tricky.

A good deal of the build system complexity in this PR arises from our desire to maintain C2's implementation intact for C2's use.

pthreadpool is a C library with no concept of namespaces, which means two copies of the library cannot exist in the same binary or symbol collision will occur violating ODR.  This means that somehow, and based on some condition, we must decide on the choice of a pthreadpool implementation.  In practice, this has become more complicated as a result of all the possible combinations that USE_NNPACK, USE_QNNPACK, USE_PYTORCH_QNNPACK, USE_XNNPACK, USE_SYSTEM_XNNPACK, USE_SYSTEM_PTHREADPOOL and other variables can result in.  Having said that, I have done my best in this PR to surgically cut through this complexity in a way that minimizes the side effects, considering the significance of the performance we are leaving on the table, yet, as a result of this combinatorial explosion explained above I cannot guarantee that every single combination will work as expected on the first try.  I am heavily relying on CI to find any issues as local testing can only go that far.

Having said that, this PR provides a simple non mobile-specific C++ thread pool implementation on top of pthreadpool, namely caffe2::PThreadPool that automatically routes to C2's implementation or the third party version depending on the build configuration.  This simplifies the logic at the cost of pushing the complexity to the build scripts.  From there on, this thread pool is used in aten parallel_for, and NNPACK and family, again, routing all usage of threading to C2 or third party pthreadpool depending on the build configuration.

When it is all said or done, the layering will look like this:

a) aten::parallel_for, uses
b) caffe2::PThreadPool, which uses
c) pthreadpool C API, which delegates to
    c-1) third_party implementation of pthreadpool if that's what the build has requested, and the rabbit hole ends here.
    c-2) C2's implementation of pthreadpool if that's what the build has requested, which itself delegates to
    c-2-1) caffe2::ThreadPool, and the rabbit hole ends here.

NNPACK, and (PyTorch) QNNPACK directly hook into (c). They never go through (b).

Differential Revision: D21232894

Test Plan: Imported from OSS

Reviewed By: dreiss

Pulled By: AshkanAliabadi

fbshipit-source-id: 8b3de86247fbc3a327e811983e082f9d40081354
This commit is contained in:
Ashkan Aliabadi
2020-06-23 16:24:27 -07:00
committed by Facebook GitHub Bot
parent c7d79f35e3
commit b9d3869df3
44 changed files with 366 additions and 1756 deletions

View File

@ -14,7 +14,7 @@ mkdir -p ${ZIP_DIR}/src
cp -R ${ARTIFACTS_DIR}/arm64/include ${ZIP_DIR}/install/
# build a FAT bianry
cd ${ZIP_DIR}/install/lib
target_libs=(libc10.a libclog.a libcpuinfo.a libeigen_blas.a libpytorch_qnnpack.a libtorch_cpu.a libtorch.a libXNNPACK.a)
target_libs=(libc10.a libclog.a libcpuinfo.a libeigen_blas.a libpthreadpool.a libpytorch_qnnpack.a libtorch_cpu.a libtorch.a libXNNPACK.a)
for lib in ${target_libs[*]}
do
if [ -f "${ARTIFACTS_DIR}/x86_64/lib/${lib}" ] && [ -f "${ARTIFACTS_DIR}/arm64/lib/${lib}" ]; then

View File

@ -1350,7 +1350,6 @@ filegroup(
"caffe2/utils/smart_tensor_printer.cc",
"caffe2/utils/string_utils.cc",
"caffe2/utils/threadpool/ThreadPool.cc",
"caffe2/utils/threadpool/ThreadPoolMobile.cc",
"caffe2/utils/threadpool/pthreadpool.cc",
"caffe2/utils/threadpool/pthreadpool_impl.cc",
],

View File

@ -481,7 +481,7 @@ if(USE_PYTORCH_QNNPACK)
endif()
if(USE_XNNPACK)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_XNNPACK -DUSE_INTERNAL_THREADPOOL_IMPL")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_XNNPACK")
endif()
if(USE_VULKAN)

View File

@ -99,6 +99,7 @@ if(ANDROID_ABI)
import_static_lib(libnnpack)
import_static_lib(libXNNPACK)
import_static_lib(libpytorch_qnnpack)
import_static_lib(libpthreadpool)
import_static_lib(libeigen_blas)
import_static_lib(libcpuinfo)
import_static_lib(libclog)
@ -115,6 +116,7 @@ if(ANDROID_ABI)
libnnpack
libXNNPACK
libpytorch_qnnpack
libpthreadpool
libeigen_blas
libcpuinfo
libclog
@ -129,6 +131,7 @@ else()
nnpack
XNNPACK
pytorch_qnnpack
pthreadpool
cpuinfo
clog
)

View File

@ -8,8 +8,10 @@
#include "pytorch_jni_common.h"
#if defined(__ANDROID__)
#include <caffe2/utils/threadpool/ThreadPool.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#ifndef USE_PTHREADPOOL
#define USE_PTHREADPOOL
#endif /* USE_PTHREADPOOL */
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#endif
namespace pytorch_jni {
@ -605,7 +607,7 @@ class PyTorchAndroidJni : public facebook::jni::JavaClass<PyTorchAndroidJni> {
}
static void setNumThreads(facebook::jni::alias_ref<jclass>, jint numThreads) {
caffe2::mobile_threadpool()->setNumThreads(numThreads);
caffe2::pthreadpool()->set_thread_count(numThreads);
}
};
#endif

View File

@ -6,8 +6,7 @@
#ifndef C10_MOBILE
#include <c10/core/thread_pool.h>
#else
#include <caffe2/utils/threadpool/ThreadPool.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#endif // C10_MOBILE
#include <atomic>
@ -88,15 +87,15 @@ void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) {
// Run the first task on the current thread directly.
fn(0, 0);
#else
caffe2::ThreadPool* pool = caffe2::mobile_threadpool();
if (pool) {
// caffe2::ThreadPool can utilize the current thread.
pool->run(fn, range);
} else {
for (size_t i = 0; i < range; ++i) {
fn(0, i);
}
}
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
pool->run(
// PThreadPool::run() is blocking. A std::function [const] reference to
// this lambda cannot go out of scope before PThreadPool::run() returns.
[&fn](const size_t task_id) {
fn(0 /* unused */, task_id);
}, range);
#endif // C10_MOBILE
}
@ -184,7 +183,7 @@ void init_num_threads() {
#endif
#ifdef C10_MOBILE
caffe2::mobile_threadpool();
caffe2::pthreadpool();
#endif
}
@ -208,7 +207,9 @@ void set_num_threads(int nthreads) {
}
}
#else
TORCH_CHECK(false, "set_num_threads is not supported for mobile.");
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
pool->set_thread_count(nthreads);
#endif // C10_MOBILE
}
@ -226,9 +227,9 @@ int get_num_threads() {
return _get_intraop_pool().size() + 1;
}
#else
caffe2::ThreadPool* pool = caffe2::mobile_threadpool();
// caffe2::ThreadPool::getNumThreads() counts the current thread.
return !pool || in_parallel_region() ? 1 /* current thread */ : pool->getNumThreads();
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!")
return in_parallel_region() ? 1 /* current thread */ : pool->get_thread_count();
#endif // C10_MOBILE
}
@ -257,8 +258,8 @@ void intraop_launch(std::function<void()> func) {
func();
}
#else
// TODO: caffe2::ThreadPool doesn't support submitting tasks separately and
// running in parallel. Should fix it when this API becomes popular.
// TODO: caffe2::PThreadPool only provides a data-parallel API.
// Task parallelism is not currently supported.
func();
#endif // C10_MOBILE
}
@ -280,8 +281,8 @@ std::shared_ptr<c10::ivalue::Future> intraop_launch_future(
}
return future;
#else
// TODO: caffe2::ThreadPool doesn't support submitting tasks separately and
// running in parallel. Should fix it when this API becomes popular.
// TODO: caffe2::PThreadPool only provides a data-parallel API.
// Task parallelism is not currently supported.
auto future = std::make_shared<c10::ivalue::Future>(NoneType::get());
func();
future->markCompleted();

View File

@ -58,7 +58,7 @@ bool _nnpack_available() {
#include <nnpack.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <ATen/native/ConvUtils.h>
namespace at {
@ -87,35 +87,7 @@ static bool init_nnpack() {
}
static pthreadpool_t nnpack_threadpool() {
// Try initializing a threadpool for NNPACK's use. If we fail to
// successfully initialize an implementation, return nullptr which will
// instruct NNPACK to run single threaded.
#ifdef C10_MOBILE
// If building for mobile, use Caffe 2's mobile-friendly threadpool.
return caffe2::mobile_pthreadpool();
#else
// Otherwise, try using pthreadpool if we manage to initialize it successfully.
static pthreadpool_t nnpack_threadpool_ = nullptr;
static bool called_nnpack_threadpool_ = false;
if (!called_nnpack_threadpool_) {
called_nnpack_threadpool_ = true;
#ifdef INTRA_OP_PARALLEL
const uint32_t threads = at::get_num_threads();
#else
const uint32_t threads = std::thread::hardware_concurrency();
#endif
nnpack_threadpool_ = pthreadpool_create(threads);
if (!nnpack_threadpool_) {
LOG(WARNING) << "Failed to initialize pthreadpool! Running NNPACK in single-threaded mode.";
}
}
return nnpack_threadpool_;
#endif
return caffe2::pthreadpool_();
}
bool _nnpack_available() {

View File

@ -5,7 +5,7 @@
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/quantized_ops.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <c10/util/math_compat.h>
#include <algorithm>
@ -375,7 +375,7 @@ Tensor qnnpack_avg_pool2d(
CAFFE_ENFORCE(
setupStatus == pytorch_qnnp_status_success,
"failed to setup QNNPACK Average Pooling operator");
pthreadpool_t threadpool = caffe2::mobile_pthreadpool();
pthreadpool_t threadpool = caffe2::pthreadpool_();
const pytorch_qnnp_status runStatus =
pytorch_qnnp_run_operator(qnnpack_operator, threadpool);
TORCH_INTERNAL_ASSERT(

View File

@ -5,7 +5,6 @@
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/quantized_ops.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <c10/util/math_compat.h>
#include <algorithm>

View File

@ -7,7 +7,7 @@
#include <ATen/native/quantized/cpu/quantized_ops.h>
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <algorithm>
@ -194,7 +194,7 @@ Tensor qnnpack_add(Tensor qa, Tensor qb, double scale, int64_t zero_point) {
setupStatus == pytorch_qnnp_status_success,
"failed to setup QNNPACK Add operator");
pthreadpool_t threadpool = caffe2::mobile_pthreadpool();
pthreadpool_t threadpool = caffe2::pthreadpool_();
const pytorch_qnnp_status runStatus =
pytorch_qnnp_run_operator(qnnpack_operator, threadpool);

View File

@ -8,7 +8,7 @@
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <c10/core/TensorOptions.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <algorithm>
@ -82,7 +82,7 @@ Tensor quantized_channel_shuffle_impl(
setupStatus == pytorch_qnnp_status_success,
"failed to setup QNNPACK ChannelShuffle operator");
pthreadpool_t threadpool = caffe2::mobile_pthreadpool();
pthreadpool_t threadpool = caffe2::pthreadpool_();
const pytorch_qnnp_status runStatus =
pytorch_qnnp_run_operator(qnnpack_operator, threadpool);
TORCH_INTERNAL_ASSERT(

View File

@ -7,7 +7,7 @@
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/quantized/Quantizer.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <algorithm>
@ -64,7 +64,7 @@ Tensor qnnpack_clamp(Tensor input, Scalar min, Scalar max) {
TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success,
"failed to setup QNNPACK Clamp operator");
pthreadpool_t threadpool = caffe2::mobile_pthreadpool();
pthreadpool_t threadpool = caffe2::pthreadpool_();
const pytorch_qnnp_status runStatus =
pytorch_qnnp_run_operator(clamp_op, threadpool);

View File

@ -10,7 +10,7 @@
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/quant_utils.h>
#include <ATen/native/quantized/cpu/conv_packed_params.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
template <int kSpatialDim = 2>
bool ConvDimChecks(
@ -603,7 +603,7 @@ at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(
output_min,
output_max,
reinterpret_cast<uint8_t*>(output.template data_ptr<c10::quint8>()),
caffe2::mobile_pthreadpool());
caffe2::pthreadpool_());
TORCH_INTERNAL_ASSERT(
run_status == pytorch_qnnp_status_success,

View File

@ -5,7 +5,7 @@
#include <ATen/native/quantized/cpu/quantized_ops.h>
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <algorithm>
@ -57,7 +57,7 @@ Tensor qnnpack_hardsigmoid(Tensor input) {
TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success,
"failed to setup QNNPACK Hardsigmoid operator");
pthreadpool_t threadpool = caffe2::mobile_pthreadpool();
pthreadpool_t threadpool = caffe2::pthreadpool_();
const pytorch_qnnp_status runStatus =
pytorch_qnnp_run_operator(hardsigmoid_op, threadpool);

View File

@ -5,7 +5,7 @@
#include <ATen/native/quantized/cpu/quantized_ops.h>
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <algorithm>
@ -51,7 +51,7 @@ Tensor qnnpack_hardswish(const Tensor& qx, Tensor& qy) {
TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success,
"failed to setup QNNPACK Hardswish operator");
pthreadpool_t threadpool = caffe2::mobile_pthreadpool();
pthreadpool_t threadpool = caffe2::pthreadpool_();
const pytorch_qnnp_status runStatus =
pytorch_qnnp_run_operator(hardswish_op, threadpool);

View File

@ -4,7 +4,7 @@
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/packed_params.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <torch/custom_class.h>
#include <torch/library.h>
@ -341,7 +341,9 @@ at::Tensor PackedLinearWeightsQnnp::apply_impl(
packB->getPackedWeights(),
(uint8_t*)output.data_ptr<c10::quint8>(),
rows_w /* output_stride */,
caffe2::mobile_pthreadpool() /* threadpool */);
// TODO (Ashkan): Disabling temporarily.
// Throws a floating point exception with OSS pthreadpool.
nullptr);
TORCH_INTERNAL_ASSERT(
runStatus == pytorch_qnnp_status_success,

View File

@ -5,7 +5,7 @@
#include <ATen/native/quantized/cpu/packed_params.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/quant_utils.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <torch/library.h>
#include <torch/custom_class.h>
@ -327,7 +327,7 @@ at::Tensor PackedLinearWeightsQnnp::apply_dynamic_impl(at::Tensor input) {
bias_ptr,
output.data_ptr<float>(),
rows_w /* output_stride */,
caffe2::mobile_pthreadpool() /* threadpool */);
caffe2::pthreadpool_() /* threadpool */);
TORCH_INTERNAL_ASSERT(
runStatus == pytorch_qnnp_status_success,

View File

@ -9,7 +9,7 @@
#include <ATen/native/quantized/cpu/quantized_ops.h>
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <algorithm>
#include <vector>
@ -346,7 +346,7 @@ void check_maxpool2d_params(
setupStatus == pytorch_qnnp_status_success,
"failed to setup QNNPACK MaxPool operator");
pthreadpool_t threadpool = caffe2::mobile_pthreadpool();
pthreadpool_t threadpool = caffe2::pthreadpool_();
const pytorch_qnnp_status runStatus =
pytorch_qnnp_run_operator(qnnpack_operator, threadpool);
TORCH_INTERNAL_ASSERT(

View File

@ -3,7 +3,7 @@
#include <ATen/NativeFunctions.h>
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
namespace at {
namespace native {
@ -66,7 +66,7 @@ Tensor qnnpack_mean(const Tensor& input, IntArrayRef dim) {
CAFFE_ENFORCE(
setupStatus == pytorch_qnnp_status_success,
"failed to setup QNNPACK Global Average Pooling operator");
pthreadpool_t threadpool = caffe2::mobile_pthreadpool();
pthreadpool_t threadpool = caffe2::pthreadpool_();
const pytorch_qnnp_status runStatus =
pytorch_qnnp_run_operator(qnnpack_operator, threadpool);
TORCH_INTERNAL_ASSERT(

View File

@ -6,7 +6,7 @@
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/quantized_ops.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <torch/library.h>
#include <algorithm>
@ -69,7 +69,7 @@ Tensor qnnpack_relu(Tensor input) {
setupStatus == pytorch_qnnp_status_success,
"failed to setup QNNPACK Relu operator");
pthreadpool_t threadpool = caffe2::mobile_pthreadpool();
pthreadpool_t threadpool = caffe2::pthreadpool_();
const pytorch_qnnp_status runStatus =
pytorch_qnnp_run_operator(qnnpack_operator, threadpool);

View File

@ -7,7 +7,7 @@
#include <ATen/native/quantized/cpu/quantized_ops.h>
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <algorithm>
@ -66,7 +66,7 @@ Tensor qnnpack_sigmoid(Tensor input) {
TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success,
"failed to setup QNNPACK sigmoid operator");
pthreadpool_t threadpool = caffe2::mobile_pthreadpool();
pthreadpool_t threadpool = caffe2::pthreadpool_();
const pytorch_qnnp_status runStatus =
pytorch_qnnp_run_operator(sigmoid_op, threadpool);

View File

@ -7,7 +7,7 @@
#include <ATen/native/quantized/cpu/quantized_ops.h>
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <algorithm>
@ -64,7 +64,7 @@ Tensor qnnpack_tanh(Tensor input) {
TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success,
"failed to setup QNNPACK TanH operator");
pthreadpool_t threadpool = caffe2::mobile_pthreadpool();
pthreadpool_t threadpool = caffe2::pthreadpool_();
const pytorch_qnnp_status runStatus =
pytorch_qnnp_run_operator(tanh_op, threadpool);

View File

@ -5,7 +5,7 @@
#ifdef USE_XNNPACK
#include <xnnpack.h>
#include <caffe2/utils/threadpool/ThreadPoolXNNPACK.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
namespace at {
namespace native {

View File

@ -208,15 +208,15 @@ Tensor run(
padded_input_nhwc.size(Layout::Activation4D::width), // input_width
padded_input_nhwc.data_ptr<float>(), // input
output.data_ptr<float>(), // output
caffe2::xnnpack_threadpool()); // threadpool
caffe2::pthreadpool_()); // threadpool
TORCH_CHECK(
xnn_status_success == setup_status,
"xnn_setup_convolution2d_nhwc_f32 failed!");
const xnn_status run_status = xnn_run_operator(
context.op.get(), // operator
caffe2::xnnpack_threadpool()); // threadpool
context.op.get(), // operator
caffe2::pthreadpool_()); // threadpool
TORCH_INTERNAL_ASSERT(
xnn_status_success == run_status,

View File

@ -137,15 +137,15 @@ Tensor run(
Layout::ActivationND::batch(padded_input.sizes()), // Batch,
padded_input.data_ptr<float>(), // input
output.data_ptr<float>(), // output
caffe2::xnnpack_threadpool()); // threadpool
caffe2::pthreadpool_()); // threadpool
TORCH_CHECK(
xnn_status_success == setup_status,
"xnn_setup_fully_connected_nc_f32 failed!");
const xnn_status run_status = xnn_run_operator(
context.op.get(), // operator
caffe2::xnnpack_threadpool()); // threadpool
context.op.get(), // operator
caffe2::pthreadpool_()); // threadpool
TORCH_INTERNAL_ASSERT(
xnn_status_success == run_status,

View File

@ -219,15 +219,15 @@ Tensor max_pool2d(
input_padded_contig_nhwc.size(Layout::Activation4D::width), // input_width
input_padded_contig_nhwc.data_ptr<float>(), // input
output_padded_contig_nhwc.data_ptr<float>(), // output
caffe2::xnnpack_threadpool()); // threadpool
caffe2::pthreadpool_()); // threadpool
TORCH_CHECK(
xnn_status_success == setup_status,
"xnn_setup_max_pooling2d_nhwc_f32 failed!");
const xnn_status run_status = xnn_run_operator(
max_pool_op, // operator
caffe2::xnnpack_threadpool()); // threadpool
max_pool_op, // operator
caffe2::pthreadpool_()); // threadpool
TORCH_INTERNAL_ASSERT(
xnn_status_success == run_status,

View File

@ -87,7 +87,6 @@ endif()
# Note: the folders that are being commented out have not been properly
# addressed yet.
# For pthreadpool_new_if_impl. TODO: Remove when threadpools are unitied.
if(NOT MSVC AND USE_XNNPACK)
if(NOT TARGET fxdiv)
set(FXDIV_BUILD_TESTS OFF CACHE BOOL "")
@ -96,10 +95,6 @@ if(NOT MSVC AND USE_XNNPACK)
"${FXDIV_SOURCE_DIR}"
"${CMAKE_BINARY_DIR}/FXdiv")
endif()
if(NOT (INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE))
set_source_files_properties(
utils/threadpool/pthreadpool_new_if_impl.c PROPERTIES COMPILE_FLAGS -fno-openmp)
endif()
endif()
add_subdirectory(core)

View File

@ -1,15 +1,8 @@
# TODO: Add ThreadPoolXNNPACK.cc when XNNPACK integration is updated
# to pass the actual threadpool ptr instead of nullptr.
if(INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE)
add_definitions(-DUSE_INTERNAL_THREADPOOL_IMPL)
list(APPEND Caffe2_CPU_SRCS
utils/string_utils.cc
utils/threadpool/pthreadpool.cc
utils/threadpool/pthreadpool_impl.cc
utils/threadpool/pthreadpool_new_if_impl.c
utils/threadpool/pthreadpool-cpp.cc
utils/threadpool/ThreadPool.cc
utils/threadpool/ThreadPoolMobile.cc
utils/threadpool/ThreadPoolXNNPACK.cc
)
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
return()
@ -28,23 +21,19 @@ list(APPEND Caffe2_CPU_SRCS
utils/proto_convert.cc
utils/proto_utils.cc
utils/proto_wrap.cc
utils/threadpool/ThreadPool.cc
utils/signal_handler.cc
utils/smart_tensor_printer.cc
utils/string_utils.cc
utils/threadpool/ThreadPool.cc)
utils/string_utils.cc)
# ---[ threadpool/pthreadpool* is a local modification of the NNPACK
# pthreadpool with a very similar interface. Neither NNPACK, nor this
# thread pool supports Windows.
if(NOT MSVC AND USE_XNNPACK)
add_definitions(-DUSE_INTERNAL_THREADPOOL_IMPL)
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS}
utils/threadpool/pthreadpool.cc
utils/threadpool/pthreadpool_impl.cc
utils/threadpool/pthreadpool_new_if_impl.c
utils/threadpool/ThreadPoolMobile.cc
utils/threadpool/ThreadPoolXNNPACK.cc
)
if(USE_PTHREADPOOL)
list(APPEND Caffe2_CPU_SRCS
utils/threadpool/pthreadpool-cpp.cc)
if(USE_INTERNAL_PTHREADPOOL_IMPL)
list(APPEND Caffe2_CPU_SRCS
utils/threadpool/pthreadpool.cc
utils/threadpool/pthreadpool_impl.cc)
endif()
endif()
set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS}

View File

@ -1,21 +0,0 @@
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/ThreadPool.h>
#include <caffe2/utils/threadpool/pthreadpool.h>
namespace caffe2 {
caffe2::ThreadPool* mobile_threadpool() {
#ifdef C10_MOBILE
static std::unique_ptr<caffe2::ThreadPool> thread_pool =
caffe2::ThreadPool::defaultThreadPool();
return thread_pool.get();
#else
return nullptr;
#endif
}
pthreadpool_t mobile_pthreadpool() {
return reinterpret_cast<pthreadpool_t>(mobile_threadpool());
}
} // namespace caffe2

View File

@ -1,24 +0,0 @@
#pragma once
#include <caffe2/utils/threadpool/pthreadpool.h>
// TODO Implement a parallel_for version for Mobile here, add to Aten/Parallel.h
namespace caffe2 {
class ThreadPool;
// Return a singleton instance of caffe2::ThreadPool for ATen/TH multithreading.
ThreadPool* mobile_threadpool();
// NOTE: This interface is temporary and should not be used.
// Please use Aten/Parallel.h for parallel primitives in pytorch.
// This implementation will be used by pytorch mobile, specifically
// NNPACK/QNNPACK. For mobile we need to use caffe2::ThreadPool instead of the
// 3rd party pthreadpool. Future work (TODO) Implement a mobile version of
// "at::parallel_for" using caffe2::ThreadPool so all ATen/TH multithreading
// usage is mobile friendly; Refactor QNNPACK or pthreadpool to explicitly using
// "at::parallel_for" primitive to replace pthreadpool_compute_1d for Pytorch;
pthreadpool_t mobile_pthreadpool();
size_t getDefaultNumThreads();
} // namespace caffe2

View File

@ -1,22 +0,0 @@
#include <caffe2/utils/threadpool/pthreadpool.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>
#include <caffe2/utils/threadpool/ThreadPoolXNNPACK.h>
#include <memory>
namespace caffe2 {
// Will be unified.
pthreadpool_t xnnpack_threadpool() {
// Depending on internal implemenation vs. OSS we will link against pthreadpool_create_xnnpack
// or pthreadpool_create. This is only temporary. It will be unified soon.
#ifdef USE_INTERNAL_THREADPOOL_IMPL
static std::unique_ptr<pthreadpool, decltype(&pthreadpool_destroy_xnnpack)>
threadpool(pthreadpool_create_xnnpack(getDefaultNumThreads()), pthreadpool_destroy_xnnpack);
#else
static std::unique_ptr<pthreadpool, decltype(&pthreadpool_destroy)>
threadpool(pthreadpool_create(getDefaultNumThreads()), pthreadpool_destroy);
#endif
return threadpool.get();
}
} // namespace caffe2

View File

@ -1,7 +0,0 @@
#pragma once
// Creating a separate .h/.cc file for creating threadpool for XNNPACK
// to avoid touching existing internal builds.
// When we unify threadpools this should all go away.
namespace caffe2 {
pthreadpool_t xnnpack_threadpool();
} // namespace caffe2

View File

@ -0,0 +1,71 @@
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <c10/util/Exception.h>
namespace caffe2 {
PThreadPool::PThreadPool(const size_t thread_count)
: threadpool_(pthreadpool_create(thread_count), pthreadpool_destroy) {}
size_t PThreadPool::get_thread_count() const {
std::lock_guard<std::mutex> lock{mutex_};
TORCH_INTERNAL_ASSERT(threadpool_.get(), "Invalid threadpool!");
return pthreadpool_get_threads_count(threadpool_.get());
}
void PThreadPool::set_thread_count(const size_t thread_count) {
std::lock_guard<std::mutex> lock{mutex_};
// As it stands, pthreadpool is an entirely data parallel framework with no
// support for task parallelism. Hence, all functions are blocking, and no
// user-provided tasks can be in flight when the control is returned to the
// user of the API, which means re-initializing the library, without the
// need to wait on any pending tasks, is all one needs to do to re-adjust
// the thread count.
threadpool_.reset(pthreadpool_create(thread_count));
}
void PThreadPool::run(
const std::function<void(size_t)>& fn,
const size_t range) {
std::lock_guard<std::mutex> lock{mutex_};
TORCH_INTERNAL_ASSERT(threadpool_.get(), "Invalid threadpool!");
struct Context final {
const std::function<void(size_t)>& fn;
} context{
fn,
};
pthreadpool_parallelize_1d(
threadpool_.get(),
// Note: pthreadpool_parallelize_1d() is a blocking function. The
// function pointer to this lambda passed on to
// pthreadpool_parallelize_1d() cannot go out of scope until
// pthreadpool_parallelize_1d() returns.
[](void* const context, const size_t item) {
reinterpret_cast<Context*>(context)->fn(item);
},
&context,
range,
0u);
}
// Forward declaration
size_t getDefaultNumThreads();
PThreadPool* pthreadpool() {
static std::unique_ptr<PThreadPool> threadpool =
std::make_unique<PThreadPool>(getDefaultNumThreads());
return threadpool.get();
}
pthreadpool_t pthreadpool_() {
PThreadPool* const threadpool = pthreadpool();
TORCH_INTERNAL_ASSERT(
threadpool, "Failed to acquire an instance of PThreadPool!");
return threadpool->threadpool_.get();
}
} // namespace caffe2

View File

@ -0,0 +1,54 @@
#pragma once
#ifdef USE_PTHREADPOOL
#ifdef USE_INTERNAL_PTHREADPOOL_IMPL
#include <caffe2/utils/threadpool/pthreadpool.h>
#else
#include <pthreadpool.h>
#endif
#include <functional>
#include <memory>
#include <mutex>
namespace caffe2 {
class PThreadPool final {
public:
explicit PThreadPool(size_t thread_count);
~PThreadPool() = default;
PThreadPool(const PThreadPool&) = delete;
PThreadPool& operator=(const PThreadPool&) = delete;
PThreadPool(PThreadPool&&) = delete;
PThreadPool& operator=(PThreadPool&&) = delete;
size_t get_thread_count() const;
void set_thread_count(size_t thread_count);
// Run, in parallel, function fn(task_id) over task_id in range [0, range).
// This function is blocking. All input is processed by the time it returns.
void run(const std::function<void(size_t)>& fn, size_t range);
private:
friend pthreadpool_t pthreadpool_();
private:
mutable std::mutex mutex_;
std::unique_ptr<pthreadpool, decltype(&pthreadpool_destroy)> threadpool_;
};
// Return a singleton instance of PThreadPool for ATen/TH multithreading.
PThreadPool* pthreadpool();
// Exposes the underlying implementation of PThreadPool.
// Only for use in external libraries so as to unify threading across
// internal (i.e. ATen, etc.) and external (e.g. NNPACK, QNNPACK, XNNPACK)
// use cases.
pthreadpool_t pthreadpool_();
} // namespace caffe2
#endif /* USE_PTHREADPOOL */

View File

@ -32,7 +32,7 @@ static inline size_t min(size_t a, size_t b) {
}
struct compute_1d_tiled_context {
pthreadpool_function_1d_tiled_t function;
legacy_pthreadpool_function_1d_tiled_t function;
void* argument;
size_t range;
size_t tile;
@ -46,9 +46,9 @@ static void compute_1d_tiled(void* context_, size_t linear_index) {
context->function(context->argument, index, tile);
}
void pthreadpool_compute_1d_tiled(
pthreadpool_t threadpool,
pthreadpool_function_1d_tiled_t function,
void legacy_pthreadpool_compute_1d_tiled(
legacy_pthreadpool_t threadpool,
legacy_pthreadpool_function_1d_tiled_t function,
void* argument,
size_t range,
size_t tile)
@ -65,12 +65,12 @@ void pthreadpool_compute_1d_tiled(
/*.argument = */ argument,
/*.range = */ range,
/*.tile = */ tile};
pthreadpool_compute_1d(threadpool, (pthreadpool_function_1d_t) compute_1d_tiled, &context, tile_range);
legacy_pthreadpool_compute_1d(threadpool, (legacy_pthreadpool_function_1d_t) compute_1d_tiled, &context, tile_range);
}
}
struct compute_2d_context {
pthreadpool_function_2d_t function;
legacy_pthreadpool_function_2d_t function;
void* argument;
caffe2::FixedDivisor<int32_t> range_j;
};
@ -85,9 +85,9 @@ static void compute_2d(void* context_, size_t linear_index) {
context->function(context->argument, q, r);
}
void pthreadpool_compute_2d(
struct pthreadpool* threadpool,
pthreadpool_function_2d_t function,
void legacy_pthreadpool_compute_2d(
legacy_pthreadpool_t threadpool,
legacy_pthreadpool_function_2d_t function,
void* argument,
size_t range_i,
size_t range_j)
@ -106,12 +106,12 @@ void pthreadpool_compute_2d(
/*.function = */ function,
/*.argument = */ argument,
/*.range_j = */ caffe2::FixedDivisor<int32_t>(range_j)};
pthreadpool_compute_1d(threadpool, (pthreadpool_function_1d_t) compute_2d, &context, range_i * range_j);
legacy_pthreadpool_compute_1d(threadpool, (legacy_pthreadpool_function_1d_t) compute_2d, &context, range_i * range_j);
}
}
struct compute_2d_tiled_context {
pthreadpool_function_2d_tiled_t function;
legacy_pthreadpool_function_2d_tiled_t function;
void* argument;
caffe2::FixedDivisor<int32_t> tile_range_j;
size_t range_i;
@ -135,9 +135,9 @@ static void compute_2d_tiled(void* context_, size_t linear_index) {
context->function(context->argument, index_i, index_j, tile_i, tile_j);
}
void pthreadpool_compute_2d_tiled(
pthreadpool_t threadpool,
pthreadpool_function_2d_tiled_t function,
void legacy_pthreadpool_compute_2d_tiled(
legacy_pthreadpool_t threadpool,
legacy_pthreadpool_function_2d_tiled_t function,
void* argument,
size_t range_i,
size_t range_j,
@ -166,12 +166,12 @@ void pthreadpool_compute_2d_tiled(
/*.range_j = */ range_j,
/*.tile_i = */ tile_i,
/*.tile_j = */ tile_j};
pthreadpool_compute_1d(threadpool, (pthreadpool_function_1d_t) compute_2d_tiled, &context, tile_range_i * tile_range_j);
legacy_pthreadpool_compute_1d(threadpool, (legacy_pthreadpool_function_1d_t) compute_2d_tiled, &context, tile_range_i * tile_range_j);
}
}
struct compute_3d_tiled_context {
pthreadpool_function_3d_tiled_t function;
legacy_pthreadpool_function_3d_tiled_t function;
void* argument;
caffe2::FixedDivisor<int32_t> tile_range_j;
caffe2::FixedDivisor<int32_t> tile_range_k;
@ -205,9 +205,9 @@ static void compute_3d_tiled(
context->argument, index_i, index_j, index_k, tile_i, tile_j, tile_k);
}
void pthreadpool_compute_3d_tiled(
pthreadpool_t threadpool,
pthreadpool_function_3d_tiled_t function,
void legacy_pthreadpool_compute_3d_tiled(
legacy_pthreadpool_t threadpool,
legacy_pthreadpool_function_3d_tiled_t function,
void* argument,
size_t range_i,
size_t range_j,
@ -251,16 +251,16 @@ void pthreadpool_compute_3d_tiled(
/*.tile_i = */ tile_i,
/*.tile_j = */ tile_j,
/*.tile_k = */ tile_k};
pthreadpool_compute_1d(
legacy_pthreadpool_compute_1d(
threadpool,
(pthreadpool_function_1d_t)compute_3d_tiled,
(legacy_pthreadpool_function_1d_t)compute_3d_tiled,
&context,
tile_range_i * tile_range_j * tile_range_k);
}
}
struct compute_4d_tiled_context {
pthreadpool_function_4d_tiled_t function;
legacy_pthreadpool_function_4d_tiled_t function;
void* argument;
caffe2::FixedDivisor<int32_t> tile_range_kl;
caffe2::FixedDivisor<int32_t> tile_range_j;
@ -310,9 +310,9 @@ static void compute_4d_tiled(
tile_l);
}
void pthreadpool_compute_4d_tiled(
pthreadpool_t threadpool,
pthreadpool_function_4d_tiled_t function,
void legacy_pthreadpool_compute_4d_tiled(
legacy_pthreadpool_t threadpool,
legacy_pthreadpool_function_4d_tiled_t function,
void* argument,
size_t range_i,
size_t range_j,
@ -367,9 +367,9 @@ void pthreadpool_compute_4d_tiled(
/*.tile_j = */ tile_j,
/*.tile_k = */ tile_k,
/*.tile_l = */ tile_l};
pthreadpool_compute_1d(
legacy_pthreadpool_compute_1d(
threadpool,
(pthreadpool_function_1d_t)compute_4d_tiled,
(legacy_pthreadpool_function_1d_t)compute_4d_tiled,
&context,
tile_range_i * tile_range_j * tile_range_k * tile_range_l);
}

View File

@ -5,49 +5,16 @@
#include "ThreadPoolCommon.h"
#include <stddef.h> // for size_t
typedef struct pthreadpool* pthreadpool_t;
typedef void (*pthreadpool_function_1d_t)(void*, size_t);
typedef void (*pthreadpool_function_1d_tiled_t)(void*, size_t, size_t);
typedef void (*pthreadpool_function_2d_t)(void*, size_t, size_t);
typedef void (*pthreadpool_function_2d_tiled_t)(void*, size_t, size_t, size_t, size_t);
typedef void (*pthreadpool_function_3d_tiled_t)(
void*,
size_t,
size_t,
size_t,
size_t,
size_t,
size_t);
typedef void (*pthreadpool_function_4d_tiled_t)(
void*,
size_t,
size_t,
size_t,
size_t,
size_t,
size_t,
size_t,
size_t);
#include <stdint.h> // for uint32_t
typedef void (*pthreadpool_task_1d_t)(void*, size_t);
typedef void (*pthreadpool_task_1d_tile_1d_t)(void*, size_t, size_t);
typedef void (*pthreadpool_task_2d_t)(void*, size_t, size_t);
typedef void (*pthreadpool_task_2d_tile_1d_t)(void*, size_t, size_t, size_t);
typedef void (*pthreadpool_task_2d_tile_2d_t)(void*, size_t, size_t, size_t, size_t);
typedef void (*pthreadpool_task_3d_tile_2d_t)(
void*,
size_t,
size_t,
size_t,
size_t,
size_t);
typedef void (*pthreadpool_task_4d_tile_2d_t)(
typedef struct pthreadpool* legacy_pthreadpool_t;
typedef void (*legacy_pthreadpool_function_1d_t)(void*, size_t);
typedef void (*legacy_pthreadpool_function_1d_tiled_t)(void*, size_t, size_t);
typedef void (*legacy_pthreadpool_function_2d_t)(void*, size_t, size_t);
typedef void (*legacy_pthreadpool_function_2d_tiled_t)(void*, size_t, size_t, size_t, size_t);
typedef void (*legacy_pthreadpool_function_3d_tiled_t)(
void*,
size_t,
size_t,
@ -55,16 +22,7 @@ typedef void (*pthreadpool_task_4d_tile_2d_t)(
size_t,
size_t,
size_t);
typedef void (*pthreadpool_task_5d_tile_2d_t)(
void*,
size_t,
size_t,
size_t,
size_t,
size_t,
size_t,
size_t);
typedef void (*pthreadpool_task_6d_tile_2d_t)(
typedef void (*legacy_pthreadpool_function_4d_tiled_t)(
void*,
size_t,
size_t,
@ -90,8 +48,8 @@ extern "C" {
* On error the function returns NULL and sets errno accordingly.
*/
//Returns internal threadpool impl.
pthreadpool_t pthreadpool_create(size_t threads_count);
// Returns internal threadpool impl.
legacy_pthreadpool_t legacy_pthreadpool_create(size_t threads_count);
/**
* Queries the number of threads in a thread pool.
@ -100,7 +58,7 @@ pthreadpool_t pthreadpool_create(size_t threads_count);
*
* @returns The number of threads in the thread pool.
*/
size_t pthreadpool_get_threads_count(pthreadpool_t threadpool);
size_t legacy_pthreadpool_get_threads_count(legacy_pthreadpool_t threadpool);
/**
* Processes items in parallel using threads from a thread pool.
@ -117,38 +75,45 @@ size_t pthreadpool_get_threads_count(pthreadpool_t threadpool);
* @param[in] items The number of items to process. The @a function
* will be called once for each item.
*/
void pthreadpool_compute_1d(
pthreadpool_t threadpool,
pthreadpool_function_1d_t function,
void legacy_pthreadpool_compute_1d(
legacy_pthreadpool_t threadpool,
legacy_pthreadpool_function_1d_t function,
void* argument,
size_t range);
void pthreadpool_compute_1d_tiled(
pthreadpool_t threadpool,
pthreadpool_function_1d_tiled_t function,
void legacy_pthreadpool_parallelize_1d(
legacy_pthreadpool_t threadpool,
legacy_pthreadpool_function_1d_t function,
void* argument,
size_t range,
uint32_t flags);
void legacy_pthreadpool_compute_1d_tiled(
legacy_pthreadpool_t threadpool,
legacy_pthreadpool_function_1d_tiled_t function,
void* argument,
size_t range,
size_t tile);
void pthreadpool_compute_2d(
pthreadpool_t threadpool,
pthreadpool_function_2d_t function,
void legacy_pthreadpool_compute_2d(
legacy_pthreadpool_t threadpool,
legacy_pthreadpool_function_2d_t function,
void* argument,
size_t range_i,
size_t range_j);
void pthreadpool_compute_2d_tiled(
pthreadpool_t threadpool,
pthreadpool_function_2d_tiled_t function,
void legacy_pthreadpool_compute_2d_tiled(
legacy_pthreadpool_t threadpool,
legacy_pthreadpool_function_2d_tiled_t function,
void* argument,
size_t range_i,
size_t range_j,
size_t tile_i,
size_t tile_j);
void pthreadpool_compute_3d_tiled(
pthreadpool_t threadpool,
pthreadpool_function_3d_tiled_t function,
void legacy_pthreadpool_compute_3d_tiled(
legacy_pthreadpool_t threadpool,
legacy_pthreadpool_function_3d_tiled_t function,
void* argument,
size_t range_i,
size_t range_j,
@ -157,9 +122,9 @@ void pthreadpool_compute_3d_tiled(
size_t tile_j,
size_t tile_k);
void pthreadpool_compute_4d_tiled(
pthreadpool_t threadpool,
pthreadpool_function_4d_tiled_t function,
void legacy_pthreadpool_compute_4d_tiled(
legacy_pthreadpool_t threadpool,
legacy_pthreadpool_function_4d_tiled_t function,
void* argument,
size_t range_i,
size_t range_j,
@ -178,129 +143,29 @@ void pthreadpool_compute_4d_tiled(
*
* @param[in,out] threadpool The thread pool to destroy.
*/
void pthreadpool_destroy(pthreadpool_t threadpool);
void legacy_pthreadpool_destroy(legacy_pthreadpool_t threadpool);
// New interface copy/pasted from pthreadpool.
// We will merge the internal and third-party/pthreadpool eventually.
// For now copy-paste to get past build issues.
#ifdef USE_INTERNAL_PTHREADPOOL_IMPL
#define PTHREADPOOL_FLAG_DISABLE_DENORMALS 0x00000001
#define pthreadpool_t legacy_pthreadpool_t
#define pthreadpool_function_1d_t legacy_pthreadpool_function_1d_t
#define pthreadpool_function_1d_tiled_t legacy_pthreadpool_function_1d_tiled_t
#define pthreadpool_function_2d_t legacy_pthreadpool_function_2d_t
#define pthreadpool_function_2d_tiled_t legacy_pthreadpool_function_2d_tiled_t
#define pthreadpool_function_3d_tiled_t legacy_pthreadpool_function_3d_tiled_t
#define pthreadpool_function_4d_tiled_t legacy_pthreadpool_function_4d_tiled_t
#define pthreadpool_create legacy_pthreadpool_create
#define pthreadpool_destroy legacy_pthreadpool_destroy
#define pthreadpool_get_threads_count legacy_pthreadpool_get_threads_count
#define pthreadpool_compute_1d legacy_pthreadpool_compute_1d
#define pthreadpool_parallelize_1d legacy_pthreadpool_parallelize_1d
#define pthreadpool_compute_1d_tiled legacy_pthreadpool_compute_1d_tiled
#define pthreadpool_compute_2d legacy_pthreadpool_compute_2d
#define pthreadpool_compute_2d_tiled legacy_pthreadpool_compute_2d_tiled
#define pthreadpool_compute_3d_tiled legacy_pthreadpool_compute_3d_tiled
#define pthreadpool_compute_4d_tiled legacy_pthreadpool_compute_4d_tiled
// Returns the copied threadpool impl of third-party/pthreadpool
pthreadpool_t pthreadpool_create_xnnpack(size_t threads_count);
// Copied third-party impl.
size_t pthreadpool_get_threads_count_xnnpack(pthreadpool_t threadpool);
// Copied third-party impl.
void pthreadpool_destroy_xnnpack(pthreadpool_t threadpool);
/**
* Processes items in parallel using threads from a thread pool.
*
* When the call returns, all items have been processed and the thread pool is
* ready for a new task.
*
* @note If multiple threads call this function with the same thread pool, the
* calls are serialized.
*
* @param[in] threadpool The thread pool to use for parallelisation.
* @param[in] function The function to call for each item.
* @param[in] argument The first argument passed to the @a function.
* @param[in] items The number of items to process. The @a function
* will be called once for each item.
*/
void pthreadpool_parallelize_1d(
pthreadpool_t threadpool,
pthreadpool_task_1d_t function,
void* argument,
size_t range,
uint32_t flags);
void pthreadpool_parallelize_1d_tile_1d(
pthreadpool_t threadpool,
pthreadpool_task_1d_tile_1d_t function,
void* argument,
size_t range,
size_t tile,
uint32_t flags);
void pthreadpool_parallelize_2d(
pthreadpool_t threadpool,
pthreadpool_task_2d_t function,
void* argument,
size_t range_i,
size_t range_j,
uint32_t flags);
void pthreadpool_parallelize_2d_tile_1d(
pthreadpool_t threadpool,
pthreadpool_task_2d_tile_1d_t function,
void* argument,
size_t range_i,
size_t range_j,
size_t tile_j,
uint32_t flags);
void pthreadpool_parallelize_2d_tile_2d(
pthreadpool_t threadpool,
pthreadpool_task_2d_tile_2d_t function,
void* argument,
size_t range_i,
size_t range_j,
size_t tile_i,
size_t tile_j,
uint32_t flags);
void pthreadpool_parallelize_3d_tile_2d(
pthreadpool_t threadpool,
pthreadpool_task_3d_tile_2d_t function,
void* argument,
size_t range_i,
size_t range_j,
size_t range_k,
size_t tile_j,
size_t tile_k,
uint32_t flags);
void pthreadpool_parallelize_4d_tile_2d(
pthreadpool_t threadpool,
pthreadpool_task_4d_tile_2d_t function,
void* argument,
size_t range_i,
size_t range_j,
size_t range_k,
size_t range_l,
size_t tile_k,
size_t tile_l,
uint32_t flags);
void pthreadpool_parallelize_5d_tile_2d(
pthreadpool_t threadpool,
pthreadpool_task_5d_tile_2d_t function,
void* argument,
size_t range_i,
size_t range_j,
size_t range_k,
size_t range_l,
size_t range_m,
size_t tile_l,
size_t tile_m,
uint32_t flags);
void pthreadpool_parallelize_6d_tile_2d(
pthreadpool_t threadpool,
pthreadpool_task_6d_tile_2d_t function,
void* argument,
size_t range_i,
size_t range_j,
size_t range_k,
size_t range_l,
size_t range_m,
size_t range_n,
size_t tile_m,
size_t tile_n,
uint32_t flags);
#endif /* USE_INTERNAL_PTHREADPOOL_IMPL */
#ifdef __cplusplus
} /* extern "C" */

View File

@ -6,9 +6,9 @@
// External API
//
void pthreadpool_compute_1d(
pthreadpool_t threadpool,
pthreadpool_function_1d_t function,
void legacy_pthreadpool_compute_1d(
legacy_pthreadpool_t threadpool,
legacy_pthreadpool_function_1d_t function,
void* argument,
size_t range) {
if (threadpool == nullptr) {
@ -27,30 +27,31 @@ void pthreadpool_compute_1d(
range);
}
size_t pthreadpool_get_threads_count(pthreadpool_t threadpool) {
// The current fix only useful when XNNPACK calls pthreadpool_get_threads_count with nullptr.
void legacy_pthreadpool_parallelize_1d(
const legacy_pthreadpool_t threadpool,
const legacy_pthreadpool_function_1d_t function,
void* const argument,
const size_t range,
uint32_t) {
legacy_pthreadpool_compute_1d(threadpool, function, argument, range);
}
size_t legacy_pthreadpool_get_threads_count(legacy_pthreadpool_t threadpool) {
// The current fix only useful when XNNPACK calls legacy_pthreadpool_get_threads_count with nullptr.
if (threadpool == nullptr) {
return 1;
}
return reinterpret_cast<caffe2::ThreadPool*>(threadpool)->getNumThreads();
// TODO: Future fix: If we keep maintaining two different threadpools.
// Old C2 and new one for XNNPACK, then the we have two different pthreadpool pointer
// types. One is caffe2::Thredpool*, the other is pthreadpool* (pthreadpool_new_if_impl.c)
// XNNPACK calls pthreadpool_get_threads_count during op setup using pthreadpool*, and
// uses _parallelize_ interface for for actual work.
// While NNPACK uses caffe2::Threadpool*.
// Thus if pthreadpool_get_threads_count is getting called from XNNPACK we cannot
// reinterpret_cast it to ThreadPool. It will seg fault or worse will have unedfined behavior.
}
pthreadpool_t pthreadpool_create(size_t threads_count) {
legacy_pthreadpool_t legacy_pthreadpool_create(size_t threads_count) {
std::mutex thread_pool_creation_mutex_;
std::lock_guard<std::mutex> guard(thread_pool_creation_mutex_);
return reinterpret_cast<pthreadpool_t>(new caffe2::ThreadPool(threads_count));
return reinterpret_cast<legacy_pthreadpool_t>(new caffe2::ThreadPool(threads_count));
}
void pthreadpool_destroy(pthreadpool_t pthreadpool) {
void legacy_pthreadpool_destroy(legacy_pthreadpool_t pthreadpool) {
if (pthreadpool) {
caffe2::ThreadPool* threadpool =
reinterpret_cast<caffe2::ThreadPool*>(pthreadpool);

File diff suppressed because it is too large Load Diff

View File

@ -1,62 +0,0 @@
#pragma once
#include <stdint.h>
#if defined(__SSE__) || defined(__x86_64__)
#include <xmmintrin.h>
#endif
struct fpu_state {
#if defined(__SSE__) || defined(__x86_64__)
uint32_t mxcsr;
#elif defined(__arm__) && defined(__ARM_FP) && (__ARM_FP != 0)
uint32_t fpscr;
#elif defined(__aarch64__)
uint64_t fpcr;
#else
char unused;
#endif
};
static inline struct fpu_state get_fpu_state() {
struct fpu_state state = { 0 };
#if defined(__SSE__) || defined(__x86_64__)
state.mxcsr = (uint32_t) _mm_getcsr();
#elif defined(__arm__) && defined(__ARM_FP) && (__ARM_FP != 0)
__asm__ __volatile__("VMRS %[fpscr], fpscr" : [fpscr] "=r" (state.fpscr));
#elif defined(__aarch64__)
__asm__ __volatile__("MRS %[fpcr], fpcr" : [fpcr] "=r" (state.fpcr));
#endif
return state;
}
static inline void set_fpu_state(const struct fpu_state state) {
#if defined(__SSE__) || defined(__x86_64__)
_mm_setcsr((unsigned int) state.mxcsr);
#elif defined(__arm__) && defined(__ARM_FP) && (__ARM_FP != 0)
__asm__ __volatile__("VMSR fpscr, %[fpscr]" : : [fpscr] "r" (state.fpscr));
#elif defined(__aarch64__)
__asm__ __volatile__("MSR fpcr, %[fpcr]" : : [fpcr] "r" (state.fpcr));
#endif
}
static inline void disable_fpu_denormals() {
#if defined(__SSE__) || defined(__x86_64__)
_mm_setcsr(_mm_getcsr() | 0x8040);
#elif defined(__arm__) && defined(__ARM_FP) && (__ARM_FP != 0)
uint32_t fpscr;
__asm__ __volatile__(
"VMRS %[fpscr], fpscr\n"
"ORR %[fpscr], #0x1000000\n"
"VMSR fpscr, %[fpscr]\n"
: [fpscr] "=r" (fpscr));
#elif defined(__aarch64__)
uint64_t fpcr;
__asm__ __volatile__(
"MRS %[fpcr], fpcr\n"
"ORR %w[fpcr], %w[fpcr], 0x1000000\n"
"ORR %w[fpcr], %w[fpcr], 0x80000\n"
"MSR fpcr, %[fpcr]\n"
: [fpcr] "=r" (fpcr));
#endif
}

View File

@ -266,6 +266,8 @@ if(USE_NNPACK OR USE_QNNPACK OR USE_PYTORCH_QNNPACK OR USE_XNNPACK)
set(CPUINFO_LOG_LEVEL "error" CACHE STRING "")
set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "")
endif()
else()
set(DISABLE_NNPACK_AND_FAMILY ON)
endif()
set(CONFU_DEPENDENCIES_SOURCE_DIR ${PROJECT_BINARY_DIR}/confu-srcs
@ -281,45 +283,45 @@ if(INTERN_BUILD_MOBILE AND INTERN_USE_EIGEN_BLAS)
endif()
# ---[ pthreadpool
# QNNPACK and NNPACK both depend on pthreadpool, but when building with libtorch
# they should use the pthreadpool implementation under caffe2/utils/threadpool
# instead of the default implementation. To avoid confusion, add pthreadpool
# subdirectory explicitly with EXCLUDE_FROM_ALL property prior to QNNPACK/NNPACK
# does so, which will prevent it from installing the default pthreadpool library.
if(INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE AND (USE_QNNPACK OR USE_NNPACK OR USE_XNNPACK))
if(NOT DEFINED PTHREADPOOL_SOURCE_DIR)
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory")
if(NOT USE_SYSTEM_PTHREADPOOL AND (INTERN_BUILD_MOBILE OR NOT DISABLE_NNPACK_AND_FAMILY))
set(USE_PTHREADPOOL ON CACHE BOOL "" FORCE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_PTHREADPOOL")
# Opt for custom Caffe2 implementation on MSVC. Windows support seems to have
# been added to pthreadpool recently but the current third party revision we are
# using right now does not suppor it. Should unify later after updating pthreadpool.
if(MSVC)
set(USE_INTERNAL_PTHREADPOOL_IMPL ON CACHE BOOL "" FORCE)
# XNNPACK cannot link against a custom implementation of pthreadpool
caffe2_update_option(USE_XNNPACK OFF)
else()
# We would like to maintain the ability to build against the internal C2
# pthreadpool implementation for now, hence this flag. This flag is not
# exposed as a build option to the user and is purly internal.
set(USE_INTERNAL_PTHREADPOOL_IMPL OFF CACHE BOOL "" FORCE)
if(NOT DEFINED PTHREADPOOL_SOURCE_DIR)
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory")
endif()
if(NOT TARGET pthreadpool)
set(PTHREADPOOL_BUILD_TESTS OFF CACHE BOOL "")
set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE BOOL "")
add_subdirectory(
"${PTHREADPOOL_SOURCE_DIR}"
"${CONFU_DEPENDENCIES_BINARY_DIR}/pthreadpool")
set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON)
endif()
list(APPEND Caffe2_DEPENDENCY_LIBS pthreadpool)
endif()
if(NOT TARGET pthreadpool)
set(PTHREADPOOL_BUILD_TESTS OFF CACHE BOOL "")
set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE BOOL "")
add_subdirectory(
"${PTHREADPOOL_SOURCE_DIR}"
"${CONFU_DEPENDENCIES_BINARY_DIR}/pthreadpool"
EXCLUDE_FROM_ALL)
endif()
endif()
# XNNPACK has not option of like QNNPACK_CUSTOM_THREADPOOL
# that allows us to hijack pthreadpool interface.
# Thus not doing this ends up building pthreadpool as well as
# the internal implemenation of pthreadpool which results in symbol conflicts.
if(USE_XNNPACK AND NOT USE_SYSTEM_XNNPACK)
if(NOT DEFINED PTHREADPOOL_SOURCE_DIR)
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory")
endif()
if(NOT TARGET pthreadpool)
set(PTHREADPOOL_BUILD_TESTS OFF CACHE BOOL "")
set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE BOOL "")
add_subdirectory(
"${PTHREADPOOL_SOURCE_DIR}"
"${CONFU_DEPENDENCIES_BINARY_DIR}/pthreadpool"
EXCLUDE_FROM_ALL)
if(USE_INTERNAL_PTHREADPOOL_IMPL)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_INTERNAL_PTHREADPOOL_IMPL")
endif()
else()
set(USE_PTHREADPOOL OFF CACHE BOOL "" FORCE)
endif()
# ---[ Caffe2 uses cpuinfo library in the thread pool
@ -369,9 +371,12 @@ if(USE_QNNPACK)
endif()
if(NOT TARGET qnnpack)
if(NOT USE_SYSTEM_PTHREADPOOL AND USE_INTERNAL_PTHREADPOOL_IMPL)
set(QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "")
endif()
set(QNNPACK_BUILD_TESTS OFF CACHE BOOL "")
set(QNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "")
set(QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "")
set(QNNPACK_LIBRARY_TYPE "static" CACHE STRING "")
add_subdirectory(
"${QNNPACK_SOURCE_DIR}"
@ -379,7 +384,6 @@ if(USE_QNNPACK)
# We build static versions of QNNPACK and pthreadpool but link
# them into a shared library for Caffe2, so they need PIC.
set_property(TARGET qnnpack PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON)
endif()
@ -400,9 +404,12 @@ if(USE_PYTORCH_QNNPACK)
endif()
if(NOT TARGET pytorch_qnnpack)
if(NOT USE_SYSTEM_PTHREADPOOL AND USE_INTERNAL_PTHREADPOOL_IMPL)
set(PYTORCH_QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "")
endif()
set(PYTORCH_QNNPACK_BUILD_TESTS OFF CACHE BOOL "")
set(PYTORCH_QNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "")
set(PYTORCH_QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "")
set(PYTORCH_QNNPACK_LIBRARY_TYPE "static" CACHE STRING "")
add_subdirectory(
"${PYTORCH_QNNPACK_SOURCE_DIR}"
@ -410,9 +417,6 @@ if(USE_PYTORCH_QNNPACK)
# We build static versions of QNNPACK and pthreadpool but link
# them into a shared library for Caffe2, so they need PIC.
set_property(TARGET pytorch_qnnpack PROPERTY POSITION_INDEPENDENT_CODE ON)
if(NOT USE_SYSTEM_PTHREADPOOL)
set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON)
endif()
set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON)
endif()
@ -447,7 +451,6 @@ if(USE_XNNPACK AND NOT USE_SYSTEM_XNNPACK)
endif()
if(NOT TARGET XNNPACK)
set(XNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "")
set(XNNPACK_LIBRARY_TYPE "static" CACHE STRING "")
set(XNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "")
set(XNNPACK_BUILD_TESTS OFF CACHE BOOL "")
@ -457,15 +460,6 @@ if(USE_XNNPACK AND NOT USE_SYSTEM_XNNPACK)
"${CONFU_DEPENDENCIES_BINARY_DIR}/XNNPACK")
set_property(TARGET XNNPACK PROPERTY POSITION_INDEPENDENT_CODE ON)
# Context: pthreadpool_get_threads_count implementation that is built in pytorch, uses
# implementation defined in caffe2/utils/threadpool/pthreadpool_impl.cc. This implementation
# assumes the the pthreadpool* passed is of type caffe2::ThradPool and thus does reinterpret cast.
# This is not valid when we create pthreadpool via caffe2::xnnpack_threadpool, which is of type
# compatible with new pthreadpool interface and is used in PT's XNNPACK integration.
# Thus all the calls for pthreadpool_get_threads_count originating from XNNPACK must be routed
# appropriately to pthreadpool_get_threads_count_xnnpack, which does not do the aforementioned
# casting to caffe2::ThradPool. Once the threadpools are unified, we will not need this.
target_compile_definitions(XNNPACK PRIVATE -Dpthreadpool_get_threads_count=pthreadpool_get_threads_count_xnnpack)
endif()
include_directories(SYSTEM ${XNNPACK_INCLUDE_DIR})

View File

@ -59,9 +59,12 @@ if(ANDROID OR IOS OR ${CMAKE_SYSTEM_NAME} STREQUAL "Linux" OR ${CMAKE_SYSTEM_NAM
set(GOOGLETEST_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/googletest" CACHE STRING "Google Test source directory")
if(NOT TARGET nnpack)
if(NOT USE_SYSTEM_PTHREADPOOL AND USE_INTERNAL_PTHREADPOOL_IMPL)
set(NNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "")
endif()
set(NNPACK_BUILD_TESTS OFF CACHE BOOL "")
set(NNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "")
set(NNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "")
set(NNPACK_LIBRARY_TYPE "static" CACHE STRING "")
set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "")
set(CPUINFO_LIBRARY_TYPE "static" CACHE STRING "")

View File

@ -69,6 +69,11 @@ if(NOT @BUILD_SHARED_LIBS@)
list(APPEND TORCH_LIBRARIES ${XNNPACK_LIBRARY})
endif()
if(NOT @USE_INTERNAL_PTHREADPOOL_IMPL@)
find_library(PTHREADPOOL_LIBRARY pthreadpool PATHS "${TORCH_INSTALL_PREFIX}/lib")
list(APPEND TORCH_LIBRARIES ${PTHREADPOOL_LIBRARY})
endif()
if(@INTERN_USE_EIGEN_BLAS@)
find_library(EIGEN_BLAS_LIBRARY eigen_blas PATHS "${TORCH_INSTALL_PREFIX}/lib")
list(APPEND TORCH_LIBRARIES ${EIGEN_BLAS_LIBRARY})

View File

@ -63,7 +63,7 @@ targets.each do |target|
target.resources_build_phase.add_file_reference(config_file_ref, true)
end
puts "Linking static libraries..."
libs = ['libc10.a', 'libclog.a', 'libXNNPACK.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a']
libs = ['libc10.a', 'libclog.a', 'libpthreadpool.a', 'libXNNPACK.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a']
targets.each do |target|
target.frameworks_build_phases.clear
for lib in libs do

View File

@ -51,7 +51,7 @@ end
# link static libraries
target.frameworks_build_phases.clear
libs = ['libc10.a', 'libclog.a', 'libXNNPACK.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a']
libs = ['libc10.a', 'libclog.a', 'libpthreadpool.a', 'libXNNPACK.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a']
for lib in libs do
path = "#{install_path}/lib/#{lib}"
if File.exist?(path)