mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adding a hook (wrapper) for non-std stream reader in PyTorchStreamReader (#15551)
Summary: To implement a stream is very annoying, since it is closely defined with the underlying storage streambuffer. So in this PR, we add ReadAdapterInterface and PyTorchStreamReader will use it. We implement IStreamAdapter as a wrapper of std::istream. And keep the user interface unchanged. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15551 Reviewed By: zrphercule Differential Revision: D13568907 Pulled By: houseroad fbshipit-source-id: 93708cb801248a6c101f35cb14d1631029365c3c
This commit is contained in:
committed by
Facebook Github Bot
parent
1488c5dd03
commit
a918f1d9af
@ -3,7 +3,10 @@ file(GLOB tmp *_test.cc)
|
||||
set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} ${tmp})
|
||||
list(APPEND Caffe2_CPU_SRCS
|
||||
${PROJECT_SOURCE_DIR}/third_party/miniz-2.0.8/miniz.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inline_container.cc)
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inline_container.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/istream_adapter.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/file_adapter.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/read_adapter_interface.cc)
|
||||
list(APPEND Caffe2_CPU_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/miniz-2.0.8)
|
||||
|
||||
set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE)
|
||||
|
28
caffe2/serialize/file_adapter.cc
Normal file
28
caffe2/serialize/file_adapter.cc
Normal file
@ -0,0 +1,28 @@
|
||||
#include "caffe2/serialize/file_adapter.h"
|
||||
#include <c10/util/Exception.h>
|
||||
#include "caffe2/core/common.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
|
||||
FileAdapter::FileAdapter(const std::string& file_name) {
|
||||
file_stream_.open(file_name, std::ifstream::in | std::ifstream::binary);
|
||||
if (!file_stream_) {
|
||||
AT_ERROR("open file failed, file path: ", file_name);
|
||||
}
|
||||
istream_adapter_ = caffe2::make_unique<IStreamAdapter>(&file_stream_);
|
||||
}
|
||||
|
||||
size_t FileAdapter::size() const {
|
||||
return istream_adapter_->size();
|
||||
}
|
||||
|
||||
size_t FileAdapter::read(uint64_t pos, void* buf, size_t n, const char* what)
|
||||
const {
|
||||
return istream_adapter_->read(pos, buf, n, what);
|
||||
}
|
||||
|
||||
FileAdapter::~FileAdapter() {}
|
||||
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
28
caffe2/serialize/file_adapter.h
Normal file
28
caffe2/serialize/file_adapter.h
Normal file
@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include "caffe2/serialize/istream_adapter.h"
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
|
||||
class FileAdapter final : public ReadAdapterInterface {
|
||||
public:
|
||||
C10_DISABLE_COPY_AND_ASSIGN(FileAdapter);
|
||||
explicit FileAdapter(const std::string& file_name);
|
||||
size_t size() const override;
|
||||
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
|
||||
const override;
|
||||
~FileAdapter();
|
||||
|
||||
private:
|
||||
std::ifstream file_stream_;
|
||||
std::unique_ptr<IStreamAdapter> istream_adapter_;
|
||||
};
|
||||
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
@ -8,12 +8,17 @@
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/Backend.h>
|
||||
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/serialize/file_adapter.h"
|
||||
#include "caffe2/serialize/inline_container.h"
|
||||
#include "caffe2/serialize/istream_adapter.h"
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
|
||||
#include "miniz.h"
|
||||
|
||||
namespace torch { namespace jit {
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
|
||||
size_t istream_read_func(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) {
|
||||
auto self = static_cast<PyTorchStreamReader*>(pOpaque);
|
||||
@ -42,27 +47,33 @@ static std::string basename(const std::string& name) {
|
||||
}
|
||||
|
||||
size_t PyTorchStreamReader::read(uint64_t pos, char* buf, size_t n) {
|
||||
in_->seekg(pos);
|
||||
if(!*in_)
|
||||
return 0;
|
||||
in_->read(static_cast<char*>(buf), n);
|
||||
if(!*in_)
|
||||
return 0;
|
||||
return n;
|
||||
return in_->read(pos, buf, n, "reading file");
|
||||
}
|
||||
|
||||
PyTorchStreamReader::PyTorchStreamReader(std::string file_name, std::istream* in)
|
||||
: ar_(new mz_zip_archive), in_(in) {
|
||||
PyTorchStreamReader::PyTorchStreamReader(const std::string& file_name)
|
||||
: ar_(caffe2::make_unique<mz_zip_archive>()),
|
||||
in_(caffe2::make_unique<FileAdapter>(file_name)) {
|
||||
init();
|
||||
}
|
||||
|
||||
PyTorchStreamReader::PyTorchStreamReader(std::istream* in)
|
||||
: ar_(caffe2::make_unique<mz_zip_archive>()),
|
||||
in_(caffe2::make_unique<IStreamAdapter>(in)) {
|
||||
init();
|
||||
}
|
||||
|
||||
PyTorchStreamReader::PyTorchStreamReader(
|
||||
std::unique_ptr<ReadAdapterInterface> in)
|
||||
: ar_(caffe2::make_unique<mz_zip_archive>()), in_(std::move(in)) {
|
||||
init();
|
||||
}
|
||||
|
||||
void PyTorchStreamReader::init() {
|
||||
AT_ASSERT(in_ != nullptr);
|
||||
AT_ASSERT(ar_ != nullptr);
|
||||
memset(ar_.get(), 0, sizeof(mz_zip_archive));
|
||||
|
||||
if (!in_) {
|
||||
file_stream_.open(file_name, std::ifstream::in | std::ifstream::binary);
|
||||
in_ = &file_stream_;
|
||||
valid("opening archive");
|
||||
}
|
||||
|
||||
in_->seekg(0, in_->end);
|
||||
size_t size = in_->tellg();
|
||||
size_t size = in_->size();
|
||||
|
||||
// check for the old magic number,
|
||||
constexpr size_t kMagicValueLength = 8;
|
||||
@ -81,7 +92,6 @@ PyTorchStreamReader::PyTorchStreamReader(std::string file_name, std::istream* in
|
||||
mz_zip_reader_init(ar_.get(), size, 0);
|
||||
valid("reading zip archive");
|
||||
|
||||
|
||||
// figure out the archive_name (i.e. the zip folder all the other files are in)
|
||||
// all lookups to getRecord will be prefixed by this folder
|
||||
int n = mz_zip_reader_get_num_files(ar_.get());
|
||||
@ -126,9 +136,6 @@ void PyTorchStreamReader::valid(const char* what) {
|
||||
if (err != MZ_ZIP_NO_ERROR) {
|
||||
CAFFE_THROW("PytorchStreamReader failed ", what, ": ", mz_zip_get_error_string(err));
|
||||
}
|
||||
if (!*in_) {
|
||||
CAFFE_THROW("PytorchStreamReader failed ", what, ".");
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
|
||||
@ -191,11 +198,12 @@ size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
|
||||
mz_zip_archive_file_stat stat;
|
||||
mz_zip_reader_file_stat(ar_.get(), getFileID(name), &stat);
|
||||
valid("retriving file meta-data");
|
||||
in_->seekg(stat.m_local_header_ofs);
|
||||
valid("seeking to file header");
|
||||
uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
|
||||
in_->read(reinterpret_cast<char*>(local_header), MZ_ZIP_LOCAL_DIR_HEADER_SIZE);
|
||||
valid("reading file header");
|
||||
in_->read(
|
||||
stat.m_local_header_ofs,
|
||||
local_header,
|
||||
MZ_ZIP_LOCAL_DIR_HEADER_SIZE,
|
||||
"reading file header");
|
||||
size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
|
||||
size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
|
||||
return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
|
||||
@ -226,8 +234,12 @@ size_t ostream_write_func(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, s
|
||||
return n;
|
||||
}
|
||||
|
||||
PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name, std::ostream* out)
|
||||
: ar_(new mz_zip_archive), archive_name_(basename(file_name)), out_(out) {
|
||||
PyTorchStreamWriter::PyTorchStreamWriter(
|
||||
std::string file_name,
|
||||
std::ostream* out)
|
||||
: ar_(caffe2::make_unique<mz_zip_archive>()),
|
||||
archive_name_(basename(file_name)),
|
||||
out_(out) {
|
||||
memset(ar_.get(), 0, sizeof(mz_zip_archive));
|
||||
|
||||
if (archive_name_.size() == 0) {
|
||||
@ -302,4 +314,5 @@ PyTorchStreamWriter::~PyTorchStreamWriter() {
|
||||
}
|
||||
}
|
||||
|
||||
}} // namespace torch::jit
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
||||
|
@ -11,6 +11,8 @@
|
||||
#include <c10/core/Backend.h>
|
||||
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/serialize/istream_adapter.h"
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
|
||||
extern "C" {
|
||||
typedef struct mz_zip_archive mz_zip_archive;
|
||||
@ -84,7 +86,8 @@ typedef struct mz_zip_archive mz_zip_archive;
|
||||
// model.json as the last file when writing after we have accumulated all
|
||||
// other information.
|
||||
|
||||
namespace torch { namespace jit {
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
|
||||
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x1L;
|
||||
@ -97,9 +100,9 @@ constexpr uint64_t kFieldAlignment = 64;
|
||||
|
||||
class CAFFE2_API PyTorchStreamReader final {
|
||||
public:
|
||||
PyTorchStreamReader(std::string archive_name, std::istream* in=nullptr);
|
||||
PyTorchStreamReader(std::istream* in)
|
||||
: PyTorchStreamReader("archive", in) {}
|
||||
explicit PyTorchStreamReader(const std::string& file_name);
|
||||
explicit PyTorchStreamReader(std::istream* in);
|
||||
explicit PyTorchStreamReader(std::unique_ptr<ReadAdapterInterface> in);
|
||||
|
||||
// return dataptr, size
|
||||
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
|
||||
@ -109,15 +112,16 @@ class CAFFE2_API PyTorchStreamReader final {
|
||||
~PyTorchStreamReader();
|
||||
|
||||
private:
|
||||
size_t read(uint64_t pos, char* buf, size_t n);
|
||||
void valid(const char* what);
|
||||
size_t getFileID(const std::string& name);
|
||||
void init();
|
||||
size_t read(uint64_t pos, char* buf, size_t n);
|
||||
void valid(const char* what);
|
||||
size_t getFileID(const std::string& name);
|
||||
|
||||
friend size_t istream_read_func(void *pOpaque, uint64_t file_ofs, void *pBuf, size_t n);
|
||||
std::unique_ptr<mz_zip_archive> ar_;
|
||||
std::string archive_name_;
|
||||
std::istream* in_;
|
||||
std::ifstream file_stream_;
|
||||
friend size_t
|
||||
istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n);
|
||||
std::unique_ptr<mz_zip_archive> ar_;
|
||||
std::string archive_name_;
|
||||
std::unique_ptr<ReadAdapterInterface> in_;
|
||||
};
|
||||
|
||||
class CAFFE2_API PyTorchStreamWriter final {
|
||||
@ -150,4 +154,5 @@ class CAFFE2_API PyTorchStreamWriter final {
|
||||
friend size_t ostream_write_func(void *pOpaque, uint64_t file_ofs, const void *pBuf, size_t n);
|
||||
};
|
||||
|
||||
}} // namespace torch::jit
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
||||
|
@ -6,7 +6,8 @@
|
||||
|
||||
#include "caffe2/serialize/inline_container.h"
|
||||
|
||||
namespace at {
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
namespace {
|
||||
|
||||
TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||
@ -14,7 +15,7 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||
|
||||
std::ostringstream oss;
|
||||
// write records through writers
|
||||
torch::jit::PyTorchStreamWriter writer(&oss);
|
||||
PyTorchStreamWriter writer(&oss);
|
||||
std::array<char, 127> data1;
|
||||
|
||||
for (int i = 0; i < data1.size(); ++i) {
|
||||
@ -37,7 +38,7 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||
std::istringstream iss(the_file);
|
||||
|
||||
// read records through readers
|
||||
torch::jit::PyTorchStreamReader reader(&iss);
|
||||
PyTorchStreamReader reader(&iss);
|
||||
at::DataPtr data_ptr;
|
||||
int64_t size;
|
||||
std::tie(data_ptr, size) = reader.getRecord("key1");
|
||||
@ -58,4 +59,5 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace at
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
||||
|
39
caffe2/serialize/istream_adapter.cc
Normal file
39
caffe2/serialize/istream_adapter.cc
Normal file
@ -0,0 +1,39 @@
|
||||
#include "caffe2/serialize/istream_adapter.h"
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
|
||||
IStreamAdapter::IStreamAdapter(std::istream* istream) : istream_(istream) {}
|
||||
|
||||
size_t IStreamAdapter::size() const {
|
||||
auto prev_pos = istream_->tellg();
|
||||
validate("getting the current position");
|
||||
istream_->seekg(0, istream_->end);
|
||||
validate("seeking to end");
|
||||
auto result = istream_->tellg();
|
||||
validate("getting size");
|
||||
istream_->seekg(prev_pos);
|
||||
validate("seeking to the original position");
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t IStreamAdapter::read(uint64_t pos, void* buf, size_t n, const char* what)
|
||||
const {
|
||||
istream_->seekg(pos);
|
||||
validate(what);
|
||||
istream_->read(static_cast<char*>(buf), n);
|
||||
validate(what);
|
||||
return n;
|
||||
}
|
||||
|
||||
void IStreamAdapter::validate(const char* what) const {
|
||||
if (!*istream_) {
|
||||
AT_ERROR("istream reader failed: ", what, ".");
|
||||
}
|
||||
}
|
||||
|
||||
IStreamAdapter::~IStreamAdapter() {}
|
||||
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
28
caffe2/serialize/istream_adapter.h
Normal file
28
caffe2/serialize/istream_adapter.h
Normal file
@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include <istream>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
|
||||
// this is a reader implemented by std::istream
|
||||
class IStreamAdapter final : public ReadAdapterInterface {
|
||||
public:
|
||||
C10_DISABLE_COPY_AND_ASSIGN(IStreamAdapter);
|
||||
explicit IStreamAdapter(std::istream* istream);
|
||||
size_t size() const override;
|
||||
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
|
||||
const override;
|
||||
~IStreamAdapter();
|
||||
|
||||
private:
|
||||
std::istream* istream_;
|
||||
void validate(const char* what) const;
|
||||
};
|
||||
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
9
caffe2/serialize/read_adapter_interface.cc
Normal file
9
caffe2/serialize/read_adapter_interface.cc
Normal file
@ -0,0 +1,9 @@
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
|
||||
ReadAdapterInterface::~ReadAdapterInterface() {}
|
||||
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
21
caffe2/serialize/read_adapter_interface.h
Normal file
21
caffe2/serialize/read_adapter_interface.h
Normal file
@ -0,0 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
|
||||
// this is the interface for the (file/stream/memory) reader in
|
||||
// PyTorchStreamReader. with this interface, we can extend the support
|
||||
// besides standard istream
|
||||
class ReadAdapterInterface {
|
||||
public:
|
||||
virtual size_t size() const = 0;
|
||||
virtual size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
|
||||
const = 0;
|
||||
virtual ~ReadAdapterInterface();
|
||||
};
|
||||
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
@ -497,7 +497,7 @@ class ScriptModuleSerializer final {
|
||||
torch::ParameterDef* param_def);
|
||||
|
||||
std::ofstream ofs_;
|
||||
PyTorchStreamWriter writer_;
|
||||
caffe2::serialize::PyTorchStreamWriter writer_;
|
||||
|
||||
// all tensors that will be stored
|
||||
std::vector<at::Tensor> tensor_table_;
|
||||
|
@ -50,7 +50,7 @@ class ScriptModuleDeserializer final {
|
||||
|
||||
void loadTensorTable(torch::ModelDef* model_def);
|
||||
|
||||
PyTorchStreamReader reader_;
|
||||
caffe2::serialize::PyTorchStreamReader reader_;
|
||||
// this is a hack to make sure the script module created in C++ is the
|
||||
// same as created in Python
|
||||
ModuleLookup moduleLookup_;
|
||||
|
@ -55,6 +55,9 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using caffe2::serialize::PyTorchStreamReader;
|
||||
using caffe2::serialize::PyTorchStreamWriter;
|
||||
|
||||
// TODO: make a fake future for python
|
||||
namespace detail {
|
||||
class Future {};
|
||||
|
Reference in New Issue
Block a user