From f558bb6f76f17d76b54275dc46cee5ee77706195 Mon Sep 17 00:00:00 2001 From: Hongyi Jia Date: Thu, 4 May 2023 01:30:59 +0000 Subject: [PATCH] inplace PyTorchStreamReader getRecord() (#100418) Summary: Sometimes we want to getRecord into an pre-allocated memory to save cpu memory. Adding new API to support the inplace memory writing. Test Plan: caffe2/serialize/inline_container_test Reviewed By: zyan0 Differential Revision: D45439517 Pull Request resolved: https://github.com/pytorch/pytorch/pull/100418 Approved by: https://github.com/davidberard98, https://github.com/houseroad --- caffe2/serialize/inline_container.cc | 23 +++++++++++++++++++++++ caffe2/serialize/inline_container.h | 2 ++ caffe2/serialize/inline_container_test.cc | 15 +++++++++++++++ 3 files changed, 40 insertions(+) diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index dfda31801a14..606b7f42cc28 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -297,6 +297,29 @@ std::tuple PyTorchStreamReader::getRecord(const std::string return std::make_tuple(std::move(retval), stat.m_uncomp_size); } +// inplace memory writing +size_t +PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) { + std::lock_guard guard(reader_lock_); + if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) { + return 0; + } + size_t key = getRecordID(name); + mz_zip_archive_file_stat stat; + mz_zip_reader_file_stat(ar_.get(), key, &stat); + TORCH_CHECK( + n == stat.m_uncomp_size, + "record size ", + stat.m_uncomp_size, + " mismatch with dst size ", + n); + valid("retrieving file meta-data for ", name.c_str()); + mz_zip_reader_extract_to_mem(ar_.get(), key, dst, stat.m_uncomp_size, 0); + valid("reading file ", name.c_str()); + + return stat.m_uncomp_size; +} + static int64_t read_le_16(uint8_t* buf) { return buf[0] + (buf[1] << 8); } diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index 3e6bc1e76b55..30436ea55f6e 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -101,6 +101,8 @@ class TORCH_API PyTorchStreamReader final { // return dataptr, size std::tuple getRecord(const std::string& name); + // inplace memory writing + size_t getRecord(const std::string& name, void* dst, size_t n); size_t getRecordOffset(const std::string& name); bool hasRecord(const std::string& name); std::vector getAllRecords(); diff --git a/caffe2/serialize/inline_container_test.cc b/caffe2/serialize/inline_container_test.cc index 4d68f4f1f985..157ff7228427 100644 --- a/caffe2/serialize/inline_container_test.cc +++ b/caffe2/serialize/inline_container_test.cc @@ -64,6 +64,11 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0); ASSERT_EQ(memcmp(the_file.c_str() + off1, data1.data(), data1.size()), 0); ASSERT_EQ(off1 % kFieldAlignment, 0); + // inplace getRecord() test + std::vector dst(size); + size_t ret = reader.getRecord("key1", dst.data(), size); + ASSERT_EQ(ret, size); + ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0); std::tie(data_ptr, size) = reader.getRecord("key2"); size_t off2 = reader.getRecordOffset("key2"); @@ -72,6 +77,11 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { ASSERT_EQ(size, data2.size()); ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0); ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0); + // inplace getRecord() test + dst.resize(size); + ret = reader.getRecord("key2", dst.data(), size); + ASSERT_EQ(ret, size); + ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0); } TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) { @@ -115,6 +125,8 @@ TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) { PyTorchStreamReader reader(&iss); // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) EXPECT_THROW(reader.getRecord("key3"), c10::Error); + std::vector dst(data1.size()); + EXPECT_THROW(reader.getRecord("key3", dst.data(), data1.size()), c10::Error); // Reader should still work after throwing EXPECT_TRUE(reader.hasRecord("key1")); @@ -165,6 +177,9 @@ TEST(PytorchStreamWriterAndReader, SkipDebugRecords) { size_t size; std::tie(ptr, size) = reader.getRecord("key1.debug_pkl"); EXPECT_EQ(size, 0); + std::vector dst(data1.size()); + size_t ret = reader.getRecord("key1.debug_pkl", dst.data(), data1.size()); + EXPECT_EQ(ret, 0); } } // namespace