Automatic pulling ExtraFileMaps without explicit mapping.

Differential Revision: D45170126nnPull Request resolved: https://github.com/pytorch/pytorch/pull/99747
This commit is contained in:
kwanghoon-meta
2023-05-01 16:27:56 -07:00
committed by GitHub
parent a1d041728b
commit 3fb0bf4d96
5 changed files with 56 additions and 6 deletions

View File

@ -548,17 +548,18 @@ mobile::Module _load_for_mobile(
mobile::Module _load_for_mobile(
std::istream& in,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
ExtraFilesMap& extra_files,
uint64_t module_load_options) {
if (getFileFormat(in) == FileFormat::FlatbufferFileFormat) {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_stream_content(in);
return _load_mobile_from_bytes(
data, size, device, extra_files, kDefaultMobileLoadOptions);
data, size, device, extra_files, module_load_options);
}
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
auto module = _load_for_mobile_impl(
std::move(rai), device, extra_files, kDefaultMobileLoadOptions);
std::move(rai), device, extra_files, module_load_options);
return module;
}
@ -649,6 +650,21 @@ mobile::Module _load_for_mobile_impl(
const size_t model_size = rai != nullptr ? rai->size() : 0;
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
if (module_load_options &
MobileModuleLoadOptions::PARSE_ALL_EXTRA_FILE_MAPS) {
// ExtraFilesMap is serialized with a "extra/", hence it is necessary to
// account for when we de-serialize de-serialized filemap key values contain
// prefix and we need to remove prior to construct the map. "extra/" string
// has a length of 6 characters, hence we need only sub-string 6th position
// of a string. Please refer to following link for a detail:
// https://www.internalfb.com/code/fbsource/[9996fcb7a6fb]/fbcode/caffe2/torch/csrc/jit/mobile/import.cpp?lines=427-434
std::vector<std::string> all_files = reader->getAllRecords();
for (auto& file_name : all_files) {
if (file_name.find("extra/") == 0) {
extra_files[file_name.substr(6)] = "";
}
}
}
BytecodeDeserializer deserializer(std::move(reader), module_load_options);
std::string error_message;