Fix load_extra_only api for flatbuffers and enable flatbuffers in mobile for OSS properly (#83855)

`_load_extra_only_for_mobile` API hasn't handled flatbuffers logic yet. Update the api accordingly.

Also find out mobile build in OSS doesn't build with flatbuffers. Filed task T129996445 to track

Differential Revision: [D38890847](https://our.internmc.facebook.com/intern/diff/D38890847/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D38890847/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83855
Approved by: https://github.com/qihqi
This commit is contained in:
chenlai
2022-08-22 20:35:11 -07:00
committed by PyTorch MergeBot
parent bbe803cb35
commit 25dd2a0422
2 changed files with 33 additions and 4 deletions

View File

@ -696,16 +696,43 @@ void _load_extra_only_for_mobile(
const std::string& filename,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
auto observer = torch::observerConfig().getModuleObserver();
// NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
auto instance_key = std::rand();
if (observer) {
observer->onEnterLoadModel(instance_key);
}
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
BytecodeDeserializer deserializer(std::move(reader));
deserializer.deserialize_only_extra(device, extra_files);
auto format = getFileFormat(filename);
switch (format) {
case FileFormat::ZipFileFormat: {
std::unique_ptr<FileAdapter> rai =
std::make_unique<FileAdapter>(filename);
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
BytecodeDeserializer deserializer(std::move(reader));
deserializer.deserialize_only_extra(device, extra_files);
break;
}
case FileFormat::FlatbufferFileFormat: {
// TODO: the current flatbuffers implementation will always load the
// whole module including the extra files. Ideally it should be
// possible to just get the extra files given data
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_file_content(filename.c_str());
if (load_flatbuffer_bytes != nullptr) {
load_flatbuffer_bytes(data, size, device, &extra_files);
} else {
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
}
break;
}
default: {
TORCH_CHECK(false, "Format error");
}
}
}
namespace mobile {