mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
332 lines
12 KiB
Python
332 lines
12 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import control, core, test_util, workspace
|
|
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TestControl(test_util.TestCase):
|
|
def setUp(self):
|
|
super(TestControl, self).setUp()
|
|
self.N_ = 10
|
|
|
|
self.init_net_ = core.Net("init-net")
|
|
cnt = self.init_net_.CreateCounter([], init_count=0)
|
|
const_n = self.init_net_.ConstantFill(
|
|
[], shape=[], value=self.N_, dtype=core.DataType.INT64)
|
|
const_0 = self.init_net_.ConstantFill(
|
|
[], shape=[], value=0, dtype=core.DataType.INT64)
|
|
|
|
self.cnt_net_ = core.Net("cnt-net")
|
|
self.cnt_net_.CountUp([cnt])
|
|
curr_cnt = self.cnt_net_.RetrieveCount([cnt])
|
|
self.init_net_.ConstantFill(
|
|
[], [curr_cnt], shape=[], value=0, dtype=core.DataType.INT64)
|
|
self.cnt_net_.AddExternalOutput(curr_cnt)
|
|
|
|
self.cnt_2_net_ = core.Net("cnt-2-net")
|
|
self.cnt_2_net_.CountUp([cnt])
|
|
self.cnt_2_net_.CountUp([cnt])
|
|
curr_cnt_2 = self.cnt_2_net_.RetrieveCount([cnt])
|
|
self.init_net_.ConstantFill(
|
|
[], [curr_cnt_2], shape=[], value=0, dtype=core.DataType.INT64)
|
|
self.cnt_2_net_.AddExternalOutput(curr_cnt_2)
|
|
|
|
self.cond_net_ = core.Net("cond-net")
|
|
cond_blob = self.cond_net_.LT([curr_cnt, const_n])
|
|
self.cond_net_.AddExternalOutput(cond_blob)
|
|
|
|
self.not_cond_net_ = core.Net("not-cond-net")
|
|
cond_blob = self.not_cond_net_.GE([curr_cnt, const_n])
|
|
self.not_cond_net_.AddExternalOutput(cond_blob)
|
|
|
|
self.true_cond_net_ = core.Net("true-cond-net")
|
|
true_blob = self.true_cond_net_.LT([const_0, const_n])
|
|
self.true_cond_net_.AddExternalOutput(true_blob)
|
|
|
|
self.false_cond_net_ = core.Net("false-cond-net")
|
|
false_blob = self.false_cond_net_.GT([const_0, const_n])
|
|
self.false_cond_net_.AddExternalOutput(false_blob)
|
|
|
|
self.idle_net_ = core.Net("idle-net")
|
|
self.idle_net_.ConstantFill(
|
|
[], shape=[], value=0, dtype=core.DataType.INT64)
|
|
|
|
def CheckNetOutput(self, nets_and_expects):
|
|
"""
|
|
Check the net output is expected
|
|
nets_and_expects is a list of tuples (net, expect)
|
|
"""
|
|
for net, expect in nets_and_expects:
|
|
output = workspace.FetchBlob(
|
|
net.Proto().external_output[-1])
|
|
self.assertEqual(output, expect)
|
|
|
|
def CheckNetAllOutput(self, net, expects):
|
|
"""
|
|
Check the net output is expected
|
|
expects is a list of bools.
|
|
"""
|
|
self.assertEqual(len(net.Proto().external_output), len(expects))
|
|
for i in range(len(expects)):
|
|
output = workspace.FetchBlob(
|
|
net.Proto().external_output[i])
|
|
self.assertEqual(output, expects[i])
|
|
|
|
def BuildAndRunPlan(self, step):
|
|
plan = core.Plan("test")
|
|
plan.AddStep(control.Do('init', self.init_net_))
|
|
plan.AddStep(step)
|
|
self.assertEqual(workspace.RunPlan(plan), True)
|
|
|
|
def ForLoopTest(self, nets_or_steps):
|
|
step = control.For('myFor', nets_or_steps, self.N_)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetOutput([(self.cnt_net_, self.N_)])
|
|
|
|
def testForLoopWithNets(self):
|
|
self.ForLoopTest(self.cnt_net_)
|
|
self.ForLoopTest([self.cnt_net_, self.idle_net_])
|
|
|
|
def testForLoopWithStep(self):
|
|
step = control.Do('count', self.cnt_net_)
|
|
self.ForLoopTest(step)
|
|
self.ForLoopTest([step, self.idle_net_])
|
|
|
|
def WhileLoopTest(self, nets_or_steps):
|
|
step = control.While('myWhile', self.cond_net_, nets_or_steps)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetOutput([(self.cnt_net_, self.N_)])
|
|
|
|
def testWhileLoopWithNet(self):
|
|
self.WhileLoopTest(self.cnt_net_)
|
|
self.WhileLoopTest([self.cnt_net_, self.idle_net_])
|
|
|
|
def testWhileLoopWithStep(self):
|
|
step = control.Do('count', self.cnt_net_)
|
|
self.WhileLoopTest(step)
|
|
self.WhileLoopTest([step, self.idle_net_])
|
|
|
|
def UntilLoopTest(self, nets_or_steps):
|
|
step = control.Until('myUntil', self.not_cond_net_, nets_or_steps)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetOutput([(self.cnt_net_, self.N_)])
|
|
|
|
def testUntilLoopWithNet(self):
|
|
self.UntilLoopTest(self.cnt_net_)
|
|
self.UntilLoopTest([self.cnt_net_, self.idle_net_])
|
|
|
|
def testUntilLoopWithStep(self):
|
|
step = control.Do('count', self.cnt_net_)
|
|
self.UntilLoopTest(step)
|
|
self.UntilLoopTest([step, self.idle_net_])
|
|
|
|
def DoWhileLoopTest(self, nets_or_steps):
|
|
step = control.DoWhile('myDoWhile', self.cond_net_, nets_or_steps)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetOutput([(self.cnt_net_, self.N_)])
|
|
|
|
def testDoWhileLoopWithNet(self):
|
|
self.DoWhileLoopTest(self.cnt_net_)
|
|
self.DoWhileLoopTest([self.idle_net_, self.cnt_net_])
|
|
|
|
def testDoWhileLoopWithStep(self):
|
|
step = control.Do('count', self.cnt_net_)
|
|
self.DoWhileLoopTest(step)
|
|
self.DoWhileLoopTest([self.idle_net_, step])
|
|
|
|
def DoUntilLoopTest(self, nets_or_steps):
|
|
step = control.DoUntil('myDoUntil', self.not_cond_net_, nets_or_steps)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetOutput([(self.cnt_net_, self.N_)])
|
|
|
|
def testDoUntilLoopWithNet(self):
|
|
self.DoUntilLoopTest(self.cnt_net_)
|
|
self.DoUntilLoopTest([self.cnt_net_, self.idle_net_])
|
|
|
|
def testDoUntilLoopWithStep(self):
|
|
step = control.Do('count', self.cnt_net_)
|
|
self.DoUntilLoopTest(step)
|
|
self.DoUntilLoopTest([self.idle_net_, step])
|
|
|
|
def IfCondTest(self, cond_net, expect, cond_on_blob):
|
|
if cond_on_blob:
|
|
step = control.Do(
|
|
'if-all',
|
|
control.Do('count', cond_net),
|
|
control.If('myIf', cond_net.Proto().external_output[-1],
|
|
self.cnt_net_))
|
|
else:
|
|
step = control.If('myIf', cond_net, self.cnt_net_)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetOutput([(self.cnt_net_, expect)])
|
|
|
|
def testIfCondTrueOnNet(self):
|
|
self.IfCondTest(self.true_cond_net_, 1, False)
|
|
|
|
def testIfCondTrueOnBlob(self):
|
|
self.IfCondTest(self.true_cond_net_, 1, True)
|
|
|
|
def testIfCondFalseOnNet(self):
|
|
self.IfCondTest(self.false_cond_net_, 0, False)
|
|
|
|
def testIfCondFalseOnBlob(self):
|
|
self.IfCondTest(self.false_cond_net_, 0, True)
|
|
|
|
def IfElseCondTest(self, cond_net, cond_value, expect, cond_on_blob):
|
|
if cond_value:
|
|
run_net = self.cnt_net_
|
|
else:
|
|
run_net = self.cnt_2_net_
|
|
if cond_on_blob:
|
|
step = control.Do(
|
|
'if-else-all',
|
|
control.Do('count', cond_net),
|
|
control.If('myIfElse', cond_net.Proto().external_output[-1],
|
|
self.cnt_net_, self.cnt_2_net_))
|
|
else:
|
|
step = control.If('myIfElse', cond_net,
|
|
self.cnt_net_, self.cnt_2_net_)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetOutput([(run_net, expect)])
|
|
|
|
def testIfElseCondTrueOnNet(self):
|
|
self.IfElseCondTest(self.true_cond_net_, True, 1, False)
|
|
|
|
def testIfElseCondTrueOnBlob(self):
|
|
self.IfElseCondTest(self.true_cond_net_, True, 1, True)
|
|
|
|
def testIfElseCondFalseOnNet(self):
|
|
self.IfElseCondTest(self.false_cond_net_, False, 2, False)
|
|
|
|
def testIfElseCondFalseOnBlob(self):
|
|
self.IfElseCondTest(self.false_cond_net_, False, 2, True)
|
|
|
|
def IfNotCondTest(self, cond_net, expect, cond_on_blob):
|
|
if cond_on_blob:
|
|
step = control.Do(
|
|
'if-not',
|
|
control.Do('count', cond_net),
|
|
control.IfNot('myIfNot', cond_net.Proto().external_output[-1],
|
|
self.cnt_net_))
|
|
else:
|
|
step = control.IfNot('myIfNot', cond_net, self.cnt_net_)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetOutput([(self.cnt_net_, expect)])
|
|
|
|
def testIfNotCondTrueOnNet(self):
|
|
self.IfNotCondTest(self.true_cond_net_, 0, False)
|
|
|
|
def testIfNotCondTrueOnBlob(self):
|
|
self.IfNotCondTest(self.true_cond_net_, 0, True)
|
|
|
|
def testIfNotCondFalseOnNet(self):
|
|
self.IfNotCondTest(self.false_cond_net_, 1, False)
|
|
|
|
def testIfNotCondFalseOnBlob(self):
|
|
self.IfNotCondTest(self.false_cond_net_, 1, True)
|
|
|
|
def IfNotElseCondTest(self, cond_net, cond_value, expect, cond_on_blob):
|
|
if cond_value:
|
|
run_net = self.cnt_2_net_
|
|
else:
|
|
run_net = self.cnt_net_
|
|
if cond_on_blob:
|
|
step = control.Do(
|
|
'if-not-else',
|
|
control.Do('count', cond_net),
|
|
control.IfNot('myIfNotElse',
|
|
cond_net.Proto().external_output[-1],
|
|
self.cnt_net_, self.cnt_2_net_))
|
|
else:
|
|
step = control.IfNot('myIfNotElse', cond_net,
|
|
self.cnt_net_, self.cnt_2_net_)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetOutput([(run_net, expect)])
|
|
|
|
def testIfNotElseCondTrueOnNet(self):
|
|
self.IfNotElseCondTest(self.true_cond_net_, True, 2, False)
|
|
|
|
def testIfNotElseCondTrueOnBlob(self):
|
|
self.IfNotElseCondTest(self.true_cond_net_, True, 2, True)
|
|
|
|
def testIfNotElseCondFalseOnNet(self):
|
|
self.IfNotElseCondTest(self.false_cond_net_, False, 1, False)
|
|
|
|
def testIfNotElseCondFalseOnBlob(self):
|
|
self.IfNotElseCondTest(self.false_cond_net_, False, 1, True)
|
|
|
|
def testSwitch(self):
|
|
step = control.Switch(
|
|
'mySwitch',
|
|
(self.false_cond_net_, self.cnt_net_),
|
|
(self.true_cond_net_, self.cnt_2_net_)
|
|
)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetOutput([(self.cnt_net_, 0), (self.cnt_2_net_, 2)])
|
|
|
|
def testSwitchNot(self):
|
|
step = control.SwitchNot(
|
|
'mySwitchNot',
|
|
(self.false_cond_net_, self.cnt_net_),
|
|
(self.true_cond_net_, self.cnt_2_net_)
|
|
)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetOutput([(self.cnt_net_, 1), (self.cnt_2_net_, 0)])
|
|
|
|
def testBoolNet(self):
|
|
bool_net = control.BoolNet(('a', True))
|
|
step = control.Do('bool', bool_net)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetAllOutput(bool_net, [True])
|
|
|
|
bool_net = control.BoolNet(('a', True), ('b', False))
|
|
step = control.Do('bool', bool_net)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetAllOutput(bool_net, [True, False])
|
|
|
|
bool_net = control.BoolNet([('a', True), ('b', False)])
|
|
step = control.Do('bool', bool_net)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetAllOutput(bool_net, [True, False])
|
|
|
|
def testCombineConditions(self):
|
|
# combined by 'Or'
|
|
combine_net = control.CombineConditions(
|
|
'test', [self.true_cond_net_, self.false_cond_net_], 'Or')
|
|
step = control.Do('combine',
|
|
self.true_cond_net_,
|
|
self.false_cond_net_,
|
|
combine_net)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetOutput([(combine_net, True)])
|
|
|
|
# combined by 'And'
|
|
combine_net = control.CombineConditions(
|
|
'test', [self.true_cond_net_, self.false_cond_net_], 'And')
|
|
step = control.Do('combine',
|
|
self.true_cond_net_,
|
|
self.false_cond_net_,
|
|
combine_net)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetOutput([(combine_net, False)])
|
|
|
|
def testMergeConditionNets(self):
|
|
# merged by 'Or'
|
|
merge_net = control.MergeConditionNets(
|
|
'test', [self.true_cond_net_, self.false_cond_net_], 'Or')
|
|
step = control.Do('merge', merge_net)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetOutput([(merge_net, True)])
|
|
|
|
# merged by 'And'
|
|
merge_net = control.MergeConditionNets(
|
|
'test', [self.true_cond_net_, self.false_cond_net_], 'And')
|
|
step = control.Do('merge', merge_net)
|
|
self.BuildAndRunPlan(step)
|
|
self.CheckNetOutput([(merge_net, False)])
|