mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Sync mobile codebase changes back to fbcode
Summary: Rather chunky sync of changes made exclusively to mobile codebases back to fbcode. Reviewed By: ajtulloch Differential Revision: D5314405 fbshipit-source-id: c4d0a7244468f953eb63288306bc9bc78eb9e1be
This commit is contained in:
committed by
Facebook Github Bot
parent
4a81b0f24a
commit
9b9df3fbeb
@ -53,7 +53,7 @@ struct DefaultCPUAllocator final : CPUAllocator {
|
||||
#else
|
||||
CAFFE_ENFORCE_EQ(posix_memalign(&data, gCaffe2Alignment, nbytes), 0);
|
||||
#endif
|
||||
CHECK(data) << "Failed to allocate " << nbytes << " bytes.";
|
||||
CAFFE_ENFORCE(data);
|
||||
memset(data, 0, nbytes);
|
||||
return {data, Delete};
|
||||
}
|
||||
|
@ -40,17 +40,22 @@ Predictor::Predictor(
|
||||
CAFFE_ENFORCE(ws_.CreateNet(run_net));
|
||||
}
|
||||
|
||||
void Predictor::run(const TensorVector& inputs, TensorVector* outputs) {
|
||||
Predictor::~Predictor() {}
|
||||
|
||||
bool Predictor::run(const TensorVector& inputs, TensorVector* outputs) {
|
||||
CAFFE_ENFORCE(inputs.size() <= run_net_.external_input_size());
|
||||
for (auto i = 0; i < inputs.size(); ++i) {
|
||||
shareInputTensor(&ws_, run_net_.external_input(i), inputs[i]);
|
||||
}
|
||||
|
||||
CAFFE_ENFORCE(ws_.RunNet(run_net_.name()));
|
||||
if (!ws_.RunNet(run_net_.name())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
outputs->resize(run_net_.external_output_size());
|
||||
for (auto i = 0; i < outputs->size(); ++i) {
|
||||
(*outputs)[i] = extractOutputTensor(&ws_, run_net_.external_output(i));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ class Predictor {
|
||||
const NetDef& init_net,
|
||||
const NetDef& run_net,
|
||||
Workspace* parent = nullptr);
|
||||
~Predictor();
|
||||
|
||||
// Executes `run_net` on the inputs.
|
||||
// The first `inputs.size()` inputs from run_net::external_inputs
|
||||
@ -24,7 +25,9 @@ class Predictor {
|
||||
|
||||
// Postcondition:
|
||||
// outputs->size() == run_net.external_inputs.size()
|
||||
void run(const TensorVector& inputs, TensorVector* outputs);
|
||||
|
||||
// Returns true on success
|
||||
bool run(const TensorVector& inputs, TensorVector* outputs);
|
||||
|
||||
const NetDef& def() const {
|
||||
return run_net_;
|
||||
|
@ -16,17 +16,6 @@ CAFFE2_DEFINE_bool(
|
||||
false,
|
||||
"If true, workspace destructor will print all blob shapes");
|
||||
|
||||
#if CAFFE2_MOBILE
|
||||
// Threadpool restrictions
|
||||
|
||||
// Whether or not threadpool caps apply to Android
|
||||
CAFFE2_DEFINE_int(caffe2_threadpool_android_cap, true, "");
|
||||
|
||||
// Whether or not threadpool caps apply to iOS
|
||||
CAFFE2_DEFINE_int(caffe2_threadpool_ios_cap, false, "");
|
||||
|
||||
#endif // CAFFE2_MOBILE
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
void Workspace::PrintBlobSizes() {
|
||||
@ -143,8 +132,7 @@ const Blob* Workspace::GetBlob(const string& name) const {
|
||||
}
|
||||
|
||||
Blob* Workspace::GetBlob(const string& name) {
|
||||
return const_cast<Blob*>(
|
||||
static_cast<const Workspace*>(this)->GetBlob(name));
|
||||
return const_cast<Blob*>(static_cast<const Workspace*>(this)->GetBlob(name));
|
||||
}
|
||||
|
||||
NetBase* Workspace::CreateNet(const NetDef& net_def, bool overwrite) {
|
||||
@ -232,51 +220,18 @@ bool Workspace::RunNetOnce(const NetDef& net_def) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Workspace::RunPlan(const PlanDef& plan,
|
||||
ShouldContinue shouldContinue) {
|
||||
bool Workspace::RunPlan(const PlanDef& plan, ShouldContinue shouldContinue) {
|
||||
return RunPlanOnWorkspace(this, plan, shouldContinue);
|
||||
}
|
||||
|
||||
#if CAFFE2_MOBILE
|
||||
ThreadPool* Workspace::GetThreadPool() {
|
||||
std::lock_guard<std::mutex> guard(thread_pool_creation_mutex_);
|
||||
|
||||
if (!thread_pool_) {
|
||||
int numThreads = std::thread::hardware_concurrency();
|
||||
|
||||
bool applyCap = false;
|
||||
#if CAFFE2_ANDROID
|
||||
applyCap = caffe2::FLAGS_caffe2_threadpool_android_cap;
|
||||
#elif CAFFE2_IOS
|
||||
applyCap = caffe2::FLAGS_caffe2_threadpool_ios_cap;
|
||||
#else
|
||||
#error Undefined architecture
|
||||
#endif
|
||||
|
||||
if (applyCap) {
|
||||
// 1 core -> 1 thread
|
||||
// 2 cores -> 2 threads
|
||||
// 4 cores -> 3 threads
|
||||
// 8 cores -> 4 threads
|
||||
// more, continue limiting to half of available cores
|
||||
|
||||
if (numThreads <= 3) {
|
||||
// no change
|
||||
} else if (numThreads <= 5) {
|
||||
// limit to 3
|
||||
numThreads = 3;
|
||||
} else {
|
||||
// Use half the cores
|
||||
numThreads = numThreads / 2;
|
||||
}
|
||||
}
|
||||
|
||||
LOG(INFO) << "Constructing thread pool with " << numThreads << " threads";
|
||||
thread_pool_.reset(new ThreadPool(numThreads));
|
||||
thread_pool_ = ThreadPool::defaultThreadPool();
|
||||
}
|
||||
|
||||
return thread_pool_.get();
|
||||
}
|
||||
#endif // CAFFE2_MOBILE
|
||||
|
||||
} // namespace caffe2
|
||||
} // namespace caffe2
|
||||
|
@ -548,14 +548,20 @@ bool ConvTransposeMobileOp<T, Context>::RunOnDeviceWithOrderNCHW() {
|
||||
// need buffers for the worker threads
|
||||
size_t colBlockSize = W + kernel_w_ / stride_w_;
|
||||
size_t threadYBufferSize = C * outputH * colBlockSize * stride_w_;
|
||||
// Require 16 byte alignment, so 4-element alignment as these are floats.
|
||||
size_t threadYBufferSizeAligned =
|
||||
((C * outputH * colBlockSize * stride_w_ + 3) / 4) * 4;
|
||||
size_t threadColBufferSize = C * kernel_h_ * kernel_w_ * W;
|
||||
|
||||
// Work around GCC 4.9 bug when this is declared inside the inner lambda.
|
||||
auto runLocalTile = [&](TensorCPU* threadBuffer, int threadId, size_t tileId) {
|
||||
auto localYData = threadBuffer->template mutable_data<T>() + threadId * threadYBufferSize;
|
||||
auto runLocalTile = [&](TensorCPU* threadBuffer,
|
||||
int threadId,
|
||||
size_t tileId) {
|
||||
auto localYData = threadBuffer->template mutable_data<T>() +
|
||||
threadId * threadYBufferSizeAligned;
|
||||
|
||||
auto localColBufferData = threadBuffer->template mutable_data<T>() +
|
||||
numThreads * threadYBufferSize + threadId * threadColBufferSize;
|
||||
numThreads * threadYBufferSizeAligned + threadId * threadColBufferSize;
|
||||
|
||||
runTileContiguous<T, Context>(tileId,
|
||||
N,
|
||||
@ -578,11 +584,14 @@ bool ConvTransposeMobileOp<T, Context>::RunOnDeviceWithOrderNCHW() {
|
||||
};
|
||||
|
||||
auto f = [&](Tensor<Context>* threadBuffer) {
|
||||
threadBuffer->Resize(numThreads * threadYBufferSize + numThreads * threadColBufferSize);
|
||||
threadBuffer->Resize(
|
||||
numThreads * threadYBufferSizeAligned +
|
||||
numThreads * threadColBufferSize);
|
||||
// Group together thread buffers for accumulation
|
||||
std::vector<T*> toSum(numThreads - 1);
|
||||
for (int i = 1; i < numThreads; ++i) {
|
||||
toSum[i - 1] = threadBuffer->template mutable_data<T>() + i * threadYBufferSize;
|
||||
toSum[i - 1] = threadBuffer->template mutable_data<T>() +
|
||||
i * threadYBufferSizeAligned;
|
||||
}
|
||||
|
||||
for (auto image_id = 0; image_id < N; ++image_id) {
|
||||
@ -591,7 +600,10 @@ bool ConvTransposeMobileOp<T, Context>::RunOnDeviceWithOrderNCHW() {
|
||||
// The column buffers are overwritten by the matrix multiplication
|
||||
// each time, so we need not clear them out each round
|
||||
math::Set<T, Context>(
|
||||
numThreads * threadYBufferSize, 0, threadBuffer->template mutable_data<T>(), &context_);
|
||||
numThreads * threadYBufferSizeAligned,
|
||||
0,
|
||||
threadBuffer->template mutable_data<T>(),
|
||||
&context_);
|
||||
|
||||
// Run tiled gemm and col2im in our threadpool; all of these tiles
|
||||
// are guaranteed to be full tiles
|
||||
|
@ -9,82 +9,25 @@ using std::min;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
class AveragePool {
|
||||
public:
|
||||
static float initialize() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static void process(
|
||||
const int x_col,
|
||||
const int y_col,
|
||||
ConstEigenMatrixMap<float>& x_mat,
|
||||
EigenMatrixMap<float>& y_mat) {
|
||||
y_mat.col(y_col) += x_mat.col(x_col);
|
||||
}
|
||||
|
||||
static void process(const T& x_data, T& y_data) {
|
||||
y_data += x_data;
|
||||
}
|
||||
|
||||
static void finalize(const int size, T& y_data) {
|
||||
y_data /= size;
|
||||
}
|
||||
|
||||
static void
|
||||
finalize(const int size, const int col, EigenMatrixMap<float>& y_mat) {
|
||||
y_mat.col(col) /= size;
|
||||
}
|
||||
|
||||
static bool runNeon() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class MaxPool {
|
||||
public:
|
||||
static float initialize() {
|
||||
return std::numeric_limits<float>::lowest();
|
||||
}
|
||||
|
||||
static void process(
|
||||
const int x_col,
|
||||
const int y_col,
|
||||
ConstEigenMatrixMap<float>& x_mat,
|
||||
EigenMatrixMap<float>& y_mat) {
|
||||
y_mat.col(y_col) = y_mat.col(y_col).cwiseMax(x_mat.col(x_col));
|
||||
}
|
||||
|
||||
static void process(const T& x_data, T& y_data) {
|
||||
if (x_data > y_data) {
|
||||
y_data = x_data;
|
||||
}
|
||||
}
|
||||
|
||||
static void finalize(const int /*size*/, T& /*y_data*/) {}
|
||||
|
||||
static void finalize(
|
||||
const int /*size*/,
|
||||
const int col,
|
||||
EigenMatrixMap<float>& /*y_mat*/) {}
|
||||
|
||||
static bool runNeon() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef __ARM_NEON__
|
||||
|
||||
bool isNeonEligible(int inputH, int inputW,
|
||||
int outputH, int outputW,
|
||||
int kH, int kW,
|
||||
int strideH, int strideW,
|
||||
int padT, int padL, int padB, int padR,
|
||||
int dilationH, int dilationW,
|
||||
const float* input,
|
||||
float* output) {
|
||||
bool isNeon4x4p0s0Eligible(
|
||||
int inputH,
|
||||
int inputW,
|
||||
int outputH,
|
||||
int outputW,
|
||||
int kH,
|
||||
int kW,
|
||||
int strideH,
|
||||
int strideW,
|
||||
int padT,
|
||||
int padL,
|
||||
int padB,
|
||||
int padR,
|
||||
int dilationH,
|
||||
int dilationW,
|
||||
const float* input,
|
||||
float* output) {
|
||||
// Use this kernel only if:
|
||||
// Kernel width is 4x4
|
||||
// Kernel stride is 4x4
|
||||
@ -103,20 +46,21 @@ bool isNeonEligible(int inputH, int inputW,
|
||||
bool outputOk = ((inputH % outputH) == 0) && ((inputW % outputW) == 0);
|
||||
bool inputOk = (inputW % 4 == 0) && (inputH % 4 == 0);
|
||||
bool alignOk = isPointerAligned(input, sizeof(float32x4_t)) &&
|
||||
isPointerAligned(output, sizeof(float32x4_t));
|
||||
isPointerAligned(output, sizeof(float32x4_t));
|
||||
|
||||
return kernelOk && strideOk && padOk && dilationOk &&
|
||||
outputOk && inputOk && alignOk;
|
||||
return kernelOk && strideOk && padOk && dilationOk && outputOk && inputOk &&
|
||||
alignOk;
|
||||
}
|
||||
|
||||
// Vectorizes 4x4p0s0 averge pooling for ARM NEON
|
||||
void avgPoolNeon4x4p0s0Plane(int inputH, int inputW,
|
||||
const float* input,
|
||||
float* output) {
|
||||
void avgPoolNeon4x4p0s0Plane(
|
||||
int inputH,
|
||||
int inputW,
|
||||
const float* input,
|
||||
float* output) {
|
||||
constexpr int kKernelHeight = 4;
|
||||
constexpr int kKernelWidth = 4;
|
||||
constexpr float kDiv =
|
||||
(1.0f / ((float) kKernelHeight * (float) kKernelWidth));
|
||||
constexpr float kDiv = (1.0f / ((float)kKernelHeight * (float)kKernelWidth));
|
||||
|
||||
// Handle portion that can be unrolled by 4
|
||||
constexpr int kUnroll = 4;
|
||||
@ -202,10 +146,13 @@ void avgPoolNeon4x4p0s0Plane(int inputH, int inputW,
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
runNeonAveragePool4x4p0s0NCHW(int N, int C, int inputH, int inputW,
|
||||
const float* input,
|
||||
float* output) {
|
||||
void runNeonAveragePool4x4p0s0NCHW(
|
||||
int N,
|
||||
int C,
|
||||
int inputH,
|
||||
int inputW,
|
||||
const float* input,
|
||||
float* output) {
|
||||
// We only have the 4x4p0s0 implementation at present, which is
|
||||
// checked at a higher level
|
||||
int outputH = inputH / 4;
|
||||
@ -220,9 +167,291 @@ runNeonAveragePool4x4p0s0NCHW(int N, int C, int inputH, int inputW,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool isNeon2x2p0s0Eligible(
|
||||
int inputH,
|
||||
int inputW,
|
||||
int outputH,
|
||||
int outputW,
|
||||
int kH,
|
||||
int kW,
|
||||
int strideH,
|
||||
int strideW,
|
||||
int padT,
|
||||
int padL,
|
||||
int padB,
|
||||
int padR,
|
||||
int dilationH,
|
||||
int dilationW,
|
||||
const float* input,
|
||||
float* output) {
|
||||
// Use this kernel only if:
|
||||
// Kernel width is 2x2
|
||||
// Kernel stride is 2x2
|
||||
// Padding is 0
|
||||
// Dilation is 1
|
||||
// Output width and height are even divisors of input width
|
||||
// Input width and height are divisible by 4 (should be implied by
|
||||
// all of the above, but just check again)
|
||||
// Input and output pointers are aligned by float32x4_t
|
||||
|
||||
bool kernelOk = (kH == 2) && (kW == 2);
|
||||
bool strideOk = (strideH == 2) && (strideW == 2);
|
||||
bool padOk = (padT == 0) && (padL == 0) && (padB == 0) && (padR == 0);
|
||||
bool dilationOk = (dilationH == 1) && (dilationW == 1);
|
||||
|
||||
bool outputOk = ((inputH % outputH) == 0) && ((inputW % outputW) == 0);
|
||||
bool inputOk = (inputW % 4 == 0) && (inputH % 4 == 0);
|
||||
bool alignOk = isPointerAligned(input, sizeof(float32x4_t)) &&
|
||||
isPointerAligned(output, sizeof(float32x4_t));
|
||||
|
||||
return kernelOk && strideOk && padOk && dilationOk && outputOk && inputOk &&
|
||||
alignOk;
|
||||
}
|
||||
|
||||
// Vectorizes 2x2p0s0 averge pooling for ARM NEON
|
||||
void maxPoolNeon2x2p0s0Plane(
|
||||
int inputH,
|
||||
int inputW,
|
||||
const float* input,
|
||||
float* output) {
|
||||
constexpr int kKernelHeight = 2;
|
||||
constexpr int kKernelWidth = 2;
|
||||
|
||||
// Handle portion that can be unrolled by 4
|
||||
constexpr int kUnroll = 4;
|
||||
constexpr int kLoadSizeFloat = (sizeof(float32x4_t) / sizeof(float));
|
||||
constexpr int kLoadCols = kUnroll * kLoadSizeFloat;
|
||||
|
||||
if (inputW % kLoadCols == 0) {
|
||||
for (int h = 0; h < inputH; h += kKernelHeight) {
|
||||
float* outputRow = output + (h / kKernelHeight) * (inputW / kKernelWidth);
|
||||
const float* curInput = input + h * inputW;
|
||||
|
||||
for (int w = 0; w < inputW; w += kLoadCols) {
|
||||
float32x2_t hmax_0, hmax_1, hmax_2, hmax_3;
|
||||
{
|
||||
float32x4_t v0_0 = vld1q_f32_aligned(curInput + 0 * inputW);
|
||||
float32x4_t v0_1 = vld1q_f32_aligned(curInput + 1 * inputW);
|
||||
float32x4_t vmax = vmaxq_f32(v0_0, v0_1);
|
||||
hmax_0 = vpmax_f32(vget_low_f32(vmax), vget_high_f32(vmax));
|
||||
}
|
||||
curInput += kLoadSizeFloat;
|
||||
{
|
||||
float32x4_t v0_0 = vld1q_f32_aligned(curInput + 0 * inputW);
|
||||
float32x4_t v0_1 = vld1q_f32_aligned(curInput + 1 * inputW);
|
||||
float32x4_t vmax = vmaxq_f32(v0_0, v0_1);
|
||||
hmax_1 = vpmax_f32(vget_low_f32(vmax), vget_high_f32(vmax));
|
||||
}
|
||||
curInput += kLoadSizeFloat;
|
||||
{
|
||||
float32x4_t v0_0 = vld1q_f32_aligned(curInput + 0 * inputW);
|
||||
float32x4_t v0_1 = vld1q_f32_aligned(curInput + 1 * inputW);
|
||||
float32x4_t vmax = vmaxq_f32(v0_0, v0_1);
|
||||
hmax_2 = vpmax_f32(vget_low_f32(vmax), vget_high_f32(vmax));
|
||||
}
|
||||
curInput += kLoadSizeFloat;
|
||||
{
|
||||
float32x4_t v0_0 = vld1q_f32_aligned(curInput + 0 * inputW);
|
||||
float32x4_t v0_1 = vld1q_f32_aligned(curInput + 1 * inputW);
|
||||
float32x4_t vmax = vmaxq_f32(v0_0, v0_1);
|
||||
hmax_3 = vpmax_f32(vget_low_f32(vmax), vget_high_f32(vmax));
|
||||
}
|
||||
curInput += kLoadSizeFloat;
|
||||
|
||||
float32x4_t out_0 = vcombine_f32(hmax_0, hmax_1);
|
||||
float32x4_t out_1 = vcombine_f32(hmax_2, hmax_3);
|
||||
vst1q_f32_aligned(&outputRow[w / kKernelWidth + 0], out_0);
|
||||
vst1q_f32_aligned(&outputRow[w / kKernelWidth + 4], out_1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Not unrolled
|
||||
for (int h = 0; h < inputH; h += kKernelHeight) {
|
||||
const float* inputRow = input + h * inputW;
|
||||
float* outputRow = output + (h / kKernelHeight) * (inputW / kKernelWidth);
|
||||
|
||||
for (int w = 0; w < inputW; w += kKernelWidth * 2) {
|
||||
const float* curInput = inputRow + w;
|
||||
float32x4_t v0_0 = vld1q_f32_aligned(curInput + 0 * inputW);
|
||||
float32x4_t v0_1 = vld1q_f32_aligned(curInput + 1 * inputW);
|
||||
float32x4_t vmax = vmaxq_f32(v0_0, v0_1);
|
||||
float32x2_t hmax = vpmax_f32(vget_low_f32(vmax), vget_high_f32(vmax));
|
||||
vst1_f32(&outputRow[w / kKernelWidth], hmax);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void runNeonMaxPool2x2p0s0NCHW(
|
||||
int N,
|
||||
int C,
|
||||
int inputH,
|
||||
int inputW,
|
||||
const float* input,
|
||||
float* output) {
|
||||
// We only have the 2x2p0s0 implementation at present, which is
|
||||
// checked at a higher level
|
||||
int outputH = inputH / 2;
|
||||
int outputW = inputW / 2;
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int c = 0; c < C; ++c) {
|
||||
const float* curInput = input + (n * C + c) * inputH * inputW;
|
||||
float* curOutput = output + (n * C + c) * outputH * outputW;
|
||||
maxPoolNeon2x2p0s0Plane(inputH, inputW, curInput, curOutput);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // __ARM_NEON__
|
||||
|
||||
} // namespace
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
class AveragePool {
|
||||
public:
|
||||
static float initialize() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static void process(
|
||||
const int x_col,
|
||||
const int y_col,
|
||||
ConstEigenMatrixMap<float>& x_mat,
|
||||
EigenMatrixMap<float>& y_mat) {
|
||||
y_mat.col(y_col) += x_mat.col(x_col);
|
||||
}
|
||||
|
||||
static void process(const T& x_data, T& y_data) {
|
||||
y_data += x_data;
|
||||
}
|
||||
|
||||
static void finalize(const int size, T& y_data) {
|
||||
y_data /= size;
|
||||
}
|
||||
|
||||
static void
|
||||
finalize(const int size, const int col, EigenMatrixMap<float>& y_mat) {
|
||||
y_mat.col(col) /= size;
|
||||
}
|
||||
|
||||
static bool runSpecialized(
|
||||
int N,
|
||||
int C,
|
||||
int inputH,
|
||||
int inputW,
|
||||
int outputH,
|
||||
int outputW,
|
||||
int kH,
|
||||
int kW,
|
||||
int strideH,
|
||||
int strideW,
|
||||
int padT,
|
||||
int padL,
|
||||
int padB,
|
||||
int padR,
|
||||
int dilationH,
|
||||
int dilationW,
|
||||
const float* input,
|
||||
float* output) {
|
||||
#ifdef __ARM_NEON__
|
||||
if (isNeon4x4p0s0Eligible(
|
||||
inputH,
|
||||
inputW,
|
||||
outputH,
|
||||
outputW,
|
||||
kH,
|
||||
kW,
|
||||
strideH,
|
||||
strideW,
|
||||
padT,
|
||||
padL,
|
||||
padB,
|
||||
padR,
|
||||
dilationH,
|
||||
dilationW,
|
||||
input,
|
||||
output)) {
|
||||
runNeonAveragePool4x4p0s0NCHW(N, C, inputH, inputW, input, output);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class MaxPool {
|
||||
public:
|
||||
static float initialize() {
|
||||
return std::numeric_limits<float>::lowest();
|
||||
}
|
||||
|
||||
static void process(
|
||||
const int x_col,
|
||||
const int y_col,
|
||||
ConstEigenMatrixMap<float>& x_mat,
|
||||
EigenMatrixMap<float>& y_mat) {
|
||||
y_mat.col(y_col) = y_mat.col(y_col).cwiseMax(x_mat.col(x_col));
|
||||
}
|
||||
|
||||
static void process(const T& x_data, T& y_data) {
|
||||
if (x_data > y_data) {
|
||||
y_data = x_data;
|
||||
}
|
||||
}
|
||||
|
||||
static void finalize(const int /*size*/, T& /*y_data*/) {}
|
||||
|
||||
static void finalize(
|
||||
const int /*size*/,
|
||||
const int col,
|
||||
EigenMatrixMap<float>& /*y_mat*/) {}
|
||||
|
||||
static bool runSpecialized(
|
||||
int N,
|
||||
int C,
|
||||
int inputH,
|
||||
int inputW,
|
||||
int outputH,
|
||||
int outputW,
|
||||
int kH,
|
||||
int kW,
|
||||
int strideH,
|
||||
int strideW,
|
||||
int padT,
|
||||
int padL,
|
||||
int padB,
|
||||
int padR,
|
||||
int dilationH,
|
||||
int dilationW,
|
||||
const float* input,
|
||||
float* output) {
|
||||
#ifdef __ARM_NEON__
|
||||
if (isNeon2x2p0s0Eligible(
|
||||
inputH,
|
||||
inputW,
|
||||
outputH,
|
||||
outputW,
|
||||
kH,
|
||||
kW,
|
||||
strideH,
|
||||
strideW,
|
||||
padT,
|
||||
padL,
|
||||
padB,
|
||||
padR,
|
||||
dilationH,
|
||||
dilationW,
|
||||
input,
|
||||
output)) {
|
||||
runNeonMaxPool2x2p0s0NCHW(N, C, inputH, inputW, input, output);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, class Context, typename PoolType>
|
||||
bool PoolOp<T, Context, PoolType>::RunOnDeviceWithOrderNCHW() {
|
||||
@ -243,30 +472,30 @@ bool PoolOp<T, Context, PoolType>::RunOnDeviceWithOrderNCHW() {
|
||||
int pooled_width = kernel_.size() > 1 ? Y->dim32(3) : 1;
|
||||
int pooled_depth = kernel_.size() > 2 ? Y->dim32(4) : 1;
|
||||
|
||||
#ifdef __ARM_NEON__
|
||||
// We specialize certain variants on ARM for vectorization
|
||||
if (PoolType::runNeon() && isNeonEligible(
|
||||
X.dim32(2),
|
||||
X.dim32(3),
|
||||
Y->dim32(2),
|
||||
Y->dim32(3),
|
||||
kernel_h(),
|
||||
kernel_w(),
|
||||
stride_h(),
|
||||
stride_w(),
|
||||
pad_t(),
|
||||
pad_l(),
|
||||
pad_b(),
|
||||
pad_r(),
|
||||
dilation_h(),
|
||||
dilation_w(),
|
||||
Xdata,
|
||||
Ydata)) {
|
||||
runNeonAveragePool4x4p0s0NCHW(
|
||||
X.dim32(0), X.dim32(1), X.dim32(2), X.dim32(3), Xdata, Ydata);
|
||||
if (kernel_.size() == 2 &&
|
||||
PoolType::runSpecialized(
|
||||
X.dim32(0),
|
||||
X.dim32(1),
|
||||
X.dim32(2),
|
||||
X.dim32(3),
|
||||
Y->dim32(2),
|
||||
Y->dim32(3),
|
||||
kernel_h(),
|
||||
kernel_w(),
|
||||
stride_h(),
|
||||
stride_w(),
|
||||
pad_t(),
|
||||
pad_l(),
|
||||
pad_b(),
|
||||
pad_r(),
|
||||
dilation_h(),
|
||||
dilation_w(),
|
||||
Xdata,
|
||||
Ydata)) {
|
||||
return true;
|
||||
}
|
||||
#endif // __ARM_NEON__
|
||||
|
||||
switch (kernel_.size()) {
|
||||
case 1:
|
||||
for (int n = 0; n < X.dim32(0); ++n) {
|
||||
@ -450,8 +679,8 @@ bool PoolOp<T, Context, PoolType>::RunOnDeviceWithOrderNHWC() {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
CAFFE_THROW("Unsupported pooling size : ", kernel_.size());
|
||||
return false;
|
||||
@ -465,48 +694,60 @@ REGISTER_CPU_OPERATOR(
|
||||
PoolOp<float, CPUContext, AveragePool<float>>);
|
||||
|
||||
OPERATOR_SCHEMA(AveragePool)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool)
|
||||
.SetDoc(R"DOC(
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool)
|
||||
.SetDoc(R"DOC(
|
||||
AveragePool consumes an input blob X and applies average pooling across the
|
||||
the blob according to kernel sizes, stride sizes, and pad lengths defined by the
|
||||
ConvPoolOpBase operator. Average pooling consisting of averaging all values of a
|
||||
subset of the input tensor according to the kernel size and downsampling the
|
||||
data into the output blob Y for further processing.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input data tensor from the previous operator; dimensions "
|
||||
"depend on whether the NCHW or NHWC operators are being used. For example, "
|
||||
"in the former, the input has size (N x C x H x W), where N is the batch "
|
||||
"size, C is the number of channels, and H and W are the height and the width "
|
||||
"of the data. The corresponding permutation of dimensions is used in the "
|
||||
"latter case. ")
|
||||
.Output(0, "Y", "Output data tensor from average pooling across the input "
|
||||
"tensor. Dimensions will vary based on various kernel, stride, and pad "
|
||||
"sizes.");
|
||||
.Input(
|
||||
0,
|
||||
"X",
|
||||
"Input data tensor from the previous operator; dimensions depend on "
|
||||
"whether the NCHW or NHWC operators are being used. For example, in "
|
||||
"the former, the input has size (N x C x H x W), where N is the batch "
|
||||
"size, C is the number of channels, and H and W are the height and the "
|
||||
"width of the data. The corresponding permutation of dimensions is "
|
||||
"used in the latter case.")
|
||||
.Output(
|
||||
0,
|
||||
"Y",
|
||||
"Output data tensor from average pooling across the input "
|
||||
"tensor. Dimensions will vary based on various kernel, stride, and pad "
|
||||
"sizes.");
|
||||
|
||||
REGISTER_CPU_OPERATOR(MaxPool, PoolOp<float, CPUContext, MaxPool<float>>);
|
||||
|
||||
OPERATOR_SCHEMA(MaxPool)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool)
|
||||
.SetDoc(R"DOC(
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool)
|
||||
.SetDoc(R"DOC(
|
||||
MaxPool consumes an input blob X and applies max pooling across the
|
||||
the blob according to kernel sizes, stride sizes, and pad lengths defined by the
|
||||
ConvPoolOpBase operator. Max pooling consisting of taking the maximumvalue of a
|
||||
subset of the input tensor according to the kernel size and downsampling the
|
||||
data into the output blob Y for further processing.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input data tensor from the previous operator; dimensions "
|
||||
"depend on whether the NCHW or NHWC operators are being used. For example, "
|
||||
"in the former, the input has size (N x C x H x W), where N is the batch "
|
||||
"size, C is the number of channels, and H and W are the height and the width "
|
||||
"of the data. The corresponding permutation of dimensions is used in the "
|
||||
"latter case. ")
|
||||
.Output(0, "Y", "Output data tensor from max pooling across the input "
|
||||
"tensor. Dimensions will vary based on various kernel, stride, and pad "
|
||||
"sizes.");
|
||||
.Input(
|
||||
0,
|
||||
"X",
|
||||
"Input data tensor from the previous operator; dimensions depend on "
|
||||
"whether the NCHW or NHWC operators are being used. For example, in "
|
||||
"the former, the input has size (N x C x H x W), where N is the batch "
|
||||
"size, C is the number of channels, and H and W are the height and the "
|
||||
"width of the data. The corresponding permutation of dimensions is "
|
||||
"used in the latter case.")
|
||||
.Output(
|
||||
0,
|
||||
"Y",
|
||||
"Output data tensor from max pooling across the input "
|
||||
"tensor. Dimensions will vary based on various kernel, stride, and pad "
|
||||
"sizes.");
|
||||
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
|
@ -22,6 +22,15 @@
|
||||
|
||||
#include "caffe2/core/init.h"
|
||||
|
||||
#if CAFFE2_ANDROID
|
||||
#ifndef SYS_gettid
|
||||
#define SYS_gettid __NR_gettid
|
||||
#endif
|
||||
#ifndef SYS_tgkill
|
||||
#define SYS_tgkill __NR_tgkill
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
struct sigaction previousSighup;
|
||||
|
@ -1,8 +1,20 @@
|
||||
#include "caffe2/utils/threadpool/ThreadPool.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
CAFFE2_DEFINE_bool(caffe2_threadpool_force_inline, false,
|
||||
"Force to always run jobs on the calling thread");
|
||||
#if CAFFE2_ANDROID
|
||||
#include <cpu-features.h>
|
||||
#endif
|
||||
|
||||
CAFFE2_DEFINE_bool(
|
||||
caffe2_threadpool_force_inline,
|
||||
false,
|
||||
"Force to always run jobs on the calling thread");
|
||||
|
||||
// Whether or not threadpool caps apply to Android
|
||||
CAFFE2_DEFINE_int(caffe2_threadpool_android_cap, true, "");
|
||||
|
||||
// Whether or not threadpool caps apply to iOS
|
||||
CAFFE2_DEFINE_int(caffe2_threadpool_ios_cap, false, "");
|
||||
|
||||
#if CAFFE2_THREADPOOL_MOBILE
|
||||
|
||||
@ -20,6 +32,50 @@ constexpr size_t kDefaultMinWorkSize = 80;
|
||||
constexpr float kDefaultImbalanceRatio = 1.0f;
|
||||
#endif
|
||||
|
||||
std::unique_ptr<ThreadPool> ThreadPool::defaultThreadPool() {
|
||||
int numThreads = std::thread::hardware_concurrency();
|
||||
|
||||
#ifdef CAFFE2_ANDROID
|
||||
// std::thread::hardware_concurrency returns online cores
|
||||
// (sysconf(_SC_NPROCESSORS_ONLN)), but we want the total number of CPUs. In
|
||||
// most cases they will match, but since the threadpool is instantiated once,
|
||||
// we want the number of threads for each device to be predictable.
|
||||
int numCpus = android_getCpuCount();
|
||||
LOG(INFO) << "Android cpu count: " << numCpus
|
||||
<< ", hardware_concurrency: " << numThreads;
|
||||
numThreads = numCpus;
|
||||
#endif
|
||||
|
||||
bool applyCap = false;
|
||||
#if CAFFE2_ANDROID
|
||||
applyCap = caffe2::FLAGS_caffe2_threadpool_android_cap;
|
||||
#elif CAFFE2_IOS
|
||||
applyCap = caffe2::FLAGS_caffe2_threadpool_ios_cap;
|
||||
#else
|
||||
#error Undefined architecture
|
||||
#endif
|
||||
|
||||
if (applyCap) {
|
||||
// 1 core -> 1 thread
|
||||
// 2 cores -> 2 threads
|
||||
// 4 cores -> 2 threads
|
||||
// 8 cores -> 4 threads
|
||||
// more, continue limiting to half of available cores
|
||||
|
||||
if (numThreads <= 3) {
|
||||
// no change
|
||||
} else if (numThreads <= 5) {
|
||||
// limit to 2
|
||||
numThreads = 2;
|
||||
} else {
|
||||
// Use half the cores
|
||||
numThreads = numThreads / 2;
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "Constructing thread pool with " << numThreads << " threads";
|
||||
return caffe2::make_unique<ThreadPool>(numThreads);
|
||||
}
|
||||
|
||||
ThreadPool::ThreadPool(int numThreads)
|
||||
: fn_(nullptr),
|
||||
workItemsPending_(0),
|
||||
@ -27,7 +83,8 @@ ThreadPool::ThreadPool(int numThreads)
|
||||
threadsReady_(0),
|
||||
minWorkSize_(kDefaultMinWorkSize)
|
||||
#ifdef CAFFE2_THREADPOOL_MAIN_IMBALANCE
|
||||
, imbalanceRatio_(kDefaultImbalanceRatio)
|
||||
,
|
||||
imbalanceRatio_(kDefaultImbalanceRatio)
|
||||
#endif
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
@ -35,15 +92,13 @@ ThreadPool::ThreadPool(int numThreads)
|
||||
// All worker threads (and the main thread) have a ThreadInfo
|
||||
for (auto i = 0; i < numThreads; ++i) {
|
||||
threadInfo_.emplace_back(
|
||||
MakeAligned<ThreadInfo>::make(kCacheLineSize, i, numThreads));
|
||||
MakeAligned<ThreadInfo>::make(kCacheLineSize, i, numThreads));
|
||||
}
|
||||
|
||||
// The first ThreadInfo is for the main thread
|
||||
for (auto i = 1; i < numThreads; ++i) {
|
||||
auto pInfo = &(threadInfo_[i]);
|
||||
auto fn = [pInfo, this, i]() {
|
||||
(*pInfo)->threadMain(i, this);
|
||||
};
|
||||
auto fn = [pInfo, this, i]() { (*pInfo)->threadMain(i, this); };
|
||||
|
||||
threads_.emplace_back(std::thread(std::move(fn)));
|
||||
}
|
||||
@ -65,26 +120,23 @@ ThreadPool::~ThreadPool() {
|
||||
}
|
||||
}
|
||||
|
||||
int
|
||||
ThreadPool::getNumThreads() const {
|
||||
int ThreadPool::getNumThreads() const {
|
||||
std::lock_guard<std::mutex> guard(executionMutex_);
|
||||
|
||||
return threadInfo_.size();
|
||||
}
|
||||
|
||||
// Sets the minimum work size (range) for which to invoke the
|
||||
// threadpool; work sizes smaller than this will just be run on the
|
||||
// main (calling) thread
|
||||
void
|
||||
ThreadPool::setMinWorkSize(size_t size) {
|
||||
// Sets the minimum work size (range) for which to invoke the
|
||||
// threadpool; work sizes smaller than this will just be run on the
|
||||
// main (calling) thread
|
||||
void ThreadPool::setMinWorkSize(size_t size) {
|
||||
std::lock_guard<std::mutex> guard(executionMutex_);
|
||||
|
||||
minWorkSize_ = size;
|
||||
}
|
||||
|
||||
#ifdef CAFFE2_THREADPOOL_MAIN_IMBALANCE
|
||||
void
|
||||
ThreadPool::setImbalanceRatio(float ratio) {
|
||||
void ThreadPool::setImbalanceRatio(float ratio) {
|
||||
std::lock_guard<std::mutex> guard(executionMutex_);
|
||||
|
||||
imbalanceRatio_ = ratio;
|
||||
@ -92,8 +144,7 @@ ThreadPool::setImbalanceRatio(float ratio) {
|
||||
#endif
|
||||
|
||||
#ifdef CAFFE2_THREADPOOL_STATS
|
||||
std::vector<ThreadStats>
|
||||
ThreadPool::getStats(bool reset) {
|
||||
std::vector<ThreadStats> ThreadPool::getStats(bool reset) {
|
||||
std::lock_guard<std::mutex> guard(executionMutex_);
|
||||
|
||||
// Set up thread state
|
||||
@ -123,14 +174,13 @@ ThreadPool::getStats(bool reset) {
|
||||
}
|
||||
#endif
|
||||
|
||||
void
|
||||
ThreadPool::run(const std::function<void(int, size_t)>& fn, size_t range) {
|
||||
void ThreadPool::run(const std::function<void(int, size_t)>& fn, size_t range) {
|
||||
std::lock_guard<std::mutex> guard(executionMutex_);
|
||||
|
||||
// If there are no worker threads, or if the range is too small (too
|
||||
// little work), just run locally
|
||||
bool runLocally = threads_.empty() || range < minWorkSize_ ||
|
||||
FLAGS_caffe2_threadpool_force_inline;
|
||||
FLAGS_caffe2_threadpool_force_inline;
|
||||
|
||||
auto numThreads = threadInfo_.size();
|
||||
size_t workUnitsPerThread = 0;
|
||||
@ -140,25 +190,24 @@ ThreadPool::run(const std::function<void(int, size_t)>& fn, size_t range) {
|
||||
if (!runLocally) {
|
||||
size_t workUnitsPerThread = (numThreads + range - 1) / numThreads;
|
||||
|
||||
// On mobile devices (especially big.LITTLE cores), there is
|
||||
// significant lag in getting other threads to participate versus
|
||||
// the current thread, which is likely already running on a big
|
||||
// core.
|
||||
// Based on tests, the main thread will execute (through its own
|
||||
// work and stealing others) about 25% more work than other
|
||||
// threads.
|
||||
// To reduce the work stealing overhead, give the main thread 25%
|
||||
// more work to start with.
|
||||
// On mobile devices (especially big.LITTLE cores), there is
|
||||
// significant lag in getting other threads to participate versus
|
||||
// the current thread, which is likely already running on a big
|
||||
// core.
|
||||
// Based on tests, the main thread will execute (through its own
|
||||
// work and stealing others) about 25% more work than other
|
||||
// threads.
|
||||
// To reduce the work stealing overhead, give the main thread 25%
|
||||
// more work to start with.
|
||||
#ifdef CAFFE2_THREADPOOL_MAIN_IMBALANCE
|
||||
firstThreadWork = (size_t) (imbalanceRatio_ * workUnitsPerThread);
|
||||
firstThreadWork = (size_t)(imbalanceRatio_ * workUnitsPerThread);
|
||||
if (firstThreadWork >= range) {
|
||||
// give all to first thread
|
||||
runLocally = true;
|
||||
}
|
||||
|
||||
size_t remainderWork = range - firstThreadWork;
|
||||
otherThreadWork =
|
||||
((numThreads - 1) + remainderWork - 1) / (numThreads - 1);
|
||||
otherThreadWork = ((numThreads - 1) + remainderWork - 1) / (numThreads - 1);
|
||||
#else
|
||||
firstThreadWork = workUnitsPerThread;
|
||||
otherThreadWork = workUnitsPerThread;
|
||||
@ -251,8 +300,7 @@ ThreadPool::run(const std::function<void(int, size_t)>& fn, size_t range) {
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ThreadInfo::threadMain(int threadId, ThreadPool* pool) {
|
||||
void ThreadInfo::threadMain(int threadId, ThreadPool* pool) {
|
||||
long lastProcessedWorkId = 0;
|
||||
|
||||
while (true) {
|
||||
@ -289,8 +337,7 @@ ThreadInfo::threadMain(int threadId, ThreadPool* pool) {
|
||||
}
|
||||
}
|
||||
|
||||
bool
|
||||
ThreadInfo::runAndSteal(int threadId, ThreadPool* pool) {
|
||||
bool ThreadInfo::runAndSteal(int threadId, ThreadPool* pool) {
|
||||
auto lambdaFunctionToRun = pool->fn_;
|
||||
int localItemsCompleted = 0;
|
||||
int localItemsStolen = 0;
|
||||
@ -312,8 +359,7 @@ ThreadInfo::runAndSteal(int threadId, ThreadPool* pool) {
|
||||
}
|
||||
|
||||
// Done, now look for other threads' items to steal
|
||||
for (auto i = (threadId_ + 1) % numThreads_;
|
||||
i != threadId_;
|
||||
for (auto i = (threadId_ + 1) % numThreads_; i != threadId_;
|
||||
i = (i + 1) % numThreads_) {
|
||||
auto& otherThread = pool->threadInfo_[i];
|
||||
|
||||
@ -339,7 +385,7 @@ ThreadInfo::runAndSteal(int threadId, ThreadPool* pool) {
|
||||
|
||||
if (localItemsCompleted > 0) {
|
||||
auto numRemaining =
|
||||
(pool->workItemsPending_ -= localItemsCompleted); // atomic
|
||||
(pool->workItemsPending_ -= localItemsCompleted); // atomic
|
||||
DCHECK_GE(numRemaining, 0);
|
||||
|
||||
if (numRemaining == 0) {
|
||||
|
@ -16,14 +16,14 @@
|
||||
// Compile-time flag to control usage of main thread work imbalance
|
||||
// #define CAFFE2_THREADPOOL_MAIN_IMBALANCE
|
||||
|
||||
#include <stdlib.h> // posix_memalign
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <stdlib.h> // posix_memalign
|
||||
|
||||
//
|
||||
// A work-stealing threadpool loosely based off of pthreadpool
|
||||
@ -39,15 +39,15 @@ struct AllocAligned {
|
||||
template <typename... Args>
|
||||
static T* alloc(size_t align, Args&&... args) {
|
||||
void* p = nullptr;
|
||||
// FIXME: we should just be able to use std::align
|
||||
// FIXME: we should just be able to use std::align
|
||||
#if !defined(__ANDROID__)
|
||||
posix_memalign((void**) &p, align, sizeof(T));
|
||||
posix_memalign((void**)&p, align, sizeof(T));
|
||||
#else
|
||||
p = memalign(align, sizeof(T));
|
||||
#endif
|
||||
|
||||
if (p) {
|
||||
return new(p) T(std::forward<Args>(args)...);
|
||||
return new (p) T(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
@ -57,7 +57,7 @@ struct AllocAligned {
|
||||
static void release(T* p) {
|
||||
if (p) {
|
||||
p->~T();
|
||||
free((void*) p);
|
||||
free((void*)p);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -74,10 +74,11 @@ struct AlignedDeleter {
|
||||
template <typename T>
|
||||
struct MakeAligned {
|
||||
template <typename... Args>
|
||||
static std::unique_ptr<T, AlignedDeleter<T>> make(size_t align,
|
||||
Args&&... args) {
|
||||
static std::unique_ptr<T, AlignedDeleter<T>> make(
|
||||
size_t align,
|
||||
Args&&... args) {
|
||||
return std::unique_ptr<T, AlignedDeleter<T>>(
|
||||
AllocAligned<T>::alloc(align, std::forward<Args>(args)...));
|
||||
AllocAligned<T>::alloc(align, std::forward<Args>(args)...));
|
||||
}
|
||||
};
|
||||
|
||||
@ -85,9 +86,7 @@ struct ThreadPool;
|
||||
|
||||
#ifdef CAFFE2_THREADPOOL_STATS
|
||||
struct ThreadStats {
|
||||
inline ThreadStats() :
|
||||
numAssigned(0), numWorkedOn(0), numStolen(0) {
|
||||
}
|
||||
inline ThreadStats() : numAssigned(0), numWorkedOn(0), numStolen(0) {}
|
||||
|
||||
inline void reset() {
|
||||
numAssigned = 0;
|
||||
@ -102,14 +101,13 @@ struct ThreadStats {
|
||||
#endif
|
||||
|
||||
struct alignas(kCacheLineSize) ThreadInfo {
|
||||
inline ThreadInfo(int threadId, int numThreads) :
|
||||
rangeStart_(0),
|
||||
rangeEnd_(0),
|
||||
rangeLength_(0),
|
||||
wantExit_(false),
|
||||
threadId_(threadId),
|
||||
numThreads_(numThreads) {
|
||||
}
|
||||
inline ThreadInfo(int threadId, int numThreads)
|
||||
: rangeStart_(0),
|
||||
rangeEnd_(0),
|
||||
rangeLength_(0),
|
||||
wantExit_(false),
|
||||
threadId_(threadId),
|
||||
numThreads_(numThreads) {}
|
||||
|
||||
// Entry point for all worker threads
|
||||
void threadMain(int threadId, ThreadPool* pool);
|
||||
@ -154,9 +152,10 @@ struct alignas(kCacheLineSize) ThreadInfo {
|
||||
};
|
||||
|
||||
class alignas(kCacheLineSize) ThreadPool {
|
||||
public:
|
||||
public:
|
||||
// Constructs a work-stealing threadpool with the given number of
|
||||
// threads
|
||||
static std::unique_ptr<ThreadPool> defaultThreadPool();
|
||||
ThreadPool(int numThreads);
|
||||
|
||||
// Shuts down all worker threads (if any) before destroying ourselves
|
||||
@ -169,7 +168,9 @@ class alignas(kCacheLineSize) ThreadPool {
|
||||
// threadpool; work sizes smaller than this will just be run on the
|
||||
// main (calling) thread
|
||||
void setMinWorkSize(size_t size);
|
||||
size_t getMinWorkSize() const { return minWorkSize_; }
|
||||
size_t getMinWorkSize() const {
|
||||
return minWorkSize_;
|
||||
}
|
||||
|
||||
#ifdef CAFFE2_THREADPOOL_MAIN_IMBALANCE
|
||||
// Set imbalance factor for the main thread versus other threads;
|
||||
@ -186,7 +187,7 @@ class alignas(kCacheLineSize) ThreadPool {
|
||||
std::vector<ThreadStats> getStats(bool reset = false);
|
||||
#endif
|
||||
|
||||
protected:
|
||||
protected:
|
||||
friend struct ThreadInfo;
|
||||
|
||||
// What we are currently working on
|
||||
@ -220,8 +221,8 @@ class alignas(kCacheLineSize) ThreadPool {
|
||||
size_t threadsReady_;
|
||||
|
||||
// The first entry is always for the main thread
|
||||
std::vector<
|
||||
std::unique_ptr<ThreadInfo, AlignedDeleter<ThreadInfo>>> threadInfo_;
|
||||
std::vector<std::unique_ptr<ThreadInfo, AlignedDeleter<ThreadInfo>>>
|
||||
threadInfo_;
|
||||
|
||||
// Set of threads that we are managing
|
||||
std::vector<std::thread> threads_;
|
||||
@ -243,4 +244,4 @@ class alignas(kCacheLineSize) ThreadPool {
|
||||
|
||||
#endif // CAFFE2_THREADPOOL_MOBILE
|
||||
|
||||
#endif // CAFFE2_UTILS_THREADPOOL_H_
|
||||
#endif // CAFFE2_UTILS_THREADPOOL_H_
|
||||
|
Reference in New Issue
Block a user