#include #include #include #include #include #include namespace torch { namespace jit { TEST(ScriptProfileTest, Basic) { const std::string source_string = R"V0G0N( def foo(a, b): return a + b # )V0G0N"; auto begin = source_string.find("return"); auto end = source_string.find(" #"); Graph g; const auto graph_string = R"IR( graph(%a : Tensor, %b : Tensor): %2 : int = prim::Constant[value=1]() %3 : Tensor = aten::add(%a, %b, %2) return (%3))IR"; torch::jit::parseIR(graph_string, &g); auto source = std::make_shared(source_string, "", 0); auto node = *g.nodes().begin(); node->setSourceRange(SourceRange{source, begin, end}); ScriptProfile p; p.enable(); { profiling::InstructionSpan g0(*node); profiling::InstructionSpan g1(*node); profiling::InstructionSpan g2(*node); } p.disable(); auto stats = p.dumpStats(); EXPECT_EQ(stats.size(), 1); auto it = stats.find(*source.get()); EXPECT_NE(it, stats.end()); auto& lines = it->second; EXPECT_EQ(lines.size(), 1); const auto& stat = lines.at(source->lineno_for_offset(begin)); EXPECT_EQ(stat.count, 3); } TEST(ScriptProfileTest, CallingOrder) { ScriptProfile p; p.enable(); EXPECT_THROW(p.dumpStats(), c10::Error); p.disable(); auto dp = std::make_shared(SourceRange{}); EXPECT_THROW(p.addDatapoint(std::move(dp)), c10::Error); } } // namespace jit } // namespace torch