Files
pytorch/caffe2/serialize/inline_container_test.cc
Lu Fang 9f9f06c937 Improve inline container and add some test (#12993)
Summary:
Added getNextRecord/hasNextRecord methods. Even the model data is stored at the end, we can still read the file from the beginning.

Added gtest to cover reader and writer's code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12993

Reviewed By: yinghai

Differential Revision: D10860086

Pulled By: houseroad

fbshipit-source-id: 01b1380f8f50f5e853fe48a8136e3176eb3b0c29
2018-10-26 12:06:47 -07:00

73 lines
2.1 KiB
C++

#include <cstdio>
#include <string>
#include <array>
#include <gtest/gtest.h>
#include "caffe2/serialize/inline_container.h"
namespace at {
namespace {
TEST(PyTorchFileWriterAndReader, SaveAndLoad) {
int64_t kFieldAlignment = 64L;
// create a name for temporary file
// TODO to have different implementation for Windows and POXIS
std::string tmp_name = std::tmpnam(nullptr);
// write records through writers
torch::jit::PyTorchFileWriter writer{tmp_name};
std::array<char, 127> data1;
for (int i = 0; i < data1.size(); ++i) {
data1[i] = data1.size() - i;
}
ASSERT_EQ(writer.writeRecord(data1.data(), data1.size()), writer.getCurrentSize());
std::array<char, 64> data2;
for (int i = 0; i < data2.size(); ++i) {
data2[i] = data2.size() - i;
}
ASSERT_EQ(writer.writeRecord(data2.data(), data2.size()), writer.getCurrentSize());
writer.writeEndOfFile();
ASSERT_TRUE(writer.closed());
// read records through readers
torch::jit::PyTorchFileReader reader{tmp_name};
ASSERT_TRUE(reader.hasNextRecord());
at::DataPtr data_ptr;
int64_t key;
int64_t size;
std::tie(data_ptr, key, size) = reader.getNextRecord();
ASSERT_EQ(key, kFieldAlignment);
ASSERT_EQ(size, data1.size());
ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0);
ASSERT_TRUE(reader.hasNextRecord());
std::tie(data_ptr, key, size) = reader.getNextRecord();
ASSERT_EQ(
key,
kFieldAlignment * 2 +
(data1.size() + kFieldAlignment - 1) / kFieldAlignment *
kFieldAlignment);
ASSERT_EQ(size, data2.size());
ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0);
ASSERT_FALSE(reader.hasNextRecord());
std::tie(data_ptr, size) = reader.getLastRecord();
ASSERT_EQ(size, data2.size());
ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0);
ASSERT_FALSE(reader.hasNextRecord());
std::tie(data_ptr, size) = reader.getRecordWithKey(kFieldAlignment);
ASSERT_EQ(size, data1.size());
ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0);
ASSERT_TRUE(reader.hasNextRecord());
// clean up
std::remove(tmp_name.c_str());
}
} // namespace
} // namespace at