Update PyTorchStreamReader API to take cpu allocator override (#150439)

Summary: Add allocator param in getRecord

Test Plan:
newly added UT
```
buck test caffe2/caffe2/serialize:inline_container_test
```

Differential Revision: D72252585

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150439
Approved by: https://github.com/albanD
This commit is contained in:
Xintong Hu
2025-04-18 01:53:14 +00:00
committed by PyTorch MergeBot
parent b434322075
commit a6182903cd
3 changed files with 212 additions and 12 deletions

View File

@ -361,7 +361,8 @@ size_t PyTorchStreamReader::getRecordID(const std::string& name) {
// return dataptr, size
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
const std::string& name) {
const std::string& name,
std::optional<at::Allocator*> allocator) {
std::lock_guard<std::mutex> guard(reader_lock_);
if ((!load_debug_symbol_) && c10::ends_with(name, kDebugPklSuffix)) {
at::DataPtr retval;
@ -371,7 +372,9 @@ std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
mz_zip_archive_file_stat stat;
mz_zip_reader_file_stat(ar_.get(), key, &stat);
valid("retrieving file meta-data for ", name.c_str());
at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
at::Allocator* allocatorPtr =
allocator.has_value() ? allocator.value() : c10::GetCPUAllocator();
at::DataPtr retval = allocatorPtr->allocate(stat.m_uncomp_size);
mz_zip_reader_extract_to_mem(
ar_.get(), key, retval.get(), stat.m_uncomp_size, 0);
valid("reading file ", name.c_str());
@ -449,10 +452,11 @@ size_t PyTorchStreamReader::getRecordMultiReaders(
// read record with multi clients
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
std::optional<at::Allocator*> allocator) {
if (additionalReaders.empty()) {
// No additional readers or record too small, use single threaded version
return getRecord(name);
return getRecord(name, allocator);
}
if ((!load_debug_symbol_) && c10::ends_with(name, kDebugPklSuffix)) {
@ -469,7 +473,9 @@ std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
return getRecord(name);
}
at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
at::Allocator* allocatorPtr =
allocator.has_value() ? allocator.value() : c10::GetCPUAllocator();
at::DataPtr retval = allocatorPtr->allocate(stat.m_uncomp_size);
void* dst = retval.get();
PyTorchStreamReader::getRecordMultiReaders(name, additionalReaders, dst, n);
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
@ -760,11 +766,7 @@ void PyTorchStreamWriter::writeRecord(
}
std::string full_name = archive_name_plus_slash_ + name;
size_t padding_size = detail::getPadding(
ar_->m_archive_size,
full_name.size(),
size,
padding_,
alignment_);
ar_->m_archive_size, full_name.size(), size, padding_, alignment_);
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
if (!compute_crc32_) {
#if (!defined(FBCODE_CAFFE2))

View File

@ -130,11 +130,15 @@ class TORCH_API PyTorchStreamReader final {
explicit PyTorchStreamReader(std::shared_ptr<ReadAdapterInterface> in);
// return dataptr, size
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
// set allocator to override default cpu allocator
std::tuple<at::DataPtr, size_t> getRecord(
const std::string& name,
std::optional<at::Allocator*> allocator = std::nullopt);
// multi-thread getRecord
std::tuple<at::DataPtr, size_t> getRecord(
const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
std::optional<at::Allocator*> allocator = std::nullopt);
// inplace memory writing
size_t getRecord(const std::string& name, void* dst, size_t n);
// inplace memory writing, multi-threads.

View File

@ -6,6 +6,7 @@
#include <gtest/gtest.h>
#include <c10/util/Logging.h>
#include "c10/core/CPUAllocator.h"
#include "c10/util/irange.h"
#include "caffe2/serialize/inline_container.h"
@ -432,6 +433,199 @@ TEST(PytorchStreamWriterAndReader, LogAPIUsageMetadata) {
const std::map<std::string, std::string>& metadata_map) {});
}
class TestAllocator : public at::Allocator {
public:
explicit TestAllocator(at::Allocator* allocator): baseAllocator_(allocator) {}
at::DataPtr allocate(size_t nbytes) override {
allocatedBytes_ += nbytes;
return baseAllocator_->allocate(nbytes);
}
at::DeleterFnPtr raw_deleter() const override {
return baseAllocator_->raw_deleter();
}
void copy_data(void* dest, const void* src, std::size_t count) const override {
default_copy_data(dest, src, count);
}
size_t getAllocatedBytes() {
return allocatedBytes_;
}
private:
at::Allocator* baseAllocator_;
size_t allocatedBytes_{0};
};
TEST(PyTorchStreamWriterAndReader, SaveAndLoadWithAllocator) {
// create two test allocators, ones is supposed to be the default allocator
// the other one is only used when user specifies it
auto defaultAllocator = at::GetCPUAllocator();
TestAllocator overrideAllocator(defaultAllocator);
TestAllocator baseAllocator(defaultAllocator);
c10::SetCPUAllocator(&baseAllocator, 10 /* priority */);
std::ostringstream oss;
// write records through writers
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
oss.write(static_cast<const char*>(b), n);
return oss ? n : 0;
});
const size_t kBytes1 = 127;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
std::array<char, kBytes1> data1;
// Inplace memory buffer
std::vector<uint8_t> buf(data1.size());
for (auto i : c10::irange(data1.size())) {
data1[i] = data1.size() - i;
}
writer.writeRecord("key1", data1.data(), data1.size());
const size_t kBytes2 = 64;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
std::array<char, kBytes2> data2;
for (auto i : c10::irange(data2.size())) {
data2[i] = data2.size() - i;
}
writer.writeRecord("key2", data2.data(), data2.size());
const std::unordered_set<std::string>& written_records =
writer.getAllWrittenRecords();
ASSERT_EQ(written_records.size(), 2);
ASSERT_EQ(written_records.count("key1"), 1);
ASSERT_EQ(written_records.count("key2"), 1);
writer.writeEndOfFile();
ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
std::string the_file = oss.str();
const char* file_name = "output.zip";
std::ofstream foo(file_name);
foo.write(the_file.c_str(), the_file.size());
foo.close();
std::istringstream iss(the_file);
// read records through readers
PyTorchStreamReader reader(&iss);
ASSERT_TRUE(reader.hasRecord("key1"));
ASSERT_TRUE(reader.hasRecord("key2"));
ASSERT_FALSE(reader.hasRecord("key2000"));
// get the bytes allocated byfore read
const auto allocBytes = baseAllocator.getAllocatedBytes();
at::DataPtr data_ptr;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t size;
// allocated with override allocator
std::tie(data_ptr, size) = reader.getRecord("key1", &overrideAllocator);
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1);
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes);
// allcoate with base allocator
std::tie(data_ptr, size) = reader.getRecord("key1");
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1);
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1);
std::tie(data_ptr, size) = reader.getRecord("key2", &overrideAllocator);
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1 + kBytes2);
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1);
std::tie(data_ptr, size) = reader.getRecord("key2");
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1 + kBytes2);
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1 + kBytes2);
std::tie(data_ptr, size) = reader.getRecord("key2", &baseAllocator);
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1 + 2 * kBytes2);
}
TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreadsWithAllocator) {
auto defaultAllocator = at::GetCPUAllocator();
TestAllocator overrideAllocator(defaultAllocator);
TestAllocator baseAllocator(defaultAllocator);
c10::SetCPUAllocator(&baseAllocator, 10 /* priority */);
std::ostringstream oss;
// write records through writers
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
oss.write(static_cast<const char*>(b), n);
return oss ? n : 0;
});
const size_t kBytes1 = 127;
const size_t kBytes2 = 64;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
std::array<char, kBytes1> data1;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
std::array<char, kBytes2> data2;
for (auto i : c10::irange(data1.size())) {
data1[i] = data1.size() - i;
}
writer.writeRecord("key1", data1.data(), data1.size());
for (auto i : c10::irange(data2.size())) {
data2[i] = data2.size() - i;
}
writer.writeRecord("key2", data2.data(), data2.size());
const std::unordered_set<std::string>& written_records =
writer.getAllWrittenRecords();
ASSERT_EQ(written_records.size(), 2);
ASSERT_EQ(written_records.count("key1"), 1);
ASSERT_EQ(written_records.count("key2"), 1);
writer.writeEndOfFile();
ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
std::string the_file = oss.str();
const char* file_name = "output.zip";
std::ofstream foo(file_name);
foo.write(the_file.c_str(), the_file.size());
foo.close();
// read records through pytorchStreamReader
std::istringstream iss(the_file);
PyTorchStreamReader reader(&iss);
reader.setAdditionalReaderSizeThreshold(0);
// before testing, sanity check
int64_t size1, size2, ret;
at::DataPtr data_ptr;
std::tie(data_ptr, size1) = reader.getRecord("key1");
std::tie(data_ptr, size2) = reader.getRecord("key2");
// Test getRecord(name, additional_readers)
std::vector<std::shared_ptr<ReadAdapterInterface>> additionalReader;
size_t allocatedBytes = 0;
auto baseAllocBytes = baseAllocator.getAllocatedBytes();
for (int i = 0; i < 10; ++i) {
// Test various sized additional readers.
std::tie(data_ptr, ret) = reader.getRecord("key1", additionalReader, &overrideAllocator);
ASSERT_EQ(ret, size1);
allocatedBytes += size1;
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), allocatedBytes);
EXPECT_EQ(baseAllocator.getAllocatedBytes(), baseAllocBytes);
ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), size1), 0);
baseAllocBytes += size2;
std::tie(data_ptr, ret) = reader.getRecord("key2", additionalReader);
ASSERT_EQ(ret, size2);
ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), size2), 0);
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), allocatedBytes);
EXPECT_EQ(baseAllocator.getAllocatedBytes(), baseAllocBytes);
}
// Inplace multi-threading getRecord(name, dst, n, additional_readers) test
additionalReader.clear();
std::vector<uint8_t> dst1(size1), dst2(size2);
for (int i = 0; i < 10; ++i) {
// Test various sizes of read threads
additionalReader.push_back(std::make_unique<IStreamAdapter>(&iss));
ret = reader.getRecord("key1", dst1.data(), size1, additionalReader);
ASSERT_EQ(ret, size1);
ret = reader.getRecord("key2", dst2.data(), size2, additionalReader);
ASSERT_EQ(ret, size2);
}
// clean up
remove(file_name);
}
class ChunkRecordIteratorTest : public ::testing::TestWithParam<int64_t> {};
INSTANTIATE_TEST_SUITE_P(
ChunkRecordIteratorTestGroup,