mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[NNAPI] Handle binary ops combining NHWC+NCHW in some cases (#48812)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48812 This came up in a squeeze-and-excitation model. Starting with an NHWC tensor T, we perform a mean operation across H and W, giving an NxC tensor, which (after some fully connected layers) is reshaped to NxCx1x1, then multiplied with T. To handle this, we detect the specific case of a binary op with one NHWC input and one contiguous input with H,W == 1,1 and allow the op to be applied (after transposing the contiguous input). Test Plan: Unit test. Reviewed By: axitkhurana Differential Revision: D25317939 Pulled By: dreiss fbshipit-source-id: b4c17ab3b874d1a7defa04664010ba82115f1c20
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b057d27b0b
commit
476c597ae6
@ -331,6 +331,18 @@ class TestNNAPI(TestCase):
|
||||
inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8)
|
||||
self.check(mod, inp)
|
||||
|
||||
def test_seblock_mul(self):
|
||||
class MulModel(torch.nn.Module):
|
||||
def forward(self, lhs, rhs):
|
||||
return lhs * rhs
|
||||
|
||||
self.check(
|
||||
MulModel(),
|
||||
[
|
||||
nhwc(torch.randn(2, 3, 4, 4)),
|
||||
torch.randn(1, 3, 1, 1),
|
||||
])
|
||||
|
||||
def test_multi_output(self):
|
||||
class MultiModel(torch.nn.Module):
|
||||
def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
Reference in New Issue
Block a user