mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-26 16:44:54 +08:00 
			
		
		
		
	Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58783 This reverts commit fc804b5def5e7d7ecad24c4d1ca4ac575e588ae8. Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D28617037 Pulled By: zhxchen17 fbshipit-source-id: 645de2ede20500a5c218d6ec3c7faae94de37a14
		
			
				
	
	
		
			63 lines
		
	
	
		
			1.6 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			63 lines
		
	
	
		
			1.6 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #include <gtest/gtest.h>
 | |
| 
 | |
| #include <c10/util/Optional.h>
 | |
| #include <test/cpp/jit/test_utils.h>
 | |
| #include <torch/csrc/jit/ir/ir.h>
 | |
| #include <torch/csrc/jit/ir/irparser.h>
 | |
| #include <torch/csrc/jit/runtime/script_profile.h>
 | |
| 
 | |
| namespace torch {
 | |
| namespace jit {
 | |
| 
 | |
| TEST(ScriptProfileTest, Basic) {
 | |
|   const std::string source_string = R"V0G0N(
 | |
|     def foo(a, b):
 | |
|       return a + b #
 | |
|   )V0G0N";
 | |
|   auto begin = source_string.find("return");
 | |
|   auto end = source_string.find(" #");
 | |
| 
 | |
|   Graph g;
 | |
|   const auto graph_string = R"IR(
 | |
|     graph(%a : Tensor,
 | |
|           %b : Tensor):
 | |
|       %2 : int = prim::Constant[value=1]()
 | |
|       %3 : Tensor = aten::add(%a, %b, %2)
 | |
|       return (%3))IR";
 | |
| 
 | |
|   torch::jit::parseIR(graph_string, &g);
 | |
|   auto source = std::make_shared<Source>(source_string, "", 0);
 | |
|   auto node = *g.nodes().begin();
 | |
|   node->setSourceRange(SourceRange{source, begin, end});
 | |
| 
 | |
|   ScriptProfile p;
 | |
|   p.enable();
 | |
|   {
 | |
|     profiling::InstructionSpan g0(*node);
 | |
|     profiling::InstructionSpan g1(*node);
 | |
|     profiling::InstructionSpan g2(*node);
 | |
|   }
 | |
|   p.disable();
 | |
| 
 | |
|   auto stats = p.dumpStats();
 | |
|   EXPECT_EQ(stats.size(), 1);
 | |
|   auto it = stats.find(*source.get());
 | |
|   EXPECT_NE(it, stats.end());
 | |
|   auto& lines = it->second;
 | |
|   EXPECT_EQ(lines.size(), 1);
 | |
|   const auto& stat = lines.at(source->lineno_for_offset(begin));
 | |
|   EXPECT_EQ(stat.count, 3);
 | |
| }
 | |
| 
 | |
| TEST(ScriptProfileTest, CallingOrder) {
 | |
|   ScriptProfile p;
 | |
|   p.enable();
 | |
|   EXPECT_THROW(p.dumpStats(), c10::Error);
 | |
|   p.disable();
 | |
|   auto dp = std::make_shared<profiling::Datapoint>(SourceRange{});
 | |
|   EXPECT_THROW(p.addDatapoint(std::move(dp)), c10::Error);
 | |
| }
 | |
| 
 | |
| } // namespace jit
 | |
| } // namespace torch
 |