mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Make JIT Serialization support arbitrary std::function<> IO (#28039)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28039 Right now, torch::save() uses std::ostream, which results in unnecessary data copies in practice. Similar for torch::load(). Adding a std::function<size_t(const void*, size_t)> as an output option, parallel to the existing filename and std::ostream apis, gives users the flexibility to emit directly to a backing store. For a simple case of appending the output to a std::string, we observe significant benchmark savings (on order of -50%), even with the minor std::function<> dispatch overhead. The main reason is that std::ostringstream effectively requires 2 extra copies of the data beyond a simple string.append lambda. We also provide a parallel api for the load(), though this one is slightly more complex due to the need to do arbitrary position reads. Test Plan: buck test mode/dev-nosan caffe2/test/... (Basic serialization test in caffe2/test/cpp/api/serialize.cpp) Benchmark in experimental/jeremyl/c2/SerializationBench.cpp, with D17823443 (1M time goes from 90ms -> 40ms, albeit with crc patch applied) Differential Revision: D17939034 fbshipit-source-id: 344cce46f74b6438cb638a8cfbeccf4e1aa882d7
This commit is contained in:
committed by
Facebook Github Bot
parent
9cc4405dc9
commit
2e0294cb39
@ -88,7 +88,9 @@ inline c10::optional<TempFile> try_make_tempfile(
|
|||||||
if (fd == -1) {
|
if (fd == -1) {
|
||||||
return c10::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
return TempFile(std::string(filename.begin(), filename.end()), fd);
|
// Don't make the string from string(filename.begin(), filename.end(), or
|
||||||
|
// there will be a trailing '\0' at the end.
|
||||||
|
return TempFile(filename.data(), fd);
|
||||||
#endif // defined(_WIN32)
|
#endif // defined(_WIN32)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -232,40 +232,51 @@ PyTorchStreamReader::~PyTorchStreamReader() {
|
|||||||
valid("closing reader for archive ", archive_name_.c_str());
|
valid("closing reader for archive ", archive_name_.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t ostream_write_func(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n) {
|
size_t ostream_write_func(
|
||||||
|
void* pOpaque,
|
||||||
|
mz_uint64 file_ofs,
|
||||||
|
const void* pBuf,
|
||||||
|
size_t n) {
|
||||||
auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
|
auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
|
||||||
if (self->current_pos_ != file_ofs) {
|
if (self->current_pos_ != file_ofs) {
|
||||||
// xxx - windows ostringstream refuses to seek to the end of an empty string
|
CAFFE_THROW("unexpected pos ", self->current_pos_, " vs ", file_ofs);
|
||||||
// so we workaround this by not calling seek unless necessary
|
|
||||||
// in the case of the first write (to the empty string) file_ofs and
|
|
||||||
// current_pos_ will be 0 and the seek won't occur.
|
|
||||||
self->out_->seekp(file_ofs);
|
|
||||||
if(!*self->out_)
|
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
size_t ret = self->writer_func_(pBuf, n);
|
||||||
|
if (n != ret) {
|
||||||
|
self->err_seen_ = true;
|
||||||
|
}
|
||||||
|
self->current_pos_ += ret;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
self->out_->write(static_cast<const char*>(pBuf), n);
|
PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name)
|
||||||
if(!*self->out_)
|
: archive_name_(basename(file_name)) {
|
||||||
return 0;
|
setup(file_name);
|
||||||
self->current_pos_ = file_ofs + n;
|
|
||||||
return n;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PyTorchStreamWriter::PyTorchStreamWriter(
|
PyTorchStreamWriter::PyTorchStreamWriter(
|
||||||
std::string file_name,
|
const std::function<size_t(const void*, size_t)>& writer_func)
|
||||||
std::ostream* out)
|
: archive_name_("archive"), writer_func_(writer_func) {
|
||||||
: ar_(caffe2::make_unique<mz_zip_archive>()),
|
setup(archive_name_);
|
||||||
archive_name_(basename(file_name)),
|
}
|
||||||
out_(out) {
|
|
||||||
|
void PyTorchStreamWriter::setup(const string& file_name) {
|
||||||
|
ar_ = caffe2::make_unique<mz_zip_archive>();
|
||||||
memset(ar_.get(), 0, sizeof(mz_zip_archive));
|
memset(ar_.get(), 0, sizeof(mz_zip_archive));
|
||||||
|
archive_name_plus_slash_ = archive_name_ + "/"; // for writeRecord().
|
||||||
|
|
||||||
if (archive_name_.size() == 0) {
|
if (archive_name_.size() == 0) {
|
||||||
CAFFE_THROW("invalid file name: ", file_name);
|
CAFFE_THROW("invalid file name: ", file_name);
|
||||||
}
|
}
|
||||||
if (!out_) {
|
if (!writer_func_) {
|
||||||
file_stream_.open(file_name, std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
|
file_stream_.open(
|
||||||
out_ = &file_stream_;
|
file_name,
|
||||||
|
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
|
||||||
valid("opening archive ", file_name.c_str());
|
valid("opening archive ", file_name.c_str());
|
||||||
|
writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
|
||||||
|
file_stream_.write(static_cast<const char*>(buf), nbytes);
|
||||||
|
return !file_stream_ ? 0 : nbytes;
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
ar_->m_pIO_opaque = this;
|
ar_->m_pIO_opaque = this;
|
||||||
@ -279,11 +290,14 @@ PyTorchStreamWriter::PyTorchStreamWriter(
|
|||||||
writeRecord("version", version.str().c_str(), version.str().size());
|
writeRecord("version", version.str().c_str(), version.str().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
void PyTorchStreamWriter::writeRecord(const std::string& name, const void* data, size_t size, bool compress) {
|
void PyTorchStreamWriter::writeRecord(
|
||||||
|
const std::string& name,
|
||||||
|
const void* data,
|
||||||
|
size_t size,
|
||||||
|
bool compress) {
|
||||||
AT_ASSERT(!finalized_);
|
AT_ASSERT(!finalized_);
|
||||||
std::stringstream ss;
|
AT_ASSERT(!archive_name_plus_slash_.empty());
|
||||||
ss << archive_name_ << "/" << name;
|
std::string full_name = archive_name_plus_slash_ + name;
|
||||||
const std::string& full_name = ss.str();
|
|
||||||
std::string padding = getPadding(ar_->m_archive_size, full_name, size);
|
std::string padding = getPadding(ar_->m_archive_size, full_name, size);
|
||||||
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
|
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
|
||||||
mz_zip_writer_add_mem_ex_v2(
|
mz_zip_writer_add_mem_ex_v2(
|
||||||
@ -310,8 +324,9 @@ void PyTorchStreamWriter::writeEndOfFile() {
|
|||||||
mz_zip_writer_finalize_archive(ar_.get());
|
mz_zip_writer_finalize_archive(ar_.get());
|
||||||
mz_zip_writer_end(ar_.get());
|
mz_zip_writer_end(ar_.get());
|
||||||
valid("writing central directory for archive ", archive_name_.c_str());
|
valid("writing central directory for archive ", archive_name_.c_str());
|
||||||
if (file_stream_.is_open())
|
if (file_stream_.is_open()) {
|
||||||
file_stream_.close();
|
file_stream_.close();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PyTorchStreamWriter::valid(const char* what, const char* info) {
|
void PyTorchStreamWriter::valid(const char* what, const char* info) {
|
||||||
@ -324,7 +339,7 @@ void PyTorchStreamWriter::valid(const char* what, const char* info) {
|
|||||||
": ",
|
": ",
|
||||||
mz_zip_get_error_string(err));
|
mz_zip_get_error_string(err));
|
||||||
}
|
}
|
||||||
if (!*out_) {
|
if (err_seen_) {
|
||||||
CAFFE_THROW("PytorchStreamWriter failed ", what, info, ".");
|
CAFFE_THROW("PytorchStreamWriter failed ", what, info, ".");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -123,11 +123,15 @@ class CAFFE2_API PyTorchStreamReader final {
|
|||||||
|
|
||||||
class CAFFE2_API PyTorchStreamWriter final {
|
class CAFFE2_API PyTorchStreamWriter final {
|
||||||
public:
|
public:
|
||||||
PyTorchStreamWriter(std::string archive_name, std::ostream* out=nullptr);
|
explicit PyTorchStreamWriter(std::string archive_name);
|
||||||
PyTorchStreamWriter(std::ostream* out)
|
explicit PyTorchStreamWriter(
|
||||||
: PyTorchStreamWriter("archive", out) {}
|
const std::function<size_t(const void*, size_t)>& writer_func);
|
||||||
|
|
||||||
void writeRecord(const std::string& name, const void* data, size_t size, bool compress = false);
|
void writeRecord(
|
||||||
|
const std::string& name,
|
||||||
|
const void* data,
|
||||||
|
size_t size,
|
||||||
|
bool compress = false);
|
||||||
void writeEndOfFile();
|
void writeEndOfFile();
|
||||||
|
|
||||||
bool finalized() const {
|
bool finalized() const {
|
||||||
@ -141,13 +145,16 @@ class CAFFE2_API PyTorchStreamWriter final {
|
|||||||
~PyTorchStreamWriter();
|
~PyTorchStreamWriter();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void setup(const string& file_name);
|
||||||
void valid(const char* what, const char* info = "");
|
void valid(const char* what, const char* info = "");
|
||||||
size_t current_pos_ = 0;
|
size_t current_pos_ = 0;
|
||||||
std::unique_ptr<mz_zip_archive> ar_;
|
std::unique_ptr<mz_zip_archive> ar_;
|
||||||
std::string archive_name_;
|
std::string archive_name_;
|
||||||
std::ostream* out_;
|
std::string archive_name_plus_slash_;
|
||||||
std::ofstream file_stream_;
|
std::ofstream file_stream_;
|
||||||
|
std::function<size_t(const void*, size_t)> writer_func_;
|
||||||
bool finalized_ = false;
|
bool finalized_ = false;
|
||||||
|
bool err_seen_ = false;
|
||||||
friend size_t ostream_write_func(
|
friend size_t ostream_write_func(
|
||||||
void* pOpaque,
|
void* pOpaque,
|
||||||
uint64_t file_ofs,
|
uint64_t file_ofs,
|
||||||
|
@ -15,7 +15,10 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
|||||||
|
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
// write records through writers
|
// write records through writers
|
||||||
PyTorchStreamWriter writer(&oss);
|
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
|
||||||
|
oss.write(static_cast<const char*>(b), n);
|
||||||
|
return oss ? n : 0;
|
||||||
|
});
|
||||||
std::array<char, 127> data1;
|
std::array<char, 127> data1;
|
||||||
|
|
||||||
for (int i = 0; i < data1.size(); ++i) {
|
for (int i = 0; i < data1.size(); ++i) {
|
||||||
|
@ -60,6 +60,37 @@ TEST(SerializeTest, BasicToFile) {
|
|||||||
ASSERT_TRUE(x.allclose(y));
|
ASSERT_TRUE(x.allclose(y));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SerializeTest, BasicViaFunc) {
|
||||||
|
torch::manual_seed(0);
|
||||||
|
|
||||||
|
auto x = torch::randn({5, 5});
|
||||||
|
|
||||||
|
std::string serialized;
|
||||||
|
torch::save(x, [&](const void* buf, size_t n) {
|
||||||
|
serialized.append(reinterpret_cast<const char *>(buf), n);
|
||||||
|
return n;
|
||||||
|
});
|
||||||
|
torch::Tensor y;
|
||||||
|
torch::load(y, serialized.data(), serialized.size());
|
||||||
|
|
||||||
|
ASSERT_TRUE(y.defined());
|
||||||
|
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
|
||||||
|
ASSERT_TRUE(x.allclose(y));
|
||||||
|
|
||||||
|
torch::Tensor z;
|
||||||
|
torch::load(z, [&](uint64_t pos, void* buf, size_t n) -> size_t {
|
||||||
|
if (pos >= serialized.size()) return 0;
|
||||||
|
size_t nbytes = std::min(static_cast<size_t>(pos) + n,
|
||||||
|
serialized.size()) - pos;
|
||||||
|
memcpy(buf, serialized.data() + pos, nbytes);
|
||||||
|
return nbytes;
|
||||||
|
},
|
||||||
|
[&]() -> size_t { return serialized.size(); });
|
||||||
|
ASSERT_TRUE(z.defined());
|
||||||
|
ASSERT_EQ(x.sizes().vec(), z.sizes().vec());
|
||||||
|
ASSERT_TRUE(x.allclose(z));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(SerializeTest, Resized) {
|
TEST(SerializeTest, Resized) {
|
||||||
torch::manual_seed(0);
|
torch::manual_seed(0);
|
||||||
|
|
||||||
|
@ -81,6 +81,17 @@ class TORCH_API InputArchive final {
|
|||||||
void load_from(std::istream& stream,
|
void load_from(std::istream& stream,
|
||||||
c10::optional<torch::Device> device = c10::nullopt);
|
c10::optional<torch::Device> device = c10::nullopt);
|
||||||
|
|
||||||
|
// Loads given the specified flat array.
|
||||||
|
void load_from(const char* data, size_t size,
|
||||||
|
c10::optional<torch::Device> device = c10::nullopt);
|
||||||
|
|
||||||
|
// Loads given the specified read and size functions.
|
||||||
|
void load_from(
|
||||||
|
const std::function<size_t(
|
||||||
|
uint64_t pos, void* buf, size_t nbytes)>& read_func,
|
||||||
|
const std::function<size_t(void)>& size_func,
|
||||||
|
c10::optional<torch::Device> device = c10::nullopt);
|
||||||
|
|
||||||
/// Forwards all arguments to `read()`.
|
/// Forwards all arguments to `read()`.
|
||||||
/// Useful for generic code that can be re-used for both `InputArchive` and
|
/// Useful for generic code that can be re-used for both `InputArchive` and
|
||||||
/// `OutputArchive` (where `operator()` forwards to `write()`).
|
/// `OutputArchive` (where `operator()` forwards to `write()`).
|
||||||
|
@ -62,6 +62,10 @@ class TORCH_API OutputArchive final {
|
|||||||
/// `stream`.
|
/// `stream`.
|
||||||
void save_to(std::ostream& stream);
|
void save_to(std::ostream& stream);
|
||||||
|
|
||||||
|
/// Saves the `OutputArchive` into a serialized representation using the
|
||||||
|
/// given writer function.
|
||||||
|
void save_to(const std::function<size_t(const void*, size_t)>& func);
|
||||||
|
|
||||||
/// Forwards all arguments to `write()`.
|
/// Forwards all arguments to `write()`.
|
||||||
/// Useful for generic code that can be re-used for both `OutputArchive` and
|
/// Useful for generic code that can be re-used for both `OutputArchive` and
|
||||||
/// `InputArchive` (where `operator()` forwards to `read()`).
|
/// `InputArchive` (where `operator()` forwards to `read()`).
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
#include <torch/csrc/jit/import.h>
|
#include <torch/csrc/jit/import.h>
|
||||||
#include <torch/csrc/jit/script/module.h>
|
#include <torch/csrc/jit/script/module.h>
|
||||||
|
#include <caffe2/serialize/read_adapter_interface.h>
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
#include <istream>
|
#include <istream>
|
||||||
@ -94,5 +94,61 @@ void InputArchive::load_from(std::istream& stream,
|
|||||||
c10::optional<torch::Device> device /*= c10::nullopt*/) {
|
c10::optional<torch::Device> device /*= c10::nullopt*/) {
|
||||||
module_ = torch::jit::load(stream, std::move(device));
|
module_ = torch::jit::load(stream, std::move(device));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void InputArchive::load_from(
|
||||||
|
const char* data,
|
||||||
|
size_t size,
|
||||||
|
c10::optional<torch::Device> device /*= c10::nullopt*/) {
|
||||||
|
using caffe2::serialize::ReadAdapterInterface;
|
||||||
|
class OurAdapter : public ReadAdapterInterface {
|
||||||
|
public:
|
||||||
|
OurAdapter(const char* data, size_t size)
|
||||||
|
: data_(data), size_(size) {
|
||||||
|
}
|
||||||
|
size_t size() const override { return size_; }
|
||||||
|
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
|
||||||
|
const override {
|
||||||
|
(void) what;
|
||||||
|
if (pos >= size_) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
size_t nread = std::min(static_cast<size_t>(pos) + n, size_) - pos;
|
||||||
|
memcpy(buf, data_ + pos, nread);
|
||||||
|
return nread;
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
const char* data_;
|
||||||
|
size_t size_;
|
||||||
|
};
|
||||||
|
std::unique_ptr<OurAdapter> adapter(new OurAdapter(data, size));
|
||||||
|
module_ = torch::jit::load(std::move(adapter), std::move(device));
|
||||||
|
}
|
||||||
|
|
||||||
|
void InputArchive::load_from(
|
||||||
|
const std::function<size_t(uint64_t, void*, size_t)>& read_func,
|
||||||
|
const std::function<size_t(void)>& size_func,
|
||||||
|
c10::optional<torch::Device> device /*= c10::nullopt*/) {
|
||||||
|
using caffe2::serialize::ReadAdapterInterface;
|
||||||
|
class OurAdapter : public ReadAdapterInterface {
|
||||||
|
public:
|
||||||
|
OurAdapter(const std::function<size_t(uint64_t, void*, size_t)>& read_func,
|
||||||
|
const std::function<size_t(void)>& size_func)
|
||||||
|
: read_func_(read_func),
|
||||||
|
size_func_(size_func) {
|
||||||
|
}
|
||||||
|
size_t size() const override { return size_func_(); }
|
||||||
|
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
|
||||||
|
const override {
|
||||||
|
(void)what;
|
||||||
|
return read_func_(pos, buf, n);
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
const std::function<size_t(uint64_t, void*, size_t)>& read_func_;
|
||||||
|
const std::function<size_t(void)>& size_func_;
|
||||||
|
};
|
||||||
|
std::unique_ptr<OurAdapter> adapter(new OurAdapter(read_func, size_func));
|
||||||
|
module_ = torch::jit::load(std::move(adapter), std::move(device));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace serialize
|
} // namespace serialize
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -42,5 +42,10 @@ void OutputArchive::save_to(const std::string& filename) {
|
|||||||
void OutputArchive::save_to(std::ostream& stream) {
|
void OutputArchive::save_to(std::ostream& stream) {
|
||||||
jit::ExportModule(module_, stream);
|
jit::ExportModule(module_, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OutputArchive::save_to(
|
||||||
|
const std::function<size_t(const void*, size_t)>& func) {
|
||||||
|
jit::ExportModule(module_, func);
|
||||||
|
}
|
||||||
} // namespace serialize
|
} // namespace serialize
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -531,11 +531,12 @@ void GraphEncoder::EncodeTensor(
|
|||||||
|
|
||||||
class ScriptModuleSerializer {
|
class ScriptModuleSerializer {
|
||||||
public:
|
public:
|
||||||
ScriptModuleSerializer(const std::string& filename)
|
explicit ScriptModuleSerializer(const std::string& filename)
|
||||||
: writer_(filename.c_str()) {}
|
: writer_(filename) {}
|
||||||
|
|
||||||
ScriptModuleSerializer(std::ostream* ofs)
|
explicit ScriptModuleSerializer(
|
||||||
: ofs_(), writer_(ofs) {}
|
const std::function<size_t(const void *, size_t)>& writer_func)
|
||||||
|
: writer_(writer_func) {}
|
||||||
|
|
||||||
void serialize(
|
void serialize(
|
||||||
const script::Module& module,
|
const script::Module& module,
|
||||||
@ -772,7 +773,6 @@ class ScriptModuleSerializer {
|
|||||||
converted_types_.insert(class_type, std::move(info));
|
converted_types_.insert(class_type, std::move(info));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ofstream ofs_;
|
|
||||||
caffe2::serialize::PyTorchStreamWriter writer_;
|
caffe2::serialize::PyTorchStreamWriter writer_;
|
||||||
std::vector<at::Tensor> constant_table_;
|
std::vector<at::Tensor> constant_table_;
|
||||||
|
|
||||||
@ -1023,7 +1023,11 @@ void ExportModule(
|
|||||||
std::ostream& out,
|
std::ostream& out,
|
||||||
const script::ExtraFilesMap& extra_files,
|
const script::ExtraFilesMap& extra_files,
|
||||||
bool bytecode_format) {
|
bool bytecode_format) {
|
||||||
ScriptModuleSerializer serializer(&out);
|
ScriptModuleSerializer serializer(
|
||||||
|
[&](const void* buf, size_t nbytes) -> size_t {
|
||||||
|
out.write(static_cast<const char *>(buf), nbytes);
|
||||||
|
return !out ? 0 : nbytes;
|
||||||
|
});
|
||||||
serializer.serialize(module, extra_files, bytecode_format);
|
serializer.serialize(module, extra_files, bytecode_format);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1036,5 +1040,14 @@ void ExportModule(
|
|||||||
serializer.serialize(module, extra_files, bytecode_format);
|
serializer.serialize(module, extra_files, bytecode_format);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ExportModule(
|
||||||
|
const script::Module& module,
|
||||||
|
const std::function<size_t(const void*, size_t)>& writer_func,
|
||||||
|
const script::ExtraFilesMap& extra_files,
|
||||||
|
bool bytecode_format) {
|
||||||
|
ScriptModuleSerializer serializer(writer_func);
|
||||||
|
serializer.serialize(module, extra_files, bytecode_format);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -55,6 +55,12 @@ TORCH_API void ExportModule(
|
|||||||
const script::ExtraFilesMap& metadata = script::ExtraFilesMap(),
|
const script::ExtraFilesMap& metadata = script::ExtraFilesMap(),
|
||||||
bool bytecode_format = false);
|
bool bytecode_format = false);
|
||||||
|
|
||||||
|
TORCH_API void ExportModule(
|
||||||
|
const script::Module& module,
|
||||||
|
const std::function<size_t(const void*, size_t)>& writer_func,
|
||||||
|
const script::ExtraFilesMap& metadata = script::ExtraFilesMap(),
|
||||||
|
bool bytecode_format = false);
|
||||||
|
|
||||||
// Surrounding system can install an additional hook to produce extra files
|
// Surrounding system can install an additional hook to produce extra files
|
||||||
// with metadata based on environment every time a module is serialized.
|
// with metadata based on environment every time a module is serialized.
|
||||||
using ExportModuleExtraFilesHook =
|
using ExportModuleExtraFilesHook =
|
||||||
|
Reference in New Issue
Block a user