NNAPI: Support const values in binary ops

Summary:
NNAPI converter failed with 1 const value and one tensor earlier
Code suggestions from dreiss

Test Plan:
pytest test/test_nnapi.py::TestNNAPI::test_pointwise_binary

Imported from OSS

Reviewed By: anshuljain1

Differential Revision: D28893881

fbshipit-source-id: 59240373fb03c6fdafa4cb2fa4d8408dd20092f6
This commit is contained in:
Akshit Khurana
2021-08-20 21:08:59 -07:00
committed by Facebook GitHub Bot
parent b4f5809db8
commit 2d58f3f56d
2 changed files with 43 additions and 7 deletions

View File

@ -49,6 +49,7 @@ class TestNNAPI(TestCase):
convert_args=None,
atol_rtol=None,
limit=None,
expected_memory_format=None
):
with torch.no_grad():
if isinstance(arg_or_args, torch.Tensor):
@ -76,6 +77,8 @@ class TestNNAPI(TestCase):
# Too many mismatches. Re-run the check with no tolerance
# to get a nice message.
self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0)
if expected_memory_format:
self.assertTrue(nnapi_out.is_contiguous(memory_format=expected_memory_format))
def float_and_quant_and_nhwc(self, inp_float, scale, zero_point):
torch.manual_seed(29)
@ -319,6 +322,28 @@ class TestNNAPI(TestCase):
torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
])
def test_pointwise_binary_const(self):
const = torch.randn(1, 4, 6, 6)
class ArgPlusConst(torch.nn.Module):
def forward(self, arg):
return arg + const
class ConstPlusArg(torch.nn.Module):
def forward(self, arg):
return const + arg
arg_contig = torch.randn(2, 4, 6, 6)
arg_nhwc = nhwc(torch.randn(2, 4, 6, 6))
for mod_class in [ArgPlusConst, ConstPlusArg]:
for use_nhwc in [False, True]:
with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc):
arg = arg_nhwc if use_nhwc else arg_contig
memory_format = torch.channels_last if use_nhwc else torch.contiguous_format
self.check(mod_class(), arg,
expected_memory_format=memory_format)
def test_hardtanh(self):
inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0])
self.check(torch.nn.Hardtanh(), inp)