mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Added predictor bindings to python interface
Summary: from caffe2.python import workspace; p = workspace.Predictor(init_net, predict_net); outputs = p.run(inputs) Reviewed By: Yangqing Differential Revision: D4576793 fbshipit-source-id: b829bbcaf2e7c34dad85024177433207bd96a234
This commit is contained in:
committed by
Facebook Github Bot
parent
61dd35f1d6
commit
56f324d191
@ -536,11 +536,38 @@ void addObjectMethods(py::module& m) {
|
||||
.def(
|
||||
"__init__",
|
||||
[](Predictor& instance, py::bytes init_net, py::bytes predict_net) {
|
||||
CAFFE_ENFORCE(gWorkspace);
|
||||
NetDef init_net_, predict_net_;
|
||||
CAFFE_ENFORCE(ParseProtobufFromLargeString(init_net, &init_net_));
|
||||
CAFFE_ENFORCE(
|
||||
ParseProtobufFromLargeString(predict_net, &predict_net_));
|
||||
new (&instance) Predictor(init_net_, predict_net_);
|
||||
new (&instance) Predictor(init_net_, predict_net_, gWorkspace);
|
||||
})
|
||||
.def(
|
||||
"run",
|
||||
[](Predictor& instance,
|
||||
std::vector<py::object> inputs) -> std::vector<py::object> {
|
||||
std::vector<TensorCPU*> tensors;
|
||||
std::vector<TensorCPU> tensors_data(inputs.size());
|
||||
for (auto i = 0; i < inputs.size(); ++i) {
|
||||
auto input = inputs[i];
|
||||
CAFFE_ENFORCE(
|
||||
PyArray_Check(input.ptr()),
|
||||
"Input must be of type numpy array.");
|
||||
PyArrayObject* array =
|
||||
reinterpret_cast<PyArrayObject*>(input.ptr());
|
||||
TensorFeeder<CPUContext>().FeedTensor(
|
||||
DeviceOption(), array, &(tensors_data[i]));
|
||||
tensors.push_back(&(tensors_data[i]));
|
||||
}
|
||||
std::vector<TensorCPU*> out;
|
||||
instance.run(tensors, &out);
|
||||
std::vector<py::object> pyout;
|
||||
for (auto t : out) {
|
||||
pyout.push_back(
|
||||
TensorFetcher<CPUContext>().FetchTensor(*t, true).obj);
|
||||
}
|
||||
return pyout;
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -27,6 +27,7 @@ SwitchWorkspace = C.switch_workspace
|
||||
RootFolder = C.root_folder
|
||||
Workspaces = C.workspaces
|
||||
BenchmarkNet = C.benchmark_net
|
||||
Predictor = C.Predictor
|
||||
|
||||
is_asan = C.is_asan
|
||||
has_gpu_support = C.has_gpu_support
|
||||
|
@ -3,12 +3,11 @@ import os
|
||||
import unittest
|
||||
|
||||
from caffe2.proto import caffe2_pb2
|
||||
from caffe2.python import core, test_util, workspace
|
||||
from caffe2.python import core, test_util, workspace, cnn
|
||||
|
||||
import caffe2.python.hypothesis_test_util as htu
|
||||
import hypothesis.strategies as st
|
||||
from hypothesis import given
|
||||
from caffe2.proto import caffe2_pb2
|
||||
|
||||
|
||||
class TestWorkspace(unittest.TestCase):
|
||||
@ -469,5 +468,48 @@ class TestCWorkspace(htu.HypothesisTestCase):
|
||||
ws.create_net("...")
|
||||
|
||||
|
||||
class TestPredictor(unittest.TestCase):
|
||||
def _create_model(self):
|
||||
m = cnn.CNNModelHelper()
|
||||
y = m.FC("data", "y",
|
||||
dim_in=4, dim_out=2,
|
||||
weight_init=m.ConstantInit(1.0),
|
||||
bias_init=m.ConstantInit(0.0),
|
||||
axis=0)
|
||||
m.net.AddExternalOutput(y)
|
||||
return m
|
||||
|
||||
# Use this test with a bigger model to see how using Predictor allows to
|
||||
# avoid issues with low protobuf size limit in Python
|
||||
#
|
||||
# def test_predictor_predefined(self):
|
||||
# workspace.ResetWorkspace()
|
||||
# path = 'caffe2/caffe2/test/assets/'
|
||||
# with open(path + 'squeeze_predict_net.pb') as f:
|
||||
# self.predict_net = f.read()
|
||||
# with open(path + 'squeeze_init_net.pb') as f:
|
||||
# self.init_net = f.read()
|
||||
# self.predictor = workspace.Predictor(self.init_net, self.predict_net)
|
||||
|
||||
# inputs = [np.zeros((1, 3, 256, 256), dtype='f')]
|
||||
# outputs = self.predictor.run(inputs)
|
||||
# self.assertEqual(len(outputs), 1)
|
||||
# self.assertEqual(outputs[0].shape, (1, 1000, 1, 1))
|
||||
# self.assertAlmostEqual(outputs[0][0][0][0][0], 5.19026289e-05)
|
||||
|
||||
|
||||
def test_predictor_memory_model(self):
|
||||
workspace.ResetWorkspace()
|
||||
m = self._create_model()
|
||||
workspace.FeedBlob("data", np.zeros([4], dtype='float32'))
|
||||
self.predictor = workspace.Predictor(
|
||||
workspace.StringifyProto(m.param_init_net.Proto()),
|
||||
workspace.StringifyProto(m.net.Proto()))
|
||||
|
||||
inputs = np.array([1, 3, 256, 256], dtype='float32')
|
||||
outputs = self.predictor.run([inputs])
|
||||
np.testing.assert_array_almost_equal(np.array([[516, 516]], dtype='float32'), outputs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
BIN
caffe2/test/assets/squeeze_predict_net.pb
Normal file
BIN
caffe2/test/assets/squeeze_predict_net.pb
Normal file
Binary file not shown.
Reference in New Issue
Block a user