Re-do D48544397: [TGIF Inplace] [xlv2][1/n] Expose a couple APIs from inline_container that will be used for chunk read" (#109183)

Summary:
Original commit changeset: 4a5f31518ad0

Original Phabricator Diff: D48544397

fix easycla

Differential Revision: D49221088

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109183
Approved by: https://github.com/wqfish
This commit is contained in:
Lujia Zhang
2023-09-14 08:17:14 +00:00
committed by PyTorch MergeBot
parent 9cd4548f01
commit a6fadf643f
3 changed files with 143 additions and 18 deletions

View File

@ -29,6 +29,35 @@ namespace caffe2 {
namespace serialize {
constexpr c10::string_view kDebugPklSuffix(".debug_pkl");
struct MzZipReaderIterWrapper {
MzZipReaderIterWrapper(mz_zip_reader_extract_iter_state* iter) : impl(iter) {}
mz_zip_reader_extract_iter_state* impl;
};
ChunkRecordIterator::ChunkRecordIterator(
size_t recordSize,
size_t chunkSize,
std::unique_ptr<MzZipReaderIterWrapper> iter)
: recordSize_(recordSize),
chunkSize_(chunkSize),
offset_(0),
iter_(std::move(iter)) {}
ChunkRecordIterator::~ChunkRecordIterator() {
mz_zip_reader_extract_iter_free(iter_->impl);
}
size_t ChunkRecordIterator::next(void* buf){
size_t want_size = std::min(chunkSize_, recordSize_ - offset_);
if (want_size == 0) {
return 0;
}
size_t read_size = mz_zip_reader_extract_iter_read(iter_->impl, buf, want_size);
TORCH_CHECK(read_size > 0, "Read bytes should be larger than 0");
offset_ += read_size;
return read_size;
}
size_t istream_read_func(void* pOpaque, mz_uint64 file_ofs, void* pBuf, size_t n) {
auto self = static_cast<PyTorchStreamReader*>(pOpaque);
return self->read(file_ofs, static_cast<char*>(pBuf), n);
@ -362,34 +391,41 @@ size_t PyTorchStreamReader::getRecord(
n);
valid("retrieving file meta-data for ", name.c_str());
mz_zip_reader_extract_iter_state* iter =
mz_zip_reader_extract_iter_new(ar_.get(), key, 0);
TORCH_CHECK(
iter != nullptr,
"Failed to create zip reader iter: ",
mz_zip_get_error_string(mz_zip_get_last_error(ar_.get())));
std::vector<uint8_t> buffer;
if (buf == nullptr) {
buffer.resize(chunk_size);
buf = buffer.data();
}
for (size_t offset = 0; offset < stat.m_uncomp_size; offset += chunk_size) {
size_t want_size =
std::min(chunk_size, (size_t)stat.m_uncomp_size - offset);
size_t read_size =
mz_zip_reader_extract_iter_read(iter, buf, want_size);
TORCH_CHECK(
read_size == want_size,
"Failed to advance zip reader iter: ",
mz_zip_get_error_string(mz_zip_get_last_error(ar_.get())));
memcpy_func((char*)dst + offset, buf, read_size);
auto chunkIterator =
createChunkReaderIter(name, (size_t)stat.m_uncomp_size, chunk_size);
while (auto readSize = chunkIterator.next(buf)) {
memcpy_func((char*)dst + chunkIterator.offset_ - readSize, buf, readSize);
}
valid("reading file ", name.c_str());
mz_zip_reader_extract_iter_free(iter);
return stat.m_uncomp_size;
}
ChunkRecordIterator PyTorchStreamReader::createChunkReaderIter(
const std::string& name,
const size_t recordSize,
const size_t chunkSize) {
// Create zip reader iterator
size_t key = getRecordID(name);
mz_zip_reader_extract_iter_state* zipReaderIter =
mz_zip_reader_extract_iter_new(ar_.get(), key, 0);
TORCH_CHECK(
zipReaderIter != nullptr,
"Failed to create zip reader iter: ",
mz_zip_get_error_string(mz_zip_get_last_error(ar_.get())));
return ChunkRecordIterator(
recordSize,
chunkSize,
std::make_unique<MzZipReaderIterWrapper>(zipReaderIter));
}
static int64_t read_le_16(uint8_t* buf) {
return buf[0] + (buf[1] << 8);
}
@ -411,6 +447,11 @@ size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
}
size_t PyTorchStreamReader::getRecordSize(const std::string& name) {
mz_zip_archive_file_stat stat;
mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat);
return stat.m_uncomp_size;
}
PyTorchStreamReader::~PyTorchStreamReader() {
mz_zip_clear_last_error(ar_.get());

View File

@ -95,6 +95,29 @@ namespace serialize {
static constexpr const char* kSerializationIdRecordName = ".data/serialization_id";
struct MzZipReaderIterWrapper;
class TORCH_API ChunkRecordIterator {
public:
~ChunkRecordIterator();
// Read at most `chunkSize` into `buf`. Return the number of actual bytes read.
size_t next(void* buf);
private:
ChunkRecordIterator(
size_t recordSize,
size_t chunkSize,
std::unique_ptr<MzZipReaderIterWrapper> iter);
const size_t recordSize_;
const size_t chunkSize_;
size_t offset_;
std::unique_ptr<MzZipReaderIterWrapper> iter_;
friend class PyTorchStreamReader;
};
class TORCH_API PyTorchStreamReader final {
public:
explicit PyTorchStreamReader(const std::string& file_name);
@ -111,11 +134,19 @@ class TORCH_API PyTorchStreamReader final {
size_t n,
size_t chunk_size,
void* buf,
const std::function<void(void*, const void*, size_t)>& memcpy_func);
const std::function<void(void*, const void*, size_t)>& memcpy_func = nullptr);
size_t getRecordSize(const std::string& name);
size_t getRecordOffset(const std::string& name);
bool hasRecord(const std::string& name);
std::vector<std::string> getAllRecords();
ChunkRecordIterator createChunkReaderIter(
const std::string& name,
const size_t recordSize,
const size_t chunkSize);
~PyTorchStreamReader();
uint64_t version() const {
return version_;

View File

@ -340,6 +340,59 @@ TEST(PytorchStreamWriterAndReader, LogAPIUsageMetadata) {
const std::map<std::string, std::string>& metadata_map) {});
}
class ChunkRecordIteratorTest : public ::testing::TestWithParam<int64_t> {};
INSTANTIATE_TEST_SUITE_P(
ChunkRecordIteratorTestGroup,
ChunkRecordIteratorTest,
testing::Values(100, 150, 1010));
TEST_P(ChunkRecordIteratorTest, ChunkRead) {
auto chunkSize = GetParam();
std::string zipFileName = "output_chunk_" + std::to_string(chunkSize) + ".zip";
const char* fileName = zipFileName.c_str();
const std::string recordName = "key1";
const size_t tensorDataSizeInBytes = 1000;
// write records through writers
std::ostringstream oss;
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
oss.write(static_cast<const char*>(b), n);
return oss ? n : 0;
});
auto tensorData = std::vector<uint8_t>(tensorDataSizeInBytes, 1);
auto dataPtr = tensorData.data();
writer.writeRecord(recordName, dataPtr, tensorDataSizeInBytes);
const std::unordered_set<std::string>& written_records =
writer.getAllWrittenRecords();
ASSERT_EQ(written_records.size(), 1);
ASSERT_EQ(written_records.count(recordName), 1);
writer.writeEndOfFile();
ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
std::string the_file = oss.str();
std::ofstream foo(fileName);
foo.write(the_file.c_str(), the_file.size());
foo.close();
LOG(INFO) << "Finished saving tensor into zip file " << fileName;
LOG(INFO) << "Testing chunk size " << chunkSize;
PyTorchStreamReader reader(fileName);
ASSERT_TRUE(reader.hasRecord(recordName));
auto chunkIterator = reader.createChunkReaderIter(
recordName, tensorDataSizeInBytes, chunkSize);
std::vector<uint8_t> buffer(chunkSize);
size_t totalReadSize = 0;
while (auto readSize = chunkIterator.next(buffer.data())) {
auto expectedData = std::vector<uint8_t>(readSize, 1);
ASSERT_EQ(memcmp(expectedData.data(), buffer.data(), readSize), 0);
totalReadSize += readSize;
}
ASSERT_EQ(totalReadSize, tensorDataSizeInBytes);
// clean up
remove(fileName);
}
} // namespace
} // namespace serialize
} // namespace caffe2