mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable Detectron model inference for CPU and MKL-DNN paths (#10157)
Summary: 1. Support ops needed for inference of Faster-RCNN/Mask-RCNN needed in Detectron, mostly direct fallbacks. 2. Use CPU device to hold 0-dim tensors and integer tensors in both fallback op and blob feeder, needed by Detectron models. 3. Ignore 0-dim tensor in MKL-DNN concat operator. 4. Generate dynamic library of Detectron module for CPU device. This PR obsoletes #9164. Pull Request resolved: https://github.com/pytorch/pytorch/pull/10157 Differential Revision: D9276837 Pulled By: yinghai fbshipit-source-id: dc364932ae4a2e7fcefdee70b5fce3c0cee91b6f
This commit is contained in:
committed by
Facebook Github Bot
parent
89834dfe64
commit
c755616e00
@ -25,13 +25,21 @@ class IDEEPConcatOp final : public IDEEPOperator {
|
||||
virtual ~IDEEPConcatOp() {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
const auto& input_zero = Input(INPUT0);
|
||||
auto* output = Output(OUTPUT);
|
||||
TensorCPU* axis_info = OperatorBase::Output<TensorCPU>(AXIS_INFO, CPU);
|
||||
|
||||
vector<itensor> inputs;
|
||||
for (int i = 0; i < InputSize(); ++i) {
|
||||
inputs.emplace_back(Input(i));
|
||||
if (OperatorBase::InputBlob(i).template IsType<itensor>()) {
|
||||
inputs.emplace_back(Input(i));
|
||||
} else {
|
||||
CAFFE_ENFORCE(OperatorBase::InputBlob(i).IsType<Tensor>(CPU),
|
||||
"Expect cpu tensor if not itensor");
|
||||
auto& tensor_cpu = OperatorBase::Input<Tensor>(i, CPU);
|
||||
CAFFE_ENFORCE(tensor_cpu.dims().size() == 0 ||
|
||||
tensor_cpu.size_from_dim(0) == 0,
|
||||
"Expect zero dim tensor");
|
||||
}
|
||||
}
|
||||
|
||||
auto axis_vdata = ideep::concat::compute(inputs, axis_, add_axis_, *output);
|
||||
|
@ -32,6 +32,8 @@
|
||||
#include <caffe2/operators/tanh_op.h>
|
||||
#include <caffe2/operators/transpose_op.h>
|
||||
#include <caffe2/operators/utility_ops.h>
|
||||
#include <caffe2/operators/affine_channel_op.h>
|
||||
#include <caffe2/operators/stop_gradient.h>
|
||||
#include <caffe2/sgd/adam_op.h>
|
||||
#include <caffe2/sgd/iter_op.h>
|
||||
#include <caffe2/sgd/learning_rate_op.h>
|
||||
@ -116,6 +118,12 @@ REGISTER_IDEEP_OPERATOR(
|
||||
REGISTER_IDEEP_OPERATOR(
|
||||
BBoxTransform,
|
||||
IDEEPFallbackOp<BBoxTransformOp<float, CPUContext>>);
|
||||
REGISTER_IDEEP_OPERATOR(
|
||||
AffineChannel,
|
||||
IDEEPFallbackOp<AffineChannelOp<float, CPUContext>>);
|
||||
REGISTER_IDEEP_OPERATOR(
|
||||
StopGradient,
|
||||
IDEEPFallbackOp<StopGradientOp<CPUContext>>);
|
||||
|
||||
REGISTER_IDEEP_OPERATOR(
|
||||
PadImage,
|
||||
|
@ -53,6 +53,8 @@ class IDEEPFallbackOp final : public IDEEPOperator {
|
||||
// then forward output blobs to local workspace.
|
||||
std::unordered_map<string, string> forwarded_output_blobs;
|
||||
for (int i = 0; i < base_def_.output_size(); i++) {
|
||||
// For in-place case, the in/output tensor for local_ws must be
|
||||
// re-created, instead of forwarding from current workspace.
|
||||
string parent_name(base_def_.output(i));
|
||||
if (!SkipOutputCopy::Contains(i)) {
|
||||
parent_name += "_cpu_output_blob_" + base_def_.type();
|
||||
@ -60,6 +62,13 @@ class IDEEPFallbackOp final : public IDEEPOperator {
|
||||
local_output_blobs_.push_back(ws->CreateBlob(parent_name));
|
||||
CHECK_NOTNULL(local_output_blobs_.back());
|
||||
forwarded_output_blobs[base_def_.output(i)] = parent_name;
|
||||
output_inplace_.push_back(false);
|
||||
for (const string &input_name : base_def_.input()) {
|
||||
if (input_name == base_def_.output(i)) {
|
||||
output_inplace_[i] = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
local_ws_.reset(new Workspace(ws, forwarded_output_blobs));
|
||||
// Set up the symbols for the local workspace.
|
||||
@ -67,31 +76,26 @@ class IDEEPFallbackOp final : public IDEEPOperator {
|
||||
local_input_blobs_.push_back(local_ws_->CreateBlob(name));
|
||||
CHECK_NOTNULL(local_input_blobs_.back());
|
||||
}
|
||||
input_share_.resize(local_input_blobs_.size(), false);
|
||||
base_op_.reset(new CPUOp(base_def_, local_ws_.get()));
|
||||
}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
for (int i = 0; i < InputSize(); ++i) {
|
||||
if (InputIsType<itensor>(i) && Input(i).get_data_type() == itensor::data_type::f32) {
|
||||
if (InputIsType<itensor>(i) &&
|
||||
Input(i).get_data_type() == itensor::data_type::f32) {
|
||||
auto& input = Input(i);
|
||||
auto dtensor = local_input_blobs_[i]->GetMutableTensor(CPU);
|
||||
dtensor->Resize(input.get_dims());
|
||||
if (input.is_public_format()) {
|
||||
dtensor->ShareExternalPointer(static_cast<float*>(input.get_data_handle()));
|
||||
} else {
|
||||
input.reorder_to(dtensor->template mutable_data<float>());
|
||||
if (input_share_[i]) {
|
||||
local_input_blobs_[i]->Reset();
|
||||
}
|
||||
} else if (
|
||||
InputIsType<itensor>(i) &&
|
||||
Input(i).get_data_type() == itensor::data_type::s32) {
|
||||
auto& input = Input(i);
|
||||
input_share_[i] = false;
|
||||
auto dtensor = local_input_blobs_[i]->GetMutableTensor(CPU);
|
||||
dtensor->Resize(input.get_dims());
|
||||
if (input.is_public_format()) {
|
||||
dtensor->ShareExternalPointer(
|
||||
static_cast<long*>(input.get_data_handle()));
|
||||
static_cast<float*>(input.get_data_handle()));
|
||||
} else {
|
||||
input.reorder_to(dtensor->template mutable_data<long>());
|
||||
input.reorder_to(dtensor->template mutable_data<float>());
|
||||
}
|
||||
} else {
|
||||
VLOG(1) << "Input " << i << " is not ideep::tensor. Skipping copy.";
|
||||
@ -99,8 +103,9 @@ class IDEEPFallbackOp final : public IDEEPOperator {
|
||||
// local_input_blobs will only be used as const blob input for the
|
||||
// base op so we are still fine.
|
||||
local_input_blobs_[i]->ShareExternal(
|
||||
const_cast<void*>(OperatorBase::Inputs()[i]->GetRaw()),
|
||||
const_cast<void *>(OperatorBase::Inputs()[i]->GetRaw()),
|
||||
OperatorBase::Inputs()[i]->meta());
|
||||
input_share_[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -120,21 +125,16 @@ class IDEEPFallbackOp final : public IDEEPOperator {
|
||||
"IDEEP fallback op currently does not support non-TensorCPU "
|
||||
"output type who needs copying.");
|
||||
const auto& src = local_output_blobs_[i]->template Get<TensorCPU>();
|
||||
|
||||
auto src_dims = src.dims();
|
||||
if (src.ndim() == 0) {
|
||||
VLOG(1) << "Copy output: index " << i << " skipped.";
|
||||
if (src.template IsType<float>() &&
|
||||
src.dims().size() != 0 && src.size_from_dim(0) != 0 &&
|
||||
base_op_->type() != "Python") {
|
||||
Blob* dst = OperatorBase::OutputBlob(i);
|
||||
dst->Reset(new Tensor(CPU));
|
||||
auto dtensor = dst->GetMutableTensor(CPU);
|
||||
dtensor->Resize(src_dims);
|
||||
dtensor->ShareData(src);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (src.template IsType<float>()) {
|
||||
Blob* dst = OperatorBase::OutputBlob(i);
|
||||
if (!dst->template IsType<itensor>()) {
|
||||
// The output tensor must be ideep tensor with public format.
|
||||
// If reusing ideep tensor with non-public format, the tensor buffer
|
||||
// will be interpreted incorrectly.
|
||||
if (!dst->template IsType<itensor>() ||
|
||||
!dst->template Get<itensor>().is_public_format()) {
|
||||
dst->Reset(new itensor());
|
||||
}
|
||||
|
||||
@ -143,7 +143,12 @@ class IDEEPFallbackOp final : public IDEEPOperator {
|
||||
if (dtensor->get_dims() != dst_dims) {
|
||||
dtensor->resize(dst_dims, itensor::data_type::f32);
|
||||
}
|
||||
dtensor->set_data_handle(const_cast<void*>(src.raw_data()));
|
||||
if (output_inplace_[i]) {
|
||||
dtensor->reorder_from(dst_dims, itensor::data_type::f32,
|
||||
const_cast<void*>(src.raw_data()));
|
||||
} else {
|
||||
dtensor->set_data_handle(const_cast<void *>(src.raw_data()));
|
||||
}
|
||||
} else {
|
||||
VLOG(2) << "Output " << base_def_.output(i) << " as CPUTensor";
|
||||
Blob* dst = OperatorBase::OutputBlob(i);
|
||||
@ -159,6 +164,8 @@ class IDEEPFallbackOp final : public IDEEPOperator {
|
||||
protected:
|
||||
vector<Blob*> local_input_blobs_;
|
||||
vector<Blob*> local_output_blobs_;
|
||||
vector<bool> output_inplace_;
|
||||
vector<bool> input_share_;
|
||||
std::unique_ptr<CPUOp> base_op_;
|
||||
std::unique_ptr<Workspace> local_ws_;
|
||||
OperatorDef base_def_;
|
||||
|
99
caffe2/python/ideep/operator_fallback_op_test.py
Normal file
99
caffe2/python/ideep/operator_fallback_op_test.py
Normal file
@ -0,0 +1,99 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
import hypothesis.strategies as st
|
||||
from hypothesis import given
|
||||
import numpy as np
|
||||
from caffe2.python import core, workspace
|
||||
from caffe2.proto import caffe2_pb2
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
import caffe2.python.ideep_test_util as mu
|
||||
|
||||
|
||||
@unittest.skipIf(not workspace.C.use_ideep, "No IDEEP support.")
|
||||
class TestFallbackOps(hu.HypothesisTestCase):
|
||||
@given(stride=st.integers(1, 3),
|
||||
pad=st.integers(0, 3),
|
||||
kernel=st.integers(3, 5),
|
||||
size=st.integers(8, 10),
|
||||
input_channels=st.integers(1, 3),
|
||||
output_channels=st.integers(1, 5),
|
||||
batch_size=st.integers(1, 3),
|
||||
use_bias=st.booleans(),
|
||||
**mu.gcs)
|
||||
def test_in_place(self, stride, pad, kernel, size,
|
||||
input_channels, output_channels,
|
||||
batch_size, use_bias, gc, dc):
|
||||
# To expose fallback in-place potential issue, the fallback op
|
||||
# following ideep op must be run at least two iterations.
|
||||
conv = core.CreateOperator(
|
||||
"Conv",
|
||||
["X", "w", "b"] if use_bias else ["X", "w"],
|
||||
["Y"],
|
||||
stride=stride,
|
||||
pad=pad,
|
||||
kernel=kernel,
|
||||
device_option=dc[0]
|
||||
)
|
||||
X = np.random.rand(
|
||||
batch_size, input_channels, size, size).astype(np.float32) - 0.5
|
||||
w = np.random.rand(output_channels, input_channels, kernel, kernel) \
|
||||
.astype(np.float32) - 0.5
|
||||
b = np.random.rand(output_channels).astype(np.float32) - 0.5
|
||||
|
||||
old_ws_name = workspace.CurrentWorkspace()
|
||||
workspace.SwitchWorkspace("_device_check_", True)
|
||||
workspace.FeedBlob('X', X, dc[0])
|
||||
workspace.FeedBlob('w', w, dc[0])
|
||||
workspace.FeedBlob('b', b, dc[0])
|
||||
workspace.RunOperatorOnce(conv)
|
||||
Y = workspace.FetchBlob('Y')
|
||||
|
||||
scale = np.random.randn(Y.shape[1]).astype(np.float32)
|
||||
bias = np.random.randn(Y.shape[1]).astype(np.float32)
|
||||
ac = core.CreateOperator(
|
||||
"AffineChannel",
|
||||
["Y", "scale", "bias"],
|
||||
["Y"],
|
||||
is_learnable=False,
|
||||
device_option=dc[0]
|
||||
)
|
||||
workspace.FeedBlob('scale', scale, dc[0])
|
||||
workspace.FeedBlob('bias', bias, dc[0])
|
||||
workspace.RunOperatorOnce(ac)
|
||||
workspace.RunOperatorOnce(conv)
|
||||
workspace.RunOperatorOnce(ac)
|
||||
Y0 = workspace.FetchBlob('Y')
|
||||
|
||||
workspace.ResetWorkspace()
|
||||
dev_net = caffe2_pb2.NetDef()
|
||||
conv_dev = caffe2_pb2.OperatorDef()
|
||||
conv_dev.CopyFrom(conv)
|
||||
conv_dev.device_option.CopyFrom(dc[1])
|
||||
ac_dev = caffe2_pb2.OperatorDef()
|
||||
ac_dev.CopyFrom(ac)
|
||||
ac_dev.device_option.CopyFrom(dc[1])
|
||||
dev_net.op.extend([conv_dev, ac_dev])
|
||||
workspace.FeedBlob('X', X, dc[1])
|
||||
workspace.FeedBlob('w', w, dc[1])
|
||||
workspace.FeedBlob('b', b, dc[1])
|
||||
workspace.FeedBlob('scale', scale, dc[1])
|
||||
workspace.FeedBlob('bias', bias, dc[1])
|
||||
workspace.RunNetOnce(dev_net)
|
||||
workspace.RunNetOnce(dev_net)
|
||||
Y1 = workspace.FetchBlob('Y')
|
||||
|
||||
if not np.allclose(Y0, Y1, atol=0.01, rtol=0.01):
|
||||
print(Y1.flatten())
|
||||
print(Y0.flatten())
|
||||
print(np.max(np.abs(Y1 - Y0)))
|
||||
self.assertTrue(False)
|
||||
|
||||
workspace.SwitchWorkspace(old_ws_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -9,6 +9,7 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "caffe2/ideep/operators/operator_fallback_ideep.h"
|
||||
#include <caffe2/ideep/ideep_utils.h>
|
||||
|
||||
namespace caffe2 {
|
||||
@ -19,42 +20,42 @@ USE_IDEEP_DEF_ALIASES();
|
||||
class IDeepFetcher;
|
||||
class IDeepFeeder;
|
||||
|
||||
REGISTER_BLOB_FETCHER((TypeMeta::Id<itensor>()),IDeepFetcher);
|
||||
REGISTER_IDEEP_OPERATOR(Python, IDEEPFallbackOp<PythonOp<CPUContext, false>>);
|
||||
|
||||
REGISTER_BLOB_FETCHER((TypeMeta::Id<itensor>()), IDeepFetcher);
|
||||
REGISTER_BLOB_FEEDER(IDEEP, IDeepFeeder);
|
||||
|
||||
class IDeepFetcher : public BlobFetcherBase {
|
||||
TypeMeta type_transform(const itensor &atensor) {
|
||||
switch(atensor.get_data_type()) {
|
||||
case itensor::data_type::f32:
|
||||
return TypeMeta::Make<float>();
|
||||
case itensor::data_type::s16:
|
||||
return TypeMeta::Make<float16>();
|
||||
case itensor::data_type::s32:
|
||||
return TypeMeta::Make<int>();
|
||||
case itensor::data_type::s8:
|
||||
return TypeMeta::Make<int8_t>();
|
||||
case itensor::data_type::u8:
|
||||
return TypeMeta::Make<uint8_t>();
|
||||
default:
|
||||
// Should we throw exception?
|
||||
return TypeMeta();
|
||||
switch (atensor.get_data_type()) {
|
||||
case itensor::data_type::f32:
|
||||
return TypeMeta::Make<float>();
|
||||
case itensor::data_type::s32:
|
||||
return TypeMeta::Make<int>();
|
||||
case itensor::data_type::s8:
|
||||
return TypeMeta::Make<int8_t>();
|
||||
case itensor::data_type::u8:
|
||||
return TypeMeta::Make<uint8_t>();
|
||||
default:
|
||||
// Should we throw exception?
|
||||
return TypeMeta();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
pybind11::object Fetch(const Blob& blob) override {
|
||||
public:
|
||||
pybind11::object Fetch(const Blob &blob) override {
|
||||
try {
|
||||
return FetchTensor(blob.Get<itensor>(), true).obj;
|
||||
} catch (ideep::error& e) {
|
||||
VLOG(1) << "IDEEP error: " << e.message;
|
||||
} catch (ideep::error &e) {
|
||||
LOG(ERROR) << "IDEEP error: " << e.message;
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
FetchedBlob FetchTensor(const itensor& atensor, bool force_copy) {
|
||||
FetchedBlob FetchTensor(const itensor &atensor, bool force_copy) {
|
||||
FetchedBlob result;
|
||||
CAFFE_ENFORCE(atensor.materialized(),
|
||||
"Trying to fetch uninitialized tensor");
|
||||
"Trying to fetch uninitialized tensor");
|
||||
const int numpy_type = CaffeToNumpyType(type_transform(atensor));
|
||||
CAFFE_ENFORCE(
|
||||
numpy_type != -1,
|
||||
@ -64,17 +65,16 @@ class IDeepFetcher : public BlobFetcherBase {
|
||||
std::vector<npy_intp> npy_dims(dims.begin(), dims.end());
|
||||
|
||||
result.copied = force_copy || atensor.need_reorder();
|
||||
void* outPtr;
|
||||
void *outPtr;
|
||||
if (result.copied) {
|
||||
result.obj = py::reinterpret_steal<py::object>(
|
||||
PyArray_SimpleNew(atensor.ndims(), npy_dims.data(), numpy_type));
|
||||
outPtr = static_cast<void *>(
|
||||
PyArray_DATA(reinterpret_cast<PyArrayObject*>(result.obj.ptr())));
|
||||
PyArray_DATA(reinterpret_cast<PyArrayObject *>(result.obj.ptr())));
|
||||
} else {
|
||||
outPtr = atensor.get_data_handle();
|
||||
result.obj = py::reinterpret_steal<py::object>(
|
||||
PyArray_SimpleNewFromData(
|
||||
atensor.ndims(), npy_dims.data(), numpy_type, outPtr));
|
||||
result.obj = py::reinterpret_steal<py::object>(PyArray_SimpleNewFromData(
|
||||
atensor.ndims(), npy_dims.data(), numpy_type, outPtr));
|
||||
}
|
||||
|
||||
if (numpy_type == NPY_OBJECT) {
|
||||
@ -95,8 +95,6 @@ class IDeepFeeder : public BlobFeederBase {
|
||||
return itensor::data_type::f32;
|
||||
else if (meta == TypeMeta::Make<int>())
|
||||
return itensor::data_type::s32;
|
||||
else if (meta == TypeMeta::Make<float16>())
|
||||
return itensor::data_type::s16;
|
||||
else if (meta == TypeMeta::Make<int8_t>())
|
||||
return itensor::data_type::s8;
|
||||
else if (meta == TypeMeta::Make<uint8_t>())
|
||||
@ -105,53 +103,74 @@ class IDeepFeeder : public BlobFeederBase {
|
||||
return itensor::data_type::data_undef;
|
||||
}
|
||||
|
||||
public:
|
||||
void FeedTensor(
|
||||
const DeviceOption& option,
|
||||
PyArrayObject *original_array,
|
||||
itensor *tensor) {
|
||||
PyArrayObject *array = PyArray_GETCONTIGUOUS(original_array);
|
||||
auto g = MakeGuard([&]() {Py_XDECREF(array); });
|
||||
|
||||
const auto npy_type = PyArray_TYPE(array);
|
||||
const TypeMeta& meta = NumpyTypeToCaffe(npy_type);
|
||||
CAFFE_ENFORCE(
|
||||
meta.id() != TypeIdentifier::uninitialized(),
|
||||
public:
|
||||
void FeedTensor(
|
||||
const DeviceOption &option,
|
||||
PyArrayObject *original_array,
|
||||
itensor *tensor) {
|
||||
PyArrayObject *array = PyArray_GETCONTIGUOUS(original_array);
|
||||
auto g = MakeGuard([&]() { Py_XDECREF(array); });
|
||||
const auto npy_type = PyArray_TYPE(array);
|
||||
const TypeMeta &meta = NumpyTypeToCaffe(npy_type);
|
||||
CAFFE_ENFORCE_NE(
|
||||
meta.id(),
|
||||
TypeIdentifier::uninitialized(),
|
||||
"This numpy data type is not supported: ",
|
||||
PyArray_TYPE(array),
|
||||
".");
|
||||
PyArray_TYPE(array), ".");
|
||||
|
||||
int ndim = PyArray_NDIM(array);
|
||||
npy_intp* npy_dims = PyArray_DIMS(array);
|
||||
int ndim = PyArray_NDIM(array);
|
||||
npy_intp *npy_dims = PyArray_DIMS(array);
|
||||
|
||||
itensor::dims adims;
|
||||
for (int i = 0; i < ndim; i++) {
|
||||
adims.push_back(static_cast<itensor::dims::value_type>(
|
||||
npy_dims[i]));
|
||||
}
|
||||
itensor::dims adims;
|
||||
for (int i = 0; i < ndim; i++) {
|
||||
adims.push_back(static_cast<itensor::dims::value_type>(npy_dims[i]));
|
||||
}
|
||||
|
||||
switch (npy_type) {
|
||||
switch (npy_type) {
|
||||
case NPY_OBJECT:
|
||||
case NPY_UNICODE:
|
||||
CAFFE_THROW("IDeep doesn't support string");
|
||||
break;
|
||||
default:
|
||||
auto type = type_transform(meta);
|
||||
tensor->resize(adims, type);
|
||||
if (tensor->get_dims() != adims || type != tensor->get_data_type()) {
|
||||
tensor->resize(adims, type);
|
||||
}
|
||||
tensor->reorder_from(adims, type,
|
||||
static_cast<void *>(PyArray_DATA(array)));
|
||||
}
|
||||
}
|
||||
static_cast<void *>(PyArray_DATA(array)));
|
||||
}
|
||||
}
|
||||
|
||||
void Feed(const DeviceOption& option, PyArrayObject* original_array,
|
||||
Blob* blob) {
|
||||
try {
|
||||
bool ZeroDim(PyArrayObject *array) {
|
||||
int ndim = PyArray_NDIM(array);
|
||||
npy_intp *npy_dims = PyArray_DIMS(array);
|
||||
return ndim == 0 ||
|
||||
std::find(npy_dims, npy_dims + ndim, 0) != npy_dims + ndim;
|
||||
}
|
||||
|
||||
void Feed(const DeviceOption &option, PyArrayObject *original_array,
|
||||
Blob *blob) {
|
||||
try {
|
||||
PyArrayObject *array = PyArray_GETCONTIGUOUS(original_array);
|
||||
auto g = MakeGuard([&]() { Py_XDECREF(array); });
|
||||
|
||||
const auto npy_type = PyArray_TYPE(array);
|
||||
const TypeMeta &meta = NumpyTypeToCaffe(npy_type);
|
||||
// TODO: if necessary, use dispatcher.
|
||||
if (meta.Match<float>() && !ZeroDim(original_array)) {
|
||||
FeedTensor(option, original_array, blob->GetMutable<itensor>());
|
||||
} catch (ideep::error& e) {
|
||||
VLOG(1) << "IDEEP error: " << e.message;
|
||||
throw;
|
||||
} else {
|
||||
DeviceOption cpu_option(option);
|
||||
cpu_option.set_device_type(DeviceType::CPU);
|
||||
TensorFeeder<CPUContext> cpu_tensor_feeder;
|
||||
cpu_tensor_feeder.FeedTensor(cpu_option, original_array,
|
||||
blob->GetMutableTensor(CPU));
|
||||
}
|
||||
}
|
||||
} catch (ideep::error &e) {
|
||||
LOG(ERROR) << "IDEEP error: " << e.message;
|
||||
throw;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace python
|
||||
|
@ -11,4 +11,8 @@ if (USE_CUDA)
|
||||
|
||||
target_link_libraries(caffe2_detectron_ops_gpu caffe2_gpu)
|
||||
install(TARGETS caffe2_detectron_ops_gpu DESTINATION lib)
|
||||
elseif(NOT IOS_PLATFORM)
|
||||
add_library(caffe2_detectron_ops SHARED ${Detectron_CPU_SRCS})
|
||||
target_link_libraries(caffe2_detectron_ops caffe2)
|
||||
install(TARGETS caffe2_detectron_ops DESTINATION lib)
|
||||
endif()
|
||||
|
@ -15,9 +15,19 @@
|
||||
*/
|
||||
|
||||
#include "batch_permutation_op.h"
|
||||
#ifdef CAFFE2_USE_IDEEP
|
||||
#include <caffe2/ideep/operators/operator_fallback_ideep.h>
|
||||
#include <caffe2/ideep/utils/ideep_operator.h>
|
||||
#endif
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
#ifdef CAFFE2_USE_IDEEP
|
||||
REGISTER_IDEEP_OPERATOR(
|
||||
BatchPermutation,
|
||||
IDEEPFallbackOp<BatchPermutationOp<float, CPUContext>>);
|
||||
#endif
|
||||
|
||||
REGISTER_CPU_OPERATOR(BatchPermutation, BatchPermutationOp<float, CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(
|
||||
BatchPermutationGradient,
|
||||
|
@ -15,8 +15,17 @@
|
||||
*/
|
||||
|
||||
#include "upsample_nearest_op.h"
|
||||
#ifdef CAFFE2_USE_IDEEP
|
||||
#include "caffe2/ideep/operators/operator_fallback_ideep.h"
|
||||
#include "caffe2/ideep/utils/ideep_operator.h"
|
||||
#endif
|
||||
|
||||
namespace caffe2 {
|
||||
#ifdef CAFFE2_USE_IDEEP
|
||||
REGISTER_IDEEP_OPERATOR(
|
||||
UpsampleNearest,
|
||||
IDEEPFallbackOp<UpsampleNearestOp<float, CPUContext>>);
|
||||
#endif
|
||||
|
||||
REGISTER_CPU_OPERATOR(UpsampleNearest, UpsampleNearestOp<float, CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(
|
||||
|
@ -35,8 +35,50 @@ class UpsampleNearestOp final : public Operator<Context> {
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
bool RunOnDevice() override {
|
||||
// No CPU implementation for now
|
||||
CAFFE_NOT_IMPLEMENTED;
|
||||
auto translate_idx = [](int ii, int d1, int d2, int d3, int scale_factor) {
|
||||
int x, y, z, w;
|
||||
w = ii % d3;
|
||||
ii = ii/d3;
|
||||
z = ii % d2;
|
||||
ii = ii/d2;
|
||||
y = ii % d1;
|
||||
ii = ii/d1;
|
||||
x = ii;
|
||||
w = w/scale_factor;
|
||||
z = z/scale_factor;
|
||||
d2 /= scale_factor;
|
||||
d3 /= scale_factor;
|
||||
return (((x*d1+y)*d2)+z)*d3+w;
|
||||
};
|
||||
|
||||
auto& X = Input(0);
|
||||
auto* Y = Output(0);
|
||||
auto out_shape = X.dims();
|
||||
out_shape[X.ndim() - 1] *= scale_;
|
||||
out_shape[X.ndim() - 2] *= scale_;
|
||||
Y->Resize(out_shape);
|
||||
|
||||
int d1;
|
||||
int d2;
|
||||
int d3;
|
||||
if (X.ndim() == 3) {
|
||||
d1 = Y->dim32(0);
|
||||
d2 = Y->dim32(1);
|
||||
d3 = Y->dim32(2);
|
||||
} else {
|
||||
d1 = Y->dim32(1);
|
||||
d2 = Y->dim32(2);
|
||||
d3 = Y->dim32(3);
|
||||
}
|
||||
|
||||
const T *input_data = X.template data<T>();
|
||||
T *output_data = Y->template mutable_data<T>();
|
||||
|
||||
for (int ii = 0; ii < Y->size(); ii++) {
|
||||
int ipidx = translate_idx(ii, d1, d2, d3, scale_);
|
||||
output_data[ii] = input_data[ipidx];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
Reference in New Issue
Block a user