mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
sync
This commit is contained in:
@ -8,7 +8,8 @@ from caffe2.python import core, test_util, workspace
|
||||
class TestWorkspace(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.net = core.Net("test-net")
|
||||
self.net.ConstantFill([], "testblob", shape=[1, 2, 3, 4], value=1.0)
|
||||
self.testblob_ref = self.net.ConstantFill(
|
||||
[], "testblob", shape=[1, 2, 3, 4], value=1.0)
|
||||
workspace.ResetWorkspace()
|
||||
|
||||
def testRootFolder(self):
|
||||
@ -64,6 +65,20 @@ class TestWorkspace(unittest.TestCase):
|
||||
self.assertEqual(fetched_again.shape, (1, 2, 3, 4))
|
||||
np.testing.assert_array_equal(fetched_again, 2.0)
|
||||
|
||||
def testFetchFeedBlobViaBlobReference(self):
|
||||
self.assertEqual(
|
||||
workspace.RunNetOnce(self.net.Proto().SerializeToString()), True)
|
||||
fetched = workspace.FetchBlob(self.testblob_ref)
|
||||
# check if fetched is correct.
|
||||
self.assertEqual(fetched.shape, (1, 2, 3, 4))
|
||||
np.testing.assert_array_equal(fetched, 1.0)
|
||||
fetched[:] = 2.0
|
||||
self.assertEqual(workspace.FeedBlob(self.testblob_ref, fetched), True)
|
||||
fetched_again = workspace.FetchBlob("testblob") # fetch by name now
|
||||
self.assertEqual(fetched_again.shape, (1, 2, 3, 4))
|
||||
np.testing.assert_array_equal(fetched_again, 2.0)
|
||||
|
||||
|
||||
def testFetchFeedBlobTypes(self):
|
||||
for dtype in [np.float16, np.float32, np.float64, np.bool,
|
||||
np.int8, np.int16, np.int32, np.int64,
|
||||
@ -101,7 +116,8 @@ class TestWorkspace(unittest.TestCase):
|
||||
strs = np.array([
|
||||
' '.join(10 * ['long string']),
|
||||
' '.join(128 * ['very long string']),
|
||||
'small string'])
|
||||
'small \0\1\2 string',
|
||||
"Hello, world! I have special \0 symbols \1!"])
|
||||
workspace.FeedBlob('my_str_tensor', strs)
|
||||
strs2 = workspace.FetchBlob('my_str_tensor')
|
||||
self.assertEqual(strs.shape, strs2.shape)
|
||||
@ -117,6 +133,32 @@ class TestWorkspace(unittest.TestCase):
|
||||
for i in range(0, strs.shape[0]):
|
||||
self.assertEqual(strs[i], strs2[i])
|
||||
|
||||
def testFetchFeedPlainString(self):
|
||||
# this is actual string, not a tensor of strings
|
||||
s = "Hello, world! I have special \0 symbols \1!"
|
||||
workspace.FeedBlob('my_plain_string', s)
|
||||
s2 = workspace.FetchBlob('my_plain_string')
|
||||
self.assertEqual(s, s2)
|
||||
|
||||
def testFetchFeedViaBlobDict(self):
|
||||
self.assertEqual(
|
||||
workspace.RunNetOnce(self.net.Proto().SerializeToString()), True)
|
||||
fetched = workspace.blobs["testblob"]
|
||||
# check if fetched is correct.
|
||||
self.assertEqual(fetched.shape, (1, 2, 3, 4))
|
||||
np.testing.assert_array_equal(fetched, 1.0)
|
||||
fetched[:] = 2.0
|
||||
workspace.blobs["testblob"] = fetched
|
||||
fetched_again = workspace.blobs["testblob"]
|
||||
self.assertEqual(fetched_again.shape, (1, 2, 3, 4))
|
||||
np.testing.assert_array_equal(fetched_again, 2.0)
|
||||
|
||||
self.assertTrue("testblob" in workspace.blobs)
|
||||
self.assertFalse("non_existant" in workspace.blobs)
|
||||
self.assertEqual(len(workspace.blobs), 1)
|
||||
for key in workspace.blobs:
|
||||
self.assertEqual(key, "testblob")
|
||||
|
||||
|
||||
class TestMultiWorkspaces(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
Reference in New Issue
Block a user