This commit is contained in:
Yangqing Jia
2016-07-28 15:06:04 -07:00
parent f01f2063dd
commit bcea409c82
70 changed files with 5124 additions and 383 deletions

View File

@ -8,7 +8,8 @@ from caffe2.python import core, test_util, workspace
class TestWorkspace(unittest.TestCase):
def setUp(self):
self.net = core.Net("test-net")
self.net.ConstantFill([], "testblob", shape=[1, 2, 3, 4], value=1.0)
self.testblob_ref = self.net.ConstantFill(
[], "testblob", shape=[1, 2, 3, 4], value=1.0)
workspace.ResetWorkspace()
def testRootFolder(self):
@ -64,6 +65,20 @@ class TestWorkspace(unittest.TestCase):
self.assertEqual(fetched_again.shape, (1, 2, 3, 4))
np.testing.assert_array_equal(fetched_again, 2.0)
def testFetchFeedBlobViaBlobReference(self):
self.assertEqual(
workspace.RunNetOnce(self.net.Proto().SerializeToString()), True)
fetched = workspace.FetchBlob(self.testblob_ref)
# check if fetched is correct.
self.assertEqual(fetched.shape, (1, 2, 3, 4))
np.testing.assert_array_equal(fetched, 1.0)
fetched[:] = 2.0
self.assertEqual(workspace.FeedBlob(self.testblob_ref, fetched), True)
fetched_again = workspace.FetchBlob("testblob") # fetch by name now
self.assertEqual(fetched_again.shape, (1, 2, 3, 4))
np.testing.assert_array_equal(fetched_again, 2.0)
def testFetchFeedBlobTypes(self):
for dtype in [np.float16, np.float32, np.float64, np.bool,
np.int8, np.int16, np.int32, np.int64,
@ -101,7 +116,8 @@ class TestWorkspace(unittest.TestCase):
strs = np.array([
' '.join(10 * ['long string']),
' '.join(128 * ['very long string']),
'small string'])
'small \0\1\2 string',
"Hello, world! I have special \0 symbols \1!"])
workspace.FeedBlob('my_str_tensor', strs)
strs2 = workspace.FetchBlob('my_str_tensor')
self.assertEqual(strs.shape, strs2.shape)
@ -117,6 +133,32 @@ class TestWorkspace(unittest.TestCase):
for i in range(0, strs.shape[0]):
self.assertEqual(strs[i], strs2[i])
def testFetchFeedPlainString(self):
# this is actual string, not a tensor of strings
s = "Hello, world! I have special \0 symbols \1!"
workspace.FeedBlob('my_plain_string', s)
s2 = workspace.FetchBlob('my_plain_string')
self.assertEqual(s, s2)
def testFetchFeedViaBlobDict(self):
self.assertEqual(
workspace.RunNetOnce(self.net.Proto().SerializeToString()), True)
fetched = workspace.blobs["testblob"]
# check if fetched is correct.
self.assertEqual(fetched.shape, (1, 2, 3, 4))
np.testing.assert_array_equal(fetched, 1.0)
fetched[:] = 2.0
workspace.blobs["testblob"] = fetched
fetched_again = workspace.blobs["testblob"]
self.assertEqual(fetched_again.shape, (1, 2, 3, 4))
np.testing.assert_array_equal(fetched_again, 2.0)
self.assertTrue("testblob" in workspace.blobs)
self.assertFalse("non_existant" in workspace.blobs)
self.assertEqual(len(workspace.blobs), 1)
for key in workspace.blobs:
self.assertEqual(key, "testblob")
class TestMultiWorkspaces(unittest.TestCase):
def setUp(self):