#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "caffe2/serialize/istream_adapter.h" namespace torch { namespace jit { namespace { Module roundtripThroughMobile(const Module& m) { ExtraFilesMap files; std::vector constants; jitModuleToPythonCodeAndConstants(m, &files, &constants); CompilationOptions options; mobile::Module mobilem = jitModuleToMobile(m, options); return jitModuleFromSourceAndConstants( mobilem._ivalue(), files, constants, 8); } template inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) { try { std::forward(functor)(); } catch (const Error& e) { EXPECT_STREQ(e.what_without_backtrace(), expectedMessage); return; } ADD_FAILURE() << "Expected to throw exception with message \"" << expectedMessage << "\" but didn't throw"; } } // namespace TEST(SerializationTest, ExtraFilesHookPreference) { // Tests that an extra file written explicitly has precedence over // extra files written by a hook // TODO: test for the warning, too const auto script = R"JIT( def forward(self): x = torch.rand(5, 5) x = x.mm(x) return x )JIT"; auto module = std::make_shared("Module", std::make_shared()); module->define(script); std::ostringstream oss; std::unordered_map extra_files; extra_files["metadata.json"] = "abc"; SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap { return {{"metadata.json", "def"}}; }); module->save(oss, extra_files); SetExportModuleExtraFilesHook(nullptr); std::istringstream iss(oss.str()); caffe2::serialize::IStreamAdapter adapter{&iss}; std::unordered_map loaded_extra_files; loaded_extra_files["metadata.json"] = ""; auto loaded_module = torch::jit::load(iss, torch::kCPU, loaded_extra_files); ASSERT_EQ(loaded_extra_files["metadata.json"], "abc"); } TEST(SerializationTest, ExtraFileHooksNoSecret) { // no secrets std::stringstream ss; { Module m("__torch__.m"); ExtraFilesMap extra; extra["metadata.json"] = "abc"; m.save(ss, extra); } ss.seekg(0); { ExtraFilesMap extra; extra["metadata.json"] = ""; extra["secret.json"] = ""; jit::load(ss, std::nullopt, extra); ASSERT_EQ(extra["metadata.json"], "abc"); ASSERT_EQ(extra["secret.json"], ""); } } TEST(SerializationTest, ExtraFileHooksWithSecret) { std::stringstream ss; { SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap { return {{"secret.json", "topsecret"}}; }); Module m("__torch__.m"); ExtraFilesMap extra; extra["metadata.json"] = "abc"; m.save(ss, extra); SetExportModuleExtraFilesHook(nullptr); } ss.seekg(0); { ExtraFilesMap extra; extra["metadata.json"] = ""; extra["secret.json"] = ""; jit::load(ss, std::nullopt, extra); ASSERT_EQ(extra["metadata.json"], "abc"); ASSERT_EQ(extra["secret.json"], "topsecret"); } } TEST(SerializationTest, TypeTags) { auto list = c10::List>(); list.push_back(c10::List({1, 2, 3})); list.push_back(c10::List({4, 5, 6})); auto dict = c10::Dict(); dict.insert("Hello", torch::ones({2, 2})); auto dict_list = c10::List>(); for (size_t i = 0; i < 5; i++) { auto another_dict = c10::Dict(); another_dict.insert("Hello" + std::to_string(i), torch::ones({2, 2})); dict_list.push_back(another_dict); } auto tuple = std::tuple(2, "hi"); struct TestItem { IValue value; TypePtr expected_type; }; std::vector items = { {list, ListType::create(ListType::create(IntType::get()))}, {2, IntType::get()}, {dict, DictType::create(StringType::get(), TensorType::get())}, {dict_list, ListType::create( DictType::create(StringType::get(), TensorType::get()))}, {tuple, TupleType::create({IntType::get(), StringType::get()})}}; // NOLINTNEXTLINE(performance-for-range-copy) for (auto item : items) { auto bytes = torch::pickle_save(item.value); auto loaded = torch::pickle_load(bytes); ASSERT_TRUE(loaded.type()->isSubtypeOf(*item.expected_type)); ASSERT_TRUE(item.expected_type->isSubtypeOf(*loaded.type())); } } TEST(SerializationTest, TestJitStream_CUDA) { torch::jit::Module model; std::vector inputs; // Deserialize the ScriptModule from a file using torch::jit::load(). // Load the scripted model. This should have been generated by tests_setup.py // Refer: TorchSaveJitStream_CUDA in test/cpp/jit/tests_setup.py model = torch::jit::load("saved_stream_model.pt"); auto output = model.forward(inputs); const auto& list_of_elements = output.toTupleRef().elements(); auto is_stream_s = list_of_elements[0].toBool(); // a,b: These are the two input tensors // c: This is output tensor generated by the operation torch.cat(a,b) auto a = list_of_elements[1].toTensor(); auto b = list_of_elements[2].toTensor(); auto c = list_of_elements[3].toTensor(); // op: this is used to verify if the cat operation produced the same results // as that on the GPU with torch.cat auto op = at::cat({a, b}, 0); // Check if the stream is set ASSERT_TRUE(is_stream_s); // Check if the sizes of the outputs (op and c) is same on the GPU and CPU ASSERT_EQ(op.sizes(), c.sizes()); // Check if both the output tensors are equal ASSERT_TRUE(op.equal(c)); } TEST(TestSourceRoundTrip, UpsampleNearest2d) { Module m("m"); m.define(R"( def forward(self, input: Tensor, scale:float): return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) )"); std::vector inputs; inputs.emplace_back(torch::rand({1, 3, 128, 128})); inputs.emplace_back(at::Scalar(2.0)); auto ref = m.forward(inputs); Module m2 = roundtripThroughMobile(m); auto res = m2.forward(inputs); auto resd = res.toTensor(); auto refd = ref.toTensor(); ASSERT_TRUE(resd.equal(refd)); } TEST(TestSourceRoundTrip, CheckAttrAccess) { Module m("m"); m.register_attribute("mobile_optimized", BoolType::get(), true); Module m2 = roundtripThroughMobile(m); bool mobile_optimized = m2.attr("mobile_optimized", false).toBool(); AT_ASSERT(mobile_optimized); } TEST(TestSourceRoundTrip, MethodInvocation) { // NOLINT (use =delete in gtest) const std::vector test_programs{ // test invoking a method with default parameter R"( def test_func(self, x, b : int = 4): return self.foo + x + b )", // inner method call with default parameter (gets inlined) R"( def add_with_default_arg(self, x, b : int = 4): return self.foo + x + b def test_func(self, x): return self.add_with_default_arg(x) # invoke method w/ default arg )", // simple method call R"( def test_func(self, x): b = 4 return self.foo + x + b )", }; for (const auto& test_program : test_programs) { Module m("m"); m.register_parameter("foo", torch::ones({}), false); m.define(test_program); const int fortyTwo = 42; // (keep linter happy) auto minput = fortyTwo * torch::ones({}); auto ref = m.run_method("test_func", minput); Module m2 = roundtripThroughMobile(m); const auto& test_func = m2.get_method("test_func"); IValue res; for (int i = 0; i < 3; ++i) { res = test_func({minput}); } auto resd = res.toTensor().item(); auto refd = ref.toTensor().item(); AT_ASSERT(resd == refd); } } TEST(SerializationTest, ParentDirNotExist) { expectThrowsEq( []() { auto t = torch::nn::Linear(5, 5); torch::save(t, "./doesnotexist/file.pt"); }, "Parent directory ./doesnotexist does not exist."); } #ifdef WIN32 TEST(SerializationTest, WindowsDrivePathTest) { // "ZZZ" is typically not a valid drive letter. // We expect to see "ZZZ:\\" or "ZZZ:/" in the error message. // Note: slash should be included for the drive letter parent in Windows. expectThrowsEq( []() { auto t = torch::nn::Linear(5, 5); torch::save(t, "ZZZ:\\file.pt"); }, "Parent directory ZZZ:\\ does not exist."); expectThrowsEq( []() { auto t = torch::nn::Linear(5, 5); torch::save(t, "ZZZ:/file.pt"); }, "Parent directory ZZZ:/ does not exist."); } TEST(SerializationTest, WindowsTempPathTest) { // Test for verifying file saving and loading in the temporary folder std::string temp_dir = std::getenv("TEMP"); std::string file_path = temp_dir + "/file.pt"; auto t1 = torch::tensor(1.0); torch::save(t1, file_path); torch::Tensor t2; torch::load(t2, file_path); ASSERT_TRUE(t1.allclose(t2, 0.0, 0.0)); } #endif TEST(SerializationTest, CalculateNecessaryArgsTest) { auto schema = torch::schema( "sync_stream(int stream_id = -1) -> ()", c10::AliasAnalysisKind::CONSERVATIVE); auto graph = std::make_shared(); auto one_val = graph->insertConstant(-1); auto necessary = CalculateNecessaryArgs(schema.arguments(), {one_val}, true); EXPECT_EQ(0, necessary.first); EXPECT_EQ(0, necessary.second); } TEST(TestSaveLoad, LoadWithoutDebugInfo) { // NOLINT (use =delete in gtest) Module m("m"); m.register_parameter("foo", torch::ones({}), false); m.define( R"( def test_func(self, x): b = 4 return self.foo + x + b )"); m.define( R"( def exception(self): assert False, "message" )"); std::stringstream ss; m.save(ss); ss.seekg(0); caffe2::serialize::PyTorchStreamReader reader(&ss); reader.setShouldLoadDebugSymbol(true); EXPECT_TRUE(reader.hasRecord("code/__torch__.py.debug_pkl")); reader.setShouldLoadDebugSymbol(false); EXPECT_FALSE(reader.hasRecord("code/__torch__.py.debug_pkl")); ss.seekg(0); Module m2 = torch::jit::load(ss); std::string error_msg = R"( def exception(self): assert False, "message" ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE)"; ASSERT_THROWS_WITH_MESSAGE(m2.run_method("exception"), error_msg); ss.seekg(0); // NO DEBUG trace so error message points to torchscript generated // source instead of original python source. std::string error2 = R"( def exception(self: __torch__.m) -> NoneType: _0 = uninitialized(NoneType) ops.prim.RaiseException("AssertionError: message") ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE return _0 )"; Module m3 = torch::jit::load(ss, std::nullopt, false); ASSERT_THROWS_WITH_MESSAGE(m3.run_method("exception"), error2); } TEST(SerializationTest, TestPickleAppend) { auto data = std::vector({'\x80', char(2), ']', 'K', char(2), 'a', '.'}); torch::IValue actual = torch::jit::unpickle(data.data(), data.size()); torch::IValue expected = c10::impl::GenericList(at::AnyType::get()); expected.toList().push_back(2); ASSERT_EQ(expected, actual); } } // namespace jit } // namespace torch