ScriptModuleOp in caffe2 (#18716)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18716

Might be useful as an intermediate stage for some systems that currently use Caffe2 nets as an execution mechanism.

Not sure it's a good idea all together, please comment.

Limitations:
- only Tensor types as inputs/outputs
- the entire module is serialized as a zip archive inside a proto in Caffe2 db, it'd be subject to 4Gb limit and is likely very slow. For small models it'd work though.
- no autograd, though it can be attached in principle
- no way to retrieve parameters inside the script module from C2 runtime perspective (though they potentially can be alias-fetched and stored as individual blobs)
- after deserialization, python wrappers returned don't have correct type (as we don't do module_lookup trick)

Build-wise, I had to add dependency from pybind_state to libtorch.so. I don't think we build Caffe2 python frontend independently anymore, so it should be fine.

Reviewed By: amirshim, houseroad

Differential Revision: D14339599

fbshipit-source-id: 88a37a8abd1f1c4703e5ef937031f222535d4080
This commit is contained in:
Dmytro Dzhulgakov
2019-04-05 01:04:58 -07:00
committed by Facebook Github Bot
parent 8bdd0c3a85
commit c34e5ff952
3 changed files with 321 additions and 58 deletions

View File

@ -5,6 +5,8 @@ from __future__ import unicode_literals
import numpy as np
import os
import shutil
import tempfile
import unittest
import torch
@ -701,5 +703,109 @@ class TestTransform(htu.HypothesisTestCase):
workspace.RunNetOnce(proto.SerializeToString()), True)
class MyModule(torch.jit.ScriptModule):
def __init__(self):
super(MyModule, self).__init__()
self.mult = torch.nn.Parameter(torch.tensor([[1, 2, 3, 4, 5.0]]))
@torch.jit.script_method
def forward(self, x):
return self.mult.mm(x)
@torch.jit.script_method
def multi_input(self, x, y, z=2):
# type: (Tensor, Tensor, int) -> Tensor
return x + y + z
@torch.jit.script_method
def multi_output(self, x):
return (x, x + 1)
@unittest.skipIf(
"ScriptModule" not in core._REGISTERED_OPERATORS,
"Script module integration in Caffe2 is not enabled")
class TestScriptModule(test_util.TestCase):
def _createFeedModule(self):
workspace.FeedBlob('m', MyModule())
def testCreation(self):
m = MyModule()
workspace.FeedBlob('module', m)
m2 = workspace.FetchBlob('module')
self.assertTrue(m is m2)
def testForward(self):
self._createFeedModule()
val = np.random.rand(5, 5).astype(np.float32)
param = np.array([[1, 2, 3, 4, 5]]).astype(np.float32)
workspace.FeedBlob('w', val)
workspace.RunOperatorOnce(core.CreateOperator("ScriptModule", ["m", "w"], ["y"]))
np.testing.assert_almost_equal(workspace.FetchBlob("y"), np.matmul(param, val), decimal=5)
def testMultiInputOutput(self):
self._createFeedModule()
val = np.random.rand(5, 5).astype(np.float32)
workspace.FeedBlob('w', val)
val2 = np.random.rand(5, 5).astype(np.float32)
workspace.FeedBlob('w2', val2)
workspace.RunOperatorOnce(core.CreateOperator("ScriptModule", ["m", "w", "w2"], ["y"], method="multi_input"))
workspace.RunOperatorOnce(core.CreateOperator("ScriptModule", ["m", "w"], ["y1", "y2"], method="multi_output"))
np.testing.assert_almost_equal(workspace.FetchBlob("y"), val + val2 + 2, decimal=5)
np.testing.assert_almost_equal(workspace.FetchBlob("y1"), val, decimal=5)
np.testing.assert_almost_equal(workspace.FetchBlob("y2"), val + 1, decimal=5)
def testSerialization(self):
tmpdir = tempfile.mkdtemp()
try:
self._createFeedModule()
workspace.RunOperatorOnce(core.CreateOperator(
"Save",
["m"], [],
absolute_path=1,
db=os.path.join(tmpdir, "db"), db_type="minidb"))
workspace.ResetWorkspace()
self.assertFalse(workspace.HasBlob('m'))
workspace.RunOperatorOnce(core.CreateOperator(
"Load",
[], [],
absolute_path=1,
db=os.path.join(tmpdir, "db"), db_type="minidb",
load_all=1))
self.assertTrue(workspace.HasBlob('m'))
# TODO: make caffe2 side load return python-sided module
# right now it returns the base class (torch._C.ScriptModule)
# self.assertTrue(isinstance(workspace.FetchBlob('m'), torch.jit.ScriptModule))
# do something with the module
val = np.random.rand(5, 5).astype(np.float32)
param = np.array([[1, 2, 3, 4, 5]]).astype(np.float32)
workspace.FeedBlob('w', val)
workspace.RunOperatorOnce(core.CreateOperator("ScriptModule", ["m", "w"], ["y"]))
np.testing.assert_almost_equal(workspace.FetchBlob("y"), np.matmul(param, val), decimal=5)
finally:
# clean up temp folder.
try:
shutil.rmtree(tmpdir)
except OSError as e:
if e.errno != errno.ENOENT:
raise
class TestScriptModuleFromString(TestScriptModule):
def _createFeedModule(self):
workspace.RunOperatorOnce(
core.CreateOperator(
"ScriptModuleLoad", [], ["m"],
serialized_binary=self._get_modules_bytes(MyModule())))
def _get_modules_bytes(self, the_module):
import io
buffer = io.BytesIO()
torch.jit.save(the_module, buffer)
return buffer.getvalue()
if __name__ == '__main__':
unittest.main()