Remove flatbuffer types/headers from flatbuffer_loader.h (#82893)

This completely hides the flatbuffer types and headers from users of flatbuffer_loader/serializer, turning them into an internal implementation detail.

A followup diff will fix up the buck files to hide the dependencies more thoroughly.

While doing this I found another use of a flatbuffer-defined name (`FLATBUFFERS_MAX_ALIGNMENT`), which highlighted the issues described in T128189662.

Differential Revision: [D38292794](https://our.internmc.facebook.com/intern/diff/D38292794/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82893
Approved by: https://github.com/qihqi
This commit is contained in:
Dave Bort
2022-08-05 14:10:14 -07:00
committed by PyTorch MergeBot
parent 0c7ca2d97b
commit 1d56ea5e92
3 changed files with 150 additions and 158 deletions

View File

@ -21,21 +21,24 @@
namespace py = pybind11;
using torch::jit::kFlatbufferDataAlignmentBytes;
static std::shared_ptr<char> copyStr(const std::string& bytes) {
size_t size = (bytes.size() / FLATBUFFERS_MAX_ALIGNMENT + 1) *
FLATBUFFERS_MAX_ALIGNMENT;
size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
kFlatbufferDataAlignmentBytes;
#ifdef _WIN32
std::shared_ptr<char> bytes_copy(
static_cast<char*>(_aligned_malloc(size, FLATBUFFERS_MAX_ALIGNMENT)),
static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
_aligned_free);
#elif defined(__APPLE__)
void* p;
::posix_memalign(&p, FLATBUFFERS_MAX_ALIGNMENT, size);
::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
#else
std::shared_ptr<char> bytes_copy(
static_cast<char*>(aligned_alloc(FLATBUFFERS_MAX_ALIGNMENT, size)), free);
static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
free);
#endif
memcpy(bytes_copy.get(), bytes.data(), bytes.size());
return bytes_copy;

View File

@ -1,5 +1,19 @@
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#ifdef FLATBUFFERS_VERSION_MAJOR
#error "flatbuffer_loader.h must not include any flatbuffers headers"
#endif // FLATBUFFERS_VERSION_MAJOR
#include <array>
#include <istream>
#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <ATen/ATen.h>
#include <ATen/core/dynamic_type.h>
#include <ATen/core/ivalue.h>
@ -12,8 +26,10 @@
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/mobile/file_format.h>
#include <torch/csrc/jit/mobile/function.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/runtime/instruction.h>
@ -28,35 +44,110 @@
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
#endif
#if defined(HAVE_MMAP)
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#endif
#ifdef _WIN32
#include <malloc.h>
#else
#include <cstdlib>
#endif
#include <string>
#include <vector>
namespace torch {
namespace jit {
// Our own alignment requirement does not need to be exactly the same as what
// flatbuffers supports, but what flatbuffers supports needs to satisfy our
// requirement.
static_assert(
kFlatbufferDataAlignmentBytes <= FLATBUFFERS_MAX_ALIGNMENT,
"Sizes must be compatible");
static_assert(
(kFlatbufferDataAlignmentBytes & ~(kFlatbufferDataAlignmentBytes - 1)) ==
kFlatbufferDataAlignmentBytes,
"Must be a power of 2");
namespace {
static constexpr c10::string_view kCustomClassPrefix =
"__torch__.torch.classes";
static constexpr c10::string_view kTorchPrefix = "__torch__";
static constexpr c10::string_view kJitPrefix = "torch.jit";
template <typename T, typename U>
std::vector<T> parseListNative(const U* list) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr);
return {list->items()->begin(), list->items()->end()};
}
class FlatbufferLoader final {
public:
FlatbufferLoader();
typedef IValue (
*IValueParser)(FlatbufferLoader&, const mobile::serialization::IValue&);
void registerIValueParser(
mobile::serialization::IValueUnion ivalue_type,
IValueParser parser);
mobile::Module parseModule(mobile::serialization::Module* module);
void extractJitSourceAndConstants(
ExtraFilesMap* jit_sources,
std::vector<IValue>* constants);
typedef TypePtr (*TypeResolver)(
const std::string& type_str,
std::shared_ptr<CompilationUnit> cu);
void internal_registerTypeResolver(TypeResolver type_resolver);
IValue& getIValue(uint32_t pos) {
TORCH_CHECK(pos < all_ivalues_.size());
return all_ivalues_[pos];
}
mobile::Function* getFunction(uint32_t pos) {
return all_functions_[pos];
}
ClassTypePtr getType(uint32_t pos) {
TORCH_CHECK(pos < all_types_.size());
return all_types_[pos];
}
c10::Storage getStorage(uint32_t index);
TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
ClassTypePtr getOrCreateClassTypeForObject(
const mobile::serialization::Object* object);
const mobile::serialization::Module* getCurrentFlatbufferInput() {
return module_;
}
void setShouldCopyTensorMemory(bool should_copy_tensor_memory) {
should_copy_tensor_memory_ = should_copy_tensor_memory;
}
std::shared_ptr<mobile::CompilationUnit> mcu_;
std::shared_ptr<CompilationUnit> cu_;
private:
IValue parseIValue(const mobile::serialization::IValue* ivalue);
std::unique_ptr<mobile::Function> parseFunction(
const mobile::serialization::Function* method);
void parseAndPopulate(
uint32_t i,
const mobile::serialization::IValue* ivalue);
std::unordered_map<uint32_t, mobile::Function*> all_functions_;
std::vector<ClassTypePtr> all_types_;
std::unordered_set<uint32_t> initialized_types_;
std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
std::vector<bool> storage_loaded_;
std::vector<c10::Storage> storages_;
std::vector<IValue> all_ivalues_;
std::array<
IValueParser,
static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
ivalue_parsers_;
TypeResolver type_resolver_ = nullptr;
mobile::serialization::Module* module_ = nullptr;
bool module_parsed_ = false;
bool should_copy_tensor_memory_ = false;
// 0 -> mobile_ivalue_size_ elements are from the mobile module.
uint32_t mobile_ivalue_size_ = 0;
};
IValue parseList(
FlatbufferLoader&,
@ -225,7 +316,6 @@ mobile::Module FlatbufferLoader::parseModule(
return m;
}
namespace {
void appendUpgraderFunctions(mobile::Function* function) {
#ifndef DISABLE_UPGRADER
for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
@ -233,7 +323,6 @@ void appendUpgraderFunctions(mobile::Function* function) {
}
#endif
}
} // namespace
std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
const mobile::serialization::Function* method) {
@ -266,9 +355,7 @@ std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
op->name()->str(), op->overload_name()->str(), num_args);
}
if (should_load_operators_) {
function->initialize_operators(true);
}
function->initialize_operators(true);
for (const auto i : *method->type_annotations()) {
function->append_type(getOrCreateTypeAnnotations(i));
@ -434,6 +521,12 @@ IValue parseList(
return res;
}
template <typename T, typename U>
std::vector<T> parseListNative(const U* list) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr);
return {list->items()->begin(), list->items()->end()};
}
IValue parseIntList(
FlatbufferLoader&,
const mobile::serialization::IValue& ivalue) {
@ -641,6 +734,8 @@ void FlatbufferLoader::extractJitSourceAndConstants(
parseExtraFilesFromVector(module_->jit_sources(), jit_sources);
}
} // namespace
mobile::Module parse_and_initialize_mobile_module(
void* data,
size_t,
@ -649,6 +744,8 @@ mobile::Module parse_and_initialize_mobile_module(
bool should_copy_tensor_memory) {
TORCH_CHECK(
mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
// TODO(T128189662): If not copying, enforce that data is aligned to
// kFlatbufferDataAlignmentBytes, and add unit tests.
FlatbufferLoader loader;
loader.setShouldCopyTensorMemory(should_copy_tensor_memory);
@ -687,6 +784,8 @@ mobile::Module parse_and_initialize_mobile_module_for_jit(
ExtraFilesMap* extra_files) {
TORCH_CHECK(
mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
// TODO(T128189662): Enforce that data is aligned to
// kFlatbufferDataAlignmentBytes, and add unit tests.
FlatbufferLoader loader;
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
@ -699,16 +798,6 @@ mobile::Module parse_and_initialize_mobile_module_for_jit(
return m;
}
mobile::Module initialize_mobile_module(
mobile::serialization::Module* flatbuffer_module,
c10::optional<at::Device>,
bool should_copy_tensor_memory) {
auto flatbufferLoader = FlatbufferLoader();
flatbufferLoader.setShouldCopyTensorMemory(should_copy_tensor_memory);
mobile::Module m = flatbufferLoader.parseModule(flatbuffer_module);
return m;
}
mobile::Module load_mobile_module_from_file(
const std::string& filename,
c10::optional<c10::Device> device,
@ -786,7 +875,8 @@ mobile::Module load_mobile_module_from_stream_with_copy(
std::move(data), size, device, extra_files);
}
static mobile::Module parse_flatbuffer_no_object(
namespace {
mobile::Module parse_flatbuffer_no_object(
std::shared_ptr<char> data,
size_t size,
c10::optional<at::Device> device) {
@ -815,6 +905,7 @@ static mobile::Module parse_flatbuffer_no_object(
m.set_delete_memory(std::move(data));
return m;
}
} // namespace
bool register_flatbuffer_loader() {
load_flatbuffer_bytes = parse_and_initialize_mobile_module;

View File

@ -1,25 +1,32 @@
#pragma once
#include <istream>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include <ATen/core/ivalue.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/mobile/function.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <c10/core/Device.h>
#include <c10/macros/Macros.h>
#include <c10/util/Optional.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
#include <torch/custom_class.h>
/**
* Defines the public API for loading flatbuffer-serialized mobile modules.
* Note that this header must not include or depend on flatbuffer-defined
* types, to avoid leaking those details to PyTorch clients.
*/
namespace torch {
namespace jit {
/// All non-copied data pointers provided to `parse_and_initialize_*` functions
/// must be aligned to this boundary. Since the Module will point directly into
/// the data, this alignment is necessary to ensure that certain types/structs
/// are properly aligned.
constexpr size_t kFlatbufferDataAlignmentBytes = 16;
/// Maps file names to file contents.
using ExtraFilesMap = std::unordered_map<std::string, std::string>;
@ -30,22 +37,9 @@ using ExtraFilesMap = std::unordered_map<std::string, std::string>;
// structure
// 3. Module initialization: Produce mobile::Module out of the structure
// produced in 2.
// Under this context, the structure described in 2. is
// mobile::serialization::Module
/// DEPRECATED: Use a parse/load function below.
// Parse a mobile::Module from flatbuffer's in-memory Module representation.
// The caller is assumed to manage the lifetimes of Module.
// This function does step 3 described above.
// If should_copy_tensor_memory is true, then the returned module will NOT
// have refences to flatbuffer_module, so it can be discarded.
// If should_copy_tensor_memory is false, then returned module will have
// tensors that points inside of flatbuffer_module; the caller need to make
// sure that flatbuffer_module outlives returned Module.
TORCH_API mobile::Module initialize_mobile_module(
mobile::serialization::Module* flatbuffer_module,
c10::optional<at::Device> device = c10::nullopt,
bool should_copy_tensor_memory = false);
// Under this context, the structure described in 2. is the flatbuffer-defined
// type mobile::serialization::Module. However, this step/type is not visible in
// the public API.
// Parse a mobile::Module from raw bytes.
//
@ -59,7 +53,8 @@ TORCH_API mobile::Module initialize_mobile_module(
//
// If should_copy_tensor_memory is false, then returned module will have tensors
// that points inside of `data`; the caller will need to make sure that `data`
// outlives the returned Module.
// outlives the returned Module. Also, `data` must be aligned to
// kFlatbufferDataAlignmentBytes.
TORCH_API mobile::Module parse_and_initialize_mobile_module(
void* data,
size_t size, // of `data`, in bytes.
@ -71,7 +66,8 @@ TORCH_API mobile::Module parse_and_initialize_mobile_module(
//
// This function does steps 2+3 described above.
//
// The returned Module holds a reference to `data`.
// The returned Module holds a reference to `data`, which must be aligned to
// kFlatbufferDataAlignmentBytes.
//
// If you do not want the Module to hold a reference to `data`, see the raw
// pointer overload of this function.
@ -107,12 +103,6 @@ TORCH_API mobile::Module load_mobile_module_from_file(
c10::optional<at::Device> device = c10::nullopt,
ExtraFilesMap* extra_files = nullptr);
/// DEPRECATED: Use the `extra_files` parameter of one of the parse/load
/// functions above.
TORCH_API void parseExtraFiles(
mobile::serialization::Module* module,
ExtraFilesMap& extra_files);
TORCH_API uint64_t get_bytecode_version(std::istream& in);
TORCH_API uint64_t get_bytecode_version(const std::string& filename);
TORCH_API uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content);
@ -133,97 +123,5 @@ TORCH_API mobile::Module load_mobile_module_from_stream_with_copy(
// in this file directly.
TORCH_API bool register_flatbuffer_loader();
/// DEPRECATED: Use one of the parse/load functions above.
class TORCH_API FlatbufferLoader {
public:
FlatbufferLoader();
typedef IValue (
*IValueParser)(FlatbufferLoader&, const mobile::serialization::IValue&);
void registerIValueParser(
mobile::serialization::IValueUnion ivalue_type,
IValueParser parser);
mobile::Module parseModule(mobile::serialization::Module* module);
void extractJitSourceAndConstants(
ExtraFilesMap* jit_sources,
std::vector<IValue>* constants);
typedef TypePtr (*TypeResolver)(
const std::string& type_str,
std::shared_ptr<CompilationUnit> cu);
void internal_registerTypeResolver(TypeResolver type_resolver);
IValue& getIValue(uint32_t pos) {
TORCH_CHECK(pos < all_ivalues_.size());
return all_ivalues_[pos];
}
mobile::Function* getFunction(uint32_t pos) {
return all_functions_[pos];
}
ClassTypePtr getType(uint32_t pos) {
TORCH_CHECK(pos < all_ivalues_.size());
return all_types_[pos];
}
c10::Storage getStorage(uint32_t index);
TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
ClassTypePtr getOrCreateClassTypeForObject(
const mobile::serialization::Object* object);
const mobile::serialization::Module* getCurrentFlatbufferInput() {
return module_;
}
bool getShouldCopyTensorMemory() {
return should_copy_tensor_memory_;
}
void setShouldCopyTensorMemory(bool should_copy_tensor_memory) {
should_copy_tensor_memory_ = should_copy_tensor_memory;
}
// Whether or not should load operators in functions.
// Not loading operators is useful because if an operator is not found
// then we throw exceptions, and sometimes we want to print out
// what operators are included before that to debug.
void setShouldLoadOperators(bool should_load_operators) {
should_load_operators_ = should_load_operators;
}
std::shared_ptr<mobile::CompilationUnit> mcu_;
std::shared_ptr<CompilationUnit> cu_;
private:
IValue parseIValue(const mobile::serialization::IValue* ivalue);
std::unique_ptr<mobile::Function> parseFunction(
const mobile::serialization::Function* method);
void parseAndPopulate(
uint32_t i,
const mobile::serialization::IValue* ivalue);
std::unordered_map<uint32_t, mobile::Function*> all_functions_;
std::vector<ClassTypePtr> all_types_;
std::unordered_set<uint32_t> initialized_types_;
std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
std::vector<bool> storage_loaded_;
std::vector<c10::Storage> storages_;
std::vector<IValue> all_ivalues_;
std::array<
IValueParser,
static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
ivalue_parsers_;
TypeResolver type_resolver_ = nullptr;
mobile::serialization::Module* module_ = nullptr;
bool module_parsed_ = false;
bool should_copy_tensor_memory_ = false;
bool should_load_operators_ = true;
// 0 -> mobile_ivalue_size_ elements are from the mobile module.
uint32_t mobile_ivalue_size_ = 0;
};
} // namespace jit
} // namespace torch