mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Make nnapi flatten converter accept flex inputs (#61024)
Summary: As title Pull Request resolved: https://github.com/pytorch/pytorch/pull/61024 Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_flatten Reviewed By: anshuljain1 Differential Revision: D29480748 fbshipit-source-id: c334b09600a64d3e552cec843d6da3de28e7d27c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
028e438d6c
commit
ae65f63971
@ -154,18 +154,17 @@ class TestNNAPI(TestCase):
|
||||
]:
|
||||
self.check(mod, torch.randn(4, 2, 1, 3, 7))
|
||||
|
||||
# TODO(axit): To add support for runtime
|
||||
# self.check(
|
||||
# torch.nn.Flatten(),
|
||||
# torch.randn(4, 2, 1, 3, 7),
|
||||
# convert_args=[torch.zeros(0, 2, 1, 3, 7)]
|
||||
# )
|
||||
# with self.assertRaisesRegex(Exception, "dims can't be flexible"):
|
||||
# self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7))
|
||||
# with self.assertRaisesRegex(Exception, "Only 1 dim"):
|
||||
# self.check(
|
||||
# torch.nn.Flatten(start_dim=1, end_dim=-2),
|
||||
# torch.randn(0, 2, 1, 3, 0))
|
||||
self.check(
|
||||
torch.nn.Flatten(),
|
||||
torch.randn(4, 2, 1, 3, 7),
|
||||
convert_args=[torch.zeros(0, 2, 1, 3, 7)]
|
||||
)
|
||||
with self.assertRaisesRegex(Exception, "dims can't be flexible"):
|
||||
self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7))
|
||||
with self.assertRaisesRegex(Exception, "Only 1 dim"):
|
||||
self.check(
|
||||
torch.nn.Flatten(start_dim=1, end_dim=-2),
|
||||
torch.randn(0, 2, 1, 3, 0))
|
||||
|
||||
def test_slice(self):
|
||||
class SliceModule(torch.nn.Module):
|
||||
|
Reference in New Issue
Block a user