mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: To implement a stream is very annoying, since it is closely defined with the underlying storage streambuffer. So in this PR, we add ReadAdapterInterface and PyTorchStreamReader will use it. We implement IStreamAdapter as a wrapper of std::istream. And keep the user interface unchanged. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15551 Reviewed By: zrphercule Differential Revision: D13568907 Pulled By: houseroad fbshipit-source-id: 93708cb801248a6c101f35cb14d1631029365c3c
319 lines
9.6 KiB
C++
319 lines
9.6 KiB
C++
#include <cstdio>
|
|
#include <cstring>
|
|
#include <cerrno>
|
|
#include <istream>
|
|
#include <ostream>
|
|
#include <fstream>
|
|
|
|
#include <c10/core/Allocator.h>
|
|
#include <c10/core/Backend.h>
|
|
|
|
#include "caffe2/core/common.h"
|
|
#include "caffe2/core/logging.h"
|
|
#include "caffe2/serialize/file_adapter.h"
|
|
#include "caffe2/serialize/inline_container.h"
|
|
#include "caffe2/serialize/istream_adapter.h"
|
|
#include "caffe2/serialize/read_adapter_interface.h"
|
|
|
|
#include "miniz.h"
|
|
|
|
namespace caffe2 {
|
|
namespace serialize {
|
|
|
|
size_t istream_read_func(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) {
|
|
auto self = static_cast<PyTorchStreamReader*>(pOpaque);
|
|
return self->read(file_ofs, static_cast<char*>(pBuf), n);
|
|
}
|
|
|
|
static std::string basename(const std::string& name) {
|
|
size_t start = 0;
|
|
for(size_t i = 0; i < name.size(); ++i) {
|
|
if (name[i] == '\\' || name[i] == '/') {
|
|
start = i + 1;
|
|
}
|
|
}
|
|
|
|
if (start >= name.size())
|
|
return "";
|
|
|
|
size_t end = name.size();
|
|
for(size_t i = end; i > start; --i) {
|
|
if (name[i - 1] == '.') {
|
|
end = i - 1;
|
|
break;
|
|
}
|
|
}
|
|
return name.substr(start, end - start);
|
|
}
|
|
|
|
size_t PyTorchStreamReader::read(uint64_t pos, char* buf, size_t n) {
|
|
return in_->read(pos, buf, n, "reading file");
|
|
}
|
|
|
|
PyTorchStreamReader::PyTorchStreamReader(const std::string& file_name)
|
|
: ar_(caffe2::make_unique<mz_zip_archive>()),
|
|
in_(caffe2::make_unique<FileAdapter>(file_name)) {
|
|
init();
|
|
}
|
|
|
|
PyTorchStreamReader::PyTorchStreamReader(std::istream* in)
|
|
: ar_(caffe2::make_unique<mz_zip_archive>()),
|
|
in_(caffe2::make_unique<IStreamAdapter>(in)) {
|
|
init();
|
|
}
|
|
|
|
PyTorchStreamReader::PyTorchStreamReader(
|
|
std::unique_ptr<ReadAdapterInterface> in)
|
|
: ar_(caffe2::make_unique<mz_zip_archive>()), in_(std::move(in)) {
|
|
init();
|
|
}
|
|
|
|
void PyTorchStreamReader::init() {
|
|
AT_ASSERT(in_ != nullptr);
|
|
AT_ASSERT(ar_ != nullptr);
|
|
memset(ar_.get(), 0, sizeof(mz_zip_archive));
|
|
|
|
size_t size = in_->size();
|
|
|
|
// check for the old magic number,
|
|
constexpr size_t kMagicValueLength = 8;
|
|
if (size > kMagicValueLength) {
|
|
char buf[kMagicValueLength];
|
|
read(0, buf, kMagicValueLength);
|
|
valid("checking magic number");
|
|
AT_ASSERTM(
|
|
memcmp("PYTORCH1", buf, kMagicValueLength) != 0,
|
|
"File is an unsupported archive format from the preview release.");
|
|
}
|
|
|
|
ar_->m_pIO_opaque = this;
|
|
ar_->m_pRead = istream_read_func;
|
|
|
|
mz_zip_reader_init(ar_.get(), size, 0);
|
|
valid("reading zip archive");
|
|
|
|
// figure out the archive_name (i.e. the zip folder all the other files are in)
|
|
// all lookups to getRecord will be prefixed by this folder
|
|
int n = mz_zip_reader_get_num_files(ar_.get());
|
|
if (n == 0) {
|
|
CAFFE_THROW("archive does not contain any files");
|
|
}
|
|
size_t name_size = mz_zip_reader_get_filename(ar_.get(), 0, nullptr, 0);
|
|
valid("getting filename");
|
|
std::string buf(name_size, '\0');
|
|
mz_zip_reader_get_filename(ar_.get(), 0, &buf[0], name_size);
|
|
valid("getting filename");
|
|
auto pos = buf.find_first_of('/');
|
|
if (pos == std::string::npos) {
|
|
CAFFE_THROW("file in archive is not in a subdirectory: ", buf);
|
|
}
|
|
archive_name_ = buf.substr(0, pos);
|
|
|
|
// version check
|
|
at::DataPtr version_ptr;
|
|
size_t version_size;
|
|
std::tie(version_ptr, version_size) = getRecord("version");
|
|
std::string version(static_cast<const char*>(version_ptr.get()), version_size);
|
|
size_t version_number = caffe2::stoull(version);
|
|
AT_ASSERTM(
|
|
version_number >= kMinSupportedFileFormatVersion,
|
|
"Attempted to read a PyTorch file with version ",
|
|
c10::to_string(version_number),
|
|
", but the minimum supported version for reading is ",
|
|
c10::to_string(kMinSupportedFileFormatVersion),
|
|
". Your PyTorch script module file is too old. Please re-export it again.");
|
|
AT_ASSERTM(
|
|
version_number <= kMaxSupportedFileFormatVersion,
|
|
"Attempted to read a PyTorch file with version ",
|
|
version_number,
|
|
", but the maximum supported version for reading is ",
|
|
kMaxSupportedFileFormatVersion,
|
|
". Your PyTorch installation may be too old.");
|
|
}
|
|
|
|
void PyTorchStreamReader::valid(const char* what) {
|
|
auto err = mz_zip_get_last_error(ar_.get());
|
|
if (err != MZ_ZIP_NO_ERROR) {
|
|
CAFFE_THROW("PytorchStreamReader failed ", what, ": ", mz_zip_get_error_string(err));
|
|
}
|
|
}
|
|
|
|
constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
|
|
constexpr int MZ_ZIP_LDH_FILENAME_LEN_OFS = 26;
|
|
constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28;
|
|
|
|
static std::string getPadding(size_t cursor, const std::string& filename, size_t size) {
|
|
size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename.size() + sizeof(mz_uint16) * 2;
|
|
if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
|
|
start += sizeof(mz_uint16) * 2;
|
|
if (size >= MZ_UINT32_MAX) {
|
|
start += 2*sizeof(mz_uint64);
|
|
}
|
|
if (cursor >= MZ_UINT32_MAX) {
|
|
start += sizeof(mz_uint64);
|
|
}
|
|
}
|
|
size_t mod = start % kFieldAlignment;
|
|
size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
|
|
size_t padding_size = next_offset - start;
|
|
std::string buf(padding_size + 4, 'Z');
|
|
// zip extra encoding (key, size_of_extra_bytes)
|
|
buf[0] = 'F';
|
|
buf[1] = 'B';
|
|
buf[2] = (uint8_t) padding_size;
|
|
buf[3] = (uint8_t) (padding_size >> 8);
|
|
return buf;
|
|
}
|
|
|
|
size_t PyTorchStreamReader::getFileID(const std::string& name) {
|
|
std::stringstream ss;
|
|
ss << archive_name_ << "/" << name;
|
|
size_t result = mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(), nullptr, 0);
|
|
if (ar_->m_last_error == MZ_ZIP_FILE_NOT_FOUND) {
|
|
CAFFE_THROW("file not found: ", ss.str());
|
|
}
|
|
valid("locating file");
|
|
return result;
|
|
}
|
|
|
|
// return dataptr, size
|
|
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
|
|
size_t key = getFileID(name);
|
|
mz_zip_archive_file_stat stat;
|
|
mz_zip_reader_file_stat(ar_.get(), key, &stat);
|
|
valid("retrieving file meta-data");
|
|
void * ptr = malloc(stat.m_uncomp_size);
|
|
mz_zip_reader_extract_to_mem(ar_.get(), key, ptr, stat.m_uncomp_size, 0);
|
|
valid("reading file");
|
|
|
|
at::DataPtr retval(ptr, ptr, free, at::kCPU);
|
|
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
|
|
}
|
|
|
|
static int64_t read_le_16(uint8_t* buf) {
|
|
return buf[0] + (buf[1] << 8);
|
|
}
|
|
|
|
size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
|
|
mz_zip_archive_file_stat stat;
|
|
mz_zip_reader_file_stat(ar_.get(), getFileID(name), &stat);
|
|
valid("retriving file meta-data");
|
|
uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
|
|
in_->read(
|
|
stat.m_local_header_ofs,
|
|
local_header,
|
|
MZ_ZIP_LOCAL_DIR_HEADER_SIZE,
|
|
"reading file header");
|
|
size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
|
|
size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
|
|
return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
|
|
}
|
|
|
|
|
|
PyTorchStreamReader::~PyTorchStreamReader() {
|
|
mz_zip_reader_end(ar_.get());
|
|
valid("closing reader");
|
|
}
|
|
|
|
size_t ostream_write_func(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n) {
|
|
auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
|
|
if (self->current_pos_ != file_ofs) {
|
|
// xxx - windows ostringstream refuses to seek to the end of an empty string
|
|
// so we workaround this by not calling seek unless necessary
|
|
// in the case of the first write (to the empty string) file_ofs and
|
|
// current_pos_ will be 0 and the seek won't occur.
|
|
self->out_->seekp(file_ofs);
|
|
if(!*self->out_)
|
|
return 0;
|
|
}
|
|
|
|
self->out_->write(static_cast<const char*>(pBuf), n);
|
|
if(!*self->out_)
|
|
return 0;
|
|
self->current_pos_ = file_ofs + n;
|
|
return n;
|
|
}
|
|
|
|
PyTorchStreamWriter::PyTorchStreamWriter(
|
|
std::string file_name,
|
|
std::ostream* out)
|
|
: ar_(caffe2::make_unique<mz_zip_archive>()),
|
|
archive_name_(basename(file_name)),
|
|
out_(out) {
|
|
memset(ar_.get(), 0, sizeof(mz_zip_archive));
|
|
|
|
if (archive_name_.size() == 0) {
|
|
CAFFE_THROW("invalid file name: ", file_name);
|
|
}
|
|
if (!out_) {
|
|
file_stream_.open(file_name, std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
|
|
out_ = &file_stream_;
|
|
valid("opening archive");
|
|
}
|
|
|
|
ar_->m_pIO_opaque = this;
|
|
ar_->m_pWrite = ostream_write_func;
|
|
|
|
mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
|
|
valid("initializing archive");
|
|
|
|
std::stringstream version;
|
|
version << kMaxSupportedFileFormatVersion << "\n";
|
|
writeRecord("version", version.str().c_str(), version.str().size());
|
|
}
|
|
|
|
void PyTorchStreamWriter::writeRecord(const std::string& name, const void* data, size_t size) {
|
|
AT_ASSERT(!finalized_);
|
|
std::stringstream ss;
|
|
ss << archive_name_ << "/" << name;
|
|
const std::string& full_name = ss.str();
|
|
std::string padding = getPadding(ar_->m_archive_size, full_name, size);
|
|
uint32_t flags = 0;
|
|
mz_zip_writer_add_mem_ex_v2(
|
|
ar_.get(),
|
|
full_name.c_str(),
|
|
data,
|
|
size,
|
|
nullptr,
|
|
0,
|
|
flags,
|
|
0,
|
|
0,
|
|
nullptr,
|
|
padding.c_str(),
|
|
padding.size(),
|
|
nullptr,
|
|
0);
|
|
valid("writing file");
|
|
}
|
|
|
|
void PyTorchStreamWriter::writeEndOfFile() {
|
|
AT_ASSERT(!finalized_);
|
|
finalized_ = true;
|
|
mz_zip_writer_finalize_archive(ar_.get());
|
|
mz_zip_writer_end(ar_.get());
|
|
valid("writing central directory");
|
|
if (file_stream_.is_open())
|
|
file_stream_.close();
|
|
}
|
|
|
|
|
|
void PyTorchStreamWriter::valid(const char* what) {
|
|
auto err = mz_zip_get_last_error(ar_.get());
|
|
if (err != MZ_ZIP_NO_ERROR) {
|
|
CAFFE_THROW("PytorchStreamWriter failed ", what, ": ", mz_zip_get_error_string(err));
|
|
}
|
|
if (!*out_) {
|
|
CAFFE_THROW("PytorchStreamWriter failed ", what, ".");
|
|
}
|
|
}
|
|
|
|
PyTorchStreamWriter::~PyTorchStreamWriter() {
|
|
if (!finalized_) {
|
|
writeEndOfFile();
|
|
}
|
|
}
|
|
|
|
} // namespace serialize
|
|
} // namespace caffe2
|