mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable concurrent reader for getRecord function (#111426)
Summary: Zion-4s core has poor perf when it comes to reading the large tensor (e.g. 300G), no matter for manifold downloading or reading from files. In this diff, I changed the getRecord function from single thread to multiple threads by passing multiple readers to getRecord function and access the same record at different chunks with different readers. We control the number of additional reader with the`sigrid_model_manager_additional_reader` flag. The default value is 0. When `additional_reader=2`, we allocate `2` extra read client threads. Pull Request resolved: https://github.com/pytorch/pytorch/pull/111426 Approved by: https://github.com/jiayisuse
This commit is contained in:
committed by
PyTorch MergeBot
parent
9d0c3e21d0
commit
12a6f5aa6b
@ -8,6 +8,7 @@
|
||||
#include <sstream>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <thread>
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
@ -346,6 +347,84 @@ std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string
|
||||
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
|
||||
}
|
||||
|
||||
size_t
|
||||
PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
|
||||
void *dst, size_t n){
|
||||
|
||||
size_t nthread = additionalReaders.size()+1;
|
||||
size_t recordOff = getRecordOffset(name);
|
||||
std::vector<std::thread> loaderThreads;
|
||||
size_t perThreadSize = (n+nthread-1)/nthread;
|
||||
std::vector<size_t> readSizes(nthread, 0);
|
||||
std::lock_guard<std::mutex> guard(reader_lock_);
|
||||
for(size_t i = 0; i < nthread ; i++){
|
||||
loaderThreads.emplace_back([this, name, i, n, recordOff, perThreadSize, dst, &additionalReaders, &readSizes]{
|
||||
size_t startPos = i*perThreadSize;
|
||||
size_t endPos = std::min((i+1)*perThreadSize,n);
|
||||
if (startPos < endPos){
|
||||
size_t threadReadSize = endPos - startPos;
|
||||
size_t size = 0;
|
||||
if (i==0){
|
||||
size = read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
|
||||
}else{
|
||||
auto reader = additionalReaders[i-1];
|
||||
size = reader->read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
|
||||
}
|
||||
readSizes[i] = size;
|
||||
LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos << "] "
|
||||
<< "from " << name << " of size " << n;
|
||||
TORCH_CHECK(
|
||||
threadReadSize == size,
|
||||
"record size ",
|
||||
threadReadSize,
|
||||
" mismatch with read size ",
|
||||
size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& thread : loaderThreads) {
|
||||
thread.join();
|
||||
}
|
||||
loaderThreads.clear();
|
||||
|
||||
auto total_read_n = std::reduce(readSizes.begin(),readSizes.end());
|
||||
TORCH_CHECK(
|
||||
n == total_read_n,
|
||||
"Multi reader total read size ",
|
||||
total_read_n,
|
||||
" mismatch with dst size ",
|
||||
n);
|
||||
|
||||
return total_read_n;
|
||||
}
|
||||
|
||||
// read record with multi clients
|
||||
std::tuple<at::DataPtr, size_t>
|
||||
PyTorchStreamReader::getRecord(const std::string& name,
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
|
||||
|
||||
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
|
||||
at::DataPtr retval;
|
||||
return std::make_tuple(std::move(retval), 0);
|
||||
}
|
||||
size_t key = getRecordID(name);
|
||||
mz_zip_archive_file_stat stat;
|
||||
mz_zip_reader_file_stat(ar_.get(), key, &stat);
|
||||
auto n = stat.m_uncomp_size;
|
||||
valid("retrieving file meta-data for ", name.c_str());
|
||||
if(additionalReaders.empty() || n < additional_reader_size_threshold_){
|
||||
// No additional readers or record too small, use single threaded version
|
||||
return getRecord(name);
|
||||
}
|
||||
|
||||
at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
|
||||
void* dst = retval.get();
|
||||
PyTorchStreamReader::getRecordMultiReaders(name, additionalReaders, dst, n);
|
||||
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
|
||||
}
|
||||
|
||||
// inplace memory writing
|
||||
size_t
|
||||
PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) {
|
||||
@ -369,6 +448,34 @@ PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) {
|
||||
return stat.m_uncomp_size;
|
||||
}
|
||||
|
||||
|
||||
// inplace memory writing, in-tensor multi-threads, can be used for large tensor.
|
||||
size_t
|
||||
PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n,
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
|
||||
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
|
||||
return 0;
|
||||
}
|
||||
size_t key = getRecordID(name);
|
||||
mz_zip_archive_file_stat stat;
|
||||
mz_zip_reader_file_stat(ar_.get(), key, &stat);
|
||||
TORCH_CHECK(
|
||||
n == stat.m_uncomp_size,
|
||||
"record size ",
|
||||
stat.m_uncomp_size,
|
||||
" mismatch with dst size ",
|
||||
n);
|
||||
valid("retrieving file meta-data for ", name.c_str());
|
||||
|
||||
if(additionalReaders.empty() || n < additional_reader_size_threshold_){
|
||||
// No additional readers, use single threaded version
|
||||
return getRecord(name, dst, n);
|
||||
}
|
||||
|
||||
PyTorchStreamReader::getRecordMultiReaders(name, additionalReaders, dst, n);
|
||||
return stat.m_uncomp_size;
|
||||
}
|
||||
|
||||
size_t PyTorchStreamReader::getRecord(
|
||||
const std::string& name,
|
||||
void* dst,
|
||||
|
@ -16,6 +16,7 @@
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
#include "caffe2/serialize/versions.h"
|
||||
|
||||
|
||||
extern "C" {
|
||||
typedef struct mz_zip_archive mz_zip_archive;
|
||||
}
|
||||
@ -126,8 +127,15 @@ class TORCH_API PyTorchStreamReader final {
|
||||
|
||||
// return dataptr, size
|
||||
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
|
||||
// multi-thread getRecord
|
||||
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name, std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
|
||||
// inplace memory writing
|
||||
size_t getRecord(const std::string& name, void* dst, size_t n);
|
||||
// inplace memory writing, multi-threads.
|
||||
// When additionalReaders is empty, the default behavior is call getRecord(name, dst, n) with default reader
|
||||
// This approach can be used for reading large tensors.
|
||||
size_t getRecord(const std::string& name, void* dst, size_t n,
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
|
||||
size_t getRecord(
|
||||
const std::string& name,
|
||||
void* dst,
|
||||
@ -136,6 +144,20 @@ class TORCH_API PyTorchStreamReader final {
|
||||
void* buf,
|
||||
const std::function<void(void*, const void*, size_t)>& memcpy_func = nullptr);
|
||||
|
||||
// Concurrent reading records with multiple readers.
|
||||
// additionalReaders are additional clients to access the underlying record at different offsets
|
||||
// and write to different trunks of buffers.
|
||||
// If the overall size of the tensor is 10, and size of additionalReader is 2.
|
||||
// The default thread will read [0,4), the additional reader will read [4,8).
|
||||
// The default reader will read [8,10).
|
||||
// The default reader will write to buffer[0,4), the additional reader will write to buffer[4,8),
|
||||
// the additional reader will write to buffer[8,10).
|
||||
// When additionalReaders is empty, the default behavior is call getRecord(name) with default reader
|
||||
// This approach can be used for reading large tensors.
|
||||
size_t getRecordMultiReaders(const std::string& name,
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
|
||||
void *dst, size_t n);
|
||||
|
||||
size_t getRecordSize(const std::string& name);
|
||||
|
||||
size_t getRecordOffset(const std::string& name);
|
||||
@ -158,7 +180,9 @@ class TORCH_API PyTorchStreamReader final {
|
||||
void setShouldLoadDebugSymbol(bool should_load_debug_symbol) {
|
||||
load_debug_symbol_ = should_load_debug_symbol;
|
||||
}
|
||||
|
||||
void setAdditionalReaderSizeThreshold(const size_t& size){
|
||||
additional_reader_size_threshold_ = size;
|
||||
}
|
||||
private:
|
||||
void init();
|
||||
size_t read(uint64_t pos, char* buf, size_t n);
|
||||
@ -175,6 +199,7 @@ class TORCH_API PyTorchStreamReader final {
|
||||
std::mutex reader_lock_;
|
||||
bool load_debug_symbol_ = true;
|
||||
std::string serialization_id_;
|
||||
size_t additional_reader_size_threshold_;
|
||||
};
|
||||
|
||||
class TORCH_API PyTorchStreamWriter final {
|
||||
|
@ -105,6 +105,86 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||
remove(file_name);
|
||||
}
|
||||
|
||||
TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) {
|
||||
|
||||
std::ostringstream oss;
|
||||
// write records through writers
|
||||
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
|
||||
oss.write(static_cast<const char*>(b), n);
|
||||
return oss ? n : 0;
|
||||
});
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
|
||||
std::array<char, 127> data1;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
|
||||
std::array<char, 64> data2;
|
||||
for (auto i : c10::irange(data1.size())) {
|
||||
data1[i] = data1.size() - i;
|
||||
}
|
||||
writer.writeRecord("key1", data1.data(), data1.size());
|
||||
|
||||
for (auto i : c10::irange(data2.size())) {
|
||||
data2[i] = data2.size() - i;
|
||||
}
|
||||
writer.writeRecord("key2", data2.data(), data2.size());
|
||||
|
||||
const std::unordered_set<std::string>& written_records =
|
||||
writer.getAllWrittenRecords();
|
||||
ASSERT_EQ(written_records.size(), 2);
|
||||
ASSERT_EQ(written_records.count("key1"), 1);
|
||||
ASSERT_EQ(written_records.count("key2"), 1);
|
||||
|
||||
writer.writeEndOfFile();
|
||||
ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
|
||||
|
||||
std::string the_file = oss.str();
|
||||
const char* file_name = "output.zip";
|
||||
std::ofstream foo(file_name);
|
||||
foo.write(the_file.c_str(), the_file.size());
|
||||
foo.close();
|
||||
|
||||
// read records through pytorchStreamReader
|
||||
std::istringstream iss(the_file);
|
||||
PyTorchStreamReader reader(&iss);
|
||||
reader.setAdditionalReaderSizeThreshold(0);
|
||||
// before testing, sanity check
|
||||
int64_t size1, size2, ret;
|
||||
at::DataPtr data_ptr;
|
||||
std::tie(data_ptr, size1) = reader.getRecord("key1");
|
||||
std::tie(data_ptr, size2) = reader.getRecord("key2");
|
||||
|
||||
// Test getRecord(name, additional_readers)
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>> additionalReader;
|
||||
for(int i=0; i<10; ++i){
|
||||
// Test various sized additional readers.
|
||||
std::tie(data_ptr, ret) = reader.getRecord("key1", additionalReader);
|
||||
ASSERT_EQ(ret, size1);
|
||||
ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), size1), 0);
|
||||
|
||||
std::tie(data_ptr, ret) = reader.getRecord("key2", additionalReader);
|
||||
ASSERT_EQ(ret, size2);
|
||||
ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), size2), 0);
|
||||
}
|
||||
|
||||
// Inplace multi-threading getRecord(name, dst, n, additional_readers) test
|
||||
additionalReader.clear();
|
||||
std::vector<uint8_t> dst1(size1), dst2(size2);
|
||||
for(int i=0; i<10; ++i){
|
||||
// Test various sizes of read threads
|
||||
additionalReader.push_back(std::make_unique<IStreamAdapter>(&iss));
|
||||
|
||||
ret = reader.getRecord("key1", dst1.data(), size1, additionalReader);
|
||||
ASSERT_EQ(ret, size1);
|
||||
ASSERT_EQ(memcmp(dst1.data(), data1.data(), size1), 0);
|
||||
|
||||
ret = reader.getRecord("key2", dst2.data(), size2, additionalReader);
|
||||
ASSERT_EQ(ret, size2);
|
||||
ASSERT_EQ(memcmp(dst2.data(), data2.data(), size2), 0);
|
||||
}
|
||||
// clean up
|
||||
remove(file_name);
|
||||
}
|
||||
|
||||
TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
|
||||
std::ostringstream oss;
|
||||
// write records through writers
|
||||
|
Reference in New Issue
Block a user