mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Automatic pulling ExtraFileMaps without explicit mapping.
Differential Revision: D45170126nnPull Request resolved: https://github.com/pytorch/pytorch/pull/99747
This commit is contained in:
@ -232,6 +232,22 @@ TEST(FlatbufferTest, ExtraFiles) {
|
||||
|
||||
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
|
||||
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
|
||||
|
||||
// Test if flatbuffer does not require any explicit key entries mapping in the
|
||||
// extra file map.
|
||||
std::unordered_map<std::string, std::string>
|
||||
loaded_extra_files_without_explicit_entries;
|
||||
auto mobile_module3 = _load_for_mobile(
|
||||
ss,
|
||||
c10::nullopt,
|
||||
loaded_extra_files_without_explicit_entries,
|
||||
MobileModuleLoadOptions::PARSE_ALL_EXTRA_FILE_MAPS);
|
||||
|
||||
ASSERT_EQ(
|
||||
loaded_extra_files_without_explicit_entries["metadata.json"], "abc");
|
||||
ASSERT_EQ(
|
||||
loaded_extra_files_without_explicit_entries["mobile_info.json"],
|
||||
"{\"key\": 23}");
|
||||
}
|
||||
|
||||
TEST(FlatbufferTest, Conv) {
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/frontend/resolver.h>
|
||||
@ -1015,6 +1014,20 @@ TEST(LiteInterpreterTest, ExtraFiles) {
|
||||
torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
|
||||
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
|
||||
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
|
||||
|
||||
std::unordered_map<std::string, std::string>
|
||||
loaded_extra_files_without_explicit_mapping;
|
||||
iss.seekg(0, iss.beg);
|
||||
torch::jit::_load_for_mobile(
|
||||
iss,
|
||||
torch::kCPU,
|
||||
loaded_extra_files_without_explicit_mapping,
|
||||
MobileModuleLoadOptions::PARSE_ALL_EXTRA_FILE_MAPS);
|
||||
ASSERT_EQ(
|
||||
loaded_extra_files_without_explicit_mapping["metadata.json"], "abc");
|
||||
ASSERT_EQ(
|
||||
loaded_extra_files_without_explicit_mapping["mobile_info.json"],
|
||||
"{\"key\": 23}");
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, OpNameExportFetchRootOperators) {
|
||||
|
||||
@ -548,17 +548,18 @@ mobile::Module _load_for_mobile(
|
||||
mobile::Module _load_for_mobile(
|
||||
std::istream& in,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
ExtraFilesMap& extra_files,
|
||||
uint64_t module_load_options) {
|
||||
if (getFileFormat(in) == FileFormat::FlatbufferFileFormat) {
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_stream_content(in);
|
||||
return _load_mobile_from_bytes(
|
||||
data, size, device, extra_files, kDefaultMobileLoadOptions);
|
||||
data, size, device, extra_files, module_load_options);
|
||||
}
|
||||
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
|
||||
auto module = _load_for_mobile_impl(
|
||||
std::move(rai), device, extra_files, kDefaultMobileLoadOptions);
|
||||
std::move(rai), device, extra_files, module_load_options);
|
||||
return module;
|
||||
}
|
||||
|
||||
@ -649,6 +650,21 @@ mobile::Module _load_for_mobile_impl(
|
||||
|
||||
const size_t model_size = rai != nullptr ? rai->size() : 0;
|
||||
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
|
||||
if (module_load_options &
|
||||
MobileModuleLoadOptions::PARSE_ALL_EXTRA_FILE_MAPS) {
|
||||
// ExtraFilesMap is serialized with a "extra/", hence it is necessary to
|
||||
// account for when we de-serialize de-serialized filemap key values contain
|
||||
// prefix and we need to remove prior to construct the map. "extra/" string
|
||||
// has a length of 6 characters, hence we need only sub-string 6th position
|
||||
// of a string. Please refer to following link for a detail:
|
||||
// https://www.internalfb.com/code/fbsource/[9996fcb7a6fb]/fbcode/caffe2/torch/csrc/jit/mobile/import.cpp?lines=427-434
|
||||
std::vector<std::string> all_files = reader->getAllRecords();
|
||||
for (auto& file_name : all_files) {
|
||||
if (file_name.find("extra/") == 0) {
|
||||
extra_files[file_name.substr(6)] = "";
|
||||
}
|
||||
}
|
||||
}
|
||||
BytecodeDeserializer deserializer(std::move(reader), module_load_options);
|
||||
|
||||
std::string error_message;
|
||||
|
||||
@ -23,7 +23,8 @@ constexpr const char* kArchiveNameVersion = "version";
|
||||
TORCH_API mobile::Module _load_for_mobile(
|
||||
std::istream& in,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files);
|
||||
ExtraFilesMap& extra_file,
|
||||
uint64_t module_load_options = kDefaultMobileLoadOptions);
|
||||
|
||||
TORCH_API mobile::Module _load_for_mobile(
|
||||
const std::string& filename,
|
||||
|
||||
@ -7,6 +7,10 @@ using c10::IValue;
|
||||
|
||||
enum MobileModuleLoadOptions {
|
||||
OPERATOR_CHECK = 1,
|
||||
// PARSE_ALL_EXTRA_FILE_MAPS is used to gate for ExtraFileMaps to pull all
|
||||
// files automatically without explicit entries mapping. Refer to PR for a
|
||||
// detail: https://github.com/pytorch/pytorch/pull/99747
|
||||
PARSE_ALL_EXTRA_FILE_MAPS = 2,
|
||||
};
|
||||
|
||||
const uint64_t kDefaultMobileLoadOptions =
|
||||
|
||||
Reference in New Issue
Block a user