Enable concurrent reader for getRecord function (#111426)

Summary:
Zion-4s core has poor perf when it comes to reading the large tensor (e.g. 300G), no matter for manifold downloading or reading from files. In this diff, I changed the getRecord function from single thread to multiple threads by passing multiple readers to getRecord function and access the same record at different chunks with different readers.
We control the number of additional reader with the`sigrid_model_manager_additional_reader` flag. The default value is 0. When `additional_reader=2`, we allocate `2` extra read client threads.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111426
Approved by: https://github.com/jiayisuse
This commit is contained in:
Zhijing Li (Accelerator Enablement)
2023-11-02 22:07:04 +00:00
committed by PyTorch MergeBot
parent 9d0c3e21d0
commit 12a6f5aa6b
3 changed files with 213 additions and 1 deletions

View File

@ -8,6 +8,7 @@
#include <sstream>
#include <sys/stat.h>
#include <sys/types.h>
#include <thread>
#include <c10/core/Allocator.h>
#include <c10/core/CPUAllocator.h>
@ -346,6 +347,84 @@ std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
}
size_t
PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
void *dst, size_t n){
size_t nthread = additionalReaders.size()+1;
size_t recordOff = getRecordOffset(name);
std::vector<std::thread> loaderThreads;
size_t perThreadSize = (n+nthread-1)/nthread;
std::vector<size_t> readSizes(nthread, 0);
std::lock_guard<std::mutex> guard(reader_lock_);
for(size_t i = 0; i < nthread ; i++){
loaderThreads.emplace_back([this, name, i, n, recordOff, perThreadSize, dst, &additionalReaders, &readSizes]{
size_t startPos = i*perThreadSize;
size_t endPos = std::min((i+1)*perThreadSize,n);
if (startPos < endPos){
size_t threadReadSize = endPos - startPos;
size_t size = 0;
if (i==0){
size = read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
}else{
auto reader = additionalReaders[i-1];
size = reader->read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
}
readSizes[i] = size;
LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos << "] "
<< "from " << name << " of size " << n;
TORCH_CHECK(
threadReadSize == size,
"record size ",
threadReadSize,
" mismatch with read size ",
size);
}
});
}
for (auto& thread : loaderThreads) {
thread.join();
}
loaderThreads.clear();
auto total_read_n = std::reduce(readSizes.begin(),readSizes.end());
TORCH_CHECK(
n == total_read_n,
"Multi reader total read size ",
total_read_n,
" mismatch with dst size ",
n);
return total_read_n;
}
// read record with multi clients
std::tuple<at::DataPtr, size_t>
PyTorchStreamReader::getRecord(const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
at::DataPtr retval;
return std::make_tuple(std::move(retval), 0);
}
size_t key = getRecordID(name);
mz_zip_archive_file_stat stat;
mz_zip_reader_file_stat(ar_.get(), key, &stat);
auto n = stat.m_uncomp_size;
valid("retrieving file meta-data for ", name.c_str());
if(additionalReaders.empty() || n < additional_reader_size_threshold_){
// No additional readers or record too small, use single threaded version
return getRecord(name);
}
at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
void* dst = retval.get();
PyTorchStreamReader::getRecordMultiReaders(name, additionalReaders, dst, n);
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
}
// inplace memory writing
size_t
PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) {
@ -369,6 +448,34 @@ PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) {
return stat.m_uncomp_size;
}
// inplace memory writing, in-tensor multi-threads, can be used for large tensor.
size_t
PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
return 0;
}
size_t key = getRecordID(name);
mz_zip_archive_file_stat stat;
mz_zip_reader_file_stat(ar_.get(), key, &stat);
TORCH_CHECK(
n == stat.m_uncomp_size,
"record size ",
stat.m_uncomp_size,
" mismatch with dst size ",
n);
valid("retrieving file meta-data for ", name.c_str());
if(additionalReaders.empty() || n < additional_reader_size_threshold_){
// No additional readers, use single threaded version
return getRecord(name, dst, n);
}
PyTorchStreamReader::getRecordMultiReaders(name, additionalReaders, dst, n);
return stat.m_uncomp_size;
}
size_t PyTorchStreamReader::getRecord(
const std::string& name,
void* dst,

View File

@ -16,6 +16,7 @@
#include "caffe2/serialize/read_adapter_interface.h"
#include "caffe2/serialize/versions.h"
extern "C" {
typedef struct mz_zip_archive mz_zip_archive;
}
@ -126,8 +127,15 @@ class TORCH_API PyTorchStreamReader final {
// return dataptr, size
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
// multi-thread getRecord
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name, std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
// inplace memory writing
size_t getRecord(const std::string& name, void* dst, size_t n);
// inplace memory writing, multi-threads.
// When additionalReaders is empty, the default behavior is call getRecord(name, dst, n) with default reader
// This approach can be used for reading large tensors.
size_t getRecord(const std::string& name, void* dst, size_t n,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
size_t getRecord(
const std::string& name,
void* dst,
@ -136,6 +144,20 @@ class TORCH_API PyTorchStreamReader final {
void* buf,
const std::function<void(void*, const void*, size_t)>& memcpy_func = nullptr);
// Concurrent reading records with multiple readers.
// additionalReaders are additional clients to access the underlying record at different offsets
// and write to different trunks of buffers.
// If the overall size of the tensor is 10, and size of additionalReader is 2.
// The default thread will read [0,4), the additional reader will read [4,8).
// The default reader will read [8,10).
// The default reader will write to buffer[0,4), the additional reader will write to buffer[4,8),
// the additional reader will write to buffer[8,10).
// When additionalReaders is empty, the default behavior is call getRecord(name) with default reader
// This approach can be used for reading large tensors.
size_t getRecordMultiReaders(const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
void *dst, size_t n);
size_t getRecordSize(const std::string& name);
size_t getRecordOffset(const std::string& name);
@ -158,7 +180,9 @@ class TORCH_API PyTorchStreamReader final {
void setShouldLoadDebugSymbol(bool should_load_debug_symbol) {
load_debug_symbol_ = should_load_debug_symbol;
}
void setAdditionalReaderSizeThreshold(const size_t& size){
additional_reader_size_threshold_ = size;
}
private:
void init();
size_t read(uint64_t pos, char* buf, size_t n);
@ -175,6 +199,7 @@ class TORCH_API PyTorchStreamReader final {
std::mutex reader_lock_;
bool load_debug_symbol_ = true;
std::string serialization_id_;
size_t additional_reader_size_threshold_;
};
class TORCH_API PyTorchStreamWriter final {

View File

@ -105,6 +105,86 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
remove(file_name);
}
TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) {
std::ostringstream oss;
// write records through writers
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
oss.write(static_cast<const char*>(b), n);
return oss ? n : 0;
});
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
std::array<char, 127> data1;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
std::array<char, 64> data2;
for (auto i : c10::irange(data1.size())) {
data1[i] = data1.size() - i;
}
writer.writeRecord("key1", data1.data(), data1.size());
for (auto i : c10::irange(data2.size())) {
data2[i] = data2.size() - i;
}
writer.writeRecord("key2", data2.data(), data2.size());
const std::unordered_set<std::string>& written_records =
writer.getAllWrittenRecords();
ASSERT_EQ(written_records.size(), 2);
ASSERT_EQ(written_records.count("key1"), 1);
ASSERT_EQ(written_records.count("key2"), 1);
writer.writeEndOfFile();
ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
std::string the_file = oss.str();
const char* file_name = "output.zip";
std::ofstream foo(file_name);
foo.write(the_file.c_str(), the_file.size());
foo.close();
// read records through pytorchStreamReader
std::istringstream iss(the_file);
PyTorchStreamReader reader(&iss);
reader.setAdditionalReaderSizeThreshold(0);
// before testing, sanity check
int64_t size1, size2, ret;
at::DataPtr data_ptr;
std::tie(data_ptr, size1) = reader.getRecord("key1");
std::tie(data_ptr, size2) = reader.getRecord("key2");
// Test getRecord(name, additional_readers)
std::vector<std::shared_ptr<ReadAdapterInterface>> additionalReader;
for(int i=0; i<10; ++i){
// Test various sized additional readers.
std::tie(data_ptr, ret) = reader.getRecord("key1", additionalReader);
ASSERT_EQ(ret, size1);
ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), size1), 0);
std::tie(data_ptr, ret) = reader.getRecord("key2", additionalReader);
ASSERT_EQ(ret, size2);
ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), size2), 0);
}
// Inplace multi-threading getRecord(name, dst, n, additional_readers) test
additionalReader.clear();
std::vector<uint8_t> dst1(size1), dst2(size2);
for(int i=0; i<10; ++i){
// Test various sizes of read threads
additionalReader.push_back(std::make_unique<IStreamAdapter>(&iss));
ret = reader.getRecord("key1", dst1.data(), size1, additionalReader);
ASSERT_EQ(ret, size1);
ASSERT_EQ(memcmp(dst1.data(), data1.data(), size1), 0);
ret = reader.getRecord("key2", dst2.data(), size2, additionalReader);
ASSERT_EQ(ret, size2);
ASSERT_EQ(memcmp(dst2.data(), data2.data(), size2), 0);
}
// clean up
remove(file_name);
}
TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
std::ostringstream oss;
// write records through writers