mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Make nnapi flatten converter accept flex inputs (#61024)
Summary: As title Pull Request resolved: https://github.com/pytorch/pytorch/pull/61024 Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_flatten Reviewed By: anshuljain1 Differential Revision: D29480748 fbshipit-source-id: c334b09600a64d3e552cec843d6da3de28e7d27c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
028e438d6c
commit
ae65f63971
@ -154,18 +154,17 @@ class TestNNAPI(TestCase):
|
||||
]:
|
||||
self.check(mod, torch.randn(4, 2, 1, 3, 7))
|
||||
|
||||
# TODO(axit): To add support for runtime
|
||||
# self.check(
|
||||
# torch.nn.Flatten(),
|
||||
# torch.randn(4, 2, 1, 3, 7),
|
||||
# convert_args=[torch.zeros(0, 2, 1, 3, 7)]
|
||||
# )
|
||||
# with self.assertRaisesRegex(Exception, "dims can't be flexible"):
|
||||
# self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7))
|
||||
# with self.assertRaisesRegex(Exception, "Only 1 dim"):
|
||||
# self.check(
|
||||
# torch.nn.Flatten(start_dim=1, end_dim=-2),
|
||||
# torch.randn(0, 2, 1, 3, 0))
|
||||
self.check(
|
||||
torch.nn.Flatten(),
|
||||
torch.randn(4, 2, 1, 3, 7),
|
||||
convert_args=[torch.zeros(0, 2, 1, 3, 7)]
|
||||
)
|
||||
with self.assertRaisesRegex(Exception, "dims can't be flexible"):
|
||||
self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7))
|
||||
with self.assertRaisesRegex(Exception, "Only 1 dim"):
|
||||
self.check(
|
||||
torch.nn.Flatten(start_dim=1, end_dim=-2),
|
||||
torch.randn(0, 2, 1, 3, 0))
|
||||
|
||||
def test_slice(self):
|
||||
class SliceModule(torch.nn.Module):
|
||||
|
@ -935,7 +935,7 @@ class _NnapiSerializer(object):
|
||||
assert node.inputsSize() == 3
|
||||
assert node.outputsSize() == 1
|
||||
|
||||
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
|
||||
in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
|
||||
|
||||
start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType")
|
||||
end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType")
|
||||
@ -956,23 +956,26 @@ class _NnapiSerializer(object):
|
||||
in_oper.shape[end_dim + 1:]
|
||||
)
|
||||
|
||||
# TODO(axit): To add support for runtime
|
||||
# if any(dim == 0 for dim in in_oper.shape[start_dim: end_dim + 1]):
|
||||
# raise Exception("Flattened dims can't be flexible")
|
||||
# non_flattened_dims = in_oper.shape[: start_dim] + in_oper.shape[end_dim + 1:]
|
||||
# if non_flattened_dims.count(0) > 1:
|
||||
# raise Exception("Only 1 dim can be flexible")
|
||||
# out_shape = tuple(
|
||||
# dim if dim != 0 else -1
|
||||
# for dim in out_shape
|
||||
# )
|
||||
if any(dim == 0 for dim in in_oper.shape[start_dim: end_dim + 1]):
|
||||
raise Exception("Flattening flexible dims is not supported yet")
|
||||
non_flattened_dims = in_oper.shape[: start_dim] + in_oper.shape[end_dim + 1:]
|
||||
if non_flattened_dims.count(0) > 1:
|
||||
raise Exception("Only 1 dim can be flexible")
|
||||
|
||||
out_oper = in_oper._replace(shape=out_shape)
|
||||
out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
|
||||
|
||||
for idx, dim in enumerate(out_shape):
|
||||
if dim == 0:
|
||||
self.forward_operand_shape(out_id, idx, in_id, in_oper.shape.index(0))
|
||||
|
||||
inputs_1 = tuple(
|
||||
dim if dim != 0 else -1
|
||||
for dim in out_shape
|
||||
)
|
||||
inputs = [None] * 2
|
||||
inputs[0] = in_id
|
||||
inputs[1] = self.add_immediate_int_vector(out_shape)
|
||||
inputs[1] = self.add_immediate_int_vector(inputs_1)
|
||||
|
||||
outputs = [None] * 1
|
||||
outputs[0] = out_id
|
||||
|
Reference in New Issue
Block a user