mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Back out "Revert D34805092: Extend _save_for_mobile and _load_for_mobile to support flatbuffer format; Default format is pickle + Change buck targets to support only pickle
and pickle + flatbuffer
for migration" (#74594)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74594
Extending `_save_for_mobile` and `_load_for_mobile` to support faltbuffer format with additional optional argument which is set to pick pickle by default.
Adding new binary target with suffix `_pickle_and_flatbuffer` to help migration.
Size test in D34909502 shows the size has regressed by ~40K but after removing pickle and comparing lite_predictors we have ~120K size measure that we will achieve when deprecating pickle and moving to flatbuffer
**BEFORE:**
```lang=mermaid
graph TD;
torch_core-->torch_mobile_deserialize;
torch_mobile_core-->torch_mobile_deserialize;
jit_module_saving-->torch_core;
jit_module_saving-->torch_mobile_core;
torch_mobile_deserialize-->caffe2_serialize;
torch_mobile_deserialize-->torch_mobile_module;
caffe2_serialize-->miniz;
flatbuffer_loader-->mobile_bytecode;
flatbuffer_serializer-->mobile_bytecode;
mobile_bytecode-->flatbuffer_2.0;
flatbuffer_loader-->torch_mobile_module;
flatbuffer_serializer-->torch_mobile_module;
```
**AFTER:**
```lang=mermaid
graph TD;
torch_core-->torch_mobile_deserialize;
torch_mobile_core-->torch_mobile_deserialize;
jit_module_saving-->torch_core;
jit_module_saving-->torch_mobile_core;
torch_mobile_deserialize-->caffe2_serialize;
torch_mobile_deserialize-->torch_mobile_module;
caffe2_serialize-->miniz;
flatbuffer_loader-->mobile_bytecode;
flatbuffer_serializer-->mobile_bytecode;
mobile_bytecode-->flatbuffer_2.0;
torch_mobile_deserialize_pickle_and_flatbuffer-->|new| flatbuffer_loader;
torch_mobile_deserialize_pickle_and_flatbuffer-->|new| torch_mobile_deserialize;
torch_mobile_core_pickle_and_flatbuffer-->|new| torch_mobile_deserialize_pickle_and_flatbuffer;
torch_core_pickle_and_flatbuffer-->|new| torch_mobile_deserialize_pickle_and_flatbuffer;
jit_module_saving_pickle_and_flatbuffer-->|new| torch_core_pickle_and_flatbuffer;
jit_module_saving_pickle_and_flatbuffer-->|new| torch_mobile_core_pickle_and_flatbuffer;
flatbuffer_serializer-->torch_mobile_module;
jit_module_saving_pickle_and_flatbuffer-->|new|jit_module_saving;
jit_module_saving_pickle_and_flatbuffer-->|new|flatbuffer_serializer;
flatbuffer_loader-->torch_mobile_module;
```
Original commit changeset: 780dfb6fd6ba
Original Phabricator Diff: D34805092 (284b2b7135
)
ghstack-source-id: 152044801
(Note: this ignores all push blocking failures!)
Test Plan:
CI
```
~/fbsource/fbcode] cd ~/fbsource/fbcode/ && buck test -c fbcode.caffe2_enable_flatbuffer=1 //caffe2/test/cpp/jit:jit -- FlatbufferTest.ExtraFiles
Parsing buck files: finished in 0.9 sec
Building: finished in 5.3 sec (100%) 12992/54304 jobs, 0/54304 updated
Total time: 6.2 sec
More details at https://www.internalfb.com/intern/buck/build/2b387fff-f813-4cfa-b53f-eb2378630d4e
BUILD SUCCEEDED
Tpx test run coordinator for Facebook. See https://fburl.com/tpx for details.
Running with tpx session id: f93a84d6-e7ce-41a0-a97f-0ef3fa6d199d
Trace available for this run at /tmp/tpx-20220323-134108.766518-f93a84d6-e7ce-41a0-a97f-0ef3fa6d199d/trace.log
RemoteExecution session id: reSessionID-f93a84d6-e7ce-41a0-a97f-0ef3fa6d199d-tpx
Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/4503599723101693
✓ ListingSuccess: caffe2/test/cpp/jit:jit : 486 tests discovered (19.122)
✓ Pass: caffe2/test/cpp/jit:jit - FlatbufferTest.ExtraFiles (0.187)
Summary
Pass: 1
ListingSuccess: 1
If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users
Finished test run: https://www.internalfb.com/intern/testinfra/testrun/4503599723101693
```
Similar Build Deps Dags
```
[pavithran@devvm5216.vll0 /data/users/pavithran/fbsource] buck query 'allpaths(//xplat/caffe2:torch_mobile_all_ops_pickle_and_flatbuffer, //xplat/caffe2:torch_mobile_deserialize_pickle_and_flatbuffer)' --output-format dot-compact | pastry
P486770901: https://www.internalfb.com/intern/paste/P486770901/
[pavithran@devvm5216.vll0 /data/users/pavithran/fbsource] buck query 'allpaths(//xplat/caffe2:torch_mobile_all_ops, //xplat/caffe2:torch_mobile_deserialize)' --output-format dot-compact | pastry
P486771278: https://www.internalfb.com/intern/paste/P486771278/
```
pickle_and_flatbuffer: https://www.internalfb.com/intern/dgw/graph/?build_id=P486770901
pickle: https://www.internalfb.com/intern/dgw/graph/?build_id=P486771278
Reviewed By: iseeyuan
Differential Revision: D35067157
fbshipit-source-id: 9044259c17a2e0da79bd6aedb28efbdfd57e23e0
(cherry picked from commit f738069ec3a72e79da56172741d027de514e9e5f)
This commit is contained in:
committed by
PyTorch MergeBot
parent
d64e7634ff
commit
fc2cf3d26f
@ -153,16 +153,30 @@ TEST(FlatbufferTest, ExtraFiles) {
|
||||
extra_files["metadata.json"] = "abc";
|
||||
extra_files["mobile_info.json"] = "{\"key\": 23}";
|
||||
|
||||
std::unordered_map<std::string, std::string> loaded_extra_files;
|
||||
#if defined ENABLE_FLATBUFFER
|
||||
std::stringstream ss;
|
||||
module->_save_for_mobile(ss, extra_files, true, /*use_flatbuffer=*/true);
|
||||
|
||||
loaded_extra_files["metadata.json"] = "";
|
||||
auto mobile_module = _load_for_mobile(ss, c10::nullopt, loaded_extra_files);
|
||||
|
||||
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
|
||||
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
|
||||
|
||||
// load it twice using the same stream
|
||||
auto mobile_module2 = _load_for_mobile(ss, c10::nullopt, loaded_extra_files);
|
||||
#else
|
||||
CompilationOptions options;
|
||||
mobile::Module bc = jitModuleToMobile(*module, options);
|
||||
auto buff = save_mobile_module_to_bytes(bc, extra_files);
|
||||
|
||||
std::unordered_map<std::string, std::string> loaded_extra_files;
|
||||
loaded_extra_files["metadata.json"] = "";
|
||||
auto* flatbuffer_module =
|
||||
mobile::serialization::GetMutableModule(buff.data());
|
||||
|
||||
parseExtraFiles(flatbuffer_module, loaded_extra_files);
|
||||
#endif
|
||||
|
||||
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
|
||||
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
|
||||
|
@ -991,7 +991,6 @@ TEST(LiteInterpreterTest, ExtraFiles) {
|
||||
module->_save_for_mobile(oss, extra_files);
|
||||
|
||||
std::istringstream iss(oss.str());
|
||||
caffe2::serialize::IStreamAdapter adapter{&iss};
|
||||
std::unordered_map<std::string, std::string> loaded_extra_files;
|
||||
loaded_extra_files["metadata.json"] = "";
|
||||
torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
|
||||
@ -1006,7 +1005,7 @@ TEST(LiteInterpreterTest, ExtraFiles) {
|
||||
loaded_extra_files[file_name.substr(6)] = "";
|
||||
}
|
||||
}
|
||||
|
||||
iss.seekg(0, iss.beg);
|
||||
torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
|
||||
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
|
||||
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
|
||||
|
@ -223,12 +223,14 @@ struct TORCH_API Module : public Object {
|
||||
void _save_for_mobile(
|
||||
std::ostream& out,
|
||||
const ExtraFilesMap& extra_files = ExtraFilesMap(),
|
||||
bool save_mobile_debug_info = false) const;
|
||||
bool save_mobile_debug_info = false,
|
||||
bool use_flatbuffer = false) const;
|
||||
|
||||
void _save_for_mobile(
|
||||
const std::string& filename,
|
||||
const ExtraFilesMap& extra_files = ExtraFilesMap(),
|
||||
bool save_mobile_debug_info = false) const;
|
||||
bool save_mobile_debug_info = false,
|
||||
bool use_flatbuffer = false) const;
|
||||
|
||||
Module copy() const;
|
||||
|
||||
|
@ -16,25 +16,29 @@ void Module::save(const std::string& filename, const ExtraFilesMap& extra_files)
|
||||
void Module::_save_for_mobile(
|
||||
std::ostream& out,
|
||||
const ExtraFilesMap& extra_files,
|
||||
bool save_mobile_debug_info) const {
|
||||
bool save_mobile_debug_info,
|
||||
bool use_flatbuffer) const {
|
||||
ExportModule(
|
||||
*this,
|
||||
out,
|
||||
extra_files,
|
||||
true /* bytecode_format */,
|
||||
save_mobile_debug_info);
|
||||
save_mobile_debug_info,
|
||||
use_flatbuffer);
|
||||
}
|
||||
|
||||
void Module::_save_for_mobile(
|
||||
const std::string& filename,
|
||||
const ExtraFilesMap& extra_files,
|
||||
bool save_mobile_debug_info) const {
|
||||
bool save_mobile_debug_info,
|
||||
bool use_flatbuffer) const {
|
||||
ExportModule(
|
||||
*this,
|
||||
filename,
|
||||
extra_files,
|
||||
true /* bytecode_format */,
|
||||
save_mobile_debug_info);
|
||||
save_mobile_debug_info,
|
||||
use_flatbuffer);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <c10/core/impl/alloc_cpu.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/ScopeExit.h>
|
||||
@ -589,26 +590,34 @@ std::tuple<std::shared_ptr<char>, size_t> get_file_content(
|
||||
// make sure buffer size is multiple of alignment
|
||||
size_t buffer_size =
|
||||
(size / FLATBUFFERS_MAX_ALIGNMENT + 1) * FLATBUFFERS_MAX_ALIGNMENT;
|
||||
#if defined(__ANDROID__)
|
||||
std::shared_ptr<char> data(
|
||||
static_cast<char*>(memalign(FLATBUFFERS_MAX_ALIGNMENT, buffer_size)),
|
||||
free);
|
||||
#elif defined(_WIN32)
|
||||
std::shared_ptr<char> data(
|
||||
static_cast<char*>(
|
||||
_aligned_malloc(buffer_size, FLATBUFFERS_MAX_ALIGNMENT)),
|
||||
_aligned_free); // NOLINT
|
||||
#else
|
||||
std::shared_ptr<char> data(
|
||||
static_cast<char*>(aligned_alloc(FLATBUFFERS_MAX_ALIGNMENT, buffer_size)),
|
||||
free); // NOLINT
|
||||
#endif
|
||||
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
|
||||
fread(data.get(), size, 1, f);
|
||||
fclose(f);
|
||||
#endif
|
||||
return std::make_tuple(data, size);
|
||||
}
|
||||
|
||||
std::tuple<std::shared_ptr<char>, size_t> get_stream_content(std::istream& in) {
|
||||
// get size of the stream and reset to orig
|
||||
std::streampos orig_pos = in.tellg();
|
||||
in.seekg(orig_pos, std::ios::end);
|
||||
const long size = in.tellg();
|
||||
in.seekg(orig_pos, in.beg);
|
||||
|
||||
// read stream
|
||||
// NOLINT make sure buffer size is multiple of alignment
|
||||
size_t buffer_size =
|
||||
(size / FLATBUFFERS_MAX_ALIGNMENT + 1) * FLATBUFFERS_MAX_ALIGNMENT;
|
||||
std::shared_ptr<char> data(
|
||||
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
|
||||
in.read(data.get(), size);
|
||||
|
||||
// reset stream to original position
|
||||
in.seekg(orig_pos, in.beg);
|
||||
return std::make_tuple(data, size);
|
||||
}
|
||||
|
||||
void FlatbufferLoader::extractJitSourceAndConstants(
|
||||
ExtraFilesMap* jit_sources,
|
||||
std::vector<IValue>* constants) {
|
||||
@ -626,6 +635,9 @@ mobile::Module parse_and_initialize_mobile_module(
|
||||
std::shared_ptr<char> data,
|
||||
size_t,
|
||||
c10::optional<at::Device>) {
|
||||
TORCH_CHECK(
|
||||
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
|
||||
"Format error");
|
||||
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
|
||||
mobile::Module m = FlatbufferLoader().parseModule(flatbuffer_module);
|
||||
m.set_delete_memory(std::move(data));
|
||||
|
@ -59,6 +59,9 @@ TORCH_API void parseExtraFiles(
|
||||
TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_file_content(
|
||||
const char* filename);
|
||||
|
||||
TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
|
||||
std::istream& in);
|
||||
|
||||
class TORCH_API FlatbufferLoader {
|
||||
public:
|
||||
FlatbufferLoader();
|
||||
|
@ -10,6 +10,10 @@
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <caffe2/serialize/versions.h>
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/mobile/file_format.h>
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||
#endif
|
||||
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||
#include <torch/csrc/jit/mobile/observer.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
@ -536,18 +540,47 @@ mobile::Module _load_for_mobile(
|
||||
std::istream& in,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
|
||||
auto module = _load_for_mobile(std::move(rai), device, extra_files);
|
||||
return module;
|
||||
auto format = getFileFormat(in);
|
||||
switch (format) {
|
||||
case FileFormat::ZipFileFormat: {
|
||||
std::unique_ptr<IStreamAdapter> rai =
|
||||
std::make_unique<IStreamAdapter>(&in);
|
||||
auto module = _load_for_mobile(std::move(rai), device, extra_files);
|
||||
return module;
|
||||
}
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_stream_content(in);
|
||||
auto* flatbuffer_module =
|
||||
mobile::serialization::GetMutableModule(data.get());
|
||||
mobile::Module m = initialize_mobile_module(flatbuffer_module);
|
||||
parseExtraFiles(flatbuffer_module, extra_files);
|
||||
return m;
|
||||
}
|
||||
#else
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Flatbuffer input file but the build hasn't enabled flatbuffer");
|
||||
}
|
||||
#endif
|
||||
default: {
|
||||
TORCH_CHECK(false, "Format error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mobile::Module _load_for_mobile(
|
||||
const std::string& filename,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
|
||||
auto module = _load_for_mobile(std::move(rai), device, extra_files);
|
||||
return module;
|
||||
return _load_for_mobile(
|
||||
filename,
|
||||
device,
|
||||
extra_files,
|
||||
/*module_load_options=*/_default_mobile_module_load_options);
|
||||
}
|
||||
|
||||
mobile::Module _load_for_mobile(
|
||||
@ -555,10 +588,37 @@ mobile::Module _load_for_mobile(
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files,
|
||||
uint64_t module_load_options) {
|
||||
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
|
||||
auto module = _load_for_mobile_impl(
|
||||
std::move(rai), device, extra_files, module_load_options);
|
||||
return module;
|
||||
auto format = getFileFormat(filename);
|
||||
switch (format) {
|
||||
case FileFormat::ZipFileFormat: {
|
||||
std::unique_ptr<FileAdapter> rai =
|
||||
std::make_unique<FileAdapter>(filename);
|
||||
auto module = _load_for_mobile_impl(
|
||||
std::move(rai), device, extra_files, module_load_options);
|
||||
return module;
|
||||
}
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_file_content(filename.c_str());
|
||||
auto* flatbuffer_module =
|
||||
mobile::serialization::GetMutableModule(data.get());
|
||||
mobile::Module m = initialize_mobile_module(flatbuffer_module);
|
||||
parseExtraFiles(flatbuffer_module, extra_files);
|
||||
return m;
|
||||
}
|
||||
#else
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Flatbuffer input file but the build hasn't enabled flatbuffer");
|
||||
}
|
||||
#endif
|
||||
default: {
|
||||
TORCH_CHECK(false, "Format error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mobile::Module _load_for_mobile(
|
||||
|
@ -9,7 +9,6 @@
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
#include <torch/csrc/jit/serialization/unpickler.h>
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
#include <exception>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
@ -1096,23 +1096,32 @@ void initJitScriptBindings(PyObject* module) {
|
||||
[](Module& m,
|
||||
const std::string& filename,
|
||||
const ExtraFilesMap& _extra_files = ExtraFilesMap(),
|
||||
bool _save_mobile_debug_info = false) {
|
||||
m._save_for_mobile(filename, _extra_files, _save_mobile_debug_info);
|
||||
bool _save_mobile_debug_info = false,
|
||||
bool _use_flatbuffer = false) {
|
||||
m._save_for_mobile(
|
||||
filename,
|
||||
_extra_files,
|
||||
_save_mobile_debug_info,
|
||||
_use_flatbuffer);
|
||||
},
|
||||
py::arg("filename"),
|
||||
py::arg("_extra_files") = ExtraFilesMap(),
|
||||
py::arg("_save_mobile_debug_info") = false)
|
||||
py::arg("_save_mobile_debug_info") = false,
|
||||
py::arg("_use_flatbuffer") = false)
|
||||
.def(
|
||||
"_save_to_buffer_for_mobile",
|
||||
[](Module& m,
|
||||
const ExtraFilesMap& _extra_files = ExtraFilesMap(),
|
||||
bool _save_mobile_debug_info = false) {
|
||||
bool _save_mobile_debug_info = false,
|
||||
bool _use_flatbuffer = false) {
|
||||
std::ostringstream buf;
|
||||
m._save_for_mobile(buf, _extra_files, _save_mobile_debug_info);
|
||||
m._save_for_mobile(
|
||||
buf, _extra_files, _save_mobile_debug_info, _use_flatbuffer);
|
||||
return py::bytes(buf.str());
|
||||
},
|
||||
py::arg("_extra_files") = ExtraFilesMap(),
|
||||
py::arg("_save_mobile_debug_info") = false)
|
||||
py::arg("_save_mobile_debug_info") = false,
|
||||
py::arg("_use_flatbuffer") = false)
|
||||
.def("_set_optimized", &Module::set_optimized)
|
||||
.def(
|
||||
"dump",
|
||||
@ -1891,6 +1900,10 @@ void initJitScriptBindings(PyObject* module) {
|
||||
std::istringstream in(buffer);
|
||||
return _get_mobile_model_contained_types(in);
|
||||
});
|
||||
m.def("_nn_module_to_mobile", [](const Module& module) {
|
||||
CompilationOptions options;
|
||||
return jitModuleToMobile(module, options);
|
||||
});
|
||||
py::class_<OperatorInfo>(m, "OperatorInfo")
|
||||
.def_readonly("num_schema_args", &OperatorInfo::num_schema_args);
|
||||
m.def("_get_model_ops_and_info", [](const std::string& filename) {
|
||||
|
@ -158,21 +158,24 @@ TORCH_API void ExportModule(
|
||||
std::ostream& out,
|
||||
const ExtraFilesMap& metadata = ExtraFilesMap(),
|
||||
bool bytecode_format = false,
|
||||
bool save_mobile_debug_info = false);
|
||||
bool save_mobile_debug_info = false,
|
||||
bool use_flatbuffer = false);
|
||||
|
||||
TORCH_API void ExportModule(
|
||||
const Module& module,
|
||||
const std::string& filename,
|
||||
const ExtraFilesMap& metadata = ExtraFilesMap(),
|
||||
bool bytecode_format = false,
|
||||
bool save_mobile_debug_info = false);
|
||||
bool save_mobile_debug_info = false,
|
||||
bool use_flatbuffer = false);
|
||||
|
||||
TORCH_API void ExportModule(
|
||||
const Module& module,
|
||||
const std::function<size_t(const void*, size_t)>& writer_func,
|
||||
const ExtraFilesMap& metadata = ExtraFilesMap(),
|
||||
bool bytecode_format = false,
|
||||
bool save_mobile_debug_info = false);
|
||||
bool save_mobile_debug_info = false,
|
||||
bool use_flatbuffer = false);
|
||||
|
||||
// Write the bytes of a pickle archive and the tensors referenced inside that
|
||||
// archive
|
||||
|
@ -16,6 +16,9 @@
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
|
||||
#include <torch/csrc/jit/serialization/export_bytecode.h>
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
||||
#endif
|
||||
#include <torch/csrc/jit/serialization/import_export_constants.h>
|
||||
#include <torch/csrc/jit/serialization/import_export_functions.h>
|
||||
#include <torch/csrc/jit/serialization/import_export_helpers.h>
|
||||
@ -788,20 +791,45 @@ SerializationStorageContext& ScriptModuleSerializer::storage_context() {
|
||||
return storage_context_;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
void save_mobile_module_to(
|
||||
const Module& module,
|
||||
const ExtraFilesMap& extra_files,
|
||||
bool save_mobile_debug_info,
|
||||
const std::function<size_t(const void*, size_t)>& writer_func) {
|
||||
CompilationOptions options = getOptionsFromGlobal();
|
||||
mobile::Module mod = jitModuleToMobile(module, options);
|
||||
auto buffer = save_mobile_module_to_bytes(mod, extra_files);
|
||||
writer_func(reinterpret_cast<void*>(buffer.data()), buffer.size());
|
||||
}
|
||||
#endif
|
||||
|
||||
void ExportModule(
|
||||
const Module& module,
|
||||
std::ostream& out,
|
||||
const ExtraFilesMap& extra_files,
|
||||
bool bytecode_format,
|
||||
bool save_mobile_debug_info) {
|
||||
caffe2::serialize::PyTorchStreamWriter writer(
|
||||
[&](const void* buf, size_t nbytes) -> size_t {
|
||||
out.write(static_cast<const char*>(buf), nbytes);
|
||||
return !out ? 0 : nbytes;
|
||||
});
|
||||
ScriptModuleSerializer serializer(writer);
|
||||
serializer.serialize(
|
||||
module, extra_files, bytecode_format, save_mobile_debug_info);
|
||||
bool save_mobile_debug_info,
|
||||
bool use_flatbuffer) {
|
||||
auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
|
||||
out.write(static_cast<const char*>(buf), nbytes);
|
||||
return !out ? 0 : nbytes;
|
||||
};
|
||||
if (use_flatbuffer) {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
save_mobile_module_to(
|
||||
module, extra_files, save_mobile_debug_info, writer_func);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
|
||||
#endif
|
||||
} else {
|
||||
caffe2::serialize::PyTorchStreamWriter writer(writer_func);
|
||||
ScriptModuleSerializer serializer(writer);
|
||||
serializer.serialize(
|
||||
module, extra_files, bytecode_format, save_mobile_debug_info);
|
||||
}
|
||||
}
|
||||
|
||||
void ExportModule(
|
||||
@ -809,11 +837,29 @@ void ExportModule(
|
||||
const std::string& filename,
|
||||
const ExtraFilesMap& extra_files,
|
||||
bool bytecode_format,
|
||||
bool save_mobile_debug_info) {
|
||||
caffe2::serialize::PyTorchStreamWriter writer(filename);
|
||||
ScriptModuleSerializer serializer(writer);
|
||||
serializer.serialize(
|
||||
module, extra_files, bytecode_format, save_mobile_debug_info);
|
||||
bool save_mobile_debug_info,
|
||||
bool use_flatbuffer) {
|
||||
if (use_flatbuffer) {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
|
||||
std::fstream ofile(filename, std::ios::binary | std::ios::out);
|
||||
ofile.write(static_cast<const char*>(buf), nbytes);
|
||||
ofile.close();
|
||||
return !ofile ? 0 : nbytes;
|
||||
};
|
||||
save_mobile_module_to(
|
||||
module, extra_files, save_mobile_debug_info, writer_func);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
|
||||
#endif
|
||||
} else {
|
||||
caffe2::serialize::PyTorchStreamWriter writer(filename);
|
||||
ScriptModuleSerializer serializer(writer);
|
||||
serializer.serialize(
|
||||
module, extra_files, bytecode_format, save_mobile_debug_info);
|
||||
}
|
||||
}
|
||||
|
||||
void ExportModule(
|
||||
@ -821,11 +867,23 @@ void ExportModule(
|
||||
const std::function<size_t(const void*, size_t)>& writer_func,
|
||||
const ExtraFilesMap& extra_files,
|
||||
bool bytecode_format,
|
||||
bool save_mobile_debug_info) {
|
||||
caffe2::serialize::PyTorchStreamWriter writer(writer_func);
|
||||
ScriptModuleSerializer serializer(writer);
|
||||
serializer.serialize(
|
||||
module, extra_files, bytecode_format, save_mobile_debug_info);
|
||||
bool save_mobile_debug_info,
|
||||
bool use_flatbuffer) {
|
||||
if (use_flatbuffer) {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
save_mobile_module_to(
|
||||
module, extra_files, save_mobile_debug_info, writer_func);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
|
||||
#endif
|
||||
} else {
|
||||
caffe2::serialize::PyTorchStreamWriter writer(writer_func);
|
||||
ScriptModuleSerializer serializer(writer);
|
||||
serializer.serialize(
|
||||
module, extra_files, bytecode_format, save_mobile_debug_info);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -2506,17 +2506,21 @@ inline const torch::jit::mobile::serialization::Module *GetModule(const void *bu
|
||||
return flatbuffers::GetRoot<torch::jit::mobile::serialization::Module>(buf);
|
||||
}
|
||||
|
||||
inline const torch::jit::mobile::serialization::Module *GetSizePrefixedModule(const void *buf) {
|
||||
return flatbuffers::GetSizePrefixedRoot<torch::jit::mobile::serialization::Module>(buf);
|
||||
}
|
||||
|
||||
inline Module *GetMutableModule(void *buf) {
|
||||
inline Module* GetMutableModule(void* buf) {
|
||||
return flatbuffers::GetMutableRoot<Module>(buf);
|
||||
}
|
||||
|
||||
inline torch::jit::mobile::serialization::Module *GetMutableSizePrefixedModule(void *buf) {
|
||||
return flatbuffers::GetMutableSizePrefixedRoot<torch::jit::mobile::serialization::Module>(buf);
|
||||
}
|
||||
// inline const torch::jit::mobile::serialization::Module
|
||||
// *GetSizePrefixedModule(const void *buf) {
|
||||
// return
|
||||
// flatbuffers::GetSizePrefixedRoot<torch::jit::mobile::serialization::Module>(buf);
|
||||
// }
|
||||
|
||||
// inline torch::jit::mobile::serialization::Module
|
||||
// *GetMutableSizePrefixedModule(void *buf) {
|
||||
// return
|
||||
// flatbuffers::GetMutableSizePrefixedRoot<torch::jit::mobile::serialization::Module>(buf);
|
||||
// }
|
||||
|
||||
inline const char *ModuleIdentifier() {
|
||||
return "PTMF";
|
||||
|
Reference in New Issue
Block a user