Extending _get_bytecode_version to support flatbuffers format (#75021)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75021

Extending `_get_bytecode_version` to support flatbuffers.
ghstack-source-id: 152771695

(Note: this ignores all push blocking failures!)

Test Plan:
```
~/fbsource/xplat] cd ~/fbsource/xplat/ && buck test //xplat/caffe2:test_lite_interpreter
Building: finished in 0.8 sec (100%) 327/327 jobs, 0/327 updated
  Total time: 0.9 sec
Testing: finished in 06:59.5 min (85 PASS/0 FAIL)
BUILD SUCCEEDED
RESULTS FOR //xplat/caffe2:test_lite_interpreter
PASS    412.3s 85 Passed   0 Skipped   0 Failed   //xplat/caffe2:test_lite_interpreter
TESTS PASSED
```

Reviewed By: iseeyuan

Differential Revision: D34900498

fbshipit-source-id: 65743076d43a933c5381ec128d0268f22c0a8441
(cherry picked from commit 457c76c7d1df6050b941c56a8198162e2e4a3388)
This commit is contained in:
Pavithran Ramachandran
2022-04-01 07:58:54 -07:00
committed by PyTorch MergeBot
parent 835cc66e5d
commit 7aaa75af05
8 changed files with 106 additions and 5 deletions

View File

@ -249,6 +249,23 @@ TEST(FlatbufferTest, Inline) {
AT_ASSERT(output.toTensor().item<float>() == 7.0);
}
#if defined ENABLE_FLATBUFFER
TEST(FlatbufferTest, GetByteCodeVersion) {
Module m("m");
m.define(R"(
def forward(self, input: Tensor):
return input + 1
)");
std::stringstream ss;
m._save_for_mobile(ss, {}, false, /*use_flatbuffer=*/true);
auto version = _get_model_bytecode_version(ss);
AT_ASSERT(version == caffe2::serialize::kProducedBytecodeVersion);
ss.seekg(0, ss.beg);
auto version_again = _get_model_bytecode_version(ss);
AT_ASSERT(version == version_again);
}
#endif
TEST(FlatbufferTest, Tuple) {
Module m("m");
m.define(R"JIT(

View File

@ -656,12 +656,14 @@ void backportAllVersionCheck(
// Check backport model version
auto backport_version = _get_model_bytecode_version(oss);
backport_version = _get_model_bytecode_version(oss);
AT_ASSERT(backport_version == current_to_version);
// Load and run the backport model, then compare the result with expect
// result
runAndCheckBytecodeModel(
oss, input_data, expect_result_list, current_to_version);
oss.seekg(0, oss.beg);
runAndCheckTorchScriptModel(
oss, input_data, expect_result_list, current_to_version);

View File

@ -3,6 +3,10 @@
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/api/compilation_unit.h> // removed after using simple type_resolver/obj_loader
#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
#include <torch/csrc/jit/mobile/file_format.h>
#if defined(ENABLE_FLATBUFFER)
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#endif
#include <torch/csrc/jit/mobile/import.h> // removed after using simple type_resolver/obj_loader
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
@ -69,13 +73,52 @@ uint64_t _get_model_bytecode_version(
const std::vector<IValue>& bytecode_ivalues);
uint64_t _get_model_bytecode_version(std::istream& in) {
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
return _get_model_bytecode_version(std::move(rai));
auto orig_pos = in.tellg();
auto format = getFileFormat(in);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
#if !defined(ENABLE_FLATBUFFER)
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
#else
return get_bytecode_version(in);
#endif
}
case FileFormat::ZipFileFormat: {
std::unique_ptr<IStreamAdapter> rai =
std::make_unique<IStreamAdapter>(&in);
auto version = _get_model_bytecode_version(std::move(rai));
in.seekg(orig_pos, in.beg);
return version;
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
}
uint64_t _get_model_bytecode_version(const std::string& filename) {
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
return _get_model_bytecode_version(std::move(rai));
auto format = getFileFormat(filename);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
#if !defined(ENABLE_FLATBUFFER)
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
#else
return get_bytecode_version(filename);
#endif
}
case FileFormat::ZipFileFormat: {
std::unique_ptr<FileAdapter> rai =
std::make_unique<FileAdapter>(filename);
return _get_model_bytecode_version(std::move(rai));
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
}
uint64_t _get_model_bytecode_version(

View File

@ -663,5 +663,27 @@ mobile::Module load_mobile_module_from_file(
return parse_and_initialize_mobile_module(std::move(data), size, device);
}
uint64_t get_bytecode_version(std::istream& in) {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_stream_content(in);
TORCH_CHECK(
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
"Format error");
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
return flatbuffer_module->bytecode_version();
}
uint64_t get_bytecode_version(const std::string& filename) {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_file_content(filename.c_str());
TORCH_CHECK(
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
"Format error");
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
return flatbuffer_module->bytecode_version();
}
} // namespace jit
} // namespace torch

View File

@ -62,6 +62,9 @@ TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_file_content(
TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
std::istream& in);
TORCH_API uint64_t get_bytecode_version(std::istream& in);
TORCH_API uint64_t get_bytecode_version(const std::string& filename);
class TORCH_API FlatbufferLoader {
public:
FlatbufferLoader();

View File

@ -135,12 +135,21 @@ class TORCH_API Module {
mem_to_delete_ = delete_mem;
}
void set_bytecode_version(int64_t version) {
bytecode_version_ = version;
}
int64_t bytecode_version() const {
return bytecode_version_;
}
private:
c10::intrusive_ptr<c10::ivalue::Object> object_;
std::unordered_map<std::string, std::string> metadata_;
std::shared_ptr<CompilationUnit> cu_;
MobileDebugTable debug_table_;
bool has_debug_handles_ = false;
int64_t bytecode_version_;
// Extra handle for the module to delete when itself is deleted
std::shared_ptr<char> mem_to_delete_;

View File

@ -377,6 +377,8 @@ mobile::Module jitModuleToMobile(
backend_debug_info_map.begin(), backend_debug_info_map.end());
m.setDebugTable(MobileDebugTable(
debug_handle_cs_ptr_map.begin(), debug_handle_cs_ptr_map.end()));
m.set_bytecode_version(options.model_version);
return m;
}

View File

@ -386,9 +386,12 @@ flatbuffers::DetachedBuffer FlatbufferSerializer::serializeModule(
jit_constants_indexes.emplace_back(storeIValueAndGetIndex(fbb, ival));
}
const uint32_t bytecode_version =
static_cast<uint32_t>(module.bytecode_version());
auto mod = CreateModule(
fbb,
0, /* version */
/*bytecode_version=*/bytecode_version,
extra_files_offset, /* extra_files */
functions_offset,
ivalue_index,