[NNAPI] Handle binary ops combining NHWC+NCHW in some cases (#48812)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48812

This came up in a squeeze-and-excitation model.  Starting with an NHWC
tensor T, we perform a mean operation across H and W, giving an NxC
tensor, which (after some fully connected layers) is reshaped to
NxCx1x1, then multiplied with T.  To handle this, we detect the specific
case of a binary op with one NHWC input and one contiguous input with
H,W == 1,1 and allow the op to be applied (after transposing the
contiguous input).

Test Plan: Unit test.

Reviewed By: axitkhurana

Differential Revision: D25317939

Pulled By: dreiss

fbshipit-source-id: b4c17ab3b874d1a7defa04664010ba82115f1c20
This commit is contained in:
David Reiss
2021-04-06 13:40:04 -07:00
committed by Facebook GitHub Bot
parent b057d27b0b
commit 476c597ae6
2 changed files with 59 additions and 1 deletions

View File

@ -331,6 +331,18 @@ class TestNNAPI(TestCase):
inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8)
self.check(mod, inp)
def test_seblock_mul(self):
class MulModel(torch.nn.Module):
def forward(self, lhs, rhs):
return lhs * rhs
self.check(
MulModel(),
[
nhwc(torch.randn(2, 3, 4, 4)),
torch.randn(1, 3, 1, 1),
])
def test_multi_output(self):
class MultiModel(torch.nn.Module):
def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]: