ScriptModuleOp in caffe2 (#18716)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18716

Might be useful as an intermediate stage for some systems that currently use Caffe2 nets as an execution mechanism.

Not sure it's a good idea all together, please comment.

Limitations:
- only Tensor types as inputs/outputs
- the entire module is serialized as a zip archive inside a proto in Caffe2 db, it'd be subject to 4Gb limit and is likely very slow. For small models it'd work though.
- no autograd, though it can be attached in principle
- no way to retrieve parameters inside the script module from C2 runtime perspective (though they potentially can be alias-fetched and stored as individual blobs)
- after deserialization, python wrappers returned don't have correct type (as we don't do module_lookup trick)

Build-wise, I had to add dependency from pybind_state to libtorch.so. I don't think we build Caffe2 python frontend independently anymore, so it should be fine.

Reviewed By: amirshim, houseroad

Differential Revision: D14339599

fbshipit-source-id: 88a37a8abd1f1c4703e5ef937031f222535d4080
This commit is contained in:
Dmytro Dzhulgakov
2019-04-05 01:04:58 -07:00
committed by Facebook Github Bot
parent 8bdd0c3a85
commit c34e5ff952
3 changed files with 321 additions and 58 deletions

View File

@ -0,0 +1,153 @@
#include <caffe2/core/context.h>
#include <caffe2/core/operator.h>
#include <caffe2/utils/math.h>
#include <torch/script.h>
#include "caffe2/core/blob_serialization.h"
namespace caffe2 {
using torch::jit::IValue;
using torch::jit::script::Method;
using torch::jit::script::Module;
namespace {
class ScriptModuleSerializer : public BlobSerializerBase {
public:
void Serialize(
const void* pointer,
TypeMeta typeMeta,
const string& name,
SerializationAcceptor acceptor) override {
CAFFE_ENFORCE(typeMeta.Match<std::shared_ptr<Module>>());
std::stringstream ss;
(*static_cast<const std::shared_ptr<Module>*>(pointer))->save(ss);
// NB: wrapping the entire zip archive as one string is probably not a
// good idea and might be slow. This is meant as a workaround, any proper
// usage would require splitting out tensor data separately.
//
// In the future we can do it by introducing a different "type" string for
// the more efficient serialization version (if we ever get to that point)
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type("torch::jit::script::Module");
blob_proto.set_content(ss.str());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
};
class ScriptModuleDeserializer : public BlobDeserializerBase {
public:
void Deserialize(const BlobProto& proto, Blob* blob) override {
const auto& serialized = proto.content();
// TODO: use adapter instead of istream?
std::stringstream ss;
ss << serialized;
ss.seekg(0);
*blob->GetMutable<std::shared_ptr<Module>>() = torch::jit::load(ss);
}
};
class ScriptModuleLoadOp final : public Operator<CPUContext> {
public:
ScriptModuleLoadOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws) {
CAFFE_ENFORCE(HasArgument("serialized_binary"));
}
bool RunOnDevice() override {
auto moduleBinary = GetSingleArgument<string>("serialized_binary", "");
// TODO: use adapter instead of istream?
std::stringstream ss;
ss << moduleBinary;
ss.seekg(0);
*OperatorBase::Output<std::shared_ptr<Module>>(0) = torch::jit::load(ss);
return true;
}
};
template <class Context>
class ScriptModuleOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ScriptModuleOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
method_name_(this->template GetSingleArgument<std::string>(
"method",
"forward")) {
// TODO: we could also parse extra arguments here and allow to pass in
// scalars to the method invocation. But there's probably less blocking need
// for that.
}
static caffe2::Tensor castIValueToTensor(const IValue& v) {
return caffe2::Tensor(torch::autograd::Variable(v.toTensor()).data());
}
bool RunOnDevice() override {
const auto& module = OperatorBase::Input<std::shared_ptr<Module>>(0);
Method& method = module->get_method(method_name_);
// Assume all inputs are tensor for now
std::vector<IValue> inputs;
const int num_inputs = InputSize();
inputs.reserve(num_inputs);
for (int i = 1; i < num_inputs; ++i) {
// jit::Interpreter takes only autograd variables (which have
// require_grad=False in this case)
inputs.emplace_back(torch::autograd::make_variable(at::Tensor(Input(i))));
}
// We just convert specified inputs. If some of the inputs were omitted and
// don't have default values, method::operator() is going to complain.
IValue output = method(inputs);
if (output.isTuple()) {
const std::vector<IValue>& elems = output.toTuple()->elements();
CAFFE_ENFORCE_EQ(elems.size(), OutputSize());
for (int i = 0; i < elems.size(); ++i) {
this->SetOutputTensor(i, castIValueToTensor(elems[i]));
}
} else if (output.isTensor()) {
CAFFE_ENFORCE_EQ(1, OutputSize());
this->SetOutputTensor(0, castIValueToTensor(output));
} else {
CAFFE_THROW("Unexpected return type: ", output.tagKind());
}
return true;
}
private:
std::string method_name_;
};
} // namespace
CAFFE_KNOWN_TYPE(std::shared_ptr<Module>);
REGISTER_BLOB_SERIALIZER(
(TypeMeta::Id<std::shared_ptr<Module>>()),
ScriptModuleSerializer);
// NB: the first argument to REGISTER_BLOB_DESERIALIZER macro doesn't really
// need to be a real type, it just get converted to string
REGISTER_BLOB_DESERIALIZER(
torch::jit::script::Module,
ScriptModuleDeserializer);
OPERATOR_SCHEMA(ScriptModule)
.NumInputs(1, INT_MAX)
.NumOutputs(0, INT_MAX)
.Input(0, "script_module_instance", "Instance of shared_ptr<Module>");
REGISTER_CPU_OPERATOR(ScriptModule, ScriptModuleOp<CPUContext>);
SHOULD_NOT_DO_GRADIENT(ScriptModule);
OPERATOR_SCHEMA(ScriptModuleLoad)
.NumInputs(0)
.NumOutputs(1)
.DisallowInputFillers()
.Output(0, "script_module_instance", "New instance of shared_ptr<Module>")
.Arg(
"serialized_binary",
"Binary string representing contents of .pt file (zip container)");
REGISTER_CPU_OPERATOR(ScriptModuleLoad, ScriptModuleLoadOp);
NO_GRADIENT(ScriptModuleLoad);
} // namespace caffe2

View File

@ -32,6 +32,14 @@
#include "caffe2/utils/string_utils.h" #include "caffe2/utils/string_utils.h"
#include "torch/csrc/autograd/variable.h" #include "torch/csrc/autograd/variable.h"
// Because of CMake setup, we can't depend on script module here just yet -
// it pulls in generated files from a different directory and it
// probabilistically breaks the build.
// TODO: enable if once shared libraries are unified in CMake
#ifdef FBCODE_CAFFE2
#include "torch/script.h"
#endif
namespace caffe2 { namespace caffe2 {
namespace python { namespace python {
@ -78,6 +86,19 @@ class StringFetcher : public BlobFetcherBase {
}; };
REGISTER_BLOB_FETCHER((TypeMeta::Id<string>()), StringFetcher); REGISTER_BLOB_FETCHER((TypeMeta::Id<string>()), StringFetcher);
#ifdef FBCODE_CAFFE2
class ScriptModuleFetcher : public BlobFetcherBase {
public:
pybind11::object Fetch(const Blob& blob) override {
return py::cast(blob.Get<std::shared_ptr<torch::jit::script::Module>>());
}
};
REGISTER_BLOB_FETCHER(
(TypeMeta::Id<std::shared_ptr<torch::jit::script::Module>>()),
caffe2::python::ScriptModuleFetcher);
#endif
static_assert( static_assert(
sizeof(int) == sizeof(int32_t), sizeof(int) == sizeof(int32_t),
"We make an assumption that int is always int32 for numpy " "We make an assumption that int is always int32 for numpy "
@ -194,6 +215,45 @@ py::object fetchBlob(Workspace* ws, const std::string& name) {
return py::bytes(ss.str()); return py::bytes(ss.str());
} }
} }
// This function can only return true, but keeping it for backward compatibility
bool feedBlob(
Blob* blob,
const py::object& arg,
const py::object device_option) {
DeviceOption option;
if (!device_option.is(py::none())) {
// If we have a device option passed in, read it.
CAFFE_ENFORCE(ParseProtoFromLargeString(
py::bytes(device_option).cast<std::string>(), &option));
}
#ifdef USE_NUMPY
if (PyArray_Check(arg.ptr())) { // numpy array
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(arg.ptr());
auto feeder = CreateFeeder(option.device_type());
CAFFE_ENFORCE(feeder, "Unknown device type encountered in FeedBlob.");
feeder->Feed(option, array, blob, true); /* default to inplace feed */
return true;
}
#else
CAFFE_THROW("Caffe2 compiled without NumPy support.");
#endif // USE_NUMPY
if (PyBytes_Check(arg.ptr()) || PyUnicode_Check(arg.ptr())) {
*blob->GetMutable<std::string>() = arg.cast<std::string>();
return true;
}
#ifdef FBCODE_CAFFE2
if (py::isinstance<torch::jit::script::Module>(arg)) {
*blob->GetMutable<std::shared_ptr<torch::jit::script::Module>>() =
arg.cast<std::shared_ptr<torch::jit::script::Module>>();
return true;
}
#endif
CAFFE_THROW(
"Unexpected type of argument - only numpy array or string are "
"supported for feeding");
return false;
}
} // namespace python_detail } // namespace python_detail
class GetPythonGradient : public GradientMakerBase { class GetPythonGradient : public GradientMakerBase {
@ -352,36 +412,7 @@ void addObjectMethods(py::module& m) {
py::return_value_policy::reference_internal) py::return_value_policy::reference_internal)
.def( .def(
"_feed", "_feed",
[](Blob* blob, &python_detail::feedBlob,
const py::object& arg,
const py::object device_option) {
DeviceOption option;
if (!device_option.is(py::none())) {
// If we have a device option passed in, read it.
CAFFE_ENFORCE(ParseProtoFromLargeString(
py::bytes(device_option).cast<std::string>(), &option));
}
#ifdef USE_NUMPY
if (PyArray_Check(arg.ptr())) { // numpy array
PyArrayObject* array
= reinterpret_cast<PyArrayObject*>(arg.ptr());
auto feeder = CreateFeeder(option.device_type());
CAFFE_ENFORCE(
feeder, "Unknown device type encountered in FeedBlob.");
feeder->Feed(option, array, blob, true); /* default to inplace feed */
return true;
}
#else
CAFFE_THROW("Caffe2 compiled without NumPy support.");
#endif // USE_NUMPY
if (PyBytes_Check(arg.ptr()) || PyUnicode_Check(arg.ptr())) {
*blob->GetMutable<std::string>() = arg.cast<std::string>();
return true;
}
CAFFE_THROW(
"Unexpected type of argument - only numpy array or string are "
"supported for feeding");
},
"Feed an input array or string, with the (optional) DeviceOption", "Feed an input array or string, with the (optional) DeviceOption",
py::arg("arg"), py::arg("arg"),
py::arg("device_option") = py::none()) py::arg("device_option") = py::none())
@ -1461,35 +1492,8 @@ void addGlobalMethods(py::module& m) {
m.def( m.def(
"feed_blob", "feed_blob",
[](const std::string& name, py::object arg, py::object device_option) { [](const std::string& name, py::object arg, py::object device_option) {
DeviceOption option;
if (!device_option.is(py::none())) {
// If we have a device option passed in, read it.
CAFFE_ENFORCE(ParseProtoFromLargeString(
py::bytes(device_option).cast<std::string>(), &option));
}
auto* blob = gWorkspace->CreateBlob(name); auto* blob = gWorkspace->CreateBlob(name);
#ifdef USE_NUMPY return python_detail::feedBlob(blob, arg, device_option);
if (PyArray_Check(arg.ptr())) { // numpy array
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(arg.ptr());
auto feeder = CreateFeeder(option.device_type());
CAFFE_ENFORCE(
feeder,
"Unknown device type encountered in FeedBlob: ",
option.device_type());
feeder->Feed(option, array, blob);
return true;
}
#else
CAFFE_THROW("Caffe2 was compiled without NumPy support.");
#endif // USE_NUMPY
if (PyBytes_Check(arg.ptr()) || PyUnicode_Check(arg.ptr())) { // string
*blob->GetMutable<std::string>() = arg.cast<std::string>();
return true;
}
CAFFE_THROW(
"Unexpected type of argument - only numpy array or string are "
"supported for feeding");
return false;
}, },
"", "",
py::arg("name"), py::arg("name"),

View File

@ -5,6 +5,8 @@ from __future__ import unicode_literals
import numpy as np import numpy as np
import os import os
import shutil
import tempfile
import unittest import unittest
import torch import torch
@ -701,5 +703,109 @@ class TestTransform(htu.HypothesisTestCase):
workspace.RunNetOnce(proto.SerializeToString()), True) workspace.RunNetOnce(proto.SerializeToString()), True)
class MyModule(torch.jit.ScriptModule):
def __init__(self):
super(MyModule, self).__init__()
self.mult = torch.nn.Parameter(torch.tensor([[1, 2, 3, 4, 5.0]]))
@torch.jit.script_method
def forward(self, x):
return self.mult.mm(x)
@torch.jit.script_method
def multi_input(self, x, y, z=2):
# type: (Tensor, Tensor, int) -> Tensor
return x + y + z
@torch.jit.script_method
def multi_output(self, x):
return (x, x + 1)
@unittest.skipIf(
"ScriptModule" not in core._REGISTERED_OPERATORS,
"Script module integration in Caffe2 is not enabled")
class TestScriptModule(test_util.TestCase):
def _createFeedModule(self):
workspace.FeedBlob('m', MyModule())
def testCreation(self):
m = MyModule()
workspace.FeedBlob('module', m)
m2 = workspace.FetchBlob('module')
self.assertTrue(m is m2)
def testForward(self):
self._createFeedModule()
val = np.random.rand(5, 5).astype(np.float32)
param = np.array([[1, 2, 3, 4, 5]]).astype(np.float32)
workspace.FeedBlob('w', val)
workspace.RunOperatorOnce(core.CreateOperator("ScriptModule", ["m", "w"], ["y"]))
np.testing.assert_almost_equal(workspace.FetchBlob("y"), np.matmul(param, val), decimal=5)
def testMultiInputOutput(self):
self._createFeedModule()
val = np.random.rand(5, 5).astype(np.float32)
workspace.FeedBlob('w', val)
val2 = np.random.rand(5, 5).astype(np.float32)
workspace.FeedBlob('w2', val2)
workspace.RunOperatorOnce(core.CreateOperator("ScriptModule", ["m", "w", "w2"], ["y"], method="multi_input"))
workspace.RunOperatorOnce(core.CreateOperator("ScriptModule", ["m", "w"], ["y1", "y2"], method="multi_output"))
np.testing.assert_almost_equal(workspace.FetchBlob("y"), val + val2 + 2, decimal=5)
np.testing.assert_almost_equal(workspace.FetchBlob("y1"), val, decimal=5)
np.testing.assert_almost_equal(workspace.FetchBlob("y2"), val + 1, decimal=5)
def testSerialization(self):
tmpdir = tempfile.mkdtemp()
try:
self._createFeedModule()
workspace.RunOperatorOnce(core.CreateOperator(
"Save",
["m"], [],
absolute_path=1,
db=os.path.join(tmpdir, "db"), db_type="minidb"))
workspace.ResetWorkspace()
self.assertFalse(workspace.HasBlob('m'))
workspace.RunOperatorOnce(core.CreateOperator(
"Load",
[], [],
absolute_path=1,
db=os.path.join(tmpdir, "db"), db_type="minidb",
load_all=1))
self.assertTrue(workspace.HasBlob('m'))
# TODO: make caffe2 side load return python-sided module
# right now it returns the base class (torch._C.ScriptModule)
# self.assertTrue(isinstance(workspace.FetchBlob('m'), torch.jit.ScriptModule))
# do something with the module
val = np.random.rand(5, 5).astype(np.float32)
param = np.array([[1, 2, 3, 4, 5]]).astype(np.float32)
workspace.FeedBlob('w', val)
workspace.RunOperatorOnce(core.CreateOperator("ScriptModule", ["m", "w"], ["y"]))
np.testing.assert_almost_equal(workspace.FetchBlob("y"), np.matmul(param, val), decimal=5)
finally:
# clean up temp folder.
try:
shutil.rmtree(tmpdir)
except OSError as e:
if e.errno != errno.ENOENT:
raise
class TestScriptModuleFromString(TestScriptModule):
def _createFeedModule(self):
workspace.RunOperatorOnce(
core.CreateOperator(
"ScriptModuleLoad", [], ["m"],
serialized_binary=self._get_modules_bytes(MyModule())))
def _get_modules_bytes(self, the_module):
import io
buffer = io.BytesIO()
torch.jit.save(the_module, buffer)
return buffer.getvalue()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()