[caffe2][nomnigraph] Fix NNPACK relu fusion for inplace relu (#7124)

This commit is contained in:
Bram Wasti
2018-05-01 16:26:54 -07:00
committed by GitHub
parent 20666feb2c
commit 967c4a0c18
3 changed files with 75 additions and 11 deletions

View File

@ -103,3 +103,49 @@ class TestTransformations(test_util.TestCase):
has_activation_arg = True
assert has_activation_arg
assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
def test_fuseNNPACKConvReluFollowedByMultipleInputOp(self):
net = core.Net("net")
net.Conv(
["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW"
)
net.Relu(["Y"], ["Y2"])
net.Conv(
["Y2", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW"
)
net.Relu(["Y"], ["Y2"])
addNNPACK(net) # get the NNPACK engine
assert (net.Proto().op[0].engine == "NNPACK")
fuseNNPACKConvRelu(net)
assert (len(net.Proto().op) == 2)
has_activation_arg = False
for arg in net.Proto().op[0].arg:
if arg.name == "activation":
assert (arg.s == "Relu")
has_activation_arg = True
assert has_activation_arg
assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]
def test_fuseNNPACKConvReluInplaceFollowedByMultipleInputOp(self):
net = core.Net("net")
net.Conv(
["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW"
)
net.Relu(["Y"], ["Y"])
net.Conv(
["Y", "w", "b"], ["Y2"], stride=1, pad=0, kernel=3, order="NCHW"
)
net.Relu(["Y2"], ["Y2"])
addNNPACK(net) # get the NNPACK engine
assert (net.Proto().op[0].engine == "NNPACK")
fuseNNPACKConvRelu(net)
assert (len(net.Proto().op) == 2)
has_activation_arg = False
for arg in net.Proto().op[0].arg:
if arg.name == "activation":
assert (arg.s == "Relu")
has_activation_arg = True
assert has_activation_arg
assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]