New flatbuffer_loader functions that do not depend on flatbuffers.h (#82618)

This is the first step towards hiding the flatbuffers types and headers from the load/serialize APIs. The two new functions make it possible to load modules without using `GetMutableModule` (defined by the generated header) or `FlatbufferLoader` (which depends on flatbuffers types).

D38292794 will remove the functions/class marked DEPRECATED here after migrating existing users.

Differential Revision: [D38292793](https://our.internmc.facebook.com/intern/diff/D38292793/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82618
Approved by: https://github.com/qihqi
This commit is contained in:
Dave Bort
2022-08-04 10:55:07 -07:00
committed by PyTorch MergeBot
parent 4ae40d74ac
commit 802a4fd286
2 changed files with 105 additions and 21 deletions

View File

@ -1,4 +1,3 @@
#include <flatbuffers/base.h>
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#include <ATen/ATen.h>
@ -21,10 +20,9 @@
#include <torch/csrc/jit/serialization/export_bytecode.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
#include <torch/csrc/jit/serialization/import_read.h>
#include <torch/csrc/jit/serialization/mobile_bytecode_generated.h>
#include <torch/custom_class.h>
#include <flatbuffers/flatbuffers.h>
#ifndef DISABLE_UPGRADER
#include <torch/csrc/jit/mobile/parse_bytecode.h>
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
@ -49,10 +47,6 @@
namespace torch {
namespace jit {
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::ReadAdapterInterface;
static constexpr c10::string_view kCustomClassPrefix =
"__torch__.torch.classes";
static constexpr c10::string_view kTorchPrefix = "__torch__";
@ -648,22 +642,63 @@ void FlatbufferLoader::extractJitSourceAndConstants(
}
mobile::Module parse_and_initialize_mobile_module(
std::shared_ptr<char> data,
void* data,
size_t,
c10::optional<at::Device>,
ExtraFilesMap* extra_files) {
ExtraFilesMap* extra_files,
bool should_copy_tensor_memory) {
TORCH_CHECK(
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
"Format error");
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
mobile::Module m = FlatbufferLoader().parseModule(flatbuffer_module);
m.set_delete_memory(std::move(data));
mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
FlatbufferLoader loader;
loader.setShouldCopyTensorMemory(should_copy_tensor_memory);
// Flatbuffer doesn't seem to have a way to provide the buffer size when
// interacting with the buffer.
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
mobile::Module m = loader.parseModule(flatbuffer_module);
if (extra_files != nullptr) {
parseExtraFiles(flatbuffer_module, *extra_files);
}
return m;
}
mobile::Module parse_and_initialize_mobile_module(
std::shared_ptr<char> data,
size_t size,
c10::optional<at::Device> device,
ExtraFilesMap* extra_files) {
mobile::Module m = parse_and_initialize_mobile_module(
data.get(),
size,
device,
extra_files,
/*should_copy_tensor_memory=*/false);
m.set_delete_memory(std::move(data));
return m;
}
mobile::Module parse_and_initialize_mobile_module_for_jit(
void* data,
size_t,
ExtraFilesMap& jit_sources,
std::vector<IValue>& jit_constants,
c10::optional<at::Device>,
ExtraFilesMap* extra_files) {
TORCH_CHECK(
mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
FlatbufferLoader loader;
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
mobile::Module m = loader.parseModule(flatbuffer_module);
if (extra_files != nullptr) {
parseExtraFiles(flatbuffer_module, *extra_files);
}
loader.extractJitSourceAndConstants(&jit_sources, &jit_constants);
return m;
}
mobile::Module initialize_mobile_module(
mobile::serialization::Module* flatbuffer_module,
c10::optional<at::Device>,

View File

@ -12,9 +12,14 @@
#include <string>
#include <vector>
/**
* Defines the public API for loading flatbuffer-serialized mobile modules.
*/
namespace torch {
namespace jit {
/// Maps file names to file contents.
using ExtraFilesMap = std::unordered_map<std::string, std::string>;
// On high level, to produce a Module from a file on disk, we need to go
@ -27,6 +32,7 @@ using ExtraFilesMap = std::unordered_map<std::string, std::string>;
// Under this context, the structure described in 2. is
// mobile::serialization::Module
/// DEPRECATED: Use a parse/load function below.
// Parse a mobile::Module from flatbuffer's in-memory Module representation.
// The caller is assumed to manage the lifetimes of Module.
// This function does step 3 described above.
@ -41,25 +47,67 @@ TORCH_API mobile::Module initialize_mobile_module(
bool should_copy_tensor_memory = false);
// Parse a mobile::Module from raw bytes.
// ownership of data is shared to the returned Module.
// (Feel free to pass in a unique_ptr too!)
// This function does steps 2+3 described above
//
// This function does steps 2+3 described above.
//
// Does not take ownership of `data`; if you want it to take ownership, see the
// shared_ptr overload of this function.
//
// If should_copy_tensor_memory is true, then the returned module will NOT have
// refences to `data`, so `data` can be freed immediately.
//
// If should_copy_tensor_memory is false, then returned module will have tensors
// that points inside of `data`; the caller will need to make sure that `data`
// outlives the returned Module.
TORCH_API mobile::Module parse_and_initialize_mobile_module(
void* data,
size_t size, // of `data`, in bytes.
c10::optional<at::Device> device = c10::nullopt,
ExtraFilesMap* extra_files = nullptr,
bool should_copy_tensor_memory = false);
// Parse a mobile::Module from raw bytes.
//
// This function does steps 2+3 described above.
//
// The returned Module holds a reference to `data`.
//
// If you do not want the Module to hold a reference to `data`, see the raw
// pointer overload of this function.
TORCH_API mobile::Module parse_and_initialize_mobile_module(
std::shared_ptr<char> data,
size_t size,
size_t size, // of `data`, in bytes.
c10::optional<at::Device> device = c10::nullopt,
ExtraFilesMap* extra_files = nullptr);
// Parse a mobile::Module from raw bytes, also returning JIT-related metadata.
//
// This is the same as parse_and_initialize_mobile_module() except that it also
// extracts JIT source files and constants. Can be used to construct a
// jit::Module.
TORCH_API mobile::Module parse_and_initialize_mobile_module_for_jit(
void* data,
size_t size, // of `data`, in bytes.
ExtraFilesMap& jit_sources,
std::vector<IValue>& jit_constants,
c10::optional<at::Device> device = c10::nullopt,
ExtraFilesMap* extra_files = nullptr);
// Load a mobile::Module from a filepath.
//
// This function does steps 1+2+3 described above.
// We need to have this as a convienience because Python
// API will need to wrap this. C++ clients should use one
// versions above.
//
// We need to have this as a convienience because Python API will need to wrap
// this. C++ clients should use one of the versions of
// parse_and_initialize_mobile_module() so they can manage the raw data more
// directly.
TORCH_API mobile::Module load_mobile_module_from_file(
const std::string& filename,
c10::optional<at::Device> device = c10::nullopt,
ExtraFilesMap* extra_files = nullptr);
/// DEPRECATED: Use the `extra_files` parameter of one of the parse/load
/// functions above.
TORCH_API void parseExtraFiles(
mobile::serialization::Module* module,
ExtraFilesMap& extra_files);
@ -84,6 +132,7 @@ TORCH_API mobile::Module load_mobile_module_from_stream_with_copy(
// in this file directly.
TORCH_API bool register_flatbuffer_loader();
/// DEPRECATED: Use one of the parse/load functions above.
class TORCH_API FlatbufferLoader {
public:
FlatbufferLoader();