Make nnapi flatten converter accept flex inputs (#61024)

Summary:
As title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61024

Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_flatten

Reviewed By: anshuljain1

Differential Revision: D29480748

fbshipit-source-id: c334b09600a64d3e552cec843d6da3de28e7d27c
This commit is contained in:
Akshit Khurana
2021-07-09 15:08:54 -07:00
committed by Facebook GitHub Bot
parent 028e438d6c
commit ae65f63971
2 changed files with 26 additions and 24 deletions

View File

@ -154,18 +154,17 @@ class TestNNAPI(TestCase):
]:
self.check(mod, torch.randn(4, 2, 1, 3, 7))
# TODO(axit): To add support for runtime
# self.check(
# torch.nn.Flatten(),
# torch.randn(4, 2, 1, 3, 7),
# convert_args=[torch.zeros(0, 2, 1, 3, 7)]
# )
# with self.assertRaisesRegex(Exception, "dims can't be flexible"):
# self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7))
# with self.assertRaisesRegex(Exception, "Only 1 dim"):
# self.check(
# torch.nn.Flatten(start_dim=1, end_dim=-2),
# torch.randn(0, 2, 1, 3, 0))
self.check(
torch.nn.Flatten(),
torch.randn(4, 2, 1, 3, 7),
convert_args=[torch.zeros(0, 2, 1, 3, 7)]
)
with self.assertRaisesRegex(Exception, "dims can't be flexible"):
self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7))
with self.assertRaisesRegex(Exception, "Only 1 dim"):
self.check(
torch.nn.Flatten(start_dim=1, end_dim=-2),
torch.randn(0, 2, 1, 3, 0))
def test_slice(self):
class SliceModule(torch.nn.Module):