Enable Detectron model inference for CPU and MKL-DNN paths (#10157)

Summary:
1. Support ops needed for inference of Faster-RCNN/Mask-RCNN needed in Detectron, mostly direct fallbacks.
2. Use CPU device to hold 0-dim tensors and integer tensors in both fallback op and blob feeder, needed by Detectron models.
3. Ignore 0-dim tensor in MKL-DNN concat operator.
4. Generate dynamic library of Detectron module for CPU device.

This PR obsoletes #9164.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10157

Differential Revision: D9276837

Pulled By: yinghai

fbshipit-source-id: dc364932ae4a2e7fcefdee70b5fce3c0cee91b6f
This commit is contained in:
jgong5
2018-08-29 14:56:55 -07:00
committed by Facebook Github Bot
parent 89834dfe64
commit c755616e00
9 changed files with 299 additions and 93 deletions

View File

@ -25,13 +25,21 @@ class IDEEPConcatOp final : public IDEEPOperator {
virtual ~IDEEPConcatOp() {}
bool RunOnDevice() override {
const auto& input_zero = Input(INPUT0);
auto* output = Output(OUTPUT);
TensorCPU* axis_info = OperatorBase::Output<TensorCPU>(AXIS_INFO, CPU);
vector<itensor> inputs;
for (int i = 0; i < InputSize(); ++i) {
inputs.emplace_back(Input(i));
if (OperatorBase::InputBlob(i).template IsType<itensor>()) {
inputs.emplace_back(Input(i));
} else {
CAFFE_ENFORCE(OperatorBase::InputBlob(i).IsType<Tensor>(CPU),
"Expect cpu tensor if not itensor");
auto& tensor_cpu = OperatorBase::Input<Tensor>(i, CPU);
CAFFE_ENFORCE(tensor_cpu.dims().size() == 0 ||
tensor_cpu.size_from_dim(0) == 0,
"Expect zero dim tensor");
}
}
auto axis_vdata = ideep::concat::compute(inputs, axis_, add_axis_, *output);

View File

@ -32,6 +32,8 @@
#include <caffe2/operators/tanh_op.h>
#include <caffe2/operators/transpose_op.h>
#include <caffe2/operators/utility_ops.h>
#include <caffe2/operators/affine_channel_op.h>
#include <caffe2/operators/stop_gradient.h>
#include <caffe2/sgd/adam_op.h>
#include <caffe2/sgd/iter_op.h>
#include <caffe2/sgd/learning_rate_op.h>
@ -116,6 +118,12 @@ REGISTER_IDEEP_OPERATOR(
REGISTER_IDEEP_OPERATOR(
BBoxTransform,
IDEEPFallbackOp<BBoxTransformOp<float, CPUContext>>);
REGISTER_IDEEP_OPERATOR(
AffineChannel,
IDEEPFallbackOp<AffineChannelOp<float, CPUContext>>);
REGISTER_IDEEP_OPERATOR(
StopGradient,
IDEEPFallbackOp<StopGradientOp<CPUContext>>);
REGISTER_IDEEP_OPERATOR(
PadImage,

View File

@ -53,6 +53,8 @@ class IDEEPFallbackOp final : public IDEEPOperator {
// then forward output blobs to local workspace.
std::unordered_map<string, string> forwarded_output_blobs;
for (int i = 0; i < base_def_.output_size(); i++) {
// For in-place case, the in/output tensor for local_ws must be
// re-created, instead of forwarding from current workspace.
string parent_name(base_def_.output(i));
if (!SkipOutputCopy::Contains(i)) {
parent_name += "_cpu_output_blob_" + base_def_.type();
@ -60,6 +62,13 @@ class IDEEPFallbackOp final : public IDEEPOperator {
local_output_blobs_.push_back(ws->CreateBlob(parent_name));
CHECK_NOTNULL(local_output_blobs_.back());
forwarded_output_blobs[base_def_.output(i)] = parent_name;
output_inplace_.push_back(false);
for (const string &input_name : base_def_.input()) {
if (input_name == base_def_.output(i)) {
output_inplace_[i] = true;
break;
}
}
}
local_ws_.reset(new Workspace(ws, forwarded_output_blobs));
// Set up the symbols for the local workspace.
@ -67,31 +76,26 @@ class IDEEPFallbackOp final : public IDEEPOperator {
local_input_blobs_.push_back(local_ws_->CreateBlob(name));
CHECK_NOTNULL(local_input_blobs_.back());
}
input_share_.resize(local_input_blobs_.size(), false);
base_op_.reset(new CPUOp(base_def_, local_ws_.get()));
}
bool RunOnDevice() override {
for (int i = 0; i < InputSize(); ++i) {
if (InputIsType<itensor>(i) && Input(i).get_data_type() == itensor::data_type::f32) {
if (InputIsType<itensor>(i) &&
Input(i).get_data_type() == itensor::data_type::f32) {
auto& input = Input(i);
auto dtensor = local_input_blobs_[i]->GetMutableTensor(CPU);
dtensor->Resize(input.get_dims());
if (input.is_public_format()) {
dtensor->ShareExternalPointer(static_cast<float*>(input.get_data_handle()));
} else {
input.reorder_to(dtensor->template mutable_data<float>());
if (input_share_[i]) {
local_input_blobs_[i]->Reset();
}
} else if (
InputIsType<itensor>(i) &&
Input(i).get_data_type() == itensor::data_type::s32) {
auto& input = Input(i);
input_share_[i] = false;
auto dtensor = local_input_blobs_[i]->GetMutableTensor(CPU);
dtensor->Resize(input.get_dims());
if (input.is_public_format()) {
dtensor->ShareExternalPointer(
static_cast<long*>(input.get_data_handle()));
static_cast<float*>(input.get_data_handle()));
} else {
input.reorder_to(dtensor->template mutable_data<long>());
input.reorder_to(dtensor->template mutable_data<float>());
}
} else {
VLOG(1) << "Input " << i << " is not ideep::tensor. Skipping copy.";
@ -99,8 +103,9 @@ class IDEEPFallbackOp final : public IDEEPOperator {
// local_input_blobs will only be used as const blob input for the
// base op so we are still fine.
local_input_blobs_[i]->ShareExternal(
const_cast<void*>(OperatorBase::Inputs()[i]->GetRaw()),
const_cast<void *>(OperatorBase::Inputs()[i]->GetRaw()),
OperatorBase::Inputs()[i]->meta());
input_share_[i] = true;
}
}
@ -120,21 +125,16 @@ class IDEEPFallbackOp final : public IDEEPOperator {
"IDEEP fallback op currently does not support non-TensorCPU "
"output type who needs copying.");
const auto& src = local_output_blobs_[i]->template Get<TensorCPU>();
auto src_dims = src.dims();
if (src.ndim() == 0) {
VLOG(1) << "Copy output: index " << i << " skipped.";
if (src.template IsType<float>() &&
src.dims().size() != 0 && src.size_from_dim(0) != 0 &&
base_op_->type() != "Python") {
Blob* dst = OperatorBase::OutputBlob(i);
dst->Reset(new Tensor(CPU));
auto dtensor = dst->GetMutableTensor(CPU);
dtensor->Resize(src_dims);
dtensor->ShareData(src);
continue;
}
if (src.template IsType<float>()) {
Blob* dst = OperatorBase::OutputBlob(i);
if (!dst->template IsType<itensor>()) {
// The output tensor must be ideep tensor with public format.
// If reusing ideep tensor with non-public format, the tensor buffer
// will be interpreted incorrectly.
if (!dst->template IsType<itensor>() ||
!dst->template Get<itensor>().is_public_format()) {
dst->Reset(new itensor());
}
@ -143,7 +143,12 @@ class IDEEPFallbackOp final : public IDEEPOperator {
if (dtensor->get_dims() != dst_dims) {
dtensor->resize(dst_dims, itensor::data_type::f32);
}
dtensor->set_data_handle(const_cast<void*>(src.raw_data()));
if (output_inplace_[i]) {
dtensor->reorder_from(dst_dims, itensor::data_type::f32,
const_cast<void*>(src.raw_data()));
} else {
dtensor->set_data_handle(const_cast<void *>(src.raw_data()));
}
} else {
VLOG(2) << "Output " << base_def_.output(i) << " as CPUTensor";
Blob* dst = OperatorBase::OutputBlob(i);
@ -159,6 +164,8 @@ class IDEEPFallbackOp final : public IDEEPOperator {
protected:
vector<Blob*> local_input_blobs_;
vector<Blob*> local_output_blobs_;
vector<bool> output_inplace_;
vector<bool> input_share_;
std::unique_ptr<CPUOp> base_op_;
std::unique_ptr<Workspace> local_ws_;
OperatorDef base_def_;

View File

@ -0,0 +1,99 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import hypothesis.strategies as st
from hypothesis import given
import numpy as np
from caffe2.python import core, workspace
from caffe2.proto import caffe2_pb2
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.ideep_test_util as mu
@unittest.skipIf(not workspace.C.use_ideep, "No IDEEP support.")
class TestFallbackOps(hu.HypothesisTestCase):
@given(stride=st.integers(1, 3),
pad=st.integers(0, 3),
kernel=st.integers(3, 5),
size=st.integers(8, 10),
input_channels=st.integers(1, 3),
output_channels=st.integers(1, 5),
batch_size=st.integers(1, 3),
use_bias=st.booleans(),
**mu.gcs)
def test_in_place(self, stride, pad, kernel, size,
input_channels, output_channels,
batch_size, use_bias, gc, dc):
# To expose fallback in-place potential issue, the fallback op
# following ideep op must be run at least two iterations.
conv = core.CreateOperator(
"Conv",
["X", "w", "b"] if use_bias else ["X", "w"],
["Y"],
stride=stride,
pad=pad,
kernel=kernel,
device_option=dc[0]
)
X = np.random.rand(
batch_size, input_channels, size, size).astype(np.float32) - 0.5
w = np.random.rand(output_channels, input_channels, kernel, kernel) \
.astype(np.float32) - 0.5
b = np.random.rand(output_channels).astype(np.float32) - 0.5
old_ws_name = workspace.CurrentWorkspace()
workspace.SwitchWorkspace("_device_check_", True)
workspace.FeedBlob('X', X, dc[0])
workspace.FeedBlob('w', w, dc[0])
workspace.FeedBlob('b', b, dc[0])
workspace.RunOperatorOnce(conv)
Y = workspace.FetchBlob('Y')
scale = np.random.randn(Y.shape[1]).astype(np.float32)
bias = np.random.randn(Y.shape[1]).astype(np.float32)
ac = core.CreateOperator(
"AffineChannel",
["Y", "scale", "bias"],
["Y"],
is_learnable=False,
device_option=dc[0]
)
workspace.FeedBlob('scale', scale, dc[0])
workspace.FeedBlob('bias', bias, dc[0])
workspace.RunOperatorOnce(ac)
workspace.RunOperatorOnce(conv)
workspace.RunOperatorOnce(ac)
Y0 = workspace.FetchBlob('Y')
workspace.ResetWorkspace()
dev_net = caffe2_pb2.NetDef()
conv_dev = caffe2_pb2.OperatorDef()
conv_dev.CopyFrom(conv)
conv_dev.device_option.CopyFrom(dc[1])
ac_dev = caffe2_pb2.OperatorDef()
ac_dev.CopyFrom(ac)
ac_dev.device_option.CopyFrom(dc[1])
dev_net.op.extend([conv_dev, ac_dev])
workspace.FeedBlob('X', X, dc[1])
workspace.FeedBlob('w', w, dc[1])
workspace.FeedBlob('b', b, dc[1])
workspace.FeedBlob('scale', scale, dc[1])
workspace.FeedBlob('bias', bias, dc[1])
workspace.RunNetOnce(dev_net)
workspace.RunNetOnce(dev_net)
Y1 = workspace.FetchBlob('Y')
if not np.allclose(Y0, Y1, atol=0.01, rtol=0.01):
print(Y1.flatten())
print(Y0.flatten())
print(np.max(np.abs(Y1 - Y0)))
self.assertTrue(False)
workspace.SwitchWorkspace(old_ws_name)
if __name__ == "__main__":
unittest.main()

View File

@ -9,6 +9,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "caffe2/ideep/operators/operator_fallback_ideep.h"
#include <caffe2/ideep/ideep_utils.h>
namespace caffe2 {
@ -19,42 +20,42 @@ USE_IDEEP_DEF_ALIASES();
class IDeepFetcher;
class IDeepFeeder;
REGISTER_BLOB_FETCHER((TypeMeta::Id<itensor>()),IDeepFetcher);
REGISTER_IDEEP_OPERATOR(Python, IDEEPFallbackOp<PythonOp<CPUContext, false>>);
REGISTER_BLOB_FETCHER((TypeMeta::Id<itensor>()), IDeepFetcher);
REGISTER_BLOB_FEEDER(IDEEP, IDeepFeeder);
class IDeepFetcher : public BlobFetcherBase {
TypeMeta type_transform(const itensor &atensor) {
switch(atensor.get_data_type()) {
case itensor::data_type::f32:
return TypeMeta::Make<float>();
case itensor::data_type::s16:
return TypeMeta::Make<float16>();
case itensor::data_type::s32:
return TypeMeta::Make<int>();
case itensor::data_type::s8:
return TypeMeta::Make<int8_t>();
case itensor::data_type::u8:
return TypeMeta::Make<uint8_t>();
default:
// Should we throw exception?
return TypeMeta();
switch (atensor.get_data_type()) {
case itensor::data_type::f32:
return TypeMeta::Make<float>();
case itensor::data_type::s32:
return TypeMeta::Make<int>();
case itensor::data_type::s8:
return TypeMeta::Make<int8_t>();
case itensor::data_type::u8:
return TypeMeta::Make<uint8_t>();
default:
// Should we throw exception?
return TypeMeta();
}
}
public:
pybind11::object Fetch(const Blob& blob) override {
public:
pybind11::object Fetch(const Blob &blob) override {
try {
return FetchTensor(blob.Get<itensor>(), true).obj;
} catch (ideep::error& e) {
VLOG(1) << "IDEEP error: " << e.message;
} catch (ideep::error &e) {
LOG(ERROR) << "IDEEP error: " << e.message;
throw;
}
}
FetchedBlob FetchTensor(const itensor& atensor, bool force_copy) {
FetchedBlob FetchTensor(const itensor &atensor, bool force_copy) {
FetchedBlob result;
CAFFE_ENFORCE(atensor.materialized(),
"Trying to fetch uninitialized tensor");
"Trying to fetch uninitialized tensor");
const int numpy_type = CaffeToNumpyType(type_transform(atensor));
CAFFE_ENFORCE(
numpy_type != -1,
@ -64,17 +65,16 @@ class IDeepFetcher : public BlobFetcherBase {
std::vector<npy_intp> npy_dims(dims.begin(), dims.end());
result.copied = force_copy || atensor.need_reorder();
void* outPtr;
void *outPtr;
if (result.copied) {
result.obj = py::reinterpret_steal<py::object>(
PyArray_SimpleNew(atensor.ndims(), npy_dims.data(), numpy_type));
outPtr = static_cast<void *>(
PyArray_DATA(reinterpret_cast<PyArrayObject*>(result.obj.ptr())));
PyArray_DATA(reinterpret_cast<PyArrayObject *>(result.obj.ptr())));
} else {
outPtr = atensor.get_data_handle();
result.obj = py::reinterpret_steal<py::object>(
PyArray_SimpleNewFromData(
atensor.ndims(), npy_dims.data(), numpy_type, outPtr));
result.obj = py::reinterpret_steal<py::object>(PyArray_SimpleNewFromData(
atensor.ndims(), npy_dims.data(), numpy_type, outPtr));
}
if (numpy_type == NPY_OBJECT) {
@ -95,8 +95,6 @@ class IDeepFeeder : public BlobFeederBase {
return itensor::data_type::f32;
else if (meta == TypeMeta::Make<int>())
return itensor::data_type::s32;
else if (meta == TypeMeta::Make<float16>())
return itensor::data_type::s16;
else if (meta == TypeMeta::Make<int8_t>())
return itensor::data_type::s8;
else if (meta == TypeMeta::Make<uint8_t>())
@ -105,53 +103,74 @@ class IDeepFeeder : public BlobFeederBase {
return itensor::data_type::data_undef;
}
public:
void FeedTensor(
const DeviceOption& option,
PyArrayObject *original_array,
itensor *tensor) {
PyArrayObject *array = PyArray_GETCONTIGUOUS(original_array);
auto g = MakeGuard([&]() {Py_XDECREF(array); });
const auto npy_type = PyArray_TYPE(array);
const TypeMeta& meta = NumpyTypeToCaffe(npy_type);
CAFFE_ENFORCE(
meta.id() != TypeIdentifier::uninitialized(),
public:
void FeedTensor(
const DeviceOption &option,
PyArrayObject *original_array,
itensor *tensor) {
PyArrayObject *array = PyArray_GETCONTIGUOUS(original_array);
auto g = MakeGuard([&]() { Py_XDECREF(array); });
const auto npy_type = PyArray_TYPE(array);
const TypeMeta &meta = NumpyTypeToCaffe(npy_type);
CAFFE_ENFORCE_NE(
meta.id(),
TypeIdentifier::uninitialized(),
"This numpy data type is not supported: ",
PyArray_TYPE(array),
".");
PyArray_TYPE(array), ".");
int ndim = PyArray_NDIM(array);
npy_intp* npy_dims = PyArray_DIMS(array);
int ndim = PyArray_NDIM(array);
npy_intp *npy_dims = PyArray_DIMS(array);
itensor::dims adims;
for (int i = 0; i < ndim; i++) {
adims.push_back(static_cast<itensor::dims::value_type>(
npy_dims[i]));
}
itensor::dims adims;
for (int i = 0; i < ndim; i++) {
adims.push_back(static_cast<itensor::dims::value_type>(npy_dims[i]));
}
switch (npy_type) {
switch (npy_type) {
case NPY_OBJECT:
case NPY_UNICODE:
CAFFE_THROW("IDeep doesn't support string");
break;
default:
auto type = type_transform(meta);
tensor->resize(adims, type);
if (tensor->get_dims() != adims || type != tensor->get_data_type()) {
tensor->resize(adims, type);
}
tensor->reorder_from(adims, type,
static_cast<void *>(PyArray_DATA(array)));
}
}
static_cast<void *>(PyArray_DATA(array)));
}
}
void Feed(const DeviceOption& option, PyArrayObject* original_array,
Blob* blob) {
try {
bool ZeroDim(PyArrayObject *array) {
int ndim = PyArray_NDIM(array);
npy_intp *npy_dims = PyArray_DIMS(array);
return ndim == 0 ||
std::find(npy_dims, npy_dims + ndim, 0) != npy_dims + ndim;
}
void Feed(const DeviceOption &option, PyArrayObject *original_array,
Blob *blob) {
try {
PyArrayObject *array = PyArray_GETCONTIGUOUS(original_array);
auto g = MakeGuard([&]() { Py_XDECREF(array); });
const auto npy_type = PyArray_TYPE(array);
const TypeMeta &meta = NumpyTypeToCaffe(npy_type);
// TODO: if necessary, use dispatcher.
if (meta.Match<float>() && !ZeroDim(original_array)) {
FeedTensor(option, original_array, blob->GetMutable<itensor>());
} catch (ideep::error& e) {
VLOG(1) << "IDEEP error: " << e.message;
throw;
} else {
DeviceOption cpu_option(option);
cpu_option.set_device_type(DeviceType::CPU);
TensorFeeder<CPUContext> cpu_tensor_feeder;
cpu_tensor_feeder.FeedTensor(cpu_option, original_array,
blob->GetMutableTensor(CPU));
}
}
} catch (ideep::error &e) {
LOG(ERROR) << "IDEEP error: " << e.message;
throw;
}
}
};
} // namespace python

View File

@ -11,4 +11,8 @@ if (USE_CUDA)
target_link_libraries(caffe2_detectron_ops_gpu caffe2_gpu)
install(TARGETS caffe2_detectron_ops_gpu DESTINATION lib)
elseif(NOT IOS_PLATFORM)
add_library(caffe2_detectron_ops SHARED ${Detectron_CPU_SRCS})
target_link_libraries(caffe2_detectron_ops caffe2)
install(TARGETS caffe2_detectron_ops DESTINATION lib)
endif()

View File

@ -15,9 +15,19 @@
*/
#include "batch_permutation_op.h"
#ifdef CAFFE2_USE_IDEEP
#include <caffe2/ideep/operators/operator_fallback_ideep.h>
#include <caffe2/ideep/utils/ideep_operator.h>
#endif
namespace caffe2 {
#ifdef CAFFE2_USE_IDEEP
REGISTER_IDEEP_OPERATOR(
BatchPermutation,
IDEEPFallbackOp<BatchPermutationOp<float, CPUContext>>);
#endif
REGISTER_CPU_OPERATOR(BatchPermutation, BatchPermutationOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(
BatchPermutationGradient,

View File

@ -15,8 +15,17 @@
*/
#include "upsample_nearest_op.h"
#ifdef CAFFE2_USE_IDEEP
#include "caffe2/ideep/operators/operator_fallback_ideep.h"
#include "caffe2/ideep/utils/ideep_operator.h"
#endif
namespace caffe2 {
#ifdef CAFFE2_USE_IDEEP
REGISTER_IDEEP_OPERATOR(
UpsampleNearest,
IDEEPFallbackOp<UpsampleNearestOp<float, CPUContext>>);
#endif
REGISTER_CPU_OPERATOR(UpsampleNearest, UpsampleNearestOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(

View File

@ -35,8 +35,50 @@ class UpsampleNearestOp final : public Operator<Context> {
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {
// No CPU implementation for now
CAFFE_NOT_IMPLEMENTED;
auto translate_idx = [](int ii, int d1, int d2, int d3, int scale_factor) {
int x, y, z, w;
w = ii % d3;
ii = ii/d3;
z = ii % d2;
ii = ii/d2;
y = ii % d1;
ii = ii/d1;
x = ii;
w = w/scale_factor;
z = z/scale_factor;
d2 /= scale_factor;
d3 /= scale_factor;
return (((x*d1+y)*d2)+z)*d3+w;
};
auto& X = Input(0);
auto* Y = Output(0);
auto out_shape = X.dims();
out_shape[X.ndim() - 1] *= scale_;
out_shape[X.ndim() - 2] *= scale_;
Y->Resize(out_shape);
int d1;
int d2;
int d3;
if (X.ndim() == 3) {
d1 = Y->dim32(0);
d2 = Y->dim32(1);
d3 = Y->dim32(2);
} else {
d1 = Y->dim32(1);
d2 = Y->dim32(2);
d3 = Y->dim32(3);
}
const T *input_data = X.template data<T>();
T *output_data = Y->template mutable_data<T>();
for (int ii = 0; ii < Y->size(); ii++) {
int ipidx = translate_idx(ii, d1, d2, d3, scale_);
output_data[ii] = input_data[ipidx];
}
return true;
}
protected: