mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
expose flop annotation to python
Summary: expose the flop annotation framework to python functions Reviewed By: Maratyszcza, Yangqing Differential Revision: D6135705 fbshipit-source-id: 2eed80b6cbda7b3ee3fe0e019a0f1fc4b0aa320b
This commit is contained in:
committed by
Facebook Github Bot
parent
388a1b1e66
commit
a0aa6d0e24
@ -58,6 +58,25 @@ class TestWorkspace(unittest.TestCase):
|
||||
self.assertEqual(len(blobs), 1)
|
||||
self.assertEqual(blobs[0], "testblob")
|
||||
|
||||
def testGetOperatorCost(self):
|
||||
op = core.CreateOperator(
|
||||
"Conv2D",
|
||||
["X", "W"], ["Y"],
|
||||
stride_h=1,
|
||||
stride_w=1,
|
||||
pad_t=1,
|
||||
pad_l=1,
|
||||
pad_b=1,
|
||||
pad_r=1,
|
||||
kernel=3,
|
||||
)
|
||||
X = np.zeros((1, 8, 8, 8))
|
||||
W = np.zeros((1, 1, 3, 3))
|
||||
workspace.FeedBlob("X", X)
|
||||
workspace.FeedBlob("W", W)
|
||||
flops, _ = workspace.GetOperatorCost(op.SerializeToString(), ["X", "W"])
|
||||
self.assertEqual(flops, 648)
|
||||
|
||||
def testRunNetOnce(self):
|
||||
self.assertEqual(
|
||||
workspace.RunNetOnce(self.net.Proto().SerializeToString()), True)
|
||||
|
Reference in New Issue
Block a user