[PyTorch, Mobile] Serialization format change for source range (#54284)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54284

In order to bring mobile deployment, via lite interpreter, on feature
parity with JIT, with respect model level debug information we must make
model level debug information available to mobile runtime.
At the moment, model level debug information is stored in SourceRange
which associates node's of graph to where the come from in original
python source code.
This information is serialized as part of debug_pkl and deserialized
when JIT loads the model and reads the model code.
On lite interpreter, we do not have access to all the functionality of
JIT and hence we cannot load model in the same way as JIT, by reading
code, constructing module hierarchy and graph corresponding module
methods etc. Instead in, lite interpreter, only bytecode corresonding to
the compiled graph, Code, is saved.
Thus in order to annotate OPs in the bytecode with equivalent
SourceRange information we do the following:
1. During model serialization, we create a unique tag for each source
range of the model.
2. Create a map of <SourceRange, tag>
3. During debug_pkl serialization we save tag along with SourceRange, on
top of byte offset.
4. During bytecode generation, the methods of the top module are
lowered. During this process methods are inlined. In the inlined graph,
when the node of a graph is lowered to bytecode, we query node's source
range and look it up against the map.
5. Resulting source range tag is serialized in module_debug_info.
6. During model deserialization, we read all the debug_pkl records in
the archieve and create a map of <tag, SourceRange>
7. This map can be used to find source code information.

During mobile runtime:
1. We read all the debug_pkl records and create <tag=debug_handle,
SourceRange> map.
   1.1 This map, MobileDebugInfo, is a member of mobile Module.
2. Interpreter catches appropriate exceptions and sets the thread local
debug handle and rethrows the exception.
3. In Function's run method we catch exception and query current debug
handle where the exception happened.
4. Query MobileDebugInfo with debug handle to retrieve source range and
augment error with source range info.

This information is still incomplete as it does not contain entire
callstack.

In the following diffs we will serialize InlinedCallStack directly.

Note that compilation is gated by SYMBOLICATE_MOBILE_DEBUG_HANDLE macro,
so that mobile builds can avoid building MobileDebugInfo, source range
and source range pickler/unpickler. Later we will add path where, if
building without debug support stack trace will contain only debug
handles. They can be symbolicated later.

Test Plan:
Ported bunch of source range tests from test_jit.py. Added on more test
in test_lite_interpreter.py

Imported from OSS

Reviewed By: raziel

Differential Revision: D27174722

fbshipit-source-id: a7b7c6088ce16dec37e823c7fefa4f0b61047e12
This commit is contained in:
Kimish Patel
2021-05-04 09:17:43 -07:00
committed by Facebook GitHub Bot
parent aa5ff7cc91
commit f4a921600a
20 changed files with 625 additions and 239 deletions

View File

@ -254,6 +254,7 @@ if(NOT DEFINED USE_VULKAN)
"ANDROID" OFF)
endif()
option(USE_SOURCE_DEBUG_ON_MOBILE "Enable " ON)
option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF)
option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF)
option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation as opposed to build-time (needs libshaderc)" OFF)
@ -647,6 +648,10 @@ if(USE_PYTORCH_METAL)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_METAL")
endif()
if(USE_SOURCE_DEBUG_ON_MOBILE)
string(APPEND CMAKE_CXX_FLAGS " -DSYMBOLICATE_MOBILE_DEBUG_HANDLE")
endif()
# ---[ Allowlist file if allowlist is specified
include(cmake/Allowlist.cmake)

View File

@ -516,9 +516,22 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
list(APPEND TORCH_SRCS ${GENERATED_H_TORCH})
list(APPEND LIBTORCH_CMAKE_SRCS "")
list(APPEND LITE_EAGER_SYMOBLICATION_SRCS "")
if(USE_SOURCE_DEBUG_ON_MOBILE)
append_filelist("libtorch_lite_eager_symbolication" LITE_EAGER_SYMOBLICATION_SRCS)
# For source debug on lite interpreter, we have to add dependency on pickling
# but references to read/writeArchiveAndTensor is not built for mobile
# so this condition specifically says we are building for source debug
# on mobile.
if(BUILD_LITE_INTERPRETER)
set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/serialization/pickle.cpp PROPERTIES COMPILE_FLAGS "-DC10_MOBILE -DFEATURE_TORCH_MOBILE")
endif()
endif()
# Switch between the full jit interpreter and lite interpreter
if(BUILD_LITE_INTERPRETER)
append_filelist("libtorch_lite_cmake_sources" LIBTORCH_CMAKE_SRCS)
list(APPEND LIBTORCH_CMAKE_SRCS ${LITE_EAGER_SYMOBLICATION_SRCS})
else()
append_filelist("libtorch_cmake_sources" LIBTORCH_CMAKE_SRCS)
@ -565,6 +578,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/mobile/sequential.cpp
)
list(APPEND TORCH_SRCS ${MOBILE_SRCS})
list(APPEND TORCH_SRCS ${LITE_EAGER_SYMOBLICATION_SRCS})
endif()
# This one needs to be unconditionally added as Functions.cpp is also unconditionally added

View File

@ -4,12 +4,27 @@ from torch.utils.mobile_optimizer import optimize_for_mobile
import io
from typing import Dict, List, NamedTuple
from collections import namedtuple
import inspect
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list
from torch.testing._internal.common_utils import TestCase, run_tests
class TestLiteScriptModule(TestCase):
def getScriptExportImportCopy(self, m, save_mobile_debug_info=True, also_test_file=False):
m_scripted = torch.jit.script(m)
if not also_test_file:
buffer = io.BytesIO(m_scripted._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=save_mobile_debug_info))
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
return mobile_module
with TemporaryFileName() as fname:
m_scripted._save_for_lite_interpreter(fname, _save_mobile_debug_info=save_mobile_debug_info)
mobile_module = _load_for_lite_interpreter(fname)
return mobile_module
def test_load_mobile_module(self):
class MyTestModule(torch.nn.Module):
def __init__(self):
@ -374,5 +389,69 @@ class TestLiteScriptModule(TestCase):
actual_ops = _export_operator_list(mobile_module)
self.assertEqual(actual_ops, expected_ops)
def test_source_range_simple(self):
class FooTest(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, w):
return torch.mm(x, w.t())
ft = FooTest()
loaded = self.getScriptExportImportCopy(ft)
_, lineno = inspect.getsourcelines(FooTest)
with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)):
loaded(torch.rand(3, 4), torch.rand(30, 40))
def test_source_range_raise_exception(self):
class FooTest2(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self):
raise RuntimeError('foo')
_, lineno = inspect.getsourcelines(FooTest2)
with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)):
ft = FooTest2()
loaded = self.getScriptExportImportCopy(ft)
loaded()
def test_source_range_function_call(self):
class FooTest3(torch.jit.ScriptModule):
@torch.jit.script_method
def add_method(self, x, w):
return x + w
@torch.jit.script_method
def forward(self, x, y, w):
x = x * y
x = x + 2
return self.add_method(x, w)
ft = FooTest3()
loaded = self.getScriptExportImportCopy(ft)
_, lineno = inspect.getsourcelines(FooTest3)
with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)):
loaded(torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40))
def test_source_range_no_debug_info(self):
class FooTest4(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, w):
return torch.mm(x, w.t())
ft = FooTest4()
loaded = self.getScriptExportImportCopy(ft, save_mobile_debug_info=False)
try:
loaded(torch.rand(3, 4), torch.rand(30, 40))
except RuntimeError as e:
error_message = f"{e}"
self.assertTrue("test_lite_script_module.py" not in error_message)
if __name__ == '__main__':
run_tests()

View File

@ -4182,8 +4182,8 @@ def foo(xyz):
debug_files = debug_records_from_mod(ft3)
for debug_file in debug_files:
for i in range(len(debug_file) - 1):
offset, source_range = debug_file[i]
offset2, source_range2 = debug_file[i + 1]
offset, source_range_tag, source_range = debug_file[i]
offset2, source_range_tag2, source_range2 = debug_file[i + 1]
self.assertNotEqual(source_range, source_range2)
def test_circular_dependency(self):

View File

@ -356,6 +356,17 @@ torch_mobile_core = [
"torch/csrc/jit/runtime/register_special_ops.cpp",
]
libtorch_lite_eager_symbolication = [
"torch/csrc/jit/frontend/source_range.cpp",
"torch/csrc/jit/mobile/debug_info.cpp",
"torch/csrc/jit/serialization/source_range_serialization.cpp",
# Later we can split serialization and deserialization logic
# to have better separation within build and only build relevant parts.
"torch/csrc/jit/serialization/pickle.cpp",
"torch/csrc/jit/serialization/pickler.cpp",
"torch/csrc/jit/serialization/unpickler.cpp",
]
# TODO: core_trainer_sources is not necessary for libtorch lite
libtorch_lite_cmake_sources = sorted(core_trainer_sources + core_sources_common + torch_mobile_core)
@ -368,6 +379,9 @@ libtorch_extra_sources = libtorch_core_jit_sources + [
"torch/csrc/jit/api/module_save.cpp",
"torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp",
"torch/csrc/jit/mobile/export_data.cpp",
# To be included for eager symbolication in lite interpreter
# when it is built in libtorch
"torch/csrc/jit/mobile/debug_info.cpp",
"torch/csrc/jit/mobile/function.cpp",
"torch/csrc/jit/mobile/import.cpp",
"torch/csrc/jit/mobile/import_data.cpp",

View File

@ -3,6 +3,11 @@
namespace torch {
namespace jit {
size_t SourceRangeHasher::operator()(const torch::jit::SourceRange& key) const {
return (
std::hash<uintptr_t>()(reinterpret_cast<uintptr_t>(key.source().get())) ^
std::hash<size_t>()(key.start()) ^ std::hash<size_t>()(key.end()));
}
c10::optional<SourceRange> Source::findSourceRangeThatGenerated(
const SourceRange& range) {

View File

@ -5,6 +5,7 @@
#include <algorithm>
#include <iostream>
#include <memory>
#include <unordered_map>
namespace torch {
namespace jit {
@ -178,6 +179,11 @@ struct TORCH_API SourceRange {
size_t end_;
};
struct SourceRangeHasher {
public:
size_t operator()(const torch::jit::SourceRange& key) const;
};
struct StackEntry {
std::string filename;
SourceRange range;
@ -201,6 +207,8 @@ struct TaggedRange {
SourceRange range;
};
using SourceRangeRecords = std::vector<TaggedRange>;
using SourceRangeTagMap =
std::unordered_map<SourceRange, int64_t, SourceRangeHasher>;
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,53 @@
#include <torch/csrc/jit/mobile/debug_info.h>
#include <torch/csrc/jit/serialization/source_range_serialization.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/serialization/pickle.h>
#include <c10/util/string_view.h>
namespace torch {
namespace jit {
MobileDebugTable::MobileDebugTable(
std::unique_ptr<caffe2::serialize::PyTorchStreamReader>& reader) {
const std::vector<std::string>& record_names = reader->getAllRecords();
const c10::string_view suffix(".debug_pkl");
for (const auto& record_name : record_names) {
if (c10::string_view(record_name).ends_with(suffix)) {
at::DataPtr debug_data;
size_t debug_size{0};
std::tie(debug_data, debug_size) = reader->getRecord(record_name);
auto ivalues =
jit::unpickle(
reinterpret_cast<const char*>(debug_data.get()), debug_size)
.toTuple()
->elements();
SourceRangeDeserializer deserializer;
for (auto& val : ivalues) {
auto tup_elems = val.toTuple()->elements();
// For BC we decode only tuples with 3 elements
// assuming it contains
// byte_offset, debug_handle (=source range tag), source range
if (tup_elems.size() == 3) {
int64_t debug_handle = tup_elems[kSourceRangeTagIndex].toInt();
auto source_range =
deserializer.deserialize(tup_elems[kSourceRangeIndex]);
source_range_map_.emplace(debug_handle, std::move(source_range));
}
}
}
}
}
std::string MobileDebugTable::getSourceDebugString(
const int64_t debug_handle) const {
const auto it = source_range_map_.find(debug_handle);
if (it == source_range_map_.end()) {
return "";
}
return source_range_map_.at(debug_handle).str();
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,31 @@
#pragma once
#include <c10/util/flat_hash_map.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/frontend/source_range.h>
namespace torch {
namespace jit {
/*
* MobileDebugTable:
* Deserializes debug_pkl records from PT model's zip archive and
* stores them in a map of debug handles to source range.
* Debug handles are unique per model and runtime, be in lite interpreter
* or delegate, raises exception using debug handles.
* getSourceDebugString method is responsible for translating debug
* handles to correspond debug information.
* At the moment this only contains information about model source.
* But later diffs will include entire stack corresponding to debug handle.
*/
class MobileDebugTable {
public:
MobileDebugTable() = default;
MobileDebugTable(
std::unique_ptr<caffe2::serialize::PyTorchStreamReader>& reader);
std::string getSourceDebugString(const int64_t debug_handle) const;
private:
ska::flat_hash_map<int64_t, SourceRange> source_range_map_;
};
} // namespace jit
} // namespace torch

View File

@ -22,12 +22,13 @@ const std::string& Function::name() const {
return name_.name();
}
void Function::append_instruction(OpCode op, int X, int N) {
void Function::append_instruction(OpCode op, int X, int N, int64_t dbg_handle) {
TORCH_CHECK(
isOpSupportedInMobile(op),
toString(op),
" is not supported in mobile module.");
code_->instructions_.emplace_back(op, X, N);
code_->debug_handles_.emplace_back(dbg_handle);
}
bool Function::append_operator(
@ -130,6 +131,11 @@ const std::shared_ptr<Code> Function::get_code() const {
return code_;
}
int64_t Function::getExceptionDebugHandle() const {
size_t pc = getInterpretersExceptionPC();
return (pc < code_->debug_handles_.size()) ? code_->debug_handles_[pc] : -1;
}
} // namespace mobile
} // namespace jit
} // namespace torch

View File

@ -19,7 +19,7 @@ class Function {
c10::IValue operator()(Stack& stack) const;
const std::string& name() const;
const c10::QualifiedName& qualname() const;
void append_instruction(OpCode op, int X, int N);
void append_instruction(OpCode op, int X, int N, int64_t dbg_handle = -1);
bool append_operator(
const std::string& name,
const std::string& overload_name,
@ -37,6 +37,11 @@ class Function {
void setSchema(c10::FunctionSchema schema);
const at::optional<c10::FunctionSchema>& getSchema() const;
// Returns the debug handle corresponding to where the execution
// is halted due to exception.
// If no corresponding debug handle is found then -1 is returned.
int64_t getExceptionDebugHandle() const;
private:
c10::QualifiedName name_;
std::shared_ptr<Code> code_;

View File

@ -305,12 +305,25 @@ void BytecodeDeserializer::parseMethods(
OpCode op_code = parseOpCode(ins_item[0].toString()->string().c_str());
int X = ins_item[1].toInt();
int N = ins_item[2].toInt();
function->append_instruction(op_code, X, N);
// TODO: Save debug handles for all instructions, not just for OP
if (op_code == OP) {
std::string module_debug_info = (has_debug_info)
? module_debug_info_list[X].toString()->string()
: "";
// In later PRs we will refactor this to always save debug handles.
// Debug info, source range and inlined callstack ptr saving will become
// optional.
if (has_debug_info) {
auto module_debug_tuple =
module_debug_info_list[X].toTuple()->elements();
std::string module_debug_info =
module_debug_tuple[0].toString()->string();
int64_t debug_handle = module_debug_tuple[1].toInt();
function->set_module_info(module_debug_info, i);
function->append_instruction(op_code, X, N, debug_handle);
} else {
function->set_module_info("", i);
function->append_instruction(op_code, X, N);
}
} else {
function->append_instruction(op_code, X, N);
}
}
@ -441,7 +454,12 @@ mobile::Module BytecodeDeserializer::deserialize(
}
parseMethods(bvals, debug_info_bvals, *mcu);
auto meta_dict = readMobileMetadata(mcu);
return mobile::Module(readArchive("data", mcu).toObject(), meta_dict, mcu);
auto m = mobile::Module(readArchive("data", mcu).toObject(), meta_dict, mcu);
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
MobileDebugTable debug_table = MobileDebugTable(reader_);
m.setDebugTable(std::move(debug_table));
#endif
return m;
}
std::unordered_map<std::string, std::string> BytecodeDeserializer::

View File

@ -4,6 +4,7 @@
#include <ATen/core/jit_type.h>
#include <ATen/core/operator_name.h>
#include <torch/csrc/jit/mobile/function.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>
#include <ATen/record_function.h>
@ -20,6 +21,8 @@ InterpreterState::InterpreterState(std::shared_ptr<Code> code)
}
namespace {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static thread_local int64_t exception_pc_{-1};
void createObject(Stack& stack, const at::ClassTypePtr& type) {
auto userObj = c10::ivalue::Object::create(
c10::StrongTypePtr(type->compilation_unit(), type),
@ -41,9 +44,14 @@ void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
using namespace at;
int64_t getInterpretersExceptionPC() {
return exception_pc_;
}
bool InterpreterState::run(Stack& stack) {
size_t pc = 0;
while (true) {
try {
Instruction inst = code_->instructions_[pc];
// std::cout << "RUNNING " << pc << " " << code_->instructions_[pc];
@ -57,8 +65,8 @@ bool InterpreterState::run(Stack& stack) {
switch (inst.op) {
case OP: {
if (at::hasGlobalCallbacks()) {
if (auto* mobile_debug_info =
static_cast<MobileDebugInfo*>(c10::ThreadLocalDebugInfo::get(
if (auto* mobile_debug_info = static_cast<MobileDebugInfo*>(
c10::ThreadLocalDebugInfo::get(
c10::DebugInfoKind::MOBILE_RUNTIME_INFO))) {
mobile_debug_info->setOpIdx(pc);
}
@ -209,9 +217,9 @@ bool InterpreterState::run(Stack& stack) {
} break;
case WARN: {
drop(stack, 1);
// Note: Please don't move the pop(stack) code below into the TORCH_WARN
// macro since TORCH_WARN fails to evaluate its arguments when
// STRIP_ERROR_MESSAGES is defined (which happens for production
// Note: Please don't move the pop(stack) code below into the
// TORCH_WARN macro since TORCH_WARN fails to evaluate its arguments
// when STRIP_ERROR_MESSAGES is defined (which happens for production
// mobile builds). This will cause the stack to be in an inconsistent
// state. It has previously resulted in a SEV (S22350).
const auto& sref = stack.back().toStringRef();
@ -222,6 +230,15 @@ bool InterpreterState::run(Stack& stack) {
default:
AT_ERROR(toString(inst.op), " is invalid.");
}
} catch (c10::Error& error) {
// Reason for catching and rethrowing the error is so that we can
// set the exception pc that is queried later
exception_pc_ = pc;
TORCH_RETHROW(error);
} catch (...) {
exception_pc_ = pc;
throw;
}
// for (auto val : stack) {
// if (val.isTensor()) {
// std::cout << val.toTensor().sizes() << std::endl;

View File

@ -8,9 +8,13 @@ namespace torch {
namespace jit {
namespace mobile {
using Stack = std::vector<c10::IValue>;
using DebugHandle = int64_t;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct Code {
// TODO: Combine instructions and debug handles vector
// into std::vector<<std::pair<Instruction, DebugHandle>>
std::vector<Instruction> instructions_;
std::vector<DebugHandle> debug_handles_;
std::vector<c10::OperatorName> op_names_;
std::vector<std::function<void(Stack&)>> operators_;
std::vector<c10::IValue> constants_;
@ -28,6 +32,14 @@ struct InterpreterState {
std::vector<c10::IValue> registers_;
};
// Interpreter executes instruction in a loop one by one
// from a list of instructions. PC is a program counter pointer
// pointing to the current instruction being executed.
// This function returns the current PC.
// Note that this is set only when exception occurs.
// since this is a thread local variable and setting it for
// every instruction will add overhead of thread local variable access.
int64_t getInterpretersExceptionPC();
} // namespace mobile
} // namespace jit
} // namespace torch

View File

@ -166,6 +166,11 @@ void Method::run(Stack& stack) const {
observer->onExitRunMethod(instance_key);
}
} catch (c10::Error& error) {
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
auto debug_string = owner_->getDebugTable().getSourceDebugString(
function_->getExceptionDebugHandle());
error.add_context(debug_string);
#endif
if (observer) {
observer->onFailRunMethod(instance_key, error.what());
}
@ -183,6 +188,11 @@ void Method::run(Stack& stack) const {
}
}
} catch (c10::Error& error) {
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
auto debug_string = owner_->getDebugTable().getSourceDebugString(
function_->getExceptionDebugHandle());
error.add_context(debug_string);
#endif
if (observer) {
observer->onFailRunMethod(instance_key, error.what());
}

View File

@ -1,5 +1,6 @@
#pragma once
#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/mobile/debug_info.h>
#include <torch/csrc/jit/mobile/function.h>
#include <torch/csrc/jit/mobile/method.h>
@ -109,10 +110,18 @@ class TORCH_API Module {
return or_else;
}
void setDebugTable(MobileDebugTable&& debug_table) {
debug_table_ = std::move(debug_table);
}
const MobileDebugTable& getDebugTable() const {
return debug_table_;
}
private:
c10::intrusive_ptr<c10::ivalue::Object> object_;
std::unordered_map<std::string, std::string> metadata_;
std::shared_ptr<CompilationUnit> cu_;
MobileDebugTable debug_table_;
};
} // namespace mobile
} // namespace jit

View File

@ -1,6 +1,7 @@
#include <torch/csrc/jit/serialization/export.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/frontend/source_range.h>
#include <torch/csrc/jit/ir/attributes.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/type_hashing.h>
@ -110,7 +111,8 @@ std::string getModuleTypeName(const Module& module, const std::string& prefix) {
std::pair<IValue, c10::optional<IValue>> getFunctionTuple(
const Module& module,
const Function& func,
bool save_mobile_debug_info) {
const bool save_mobile_debug_info,
const SourceRangeTagMap& source_range_tag_map) {
auto graph = func.graph()->copy();
Inline(*graph);
@ -122,6 +124,7 @@ std::pair<IValue, c10::optional<IValue>> getFunctionTuple(
std::vector<c10::OperatorName> opnames;
std::vector<std::string> method_names;
std::vector<std::string> op_module_paths;
std::vector<int64_t> op_source_debug_tags;
for (size_t i = 0; i < instructions_copy.size(); ++i) {
Instruction ins = instructions_copy[i];
if (ins.op == OP || ins.op == OPN) {
@ -129,7 +132,30 @@ std::pair<IValue, c10::optional<IValue>> getFunctionTuple(
opnames.emplace_back(node->schema().operator_name());
if (save_mobile_debug_info) {
std::string root_scope_string = getModuleTypeName(module, "top");
// A little explanation as to why node->sourceRange() is not enough
// and we need to do node->sourceRange().findSourceRangeThatGenerated()
// When you do m = torch.jit.script/trace(model)
// the scripted model has graphs for methods with nodes.
// The nodes of this graph are annotated with sourceRange that is
// original python code. However when such a model is serialized via
// torch.jit.save, what is saved is compiled python code of the model.
// The compiled code for methods is effectively compiled graph which
// contains prim/aten etc. ops. This is not the same as original python
// code. When such a serialized model is loaded via torch.jit.load,
// node->sourceRange() does not point to original python code but points
// to transformed/compiled python code. So in order to get original
// python code which is serialized in debug_pkl, it is necessary to do
// node->sourceRange().findSourceRangeThatGenerated()
auto source_range = node->sourceRange().findSourceRangeThatGenerated()
? node->sourceRange().findSourceRangeThatGenerated().value()
: node->sourceRange();
int64_t source_range_tag{-1};
const auto& it = source_range_tag_map.find(source_range);
if (it != source_range_tag_map.end()) {
source_range_tag = it->second;
}
op_module_paths.emplace_back(getModulePath(node, root_scope_string));
op_source_debug_tags.emplace_back(source_range_tag);
}
}
// CALL nodes at this point represent built-in (i.e. non-Graph)
@ -280,12 +306,22 @@ std::pair<IValue, c10::optional<IValue>> getFunctionTuple(
c10::optional<IValue> debug_info_vals;
if (save_mobile_debug_info) {
// module debug info
std::vector<IValue> module_paths;
module_paths.reserve(op_module_paths.size());
for (auto& path : op_module_paths) {
module_paths.emplace_back(std::move(path));
// Temporarily adding source debug tag here.
// In the diffs to follow we should move to serializing
// InlinedCallStack.
// Then we will just serialize either a vector or dictionary
// of debug handles, a.k.a PCs for lite interpreter,
// to InlinedCallStack
std::vector<IValue> module_debug_tuples;
module_debug_tuples.reserve(op_module_paths.size());
for (size_t i = 0; i < op_module_paths.size(); ++i) {
auto& path = op_module_paths[i];
int64_t source_debug_tag = op_source_debug_tags[i];
module_debug_tuples.emplace_back(
c10::ivalue::Tuple::create({std::move(path), source_debug_tag}));
}
auto module_debug_info = Table({{"module_debug_info", Tup(module_paths)}});
auto module_debug_info =
Table({{"module_debug_info", Tup(module_debug_tuples)}});
debug_info_vals = Tup({func.qualname().qualifiedName(), module_debug_info});
}
return std::make_pair(bytecode_vals, debug_info_vals);
@ -297,7 +333,8 @@ void setstateTuple(
std::vector<c10::IValue>& elements,
std::unordered_set<std::string>& qn_cache,
c10::optional<std::vector<c10::IValue>>& debug_info_elements,
bool save_mobile_debug_info) {
const bool save_mobile_debug_info,
const SourceRangeTagMap& source_range_map) {
if (!ivalue.isObject())
return;
auto obj = ivalue.toObject();
@ -309,8 +346,8 @@ void setstateTuple(
return;
}
if (setstate.isGraphFunction()) {
auto func_tuple =
getFunctionTuple(module, setstate, save_mobile_debug_info);
auto func_tuple = getFunctionTuple(
module, setstate, save_mobile_debug_info, source_range_map);
elements.push_back(func_tuple.first);
qn_cache.emplace(qn);
if (save_mobile_debug_info) {
@ -325,7 +362,8 @@ void setstateTuple(
elements,
qn_cache,
debug_info_elements,
save_mobile_debug_info);
save_mobile_debug_info,
source_range_map);
}
}
}
@ -335,7 +373,8 @@ void moduleMethodsTuple(
const Module& module,
std::vector<c10::IValue>& elements, // note: appended to in-place
c10::optional<std::vector<c10::IValue>>& debug_info_elements,
bool save_mobile_debug_info) {
const bool save_mobile_debug_info,
const SourceRangeTagMap& source_range_map = SourceRangeTagMap()) {
auto methods = module.get_methods();
std::unordered_set<std::string> qn_cache;
// top level methods
@ -344,8 +383,8 @@ void moduleMethodsTuple(
if (qn_cache.find(qn) != qn_cache.end()) {
continue;
}
auto func_tuple =
getFunctionTuple(module, method.function(), save_mobile_debug_info);
auto func_tuple = getFunctionTuple(
module, method.function(), save_mobile_debug_info, source_range_map);
elements.push_back(func_tuple.first);
qn_cache.emplace(qn);
if (save_mobile_debug_info) {
@ -360,7 +399,8 @@ void moduleMethodsTuple(
elements,
qn_cache,
debug_info_elements,
save_mobile_debug_info);
save_mobile_debug_info,
source_range_map);
}
void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) {
@ -488,6 +528,15 @@ class ScriptModuleSerializer {
}
}
void updateSourceRangeTags(const SourceRangeRecords& ranges) {
for (const auto& range : ranges) {
if (source_range_tags_.find(range.range) == source_range_tags_.end()) {
source_range_tags_[range.range] = current_source_range_tag_;
current_source_range_tag_++;
}
}
}
void writeCode(const at::NamedTypePtr& root_type) {
class_deps_.add(root_type);
for (size_t i = 0; i < class_deps_.size(); ++i) {
@ -496,6 +545,7 @@ class ScriptModuleSerializer {
convertNamedType(class_deps_[i]);
}
current_source_range_tag_ = 0;
// Mapping of filename => src. We need this because multiple classes may go
// in the same file (e.g. foo.bar.Baz and foo.bar.Qux)
for (auto& item : file_streams_) {
@ -517,7 +567,9 @@ class ScriptModuleSerializer {
// Write out the debug information
std::string debugFilename = filename + ".debug_pkl";
SourceRangePickler source_range_pickler;
auto range_data = source_range_pickler.pickle(item.value().ranges());
updateSourceRangeTags(item.value().ranges());
auto range_data = source_range_pickler.pickle(
item.value().ranges(), source_range_tags_);
writer_.writeRecord(
debugFilename,
range_data.data(),
@ -526,7 +578,7 @@ class ScriptModuleSerializer {
}
}
void writeByteCode(const Module& module, bool save_mobile_debug_info) {
void writeByteCode(const Module& module, const bool save_mobile_debug_info) {
std::vector<c10::IValue> elements;
elements.emplace_back(
static_cast<int64_t>(caffe2::serialize::kProducedBytecodeVersion));
@ -538,7 +590,11 @@ class ScriptModuleSerializer {
}
moduleMethodsTuple(
module, elements, debug_info_elements, save_mobile_debug_info);
module,
elements,
debug_info_elements,
save_mobile_debug_info,
source_range_tags_);
auto telements = Tup(std::move(elements));
writeArchive("bytecode", telements);
if (save_mobile_debug_info) {
@ -585,6 +641,32 @@ class ScriptModuleSerializer {
// qualifier, e.g. '__torch__.Bar' -> PythonPrint for the file that will be
// created
OrderedDict<std::string, PythonPrint> file_streams_;
// Uniquely identifies a SourceRange in a model.
// SourceRanges are associated with Nodes of Graphs.
// However for mobile deployment we dont intend to ship
// full JIT with capabilities of reading code and constructing
// graphs.
// Instead we serialize the Code generated from graph of the methods.
// Code is serialized in bytecode format that contains instructions
// corresponding to the nodes of the graph. Since original graph is gone, the
// question is how do we identify where the ops, in serialized bytecode, come
// from in original model code. We do this in two parts.
// 1. Associate a unique tag to SourceRange.
// 2. Serialize this unique_tag.
// 2.1 Meaning save <byte_offset, source_range_tag, source range> instead of
// <byte_offset, source range>
// 3. During serializing model for mobile, i.e. bytecode generation,
// save unique tag of SourceRange corresponding to the Node.
// 4. During deserialization, read all the debug_pkl, to construct a map
// of <unique_tag, SourceRange> and use tag saved with OPs in bytecode
// to lookup the source range.
// Strictly speaking we will serialize InlinedCallStack directly, which
// contains SourceRange. This way we have access to entire callstack and not
// just source information about where the node is, since bytecode inlines the
// graph before saving it.
SourceRangeTagMap source_range_tags_;
int64_t current_source_range_tag_;
};
void ExportModule(

View File

@ -1,7 +1,6 @@
#include <torch/csrc/jit/serialization/source_range_serialization.h>
#include <torch/csrc/jit/serialization/source_range_serialization_impl.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/serialization/pickle.h>
namespace torch {
@ -23,9 +22,7 @@ class SourceRangeSerializer {
std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources;
};
class SourceRangeDeserializer {
public:
SourceRange deserialize(const c10::IValue& iv) {
SourceRange SourceRangeDeserializer::deserialize(const c10::IValue& iv) {
auto tup_elems = iv.toTuple()->elements();
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
std::shared_ptr<Source> source_ = deserialize_source(tup_elems[0]);
@ -34,8 +31,8 @@ class SourceRangeDeserializer {
return SourceRange(source_, start_, end_);
}
private:
std::shared_ptr<Source> deserialize_source(const c10::IValue& iv) {
std::shared_ptr<Source> SourceRangeDeserializer::deserialize_source(
const c10::IValue& iv) {
auto tup = iv.toTuple();
if (cached_sources.count(tup)) {
return cached_sources.at(tup);
@ -44,8 +41,7 @@ class SourceRangeDeserializer {
auto tup_elems = tup->elements();
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
std::string text_ = tup_elems[0].toString()->string();
c10::optional<std::string> filename_ =
tup_elems[1].toOptional<std::string>();
c10::optional<std::string> filename_ = tup_elems[1].toOptional<std::string>();
int64_t starting_line_no_ = tup_elems[2].toInt();
auto source = std::make_shared<Source>(
@ -54,12 +50,6 @@ class SourceRangeDeserializer {
return source;
}
std::unordered_map<
c10::intrusive_ptr<c10::ivalue::Tuple>,
std::shared_ptr<Source>>
cached_sources;
};
c10::IValue SourceRangeSerializer::serialize(const SourceRange& sr) {
std::vector<c10::IValue> elements = {
serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end()};
@ -84,11 +74,20 @@ c10::IValue SourceRangeSerializer::serialize_source(
SourceRangePickler::SourceRangePickler() : srs(new SourceRangeSerializer()) {}
std::vector<char> SourceRangePickler::pickle(const SourceRangeRecords& ranges) {
std::vector<char> SourceRangePickler::pickle(
const SourceRangeRecords& ranges,
const SourceRangeTagMap& source_range_tags) {
std::vector<c10::IValue> ivalues;
for (const auto& range : ranges) {
int64_t source_range_tag{-1};
const auto& it = source_range_tags.find(range.range);
if (it != source_range_tags.end()) {
source_range_tag = it->second;
}
std::vector<c10::IValue> row_elems{
(int64_t)range.bytes, srs->serialize(range.range)};
(int64_t)range.bytes,
srs->serialize(range.range),
static_cast<int64_t>(source_range_tag)};
ivalues.emplace_back(c10::ivalue::Tuple::create(std::move(row_elems)));
}
std::vector<at::Tensor> table;
@ -118,8 +117,8 @@ void ConcreteSourceRangeUnpickler::unpickle() {
unpickled_records = std::make_shared<SourceRangeRecords>();
for (auto& val : ivalues) {
auto tup_elems = val.toTuple()->elements();
int64_t offset = tup_elems[0].toInt();
auto source_range = deserializer->deserialize(tup_elems[1]);
int64_t offset = tup_elems[kByteOffsetIndex].toInt();
auto source_range = deserializer->deserialize(tup_elems[kSourceRangeIndex]);
unpickled_records->emplace_back(offset, std::move(source_range));
}
}

View File

@ -3,6 +3,8 @@
#include <c10/core/Allocator.h>
#include <torch/csrc/jit/frontend/source_range.h>
#include <ATen/core/ivalue.h>
#include <unordered_map>
#include <vector>
@ -15,18 +17,34 @@ namespace jit {
class Pickler;
class SourceRangeSerializer;
class SourceRangeDeserializer;
static constexpr size_t kByteOffsetIndex = 0;
static constexpr size_t kSourceRangeIndex = 1;
static constexpr size_t kSourceRangeTagIndex = 2;
class SourceRangePickler {
public:
SourceRangePickler();
std::vector<char> pickle(const SourceRangeRecords& ranges);
std::vector<char> pickle(
const SourceRangeRecords& ranges,
const SourceRangeTagMap& source_range_tags);
private:
std::shared_ptr<SourceRangeSerializer> srs;
};
class SourceRangeDeserializer {
public:
SourceRange deserialize(const c10::IValue& iv);
private:
std::shared_ptr<Source> deserialize_source(const c10::IValue& iv);
std::unordered_map<
c10::intrusive_ptr<c10::ivalue::Tuple>,
std::shared_ptr<Source>>
cached_sources;
};
class SourceRangeUnpickler {
public:
virtual c10::optional<SourceRange> findSourceRangeThatGenerated(

View File

@ -240,7 +240,8 @@ def get_model_info(
code_parts = []
for di, di_next in zip(debug_info, debug_info[1:]):
start, source_range = di
# accounting for source range serialization format change
start, source_range, _ = di
end = di_next[0]
assert end > start
source, s_start, s_end = source_range