mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Refactor saving jit::Module to mobile .pt in 2 steps: (#66494)
Summary:
1. is to convert Function -> mobile::Function
2. is to serialize mobile::Function
This also opens opportunity to create mobile::Module without saving/reloading
Fixes #{issue number}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66494
Reviewed By: zhxchen17
Differential Revision: D32293022
Pulled By: qihqi
fbshipit-source-id: 29b43d47ff86071d5e2f9d6ca4dba4445711ce3d
			
			
This commit is contained in:
		
				
					committed by
					
						 Facebook GitHub Bot
						Facebook GitHub Bot
					
				
			
			
				
	
			
			
			
						parent
						
							e2aeb4a7af
						
					
				
				
					commit
					4eb772fde6
				
			| @ -67,6 +67,7 @@ set(JIT_TEST_SRCS | ||||
|   ${JIT_TEST_ROOT}/test_irparser.cpp | ||||
|   ${JIT_TEST_ROOT}/test_jit_type.cpp | ||||
|   ${JIT_TEST_ROOT}/test_lite_interpreter.cpp | ||||
|   ${JIT_TEST_ROOT}/test_lite_interpreter_direct.cpp | ||||
|   ${JIT_TEST_ROOT}/test_lite_trainer.cpp | ||||
|   ${JIT_TEST_ROOT}/test_memory_dag.cpp | ||||
|   ${JIT_TEST_ROOT}/test_misc.cpp | ||||
|  | ||||
							
								
								
									
										921
									
								
								test/cpp/jit/test_lite_interpreter_direct.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										921
									
								
								test/cpp/jit/test_lite_interpreter_direct.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,921 @@ | ||||
| #include <test/cpp/jit/test_utils.h> | ||||
|  | ||||
| #include <gtest/gtest.h> | ||||
|  | ||||
| #include <c10/core/TensorOptions.h> | ||||
| #include <torch/csrc/autograd/generated/variable_factories.h> | ||||
| #include <torch/csrc/jit/api/module.h> | ||||
| #include <torch/csrc/jit/frontend/resolver.h> | ||||
| #include <torch/csrc/jit/mobile/backport.h> | ||||
| #include <torch/csrc/jit/mobile/backport_manager.h> | ||||
| #include <torch/csrc/jit/mobile/import.h> | ||||
| #include <torch/csrc/jit/mobile/interpreter.h> | ||||
| #include <torch/csrc/jit/mobile/model_compatibility.h> | ||||
| #include <torch/csrc/jit/mobile/module.h> | ||||
| #include <torch/csrc/jit/mobile/parse_bytecode.h> | ||||
| #include <torch/csrc/jit/mobile/parse_operators.h> | ||||
| #include <torch/csrc/jit/mobile/runtime_compatibility.h> | ||||
| #include <torch/csrc/jit/serialization/export.h> | ||||
| #include <torch/csrc/jit/serialization/export_bytecode.h> | ||||
| #include <torch/csrc/jit/serialization/import.h> | ||||
| #include <torch/custom_class.h> | ||||
| #include <torch/torch.h> | ||||
|  | ||||
| #include <unordered_set> | ||||
|  | ||||
| // Tests go in torch::jit | ||||
| namespace torch { | ||||
| namespace jit { | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, UpsampleNearest2d) { | ||||
|   Module m("m"); | ||||
|   m.define(R"( | ||||
|     def forward(self, input: Tensor, scale:float): | ||||
|       return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) | ||||
|   )"); | ||||
|  | ||||
|   std::vector<IValue> inputs; | ||||
|   inputs.emplace_back(torch::rand({1, 3, 128, 128})); | ||||
|   inputs.emplace_back(at::Scalar(2.0)); | ||||
|   auto ref = m.forward(inputs); | ||||
|  | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   IValue res; | ||||
|   res = bc.forward(inputs); | ||||
|  | ||||
|   auto resd = res.toTensor(); | ||||
|   auto refd = ref.toTensor(); | ||||
|   ASSERT_TRUE(resd.equal(refd)); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, CheckAttrAccess) { | ||||
|   Module m("m"); | ||||
|   m.register_attribute("mobile_optimized", BoolType::get(), true); | ||||
|  | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   bool mobile_optimized = bc.attr("mobile_optimized", false).toBool(); | ||||
|  | ||||
|   AT_ASSERT(mobile_optimized); | ||||
|   m.setattr("mobile_optimized", false); | ||||
|   bc = jitModuleToMobile(m, options); | ||||
|   mobile_optimized = bc.attr("mobile_optimized", false).toBool(); | ||||
|   AT_ASSERT(!mobile_optimized); | ||||
| } | ||||
|  | ||||
| TEST( | ||||
|     LiteInterpreterDirectTest, | ||||
|     MethodInvocation) { // NOLINT (use =delete in gtest) | ||||
|   const std::vector<std::string> test_programs{ | ||||
|       // test invoking a method with default parameter | ||||
|       R"( | ||||
|       def test_func(self, x, b : int = 4): | ||||
|         return self.foo + x + b | ||||
|       )", | ||||
|       // inner method call with default parameter (gets inlined) | ||||
|       R"( | ||||
|       def add_with_default_arg(self, x, b : int = 4): | ||||
|         return self.foo + x + b | ||||
|       def test_func(self, x): | ||||
|         return self.add_with_default_arg(x)  # invoke method w/ default arg | ||||
|       )", | ||||
|       // simple method call | ||||
|       R"( | ||||
|       def test_func(self, x): | ||||
|         b = 4 | ||||
|         return self.foo + x + b | ||||
|       )", | ||||
|   }; | ||||
|   for (const auto& test_program : test_programs) { | ||||
|     Module m("m"); | ||||
|     m.register_parameter("foo", torch::ones({}), false); | ||||
|     m.define(test_program); | ||||
|  | ||||
|     const int fortyTwo = 42; // (keep linter happy) | ||||
|     auto minput = fortyTwo * torch::ones({}); | ||||
|     auto ref = m.run_method("test_func", minput); | ||||
|  | ||||
|     CompilationOptions options; | ||||
|     mobile::Module bc = jitModuleToMobile(m, options); | ||||
|     const auto& test_func = bc.get_method("test_func"); | ||||
|     std::cerr << "hello " << std::endl; | ||||
|     IValue res; | ||||
|     for (int i = 0; i < 3; ++i) { | ||||
|       res = test_func({minput}); | ||||
|     } | ||||
|     std::cerr << "hello 3" << std::endl; | ||||
|  | ||||
|     auto resd = res.toTensor().item<float>(); | ||||
|     auto refd = ref.toTensor().item<float>(); | ||||
|     AT_ASSERT(resd == refd); | ||||
|   } | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, Conv) { | ||||
|   auto s = std::getenv("PYTORCH_TEST_WITH_TSAN"); | ||||
|   if (s && strcmp(s, "1") == 0) | ||||
|     return; | ||||
|  | ||||
|   std::vector<torch::jit::IValue> inputs; | ||||
|  | ||||
|   Module m("m"); | ||||
|   m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); | ||||
|   m.register_parameter("bias", torch::ones({20}), false); | ||||
|   m.define(R"( | ||||
|     def forward(self, input): | ||||
|       return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) | ||||
|   )"); | ||||
|  | ||||
|   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace) | ||||
|   inputs.push_back(torch::ones({1, 1, 28, 28})); | ||||
|  | ||||
|   auto outputref = m.forward(inputs).toTensor(); | ||||
|  | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   IValue res; | ||||
|   for (int i = 0; i < 3; ++i) { | ||||
|     res = bc.get_method("forward")(inputs); | ||||
|   } | ||||
|   auto output = res.toTensor(); | ||||
|   AT_ASSERT(outputref.dim() == output.dim()); | ||||
|   AT_ASSERT( | ||||
|       outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>()); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, Inline) { | ||||
|   Module m("m"); | ||||
|   m.define(R"JIT( | ||||
|   def foo1(self, x): | ||||
|       return x + 1 | ||||
|  | ||||
|   def foo2(self, x): | ||||
|       return self.foo1(x) + 2 | ||||
|  | ||||
|   def foo3(self, x): | ||||
|       return self.foo2(x) + 3 | ||||
|   )JIT"); | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   std::vector<torch::jit::IValue> inputs({torch::ones({})}); | ||||
|   auto output = bc.get_method("foo3")(inputs); | ||||
|   AT_ASSERT(output.toTensor().item<float>() == 7.0); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, Tuple) { | ||||
|   Module m("m"); | ||||
|   m.define(R"JIT( | ||||
|   def foo(self, x): | ||||
|       return (1, 2, x + 3) | ||||
|  | ||||
|   def forward(self, x): | ||||
|       tuple = self.foo(x) | ||||
|       return tuple | ||||
|   )JIT"); | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   std::vector<torch::jit::IValue> inputs({torch::ones({})}); | ||||
|   auto output = bc.get_method("forward")(inputs); | ||||
|   AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, Dict) { | ||||
|   Module m("m"); | ||||
|   m.define(R"JIT( | ||||
|   def foo(self, x): | ||||
|       return {"result": x + 1} | ||||
|  | ||||
|   def forward(self, x): | ||||
|       d = self.foo(x) | ||||
|       return d | ||||
|   )JIT"); | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   std::vector<torch::jit::IValue> inputs({torch::ones({})}); | ||||
|   auto output = bc.get_method("forward")(inputs); | ||||
|   AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, Prim) { | ||||
|   Module m("m"); | ||||
|   m.define(R"JIT( | ||||
|         def forward(self, x): | ||||
|             return int(x) | ||||
|   )JIT"); | ||||
|  | ||||
|   std::vector<IValue> inputs; | ||||
|   auto minput = 3.5 * torch::ones({}); | ||||
|   inputs.emplace_back(minput); | ||||
|   auto ref = m.run_method("forward", minput); | ||||
|  | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|  | ||||
|   IValue res; | ||||
|   for (int i = 0; i < 3; ++i) { | ||||
|     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) | ||||
|     auto bcinputs = inputs; | ||||
|     res = bc.get_method("forward")(bcinputs); | ||||
|   } | ||||
|  | ||||
|   auto resi = res.toInt(); | ||||
|   auto refi = ref.toInt(); | ||||
|   AT_ASSERT(resi == refi); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, PrimScalar) { | ||||
|   Module m("m"); | ||||
|   m.define(R"JIT( | ||||
|         def forward(self, x): | ||||
|             return int(x.item()) | ||||
|   )JIT"); | ||||
|  | ||||
|   std::vector<IValue> inputs; | ||||
|   auto minput = 3.5 * torch::ones({}); | ||||
|   inputs.emplace_back(minput); | ||||
|   auto ref = m.run_method("forward", minput); | ||||
|  | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   IValue res; | ||||
|   for (int i = 0; i < 3; ++i) { | ||||
|     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) | ||||
|     auto bcinputs = inputs; | ||||
|     res = bc.get_method("forward")(bcinputs); | ||||
|   } | ||||
|  | ||||
|   auto resi = res.toInt(); | ||||
|   auto refi = ref.toInt(); | ||||
|   AT_ASSERT(resi == refi); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, WrongMethodName) { | ||||
|   Module m("m"); | ||||
|   m.register_parameter("foo", torch::ones({}), false); | ||||
|   m.define(R"( | ||||
|     def add(self, x): | ||||
|       b = 4 | ||||
|       return self.foo + x + b | ||||
|   )"); | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   std::vector<IValue> inputs; | ||||
|   auto minput = 5 * torch::ones({}); | ||||
|   inputs.emplace_back(minput); | ||||
|   ASSERT_THROWS_WITH_MESSAGE( | ||||
|       bc.get_method("forward")(inputs), "is not defined"); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, SetState) { | ||||
|   Module m("m"); | ||||
|   m.register_parameter("foo", torch::ones({}), false); | ||||
|   m.define(R"( | ||||
|     def __getstate__(self): | ||||
|       return self.foo | ||||
|     def __setstate__(self, a): | ||||
|       self.foo = a | ||||
|     def forward(self, x): | ||||
|       b = 4 | ||||
|       return self.foo + x + b | ||||
|   )"); | ||||
|  | ||||
|   std::vector<IValue> inputs; | ||||
|   auto minput = 5 * torch::ones({}); | ||||
|   inputs.emplace_back(minput); | ||||
|  | ||||
|   std::stringstream ms; | ||||
|   m.save(ms); | ||||
|   auto loaded_m = load(ms); | ||||
|   auto ref = loaded_m.run_method("forward", minput); | ||||
|  | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   IValue res; | ||||
|   for (int i = 0; i < 3; ++i) { | ||||
|     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) | ||||
|     auto bcinputs = inputs; | ||||
|     res = bc.get_method("forward")(bcinputs); | ||||
|   } | ||||
|  | ||||
|   auto resd = res.toTensor().item<float>(); | ||||
|   auto refd = ref.toTensor().item<float>(); | ||||
|   AT_ASSERT(resd == refd); | ||||
| } | ||||
|  | ||||
| class TorchBindLiteInterpreterDirectTestStruct | ||||
|     : public torch::jit::CustomClassHolder { | ||||
|  public: | ||||
|   std::string get(at::Tensor t) { | ||||
|     std::stringstream ss; | ||||
|     ss << "Hello! Your tensor has "; | ||||
|     ss << t.numel(); | ||||
|     ss << " elements!"; | ||||
|     return ss.str(); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| namespace { | ||||
| struct ClassNamespaceValue : public SugaredValue { | ||||
|   explicit ClassNamespaceValue(c10::QualifiedName name) | ||||
|       : basename_(std::move(name)) {} | ||||
|  | ||||
|   std::shared_ptr<SugaredValue> attr( | ||||
|       const SourceRange&, | ||||
|       GraphFunction&, | ||||
|       const std::string& name) override { | ||||
|     const auto fullName = c10::QualifiedName(basename_, name); | ||||
|  | ||||
|     // Check to see if it is a custom class. | ||||
|     if (auto custom_class = getCustomClass(fullName.qualifiedName())) { | ||||
|       return std::make_shared<ClassValue>(custom_class); | ||||
|     } | ||||
|  | ||||
|     // If it's not a custom class, assume it's another namespace | ||||
|     // NOLINTNEXTLINE(performance-move-const-arg) | ||||
|     return std::make_shared<ClassNamespaceValue>(fullName); | ||||
|   } | ||||
|  | ||||
|   std::string kind() const override { | ||||
|     return "Class Namespace"; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   c10::QualifiedName basename_; | ||||
| }; | ||||
|  | ||||
| struct TestModuleResolver : public Resolver { | ||||
|   std::shared_ptr<SugaredValue> resolveValue( | ||||
|       const std::string& name, | ||||
|       GraphFunction&, | ||||
|       const SourceRange&) override { | ||||
|     if (name == "torch") { | ||||
|       return std::make_shared<BuiltinModule>("aten"); | ||||
|     } else if (name == "__torch__") { | ||||
|       return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name)); | ||||
|     } | ||||
|  | ||||
|     return nullptr; | ||||
|   } | ||||
|  | ||||
|   TypePtr resolveType(const std::string&, const SourceRange&) override { | ||||
|     return nullptr; | ||||
|   } | ||||
| }; | ||||
| } // namespace | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, BuiltinFunction) { | ||||
|   script::Module m("m"); | ||||
|   auto custom_class_obj = | ||||
|       make_custom_class<TorchBindLiteInterpreterDirectTestStruct>(); | ||||
|   m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj); | ||||
|   m.define(R"( | ||||
|     def forward(self, x) -> str: | ||||
|       return self.my_obj.get(x) | ||||
|   )"); | ||||
|  | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   auto res = | ||||
|       bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})}); | ||||
|   // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) | ||||
|   auto str = res.toStringRef(); | ||||
|   std::string expected = "Hello! Your tensor has 12 elements!"; | ||||
|   AT_ASSERT(str == expected); | ||||
| } | ||||
|  | ||||
| #if !defined FB_XPLAT_BUILD | ||||
| TEST(LiteInterpreterDirectTest, GetRuntimeByteCodeVersion) { | ||||
|   auto runtime_bytecode_version = _get_runtime_bytecode_version(); | ||||
|   AT_ASSERT( | ||||
|       runtime_bytecode_version == | ||||
|       caffe2::serialize::kMaxSupportedBytecodeVersion); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, GetRuntimeOperatorsVersion) { | ||||
|   auto runtime_operators_version = _get_runtime_operators_min_max_versions(); | ||||
|   AT_ASSERT( | ||||
|       runtime_operators_version.first == | ||||
|           caffe2::serialize::kMinSupportedFileFormatVersion && | ||||
|       runtime_operators_version.second == | ||||
|           caffe2::serialize::kMaxSupportedFileFormatVersion); | ||||
| } | ||||
|  | ||||
| /** | ||||
|  * The test below is disarmed for FB internal xplat builds since | ||||
|  * BUCK requires us to pass in the script_module_v4.ptl file in | ||||
|  * as a resource dependency of the build rule for this file, and | ||||
|  * we would need to access it via the C++ Resources API instead | ||||
|  * of directly reading from disk (which is what the open source | ||||
|  * build/run does). | ||||
|  */ | ||||
| TEST(LiteInterpreterDirectTest, GetByteCodeVersion) { | ||||
|   std::string filePath(__FILE__); | ||||
|   auto test_model_file_v4 = | ||||
|       filePath.substr(0, filePath.find_last_of("/\\") + 1); | ||||
|   test_model_file_v4.append("script_module_v4.ptl"); | ||||
|  | ||||
|   auto version_v4 = _get_model_bytecode_version(test_model_file_v4); | ||||
|   AT_ASSERT(version_v4 == 4); | ||||
| } | ||||
|  | ||||
| #endif // !defined(FB_XPLAT_BUILD) | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, GetRuntimeOpsAndInfo) { | ||||
|   auto runtime_ops = _get_runtime_ops_and_info(); | ||||
|   // Ballpark estimate of the minimal number of ops; just used to | ||||
|   // verify API returns a reasonably large number. | ||||
|   AT_ASSERT(runtime_ops.size() > 2900); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, Eval) { | ||||
|   std::vector<torch::jit::IValue> inputs; | ||||
|  | ||||
|   Module m("m"); | ||||
|   m.define(R"( | ||||
|     def __init__(self, x): | ||||
|       self.training = True | ||||
|  | ||||
|     def forward(self, input): | ||||
|       return torch.dropout(input, 1.0, self.training) | ||||
|   )"); | ||||
|  | ||||
|   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace) | ||||
|   inputs.push_back(torch::ones({1, 1, 28, 28})); | ||||
|   m.eval(); | ||||
|   auto outputref = m.forward(inputs).toTensor(); | ||||
|  | ||||
|   // save m in training mode to make sure that mobile eval() will correctly | ||||
|   // change back to eval mode | ||||
|   m.train(); | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   bc.eval(); | ||||
|   IValue res; | ||||
|   for (int i = 0; i < 3; ++i) { | ||||
|     res = bc.get_method("forward")(inputs); | ||||
|   } | ||||
|   auto output = res.toTensor(); | ||||
|   AT_ASSERT(outputref.dim() == output.dim()); | ||||
|   AT_ASSERT( | ||||
|       outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>()); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, FindWrongMethodName) { | ||||
|   Module m("m"); | ||||
|   m.register_parameter("foo", torch::ones({}), false); | ||||
|   m.define(R"( | ||||
|     def add(self, x): | ||||
|       b = 4 | ||||
|       return self.foo + x + b | ||||
|   )"); | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   ASSERT_TRUE(bc.find_method("forward") == c10::nullopt); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, FindAndRunMethod) { | ||||
|   Module m("m"); | ||||
|   m.register_parameter("foo", torch::ones({}), false); | ||||
|   m.define(R"( | ||||
|     def add_it(self, x): | ||||
|       b = 4 | ||||
|       return self.foo + x + b | ||||
|   )"); | ||||
|  | ||||
|   std::vector<IValue> inputs; | ||||
|   auto minput = 5 * torch::ones({}); | ||||
|   inputs.emplace_back(minput); | ||||
|   auto ref = m.get_method("add_it")(inputs); | ||||
|  | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   IValue res; | ||||
|   for (int i = 0; i < 3; ++i) { | ||||
|     auto bcinputs = inputs; | ||||
|     auto method = bc.find_method("add_it"); | ||||
|     AT_ASSERT(method != c10::nullopt); | ||||
|     res = (*method)(std::move(bcinputs)); | ||||
|   } | ||||
|  | ||||
|   auto resd = res.toTensor().item<float>(); | ||||
|   auto refd = ref.toTensor().item<float>(); | ||||
|   AT_ASSERT(resd == refd); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, RunMethodVariadic) { | ||||
|   Module m("m"); | ||||
|   m.register_parameter("foo", torch::ones({}), false); | ||||
|   m.define(R"( | ||||
|     def add_three(self, x, y): | ||||
|       return self.foo + x + y | ||||
|   )"); | ||||
|  | ||||
|   std::vector<IValue> inputs; | ||||
|   auto inputx = 5 * torch::ones({}); | ||||
|   auto inputy = 4 * torch::ones({}); | ||||
|   auto ref = m.run_method("add_three", inputx, inputy); | ||||
|  | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   IValue res = bc.run_method("add_three", inputx, inputy); | ||||
|  | ||||
|   auto resd = res.toTensor().item<float>(); | ||||
|   auto refd = ref.toTensor().item<float>(); | ||||
|   AT_ASSERT(resd == refd); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, DuplicateSetState) { | ||||
|   Module m("M"); | ||||
|   m.register_parameter("foo", torch::ones({}), false); | ||||
|   m.define(R"( | ||||
|     def __getstate__(self): | ||||
|       return self.foo + self.foo | ||||
|     def __setstate__(self, a): | ||||
|       self.foo = a | ||||
|     def forward(self, x): | ||||
|       b = 4 | ||||
|       return self.foo + x + b | ||||
|   )"); | ||||
|  | ||||
|   Module b("B"); | ||||
|   b.register_module("M0", m); | ||||
|   b.register_module("M1", m); | ||||
|   b.define(R"( | ||||
|     def forward(self, x): | ||||
|       return self.M0.forward(x) + self.M1.forward(x) | ||||
|   )"); | ||||
|  | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   const auto methods = bc.get_methods(); | ||||
|   const size_t expected_n = 3; | ||||
|   ASSERT_EQ(methods.size(), expected_n); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, OpNameExportFetchRootOperators) { | ||||
|   torch::jit::Module m("m"); | ||||
|   m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); | ||||
|   m.register_parameter("bias", torch::ones({20}), false); | ||||
|   m.define(R"( | ||||
|     def forward(self, input): | ||||
|       x1 = torch.zeros(2, 2) | ||||
|       x2 = torch.empty_like(torch.empty(2, 2)) | ||||
|       x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) | ||||
|       return (x1, x2, x3) | ||||
|   )"); | ||||
|   m.eval(); | ||||
|  | ||||
|   CompilationOptions options; | ||||
|   mobile::Module ptl_model = jitModuleToMobile(m, options); | ||||
|   std::set<std::string> operator_names = | ||||
|       torch::jit::mobile::_export_operator_list(ptl_model); | ||||
|   std::set<std::string> expected_operator_names = { | ||||
|       "aten::_convolution", | ||||
|       "aten::empty.memory_format", | ||||
|       "aten::empty_like", | ||||
|       "aten::zeros", | ||||
|   }; | ||||
|   EXPECT_EQ(operator_names, expected_operator_names) | ||||
|       << "Expected the root operator lists to be the same"; | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, DefaultArgsConv) { | ||||
|   auto s = std::getenv("PYTORCH_TEST_WITH_TSAN"); | ||||
|   if (s && strcmp(s, "1") == 0) | ||||
|     return; | ||||
|  | ||||
|   std::vector<torch::jit::IValue> inputs; | ||||
|  | ||||
|   Module m("m"); | ||||
|   m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); | ||||
|   m.register_parameter("bias", torch::ones({20}), false); | ||||
|   m.define(R"( | ||||
|     def forward(self, input): | ||||
|       return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1) | ||||
|   )"); | ||||
|  | ||||
|   inputs.emplace_back(torch::ones({1, 1, 28, 28})); | ||||
|  | ||||
|   auto outputref = m.forward(inputs).toTensor(); | ||||
|  | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   IValue res; | ||||
|   for (int i = 0; i < 1; ++i) { | ||||
|     res = bc.get_method("forward")(inputs); | ||||
|   } | ||||
|   auto output = res.toTensor(); | ||||
|   AT_ASSERT(outputref.dim() == output.dim()); | ||||
|   AT_ASSERT(output.equal(outputref)); | ||||
| } | ||||
|  | ||||
| namespace { | ||||
| void testLiteModuleCompareResultTensors( | ||||
|     Module& m, | ||||
|     const std::vector<torch::jit::IValue>& inputs, | ||||
|     const std::string& method_name = "forward") { | ||||
|   auto outputref = m.get_method(method_name)(inputs).toTensor(); | ||||
|  | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   IValue res; | ||||
|   for (int i = 0; i < 3; ++i) { | ||||
|     res = bc.get_method(method_name)(inputs); | ||||
|   } | ||||
|   auto output = res.toTensor(); | ||||
|   AT_ASSERT(outputref.dim() == output.dim()); | ||||
|   AT_ASSERT(output.equal(outputref)); | ||||
| } | ||||
|  | ||||
| void testDefaultArgsPinv2(int num_args) { | ||||
|   Module m("m"); | ||||
|   if (num_args == 1) { | ||||
|     m.define(R"( | ||||
|       def forward(self, input): | ||||
|         return torch.linalg_pinv(input) | ||||
|     )"); | ||||
|   } else if (num_args == 2) { | ||||
|     m.define(R"( | ||||
|       def forward(self, input): | ||||
|         return torch.linalg_pinv(input, 1e-5) | ||||
|     )"); | ||||
|   } else if (num_args == 3) { | ||||
|     m.define(R"( | ||||
|       def forward(self, input): | ||||
|         return torch.linalg_pinv(input, 1e-5, True) | ||||
|     )"); | ||||
|   } | ||||
|  | ||||
|   std::vector<torch::jit::IValue> inputs; | ||||
|   const int N = 28; | ||||
|   auto input = torch::range(1, N * N, 1); | ||||
|   input[0] = 1; // a more stable matrix | ||||
|   input = input.view({N, N}); | ||||
|   inputs.emplace_back(input); | ||||
|   testLiteModuleCompareResultTensors(m, inputs); | ||||
| } | ||||
| } // namespace | ||||
|  | ||||
| #if !defined FB_XPLAT_BUILD | ||||
| TEST(LiteInterpreterDirectTest, DefaultArgsPinv) { | ||||
|   // Test with different number of specified arguments. | ||||
|   // Arguments not specified take default value. | ||||
|   for (int num_args = 1; num_args <= 3; ++num_args) { | ||||
|     testDefaultArgsPinv2(num_args); | ||||
|   } | ||||
|  | ||||
|   //  bytecode with one specified argument: | ||||
|   //  (6, | ||||
|   //      ('__torch__.m.forward', | ||||
|   //          (('instructions', | ||||
|   //              (('STOREN', 1, 2), | ||||
|   //                  ('DROPR', 1, 0), | ||||
|   //                  ('MOVE', 2, 0), | ||||
|   //                  ('OP', 0, 0), | ||||
|   //                  ('RET', 0, 0))), | ||||
|   //              ('operators', (('aten::linalg_pinv', '', 1),)), | ||||
|   //              ('constants', (False, 1e-15)), # default constants are not | ||||
|   //              used | ||||
|   //              ('types', ()), | ||||
|   //              ('register_size', 2)), | ||||
|   //          (('arguments', | ||||
|   //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value', | ||||
|   //              None)), | ||||
|   //                  (('name', 'input'), ('type', 'Tensor'), ('default_value', | ||||
|   //                  None)))), | ||||
|   //              ('returns', | ||||
|   //                  ((('name', ''), ('type', 'Tensor'), ('default_value', | ||||
|   //                  None)),))))) | ||||
|  | ||||
|   //  bytecode with 2 specified argument: | ||||
|   //  (6, | ||||
|   //      ('__torch__.m.forward', | ||||
|   //          (('instructions', | ||||
|   //              (('STOREN', 1, 2), | ||||
|   //                  ('DROPR', 1, 0), | ||||
|   //                  ('MOVE', 2, 0), | ||||
|   //                  ('LOADC', 1, 0), # added LOADC for specified argument | ||||
|   //                  ('OP', 0, 0), | ||||
|   //                  ('RET', 0, 0))), | ||||
|   //              ('operators', (('aten::linalg_pinv', '', 2),)), | ||||
|   //              ('constants', (False, 1e-05)), # updated constant table | ||||
|   //              ('types', ()), | ||||
|   //              ('register_size', 2)), | ||||
|   //          (('arguments', | ||||
|   //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value', | ||||
|   //              None)), | ||||
|   //                  (('name', 'input'), ('type', 'Tensor'), ('default_value', | ||||
|   //                  None)))), | ||||
|   //              ('returns', | ||||
|   //                  ((('name', ''), ('type', 'Tensor'), ('default_value', | ||||
|   //                  None)),))))) | ||||
|  | ||||
|   //  bytecode with 3 specified arguments: | ||||
|   //  (6, | ||||
|   //      ('__torch__.m.forward', | ||||
|   //          (('instructions', | ||||
|   //              (('STOREN', 1, 2), | ||||
|   //                  ('DROPR', 1, 0), | ||||
|   //                  ('MOVE', 2, 0), | ||||
|   //                  ('LOADC', 1, 0), | ||||
|   //                  ('LOADC', 0, 0), | ||||
|   //                  ('OP', 0, 0), | ||||
|   //                  ('RET', 0, 0))), | ||||
|   //              ('operators', (('aten::linalg_pinv', '', 3),)), | ||||
|   //              ('constants', (True, 1e-05)), | ||||
|   //              ('types', ()), | ||||
|   //              ('register_size', 2)), | ||||
|   //          (('arguments', | ||||
|   //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value', | ||||
|   //              None)), | ||||
|   //                  (('name', 'input'), ('type', 'Tensor'), ('default_value', | ||||
|   //                  None)))), | ||||
|   //              ('returns', | ||||
|   //                  ((('name', ''), ('type', 'Tensor'), ('default_value', | ||||
|   //                  None)),))))) | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, DefaultArgsTensorinvSpecifyDefault) { | ||||
|   // The second argument is specified, but the value is the same as the default | ||||
|   // value. It's treated as "not specified" since the value can be fetched from | ||||
|   // schema. | ||||
|   Module m("m"); | ||||
|   m.define(R"( | ||||
|     def forward(self, input): | ||||
|       return torch.linalg_tensorinv(input, 2) | ||||
|   )"); | ||||
|   torch::jit::MobileCode code(m.get_method("forward").graph(), "forward"); | ||||
|   auto arg_nums = code.op_to_num_specified_args(); | ||||
|   ASSERT_EQ(arg_nums.size(), 1); | ||||
|   ASSERT_EQ(arg_nums["aten::linalg_tensorinv"], 1); | ||||
|   std::vector<torch::jit::IValue> inputs; | ||||
|   const int N = 4; | ||||
|   auto input = torch::rand({N, N, N, N}); | ||||
|   inputs.emplace_back(input); | ||||
|   testLiteModuleCompareResultTensors(m, inputs); | ||||
| } | ||||
|  | ||||
| void testDefaultArgsPinvWithOutArg2(int num_args) { | ||||
|   Module m("m"); | ||||
|   if (num_args == 1) { | ||||
|     m.define(R"( | ||||
|       def forward(self, input): | ||||
|         return torch.linalg_pinv(input, out=input) | ||||
|     )"); | ||||
|   } else if (num_args == 2) { | ||||
|     m.define(R"( | ||||
|       def forward(self, input): | ||||
|         return torch.linalg_pinv(input, 1e-5, out=input) | ||||
|     )"); | ||||
|   } else if (num_args == 3) { | ||||
|     m.define(R"( | ||||
|       def forward(self, input): | ||||
|         return torch.linalg_pinv(input, 1e-5, True, out=input) | ||||
|     )"); | ||||
|   } | ||||
|  | ||||
|   const int N = 28; | ||||
|   auto input = torch::range(1, N * N, 1); | ||||
|   input[0] = 10000; // a more stable matrix | ||||
|   input = input.view({N, N}); | ||||
|   auto ref = m.run_method("forward", input); | ||||
|   TORCH_CHECK(!input.equal(torch::range(1, N * N, 1))); | ||||
|   TORCH_CHECK(input.equal(ref.toTensor())); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, DefaultArgsPinvWithOutArg) { | ||||
|   // Test with different number of specified arguments + out arg. | ||||
|   // Arguments not specified take default value. | ||||
|   for (int num_args = 1; num_args <= 3; ++num_args) { | ||||
|     testDefaultArgsPinvWithOutArg2(num_args); | ||||
|   } | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, DefaultArgsWithOutArg) { | ||||
|   Module m("m"); | ||||
|   m.define(R"( | ||||
|     def forward(self, x, h): | ||||
|       torch.add(x, h, out=x) | ||||
|   )"); | ||||
|  | ||||
|   std::vector<IValue> inputs; | ||||
|   auto input_x = 2 * torch::ones({}); | ||||
|   auto input_h = torch::ones({}); | ||||
|   auto ref = m.run_method("forward", input_x, input_h); | ||||
|  | ||||
|   CompilationOptions options; | ||||
|   mobile::Module bc = jitModuleToMobile(m, options); | ||||
|   bc.run_method("forward", input_x, input_h); | ||||
|   AT_ASSERT(input_x.equal(4 * torch::ones({}))); | ||||
| } | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, TestExceptionStackWithTwoLevelModuleHierarchy) { | ||||
|   Module a("A"); | ||||
|   a.define(R"( | ||||
|     def bar(self, x, y): | ||||
|       return x + y | ||||
|   )"); | ||||
|   Module b("B"); | ||||
|   b.register_module("A0", a); | ||||
|   b.define(R"( | ||||
|     def foo(self, x, y): | ||||
|       return self.A0.bar(x, y) + 2 | ||||
|   )"); | ||||
|   Module c("C"); | ||||
|   c.register_module("B0", b); | ||||
|   c.define(R"( | ||||
|     def forward(self, x, y): | ||||
|       return self.B0.foo(x, y) + 3 | ||||
|   )"); | ||||
|  | ||||
|   std::vector<IValue> inputs; | ||||
|   inputs.emplace_back(torch::rand({2, 4})); | ||||
|   inputs.emplace_back(torch::rand({13, 9})); | ||||
|  | ||||
|   CompilationOptions options; | ||||
|   auto lite_m = jitModuleToMobile(c, options); | ||||
|   std::string error_pattern = R"( | ||||
|   Module hierarchy:top(C)::<unknown>.B0(B)::foo.A0(A)::bar.aten::add | ||||
| Traceback of TorchScript (most recent call last): | ||||
|   File "<string>", line 3, in <unknown> | ||||
|  | ||||
|     def forward(self, x, y): | ||||
|       return self.B0.foo(x, y) + 3 | ||||
|              ~~~~~~~~~~~ <--- HERE | ||||
|  | ||||
|   File "<string>", line 3, in foo | ||||
|  | ||||
|     def foo(self, x, y): | ||||
|       return self.A0.bar(x, y) + 2 | ||||
|              ~~~~~~~~~~~ <--- HERE | ||||
|  | ||||
|   File "<string>", line 3, in bar | ||||
|  | ||||
|     def bar(self, x, y): | ||||
|       return x + y | ||||
|              ~~~~~ <--- HERE | ||||
|   )"; | ||||
|   ASSERT_THROWS_WITH_MESSAGE(lite_m.forward(inputs), error_pattern); | ||||
| } | ||||
| #endif // !defined(FB_XPLAT_BUILD) | ||||
|  | ||||
| namespace { | ||||
| static auto reg = | ||||
|     torch::class_<TorchBindLiteInterpreterDirectTestStruct>( | ||||
|         "_TorchScriptTesting", | ||||
|         "_LiteInterpreterDirectTest") | ||||
|         .def(torch::init<>()) | ||||
|         .def("get", &TorchBindLiteInterpreterDirectTestStruct::get) | ||||
|         .def_pickle( | ||||
|             // __getattr__ | ||||
|             [](const c10::intrusive_ptr< | ||||
|                 TorchBindLiteInterpreterDirectTestStruct>&) -> int64_t { | ||||
|               return 0; | ||||
|             }, | ||||
|             // __setattr__ | ||||
|             [](int64_t) { | ||||
|               return c10::make_intrusive< | ||||
|                   TorchBindLiteInterpreterDirectTestStruct>(); | ||||
|             }); | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| TEST(LiteInterpreterDirectTest, OperatorCacheDifferentiatesDefaultArgs) { | ||||
|   // Create 3 methods: | ||||
|   // | ||||
|   // 1. forward() returns a tensor with dtype=torch.int64 (4) | ||||
|   // 2. forward2() returns a tensor with dtype=torch.float32 (6) | ||||
|   // 3. forward3() returns a tensor with dtype=torch.float32 but | ||||
|   //    the dtype is inferred by the input tensor's dtype | ||||
|   // | ||||
|   // If caching works correctly, then the result from the full-jit | ||||
|   // module and the lite module will be the same. Otherwise, it | ||||
|   // will be different if we don't correctly ignore the cache | ||||
|   // entry for an operator that has a different number of | ||||
|   // arguments. | ||||
|   Module m("m"); | ||||
|   m.define(R"( | ||||
|     def forward(self): | ||||
|       ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4) | ||||
|       return ret1.fill_(25) | ||||
|   )"); | ||||
|   m.define(R"( | ||||
|     def forward2(self): | ||||
|       ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6) | ||||
|       return ret1.fill_(32.0) | ||||
|   )"); | ||||
|   m.define(R"( | ||||
|     def forward3(self): | ||||
|       ret1 = torch.new_empty(torch.zeros(10), [10]) | ||||
|       return ret1.fill_(12.0) | ||||
|   )"); | ||||
|  | ||||
|   std::vector<torch::jit::IValue> inputs; | ||||
|   testLiteModuleCompareResultTensors(m, inputs, "forward"); | ||||
|   testLiteModuleCompareResultTensors(m, inputs, "forward2"); | ||||
|   testLiteModuleCompareResultTensors(m, inputs, "forward3"); | ||||
| } | ||||
|  | ||||
| } // namespace jit | ||||
| } // namespace torch | ||||
| @ -24,6 +24,10 @@ struct TORCH_API GraphFunction : public Function { | ||||
|  | ||||
|   void run(Stack& stack) override; | ||||
|  | ||||
|   std::function<void(GraphFunction&)> function_creator() const { | ||||
|     return function_creator_; | ||||
|   } | ||||
|  | ||||
|   c10::intrusive_ptr<c10::ivalue::Future> runAsync( | ||||
|       Stack& stack, | ||||
|       TaskLauncher taskLauncher = at::launch) override; | ||||
|  | ||||
| @ -20,6 +20,7 @@ struct Code { | ||||
|   std::vector<Instruction> instructions_; | ||||
|   std::vector<DebugHandle> debug_handles_; | ||||
|   std::vector<c10::OperatorName> op_names_; | ||||
|   std::vector<int> operator_input_sizes_; | ||||
|   std::vector<std::function<void(Stack&)>> operators_; | ||||
|   std::vector<c10::IValue> constants_; | ||||
|   std::vector<c10::TypePtr> types_; | ||||
|  | ||||
| @ -23,6 +23,10 @@ class MobileDebugTable { | ||||
|   MobileDebugTable( | ||||
|       std::unique_ptr<caffe2::serialize::PyTorchStreamReader>& reader, | ||||
|       const std::shared_ptr<CompilationUnit>& cu); | ||||
|  | ||||
|   template <typename It> | ||||
|   MobileDebugTable(It begin, It end) : callstack_ptr_map_(begin, end) {} | ||||
|  | ||||
|   std::string getSourceDebugString( | ||||
|       const int64_t debug_handle, | ||||
|       const std::string& top_module_type_name = "ModuleTypeUnknown") const; | ||||
| @ -36,6 +40,11 @@ class MobileDebugTable { | ||||
|       const std::vector<int64_t>& debug_handles, | ||||
|       const std::string& top_module_type_name = "ModuleTypeUnknown") const; | ||||
|  | ||||
|   const ska::flat_hash_map<int64_t, DebugInfoTuple>& getCallStackPtrMap() | ||||
|       const { | ||||
|     return callstack_ptr_map_; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   std::pair<std::string, std::string> getSourceDebugModuleHierarchyInfo( | ||||
|       const std::vector<int64_t>& debug_handles, | ||||
|  | ||||
| @ -13,6 +13,14 @@ namespace mobile { | ||||
| Function::Function(c10::QualifiedName name) | ||||
|     : name_(std::move(name)), code_(std::make_shared<Code>()) {} | ||||
|  | ||||
| Function::Function( | ||||
|     c10::QualifiedName name, | ||||
|     std::shared_ptr<Code> code, | ||||
|     at::optional<c10::FunctionSchema> schema) | ||||
|     : name_(std::move(name)), | ||||
|       code_(std::move(code)), | ||||
|       schema_(std::move(schema)) {} | ||||
|  | ||||
| const c10::QualifiedName& Function::qualname() const { | ||||
|   return name_; | ||||
| } | ||||
| @ -43,89 +51,11 @@ bool Function::append_operator( | ||||
|   // Keep the original opname in code_ | ||||
|   code_->op_names_.emplace_back(name, overload_name); | ||||
|   const auto& opname = code_->op_names_.back(); | ||||
|   const auto full_name = c10::toString(opname); | ||||
|  | ||||
|   std::function<void(Stack&)> fn; | ||||
|  | ||||
|   const std::vector<c10::Argument>* pArgs = nullptr; | ||||
|   bool promoted_op = mobile::hasPrimOpsFn(full_name); | ||||
|   if (promoted_op) { | ||||
|     fn = mobile::getPrimOpsFn(full_name); | ||||
|   } else { | ||||
|     std::shared_ptr<Operator> jit_op = findOperatorFor(opname); | ||||
|     if (jit_op) { | ||||
|       fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); }; | ||||
|       pArgs = &jit_op->schema().arguments(); | ||||
|     } else { | ||||
|       auto op = c10::Dispatcher::singleton().findSchema(opname); | ||||
|       if (op.has_value()) { | ||||
|         fn = [op](Stack& stack) { op->callBoxed(&stack); }; | ||||
|         if (op->hasSchema()) { | ||||
|           pArgs = &op->schema().arguments(); | ||||
|         } else { | ||||
|           TORCH_CHECK(false, "arguments are missing for operator ", opname); | ||||
|         } | ||||
|       } else { | ||||
|         return false; | ||||
|       } | ||||
|     } | ||||
|   auto func = makeOperatorFunction(opname, num_specified_args, model_version); | ||||
|   if (!func.has_value()) { | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   if (!promoted_op) { | ||||
|     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pArgs); | ||||
|     const auto& args = *pArgs; | ||||
|     if (model_version == 0x3LL && opname.name == "aten::_convolution" && | ||||
|         opname.overload_name.empty()) { | ||||
|       // Since byte-code versions 0x4L, convolution has an additional | ||||
|       // default-value argument (allow_tf32=True, see | ||||
|       // https://github.com/pytorch/pytorch/pull/40737). This wrapper handles | ||||
|       // backward compatibility with models of byte-code version <= 0x3L, where | ||||
|       // this bool argument does not yet exist. | ||||
|       fn = [fn](Stack& stack) { | ||||
|         stack.push_back(true); | ||||
|         fn(stack); | ||||
|       }; | ||||
|     } else { | ||||
|       // num_specified_args >= 0 indicates number of arguments are available | ||||
|       // from model. We can use it to handle backward compatibility. | ||||
|       if (num_specified_args && | ||||
|           num_specified_args.value() < static_cast<int64_t>(args.size())) { | ||||
|         fn = [fn, num_specified_args, &args](Stack& stack) { | ||||
|           std::vector<IValue> out_args; | ||||
|           // The following logic pops and temporarily stores all out arguments | ||||
|           // from the stack (which can be 0 or more, and always appended to the | ||||
|           // schema), in order to push the necessary default values. Finally, | ||||
|           // the out arguments are pushed back into the stack. | ||||
|           for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) { | ||||
|             out_args.push_back(stack.back()); | ||||
|             stack.pop_back(); | ||||
|           } | ||||
|           size_t start_index = num_specified_args.value() - out_args.size(); | ||||
|           TORCH_CHECK( | ||||
|               start_index >= 0, | ||||
|               "The number of output arguments is: ", | ||||
|               out_args.size(), | ||||
|               ", which is more then the number of specified arguments: ", | ||||
|               num_specified_args.value()); | ||||
|           for (size_t i = start_index; i < (args.size() - out_args.size()); | ||||
|                ++i) { | ||||
|             TORCH_CHECK( | ||||
|                 args[i].default_value().has_value(), | ||||
|                 "Error happened at preparing for default values for the argument. The ", | ||||
|                 i, | ||||
|                 "th argument ", | ||||
|                 args[i].name(), | ||||
|                 " does not have a specified value or default value. "); | ||||
|  | ||||
|             stack.push_back(args[i].default_value()); | ||||
|           } | ||||
|           stack.insert(stack.end(), out_args.rbegin(), out_args.rend()); | ||||
|           fn(stack); | ||||
|         }; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   code_->operators_.emplace_back(fn); | ||||
|   code_->operators_.emplace_back(*func); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| @ -197,6 +127,93 @@ const std::vector<int64_t>& Function::getExceptionDebugHandles() const { | ||||
|   return getInterpretersExceptionDebugHandles(); | ||||
| } | ||||
|  | ||||
| c10::optional<std::function<void(Stack&)>> makeOperatorFunction( | ||||
|     c10::OperatorName opname, | ||||
|     c10::optional<int> num_specified_args, | ||||
|     int64_t model_version) { | ||||
|   std::function<void(Stack&)> fn; | ||||
|   const auto full_name = c10::toString(opname); | ||||
|   const std::vector<c10::Argument>* pArgs = nullptr; | ||||
|   bool promoted_op = mobile::hasPrimOpsFn(full_name); | ||||
|   if (promoted_op) { | ||||
|     fn = mobile::getPrimOpsFn(full_name); | ||||
|   } else { | ||||
|     std::shared_ptr<Operator> jit_op = findOperatorFor(opname); | ||||
|     if (jit_op) { | ||||
|       fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); }; | ||||
|       pArgs = &jit_op->schema().arguments(); | ||||
|     } else { | ||||
|       auto op = c10::Dispatcher::singleton().findSchema(opname); | ||||
|       if (op.has_value()) { | ||||
|         fn = [op](Stack& stack) { op->callBoxed(&stack); }; | ||||
|         if (op->hasSchema()) { | ||||
|           pArgs = &op->schema().arguments(); | ||||
|         } else { | ||||
|           TORCH_CHECK(false, "arguments are missing for operator ", opname); | ||||
|         } | ||||
|       } else { | ||||
|         return c10::nullopt; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   if (!promoted_op) { | ||||
|     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pArgs); | ||||
|     const auto& args = *pArgs; | ||||
|     if (model_version == 0x3LL && opname.name == "aten::_convolution" && | ||||
|         opname.overload_name.empty()) { | ||||
|       // Since byte-code versions 0x4L, convolution has an additional | ||||
|       // default-value argument (allow_tf32=True, see | ||||
|       // https://github.com/pytorch/pytorch/pull/40737). This wrapper handles | ||||
|       // backward compatibility with models of byte-code version <= 0x3L, where | ||||
|       // this bool argument does not yet exist. | ||||
|       fn = [fn](Stack& stack) { | ||||
|         stack.push_back(true); | ||||
|         fn(stack); | ||||
|       }; | ||||
|     } else { | ||||
|       // num_specified_args >= 0 indicates number of arguments are available | ||||
|       // from model. We can use it to handle backward compatibility. | ||||
|       if (num_specified_args && | ||||
|           num_specified_args.value() < static_cast<int64_t>(args.size())) { | ||||
|         fn = [fn, num_specified_args, &args](Stack& stack) { | ||||
|           std::vector<IValue> out_args; | ||||
|           // The following logic pops and temporarily stores all out arguments | ||||
|           // from the stack (which can be 0 or more, and always appended to the | ||||
|           // schema), in order to push the necessary default values. Finally, | ||||
|           // the out arguments are pushed back into the stack. | ||||
|           for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) { | ||||
|             out_args.push_back(stack.back()); | ||||
|             stack.pop_back(); | ||||
|           } | ||||
|           size_t start_index = num_specified_args.value() - out_args.size(); | ||||
|           TORCH_CHECK( | ||||
|               start_index >= 0, | ||||
|               "The number of output arguments is: ", | ||||
|               out_args.size(), | ||||
|               ", which is more then the number of specified arguments: ", | ||||
|               num_specified_args.value()); | ||||
|           for (size_t i = start_index; i < (args.size() - out_args.size()); | ||||
|                ++i) { | ||||
|             TORCH_CHECK( | ||||
|                 args[i].default_value().has_value(), | ||||
|                 "Error happened at preparing for default values for the argument. The ", | ||||
|                 i, | ||||
|                 "th argument ", | ||||
|                 args[i].name(), | ||||
|                 " does not have a specified value or default value. "); | ||||
|  | ||||
|             stack.push_back(args[i].default_value()); | ||||
|           } | ||||
|           stack.insert(stack.end(), out_args.rbegin(), out_args.rend()); | ||||
|           fn(stack); | ||||
|         }; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   return fn; | ||||
| } | ||||
|  | ||||
| } // namespace mobile | ||||
| } // namespace jit | ||||
| } // namespace torch | ||||
|  | ||||
| @ -17,6 +17,10 @@ struct Code; | ||||
| class TORCH_API Function : public torch::jit::Function { | ||||
|  public: | ||||
|   explicit Function(c10::QualifiedName name); | ||||
|   Function( | ||||
|       c10::QualifiedName name, | ||||
|       std::shared_ptr<Code> code, | ||||
|       at::optional<c10::FunctionSchema> schema); | ||||
|   void run(Stack& stack) override; | ||||
|   at::IValue operator()(Stack& stack); | ||||
|   void ensure_defined() override {} | ||||
| @ -24,6 +28,9 @@ class TORCH_API Function : public torch::jit::Function { | ||||
|   const c10::QualifiedName& qualname() const override; | ||||
|   bool call(Stack&, c10::function_ref<void(const mobile::Code&)>) override; | ||||
|  | ||||
|   // NOTE: the APIs below is dangerous: if you call append_instruction with | ||||
|   // dbg_handle and then call it without; then the dbg_handle will become | ||||
|   // misaligned. Therefore only use ONE variant at time. | ||||
|   void append_instruction(OpCode op, int X, int N, int64_t dbg_handle); | ||||
|   void append_instruction(OpCode op, int X, int N); | ||||
|   bool append_operator( | ||||
| @ -56,6 +63,11 @@ class TORCH_API Function : public torch::jit::Function { | ||||
|   at::optional<c10::FunctionSchema> schema_; // (byte-code version 4+) | ||||
| }; | ||||
|  | ||||
| c10::optional<std::function<void(Stack&)>> makeOperatorFunction( | ||||
|     c10::OperatorName opname, | ||||
|     c10::optional<int> num_specified_args, | ||||
|     int64_t model_version); | ||||
|  | ||||
| } // namespace mobile | ||||
| } // namespace jit | ||||
| } // namespace torch | ||||
|  | ||||
| @ -94,15 +94,15 @@ bool InterpreterState::run(Stack& stack) { | ||||
|         debug_handle = *handle; | ||||
|       } | ||||
|  | ||||
|       // std::cout << "RUNNING " << pc << " " | ||||
|       //           << code_->instructions_with_handles_[pc].instruction; | ||||
|       // std::cout << "RUNNING " << pc << " " << code.instructions_[pc]; | ||||
|       // if (inst.op == OP) { | ||||
|       //   std::cout << ", " << code_->op_names_[inst.X].name; | ||||
|       //   if (!code_->op_names_[inst.X].overload_name.empty()) { | ||||
|       //     std::cout << "." << code_->op_names_[inst.X].overload_name; | ||||
|       //   std::cout << ", " << code.op_names_[inst.X].name; | ||||
|       //   if (!code.op_names_[inst.X].overload_name.empty()) { | ||||
|       //     std::cout << "." << code.op_names_[inst.X].overload_name; | ||||
|       //   } | ||||
|       // } | ||||
|       // std::cout << std::endl; | ||||
|       // std::cout << "top " << stack.back().tagKind() << std::endl; | ||||
|  | ||||
|       // TODO(iliacher): remove the workaround after RecordFunction is in | ||||
|       // Dispatcher | ||||
|  | ||||
| @ -135,7 +135,7 @@ class TORCH_API Module { | ||||
|   std::unordered_map<std::string, std::string> metadata_; | ||||
|   std::shared_ptr<CompilationUnit> cu_; | ||||
|   MobileDebugTable debug_table_; | ||||
|   bool has_debug_handles_; | ||||
|   bool has_debug_handles_ = false; | ||||
| }; | ||||
| } // namespace mobile | ||||
| } // namespace jit | ||||
|  | ||||
| @ -33,7 +33,6 @@ using torch::distributed::autograd::DistAutogradContainer; | ||||
| #endif | ||||
|  | ||||
| #include <exception> | ||||
| #include <iostream> | ||||
| #include <memory> | ||||
| #include <mutex> | ||||
| #include <ostream> | ||||
|  | ||||
| @ -1,23 +1,333 @@ | ||||
| #include <torch/csrc/jit/serialization/export_bytecode.h> | ||||
| #include <utility> | ||||
|  | ||||
| #include <torch/csrc/jit/runtime/instruction.h> | ||||
| #include <torch/csrc/jit/serialization/export.h> | ||||
|  | ||||
| #include <c10/util/Exception.h> | ||||
| #include <torch/csrc/jit/api/function_impl.h> | ||||
| #include <torch/csrc/jit/api/method.h> | ||||
| #include <torch/csrc/jit/backends/backend_debug_handler.h> | ||||
| #include <torch/csrc/jit/backends/backend_debug_info.h> | ||||
| #include <torch/csrc/jit/frontend/source_range.h> | ||||
| #include <torch/csrc/jit/ir/attributes.h> | ||||
| #include <torch/csrc/jit/ir/ir.h> | ||||
| #include <torch/csrc/jit/ir/type_hashing.h> | ||||
| #include <torch/csrc/jit/mobile/function.h> | ||||
| #include <torch/csrc/jit/mobile/interpreter.h> | ||||
| #include <torch/csrc/jit/mobile/method.h> | ||||
| #include <torch/csrc/jit/mobile/module.h> | ||||
| #include <torch/csrc/jit/passes/inliner.h> | ||||
| #include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h> | ||||
| #include <torch/csrc/jit/serialization/import_export_constants.h> | ||||
| #include <torch/csrc/jit/serialization/import_export_functions.h> | ||||
| #include <torch/csrc/jit/serialization/import_export_helpers.h> | ||||
| #include <torch/csrc/jit/serialization/pickle.h> | ||||
| #include <torch/csrc/jit/serialization/python_print.h> | ||||
| #include <torch/csrc/jit/serialization/source_range_serialization.h> | ||||
| #include <torch/csrc/jit/serialization/type_name_uniquer.h> | ||||
|  | ||||
| #include <caffe2/serialize/inline_container.h> | ||||
|  | ||||
| namespace torch { | ||||
| namespace jit { | ||||
|  | ||||
| void BytecodeExportSet::add( | ||||
|     const c10::QualifiedName& qn, | ||||
|     ExportedFunction exported) { | ||||
|   items_.emplace(qn, std::move(exported)); | ||||
| std::vector<Method> gatherGetSetStates(ObjectPtr obj) { | ||||
|   std::vector<Method> methods; | ||||
|   // Use DFS on IValue's to traverse dependencies of module._ivalue and | ||||
|   // add all setstate/getstates to initial stack. | ||||
|   std::vector<ObjectPtr> ivalue_stack; | ||||
|   ivalue_stack.emplace_back(obj); | ||||
|   while (!ivalue_stack.empty()) { | ||||
|     ObjectPtr cur = ivalue_stack.back(); | ||||
|     ivalue_stack.pop_back(); | ||||
|     auto type = cur->type(); | ||||
|     Function* setstate = type->findMethod("__setstate__"); | ||||
|     Function* getstate = type->findMethod("__getstate__"); | ||||
|     if (getstate && setstate) { | ||||
|       if (setstate->isGraphFunction()) { | ||||
|         methods.emplace_back(cur, setstate); | ||||
|       } | ||||
|       if (getstate->isGraphFunction()) { | ||||
|         methods.emplace_back(cur, getstate); | ||||
|       } | ||||
|     } else { | ||||
|       for (size_t i = 0, n = type->numAttributes(); i < n; ++i) { | ||||
|         IValue field = cur->getSlot(i); | ||||
|         if (field.isObject()) { | ||||
|           ivalue_stack.emplace_back(field.toObject()); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   return methods; | ||||
| } | ||||
|  | ||||
| void BytecodeExportSet::update(const c10::QualifiedName& qn, bool toplevel) { | ||||
|   items_.at(qn).toplevel = toplevel; | ||||
| std::vector<Method> findAllDependentFunctions( | ||||
|     const Module& module, | ||||
|     Graph& graph) { | ||||
|   std::vector<Method> methods; | ||||
|   std::unordered_set<c10::string_view> called_method_names; | ||||
|   auto nodes = findAllNodes(graph, c10::prim::CallMethod, true); | ||||
|   for (Node* node : nodes) { | ||||
|     if (auto iface = node->input(0)->type()->castRaw<InterfaceType>()) { | ||||
|       const FunctionSchema* schema = iface->getMethod(node->s(attr::name)); | ||||
|       called_method_names.insert(schema->name()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   for (const auto& submodule : module.modules()) { | ||||
|     for (const auto& m : submodule.get_methods()) { | ||||
|       if (called_method_names.find(m.function().qualname().name()) != | ||||
|           called_method_names.end()) { | ||||
|         methods.emplace_back(m); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   return methods; | ||||
| } | ||||
|  | ||||
| bool BytecodeExportSet::contains(const c10::QualifiedName& qn) const { | ||||
|   return items_.find(qn) != items_.end(); | ||||
| // NOTE: order of functions returned will be: | ||||
| // 1. functions originated from the methods passed in will be first | ||||
| // 2. All the dependent functions will come afterwards. | ||||
| // This order is meaningful because currently mobile Module looks up | ||||
| // methods with linear search. | ||||
| std::vector<std::unique_ptr<GraphFunction>> inlineFunctions( | ||||
|     const std::vector<Method>& initial_methods, | ||||
|     bool incl_dependent_functions) { | ||||
|   std::set<std::pair<std::string, Function*>> visited; | ||||
|   std::deque<Method> stack; | ||||
|   std::copy( | ||||
|       initial_methods.begin(), | ||||
|       initial_methods.end(), | ||||
|       std::back_inserter(stack)); | ||||
|   std::vector<std::unique_ptr<GraphFunction>> inlined_functions; | ||||
|   while (!stack.empty()) { | ||||
|     Method cur = stack.front(); | ||||
|     stack.pop_front(); | ||||
|     auto tup = std::make_pair( | ||||
|         cur.owner()._ivalue()->type()->name()->qualifiedName(), | ||||
|         &cur.function()); | ||||
|     if (visited.find(tup) != visited.end()) { | ||||
|       continue; | ||||
|     } | ||||
|     visited.insert(tup); | ||||
|     const auto& f = toGraphFunction(cur.function()); | ||||
|     auto graph = f.graph()->copyUnique(); | ||||
|     Inline(*graph); | ||||
|     c10::QualifiedName qn(*cur.owner()._ivalue()->type()->name(), f.name()); | ||||
|  | ||||
|     if (incl_dependent_functions) { | ||||
|       std::vector<Method> dependent_methods = | ||||
|           findAllDependentFunctions(cur.owner(), *graph); | ||||
|       std::copy( | ||||
|           dependent_methods.begin(), | ||||
|           dependent_methods.end(), | ||||
|           std::back_inserter(stack)); | ||||
|     } | ||||
|     auto inlined_func = std::make_unique<GraphFunction>( | ||||
|         qn, std::move(graph), f.function_creator()); | ||||
|     inlined_func->setSchema(f.getSchema()); | ||||
|     inlined_functions.emplace_back(std::move(inlined_func)); | ||||
|   } | ||||
|   return inlined_functions; | ||||
| } | ||||
|  | ||||
| std::unique_ptr<mobile::Code> compileGraphToMobileCode( | ||||
|     const std::string& name, | ||||
|     const std::shared_ptr<Graph>& graph, | ||||
|     const CompilationOptions& compilation_options, | ||||
|     BackendDebugInfoRecorder& debug_info_recorder) { | ||||
|   MobileCode code( | ||||
|       graph, | ||||
|       name, | ||||
|       compilation_options.enable_default_value_for_unspecified_arg, | ||||
|       compilation_options.enable_default_args_before_out_args); | ||||
|  | ||||
|   std::unique_ptr<mobile::Code> mobile_code_ptr = | ||||
|       std::make_unique<mobile::Code>(); | ||||
|   mobile::Code& mobile_code = *mobile_code_ptr; | ||||
|  | ||||
|   // operator names | ||||
|   std::vector<std::string> method_names; | ||||
|   std::vector<int64_t> op_debug_handles; | ||||
|   int next_new_op_index = 0; | ||||
|  | ||||
|   auto op_to_specified_args = code.op_to_num_specified_args(); | ||||
|  | ||||
|   for (size_t i = 0; i < code.instructions().size(); ++i) { | ||||
|     Instruction ins = code.instructions()[i]; | ||||
|  | ||||
|     if ((ins.op == OP || ins.op == OPN) && ins.X == next_new_op_index) { | ||||
|       // Found a new op (assumes new operators ordered by ascending ins.X) | ||||
|       auto node = code.instructions_source()[i]; | ||||
|       const c10::OperatorName& opname = node->schema().operator_name(); | ||||
|       auto unique_name = c10::toString(opname); | ||||
|       // For operator with vararg, adding default arguments would be confusing | ||||
|       // and is not allowed. For an operator with num_args = -1, it means the | ||||
|       // number of arguments is not available for this operator, we don't do any | ||||
|       // backward compatibility adaptation at runtime. | ||||
|       c10::optional<int> num_args = c10::nullopt; | ||||
|       auto it = op_to_specified_args.find(unique_name); | ||||
|       if (it != op_to_specified_args.end()) { | ||||
|         num_args = it->second; | ||||
|       } | ||||
|       mobile_code.operator_input_sizes_.emplace_back(num_args.value_or(-1)); | ||||
|       mobile_code.op_names_.emplace_back(opname); | ||||
|       auto func = mobile::makeOperatorFunction( | ||||
|           opname, num_args, compilation_options.model_version); | ||||
|       TORCH_INTERNAL_ASSERT( | ||||
|           func.has_value(), | ||||
|           "Operator with name: ", | ||||
|           toString(opname), | ||||
|           " not found"); | ||||
|       mobile_code.operators_.emplace_back(*func); | ||||
|       next_new_op_index++; | ||||
|     } | ||||
|     // CALL nodes at this point represent built-in (i.e. non-Graph) | ||||
|     // functions that were not inlined. Here we convert the CALL | ||||
|     // instructions for these functions into INTERFACE_CALL instructions | ||||
|     // s.t. at runtime, we will look up the Function* on the Type of the | ||||
|     // 0th argument in the stack and call that directly. | ||||
|     if (ins.op == CALL) { | ||||
|       auto node = code.instructions_source()[i]; | ||||
|       if (node->kind() == prim::CallMethod) { | ||||
|         // NB: replacing instruction | ||||
|         auto method_name_idx = | ||||
|             code.constant_table().size() + method_names.size(); | ||||
|         method_names.emplace_back(node->s(attr::name)); | ||||
|         ins = Instruction{ | ||||
|             INTERFACE_CALL, | ||||
|             static_cast<int32_t>(method_name_idx), | ||||
|             static_cast<uint16_t>(node->inputs().size())}; | ||||
|       } else { | ||||
|         TORCH_INTERNAL_ASSERT( | ||||
|             false, "Unsupported node kind on CALL opcode for mobile"); | ||||
|       } | ||||
|     } else if (ins.op == RET) { | ||||
|       auto node = code.instructions_source()[i]; | ||||
|       for (const auto& input : node->inputs()) { | ||||
|         const auto& input_type = input->type(); | ||||
|         if (input_type->kind() == TypeKind::ListType || | ||||
|             input_type->kind() == TypeKind::DictType) { | ||||
|           for (const TypePtr& element_type : input_type->containedTypes()) { | ||||
|             TORCH_CHECK( | ||||
|                 element_type->kind() != TypeKind::ClassType, | ||||
|                 "Returining a list or dictionary with pytorch class type ", | ||||
|                 "is not supported in mobile module " | ||||
|                 "(List[Foo] or Dict[int, Foo] for class Foo(torch.nn.Module)). " | ||||
|                 "Workaround: instead of using pytorch class as their element type, ", | ||||
|                 "use a combination of list, dictionary, and single types."); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|     } else { | ||||
|       TORCH_CHECK( | ||||
|           isOpSupportedInMobile(ins.op), | ||||
|           toString(ins.op), | ||||
|           " is not supported in mobile module."); | ||||
|     } | ||||
|     auto node = code.instructions_source()[i]; | ||||
|     int64_t debug_handle = debug_info_recorder.getNextDebugHandle(node); | ||||
|     // Note 1-to-1 correspondence between instructions and debug handles | ||||
|     mobile_code.instructions_.emplace_back(ins); | ||||
|     mobile_code.debug_handles_.emplace_back(debug_handle); | ||||
|   } | ||||
|  | ||||
|   // copy constants | ||||
|   mobile_code.constants_ = code.constant_table(); | ||||
|  | ||||
|   // Make a copy of the constants and append the method names | ||||
|   // that we emitted for the converted INTERFACE_CALL nodes above. | ||||
|   for (auto& method_name : method_names) { | ||||
|     mobile_code.constants_.emplace_back(method_name); | ||||
|   } | ||||
|  | ||||
|   mobile_code.types_ = code.type_table(); | ||||
|   mobile_code.register_size_ = code.register_size(); | ||||
|   return mobile_code_ptr; | ||||
| } | ||||
|  | ||||
| void checkSchema(const FunctionSchema& schema) { | ||||
|   TORCH_CHECK( | ||||
|       schema.overload_name().empty(), // @TODO: is this check correct? | ||||
|       "Overloads are not supported in mobile modules."); | ||||
|   TORCH_CHECK( | ||||
|       !schema.is_vararg(), "Python *args are not supported in mobile modules."); | ||||
|   TORCH_CHECK( | ||||
|       !schema.is_varret(), | ||||
|       "A variable number of return values is not supported in mobile modules."); | ||||
| } | ||||
|  | ||||
| bool isLoweredModule(const Module& m) { | ||||
|   c10::QualifiedName type_name; | ||||
|   if (m.type()->name()) { | ||||
|     type_name = m.type()->name().value(); | ||||
|   } | ||||
|   bool isLoweredModule = false; | ||||
|   for (const auto& atom : type_name.atoms()) { | ||||
|     if (atom == "LoweredModule") { | ||||
|       isLoweredModule = true; | ||||
|       break; | ||||
|     } | ||||
|   } | ||||
|   return isLoweredModule; | ||||
| } | ||||
|  | ||||
| // Check if the global static map of backend debug info | ||||
| // contains debug info for this module and any of its children. | ||||
| // If so combine all the maps together and return one. | ||||
| void getBackendDebugInfoMap( | ||||
|     const Module& m, | ||||
|     BackendDebugInfoMapType& debug_map) { | ||||
|   if (isLoweredModule(m)) { | ||||
|     auto backend_debug_info = | ||||
|         m.attr("__backend_debug_info").toCustomClass<PyTorchBackendDebugInfo>(); | ||||
|     const auto& map = backend_debug_info->getDebugInfoMap(); | ||||
|     if (map) { | ||||
|       debug_map.insert(map.value().begin(), map.value().end()); | ||||
|     } | ||||
|   } | ||||
|   for (const auto& c : m.children()) { | ||||
|     getBackendDebugInfoMap(c, debug_map); | ||||
|   } | ||||
| } | ||||
|  | ||||
| mobile::Module jitModuleToMobile( | ||||
|     const Module& module, | ||||
|     const CompilationOptions& options) { | ||||
|   std::shared_ptr<mobile::CompilationUnit> mcu = | ||||
|       std::make_shared<mobile::CompilationUnit>(); | ||||
|   BackendDebugInfoRecorder debug_info_recorder; | ||||
|  | ||||
|   std::vector<Method> methods_to_export = module.get_methods(); | ||||
|   std::vector<Method> getsetstates = gatherGetSetStates(module._ivalue()); | ||||
|   std::copy( | ||||
|       getsetstates.begin(), | ||||
|       getsetstates.end(), | ||||
|       std::back_inserter(methods_to_export)); | ||||
|  | ||||
|   for (const auto& func : | ||||
|        inlineFunctions(methods_to_export, options.incl_interface_call)) { | ||||
|     std::shared_ptr<mobile::Code> mobile_code_ptr = compileGraphToMobileCode( | ||||
|         func->name(), func->graph(), options, debug_info_recorder); | ||||
|     const auto& schema = func->getSchema(); | ||||
|     checkSchema(schema); | ||||
|     auto mobile_func = std::make_unique<mobile::Function>( | ||||
|         func->qualname(), mobile_code_ptr, schema); | ||||
|     mcu->register_function(std::move(mobile_func)); | ||||
|   } | ||||
|  | ||||
|   mobile::Module m(module._ivalue(), mcu); | ||||
|   m.setHasDebugHandles(true); | ||||
|   BackendDebugInfoMapType backend_debug_info_map; | ||||
|   getBackendDebugInfoMap(module, backend_debug_info_map); | ||||
|   auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording(); | ||||
|   debug_handle_cs_ptr_map.insert( | ||||
|       backend_debug_info_map.begin(), backend_debug_info_map.end()); | ||||
|   m.setDebugTable(MobileDebugTable( | ||||
|       debug_handle_cs_ptr_map.begin(), debug_handle_cs_ptr_map.end())); | ||||
|   return m; | ||||
| } | ||||
|  | ||||
| } // namespace jit | ||||
|  | ||||
| @ -1,59 +1,31 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <tuple> | ||||
| #include <unordered_map> | ||||
|  | ||||
| #include <ATen/core/function_schema.h> | ||||
| #include <ATen/core/ivalue.h> | ||||
| #include <ATen/core/jit_type.h> | ||||
| #include <ATen/core/qualified_name.h> | ||||
| #include <torch/csrc/jit/backends/backend_debug_handler.h> | ||||
| #include <torch/csrc/jit/mobile/function.h> | ||||
| #include <torch/csrc/jit/mobile/module.h> | ||||
| #include <torch/csrc/jit/runtime/interpreter.h> | ||||
| #include <torch/csrc/jit/serialization/type_name_uniquer.h> | ||||
|  | ||||
| namespace torch { | ||||
| namespace jit { | ||||
|  | ||||
| struct ExportedFunction { | ||||
|   ExportedFunction( | ||||
|       const Module& m, | ||||
|       const Function& f, | ||||
|       std::unique_ptr<Graph> g, | ||||
|       bool t) | ||||
|       : mod(m), function(f), optimizedGraph(std::move(g)), toplevel(t) {} | ||||
|   Module mod; | ||||
|   const Function& function; | ||||
|   std::unique_ptr<Graph> optimizedGraph; | ||||
|   bool toplevel; | ||||
| struct TORCH_API CompilationOptions { | ||||
|   bool incl_interface_call = false; | ||||
|   bool enable_default_value_for_unspecified_arg = false; | ||||
|   bool enable_default_args_before_out_args = true; | ||||
|   int model_version = caffe2::serialize::kProducedBytecodeVersion; | ||||
| }; | ||||
|  | ||||
| class TORCH_API BytecodeExportSet { | ||||
|  public: | ||||
|   BytecodeExportSet() = default; | ||||
|   BytecodeExportSet(const BytecodeExportSet&) = delete; | ||||
|   BytecodeExportSet& operator=(const BytecodeExportSet&) = delete; | ||||
|   BytecodeExportSet(BytecodeExportSet&&) = default; | ||||
|   BytecodeExportSet& operator=(BytecodeExportSet&&) = default; | ||||
|  | ||||
|   void add(const c10::QualifiedName& qn, ExportedFunction); | ||||
|   void update(const c10::QualifiedName& qn, bool toplevel); | ||||
|   bool contains(const c10::QualifiedName& qn) const; | ||||
|  | ||||
|   template <typename F> | ||||
|   void visit(F&& f) { | ||||
|     for (auto& item : items_) { | ||||
|       if (item.second.toplevel) { | ||||
|         f(item.first, item.second); | ||||
|       } | ||||
|     } | ||||
|     for (auto& item : items_) { | ||||
|       if (!item.second.toplevel) { | ||||
|         f(item.first, item.second); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   std::unordered_map<c10::QualifiedName, ExportedFunction> items_; | ||||
| }; | ||||
| TORCH_API mobile::Module jitModuleToMobile( | ||||
|     const Module& module, | ||||
|     const CompilationOptions& options); | ||||
|  | ||||
| } // namespace jit | ||||
| } // namespace torch | ||||
|  | ||||
| @ -38,6 +38,18 @@ | ||||
| namespace torch { | ||||
| namespace jit { | ||||
|  | ||||
| CompilationOptions getOptionsFromGlobal() { | ||||
|   CompilationOptions compilation_options; | ||||
|   compilation_options.enable_default_args_before_out_args = | ||||
|       BytecodeEmitMode::is_default_args_before_out_args_enabled(); | ||||
|   compilation_options.enable_default_value_for_unspecified_arg = | ||||
|       BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled(); | ||||
|   compilation_options.incl_interface_call = getMobileInterfaceCallExport(); | ||||
|   compilation_options.model_version = | ||||
|       caffe2::serialize::kProducedBytecodeVersion; | ||||
|   return compilation_options; | ||||
| } | ||||
|  | ||||
| IValue to_tuple(std::initializer_list<IValue> ivalues) { | ||||
|   return c10::ivalue::Tuple::create(ivalues); | ||||
| } | ||||
| @ -63,138 +75,49 @@ ExportModuleExtraFilesHook& GetExtraFilesHook() { | ||||
| } | ||||
|  | ||||
| std::pair<IValue, IValue> getFunctionTuple( | ||||
|     const Module& module, | ||||
|     const Function& func, | ||||
|     std::unique_ptr<Graph> optimizedGraph, | ||||
|     const CompilationUnit& compilation_unit, | ||||
|     const mobile::Function& func, | ||||
|     BackendDebugInfoRecorder& debug_info_recorder, | ||||
|     const std::string& qn, | ||||
|     TypeNameUniquer& type_name_uniquer_) { | ||||
|   TORCH_INTERNAL_ASSERT(optimizedGraph); | ||||
|   std::shared_ptr<MobileCode> code; | ||||
|   code = std::make_shared<MobileCode>( | ||||
|       std::move(optimizedGraph), func.name(), BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled() /* emit_default_input_instructions */, BytecodeEmitMode::is_default_args_before_out_args_enabled() /* enable_defaults_args_with_out_args */); | ||||
|   auto instructions_copy = code->instructions(); | ||||
|  | ||||
|   // operator names | ||||
|   std::vector<c10::OperatorName> opnames; | ||||
|   std::vector<std::string> method_names; | ||||
|   std::vector<int64_t> op_debug_handles; | ||||
|   int next_new_op_index = 0; | ||||
|   for (size_t i = 0; i < instructions_copy.size(); ++i) { | ||||
|     Instruction ins = instructions_copy[i]; | ||||
|     if ((ins.op == OP || ins.op == OPN) && ins.X == next_new_op_index) { | ||||
|       // Found a new op (assumes new operators ordered by ascending ins.X) | ||||
|       auto node = code->instructions_source()[i]; | ||||
|       opnames.emplace_back(node->schema().operator_name()); | ||||
|       next_new_op_index++; | ||||
|     } | ||||
|     // CALL nodes at this point represent built-in (i.e. non-Graph) | ||||
|     // functions that were not inlined. Here we convert the CALL | ||||
|     // instructions for these functions into INTERFACE_CALL instructions | ||||
|     // s.t. at runtime, we will look up the Function* on the Type of the | ||||
|     // 0th argument in the stack and call that directly. | ||||
|     if (ins.op == CALL) { | ||||
|       auto node = code->instructions_source()[i]; | ||||
|       if (node->kind() == prim::CallMethod) { | ||||
|         // NB: replacing instruction | ||||
|         auto method_name_idx = | ||||
|             code->constant_table().size() + method_names.size(); | ||||
|         method_names.emplace_back(node->s(attr::name)); | ||||
|         Instruction new_instr{ | ||||
|             INTERFACE_CALL, | ||||
|             static_cast<int32_t>(method_name_idx), | ||||
|             static_cast<uint16_t>(node->inputs().size())}; | ||||
|         instructions_copy[i] = new_instr; | ||||
|       } else { | ||||
|         TORCH_INTERNAL_ASSERT( | ||||
|             false, "Unsupported node kind on CALL opcode for mobile"); | ||||
|       } | ||||
|     } else if (ins.op == RET) { | ||||
|       auto node = code->instructions_source()[i]; | ||||
|       for (const auto& input : node->inputs()) { | ||||
|         const auto& input_type = input->type(); | ||||
|         if (input_type->kind() == TypeKind::ListType || | ||||
|             input_type->kind() == TypeKind::DictType) { | ||||
|           for (const TypePtr& element_type : input_type->containedTypes()) { | ||||
|             TORCH_CHECK( | ||||
|                 element_type->kind() != TypeKind::ClassType, | ||||
|                 "Returining a list or dictionary with pytorch class type ", | ||||
|                 "is not supported in mobile module " | ||||
|                 "(List[Foo] or Dict[int, Foo] for class Foo(torch.nn.Module)). " | ||||
|                 "Workaround: instead of using pytorch class as their element type, ", | ||||
|                 "use a combination of list, dictionary, and single types."); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|     } else { | ||||
|       TORCH_CHECK( | ||||
|           isOpSupportedInMobile(ins.op), | ||||
|           toString(ins.op), | ||||
|           " is not supported in mobile module."); | ||||
|     } | ||||
|     auto node = code->instructions_source()[i]; | ||||
|     int64_t debug_handle = debug_info_recorder.getNextDebugHandle(node); | ||||
|     // Note 1-to-1 correspondence between instructions and debug handles | ||||
|     op_debug_handles.emplace_back(debug_handle); | ||||
|   } | ||||
|   const std::shared_ptr<mobile::Code> mobile_code_ptr = func.get_code(); | ||||
|  | ||||
|   // instructions | ||||
|   std::vector<IValue> instructions; | ||||
|   instructions.reserve(instructions_copy.size()); | ||||
|   for (Instruction ins : instructions_copy) { | ||||
|   instructions.reserve(mobile_code_ptr->instructions_.size()); | ||||
|   for (Instruction ins : mobile_code_ptr->instructions_) { | ||||
|     instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N})); | ||||
|   } | ||||
|  | ||||
|   // operators | ||||
|   std::vector<IValue> operators; | ||||
|   auto op_to_specified_args = code->op_to_num_specified_args(); | ||||
|   operators.reserve(opnames.size()); | ||||
|   for (const auto& opname : opnames) { | ||||
|     auto unique_name = c10::toString(opname); | ||||
|     // For operator with vararg, adding default arguments would be confusing and | ||||
|     // is not allowed. For an operator with num_args = -1, it means the number | ||||
|     // of arguments is not available for this operator, we don't do any backward | ||||
|     // compatibility adaptation at runtime. | ||||
|     int num_args = -1; | ||||
|     auto it = op_to_specified_args.find(unique_name); | ||||
|     if (it != op_to_specified_args.end()) { | ||||
|       num_args = it->second; | ||||
|     } | ||||
|   operators.reserve(mobile_code_ptr->op_names_.size()); | ||||
|   for (int i = 0; i < mobile_code_ptr->op_names_.size(); ++i) { | ||||
|     const auto& opname = mobile_code_ptr->op_names_[i]; | ||||
|     const int size = mobile_code_ptr->operator_input_sizes_[i]; | ||||
|     if (BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()) { | ||||
|       operators.emplace_back(to_tuple({opname.name, opname.overload_name})); | ||||
|     } else { | ||||
|       operators.emplace_back( | ||||
|           to_tuple({opname.name, opname.overload_name, num_args})); | ||||
|           to_tuple({opname.name, opname.overload_name, size})); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   // constants | ||||
|   // | ||||
|   // Make a copy of the constants and append the method names | ||||
|   // that we emitted for the converted INTERFACE_CALL nodes above. | ||||
|   auto constants = code->constant_table(); | ||||
|   for (auto& method_name : method_names) { | ||||
|     constants.emplace_back(std::move(method_name)); | ||||
|   } | ||||
|  | ||||
|   // types | ||||
|   std::vector<IValue> types; | ||||
|   types.reserve(code->type_table().size()); | ||||
|   types.reserve(mobile_code_ptr->types_.size()); | ||||
|   static const std::string torch_prefix("__torch__"); | ||||
|   static const std::string class_prefix("__torch__.torch.classes"); | ||||
|   std::shared_ptr<torch::jit::CompilationUnit> cu = | ||||
|       module._ivalue()->compilation_unit(); | ||||
|  | ||||
|   for (const TypePtr& t : code->type_table()) { | ||||
|   for (const TypePtr& t : mobile_code_ptr->types_) { | ||||
|     std::string type_str = t->annotation_str(); | ||||
|     if (t->kind() == TypeKind::TupleType) { | ||||
|       TORCH_CHECK( | ||||
|           cu->get_named_tuple(t->str()), | ||||
|           compilation_unit.get_named_tuple(t->str()), | ||||
|           "Can't find definition for the qualified name: ", | ||||
|           t->str(), | ||||
|           "(TypeKind::TupleType)  in compilation unit.", | ||||
|           "Please report a bug to PyTorch."); | ||||
|       auto named_tuple_type = cu->get_named_tuple(t->str()); | ||||
|       auto named_tuple_type = compilation_unit.get_named_tuple(t->str()); | ||||
|       if (named_tuple_type != nullptr) { | ||||
|         std::string named_tuple_str = t->str(); | ||||
|         named_tuple_str.append("[NamedTuple, ["); | ||||
| @ -254,12 +177,12 @@ std::pair<IValue, IValue> getFunctionTuple( | ||||
|  | ||||
|   // since the register location is embedded into the bytecode, pass the | ||||
|   // register size | ||||
|   auto register_size = static_cast<int>(code->register_size()); | ||||
|   auto register_size = static_cast<int>(mobile_code_ptr->register_size_); | ||||
|  | ||||
|   auto codeTable = Table( | ||||
|       {{"instructions", to_tuple(instructions)}, | ||||
|        {"operators", to_tuple(operators)}, | ||||
|        {"constants", to_tuple(constants)}, | ||||
|        {"constants", to_tuple(mobile_code_ptr->constants_)}, | ||||
|        {"types", to_tuple(types)}, | ||||
|        {"register_size", register_size}}); | ||||
|  | ||||
| @ -273,14 +196,7 @@ std::pair<IValue, IValue> getFunctionTuple( | ||||
|     } | ||||
|     return c10::nullopt; | ||||
|   }; | ||||
|   TORCH_CHECK( | ||||
|       schema.overload_name().empty(), // @TODO: is this check correct? | ||||
|       "Overloads are not supported in mobile modules."); | ||||
|   TORCH_CHECK( | ||||
|       !schema.is_vararg(), "Python *args are not supported in mobile modules."); | ||||
|   TORCH_CHECK( | ||||
|       !schema.is_varret(), | ||||
|       "A variable number of return values is not supported in mobile modules."); | ||||
|  | ||||
|   auto makeArgTuple = [&](const std::vector<Argument>& args) { | ||||
|     std::vector<IValue> argTables; | ||||
|     for (auto&& arg : args) { | ||||
| @ -315,6 +231,17 @@ std::pair<IValue, IValue> getFunctionTuple( | ||||
|   }); | ||||
|  | ||||
|   // function tuple | ||||
|   std::string qn; | ||||
|   if (func.name() == "__setstate__" || func.name() == "__getstate__") { | ||||
|     auto classtype = func.getSchema().arguments()[0].type()->cast<ClassType>(); | ||||
|     TORCH_INTERNAL_ASSERT( | ||||
|         classtype, "class is null ", func.qualname().qualifiedName()); | ||||
|     qn = c10::QualifiedName( | ||||
|              type_name_uniquer_.getUniqueName(classtype), func.name()) | ||||
|              .qualifiedName(); | ||||
|   } else { | ||||
|     qn = func.qualname().qualifiedName(); | ||||
|   } | ||||
|   auto bytecode_vals = to_tuple({qn, codeTable, schemaTable}); | ||||
|  | ||||
|   c10::optional<IValue> debug_info_vals; | ||||
| @ -324,41 +251,27 @@ std::pair<IValue, IValue> getFunctionTuple( | ||||
|   // debug handles generated by debug_handle_manager | ||||
|   // will correspond to {source_range, inlinedCallStackPtr} which we will | ||||
|   // serialize separately. | ||||
|   IValue module_debug_tuple = c10::ivalue::Tuple::create(op_debug_handles); | ||||
|   IValue module_debug_tuple = | ||||
|       c10::ivalue::Tuple::create(mobile_code_ptr->debug_handles_); | ||||
|   auto function_debug_info = | ||||
|       Table({{"function_debug_handles", module_debug_tuple}}); | ||||
|   debug_info_vals = to_tuple({qn, function_debug_info}); | ||||
|   return std::make_pair(bytecode_vals, debug_info_vals); | ||||
| } | ||||
|  | ||||
| void pushFunctionToIValues( | ||||
|     BytecodeExportSet exportSet, | ||||
| void pushMobileFunctionsToIValues( | ||||
|     const CompilationUnit& compilation_unit, | ||||
|     const mobile::Module& module, | ||||
|     std::vector<c10::IValue>& elements, | ||||
|     std::vector<c10::IValue>& debugInfoElements, | ||||
|     BackendDebugInfoRecorder& recorder, | ||||
|     TypeNameUniquer& uniquer) { | ||||
|   exportSet.visit( | ||||
|       [&](const c10::QualifiedName& qn, ExportedFunction& exported) { | ||||
|         auto tuple = getFunctionTuple( | ||||
|             exported.mod, | ||||
|             exported.function, | ||||
|             std::move(exported.optimizedGraph), | ||||
|             recorder, | ||||
|             qn.qualifiedName(), | ||||
|             uniquer); | ||||
|         elements.push_back(std::move(tuple.first)); | ||||
|         debugInfoElements.push_back(std::move(tuple.second)); | ||||
|       }); | ||||
| } | ||||
|  | ||||
| void pushFunctionToIValues( | ||||
|     BytecodeExportSet exportSet, | ||||
|     std::vector<c10::IValue>& elements, | ||||
|     BackendDebugInfoRecorder& recorder, | ||||
|     TypeNameUniquer& uniquer) { | ||||
|   std::vector<c10::IValue> debugInfoElements; | ||||
|   pushFunctionToIValues( | ||||
|       std::move(exportSet), elements, debugInfoElements, recorder, uniquer); | ||||
|   for (const auto& method : module.get_methods()) { | ||||
|     auto tuple = getFunctionTuple( | ||||
|         compilation_unit, method.function(), recorder, uniquer); | ||||
|     elements.push_back(std::move(tuple.first)); | ||||
|     debugInfoElements.push_back(std::move(tuple.second)); | ||||
|   } | ||||
| } | ||||
|  | ||||
| std::unordered_set<const FunctionSchema*> getInterfaceCalls(Graph& graph) { | ||||
| @ -402,61 +315,6 @@ std::vector<ModuleMethod> getModuleInterfaceExports( | ||||
|   return ret; | ||||
| } | ||||
|  | ||||
| void exportFunction( | ||||
|     BytecodeExportSet& exportSet, | ||||
|     const ModuleMethod& method, | ||||
|     bool toplevel = false) { | ||||
|   const auto& func = method.function; | ||||
|   const auto& qn = method.exportName; | ||||
|   if (exportSet.contains(qn)) { | ||||
|     if (toplevel) { | ||||
|       exportSet.update(qn, toplevel); | ||||
|     } | ||||
|     return; | ||||
|   } | ||||
|   auto graph = func.graph()->copyUnique(); | ||||
|   Inline(*graph); | ||||
|   auto interfaceCalls = getInterfaceCalls(*graph); | ||||
|   exportSet.add( | ||||
|       qn, ExportedFunction{method.module, func, std::move(graph), toplevel}); | ||||
|  | ||||
|   if (!getMobileInterfaceCallExport()) { | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   auto interfaces = getModuleInterfaceExports(method.module, interfaceCalls); | ||||
|   for (const auto& interface : interfaces) { | ||||
|     exportFunction(exportSet, interface); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void setstateTuple( | ||||
|     BytecodeExportSet& exportSet, | ||||
|     const Module& module, | ||||
|     const IValue& ivalue, | ||||
|     TypeNameUniquer& type_name_uniquer_, | ||||
|     bool toplevel = false) { | ||||
|   if (!ivalue.isObject()) | ||||
|     return; | ||||
|   auto obj = ivalue.toObject(); | ||||
|   auto type = obj->type(); | ||||
|   if (checkHasValidSetGetState(type)) { | ||||
|     Function& setstate = type->getMethod("__setstate__"); | ||||
|     auto qn = type_name_uniquer_.getUniqueName(obj->type()).qualifiedName() + | ||||
|         "." + setstate.name(); | ||||
|     if (exportSet.contains(qn)) { | ||||
|       return; | ||||
|     } | ||||
|     if (auto f = tryToGraphFunction(setstate)) { | ||||
|       exportFunction(exportSet, ModuleMethod{module, *f, qn}, toplevel); | ||||
|     } | ||||
|   } else { | ||||
|     for (size_t i = 0, n = type->numAttributes(); i < n; ++i) { | ||||
|       setstateTuple(exportSet, module, obj->getSlot(i), type_name_uniquer_); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| bool isLoweredModule(const Module& m) { | ||||
|   c10::QualifiedName type_name; | ||||
|   if (m.type()->name()) { | ||||
| @ -544,24 +402,6 @@ bool getMobileInterfaceCallExport() { | ||||
|   return mobileInterfaceCallExport().load(std::memory_order_relaxed); | ||||
| } | ||||
|  | ||||
| BytecodeExportSet moduleMethodsTuple( | ||||
|     const Module& module, | ||||
|     TypeNameUniquer& type_name_uniquer_) { | ||||
|   BytecodeExportSet exportSet; | ||||
|   auto methods = module.get_methods(); | ||||
|   // top level methods | ||||
|   for (const auto& method : methods) { | ||||
|     const auto& f = toGraphFunction(method.function()); | ||||
|     exportFunction( | ||||
|         exportSet, ModuleMethod{module, f, f.qualname()}, /* toplevel */ true); | ||||
|   } | ||||
|  | ||||
|   // __setstate__ of all components | ||||
|   setstateTuple(exportSet, module, module._ivalue(), type_name_uniquer_, true); | ||||
|  | ||||
|   return exportSet; | ||||
| } | ||||
|  | ||||
| void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) { | ||||
|   GetExtraFilesHook() = std::move(hook); | ||||
| } | ||||
| @ -774,9 +614,12 @@ void ScriptModuleSerializer::writeByteCode( | ||||
|   // Always save debug handles | ||||
|   debug_info_elements.emplace_back(static_cast<int64_t>(version_to_write)); | ||||
|  | ||||
|   BytecodeExportSet exportSet = moduleMethodsTuple(module, type_name_uniquer_); | ||||
|   pushFunctionToIValues( | ||||
|       std::move(exportSet), | ||||
|   mobile::Module mobile_module = | ||||
|       jitModuleToMobile(module, getOptionsFromGlobal()); | ||||
|  | ||||
|   pushMobileFunctionsToIValues( | ||||
|       *module._ivalue()->compilation_unit(), | ||||
|       mobile_module, | ||||
|       elements, | ||||
|       debug_info_elements, | ||||
|       debug_info_recorder, | ||||
| @ -840,9 +683,9 @@ void ScriptModuleSerializer::writeByteCode( | ||||
|     getBackendDebugInfoMap(module, backend_debug_info_map); | ||||
|     // Now get the debug-handles-to-inlined-cs-ptr-map | ||||
|     // And serialize that in a separate archive | ||||
|     auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording(); | ||||
|     debug_handle_cs_ptr_map.insert( | ||||
|         backend_debug_info_map.begin(), backend_debug_info_map.end()); | ||||
|     const auto& debug_info = mobile_module.getDebugTable().getCallStackPtrMap(); | ||||
|     BackendDebugInfoMapType debug_handle_cs_ptr_map( | ||||
|         debug_info.begin(), debug_info.end()); | ||||
|     CallStackDebugInfoPickler cs_debug_info_pickler; | ||||
|     auto cs_data = cs_debug_info_pickler.pickle( | ||||
|         debug_handle_cs_ptr_map, source_range_tags_); | ||||
| @ -962,31 +805,13 @@ void ExportModule( | ||||
|  | ||||
| namespace { | ||||
| void export_opnames(const script::Module& m, std::set<std::string>& opnames) { | ||||
|   std::vector<c10::IValue> elements; | ||||
|   BackendDebugInfoRecorder dummy; | ||||
|   TypeNameUniquer dummy_uniquer = TypeNameUniquer(); | ||||
|   BytecodeExportSet exportSet = moduleMethodsTuple(m, dummy_uniquer); | ||||
|   pushFunctionToIValues(std::move(exportSet), elements, dummy, dummy_uniquer); | ||||
|   for (const auto& element : elements) { | ||||
|     auto table = element.toTupleRef().elements()[1]; | ||||
|     auto row = | ||||
|         table.toTupleRef().elements().at(BYTECODE_INDEX_OPERATOR).toTuple(); | ||||
|     TORCH_INTERNAL_ASSERT( | ||||
|         row->elements().at(0).toStringRef() == "operators", | ||||
|         "Expected operators but found ", | ||||
|         row->elements().at(0).toStringRef()); | ||||
|     const auto& ops_list = row->elements().at(1).toTupleRef().elements(); | ||||
|     for (const auto& op : ops_list) { | ||||
|       const auto& op_item = op.toTupleRef().elements(); | ||||
|       TORCH_CHECK( | ||||
|           op_item.size() >= 2, | ||||
|           "There should be either two parts (name and overload name), ", | ||||
|           "or three parts (name, overload name and number of specified args) ", | ||||
|           "for an operator."); | ||||
|       auto opname = op_item[0].toString()->string(); | ||||
|       auto overload = op_item[1].toString()->string(); | ||||
|   mobile::Module mobile_m = jitModuleToMobile(m, getOptionsFromGlobal()); | ||||
|   for (const auto& method : mobile_m.get_methods()) { | ||||
|     for (const auto& op : method.function().get_code()->op_names_) { | ||||
|       // NOLINTNEXTLINE(performance-inefficient-string-concatenation) | ||||
|       opnames.emplace(overload.empty() ? opname : opname + "." + overload); | ||||
|       opnames.emplace( | ||||
|           op.overload_name.empty() ? op.name | ||||
|                                    : op.name + "." + op.overload_name); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
		Reference in New Issue
	
	Block a user