diff --git a/CMakeLists.txt b/CMakeLists.txt index cc9be90f2c5a..790feb2cb913 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -254,6 +254,7 @@ if(NOT DEFINED USE_VULKAN) "ANDROID" OFF) endif() +option(USE_SOURCE_DEBUG_ON_MOBILE "Enable " ON) option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF) option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF) option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation as opposed to build-time (needs libshaderc)" OFF) @@ -647,6 +648,10 @@ if(USE_PYTORCH_METAL) string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_METAL") endif() +if(USE_SOURCE_DEBUG_ON_MOBILE) + string(APPEND CMAKE_CXX_FLAGS " -DSYMBOLICATE_MOBILE_DEBUG_HANDLE") +endif() + # ---[ Allowlist file if allowlist is specified include(cmake/Allowlist.cmake) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 39ebf0e7cfd2..d816a837f248 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -516,9 +516,22 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) list(APPEND TORCH_SRCS ${GENERATED_H_TORCH}) list(APPEND LIBTORCH_CMAKE_SRCS "") + list(APPEND LITE_EAGER_SYMOBLICATION_SRCS "") + if(USE_SOURCE_DEBUG_ON_MOBILE) + append_filelist("libtorch_lite_eager_symbolication" LITE_EAGER_SYMOBLICATION_SRCS) + # For source debug on lite interpreter, we have to add dependency on pickling + # but references to read/writeArchiveAndTensor is not built for mobile + # so this condition specifically says we are building for source debug + # on mobile. + if(BUILD_LITE_INTERPRETER) + set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/serialization/pickle.cpp PROPERTIES COMPILE_FLAGS "-DC10_MOBILE -DFEATURE_TORCH_MOBILE") + endif() + endif() + # Switch between the full jit interpreter and lite interpreter if(BUILD_LITE_INTERPRETER) append_filelist("libtorch_lite_cmake_sources" LIBTORCH_CMAKE_SRCS) + list(APPEND LIBTORCH_CMAKE_SRCS ${LITE_EAGER_SYMOBLICATION_SRCS}) else() append_filelist("libtorch_cmake_sources" LIBTORCH_CMAKE_SRCS) @@ -565,6 +578,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/mobile/sequential.cpp ) list(APPEND TORCH_SRCS ${MOBILE_SRCS}) + list(APPEND TORCH_SRCS ${LITE_EAGER_SYMOBLICATION_SRCS}) endif() # This one needs to be unconditionally added as Functions.cpp is also unconditionally added diff --git a/test/mobile/test_lite_script_module.py b/test/mobile/test_lite_script_module.py index 6369203203c9..c3bdda4ce420 100644 --- a/test/mobile/test_lite_script_module.py +++ b/test/mobile/test_lite_script_module.py @@ -4,12 +4,27 @@ from torch.utils.mobile_optimizer import optimize_for_mobile import io from typing import Dict, List, NamedTuple from collections import namedtuple +import inspect from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list from torch.testing._internal.common_utils import TestCase, run_tests class TestLiteScriptModule(TestCase): + def getScriptExportImportCopy(self, m, save_mobile_debug_info=True, also_test_file=False): + m_scripted = torch.jit.script(m) + + if not also_test_file: + buffer = io.BytesIO(m_scripted._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=save_mobile_debug_info)) + buffer.seek(0) + mobile_module = _load_for_lite_interpreter(buffer) + return mobile_module + + with TemporaryFileName() as fname: + m_scripted._save_for_lite_interpreter(fname, _save_mobile_debug_info=save_mobile_debug_info) + mobile_module = _load_for_lite_interpreter(fname) + return mobile_module + def test_load_mobile_module(self): class MyTestModule(torch.nn.Module): def __init__(self): @@ -374,5 +389,69 @@ class TestLiteScriptModule(TestCase): actual_ops = _export_operator_list(mobile_module) self.assertEqual(actual_ops, expected_ops) + def test_source_range_simple(self): + + class FooTest(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x, w): + return torch.mm(x, w.t()) + + ft = FooTest() + loaded = self.getScriptExportImportCopy(ft) + _, lineno = inspect.getsourcelines(FooTest) + + with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)): + loaded(torch.rand(3, 4), torch.rand(30, 40)) + + def test_source_range_raise_exception(self): + + class FooTest2(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self): + raise RuntimeError('foo') + + _, lineno = inspect.getsourcelines(FooTest2) + + with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)): + ft = FooTest2() + loaded = self.getScriptExportImportCopy(ft) + loaded() + + def test_source_range_function_call(self): + class FooTest3(torch.jit.ScriptModule): + @torch.jit.script_method + def add_method(self, x, w): + return x + w + + @torch.jit.script_method + def forward(self, x, y, w): + x = x * y + x = x + 2 + return self.add_method(x, w) + + ft = FooTest3() + loaded = self.getScriptExportImportCopy(ft) + _, lineno = inspect.getsourcelines(FooTest3) + + with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)): + loaded(torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40)) + + def test_source_range_no_debug_info(self): + + class FooTest4(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x, w): + return torch.mm(x, w.t()) + + ft = FooTest4() + loaded = self.getScriptExportImportCopy(ft, save_mobile_debug_info=False) + + try: + loaded(torch.rand(3, 4), torch.rand(30, 40)) + except RuntimeError as e: + error_message = f"{e}" + self.assertTrue("test_lite_script_module.py" not in error_message) + + if __name__ == '__main__': run_tests() diff --git a/test/test_jit.py b/test/test_jit.py index 5d8dd095d3db..8b15ad69c6a6 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4182,8 +4182,8 @@ def foo(xyz): debug_files = debug_records_from_mod(ft3) for debug_file in debug_files: for i in range(len(debug_file) - 1): - offset, source_range = debug_file[i] - offset2, source_range2 = debug_file[i + 1] + offset, source_range_tag, source_range = debug_file[i] + offset2, source_range_tag2, source_range2 = debug_file[i + 1] self.assertNotEqual(source_range, source_range2) def test_circular_dependency(self): diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 4e31c573eb4e..80ec81c3496d 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -356,6 +356,17 @@ torch_mobile_core = [ "torch/csrc/jit/runtime/register_special_ops.cpp", ] +libtorch_lite_eager_symbolication = [ + "torch/csrc/jit/frontend/source_range.cpp", + "torch/csrc/jit/mobile/debug_info.cpp", + "torch/csrc/jit/serialization/source_range_serialization.cpp", + # Later we can split serialization and deserialization logic + # to have better separation within build and only build relevant parts. + "torch/csrc/jit/serialization/pickle.cpp", + "torch/csrc/jit/serialization/pickler.cpp", + "torch/csrc/jit/serialization/unpickler.cpp", +] + # TODO: core_trainer_sources is not necessary for libtorch lite libtorch_lite_cmake_sources = sorted(core_trainer_sources + core_sources_common + torch_mobile_core) @@ -368,6 +379,9 @@ libtorch_extra_sources = libtorch_core_jit_sources + [ "torch/csrc/jit/api/module_save.cpp", "torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp", "torch/csrc/jit/mobile/export_data.cpp", + # To be included for eager symbolication in lite interpreter + # when it is built in libtorch + "torch/csrc/jit/mobile/debug_info.cpp", "torch/csrc/jit/mobile/function.cpp", "torch/csrc/jit/mobile/import.cpp", "torch/csrc/jit/mobile/import_data.cpp", diff --git a/torch/csrc/jit/frontend/source_range.cpp b/torch/csrc/jit/frontend/source_range.cpp index 9f3a12336813..c1deda9c2ee9 100644 --- a/torch/csrc/jit/frontend/source_range.cpp +++ b/torch/csrc/jit/frontend/source_range.cpp @@ -3,6 +3,11 @@ namespace torch { namespace jit { +size_t SourceRangeHasher::operator()(const torch::jit::SourceRange& key) const { + return ( + std::hash()(reinterpret_cast(key.source().get())) ^ + std::hash()(key.start()) ^ std::hash()(key.end())); +} c10::optional Source::findSourceRangeThatGenerated( const SourceRange& range) { diff --git a/torch/csrc/jit/frontend/source_range.h b/torch/csrc/jit/frontend/source_range.h index 36772807ca8b..da32ed238939 100644 --- a/torch/csrc/jit/frontend/source_range.h +++ b/torch/csrc/jit/frontend/source_range.h @@ -5,6 +5,7 @@ #include #include #include +#include namespace torch { namespace jit { @@ -178,6 +179,11 @@ struct TORCH_API SourceRange { size_t end_; }; +struct SourceRangeHasher { + public: + size_t operator()(const torch::jit::SourceRange& key) const; +}; + struct StackEntry { std::string filename; SourceRange range; @@ -201,6 +207,8 @@ struct TaggedRange { SourceRange range; }; using SourceRangeRecords = std::vector; +using SourceRangeTagMap = + std::unordered_map; } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/debug_info.cpp b/torch/csrc/jit/mobile/debug_info.cpp new file mode 100644 index 000000000000..403db00298f1 --- /dev/null +++ b/torch/csrc/jit/mobile/debug_info.cpp @@ -0,0 +1,53 @@ +#include +#include + +#include +#include + +#include + +namespace torch { +namespace jit { + +MobileDebugTable::MobileDebugTable( + std::unique_ptr& reader) { + const std::vector& record_names = reader->getAllRecords(); + const c10::string_view suffix(".debug_pkl"); + for (const auto& record_name : record_names) { + if (c10::string_view(record_name).ends_with(suffix)) { + at::DataPtr debug_data; + size_t debug_size{0}; + std::tie(debug_data, debug_size) = reader->getRecord(record_name); + auto ivalues = + jit::unpickle( + reinterpret_cast(debug_data.get()), debug_size) + .toTuple() + ->elements(); + SourceRangeDeserializer deserializer; + for (auto& val : ivalues) { + auto tup_elems = val.toTuple()->elements(); + // For BC we decode only tuples with 3 elements + // assuming it contains + // byte_offset, debug_handle (=source range tag), source range + if (tup_elems.size() == 3) { + int64_t debug_handle = tup_elems[kSourceRangeTagIndex].toInt(); + auto source_range = + deserializer.deserialize(tup_elems[kSourceRangeIndex]); + source_range_map_.emplace(debug_handle, std::move(source_range)); + } + } + } + } +} + +std::string MobileDebugTable::getSourceDebugString( + const int64_t debug_handle) const { + const auto it = source_range_map_.find(debug_handle); + if (it == source_range_map_.end()) { + return ""; + } + return source_range_map_.at(debug_handle).str(); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/mobile/debug_info.h b/torch/csrc/jit/mobile/debug_info.h new file mode 100644 index 000000000000..a5d04cdace2b --- /dev/null +++ b/torch/csrc/jit/mobile/debug_info.h @@ -0,0 +1,31 @@ +#pragma once +#include +#include +#include + +namespace torch { +namespace jit { +/* + * MobileDebugTable: + * Deserializes debug_pkl records from PT model's zip archive and + * stores them in a map of debug handles to source range. + * Debug handles are unique per model and runtime, be in lite interpreter + * or delegate, raises exception using debug handles. + * getSourceDebugString method is responsible for translating debug + * handles to correspond debug information. + * At the moment this only contains information about model source. + * But later diffs will include entire stack corresponding to debug handle. + */ +class MobileDebugTable { + public: + MobileDebugTable() = default; + MobileDebugTable( + std::unique_ptr& reader); + std::string getSourceDebugString(const int64_t debug_handle) const; + + private: + ska::flat_hash_map source_range_map_; +}; + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/mobile/function.cpp b/torch/csrc/jit/mobile/function.cpp index 6edb98ca942c..60db9e3d07e4 100644 --- a/torch/csrc/jit/mobile/function.cpp +++ b/torch/csrc/jit/mobile/function.cpp @@ -22,12 +22,13 @@ const std::string& Function::name() const { return name_.name(); } -void Function::append_instruction(OpCode op, int X, int N) { +void Function::append_instruction(OpCode op, int X, int N, int64_t dbg_handle) { TORCH_CHECK( isOpSupportedInMobile(op), toString(op), " is not supported in mobile module."); code_->instructions_.emplace_back(op, X, N); + code_->debug_handles_.emplace_back(dbg_handle); } bool Function::append_operator( @@ -130,6 +131,11 @@ const std::shared_ptr Function::get_code() const { return code_; } +int64_t Function::getExceptionDebugHandle() const { + size_t pc = getInterpretersExceptionPC(); + return (pc < code_->debug_handles_.size()) ? code_->debug_handles_[pc] : -1; +} + } // namespace mobile } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/function.h b/torch/csrc/jit/mobile/function.h index eb229f5abbc7..b276f048e873 100644 --- a/torch/csrc/jit/mobile/function.h +++ b/torch/csrc/jit/mobile/function.h @@ -19,7 +19,7 @@ class Function { c10::IValue operator()(Stack& stack) const; const std::string& name() const; const c10::QualifiedName& qualname() const; - void append_instruction(OpCode op, int X, int N); + void append_instruction(OpCode op, int X, int N, int64_t dbg_handle = -1); bool append_operator( const std::string& name, const std::string& overload_name, @@ -37,6 +37,11 @@ class Function { void setSchema(c10::FunctionSchema schema); const at::optional& getSchema() const; + // Returns the debug handle corresponding to where the execution + // is halted due to exception. + // If no corresponding debug handle is found then -1 is returned. + int64_t getExceptionDebugHandle() const; + private: c10::QualifiedName name_; std::shared_ptr code_; diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index 8f2cb7072e4a..53e55217647b 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -305,12 +305,25 @@ void BytecodeDeserializer::parseMethods( OpCode op_code = parseOpCode(ins_item[0].toString()->string().c_str()); int X = ins_item[1].toInt(); int N = ins_item[2].toInt(); - function->append_instruction(op_code, X, N); + // TODO: Save debug handles for all instructions, not just for OP if (op_code == OP) { - std::string module_debug_info = (has_debug_info) - ? module_debug_info_list[X].toString()->string() - : ""; - function->set_module_info(module_debug_info, i); + // In later PRs we will refactor this to always save debug handles. + // Debug info, source range and inlined callstack ptr saving will become + // optional. + if (has_debug_info) { + auto module_debug_tuple = + module_debug_info_list[X].toTuple()->elements(); + std::string module_debug_info = + module_debug_tuple[0].toString()->string(); + int64_t debug_handle = module_debug_tuple[1].toInt(); + function->set_module_info(module_debug_info, i); + function->append_instruction(op_code, X, N, debug_handle); + } else { + function->set_module_info("", i); + function->append_instruction(op_code, X, N); + } + } else { + function->append_instruction(op_code, X, N); } } @@ -441,7 +454,12 @@ mobile::Module BytecodeDeserializer::deserialize( } parseMethods(bvals, debug_info_bvals, *mcu); auto meta_dict = readMobileMetadata(mcu); - return mobile::Module(readArchive("data", mcu).toObject(), meta_dict, mcu); + auto m = mobile::Module(readArchive("data", mcu).toObject(), meta_dict, mcu); +#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) + MobileDebugTable debug_table = MobileDebugTable(reader_); + m.setDebugTable(std::move(debug_table)); +#endif + return m; } std::unordered_map BytecodeDeserializer:: diff --git a/torch/csrc/jit/mobile/interpreter.cpp b/torch/csrc/jit/mobile/interpreter.cpp index c3318542bb46..63ee5a7dd2eb 100644 --- a/torch/csrc/jit/mobile/interpreter.cpp +++ b/torch/csrc/jit/mobile/interpreter.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -20,6 +21,8 @@ InterpreterState::InterpreterState(std::shared_ptr code) } namespace { +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +static thread_local int64_t exception_pc_{-1}; void createObject(Stack& stack, const at::ClassTypePtr& type) { auto userObj = c10::ivalue::Object::create( c10::StrongTypePtr(type->compilation_unit(), type), @@ -41,186 +44,200 @@ void isinstance(Stack& stack, at::ArrayRef types) { using namespace at; +int64_t getInterpretersExceptionPC() { + return exception_pc_; +} + bool InterpreterState::run(Stack& stack) { size_t pc = 0; while (true) { - Instruction inst = code_->instructions_[pc]; + try { + Instruction inst = code_->instructions_[pc]; - // std::cout << "RUNNING " << pc << " " << code_->instructions_[pc]; - // if (inst.op == OP) { - // std::cout << ", " << code_->op_names_[inst.X].name; - // if (!code_->op_names_[inst.X].overload_name.empty()) { - // std::cout << "." << code_->op_names_[inst.X].overload_name; - // } - // } - // std::cout << std::endl; - switch (inst.op) { - case OP: { - if (at::hasGlobalCallbacks()) { - if (auto* mobile_debug_info = - static_cast(c10::ThreadLocalDebugInfo::get( - c10::DebugInfoKind::MOBILE_RUNTIME_INFO))) { - mobile_debug_info->setOpIdx(pc); + // std::cout << "RUNNING " << pc << " " << code_->instructions_[pc]; + // if (inst.op == OP) { + // std::cout << ", " << code_->op_names_[inst.X].name; + // if (!code_->op_names_[inst.X].overload_name.empty()) { + // std::cout << "." << code_->op_names_[inst.X].overload_name; + // } + // } + // std::cout << std::endl; + switch (inst.op) { + case OP: { + if (at::hasGlobalCallbacks()) { + if (auto* mobile_debug_info = static_cast( + c10::ThreadLocalDebugInfo::get( + c10::DebugInfoKind::MOBILE_RUNTIME_INFO))) { + mobile_debug_info->setOpIdx(pc); + } } - } - // TODO(iliacher): remove the workaround after RecordFunction is in - // Dispatcher - bool prev_value = isRecordFunctionEnabled(); - if (!prev_value) { - // enable only for the RecordFunction - enableRecordFunction(true); - } - RECORD_USER_SCOPE_WITH_INPUTS(code_->op_names_[inst.X].name, stack); - if (!prev_value) { - enableRecordFunction(false); - } - code_->operators_[inst.X](stack); - ++pc; - } break; - case OPN: { - stack.push_back(inst.N); - code_->operators_[inst.X](stack); - ++pc; - } break; - case INTERFACE_CALL: { - torch::jit::Function& method = - peek(stack, 0, inst.N) - .toObject() - ->type() - ->getMethod(code_->constants_[inst.X].toStringRef()); - method.run(stack); - ++pc; - } break; - case LOAD: - stack.emplace_back(reg(inst.X)); - ++pc; - break; - case MOVE: - stack.emplace_back(std::move(reg(inst.X))); - ++pc; - break; - case STORE: - reg(inst.X) = pop(stack); - ++pc; - break; - case STOREN: - for (size_t i = inst.N; i > 0; --i) { - reg(inst.X + i - 1) = pop(stack); - } - ++pc; - break; - case DROP: - pop(stack); - ++pc; - break; - case DROPR: - reg(inst.X) = IValue(); - ++pc; - break; - case LOADC: - stack.emplace_back(code_->constants_[inst.X]); - ++pc; - break; - case GET_ATTR: { - auto userObj = pop(stack).toObject(); - auto value = userObj->getSlot(inst.X); - push(stack, std::move(value)); - ++pc; - } break; - case SET_ATTR: { - auto v = pop(stack); - auto userObj = pop(stack).toObject(); - // Mobile only: since the number of slots is not known, resize the - // numAttributes before setSlot. - // NOLINTNEXTLINE(clang-diagnostic-sign-compare) - while (userObj->type()->numAttributes() <= inst.X) { - std::stringstream ss; - ss << userObj->type()->numAttributes(); - userObj->type()->addAttribute(ss.str(), c10::NoneType::get()); - } - userObj->setSlot(inst.X, std::move(v)); - ++pc; - } break; - case JF: - pc += (pop(stack).toBool()) ? 1 : inst.X; - break; - case JMP: - pc += inst.X; - break; - case LOOP: { - // stack: iteration_count, max_iter, cond, loop_carried_deps... - auto frame = stack.end() - (inst.N + 1); - int64_t trip_count = frame[0].toInt(); - int64_t max_trip_count = frame[1].toInt(); - bool cond = frame[2].toBool(); - if (trip_count < max_trip_count && cond) { - frame[2] = trip_count; - frame[0] = trip_count + 1; + // TODO(iliacher): remove the workaround after RecordFunction is in + // Dispatcher + bool prev_value = isRecordFunctionEnabled(); + if (!prev_value) { + // enable only for the RecordFunction + enableRecordFunction(true); + } + RECORD_USER_SCOPE_WITH_INPUTS(code_->op_names_[inst.X].name, stack); + if (!prev_value) { + enableRecordFunction(false); + } + code_->operators_[inst.X](stack); ++pc; - } else { - size_t n_loop_carried = inst.N - 2; - for (size_t i = 0; i < n_loop_carried; ++i) { - frame[i] = std::move(frame[i + 3]); + } break; + case OPN: { + stack.push_back(inst.N); + code_->operators_[inst.X](stack); + ++pc; + } break; + case INTERFACE_CALL: { + torch::jit::Function& method = + peek(stack, 0, inst.N) + .toObject() + ->type() + ->getMethod(code_->constants_[inst.X].toStringRef()); + method.run(stack); + ++pc; + } break; + case LOAD: + stack.emplace_back(reg(inst.X)); + ++pc; + break; + case MOVE: + stack.emplace_back(std::move(reg(inst.X))); + ++pc; + break; + case STORE: + reg(inst.X) = pop(stack); + ++pc; + break; + case STOREN: + for (size_t i = inst.N; i > 0; --i) { + reg(inst.X + i - 1) = pop(stack); } - drop(stack, 3); // iteration_count, max_iter, cond + ++pc; + break; + case DROP: + pop(stack); + ++pc; + break; + case DROPR: + reg(inst.X) = IValue(); + ++pc; + break; + case LOADC: + stack.emplace_back(code_->constants_[inst.X]); + ++pc; + break; + case GET_ATTR: { + auto userObj = pop(stack).toObject(); + auto value = userObj->getSlot(inst.X); + push(stack, std::move(value)); + ++pc; + } break; + case SET_ATTR: { + auto v = pop(stack); + auto userObj = pop(stack).toObject(); + // Mobile only: since the number of slots is not known, resize the + // numAttributes before setSlot. + // NOLINTNEXTLINE(clang-diagnostic-sign-compare) + while (userObj->type()->numAttributes() <= inst.X) { + std::stringstream ss; + ss << userObj->type()->numAttributes(); + userObj->type()->addAttribute(ss.str(), c10::NoneType::get()); + } + userObj->setSlot(inst.X, std::move(v)); + ++pc; + } break; + case JF: + pc += (pop(stack).toBool()) ? 1 : inst.X; + break; + case JMP: pc += inst.X; - } - } break; - case RET: - return false; - case LIST_CONSTRUCT: { - const auto& type = code_->types_[inst.X]->expectRef(); - listConstruct(stack, type, inst.N); - ++pc; - } break; - case LIST_UNPACK: { - listUnpack(stack, inst.X); - ++pc; - } break; - case TUPLE_CONSTRUCT: { - tupleConstruct(stack, inst.X); - ++pc; - } break; - case TUPLE_SLICE: { - tupleSlice(stack, inst.X, inst.X + inst.N); - ++pc; - } break; - case DICT_CONSTRUCT: { - const auto& type = code_->types_[inst.X]->expectRef(); - dictConstruct(stack, type, inst.N); - ++pc; - } break; - case NAMED_TUPLE_CONSTRUCT: { - namedTupleConstruct( - stack, code_->types_[inst.X]->expect(), inst.N); - ++pc; - } break; - case CREATE_OBJECT: { - auto type = code_->types_[inst.X]->expect(); - createObject(stack, type); - ++pc; - } break; - case ISINSTANCE: { - at::ArrayRef types( - &(code_->types_[inst.X]), &(code_->types_[inst.X + inst.N])); - isinstance(stack, types); - ++pc; - } break; - case WARN: { - drop(stack, 1); - // Note: Please don't move the pop(stack) code below into the TORCH_WARN - // macro since TORCH_WARN fails to evaluate its arguments when - // STRIP_ERROR_MESSAGES is defined (which happens for production - // mobile builds). This will cause the stack to be in an inconsistent - // state. It has previously resulted in a SEV (S22350). - const auto& sref = stack.back().toStringRef(); - TORCH_WARN(sref); - stack.pop_back(); - ++pc; - } break; - default: - AT_ERROR(toString(inst.op), " is invalid."); + break; + case LOOP: { + // stack: iteration_count, max_iter, cond, loop_carried_deps... + auto frame = stack.end() - (inst.N + 1); + int64_t trip_count = frame[0].toInt(); + int64_t max_trip_count = frame[1].toInt(); + bool cond = frame[2].toBool(); + if (trip_count < max_trip_count && cond) { + frame[2] = trip_count; + frame[0] = trip_count + 1; + ++pc; + } else { + size_t n_loop_carried = inst.N - 2; + for (size_t i = 0; i < n_loop_carried; ++i) { + frame[i] = std::move(frame[i + 3]); + } + drop(stack, 3); // iteration_count, max_iter, cond + pc += inst.X; + } + } break; + case RET: + return false; + case LIST_CONSTRUCT: { + const auto& type = code_->types_[inst.X]->expectRef(); + listConstruct(stack, type, inst.N); + ++pc; + } break; + case LIST_UNPACK: { + listUnpack(stack, inst.X); + ++pc; + } break; + case TUPLE_CONSTRUCT: { + tupleConstruct(stack, inst.X); + ++pc; + } break; + case TUPLE_SLICE: { + tupleSlice(stack, inst.X, inst.X + inst.N); + ++pc; + } break; + case DICT_CONSTRUCT: { + const auto& type = code_->types_[inst.X]->expectRef(); + dictConstruct(stack, type, inst.N); + ++pc; + } break; + case NAMED_TUPLE_CONSTRUCT: { + namedTupleConstruct( + stack, code_->types_[inst.X]->expect(), inst.N); + ++pc; + } break; + case CREATE_OBJECT: { + auto type = code_->types_[inst.X]->expect(); + createObject(stack, type); + ++pc; + } break; + case ISINSTANCE: { + at::ArrayRef types( + &(code_->types_[inst.X]), &(code_->types_[inst.X + inst.N])); + isinstance(stack, types); + ++pc; + } break; + case WARN: { + drop(stack, 1); + // Note: Please don't move the pop(stack) code below into the + // TORCH_WARN macro since TORCH_WARN fails to evaluate its arguments + // when STRIP_ERROR_MESSAGES is defined (which happens for production + // mobile builds). This will cause the stack to be in an inconsistent + // state. It has previously resulted in a SEV (S22350). + const auto& sref = stack.back().toStringRef(); + TORCH_WARN(sref); + stack.pop_back(); + ++pc; + } break; + default: + AT_ERROR(toString(inst.op), " is invalid."); + } + } catch (c10::Error& error) { + // Reason for catching and rethrowing the error is so that we can + // set the exception pc that is queried later + exception_pc_ = pc; + TORCH_RETHROW(error); + } catch (...) { + exception_pc_ = pc; + throw; } // for (auto val : stack) { // if (val.isTensor()) { diff --git a/torch/csrc/jit/mobile/interpreter.h b/torch/csrc/jit/mobile/interpreter.h index c80ded3bc6b2..b89b73bc5c2a 100644 --- a/torch/csrc/jit/mobile/interpreter.h +++ b/torch/csrc/jit/mobile/interpreter.h @@ -8,9 +8,13 @@ namespace torch { namespace jit { namespace mobile { using Stack = std::vector; +using DebugHandle = int64_t; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct Code { + // TODO: Combine instructions and debug handles vector + // into std::vector<> std::vector instructions_; + std::vector debug_handles_; std::vector op_names_; std::vector> operators_; std::vector constants_; @@ -28,6 +32,14 @@ struct InterpreterState { std::vector registers_; }; +// Interpreter executes instruction in a loop one by one +// from a list of instructions. PC is a program counter pointer +// pointing to the current instruction being executed. +// This function returns the current PC. +// Note that this is set only when exception occurs. +// since this is a thread local variable and setting it for +// every instruction will add overhead of thread local variable access. +int64_t getInterpretersExceptionPC(); } // namespace mobile } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/module.cpp b/torch/csrc/jit/mobile/module.cpp index 16194de7e09a..006a9f915acb 100644 --- a/torch/csrc/jit/mobile/module.cpp +++ b/torch/csrc/jit/mobile/module.cpp @@ -166,6 +166,11 @@ void Method::run(Stack& stack) const { observer->onExitRunMethod(instance_key); } } catch (c10::Error& error) { +#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) + auto debug_string = owner_->getDebugTable().getSourceDebugString( + function_->getExceptionDebugHandle()); + error.add_context(debug_string); +#endif if (observer) { observer->onFailRunMethod(instance_key, error.what()); } @@ -183,6 +188,11 @@ void Method::run(Stack& stack) const { } } } catch (c10::Error& error) { +#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) + auto debug_string = owner_->getDebugTable().getSourceDebugString( + function_->getExceptionDebugHandle()); + error.add_context(debug_string); +#endif if (observer) { observer->onFailRunMethod(instance_key, error.what()); } diff --git a/torch/csrc/jit/mobile/module.h b/torch/csrc/jit/mobile/module.h index c95b6cd4a21d..7b27da7d0166 100644 --- a/torch/csrc/jit/mobile/module.h +++ b/torch/csrc/jit/mobile/module.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include #include @@ -109,10 +110,18 @@ class TORCH_API Module { return or_else; } + void setDebugTable(MobileDebugTable&& debug_table) { + debug_table_ = std::move(debug_table); + } + const MobileDebugTable& getDebugTable() const { + return debug_table_; + } + private: c10::intrusive_ptr object_; std::unordered_map metadata_; std::shared_ptr cu_; + MobileDebugTable debug_table_; }; } // namespace mobile } // namespace jit diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index 39524b7a43eb..0c2d1b554e95 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -110,7 +111,8 @@ std::string getModuleTypeName(const Module& module, const std::string& prefix) { std::pair> getFunctionTuple( const Module& module, const Function& func, - bool save_mobile_debug_info) { + const bool save_mobile_debug_info, + const SourceRangeTagMap& source_range_tag_map) { auto graph = func.graph()->copy(); Inline(*graph); @@ -122,6 +124,7 @@ std::pair> getFunctionTuple( std::vector opnames; std::vector method_names; std::vector op_module_paths; + std::vector op_source_debug_tags; for (size_t i = 0; i < instructions_copy.size(); ++i) { Instruction ins = instructions_copy[i]; if (ins.op == OP || ins.op == OPN) { @@ -129,7 +132,30 @@ std::pair> getFunctionTuple( opnames.emplace_back(node->schema().operator_name()); if (save_mobile_debug_info) { std::string root_scope_string = getModuleTypeName(module, "top"); + // A little explanation as to why node->sourceRange() is not enough + // and we need to do node->sourceRange().findSourceRangeThatGenerated() + // When you do m = torch.jit.script/trace(model) + // the scripted model has graphs for methods with nodes. + // The nodes of this graph are annotated with sourceRange that is + // original python code. However when such a model is serialized via + // torch.jit.save, what is saved is compiled python code of the model. + // The compiled code for methods is effectively compiled graph which + // contains prim/aten etc. ops. This is not the same as original python + // code. When such a serialized model is loaded via torch.jit.load, + // node->sourceRange() does not point to original python code but points + // to transformed/compiled python code. So in order to get original + // python code which is serialized in debug_pkl, it is necessary to do + // node->sourceRange().findSourceRangeThatGenerated() + auto source_range = node->sourceRange().findSourceRangeThatGenerated() + ? node->sourceRange().findSourceRangeThatGenerated().value() + : node->sourceRange(); + int64_t source_range_tag{-1}; + const auto& it = source_range_tag_map.find(source_range); + if (it != source_range_tag_map.end()) { + source_range_tag = it->second; + } op_module_paths.emplace_back(getModulePath(node, root_scope_string)); + op_source_debug_tags.emplace_back(source_range_tag); } } // CALL nodes at this point represent built-in (i.e. non-Graph) @@ -280,12 +306,22 @@ std::pair> getFunctionTuple( c10::optional debug_info_vals; if (save_mobile_debug_info) { // module debug info - std::vector module_paths; - module_paths.reserve(op_module_paths.size()); - for (auto& path : op_module_paths) { - module_paths.emplace_back(std::move(path)); + // Temporarily adding source debug tag here. + // In the diffs to follow we should move to serializing + // InlinedCallStack. + // Then we will just serialize either a vector or dictionary + // of debug handles, a.k.a PCs for lite interpreter, + // to InlinedCallStack + std::vector module_debug_tuples; + module_debug_tuples.reserve(op_module_paths.size()); + for (size_t i = 0; i < op_module_paths.size(); ++i) { + auto& path = op_module_paths[i]; + int64_t source_debug_tag = op_source_debug_tags[i]; + module_debug_tuples.emplace_back( + c10::ivalue::Tuple::create({std::move(path), source_debug_tag})); } - auto module_debug_info = Table({{"module_debug_info", Tup(module_paths)}}); + auto module_debug_info = + Table({{"module_debug_info", Tup(module_debug_tuples)}}); debug_info_vals = Tup({func.qualname().qualifiedName(), module_debug_info}); } return std::make_pair(bytecode_vals, debug_info_vals); @@ -297,7 +333,8 @@ void setstateTuple( std::vector& elements, std::unordered_set& qn_cache, c10::optional>& debug_info_elements, - bool save_mobile_debug_info) { + const bool save_mobile_debug_info, + const SourceRangeTagMap& source_range_map) { if (!ivalue.isObject()) return; auto obj = ivalue.toObject(); @@ -309,8 +346,8 @@ void setstateTuple( return; } if (setstate.isGraphFunction()) { - auto func_tuple = - getFunctionTuple(module, setstate, save_mobile_debug_info); + auto func_tuple = getFunctionTuple( + module, setstate, save_mobile_debug_info, source_range_map); elements.push_back(func_tuple.first); qn_cache.emplace(qn); if (save_mobile_debug_info) { @@ -325,7 +362,8 @@ void setstateTuple( elements, qn_cache, debug_info_elements, - save_mobile_debug_info); + save_mobile_debug_info, + source_range_map); } } } @@ -335,7 +373,8 @@ void moduleMethodsTuple( const Module& module, std::vector& elements, // note: appended to in-place c10::optional>& debug_info_elements, - bool save_mobile_debug_info) { + const bool save_mobile_debug_info, + const SourceRangeTagMap& source_range_map = SourceRangeTagMap()) { auto methods = module.get_methods(); std::unordered_set qn_cache; // top level methods @@ -344,8 +383,8 @@ void moduleMethodsTuple( if (qn_cache.find(qn) != qn_cache.end()) { continue; } - auto func_tuple = - getFunctionTuple(module, method.function(), save_mobile_debug_info); + auto func_tuple = getFunctionTuple( + module, method.function(), save_mobile_debug_info, source_range_map); elements.push_back(func_tuple.first); qn_cache.emplace(qn); if (save_mobile_debug_info) { @@ -360,7 +399,8 @@ void moduleMethodsTuple( elements, qn_cache, debug_info_elements, - save_mobile_debug_info); + save_mobile_debug_info, + source_range_map); } void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) { @@ -488,6 +528,15 @@ class ScriptModuleSerializer { } } + void updateSourceRangeTags(const SourceRangeRecords& ranges) { + for (const auto& range : ranges) { + if (source_range_tags_.find(range.range) == source_range_tags_.end()) { + source_range_tags_[range.range] = current_source_range_tag_; + current_source_range_tag_++; + } + } + } + void writeCode(const at::NamedTypePtr& root_type) { class_deps_.add(root_type); for (size_t i = 0; i < class_deps_.size(); ++i) { @@ -496,6 +545,7 @@ class ScriptModuleSerializer { convertNamedType(class_deps_[i]); } + current_source_range_tag_ = 0; // Mapping of filename => src. We need this because multiple classes may go // in the same file (e.g. foo.bar.Baz and foo.bar.Qux) for (auto& item : file_streams_) { @@ -517,7 +567,9 @@ class ScriptModuleSerializer { // Write out the debug information std::string debugFilename = filename + ".debug_pkl"; SourceRangePickler source_range_pickler; - auto range_data = source_range_pickler.pickle(item.value().ranges()); + updateSourceRangeTags(item.value().ranges()); + auto range_data = source_range_pickler.pickle( + item.value().ranges(), source_range_tags_); writer_.writeRecord( debugFilename, range_data.data(), @@ -526,7 +578,7 @@ class ScriptModuleSerializer { } } - void writeByteCode(const Module& module, bool save_mobile_debug_info) { + void writeByteCode(const Module& module, const bool save_mobile_debug_info) { std::vector elements; elements.emplace_back( static_cast(caffe2::serialize::kProducedBytecodeVersion)); @@ -538,7 +590,11 @@ class ScriptModuleSerializer { } moduleMethodsTuple( - module, elements, debug_info_elements, save_mobile_debug_info); + module, + elements, + debug_info_elements, + save_mobile_debug_info, + source_range_tags_); auto telements = Tup(std::move(elements)); writeArchive("bytecode", telements); if (save_mobile_debug_info) { @@ -585,6 +641,32 @@ class ScriptModuleSerializer { // qualifier, e.g. '__torch__.Bar' -> PythonPrint for the file that will be // created OrderedDict file_streams_; + + // Uniquely identifies a SourceRange in a model. + // SourceRanges are associated with Nodes of Graphs. + // However for mobile deployment we dont intend to ship + // full JIT with capabilities of reading code and constructing + // graphs. + // Instead we serialize the Code generated from graph of the methods. + // Code is serialized in bytecode format that contains instructions + // corresponding to the nodes of the graph. Since original graph is gone, the + // question is how do we identify where the ops, in serialized bytecode, come + // from in original model code. We do this in two parts. + // 1. Associate a unique tag to SourceRange. + // 2. Serialize this unique_tag. + // 2.1 Meaning save instead of + // + // 3. During serializing model for mobile, i.e. bytecode generation, + // save unique tag of SourceRange corresponding to the Node. + // 4. During deserialization, read all the debug_pkl, to construct a map + // of and use tag saved with OPs in bytecode + // to lookup the source range. + // Strictly speaking we will serialize InlinedCallStack directly, which + // contains SourceRange. This way we have access to entire callstack and not + // just source information about where the node is, since bytecode inlines the + // graph before saving it. + SourceRangeTagMap source_range_tags_; + int64_t current_source_range_tag_; }; void ExportModule( diff --git a/torch/csrc/jit/serialization/source_range_serialization.cpp b/torch/csrc/jit/serialization/source_range_serialization.cpp index 9f158e48f0e3..b9e2df8764b1 100644 --- a/torch/csrc/jit/serialization/source_range_serialization.cpp +++ b/torch/csrc/jit/serialization/source_range_serialization.cpp @@ -1,7 +1,6 @@ #include #include -#include #include namespace torch { @@ -23,42 +22,33 @@ class SourceRangeSerializer { std::unordered_map, c10::IValue> serialized_sources; }; -class SourceRangeDeserializer { - public: - SourceRange deserialize(const c10::IValue& iv) { - auto tup_elems = iv.toTuple()->elements(); - TORCH_INTERNAL_ASSERT(tup_elems.size() == 3); - std::shared_ptr source_ = deserialize_source(tup_elems[0]); - int64_t start_ = tup_elems[1].toInt(); - int64_t end_ = tup_elems[2].toInt(); - return SourceRange(source_, start_, end_); +SourceRange SourceRangeDeserializer::deserialize(const c10::IValue& iv) { + auto tup_elems = iv.toTuple()->elements(); + TORCH_INTERNAL_ASSERT(tup_elems.size() == 3); + std::shared_ptr source_ = deserialize_source(tup_elems[0]); + int64_t start_ = tup_elems[1].toInt(); + int64_t end_ = tup_elems[2].toInt(); + return SourceRange(source_, start_, end_); +} + +std::shared_ptr SourceRangeDeserializer::deserialize_source( + const c10::IValue& iv) { + auto tup = iv.toTuple(); + if (cached_sources.count(tup)) { + return cached_sources.at(tup); } - private: - std::shared_ptr deserialize_source(const c10::IValue& iv) { - auto tup = iv.toTuple(); - if (cached_sources.count(tup)) { - return cached_sources.at(tup); - } + auto tup_elems = tup->elements(); + TORCH_INTERNAL_ASSERT(tup_elems.size() == 3); + std::string text_ = tup_elems[0].toString()->string(); + c10::optional filename_ = tup_elems[1].toOptional(); + int64_t starting_line_no_ = tup_elems[2].toInt(); - auto tup_elems = tup->elements(); - TORCH_INTERNAL_ASSERT(tup_elems.size() == 3); - std::string text_ = tup_elems[0].toString()->string(); - c10::optional filename_ = - tup_elems[1].toOptional(); - int64_t starting_line_no_ = tup_elems[2].toInt(); - - auto source = std::make_shared( - std::move(text_), std::move(filename_), starting_line_no_); - cached_sources[tup] = source; - return source; - } - - std::unordered_map< - c10::intrusive_ptr, - std::shared_ptr> - cached_sources; -}; + auto source = std::make_shared( + std::move(text_), std::move(filename_), starting_line_no_); + cached_sources[tup] = source; + return source; +} c10::IValue SourceRangeSerializer::serialize(const SourceRange& sr) { std::vector elements = { @@ -84,11 +74,20 @@ c10::IValue SourceRangeSerializer::serialize_source( SourceRangePickler::SourceRangePickler() : srs(new SourceRangeSerializer()) {} -std::vector SourceRangePickler::pickle(const SourceRangeRecords& ranges) { +std::vector SourceRangePickler::pickle( + const SourceRangeRecords& ranges, + const SourceRangeTagMap& source_range_tags) { std::vector ivalues; for (const auto& range : ranges) { + int64_t source_range_tag{-1}; + const auto& it = source_range_tags.find(range.range); + if (it != source_range_tags.end()) { + source_range_tag = it->second; + } std::vector row_elems{ - (int64_t)range.bytes, srs->serialize(range.range)}; + (int64_t)range.bytes, + srs->serialize(range.range), + static_cast(source_range_tag)}; ivalues.emplace_back(c10::ivalue::Tuple::create(std::move(row_elems))); } std::vector table; @@ -118,8 +117,8 @@ void ConcreteSourceRangeUnpickler::unpickle() { unpickled_records = std::make_shared(); for (auto& val : ivalues) { auto tup_elems = val.toTuple()->elements(); - int64_t offset = tup_elems[0].toInt(); - auto source_range = deserializer->deserialize(tup_elems[1]); + int64_t offset = tup_elems[kByteOffsetIndex].toInt(); + auto source_range = deserializer->deserialize(tup_elems[kSourceRangeIndex]); unpickled_records->emplace_back(offset, std::move(source_range)); } } diff --git a/torch/csrc/jit/serialization/source_range_serialization.h b/torch/csrc/jit/serialization/source_range_serialization.h index 3e47b14cd997..85f42b865d34 100644 --- a/torch/csrc/jit/serialization/source_range_serialization.h +++ b/torch/csrc/jit/serialization/source_range_serialization.h @@ -3,6 +3,8 @@ #include #include +#include + #include #include @@ -15,18 +17,34 @@ namespace jit { class Pickler; class SourceRangeSerializer; -class SourceRangeDeserializer; +static constexpr size_t kByteOffsetIndex = 0; +static constexpr size_t kSourceRangeIndex = 1; +static constexpr size_t kSourceRangeTagIndex = 2; class SourceRangePickler { public: SourceRangePickler(); - std::vector pickle(const SourceRangeRecords& ranges); + std::vector pickle( + const SourceRangeRecords& ranges, + const SourceRangeTagMap& source_range_tags); private: std::shared_ptr srs; }; +class SourceRangeDeserializer { + public: + SourceRange deserialize(const c10::IValue& iv); + + private: + std::shared_ptr deserialize_source(const c10::IValue& iv); + std::unordered_map< + c10::intrusive_ptr, + std::shared_ptr> + cached_sources; +}; + class SourceRangeUnpickler { public: virtual c10::optional findSourceRangeThatGenerated( diff --git a/torch/utils/model_dump/__init__.py b/torch/utils/model_dump/__init__.py index ebc522f00bd3..fe5728fdc2d9 100644 --- a/torch/utils/model_dump/__init__.py +++ b/torch/utils/model_dump/__init__.py @@ -240,7 +240,8 @@ def get_model_info( code_parts = [] for di, di_next in zip(debug_info, debug_info[1:]): - start, source_range = di + # accounting for source range serialization format change + start, source_range, _ = di end = di_next[0] assert end > start source, s_start, s_end = source_range