Files
pytorch/test/cpp/jit/test_cs_debug_info_serialization.cpp
Kimish Patel ede3f5421f [Pytorch Delegated Backend] Save function name in debug info (#57481)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57481

This diff introduces function name to InlinedCallStack.
Since we are using InlinedCallStack for debug information in lite
interpreter as well as delegate backends, where InlinedCallStack cannot
be constructed from model source code, we need to save function name.
In the absence of function name Function* is used to get name of the
function. This is when JIT compiles code at runtime.
When that is not possible, this diff introduces a way to obtain function
name.

Test Plan:
test_backend
test_cs_debug_info_serialization

test_backend
test_cs_debug_info_serialization

Imported from OSS

Differential Revision:
D28159097
D28159097

Reviewed By: raziel, ZolotukhinM

Pulled By: kimishpatel

fbshipit-source-id: deacaea3325e27273f92ae96cf0cd0789bbd6e72
2021-05-25 13:19:02 -07:00

156 lines
4.9 KiB
C++

#include <test/cpp/jit/test_utils.h>
#include <gtest/gtest.h>
#include <c10/core/TensorOptions.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/backends/backend_debug_handler.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/custom_class.h>
#include <torch/torch.h>
#include <stack>
#include <unordered_set>
// Tests go in torch::jit
namespace torch {
namespace jit {
namespace {
bool validate_debug_info(
const DebugInfoTuple& pre_serialize,
const DebugInfoTuple& post_serialize) {
auto sr1 = std::get<kDebugInfoTupleSourceRangeIndex>(pre_serialize);
auto sr2 = std::get<kDebugInfoTupleSourceRangeIndex>(post_serialize);
if (sr1 != sr2) {
return false;
}
auto csptr1 = std::get<kDebugInfoTupleInlinedCSIndex>(pre_serialize);
auto csptr2 = std::get<kDebugInfoTupleInlinedCSIndex>(post_serialize);
if (!csptr1.defined()) {
return !csptr2.defined();
}
if (!csptr2.defined()) {
return false;
}
auto vec1 = csptr1->vec();
auto vec2 = csptr2->vec();
if (vec1.size() != vec2.size()) {
return false;
}
while (csptr1) {
auto rhs_sr = csptr1->source_range();
auto lhs_sr = csptr2->source_range();
auto rhs_module = csptr1->module_instance();
auto lhs_module = csptr2->module_instance();
std::string rhs_fn_name, lhs_fn_name;
if (csptr1->function()) {
rhs_fn_name = csptr1->function()->name();
} else {
rhs_fn_name = csptr1->function_name();
}
if (csptr2->function()) {
lhs_fn_name = csptr2->function()->name();
} else {
lhs_fn_name = csptr2->function_name();
}
if (!((rhs_module.has_value() == lhs_module.has_value()) &&
(rhs_module.has_value() &&
(rhs_module.value().class_type()->name().value() ==
lhs_module.value().class_type()->name().value()) &&
(rhs_module.value().instance_name() ==
lhs_module.value().instance_name())) &&
(rhs_fn_name == lhs_fn_name) && (rhs_sr == lhs_sr))) {
return false;
}
if (csptr1->callee()) {
csptr1 = csptr1->callee().value();
csptr2 = csptr2->callee().value();
} else {
csptr1 = c10::intrusive_ptr<InlinedCallStack>();
}
}
return true;
}
TEST(CSDebugInfoSerializaitionTest, TwoSubmodules) {
std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
Module a("A", cu);
a.define(R"JIT(
def forward(self, x):
return x + 1
)JIT");
Module b("B", cu);
b.define(R"JIT(
def forward(self, x):
return x + 2
)JIT");
Module c("C", cu);
c.register_module("A0", a);
c.register_module("B0", b);
c.define(R"JIT(
def forward(self, x):
return self.A0.forward(x) + self.B0.forward(x)
)JIT");
BackendDebugInfoRecorder debug_info_recorder;
auto graph = c.get_method("forward").graph();
Inline(*graph);
std::stack<Block*> blocks_to_visit;
// maps from source range to debug handle
SourceRangeTagMap source_range_tags;
// Maps from debug handle to source range
ska::flat_hash_map<int64_t, SourceRange> source_range_map;
int64_t source_range_tag{0};
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
source_range_tags[n->sourceRange()] = source_range_tag;
source_range_map[source_range_tag] = n->sourceRange();
source_range_tag++;
debug_info_recorder.getNextDebugHandle(n);
if (n->callstack().has_value()) {
for (const auto& e : n->callstack().value()->vec()) {
auto sr = std::get<1>(e);
source_range_tags[sr] = source_range_tag;
source_range_map[source_range_tag] = sr;
source_range_tag++;
}
}
}
}
auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording();
CallStackDebugInfoPickler cs_debug_info_pickler;
auto cs_data =
cs_debug_info_pickler.pickle(debug_handle_cs_ptr_map, source_range_tags);
at::DataPtr data_ptr(cs_data.data(), DeviceType::CPU);
CallStackDebugInfoUnpickler unpickler;
auto deserialized_cs_map = unpickler.unpickle(
std::move(data_ptr), cs_data.size(), source_range_map, cu);
for (const auto& it : debug_handle_cs_ptr_map) {
auto handle = it.first;
auto debug_info_one = it.second;
TORCH_CHECK(
deserialized_cs_map.count(handle),
"Serialized debug handle must be in deserialized map.");
auto debug_info_two = deserialized_cs_map[handle];
ASSERT_TRUE(validate_debug_info(debug_info_one, debug_info_two));
}
}
} // namespace
} // namespace jit
} // namespace torch