Format caffe2/serialize (#141850)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141850
Approved by: https://github.com/cpuhrsch
This commit is contained in:
cyy
2024-12-04 01:14:24 +00:00
committed by PyTorch MergeBot
parent 941da90e8a
commit bffaddf9ea
6 changed files with 185 additions and 132 deletions

View File

@ -21,7 +21,8 @@ FileAdapter::RAIIFile::RAIIFile(const std::string& file_name) {
auto error_msg =
std::system_category().default_error_condition(old_errno).message();
#endif
TORCH_CHECK(false,
TORCH_CHECK(
false,
"open file failed because of errno ",
old_errno,
" on fopen: ",

View File

@ -1,8 +1,8 @@
#pragma once
#include <c10/macros/Macros.h>
#include <fstream>
#include <memory>
#include <c10/macros/Macros.h>
#include "caffe2/serialize/istream_adapter.h"
#include "caffe2/serialize/read_adapter_interface.h"

View File

@ -1,7 +1,6 @@
#pragma once
#include <cstring>
#include <caffe2/serialize/read_adapter_interface.h>
#include <cstring>
namespace caffe2 {
namespace serialize {
@ -27,6 +26,5 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
off_t size_;
};
} // namespace serialize
} // namespace caffe2

View File

@ -1,13 +1,13 @@
#include <cstdio>
#include <cstring>
#include <cerrno>
#include <istream>
#include <ostream>
#include <fstream>
#include <algorithm>
#include <sstream>
#include <sys/stat.h>
#include <sys/types.h>
#include <algorithm>
#include <cerrno>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <istream>
#include <ostream>
#include <sstream>
#include <thread>
#include <c10/core/Allocator.h>
@ -53,13 +53,15 @@ size_t ChunkRecordIterator::next(void* buf){
if (want_size == 0) {
return 0;
}
size_t read_size = mz_zip_reader_extract_iter_read(iter_->impl, buf, want_size);
size_t read_size =
mz_zip_reader_extract_iter_read(iter_->impl, buf, want_size);
TORCH_CHECK(read_size > 0, "Read bytes should be larger than 0");
offset_ += read_size;
return read_size;
}
size_t istream_read_func(void* pOpaque, mz_uint64 file_ofs, void* pBuf, size_t n) {
size_t
istream_read_func(void* pOpaque, mz_uint64 file_ofs, void* pBuf, size_t n) {
auto self = static_cast<PyTorchStreamReader*>(pOpaque);
return self->read(file_ofs, static_cast<char*>(pBuf), n);
}
@ -157,8 +159,8 @@ void PyTorchStreamReader::init() {
mz_zip_reader_init(ar_.get(), size, 0);
valid("reading zip archive");
// figure out the archive_name (i.e. the zip folder all the other files are in)
// all lookups to getRecord will be prefixed by this folder
// figure out the archive_name (i.e. the zip folder all the other files are
// in) all lookups to getRecord will be prefixed by this folder
mz_uint n = mz_zip_reader_get_num_files(ar_.get());
if (n == 0) {
CAFFE_THROW("archive does not contain any files");
@ -201,15 +203,15 @@ void PyTorchStreamReader::init() {
TORCH_CHECK(hasRecord("version"))
std::tie(version_ptr, version_size) = getRecord("version");
}
std::string version(static_cast<const char*>(version_ptr.get()), version_size);
std::string version(
static_cast<const char*>(version_ptr.get()), version_size);
try {
version_ = std::stoull(version);
} catch (const std::invalid_argument& e) {
CAFFE_THROW("Couldn't parse the version ",
version,
" as Long Long.");
CAFFE_THROW("Couldn't parse the version ", version, " as Long Long.");
}
if (version_ < static_cast<decltype(version_)>(kMinSupportedFileFormatVersion)) {
if (version_ <
static_cast<decltype(version_)>(kMinSupportedFileFormatVersion)) {
CAFFE_THROW(
"Attempted to read a PyTorch file with version ",
std::to_string(version_),
@ -219,7 +221,8 @@ void PyTorchStreamReader::init() {
" with latest version of PyTorch to mitigate this issue.");
}
if (version_ > static_cast<decltype(version_)>(kMaxSupportedFileFormatVersion)) {
if (version_ >
static_cast<decltype(version_)>(kMaxSupportedFileFormatVersion)) {
CAFFE_THROW(
"Attempted to read a PyTorch file with version ",
version_,
@ -277,12 +280,13 @@ size_t getPadding(
padding_buf[3] = (uint8_t)(padding_size >> 8);
return padding_size_plus_fbxx;
}
}
} // namespace detail
bool PyTorchStreamReader::hasRecord(const std::string& name) {
std::lock_guard<std::mutex> guard(reader_lock_);
if ((!load_debug_symbol_) && c10::ends_with(std::string_view(name), kDebugPklSuffix)) {
if ((!load_debug_symbol_) &&
c10::ends_with(std::string_view(name), kDebugPklSuffix)) {
return false;
}
std::string ss = archive_name_plus_slash_ + name;
@ -307,7 +311,8 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
char buf[MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE];
for (size_t i = 0; i < num_files; i++) {
mz_zip_reader_get_filename(ar_.get(), i, buf, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE);
mz_zip_reader_get_filename(
ar_.get(), i, buf, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE);
if (strncmp(
buf,
archive_name_plus_slash_.data(),
@ -319,7 +324,9 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
buf);
}
if ((load_debug_symbol_) ||
(!c10::ends_with(std::string_view(buf + archive_name_plus_slash_.size()),kDebugPklSuffix))) {
(!c10::ends_with(
std::string_view(buf + archive_name_plus_slash_.size()),
kDebugPklSuffix))) {
// NOLINTNEXTLINE(modernize-use-emplace)
out.push_back(buf + archive_name_plus_slash_.size());
}
@ -340,7 +347,8 @@ size_t PyTorchStreamReader::getRecordID(const std::string& name) {
}
// return dataptr, size
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
const std::string& name) {
std::lock_guard<std::mutex> guard(reader_lock_);
if ((!load_debug_symbol_) && c10::ends_with(name, kDebugPklSuffix)) {
at::DataPtr retval;
@ -351,17 +359,18 @@ std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string
mz_zip_reader_file_stat(ar_.get(), key, &stat);
valid("retrieving file meta-data for ", name.c_str());
at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
mz_zip_reader_extract_to_mem(ar_.get(), key, retval.get(), stat.m_uncomp_size, 0);
mz_zip_reader_extract_to_mem(
ar_.get(), key, retval.get(), stat.m_uncomp_size, 0);
valid("reading file ", name.c_str());
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
}
size_t
PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
size_t PyTorchStreamReader::getRecordMultiReaders(
const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
void *dst, size_t n){
void* dst,
size_t n) {
size_t nthread = additionalReaders.size() + 1;
size_t recordOff = getRecordOffset(name);
std::vector<std::thread> loaderThreads;
@ -369,20 +378,31 @@ PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
std::vector<size_t> readSizes(nthread, 0);
std::lock_guard<std::mutex> guard(reader_lock_);
for (size_t i = 0; i < nthread; i++) {
loaderThreads.emplace_back([this, name, i, n, recordOff, perThreadSize, dst, &additionalReaders, &readSizes]{
loaderThreads.emplace_back([this,
name,
i,
n,
recordOff,
perThreadSize,
dst,
&additionalReaders,
&readSizes] {
size_t startPos = i * perThreadSize;
size_t endPos = std::min((i + 1) * perThreadSize, n);
if (startPos < endPos) {
size_t threadReadSize = endPos - startPos;
size_t size = 0;
if (i == 0) {
size = read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
size =
read(recordOff + startPos, (char*)dst + startPos, threadReadSize);
} else {
auto reader = additionalReaders[i - 1];
size = reader->read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
size = reader->read(
recordOff + startPos, (char*)dst + startPos, threadReadSize);
}
readSizes[i] = size;
LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos << "] "
LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos
<< "] "
<< "from " << name << " of size " << n;
TORCH_CHECK(
threadReadSize == size,
@ -415,8 +435,8 @@ PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
}
// read record with multi clients
std::tuple<at::DataPtr, size_t>
PyTorchStreamReader::getRecord(const std::string& name,
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
if (additionalReaders.empty()) {
// No additional readers or record too small, use single threaded version
@ -466,17 +486,20 @@ PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) {
return stat.m_uncomp_size;
}
// inplace memory writing, in-tensor multi-threads, can be used for large tensor.
size_t
PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n,
// inplace memory writing, in-tensor multi-threads, can be used for large
// tensor.
size_t PyTorchStreamReader::getRecord(
const std::string& name,
void* dst,
size_t n,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
if (additionalReaders.empty()) {
// No additional readers, use single threaded version
return getRecord(name, dst, n);
}
if ((!load_debug_symbol_) && c10::ends_with(std::string_view(name), kDebugPklSuffix)) {
if ((!load_debug_symbol_) &&
c10::ends_with(std::string_view(name), kDebugPklSuffix)) {
return 0;
}
size_t key = getRecordID(name);
@ -577,7 +600,8 @@ size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
"reading file header");
size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len +
extra_len;
}
size_t PyTorchStreamReader::getRecordSize(const std::string& name) {
@ -620,14 +644,16 @@ size_t ostream_write_func(
return ret;
}
PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name, bool compute_crc32)
: archive_name_(basename(file_name)),
compute_crc32_(compute_crc32) {
PyTorchStreamWriter::PyTorchStreamWriter(
const std::string& file_name,
bool compute_crc32)
: archive_name_(basename(file_name)), compute_crc32_(compute_crc32) {
setup(file_name);
}
PyTorchStreamWriter::PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)> writer_func, bool compute_crc32)
const std::function<size_t(const void*, size_t)> writer_func,
bool compute_crc32)
: archive_name_("archive"),
writer_func_(writer_func),
compute_crc32_(compute_crc32) {
@ -651,8 +677,10 @@ void PyTorchStreamWriter::setup(const string& file_name) {
const std::string dir_name = parentdir(file_name);
if (!dir_name.empty()) {
struct stat st;
bool dir_exists = (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR));
TORCH_CHECK(dir_exists, "Parent directory ", dir_name, " does not exist.");
bool dir_exists =
(stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR));
TORCH_CHECK(
dir_exists, "Parent directory ", dir_name, " does not exist.");
}
TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened.");
writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
@ -732,13 +760,16 @@ void PyTorchStreamWriter::writeEndOfFile() {
~Finalizer() {
var_ = true;
}
private:
bool& var_;
} f(finalized_);
auto allRecords = getAllWrittenRecords();
// If no ".data/version" or "version" record in the output model, rewrites version info
if(allRecords.find(".data/version") == allRecords.end() && allRecords.find("version") == allRecords.end()) {
// If no ".data/version" or "version" record in the output model, rewrites
// version info
if (allRecords.find(".data/version") == allRecords.end() &&
allRecords.find("version") == allRecords.end()) {
std::string version = std::to_string(version_);
version.push_back('\n');
if (version_ >= 0x6L) {
@ -808,9 +839,8 @@ void PyTorchStreamWriter::writeSerializationId() {
}
std::ostringstream serialization_id_oss;
serialization_id_oss << std::setfill('0') << std::setw(20)
<< combined_record_name_hash
<< std::setfill('0') << std::setw(20)
<< combined_uncomp_crc32_;
<< combined_record_name_hash << std::setfill('0')
<< std::setw(20) << combined_uncomp_crc32_;
serialization_id_ = serialization_id_oss.str();
writeRecord(
kSerializationIdRecordName,

View File

@ -16,7 +16,6 @@
#include "caffe2/serialize/read_adapter_interface.h"
#include "caffe2/serialize/versions.h"
extern "C" {
typedef struct mz_zip_archive mz_zip_archive;
}
@ -94,7 +93,8 @@ typedef struct mz_zip_archive mz_zip_archive;
namespace caffe2 {
namespace serialize {
static constexpr const char* kSerializationIdRecordName = ".data/serialization_id";
static constexpr const char* kSerializationIdRecordName =
".data/serialization_id";
struct MzZipReaderIterWrapper;
@ -102,9 +102,12 @@ class TORCH_API ChunkRecordIterator {
public:
~ChunkRecordIterator();
// Read at most `chunkSize` into `buf`. Return the number of actual bytes read.
// Read at most `chunkSize` into `buf`. Return the number of actual bytes
// read.
size_t next(void* buf);
size_t recordSize() const { return recordSize_; }
size_t recordSize() const {
return recordSize_;
}
private:
ChunkRecordIterator(
@ -129,13 +132,19 @@ class TORCH_API PyTorchStreamReader final {
// return dataptr, size
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
// multi-thread getRecord
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name, std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
std::tuple<at::DataPtr, size_t> getRecord(
const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
// inplace memory writing
size_t getRecord(const std::string& name, void* dst, size_t n);
// inplace memory writing, multi-threads.
// When additionalReaders is empty, the default behavior is call getRecord(name, dst, n) with default reader
// This approach can be used for reading large tensors.
size_t getRecord(const std::string& name, void* dst, size_t n,
// When additionalReaders is empty, the default behavior is call
// getRecord(name, dst, n) with default reader This approach can be used for
// reading large tensors.
size_t getRecord(
const std::string& name,
void* dst,
size_t n,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
size_t getRecord(
const std::string& name,
@ -143,21 +152,24 @@ class TORCH_API PyTorchStreamReader final {
size_t n,
size_t chunk_size,
void* buf,
const std::function<void(void*, const void*, size_t)>& memcpy_func = nullptr);
const std::function<void(void*, const void*, size_t)>& memcpy_func =
nullptr);
// Concurrent reading records with multiple readers.
// additionalReaders are additional clients to access the underlying record at different offsets
// and write to different trunks of buffers.
// If the overall size of the tensor is 10, and size of additionalReader is 2.
// The default thread will read [0,4), the additional reader will read [4,8).
// The default reader will read [8,10).
// The default reader will write to buffer[0,4), the additional reader will write to buffer[4,8),
// the additional reader will write to buffer[8,10).
// When additionalReaders is empty, the default behavior is call getRecord(name) with default reader
// This approach can be used for reading large tensors.
size_t getRecordMultiReaders(const std::string& name,
// additionalReaders are additional clients to access the underlying record at
// different offsets and write to different trunks of buffers. If the overall
// size of the tensor is 10, and size of additionalReader is 2. The default
// thread will read [0,4), the additional reader will read [4,8). The default
// reader will read [8,10). The default reader will write to buffer[0,4), the
// additional reader will write to buffer[4,8), the additional reader will
// write to buffer[8,10). When additionalReaders is empty, the default
// behavior is call getRecord(name) with default reader This approach can be
// used for reading large tensors.
size_t getRecordMultiReaders(
const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
void *dst, size_t n);
void* dst,
size_t n);
size_t getRecordSize(const std::string& name);
@ -184,6 +196,7 @@ class TORCH_API PyTorchStreamReader final {
void setAdditionalReaderSizeThreshold(const size_t& size) {
additional_reader_size_threshold_ = size;
}
private:
void init();
size_t read(uint64_t pos, char* buf, size_t n);
@ -205,9 +218,12 @@ class TORCH_API PyTorchStreamReader final {
class TORCH_API PyTorchStreamWriter final {
public:
explicit PyTorchStreamWriter(const std::string& archive_name, bool compute_crc32 = true);
explicit PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)> writer_func, bool compute_crc32 = true);
const std::string& archive_name,
bool compute_crc32 = true);
explicit PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)> writer_func,
bool compute_crc32 = true);
void setMinVersion(const uint64_t version);

View File

@ -5,9 +5,9 @@
#include <gtest/gtest.h>
#include "caffe2/serialize/inline_container.h"
#include <c10/util/Logging.h>
#include "c10/util/irange.h"
#include "caffe2/serialize/inline_container.h"
namespace caffe2 {
namespace serialize {
@ -77,9 +77,12 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
// chunked getRecord() test
ret = reader.getRecord(
"key1", dst.data(), size, 3, buf.data(), [](void* dst, const void* src, size_t n) {
memcpy(dst, src, n);
});
"key1",
dst.data(),
size,
3,
buf.data(),
[](void* dst, const void* src, size_t n) { memcpy(dst, src, n); });
ASSERT_EQ(ret, size);
ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
@ -97,9 +100,12 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
// chunked getRecord() test
ret = reader.getRecord(
"key2", dst.data(), size, 3, buf.data(), [](void* dst, const void* src, size_t n) {
memcpy(dst, src, n);
});
"key2",
dst.data(),
size,
3,
buf.data(),
[](void* dst, const void* src, size_t n) { memcpy(dst, src, n); });
ASSERT_EQ(ret, size);
ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
// clean up
@ -107,7 +113,6 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
}
TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) {
std::ostringstream oss;
// write records through writers
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
@ -361,7 +366,10 @@ TEST(PytorchStreamWriterAndReader, SkipDuplicateSerializationIdRecords) {
});
std::string dup_serialization_id = "dup-serialization-id";
writer.writeRecord(kSerializationIdRecordName, dup_serialization_id.c_str(), dup_serialization_id.size());
writer.writeRecord(
kSerializationIdRecordName,
dup_serialization_id.c_str(),
dup_serialization_id.size());
const std::unordered_set<std::string>& written_records =
writer.getAllWrittenRecords();
@ -415,8 +423,7 @@ TEST(PytorchStreamWriterAndReader, LogAPIUsageMetadata) {
{"pytorch.stream.reader.metadata",
{{"serialization_id", writer.serializationId()},
{"file_name", "archive"},
{"file_size", str(iss.str().length())}}}
};
{"file_size", str(iss.str().length())}}}};
ASSERT_EQ(expected_logs, logs);
// reset logger
@ -433,7 +440,8 @@ INSTANTIATE_TEST_SUITE_P(
TEST_P(ChunkRecordIteratorTest, ChunkRead) {
auto chunkSize = GetParam();
std::string zipFileName = "output_chunk_" + std::to_string(chunkSize) + ".zip";
std::string zipFileName =
"output_chunk_" + std::to_string(chunkSize) + ".zip";
const char* fileName = zipFileName.c_str();
const std::string recordName = "key1";
const size_t tensorDataSizeInBytes = 1000;