mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Pytorch Edge] Model Ops compatibility api (#57501)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57501 Add an api _get_model_ops_and_info to get root operators and versioning info of a model in both cxx and python, and the input can be from a file path or buffer. ghstack-source-id: 129620112 Test Plan: unit test. Reviewed By: xcheng16, raziel Differential Revision: D28162765 fbshipit-source-id: 4413c1e906b8a872e4a717d849da37347adbbea4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
2a456e4f49
commit
1c5f63d86d
@ -8,12 +8,13 @@ from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
from torch.jit.mobile import (
|
||||
_load_for_lite_interpreter,
|
||||
_get_model_bytecode_version,
|
||||
_get_model_ops_and_info,
|
||||
_backport_for_mobile_to_buffer,
|
||||
_backport_for_mobile)
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from pathlib import Path
|
||||
|
||||
pytorch_test_dri = Path(__file__).resolve().parents[1]
|
||||
pytorch_test_dir = Path(__file__).resolve().parents[1]
|
||||
|
||||
# script_module_v4.ptl and script_module_v5.ptl source code
|
||||
# class TestModule(torch.nn.Module):
|
||||
@ -97,6 +98,38 @@ SCRIPT_MODULE_V5_BYTECODE_PKL = '''
|
||||
((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
|
||||
'''
|
||||
|
||||
SCRIPT_MODULE_V6_BYTECODE_PKL = '''
|
||||
(6,
|
||||
('__torch__.*.TestModule.forward',
|
||||
(('instructions',
|
||||
(('STOREN', 1, 2),
|
||||
('DROPR', 1, 0),
|
||||
('LOADC', 0, 0),
|
||||
('LOADC', 1, 0),
|
||||
('MOVE', 2, 0),
|
||||
('OP', 0, 0),
|
||||
('OP', 1, 0),
|
||||
('RET', 0, 0))),
|
||||
('operators', (('aten::add', 'int', 2), ('aten::add', 'Scalar', 2))),
|
||||
('constants',
|
||||
(torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, '0', 'cpu', 8),),
|
||||
0,
|
||||
(2, 4),
|
||||
(4, 1),
|
||||
False,
|
||||
collections.OrderedDict()),
|
||||
1)),
|
||||
('types', ()),
|
||||
('register_size', 2)),
|
||||
(('arguments',
|
||||
((('name', 'self'),
|
||||
('type', '__torch__.*.TestModule'),
|
||||
('default_value', None)),
|
||||
(('name', 'y'), ('type', 'int'), ('default_value', None)))),
|
||||
('returns',
|
||||
((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
|
||||
'''
|
||||
|
||||
SCRIPT_MODULE_BYTECODE_PKL = {
|
||||
4: {
|
||||
"bytecode_pkl": SCRIPT_MODULE_V4_BYTECODE_PKL,
|
||||
@ -113,7 +146,7 @@ class testVariousModelVersions(TestCase):
|
||||
actual_version = _get_model_bytecode_version(model_path)
|
||||
assert(actual_version == expect_version)
|
||||
for version, model_info in SCRIPT_MODULE_BYTECODE_PKL.items():
|
||||
model_path = pytorch_test_dri / "cpp" / "jit" / model_info["model_name"]
|
||||
model_path = pytorch_test_dir / "cpp" / "jit" / model_info["model_name"]
|
||||
check_model_version(model_path, version)
|
||||
|
||||
def test_bytecode_values_for_all_backport_functions(self):
|
||||
@ -130,7 +163,7 @@ class testVariousModelVersions(TestCase):
|
||||
while current_from_version > MINIMUM_TO_VERSION:
|
||||
# Load model v5 and run forward method
|
||||
model_name = SCRIPT_MODULE_BYTECODE_PKL[current_from_version]["model_name"]
|
||||
input_model_path = pytorch_test_dri / "cpp" / "jit" / model_name
|
||||
input_model_path = pytorch_test_dir / "cpp" / "jit" / model_name
|
||||
|
||||
# A temporary model file will be export to this path, and run through bytecode.pkl
|
||||
# content check.
|
||||
@ -205,7 +238,7 @@ class testVariousModelVersions(TestCase):
|
||||
# Check just the test_backport_bytecode_from_file_to_file mechanism but not the function implementations
|
||||
def test_backport_bytecode_from_file_to_file(self):
|
||||
maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
|
||||
script_module_v5_path = pytorch_test_dri / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[
|
||||
script_module_v5_path = pytorch_test_dir / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[
|
||||
maximum_checked_in_model_version]["model_name"]
|
||||
|
||||
if (maximum_checked_in_model_version > MINIMUM_TO_VERSION):
|
||||
@ -241,7 +274,7 @@ class testVariousModelVersions(TestCase):
|
||||
# Check just the _backport_for_mobile_to_buffer mechanism but not the function implementations
|
||||
def test_backport_bytecode_from_file_to_buffer(self):
|
||||
maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
|
||||
script_module_v5_path = pytorch_test_dri / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[
|
||||
script_module_v5_path = pytorch_test_dir / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[
|
||||
maximum_checked_in_model_version]["model_name"]
|
||||
|
||||
if (maximum_checked_in_model_version > MINIMUM_TO_VERSION):
|
||||
@ -264,5 +297,12 @@ class testVariousModelVersions(TestCase):
|
||||
torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result)
|
||||
|
||||
|
||||
def test_get_model_ops_and_info(self):
|
||||
# TODO update this to be more in the style of the above tests after a backport from 6 -> 5 exists
|
||||
script_module_v6 = pytorch_test_dir / "cpp" / "jit" / "script_module_v6.ptl"
|
||||
ops_v6 = _get_model_ops_and_info(script_module_v6)
|
||||
assert(ops_v6["aten::add.int"].num_schema_args == 2)
|
||||
assert(ops_v6["aten::add.Scalar"].num_schema_args == 2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user