mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
ScriptModuleOp in caffe2 (#18716)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18716 Might be useful as an intermediate stage for some systems that currently use Caffe2 nets as an execution mechanism. Not sure it's a good idea all together, please comment. Limitations: - only Tensor types as inputs/outputs - the entire module is serialized as a zip archive inside a proto in Caffe2 db, it'd be subject to 4Gb limit and is likely very slow. For small models it'd work though. - no autograd, though it can be attached in principle - no way to retrieve parameters inside the script module from C2 runtime perspective (though they potentially can be alias-fetched and stored as individual blobs) - after deserialization, python wrappers returned don't have correct type (as we don't do module_lookup trick) Build-wise, I had to add dependency from pybind_state to libtorch.so. I don't think we build Caffe2 python frontend independently anymore, so it should be fine. Reviewed By: amirshim, houseroad Differential Revision: D14339599 fbshipit-source-id: 88a37a8abd1f1c4703e5ef937031f222535d4080
This commit is contained in:
committed by
Facebook Github Bot
parent
8bdd0c3a85
commit
c34e5ff952
153
caffe2/contrib/pytorch/script_module_op.cc
Normal file
153
caffe2/contrib/pytorch/script_module_op.cc
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
#include <caffe2/core/context.h>
|
||||||
|
#include <caffe2/core/operator.h>
|
||||||
|
#include <caffe2/utils/math.h>
|
||||||
|
#include <torch/script.h>
|
||||||
|
#include "caffe2/core/blob_serialization.h"
|
||||||
|
|
||||||
|
namespace caffe2 {
|
||||||
|
|
||||||
|
using torch::jit::IValue;
|
||||||
|
using torch::jit::script::Method;
|
||||||
|
using torch::jit::script::Module;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ScriptModuleSerializer : public BlobSerializerBase {
|
||||||
|
public:
|
||||||
|
void Serialize(
|
||||||
|
const void* pointer,
|
||||||
|
TypeMeta typeMeta,
|
||||||
|
const string& name,
|
||||||
|
SerializationAcceptor acceptor) override {
|
||||||
|
CAFFE_ENFORCE(typeMeta.Match<std::shared_ptr<Module>>());
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
(*static_cast<const std::shared_ptr<Module>*>(pointer))->save(ss);
|
||||||
|
|
||||||
|
// NB: wrapping the entire zip archive as one string is probably not a
|
||||||
|
// good idea and might be slow. This is meant as a workaround, any proper
|
||||||
|
// usage would require splitting out tensor data separately.
|
||||||
|
//
|
||||||
|
// In the future we can do it by introducing a different "type" string for
|
||||||
|
// the more efficient serialization version (if we ever get to that point)
|
||||||
|
BlobProto blob_proto;
|
||||||
|
blob_proto.set_name(name);
|
||||||
|
blob_proto.set_type("torch::jit::script::Module");
|
||||||
|
blob_proto.set_content(ss.str());
|
||||||
|
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ScriptModuleDeserializer : public BlobDeserializerBase {
|
||||||
|
public:
|
||||||
|
void Deserialize(const BlobProto& proto, Blob* blob) override {
|
||||||
|
const auto& serialized = proto.content();
|
||||||
|
// TODO: use adapter instead of istream?
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << serialized;
|
||||||
|
ss.seekg(0);
|
||||||
|
*blob->GetMutable<std::shared_ptr<Module>>() = torch::jit::load(ss);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ScriptModuleLoadOp final : public Operator<CPUContext> {
|
||||||
|
public:
|
||||||
|
ScriptModuleLoadOp(const OperatorDef& operator_def, Workspace* ws)
|
||||||
|
: Operator<CPUContext>(operator_def, ws) {
|
||||||
|
CAFFE_ENFORCE(HasArgument("serialized_binary"));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RunOnDevice() override {
|
||||||
|
auto moduleBinary = GetSingleArgument<string>("serialized_binary", "");
|
||||||
|
// TODO: use adapter instead of istream?
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << moduleBinary;
|
||||||
|
ss.seekg(0);
|
||||||
|
*OperatorBase::Output<std::shared_ptr<Module>>(0) = torch::jit::load(ss);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class Context>
|
||||||
|
class ScriptModuleOp final : public Operator<Context> {
|
||||||
|
public:
|
||||||
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||||
|
|
||||||
|
ScriptModuleOp(const OperatorDef& operator_def, Workspace* ws)
|
||||||
|
: Operator<Context>(operator_def, ws),
|
||||||
|
method_name_(this->template GetSingleArgument<std::string>(
|
||||||
|
"method",
|
||||||
|
"forward")) {
|
||||||
|
// TODO: we could also parse extra arguments here and allow to pass in
|
||||||
|
// scalars to the method invocation. But there's probably less blocking need
|
||||||
|
// for that.
|
||||||
|
}
|
||||||
|
|
||||||
|
static caffe2::Tensor castIValueToTensor(const IValue& v) {
|
||||||
|
return caffe2::Tensor(torch::autograd::Variable(v.toTensor()).data());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RunOnDevice() override {
|
||||||
|
const auto& module = OperatorBase::Input<std::shared_ptr<Module>>(0);
|
||||||
|
Method& method = module->get_method(method_name_);
|
||||||
|
// Assume all inputs are tensor for now
|
||||||
|
std::vector<IValue> inputs;
|
||||||
|
const int num_inputs = InputSize();
|
||||||
|
inputs.reserve(num_inputs);
|
||||||
|
for (int i = 1; i < num_inputs; ++i) {
|
||||||
|
// jit::Interpreter takes only autograd variables (which have
|
||||||
|
// require_grad=False in this case)
|
||||||
|
inputs.emplace_back(torch::autograd::make_variable(at::Tensor(Input(i))));
|
||||||
|
}
|
||||||
|
// We just convert specified inputs. If some of the inputs were omitted and
|
||||||
|
// don't have default values, method::operator() is going to complain.
|
||||||
|
IValue output = method(inputs);
|
||||||
|
if (output.isTuple()) {
|
||||||
|
const std::vector<IValue>& elems = output.toTuple()->elements();
|
||||||
|
CAFFE_ENFORCE_EQ(elems.size(), OutputSize());
|
||||||
|
for (int i = 0; i < elems.size(); ++i) {
|
||||||
|
this->SetOutputTensor(i, castIValueToTensor(elems[i]));
|
||||||
|
}
|
||||||
|
} else if (output.isTensor()) {
|
||||||
|
CAFFE_ENFORCE_EQ(1, OutputSize());
|
||||||
|
this->SetOutputTensor(0, castIValueToTensor(output));
|
||||||
|
} else {
|
||||||
|
CAFFE_THROW("Unexpected return type: ", output.tagKind());
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string method_name_;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
CAFFE_KNOWN_TYPE(std::shared_ptr<Module>);
|
||||||
|
|
||||||
|
REGISTER_BLOB_SERIALIZER(
|
||||||
|
(TypeMeta::Id<std::shared_ptr<Module>>()),
|
||||||
|
ScriptModuleSerializer);
|
||||||
|
// NB: the first argument to REGISTER_BLOB_DESERIALIZER macro doesn't really
|
||||||
|
// need to be a real type, it just get converted to string
|
||||||
|
REGISTER_BLOB_DESERIALIZER(
|
||||||
|
torch::jit::script::Module,
|
||||||
|
ScriptModuleDeserializer);
|
||||||
|
|
||||||
|
OPERATOR_SCHEMA(ScriptModule)
|
||||||
|
.NumInputs(1, INT_MAX)
|
||||||
|
.NumOutputs(0, INT_MAX)
|
||||||
|
.Input(0, "script_module_instance", "Instance of shared_ptr<Module>");
|
||||||
|
REGISTER_CPU_OPERATOR(ScriptModule, ScriptModuleOp<CPUContext>);
|
||||||
|
SHOULD_NOT_DO_GRADIENT(ScriptModule);
|
||||||
|
|
||||||
|
OPERATOR_SCHEMA(ScriptModuleLoad)
|
||||||
|
.NumInputs(0)
|
||||||
|
.NumOutputs(1)
|
||||||
|
.DisallowInputFillers()
|
||||||
|
.Output(0, "script_module_instance", "New instance of shared_ptr<Module>")
|
||||||
|
.Arg(
|
||||||
|
"serialized_binary",
|
||||||
|
"Binary string representing contents of .pt file (zip container)");
|
||||||
|
REGISTER_CPU_OPERATOR(ScriptModuleLoad, ScriptModuleLoadOp);
|
||||||
|
NO_GRADIENT(ScriptModuleLoad);
|
||||||
|
|
||||||
|
} // namespace caffe2
|
||||||
@ -32,6 +32,14 @@
|
|||||||
#include "caffe2/utils/string_utils.h"
|
#include "caffe2/utils/string_utils.h"
|
||||||
#include "torch/csrc/autograd/variable.h"
|
#include "torch/csrc/autograd/variable.h"
|
||||||
|
|
||||||
|
// Because of CMake setup, we can't depend on script module here just yet -
|
||||||
|
// it pulls in generated files from a different directory and it
|
||||||
|
// probabilistically breaks the build.
|
||||||
|
// TODO: enable if once shared libraries are unified in CMake
|
||||||
|
#ifdef FBCODE_CAFFE2
|
||||||
|
#include "torch/script.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace caffe2 {
|
namespace caffe2 {
|
||||||
namespace python {
|
namespace python {
|
||||||
|
|
||||||
@ -78,6 +86,19 @@ class StringFetcher : public BlobFetcherBase {
|
|||||||
};
|
};
|
||||||
REGISTER_BLOB_FETCHER((TypeMeta::Id<string>()), StringFetcher);
|
REGISTER_BLOB_FETCHER((TypeMeta::Id<string>()), StringFetcher);
|
||||||
|
|
||||||
|
#ifdef FBCODE_CAFFE2
|
||||||
|
class ScriptModuleFetcher : public BlobFetcherBase {
|
||||||
|
public:
|
||||||
|
pybind11::object Fetch(const Blob& blob) override {
|
||||||
|
return py::cast(blob.Get<std::shared_ptr<torch::jit::script::Module>>());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_BLOB_FETCHER(
|
||||||
|
(TypeMeta::Id<std::shared_ptr<torch::jit::script::Module>>()),
|
||||||
|
caffe2::python::ScriptModuleFetcher);
|
||||||
|
#endif
|
||||||
|
|
||||||
static_assert(
|
static_assert(
|
||||||
sizeof(int) == sizeof(int32_t),
|
sizeof(int) == sizeof(int32_t),
|
||||||
"We make an assumption that int is always int32 for numpy "
|
"We make an assumption that int is always int32 for numpy "
|
||||||
@ -194,6 +215,45 @@ py::object fetchBlob(Workspace* ws, const std::string& name) {
|
|||||||
return py::bytes(ss.str());
|
return py::bytes(ss.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This function can only return true, but keeping it for backward compatibility
|
||||||
|
bool feedBlob(
|
||||||
|
Blob* blob,
|
||||||
|
const py::object& arg,
|
||||||
|
const py::object device_option) {
|
||||||
|
DeviceOption option;
|
||||||
|
if (!device_option.is(py::none())) {
|
||||||
|
// If we have a device option passed in, read it.
|
||||||
|
CAFFE_ENFORCE(ParseProtoFromLargeString(
|
||||||
|
py::bytes(device_option).cast<std::string>(), &option));
|
||||||
|
}
|
||||||
|
#ifdef USE_NUMPY
|
||||||
|
if (PyArray_Check(arg.ptr())) { // numpy array
|
||||||
|
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(arg.ptr());
|
||||||
|
auto feeder = CreateFeeder(option.device_type());
|
||||||
|
CAFFE_ENFORCE(feeder, "Unknown device type encountered in FeedBlob.");
|
||||||
|
feeder->Feed(option, array, blob, true); /* default to inplace feed */
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
CAFFE_THROW("Caffe2 compiled without NumPy support.");
|
||||||
|
#endif // USE_NUMPY
|
||||||
|
if (PyBytes_Check(arg.ptr()) || PyUnicode_Check(arg.ptr())) {
|
||||||
|
*blob->GetMutable<std::string>() = arg.cast<std::string>();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
#ifdef FBCODE_CAFFE2
|
||||||
|
if (py::isinstance<torch::jit::script::Module>(arg)) {
|
||||||
|
*blob->GetMutable<std::shared_ptr<torch::jit::script::Module>>() =
|
||||||
|
arg.cast<std::shared_ptr<torch::jit::script::Module>>();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
CAFFE_THROW(
|
||||||
|
"Unexpected type of argument - only numpy array or string are "
|
||||||
|
"supported for feeding");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
} // namespace python_detail
|
} // namespace python_detail
|
||||||
|
|
||||||
class GetPythonGradient : public GradientMakerBase {
|
class GetPythonGradient : public GradientMakerBase {
|
||||||
@ -352,36 +412,7 @@ void addObjectMethods(py::module& m) {
|
|||||||
py::return_value_policy::reference_internal)
|
py::return_value_policy::reference_internal)
|
||||||
.def(
|
.def(
|
||||||
"_feed",
|
"_feed",
|
||||||
[](Blob* blob,
|
&python_detail::feedBlob,
|
||||||
const py::object& arg,
|
|
||||||
const py::object device_option) {
|
|
||||||
DeviceOption option;
|
|
||||||
if (!device_option.is(py::none())) {
|
|
||||||
// If we have a device option passed in, read it.
|
|
||||||
CAFFE_ENFORCE(ParseProtoFromLargeString(
|
|
||||||
py::bytes(device_option).cast<std::string>(), &option));
|
|
||||||
}
|
|
||||||
#ifdef USE_NUMPY
|
|
||||||
if (PyArray_Check(arg.ptr())) { // numpy array
|
|
||||||
PyArrayObject* array
|
|
||||||
= reinterpret_cast<PyArrayObject*>(arg.ptr());
|
|
||||||
auto feeder = CreateFeeder(option.device_type());
|
|
||||||
CAFFE_ENFORCE(
|
|
||||||
feeder, "Unknown device type encountered in FeedBlob.");
|
|
||||||
feeder->Feed(option, array, blob, true); /* default to inplace feed */
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
CAFFE_THROW("Caffe2 compiled without NumPy support.");
|
|
||||||
#endif // USE_NUMPY
|
|
||||||
if (PyBytes_Check(arg.ptr()) || PyUnicode_Check(arg.ptr())) {
|
|
||||||
*blob->GetMutable<std::string>() = arg.cast<std::string>();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
CAFFE_THROW(
|
|
||||||
"Unexpected type of argument - only numpy array or string are "
|
|
||||||
"supported for feeding");
|
|
||||||
},
|
|
||||||
"Feed an input array or string, with the (optional) DeviceOption",
|
"Feed an input array or string, with the (optional) DeviceOption",
|
||||||
py::arg("arg"),
|
py::arg("arg"),
|
||||||
py::arg("device_option") = py::none())
|
py::arg("device_option") = py::none())
|
||||||
@ -1461,35 +1492,8 @@ void addGlobalMethods(py::module& m) {
|
|||||||
m.def(
|
m.def(
|
||||||
"feed_blob",
|
"feed_blob",
|
||||||
[](const std::string& name, py::object arg, py::object device_option) {
|
[](const std::string& name, py::object arg, py::object device_option) {
|
||||||
DeviceOption option;
|
|
||||||
if (!device_option.is(py::none())) {
|
|
||||||
// If we have a device option passed in, read it.
|
|
||||||
CAFFE_ENFORCE(ParseProtoFromLargeString(
|
|
||||||
py::bytes(device_option).cast<std::string>(), &option));
|
|
||||||
}
|
|
||||||
auto* blob = gWorkspace->CreateBlob(name);
|
auto* blob = gWorkspace->CreateBlob(name);
|
||||||
#ifdef USE_NUMPY
|
return python_detail::feedBlob(blob, arg, device_option);
|
||||||
if (PyArray_Check(arg.ptr())) { // numpy array
|
|
||||||
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(arg.ptr());
|
|
||||||
auto feeder = CreateFeeder(option.device_type());
|
|
||||||
CAFFE_ENFORCE(
|
|
||||||
feeder,
|
|
||||||
"Unknown device type encountered in FeedBlob: ",
|
|
||||||
option.device_type());
|
|
||||||
feeder->Feed(option, array, blob);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
CAFFE_THROW("Caffe2 was compiled without NumPy support.");
|
|
||||||
#endif // USE_NUMPY
|
|
||||||
if (PyBytes_Check(arg.ptr()) || PyUnicode_Check(arg.ptr())) { // string
|
|
||||||
*blob->GetMutable<std::string>() = arg.cast<std::string>();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
CAFFE_THROW(
|
|
||||||
"Unexpected type of argument - only numpy array or string are "
|
|
||||||
"supported for feeding");
|
|
||||||
return false;
|
|
||||||
},
|
},
|
||||||
"",
|
"",
|
||||||
py::arg("name"),
|
py::arg("name"),
|
||||||
|
|||||||
@ -5,6 +5,8 @@ from __future__ import unicode_literals
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -701,5 +703,109 @@ class TestTransform(htu.HypothesisTestCase):
|
|||||||
workspace.RunNetOnce(proto.SerializeToString()), True)
|
workspace.RunNetOnce(proto.SerializeToString()), True)
|
||||||
|
|
||||||
|
|
||||||
|
class MyModule(torch.jit.ScriptModule):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModule, self).__init__()
|
||||||
|
self.mult = torch.nn.Parameter(torch.tensor([[1, 2, 3, 4, 5.0]]))
|
||||||
|
|
||||||
|
@torch.jit.script_method
|
||||||
|
def forward(self, x):
|
||||||
|
return self.mult.mm(x)
|
||||||
|
|
||||||
|
@torch.jit.script_method
|
||||||
|
def multi_input(self, x, y, z=2):
|
||||||
|
# type: (Tensor, Tensor, int) -> Tensor
|
||||||
|
return x + y + z
|
||||||
|
|
||||||
|
@torch.jit.script_method
|
||||||
|
def multi_output(self, x):
|
||||||
|
return (x, x + 1)
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
"ScriptModule" not in core._REGISTERED_OPERATORS,
|
||||||
|
"Script module integration in Caffe2 is not enabled")
|
||||||
|
class TestScriptModule(test_util.TestCase):
|
||||||
|
def _createFeedModule(self):
|
||||||
|
workspace.FeedBlob('m', MyModule())
|
||||||
|
|
||||||
|
def testCreation(self):
|
||||||
|
m = MyModule()
|
||||||
|
workspace.FeedBlob('module', m)
|
||||||
|
m2 = workspace.FetchBlob('module')
|
||||||
|
self.assertTrue(m is m2)
|
||||||
|
|
||||||
|
def testForward(self):
|
||||||
|
self._createFeedModule()
|
||||||
|
val = np.random.rand(5, 5).astype(np.float32)
|
||||||
|
param = np.array([[1, 2, 3, 4, 5]]).astype(np.float32)
|
||||||
|
workspace.FeedBlob('w', val)
|
||||||
|
workspace.RunOperatorOnce(core.CreateOperator("ScriptModule", ["m", "w"], ["y"]))
|
||||||
|
np.testing.assert_almost_equal(workspace.FetchBlob("y"), np.matmul(param, val), decimal=5)
|
||||||
|
|
||||||
|
def testMultiInputOutput(self):
|
||||||
|
self._createFeedModule()
|
||||||
|
val = np.random.rand(5, 5).astype(np.float32)
|
||||||
|
workspace.FeedBlob('w', val)
|
||||||
|
val2 = np.random.rand(5, 5).astype(np.float32)
|
||||||
|
workspace.FeedBlob('w2', val2)
|
||||||
|
workspace.RunOperatorOnce(core.CreateOperator("ScriptModule", ["m", "w", "w2"], ["y"], method="multi_input"))
|
||||||
|
workspace.RunOperatorOnce(core.CreateOperator("ScriptModule", ["m", "w"], ["y1", "y2"], method="multi_output"))
|
||||||
|
np.testing.assert_almost_equal(workspace.FetchBlob("y"), val + val2 + 2, decimal=5)
|
||||||
|
np.testing.assert_almost_equal(workspace.FetchBlob("y1"), val, decimal=5)
|
||||||
|
np.testing.assert_almost_equal(workspace.FetchBlob("y2"), val + 1, decimal=5)
|
||||||
|
|
||||||
|
def testSerialization(self):
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
try:
|
||||||
|
self._createFeedModule()
|
||||||
|
workspace.RunOperatorOnce(core.CreateOperator(
|
||||||
|
"Save",
|
||||||
|
["m"], [],
|
||||||
|
absolute_path=1,
|
||||||
|
db=os.path.join(tmpdir, "db"), db_type="minidb"))
|
||||||
|
workspace.ResetWorkspace()
|
||||||
|
|
||||||
|
self.assertFalse(workspace.HasBlob('m'))
|
||||||
|
workspace.RunOperatorOnce(core.CreateOperator(
|
||||||
|
"Load",
|
||||||
|
[], [],
|
||||||
|
absolute_path=1,
|
||||||
|
db=os.path.join(tmpdir, "db"), db_type="minidb",
|
||||||
|
load_all=1))
|
||||||
|
self.assertTrue(workspace.HasBlob('m'))
|
||||||
|
# TODO: make caffe2 side load return python-sided module
|
||||||
|
# right now it returns the base class (torch._C.ScriptModule)
|
||||||
|
# self.assertTrue(isinstance(workspace.FetchBlob('m'), torch.jit.ScriptModule))
|
||||||
|
|
||||||
|
# do something with the module
|
||||||
|
val = np.random.rand(5, 5).astype(np.float32)
|
||||||
|
param = np.array([[1, 2, 3, 4, 5]]).astype(np.float32)
|
||||||
|
workspace.FeedBlob('w', val)
|
||||||
|
workspace.RunOperatorOnce(core.CreateOperator("ScriptModule", ["m", "w"], ["y"]))
|
||||||
|
np.testing.assert_almost_equal(workspace.FetchBlob("y"), np.matmul(param, val), decimal=5)
|
||||||
|
finally:
|
||||||
|
# clean up temp folder.
|
||||||
|
try:
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno != errno.ENOENT:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class TestScriptModuleFromString(TestScriptModule):
|
||||||
|
def _createFeedModule(self):
|
||||||
|
workspace.RunOperatorOnce(
|
||||||
|
core.CreateOperator(
|
||||||
|
"ScriptModuleLoad", [], ["m"],
|
||||||
|
serialized_binary=self._get_modules_bytes(MyModule())))
|
||||||
|
|
||||||
|
def _get_modules_bytes(self, the_module):
|
||||||
|
import io
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.jit.save(the_module, buffer)
|
||||||
|
return buffer.getvalue()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user