mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-30 11:44:59 +08:00 
			
		
		
		
	Follows #130509 Pull Request resolved: https://github.com/pytorch/pytorch/pull/130674 Approved by: https://github.com/Skylion007
		
			
				
	
	
		
			366 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			366 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #include <gtest/gtest.h>
 | |
| 
 | |
| #include <test/cpp/jit/test_utils.h>
 | |
| #include <cstdlib>
 | |
| #include <iostream>
 | |
| #include <sstream>
 | |
| 
 | |
| #include <caffe2/serialize/inline_container.h>
 | |
| #include <torch/csrc/jit/mobile/module.h>
 | |
| #include <torch/csrc/jit/runtime/calculate_necessary_args.h>
 | |
| #include <torch/csrc/jit/serialization/export.h>
 | |
| #include <torch/csrc/jit/serialization/export_bytecode.h>
 | |
| #include <torch/csrc/jit/serialization/import.h>
 | |
| #include <torch/csrc/jit/serialization/import_source.h>
 | |
| #include <torch/script.h>
 | |
| #include <torch/torch.h>
 | |
| 
 | |
| #include "caffe2/serialize/istream_adapter.h"
 | |
| 
 | |
| namespace torch {
 | |
| namespace jit {
 | |
| 
 | |
| namespace {
 | |
| 
 | |
| Module roundtripThroughMobile(const Module& m) {
 | |
|   ExtraFilesMap files;
 | |
|   std::vector<IValue> constants;
 | |
|   jitModuleToPythonCodeAndConstants(m, &files, &constants);
 | |
|   CompilationOptions options;
 | |
|   mobile::Module mobilem = jitModuleToMobile(m, options);
 | |
|   return jitModuleFromSourceAndConstants(
 | |
|       mobilem._ivalue(), files, constants, 8);
 | |
| }
 | |
| 
 | |
| template <class Functor>
 | |
| inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) {
 | |
|   try {
 | |
|     std::forward<Functor>(functor)();
 | |
|   } catch (const Error& e) {
 | |
|     EXPECT_STREQ(e.what_without_backtrace(), expectedMessage);
 | |
|     return;
 | |
|   }
 | |
|   ADD_FAILURE() << "Expected to throw exception with message \""
 | |
|                 << expectedMessage << "\" but didn't throw";
 | |
| }
 | |
| 
 | |
| } // namespace
 | |
| 
 | |
| TEST(SerializationTest, ExtraFilesHookPreference) {
 | |
|   // Tests that an extra file written explicitly has precedence over
 | |
|   //   extra files written by a hook
 | |
|   // TODO: test for the warning, too
 | |
|   const auto script = R"JIT(
 | |
|     def forward(self):
 | |
|         x = torch.rand(5, 5)
 | |
|         x = x.mm(x)
 | |
|         return x
 | |
|   )JIT";
 | |
| 
 | |
|   auto module =
 | |
|       std::make_shared<Module>("Module", std::make_shared<CompilationUnit>());
 | |
|   module->define(script);
 | |
|   std::ostringstream oss;
 | |
|   std::unordered_map<std::string, std::string> extra_files;
 | |
|   extra_files["metadata.json"] = "abc";
 | |
|   SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap {
 | |
|     return {{"metadata.json", "def"}};
 | |
|   });
 | |
|   module->save(oss, extra_files);
 | |
|   SetExportModuleExtraFilesHook(nullptr);
 | |
| 
 | |
|   std::istringstream iss(oss.str());
 | |
|   caffe2::serialize::IStreamAdapter adapter{&iss};
 | |
|   std::unordered_map<std::string, std::string> loaded_extra_files;
 | |
|   loaded_extra_files["metadata.json"] = "";
 | |
|   auto loaded_module = torch::jit::load(iss, torch::kCPU, loaded_extra_files);
 | |
|   ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
 | |
| }
 | |
| 
 | |
| TEST(SerializationTest, ExtraFileHooksNoSecret) {
 | |
|   // no secrets
 | |
|   std::stringstream ss;
 | |
|   {
 | |
|     Module m("__torch__.m");
 | |
|     ExtraFilesMap extra;
 | |
|     extra["metadata.json"] = "abc";
 | |
|     m.save(ss, extra);
 | |
|   }
 | |
|   ss.seekg(0);
 | |
|   {
 | |
|     ExtraFilesMap extra;
 | |
|     extra["metadata.json"] = "";
 | |
|     extra["secret.json"] = "";
 | |
|     jit::load(ss, std::nullopt, extra);
 | |
|     ASSERT_EQ(extra["metadata.json"], "abc");
 | |
|     ASSERT_EQ(extra["secret.json"], "");
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(SerializationTest, ExtraFileHooksWithSecret) {
 | |
|   std::stringstream ss;
 | |
|   {
 | |
|     SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap {
 | |
|       return {{"secret.json", "topsecret"}};
 | |
|     });
 | |
|     Module m("__torch__.m");
 | |
|     ExtraFilesMap extra;
 | |
|     extra["metadata.json"] = "abc";
 | |
|     m.save(ss, extra);
 | |
|     SetExportModuleExtraFilesHook(nullptr);
 | |
|   }
 | |
|   ss.seekg(0);
 | |
|   {
 | |
|     ExtraFilesMap extra;
 | |
|     extra["metadata.json"] = "";
 | |
|     extra["secret.json"] = "";
 | |
|     jit::load(ss, std::nullopt, extra);
 | |
|     ASSERT_EQ(extra["metadata.json"], "abc");
 | |
|     ASSERT_EQ(extra["secret.json"], "topsecret");
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(SerializationTest, TypeTags) {
 | |
|   auto list = c10::List<c10::List<int64_t>>();
 | |
|   list.push_back(c10::List<int64_t>({1, 2, 3}));
 | |
|   list.push_back(c10::List<int64_t>({4, 5, 6}));
 | |
|   auto dict = c10::Dict<std::string, at::Tensor>();
 | |
|   dict.insert("Hello", torch::ones({2, 2}));
 | |
|   auto dict_list = c10::List<c10::Dict<std::string, at::Tensor>>();
 | |
|   for (size_t i = 0; i < 5; i++) {
 | |
|     auto another_dict = c10::Dict<std::string, at::Tensor>();
 | |
|     another_dict.insert("Hello" + std::to_string(i), torch::ones({2, 2}));
 | |
|     dict_list.push_back(another_dict);
 | |
|   }
 | |
|   auto tuple = std::tuple<int, std::string>(2, "hi");
 | |
|   struct TestItem {
 | |
|     IValue value;
 | |
|     TypePtr expected_type;
 | |
|   };
 | |
|   std::vector<TestItem> items = {
 | |
|       {list, ListType::create(ListType::create(IntType::get()))},
 | |
|       {2, IntType::get()},
 | |
|       {dict, DictType::create(StringType::get(), TensorType::get())},
 | |
|       {dict_list,
 | |
|        ListType::create(
 | |
|            DictType::create(StringType::get(), TensorType::get()))},
 | |
|       {tuple, TupleType::create({IntType::get(), StringType::get()})}};
 | |
|   // NOLINTNEXTLINE(performance-for-range-copy)
 | |
|   for (auto item : items) {
 | |
|     auto bytes = torch::pickle_save(item.value);
 | |
|     auto loaded = torch::pickle_load(bytes);
 | |
|     ASSERT_TRUE(loaded.type()->isSubtypeOf(*item.expected_type));
 | |
|     ASSERT_TRUE(item.expected_type->isSubtypeOf(*loaded.type()));
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(SerializationTest, TestJitStream_CUDA) {
 | |
|   torch::jit::Module model;
 | |
|   std::vector<torch::jit::IValue> inputs;
 | |
|   // Deserialize the ScriptModule from a file using torch::jit::load().
 | |
|   // Load the scripted model. This should have been generated by tests_setup.py
 | |
|   // Refer: TorchSaveJitStream_CUDA in test/cpp/jit/tests_setup.py
 | |
|   model = torch::jit::load("saved_stream_model.pt");
 | |
| 
 | |
|   auto output = model.forward(inputs);
 | |
|   const auto& list_of_elements = output.toTupleRef().elements();
 | |
|   auto is_stream_s = list_of_elements[0].toBool();
 | |
| 
 | |
|   // a,b: These are the two input tensors
 | |
|   // c: This is output tensor generated by the operation torch.cat(a,b)
 | |
|   auto a = list_of_elements[1].toTensor();
 | |
|   auto b = list_of_elements[2].toTensor();
 | |
|   auto c = list_of_elements[3].toTensor();
 | |
|   // op: this is used to verify if the cat operation produced the same results
 | |
|   // as that on the GPU with torch.cat
 | |
|   auto op = at::cat({a, b}, 0);
 | |
| 
 | |
|   // Check if the stream is set
 | |
|   ASSERT_TRUE(is_stream_s);
 | |
|   // Check if the sizes of the outputs (op and c) is same on the GPU and CPU
 | |
|   ASSERT_EQ(op.sizes(), c.sizes());
 | |
|   // Check if both the output tensors are equal
 | |
|   ASSERT_TRUE(op.equal(c));
 | |
| }
 | |
| 
 | |
| TEST(TestSourceRoundTrip, UpsampleNearest2d) {
 | |
|   Module m("m");
 | |
|   m.define(R"(
 | |
|     def forward(self, input: Tensor, scale:float):
 | |
|       return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
 | |
|   )");
 | |
| 
 | |
|   std::vector<IValue> inputs;
 | |
|   inputs.emplace_back(torch::rand({1, 3, 128, 128}));
 | |
|   inputs.emplace_back(at::Scalar(2.0));
 | |
|   auto ref = m.forward(inputs);
 | |
| 
 | |
|   Module m2 = roundtripThroughMobile(m);
 | |
|   auto res = m2.forward(inputs);
 | |
| 
 | |
|   auto resd = res.toTensor();
 | |
|   auto refd = ref.toTensor();
 | |
|   ASSERT_TRUE(resd.equal(refd));
 | |
| }
 | |
| 
 | |
| TEST(TestSourceRoundTrip, CheckAttrAccess) {
 | |
|   Module m("m");
 | |
|   m.register_attribute("mobile_optimized", BoolType::get(), true);
 | |
|   Module m2 = roundtripThroughMobile(m);
 | |
|   bool mobile_optimized = m2.attr("mobile_optimized", false).toBool();
 | |
|   AT_ASSERT(mobile_optimized);
 | |
| }
 | |
| 
 | |
| TEST(TestSourceRoundTrip,
 | |
|      MethodInvocation) { // NOLINT (use =delete in gtest)
 | |
|   const std::vector<std::string> test_programs{
 | |
|       // test invoking a method with default parameter
 | |
|       R"(
 | |
|       def test_func(self, x, b : int = 4):
 | |
|         return self.foo + x + b
 | |
|       )",
 | |
|       // inner method call with default parameter (gets inlined)
 | |
|       R"(
 | |
|       def add_with_default_arg(self, x, b : int = 4):
 | |
|         return self.foo + x + b
 | |
|       def test_func(self, x):
 | |
|         return self.add_with_default_arg(x)  # invoke method w/ default arg
 | |
|       )",
 | |
|       // simple method call
 | |
|       R"(
 | |
|       def test_func(self, x):
 | |
|         b = 4
 | |
|         return self.foo + x + b
 | |
|       )",
 | |
|   };
 | |
|   for (const auto& test_program : test_programs) {
 | |
|     Module m("m");
 | |
|     m.register_parameter("foo", torch::ones({}), false);
 | |
|     m.define(test_program);
 | |
| 
 | |
|     const int fortyTwo = 42; // (keep linter happy)
 | |
|     auto minput = fortyTwo * torch::ones({});
 | |
|     auto ref = m.run_method("test_func", minput);
 | |
| 
 | |
|     Module m2 = roundtripThroughMobile(m);
 | |
|     const auto& test_func = m2.get_method("test_func");
 | |
|     IValue res;
 | |
|     for (int i = 0; i < 3; ++i) {
 | |
|       res = test_func({minput});
 | |
|     }
 | |
| 
 | |
|     auto resd = res.toTensor().item<float>();
 | |
|     auto refd = ref.toTensor().item<float>();
 | |
|     AT_ASSERT(resd == refd);
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(SerializationTest, ParentDirNotExist) {
 | |
|   expectThrowsEq(
 | |
|       []() {
 | |
|         auto t = torch::nn::Linear(5, 5);
 | |
|         torch::save(t, "./doesnotexist/file.pt");
 | |
|       },
 | |
|       "Parent directory ./doesnotexist does not exist.");
 | |
| }
 | |
| 
 | |
| #ifdef WIN32
 | |
| TEST(SerializationTest, WindowsDrivePathTest) {
 | |
|   // "ZZZ" is typically not a valid drive letter.
 | |
|   // We expect to see "ZZZ:\\" or "ZZZ:/" in the error message.
 | |
|   // Note: slash should be included for the drive letter parent in Windows.
 | |
|   expectThrowsEq(
 | |
|       []() {
 | |
|         auto t = torch::nn::Linear(5, 5);
 | |
|         torch::save(t, "ZZZ:\\file.pt");
 | |
|       },
 | |
|       "Parent directory ZZZ:\\ does not exist.");
 | |
|   expectThrowsEq(
 | |
|       []() {
 | |
|         auto t = torch::nn::Linear(5, 5);
 | |
|         torch::save(t, "ZZZ:/file.pt");
 | |
|       },
 | |
|       "Parent directory ZZZ:/ does not exist.");
 | |
| }
 | |
| 
 | |
| TEST(SerializationTest, WindowsTempPathTest) {
 | |
|   // Test for verifying file saving and loading in the temporary folder
 | |
|   std::string temp_dir = std::getenv("TEMP");
 | |
|   std::string file_path = temp_dir + "/file.pt";
 | |
|   auto t1 = torch::tensor(1.0);
 | |
|   torch::save(t1, file_path);
 | |
|   torch::Tensor t2;
 | |
|   torch::load(t2, file_path);
 | |
|   ASSERT_TRUE(t1.allclose(t2, 0.0, 0.0));
 | |
| }
 | |
| #endif
 | |
| 
 | |
| TEST(SerializationTest, CalculateNecessaryArgsTest) {
 | |
|   auto schema = torch::schema(
 | |
|       "sync_stream(int stream_id = -1) -> ()",
 | |
|       c10::AliasAnalysisKind::CONSERVATIVE);
 | |
| 
 | |
|   auto graph = std::make_shared<Graph>();
 | |
|   auto one_val = graph->insertConstant(-1);
 | |
|   auto necessary = CalculateNecessaryArgs(schema.arguments(), {one_val}, true);
 | |
|   EXPECT_EQ(0, necessary.first);
 | |
|   EXPECT_EQ(0, necessary.second);
 | |
| }
 | |
| 
 | |
| TEST(TestSaveLoad, LoadWithoutDebugInfo) { // NOLINT (use =delete in gtest)
 | |
|   Module m("m");
 | |
|   m.register_parameter("foo", torch::ones({}), false);
 | |
|   m.define(
 | |
|       R"(
 | |
|     def test_func(self, x):
 | |
|       b = 4
 | |
|       return self.foo + x + b
 | |
|     )");
 | |
|   m.define(
 | |
|       R"(
 | |
|     def exception(self):
 | |
|       assert False, "message"
 | |
|     )");
 | |
|   std::stringstream ss;
 | |
|   m.save(ss);
 | |
|   ss.seekg(0);
 | |
|   caffe2::serialize::PyTorchStreamReader reader(&ss);
 | |
|   reader.setShouldLoadDebugSymbol(true);
 | |
|   EXPECT_TRUE(reader.hasRecord("code/__torch__.py.debug_pkl"));
 | |
|   reader.setShouldLoadDebugSymbol(false);
 | |
|   EXPECT_FALSE(reader.hasRecord("code/__torch__.py.debug_pkl"));
 | |
|   ss.seekg(0);
 | |
|   Module m2 = torch::jit::load(ss);
 | |
|   std::string error_msg = R"(
 | |
|     def exception(self):
 | |
|       assert False, "message"
 | |
|       ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE)";
 | |
|   ASSERT_THROWS_WITH_MESSAGE(m2.run_method("exception"), error_msg);
 | |
| 
 | |
|   ss.seekg(0);
 | |
|   // NO DEBUG trace so error message points to torchscript generated
 | |
|   // source instead of original python source.
 | |
|   std::string error2 = R"(
 | |
|     def exception(self: __torch__.m) -> NoneType:
 | |
|       _0 = uninitialized(NoneType)
 | |
|       ops.prim.RaiseException("AssertionError: message")
 | |
|       ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
 | |
|       return _0
 | |
|   )";
 | |
|   Module m3 = torch::jit::load(ss, std::nullopt, false);
 | |
|   ASSERT_THROWS_WITH_MESSAGE(m3.run_method("exception"), error2);
 | |
| }
 | |
| 
 | |
| TEST(SerializationTest, TestPickleAppend) {
 | |
|   auto data = std::vector<char>({'\x80', char(2), ']', 'K', char(2), 'a', '.'});
 | |
| 
 | |
|   torch::IValue actual = torch::jit::unpickle(data.data(), data.size());
 | |
| 
 | |
|   torch::IValue expected = c10::impl::GenericList(at::AnyType::get());
 | |
|   expected.toList().push_back(2);
 | |
|   ASSERT_EQ(expected, actual);
 | |
| }
 | |
| 
 | |
| } // namespace jit
 | |
| } // namespace torch
 |