mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Extending _get_bytecode_version to support flatbuffers format (#75021)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75021 Extending `_get_bytecode_version` to support flatbuffers. ghstack-source-id: 152771695 (Note: this ignores all push blocking failures!) Test Plan: ``` ~/fbsource/xplat] cd ~/fbsource/xplat/ && buck test //xplat/caffe2:test_lite_interpreter Building: finished in 0.8 sec (100%) 327/327 jobs, 0/327 updated Total time: 0.9 sec Testing: finished in 06:59.5 min (85 PASS/0 FAIL) BUILD SUCCEEDED RESULTS FOR //xplat/caffe2:test_lite_interpreter PASS 412.3s 85 Passed 0 Skipped 0 Failed //xplat/caffe2:test_lite_interpreter TESTS PASSED ``` Reviewed By: iseeyuan Differential Revision: D34900498 fbshipit-source-id: 65743076d43a933c5381ec128d0268f22c0a8441 (cherry picked from commit 457c76c7d1df6050b941c56a8198162e2e4a3388)
This commit is contained in:
committed by
PyTorch MergeBot
parent
835cc66e5d
commit
7aaa75af05
@ -249,6 +249,23 @@ TEST(FlatbufferTest, Inline) {
|
||||
AT_ASSERT(output.toTensor().item<float>() == 7.0);
|
||||
}
|
||||
|
||||
#if defined ENABLE_FLATBUFFER
|
||||
TEST(FlatbufferTest, GetByteCodeVersion) {
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
def forward(self, input: Tensor):
|
||||
return input + 1
|
||||
)");
|
||||
std::stringstream ss;
|
||||
m._save_for_mobile(ss, {}, false, /*use_flatbuffer=*/true);
|
||||
auto version = _get_model_bytecode_version(ss);
|
||||
AT_ASSERT(version == caffe2::serialize::kProducedBytecodeVersion);
|
||||
ss.seekg(0, ss.beg);
|
||||
auto version_again = _get_model_bytecode_version(ss);
|
||||
AT_ASSERT(version == version_again);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(FlatbufferTest, Tuple) {
|
||||
Module m("m");
|
||||
m.define(R"JIT(
|
||||
|
@ -656,12 +656,14 @@ void backportAllVersionCheck(
|
||||
|
||||
// Check backport model version
|
||||
auto backport_version = _get_model_bytecode_version(oss);
|
||||
backport_version = _get_model_bytecode_version(oss);
|
||||
AT_ASSERT(backport_version == current_to_version);
|
||||
|
||||
// Load and run the backport model, then compare the result with expect
|
||||
// result
|
||||
runAndCheckBytecodeModel(
|
||||
oss, input_data, expect_result_list, current_to_version);
|
||||
oss.seekg(0, oss.beg);
|
||||
runAndCheckTorchScriptModel(
|
||||
oss, input_data, expect_result_list, current_to_version);
|
||||
|
||||
|
@ -3,6 +3,10 @@
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <torch/csrc/jit/api/compilation_unit.h> // removed after using simple type_resolver/obj_loader
|
||||
#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
|
||||
#include <torch/csrc/jit/mobile/file_format.h>
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||
#endif
|
||||
#include <torch/csrc/jit/mobile/import.h> // removed after using simple type_resolver/obj_loader
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <torch/csrc/jit/serialization/import_export_constants.h>
|
||||
@ -69,13 +73,52 @@ uint64_t _get_model_bytecode_version(
|
||||
const std::vector<IValue>& bytecode_ivalues);
|
||||
|
||||
uint64_t _get_model_bytecode_version(std::istream& in) {
|
||||
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
|
||||
return _get_model_bytecode_version(std::move(rai));
|
||||
auto orig_pos = in.tellg();
|
||||
auto format = getFileFormat(in);
|
||||
switch (format) {
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
#if !defined(ENABLE_FLATBUFFER)
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Flatbuffer input file but the build hasn't enabled flatbuffer");
|
||||
#else
|
||||
return get_bytecode_version(in);
|
||||
#endif
|
||||
}
|
||||
case FileFormat::ZipFileFormat: {
|
||||
std::unique_ptr<IStreamAdapter> rai =
|
||||
std::make_unique<IStreamAdapter>(&in);
|
||||
auto version = _get_model_bytecode_version(std::move(rai));
|
||||
in.seekg(orig_pos, in.beg);
|
||||
return version;
|
||||
}
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "Unrecognized data format");
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t _get_model_bytecode_version(const std::string& filename) {
|
||||
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
|
||||
return _get_model_bytecode_version(std::move(rai));
|
||||
auto format = getFileFormat(filename);
|
||||
switch (format) {
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
#if !defined(ENABLE_FLATBUFFER)
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Flatbuffer input file but the build hasn't enabled flatbuffer");
|
||||
#else
|
||||
return get_bytecode_version(filename);
|
||||
#endif
|
||||
}
|
||||
case FileFormat::ZipFileFormat: {
|
||||
std::unique_ptr<FileAdapter> rai =
|
||||
std::make_unique<FileAdapter>(filename);
|
||||
return _get_model_bytecode_version(std::move(rai));
|
||||
}
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "Unrecognized data format");
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t _get_model_bytecode_version(
|
||||
|
@ -663,5 +663,27 @@ mobile::Module load_mobile_module_from_file(
|
||||
return parse_and_initialize_mobile_module(std::move(data), size, device);
|
||||
}
|
||||
|
||||
uint64_t get_bytecode_version(std::istream& in) {
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_stream_content(in);
|
||||
TORCH_CHECK(
|
||||
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
|
||||
"Format error");
|
||||
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
|
||||
return flatbuffer_module->bytecode_version();
|
||||
}
|
||||
|
||||
uint64_t get_bytecode_version(const std::string& filename) {
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_file_content(filename.c_str());
|
||||
TORCH_CHECK(
|
||||
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
|
||||
"Format error");
|
||||
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
|
||||
return flatbuffer_module->bytecode_version();
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -62,6 +62,9 @@ TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_file_content(
|
||||
TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
|
||||
std::istream& in);
|
||||
|
||||
TORCH_API uint64_t get_bytecode_version(std::istream& in);
|
||||
TORCH_API uint64_t get_bytecode_version(const std::string& filename);
|
||||
|
||||
class TORCH_API FlatbufferLoader {
|
||||
public:
|
||||
FlatbufferLoader();
|
||||
|
@ -135,12 +135,21 @@ class TORCH_API Module {
|
||||
mem_to_delete_ = delete_mem;
|
||||
}
|
||||
|
||||
void set_bytecode_version(int64_t version) {
|
||||
bytecode_version_ = version;
|
||||
}
|
||||
|
||||
int64_t bytecode_version() const {
|
||||
return bytecode_version_;
|
||||
}
|
||||
|
||||
private:
|
||||
c10::intrusive_ptr<c10::ivalue::Object> object_;
|
||||
std::unordered_map<std::string, std::string> metadata_;
|
||||
std::shared_ptr<CompilationUnit> cu_;
|
||||
MobileDebugTable debug_table_;
|
||||
bool has_debug_handles_ = false;
|
||||
int64_t bytecode_version_;
|
||||
|
||||
// Extra handle for the module to delete when itself is deleted
|
||||
std::shared_ptr<char> mem_to_delete_;
|
||||
|
@ -377,6 +377,8 @@ mobile::Module jitModuleToMobile(
|
||||
backend_debug_info_map.begin(), backend_debug_info_map.end());
|
||||
m.setDebugTable(MobileDebugTable(
|
||||
debug_handle_cs_ptr_map.begin(), debug_handle_cs_ptr_map.end()));
|
||||
|
||||
m.set_bytecode_version(options.model_version);
|
||||
return m;
|
||||
}
|
||||
|
||||
|
@ -386,9 +386,12 @@ flatbuffers::DetachedBuffer FlatbufferSerializer::serializeModule(
|
||||
jit_constants_indexes.emplace_back(storeIValueAndGetIndex(fbb, ival));
|
||||
}
|
||||
|
||||
const uint32_t bytecode_version =
|
||||
static_cast<uint32_t>(module.bytecode_version());
|
||||
|
||||
auto mod = CreateModule(
|
||||
fbb,
|
||||
0, /* version */
|
||||
/*bytecode_version=*/bytecode_version,
|
||||
extra_files_offset, /* extra_files */
|
||||
functions_offset,
|
||||
ivalue_index,
|
||||
|
Reference in New Issue
Block a user