mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyTorch Edge][type] Add type check in compatibility api (#63129)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63129 1. Add an api to get `supported_types` from runtime, expose in c++ only. 2. Add an api to get `contained_types` from model, expose in both c++ and PyThon. 3. Add a field `contained_types_` in `type_parser.cpp` to track the contained types when parsing python string. 4. Expand `is_compatible` api to check type. When checking type, it will check the contained type list from the model with the support type list from runtime. 5. Expand the unittest for compatibility to cover type 6. Add unit test in python to check type list ghstack-source-id: 139826944 Test Plan: ``` buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.GetContainTypes' buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.isCompatibleSuccess' buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.isCompatibleFail' buck test //caffe2/test:mobile ``` Reviewed By: iseeyuan Differential Revision: D30231419 fbshipit-source-id: 8427f423ec28cc5de56411f15fd960d8595d6947
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c75210face
commit
a5895f85be
@ -7,6 +7,7 @@ import torch.utils.show_pickle
|
||||
# from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
from torch.jit.mobile import (
|
||||
_load_for_lite_interpreter,
|
||||
_get_mobile_model_contained_types,
|
||||
_get_model_bytecode_version,
|
||||
_get_model_ops_and_info,
|
||||
_backport_for_mobile_to_buffer,
|
||||
@ -306,5 +307,23 @@ class testVariousModelVersions(TestCase):
|
||||
assert(ops_v6["aten::add.int"].num_schema_args == 2)
|
||||
assert(ops_v6["aten::add.Scalar"].num_schema_args == 2)
|
||||
|
||||
def test_get_mobile_model_contained_types(self):
|
||||
class MyTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyTestModule, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x + 10
|
||||
|
||||
sample_input = torch.tensor([1])
|
||||
|
||||
script_module = torch.jit.script(MyTestModule())
|
||||
script_module_result = script_module(sample_input)
|
||||
|
||||
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
type_list = _get_mobile_model_contained_types(buffer)
|
||||
assert(len(type_list) >= 0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user