Make JIT Serialization support arbitrary std::function<> IO (#28039)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28039

Right now, torch::save() uses std::ostream, which results in unnecessary
data copies in practice. Similar for torch::load().

Adding a std::function<size_t(const void*, size_t)> as an output option,
parallel to the existing filename and std::ostream apis, gives users the
flexibility to emit directly to a backing store.

For a simple case of appending the output to a std::string, we observe
significant benchmark savings (on order of -50%), even with the
minor std::function<> dispatch overhead. The main reason is that
std::ostringstream effectively requires 2 extra copies of the data
beyond a simple string.append lambda.

We also provide a parallel api for the load(), though this one is
slightly more complex due to the need to do arbitrary position reads.

Test Plan:
buck test mode/dev-nosan caffe2/test/...
      (Basic serialization test in caffe2/test/cpp/api/serialize.cpp)
      Benchmark in experimental/jeremyl/c2/SerializationBench.cpp, with D17823443
        (1M time goes from 90ms -> 40ms, albeit with crc patch applied)

Differential Revision: D17939034

fbshipit-source-id: 344cce46f74b6438cb638a8cfbeccf4e1aa882d7
This commit is contained in:
Jeremy Lilley
2019-10-15 22:08:17 -07:00
committed by Facebook Github Bot
parent 9cc4405dc9
commit 2e0294cb39
11 changed files with 194 additions and 41 deletions

View File

@ -88,7 +88,9 @@ inline c10::optional<TempFile> try_make_tempfile(
if (fd == -1) { if (fd == -1) {
return c10::nullopt; return c10::nullopt;
} }
return TempFile(std::string(filename.begin(), filename.end()), fd); // Don't make the string from string(filename.begin(), filename.end(), or
// there will be a trailing '\0' at the end.
return TempFile(filename.data(), fd);
#endif // defined(_WIN32) #endif // defined(_WIN32)
} }

View File

@ -232,40 +232,51 @@ PyTorchStreamReader::~PyTorchStreamReader() {
valid("closing reader for archive ", archive_name_.c_str()); valid("closing reader for archive ", archive_name_.c_str());
} }
size_t ostream_write_func(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n) { size_t ostream_write_func(
void* pOpaque,
mz_uint64 file_ofs,
const void* pBuf,
size_t n) {
auto self = static_cast<PyTorchStreamWriter*>(pOpaque); auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
if (self->current_pos_ != file_ofs) { if (self->current_pos_ != file_ofs) {
// xxx - windows ostringstream refuses to seek to the end of an empty string CAFFE_THROW("unexpected pos ", self->current_pos_, " vs ", file_ofs);
// so we workaround this by not calling seek unless necessary
// in the case of the first write (to the empty string) file_ofs and
// current_pos_ will be 0 and the seek won't occur.
self->out_->seekp(file_ofs);
if(!*self->out_)
return 0;
} }
size_t ret = self->writer_func_(pBuf, n);
if (n != ret) {
self->err_seen_ = true;
}
self->current_pos_ += ret;
return ret;
}
self->out_->write(static_cast<const char*>(pBuf), n); PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name)
if(!*self->out_) : archive_name_(basename(file_name)) {
return 0; setup(file_name);
self->current_pos_ = file_ofs + n;
return n;
} }
PyTorchStreamWriter::PyTorchStreamWriter( PyTorchStreamWriter::PyTorchStreamWriter(
std::string file_name, const std::function<size_t(const void*, size_t)>& writer_func)
std::ostream* out) : archive_name_("archive"), writer_func_(writer_func) {
: ar_(caffe2::make_unique<mz_zip_archive>()), setup(archive_name_);
archive_name_(basename(file_name)), }
out_(out) {
void PyTorchStreamWriter::setup(const string& file_name) {
ar_ = caffe2::make_unique<mz_zip_archive>();
memset(ar_.get(), 0, sizeof(mz_zip_archive)); memset(ar_.get(), 0, sizeof(mz_zip_archive));
archive_name_plus_slash_ = archive_name_ + "/"; // for writeRecord().
if (archive_name_.size() == 0) { if (archive_name_.size() == 0) {
CAFFE_THROW("invalid file name: ", file_name); CAFFE_THROW("invalid file name: ", file_name);
} }
if (!out_) { if (!writer_func_) {
file_stream_.open(file_name, std::ofstream::out | std::ofstream::trunc | std::ofstream::binary); file_stream_.open(
out_ = &file_stream_; file_name,
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
valid("opening archive ", file_name.c_str()); valid("opening archive ", file_name.c_str());
writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
file_stream_.write(static_cast<const char*>(buf), nbytes);
return !file_stream_ ? 0 : nbytes;
};
} }
ar_->m_pIO_opaque = this; ar_->m_pIO_opaque = this;
@ -279,11 +290,14 @@ PyTorchStreamWriter::PyTorchStreamWriter(
writeRecord("version", version.str().c_str(), version.str().size()); writeRecord("version", version.str().c_str(), version.str().size());
} }
void PyTorchStreamWriter::writeRecord(const std::string& name, const void* data, size_t size, bool compress) { void PyTorchStreamWriter::writeRecord(
const std::string& name,
const void* data,
size_t size,
bool compress) {
AT_ASSERT(!finalized_); AT_ASSERT(!finalized_);
std::stringstream ss; AT_ASSERT(!archive_name_plus_slash_.empty());
ss << archive_name_ << "/" << name; std::string full_name = archive_name_plus_slash_ + name;
const std::string& full_name = ss.str();
std::string padding = getPadding(ar_->m_archive_size, full_name, size); std::string padding = getPadding(ar_->m_archive_size, full_name, size);
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0; uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
mz_zip_writer_add_mem_ex_v2( mz_zip_writer_add_mem_ex_v2(
@ -310,8 +324,9 @@ void PyTorchStreamWriter::writeEndOfFile() {
mz_zip_writer_finalize_archive(ar_.get()); mz_zip_writer_finalize_archive(ar_.get());
mz_zip_writer_end(ar_.get()); mz_zip_writer_end(ar_.get());
valid("writing central directory for archive ", archive_name_.c_str()); valid("writing central directory for archive ", archive_name_.c_str());
if (file_stream_.is_open()) if (file_stream_.is_open()) {
file_stream_.close(); file_stream_.close();
}
} }
void PyTorchStreamWriter::valid(const char* what, const char* info) { void PyTorchStreamWriter::valid(const char* what, const char* info) {
@ -324,7 +339,7 @@ void PyTorchStreamWriter::valid(const char* what, const char* info) {
": ", ": ",
mz_zip_get_error_string(err)); mz_zip_get_error_string(err));
} }
if (!*out_) { if (err_seen_) {
CAFFE_THROW("PytorchStreamWriter failed ", what, info, "."); CAFFE_THROW("PytorchStreamWriter failed ", what, info, ".");
} }
} }

View File

@ -123,11 +123,15 @@ class CAFFE2_API PyTorchStreamReader final {
class CAFFE2_API PyTorchStreamWriter final { class CAFFE2_API PyTorchStreamWriter final {
public: public:
PyTorchStreamWriter(std::string archive_name, std::ostream* out=nullptr); explicit PyTorchStreamWriter(std::string archive_name);
PyTorchStreamWriter(std::ostream* out) explicit PyTorchStreamWriter(
: PyTorchStreamWriter("archive", out) {} const std::function<size_t(const void*, size_t)>& writer_func);
void writeRecord(const std::string& name, const void* data, size_t size, bool compress = false); void writeRecord(
const std::string& name,
const void* data,
size_t size,
bool compress = false);
void writeEndOfFile(); void writeEndOfFile();
bool finalized() const { bool finalized() const {
@ -141,13 +145,16 @@ class CAFFE2_API PyTorchStreamWriter final {
~PyTorchStreamWriter(); ~PyTorchStreamWriter();
private: private:
void setup(const string& file_name);
void valid(const char* what, const char* info = ""); void valid(const char* what, const char* info = "");
size_t current_pos_ = 0; size_t current_pos_ = 0;
std::unique_ptr<mz_zip_archive> ar_; std::unique_ptr<mz_zip_archive> ar_;
std::string archive_name_; std::string archive_name_;
std::ostream* out_; std::string archive_name_plus_slash_;
std::ofstream file_stream_; std::ofstream file_stream_;
std::function<size_t(const void*, size_t)> writer_func_;
bool finalized_ = false; bool finalized_ = false;
bool err_seen_ = false;
friend size_t ostream_write_func( friend size_t ostream_write_func(
void* pOpaque, void* pOpaque,
uint64_t file_ofs, uint64_t file_ofs,

View File

@ -15,7 +15,10 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
std::ostringstream oss; std::ostringstream oss;
// write records through writers // write records through writers
PyTorchStreamWriter writer(&oss); PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
oss.write(static_cast<const char*>(b), n);
return oss ? n : 0;
});
std::array<char, 127> data1; std::array<char, 127> data1;
for (int i = 0; i < data1.size(); ++i) { for (int i = 0; i < data1.size(); ++i) {

View File

@ -60,6 +60,37 @@ TEST(SerializeTest, BasicToFile) {
ASSERT_TRUE(x.allclose(y)); ASSERT_TRUE(x.allclose(y));
} }
TEST(SerializeTest, BasicViaFunc) {
torch::manual_seed(0);
auto x = torch::randn({5, 5});
std::string serialized;
torch::save(x, [&](const void* buf, size_t n) {
serialized.append(reinterpret_cast<const char *>(buf), n);
return n;
});
torch::Tensor y;
torch::load(y, serialized.data(), serialized.size());
ASSERT_TRUE(y.defined());
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
ASSERT_TRUE(x.allclose(y));
torch::Tensor z;
torch::load(z, [&](uint64_t pos, void* buf, size_t n) -> size_t {
if (pos >= serialized.size()) return 0;
size_t nbytes = std::min(static_cast<size_t>(pos) + n,
serialized.size()) - pos;
memcpy(buf, serialized.data() + pos, nbytes);
return nbytes;
},
[&]() -> size_t { return serialized.size(); });
ASSERT_TRUE(z.defined());
ASSERT_EQ(x.sizes().vec(), z.sizes().vec());
ASSERT_TRUE(x.allclose(z));
}
TEST(SerializeTest, Resized) { TEST(SerializeTest, Resized) {
torch::manual_seed(0); torch::manual_seed(0);

View File

@ -81,6 +81,17 @@ class TORCH_API InputArchive final {
void load_from(std::istream& stream, void load_from(std::istream& stream,
c10::optional<torch::Device> device = c10::nullopt); c10::optional<torch::Device> device = c10::nullopt);
// Loads given the specified flat array.
void load_from(const char* data, size_t size,
c10::optional<torch::Device> device = c10::nullopt);
// Loads given the specified read and size functions.
void load_from(
const std::function<size_t(
uint64_t pos, void* buf, size_t nbytes)>& read_func,
const std::function<size_t(void)>& size_func,
c10::optional<torch::Device> device = c10::nullopt);
/// Forwards all arguments to `read()`. /// Forwards all arguments to `read()`.
/// Useful for generic code that can be re-used for both `InputArchive` and /// Useful for generic code that can be re-used for both `InputArchive` and
/// `OutputArchive` (where `operator()` forwards to `write()`). /// `OutputArchive` (where `operator()` forwards to `write()`).

View File

@ -62,6 +62,10 @@ class TORCH_API OutputArchive final {
/// `stream`. /// `stream`.
void save_to(std::ostream& stream); void save_to(std::ostream& stream);
/// Saves the `OutputArchive` into a serialized representation using the
/// given writer function.
void save_to(const std::function<size_t(const void*, size_t)>& func);
/// Forwards all arguments to `write()`. /// Forwards all arguments to `write()`.
/// Useful for generic code that can be re-used for both `OutputArchive` and /// Useful for generic code that can be re-used for both `OutputArchive` and
/// `InputArchive` (where `operator()` forwards to `read()`). /// `InputArchive` (where `operator()` forwards to `read()`).

View File

@ -5,7 +5,7 @@
#include <torch/csrc/jit/import.h> #include <torch/csrc/jit/import.h>
#include <torch/csrc/jit/script/module.h> #include <torch/csrc/jit/script/module.h>
#include <caffe2/serialize/read_adapter_interface.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <istream> #include <istream>
@ -94,5 +94,61 @@ void InputArchive::load_from(std::istream& stream,
c10::optional<torch::Device> device /*= c10::nullopt*/) { c10::optional<torch::Device> device /*= c10::nullopt*/) {
module_ = torch::jit::load(stream, std::move(device)); module_ = torch::jit::load(stream, std::move(device));
} }
void InputArchive::load_from(
const char* data,
size_t size,
c10::optional<torch::Device> device /*= c10::nullopt*/) {
using caffe2::serialize::ReadAdapterInterface;
class OurAdapter : public ReadAdapterInterface {
public:
OurAdapter(const char* data, size_t size)
: data_(data), size_(size) {
}
size_t size() const override { return size_; }
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
const override {
(void) what;
if (pos >= size_) {
return 0;
}
size_t nread = std::min(static_cast<size_t>(pos) + n, size_) - pos;
memcpy(buf, data_ + pos, nread);
return nread;
}
private:
const char* data_;
size_t size_;
};
std::unique_ptr<OurAdapter> adapter(new OurAdapter(data, size));
module_ = torch::jit::load(std::move(adapter), std::move(device));
}
void InputArchive::load_from(
const std::function<size_t(uint64_t, void*, size_t)>& read_func,
const std::function<size_t(void)>& size_func,
c10::optional<torch::Device> device /*= c10::nullopt*/) {
using caffe2::serialize::ReadAdapterInterface;
class OurAdapter : public ReadAdapterInterface {
public:
OurAdapter(const std::function<size_t(uint64_t, void*, size_t)>& read_func,
const std::function<size_t(void)>& size_func)
: read_func_(read_func),
size_func_(size_func) {
}
size_t size() const override { return size_func_(); }
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
const override {
(void)what;
return read_func_(pos, buf, n);
}
private:
const std::function<size_t(uint64_t, void*, size_t)>& read_func_;
const std::function<size_t(void)>& size_func_;
};
std::unique_ptr<OurAdapter> adapter(new OurAdapter(read_func, size_func));
module_ = torch::jit::load(std::move(adapter), std::move(device));
}
} // namespace serialize } // namespace serialize
} // namespace torch } // namespace torch

View File

@ -42,5 +42,10 @@ void OutputArchive::save_to(const std::string& filename) {
void OutputArchive::save_to(std::ostream& stream) { void OutputArchive::save_to(std::ostream& stream) {
jit::ExportModule(module_, stream); jit::ExportModule(module_, stream);
} }
void OutputArchive::save_to(
const std::function<size_t(const void*, size_t)>& func) {
jit::ExportModule(module_, func);
}
} // namespace serialize } // namespace serialize
} // namespace torch } // namespace torch

View File

@ -531,11 +531,12 @@ void GraphEncoder::EncodeTensor(
class ScriptModuleSerializer { class ScriptModuleSerializer {
public: public:
ScriptModuleSerializer(const std::string& filename) explicit ScriptModuleSerializer(const std::string& filename)
: writer_(filename.c_str()) {} : writer_(filename) {}
ScriptModuleSerializer(std::ostream* ofs) explicit ScriptModuleSerializer(
: ofs_(), writer_(ofs) {} const std::function<size_t(const void *, size_t)>& writer_func)
: writer_(writer_func) {}
void serialize( void serialize(
const script::Module& module, const script::Module& module,
@ -772,7 +773,6 @@ class ScriptModuleSerializer {
converted_types_.insert(class_type, std::move(info)); converted_types_.insert(class_type, std::move(info));
} }
std::ofstream ofs_;
caffe2::serialize::PyTorchStreamWriter writer_; caffe2::serialize::PyTorchStreamWriter writer_;
std::vector<at::Tensor> constant_table_; std::vector<at::Tensor> constant_table_;
@ -1023,7 +1023,11 @@ void ExportModule(
std::ostream& out, std::ostream& out,
const script::ExtraFilesMap& extra_files, const script::ExtraFilesMap& extra_files,
bool bytecode_format) { bool bytecode_format) {
ScriptModuleSerializer serializer(&out); ScriptModuleSerializer serializer(
[&](const void* buf, size_t nbytes) -> size_t {
out.write(static_cast<const char *>(buf), nbytes);
return !out ? 0 : nbytes;
});
serializer.serialize(module, extra_files, bytecode_format); serializer.serialize(module, extra_files, bytecode_format);
} }
@ -1036,5 +1040,14 @@ void ExportModule(
serializer.serialize(module, extra_files, bytecode_format); serializer.serialize(module, extra_files, bytecode_format);
} }
void ExportModule(
const script::Module& module,
const std::function<size_t(const void*, size_t)>& writer_func,
const script::ExtraFilesMap& extra_files,
bool bytecode_format) {
ScriptModuleSerializer serializer(writer_func);
serializer.serialize(module, extra_files, bytecode_format);
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -55,6 +55,12 @@ TORCH_API void ExportModule(
const script::ExtraFilesMap& metadata = script::ExtraFilesMap(), const script::ExtraFilesMap& metadata = script::ExtraFilesMap(),
bool bytecode_format = false); bool bytecode_format = false);
TORCH_API void ExportModule(
const script::Module& module,
const std::function<size_t(const void*, size_t)>& writer_func,
const script::ExtraFilesMap& metadata = script::ExtraFilesMap(),
bool bytecode_format = false);
// Surrounding system can install an additional hook to produce extra files // Surrounding system can install an additional hook to produce extra files
// with metadata based on environment every time a module is serialized. // with metadata based on environment every time a module is serialized.
using ExportModuleExtraFilesHook = using ExportModuleExtraFilesHook =