mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: fix auto grad summing for IfOp where intermediate output needs renaming. Bug before this diff: - we only renames the output of IfOp without changing the subnet ops output - this results in blob not found error the unittest provides an example this diff fix that for IfOp Pull Request resolved: https://github.com/pytorch/pytorch/pull/14772 Differential Revision: D13327090 Pulled By: harouwu fbshipit-source-id: ec40ee88526ace3619c54551e223dd71158a02f8
41 lines
1.7 KiB
Python
41 lines
1.7 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import core, test_util, workspace
|
|
from caffe2.python.control_ops_grad import disambiguate_grad_if_op_output
|
|
from caffe2.python.model_helper import ModelHelper
|
|
import numpy as np
|
|
|
|
|
|
class TestControl(test_util.TestCase):
|
|
def test_disambiguate_grad_if_op_output(self):
|
|
workspace.FeedBlob("cond", np.array(True))
|
|
workspace.FeedBlob("then_grad", np.array(1))
|
|
workspace.FeedBlob("else_grad", np.array(2))
|
|
|
|
then_model = ModelHelper(name="then_test_model")
|
|
then_model.net.Copy("then_grad", "input_grad")
|
|
|
|
else_model = ModelHelper(name="else_test_model")
|
|
else_model.net.Copy("else_grad", "else_temp_grad")
|
|
else_model.net.Copy("else_temp", "input_grad")
|
|
|
|
# to BuildGradientGenerators, in forward pass, we need else temp
|
|
# as one of the output. Which later on results in a grad op like this:
|
|
grad_op = core.CreateOperator(
|
|
"If",
|
|
["cond", "then_grad", "else_grad"],
|
|
["input_grad", "else_temp_grad"],
|
|
then_net=then_model.net.Proto(),
|
|
else_net=else_model.net.Proto(),
|
|
)
|
|
|
|
# in certain cases, another branch of the net also generates input_grad
|
|
# and we call _DisambiguateGradOpOutput in core.py
|
|
new_grad_output = "input_grad" + "_autosplit_" + "0"
|
|
disambiguate_grad_if_op_output(grad_op, 0, new_grad_output)
|
|
self.assertEqual(grad_op.output[0], new_grad_output)
|
|
self.assertEqual(grad_op.arg[1].n.op[1].output[0], new_grad_output)
|