Sync mobile codebase changes back to fbcode

Summary: Rather chunky sync of changes made exclusively to mobile codebases back to fbcode.

Reviewed By: ajtulloch

Differential Revision: D5314405

fbshipit-source-id: c4d0a7244468f953eb63288306bc9bc78eb9e1be
This commit is contained in:
Jon Morton
2017-07-18 17:40:33 -07:00
committed by Facebook Github Bot
parent 4a81b0f24a
commit 9b9df3fbeb
9 changed files with 533 additions and 261 deletions

View File

@ -53,7 +53,7 @@ struct DefaultCPUAllocator final : CPUAllocator {
#else
CAFFE_ENFORCE_EQ(posix_memalign(&data, gCaffe2Alignment, nbytes), 0);
#endif
CHECK(data) << "Failed to allocate " << nbytes << " bytes.";
CAFFE_ENFORCE(data);
memset(data, 0, nbytes);
return {data, Delete};
}

View File

@ -40,17 +40,22 @@ Predictor::Predictor(
CAFFE_ENFORCE(ws_.CreateNet(run_net));
}
void Predictor::run(const TensorVector& inputs, TensorVector* outputs) {
Predictor::~Predictor() {}
bool Predictor::run(const TensorVector& inputs, TensorVector* outputs) {
CAFFE_ENFORCE(inputs.size() <= run_net_.external_input_size());
for (auto i = 0; i < inputs.size(); ++i) {
shareInputTensor(&ws_, run_net_.external_input(i), inputs[i]);
}
CAFFE_ENFORCE(ws_.RunNet(run_net_.name()));
if (!ws_.RunNet(run_net_.name())) {
return false;
}
outputs->resize(run_net_.external_output_size());
for (auto i = 0; i < outputs->size(); ++i) {
(*outputs)[i] = extractOutputTensor(&ws_, run_net_.external_output(i));
}
return true;
}
}

View File

@ -14,6 +14,7 @@ class Predictor {
const NetDef& init_net,
const NetDef& run_net,
Workspace* parent = nullptr);
~Predictor();
// Executes `run_net` on the inputs.
// The first `inputs.size()` inputs from run_net::external_inputs
@ -24,7 +25,9 @@ class Predictor {
// Postcondition:
// outputs->size() == run_net.external_inputs.size()
void run(const TensorVector& inputs, TensorVector* outputs);
// Returns true on success
bool run(const TensorVector& inputs, TensorVector* outputs);
const NetDef& def() const {
return run_net_;

View File

@ -16,17 +16,6 @@ CAFFE2_DEFINE_bool(
false,
"If true, workspace destructor will print all blob shapes");
#if CAFFE2_MOBILE
// Threadpool restrictions
// Whether or not threadpool caps apply to Android
CAFFE2_DEFINE_int(caffe2_threadpool_android_cap, true, "");
// Whether or not threadpool caps apply to iOS
CAFFE2_DEFINE_int(caffe2_threadpool_ios_cap, false, "");
#endif // CAFFE2_MOBILE
namespace caffe2 {
void Workspace::PrintBlobSizes() {
@ -143,8 +132,7 @@ const Blob* Workspace::GetBlob(const string& name) const {
}
Blob* Workspace::GetBlob(const string& name) {
return const_cast<Blob*>(
static_cast<const Workspace*>(this)->GetBlob(name));
return const_cast<Blob*>(static_cast<const Workspace*>(this)->GetBlob(name));
}
NetBase* Workspace::CreateNet(const NetDef& net_def, bool overwrite) {
@ -232,51 +220,18 @@ bool Workspace::RunNetOnce(const NetDef& net_def) {
return true;
}
bool Workspace::RunPlan(const PlanDef& plan,
ShouldContinue shouldContinue) {
bool Workspace::RunPlan(const PlanDef& plan, ShouldContinue shouldContinue) {
return RunPlanOnWorkspace(this, plan, shouldContinue);
}
#if CAFFE2_MOBILE
ThreadPool* Workspace::GetThreadPool() {
std::lock_guard<std::mutex> guard(thread_pool_creation_mutex_);
if (!thread_pool_) {
int numThreads = std::thread::hardware_concurrency();
bool applyCap = false;
#if CAFFE2_ANDROID
applyCap = caffe2::FLAGS_caffe2_threadpool_android_cap;
#elif CAFFE2_IOS
applyCap = caffe2::FLAGS_caffe2_threadpool_ios_cap;
#else
#error Undefined architecture
#endif
if (applyCap) {
// 1 core -> 1 thread
// 2 cores -> 2 threads
// 4 cores -> 3 threads
// 8 cores -> 4 threads
// more, continue limiting to half of available cores
if (numThreads <= 3) {
// no change
} else if (numThreads <= 5) {
// limit to 3
numThreads = 3;
} else {
// Use half the cores
numThreads = numThreads / 2;
}
}
LOG(INFO) << "Constructing thread pool with " << numThreads << " threads";
thread_pool_.reset(new ThreadPool(numThreads));
thread_pool_ = ThreadPool::defaultThreadPool();
}
return thread_pool_.get();
}
#endif // CAFFE2_MOBILE
} // namespace caffe2
} // namespace caffe2

View File

@ -548,14 +548,20 @@ bool ConvTransposeMobileOp<T, Context>::RunOnDeviceWithOrderNCHW() {
// need buffers for the worker threads
size_t colBlockSize = W + kernel_w_ / stride_w_;
size_t threadYBufferSize = C * outputH * colBlockSize * stride_w_;
// Require 16 byte alignment, so 4-element alignment as these are floats.
size_t threadYBufferSizeAligned =
((C * outputH * colBlockSize * stride_w_ + 3) / 4) * 4;
size_t threadColBufferSize = C * kernel_h_ * kernel_w_ * W;
// Work around GCC 4.9 bug when this is declared inside the inner lambda.
auto runLocalTile = [&](TensorCPU* threadBuffer, int threadId, size_t tileId) {
auto localYData = threadBuffer->template mutable_data<T>() + threadId * threadYBufferSize;
auto runLocalTile = [&](TensorCPU* threadBuffer,
int threadId,
size_t tileId) {
auto localYData = threadBuffer->template mutable_data<T>() +
threadId * threadYBufferSizeAligned;
auto localColBufferData = threadBuffer->template mutable_data<T>() +
numThreads * threadYBufferSize + threadId * threadColBufferSize;
numThreads * threadYBufferSizeAligned + threadId * threadColBufferSize;
runTileContiguous<T, Context>(tileId,
N,
@ -578,11 +584,14 @@ bool ConvTransposeMobileOp<T, Context>::RunOnDeviceWithOrderNCHW() {
};
auto f = [&](Tensor<Context>* threadBuffer) {
threadBuffer->Resize(numThreads * threadYBufferSize + numThreads * threadColBufferSize);
threadBuffer->Resize(
numThreads * threadYBufferSizeAligned +
numThreads * threadColBufferSize);
// Group together thread buffers for accumulation
std::vector<T*> toSum(numThreads - 1);
for (int i = 1; i < numThreads; ++i) {
toSum[i - 1] = threadBuffer->template mutable_data<T>() + i * threadYBufferSize;
toSum[i - 1] = threadBuffer->template mutable_data<T>() +
i * threadYBufferSizeAligned;
}
for (auto image_id = 0; image_id < N; ++image_id) {
@ -591,7 +600,10 @@ bool ConvTransposeMobileOp<T, Context>::RunOnDeviceWithOrderNCHW() {
// The column buffers are overwritten by the matrix multiplication
// each time, so we need not clear them out each round
math::Set<T, Context>(
numThreads * threadYBufferSize, 0, threadBuffer->template mutable_data<T>(), &context_);
numThreads * threadYBufferSizeAligned,
0,
threadBuffer->template mutable_data<T>(),
&context_);
// Run tiled gemm and col2im in our threadpool; all of these tiles
// are guaranteed to be full tiles

View File

@ -9,82 +9,25 @@ using std::min;
namespace {
template <typename T>
class AveragePool {
public:
static float initialize() {
return 0.0;
}
static void process(
const int x_col,
const int y_col,
ConstEigenMatrixMap<float>& x_mat,
EigenMatrixMap<float>& y_mat) {
y_mat.col(y_col) += x_mat.col(x_col);
}
static void process(const T& x_data, T& y_data) {
y_data += x_data;
}
static void finalize(const int size, T& y_data) {
y_data /= size;
}
static void
finalize(const int size, const int col, EigenMatrixMap<float>& y_mat) {
y_mat.col(col) /= size;
}
static bool runNeon() {
return true;
}
};
template <typename T>
class MaxPool {
public:
static float initialize() {
return std::numeric_limits<float>::lowest();
}
static void process(
const int x_col,
const int y_col,
ConstEigenMatrixMap<float>& x_mat,
EigenMatrixMap<float>& y_mat) {
y_mat.col(y_col) = y_mat.col(y_col).cwiseMax(x_mat.col(x_col));
}
static void process(const T& x_data, T& y_data) {
if (x_data > y_data) {
y_data = x_data;
}
}
static void finalize(const int /*size*/, T& /*y_data*/) {}
static void finalize(
const int /*size*/,
const int col,
EigenMatrixMap<float>& /*y_mat*/) {}
static bool runNeon() {
return false;
}
};
#ifdef __ARM_NEON__
bool isNeonEligible(int inputH, int inputW,
int outputH, int outputW,
int kH, int kW,
int strideH, int strideW,
int padT, int padL, int padB, int padR,
int dilationH, int dilationW,
const float* input,
float* output) {
bool isNeon4x4p0s0Eligible(
int inputH,
int inputW,
int outputH,
int outputW,
int kH,
int kW,
int strideH,
int strideW,
int padT,
int padL,
int padB,
int padR,
int dilationH,
int dilationW,
const float* input,
float* output) {
// Use this kernel only if:
// Kernel width is 4x4
// Kernel stride is 4x4
@ -103,20 +46,21 @@ bool isNeonEligible(int inputH, int inputW,
bool outputOk = ((inputH % outputH) == 0) && ((inputW % outputW) == 0);
bool inputOk = (inputW % 4 == 0) && (inputH % 4 == 0);
bool alignOk = isPointerAligned(input, sizeof(float32x4_t)) &&
isPointerAligned(output, sizeof(float32x4_t));
isPointerAligned(output, sizeof(float32x4_t));
return kernelOk && strideOk && padOk && dilationOk &&
outputOk && inputOk && alignOk;
return kernelOk && strideOk && padOk && dilationOk && outputOk && inputOk &&
alignOk;
}
// Vectorizes 4x4p0s0 averge pooling for ARM NEON
void avgPoolNeon4x4p0s0Plane(int inputH, int inputW,
const float* input,
float* output) {
void avgPoolNeon4x4p0s0Plane(
int inputH,
int inputW,
const float* input,
float* output) {
constexpr int kKernelHeight = 4;
constexpr int kKernelWidth = 4;
constexpr float kDiv =
(1.0f / ((float) kKernelHeight * (float) kKernelWidth));
constexpr float kDiv = (1.0f / ((float)kKernelHeight * (float)kKernelWidth));
// Handle portion that can be unrolled by 4
constexpr int kUnroll = 4;
@ -202,10 +146,13 @@ void avgPoolNeon4x4p0s0Plane(int inputH, int inputW,
}
}
void
runNeonAveragePool4x4p0s0NCHW(int N, int C, int inputH, int inputW,
const float* input,
float* output) {
void runNeonAveragePool4x4p0s0NCHW(
int N,
int C,
int inputH,
int inputW,
const float* input,
float* output) {
// We only have the 4x4p0s0 implementation at present, which is
// checked at a higher level
int outputH = inputH / 4;
@ -220,9 +167,291 @@ runNeonAveragePool4x4p0s0NCHW(int N, int C, int inputH, int inputW,
}
}
}
bool isNeon2x2p0s0Eligible(
int inputH,
int inputW,
int outputH,
int outputW,
int kH,
int kW,
int strideH,
int strideW,
int padT,
int padL,
int padB,
int padR,
int dilationH,
int dilationW,
const float* input,
float* output) {
// Use this kernel only if:
// Kernel width is 2x2
// Kernel stride is 2x2
// Padding is 0
// Dilation is 1
// Output width and height are even divisors of input width
// Input width and height are divisible by 4 (should be implied by
// all of the above, but just check again)
// Input and output pointers are aligned by float32x4_t
bool kernelOk = (kH == 2) && (kW == 2);
bool strideOk = (strideH == 2) && (strideW == 2);
bool padOk = (padT == 0) && (padL == 0) && (padB == 0) && (padR == 0);
bool dilationOk = (dilationH == 1) && (dilationW == 1);
bool outputOk = ((inputH % outputH) == 0) && ((inputW % outputW) == 0);
bool inputOk = (inputW % 4 == 0) && (inputH % 4 == 0);
bool alignOk = isPointerAligned(input, sizeof(float32x4_t)) &&
isPointerAligned(output, sizeof(float32x4_t));
return kernelOk && strideOk && padOk && dilationOk && outputOk && inputOk &&
alignOk;
}
// Vectorizes 2x2p0s0 averge pooling for ARM NEON
void maxPoolNeon2x2p0s0Plane(
int inputH,
int inputW,
const float* input,
float* output) {
constexpr int kKernelHeight = 2;
constexpr int kKernelWidth = 2;
// Handle portion that can be unrolled by 4
constexpr int kUnroll = 4;
constexpr int kLoadSizeFloat = (sizeof(float32x4_t) / sizeof(float));
constexpr int kLoadCols = kUnroll * kLoadSizeFloat;
if (inputW % kLoadCols == 0) {
for (int h = 0; h < inputH; h += kKernelHeight) {
float* outputRow = output + (h / kKernelHeight) * (inputW / kKernelWidth);
const float* curInput = input + h * inputW;
for (int w = 0; w < inputW; w += kLoadCols) {
float32x2_t hmax_0, hmax_1, hmax_2, hmax_3;
{
float32x4_t v0_0 = vld1q_f32_aligned(curInput + 0 * inputW);
float32x4_t v0_1 = vld1q_f32_aligned(curInput + 1 * inputW);
float32x4_t vmax = vmaxq_f32(v0_0, v0_1);
hmax_0 = vpmax_f32(vget_low_f32(vmax), vget_high_f32(vmax));
}
curInput += kLoadSizeFloat;
{
float32x4_t v0_0 = vld1q_f32_aligned(curInput + 0 * inputW);
float32x4_t v0_1 = vld1q_f32_aligned(curInput + 1 * inputW);
float32x4_t vmax = vmaxq_f32(v0_0, v0_1);
hmax_1 = vpmax_f32(vget_low_f32(vmax), vget_high_f32(vmax));
}
curInput += kLoadSizeFloat;
{
float32x4_t v0_0 = vld1q_f32_aligned(curInput + 0 * inputW);
float32x4_t v0_1 = vld1q_f32_aligned(curInput + 1 * inputW);
float32x4_t vmax = vmaxq_f32(v0_0, v0_1);
hmax_2 = vpmax_f32(vget_low_f32(vmax), vget_high_f32(vmax));
}
curInput += kLoadSizeFloat;
{
float32x4_t v0_0 = vld1q_f32_aligned(curInput + 0 * inputW);
float32x4_t v0_1 = vld1q_f32_aligned(curInput + 1 * inputW);
float32x4_t vmax = vmaxq_f32(v0_0, v0_1);
hmax_3 = vpmax_f32(vget_low_f32(vmax), vget_high_f32(vmax));
}
curInput += kLoadSizeFloat;
float32x4_t out_0 = vcombine_f32(hmax_0, hmax_1);
float32x4_t out_1 = vcombine_f32(hmax_2, hmax_3);
vst1q_f32_aligned(&outputRow[w / kKernelWidth + 0], out_0);
vst1q_f32_aligned(&outputRow[w / kKernelWidth + 4], out_1);
}
}
} else {
// Not unrolled
for (int h = 0; h < inputH; h += kKernelHeight) {
const float* inputRow = input + h * inputW;
float* outputRow = output + (h / kKernelHeight) * (inputW / kKernelWidth);
for (int w = 0; w < inputW; w += kKernelWidth * 2) {
const float* curInput = inputRow + w;
float32x4_t v0_0 = vld1q_f32_aligned(curInput + 0 * inputW);
float32x4_t v0_1 = vld1q_f32_aligned(curInput + 1 * inputW);
float32x4_t vmax = vmaxq_f32(v0_0, v0_1);
float32x2_t hmax = vpmax_f32(vget_low_f32(vmax), vget_high_f32(vmax));
vst1_f32(&outputRow[w / kKernelWidth], hmax);
}
}
}
}
void runNeonMaxPool2x2p0s0NCHW(
int N,
int C,
int inputH,
int inputW,
const float* input,
float* output) {
// We only have the 2x2p0s0 implementation at present, which is
// checked at a higher level
int outputH = inputH / 2;
int outputW = inputW / 2;
for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) {
const float* curInput = input + (n * C + c) * inputH * inputW;
float* curOutput = output + (n * C + c) * outputH * outputW;
maxPoolNeon2x2p0s0Plane(inputH, inputW, curInput, curOutput);
}
}
}
#endif // __ARM_NEON__
} // namespace
} // namespace
template <typename T>
class AveragePool {
public:
static float initialize() {
return 0.0;
}
static void process(
const int x_col,
const int y_col,
ConstEigenMatrixMap<float>& x_mat,
EigenMatrixMap<float>& y_mat) {
y_mat.col(y_col) += x_mat.col(x_col);
}
static void process(const T& x_data, T& y_data) {
y_data += x_data;
}
static void finalize(const int size, T& y_data) {
y_data /= size;
}
static void
finalize(const int size, const int col, EigenMatrixMap<float>& y_mat) {
y_mat.col(col) /= size;
}
static bool runSpecialized(
int N,
int C,
int inputH,
int inputW,
int outputH,
int outputW,
int kH,
int kW,
int strideH,
int strideW,
int padT,
int padL,
int padB,
int padR,
int dilationH,
int dilationW,
const float* input,
float* output) {
#ifdef __ARM_NEON__
if (isNeon4x4p0s0Eligible(
inputH,
inputW,
outputH,
outputW,
kH,
kW,
strideH,
strideW,
padT,
padL,
padB,
padR,
dilationH,
dilationW,
input,
output)) {
runNeonAveragePool4x4p0s0NCHW(N, C, inputH, inputW, input, output);
return true;
}
#endif
return false;
}
};
template <typename T>
class MaxPool {
public:
static float initialize() {
return std::numeric_limits<float>::lowest();
}
static void process(
const int x_col,
const int y_col,
ConstEigenMatrixMap<float>& x_mat,
EigenMatrixMap<float>& y_mat) {
y_mat.col(y_col) = y_mat.col(y_col).cwiseMax(x_mat.col(x_col));
}
static void process(const T& x_data, T& y_data) {
if (x_data > y_data) {
y_data = x_data;
}
}
static void finalize(const int /*size*/, T& /*y_data*/) {}
static void finalize(
const int /*size*/,
const int col,
EigenMatrixMap<float>& /*y_mat*/) {}
static bool runSpecialized(
int N,
int C,
int inputH,
int inputW,
int outputH,
int outputW,
int kH,
int kW,
int strideH,
int strideW,
int padT,
int padL,
int padB,
int padR,
int dilationH,
int dilationW,
const float* input,
float* output) {
#ifdef __ARM_NEON__
if (isNeon2x2p0s0Eligible(
inputH,
inputW,
outputH,
outputW,
kH,
kW,
strideH,
strideW,
padT,
padL,
padB,
padR,
dilationH,
dilationW,
input,
output)) {
runNeonMaxPool2x2p0s0NCHW(N, C, inputH, inputW, input, output);
return true;
}
#endif
return false;
}
};
template <typename T, class Context, typename PoolType>
bool PoolOp<T, Context, PoolType>::RunOnDeviceWithOrderNCHW() {
@ -243,30 +472,30 @@ bool PoolOp<T, Context, PoolType>::RunOnDeviceWithOrderNCHW() {
int pooled_width = kernel_.size() > 1 ? Y->dim32(3) : 1;
int pooled_depth = kernel_.size() > 2 ? Y->dim32(4) : 1;
#ifdef __ARM_NEON__
// We specialize certain variants on ARM for vectorization
if (PoolType::runNeon() && isNeonEligible(
X.dim32(2),
X.dim32(3),
Y->dim32(2),
Y->dim32(3),
kernel_h(),
kernel_w(),
stride_h(),
stride_w(),
pad_t(),
pad_l(),
pad_b(),
pad_r(),
dilation_h(),
dilation_w(),
Xdata,
Ydata)) {
runNeonAveragePool4x4p0s0NCHW(
X.dim32(0), X.dim32(1), X.dim32(2), X.dim32(3), Xdata, Ydata);
if (kernel_.size() == 2 &&
PoolType::runSpecialized(
X.dim32(0),
X.dim32(1),
X.dim32(2),
X.dim32(3),
Y->dim32(2),
Y->dim32(3),
kernel_h(),
kernel_w(),
stride_h(),
stride_w(),
pad_t(),
pad_l(),
pad_b(),
pad_r(),
dilation_h(),
dilation_w(),
Xdata,
Ydata)) {
return true;
}
#endif // __ARM_NEON__
switch (kernel_.size()) {
case 1:
for (int n = 0; n < X.dim32(0); ++n) {
@ -450,8 +679,8 @@ bool PoolOp<T, Context, PoolType>::RunOnDeviceWithOrderNHWC() {
}
}
}
}
break;
}
break;
default:
CAFFE_THROW("Unsupported pooling size : ", kernel_.size());
return false;
@ -465,48 +694,60 @@ REGISTER_CPU_OPERATOR(
PoolOp<float, CPUContext, AveragePool<float>>);
OPERATOR_SCHEMA(AveragePool)
.NumInputs(1)
.NumOutputs(1)
.TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool)
.SetDoc(R"DOC(
.NumInputs(1)
.NumOutputs(1)
.TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool)
.SetDoc(R"DOC(
AveragePool consumes an input blob X and applies average pooling across the
the blob according to kernel sizes, stride sizes, and pad lengths defined by the
ConvPoolOpBase operator. Average pooling consisting of averaging all values of a
subset of the input tensor according to the kernel size and downsampling the
data into the output blob Y for further processing.
)DOC")
.Input(0, "X", "Input data tensor from the previous operator; dimensions "
"depend on whether the NCHW or NHWC operators are being used. For example, "
"in the former, the input has size (N x C x H x W), where N is the batch "
"size, C is the number of channels, and H and W are the height and the width "
"of the data. The corresponding permutation of dimensions is used in the "
"latter case. ")
.Output(0, "Y", "Output data tensor from average pooling across the input "
"tensor. Dimensions will vary based on various kernel, stride, and pad "
"sizes.");
.Input(
0,
"X",
"Input data tensor from the previous operator; dimensions depend on "
"whether the NCHW or NHWC operators are being used. For example, in "
"the former, the input has size (N x C x H x W), where N is the batch "
"size, C is the number of channels, and H and W are the height and the "
"width of the data. The corresponding permutation of dimensions is "
"used in the latter case.")
.Output(
0,
"Y",
"Output data tensor from average pooling across the input "
"tensor. Dimensions will vary based on various kernel, stride, and pad "
"sizes.");
REGISTER_CPU_OPERATOR(MaxPool, PoolOp<float, CPUContext, MaxPool<float>>);
OPERATOR_SCHEMA(MaxPool)
.NumInputs(1)
.NumOutputs(1)
.TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool)
.SetDoc(R"DOC(
.NumInputs(1)
.NumOutputs(1)
.TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool)
.SetDoc(R"DOC(
MaxPool consumes an input blob X and applies max pooling across the
the blob according to kernel sizes, stride sizes, and pad lengths defined by the
ConvPoolOpBase operator. Max pooling consisting of taking the maximumvalue of a
subset of the input tensor according to the kernel size and downsampling the
data into the output blob Y for further processing.
)DOC")
.Input(0, "X", "Input data tensor from the previous operator; dimensions "
"depend on whether the NCHW or NHWC operators are being used. For example, "
"in the former, the input has size (N x C x H x W), where N is the batch "
"size, C is the number of channels, and H and W are the height and the width "
"of the data. The corresponding permutation of dimensions is used in the "
"latter case. ")
.Output(0, "Y", "Output data tensor from max pooling across the input "
"tensor. Dimensions will vary based on various kernel, stride, and pad "
"sizes.");
.Input(
0,
"X",
"Input data tensor from the previous operator; dimensions depend on "
"whether the NCHW or NHWC operators are being used. For example, in "
"the former, the input has size (N x C x H x W), where N is the batch "
"size, C is the number of channels, and H and W are the height and the "
"width of the data. The corresponding permutation of dimensions is "
"used in the latter case.")
.Output(
0,
"Y",
"Output data tensor from max pooling across the input "
"tensor. Dimensions will vary based on various kernel, stride, and pad "
"sizes.");
} // namespace
} // namespace caffe2

View File

@ -22,6 +22,15 @@
#include "caffe2/core/init.h"
#if CAFFE2_ANDROID
#ifndef SYS_gettid
#define SYS_gettid __NR_gettid
#endif
#ifndef SYS_tgkill
#define SYS_tgkill __NR_tgkill
#endif
#endif
namespace {
struct sigaction previousSighup;

View File

@ -1,8 +1,20 @@
#include "caffe2/utils/threadpool/ThreadPool.h"
#include "caffe2/core/logging.h"
CAFFE2_DEFINE_bool(caffe2_threadpool_force_inline, false,
"Force to always run jobs on the calling thread");
#if CAFFE2_ANDROID
#include <cpu-features.h>
#endif
CAFFE2_DEFINE_bool(
caffe2_threadpool_force_inline,
false,
"Force to always run jobs on the calling thread");
// Whether or not threadpool caps apply to Android
CAFFE2_DEFINE_int(caffe2_threadpool_android_cap, true, "");
// Whether or not threadpool caps apply to iOS
CAFFE2_DEFINE_int(caffe2_threadpool_ios_cap, false, "");
#if CAFFE2_THREADPOOL_MOBILE
@ -20,6 +32,50 @@ constexpr size_t kDefaultMinWorkSize = 80;
constexpr float kDefaultImbalanceRatio = 1.0f;
#endif
std::unique_ptr<ThreadPool> ThreadPool::defaultThreadPool() {
int numThreads = std::thread::hardware_concurrency();
#ifdef CAFFE2_ANDROID
// std::thread::hardware_concurrency returns online cores
// (sysconf(_SC_NPROCESSORS_ONLN)), but we want the total number of CPUs. In
// most cases they will match, but since the threadpool is instantiated once,
// we want the number of threads for each device to be predictable.
int numCpus = android_getCpuCount();
LOG(INFO) << "Android cpu count: " << numCpus
<< ", hardware_concurrency: " << numThreads;
numThreads = numCpus;
#endif
bool applyCap = false;
#if CAFFE2_ANDROID
applyCap = caffe2::FLAGS_caffe2_threadpool_android_cap;
#elif CAFFE2_IOS
applyCap = caffe2::FLAGS_caffe2_threadpool_ios_cap;
#else
#error Undefined architecture
#endif
if (applyCap) {
// 1 core -> 1 thread
// 2 cores -> 2 threads
// 4 cores -> 2 threads
// 8 cores -> 4 threads
// more, continue limiting to half of available cores
if (numThreads <= 3) {
// no change
} else if (numThreads <= 5) {
// limit to 2
numThreads = 2;
} else {
// Use half the cores
numThreads = numThreads / 2;
}
}
LOG(INFO) << "Constructing thread pool with " << numThreads << " threads";
return caffe2::make_unique<ThreadPool>(numThreads);
}
ThreadPool::ThreadPool(int numThreads)
: fn_(nullptr),
workItemsPending_(0),
@ -27,7 +83,8 @@ ThreadPool::ThreadPool(int numThreads)
threadsReady_(0),
minWorkSize_(kDefaultMinWorkSize)
#ifdef CAFFE2_THREADPOOL_MAIN_IMBALANCE
, imbalanceRatio_(kDefaultImbalanceRatio)
,
imbalanceRatio_(kDefaultImbalanceRatio)
#endif
{
std::lock_guard<std::mutex> guard(mutex_);
@ -35,15 +92,13 @@ ThreadPool::ThreadPool(int numThreads)
// All worker threads (and the main thread) have a ThreadInfo
for (auto i = 0; i < numThreads; ++i) {
threadInfo_.emplace_back(
MakeAligned<ThreadInfo>::make(kCacheLineSize, i, numThreads));
MakeAligned<ThreadInfo>::make(kCacheLineSize, i, numThreads));
}
// The first ThreadInfo is for the main thread
for (auto i = 1; i < numThreads; ++i) {
auto pInfo = &(threadInfo_[i]);
auto fn = [pInfo, this, i]() {
(*pInfo)->threadMain(i, this);
};
auto fn = [pInfo, this, i]() { (*pInfo)->threadMain(i, this); };
threads_.emplace_back(std::thread(std::move(fn)));
}
@ -65,26 +120,23 @@ ThreadPool::~ThreadPool() {
}
}
int
ThreadPool::getNumThreads() const {
int ThreadPool::getNumThreads() const {
std::lock_guard<std::mutex> guard(executionMutex_);
return threadInfo_.size();
}
// Sets the minimum work size (range) for which to invoke the
// threadpool; work sizes smaller than this will just be run on the
// main (calling) thread
void
ThreadPool::setMinWorkSize(size_t size) {
// Sets the minimum work size (range) for which to invoke the
// threadpool; work sizes smaller than this will just be run on the
// main (calling) thread
void ThreadPool::setMinWorkSize(size_t size) {
std::lock_guard<std::mutex> guard(executionMutex_);
minWorkSize_ = size;
}
#ifdef CAFFE2_THREADPOOL_MAIN_IMBALANCE
void
ThreadPool::setImbalanceRatio(float ratio) {
void ThreadPool::setImbalanceRatio(float ratio) {
std::lock_guard<std::mutex> guard(executionMutex_);
imbalanceRatio_ = ratio;
@ -92,8 +144,7 @@ ThreadPool::setImbalanceRatio(float ratio) {
#endif
#ifdef CAFFE2_THREADPOOL_STATS
std::vector<ThreadStats>
ThreadPool::getStats(bool reset) {
std::vector<ThreadStats> ThreadPool::getStats(bool reset) {
std::lock_guard<std::mutex> guard(executionMutex_);
// Set up thread state
@ -123,14 +174,13 @@ ThreadPool::getStats(bool reset) {
}
#endif
void
ThreadPool::run(const std::function<void(int, size_t)>& fn, size_t range) {
void ThreadPool::run(const std::function<void(int, size_t)>& fn, size_t range) {
std::lock_guard<std::mutex> guard(executionMutex_);
// If there are no worker threads, or if the range is too small (too
// little work), just run locally
bool runLocally = threads_.empty() || range < minWorkSize_ ||
FLAGS_caffe2_threadpool_force_inline;
FLAGS_caffe2_threadpool_force_inline;
auto numThreads = threadInfo_.size();
size_t workUnitsPerThread = 0;
@ -140,25 +190,24 @@ ThreadPool::run(const std::function<void(int, size_t)>& fn, size_t range) {
if (!runLocally) {
size_t workUnitsPerThread = (numThreads + range - 1) / numThreads;
// On mobile devices (especially big.LITTLE cores), there is
// significant lag in getting other threads to participate versus
// the current thread, which is likely already running on a big
// core.
// Based on tests, the main thread will execute (through its own
// work and stealing others) about 25% more work than other
// threads.
// To reduce the work stealing overhead, give the main thread 25%
// more work to start with.
// On mobile devices (especially big.LITTLE cores), there is
// significant lag in getting other threads to participate versus
// the current thread, which is likely already running on a big
// core.
// Based on tests, the main thread will execute (through its own
// work and stealing others) about 25% more work than other
// threads.
// To reduce the work stealing overhead, give the main thread 25%
// more work to start with.
#ifdef CAFFE2_THREADPOOL_MAIN_IMBALANCE
firstThreadWork = (size_t) (imbalanceRatio_ * workUnitsPerThread);
firstThreadWork = (size_t)(imbalanceRatio_ * workUnitsPerThread);
if (firstThreadWork >= range) {
// give all to first thread
runLocally = true;
}
size_t remainderWork = range - firstThreadWork;
otherThreadWork =
((numThreads - 1) + remainderWork - 1) / (numThreads - 1);
otherThreadWork = ((numThreads - 1) + remainderWork - 1) / (numThreads - 1);
#else
firstThreadWork = workUnitsPerThread;
otherThreadWork = workUnitsPerThread;
@ -251,8 +300,7 @@ ThreadPool::run(const std::function<void(int, size_t)>& fn, size_t range) {
}
}
void
ThreadInfo::threadMain(int threadId, ThreadPool* pool) {
void ThreadInfo::threadMain(int threadId, ThreadPool* pool) {
long lastProcessedWorkId = 0;
while (true) {
@ -289,8 +337,7 @@ ThreadInfo::threadMain(int threadId, ThreadPool* pool) {
}
}
bool
ThreadInfo::runAndSteal(int threadId, ThreadPool* pool) {
bool ThreadInfo::runAndSteal(int threadId, ThreadPool* pool) {
auto lambdaFunctionToRun = pool->fn_;
int localItemsCompleted = 0;
int localItemsStolen = 0;
@ -312,8 +359,7 @@ ThreadInfo::runAndSteal(int threadId, ThreadPool* pool) {
}
// Done, now look for other threads' items to steal
for (auto i = (threadId_ + 1) % numThreads_;
i != threadId_;
for (auto i = (threadId_ + 1) % numThreads_; i != threadId_;
i = (i + 1) % numThreads_) {
auto& otherThread = pool->threadInfo_[i];
@ -339,7 +385,7 @@ ThreadInfo::runAndSteal(int threadId, ThreadPool* pool) {
if (localItemsCompleted > 0) {
auto numRemaining =
(pool->workItemsPending_ -= localItemsCompleted); // atomic
(pool->workItemsPending_ -= localItemsCompleted); // atomic
DCHECK_GE(numRemaining, 0);
if (numRemaining == 0) {

View File

@ -16,14 +16,14 @@
// Compile-time flag to control usage of main thread work imbalance
// #define CAFFE2_THREADPOOL_MAIN_IMBALANCE
#include <stdlib.h> // posix_memalign
#include <atomic>
#include <condition_variable>
#include <functional>
#include <memory>
#include <mutex>
#include <thread>
#include <vector>
#include <functional>
#include <stdlib.h> // posix_memalign
//
// A work-stealing threadpool loosely based off of pthreadpool
@ -39,15 +39,15 @@ struct AllocAligned {
template <typename... Args>
static T* alloc(size_t align, Args&&... args) {
void* p = nullptr;
// FIXME: we should just be able to use std::align
// FIXME: we should just be able to use std::align
#if !defined(__ANDROID__)
posix_memalign((void**) &p, align, sizeof(T));
posix_memalign((void**)&p, align, sizeof(T));
#else
p = memalign(align, sizeof(T));
#endif
if (p) {
return new(p) T(std::forward<Args>(args)...);
return new (p) T(std::forward<Args>(args)...);
}
return nullptr;
@ -57,7 +57,7 @@ struct AllocAligned {
static void release(T* p) {
if (p) {
p->~T();
free((void*) p);
free((void*)p);
}
}
};
@ -74,10 +74,11 @@ struct AlignedDeleter {
template <typename T>
struct MakeAligned {
template <typename... Args>
static std::unique_ptr<T, AlignedDeleter<T>> make(size_t align,
Args&&... args) {
static std::unique_ptr<T, AlignedDeleter<T>> make(
size_t align,
Args&&... args) {
return std::unique_ptr<T, AlignedDeleter<T>>(
AllocAligned<T>::alloc(align, std::forward<Args>(args)...));
AllocAligned<T>::alloc(align, std::forward<Args>(args)...));
}
};
@ -85,9 +86,7 @@ struct ThreadPool;
#ifdef CAFFE2_THREADPOOL_STATS
struct ThreadStats {
inline ThreadStats() :
numAssigned(0), numWorkedOn(0), numStolen(0) {
}
inline ThreadStats() : numAssigned(0), numWorkedOn(0), numStolen(0) {}
inline void reset() {
numAssigned = 0;
@ -102,14 +101,13 @@ struct ThreadStats {
#endif
struct alignas(kCacheLineSize) ThreadInfo {
inline ThreadInfo(int threadId, int numThreads) :
rangeStart_(0),
rangeEnd_(0),
rangeLength_(0),
wantExit_(false),
threadId_(threadId),
numThreads_(numThreads) {
}
inline ThreadInfo(int threadId, int numThreads)
: rangeStart_(0),
rangeEnd_(0),
rangeLength_(0),
wantExit_(false),
threadId_(threadId),
numThreads_(numThreads) {}
// Entry point for all worker threads
void threadMain(int threadId, ThreadPool* pool);
@ -154,9 +152,10 @@ struct alignas(kCacheLineSize) ThreadInfo {
};
class alignas(kCacheLineSize) ThreadPool {
public:
public:
// Constructs a work-stealing threadpool with the given number of
// threads
static std::unique_ptr<ThreadPool> defaultThreadPool();
ThreadPool(int numThreads);
// Shuts down all worker threads (if any) before destroying ourselves
@ -169,7 +168,9 @@ class alignas(kCacheLineSize) ThreadPool {
// threadpool; work sizes smaller than this will just be run on the
// main (calling) thread
void setMinWorkSize(size_t size);
size_t getMinWorkSize() const { return minWorkSize_; }
size_t getMinWorkSize() const {
return minWorkSize_;
}
#ifdef CAFFE2_THREADPOOL_MAIN_IMBALANCE
// Set imbalance factor for the main thread versus other threads;
@ -186,7 +187,7 @@ class alignas(kCacheLineSize) ThreadPool {
std::vector<ThreadStats> getStats(bool reset = false);
#endif
protected:
protected:
friend struct ThreadInfo;
// What we are currently working on
@ -220,8 +221,8 @@ class alignas(kCacheLineSize) ThreadPool {
size_t threadsReady_;
// The first entry is always for the main thread
std::vector<
std::unique_ptr<ThreadInfo, AlignedDeleter<ThreadInfo>>> threadInfo_;
std::vector<std::unique_ptr<ThreadInfo, AlignedDeleter<ThreadInfo>>>
threadInfo_;
// Set of threads that we are managing
std::vector<std::thread> threads_;
@ -243,4 +244,4 @@ class alignas(kCacheLineSize) ThreadPool {
#endif // CAFFE2_THREADPOOL_MOBILE
#endif // CAFFE2_UTILS_THREADPOOL_H_
#endif // CAFFE2_UTILS_THREADPOOL_H_