mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-21 05:34:18 +08:00 
			
		
		
		
	Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57501 Add an api _get_model_ops_and_info to get root operators and versioning info of a model in both cxx and python, and the input can be from a file path or buffer. ghstack-source-id: 129620112 Test Plan: unit test. Reviewed By: xcheng16, raziel Differential Revision: D28162765 fbshipit-source-id: 4413c1e906b8a872e4a717d849da37347adbbea4
		
			
				
	
	
		
			309 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			309 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import fnmatch
 | |
| import io
 | |
| import shutil
 | |
| import tempfile
 | |
| import torch
 | |
| import torch.utils.show_pickle
 | |
| from torch.utils.mobile_optimizer import optimize_for_mobile
 | |
| from torch.jit.mobile import (
 | |
|     _load_for_lite_interpreter,
 | |
|     _get_model_bytecode_version,
 | |
|     _get_model_ops_and_info,
 | |
|     _backport_for_mobile_to_buffer,
 | |
|     _backport_for_mobile)
 | |
| from torch.testing._internal.common_utils import TestCase, run_tests
 | |
| from pathlib import Path
 | |
| 
 | |
| pytorch_test_dir = Path(__file__).resolve().parents[1]
 | |
| 
 | |
| # script_module_v4.ptl and script_module_v5.ptl source code
 | |
| # class TestModule(torch.nn.Module):
 | |
| #     def __init__(self, v):
 | |
| #         super().__init__()
 | |
| #         self.x = v
 | |
| 
 | |
| #     def forward(self, y: int):
 | |
| #         increment = torch.ones([2, 4], dtype=torch.float64)
 | |
| #         return self.x + y + increment
 | |
| 
 | |
| # output_model_path = Path(tmpdirname, "script_module_v5.ptl")
 | |
| # script_module = torch.jit.script(TestModule(1))
 | |
| # optimized_scripted_module = optimize_for_mobile(script_module)
 | |
| # exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter(
 | |
| #   str(output_model_path))
 | |
| 
 | |
| SCRIPT_MODULE_V4_BYTECODE_PKL = '''
 | |
| (4,
 | |
|  ('__torch__.*.TestModule.forward',
 | |
|   (('instructions',
 | |
|     (('STOREN', 1, 2),
 | |
|      ('DROPR', 1, 0),
 | |
|      ('LOADC', 0, 0),
 | |
|      ('LOADC', 1, 0),
 | |
|      ('MOVE', 2, 0),
 | |
|      ('OP', 0, 0),
 | |
|      ('LOADC', 1, 0),
 | |
|      ('OP', 1, 0),
 | |
|      ('RET', 0, 0))),
 | |
|    ('operators', (('aten::add', 'int'), ('aten::add', 'Scalar'))),
 | |
|    ('constants',
 | |
|     (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, '0', 'cpu', 8),),
 | |
|        0,
 | |
|        (2, 4),
 | |
|        (4, 1),
 | |
|        False,
 | |
|        collections.OrderedDict()),
 | |
|      1)),
 | |
|    ('types', ()),
 | |
|    ('register_size', 2)),
 | |
|   (('arguments',
 | |
|     ((('name', 'self'),
 | |
|       ('type', '__torch__.*.TestModule'),
 | |
|       ('default_value', None)),
 | |
|      (('name', 'y'), ('type', 'int'), ('default_value', None)))),
 | |
|    ('returns',
 | |
|     ((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
 | |
|         '''
 | |
| 
 | |
| SCRIPT_MODULE_V5_BYTECODE_PKL = '''
 | |
| (5,
 | |
|  ('__torch__.*.TestModule.forward',
 | |
|   (('instructions',
 | |
|     (('STOREN', 1, 2),
 | |
|      ('DROPR', 1, 0),
 | |
|      ('LOADC', 0, 0),
 | |
|      ('LOADC', 1, 0),
 | |
|      ('MOVE', 2, 0),
 | |
|      ('OP', 0, 0),
 | |
|      ('LOADC', 1, 0),
 | |
|      ('OP', 1, 0),
 | |
|      ('RET', 0, 0))),
 | |
|    ('operators', (('aten::add', 'int'), ('aten::add', 'Scalar'))),
 | |
|    ('constants',
 | |
|     (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, 'constants/0', 'cpu', 8),),
 | |
|        0,
 | |
|        (2, 4),
 | |
|        (4, 1),
 | |
|        False,
 | |
|        collections.OrderedDict()),
 | |
|      1)),
 | |
|    ('types', ()),
 | |
|    ('register_size', 2)),
 | |
|   (('arguments',
 | |
|     ((('name', 'self'),
 | |
|       ('type', '__torch__.*.TestModule'),
 | |
|       ('default_value', None)),
 | |
|      (('name', 'y'), ('type', 'int'), ('default_value', None)))),
 | |
|    ('returns',
 | |
|     ((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
 | |
|         '''
 | |
| 
 | |
| SCRIPT_MODULE_V6_BYTECODE_PKL = '''
 | |
| (6,
 | |
|  ('__torch__.*.TestModule.forward',
 | |
|   (('instructions',
 | |
|     (('STOREN', 1, 2),
 | |
|      ('DROPR', 1, 0),
 | |
|      ('LOADC', 0, 0),
 | |
|      ('LOADC', 1, 0),
 | |
|      ('MOVE', 2, 0),
 | |
|      ('OP', 0, 0),
 | |
|      ('OP', 1, 0),
 | |
|      ('RET', 0, 0))),
 | |
|    ('operators', (('aten::add', 'int', 2), ('aten::add', 'Scalar', 2))),
 | |
|    ('constants',
 | |
|     (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, '0', 'cpu', 8),),
 | |
|        0,
 | |
|        (2, 4),
 | |
|        (4, 1),
 | |
|        False,
 | |
|        collections.OrderedDict()),
 | |
|      1)),
 | |
|    ('types', ()),
 | |
|    ('register_size', 2)),
 | |
|   (('arguments',
 | |
|     ((('name', 'self'),
 | |
|       ('type', '__torch__.*.TestModule'),
 | |
|       ('default_value', None)),
 | |
|      (('name', 'y'), ('type', 'int'), ('default_value', None)))),
 | |
|    ('returns',
 | |
|     ((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
 | |
|     '''
 | |
| 
 | |
| SCRIPT_MODULE_BYTECODE_PKL = {
 | |
|     4: {
 | |
|         "bytecode_pkl": SCRIPT_MODULE_V4_BYTECODE_PKL,
 | |
|         "model_name": "script_module_v4.ptl"},
 | |
| }
 | |
| 
 | |
| # The minimum version a model can be backported to
 | |
| # Need to be updated when a bytecode version is completely retired
 | |
| MINIMUM_TO_VERSION = 4
 | |
| 
 | |
| class testVariousModelVersions(TestCase):
 | |
|     def test_get_model_bytecode_version(self):
 | |
|         def check_model_version(model_path, expect_version):
 | |
|             actual_version = _get_model_bytecode_version(model_path)
 | |
|             assert(actual_version == expect_version)
 | |
|         for version, model_info in SCRIPT_MODULE_BYTECODE_PKL.items():
 | |
|             model_path = pytorch_test_dir / "cpp" / "jit" / model_info["model_name"]
 | |
|             check_model_version(model_path, version)
 | |
| 
 | |
|     def test_bytecode_values_for_all_backport_functions(self):
 | |
|         # Find the maximum version of the checked in models, start backporting to the minimum support version,
 | |
|         # and comparing the bytecode pkl content.
 | |
|         # It can't be merged to the test `test_all_backport_functions`, because optimization is dynamic and
 | |
|         # the content might change when optimize function changes. This test focuses
 | |
|         # on bytecode.pkl content validation. For the content validation, it is not byte to byte check, but
 | |
|         # regular expression matching. The wildcard can be used to skip some specific content comparison.
 | |
|         maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
 | |
|         current_from_version = maximum_checked_in_model_version
 | |
| 
 | |
|         with tempfile.TemporaryDirectory() as tmpdirname:
 | |
|             while current_from_version > MINIMUM_TO_VERSION:
 | |
|                 # Load model v5 and run forward method
 | |
|                 model_name = SCRIPT_MODULE_BYTECODE_PKL[current_from_version]["model_name"]
 | |
|                 input_model_path = pytorch_test_dir / "cpp" / "jit" / model_name
 | |
| 
 | |
|                 # A temporary model file will be export to this path, and run through bytecode.pkl
 | |
|                 # content check.
 | |
|                 tmp_output_model_path_backport = Path(tmpdirname, "tmp_script_module_backport.ptl")
 | |
| 
 | |
|                 current_to_version = current_from_version - 1
 | |
|                 backport_success = _backport_for_mobile(input_model_path, tmp_output_model_path_backport, current_to_version)
 | |
|                 assert(backport_success)
 | |
| 
 | |
|                 expect_bytecode_pkl = SCRIPT_MODULE_BYTECODE_PKL[current_to_version]["bytecode_pkl"]
 | |
| 
 | |
|                 buf = io.StringIO()
 | |
|                 torch.utils.show_pickle.main(
 | |
|                     ["", tmpdirname + "/" + tmp_output_model_path_backport.name + "@*/bytecode.pkl"],
 | |
|                     output_stream=buf)
 | |
|                 output = buf.getvalue()
 | |
| 
 | |
|                 acutal_result_clean = "".join(output.split())
 | |
|                 expect_result_clean = "".join(expect_bytecode_pkl.split())
 | |
|                 isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
 | |
|                 assert(isMatch)
 | |
| 
 | |
|                 current_from_version -= 1
 | |
|             shutil.rmtree(tmpdirname)
 | |
| 
 | |
|     def test_all_backport_functions(self):
 | |
|         # Backport from the latest bytecode version to the minimum support version
 | |
|         # Load, run the backport model, and check version
 | |
|         class TestModule(torch.nn.Module):
 | |
|             def __init__(self, v):
 | |
|                 super().__init__()
 | |
|                 self.x = v
 | |
| 
 | |
|             def forward(self, y: int):
 | |
|                 increment = torch.ones([2, 4], dtype=torch.float64)
 | |
|                 return self.x + y + increment
 | |
| 
 | |
|         module_input = 1
 | |
|         expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64)
 | |
| 
 | |
|         # temporary input model file and output model file will be exported in the temporary folder
 | |
|         with tempfile.TemporaryDirectory() as tmpdirname:
 | |
|             tmp_input_model_path = Path(tmpdirname, "tmp_script_module.ptl")
 | |
|             script_module = torch.jit.script(TestModule(1))
 | |
|             optimized_scripted_module = optimize_for_mobile(script_module)
 | |
|             exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter(str(tmp_input_model_path))
 | |
| 
 | |
|             current_from_version = _get_model_bytecode_version(tmp_input_model_path)
 | |
|             current_to_version = current_from_version - 1
 | |
|             tmp_output_model_path = Path(tmpdirname, "tmp_script_module_backport.ptl")
 | |
| 
 | |
|             while current_to_version >= MINIMUM_TO_VERSION:
 | |
|                 # Backport the latest model to `to_version` to a tmp file "tmp_script_module_backport"
 | |
|                 backport_success = _backport_for_mobile(tmp_input_model_path, tmp_output_model_path, current_to_version)
 | |
|                 assert(backport_success)
 | |
| 
 | |
|                 backport_version = _get_model_bytecode_version(tmp_output_model_path)
 | |
|                 assert(backport_version == current_to_version)
 | |
| 
 | |
|                 # Load model and run forward method
 | |
|                 mobile_module = _load_for_lite_interpreter(str(tmp_input_model_path))
 | |
|                 mobile_module_result = mobile_module(module_input)
 | |
|                 torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result)
 | |
|                 current_to_version -= 1
 | |
| 
 | |
|             # Check backport failure case
 | |
|             backport_success = _backport_for_mobile(tmp_input_model_path, tmp_output_model_path, MINIMUM_TO_VERSION - 1)
 | |
|             assert(not backport_success)
 | |
|             # need to clean the folder before it closes, otherwise will run into git not clean error
 | |
|             shutil.rmtree(tmpdirname)
 | |
| 
 | |
|     # Check just the test_backport_bytecode_from_file_to_file mechanism but not the function implementations
 | |
|     def test_backport_bytecode_from_file_to_file(self):
 | |
|         maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
 | |
|         script_module_v5_path = pytorch_test_dir / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[
 | |
|             maximum_checked_in_model_version]["model_name"]
 | |
| 
 | |
|         if (maximum_checked_in_model_version > MINIMUM_TO_VERSION):
 | |
|             with tempfile.TemporaryDirectory() as tmpdirname:
 | |
|                 tmp_backport_model_path = Path(tmpdirname, "tmp_script_module_v5_backported_to_v4.ptl")
 | |
|                 # backport from file
 | |
|                 success = _backport_for_mobile(
 | |
|                     script_module_v5_path,
 | |
|                     tmp_backport_model_path,
 | |
|                     maximum_checked_in_model_version - 1)
 | |
|                 assert(success)
 | |
| 
 | |
|                 buf = io.StringIO()
 | |
|                 torch.utils.show_pickle.main(
 | |
|                     ["", tmpdirname + "/" + tmp_backport_model_path.name + "@*/bytecode.pkl"],
 | |
|                     output_stream=buf)
 | |
|                 output = buf.getvalue()
 | |
| 
 | |
|                 expected_result = SCRIPT_MODULE_V4_BYTECODE_PKL
 | |
|                 acutal_result_clean = "".join(output.split())
 | |
|                 expect_result_clean = "".join(expected_result.split())
 | |
|                 isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
 | |
|                 assert(isMatch)
 | |
| 
 | |
|                 # Load model v4 and run forward method
 | |
|                 mobile_module = _load_for_lite_interpreter(str(tmp_backport_model_path))
 | |
|                 module_input = 1
 | |
|                 mobile_module_result = mobile_module(module_input)
 | |
|                 expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64)
 | |
|                 torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result)
 | |
|                 shutil.rmtree(tmpdirname)
 | |
| 
 | |
|     # Check just the _backport_for_mobile_to_buffer mechanism but not the function implementations
 | |
|     def test_backport_bytecode_from_file_to_buffer(self):
 | |
|         maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
 | |
|         script_module_v5_path = pytorch_test_dir / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[
 | |
|             maximum_checked_in_model_version]["model_name"]
 | |
| 
 | |
|         if (maximum_checked_in_model_version > MINIMUM_TO_VERSION):
 | |
|             # Backport model to v4
 | |
|             script_module_v4_buffer = _backport_for_mobile_to_buffer(
 | |
|                 script_module_v5_path, maximum_checked_in_model_version - 1)
 | |
|             buf = io.StringIO()
 | |
| 
 | |
|             # Check version of the model v4 from backport
 | |
|             bytesio = io.BytesIO(script_module_v4_buffer)
 | |
|             backport_version = _get_model_bytecode_version(bytesio)
 | |
|             assert(backport_version == maximum_checked_in_model_version - 1)
 | |
| 
 | |
|             # Load model v4 from backport and run forward method
 | |
|             bytesio = io.BytesIO(script_module_v4_buffer)
 | |
|             mobile_module = _load_for_lite_interpreter(bytesio)
 | |
|             module_input = 1
 | |
|             mobile_module_result = mobile_module(module_input)
 | |
|             expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64)
 | |
|             torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result)
 | |
| 
 | |
| 
 | |
|     def test_get_model_ops_and_info(self):
 | |
|         # TODO update this to be more in the style of the above tests after a backport from 6 -> 5 exists
 | |
|         script_module_v6 = pytorch_test_dir / "cpp" / "jit" / "script_module_v6.ptl"
 | |
|         ops_v6 = _get_model_ops_and_info(script_module_v6)
 | |
|         assert(ops_v6["aten::add.int"].num_schema_args == 2)
 | |
|         assert(ops_v6["aten::add.Scalar"].num_schema_args == 2)
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     run_tests()
 |