mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
inplace PyTorchStreamReader getRecord() (#100418)
Summary: Sometimes we want to getRecord into an pre-allocated memory to save cpu memory. Adding new API to support the inplace memory writing. Test Plan: caffe2/serialize/inline_container_test Reviewed By: zyan0 Differential Revision: D45439517 Pull Request resolved: https://github.com/pytorch/pytorch/pull/100418 Approved by: https://github.com/davidberard98, https://github.com/houseroad
This commit is contained in:
committed by
PyTorch MergeBot
parent
e6c0164f1c
commit
f558bb6f76
@ -297,6 +297,29 @@ std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string
|
|||||||
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
|
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// inplace memory writing
|
||||||
|
size_t
|
||||||
|
PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) {
|
||||||
|
std::lock_guard<std::mutex> guard(reader_lock_);
|
||||||
|
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
size_t key = getRecordID(name);
|
||||||
|
mz_zip_archive_file_stat stat;
|
||||||
|
mz_zip_reader_file_stat(ar_.get(), key, &stat);
|
||||||
|
TORCH_CHECK(
|
||||||
|
n == stat.m_uncomp_size,
|
||||||
|
"record size ",
|
||||||
|
stat.m_uncomp_size,
|
||||||
|
" mismatch with dst size ",
|
||||||
|
n);
|
||||||
|
valid("retrieving file meta-data for ", name.c_str());
|
||||||
|
mz_zip_reader_extract_to_mem(ar_.get(), key, dst, stat.m_uncomp_size, 0);
|
||||||
|
valid("reading file ", name.c_str());
|
||||||
|
|
||||||
|
return stat.m_uncomp_size;
|
||||||
|
}
|
||||||
|
|
||||||
static int64_t read_le_16(uint8_t* buf) {
|
static int64_t read_le_16(uint8_t* buf) {
|
||||||
return buf[0] + (buf[1] << 8);
|
return buf[0] + (buf[1] << 8);
|
||||||
}
|
}
|
||||||
|
@ -101,6 +101,8 @@ class TORCH_API PyTorchStreamReader final {
|
|||||||
|
|
||||||
// return dataptr, size
|
// return dataptr, size
|
||||||
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
|
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
|
||||||
|
// inplace memory writing
|
||||||
|
size_t getRecord(const std::string& name, void* dst, size_t n);
|
||||||
size_t getRecordOffset(const std::string& name);
|
size_t getRecordOffset(const std::string& name);
|
||||||
bool hasRecord(const std::string& name);
|
bool hasRecord(const std::string& name);
|
||||||
std::vector<std::string> getAllRecords();
|
std::vector<std::string> getAllRecords();
|
||||||
|
@ -64,6 +64,11 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
|||||||
ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0);
|
ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0);
|
||||||
ASSERT_EQ(memcmp(the_file.c_str() + off1, data1.data(), data1.size()), 0);
|
ASSERT_EQ(memcmp(the_file.c_str() + off1, data1.data(), data1.size()), 0);
|
||||||
ASSERT_EQ(off1 % kFieldAlignment, 0);
|
ASSERT_EQ(off1 % kFieldAlignment, 0);
|
||||||
|
// inplace getRecord() test
|
||||||
|
std::vector<uint8_t> dst(size);
|
||||||
|
size_t ret = reader.getRecord("key1", dst.data(), size);
|
||||||
|
ASSERT_EQ(ret, size);
|
||||||
|
ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
|
||||||
|
|
||||||
std::tie(data_ptr, size) = reader.getRecord("key2");
|
std::tie(data_ptr, size) = reader.getRecord("key2");
|
||||||
size_t off2 = reader.getRecordOffset("key2");
|
size_t off2 = reader.getRecordOffset("key2");
|
||||||
@ -72,6 +77,11 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
|||||||
ASSERT_EQ(size, data2.size());
|
ASSERT_EQ(size, data2.size());
|
||||||
ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0);
|
ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0);
|
||||||
ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0);
|
ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0);
|
||||||
|
// inplace getRecord() test
|
||||||
|
dst.resize(size);
|
||||||
|
ret = reader.getRecord("key2", dst.data(), size);
|
||||||
|
ASSERT_EQ(ret, size);
|
||||||
|
ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
|
TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
|
||||||
@ -115,6 +125,8 @@ TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
|
|||||||
PyTorchStreamReader reader(&iss);
|
PyTorchStreamReader reader(&iss);
|
||||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||||
EXPECT_THROW(reader.getRecord("key3"), c10::Error);
|
EXPECT_THROW(reader.getRecord("key3"), c10::Error);
|
||||||
|
std::vector<uint8_t> dst(data1.size());
|
||||||
|
EXPECT_THROW(reader.getRecord("key3", dst.data(), data1.size()), c10::Error);
|
||||||
|
|
||||||
// Reader should still work after throwing
|
// Reader should still work after throwing
|
||||||
EXPECT_TRUE(reader.hasRecord("key1"));
|
EXPECT_TRUE(reader.hasRecord("key1"));
|
||||||
@ -165,6 +177,9 @@ TEST(PytorchStreamWriterAndReader, SkipDebugRecords) {
|
|||||||
size_t size;
|
size_t size;
|
||||||
std::tie(ptr, size) = reader.getRecord("key1.debug_pkl");
|
std::tie(ptr, size) = reader.getRecord("key1.debug_pkl");
|
||||||
EXPECT_EQ(size, 0);
|
EXPECT_EQ(size, 0);
|
||||||
|
std::vector<uint8_t> dst(data1.size());
|
||||||
|
size_t ret = reader.getRecord("key1.debug_pkl", dst.data(), data1.size());
|
||||||
|
EXPECT_EQ(ret, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Reference in New Issue
Block a user