[PyTorch Edge][type] Add type check in compatibility api (#63129)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63129

1. Add an api to get `supported_types` from runtime, expose in c++ only.
2. Add an api to get `contained_types` from model, expose in both c++ and PyThon.
3. Add a field `contained_types_` in `type_parser.cpp` to track the contained types when parsing python string.
4. Expand `is_compatible` api to check type. When checking type, it will check the contained type list from the model with the support type list from runtime.
5. Expand the unittest for compatibility to cover type
6. Add unit test in python to check type list
ghstack-source-id: 139826944

Test Plan:
```
buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.GetContainTypes'

buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.isCompatibleSuccess'
buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.isCompatibleFail'

buck test //caffe2/test:mobile
```

Reviewed By: iseeyuan

Differential Revision: D30231419

fbshipit-source-id: 8427f423ec28cc5de56411f15fd960d8595d6947
This commit is contained in:
Chen Lai
2021-10-06 02:20:54 -07:00
committed by Facebook GitHub Bot
parent c75210face
commit a5895f85be
11 changed files with 293 additions and 24 deletions

View File

@ -7,6 +7,7 @@ import torch.utils.show_pickle
# from torch.utils.mobile_optimizer import optimize_for_mobile
from torch.jit.mobile import (
_load_for_lite_interpreter,
_get_mobile_model_contained_types,
_get_model_bytecode_version,
_get_model_ops_and_info,
_backport_for_mobile_to_buffer,
@ -306,5 +307,23 @@ class testVariousModelVersions(TestCase):
assert(ops_v6["aten::add.int"].num_schema_args == 2)
assert(ops_v6["aten::add.Scalar"].num_schema_args == 2)
def test_get_mobile_model_contained_types(self):
class MyTestModule(torch.nn.Module):
def __init__(self):
super(MyTestModule, self).__init__()
def forward(self, x):
return x + 10
sample_input = torch.tensor([1])
script_module = torch.jit.script(MyTestModule())
script_module_result = script_module(sample_input)
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
type_list = _get_mobile_model_contained_types(buffer)
assert(len(type_list) >= 0)
if __name__ == '__main__':
run_tests()