mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add aten::avgpool2d NNAPI converter (#58538)
Summary: Add support for aten::avgpool2d op in the NNAPI model converter with var size support Pull Request resolved: https://github.com/pytorch/pytorch/pull/58538 Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_avgpool2d Reviewed By: anshuljain1 Differential Revision: D28531944 fbshipit-source-id: 43ff8c9389365698c282f204042b49c7ec84d824
This commit is contained in:
committed by
Facebook GitHub Bot
parent
19b6ee4d4e
commit
369802a504
@ -251,6 +251,34 @@ class TestNNAPI(TestCase):
|
||||
self.check(torch.nn.MaxPool2d((3, 4)), inp)
|
||||
self.check(torch.nn.MaxPool2d((3, 4), (1, 2)), inp)
|
||||
|
||||
def test_avg_pool2d(self):
|
||||
for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128):
|
||||
with self.subTest(name):
|
||||
atol_rtol = None
|
||||
limit = None
|
||||
convert_dims = (2, 3, 0, 0)
|
||||
convert_arg = torch.zeros(*convert_dims)
|
||||
|
||||
for model in (
|
||||
torch.nn.AvgPool2d(2),
|
||||
torch.nn.AvgPool2d((3, 4)),
|
||||
torch.nn.AvgPool2d((3, 4), (1, 2))):
|
||||
if "quant" in name:
|
||||
atol_rtol = (1, 0)
|
||||
limit = model(inp).numel()
|
||||
convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128)
|
||||
if "nhwc" in name:
|
||||
convert_arg = nhwc(convert_arg)
|
||||
|
||||
self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
|
||||
self.check(
|
||||
model,
|
||||
inp,
|
||||
convert_args=[convert_arg],
|
||||
atol_rtol=atol_rtol,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
def test_adaptive_avg_pool2d(self):
|
||||
for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128):
|
||||
with self.subTest(name):
|
||||
|
Reference in New Issue
Block a user