mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update PyTorchStreamReader API to take cpu allocator override (#150439)
Summary: Add allocator param in getRecord Test Plan: newly added UT ``` buck test caffe2/caffe2/serialize:inline_container_test ``` Differential Revision: D72252585 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150439 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
b434322075
commit
a6182903cd
@ -361,7 +361,8 @@ size_t PyTorchStreamReader::getRecordID(const std::string& name) {
|
||||
|
||||
// return dataptr, size
|
||||
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
|
||||
const std::string& name) {
|
||||
const std::string& name,
|
||||
std::optional<at::Allocator*> allocator) {
|
||||
std::lock_guard<std::mutex> guard(reader_lock_);
|
||||
if ((!load_debug_symbol_) && c10::ends_with(name, kDebugPklSuffix)) {
|
||||
at::DataPtr retval;
|
||||
@ -371,7 +372,9 @@ std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
|
||||
mz_zip_archive_file_stat stat;
|
||||
mz_zip_reader_file_stat(ar_.get(), key, &stat);
|
||||
valid("retrieving file meta-data for ", name.c_str());
|
||||
at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
|
||||
at::Allocator* allocatorPtr =
|
||||
allocator.has_value() ? allocator.value() : c10::GetCPUAllocator();
|
||||
at::DataPtr retval = allocatorPtr->allocate(stat.m_uncomp_size);
|
||||
mz_zip_reader_extract_to_mem(
|
||||
ar_.get(), key, retval.get(), stat.m_uncomp_size, 0);
|
||||
valid("reading file ", name.c_str());
|
||||
@ -449,10 +452,11 @@ size_t PyTorchStreamReader::getRecordMultiReaders(
|
||||
// read record with multi clients
|
||||
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
|
||||
const std::string& name,
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
|
||||
std::optional<at::Allocator*> allocator) {
|
||||
if (additionalReaders.empty()) {
|
||||
// No additional readers or record too small, use single threaded version
|
||||
return getRecord(name);
|
||||
return getRecord(name, allocator);
|
||||
}
|
||||
|
||||
if ((!load_debug_symbol_) && c10::ends_with(name, kDebugPklSuffix)) {
|
||||
@ -469,7 +473,9 @@ std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
|
||||
return getRecord(name);
|
||||
}
|
||||
|
||||
at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
|
||||
at::Allocator* allocatorPtr =
|
||||
allocator.has_value() ? allocator.value() : c10::GetCPUAllocator();
|
||||
at::DataPtr retval = allocatorPtr->allocate(stat.m_uncomp_size);
|
||||
void* dst = retval.get();
|
||||
PyTorchStreamReader::getRecordMultiReaders(name, additionalReaders, dst, n);
|
||||
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
|
||||
@ -760,11 +766,7 @@ void PyTorchStreamWriter::writeRecord(
|
||||
}
|
||||
std::string full_name = archive_name_plus_slash_ + name;
|
||||
size_t padding_size = detail::getPadding(
|
||||
ar_->m_archive_size,
|
||||
full_name.size(),
|
||||
size,
|
||||
padding_,
|
||||
alignment_);
|
||||
ar_->m_archive_size, full_name.size(), size, padding_, alignment_);
|
||||
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
|
||||
if (!compute_crc32_) {
|
||||
#if (!defined(FBCODE_CAFFE2))
|
||||
|
@ -130,11 +130,15 @@ class TORCH_API PyTorchStreamReader final {
|
||||
explicit PyTorchStreamReader(std::shared_ptr<ReadAdapterInterface> in);
|
||||
|
||||
// return dataptr, size
|
||||
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
|
||||
// set allocator to override default cpu allocator
|
||||
std::tuple<at::DataPtr, size_t> getRecord(
|
||||
const std::string& name,
|
||||
std::optional<at::Allocator*> allocator = std::nullopt);
|
||||
// multi-thread getRecord
|
||||
std::tuple<at::DataPtr, size_t> getRecord(
|
||||
const std::string& name,
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
|
||||
std::optional<at::Allocator*> allocator = std::nullopt);
|
||||
// inplace memory writing
|
||||
size_t getRecord(const std::string& name, void* dst, size_t n);
|
||||
// inplace memory writing, multi-threads.
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include "c10/core/CPUAllocator.h"
|
||||
#include "c10/util/irange.h"
|
||||
#include "caffe2/serialize/inline_container.h"
|
||||
|
||||
@ -432,6 +433,199 @@ TEST(PytorchStreamWriterAndReader, LogAPIUsageMetadata) {
|
||||
const std::map<std::string, std::string>& metadata_map) {});
|
||||
}
|
||||
|
||||
class TestAllocator : public at::Allocator {
|
||||
public:
|
||||
|
||||
explicit TestAllocator(at::Allocator* allocator): baseAllocator_(allocator) {}
|
||||
at::DataPtr allocate(size_t nbytes) override {
|
||||
allocatedBytes_ += nbytes;
|
||||
return baseAllocator_->allocate(nbytes);
|
||||
}
|
||||
at::DeleterFnPtr raw_deleter() const override {
|
||||
return baseAllocator_->raw_deleter();
|
||||
}
|
||||
void copy_data(void* dest, const void* src, std::size_t count) const override {
|
||||
default_copy_data(dest, src, count);
|
||||
}
|
||||
size_t getAllocatedBytes() {
|
||||
return allocatedBytes_;
|
||||
}
|
||||
private:
|
||||
at::Allocator* baseAllocator_;
|
||||
size_t allocatedBytes_{0};
|
||||
};
|
||||
|
||||
TEST(PyTorchStreamWriterAndReader, SaveAndLoadWithAllocator) {
|
||||
// create two test allocators, ones is supposed to be the default allocator
|
||||
// the other one is only used when user specifies it
|
||||
auto defaultAllocator = at::GetCPUAllocator();
|
||||
TestAllocator overrideAllocator(defaultAllocator);
|
||||
TestAllocator baseAllocator(defaultAllocator);
|
||||
c10::SetCPUAllocator(&baseAllocator, 10 /* priority */);
|
||||
|
||||
std::ostringstream oss;
|
||||
// write records through writers
|
||||
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
|
||||
oss.write(static_cast<const char*>(b), n);
|
||||
return oss ? n : 0;
|
||||
});
|
||||
const size_t kBytes1 = 127;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
|
||||
std::array<char, kBytes1> data1;
|
||||
// Inplace memory buffer
|
||||
std::vector<uint8_t> buf(data1.size());
|
||||
|
||||
for (auto i : c10::irange(data1.size())) {
|
||||
data1[i] = data1.size() - i;
|
||||
}
|
||||
writer.writeRecord("key1", data1.data(), data1.size());
|
||||
|
||||
const size_t kBytes2 = 64;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
|
||||
std::array<char, kBytes2> data2;
|
||||
for (auto i : c10::irange(data2.size())) {
|
||||
data2[i] = data2.size() - i;
|
||||
}
|
||||
writer.writeRecord("key2", data2.data(), data2.size());
|
||||
|
||||
const std::unordered_set<std::string>& written_records =
|
||||
writer.getAllWrittenRecords();
|
||||
ASSERT_EQ(written_records.size(), 2);
|
||||
ASSERT_EQ(written_records.count("key1"), 1);
|
||||
ASSERT_EQ(written_records.count("key2"), 1);
|
||||
|
||||
writer.writeEndOfFile();
|
||||
ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
|
||||
|
||||
std::string the_file = oss.str();
|
||||
const char* file_name = "output.zip";
|
||||
std::ofstream foo(file_name);
|
||||
foo.write(the_file.c_str(), the_file.size());
|
||||
foo.close();
|
||||
|
||||
std::istringstream iss(the_file);
|
||||
|
||||
// read records through readers
|
||||
PyTorchStreamReader reader(&iss);
|
||||
ASSERT_TRUE(reader.hasRecord("key1"));
|
||||
ASSERT_TRUE(reader.hasRecord("key2"));
|
||||
ASSERT_FALSE(reader.hasRecord("key2000"));
|
||||
// get the bytes allocated byfore read
|
||||
const auto allocBytes = baseAllocator.getAllocatedBytes();
|
||||
at::DataPtr data_ptr;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
int64_t size;
|
||||
// allocated with override allocator
|
||||
std::tie(data_ptr, size) = reader.getRecord("key1", &overrideAllocator);
|
||||
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1);
|
||||
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes);
|
||||
// allcoate with base allocator
|
||||
std::tie(data_ptr, size) = reader.getRecord("key1");
|
||||
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1);
|
||||
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1);
|
||||
|
||||
std::tie(data_ptr, size) = reader.getRecord("key2", &overrideAllocator);
|
||||
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1 + kBytes2);
|
||||
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1);
|
||||
std::tie(data_ptr, size) = reader.getRecord("key2");
|
||||
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1 + kBytes2);
|
||||
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1 + kBytes2);
|
||||
std::tie(data_ptr, size) = reader.getRecord("key2", &baseAllocator);
|
||||
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1 + 2 * kBytes2);
|
||||
}
|
||||
|
||||
|
||||
TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreadsWithAllocator) {
|
||||
auto defaultAllocator = at::GetCPUAllocator();
|
||||
TestAllocator overrideAllocator(defaultAllocator);
|
||||
TestAllocator baseAllocator(defaultAllocator);
|
||||
c10::SetCPUAllocator(&baseAllocator, 10 /* priority */);
|
||||
std::ostringstream oss;
|
||||
// write records through writers
|
||||
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
|
||||
oss.write(static_cast<const char*>(b), n);
|
||||
return oss ? n : 0;
|
||||
});
|
||||
|
||||
const size_t kBytes1 = 127;
|
||||
const size_t kBytes2 = 64;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
|
||||
std::array<char, kBytes1> data1;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
|
||||
std::array<char, kBytes2> data2;
|
||||
for (auto i : c10::irange(data1.size())) {
|
||||
data1[i] = data1.size() - i;
|
||||
}
|
||||
writer.writeRecord("key1", data1.data(), data1.size());
|
||||
|
||||
for (auto i : c10::irange(data2.size())) {
|
||||
data2[i] = data2.size() - i;
|
||||
}
|
||||
writer.writeRecord("key2", data2.data(), data2.size());
|
||||
|
||||
const std::unordered_set<std::string>& written_records =
|
||||
writer.getAllWrittenRecords();
|
||||
ASSERT_EQ(written_records.size(), 2);
|
||||
ASSERT_EQ(written_records.count("key1"), 1);
|
||||
ASSERT_EQ(written_records.count("key2"), 1);
|
||||
|
||||
writer.writeEndOfFile();
|
||||
ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
|
||||
|
||||
std::string the_file = oss.str();
|
||||
const char* file_name = "output.zip";
|
||||
std::ofstream foo(file_name);
|
||||
foo.write(the_file.c_str(), the_file.size());
|
||||
foo.close();
|
||||
|
||||
// read records through pytorchStreamReader
|
||||
std::istringstream iss(the_file);
|
||||
PyTorchStreamReader reader(&iss);
|
||||
reader.setAdditionalReaderSizeThreshold(0);
|
||||
// before testing, sanity check
|
||||
int64_t size1, size2, ret;
|
||||
at::DataPtr data_ptr;
|
||||
std::tie(data_ptr, size1) = reader.getRecord("key1");
|
||||
std::tie(data_ptr, size2) = reader.getRecord("key2");
|
||||
|
||||
// Test getRecord(name, additional_readers)
|
||||
std::vector<std::shared_ptr<ReadAdapterInterface>> additionalReader;
|
||||
size_t allocatedBytes = 0;
|
||||
auto baseAllocBytes = baseAllocator.getAllocatedBytes();
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
// Test various sized additional readers.
|
||||
std::tie(data_ptr, ret) = reader.getRecord("key1", additionalReader, &overrideAllocator);
|
||||
ASSERT_EQ(ret, size1);
|
||||
allocatedBytes += size1;
|
||||
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), allocatedBytes);
|
||||
EXPECT_EQ(baseAllocator.getAllocatedBytes(), baseAllocBytes);
|
||||
ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), size1), 0);
|
||||
|
||||
baseAllocBytes += size2;
|
||||
std::tie(data_ptr, ret) = reader.getRecord("key2", additionalReader);
|
||||
ASSERT_EQ(ret, size2);
|
||||
ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), size2), 0);
|
||||
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), allocatedBytes);
|
||||
EXPECT_EQ(baseAllocator.getAllocatedBytes(), baseAllocBytes);
|
||||
}
|
||||
|
||||
// Inplace multi-threading getRecord(name, dst, n, additional_readers) test
|
||||
additionalReader.clear();
|
||||
std::vector<uint8_t> dst1(size1), dst2(size2);
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
// Test various sizes of read threads
|
||||
additionalReader.push_back(std::make_unique<IStreamAdapter>(&iss));
|
||||
|
||||
ret = reader.getRecord("key1", dst1.data(), size1, additionalReader);
|
||||
ASSERT_EQ(ret, size1);
|
||||
|
||||
ret = reader.getRecord("key2", dst2.data(), size2, additionalReader);
|
||||
ASSERT_EQ(ret, size2);
|
||||
}
|
||||
// clean up
|
||||
remove(file_name);
|
||||
}
|
||||
|
||||
class ChunkRecordIteratorTest : public ::testing::TestWithParam<int64_t> {};
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ChunkRecordIteratorTestGroup,
|
||||
|
Reference in New Issue
Block a user