mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove flatbuffer types/headers from flatbuffer_loader.h (#82893)
This completely hides the flatbuffer types and headers from users of flatbuffer_loader/serializer, turning them into an internal implementation detail. A followup diff will fix up the buck files to hide the dependencies more thoroughly. While doing this I found another use of a flatbuffer-defined name (`FLATBUFFERS_MAX_ALIGNMENT`), which highlighted the issues described in T128189662. Differential Revision: [D38292794](https://our.internmc.facebook.com/intern/diff/D38292794/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/82893 Approved by: https://github.com/qihqi
This commit is contained in:
committed by
PyTorch MergeBot
parent
0c7ca2d97b
commit
1d56ea5e92
@ -21,21 +21,24 @@
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using torch::jit::kFlatbufferDataAlignmentBytes;
|
||||
|
||||
static std::shared_ptr<char> copyStr(const std::string& bytes) {
|
||||
size_t size = (bytes.size() / FLATBUFFERS_MAX_ALIGNMENT + 1) *
|
||||
FLATBUFFERS_MAX_ALIGNMENT;
|
||||
size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
|
||||
kFlatbufferDataAlignmentBytes;
|
||||
#ifdef _WIN32
|
||||
std::shared_ptr<char> bytes_copy(
|
||||
static_cast<char*>(_aligned_malloc(size, FLATBUFFERS_MAX_ALIGNMENT)),
|
||||
static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
|
||||
_aligned_free);
|
||||
#elif defined(__APPLE__)
|
||||
void* p;
|
||||
::posix_memalign(&p, FLATBUFFERS_MAX_ALIGNMENT, size);
|
||||
::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
|
||||
TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
|
||||
std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
|
||||
#else
|
||||
std::shared_ptr<char> bytes_copy(
|
||||
static_cast<char*>(aligned_alloc(FLATBUFFERS_MAX_ALIGNMENT, size)), free);
|
||||
static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
|
||||
free);
|
||||
#endif
|
||||
memcpy(bytes_copy.get(), bytes.data(), bytes.size());
|
||||
return bytes_copy;
|
||||
|
@ -1,5 +1,19 @@
|
||||
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||
|
||||
#ifdef FLATBUFFERS_VERSION_MAJOR
|
||||
#error "flatbuffer_loader.h must not include any flatbuffers headers"
|
||||
#endif // FLATBUFFERS_VERSION_MAJOR
|
||||
|
||||
#include <array>
|
||||
#include <istream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/dynamic_type.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
@ -12,8 +26,10 @@
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <torch/csrc/jit/frontend/script_type_parser.h>
|
||||
#include <torch/csrc/jit/mobile/file_format.h>
|
||||
#include <torch/csrc/jit/mobile/function.h>
|
||||
#include <torch/csrc/jit/mobile/import.h>
|
||||
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/mobile/observer.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
@ -28,35 +44,110 @@
|
||||
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
|
||||
#endif
|
||||
|
||||
#if defined(HAVE_MMAP)
|
||||
#include <fcntl.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <malloc.h>
|
||||
#else
|
||||
#include <cstdlib>
|
||||
#endif
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// Our own alignment requirement does not need to be exactly the same as what
|
||||
// flatbuffers supports, but what flatbuffers supports needs to satisfy our
|
||||
// requirement.
|
||||
static_assert(
|
||||
kFlatbufferDataAlignmentBytes <= FLATBUFFERS_MAX_ALIGNMENT,
|
||||
"Sizes must be compatible");
|
||||
static_assert(
|
||||
(kFlatbufferDataAlignmentBytes & ~(kFlatbufferDataAlignmentBytes - 1)) ==
|
||||
kFlatbufferDataAlignmentBytes,
|
||||
"Must be a power of 2");
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr c10::string_view kCustomClassPrefix =
|
||||
"__torch__.torch.classes";
|
||||
static constexpr c10::string_view kTorchPrefix = "__torch__";
|
||||
static constexpr c10::string_view kJitPrefix = "torch.jit";
|
||||
|
||||
template <typename T, typename U>
|
||||
std::vector<T> parseListNative(const U* list) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr);
|
||||
return {list->items()->begin(), list->items()->end()};
|
||||
}
|
||||
class FlatbufferLoader final {
|
||||
public:
|
||||
FlatbufferLoader();
|
||||
|
||||
typedef IValue (
|
||||
*IValueParser)(FlatbufferLoader&, const mobile::serialization::IValue&);
|
||||
void registerIValueParser(
|
||||
mobile::serialization::IValueUnion ivalue_type,
|
||||
IValueParser parser);
|
||||
mobile::Module parseModule(mobile::serialization::Module* module);
|
||||
|
||||
void extractJitSourceAndConstants(
|
||||
ExtraFilesMap* jit_sources,
|
||||
std::vector<IValue>* constants);
|
||||
|
||||
typedef TypePtr (*TypeResolver)(
|
||||
const std::string& type_str,
|
||||
std::shared_ptr<CompilationUnit> cu);
|
||||
|
||||
void internal_registerTypeResolver(TypeResolver type_resolver);
|
||||
|
||||
IValue& getIValue(uint32_t pos) {
|
||||
TORCH_CHECK(pos < all_ivalues_.size());
|
||||
return all_ivalues_[pos];
|
||||
}
|
||||
|
||||
mobile::Function* getFunction(uint32_t pos) {
|
||||
return all_functions_[pos];
|
||||
}
|
||||
|
||||
ClassTypePtr getType(uint32_t pos) {
|
||||
TORCH_CHECK(pos < all_types_.size());
|
||||
return all_types_[pos];
|
||||
}
|
||||
|
||||
c10::Storage getStorage(uint32_t index);
|
||||
TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
|
||||
ClassTypePtr getOrCreateClassTypeForObject(
|
||||
const mobile::serialization::Object* object);
|
||||
|
||||
const mobile::serialization::Module* getCurrentFlatbufferInput() {
|
||||
return module_;
|
||||
}
|
||||
|
||||
void setShouldCopyTensorMemory(bool should_copy_tensor_memory) {
|
||||
should_copy_tensor_memory_ = should_copy_tensor_memory;
|
||||
}
|
||||
|
||||
std::shared_ptr<mobile::CompilationUnit> mcu_;
|
||||
std::shared_ptr<CompilationUnit> cu_;
|
||||
|
||||
private:
|
||||
IValue parseIValue(const mobile::serialization::IValue* ivalue);
|
||||
std::unique_ptr<mobile::Function> parseFunction(
|
||||
const mobile::serialization::Function* method);
|
||||
void parseAndPopulate(
|
||||
uint32_t i,
|
||||
const mobile::serialization::IValue* ivalue);
|
||||
|
||||
std::unordered_map<uint32_t, mobile::Function*> all_functions_;
|
||||
std::vector<ClassTypePtr> all_types_;
|
||||
std::unordered_set<uint32_t> initialized_types_;
|
||||
std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
|
||||
std::vector<bool> storage_loaded_;
|
||||
std::vector<c10::Storage> storages_;
|
||||
std::vector<IValue> all_ivalues_;
|
||||
std::array<
|
||||
IValueParser,
|
||||
static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
|
||||
ivalue_parsers_;
|
||||
TypeResolver type_resolver_ = nullptr;
|
||||
mobile::serialization::Module* module_ = nullptr;
|
||||
bool module_parsed_ = false;
|
||||
bool should_copy_tensor_memory_ = false;
|
||||
// 0 -> mobile_ivalue_size_ elements are from the mobile module.
|
||||
uint32_t mobile_ivalue_size_ = 0;
|
||||
};
|
||||
|
||||
IValue parseList(
|
||||
FlatbufferLoader&,
|
||||
@ -225,7 +316,6 @@ mobile::Module FlatbufferLoader::parseModule(
|
||||
return m;
|
||||
}
|
||||
|
||||
namespace {
|
||||
void appendUpgraderFunctions(mobile::Function* function) {
|
||||
#ifndef DISABLE_UPGRADER
|
||||
for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
|
||||
@ -233,7 +323,6 @@ void appendUpgraderFunctions(mobile::Function* function) {
|
||||
}
|
||||
#endif
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
|
||||
const mobile::serialization::Function* method) {
|
||||
@ -266,9 +355,7 @@ std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
|
||||
op->name()->str(), op->overload_name()->str(), num_args);
|
||||
}
|
||||
|
||||
if (should_load_operators_) {
|
||||
function->initialize_operators(true);
|
||||
}
|
||||
function->initialize_operators(true);
|
||||
|
||||
for (const auto i : *method->type_annotations()) {
|
||||
function->append_type(getOrCreateTypeAnnotations(i));
|
||||
@ -434,6 +521,12 @@ IValue parseList(
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
std::vector<T> parseListNative(const U* list) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr);
|
||||
return {list->items()->begin(), list->items()->end()};
|
||||
}
|
||||
|
||||
IValue parseIntList(
|
||||
FlatbufferLoader&,
|
||||
const mobile::serialization::IValue& ivalue) {
|
||||
@ -641,6 +734,8 @@ void FlatbufferLoader::extractJitSourceAndConstants(
|
||||
parseExtraFilesFromVector(module_->jit_sources(), jit_sources);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
mobile::Module parse_and_initialize_mobile_module(
|
||||
void* data,
|
||||
size_t,
|
||||
@ -649,6 +744,8 @@ mobile::Module parse_and_initialize_mobile_module(
|
||||
bool should_copy_tensor_memory) {
|
||||
TORCH_CHECK(
|
||||
mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
|
||||
// TODO(T128189662): If not copying, enforce that data is aligned to
|
||||
// kFlatbufferDataAlignmentBytes, and add unit tests.
|
||||
|
||||
FlatbufferLoader loader;
|
||||
loader.setShouldCopyTensorMemory(should_copy_tensor_memory);
|
||||
@ -687,6 +784,8 @@ mobile::Module parse_and_initialize_mobile_module_for_jit(
|
||||
ExtraFilesMap* extra_files) {
|
||||
TORCH_CHECK(
|
||||
mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
|
||||
// TODO(T128189662): Enforce that data is aligned to
|
||||
// kFlatbufferDataAlignmentBytes, and add unit tests.
|
||||
|
||||
FlatbufferLoader loader;
|
||||
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
|
||||
@ -699,16 +798,6 @@ mobile::Module parse_and_initialize_mobile_module_for_jit(
|
||||
return m;
|
||||
}
|
||||
|
||||
mobile::Module initialize_mobile_module(
|
||||
mobile::serialization::Module* flatbuffer_module,
|
||||
c10::optional<at::Device>,
|
||||
bool should_copy_tensor_memory) {
|
||||
auto flatbufferLoader = FlatbufferLoader();
|
||||
flatbufferLoader.setShouldCopyTensorMemory(should_copy_tensor_memory);
|
||||
mobile::Module m = flatbufferLoader.parseModule(flatbuffer_module);
|
||||
return m;
|
||||
}
|
||||
|
||||
mobile::Module load_mobile_module_from_file(
|
||||
const std::string& filename,
|
||||
c10::optional<c10::Device> device,
|
||||
@ -786,7 +875,8 @@ mobile::Module load_mobile_module_from_stream_with_copy(
|
||||
std::move(data), size, device, extra_files);
|
||||
}
|
||||
|
||||
static mobile::Module parse_flatbuffer_no_object(
|
||||
namespace {
|
||||
mobile::Module parse_flatbuffer_no_object(
|
||||
std::shared_ptr<char> data,
|
||||
size_t size,
|
||||
c10::optional<at::Device> device) {
|
||||
@ -815,6 +905,7 @@ static mobile::Module parse_flatbuffer_no_object(
|
||||
m.set_delete_memory(std::move(data));
|
||||
return m;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool register_flatbuffer_loader() {
|
||||
load_flatbuffer_bytes = parse_and_initialize_mobile_module;
|
||||
|
@ -1,25 +1,32 @@
|
||||
#pragma once
|
||||
|
||||
#include <istream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <torch/csrc/jit/mobile/function.h>
|
||||
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
#include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
/**
|
||||
* Defines the public API for loading flatbuffer-serialized mobile modules.
|
||||
* Note that this header must not include or depend on flatbuffer-defined
|
||||
* types, to avoid leaking those details to PyTorch clients.
|
||||
*/
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
/// All non-copied data pointers provided to `parse_and_initialize_*` functions
|
||||
/// must be aligned to this boundary. Since the Module will point directly into
|
||||
/// the data, this alignment is necessary to ensure that certain types/structs
|
||||
/// are properly aligned.
|
||||
constexpr size_t kFlatbufferDataAlignmentBytes = 16;
|
||||
|
||||
/// Maps file names to file contents.
|
||||
using ExtraFilesMap = std::unordered_map<std::string, std::string>;
|
||||
|
||||
@ -30,22 +37,9 @@ using ExtraFilesMap = std::unordered_map<std::string, std::string>;
|
||||
// structure
|
||||
// 3. Module initialization: Produce mobile::Module out of the structure
|
||||
// produced in 2.
|
||||
// Under this context, the structure described in 2. is
|
||||
// mobile::serialization::Module
|
||||
|
||||
/// DEPRECATED: Use a parse/load function below.
|
||||
// Parse a mobile::Module from flatbuffer's in-memory Module representation.
|
||||
// The caller is assumed to manage the lifetimes of Module.
|
||||
// This function does step 3 described above.
|
||||
// If should_copy_tensor_memory is true, then the returned module will NOT
|
||||
// have refences to flatbuffer_module, so it can be discarded.
|
||||
// If should_copy_tensor_memory is false, then returned module will have
|
||||
// tensors that points inside of flatbuffer_module; the caller need to make
|
||||
// sure that flatbuffer_module outlives returned Module.
|
||||
TORCH_API mobile::Module initialize_mobile_module(
|
||||
mobile::serialization::Module* flatbuffer_module,
|
||||
c10::optional<at::Device> device = c10::nullopt,
|
||||
bool should_copy_tensor_memory = false);
|
||||
// Under this context, the structure described in 2. is the flatbuffer-defined
|
||||
// type mobile::serialization::Module. However, this step/type is not visible in
|
||||
// the public API.
|
||||
|
||||
// Parse a mobile::Module from raw bytes.
|
||||
//
|
||||
@ -59,7 +53,8 @@ TORCH_API mobile::Module initialize_mobile_module(
|
||||
//
|
||||
// If should_copy_tensor_memory is false, then returned module will have tensors
|
||||
// that points inside of `data`; the caller will need to make sure that `data`
|
||||
// outlives the returned Module.
|
||||
// outlives the returned Module. Also, `data` must be aligned to
|
||||
// kFlatbufferDataAlignmentBytes.
|
||||
TORCH_API mobile::Module parse_and_initialize_mobile_module(
|
||||
void* data,
|
||||
size_t size, // of `data`, in bytes.
|
||||
@ -71,7 +66,8 @@ TORCH_API mobile::Module parse_and_initialize_mobile_module(
|
||||
//
|
||||
// This function does steps 2+3 described above.
|
||||
//
|
||||
// The returned Module holds a reference to `data`.
|
||||
// The returned Module holds a reference to `data`, which must be aligned to
|
||||
// kFlatbufferDataAlignmentBytes.
|
||||
//
|
||||
// If you do not want the Module to hold a reference to `data`, see the raw
|
||||
// pointer overload of this function.
|
||||
@ -107,12 +103,6 @@ TORCH_API mobile::Module load_mobile_module_from_file(
|
||||
c10::optional<at::Device> device = c10::nullopt,
|
||||
ExtraFilesMap* extra_files = nullptr);
|
||||
|
||||
/// DEPRECATED: Use the `extra_files` parameter of one of the parse/load
|
||||
/// functions above.
|
||||
TORCH_API void parseExtraFiles(
|
||||
mobile::serialization::Module* module,
|
||||
ExtraFilesMap& extra_files);
|
||||
|
||||
TORCH_API uint64_t get_bytecode_version(std::istream& in);
|
||||
TORCH_API uint64_t get_bytecode_version(const std::string& filename);
|
||||
TORCH_API uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content);
|
||||
@ -133,97 +123,5 @@ TORCH_API mobile::Module load_mobile_module_from_stream_with_copy(
|
||||
// in this file directly.
|
||||
TORCH_API bool register_flatbuffer_loader();
|
||||
|
||||
/// DEPRECATED: Use one of the parse/load functions above.
|
||||
class TORCH_API FlatbufferLoader {
|
||||
public:
|
||||
FlatbufferLoader();
|
||||
|
||||
typedef IValue (
|
||||
*IValueParser)(FlatbufferLoader&, const mobile::serialization::IValue&);
|
||||
void registerIValueParser(
|
||||
mobile::serialization::IValueUnion ivalue_type,
|
||||
IValueParser parser);
|
||||
mobile::Module parseModule(mobile::serialization::Module* module);
|
||||
|
||||
void extractJitSourceAndConstants(
|
||||
ExtraFilesMap* jit_sources,
|
||||
std::vector<IValue>* constants);
|
||||
|
||||
typedef TypePtr (*TypeResolver)(
|
||||
const std::string& type_str,
|
||||
std::shared_ptr<CompilationUnit> cu);
|
||||
|
||||
void internal_registerTypeResolver(TypeResolver type_resolver);
|
||||
|
||||
IValue& getIValue(uint32_t pos) {
|
||||
TORCH_CHECK(pos < all_ivalues_.size());
|
||||
return all_ivalues_[pos];
|
||||
}
|
||||
|
||||
mobile::Function* getFunction(uint32_t pos) {
|
||||
return all_functions_[pos];
|
||||
}
|
||||
|
||||
ClassTypePtr getType(uint32_t pos) {
|
||||
TORCH_CHECK(pos < all_ivalues_.size());
|
||||
return all_types_[pos];
|
||||
}
|
||||
|
||||
c10::Storage getStorage(uint32_t index);
|
||||
TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
|
||||
ClassTypePtr getOrCreateClassTypeForObject(
|
||||
const mobile::serialization::Object* object);
|
||||
|
||||
const mobile::serialization::Module* getCurrentFlatbufferInput() {
|
||||
return module_;
|
||||
}
|
||||
|
||||
bool getShouldCopyTensorMemory() {
|
||||
return should_copy_tensor_memory_;
|
||||
}
|
||||
|
||||
void setShouldCopyTensorMemory(bool should_copy_tensor_memory) {
|
||||
should_copy_tensor_memory_ = should_copy_tensor_memory;
|
||||
}
|
||||
|
||||
// Whether or not should load operators in functions.
|
||||
// Not loading operators is useful because if an operator is not found
|
||||
// then we throw exceptions, and sometimes we want to print out
|
||||
// what operators are included before that to debug.
|
||||
void setShouldLoadOperators(bool should_load_operators) {
|
||||
should_load_operators_ = should_load_operators;
|
||||
}
|
||||
|
||||
std::shared_ptr<mobile::CompilationUnit> mcu_;
|
||||
std::shared_ptr<CompilationUnit> cu_;
|
||||
|
||||
private:
|
||||
IValue parseIValue(const mobile::serialization::IValue* ivalue);
|
||||
std::unique_ptr<mobile::Function> parseFunction(
|
||||
const mobile::serialization::Function* method);
|
||||
void parseAndPopulate(
|
||||
uint32_t i,
|
||||
const mobile::serialization::IValue* ivalue);
|
||||
|
||||
std::unordered_map<uint32_t, mobile::Function*> all_functions_;
|
||||
std::vector<ClassTypePtr> all_types_;
|
||||
std::unordered_set<uint32_t> initialized_types_;
|
||||
std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
|
||||
std::vector<bool> storage_loaded_;
|
||||
std::vector<c10::Storage> storages_;
|
||||
std::vector<IValue> all_ivalues_;
|
||||
std::array<
|
||||
IValueParser,
|
||||
static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
|
||||
ivalue_parsers_;
|
||||
TypeResolver type_resolver_ = nullptr;
|
||||
mobile::serialization::Module* module_ = nullptr;
|
||||
bool module_parsed_ = false;
|
||||
bool should_copy_tensor_memory_ = false;
|
||||
bool should_load_operators_ = true;
|
||||
// 0 -> mobile_ivalue_size_ elements are from the mobile module.
|
||||
uint32_t mobile_ivalue_size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user