Make JIT Serialization support arbitrary std::function<> IO (#28039)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28039

Right now, torch::save() uses std::ostream, which results in unnecessary
data copies in practice. Similar for torch::load().

Adding a std::function<size_t(const void*, size_t)> as an output option,
parallel to the existing filename and std::ostream apis, gives users the
flexibility to emit directly to a backing store.

For a simple case of appending the output to a std::string, we observe
significant benchmark savings (on order of -50%), even with the
minor std::function<> dispatch overhead. The main reason is that
std::ostringstream effectively requires 2 extra copies of the data
beyond a simple string.append lambda.

We also provide a parallel api for the load(), though this one is
slightly more complex due to the need to do arbitrary position reads.

Test Plan:
buck test mode/dev-nosan caffe2/test/...
      (Basic serialization test in caffe2/test/cpp/api/serialize.cpp)
      Benchmark in experimental/jeremyl/c2/SerializationBench.cpp, with D17823443
        (1M time goes from 90ms -> 40ms, albeit with crc patch applied)

Differential Revision: D17939034

fbshipit-source-id: 344cce46f74b6438cb638a8cfbeccf4e1aa882d7
This commit is contained in:
Jeremy Lilley
2019-10-15 22:08:17 -07:00
committed by Facebook Github Bot
parent 9cc4405dc9
commit 2e0294cb39
11 changed files with 194 additions and 41 deletions

View File

@ -88,7 +88,9 @@ inline c10::optional<TempFile> try_make_tempfile(
if (fd == -1) {
return c10::nullopt;
}
return TempFile(std::string(filename.begin(), filename.end()), fd);
// Don't make the string from string(filename.begin(), filename.end(), or
// there will be a trailing '\0' at the end.
return TempFile(filename.data(), fd);
#endif // defined(_WIN32)
}

View File

@ -232,40 +232,51 @@ PyTorchStreamReader::~PyTorchStreamReader() {
valid("closing reader for archive ", archive_name_.c_str());
}
size_t ostream_write_func(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n) {
size_t ostream_write_func(
void* pOpaque,
mz_uint64 file_ofs,
const void* pBuf,
size_t n) {
auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
if (self->current_pos_ != file_ofs) {
// xxx - windows ostringstream refuses to seek to the end of an empty string
// so we workaround this by not calling seek unless necessary
// in the case of the first write (to the empty string) file_ofs and
// current_pos_ will be 0 and the seek won't occur.
self->out_->seekp(file_ofs);
if(!*self->out_)
return 0;
CAFFE_THROW("unexpected pos ", self->current_pos_, " vs ", file_ofs);
}
size_t ret = self->writer_func_(pBuf, n);
if (n != ret) {
self->err_seen_ = true;
}
self->current_pos_ += ret;
return ret;
}
self->out_->write(static_cast<const char*>(pBuf), n);
if(!*self->out_)
return 0;
self->current_pos_ = file_ofs + n;
return n;
PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name)
: archive_name_(basename(file_name)) {
setup(file_name);
}
PyTorchStreamWriter::PyTorchStreamWriter(
std::string file_name,
std::ostream* out)
: ar_(caffe2::make_unique<mz_zip_archive>()),
archive_name_(basename(file_name)),
out_(out) {
const std::function<size_t(const void*, size_t)>& writer_func)
: archive_name_("archive"), writer_func_(writer_func) {
setup(archive_name_);
}
void PyTorchStreamWriter::setup(const string& file_name) {
ar_ = caffe2::make_unique<mz_zip_archive>();
memset(ar_.get(), 0, sizeof(mz_zip_archive));
archive_name_plus_slash_ = archive_name_ + "/"; // for writeRecord().
if (archive_name_.size() == 0) {
CAFFE_THROW("invalid file name: ", file_name);
}
if (!out_) {
file_stream_.open(file_name, std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
out_ = &file_stream_;
if (!writer_func_) {
file_stream_.open(
file_name,
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
valid("opening archive ", file_name.c_str());
writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
file_stream_.write(static_cast<const char*>(buf), nbytes);
return !file_stream_ ? 0 : nbytes;
};
}
ar_->m_pIO_opaque = this;
@ -279,11 +290,14 @@ PyTorchStreamWriter::PyTorchStreamWriter(
writeRecord("version", version.str().c_str(), version.str().size());
}
void PyTorchStreamWriter::writeRecord(const std::string& name, const void* data, size_t size, bool compress) {
void PyTorchStreamWriter::writeRecord(
const std::string& name,
const void* data,
size_t size,
bool compress) {
AT_ASSERT(!finalized_);
std::stringstream ss;
ss << archive_name_ << "/" << name;
const std::string& full_name = ss.str();
AT_ASSERT(!archive_name_plus_slash_.empty());
std::string full_name = archive_name_plus_slash_ + name;
std::string padding = getPadding(ar_->m_archive_size, full_name, size);
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
mz_zip_writer_add_mem_ex_v2(
@ -310,8 +324,9 @@ void PyTorchStreamWriter::writeEndOfFile() {
mz_zip_writer_finalize_archive(ar_.get());
mz_zip_writer_end(ar_.get());
valid("writing central directory for archive ", archive_name_.c_str());
if (file_stream_.is_open())
if (file_stream_.is_open()) {
file_stream_.close();
}
}
void PyTorchStreamWriter::valid(const char* what, const char* info) {
@ -324,7 +339,7 @@ void PyTorchStreamWriter::valid(const char* what, const char* info) {
": ",
mz_zip_get_error_string(err));
}
if (!*out_) {
if (err_seen_) {
CAFFE_THROW("PytorchStreamWriter failed ", what, info, ".");
}
}

View File

@ -123,11 +123,15 @@ class CAFFE2_API PyTorchStreamReader final {
class CAFFE2_API PyTorchStreamWriter final {
public:
PyTorchStreamWriter(std::string archive_name, std::ostream* out=nullptr);
PyTorchStreamWriter(std::ostream* out)
: PyTorchStreamWriter("archive", out) {}
explicit PyTorchStreamWriter(std::string archive_name);
explicit PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)>& writer_func);
void writeRecord(const std::string& name, const void* data, size_t size, bool compress = false);
void writeRecord(
const std::string& name,
const void* data,
size_t size,
bool compress = false);
void writeEndOfFile();
bool finalized() const {
@ -141,13 +145,16 @@ class CAFFE2_API PyTorchStreamWriter final {
~PyTorchStreamWriter();
private:
void setup(const string& file_name);
void valid(const char* what, const char* info = "");
size_t current_pos_ = 0;
std::unique_ptr<mz_zip_archive> ar_;
std::string archive_name_;
std::ostream* out_;
std::string archive_name_plus_slash_;
std::ofstream file_stream_;
std::function<size_t(const void*, size_t)> writer_func_;
bool finalized_ = false;
bool err_seen_ = false;
friend size_t ostream_write_func(
void* pOpaque,
uint64_t file_ofs,

View File

@ -15,7 +15,10 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
std::ostringstream oss;
// write records through writers
PyTorchStreamWriter writer(&oss);
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
oss.write(static_cast<const char*>(b), n);
return oss ? n : 0;
});
std::array<char, 127> data1;
for (int i = 0; i < data1.size(); ++i) {

View File

@ -60,6 +60,37 @@ TEST(SerializeTest, BasicToFile) {
ASSERT_TRUE(x.allclose(y));
}
TEST(SerializeTest, BasicViaFunc) {
torch::manual_seed(0);
auto x = torch::randn({5, 5});
std::string serialized;
torch::save(x, [&](const void* buf, size_t n) {
serialized.append(reinterpret_cast<const char *>(buf), n);
return n;
});
torch::Tensor y;
torch::load(y, serialized.data(), serialized.size());
ASSERT_TRUE(y.defined());
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
ASSERT_TRUE(x.allclose(y));
torch::Tensor z;
torch::load(z, [&](uint64_t pos, void* buf, size_t n) -> size_t {
if (pos >= serialized.size()) return 0;
size_t nbytes = std::min(static_cast<size_t>(pos) + n,
serialized.size()) - pos;
memcpy(buf, serialized.data() + pos, nbytes);
return nbytes;
},
[&]() -> size_t { return serialized.size(); });
ASSERT_TRUE(z.defined());
ASSERT_EQ(x.sizes().vec(), z.sizes().vec());
ASSERT_TRUE(x.allclose(z));
}
TEST(SerializeTest, Resized) {
torch::manual_seed(0);

View File

@ -81,6 +81,17 @@ class TORCH_API InputArchive final {
void load_from(std::istream& stream,
c10::optional<torch::Device> device = c10::nullopt);
// Loads given the specified flat array.
void load_from(const char* data, size_t size,
c10::optional<torch::Device> device = c10::nullopt);
// Loads given the specified read and size functions.
void load_from(
const std::function<size_t(
uint64_t pos, void* buf, size_t nbytes)>& read_func,
const std::function<size_t(void)>& size_func,
c10::optional<torch::Device> device = c10::nullopt);
/// Forwards all arguments to `read()`.
/// Useful for generic code that can be re-used for both `InputArchive` and
/// `OutputArchive` (where `operator()` forwards to `write()`).

View File

@ -62,6 +62,10 @@ class TORCH_API OutputArchive final {
/// `stream`.
void save_to(std::ostream& stream);
/// Saves the `OutputArchive` into a serialized representation using the
/// given writer function.
void save_to(const std::function<size_t(const void*, size_t)>& func);
/// Forwards all arguments to `write()`.
/// Useful for generic code that can be re-used for both `OutputArchive` and
/// `InputArchive` (where `operator()` forwards to `read()`).

View File

@ -5,7 +5,7 @@
#include <torch/csrc/jit/import.h>
#include <torch/csrc/jit/script/module.h>
#include <caffe2/serialize/read_adapter_interface.h>
#include <c10/util/Exception.h>
#include <istream>
@ -94,5 +94,61 @@ void InputArchive::load_from(std::istream& stream,
c10::optional<torch::Device> device /*= c10::nullopt*/) {
module_ = torch::jit::load(stream, std::move(device));
}
void InputArchive::load_from(
const char* data,
size_t size,
c10::optional<torch::Device> device /*= c10::nullopt*/) {
using caffe2::serialize::ReadAdapterInterface;
class OurAdapter : public ReadAdapterInterface {
public:
OurAdapter(const char* data, size_t size)
: data_(data), size_(size) {
}
size_t size() const override { return size_; }
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
const override {
(void) what;
if (pos >= size_) {
return 0;
}
size_t nread = std::min(static_cast<size_t>(pos) + n, size_) - pos;
memcpy(buf, data_ + pos, nread);
return nread;
}
private:
const char* data_;
size_t size_;
};
std::unique_ptr<OurAdapter> adapter(new OurAdapter(data, size));
module_ = torch::jit::load(std::move(adapter), std::move(device));
}
void InputArchive::load_from(
const std::function<size_t(uint64_t, void*, size_t)>& read_func,
const std::function<size_t(void)>& size_func,
c10::optional<torch::Device> device /*= c10::nullopt*/) {
using caffe2::serialize::ReadAdapterInterface;
class OurAdapter : public ReadAdapterInterface {
public:
OurAdapter(const std::function<size_t(uint64_t, void*, size_t)>& read_func,
const std::function<size_t(void)>& size_func)
: read_func_(read_func),
size_func_(size_func) {
}
size_t size() const override { return size_func_(); }
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
const override {
(void)what;
return read_func_(pos, buf, n);
}
private:
const std::function<size_t(uint64_t, void*, size_t)>& read_func_;
const std::function<size_t(void)>& size_func_;
};
std::unique_ptr<OurAdapter> adapter(new OurAdapter(read_func, size_func));
module_ = torch::jit::load(std::move(adapter), std::move(device));
}
} // namespace serialize
} // namespace torch

View File

@ -42,5 +42,10 @@ void OutputArchive::save_to(const std::string& filename) {
void OutputArchive::save_to(std::ostream& stream) {
jit::ExportModule(module_, stream);
}
void OutputArchive::save_to(
const std::function<size_t(const void*, size_t)>& func) {
jit::ExportModule(module_, func);
}
} // namespace serialize
} // namespace torch

View File

@ -531,11 +531,12 @@ void GraphEncoder::EncodeTensor(
class ScriptModuleSerializer {
public:
ScriptModuleSerializer(const std::string& filename)
: writer_(filename.c_str()) {}
explicit ScriptModuleSerializer(const std::string& filename)
: writer_(filename) {}
ScriptModuleSerializer(std::ostream* ofs)
: ofs_(), writer_(ofs) {}
explicit ScriptModuleSerializer(
const std::function<size_t(const void *, size_t)>& writer_func)
: writer_(writer_func) {}
void serialize(
const script::Module& module,
@ -772,7 +773,6 @@ class ScriptModuleSerializer {
converted_types_.insert(class_type, std::move(info));
}
std::ofstream ofs_;
caffe2::serialize::PyTorchStreamWriter writer_;
std::vector<at::Tensor> constant_table_;
@ -1023,7 +1023,11 @@ void ExportModule(
std::ostream& out,
const script::ExtraFilesMap& extra_files,
bool bytecode_format) {
ScriptModuleSerializer serializer(&out);
ScriptModuleSerializer serializer(
[&](const void* buf, size_t nbytes) -> size_t {
out.write(static_cast<const char *>(buf), nbytes);
return !out ? 0 : nbytes;
});
serializer.serialize(module, extra_files, bytecode_format);
}
@ -1036,5 +1040,14 @@ void ExportModule(
serializer.serialize(module, extra_files, bytecode_format);
}
void ExportModule(
const script::Module& module,
const std::function<size_t(const void*, size_t)>& writer_func,
const script::ExtraFilesMap& extra_files,
bool bytecode_format) {
ScriptModuleSerializer serializer(writer_func);
serializer.serialize(module, extra_files, bytecode_format);
}
} // namespace jit
} // namespace torch

View File

@ -55,6 +55,12 @@ TORCH_API void ExportModule(
const script::ExtraFilesMap& metadata = script::ExtraFilesMap(),
bool bytecode_format = false);
TORCH_API void ExportModule(
const script::Module& module,
const std::function<size_t(const void*, size_t)>& writer_func,
const script::ExtraFilesMap& metadata = script::ExtraFilesMap(),
bool bytecode_format = false);
// Surrounding system can install an additional hook to produce extra files
// with metadata based on environment every time a module is serialized.
using ExportModuleExtraFilesHook =