mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
ScriptModuleOp in caffe2 (#18716)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18716 Might be useful as an intermediate stage for some systems that currently use Caffe2 nets as an execution mechanism. Not sure it's a good idea all together, please comment. Limitations: - only Tensor types as inputs/outputs - the entire module is serialized as a zip archive inside a proto in Caffe2 db, it'd be subject to 4Gb limit and is likely very slow. For small models it'd work though. - no autograd, though it can be attached in principle - no way to retrieve parameters inside the script module from C2 runtime perspective (though they potentially can be alias-fetched and stored as individual blobs) - after deserialization, python wrappers returned don't have correct type (as we don't do module_lookup trick) Build-wise, I had to add dependency from pybind_state to libtorch.so. I don't think we build Caffe2 python frontend independently anymore, so it should be fine. Reviewed By: amirshim, houseroad Differential Revision: D14339599 fbshipit-source-id: 88a37a8abd1f1c4703e5ef937031f222535d4080
This commit is contained in:
committed by
Facebook Github Bot
parent
8bdd0c3a85
commit
c34e5ff952
@ -5,6 +5,8 @@ from __future__ import unicode_literals
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
@ -701,5 +703,109 @@ class TestTransform(htu.HypothesisTestCase):
|
||||
workspace.RunNetOnce(proto.SerializeToString()), True)
|
||||
|
||||
|
||||
class MyModule(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
self.mult = torch.nn.Parameter(torch.tensor([[1, 2, 3, 4, 5.0]]))
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
return self.mult.mm(x)
|
||||
|
||||
@torch.jit.script_method
|
||||
def multi_input(self, x, y, z=2):
|
||||
# type: (Tensor, Tensor, int) -> Tensor
|
||||
return x + y + z
|
||||
|
||||
@torch.jit.script_method
|
||||
def multi_output(self, x):
|
||||
return (x, x + 1)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
"ScriptModule" not in core._REGISTERED_OPERATORS,
|
||||
"Script module integration in Caffe2 is not enabled")
|
||||
class TestScriptModule(test_util.TestCase):
|
||||
def _createFeedModule(self):
|
||||
workspace.FeedBlob('m', MyModule())
|
||||
|
||||
def testCreation(self):
|
||||
m = MyModule()
|
||||
workspace.FeedBlob('module', m)
|
||||
m2 = workspace.FetchBlob('module')
|
||||
self.assertTrue(m is m2)
|
||||
|
||||
def testForward(self):
|
||||
self._createFeedModule()
|
||||
val = np.random.rand(5, 5).astype(np.float32)
|
||||
param = np.array([[1, 2, 3, 4, 5]]).astype(np.float32)
|
||||
workspace.FeedBlob('w', val)
|
||||
workspace.RunOperatorOnce(core.CreateOperator("ScriptModule", ["m", "w"], ["y"]))
|
||||
np.testing.assert_almost_equal(workspace.FetchBlob("y"), np.matmul(param, val), decimal=5)
|
||||
|
||||
def testMultiInputOutput(self):
|
||||
self._createFeedModule()
|
||||
val = np.random.rand(5, 5).astype(np.float32)
|
||||
workspace.FeedBlob('w', val)
|
||||
val2 = np.random.rand(5, 5).astype(np.float32)
|
||||
workspace.FeedBlob('w2', val2)
|
||||
workspace.RunOperatorOnce(core.CreateOperator("ScriptModule", ["m", "w", "w2"], ["y"], method="multi_input"))
|
||||
workspace.RunOperatorOnce(core.CreateOperator("ScriptModule", ["m", "w"], ["y1", "y2"], method="multi_output"))
|
||||
np.testing.assert_almost_equal(workspace.FetchBlob("y"), val + val2 + 2, decimal=5)
|
||||
np.testing.assert_almost_equal(workspace.FetchBlob("y1"), val, decimal=5)
|
||||
np.testing.assert_almost_equal(workspace.FetchBlob("y2"), val + 1, decimal=5)
|
||||
|
||||
def testSerialization(self):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
try:
|
||||
self._createFeedModule()
|
||||
workspace.RunOperatorOnce(core.CreateOperator(
|
||||
"Save",
|
||||
["m"], [],
|
||||
absolute_path=1,
|
||||
db=os.path.join(tmpdir, "db"), db_type="minidb"))
|
||||
workspace.ResetWorkspace()
|
||||
|
||||
self.assertFalse(workspace.HasBlob('m'))
|
||||
workspace.RunOperatorOnce(core.CreateOperator(
|
||||
"Load",
|
||||
[], [],
|
||||
absolute_path=1,
|
||||
db=os.path.join(tmpdir, "db"), db_type="minidb",
|
||||
load_all=1))
|
||||
self.assertTrue(workspace.HasBlob('m'))
|
||||
# TODO: make caffe2 side load return python-sided module
|
||||
# right now it returns the base class (torch._C.ScriptModule)
|
||||
# self.assertTrue(isinstance(workspace.FetchBlob('m'), torch.jit.ScriptModule))
|
||||
|
||||
# do something with the module
|
||||
val = np.random.rand(5, 5).astype(np.float32)
|
||||
param = np.array([[1, 2, 3, 4, 5]]).astype(np.float32)
|
||||
workspace.FeedBlob('w', val)
|
||||
workspace.RunOperatorOnce(core.CreateOperator("ScriptModule", ["m", "w"], ["y"]))
|
||||
np.testing.assert_almost_equal(workspace.FetchBlob("y"), np.matmul(param, val), decimal=5)
|
||||
finally:
|
||||
# clean up temp folder.
|
||||
try:
|
||||
shutil.rmtree(tmpdir)
|
||||
except OSError as e:
|
||||
if e.errno != errno.ENOENT:
|
||||
raise
|
||||
|
||||
|
||||
class TestScriptModuleFromString(TestScriptModule):
|
||||
def _createFeedModule(self):
|
||||
workspace.RunOperatorOnce(
|
||||
core.CreateOperator(
|
||||
"ScriptModuleLoad", [], ["m"],
|
||||
serialized_binary=self._get_modules_bytes(MyModule())))
|
||||
|
||||
def _get_modules_bytes(self, the_module):
|
||||
import io
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(the_module, buffer)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user