mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
New flatbuffer_loader functions that do not depend on flatbuffers.h (#82618)
This is the first step towards hiding the flatbuffers types and headers from the load/serialize APIs. The two new functions make it possible to load modules without using `GetMutableModule` (defined by the generated header) or `FlatbufferLoader` (which depends on flatbuffers types). D38292794 will remove the functions/class marked DEPRECATED here after migrating existing users. Differential Revision: [D38292793](https://our.internmc.facebook.com/intern/diff/D38292793/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/82618 Approved by: https://github.com/qihqi
This commit is contained in:
committed by
PyTorch MergeBot
parent
4ae40d74ac
commit
802a4fd286
@ -1,4 +1,3 @@
|
||||
#include <flatbuffers/base.h>
|
||||
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
@ -21,10 +20,9 @@
|
||||
#include <torch/csrc/jit/serialization/export_bytecode.h>
|
||||
#include <torch/csrc/jit/serialization/import_export_constants.h>
|
||||
#include <torch/csrc/jit/serialization/import_read.h>
|
||||
#include <torch/csrc/jit/serialization/mobile_bytecode_generated.h>
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
#include <flatbuffers/flatbuffers.h>
|
||||
|
||||
#ifndef DISABLE_UPGRADER
|
||||
#include <torch/csrc/jit/mobile/parse_bytecode.h>
|
||||
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
|
||||
@ -49,10 +47,6 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using caffe2::serialize::IStreamAdapter;
|
||||
using caffe2::serialize::PyTorchStreamReader;
|
||||
using caffe2::serialize::ReadAdapterInterface;
|
||||
|
||||
static constexpr c10::string_view kCustomClassPrefix =
|
||||
"__torch__.torch.classes";
|
||||
static constexpr c10::string_view kTorchPrefix = "__torch__";
|
||||
@ -648,22 +642,63 @@ void FlatbufferLoader::extractJitSourceAndConstants(
|
||||
}
|
||||
|
||||
mobile::Module parse_and_initialize_mobile_module(
|
||||
std::shared_ptr<char> data,
|
||||
void* data,
|
||||
size_t,
|
||||
c10::optional<at::Device>,
|
||||
ExtraFilesMap* extra_files) {
|
||||
ExtraFilesMap* extra_files,
|
||||
bool should_copy_tensor_memory) {
|
||||
TORCH_CHECK(
|
||||
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
|
||||
"Format error");
|
||||
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
|
||||
mobile::Module m = FlatbufferLoader().parseModule(flatbuffer_module);
|
||||
m.set_delete_memory(std::move(data));
|
||||
mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
|
||||
|
||||
FlatbufferLoader loader;
|
||||
loader.setShouldCopyTensorMemory(should_copy_tensor_memory);
|
||||
|
||||
// Flatbuffer doesn't seem to have a way to provide the buffer size when
|
||||
// interacting with the buffer.
|
||||
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
|
||||
mobile::Module m = loader.parseModule(flatbuffer_module);
|
||||
if (extra_files != nullptr) {
|
||||
parseExtraFiles(flatbuffer_module, *extra_files);
|
||||
}
|
||||
return m;
|
||||
}
|
||||
|
||||
mobile::Module parse_and_initialize_mobile_module(
|
||||
std::shared_ptr<char> data,
|
||||
size_t size,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap* extra_files) {
|
||||
mobile::Module m = parse_and_initialize_mobile_module(
|
||||
data.get(),
|
||||
size,
|
||||
device,
|
||||
extra_files,
|
||||
/*should_copy_tensor_memory=*/false);
|
||||
m.set_delete_memory(std::move(data));
|
||||
return m;
|
||||
}
|
||||
|
||||
mobile::Module parse_and_initialize_mobile_module_for_jit(
|
||||
void* data,
|
||||
size_t,
|
||||
ExtraFilesMap& jit_sources,
|
||||
std::vector<IValue>& jit_constants,
|
||||
c10::optional<at::Device>,
|
||||
ExtraFilesMap* extra_files) {
|
||||
TORCH_CHECK(
|
||||
mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
|
||||
|
||||
FlatbufferLoader loader;
|
||||
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
|
||||
mobile::Module m = loader.parseModule(flatbuffer_module);
|
||||
if (extra_files != nullptr) {
|
||||
parseExtraFiles(flatbuffer_module, *extra_files);
|
||||
}
|
||||
|
||||
loader.extractJitSourceAndConstants(&jit_sources, &jit_constants);
|
||||
return m;
|
||||
}
|
||||
|
||||
mobile::Module initialize_mobile_module(
|
||||
mobile::serialization::Module* flatbuffer_module,
|
||||
c10::optional<at::Device>,
|
||||
|
@ -12,9 +12,14 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
/**
|
||||
* Defines the public API for loading flatbuffer-serialized mobile modules.
|
||||
*/
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
/// Maps file names to file contents.
|
||||
using ExtraFilesMap = std::unordered_map<std::string, std::string>;
|
||||
|
||||
// On high level, to produce a Module from a file on disk, we need to go
|
||||
@ -27,6 +32,7 @@ using ExtraFilesMap = std::unordered_map<std::string, std::string>;
|
||||
// Under this context, the structure described in 2. is
|
||||
// mobile::serialization::Module
|
||||
|
||||
/// DEPRECATED: Use a parse/load function below.
|
||||
// Parse a mobile::Module from flatbuffer's in-memory Module representation.
|
||||
// The caller is assumed to manage the lifetimes of Module.
|
||||
// This function does step 3 described above.
|
||||
@ -41,25 +47,67 @@ TORCH_API mobile::Module initialize_mobile_module(
|
||||
bool should_copy_tensor_memory = false);
|
||||
|
||||
// Parse a mobile::Module from raw bytes.
|
||||
// ownership of data is shared to the returned Module.
|
||||
// (Feel free to pass in a unique_ptr too!)
|
||||
// This function does steps 2+3 described above
|
||||
//
|
||||
// This function does steps 2+3 described above.
|
||||
//
|
||||
// Does not take ownership of `data`; if you want it to take ownership, see the
|
||||
// shared_ptr overload of this function.
|
||||
//
|
||||
// If should_copy_tensor_memory is true, then the returned module will NOT have
|
||||
// refences to `data`, so `data` can be freed immediately.
|
||||
//
|
||||
// If should_copy_tensor_memory is false, then returned module will have tensors
|
||||
// that points inside of `data`; the caller will need to make sure that `data`
|
||||
// outlives the returned Module.
|
||||
TORCH_API mobile::Module parse_and_initialize_mobile_module(
|
||||
void* data,
|
||||
size_t size, // of `data`, in bytes.
|
||||
c10::optional<at::Device> device = c10::nullopt,
|
||||
ExtraFilesMap* extra_files = nullptr,
|
||||
bool should_copy_tensor_memory = false);
|
||||
|
||||
// Parse a mobile::Module from raw bytes.
|
||||
//
|
||||
// This function does steps 2+3 described above.
|
||||
//
|
||||
// The returned Module holds a reference to `data`.
|
||||
//
|
||||
// If you do not want the Module to hold a reference to `data`, see the raw
|
||||
// pointer overload of this function.
|
||||
TORCH_API mobile::Module parse_and_initialize_mobile_module(
|
||||
std::shared_ptr<char> data,
|
||||
size_t size,
|
||||
size_t size, // of `data`, in bytes.
|
||||
c10::optional<at::Device> device = c10::nullopt,
|
||||
ExtraFilesMap* extra_files = nullptr);
|
||||
|
||||
// Parse a mobile::Module from raw bytes, also returning JIT-related metadata.
|
||||
//
|
||||
// This is the same as parse_and_initialize_mobile_module() except that it also
|
||||
// extracts JIT source files and constants. Can be used to construct a
|
||||
// jit::Module.
|
||||
TORCH_API mobile::Module parse_and_initialize_mobile_module_for_jit(
|
||||
void* data,
|
||||
size_t size, // of `data`, in bytes.
|
||||
ExtraFilesMap& jit_sources,
|
||||
std::vector<IValue>& jit_constants,
|
||||
c10::optional<at::Device> device = c10::nullopt,
|
||||
ExtraFilesMap* extra_files = nullptr);
|
||||
|
||||
// Load a mobile::Module from a filepath.
|
||||
//
|
||||
// This function does steps 1+2+3 described above.
|
||||
// We need to have this as a convienience because Python
|
||||
// API will need to wrap this. C++ clients should use one
|
||||
// versions above.
|
||||
//
|
||||
// We need to have this as a convienience because Python API will need to wrap
|
||||
// this. C++ clients should use one of the versions of
|
||||
// parse_and_initialize_mobile_module() so they can manage the raw data more
|
||||
// directly.
|
||||
TORCH_API mobile::Module load_mobile_module_from_file(
|
||||
const std::string& filename,
|
||||
c10::optional<at::Device> device = c10::nullopt,
|
||||
ExtraFilesMap* extra_files = nullptr);
|
||||
|
||||
/// DEPRECATED: Use the `extra_files` parameter of one of the parse/load
|
||||
/// functions above.
|
||||
TORCH_API void parseExtraFiles(
|
||||
mobile::serialization::Module* module,
|
||||
ExtraFilesMap& extra_files);
|
||||
@ -84,6 +132,7 @@ TORCH_API mobile::Module load_mobile_module_from_stream_with_copy(
|
||||
// in this file directly.
|
||||
TORCH_API bool register_flatbuffer_loader();
|
||||
|
||||
/// DEPRECATED: Use one of the parse/load functions above.
|
||||
class TORCH_API FlatbufferLoader {
|
||||
public:
|
||||
FlatbufferLoader();
|
||||
|
Reference in New Issue
Block a user