mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 17:54:55 +08:00
[caffe2][nomnigraph] Fix NNPACK relu fusion for inplace relu (#7124)
This commit is contained in:
@ -103,3 +103,49 @@ class TestTransformations(test_util.TestCase):
|
||||
has_activation_arg = True
|
||||
assert has_activation_arg
|
||||
assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
|
||||
|
||||
def test_fuseNNPACKConvReluFollowedByMultipleInputOp(self):
|
||||
net = core.Net("net")
|
||||
net.Conv(
|
||||
["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW"
|
||||
)
|
||||
net.Relu(["Y"], ["Y2"])
|
||||
net.Conv(
|
||||
["Y2", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW"
|
||||
)
|
||||
net.Relu(["Y"], ["Y2"])
|
||||
addNNPACK(net) # get the NNPACK engine
|
||||
assert (net.Proto().op[0].engine == "NNPACK")
|
||||
fuseNNPACKConvRelu(net)
|
||||
assert (len(net.Proto().op) == 2)
|
||||
has_activation_arg = False
|
||||
for arg in net.Proto().op[0].arg:
|
||||
if arg.name == "activation":
|
||||
assert (arg.s == "Relu")
|
||||
has_activation_arg = True
|
||||
assert has_activation_arg
|
||||
assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
|
||||
assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]
|
||||
|
||||
def test_fuseNNPACKConvReluInplaceFollowedByMultipleInputOp(self):
|
||||
net = core.Net("net")
|
||||
net.Conv(
|
||||
["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW"
|
||||
)
|
||||
net.Relu(["Y"], ["Y"])
|
||||
net.Conv(
|
||||
["Y", "w", "b"], ["Y2"], stride=1, pad=0, kernel=3, order="NCHW"
|
||||
)
|
||||
net.Relu(["Y2"], ["Y2"])
|
||||
addNNPACK(net) # get the NNPACK engine
|
||||
assert (net.Proto().op[0].engine == "NNPACK")
|
||||
fuseNNPACKConvRelu(net)
|
||||
assert (len(net.Proto().op) == 2)
|
||||
has_activation_arg = False
|
||||
for arg in net.Proto().op[0].arg:
|
||||
if arg.name == "activation":
|
||||
assert (arg.s == "Relu")
|
||||
has_activation_arg = True
|
||||
assert has_activation_arg
|
||||
assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
|
||||
assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]
|
||||
|
||||
Reference in New Issue
Block a user