#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include // Tests go in torch::jit namespace torch { namespace jit { namespace { bool validate_debug_info( const DebugInfoTuple& pre_serialize, const DebugInfoTuple& post_serialize) { auto sr1 = std::get(pre_serialize); auto sr2 = std::get(post_serialize); if (sr1 != sr2) { return false; } auto csptr1 = std::get(pre_serialize); auto csptr2 = std::get(post_serialize); if (!csptr1.defined()) { return !csptr2.defined(); } if (!csptr2.defined()) { return false; } auto vec1 = csptr1->vec(); auto vec2 = csptr2->vec(); if (vec1.size() != vec2.size()) { return false; } while (csptr1) { auto rhs_sr = csptr1->source_range(); auto lhs_sr = csptr2->source_range(); auto rhs_module = csptr1->module_instance(); auto lhs_module = csptr2->module_instance(); std::string rhs_fn_name, lhs_fn_name; if (csptr1->function()) { rhs_fn_name = csptr1->function()->name(); } else { rhs_fn_name = csptr1->function_name(); } if (csptr2->function()) { lhs_fn_name = csptr2->function()->name(); } else { lhs_fn_name = csptr2->function_name(); } if (!((rhs_module.has_value() == lhs_module.has_value()) && (rhs_module.has_value() && (rhs_module.value().class_type()->name().value() == lhs_module.value().class_type()->name().value()) && (rhs_module.value().instance_name() == lhs_module.value().instance_name())) && (rhs_fn_name == lhs_fn_name) && (rhs_sr == lhs_sr))) { return false; } if (csptr1->callee()) { csptr1 = csptr1->callee().value(); csptr2 = csptr2->callee().value(); } else { csptr1 = c10::intrusive_ptr(); } } return true; } TEST(CSDebugInfoSerializaitionTest, TwoSubmodules) { std::shared_ptr cu = std::make_shared(); Module a("A", cu); a.define(R"JIT( def forward(self, x): return x + 1 )JIT"); Module b("B", cu); b.define(R"JIT( def forward(self, x): return x + 2 )JIT"); Module c("C", cu); c.register_module("A0", a); c.register_module("B0", b); c.define(R"JIT( def forward(self, x): return self.A0.forward(x) + self.B0.forward(x) )JIT"); BackendDebugInfoRecorder debug_info_recorder; auto graph = c.get_method("forward").graph(); Inline(*graph); std::stack blocks_to_visit; // maps from source range to debug handle SourceRangeTagMap source_range_tags; // Maps from debug handle to source range ska::flat_hash_map source_range_map; int64_t source_range_tag{0}; blocks_to_visit.push(graph->block()); while (!blocks_to_visit.empty()) { Block* b = blocks_to_visit.top(); blocks_to_visit.pop(); for (Node* n : b->nodes()) { source_range_tags[n->sourceRange()] = source_range_tag; source_range_map[source_range_tag] = n->sourceRange(); source_range_tag++; debug_info_recorder.getNextDebugHandle(n); if (n->callstack().has_value()) { for (const auto& e : n->callstack().value()->vec()) { auto sr = std::get<1>(e); source_range_tags[sr] = source_range_tag; source_range_map[source_range_tag] = sr; source_range_tag++; } } } } auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording(); CallStackDebugInfoPickler cs_debug_info_pickler; auto cs_data = cs_debug_info_pickler.pickle(debug_handle_cs_ptr_map, source_range_tags); at::DataPtr data_ptr(cs_data.data(), DeviceType::CPU); CallStackDebugInfoUnpickler unpickler; auto deserialized_cs_map = unpickler.unpickle( std::move(data_ptr), cs_data.size(), source_range_map, cu); for (const auto& it : debug_handle_cs_ptr_map) { auto handle = it.first; auto debug_info_one = it.second; TORCH_CHECK( deserialized_cs_map.count(handle), "Serialized debug handle must be in deserialized map."); auto debug_info_two = deserialized_cs_map[handle]; ASSERT_TRUE(validate_debug_info(debug_info_one, debug_info_two)); } } } // namespace } // namespace jit } // namespace torch