Revert D34805092: Extend _save_for_mobile and _load_for_mobile to support flatbuffer format; Default format is pickle + Change buck targets to support only pickle and pickle + flatbuffer for migration

Test Plan: revert-hammer

Differential Revision:
D34805092 (284b2b7135)

Original commit changeset: 57f3fc81d68f

Original Phabricator Diff: D34805092 (284b2b7135)

fbshipit-source-id: 780dfb6fd6ba5f9348f24a2fb3c57971b7155541
(cherry picked from commit bebeb8b84e11c34cbde4857d0e1c291731a7c781)
This commit is contained in:
Nikita Shulga
2022-03-22 15:39:28 -07:00
committed by PyTorch MergeBot
parent 144b7de9dd
commit c53b3ed20f
11 changed files with 48 additions and 261 deletions

View File

@ -153,30 +153,16 @@ TEST(FlatbufferTest, ExtraFiles) {
extra_files["metadata.json"] = "abc";
extra_files["mobile_info.json"] = "{\"key\": 23}";
std::unordered_map<std::string, std::string> loaded_extra_files;
#if defined ENABLE_FLATBUFFER
std::stringstream ss;
module->_save_for_mobile(ss, extra_files, true, /*use_flatbuffer=*/true);
loaded_extra_files["metadata.json"] = "";
auto mobile_module = _load_for_mobile(ss, c10::nullopt, loaded_extra_files);
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
// load it twice using the same stream
auto mobile_module2 = _load_for_mobile(ss, c10::nullopt, loaded_extra_files);
#else
CompilationOptions options;
mobile::Module bc = jitModuleToMobile(*module, options);
auto buff = save_mobile_module_to_bytes(bc, extra_files);
std::unordered_map<std::string, std::string> loaded_extra_files;
loaded_extra_files["metadata.json"] = "";
auto* flatbuffer_module =
mobile::serialization::GetMutableModule(buff.data());
parseExtraFiles(flatbuffer_module, loaded_extra_files);
#endif
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");

View File

@ -991,9 +991,9 @@ TEST(LiteInterpreterTest, ExtraFiles) {
module->_save_for_mobile(oss, extra_files);
std::istringstream iss(oss.str());
caffe2::serialize::IStreamAdapter adapter{&iss};
std::unordered_map<std::string, std::string> loaded_extra_files;
loaded_extra_files["metadata.json"] = "";
ASSERT_TRUE(iss.tellg() == std::ios::beg);
torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
@ -1006,8 +1006,7 @@ TEST(LiteInterpreterTest, ExtraFiles) {
loaded_extra_files[file_name.substr(6)] = "";
}
}
iss.seekg(0, std::ios::beg);
ASSERT_TRUE(iss.tellg() == std::ios::beg);
torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");

View File

@ -223,14 +223,12 @@ struct TORCH_API Module : public Object {
void _save_for_mobile(
std::ostream& out,
const ExtraFilesMap& extra_files = ExtraFilesMap(),
bool save_mobile_debug_info = false,
bool use_flatbuffer = false) const;
bool save_mobile_debug_info = false) const;
void _save_for_mobile(
const std::string& filename,
const ExtraFilesMap& extra_files = ExtraFilesMap(),
bool save_mobile_debug_info = false,
bool use_flatbuffer = false) const;
bool save_mobile_debug_info = false) const;
Module copy() const;

View File

@ -16,29 +16,25 @@ void Module::save(const std::string& filename, const ExtraFilesMap& extra_files)
void Module::_save_for_mobile(
std::ostream& out,
const ExtraFilesMap& extra_files,
bool save_mobile_debug_info,
bool use_flatbuffer) const {
bool save_mobile_debug_info) const {
ExportModule(
*this,
out,
extra_files,
true /* bytecode_format */,
save_mobile_debug_info,
use_flatbuffer);
save_mobile_debug_info);
}
void Module::_save_for_mobile(
const std::string& filename,
const ExtraFilesMap& extra_files,
bool save_mobile_debug_info,
bool use_flatbuffer) const {
bool save_mobile_debug_info) const {
ExportModule(
*this,
filename,
extra_files,
true /* bytecode_format */,
save_mobile_debug_info,
use_flatbuffer);
save_mobile_debug_info);
}
} // namespace jit

View File

@ -609,34 +609,6 @@ std::tuple<std::shared_ptr<char>, size_t> get_file_content(
return std::make_tuple(data, size);
}
std::tuple<std::shared_ptr<char>, size_t> get_stream_content(std::istream& in) {
// get size of the stream and reset to orig
std::streampos orig_pos = in.tellg();
in.seekg(orig_pos, std::ios::end);
const long size = in.tellg();
in.seekg(orig_pos, in.beg);
// read stream
// NOLINT make sure buffer size is multiple of alignment
size_t buffer_size =
(size / FLATBUFFERS_MAX_ALIGNMENT + 1) * FLATBUFFERS_MAX_ALIGNMENT;
#ifdef _WIN32
std::shared_ptr<char> data(
static_cast<char*>(
_aligned_malloc(buffer_size, FLATBUFFERS_MAX_ALIGNMENT)),
_aligned_free); // NOLINT
#else
std::shared_ptr<char> data(
static_cast<char*>(aligned_alloc(FLATBUFFERS_MAX_ALIGNMENT, buffer_size)),
free); // NOLINT
#endif
in.read(data.get(), size);
// reset stream to original position
in.seekg(orig_pos, in.beg);
return std::make_tuple(data, size);
}
void FlatbufferLoader::extractJitSourceAndConstants(
ExtraFilesMap* jit_sources,
std::vector<IValue>* constants) {
@ -654,9 +626,6 @@ mobile::Module parse_and_initialize_mobile_module(
std::shared_ptr<char> data,
size_t,
c10::optional<at::Device>) {
TORCH_CHECK(
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
"Format error");
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
mobile::Module m = FlatbufferLoader().parseModule(flatbuffer_module);
m.set_delete_memory(std::move(data));

View File

@ -59,9 +59,6 @@ TORCH_API void parseExtraFiles(
TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_file_content(
const char* filename);
TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
std::istream& in);
class TORCH_API FlatbufferLoader {
public:
FlatbufferLoader();

View File

@ -10,10 +10,6 @@
#include <caffe2/serialize/inline_container.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/mobile/file_format.h>
#if defined(ENABLE_FLATBUFFER)
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#endif
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/mobile/type_parser.h>
@ -540,72 +536,18 @@ mobile::Module _load_for_mobile(
std::istream& in,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
auto format = getFileFormat(in);
switch (format) {
case FileFormat::ZipFileFormat: {
std::unique_ptr<IStreamAdapter> rai =
std::make_unique<IStreamAdapter>(&in);
auto module = _load_for_mobile(std::move(rai), device, extra_files);
return module;
}
#if defined(ENABLE_FLATBUFFER)
case FileFormat::FlatbufferFileFormat: {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_stream_content(in);
auto* flatbuffer_module =
mobile::serialization::GetMutableModule(data.get());
mobile::Module m = initialize_mobile_module(flatbuffer_module);
parseExtraFiles(flatbuffer_module, extra_files);
return m;
}
#else
case FileFormat::FlatbufferFileFormat: {
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
}
#endif
default: {
TORCH_CHECK(false, "Format error");
}
}
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
auto module = _load_for_mobile(std::move(rai), device, extra_files);
return module;
}
mobile::Module _load_for_mobile(
const std::string& filename,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
auto format = getFileFormat(filename);
switch (format) {
case FileFormat::ZipFileFormat: {
std::unique_ptr<FileAdapter> rai =
std::make_unique<FileAdapter>(filename);
auto module = _load_for_mobile(std::move(rai), device, extra_files);
return module;
}
#if defined(ENABLE_FLATBUFFER)
case FileFormat::FlatbufferFileFormat: {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_file_content(filename.c_str());
auto* flatbuffer_module =
mobile::serialization::GetMutableModule(data.get());
mobile::Module m = initialize_mobile_module(flatbuffer_module);
parseExtraFiles(flatbuffer_module, extra_files);
return m;
}
#else
case FileFormat::FlatbufferFileFormat: {
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
}
#endif
default: {
TORCH_CHECK(false, "Format error");
}
}
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
auto module = _load_for_mobile(std::move(rai), device, extra_files);
return module;
}
mobile::Module _load_for_mobile(
@ -613,37 +555,10 @@ mobile::Module _load_for_mobile(
c10::optional<at::Device> device,
ExtraFilesMap& extra_files,
uint64_t module_load_options) {
auto format = getFileFormat(filename);
switch (format) {
case FileFormat::ZipFileFormat: {
std::unique_ptr<FileAdapter> rai =
std::make_unique<FileAdapter>(filename);
auto module = _load_for_mobile_impl(
std::move(rai), device, extra_files, module_load_options);
return module;
}
#if defined(ENABLE_FLATBUFFER)
case FileFormat::FlatbufferFileFormat: {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_file_content(filename.c_str());
auto* flatbuffer_module =
mobile::serialization::GetMutableModule(data.get());
mobile::Module m = initialize_mobile_module(flatbuffer_module);
parseExtraFiles(flatbuffer_module, extra_files);
return m;
}
#else
case FileFormat::FlatbufferFileFormat: {
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
}
#endif
default: {
TORCH_CHECK(false, "Format error");
}
}
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
auto module = _load_for_mobile_impl(
std::move(rai), device, extra_files, module_load_options);
return module;
}
mobile::Module _load_for_mobile(

View File

@ -9,6 +9,7 @@
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/unpickler.h>
#include <torch/custom_class.h>
#include <exception>
#include <fstream>
#include <string>

View File

@ -1096,32 +1096,23 @@ void initJitScriptBindings(PyObject* module) {
[](Module& m,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap(),
bool _save_mobile_debug_info = false,
bool _use_flatbuffer = false) {
m._save_for_mobile(
filename,
_extra_files,
_save_mobile_debug_info,
_use_flatbuffer);
bool _save_mobile_debug_info = false) {
m._save_for_mobile(filename, _extra_files, _save_mobile_debug_info);
},
py::arg("filename"),
py::arg("_extra_files") = ExtraFilesMap(),
py::arg("_save_mobile_debug_info") = false,
py::arg("_use_flatbuffer") = false)
py::arg("_save_mobile_debug_info") = false)
.def(
"_save_to_buffer_for_mobile",
[](Module& m,
const ExtraFilesMap& _extra_files = ExtraFilesMap(),
bool _save_mobile_debug_info = false,
bool _use_flatbuffer = false) {
bool _save_mobile_debug_info = false) {
std::ostringstream buf;
m._save_for_mobile(
buf, _extra_files, _save_mobile_debug_info, _use_flatbuffer);
m._save_for_mobile(buf, _extra_files, _save_mobile_debug_info);
return py::bytes(buf.str());
},
py::arg("_extra_files") = ExtraFilesMap(),
py::arg("_save_mobile_debug_info") = false,
py::arg("_use_flatbuffer") = false)
py::arg("_save_mobile_debug_info") = false)
.def("_set_optimized", &Module::set_optimized)
.def(
"dump",
@ -1900,10 +1891,6 @@ void initJitScriptBindings(PyObject* module) {
std::istringstream in(buffer);
return _get_mobile_model_contained_types(in);
});
m.def("_nn_module_to_mobile", [](const Module& module) {
CompilationOptions options;
return jitModuleToMobile(module, options);
});
py::class_<OperatorInfo>(m, "OperatorInfo")
.def_readonly("num_schema_args", &OperatorInfo::num_schema_args);
m.def("_get_model_ops_and_info", [](const std::string& filename) {

View File

@ -158,24 +158,21 @@ TORCH_API void ExportModule(
std::ostream& out,
const ExtraFilesMap& metadata = ExtraFilesMap(),
bool bytecode_format = false,
bool save_mobile_debug_info = false,
bool use_flatbuffer = false);
bool save_mobile_debug_info = false);
TORCH_API void ExportModule(
const Module& module,
const std::string& filename,
const ExtraFilesMap& metadata = ExtraFilesMap(),
bool bytecode_format = false,
bool save_mobile_debug_info = false,
bool use_flatbuffer = false);
bool save_mobile_debug_info = false);
TORCH_API void ExportModule(
const Module& module,
const std::function<size_t(const void*, size_t)>& writer_func,
const ExtraFilesMap& metadata = ExtraFilesMap(),
bool bytecode_format = false,
bool save_mobile_debug_info = false,
bool use_flatbuffer = false);
bool save_mobile_debug_info = false);
// Write the bytes of a pickle archive and the tensors referenced inside that
// archive

View File

@ -16,9 +16,6 @@
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
#include <torch/csrc/jit/serialization/export_bytecode.h>
#if defined(ENABLE_FLATBUFFER)
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#endif
#include <torch/csrc/jit/serialization/import_export_constants.h>
#include <torch/csrc/jit/serialization/import_export_functions.h>
#include <torch/csrc/jit/serialization/import_export_helpers.h>
@ -791,45 +788,20 @@ SerializationStorageContext& ScriptModuleSerializer::storage_context() {
return storage_context_;
}
#if defined(ENABLE_FLATBUFFER)
void save_mobile_module_to(
const Module& module,
const ExtraFilesMap& extra_files,
bool save_mobile_debug_info,
const std::function<size_t(const void*, size_t)>& writer_func) {
CompilationOptions options = getOptionsFromGlobal();
mobile::Module mod = jitModuleToMobile(module, options);
auto buffer = save_mobile_module_to_bytes(mod, extra_files);
writer_func(reinterpret_cast<void*>(buffer.data()), buffer.size());
}
#endif
void ExportModule(
const Module& module,
std::ostream& out,
const ExtraFilesMap& extra_files,
bool bytecode_format,
bool save_mobile_debug_info,
bool use_flatbuffer) {
auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
out.write(static_cast<const char*>(buf), nbytes);
return !out ? 0 : nbytes;
};
if (use_flatbuffer) {
#if defined(ENABLE_FLATBUFFER)
save_mobile_module_to(
module, extra_files, save_mobile_debug_info, writer_func);
#else
TORCH_CHECK(
false,
"Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
#endif
} else {
caffe2::serialize::PyTorchStreamWriter writer(writer_func);
ScriptModuleSerializer serializer(writer);
serializer.serialize(
module, extra_files, bytecode_format, save_mobile_debug_info);
}
bool save_mobile_debug_info) {
caffe2::serialize::PyTorchStreamWriter writer(
[&](const void* buf, size_t nbytes) -> size_t {
out.write(static_cast<const char*>(buf), nbytes);
return !out ? 0 : nbytes;
});
ScriptModuleSerializer serializer(writer);
serializer.serialize(
module, extra_files, bytecode_format, save_mobile_debug_info);
}
void ExportModule(
@ -837,29 +809,11 @@ void ExportModule(
const std::string& filename,
const ExtraFilesMap& extra_files,
bool bytecode_format,
bool save_mobile_debug_info,
bool use_flatbuffer) {
if (use_flatbuffer) {
#if defined(ENABLE_FLATBUFFER)
auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
std::fstream ofile(filename, std::ios::binary | std::ios::out);
ofile.write(static_cast<const char*>(buf), nbytes);
ofile.close();
return !ofile ? 0 : nbytes;
};
save_mobile_module_to(
module, extra_files, save_mobile_debug_info, writer_func);
#else
TORCH_CHECK(
false,
"Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
#endif
} else {
caffe2::serialize::PyTorchStreamWriter writer(filename);
ScriptModuleSerializer serializer(writer);
serializer.serialize(
module, extra_files, bytecode_format, save_mobile_debug_info);
}
bool save_mobile_debug_info) {
caffe2::serialize::PyTorchStreamWriter writer(filename);
ScriptModuleSerializer serializer(writer);
serializer.serialize(
module, extra_files, bytecode_format, save_mobile_debug_info);
}
void ExportModule(
@ -867,23 +821,11 @@ void ExportModule(
const std::function<size_t(const void*, size_t)>& writer_func,
const ExtraFilesMap& extra_files,
bool bytecode_format,
bool save_mobile_debug_info,
bool use_flatbuffer) {
if (use_flatbuffer) {
#if defined(ENABLE_FLATBUFFER)
save_mobile_module_to(
module, extra_files, save_mobile_debug_info, writer_func);
#else
TORCH_CHECK(
false,
"Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
#endif
} else {
caffe2::serialize::PyTorchStreamWriter writer(writer_func);
ScriptModuleSerializer serializer(writer);
serializer.serialize(
module, extra_files, bytecode_format, save_mobile_debug_info);
}
bool save_mobile_debug_info) {
caffe2::serialize::PyTorchStreamWriter writer(writer_func);
ScriptModuleSerializer serializer(writer);
serializer.serialize(
module, extra_files, bytecode_format, save_mobile_debug_info);
}
namespace {