mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
NNAPI: Support const values in binary ops
Summary: NNAPI converter failed with 1 const value and one tensor earlier Code suggestions from dreiss Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_pointwise_binary Imported from OSS Reviewed By: anshuljain1 Differential Revision: D28893881 fbshipit-source-id: 59240373fb03c6fdafa4cb2fa4d8408dd20092f6
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b4f5809db8
commit
2d58f3f56d
@ -49,6 +49,7 @@ class TestNNAPI(TestCase):
|
||||
convert_args=None,
|
||||
atol_rtol=None,
|
||||
limit=None,
|
||||
expected_memory_format=None
|
||||
):
|
||||
with torch.no_grad():
|
||||
if isinstance(arg_or_args, torch.Tensor):
|
||||
@ -76,6 +77,8 @@ class TestNNAPI(TestCase):
|
||||
# Too many mismatches. Re-run the check with no tolerance
|
||||
# to get a nice message.
|
||||
self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0)
|
||||
if expected_memory_format:
|
||||
self.assertTrue(nnapi_out.is_contiguous(memory_format=expected_memory_format))
|
||||
|
||||
def float_and_quant_and_nhwc(self, inp_float, scale, zero_point):
|
||||
torch.manual_seed(29)
|
||||
@ -319,6 +322,28 @@ class TestNNAPI(TestCase):
|
||||
torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
|
||||
])
|
||||
|
||||
def test_pointwise_binary_const(self):
|
||||
const = torch.randn(1, 4, 6, 6)
|
||||
|
||||
class ArgPlusConst(torch.nn.Module):
|
||||
def forward(self, arg):
|
||||
return arg + const
|
||||
|
||||
class ConstPlusArg(torch.nn.Module):
|
||||
def forward(self, arg):
|
||||
return const + arg
|
||||
|
||||
arg_contig = torch.randn(2, 4, 6, 6)
|
||||
arg_nhwc = nhwc(torch.randn(2, 4, 6, 6))
|
||||
|
||||
for mod_class in [ArgPlusConst, ConstPlusArg]:
|
||||
for use_nhwc in [False, True]:
|
||||
with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc):
|
||||
arg = arg_nhwc if use_nhwc else arg_contig
|
||||
memory_format = torch.channels_last if use_nhwc else torch.contiguous_format
|
||||
self.check(mod_class(), arg,
|
||||
expected_memory_format=memory_format)
|
||||
|
||||
def test_hardtanh(self):
|
||||
inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0])
|
||||
self.check(torch.nn.Hardtanh(), inp)
|
||||
|
||||
Reference in New Issue
Block a user