import unittest from caffe2.python import core, test_util, workspace from caffe2.python.control_ops_grad import disambiguate_grad_if_op_output from caffe2.python.model_helper import ModelHelper import numpy as np class TestControl(test_util.TestCase): def test_disambiguate_grad_if_op_output(self): workspace.FeedBlob("cond", np.array(True)) workspace.FeedBlob("then_grad", np.array(1)) workspace.FeedBlob("else_grad", np.array(2)) then_model = ModelHelper(name="then_test_model") then_model.net.Copy("then_grad", "input_grad") else_model = ModelHelper(name="else_test_model") else_model.net.Copy("else_grad", "else_temp_grad") else_model.net.Copy("else_temp", "input_grad") # to BuildGradientGenerators, in forward pass, we need else temp # as one of the output. Which later on results in a grad op like this: grad_op = core.CreateOperator( "If", ["cond", "then_grad", "else_grad"], ["input_grad", "else_temp_grad"], then_net=then_model.net.Proto(), else_net=else_model.net.Proto(), ) # in certain cases, another branch of the net also generates input_grad # and we call _DisambiguateGradOpOutput in core.py new_grad_output = "input_grad" + "_autosplit_" + "0" disambiguate_grad_if_op_output(grad_op, 0, new_grad_output) self.assertEqual(grad_op.output[0], new_grad_output) for arg in grad_op.arg: if arg.name == "else_net": self.assertEqual(arg.n.op[1].output[0], new_grad_output) else: self.assertEqual(arg.name, "then_net") if __name__ == '__main__': unittest.main()