expose flop annotation to python

Summary: expose the flop annotation framework to python functions

Reviewed By: Maratyszcza, Yangqing

Differential Revision: D6135705

fbshipit-source-id: 2eed80b6cbda7b3ee3fe0e019a0f1fc4b0aa320b
This commit is contained in:
Bram Wasti
2017-10-24 11:23:48 -07:00
committed by Facebook Github Bot
parent 388a1b1e66
commit a0aa6d0e24
5 changed files with 43 additions and 2 deletions

View File

@ -58,6 +58,25 @@ class TestWorkspace(unittest.TestCase):
self.assertEqual(len(blobs), 1)
self.assertEqual(blobs[0], "testblob")
def testGetOperatorCost(self):
op = core.CreateOperator(
"Conv2D",
["X", "W"], ["Y"],
stride_h=1,
stride_w=1,
pad_t=1,
pad_l=1,
pad_b=1,
pad_r=1,
kernel=3,
)
X = np.zeros((1, 8, 8, 8))
W = np.zeros((1, 1, 3, 3))
workspace.FeedBlob("X", X)
workspace.FeedBlob("W", W)
flops, _ = workspace.GetOperatorCost(op.SerializeToString(), ["X", "W"])
self.assertEqual(flops, 648)
def testRunNetOnce(self):
self.assertEqual(
workspace.RunNetOnce(self.net.Proto().SerializeToString()), True)