Make nnapi flatten converter accept flex inputs (#61024)

Summary:
As title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61024

Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_flatten

Reviewed By: anshuljain1

Differential Revision: D29480748

fbshipit-source-id: c334b09600a64d3e552cec843d6da3de28e7d27c
This commit is contained in:
Akshit Khurana
2021-07-09 15:08:54 -07:00
committed by Facebook GitHub Bot
parent 028e438d6c
commit ae65f63971
2 changed files with 26 additions and 24 deletions

View File

@ -154,18 +154,17 @@ class TestNNAPI(TestCase):
]:
self.check(mod, torch.randn(4, 2, 1, 3, 7))
# TODO(axit): To add support for runtime
# self.check(
# torch.nn.Flatten(),
# torch.randn(4, 2, 1, 3, 7),
# convert_args=[torch.zeros(0, 2, 1, 3, 7)]
# )
# with self.assertRaisesRegex(Exception, "dims can't be flexible"):
# self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7))
# with self.assertRaisesRegex(Exception, "Only 1 dim"):
# self.check(
# torch.nn.Flatten(start_dim=1, end_dim=-2),
# torch.randn(0, 2, 1, 3, 0))
self.check(
torch.nn.Flatten(),
torch.randn(4, 2, 1, 3, 7),
convert_args=[torch.zeros(0, 2, 1, 3, 7)]
)
with self.assertRaisesRegex(Exception, "dims can't be flexible"):
self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7))
with self.assertRaisesRegex(Exception, "Only 1 dim"):
self.check(
torch.nn.Flatten(start_dim=1, end_dim=-2),
torch.randn(0, 2, 1, 3, 0))
def test_slice(self):
class SliceModule(torch.nn.Module):

View File

@ -935,7 +935,7 @@ class _NnapiSerializer(object):
assert node.inputsSize() == 3
assert node.outputsSize() == 1
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType")
end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType")
@ -956,23 +956,26 @@ class _NnapiSerializer(object):
in_oper.shape[end_dim + 1:]
)
# TODO(axit): To add support for runtime
# if any(dim == 0 for dim in in_oper.shape[start_dim: end_dim + 1]):
# raise Exception("Flattened dims can't be flexible")
# non_flattened_dims = in_oper.shape[: start_dim] + in_oper.shape[end_dim + 1:]
# if non_flattened_dims.count(0) > 1:
# raise Exception("Only 1 dim can be flexible")
# out_shape = tuple(
# dim if dim != 0 else -1
# for dim in out_shape
# )
if any(dim == 0 for dim in in_oper.shape[start_dim: end_dim + 1]):
raise Exception("Flattening flexible dims is not supported yet")
non_flattened_dims = in_oper.shape[: start_dim] + in_oper.shape[end_dim + 1:]
if non_flattened_dims.count(0) > 1:
raise Exception("Only 1 dim can be flexible")
out_oper = in_oper._replace(shape=out_shape)
out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
for idx, dim in enumerate(out_shape):
if dim == 0:
self.forward_operand_shape(out_id, idx, in_id, in_oper.shape.index(0))
inputs_1 = tuple(
dim if dim != 0 else -1
for dim in out_shape
)
inputs = [None] * 2
inputs[0] = in_id
inputs[1] = self.add_immediate_int_vector(out_shape)
inputs[1] = self.add_immediate_int_vector(inputs_1)
outputs = [None] * 1
outputs[0] = out_id