mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[PyTorch, Mobile] Serialization format change for source range (#54284)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54284 In order to bring mobile deployment, via lite interpreter, on feature parity with JIT, with respect model level debug information we must make model level debug information available to mobile runtime. At the moment, model level debug information is stored in SourceRange which associates node's of graph to where the come from in original python source code. This information is serialized as part of debug_pkl and deserialized when JIT loads the model and reads the model code. On lite interpreter, we do not have access to all the functionality of JIT and hence we cannot load model in the same way as JIT, by reading code, constructing module hierarchy and graph corresponding module methods etc. Instead in, lite interpreter, only bytecode corresonding to the compiled graph, Code, is saved. Thus in order to annotate OPs in the bytecode with equivalent SourceRange information we do the following: 1. During model serialization, we create a unique tag for each source range of the model. 2. Create a map of <SourceRange, tag> 3. During debug_pkl serialization we save tag along with SourceRange, on top of byte offset. 4. During bytecode generation, the methods of the top module are lowered. During this process methods are inlined. In the inlined graph, when the node of a graph is lowered to bytecode, we query node's source range and look it up against the map. 5. Resulting source range tag is serialized in module_debug_info. 6. During model deserialization, we read all the debug_pkl records in the archieve and create a map of <tag, SourceRange> 7. This map can be used to find source code information. During mobile runtime: 1. We read all the debug_pkl records and create <tag=debug_handle, SourceRange> map. 1.1 This map, MobileDebugInfo, is a member of mobile Module. 2. Interpreter catches appropriate exceptions and sets the thread local debug handle and rethrows the exception. 3. In Function's run method we catch exception and query current debug handle where the exception happened. 4. Query MobileDebugInfo with debug handle to retrieve source range and augment error with source range info. This information is still incomplete as it does not contain entire callstack. In the following diffs we will serialize InlinedCallStack directly. Note that compilation is gated by SYMBOLICATE_MOBILE_DEBUG_HANDLE macro, so that mobile builds can avoid building MobileDebugInfo, source range and source range pickler/unpickler. Later we will add path where, if building without debug support stack trace will contain only debug handles. They can be symbolicated later. Test Plan: Ported bunch of source range tests from test_jit.py. Added on more test in test_lite_interpreter.py Imported from OSS Reviewed By: raziel Differential Revision: D27174722 fbshipit-source-id: a7b7c6088ce16dec37e823c7fefa4f0b61047e12
This commit is contained in:
		
				
					committed by
					
						 Facebook GitHub Bot
						Facebook GitHub Bot
					
				
			
			
				
	
			
			
			
						parent
						
							aa5ff7cc91
						
					
				
				
					commit
					f4a921600a
				
			| @ -254,6 +254,7 @@ if(NOT DEFINED USE_VULKAN) | ||||
|       "ANDROID" OFF) | ||||
| endif() | ||||
|  | ||||
| option(USE_SOURCE_DEBUG_ON_MOBILE "Enable " ON) | ||||
| option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF) | ||||
| option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF) | ||||
| option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation as opposed to build-time (needs libshaderc)" OFF) | ||||
| @ -647,6 +648,10 @@ if(USE_PYTORCH_METAL) | ||||
|   string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_METAL") | ||||
| endif() | ||||
|  | ||||
| if(USE_SOURCE_DEBUG_ON_MOBILE) | ||||
|   string(APPEND CMAKE_CXX_FLAGS " -DSYMBOLICATE_MOBILE_DEBUG_HANDLE") | ||||
| endif() | ||||
|  | ||||
| # ---[ Allowlist file if allowlist is specified | ||||
| include(cmake/Allowlist.cmake) | ||||
|  | ||||
|  | ||||
| @ -516,9 +516,22 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) | ||||
|   list(APPEND TORCH_SRCS ${GENERATED_H_TORCH}) | ||||
|   list(APPEND LIBTORCH_CMAKE_SRCS "") | ||||
|  | ||||
|   list(APPEND LITE_EAGER_SYMOBLICATION_SRCS "") | ||||
|   if(USE_SOURCE_DEBUG_ON_MOBILE) | ||||
|     append_filelist("libtorch_lite_eager_symbolication" LITE_EAGER_SYMOBLICATION_SRCS) | ||||
|     # For source debug on lite interpreter, we have to add dependency on pickling | ||||
|     # but references to read/writeArchiveAndTensor is not built for mobile | ||||
|     # so this condition specifically says we are building for source debug | ||||
|     # on mobile. | ||||
|     if(BUILD_LITE_INTERPRETER) | ||||
|       set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/serialization/pickle.cpp PROPERTIES COMPILE_FLAGS "-DC10_MOBILE -DFEATURE_TORCH_MOBILE") | ||||
|     endif() | ||||
|   endif() | ||||
|  | ||||
|   # Switch between the full jit interpreter and lite interpreter | ||||
|   if(BUILD_LITE_INTERPRETER) | ||||
|     append_filelist("libtorch_lite_cmake_sources" LIBTORCH_CMAKE_SRCS) | ||||
|     list(APPEND LIBTORCH_CMAKE_SRCS ${LITE_EAGER_SYMOBLICATION_SRCS}) | ||||
|   else() | ||||
|     append_filelist("libtorch_cmake_sources" LIBTORCH_CMAKE_SRCS) | ||||
|  | ||||
| @ -565,6 +578,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) | ||||
|        ${TORCH_SRC_DIR}/csrc/jit/mobile/sequential.cpp | ||||
|        ) | ||||
|     list(APPEND TORCH_SRCS ${MOBILE_SRCS}) | ||||
|     list(APPEND TORCH_SRCS ${LITE_EAGER_SYMOBLICATION_SRCS}) | ||||
|   endif() | ||||
|  | ||||
|   # This one needs to be unconditionally added as Functions.cpp is also unconditionally added | ||||
|  | ||||
| @ -4,12 +4,27 @@ from torch.utils.mobile_optimizer import optimize_for_mobile | ||||
| import io | ||||
| from typing import Dict, List, NamedTuple | ||||
| from collections import namedtuple | ||||
| import inspect | ||||
|  | ||||
| from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list | ||||
| from torch.testing._internal.common_utils import TestCase, run_tests | ||||
|  | ||||
| class TestLiteScriptModule(TestCase): | ||||
|  | ||||
|     def getScriptExportImportCopy(self, m, save_mobile_debug_info=True, also_test_file=False): | ||||
|         m_scripted = torch.jit.script(m) | ||||
|  | ||||
|         if not also_test_file: | ||||
|             buffer = io.BytesIO(m_scripted._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=save_mobile_debug_info)) | ||||
|             buffer.seek(0) | ||||
|             mobile_module = _load_for_lite_interpreter(buffer) | ||||
|             return mobile_module | ||||
|  | ||||
|         with TemporaryFileName() as fname: | ||||
|             m_scripted._save_for_lite_interpreter(fname, _save_mobile_debug_info=save_mobile_debug_info) | ||||
|             mobile_module = _load_for_lite_interpreter(fname) | ||||
|             return mobile_module | ||||
|  | ||||
|     def test_load_mobile_module(self): | ||||
|         class MyTestModule(torch.nn.Module): | ||||
|             def __init__(self): | ||||
| @ -374,5 +389,69 @@ class TestLiteScriptModule(TestCase): | ||||
|         actual_ops = _export_operator_list(mobile_module) | ||||
|         self.assertEqual(actual_ops, expected_ops) | ||||
|  | ||||
|     def test_source_range_simple(self): | ||||
|  | ||||
|         class FooTest(torch.jit.ScriptModule): | ||||
|             @torch.jit.script_method | ||||
|             def forward(self, x, w): | ||||
|                 return torch.mm(x, w.t()) | ||||
|  | ||||
|         ft = FooTest() | ||||
|         loaded = self.getScriptExportImportCopy(ft) | ||||
|         _, lineno = inspect.getsourcelines(FooTest) | ||||
|  | ||||
|         with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)): | ||||
|             loaded(torch.rand(3, 4), torch.rand(30, 40)) | ||||
|  | ||||
|     def test_source_range_raise_exception(self): | ||||
|  | ||||
|         class FooTest2(torch.jit.ScriptModule): | ||||
|             @torch.jit.script_method | ||||
|             def forward(self): | ||||
|                 raise RuntimeError('foo') | ||||
|  | ||||
|         _, lineno = inspect.getsourcelines(FooTest2) | ||||
|  | ||||
|         with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)): | ||||
|             ft = FooTest2() | ||||
|             loaded = self.getScriptExportImportCopy(ft) | ||||
|             loaded() | ||||
|  | ||||
|     def test_source_range_function_call(self): | ||||
|         class FooTest3(torch.jit.ScriptModule): | ||||
|             @torch.jit.script_method | ||||
|             def add_method(self, x, w): | ||||
|                 return x + w | ||||
|  | ||||
|             @torch.jit.script_method | ||||
|             def forward(self, x, y, w): | ||||
|                 x = x * y | ||||
|                 x = x + 2 | ||||
|                 return self.add_method(x, w) | ||||
|  | ||||
|         ft = FooTest3() | ||||
|         loaded = self.getScriptExportImportCopy(ft) | ||||
|         _, lineno = inspect.getsourcelines(FooTest3) | ||||
|  | ||||
|         with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)): | ||||
|             loaded(torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40)) | ||||
|  | ||||
|     def test_source_range_no_debug_info(self): | ||||
|  | ||||
|         class FooTest4(torch.jit.ScriptModule): | ||||
|             @torch.jit.script_method | ||||
|             def forward(self, x, w): | ||||
|                 return torch.mm(x, w.t()) | ||||
|  | ||||
|         ft = FooTest4() | ||||
|         loaded = self.getScriptExportImportCopy(ft, save_mobile_debug_info=False) | ||||
|  | ||||
|         try: | ||||
|             loaded(torch.rand(3, 4), torch.rand(30, 40)) | ||||
|         except RuntimeError as e: | ||||
|             error_message = f"{e}" | ||||
|         self.assertTrue("test_lite_script_module.py" not in error_message) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     run_tests() | ||||
|  | ||||
| @ -4182,8 +4182,8 @@ def foo(xyz): | ||||
|         debug_files = debug_records_from_mod(ft3) | ||||
|         for debug_file in debug_files: | ||||
|             for i in range(len(debug_file) - 1): | ||||
|                 offset, source_range = debug_file[i] | ||||
|                 offset2, source_range2 = debug_file[i + 1] | ||||
|                 offset, source_range_tag, source_range = debug_file[i] | ||||
|                 offset2, source_range_tag2, source_range2 = debug_file[i + 1] | ||||
|                 self.assertNotEqual(source_range, source_range2) | ||||
|  | ||||
|     def test_circular_dependency(self): | ||||
|  | ||||
| @ -356,6 +356,17 @@ torch_mobile_core = [ | ||||
|     "torch/csrc/jit/runtime/register_special_ops.cpp", | ||||
| ] | ||||
|  | ||||
| libtorch_lite_eager_symbolication = [ | ||||
|     "torch/csrc/jit/frontend/source_range.cpp", | ||||
|     "torch/csrc/jit/mobile/debug_info.cpp", | ||||
|     "torch/csrc/jit/serialization/source_range_serialization.cpp", | ||||
|     # Later we can split serialization and deserialization logic | ||||
|     # to have better separation within build and only build relevant parts. | ||||
|     "torch/csrc/jit/serialization/pickle.cpp", | ||||
|     "torch/csrc/jit/serialization/pickler.cpp", | ||||
|     "torch/csrc/jit/serialization/unpickler.cpp", | ||||
| ] | ||||
|  | ||||
| # TODO: core_trainer_sources is not necessary for libtorch lite | ||||
| libtorch_lite_cmake_sources = sorted(core_trainer_sources + core_sources_common + torch_mobile_core) | ||||
|  | ||||
| @ -368,6 +379,9 @@ libtorch_extra_sources = libtorch_core_jit_sources + [ | ||||
|     "torch/csrc/jit/api/module_save.cpp", | ||||
|     "torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp", | ||||
|     "torch/csrc/jit/mobile/export_data.cpp", | ||||
|     # To be included for eager symbolication in lite interpreter | ||||
|     # when it is built in libtorch | ||||
|     "torch/csrc/jit/mobile/debug_info.cpp", | ||||
|     "torch/csrc/jit/mobile/function.cpp", | ||||
|     "torch/csrc/jit/mobile/import.cpp", | ||||
|     "torch/csrc/jit/mobile/import_data.cpp", | ||||
|  | ||||
| @ -3,6 +3,11 @@ | ||||
|  | ||||
| namespace torch { | ||||
| namespace jit { | ||||
| size_t SourceRangeHasher::operator()(const torch::jit::SourceRange& key) const { | ||||
|   return ( | ||||
|       std::hash<uintptr_t>()(reinterpret_cast<uintptr_t>(key.source().get())) ^ | ||||
|       std::hash<size_t>()(key.start()) ^ std::hash<size_t>()(key.end())); | ||||
| } | ||||
|  | ||||
| c10::optional<SourceRange> Source::findSourceRangeThatGenerated( | ||||
|     const SourceRange& range) { | ||||
|  | ||||
| @ -5,6 +5,7 @@ | ||||
| #include <algorithm> | ||||
| #include <iostream> | ||||
| #include <memory> | ||||
| #include <unordered_map> | ||||
| namespace torch { | ||||
| namespace jit { | ||||
|  | ||||
| @ -178,6 +179,11 @@ struct TORCH_API SourceRange { | ||||
|   size_t end_; | ||||
| }; | ||||
|  | ||||
| struct SourceRangeHasher { | ||||
|  public: | ||||
|   size_t operator()(const torch::jit::SourceRange& key) const; | ||||
| }; | ||||
|  | ||||
| struct StackEntry { | ||||
|   std::string filename; | ||||
|   SourceRange range; | ||||
| @ -201,6 +207,8 @@ struct TaggedRange { | ||||
|   SourceRange range; | ||||
| }; | ||||
| using SourceRangeRecords = std::vector<TaggedRange>; | ||||
| using SourceRangeTagMap = | ||||
|     std::unordered_map<SourceRange, int64_t, SourceRangeHasher>; | ||||
|  | ||||
| } // namespace jit | ||||
| } // namespace torch | ||||
|  | ||||
							
								
								
									
										53
									
								
								torch/csrc/jit/mobile/debug_info.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								torch/csrc/jit/mobile/debug_info.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,53 @@ | ||||
| #include <torch/csrc/jit/mobile/debug_info.h> | ||||
| #include <torch/csrc/jit/serialization/source_range_serialization.h> | ||||
|  | ||||
| #include <ATen/core/ivalue.h> | ||||
| #include <torch/csrc/jit/serialization/pickle.h> | ||||
|  | ||||
| #include <c10/util/string_view.h> | ||||
|  | ||||
| namespace torch { | ||||
| namespace jit { | ||||
|  | ||||
| MobileDebugTable::MobileDebugTable( | ||||
|     std::unique_ptr<caffe2::serialize::PyTorchStreamReader>& reader) { | ||||
|   const std::vector<std::string>& record_names = reader->getAllRecords(); | ||||
|   const c10::string_view suffix(".debug_pkl"); | ||||
|   for (const auto& record_name : record_names) { | ||||
|     if (c10::string_view(record_name).ends_with(suffix)) { | ||||
|       at::DataPtr debug_data; | ||||
|       size_t debug_size{0}; | ||||
|       std::tie(debug_data, debug_size) = reader->getRecord(record_name); | ||||
|       auto ivalues = | ||||
|           jit::unpickle( | ||||
|               reinterpret_cast<const char*>(debug_data.get()), debug_size) | ||||
|               .toTuple() | ||||
|               ->elements(); | ||||
|       SourceRangeDeserializer deserializer; | ||||
|       for (auto& val : ivalues) { | ||||
|         auto tup_elems = val.toTuple()->elements(); | ||||
|         // For BC we decode only tuples with 3 elements | ||||
|         // assuming it contains | ||||
|         // byte_offset, debug_handle (=source range tag), source range | ||||
|         if (tup_elems.size() == 3) { | ||||
|           int64_t debug_handle = tup_elems[kSourceRangeTagIndex].toInt(); | ||||
|           auto source_range = | ||||
|               deserializer.deserialize(tup_elems[kSourceRangeIndex]); | ||||
|           source_range_map_.emplace(debug_handle, std::move(source_range)); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| std::string MobileDebugTable::getSourceDebugString( | ||||
|     const int64_t debug_handle) const { | ||||
|   const auto it = source_range_map_.find(debug_handle); | ||||
|   if (it == source_range_map_.end()) { | ||||
|     return ""; | ||||
|   } | ||||
|   return source_range_map_.at(debug_handle).str(); | ||||
| } | ||||
|  | ||||
| } // namespace jit | ||||
| } // namespace torch | ||||
							
								
								
									
										31
									
								
								torch/csrc/jit/mobile/debug_info.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								torch/csrc/jit/mobile/debug_info.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,31 @@ | ||||
| #pragma once | ||||
| #include <c10/util/flat_hash_map.h> | ||||
| #include <caffe2/serialize/inline_container.h> | ||||
| #include <torch/csrc/jit/frontend/source_range.h> | ||||
|  | ||||
| namespace torch { | ||||
| namespace jit { | ||||
| /* | ||||
|  * MobileDebugTable: | ||||
|  * Deserializes debug_pkl records from PT model's zip archive and | ||||
|  * stores them in a map of debug handles to source range. | ||||
|  * Debug handles are unique per model and runtime, be in lite interpreter | ||||
|  * or delegate, raises exception using debug handles. | ||||
|  * getSourceDebugString method is responsible for translating debug | ||||
|  * handles to correspond debug information. | ||||
|  * At the moment this only contains information about model source. | ||||
|  * But later diffs will include entire stack corresponding to debug handle. | ||||
|  */ | ||||
| class MobileDebugTable { | ||||
|  public: | ||||
|   MobileDebugTable() = default; | ||||
|   MobileDebugTable( | ||||
|       std::unique_ptr<caffe2::serialize::PyTorchStreamReader>& reader); | ||||
|   std::string getSourceDebugString(const int64_t debug_handle) const; | ||||
|  | ||||
|  private: | ||||
|   ska::flat_hash_map<int64_t, SourceRange> source_range_map_; | ||||
| }; | ||||
|  | ||||
| } // namespace jit | ||||
| } // namespace torch | ||||
| @ -22,12 +22,13 @@ const std::string& Function::name() const { | ||||
|   return name_.name(); | ||||
| } | ||||
|  | ||||
| void Function::append_instruction(OpCode op, int X, int N) { | ||||
| void Function::append_instruction(OpCode op, int X, int N, int64_t dbg_handle) { | ||||
|   TORCH_CHECK( | ||||
|       isOpSupportedInMobile(op), | ||||
|       toString(op), | ||||
|       " is not supported in mobile module."); | ||||
|   code_->instructions_.emplace_back(op, X, N); | ||||
|   code_->debug_handles_.emplace_back(dbg_handle); | ||||
| } | ||||
|  | ||||
| bool Function::append_operator( | ||||
| @ -130,6 +131,11 @@ const std::shared_ptr<Code> Function::get_code() const { | ||||
|   return code_; | ||||
| } | ||||
|  | ||||
| int64_t Function::getExceptionDebugHandle() const { | ||||
|   size_t pc = getInterpretersExceptionPC(); | ||||
|   return (pc < code_->debug_handles_.size()) ? code_->debug_handles_[pc] : -1; | ||||
| } | ||||
|  | ||||
| } // namespace mobile | ||||
| } // namespace jit | ||||
| } // namespace torch | ||||
|  | ||||
| @ -19,7 +19,7 @@ class Function { | ||||
|   c10::IValue operator()(Stack& stack) const; | ||||
|   const std::string& name() const; | ||||
|   const c10::QualifiedName& qualname() const; | ||||
|   void append_instruction(OpCode op, int X, int N); | ||||
|   void append_instruction(OpCode op, int X, int N, int64_t dbg_handle = -1); | ||||
|   bool append_operator( | ||||
|       const std::string& name, | ||||
|       const std::string& overload_name, | ||||
| @ -37,6 +37,11 @@ class Function { | ||||
|   void setSchema(c10::FunctionSchema schema); | ||||
|   const at::optional<c10::FunctionSchema>& getSchema() const; | ||||
|  | ||||
|   // Returns the debug handle corresponding to where the execution | ||||
|   // is halted due to exception. | ||||
|   // If no corresponding debug handle is found then -1 is returned. | ||||
|   int64_t getExceptionDebugHandle() const; | ||||
|  | ||||
|  private: | ||||
|   c10::QualifiedName name_; | ||||
|   std::shared_ptr<Code> code_; | ||||
|  | ||||
| @ -305,12 +305,25 @@ void BytecodeDeserializer::parseMethods( | ||||
|       OpCode op_code = parseOpCode(ins_item[0].toString()->string().c_str()); | ||||
|       int X = ins_item[1].toInt(); | ||||
|       int N = ins_item[2].toInt(); | ||||
|       function->append_instruction(op_code, X, N); | ||||
|       // TODO: Save debug handles for all instructions, not just for OP | ||||
|       if (op_code == OP) { | ||||
|         std::string module_debug_info = (has_debug_info) | ||||
|             ? module_debug_info_list[X].toString()->string() | ||||
|             : ""; | ||||
|         // In later PRs we will refactor this to always save debug handles. | ||||
|         // Debug info, source range and inlined callstack ptr saving will become | ||||
|         // optional. | ||||
|         if (has_debug_info) { | ||||
|           auto module_debug_tuple = | ||||
|               module_debug_info_list[X].toTuple()->elements(); | ||||
|           std::string module_debug_info = | ||||
|               module_debug_tuple[0].toString()->string(); | ||||
|           int64_t debug_handle = module_debug_tuple[1].toInt(); | ||||
|           function->set_module_info(module_debug_info, i); | ||||
|           function->append_instruction(op_code, X, N, debug_handle); | ||||
|         } else { | ||||
|           function->set_module_info("", i); | ||||
|           function->append_instruction(op_code, X, N); | ||||
|         } | ||||
|       } else { | ||||
|         function->append_instruction(op_code, X, N); | ||||
|       } | ||||
|     } | ||||
|  | ||||
| @ -441,7 +454,12 @@ mobile::Module BytecodeDeserializer::deserialize( | ||||
|   } | ||||
|   parseMethods(bvals, debug_info_bvals, *mcu); | ||||
|   auto meta_dict = readMobileMetadata(mcu); | ||||
|   return mobile::Module(readArchive("data", mcu).toObject(), meta_dict, mcu); | ||||
|   auto m = mobile::Module(readArchive("data", mcu).toObject(), meta_dict, mcu); | ||||
| #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) | ||||
|   MobileDebugTable debug_table = MobileDebugTable(reader_); | ||||
|   m.setDebugTable(std::move(debug_table)); | ||||
| #endif | ||||
|   return m; | ||||
| } | ||||
|  | ||||
| std::unordered_map<std::string, std::string> BytecodeDeserializer:: | ||||
|  | ||||
| @ -4,6 +4,7 @@ | ||||
| #include <ATen/core/jit_type.h> | ||||
| #include <ATen/core/operator_name.h> | ||||
| #include <torch/csrc/jit/mobile/function.h> | ||||
| #include <torch/csrc/jit/runtime/jit_exception.h> | ||||
| #include <torch/csrc/jit/runtime/vararg_functions.h> | ||||
|  | ||||
| #include <ATen/record_function.h> | ||||
| @ -20,6 +21,8 @@ InterpreterState::InterpreterState(std::shared_ptr<Code> code) | ||||
| } | ||||
|  | ||||
| namespace { | ||||
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) | ||||
| static thread_local int64_t exception_pc_{-1}; | ||||
| void createObject(Stack& stack, const at::ClassTypePtr& type) { | ||||
|   auto userObj = c10::ivalue::Object::create( | ||||
|       c10::StrongTypePtr(type->compilation_unit(), type), | ||||
| @ -41,9 +44,14 @@ void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) { | ||||
|  | ||||
| using namespace at; | ||||
|  | ||||
| int64_t getInterpretersExceptionPC() { | ||||
|   return exception_pc_; | ||||
| } | ||||
|  | ||||
| bool InterpreterState::run(Stack& stack) { | ||||
|   size_t pc = 0; | ||||
|   while (true) { | ||||
|     try { | ||||
|       Instruction inst = code_->instructions_[pc]; | ||||
|  | ||||
|       //    std::cout << "RUNNING " << pc << " " << code_->instructions_[pc]; | ||||
| @ -57,8 +65,8 @@ bool InterpreterState::run(Stack& stack) { | ||||
|       switch (inst.op) { | ||||
|         case OP: { | ||||
|           if (at::hasGlobalCallbacks()) { | ||||
|           if (auto* mobile_debug_info = | ||||
|                   static_cast<MobileDebugInfo*>(c10::ThreadLocalDebugInfo::get( | ||||
|             if (auto* mobile_debug_info = static_cast<MobileDebugInfo*>( | ||||
|                     c10::ThreadLocalDebugInfo::get( | ||||
|                         c10::DebugInfoKind::MOBILE_RUNTIME_INFO))) { | ||||
|               mobile_debug_info->setOpIdx(pc); | ||||
|             } | ||||
| @ -209,9 +217,9 @@ bool InterpreterState::run(Stack& stack) { | ||||
|         } break; | ||||
|         case WARN: { | ||||
|           drop(stack, 1); | ||||
|         // Note: Please don't move the pop(stack) code below into the TORCH_WARN | ||||
|         // macro since TORCH_WARN fails to evaluate its arguments when | ||||
|         // STRIP_ERROR_MESSAGES is defined (which happens for production | ||||
|           // Note: Please don't move the pop(stack) code below into the | ||||
|           // TORCH_WARN macro since TORCH_WARN fails to evaluate its arguments | ||||
|           // when STRIP_ERROR_MESSAGES is defined (which happens for production | ||||
|           // mobile builds). This will cause the stack to be in an inconsistent | ||||
|           // state. It has previously resulted in a SEV (S22350). | ||||
|           const auto& sref = stack.back().toStringRef(); | ||||
| @ -222,6 +230,15 @@ bool InterpreterState::run(Stack& stack) { | ||||
|         default: | ||||
|           AT_ERROR(toString(inst.op), " is invalid."); | ||||
|       } | ||||
|     } catch (c10::Error& error) { | ||||
|       // Reason for catching and rethrowing the error is so that we can | ||||
|       // set the exception pc that is queried later | ||||
|       exception_pc_ = pc; | ||||
|       TORCH_RETHROW(error); | ||||
|     } catch (...) { | ||||
|       exception_pc_ = pc; | ||||
|       throw; | ||||
|     } | ||||
|     //  for (auto val : stack) { | ||||
|     //    if (val.isTensor()) { | ||||
|     //      std::cout << val.toTensor().sizes() << std::endl; | ||||
|  | ||||
| @ -8,9 +8,13 @@ namespace torch { | ||||
| namespace jit { | ||||
| namespace mobile { | ||||
| using Stack = std::vector<c10::IValue>; | ||||
| using DebugHandle = int64_t; | ||||
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) | ||||
| struct Code { | ||||
|   // TODO: Combine instructions and debug handles vector | ||||
|   // into std::vector<<std::pair<Instruction, DebugHandle>> | ||||
|   std::vector<Instruction> instructions_; | ||||
|   std::vector<DebugHandle> debug_handles_; | ||||
|   std::vector<c10::OperatorName> op_names_; | ||||
|   std::vector<std::function<void(Stack&)>> operators_; | ||||
|   std::vector<c10::IValue> constants_; | ||||
| @ -28,6 +32,14 @@ struct InterpreterState { | ||||
|   std::vector<c10::IValue> registers_; | ||||
| }; | ||||
|  | ||||
| // Interpreter executes instruction in a loop one by one | ||||
| // from a list of instructions. PC is a program counter pointer | ||||
| // pointing to the current instruction being executed. | ||||
| // This function returns the current PC. | ||||
| // Note that this is set only when exception occurs. | ||||
| // since this is a thread local variable and setting it for | ||||
| // every instruction will add overhead of thread local variable access. | ||||
| int64_t getInterpretersExceptionPC(); | ||||
| } // namespace mobile | ||||
| } // namespace jit | ||||
| } // namespace torch | ||||
|  | ||||
| @ -166,6 +166,11 @@ void Method::run(Stack& stack) const { | ||||
|       observer->onExitRunMethod(instance_key); | ||||
|     } | ||||
|   } catch (c10::Error& error) { | ||||
| #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) | ||||
|     auto debug_string = owner_->getDebugTable().getSourceDebugString( | ||||
|         function_->getExceptionDebugHandle()); | ||||
|     error.add_context(debug_string); | ||||
| #endif | ||||
|     if (observer) { | ||||
|       observer->onFailRunMethod(instance_key, error.what()); | ||||
|     } | ||||
| @ -183,6 +188,11 @@ void Method::run(Stack& stack) const { | ||||
|         } | ||||
|       } | ||||
|     } catch (c10::Error& error) { | ||||
| #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) | ||||
|       auto debug_string = owner_->getDebugTable().getSourceDebugString( | ||||
|           function_->getExceptionDebugHandle()); | ||||
|       error.add_context(debug_string); | ||||
| #endif | ||||
|       if (observer) { | ||||
|         observer->onFailRunMethod(instance_key, error.what()); | ||||
|       } | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| #pragma once | ||||
| #include <ATen/core/jit_type.h> | ||||
| #include <torch/csrc/jit/mobile/debug_info.h> | ||||
| #include <torch/csrc/jit/mobile/function.h> | ||||
| #include <torch/csrc/jit/mobile/method.h> | ||||
|  | ||||
| @ -109,10 +110,18 @@ class TORCH_API Module { | ||||
|     return or_else; | ||||
|   } | ||||
|  | ||||
|   void setDebugTable(MobileDebugTable&& debug_table) { | ||||
|     debug_table_ = std::move(debug_table); | ||||
|   } | ||||
|   const MobileDebugTable& getDebugTable() const { | ||||
|     return debug_table_; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   c10::intrusive_ptr<c10::ivalue::Object> object_; | ||||
|   std::unordered_map<std::string, std::string> metadata_; | ||||
|   std::shared_ptr<CompilationUnit> cu_; | ||||
|   MobileDebugTable debug_table_; | ||||
| }; | ||||
| } // namespace mobile | ||||
| } // namespace jit | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| #include <torch/csrc/jit/serialization/export.h> | ||||
|  | ||||
| #include <c10/util/Exception.h> | ||||
| #include <torch/csrc/jit/frontend/source_range.h> | ||||
| #include <torch/csrc/jit/ir/attributes.h> | ||||
| #include <torch/csrc/jit/ir/ir.h> | ||||
| #include <torch/csrc/jit/ir/type_hashing.h> | ||||
| @ -110,7 +111,8 @@ std::string getModuleTypeName(const Module& module, const std::string& prefix) { | ||||
| std::pair<IValue, c10::optional<IValue>> getFunctionTuple( | ||||
|     const Module& module, | ||||
|     const Function& func, | ||||
|     bool save_mobile_debug_info) { | ||||
|     const bool save_mobile_debug_info, | ||||
|     const SourceRangeTagMap& source_range_tag_map) { | ||||
|   auto graph = func.graph()->copy(); | ||||
|  | ||||
|   Inline(*graph); | ||||
| @ -122,6 +124,7 @@ std::pair<IValue, c10::optional<IValue>> getFunctionTuple( | ||||
|   std::vector<c10::OperatorName> opnames; | ||||
|   std::vector<std::string> method_names; | ||||
|   std::vector<std::string> op_module_paths; | ||||
|   std::vector<int64_t> op_source_debug_tags; | ||||
|   for (size_t i = 0; i < instructions_copy.size(); ++i) { | ||||
|     Instruction ins = instructions_copy[i]; | ||||
|     if (ins.op == OP || ins.op == OPN) { | ||||
| @ -129,7 +132,30 @@ std::pair<IValue, c10::optional<IValue>> getFunctionTuple( | ||||
|       opnames.emplace_back(node->schema().operator_name()); | ||||
|       if (save_mobile_debug_info) { | ||||
|         std::string root_scope_string = getModuleTypeName(module, "top"); | ||||
|         // A little explanation as to why node->sourceRange() is not enough | ||||
|         // and we need to do node->sourceRange().findSourceRangeThatGenerated() | ||||
|         // When you do m = torch.jit.script/trace(model) | ||||
|         // the scripted model has graphs for methods with nodes. | ||||
|         // The nodes of this graph are annotated with sourceRange that is | ||||
|         // original python code. However when such a model is serialized via | ||||
|         // torch.jit.save, what is saved is compiled python code of the model. | ||||
|         // The compiled code for methods is effectively compiled graph which | ||||
|         // contains prim/aten etc. ops. This is not the same as original python | ||||
|         // code. When such a serialized model is loaded via torch.jit.load, | ||||
|         // node->sourceRange() does not point to original python code but points | ||||
|         // to transformed/compiled python code. So in order to get original | ||||
|         // python code which is serialized in debug_pkl, it is necessary to do | ||||
|         // node->sourceRange().findSourceRangeThatGenerated() | ||||
|         auto source_range = node->sourceRange().findSourceRangeThatGenerated() | ||||
|             ? node->sourceRange().findSourceRangeThatGenerated().value() | ||||
|             : node->sourceRange(); | ||||
|         int64_t source_range_tag{-1}; | ||||
|         const auto& it = source_range_tag_map.find(source_range); | ||||
|         if (it != source_range_tag_map.end()) { | ||||
|           source_range_tag = it->second; | ||||
|         } | ||||
|         op_module_paths.emplace_back(getModulePath(node, root_scope_string)); | ||||
|         op_source_debug_tags.emplace_back(source_range_tag); | ||||
|       } | ||||
|     } | ||||
|     // CALL nodes at this point represent built-in (i.e. non-Graph) | ||||
| @ -280,12 +306,22 @@ std::pair<IValue, c10::optional<IValue>> getFunctionTuple( | ||||
|   c10::optional<IValue> debug_info_vals; | ||||
|   if (save_mobile_debug_info) { | ||||
|     // module debug info | ||||
|     std::vector<IValue> module_paths; | ||||
|     module_paths.reserve(op_module_paths.size()); | ||||
|     for (auto& path : op_module_paths) { | ||||
|       module_paths.emplace_back(std::move(path)); | ||||
|     // Temporarily adding source debug tag here. | ||||
|     // In the diffs to follow we should move to serializing | ||||
|     // InlinedCallStack. | ||||
|     // Then we will just serialize either a vector or dictionary | ||||
|     // of debug handles, a.k.a PCs for lite interpreter, | ||||
|     // to InlinedCallStack | ||||
|     std::vector<IValue> module_debug_tuples; | ||||
|     module_debug_tuples.reserve(op_module_paths.size()); | ||||
|     for (size_t i = 0; i < op_module_paths.size(); ++i) { | ||||
|       auto& path = op_module_paths[i]; | ||||
|       int64_t source_debug_tag = op_source_debug_tags[i]; | ||||
|       module_debug_tuples.emplace_back( | ||||
|           c10::ivalue::Tuple::create({std::move(path), source_debug_tag})); | ||||
|     } | ||||
|     auto module_debug_info = Table({{"module_debug_info", Tup(module_paths)}}); | ||||
|     auto module_debug_info = | ||||
|         Table({{"module_debug_info", Tup(module_debug_tuples)}}); | ||||
|     debug_info_vals = Tup({func.qualname().qualifiedName(), module_debug_info}); | ||||
|   } | ||||
|   return std::make_pair(bytecode_vals, debug_info_vals); | ||||
| @ -297,7 +333,8 @@ void setstateTuple( | ||||
|     std::vector<c10::IValue>& elements, | ||||
|     std::unordered_set<std::string>& qn_cache, | ||||
|     c10::optional<std::vector<c10::IValue>>& debug_info_elements, | ||||
|     bool save_mobile_debug_info) { | ||||
|     const bool save_mobile_debug_info, | ||||
|     const SourceRangeTagMap& source_range_map) { | ||||
|   if (!ivalue.isObject()) | ||||
|     return; | ||||
|   auto obj = ivalue.toObject(); | ||||
| @ -309,8 +346,8 @@ void setstateTuple( | ||||
|       return; | ||||
|     } | ||||
|     if (setstate.isGraphFunction()) { | ||||
|       auto func_tuple = | ||||
|           getFunctionTuple(module, setstate, save_mobile_debug_info); | ||||
|       auto func_tuple = getFunctionTuple( | ||||
|           module, setstate, save_mobile_debug_info, source_range_map); | ||||
|       elements.push_back(func_tuple.first); | ||||
|       qn_cache.emplace(qn); | ||||
|       if (save_mobile_debug_info) { | ||||
| @ -325,7 +362,8 @@ void setstateTuple( | ||||
|           elements, | ||||
|           qn_cache, | ||||
|           debug_info_elements, | ||||
|           save_mobile_debug_info); | ||||
|           save_mobile_debug_info, | ||||
|           source_range_map); | ||||
|     } | ||||
|   } | ||||
| } | ||||
| @ -335,7 +373,8 @@ void moduleMethodsTuple( | ||||
|     const Module& module, | ||||
|     std::vector<c10::IValue>& elements, // note: appended to in-place | ||||
|     c10::optional<std::vector<c10::IValue>>& debug_info_elements, | ||||
|     bool save_mobile_debug_info) { | ||||
|     const bool save_mobile_debug_info, | ||||
|     const SourceRangeTagMap& source_range_map = SourceRangeTagMap()) { | ||||
|   auto methods = module.get_methods(); | ||||
|   std::unordered_set<std::string> qn_cache; | ||||
|   // top level methods | ||||
| @ -344,8 +383,8 @@ void moduleMethodsTuple( | ||||
|     if (qn_cache.find(qn) != qn_cache.end()) { | ||||
|       continue; | ||||
|     } | ||||
|     auto func_tuple = | ||||
|         getFunctionTuple(module, method.function(), save_mobile_debug_info); | ||||
|     auto func_tuple = getFunctionTuple( | ||||
|         module, method.function(), save_mobile_debug_info, source_range_map); | ||||
|     elements.push_back(func_tuple.first); | ||||
|     qn_cache.emplace(qn); | ||||
|     if (save_mobile_debug_info) { | ||||
| @ -360,7 +399,8 @@ void moduleMethodsTuple( | ||||
|       elements, | ||||
|       qn_cache, | ||||
|       debug_info_elements, | ||||
|       save_mobile_debug_info); | ||||
|       save_mobile_debug_info, | ||||
|       source_range_map); | ||||
| } | ||||
|  | ||||
| void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) { | ||||
| @ -488,6 +528,15 @@ class ScriptModuleSerializer { | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   void updateSourceRangeTags(const SourceRangeRecords& ranges) { | ||||
|     for (const auto& range : ranges) { | ||||
|       if (source_range_tags_.find(range.range) == source_range_tags_.end()) { | ||||
|         source_range_tags_[range.range] = current_source_range_tag_; | ||||
|         current_source_range_tag_++; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   void writeCode(const at::NamedTypePtr& root_type) { | ||||
|     class_deps_.add(root_type); | ||||
|     for (size_t i = 0; i < class_deps_.size(); ++i) { | ||||
| @ -496,6 +545,7 @@ class ScriptModuleSerializer { | ||||
|       convertNamedType(class_deps_[i]); | ||||
|     } | ||||
|  | ||||
|     current_source_range_tag_ = 0; | ||||
|     // Mapping of filename => src. We need this because multiple classes may go | ||||
|     // in the same file (e.g. foo.bar.Baz and foo.bar.Qux) | ||||
|     for (auto& item : file_streams_) { | ||||
| @ -517,7 +567,9 @@ class ScriptModuleSerializer { | ||||
|       // Write out the debug information | ||||
|       std::string debugFilename = filename + ".debug_pkl"; | ||||
|       SourceRangePickler source_range_pickler; | ||||
|       auto range_data = source_range_pickler.pickle(item.value().ranges()); | ||||
|       updateSourceRangeTags(item.value().ranges()); | ||||
|       auto range_data = source_range_pickler.pickle( | ||||
|           item.value().ranges(), source_range_tags_); | ||||
|       writer_.writeRecord( | ||||
|           debugFilename, | ||||
|           range_data.data(), | ||||
| @ -526,7 +578,7 @@ class ScriptModuleSerializer { | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   void writeByteCode(const Module& module, bool save_mobile_debug_info) { | ||||
|   void writeByteCode(const Module& module, const bool save_mobile_debug_info) { | ||||
|     std::vector<c10::IValue> elements; | ||||
|     elements.emplace_back( | ||||
|         static_cast<int64_t>(caffe2::serialize::kProducedBytecodeVersion)); | ||||
| @ -538,7 +590,11 @@ class ScriptModuleSerializer { | ||||
|     } | ||||
|  | ||||
|     moduleMethodsTuple( | ||||
|         module, elements, debug_info_elements, save_mobile_debug_info); | ||||
|         module, | ||||
|         elements, | ||||
|         debug_info_elements, | ||||
|         save_mobile_debug_info, | ||||
|         source_range_tags_); | ||||
|     auto telements = Tup(std::move(elements)); | ||||
|     writeArchive("bytecode", telements); | ||||
|     if (save_mobile_debug_info) { | ||||
| @ -585,6 +641,32 @@ class ScriptModuleSerializer { | ||||
|   // qualifier, e.g. '__torch__.Bar' -> PythonPrint for the file that will be | ||||
|   // created | ||||
|   OrderedDict<std::string, PythonPrint> file_streams_; | ||||
|  | ||||
|   // Uniquely identifies a SourceRange in a model. | ||||
|   // SourceRanges are associated with Nodes of Graphs. | ||||
|   // However for mobile deployment we dont intend to ship | ||||
|   // full JIT with capabilities of reading code and constructing | ||||
|   // graphs. | ||||
|   // Instead we serialize the Code generated from graph of the methods. | ||||
|   // Code is serialized in bytecode format that contains instructions | ||||
|   // corresponding to the nodes of the graph. Since original graph is gone, the | ||||
|   // question is how do we identify where the ops, in serialized bytecode, come | ||||
|   // from in original model code. We do this in two parts. | ||||
|   // 1. Associate a unique tag to SourceRange. | ||||
|   // 2. Serialize this unique_tag. | ||||
|   //  2.1 Meaning save <byte_offset, source_range_tag, source range> instead of | ||||
|   //      <byte_offset, source range> | ||||
|   // 3. During serializing model for mobile, i.e. bytecode generation, | ||||
|   //    save unique tag of SourceRange corresponding to the Node. | ||||
|   // 4. During deserialization, read all the debug_pkl, to construct a map | ||||
|   //    of <unique_tag, SourceRange> and use tag saved with OPs in bytecode | ||||
|   //    to lookup the source range. | ||||
|   // Strictly speaking we will serialize InlinedCallStack directly, which | ||||
|   // contains SourceRange. This way we have access to entire callstack and not | ||||
|   // just source information about where the node is, since bytecode inlines the | ||||
|   // graph before saving it. | ||||
|   SourceRangeTagMap source_range_tags_; | ||||
|   int64_t current_source_range_tag_; | ||||
| }; | ||||
|  | ||||
| void ExportModule( | ||||
|  | ||||
| @ -1,7 +1,6 @@ | ||||
| #include <torch/csrc/jit/serialization/source_range_serialization.h> | ||||
| #include <torch/csrc/jit/serialization/source_range_serialization_impl.h> | ||||
|  | ||||
| #include <ATen/core/ivalue.h> | ||||
| #include <torch/csrc/jit/serialization/pickle.h> | ||||
|  | ||||
| namespace torch { | ||||
| @ -23,9 +22,7 @@ class SourceRangeSerializer { | ||||
|   std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources; | ||||
| }; | ||||
|  | ||||
| class SourceRangeDeserializer { | ||||
|  public: | ||||
|   SourceRange deserialize(const c10::IValue& iv) { | ||||
| SourceRange SourceRangeDeserializer::deserialize(const c10::IValue& iv) { | ||||
|   auto tup_elems = iv.toTuple()->elements(); | ||||
|   TORCH_INTERNAL_ASSERT(tup_elems.size() == 3); | ||||
|   std::shared_ptr<Source> source_ = deserialize_source(tup_elems[0]); | ||||
| @ -34,8 +31,8 @@ class SourceRangeDeserializer { | ||||
|   return SourceRange(source_, start_, end_); | ||||
| } | ||||
|  | ||||
|  private: | ||||
|   std::shared_ptr<Source> deserialize_source(const c10::IValue& iv) { | ||||
| std::shared_ptr<Source> SourceRangeDeserializer::deserialize_source( | ||||
|     const c10::IValue& iv) { | ||||
|   auto tup = iv.toTuple(); | ||||
|   if (cached_sources.count(tup)) { | ||||
|     return cached_sources.at(tup); | ||||
| @ -44,8 +41,7 @@ class SourceRangeDeserializer { | ||||
|   auto tup_elems = tup->elements(); | ||||
|   TORCH_INTERNAL_ASSERT(tup_elems.size() == 3); | ||||
|   std::string text_ = tup_elems[0].toString()->string(); | ||||
|     c10::optional<std::string> filename_ = | ||||
|         tup_elems[1].toOptional<std::string>(); | ||||
|   c10::optional<std::string> filename_ = tup_elems[1].toOptional<std::string>(); | ||||
|   int64_t starting_line_no_ = tup_elems[2].toInt(); | ||||
|  | ||||
|   auto source = std::make_shared<Source>( | ||||
| @ -54,12 +50,6 @@ class SourceRangeDeserializer { | ||||
|   return source; | ||||
| } | ||||
|  | ||||
|   std::unordered_map< | ||||
|       c10::intrusive_ptr<c10::ivalue::Tuple>, | ||||
|       std::shared_ptr<Source>> | ||||
|       cached_sources; | ||||
| }; | ||||
|  | ||||
| c10::IValue SourceRangeSerializer::serialize(const SourceRange& sr) { | ||||
|   std::vector<c10::IValue> elements = { | ||||
|       serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end()}; | ||||
| @ -84,11 +74,20 @@ c10::IValue SourceRangeSerializer::serialize_source( | ||||
|  | ||||
| SourceRangePickler::SourceRangePickler() : srs(new SourceRangeSerializer()) {} | ||||
|  | ||||
| std::vector<char> SourceRangePickler::pickle(const SourceRangeRecords& ranges) { | ||||
| std::vector<char> SourceRangePickler::pickle( | ||||
|     const SourceRangeRecords& ranges, | ||||
|     const SourceRangeTagMap& source_range_tags) { | ||||
|   std::vector<c10::IValue> ivalues; | ||||
|   for (const auto& range : ranges) { | ||||
|     int64_t source_range_tag{-1}; | ||||
|     const auto& it = source_range_tags.find(range.range); | ||||
|     if (it != source_range_tags.end()) { | ||||
|       source_range_tag = it->second; | ||||
|     } | ||||
|     std::vector<c10::IValue> row_elems{ | ||||
|         (int64_t)range.bytes, srs->serialize(range.range)}; | ||||
|         (int64_t)range.bytes, | ||||
|         srs->serialize(range.range), | ||||
|         static_cast<int64_t>(source_range_tag)}; | ||||
|     ivalues.emplace_back(c10::ivalue::Tuple::create(std::move(row_elems))); | ||||
|   } | ||||
|   std::vector<at::Tensor> table; | ||||
| @ -118,8 +117,8 @@ void ConcreteSourceRangeUnpickler::unpickle() { | ||||
|   unpickled_records = std::make_shared<SourceRangeRecords>(); | ||||
|   for (auto& val : ivalues) { | ||||
|     auto tup_elems = val.toTuple()->elements(); | ||||
|     int64_t offset = tup_elems[0].toInt(); | ||||
|     auto source_range = deserializer->deserialize(tup_elems[1]); | ||||
|     int64_t offset = tup_elems[kByteOffsetIndex].toInt(); | ||||
|     auto source_range = deserializer->deserialize(tup_elems[kSourceRangeIndex]); | ||||
|     unpickled_records->emplace_back(offset, std::move(source_range)); | ||||
|   } | ||||
| } | ||||
|  | ||||
| @ -3,6 +3,8 @@ | ||||
| #include <c10/core/Allocator.h> | ||||
| #include <torch/csrc/jit/frontend/source_range.h> | ||||
|  | ||||
| #include <ATen/core/ivalue.h> | ||||
|  | ||||
| #include <unordered_map> | ||||
| #include <vector> | ||||
|  | ||||
| @ -15,18 +17,34 @@ namespace jit { | ||||
|  | ||||
| class Pickler; | ||||
| class SourceRangeSerializer; | ||||
| class SourceRangeDeserializer; | ||||
| static constexpr size_t kByteOffsetIndex = 0; | ||||
| static constexpr size_t kSourceRangeIndex = 1; | ||||
| static constexpr size_t kSourceRangeTagIndex = 2; | ||||
|  | ||||
| class SourceRangePickler { | ||||
|  public: | ||||
|   SourceRangePickler(); | ||||
|  | ||||
|   std::vector<char> pickle(const SourceRangeRecords& ranges); | ||||
|   std::vector<char> pickle( | ||||
|       const SourceRangeRecords& ranges, | ||||
|       const SourceRangeTagMap& source_range_tags); | ||||
|  | ||||
|  private: | ||||
|   std::shared_ptr<SourceRangeSerializer> srs; | ||||
| }; | ||||
|  | ||||
| class SourceRangeDeserializer { | ||||
|  public: | ||||
|   SourceRange deserialize(const c10::IValue& iv); | ||||
|  | ||||
|  private: | ||||
|   std::shared_ptr<Source> deserialize_source(const c10::IValue& iv); | ||||
|   std::unordered_map< | ||||
|       c10::intrusive_ptr<c10::ivalue::Tuple>, | ||||
|       std::shared_ptr<Source>> | ||||
|       cached_sources; | ||||
| }; | ||||
|  | ||||
| class SourceRangeUnpickler { | ||||
|  public: | ||||
|   virtual c10::optional<SourceRange> findSourceRangeThatGenerated( | ||||
|  | ||||
| @ -240,7 +240,8 @@ def get_model_info( | ||||
|  | ||||
|             code_parts = [] | ||||
|             for di, di_next in zip(debug_info, debug_info[1:]): | ||||
|                 start, source_range = di | ||||
|                 # accounting for source range serialization format change | ||||
|                 start, source_range, _ = di | ||||
|                 end = di_next[0] | ||||
|                 assert end > start | ||||
|                 source, s_start, s_end = source_range | ||||
|  | ||||
		Reference in New Issue
	
	Block a user