mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Re-do D48544397: [TGIF Inplace] [xlv2][1/n] Expose a couple APIs from inline_container that will be used for chunk read" (#109183)
Summary: Original commit changeset: 4a5f31518ad0 Original Phabricator Diff: D48544397 fix easycla Differential Revision: D49221088 Pull Request resolved: https://github.com/pytorch/pytorch/pull/109183 Approved by: https://github.com/wqfish
This commit is contained in:
committed by
PyTorch MergeBot
parent
9cd4548f01
commit
a6fadf643f
@ -29,6 +29,35 @@ namespace caffe2 {
|
||||
namespace serialize {
|
||||
constexpr c10::string_view kDebugPklSuffix(".debug_pkl");
|
||||
|
||||
struct MzZipReaderIterWrapper {
|
||||
MzZipReaderIterWrapper(mz_zip_reader_extract_iter_state* iter) : impl(iter) {}
|
||||
mz_zip_reader_extract_iter_state* impl;
|
||||
};
|
||||
|
||||
ChunkRecordIterator::ChunkRecordIterator(
|
||||
size_t recordSize,
|
||||
size_t chunkSize,
|
||||
std::unique_ptr<MzZipReaderIterWrapper> iter)
|
||||
: recordSize_(recordSize),
|
||||
chunkSize_(chunkSize),
|
||||
offset_(0),
|
||||
iter_(std::move(iter)) {}
|
||||
|
||||
ChunkRecordIterator::~ChunkRecordIterator() {
|
||||
mz_zip_reader_extract_iter_free(iter_->impl);
|
||||
}
|
||||
|
||||
size_t ChunkRecordIterator::next(void* buf){
|
||||
size_t want_size = std::min(chunkSize_, recordSize_ - offset_);
|
||||
if (want_size == 0) {
|
||||
return 0;
|
||||
}
|
||||
size_t read_size = mz_zip_reader_extract_iter_read(iter_->impl, buf, want_size);
|
||||
TORCH_CHECK(read_size > 0, "Read bytes should be larger than 0");
|
||||
offset_ += read_size;
|
||||
return read_size;
|
||||
}
|
||||
|
||||
size_t istream_read_func(void* pOpaque, mz_uint64 file_ofs, void* pBuf, size_t n) {
|
||||
auto self = static_cast<PyTorchStreamReader*>(pOpaque);
|
||||
return self->read(file_ofs, static_cast<char*>(pBuf), n);
|
||||
@ -362,34 +391,41 @@ size_t PyTorchStreamReader::getRecord(
|
||||
n);
|
||||
valid("retrieving file meta-data for ", name.c_str());
|
||||
|
||||
mz_zip_reader_extract_iter_state* iter =
|
||||
mz_zip_reader_extract_iter_new(ar_.get(), key, 0);
|
||||
TORCH_CHECK(
|
||||
iter != nullptr,
|
||||
"Failed to create zip reader iter: ",
|
||||
mz_zip_get_error_string(mz_zip_get_last_error(ar_.get())));
|
||||
std::vector<uint8_t> buffer;
|
||||
if (buf == nullptr) {
|
||||
buffer.resize(chunk_size);
|
||||
buf = buffer.data();
|
||||
}
|
||||
for (size_t offset = 0; offset < stat.m_uncomp_size; offset += chunk_size) {
|
||||
size_t want_size =
|
||||
std::min(chunk_size, (size_t)stat.m_uncomp_size - offset);
|
||||
size_t read_size =
|
||||
mz_zip_reader_extract_iter_read(iter, buf, want_size);
|
||||
TORCH_CHECK(
|
||||
read_size == want_size,
|
||||
"Failed to advance zip reader iter: ",
|
||||
mz_zip_get_error_string(mz_zip_get_last_error(ar_.get())));
|
||||
memcpy_func((char*)dst + offset, buf, read_size);
|
||||
|
||||
auto chunkIterator =
|
||||
createChunkReaderIter(name, (size_t)stat.m_uncomp_size, chunk_size);
|
||||
while (auto readSize = chunkIterator.next(buf)) {
|
||||
memcpy_func((char*)dst + chunkIterator.offset_ - readSize, buf, readSize);
|
||||
}
|
||||
valid("reading file ", name.c_str());
|
||||
mz_zip_reader_extract_iter_free(iter);
|
||||
|
||||
return stat.m_uncomp_size;
|
||||
}
|
||||
|
||||
ChunkRecordIterator PyTorchStreamReader::createChunkReaderIter(
|
||||
const std::string& name,
|
||||
const size_t recordSize,
|
||||
const size_t chunkSize) {
|
||||
// Create zip reader iterator
|
||||
size_t key = getRecordID(name);
|
||||
mz_zip_reader_extract_iter_state* zipReaderIter =
|
||||
mz_zip_reader_extract_iter_new(ar_.get(), key, 0);
|
||||
TORCH_CHECK(
|
||||
zipReaderIter != nullptr,
|
||||
"Failed to create zip reader iter: ",
|
||||
mz_zip_get_error_string(mz_zip_get_last_error(ar_.get())));
|
||||
|
||||
return ChunkRecordIterator(
|
||||
recordSize,
|
||||
chunkSize,
|
||||
std::make_unique<MzZipReaderIterWrapper>(zipReaderIter));
|
||||
}
|
||||
|
||||
static int64_t read_le_16(uint8_t* buf) {
|
||||
return buf[0] + (buf[1] << 8);
|
||||
}
|
||||
@ -411,6 +447,11 @@ size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
|
||||
return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
|
||||
}
|
||||
|
||||
size_t PyTorchStreamReader::getRecordSize(const std::string& name) {
|
||||
mz_zip_archive_file_stat stat;
|
||||
mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat);
|
||||
return stat.m_uncomp_size;
|
||||
}
|
||||
|
||||
PyTorchStreamReader::~PyTorchStreamReader() {
|
||||
mz_zip_clear_last_error(ar_.get());
|
||||
|
@ -95,6 +95,29 @@ namespace serialize {
|
||||
|
||||
static constexpr const char* kSerializationIdRecordName = ".data/serialization_id";
|
||||
|
||||
struct MzZipReaderIterWrapper;
|
||||
|
||||
class TORCH_API ChunkRecordIterator {
|
||||
public:
|
||||
~ChunkRecordIterator();
|
||||
|
||||
// Read at most `chunkSize` into `buf`. Return the number of actual bytes read.
|
||||
size_t next(void* buf);
|
||||
|
||||
private:
|
||||
ChunkRecordIterator(
|
||||
size_t recordSize,
|
||||
size_t chunkSize,
|
||||
std::unique_ptr<MzZipReaderIterWrapper> iter);
|
||||
|
||||
const size_t recordSize_;
|
||||
const size_t chunkSize_;
|
||||
size_t offset_;
|
||||
std::unique_ptr<MzZipReaderIterWrapper> iter_;
|
||||
|
||||
friend class PyTorchStreamReader;
|
||||
};
|
||||
|
||||
class TORCH_API PyTorchStreamReader final {
|
||||
public:
|
||||
explicit PyTorchStreamReader(const std::string& file_name);
|
||||
@ -111,11 +134,19 @@ class TORCH_API PyTorchStreamReader final {
|
||||
size_t n,
|
||||
size_t chunk_size,
|
||||
void* buf,
|
||||
const std::function<void(void*, const void*, size_t)>& memcpy_func);
|
||||
const std::function<void(void*, const void*, size_t)>& memcpy_func = nullptr);
|
||||
|
||||
size_t getRecordSize(const std::string& name);
|
||||
|
||||
size_t getRecordOffset(const std::string& name);
|
||||
bool hasRecord(const std::string& name);
|
||||
std::vector<std::string> getAllRecords();
|
||||
|
||||
ChunkRecordIterator createChunkReaderIter(
|
||||
const std::string& name,
|
||||
const size_t recordSize,
|
||||
const size_t chunkSize);
|
||||
|
||||
~PyTorchStreamReader();
|
||||
uint64_t version() const {
|
||||
return version_;
|
||||
|
@ -340,6 +340,59 @@ TEST(PytorchStreamWriterAndReader, LogAPIUsageMetadata) {
|
||||
const std::map<std::string, std::string>& metadata_map) {});
|
||||
}
|
||||
|
||||
class ChunkRecordIteratorTest : public ::testing::TestWithParam<int64_t> {};
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ChunkRecordIteratorTestGroup,
|
||||
ChunkRecordIteratorTest,
|
||||
testing::Values(100, 150, 1010));
|
||||
|
||||
TEST_P(ChunkRecordIteratorTest, ChunkRead) {
|
||||
auto chunkSize = GetParam();
|
||||
std::string zipFileName = "output_chunk_" + std::to_string(chunkSize) + ".zip";
|
||||
const char* fileName = zipFileName.c_str();
|
||||
const std::string recordName = "key1";
|
||||
const size_t tensorDataSizeInBytes = 1000;
|
||||
|
||||
// write records through writers
|
||||
std::ostringstream oss;
|
||||
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
|
||||
oss.write(static_cast<const char*>(b), n);
|
||||
return oss ? n : 0;
|
||||
});
|
||||
|
||||
auto tensorData = std::vector<uint8_t>(tensorDataSizeInBytes, 1);
|
||||
auto dataPtr = tensorData.data();
|
||||
writer.writeRecord(recordName, dataPtr, tensorDataSizeInBytes);
|
||||
const std::unordered_set<std::string>& written_records =
|
||||
writer.getAllWrittenRecords();
|
||||
ASSERT_EQ(written_records.size(), 1);
|
||||
ASSERT_EQ(written_records.count(recordName), 1);
|
||||
writer.writeEndOfFile();
|
||||
ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
|
||||
|
||||
std::string the_file = oss.str();
|
||||
std::ofstream foo(fileName);
|
||||
foo.write(the_file.c_str(), the_file.size());
|
||||
foo.close();
|
||||
LOG(INFO) << "Finished saving tensor into zip file " << fileName;
|
||||
|
||||
LOG(INFO) << "Testing chunk size " << chunkSize;
|
||||
PyTorchStreamReader reader(fileName);
|
||||
ASSERT_TRUE(reader.hasRecord(recordName));
|
||||
auto chunkIterator = reader.createChunkReaderIter(
|
||||
recordName, tensorDataSizeInBytes, chunkSize);
|
||||
std::vector<uint8_t> buffer(chunkSize);
|
||||
size_t totalReadSize = 0;
|
||||
while (auto readSize = chunkIterator.next(buffer.data())) {
|
||||
auto expectedData = std::vector<uint8_t>(readSize, 1);
|
||||
ASSERT_EQ(memcmp(expectedData.data(), buffer.data(), readSize), 0);
|
||||
totalReadSize += readSize;
|
||||
}
|
||||
ASSERT_EQ(totalReadSize, tensorDataSizeInBytes);
|
||||
// clean up
|
||||
remove(fileName);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
||||
|
Reference in New Issue
Block a user