mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57481 This diff introduces function name to InlinedCallStack. Since we are using InlinedCallStack for debug information in lite interpreter as well as delegate backends, where InlinedCallStack cannot be constructed from model source code, we need to save function name. In the absence of function name Function* is used to get name of the function. This is when JIT compiles code at runtime. When that is not possible, this diff introduces a way to obtain function name. Test Plan: test_backend test_cs_debug_info_serialization test_backend test_cs_debug_info_serialization Imported from OSS Differential Revision: D28159097 D28159097 Reviewed By: raziel, ZolotukhinM Pulled By: kimishpatel fbshipit-source-id: deacaea3325e27273f92ae96cf0cd0789bbd6e72
156 lines
4.9 KiB
C++
156 lines
4.9 KiB
C++
#include <test/cpp/jit/test_utils.h>
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <c10/core/TensorOptions.h>
|
|
#include <torch/csrc/autograd/generated/variable_factories.h>
|
|
#include <torch/csrc/jit/api/module.h>
|
|
#include <torch/csrc/jit/backends/backend_debug_handler.h>
|
|
#include <torch/csrc/jit/frontend/resolver.h>
|
|
#include <torch/csrc/jit/mobile/import.h>
|
|
#include <torch/csrc/jit/mobile/module.h>
|
|
#include <torch/csrc/jit/passes/inliner.h>
|
|
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
|
|
#include <torch/csrc/jit/serialization/export.h>
|
|
#include <torch/csrc/jit/serialization/import.h>
|
|
#include <torch/custom_class.h>
|
|
#include <torch/torch.h>
|
|
|
|
#include <stack>
|
|
#include <unordered_set>
|
|
|
|
// Tests go in torch::jit
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
namespace {
|
|
bool validate_debug_info(
|
|
const DebugInfoTuple& pre_serialize,
|
|
const DebugInfoTuple& post_serialize) {
|
|
auto sr1 = std::get<kDebugInfoTupleSourceRangeIndex>(pre_serialize);
|
|
auto sr2 = std::get<kDebugInfoTupleSourceRangeIndex>(post_serialize);
|
|
if (sr1 != sr2) {
|
|
return false;
|
|
}
|
|
auto csptr1 = std::get<kDebugInfoTupleInlinedCSIndex>(pre_serialize);
|
|
auto csptr2 = std::get<kDebugInfoTupleInlinedCSIndex>(post_serialize);
|
|
if (!csptr1.defined()) {
|
|
return !csptr2.defined();
|
|
}
|
|
if (!csptr2.defined()) {
|
|
return false;
|
|
}
|
|
auto vec1 = csptr1->vec();
|
|
auto vec2 = csptr2->vec();
|
|
if (vec1.size() != vec2.size()) {
|
|
return false;
|
|
}
|
|
while (csptr1) {
|
|
auto rhs_sr = csptr1->source_range();
|
|
auto lhs_sr = csptr2->source_range();
|
|
auto rhs_module = csptr1->module_instance();
|
|
auto lhs_module = csptr2->module_instance();
|
|
std::string rhs_fn_name, lhs_fn_name;
|
|
if (csptr1->function()) {
|
|
rhs_fn_name = csptr1->function()->name();
|
|
} else {
|
|
rhs_fn_name = csptr1->function_name();
|
|
}
|
|
if (csptr2->function()) {
|
|
lhs_fn_name = csptr2->function()->name();
|
|
} else {
|
|
lhs_fn_name = csptr2->function_name();
|
|
}
|
|
if (!((rhs_module.has_value() == lhs_module.has_value()) &&
|
|
(rhs_module.has_value() &&
|
|
(rhs_module.value().class_type()->name().value() ==
|
|
lhs_module.value().class_type()->name().value()) &&
|
|
(rhs_module.value().instance_name() ==
|
|
lhs_module.value().instance_name())) &&
|
|
(rhs_fn_name == lhs_fn_name) && (rhs_sr == lhs_sr))) {
|
|
return false;
|
|
}
|
|
if (csptr1->callee()) {
|
|
csptr1 = csptr1->callee().value();
|
|
csptr2 = csptr2->callee().value();
|
|
} else {
|
|
csptr1 = c10::intrusive_ptr<InlinedCallStack>();
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
TEST(CSDebugInfoSerializaitionTest, TwoSubmodules) {
|
|
std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
|
|
Module a("A", cu);
|
|
a.define(R"JIT(
|
|
def forward(self, x):
|
|
return x + 1
|
|
)JIT");
|
|
Module b("B", cu);
|
|
b.define(R"JIT(
|
|
def forward(self, x):
|
|
return x + 2
|
|
)JIT");
|
|
Module c("C", cu);
|
|
c.register_module("A0", a);
|
|
c.register_module("B0", b);
|
|
c.define(R"JIT(
|
|
def forward(self, x):
|
|
return self.A0.forward(x) + self.B0.forward(x)
|
|
)JIT");
|
|
|
|
BackendDebugInfoRecorder debug_info_recorder;
|
|
auto graph = c.get_method("forward").graph();
|
|
Inline(*graph);
|
|
std::stack<Block*> blocks_to_visit;
|
|
|
|
// maps from source range to debug handle
|
|
SourceRangeTagMap source_range_tags;
|
|
// Maps from debug handle to source range
|
|
ska::flat_hash_map<int64_t, SourceRange> source_range_map;
|
|
int64_t source_range_tag{0};
|
|
|
|
blocks_to_visit.push(graph->block());
|
|
while (!blocks_to_visit.empty()) {
|
|
Block* b = blocks_to_visit.top();
|
|
blocks_to_visit.pop();
|
|
for (Node* n : b->nodes()) {
|
|
source_range_tags[n->sourceRange()] = source_range_tag;
|
|
source_range_map[source_range_tag] = n->sourceRange();
|
|
source_range_tag++;
|
|
debug_info_recorder.getNextDebugHandle(n);
|
|
if (n->callstack().has_value()) {
|
|
for (const auto& e : n->callstack().value()->vec()) {
|
|
auto sr = std::get<1>(e);
|
|
source_range_tags[sr] = source_range_tag;
|
|
source_range_map[source_range_tag] = sr;
|
|
source_range_tag++;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording();
|
|
CallStackDebugInfoPickler cs_debug_info_pickler;
|
|
auto cs_data =
|
|
cs_debug_info_pickler.pickle(debug_handle_cs_ptr_map, source_range_tags);
|
|
at::DataPtr data_ptr(cs_data.data(), DeviceType::CPU);
|
|
CallStackDebugInfoUnpickler unpickler;
|
|
auto deserialized_cs_map = unpickler.unpickle(
|
|
std::move(data_ptr), cs_data.size(), source_range_map, cu);
|
|
for (const auto& it : debug_handle_cs_ptr_map) {
|
|
auto handle = it.first;
|
|
auto debug_info_one = it.second;
|
|
TORCH_CHECK(
|
|
deserialized_cs_map.count(handle),
|
|
"Serialized debug handle must be in deserialized map.");
|
|
auto debug_info_two = deserialized_cs_map[handle];
|
|
ASSERT_TRUE(validate_debug_info(debug_info_one, debug_info_two));
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|