mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make JIT Serialization support arbitrary std::function<> IO (#28039)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28039 Right now, torch::save() uses std::ostream, which results in unnecessary data copies in practice. Similar for torch::load(). Adding a std::function<size_t(const void*, size_t)> as an output option, parallel to the existing filename and std::ostream apis, gives users the flexibility to emit directly to a backing store. For a simple case of appending the output to a std::string, we observe significant benchmark savings (on order of -50%), even with the minor std::function<> dispatch overhead. The main reason is that std::ostringstream effectively requires 2 extra copies of the data beyond a simple string.append lambda. We also provide a parallel api for the load(), though this one is slightly more complex due to the need to do arbitrary position reads. Test Plan: buck test mode/dev-nosan caffe2/test/... (Basic serialization test in caffe2/test/cpp/api/serialize.cpp) Benchmark in experimental/jeremyl/c2/SerializationBench.cpp, with D17823443 (1M time goes from 90ms -> 40ms, albeit with crc patch applied) Differential Revision: D17939034 fbshipit-source-id: 344cce46f74b6438cb638a8cfbeccf4e1aa882d7
This commit is contained in:
committed by
Facebook Github Bot
parent
9cc4405dc9
commit
2e0294cb39
@ -88,7 +88,9 @@ inline c10::optional<TempFile> try_make_tempfile(
|
||||
if (fd == -1) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
return TempFile(std::string(filename.begin(), filename.end()), fd);
|
||||
// Don't make the string from string(filename.begin(), filename.end(), or
|
||||
// there will be a trailing '\0' at the end.
|
||||
return TempFile(filename.data(), fd);
|
||||
#endif // defined(_WIN32)
|
||||
}
|
||||
|
||||
|
||||
@ -232,40 +232,51 @@ PyTorchStreamReader::~PyTorchStreamReader() {
|
||||
valid("closing reader for archive ", archive_name_.c_str());
|
||||
}
|
||||
|
||||
size_t ostream_write_func(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n) {
|
||||
size_t ostream_write_func(
|
||||
void* pOpaque,
|
||||
mz_uint64 file_ofs,
|
||||
const void* pBuf,
|
||||
size_t n) {
|
||||
auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
|
||||
if (self->current_pos_ != file_ofs) {
|
||||
// xxx - windows ostringstream refuses to seek to the end of an empty string
|
||||
// so we workaround this by not calling seek unless necessary
|
||||
// in the case of the first write (to the empty string) file_ofs and
|
||||
// current_pos_ will be 0 and the seek won't occur.
|
||||
self->out_->seekp(file_ofs);
|
||||
if(!*self->out_)
|
||||
return 0;
|
||||
CAFFE_THROW("unexpected pos ", self->current_pos_, " vs ", file_ofs);
|
||||
}
|
||||
size_t ret = self->writer_func_(pBuf, n);
|
||||
if (n != ret) {
|
||||
self->err_seen_ = true;
|
||||
}
|
||||
self->current_pos_ += ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
self->out_->write(static_cast<const char*>(pBuf), n);
|
||||
if(!*self->out_)
|
||||
return 0;
|
||||
self->current_pos_ = file_ofs + n;
|
||||
return n;
|
||||
PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name)
|
||||
: archive_name_(basename(file_name)) {
|
||||
setup(file_name);
|
||||
}
|
||||
|
||||
PyTorchStreamWriter::PyTorchStreamWriter(
|
||||
std::string file_name,
|
||||
std::ostream* out)
|
||||
: ar_(caffe2::make_unique<mz_zip_archive>()),
|
||||
archive_name_(basename(file_name)),
|
||||
out_(out) {
|
||||
const std::function<size_t(const void*, size_t)>& writer_func)
|
||||
: archive_name_("archive"), writer_func_(writer_func) {
|
||||
setup(archive_name_);
|
||||
}
|
||||
|
||||
void PyTorchStreamWriter::setup(const string& file_name) {
|
||||
ar_ = caffe2::make_unique<mz_zip_archive>();
|
||||
memset(ar_.get(), 0, sizeof(mz_zip_archive));
|
||||
archive_name_plus_slash_ = archive_name_ + "/"; // for writeRecord().
|
||||
|
||||
if (archive_name_.size() == 0) {
|
||||
CAFFE_THROW("invalid file name: ", file_name);
|
||||
}
|
||||
if (!out_) {
|
||||
file_stream_.open(file_name, std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
|
||||
out_ = &file_stream_;
|
||||
if (!writer_func_) {
|
||||
file_stream_.open(
|
||||
file_name,
|
||||
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
|
||||
valid("opening archive ", file_name.c_str());
|
||||
writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
|
||||
file_stream_.write(static_cast<const char*>(buf), nbytes);
|
||||
return !file_stream_ ? 0 : nbytes;
|
||||
};
|
||||
}
|
||||
|
||||
ar_->m_pIO_opaque = this;
|
||||
@ -279,11 +290,14 @@ PyTorchStreamWriter::PyTorchStreamWriter(
|
||||
writeRecord("version", version.str().c_str(), version.str().size());
|
||||
}
|
||||
|
||||
void PyTorchStreamWriter::writeRecord(const std::string& name, const void* data, size_t size, bool compress) {
|
||||
void PyTorchStreamWriter::writeRecord(
|
||||
const std::string& name,
|
||||
const void* data,
|
||||
size_t size,
|
||||
bool compress) {
|
||||
AT_ASSERT(!finalized_);
|
||||
std::stringstream ss;
|
||||
ss << archive_name_ << "/" << name;
|
||||
const std::string& full_name = ss.str();
|
||||
AT_ASSERT(!archive_name_plus_slash_.empty());
|
||||
std::string full_name = archive_name_plus_slash_ + name;
|
||||
std::string padding = getPadding(ar_->m_archive_size, full_name, size);
|
||||
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
|
||||
mz_zip_writer_add_mem_ex_v2(
|
||||
@ -310,8 +324,9 @@ void PyTorchStreamWriter::writeEndOfFile() {
|
||||
mz_zip_writer_finalize_archive(ar_.get());
|
||||
mz_zip_writer_end(ar_.get());
|
||||
valid("writing central directory for archive ", archive_name_.c_str());
|
||||
if (file_stream_.is_open())
|
||||
if (file_stream_.is_open()) {
|
||||
file_stream_.close();
|
||||
}
|
||||
}
|
||||
|
||||
void PyTorchStreamWriter::valid(const char* what, const char* info) {
|
||||
@ -324,7 +339,7 @@ void PyTorchStreamWriter::valid(const char* what, const char* info) {
|
||||
": ",
|
||||
mz_zip_get_error_string(err));
|
||||
}
|
||||
if (!*out_) {
|
||||
if (err_seen_) {
|
||||
CAFFE_THROW("PytorchStreamWriter failed ", what, info, ".");
|
||||
}
|
||||
}
|
||||
|
||||
@ -123,11 +123,15 @@ class CAFFE2_API PyTorchStreamReader final {
|
||||
|
||||
class CAFFE2_API PyTorchStreamWriter final {
|
||||
public:
|
||||
PyTorchStreamWriter(std::string archive_name, std::ostream* out=nullptr);
|
||||
PyTorchStreamWriter(std::ostream* out)
|
||||
: PyTorchStreamWriter("archive", out) {}
|
||||
explicit PyTorchStreamWriter(std::string archive_name);
|
||||
explicit PyTorchStreamWriter(
|
||||
const std::function<size_t(const void*, size_t)>& writer_func);
|
||||
|
||||
void writeRecord(const std::string& name, const void* data, size_t size, bool compress = false);
|
||||
void writeRecord(
|
||||
const std::string& name,
|
||||
const void* data,
|
||||
size_t size,
|
||||
bool compress = false);
|
||||
void writeEndOfFile();
|
||||
|
||||
bool finalized() const {
|
||||
@ -141,13 +145,16 @@ class CAFFE2_API PyTorchStreamWriter final {
|
||||
~PyTorchStreamWriter();
|
||||
|
||||
private:
|
||||
void setup(const string& file_name);
|
||||
void valid(const char* what, const char* info = "");
|
||||
size_t current_pos_ = 0;
|
||||
std::unique_ptr<mz_zip_archive> ar_;
|
||||
std::string archive_name_;
|
||||
std::ostream* out_;
|
||||
std::string archive_name_plus_slash_;
|
||||
std::ofstream file_stream_;
|
||||
std::function<size_t(const void*, size_t)> writer_func_;
|
||||
bool finalized_ = false;
|
||||
bool err_seen_ = false;
|
||||
friend size_t ostream_write_func(
|
||||
void* pOpaque,
|
||||
uint64_t file_ofs,
|
||||
|
||||
@ -15,7 +15,10 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||
|
||||
std::ostringstream oss;
|
||||
// write records through writers
|
||||
PyTorchStreamWriter writer(&oss);
|
||||
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
|
||||
oss.write(static_cast<const char*>(b), n);
|
||||
return oss ? n : 0;
|
||||
});
|
||||
std::array<char, 127> data1;
|
||||
|
||||
for (int i = 0; i < data1.size(); ++i) {
|
||||
|
||||
@ -60,6 +60,37 @@ TEST(SerializeTest, BasicToFile) {
|
||||
ASSERT_TRUE(x.allclose(y));
|
||||
}
|
||||
|
||||
TEST(SerializeTest, BasicViaFunc) {
|
||||
torch::manual_seed(0);
|
||||
|
||||
auto x = torch::randn({5, 5});
|
||||
|
||||
std::string serialized;
|
||||
torch::save(x, [&](const void* buf, size_t n) {
|
||||
serialized.append(reinterpret_cast<const char *>(buf), n);
|
||||
return n;
|
||||
});
|
||||
torch::Tensor y;
|
||||
torch::load(y, serialized.data(), serialized.size());
|
||||
|
||||
ASSERT_TRUE(y.defined());
|
||||
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
|
||||
ASSERT_TRUE(x.allclose(y));
|
||||
|
||||
torch::Tensor z;
|
||||
torch::load(z, [&](uint64_t pos, void* buf, size_t n) -> size_t {
|
||||
if (pos >= serialized.size()) return 0;
|
||||
size_t nbytes = std::min(static_cast<size_t>(pos) + n,
|
||||
serialized.size()) - pos;
|
||||
memcpy(buf, serialized.data() + pos, nbytes);
|
||||
return nbytes;
|
||||
},
|
||||
[&]() -> size_t { return serialized.size(); });
|
||||
ASSERT_TRUE(z.defined());
|
||||
ASSERT_EQ(x.sizes().vec(), z.sizes().vec());
|
||||
ASSERT_TRUE(x.allclose(z));
|
||||
}
|
||||
|
||||
TEST(SerializeTest, Resized) {
|
||||
torch::manual_seed(0);
|
||||
|
||||
|
||||
@ -81,6 +81,17 @@ class TORCH_API InputArchive final {
|
||||
void load_from(std::istream& stream,
|
||||
c10::optional<torch::Device> device = c10::nullopt);
|
||||
|
||||
// Loads given the specified flat array.
|
||||
void load_from(const char* data, size_t size,
|
||||
c10::optional<torch::Device> device = c10::nullopt);
|
||||
|
||||
// Loads given the specified read and size functions.
|
||||
void load_from(
|
||||
const std::function<size_t(
|
||||
uint64_t pos, void* buf, size_t nbytes)>& read_func,
|
||||
const std::function<size_t(void)>& size_func,
|
||||
c10::optional<torch::Device> device = c10::nullopt);
|
||||
|
||||
/// Forwards all arguments to `read()`.
|
||||
/// Useful for generic code that can be re-used for both `InputArchive` and
|
||||
/// `OutputArchive` (where `operator()` forwards to `write()`).
|
||||
|
||||
@ -62,6 +62,10 @@ class TORCH_API OutputArchive final {
|
||||
/// `stream`.
|
||||
void save_to(std::ostream& stream);
|
||||
|
||||
/// Saves the `OutputArchive` into a serialized representation using the
|
||||
/// given writer function.
|
||||
void save_to(const std::function<size_t(const void*, size_t)>& func);
|
||||
|
||||
/// Forwards all arguments to `write()`.
|
||||
/// Useful for generic code that can be re-used for both `OutputArchive` and
|
||||
/// `InputArchive` (where `operator()` forwards to `read()`).
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
#include <torch/csrc/jit/import.h>
|
||||
#include <torch/csrc/jit/script/module.h>
|
||||
|
||||
#include <caffe2/serialize/read_adapter_interface.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <istream>
|
||||
@ -94,5 +94,61 @@ void InputArchive::load_from(std::istream& stream,
|
||||
c10::optional<torch::Device> device /*= c10::nullopt*/) {
|
||||
module_ = torch::jit::load(stream, std::move(device));
|
||||
}
|
||||
|
||||
void InputArchive::load_from(
|
||||
const char* data,
|
||||
size_t size,
|
||||
c10::optional<torch::Device> device /*= c10::nullopt*/) {
|
||||
using caffe2::serialize::ReadAdapterInterface;
|
||||
class OurAdapter : public ReadAdapterInterface {
|
||||
public:
|
||||
OurAdapter(const char* data, size_t size)
|
||||
: data_(data), size_(size) {
|
||||
}
|
||||
size_t size() const override { return size_; }
|
||||
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
|
||||
const override {
|
||||
(void) what;
|
||||
if (pos >= size_) {
|
||||
return 0;
|
||||
}
|
||||
size_t nread = std::min(static_cast<size_t>(pos) + n, size_) - pos;
|
||||
memcpy(buf, data_ + pos, nread);
|
||||
return nread;
|
||||
}
|
||||
private:
|
||||
const char* data_;
|
||||
size_t size_;
|
||||
};
|
||||
std::unique_ptr<OurAdapter> adapter(new OurAdapter(data, size));
|
||||
module_ = torch::jit::load(std::move(adapter), std::move(device));
|
||||
}
|
||||
|
||||
void InputArchive::load_from(
|
||||
const std::function<size_t(uint64_t, void*, size_t)>& read_func,
|
||||
const std::function<size_t(void)>& size_func,
|
||||
c10::optional<torch::Device> device /*= c10::nullopt*/) {
|
||||
using caffe2::serialize::ReadAdapterInterface;
|
||||
class OurAdapter : public ReadAdapterInterface {
|
||||
public:
|
||||
OurAdapter(const std::function<size_t(uint64_t, void*, size_t)>& read_func,
|
||||
const std::function<size_t(void)>& size_func)
|
||||
: read_func_(read_func),
|
||||
size_func_(size_func) {
|
||||
}
|
||||
size_t size() const override { return size_func_(); }
|
||||
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
|
||||
const override {
|
||||
(void)what;
|
||||
return read_func_(pos, buf, n);
|
||||
}
|
||||
private:
|
||||
const std::function<size_t(uint64_t, void*, size_t)>& read_func_;
|
||||
const std::function<size_t(void)>& size_func_;
|
||||
};
|
||||
std::unique_ptr<OurAdapter> adapter(new OurAdapter(read_func, size_func));
|
||||
module_ = torch::jit::load(std::move(adapter), std::move(device));
|
||||
}
|
||||
|
||||
} // namespace serialize
|
||||
} // namespace torch
|
||||
|
||||
@ -42,5 +42,10 @@ void OutputArchive::save_to(const std::string& filename) {
|
||||
void OutputArchive::save_to(std::ostream& stream) {
|
||||
jit::ExportModule(module_, stream);
|
||||
}
|
||||
|
||||
void OutputArchive::save_to(
|
||||
const std::function<size_t(const void*, size_t)>& func) {
|
||||
jit::ExportModule(module_, func);
|
||||
}
|
||||
} // namespace serialize
|
||||
} // namespace torch
|
||||
|
||||
@ -531,11 +531,12 @@ void GraphEncoder::EncodeTensor(
|
||||
|
||||
class ScriptModuleSerializer {
|
||||
public:
|
||||
ScriptModuleSerializer(const std::string& filename)
|
||||
: writer_(filename.c_str()) {}
|
||||
explicit ScriptModuleSerializer(const std::string& filename)
|
||||
: writer_(filename) {}
|
||||
|
||||
ScriptModuleSerializer(std::ostream* ofs)
|
||||
: ofs_(), writer_(ofs) {}
|
||||
explicit ScriptModuleSerializer(
|
||||
const std::function<size_t(const void *, size_t)>& writer_func)
|
||||
: writer_(writer_func) {}
|
||||
|
||||
void serialize(
|
||||
const script::Module& module,
|
||||
@ -772,7 +773,6 @@ class ScriptModuleSerializer {
|
||||
converted_types_.insert(class_type, std::move(info));
|
||||
}
|
||||
|
||||
std::ofstream ofs_;
|
||||
caffe2::serialize::PyTorchStreamWriter writer_;
|
||||
std::vector<at::Tensor> constant_table_;
|
||||
|
||||
@ -1023,7 +1023,11 @@ void ExportModule(
|
||||
std::ostream& out,
|
||||
const script::ExtraFilesMap& extra_files,
|
||||
bool bytecode_format) {
|
||||
ScriptModuleSerializer serializer(&out);
|
||||
ScriptModuleSerializer serializer(
|
||||
[&](const void* buf, size_t nbytes) -> size_t {
|
||||
out.write(static_cast<const char *>(buf), nbytes);
|
||||
return !out ? 0 : nbytes;
|
||||
});
|
||||
serializer.serialize(module, extra_files, bytecode_format);
|
||||
}
|
||||
|
||||
@ -1036,5 +1040,14 @@ void ExportModule(
|
||||
serializer.serialize(module, extra_files, bytecode_format);
|
||||
}
|
||||
|
||||
void ExportModule(
|
||||
const script::Module& module,
|
||||
const std::function<size_t(const void*, size_t)>& writer_func,
|
||||
const script::ExtraFilesMap& extra_files,
|
||||
bool bytecode_format) {
|
||||
ScriptModuleSerializer serializer(writer_func);
|
||||
serializer.serialize(module, extra_files, bytecode_format);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@ -55,6 +55,12 @@ TORCH_API void ExportModule(
|
||||
const script::ExtraFilesMap& metadata = script::ExtraFilesMap(),
|
||||
bool bytecode_format = false);
|
||||
|
||||
TORCH_API void ExportModule(
|
||||
const script::Module& module,
|
||||
const std::function<size_t(const void*, size_t)>& writer_func,
|
||||
const script::ExtraFilesMap& metadata = script::ExtraFilesMap(),
|
||||
bool bytecode_format = false);
|
||||
|
||||
// Surrounding system can install an additional hook to produce extra files
|
||||
// with metadata based on environment every time a module is serialized.
|
||||
using ExportModuleExtraFilesHook =
|
||||
|
||||
Reference in New Issue
Block a user