mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Added getNextRecord/hasNextRecord methods. Even the model data is stored at the end, we can still read the file from the beginning. Added gtest to cover reader and writer's code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/12993 Reviewed By: yinghai Differential Revision: D10860086 Pulled By: houseroad fbshipit-source-id: 01b1380f8f50f5e853fe48a8136e3176eb3b0c29
73 lines
2.1 KiB
C++
73 lines
2.1 KiB
C++
#include <cstdio>
|
|
#include <string>
|
|
#include <array>
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#include "caffe2/serialize/inline_container.h"
|
|
|
|
namespace at {
|
|
namespace {
|
|
|
|
TEST(PyTorchFileWriterAndReader, SaveAndLoad) {
|
|
int64_t kFieldAlignment = 64L;
|
|
// create a name for temporary file
|
|
// TODO to have different implementation for Windows and POXIS
|
|
std::string tmp_name = std::tmpnam(nullptr);
|
|
|
|
// write records through writers
|
|
torch::jit::PyTorchFileWriter writer{tmp_name};
|
|
std::array<char, 127> data1;
|
|
|
|
for (int i = 0; i < data1.size(); ++i) {
|
|
data1[i] = data1.size() - i;
|
|
}
|
|
ASSERT_EQ(writer.writeRecord(data1.data(), data1.size()), writer.getCurrentSize());
|
|
std::array<char, 64> data2;
|
|
for (int i = 0; i < data2.size(); ++i) {
|
|
data2[i] = data2.size() - i;
|
|
}
|
|
ASSERT_EQ(writer.writeRecord(data2.data(), data2.size()), writer.getCurrentSize());
|
|
writer.writeEndOfFile();
|
|
ASSERT_TRUE(writer.closed());
|
|
|
|
// read records through readers
|
|
torch::jit::PyTorchFileReader reader{tmp_name};
|
|
ASSERT_TRUE(reader.hasNextRecord());
|
|
at::DataPtr data_ptr;
|
|
int64_t key;
|
|
int64_t size;
|
|
std::tie(data_ptr, key, size) = reader.getNextRecord();
|
|
ASSERT_EQ(key, kFieldAlignment);
|
|
ASSERT_EQ(size, data1.size());
|
|
ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0);
|
|
|
|
ASSERT_TRUE(reader.hasNextRecord());
|
|
std::tie(data_ptr, key, size) = reader.getNextRecord();
|
|
ASSERT_EQ(
|
|
key,
|
|
kFieldAlignment * 2 +
|
|
(data1.size() + kFieldAlignment - 1) / kFieldAlignment *
|
|
kFieldAlignment);
|
|
ASSERT_EQ(size, data2.size());
|
|
ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0);
|
|
|
|
ASSERT_FALSE(reader.hasNextRecord());
|
|
|
|
std::tie(data_ptr, size) = reader.getLastRecord();
|
|
ASSERT_EQ(size, data2.size());
|
|
ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0);
|
|
ASSERT_FALSE(reader.hasNextRecord());
|
|
|
|
std::tie(data_ptr, size) = reader.getRecordWithKey(kFieldAlignment);
|
|
ASSERT_EQ(size, data1.size());
|
|
ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0);
|
|
ASSERT_TRUE(reader.hasNextRecord());
|
|
|
|
// clean up
|
|
std::remove(tmp_name.c_str());
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace at
|