mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Format caffe2/serialize (#141850)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/141850 Approved by: https://github.com/cpuhrsch
This commit is contained in:
@ -21,7 +21,8 @@ FileAdapter::RAIIFile::RAIIFile(const std::string& file_name) {
|
||||
auto error_msg =
|
||||
std::system_category().default_error_condition(old_errno).message();
|
||||
#endif
|
||||
TORCH_CHECK(false,
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"open file failed because of errno ",
|
||||
old_errno,
|
||||
" on fopen: ",
|
||||
|
@ -1,8 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include "caffe2/serialize/istream_adapter.h"
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
|
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
#include <cstring>
|
||||
#include <caffe2/serialize/read_adapter_interface.h>
|
||||
|
||||
#include <cstring>
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
@ -27,6 +26,5 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
|
||||
off_t size_;
|
||||
};
|
||||
|
||||
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
||||
|
@ -1,13 +1,13 @@
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <cerrno>
|
||||
#include <istream>
|
||||
#include <ostream>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <algorithm>
|
||||
#include <cerrno>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <istream>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
@ -53,13 +53,15 @@ size_t ChunkRecordIterator::next(void* buf){
|
||||
if (want_size == 0) {
|
||||
return 0;
|
||||
}
|
||||
size_t read_size = mz_zip_reader_extract_iter_read(iter_->impl, buf, want_size);
|
||||
size_t read_size =
|
||||
mz_zip_reader_extract_iter_read(iter_->impl, buf, want_size);
|
||||
TORCH_CHECK(read_size > 0, "Read bytes should be larger than 0");
|
||||
offset_ += read_size;
|
||||
return read_size;
|
||||
}
|
||||
|
||||
size_t istream_read_func(void* pOpaque, mz_uint64 file_ofs, void* pBuf, size_t n) {
|
||||
size_t
|
||||
istream_read_func(void* pOpaque, mz_uint64 file_ofs, void* pBuf, size_t n) {
|
||||
auto self = static_cast<PyTorchStreamReader*>(pOpaque);
|
||||
return self->read(file_ofs, static_cast<char*>(pBuf), n);
|
||||
}
|
||||
@ -157,8 +159,8 @@ void PyTorchStreamReader::init() {
|
||||
mz_zip_reader_init(ar_.get(), size, 0);
|
||||
valid("reading zip archive");
|
||||
|
||||
// figure out the archive_name (i.e. the zip folder all the other files are in)
|
||||
// all lookups to getRecord will be prefixed by this folder
|
||||
// figure out the archive_name (i.e. the zip folder all the other files are
|
||||
// in) all lookups to getRecord will be prefixed by this folder
|
||||
mz_uint n = mz_zip_reader_get_num_files(ar_.get());
|
||||
if (n == 0) {
|
||||
CAFFE_THROW("archive does not contain any files");
|
||||
@ -201,15 +203,15 @@ void PyTorchStreamReader::init() {
|
||||
TORCH_CHECK(hasRecord("version"))
|
||||
std::tie(version_ptr, version_size) = getRecord("version");
|
||||
}
|
||||
std::string version(static_cast<const char*>(version_ptr.get()), version_size);
|
||||
std::string version(
|
||||
static_cast<const char*>(version_ptr.get()), version_size);
|
||||
try {
|
||||
version_ = std::stoull(version);
|
||||
} catch (const std::invalid_argument& e) {
|
||||
CAFFE_THROW("Couldn't parse the version ",
|
||||
version,
|
||||
" as Long Long.");
|
||||
CAFFE_THROW("Couldn't parse the version ", version, " as Long Long.");
|
||||
}
|
||||
if (version_ < static_cast<decltype(version_)>(kMinSupportedFileFormatVersion)) {
|
||||
if (version_ <
|
||||
static_cast<decltype(version_)>(kMinSupportedFileFormatVersion)) {
|
||||
CAFFE_THROW(
|
||||
"Attempted to read a PyTorch file with version ",
|
||||
std::to_string(version_),
|
||||
@ -219,7 +221,8 @@ void PyTorchStreamReader::init() {
|
||||
" with latest version of PyTorch to mitigate this issue.");
|
||||
}
|
||||
|
||||
if (version_ > static_cast<decltype(version_)>(kMaxSupportedFileFormatVersion)) {
|
||||
if (version_ >
|
||||
static_cast<decltype(version_)>(kMaxSupportedFileFormatVersion)) {
|
||||
CAFFE_THROW(
|
||||
"Attempted to read a PyTorch file with version ",
|
||||
version_,
|
||||
@ -277,12 +280,13 @@ size_t getPadding(
|
||||
padding_buf[3] = (uint8_t)(padding_size >> 8);
|
||||
return padding_size_plus_fbxx;
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
bool PyTorchStreamReader::hasRecord(const std::string& name) {
|
||||
std::lock_guard<std::mutex> guard(reader_lock_);
|
||||
|
||||
if ((!load_debug_symbol_) && c10::ends_with(std::string_view(name), kDebugPklSuffix)) {
|
||||
if ((!load_debug_symbol_) &&
|
||||
c10::ends_with(std::string_view(name), kDebugPklSuffix)) {
|
||||
return false;
|
||||
}
|
||||
std::string ss = archive_name_plus_slash_ + name;
|
||||
@ -307,7 +311,8 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
char buf[MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE];
|
||||
for (size_t i = 0; i < num_files; i++) {
|
||||
mz_zip_reader_get_filename(ar_.get(), i, buf, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE);
|
||||
mz_zip_reader_get_filename(
|
||||
ar_.get(), i, buf, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE);
|
||||
if (strncmp(
|
||||
buf,
|
||||
archive_name_plus_slash_.data(),
|
||||
@ -319,7 +324,9 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
|
||||
buf);
|
||||
}
|
||||
if ((load_debug_symbol_) ||
|
||||
(!c10::ends_with(std::string_view(buf + archive_name_plus_slash_.size()),kDebugPklSuffix))) {
|
||||
(!c10::ends_with(
|
||||
std::string_view(buf + archive_name_plus_slash_.size()),
|
||||
kDebugPklSuffix))) {
|
||||
// NOLINTNEXTLINE(modernize-use-emplace)
|
||||
out.push_back(buf + archive_name_plus_slash_.size());
|
||||
}
|
||||
@ -340,7 +347,8 @@ size_t PyTorchStreamReader::getRecordID(const std::string& name) {
|
||||
}
|
||||
|
||||
// return dataptr, size
|
||||
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
|
||||
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
|
||||
const std::string& name) {
|
||||
std::lock_guard<std::mutex> guard(reader_lock_);
|
||||
if ((!load_debug_symbol_) && c10::ends_with(name, kDebugPklSuffix)) {
|
||||
at::DataPtr retval;
|
||||
@ -351,17 +359,18 @@ std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string
|
||||
mz_zip_reader_file_stat(ar_.get(), key, &stat);
|
||||
valid("retrieving file meta-data for ", name.c_str());
|
||||
at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
|
||||
mz_zip_reader_extract_to_mem(ar_.get(), key, retval.get(), stat.m_uncomp_size, 0);
|
||||
mz_zip_reader_extract_to_mem(
|
||||
ar_.get(), key, retval.get(), stat.m_uncomp_size, 0);
|
||||
valid("reading file ", name.c_str());
|
||||
|
||||
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
|
||||
}
|
||||
|
||||
size_t
|
||||
PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
|
||||
size_t PyTorchStreamReader::getRecordMultiReaders(
|
||||
const std::string& name,
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
|
||||
void *dst, size_t n){
|
||||
|
||||
void* dst,
|
||||
size_t n) {
|
||||
size_t nthread = additionalReaders.size() + 1;
|
||||
size_t recordOff = getRecordOffset(name);
|
||||
std::vector<std::thread> loaderThreads;
|
||||
@ -369,20 +378,31 @@ PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
|
||||
std::vector<size_t> readSizes(nthread, 0);
|
||||
std::lock_guard<std::mutex> guard(reader_lock_);
|
||||
for (size_t i = 0; i < nthread; i++) {
|
||||
loaderThreads.emplace_back([this, name, i, n, recordOff, perThreadSize, dst, &additionalReaders, &readSizes]{
|
||||
loaderThreads.emplace_back([this,
|
||||
name,
|
||||
i,
|
||||
n,
|
||||
recordOff,
|
||||
perThreadSize,
|
||||
dst,
|
||||
&additionalReaders,
|
||||
&readSizes] {
|
||||
size_t startPos = i * perThreadSize;
|
||||
size_t endPos = std::min((i + 1) * perThreadSize, n);
|
||||
if (startPos < endPos) {
|
||||
size_t threadReadSize = endPos - startPos;
|
||||
size_t size = 0;
|
||||
if (i == 0) {
|
||||
size = read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
|
||||
size =
|
||||
read(recordOff + startPos, (char*)dst + startPos, threadReadSize);
|
||||
} else {
|
||||
auto reader = additionalReaders[i - 1];
|
||||
size = reader->read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
|
||||
size = reader->read(
|
||||
recordOff + startPos, (char*)dst + startPos, threadReadSize);
|
||||
}
|
||||
readSizes[i] = size;
|
||||
LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos << "] "
|
||||
LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos
|
||||
<< "] "
|
||||
<< "from " << name << " of size " << n;
|
||||
TORCH_CHECK(
|
||||
threadReadSize == size,
|
||||
@ -415,8 +435,8 @@ PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
|
||||
}
|
||||
|
||||
// read record with multi clients
|
||||
std::tuple<at::DataPtr, size_t>
|
||||
PyTorchStreamReader::getRecord(const std::string& name,
|
||||
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
|
||||
const std::string& name,
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
|
||||
if (additionalReaders.empty()) {
|
||||
// No additional readers or record too small, use single threaded version
|
||||
@ -466,17 +486,20 @@ PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) {
|
||||
return stat.m_uncomp_size;
|
||||
}
|
||||
|
||||
|
||||
// inplace memory writing, in-tensor multi-threads, can be used for large tensor.
|
||||
size_t
|
||||
PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n,
|
||||
// inplace memory writing, in-tensor multi-threads, can be used for large
|
||||
// tensor.
|
||||
size_t PyTorchStreamReader::getRecord(
|
||||
const std::string& name,
|
||||
void* dst,
|
||||
size_t n,
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
|
||||
if (additionalReaders.empty()) {
|
||||
// No additional readers, use single threaded version
|
||||
return getRecord(name, dst, n);
|
||||
}
|
||||
|
||||
if ((!load_debug_symbol_) && c10::ends_with(std::string_view(name), kDebugPklSuffix)) {
|
||||
if ((!load_debug_symbol_) &&
|
||||
c10::ends_with(std::string_view(name), kDebugPklSuffix)) {
|
||||
return 0;
|
||||
}
|
||||
size_t key = getRecordID(name);
|
||||
@ -577,7 +600,8 @@ size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
|
||||
"reading file header");
|
||||
size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
|
||||
size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
|
||||
return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
|
||||
return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len +
|
||||
extra_len;
|
||||
}
|
||||
|
||||
size_t PyTorchStreamReader::getRecordSize(const std::string& name) {
|
||||
@ -620,14 +644,16 @@ size_t ostream_write_func(
|
||||
return ret;
|
||||
}
|
||||
|
||||
PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name, bool compute_crc32)
|
||||
: archive_name_(basename(file_name)),
|
||||
compute_crc32_(compute_crc32) {
|
||||
PyTorchStreamWriter::PyTorchStreamWriter(
|
||||
const std::string& file_name,
|
||||
bool compute_crc32)
|
||||
: archive_name_(basename(file_name)), compute_crc32_(compute_crc32) {
|
||||
setup(file_name);
|
||||
}
|
||||
|
||||
PyTorchStreamWriter::PyTorchStreamWriter(
|
||||
const std::function<size_t(const void*, size_t)> writer_func, bool compute_crc32)
|
||||
const std::function<size_t(const void*, size_t)> writer_func,
|
||||
bool compute_crc32)
|
||||
: archive_name_("archive"),
|
||||
writer_func_(writer_func),
|
||||
compute_crc32_(compute_crc32) {
|
||||
@ -651,8 +677,10 @@ void PyTorchStreamWriter::setup(const string& file_name) {
|
||||
const std::string dir_name = parentdir(file_name);
|
||||
if (!dir_name.empty()) {
|
||||
struct stat st;
|
||||
bool dir_exists = (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR));
|
||||
TORCH_CHECK(dir_exists, "Parent directory ", dir_name, " does not exist.");
|
||||
bool dir_exists =
|
||||
(stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR));
|
||||
TORCH_CHECK(
|
||||
dir_exists, "Parent directory ", dir_name, " does not exist.");
|
||||
}
|
||||
TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened.");
|
||||
writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
|
||||
@ -732,13 +760,16 @@ void PyTorchStreamWriter::writeEndOfFile() {
|
||||
~Finalizer() {
|
||||
var_ = true;
|
||||
}
|
||||
|
||||
private:
|
||||
bool& var_;
|
||||
} f(finalized_);
|
||||
|
||||
auto allRecords = getAllWrittenRecords();
|
||||
// If no ".data/version" or "version" record in the output model, rewrites version info
|
||||
if(allRecords.find(".data/version") == allRecords.end() && allRecords.find("version") == allRecords.end()) {
|
||||
// If no ".data/version" or "version" record in the output model, rewrites
|
||||
// version info
|
||||
if (allRecords.find(".data/version") == allRecords.end() &&
|
||||
allRecords.find("version") == allRecords.end()) {
|
||||
std::string version = std::to_string(version_);
|
||||
version.push_back('\n');
|
||||
if (version_ >= 0x6L) {
|
||||
@ -808,9 +839,8 @@ void PyTorchStreamWriter::writeSerializationId() {
|
||||
}
|
||||
std::ostringstream serialization_id_oss;
|
||||
serialization_id_oss << std::setfill('0') << std::setw(20)
|
||||
<< combined_record_name_hash
|
||||
<< std::setfill('0') << std::setw(20)
|
||||
<< combined_uncomp_crc32_;
|
||||
<< combined_record_name_hash << std::setfill('0')
|
||||
<< std::setw(20) << combined_uncomp_crc32_;
|
||||
serialization_id_ = serialization_id_oss.str();
|
||||
writeRecord(
|
||||
kSerializationIdRecordName,
|
||||
|
@ -16,7 +16,6 @@
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
#include "caffe2/serialize/versions.h"
|
||||
|
||||
|
||||
extern "C" {
|
||||
typedef struct mz_zip_archive mz_zip_archive;
|
||||
}
|
||||
@ -94,7 +93,8 @@ typedef struct mz_zip_archive mz_zip_archive;
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
|
||||
static constexpr const char* kSerializationIdRecordName = ".data/serialization_id";
|
||||
static constexpr const char* kSerializationIdRecordName =
|
||||
".data/serialization_id";
|
||||
|
||||
struct MzZipReaderIterWrapper;
|
||||
|
||||
@ -102,9 +102,12 @@ class TORCH_API ChunkRecordIterator {
|
||||
public:
|
||||
~ChunkRecordIterator();
|
||||
|
||||
// Read at most `chunkSize` into `buf`. Return the number of actual bytes read.
|
||||
// Read at most `chunkSize` into `buf`. Return the number of actual bytes
|
||||
// read.
|
||||
size_t next(void* buf);
|
||||
size_t recordSize() const { return recordSize_; }
|
||||
size_t recordSize() const {
|
||||
return recordSize_;
|
||||
}
|
||||
|
||||
private:
|
||||
ChunkRecordIterator(
|
||||
@ -129,13 +132,19 @@ class TORCH_API PyTorchStreamReader final {
|
||||
// return dataptr, size
|
||||
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
|
||||
// multi-thread getRecord
|
||||
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name, std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
|
||||
std::tuple<at::DataPtr, size_t> getRecord(
|
||||
const std::string& name,
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
|
||||
// inplace memory writing
|
||||
size_t getRecord(const std::string& name, void* dst, size_t n);
|
||||
// inplace memory writing, multi-threads.
|
||||
// When additionalReaders is empty, the default behavior is call getRecord(name, dst, n) with default reader
|
||||
// This approach can be used for reading large tensors.
|
||||
size_t getRecord(const std::string& name, void* dst, size_t n,
|
||||
// When additionalReaders is empty, the default behavior is call
|
||||
// getRecord(name, dst, n) with default reader This approach can be used for
|
||||
// reading large tensors.
|
||||
size_t getRecord(
|
||||
const std::string& name,
|
||||
void* dst,
|
||||
size_t n,
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
|
||||
size_t getRecord(
|
||||
const std::string& name,
|
||||
@ -143,21 +152,24 @@ class TORCH_API PyTorchStreamReader final {
|
||||
size_t n,
|
||||
size_t chunk_size,
|
||||
void* buf,
|
||||
const std::function<void(void*, const void*, size_t)>& memcpy_func = nullptr);
|
||||
const std::function<void(void*, const void*, size_t)>& memcpy_func =
|
||||
nullptr);
|
||||
|
||||
// Concurrent reading records with multiple readers.
|
||||
// additionalReaders are additional clients to access the underlying record at different offsets
|
||||
// and write to different trunks of buffers.
|
||||
// If the overall size of the tensor is 10, and size of additionalReader is 2.
|
||||
// The default thread will read [0,4), the additional reader will read [4,8).
|
||||
// The default reader will read [8,10).
|
||||
// The default reader will write to buffer[0,4), the additional reader will write to buffer[4,8),
|
||||
// the additional reader will write to buffer[8,10).
|
||||
// When additionalReaders is empty, the default behavior is call getRecord(name) with default reader
|
||||
// This approach can be used for reading large tensors.
|
||||
size_t getRecordMultiReaders(const std::string& name,
|
||||
// additionalReaders are additional clients to access the underlying record at
|
||||
// different offsets and write to different trunks of buffers. If the overall
|
||||
// size of the tensor is 10, and size of additionalReader is 2. The default
|
||||
// thread will read [0,4), the additional reader will read [4,8). The default
|
||||
// reader will read [8,10). The default reader will write to buffer[0,4), the
|
||||
// additional reader will write to buffer[4,8), the additional reader will
|
||||
// write to buffer[8,10). When additionalReaders is empty, the default
|
||||
// behavior is call getRecord(name) with default reader This approach can be
|
||||
// used for reading large tensors.
|
||||
size_t getRecordMultiReaders(
|
||||
const std::string& name,
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
|
||||
void *dst, size_t n);
|
||||
void* dst,
|
||||
size_t n);
|
||||
|
||||
size_t getRecordSize(const std::string& name);
|
||||
|
||||
@ -184,6 +196,7 @@ class TORCH_API PyTorchStreamReader final {
|
||||
void setAdditionalReaderSizeThreshold(const size_t& size) {
|
||||
additional_reader_size_threshold_ = size;
|
||||
}
|
||||
|
||||
private:
|
||||
void init();
|
||||
size_t read(uint64_t pos, char* buf, size_t n);
|
||||
@ -205,9 +218,12 @@ class TORCH_API PyTorchStreamReader final {
|
||||
|
||||
class TORCH_API PyTorchStreamWriter final {
|
||||
public:
|
||||
explicit PyTorchStreamWriter(const std::string& archive_name, bool compute_crc32 = true);
|
||||
explicit PyTorchStreamWriter(
|
||||
const std::function<size_t(const void*, size_t)> writer_func, bool compute_crc32 = true);
|
||||
const std::string& archive_name,
|
||||
bool compute_crc32 = true);
|
||||
explicit PyTorchStreamWriter(
|
||||
const std::function<size_t(const void*, size_t)> writer_func,
|
||||
bool compute_crc32 = true);
|
||||
|
||||
void setMinVersion(const uint64_t version);
|
||||
|
||||
|
@ -5,9 +5,9 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "caffe2/serialize/inline_container.h"
|
||||
#include <c10/util/Logging.h>
|
||||
#include "c10/util/irange.h"
|
||||
#include "caffe2/serialize/inline_container.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
@ -77,9 +77,12 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||
ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
|
||||
// chunked getRecord() test
|
||||
ret = reader.getRecord(
|
||||
"key1", dst.data(), size, 3, buf.data(), [](void* dst, const void* src, size_t n) {
|
||||
memcpy(dst, src, n);
|
||||
});
|
||||
"key1",
|
||||
dst.data(),
|
||||
size,
|
||||
3,
|
||||
buf.data(),
|
||||
[](void* dst, const void* src, size_t n) { memcpy(dst, src, n); });
|
||||
ASSERT_EQ(ret, size);
|
||||
ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
|
||||
|
||||
@ -97,9 +100,12 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||
ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
|
||||
// chunked getRecord() test
|
||||
ret = reader.getRecord(
|
||||
"key2", dst.data(), size, 3, buf.data(), [](void* dst, const void* src, size_t n) {
|
||||
memcpy(dst, src, n);
|
||||
});
|
||||
"key2",
|
||||
dst.data(),
|
||||
size,
|
||||
3,
|
||||
buf.data(),
|
||||
[](void* dst, const void* src, size_t n) { memcpy(dst, src, n); });
|
||||
ASSERT_EQ(ret, size);
|
||||
ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
|
||||
// clean up
|
||||
@ -107,7 +113,6 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||
}
|
||||
|
||||
TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) {
|
||||
|
||||
std::ostringstream oss;
|
||||
// write records through writers
|
||||
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
|
||||
@ -361,7 +366,10 @@ TEST(PytorchStreamWriterAndReader, SkipDuplicateSerializationIdRecords) {
|
||||
});
|
||||
|
||||
std::string dup_serialization_id = "dup-serialization-id";
|
||||
writer.writeRecord(kSerializationIdRecordName, dup_serialization_id.c_str(), dup_serialization_id.size());
|
||||
writer.writeRecord(
|
||||
kSerializationIdRecordName,
|
||||
dup_serialization_id.c_str(),
|
||||
dup_serialization_id.size());
|
||||
|
||||
const std::unordered_set<std::string>& written_records =
|
||||
writer.getAllWrittenRecords();
|
||||
@ -415,8 +423,7 @@ TEST(PytorchStreamWriterAndReader, LogAPIUsageMetadata) {
|
||||
{"pytorch.stream.reader.metadata",
|
||||
{{"serialization_id", writer.serializationId()},
|
||||
{"file_name", "archive"},
|
||||
{"file_size", str(iss.str().length())}}}
|
||||
};
|
||||
{"file_size", str(iss.str().length())}}}};
|
||||
ASSERT_EQ(expected_logs, logs);
|
||||
|
||||
// reset logger
|
||||
@ -433,7 +440,8 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
|
||||
TEST_P(ChunkRecordIteratorTest, ChunkRead) {
|
||||
auto chunkSize = GetParam();
|
||||
std::string zipFileName = "output_chunk_" + std::to_string(chunkSize) + ".zip";
|
||||
std::string zipFileName =
|
||||
"output_chunk_" + std::to_string(chunkSize) + ".zip";
|
||||
const char* fileName = zipFileName.c_str();
|
||||
const std::string recordName = "key1";
|
||||
const size_t tensorDataSizeInBytes = 1000;
|
||||
|
Reference in New Issue
Block a user