Format caffe2/serialize (#141850)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141850
Approved by: https://github.com/cpuhrsch
This commit is contained in:
cyy
2024-12-04 01:14:24 +00:00
committed by PyTorch MergeBot
parent 941da90e8a
commit bffaddf9ea
6 changed files with 185 additions and 132 deletions

View File

@ -21,7 +21,8 @@ FileAdapter::RAIIFile::RAIIFile(const std::string& file_name) {
auto error_msg =
std::system_category().default_error_condition(old_errno).message();
#endif
TORCH_CHECK(false,
TORCH_CHECK(
false,
"open file failed because of errno ",
old_errno,
" on fopen: ",

View File

@ -1,8 +1,8 @@
#pragma once
#include <c10/macros/Macros.h>
#include <fstream>
#include <memory>
#include <c10/macros/Macros.h>
#include "caffe2/serialize/istream_adapter.h"
#include "caffe2/serialize/read_adapter_interface.h"

View File

@ -1,7 +1,6 @@
#pragma once
#include <cstring>
#include <caffe2/serialize/read_adapter_interface.h>
#include <cstring>
namespace caffe2 {
namespace serialize {
@ -17,7 +16,7 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
const override {
(void) what;
(void)what;
memcpy(buf, (int8_t*)(data_) + pos, n);
return n;
}
@ -27,6 +26,5 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
off_t size_;
};
} // namespace serialize
} // namespace caffe2

View File

@ -1,13 +1,13 @@
#include <cstdio>
#include <cstring>
#include <cerrno>
#include <istream>
#include <ostream>
#include <fstream>
#include <algorithm>
#include <sstream>
#include <sys/stat.h>
#include <sys/types.h>
#include <algorithm>
#include <cerrno>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <istream>
#include <ostream>
#include <sstream>
#include <thread>
#include <c10/core/Allocator.h>
@ -48,25 +48,27 @@ ChunkRecordIterator::~ChunkRecordIterator() {
mz_zip_reader_extract_iter_free(iter_->impl);
}
size_t ChunkRecordIterator::next(void* buf){
size_t ChunkRecordIterator::next(void* buf) {
size_t want_size = std::min(chunkSize_, recordSize_ - offset_);
if (want_size == 0) {
return 0;
}
size_t read_size = mz_zip_reader_extract_iter_read(iter_->impl, buf, want_size);
size_t read_size =
mz_zip_reader_extract_iter_read(iter_->impl, buf, want_size);
TORCH_CHECK(read_size > 0, "Read bytes should be larger than 0");
offset_ += read_size;
return read_size;
}
size_t istream_read_func(void* pOpaque, mz_uint64 file_ofs, void* pBuf, size_t n) {
size_t
istream_read_func(void* pOpaque, mz_uint64 file_ofs, void* pBuf, size_t n) {
auto self = static_cast<PyTorchStreamReader*>(pOpaque);
return self->read(file_ofs, static_cast<char*>(pBuf), n);
}
static std::string basename(const std::string& name) {
size_t start = 0;
for(size_t i = 0; i < name.size(); ++i) {
for (size_t i = 0; i < name.size(); ++i) {
if (name[i] == '\\' || name[i] == '/') {
start = i + 1;
}
@ -77,7 +79,7 @@ static std::string basename(const std::string& name) {
}
size_t end = name.size();
for(size_t i = end; i > start; --i) {
for (size_t i = end; i > start; --i) {
if (name[i - 1] == '.') {
end = i - 1;
break;
@ -92,13 +94,13 @@ static std::string parentdir(const std::string& name) {
end = name.find_last_of('\\');
}
#ifdef WIN32
#ifdef WIN32
if (end != std::string::npos && end > 1 && name[end - 1] == ':') {
// This is a Windows root directory, so include the slash in
// the parent directory
end++;
}
#endif
#endif
if (end == std::string::npos) {
return "";
@ -157,8 +159,8 @@ void PyTorchStreamReader::init() {
mz_zip_reader_init(ar_.get(), size, 0);
valid("reading zip archive");
// figure out the archive_name (i.e. the zip folder all the other files are in)
// all lookups to getRecord will be prefixed by this folder
// figure out the archive_name (i.e. the zip folder all the other files are
// in) all lookups to getRecord will be prefixed by this folder
mz_uint n = mz_zip_reader_get_num_files(ar_.get());
if (n == 0) {
CAFFE_THROW("archive does not contain any files");
@ -201,15 +203,15 @@ void PyTorchStreamReader::init() {
TORCH_CHECK(hasRecord("version"))
std::tie(version_ptr, version_size) = getRecord("version");
}
std::string version(static_cast<const char*>(version_ptr.get()), version_size);
std::string version(
static_cast<const char*>(version_ptr.get()), version_size);
try {
version_ = std::stoull(version);
} catch (const std::invalid_argument& e) {
CAFFE_THROW("Couldn't parse the version ",
version,
" as Long Long.");
CAFFE_THROW("Couldn't parse the version ", version, " as Long Long.");
}
if (version_ < static_cast<decltype(version_)>(kMinSupportedFileFormatVersion)) {
if (version_ <
static_cast<decltype(version_)>(kMinSupportedFileFormatVersion)) {
CAFFE_THROW(
"Attempted to read a PyTorch file with version ",
std::to_string(version_),
@ -219,7 +221,8 @@ void PyTorchStreamReader::init() {
" with latest version of PyTorch to mitigate this issue.");
}
if (version_ > static_cast<decltype(version_)>(kMaxSupportedFileFormatVersion)) {
if (version_ >
static_cast<decltype(version_)>(kMaxSupportedFileFormatVersion)) {
CAFFE_THROW(
"Attempted to read a PyTorch file with version ",
version_,
@ -277,12 +280,13 @@ size_t getPadding(
padding_buf[3] = (uint8_t)(padding_size >> 8);
return padding_size_plus_fbxx;
}
}
} // namespace detail
bool PyTorchStreamReader::hasRecord(const std::string& name) {
std::lock_guard<std::mutex> guard(reader_lock_);
if ((!load_debug_symbol_) && c10::ends_with(std::string_view(name), kDebugPklSuffix)) {
if ((!load_debug_symbol_) &&
c10::ends_with(std::string_view(name), kDebugPklSuffix)) {
return false;
}
std::string ss = archive_name_plus_slash_ + name;
@ -307,7 +311,8 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
char buf[MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE];
for (size_t i = 0; i < num_files; i++) {
mz_zip_reader_get_filename(ar_.get(), i, buf, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE);
mz_zip_reader_get_filename(
ar_.get(), i, buf, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE);
if (strncmp(
buf,
archive_name_plus_slash_.data(),
@ -319,7 +324,9 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
buf);
}
if ((load_debug_symbol_) ||
(!c10::ends_with(std::string_view(buf + archive_name_plus_slash_.size()),kDebugPklSuffix))) {
(!c10::ends_with(
std::string_view(buf + archive_name_plus_slash_.size()),
kDebugPklSuffix))) {
// NOLINTNEXTLINE(modernize-use-emplace)
out.push_back(buf + archive_name_plus_slash_.size());
}
@ -340,7 +347,8 @@ size_t PyTorchStreamReader::getRecordID(const std::string& name) {
}
// return dataptr, size
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
const std::string& name) {
std::lock_guard<std::mutex> guard(reader_lock_);
if ((!load_debug_symbol_) && c10::ends_with(name, kDebugPklSuffix)) {
at::DataPtr retval;
@ -351,45 +359,57 @@ std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string
mz_zip_reader_file_stat(ar_.get(), key, &stat);
valid("retrieving file meta-data for ", name.c_str());
at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
mz_zip_reader_extract_to_mem(ar_.get(), key, retval.get(), stat.m_uncomp_size, 0);
mz_zip_reader_extract_to_mem(
ar_.get(), key, retval.get(), stat.m_uncomp_size, 0);
valid("reading file ", name.c_str());
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
}
size_t
PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
void *dst, size_t n){
size_t nthread = additionalReaders.size()+1;
size_t PyTorchStreamReader::getRecordMultiReaders(
const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
void* dst,
size_t n) {
size_t nthread = additionalReaders.size() + 1;
size_t recordOff = getRecordOffset(name);
std::vector<std::thread> loaderThreads;
size_t perThreadSize = (n+nthread-1)/nthread;
size_t perThreadSize = (n + nthread - 1) / nthread;
std::vector<size_t> readSizes(nthread, 0);
std::lock_guard<std::mutex> guard(reader_lock_);
for(size_t i = 0; i < nthread ; i++){
loaderThreads.emplace_back([this, name, i, n, recordOff, perThreadSize, dst, &additionalReaders, &readSizes]{
size_t startPos = i*perThreadSize;
size_t endPos = std::min((i+1)*perThreadSize,n);
if (startPos < endPos){
for (size_t i = 0; i < nthread; i++) {
loaderThreads.emplace_back([this,
name,
i,
n,
recordOff,
perThreadSize,
dst,
&additionalReaders,
&readSizes] {
size_t startPos = i * perThreadSize;
size_t endPos = std::min((i + 1) * perThreadSize, n);
if (startPos < endPos) {
size_t threadReadSize = endPos - startPos;
size_t size = 0;
if (i==0){
size = read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
}else{
auto reader = additionalReaders[i-1];
size = reader->read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
if (i == 0) {
size =
read(recordOff + startPos, (char*)dst + startPos, threadReadSize);
} else {
auto reader = additionalReaders[i - 1];
size = reader->read(
recordOff + startPos, (char*)dst + startPos, threadReadSize);
}
readSizes[i] = size;
LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos << "] "
<< "from " << name << " of size " << n;
LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos
<< "] "
<< "from " << name << " of size " << n;
TORCH_CHECK(
threadReadSize == size,
"record size ",
threadReadSize,
" mismatch with read size ",
size);
threadReadSize == size,
"record size ",
threadReadSize,
" mismatch with read size ",
size);
}
});
}
@ -400,7 +420,7 @@ PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
loaderThreads.clear();
size_t total_read_n = 0;
for (auto& r : readSizes){
for (auto& r : readSizes) {
total_read_n += r;
}
@ -415,10 +435,10 @@ PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
}
// read record with multi clients
std::tuple<at::DataPtr, size_t>
PyTorchStreamReader::getRecord(const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
if(additionalReaders.empty()){
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
if (additionalReaders.empty()) {
// No additional readers or record too small, use single threaded version
return getRecord(name);
}
@ -432,7 +452,7 @@ PyTorchStreamReader::getRecord(const std::string& name,
mz_zip_reader_file_stat(ar_.get(), key, &stat);
auto n = stat.m_uncomp_size;
valid("retrieving file meta-data for ", name.c_str());
if(n < additional_reader_size_threshold_){
if (n < additional_reader_size_threshold_) {
// Reader size too small, use single threaded version
return getRecord(name);
}
@ -466,17 +486,20 @@ PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) {
return stat.m_uncomp_size;
}
// inplace memory writing, in-tensor multi-threads, can be used for large tensor.
size_t
PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
if(additionalReaders.empty()){
// inplace memory writing, in-tensor multi-threads, can be used for large
// tensor.
size_t PyTorchStreamReader::getRecord(
const std::string& name,
void* dst,
size_t n,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
if (additionalReaders.empty()) {
// No additional readers, use single threaded version
return getRecord(name, dst, n);
}
if ((!load_debug_symbol_) && c10::ends_with(std::string_view(name), kDebugPklSuffix)) {
if ((!load_debug_symbol_) &&
c10::ends_with(std::string_view(name), kDebugPklSuffix)) {
return 0;
}
size_t key = getRecordID(name);
@ -490,7 +513,7 @@ PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n,
n);
valid("retrieving file meta-data for ", name.c_str());
if(n < additional_reader_size_threshold_){
if (n < additional_reader_size_threshold_) {
// Reader size too small, use single threaded version
return getRecord(name, dst, n);
}
@ -577,7 +600,8 @@ size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
"reading file header");
size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len +
extra_len;
}
size_t PyTorchStreamReader::getRecordSize(const std::string& name) {
@ -620,14 +644,16 @@ size_t ostream_write_func(
return ret;
}
PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name, bool compute_crc32)
: archive_name_(basename(file_name)),
compute_crc32_(compute_crc32) {
PyTorchStreamWriter::PyTorchStreamWriter(
const std::string& file_name,
bool compute_crc32)
: archive_name_(basename(file_name)), compute_crc32_(compute_crc32) {
setup(file_name);
}
PyTorchStreamWriter::PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)> writer_func, bool compute_crc32)
const std::function<size_t(const void*, size_t)> writer_func,
bool compute_crc32)
: archive_name_("archive"),
writer_func_(writer_func),
compute_crc32_(compute_crc32) {
@ -649,10 +675,12 @@ void PyTorchStreamWriter::setup(const string& file_name) {
valid("opening archive ", file_name.c_str());
const std::string dir_name = parentdir(file_name);
if(!dir_name.empty()) {
if (!dir_name.empty()) {
struct stat st;
bool dir_exists = (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR));
TORCH_CHECK(dir_exists, "Parent directory ", dir_name, " does not exist.");
bool dir_exists =
(stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR));
TORCH_CHECK(
dir_exists, "Parent directory ", dir_name, " does not exist.");
}
TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened.");
writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
@ -728,17 +756,20 @@ void PyTorchStreamWriter::writeEndOfFile() {
// destructor would would result in `std::terminate()`
// See https://github.com/pytorch/pytorch/issues/87997/
struct Finalizer {
Finalizer(bool& var): var_(var) {}
Finalizer(bool& var) : var_(var) {}
~Finalizer() {
var_ = true;
}
private:
bool& var_;
} f(finalized_);
auto allRecords = getAllWrittenRecords();
// If no ".data/version" or "version" record in the output model, rewrites version info
if(allRecords.find(".data/version") == allRecords.end() && allRecords.find("version") == allRecords.end()) {
// If no ".data/version" or "version" record in the output model, rewrites
// version info
if (allRecords.find(".data/version") == allRecords.end() &&
allRecords.find("version") == allRecords.end()) {
std::string version = std::to_string(version_);
version.push_back('\n');
if (version_ >= 0x6L) {
@ -749,7 +780,7 @@ void PyTorchStreamWriter::writeEndOfFile() {
}
// If no "byteorder" record in the output model, rewrites byteorder info
if(allRecords.find("byteorder") == allRecords.end()) {
if (allRecords.find("byteorder") == allRecords.end()) {
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
std::string byteorder = "little";
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
@ -808,9 +839,8 @@ void PyTorchStreamWriter::writeSerializationId() {
}
std::ostringstream serialization_id_oss;
serialization_id_oss << std::setfill('0') << std::setw(20)
<< combined_record_name_hash
<< std::setfill('0') << std::setw(20)
<< combined_uncomp_crc32_;
<< combined_record_name_hash << std::setfill('0')
<< std::setw(20) << combined_uncomp_crc32_;
serialization_id_ = serialization_id_oss.str();
writeRecord(
kSerializationIdRecordName,

View File

@ -16,7 +16,6 @@
#include "caffe2/serialize/read_adapter_interface.h"
#include "caffe2/serialize/versions.h"
extern "C" {
typedef struct mz_zip_archive mz_zip_archive;
}
@ -94,7 +93,8 @@ typedef struct mz_zip_archive mz_zip_archive;
namespace caffe2 {
namespace serialize {
static constexpr const char* kSerializationIdRecordName = ".data/serialization_id";
static constexpr const char* kSerializationIdRecordName =
".data/serialization_id";
struct MzZipReaderIterWrapper;
@ -102,12 +102,15 @@ class TORCH_API ChunkRecordIterator {
public:
~ChunkRecordIterator();
// Read at most `chunkSize` into `buf`. Return the number of actual bytes read.
// Read at most `chunkSize` into `buf`. Return the number of actual bytes
// read.
size_t next(void* buf);
size_t recordSize() const { return recordSize_; }
size_t recordSize() const {
return recordSize_;
}
private:
ChunkRecordIterator(
ChunkRecordIterator(
size_t recordSize,
size_t chunkSize,
std::unique_ptr<MzZipReaderIterWrapper> iter);
@ -129,35 +132,44 @@ class TORCH_API PyTorchStreamReader final {
// return dataptr, size
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
// multi-thread getRecord
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name, std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
std::tuple<at::DataPtr, size_t> getRecord(
const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
// inplace memory writing
size_t getRecord(const std::string& name, void* dst, size_t n);
// inplace memory writing, multi-threads.
// When additionalReaders is empty, the default behavior is call getRecord(name, dst, n) with default reader
// This approach can be used for reading large tensors.
size_t getRecord(const std::string& name, void* dst, size_t n,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
// When additionalReaders is empty, the default behavior is call
// getRecord(name, dst, n) with default reader This approach can be used for
// reading large tensors.
size_t getRecord(
const std::string& name,
void* dst,
size_t n,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
size_t getRecord(
const std::string& name,
void* dst,
size_t n,
size_t chunk_size,
void* buf,
const std::function<void(void*, const void*, size_t)>& memcpy_func = nullptr);
const std::function<void(void*, const void*, size_t)>& memcpy_func =
nullptr);
// Concurrent reading records with multiple readers.
// additionalReaders are additional clients to access the underlying record at different offsets
// and write to different trunks of buffers.
// If the overall size of the tensor is 10, and size of additionalReader is 2.
// The default thread will read [0,4), the additional reader will read [4,8).
// The default reader will read [8,10).
// The default reader will write to buffer[0,4), the additional reader will write to buffer[4,8),
// the additional reader will write to buffer[8,10).
// When additionalReaders is empty, the default behavior is call getRecord(name) with default reader
// This approach can be used for reading large tensors.
size_t getRecordMultiReaders(const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
void *dst, size_t n);
// additionalReaders are additional clients to access the underlying record at
// different offsets and write to different trunks of buffers. If the overall
// size of the tensor is 10, and size of additionalReader is 2. The default
// thread will read [0,4), the additional reader will read [4,8). The default
// reader will read [8,10). The default reader will write to buffer[0,4), the
// additional reader will write to buffer[4,8), the additional reader will
// write to buffer[8,10). When additionalReaders is empty, the default
// behavior is call getRecord(name) with default reader This approach can be
// used for reading large tensors.
size_t getRecordMultiReaders(
const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
void* dst,
size_t n);
size_t getRecordSize(const std::string& name);
@ -181,9 +193,10 @@ class TORCH_API PyTorchStreamReader final {
void setShouldLoadDebugSymbol(bool should_load_debug_symbol) {
load_debug_symbol_ = should_load_debug_symbol;
}
void setAdditionalReaderSizeThreshold(const size_t& size){
void setAdditionalReaderSizeThreshold(const size_t& size) {
additional_reader_size_threshold_ = size;
}
private:
void init();
size_t read(uint64_t pos, char* buf, size_t n);
@ -205,9 +218,12 @@ class TORCH_API PyTorchStreamReader final {
class TORCH_API PyTorchStreamWriter final {
public:
explicit PyTorchStreamWriter(const std::string& archive_name, bool compute_crc32 = true);
explicit PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)> writer_func, bool compute_crc32 = true);
const std::string& archive_name,
bool compute_crc32 = true);
explicit PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)> writer_func,
bool compute_crc32 = true);
void setMinVersion(const uint64_t version);

View File

@ -5,9 +5,9 @@
#include <gtest/gtest.h>
#include "caffe2/serialize/inline_container.h"
#include <c10/util/Logging.h>
#include "c10/util/irange.h"
#include "caffe2/serialize/inline_container.h"
namespace caffe2 {
namespace serialize {
@ -77,9 +77,12 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
// chunked getRecord() test
ret = reader.getRecord(
"key1", dst.data(), size, 3, buf.data(), [](void* dst, const void* src, size_t n) {
memcpy(dst, src, n);
});
"key1",
dst.data(),
size,
3,
buf.data(),
[](void* dst, const void* src, size_t n) { memcpy(dst, src, n); });
ASSERT_EQ(ret, size);
ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
@ -97,9 +100,12 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
// chunked getRecord() test
ret = reader.getRecord(
"key2", dst.data(), size, 3, buf.data(), [](void* dst, const void* src, size_t n) {
memcpy(dst, src, n);
});
"key2",
dst.data(),
size,
3,
buf.data(),
[](void* dst, const void* src, size_t n) { memcpy(dst, src, n); });
ASSERT_EQ(ret, size);
ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
// clean up
@ -107,7 +113,6 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
}
TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) {
std::ostringstream oss;
// write records through writers
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
@ -156,7 +161,7 @@ TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) {
// Test getRecord(name, additional_readers)
std::vector<std::shared_ptr<ReadAdapterInterface>> additionalReader;
for(int i=0; i<10; ++i){
for (int i = 0; i < 10; ++i) {
// Test various sized additional readers.
std::tie(data_ptr, ret) = reader.getRecord("key1", additionalReader);
ASSERT_EQ(ret, size1);
@ -170,7 +175,7 @@ TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) {
// Inplace multi-threading getRecord(name, dst, n, additional_readers) test
additionalReader.clear();
std::vector<uint8_t> dst1(size1), dst2(size2);
for(int i=0; i<10; ++i){
for (int i = 0; i < 10; ++i) {
// Test various sizes of read threads
additionalReader.push_back(std::make_unique<IStreamAdapter>(&iss));
@ -324,7 +329,7 @@ TEST(PytorchStreamWriterAndReader, ValidSerializationId) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
std::array<char, 127> data1;
for (auto i: c10::irange(data1.size())) {
for (auto i : c10::irange(data1.size())) {
data1[i] = data1.size() - i;
}
writer.writeRecord("key1.debug_pkl", data1.data(), data1.size());
@ -361,7 +366,10 @@ TEST(PytorchStreamWriterAndReader, SkipDuplicateSerializationIdRecords) {
});
std::string dup_serialization_id = "dup-serialization-id";
writer.writeRecord(kSerializationIdRecordName, dup_serialization_id.c_str(), dup_serialization_id.size());
writer.writeRecord(
kSerializationIdRecordName,
dup_serialization_id.c_str(),
dup_serialization_id.size());
const std::unordered_set<std::string>& written_records =
writer.getAllWrittenRecords();
@ -410,13 +418,12 @@ TEST(PytorchStreamWriterAndReader, LogAPIUsageMetadata) {
std::map<std::string, std::map<std::string, std::string>> expected_logs = {
{"pytorch.stream.writer.metadata",
{{"serialization_id", writer.serializationId()},
{"file_name", "archive"},
{"file_size", str(oss.str().length())}}},
{"file_name", "archive"},
{"file_size", str(oss.str().length())}}},
{"pytorch.stream.reader.metadata",
{{"serialization_id", writer.serializationId()},
{"file_name", "archive"},
{"file_size", str(iss.str().length())}}}
};
{"file_name", "archive"},
{"file_size", str(iss.str().length())}}}};
ASSERT_EQ(expected_logs, logs);
// reset logger
@ -433,7 +440,8 @@ INSTANTIATE_TEST_SUITE_P(
TEST_P(ChunkRecordIteratorTest, ChunkRead) {
auto chunkSize = GetParam();
std::string zipFileName = "output_chunk_" + std::to_string(chunkSize) + ".zip";
std::string zipFileName =
"output_chunk_" + std::to_string(chunkSize) + ".zip";
const char* fileName = zipFileName.c_str();
const std::string recordName = "key1";
const size_t tensorDataSizeInBytes = 1000;