diff --git a/caffe2/serialize/file_adapter.cc b/caffe2/serialize/file_adapter.cc index 3839fb5bbb83..84ad13db0c96 100644 --- a/caffe2/serialize/file_adapter.cc +++ b/caffe2/serialize/file_adapter.cc @@ -21,7 +21,8 @@ FileAdapter::RAIIFile::RAIIFile(const std::string& file_name) { auto error_msg = std::system_category().default_error_condition(old_errno).message(); #endif - TORCH_CHECK(false, + TORCH_CHECK( + false, "open file failed because of errno ", old_errno, " on fopen: ", diff --git a/caffe2/serialize/file_adapter.h b/caffe2/serialize/file_adapter.h index 36bc5de27100..23307abc904d 100644 --- a/caffe2/serialize/file_adapter.h +++ b/caffe2/serialize/file_adapter.h @@ -1,8 +1,8 @@ #pragma once +#include #include #include -#include #include "caffe2/serialize/istream_adapter.h" #include "caffe2/serialize/read_adapter_interface.h" diff --git a/caffe2/serialize/in_memory_adapter.h b/caffe2/serialize/in_memory_adapter.h index 817f3f5d6774..f9f5619ac3bc 100644 --- a/caffe2/serialize/in_memory_adapter.h +++ b/caffe2/serialize/in_memory_adapter.h @@ -1,7 +1,6 @@ #pragma once -#include #include - +#include namespace caffe2 { namespace serialize { @@ -17,7 +16,7 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface { size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") const override { - (void) what; + (void)what; memcpy(buf, (int8_t*)(data_) + pos, n); return n; } @@ -27,6 +26,5 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface { off_t size_; }; - } // namespace serialize } // namespace caffe2 diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index ed2c5eeb6900..2554d25fd2a3 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -1,13 +1,13 @@ -#include -#include -#include -#include -#include -#include -#include -#include #include #include +#include +#include +#include +#include +#include +#include +#include +#include #include #include @@ -48,25 +48,27 @@ ChunkRecordIterator::~ChunkRecordIterator() { mz_zip_reader_extract_iter_free(iter_->impl); } -size_t ChunkRecordIterator::next(void* buf){ +size_t ChunkRecordIterator::next(void* buf) { size_t want_size = std::min(chunkSize_, recordSize_ - offset_); if (want_size == 0) { return 0; } - size_t read_size = mz_zip_reader_extract_iter_read(iter_->impl, buf, want_size); + size_t read_size = + mz_zip_reader_extract_iter_read(iter_->impl, buf, want_size); TORCH_CHECK(read_size > 0, "Read bytes should be larger than 0"); offset_ += read_size; return read_size; } -size_t istream_read_func(void* pOpaque, mz_uint64 file_ofs, void* pBuf, size_t n) { +size_t +istream_read_func(void* pOpaque, mz_uint64 file_ofs, void* pBuf, size_t n) { auto self = static_cast(pOpaque); return self->read(file_ofs, static_cast(pBuf), n); } static std::string basename(const std::string& name) { size_t start = 0; - for(size_t i = 0; i < name.size(); ++i) { + for (size_t i = 0; i < name.size(); ++i) { if (name[i] == '\\' || name[i] == '/') { start = i + 1; } @@ -77,7 +79,7 @@ static std::string basename(const std::string& name) { } size_t end = name.size(); - for(size_t i = end; i > start; --i) { + for (size_t i = end; i > start; --i) { if (name[i - 1] == '.') { end = i - 1; break; @@ -92,13 +94,13 @@ static std::string parentdir(const std::string& name) { end = name.find_last_of('\\'); } - #ifdef WIN32 +#ifdef WIN32 if (end != std::string::npos && end > 1 && name[end - 1] == ':') { // This is a Windows root directory, so include the slash in // the parent directory end++; } - #endif +#endif if (end == std::string::npos) { return ""; @@ -157,8 +159,8 @@ void PyTorchStreamReader::init() { mz_zip_reader_init(ar_.get(), size, 0); valid("reading zip archive"); - // figure out the archive_name (i.e. the zip folder all the other files are in) - // all lookups to getRecord will be prefixed by this folder + // figure out the archive_name (i.e. the zip folder all the other files are + // in) all lookups to getRecord will be prefixed by this folder mz_uint n = mz_zip_reader_get_num_files(ar_.get()); if (n == 0) { CAFFE_THROW("archive does not contain any files"); @@ -201,15 +203,15 @@ void PyTorchStreamReader::init() { TORCH_CHECK(hasRecord("version")) std::tie(version_ptr, version_size) = getRecord("version"); } - std::string version(static_cast(version_ptr.get()), version_size); + std::string version( + static_cast(version_ptr.get()), version_size); try { version_ = std::stoull(version); } catch (const std::invalid_argument& e) { - CAFFE_THROW("Couldn't parse the version ", - version, - " as Long Long."); + CAFFE_THROW("Couldn't parse the version ", version, " as Long Long."); } - if (version_ < static_cast(kMinSupportedFileFormatVersion)) { + if (version_ < + static_cast(kMinSupportedFileFormatVersion)) { CAFFE_THROW( "Attempted to read a PyTorch file with version ", std::to_string(version_), @@ -219,7 +221,8 @@ void PyTorchStreamReader::init() { " with latest version of PyTorch to mitigate this issue."); } - if (version_ > static_cast(kMaxSupportedFileFormatVersion)) { + if (version_ > + static_cast(kMaxSupportedFileFormatVersion)) { CAFFE_THROW( "Attempted to read a PyTorch file with version ", version_, @@ -277,12 +280,13 @@ size_t getPadding( padding_buf[3] = (uint8_t)(padding_size >> 8); return padding_size_plus_fbxx; } -} +} // namespace detail bool PyTorchStreamReader::hasRecord(const std::string& name) { std::lock_guard guard(reader_lock_); - if ((!load_debug_symbol_) && c10::ends_with(std::string_view(name), kDebugPklSuffix)) { + if ((!load_debug_symbol_) && + c10::ends_with(std::string_view(name), kDebugPklSuffix)) { return false; } std::string ss = archive_name_plus_slash_ + name; @@ -307,7 +311,8 @@ std::vector PyTorchStreamReader::getAllRecords() { // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) char buf[MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE]; for (size_t i = 0; i < num_files; i++) { - mz_zip_reader_get_filename(ar_.get(), i, buf, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE); + mz_zip_reader_get_filename( + ar_.get(), i, buf, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE); if (strncmp( buf, archive_name_plus_slash_.data(), @@ -319,7 +324,9 @@ std::vector PyTorchStreamReader::getAllRecords() { buf); } if ((load_debug_symbol_) || - (!c10::ends_with(std::string_view(buf + archive_name_plus_slash_.size()),kDebugPklSuffix))) { + (!c10::ends_with( + std::string_view(buf + archive_name_plus_slash_.size()), + kDebugPklSuffix))) { // NOLINTNEXTLINE(modernize-use-emplace) out.push_back(buf + archive_name_plus_slash_.size()); } @@ -340,7 +347,8 @@ size_t PyTorchStreamReader::getRecordID(const std::string& name) { } // return dataptr, size -std::tuple PyTorchStreamReader::getRecord(const std::string& name) { +std::tuple PyTorchStreamReader::getRecord( + const std::string& name) { std::lock_guard guard(reader_lock_); if ((!load_debug_symbol_) && c10::ends_with(name, kDebugPklSuffix)) { at::DataPtr retval; @@ -351,45 +359,57 @@ std::tuple PyTorchStreamReader::getRecord(const std::string mz_zip_reader_file_stat(ar_.get(), key, &stat); valid("retrieving file meta-data for ", name.c_str()); at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size); - mz_zip_reader_extract_to_mem(ar_.get(), key, retval.get(), stat.m_uncomp_size, 0); + mz_zip_reader_extract_to_mem( + ar_.get(), key, retval.get(), stat.m_uncomp_size, 0); valid("reading file ", name.c_str()); return std::make_tuple(std::move(retval), stat.m_uncomp_size); } -size_t -PyTorchStreamReader::getRecordMultiReaders(const std::string& name, - std::vector>& additionalReaders, - void *dst, size_t n){ - - size_t nthread = additionalReaders.size()+1; +size_t PyTorchStreamReader::getRecordMultiReaders( + const std::string& name, + std::vector>& additionalReaders, + void* dst, + size_t n) { + size_t nthread = additionalReaders.size() + 1; size_t recordOff = getRecordOffset(name); std::vector loaderThreads; - size_t perThreadSize = (n+nthread-1)/nthread; + size_t perThreadSize = (n + nthread - 1) / nthread; std::vector readSizes(nthread, 0); std::lock_guard guard(reader_lock_); - for(size_t i = 0; i < nthread ; i++){ - loaderThreads.emplace_back([this, name, i, n, recordOff, perThreadSize, dst, &additionalReaders, &readSizes]{ - size_t startPos = i*perThreadSize; - size_t endPos = std::min((i+1)*perThreadSize,n); - if (startPos < endPos){ + for (size_t i = 0; i < nthread; i++) { + loaderThreads.emplace_back([this, + name, + i, + n, + recordOff, + perThreadSize, + dst, + &additionalReaders, + &readSizes] { + size_t startPos = i * perThreadSize; + size_t endPos = std::min((i + 1) * perThreadSize, n); + if (startPos < endPos) { size_t threadReadSize = endPos - startPos; size_t size = 0; - if (i==0){ - size = read(recordOff+startPos, (char *)dst+startPos, threadReadSize); - }else{ - auto reader = additionalReaders[i-1]; - size = reader->read(recordOff+startPos, (char *)dst+startPos, threadReadSize); + if (i == 0) { + size = + read(recordOff + startPos, (char*)dst + startPos, threadReadSize); + } else { + auto reader = additionalReaders[i - 1]; + size = reader->read( + recordOff + startPos, (char*)dst + startPos, threadReadSize); } readSizes[i] = size; - LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos << "] " - << "from " << name << " of size " << n; + LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos + << "] " + << "from " << name << " of size " << n; TORCH_CHECK( - threadReadSize == size, - "record size ", - threadReadSize, - " mismatch with read size ", - size); + threadReadSize == size, + "record size ", + threadReadSize, + " mismatch with read size ", + size); } }); } @@ -400,7 +420,7 @@ PyTorchStreamReader::getRecordMultiReaders(const std::string& name, loaderThreads.clear(); size_t total_read_n = 0; - for (auto& r : readSizes){ + for (auto& r : readSizes) { total_read_n += r; } @@ -415,10 +435,10 @@ PyTorchStreamReader::getRecordMultiReaders(const std::string& name, } // read record with multi clients -std::tuple -PyTorchStreamReader::getRecord(const std::string& name, - std::vector>& additionalReaders) { - if(additionalReaders.empty()){ +std::tuple PyTorchStreamReader::getRecord( + const std::string& name, + std::vector>& additionalReaders) { + if (additionalReaders.empty()) { // No additional readers or record too small, use single threaded version return getRecord(name); } @@ -432,7 +452,7 @@ PyTorchStreamReader::getRecord(const std::string& name, mz_zip_reader_file_stat(ar_.get(), key, &stat); auto n = stat.m_uncomp_size; valid("retrieving file meta-data for ", name.c_str()); - if(n < additional_reader_size_threshold_){ + if (n < additional_reader_size_threshold_) { // Reader size too small, use single threaded version return getRecord(name); } @@ -466,17 +486,20 @@ PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) { return stat.m_uncomp_size; } - -// inplace memory writing, in-tensor multi-threads, can be used for large tensor. -size_t -PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n, - std::vector>& additionalReaders) { - if(additionalReaders.empty()){ +// inplace memory writing, in-tensor multi-threads, can be used for large +// tensor. +size_t PyTorchStreamReader::getRecord( + const std::string& name, + void* dst, + size_t n, + std::vector>& additionalReaders) { + if (additionalReaders.empty()) { // No additional readers, use single threaded version return getRecord(name, dst, n); } - if ((!load_debug_symbol_) && c10::ends_with(std::string_view(name), kDebugPklSuffix)) { + if ((!load_debug_symbol_) && + c10::ends_with(std::string_view(name), kDebugPklSuffix)) { return 0; } size_t key = getRecordID(name); @@ -490,7 +513,7 @@ PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n, n); valid("retrieving file meta-data for ", name.c_str()); - if(n < additional_reader_size_threshold_){ + if (n < additional_reader_size_threshold_) { // Reader size too small, use single threaded version return getRecord(name, dst, n); } @@ -577,7 +600,8 @@ size_t PyTorchStreamReader::getRecordOffset(const std::string& name) { "reading file header"); size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS); size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS); - return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len; + return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + + extra_len; } size_t PyTorchStreamReader::getRecordSize(const std::string& name) { @@ -620,14 +644,16 @@ size_t ostream_write_func( return ret; } -PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name, bool compute_crc32) - : archive_name_(basename(file_name)), - compute_crc32_(compute_crc32) { +PyTorchStreamWriter::PyTorchStreamWriter( + const std::string& file_name, + bool compute_crc32) + : archive_name_(basename(file_name)), compute_crc32_(compute_crc32) { setup(file_name); } PyTorchStreamWriter::PyTorchStreamWriter( - const std::function writer_func, bool compute_crc32) + const std::function writer_func, + bool compute_crc32) : archive_name_("archive"), writer_func_(writer_func), compute_crc32_(compute_crc32) { @@ -649,10 +675,12 @@ void PyTorchStreamWriter::setup(const string& file_name) { valid("opening archive ", file_name.c_str()); const std::string dir_name = parentdir(file_name); - if(!dir_name.empty()) { + if (!dir_name.empty()) { struct stat st; - bool dir_exists = (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR)); - TORCH_CHECK(dir_exists, "Parent directory ", dir_name, " does not exist."); + bool dir_exists = + (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR)); + TORCH_CHECK( + dir_exists, "Parent directory ", dir_name, " does not exist."); } TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened."); writer_func_ = [this](const void* buf, size_t nbytes) -> size_t { @@ -728,17 +756,20 @@ void PyTorchStreamWriter::writeEndOfFile() { // destructor would would result in `std::terminate()` // See https://github.com/pytorch/pytorch/issues/87997/ struct Finalizer { - Finalizer(bool& var): var_(var) {} + Finalizer(bool& var) : var_(var) {} ~Finalizer() { var_ = true; } + private: bool& var_; } f(finalized_); auto allRecords = getAllWrittenRecords(); - // If no ".data/version" or "version" record in the output model, rewrites version info - if(allRecords.find(".data/version") == allRecords.end() && allRecords.find("version") == allRecords.end()) { + // If no ".data/version" or "version" record in the output model, rewrites + // version info + if (allRecords.find(".data/version") == allRecords.end() && + allRecords.find("version") == allRecords.end()) { std::string version = std::to_string(version_); version.push_back('\n'); if (version_ >= 0x6L) { @@ -749,7 +780,7 @@ void PyTorchStreamWriter::writeEndOfFile() { } // If no "byteorder" record in the output model, rewrites byteorder info - if(allRecords.find("byteorder") == allRecords.end()) { + if (allRecords.find("byteorder") == allRecords.end()) { #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ std::string byteorder = "little"; #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ @@ -808,9 +839,8 @@ void PyTorchStreamWriter::writeSerializationId() { } std::ostringstream serialization_id_oss; serialization_id_oss << std::setfill('0') << std::setw(20) - << combined_record_name_hash - << std::setfill('0') << std::setw(20) - << combined_uncomp_crc32_; + << combined_record_name_hash << std::setfill('0') + << std::setw(20) << combined_uncomp_crc32_; serialization_id_ = serialization_id_oss.str(); writeRecord( kSerializationIdRecordName, diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index 55a723f3b891..59e0991399a3 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -16,7 +16,6 @@ #include "caffe2/serialize/read_adapter_interface.h" #include "caffe2/serialize/versions.h" - extern "C" { typedef struct mz_zip_archive mz_zip_archive; } @@ -94,7 +93,8 @@ typedef struct mz_zip_archive mz_zip_archive; namespace caffe2 { namespace serialize { -static constexpr const char* kSerializationIdRecordName = ".data/serialization_id"; +static constexpr const char* kSerializationIdRecordName = + ".data/serialization_id"; struct MzZipReaderIterWrapper; @@ -102,12 +102,15 @@ class TORCH_API ChunkRecordIterator { public: ~ChunkRecordIterator(); - // Read at most `chunkSize` into `buf`. Return the number of actual bytes read. + // Read at most `chunkSize` into `buf`. Return the number of actual bytes + // read. size_t next(void* buf); - size_t recordSize() const { return recordSize_; } + size_t recordSize() const { + return recordSize_; + } private: - ChunkRecordIterator( + ChunkRecordIterator( size_t recordSize, size_t chunkSize, std::unique_ptr iter); @@ -129,35 +132,44 @@ class TORCH_API PyTorchStreamReader final { // return dataptr, size std::tuple getRecord(const std::string& name); // multi-thread getRecord - std::tuple getRecord(const std::string& name, std::vector>& additionalReaders); + std::tuple getRecord( + const std::string& name, + std::vector>& additionalReaders); // inplace memory writing size_t getRecord(const std::string& name, void* dst, size_t n); // inplace memory writing, multi-threads. - // When additionalReaders is empty, the default behavior is call getRecord(name, dst, n) with default reader - // This approach can be used for reading large tensors. - size_t getRecord(const std::string& name, void* dst, size_t n, - std::vector>& additionalReaders); + // When additionalReaders is empty, the default behavior is call + // getRecord(name, dst, n) with default reader This approach can be used for + // reading large tensors. + size_t getRecord( + const std::string& name, + void* dst, + size_t n, + std::vector>& additionalReaders); size_t getRecord( const std::string& name, void* dst, size_t n, size_t chunk_size, void* buf, - const std::function& memcpy_func = nullptr); + const std::function& memcpy_func = + nullptr); // Concurrent reading records with multiple readers. - // additionalReaders are additional clients to access the underlying record at different offsets - // and write to different trunks of buffers. - // If the overall size of the tensor is 10, and size of additionalReader is 2. - // The default thread will read [0,4), the additional reader will read [4,8). - // The default reader will read [8,10). - // The default reader will write to buffer[0,4), the additional reader will write to buffer[4,8), - // the additional reader will write to buffer[8,10). - // When additionalReaders is empty, the default behavior is call getRecord(name) with default reader - // This approach can be used for reading large tensors. - size_t getRecordMultiReaders(const std::string& name, - std::vector>& additionalReaders, - void *dst, size_t n); + // additionalReaders are additional clients to access the underlying record at + // different offsets and write to different trunks of buffers. If the overall + // size of the tensor is 10, and size of additionalReader is 2. The default + // thread will read [0,4), the additional reader will read [4,8). The default + // reader will read [8,10). The default reader will write to buffer[0,4), the + // additional reader will write to buffer[4,8), the additional reader will + // write to buffer[8,10). When additionalReaders is empty, the default + // behavior is call getRecord(name) with default reader This approach can be + // used for reading large tensors. + size_t getRecordMultiReaders( + const std::string& name, + std::vector>& additionalReaders, + void* dst, + size_t n); size_t getRecordSize(const std::string& name); @@ -181,9 +193,10 @@ class TORCH_API PyTorchStreamReader final { void setShouldLoadDebugSymbol(bool should_load_debug_symbol) { load_debug_symbol_ = should_load_debug_symbol; } - void setAdditionalReaderSizeThreshold(const size_t& size){ + void setAdditionalReaderSizeThreshold(const size_t& size) { additional_reader_size_threshold_ = size; } + private: void init(); size_t read(uint64_t pos, char* buf, size_t n); @@ -205,9 +218,12 @@ class TORCH_API PyTorchStreamReader final { class TORCH_API PyTorchStreamWriter final { public: - explicit PyTorchStreamWriter(const std::string& archive_name, bool compute_crc32 = true); explicit PyTorchStreamWriter( - const std::function writer_func, bool compute_crc32 = true); + const std::string& archive_name, + bool compute_crc32 = true); + explicit PyTorchStreamWriter( + const std::function writer_func, + bool compute_crc32 = true); void setMinVersion(const uint64_t version); diff --git a/caffe2/serialize/inline_container_test.cc b/caffe2/serialize/inline_container_test.cc index 4e027f681961..1fb19c88d330 100644 --- a/caffe2/serialize/inline_container_test.cc +++ b/caffe2/serialize/inline_container_test.cc @@ -5,9 +5,9 @@ #include -#include "caffe2/serialize/inline_container.h" #include #include "c10/util/irange.h" +#include "caffe2/serialize/inline_container.h" namespace caffe2 { namespace serialize { @@ -77,9 +77,12 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0); // chunked getRecord() test ret = reader.getRecord( - "key1", dst.data(), size, 3, buf.data(), [](void* dst, const void* src, size_t n) { - memcpy(dst, src, n); - }); + "key1", + dst.data(), + size, + 3, + buf.data(), + [](void* dst, const void* src, size_t n) { memcpy(dst, src, n); }); ASSERT_EQ(ret, size); ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0); @@ -97,9 +100,12 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0); // chunked getRecord() test ret = reader.getRecord( - "key2", dst.data(), size, 3, buf.data(), [](void* dst, const void* src, size_t n) { - memcpy(dst, src, n); - }); + "key2", + dst.data(), + size, + 3, + buf.data(), + [](void* dst, const void* src, size_t n) { memcpy(dst, src, n); }); ASSERT_EQ(ret, size); ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0); // clean up @@ -107,7 +113,6 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { } TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) { - std::ostringstream oss; // write records through writers PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t { @@ -156,7 +161,7 @@ TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) { // Test getRecord(name, additional_readers) std::vector> additionalReader; - for(int i=0; i<10; ++i){ + for (int i = 0; i < 10; ++i) { // Test various sized additional readers. std::tie(data_ptr, ret) = reader.getRecord("key1", additionalReader); ASSERT_EQ(ret, size1); @@ -170,7 +175,7 @@ TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) { // Inplace multi-threading getRecord(name, dst, n, additional_readers) test additionalReader.clear(); std::vector dst1(size1), dst2(size2); - for(int i=0; i<10; ++i){ + for (int i = 0; i < 10; ++i) { // Test various sizes of read threads additionalReader.push_back(std::make_unique(&iss)); @@ -324,7 +329,7 @@ TEST(PytorchStreamWriterAndReader, ValidSerializationId) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) std::array data1; - for (auto i: c10::irange(data1.size())) { + for (auto i : c10::irange(data1.size())) { data1[i] = data1.size() - i; } writer.writeRecord("key1.debug_pkl", data1.data(), data1.size()); @@ -361,7 +366,10 @@ TEST(PytorchStreamWriterAndReader, SkipDuplicateSerializationIdRecords) { }); std::string dup_serialization_id = "dup-serialization-id"; - writer.writeRecord(kSerializationIdRecordName, dup_serialization_id.c_str(), dup_serialization_id.size()); + writer.writeRecord( + kSerializationIdRecordName, + dup_serialization_id.c_str(), + dup_serialization_id.size()); const std::unordered_set& written_records = writer.getAllWrittenRecords(); @@ -410,13 +418,12 @@ TEST(PytorchStreamWriterAndReader, LogAPIUsageMetadata) { std::map> expected_logs = { {"pytorch.stream.writer.metadata", {{"serialization_id", writer.serializationId()}, - {"file_name", "archive"}, - {"file_size", str(oss.str().length())}}}, + {"file_name", "archive"}, + {"file_size", str(oss.str().length())}}}, {"pytorch.stream.reader.metadata", {{"serialization_id", writer.serializationId()}, - {"file_name", "archive"}, - {"file_size", str(iss.str().length())}}} - }; + {"file_name", "archive"}, + {"file_size", str(iss.str().length())}}}}; ASSERT_EQ(expected_logs, logs); // reset logger @@ -433,7 +440,8 @@ INSTANTIATE_TEST_SUITE_P( TEST_P(ChunkRecordIteratorTest, ChunkRead) { auto chunkSize = GetParam(); - std::string zipFileName = "output_chunk_" + std::to_string(chunkSize) + ".zip"; + std::string zipFileName = + "output_chunk_" + std::to_string(chunkSize) + ".zip"; const char* fileName = zipFileName.c_str(); const std::string recordName = "key1"; const size_t tensorDataSizeInBytes = 1000;