[pytorch][PR] Add ability for a mobile::Module to save as flatbuffer (#70201)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70201

Included functions:
save_mobile_module -> saves a mobile::Module to flatbuffer
load_mobile_module_from_file -> loads a flatbuffer into mobile::Module
parse_mobile_module -> parses from bytes or deserialized flatbuffer module object

Compared to previous attempts, this diff only adds flatbuffer to cmake target and leaves fbcode/xplat ones unchanged.

Test Plan: unittest

Reviewed By: malfet, gmagogsfm

Differential Revision: D33239362

fbshipit-source-id: b9ca36b83d6af2d78cc50b9eb9e2a6fa7fce0763
This commit is contained in:
Han Qi
2022-01-12 16:27:21 -08:00
committed by Facebook GitHub Bot
parent 7a93d8bb2d
commit 1bc3571078
20 changed files with 5132 additions and 2 deletions

View File

@ -49,7 +49,7 @@ jobs:
- name: Ensure canonical include
if: always()
run: |
(! git --no-pager grep -In $'#include "' -- ./c10 ./aten ./torch/csrc ':(exclude)aten/src/ATen/native/quantized/cpu/qnnpack/**' || (echo "The above lines have include with quotes; please convert them to #include <xxxx>"; false))
(! git --no-pager grep -In $'#include "' -- ./c10 ./aten ./torch/csrc ':(exclude)aten/src/ATen/native/quantized/cpu/qnnpack/**' ':(exclude)torch/csrc/jit/serialization/mobile_bytecode_generated.h'|| (echo "The above lines have include with quotes; please convert them to #include <xxxx>"; false))
- name: Ensure no versionless Python shebangs
if: always()
run: |

3
.gitmodules vendored
View File

@ -142,3 +142,6 @@
[submodule "third_party/breakpad"]
path = third_party/breakpad
url = https://github.com/driazati/breakpad.git
[submodule "third_party/flatbuffers"]
path = third_party/flatbuffers
url = https://github.com/google/flatbuffers.git

View File

@ -1692,6 +1692,7 @@ cc_library(
":aten_headers",
":caffe2_headers",
"//c10:headers",
"@com_github_google_flatbuffers//:flatbuffers",
"@local_config_python//:python_headers",
"@onnx",
],
@ -1725,6 +1726,8 @@ cc_library(
],
)) + libtorch_core_sources + libtorch_distributed_sources + torch_cpp_srcs + libtorch_extra_sources + jit_core_sources + [
":cpp_generated_code",
"torch/csrc/jit/serialization/flatbuffer_serializer.cpp",
"torch/csrc/jit/mobile/flatbuffer_loader.cpp"
],
copts = TORCH_COPTS,
defines = [

View File

@ -197,3 +197,8 @@ new_local_repository(
build_file = "@//third_party:cudnn.BUILD",
path = "/usr/",
)
local_repository(
name = "com_github_google_flatbuffers",
path = "third_party/flatbuffers",
)

View File

@ -560,6 +560,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/model_compatibility.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/module.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/flatbuffer_loader.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_bytecode.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_operators.cpp
@ -595,6 +596,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/serialization/export.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/export_bytecode.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/export_module.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/flatbuffer_serializer.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp
${TORCH_SRC_DIR}/csrc/jit/api/module_save.cpp
${TORCH_SRC_DIR}/csrc/utils/byte_order.cpp
@ -1645,6 +1647,9 @@ if(APPLE AND USE_PYTORCH_METAL)
endif()
endif()
target_link_libraries(torch_cpu PRIVATE flatbuffers)
# Note [Global dependencies]
# Some libraries (e.g. OpenMPI) like to dlopen plugins after they're initialized,
# and they assume that all of their symbols will be available in the global namespace.

View File

@ -1996,3 +1996,6 @@ if(USE_KINETO)
message(STATUS "Configured Kineto")
endif()
endif()
# Include google/FlatBuffers
include(${CMAKE_CURRENT_LIST_DIR}/FlatBuffers.cmake)

10
cmake/FlatBuffers.cmake Normal file
View File

@ -0,0 +1,10 @@
set(FlatBuffers_Include ${PROJECT_SOURCE_DIR}/third_party/flatbuffers/include)
file(GLOB FlatBuffers_Library_SRCS
${FlatBuffers_Include}/flatbuffers/*.h
)
add_library(flatbuffers INTERFACE)
target_sources(
flatbuffers
INTERFACE ${FlatBuffers_Library_SRCS}
)
target_include_directories(flatbuffers INTERFACE ${FlatBuffers_Include})

View File

@ -89,6 +89,7 @@ set(JIT_TEST_SRCS
${JIT_TEST_ROOT}/test_script_profile.cpp
${JIT_TEST_ROOT}/test_shape_analysis.cpp
${JIT_TEST_ROOT}/test_jit_logging_levels.cpp
${JIT_TEST_ROOT}/test_flatbuffer.cpp
)
if(USE_CUDA)
@ -101,6 +102,10 @@ add_executable(test_jit
${JIT_TEST_SRCS}
)
target_link_libraries(
test_jit PRIVATE flatbuffers)
# TODO temporary until we can delete the old gtest polyfills.
target_compile_definitions(test_jit PRIVATE USE_GTEST)

File diff suppressed because it is too large Load Diff

1
third_party/flatbuffers vendored Submodule

Submodule third_party/flatbuffers added at f2f9380c86

View File

@ -70,6 +70,7 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES
${TORCH_ROOT}/third_party/gloo
${TORCH_ROOT}/third_party/onnx
${TORCH_ROOT}/third_party/flatbuffers/include
${pybind11_INCLUDE_DIRS}
${TORCH_SRC_DIR}/csrc
@ -345,6 +346,8 @@ if(HAVE_SOVERSION)
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})
endif()
add_dependencies(torch_python torch_python_stubs)
add_dependencies(torch_python flatbuffers)
if(USE_PRECOMPILED_HEADERS)
target_precompile_headers(torch_python PRIVATE

View File

@ -0,0 +1,518 @@
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#include <ATen/ATen.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/qualified_name.h>
#include <c10/core/CPUAllocator.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/ScopeExit.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
#include <torch/csrc/jit/serialization/import_read.h>
#include <torch/custom_class.h>
#include <flatbuffers/flatbuffers.h>
#if defined(HAVE_MMAP)
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#endif
#include <string>
#include <vector>
namespace torch {
namespace jit {
namespace {
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::ReadAdapterInterface;
static constexpr c10::string_view kCustomClassPrefix =
"__torch__.torch.classes";
static constexpr c10::string_view kTorchPrefix = "__torch__";
static constexpr c10::string_view kJitPrefix = "torch.jit";
class FlatbufferLoader {
public:
FlatbufferLoader()
: mcu_(std::make_shared<mobile::CompilationUnit>()),
cu_(std::make_shared<CompilationUnit>()) {}
mobile::Module parseModule(mobile::serialization::Module* module);
private:
IValue parseIValue(const mobile::serialization::IValue* ivalue);
IValue parseList(const mobile::serialization::List* list);
at::Tensor parseTensor(const mobile::serialization::TensorMetadata* tensor);
IValue parseTuple(const mobile::serialization::Tuple* tuple);
IValue parseDict(const mobile::serialization::Dict* dict);
IValue parseObject(const mobile::serialization::Object* object);
std::unique_ptr<mobile::Function> parseFunction(
const mobile::serialization::Function* method);
IValue& getIValue(uint32_t pos) {
TORCH_CHECK(pos < all_ivalues_.size());
return all_ivalues_[pos];
}
mobile::Function* getFunction(uint32_t pos) {
return all_functions_[pos];
}
ClassTypePtr getType(uint32_t pos) const {
TORCH_CHECK(pos < all_ivalues_.size());
return all_types_[pos];
// auto iter = all_types_.find(pos);
// AT_ASSERT(iter != all_types_.end(), "type not found at pos: ", pos);
// return iter->second;
}
c10::Storage getStorage(uint32_t index);
TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
// fields
std::unordered_map<uint32_t, mobile::Function*> all_functions_;
std::vector<ClassTypePtr> all_types_;
std::unordered_set<uint32_t> initialized_types_;
std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
std::vector<bool> storage_loaded_;
std::vector<c10::Storage> storages_;
std::vector<IValue> all_ivalues_;
std::shared_ptr<mobile::CompilationUnit> mcu_;
std::shared_ptr<CompilationUnit> cu_;
mobile::serialization::Module* module_ = nullptr;
};
mobile::Module FlatbufferLoader::parseModule(
mobile::serialization::Module* module) {
module_ = module;
all_ivalues_.clear();
all_types_.clear();
storages_.clear();
storage_loaded_.clear();
const auto* ivalues = module->ivalues();
all_ivalues_.resize(ivalues->size());
all_types_.resize(module->object_types()->size());
storages_.resize(module->storage_data_size());
storage_loaded_.resize(module->storage_data_size(), false);
for (uint32_t i = 0; i < ivalues->size(); i++) {
const auto* ival = ivalues->Get(i);
if (const auto* func = ival->val_as_Function()) {
auto func_ptr = parseFunction(func);
all_functions_[i] = func_ptr.get();
mcu_->register_function(std::move(func_ptr));
} else {
all_ivalues_[i] = parseIValue(ival);
}
}
IValue& module_ivalue = getIValue(module->state_obj());
// register function to class
// for (const auto& func: all_functions_) {
// const auto* fb_func = ivalues->Get(func.first)->val_as_Function();
// auto class_type = getType(fb_func->class_type());
// class_type->addMethod(func.second);
// }
return mobile::Module(module_ivalue.toObject(), mcu_);
}
std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
const mobile::serialization::Function* method) {
auto function = std::make_unique<mobile::Function>(
c10::QualifiedName(method->qn()->str()));
// TODO(qihan) add debug handle
// const auto* debug_handle = method->debug_info()->debug_handle();
for (const auto* inst : *method->instructions()) {
function->append_instruction(
static_cast<OpCode>(inst->op()), inst->x(), inst->n());
}
for (uint32_t i : *method->constants()) {
function->append_constant(getIValue(i));
}
std::unordered_set<std::string> unsupported_op_names;
const int64_t model_version = 0x6L;
for (const auto* op : *method->operators()) {
c10::optional<int> num_args = c10::nullopt;
if (op->num_args_serialized() > -1) {
num_args = op->num_args_serialized();
}
auto op_found = function->append_operator(
op->name()->str(), op->overload_name()->str(), num_args, model_version);
if (!op_found) {
unsupported_op_names.emplace(
op->name()->str() + "/" + op->overload_name()->str());
}
}
AT_ASSERT(unsupported_op_names.empty());
for (const auto i : *method->type_annotations()) {
function->append_type(getOrCreateTypeAnnotations(i));
}
function->set_register_size(method->register_size());
if (method->schema()) {
auto parseArgList = [this](const auto* args_fb) {
std::vector<c10::Argument> args;
for (const auto* arg_tb : *args_fb) {
IValue default_value = getIValue(arg_tb->default_value());
TypePtr type_ptr = getOrCreateTypeAnnotations(arg_tb->type());
auto arg = c10::Argument(
arg_tb->name()->str(),
std::move(type_ptr),
c10::nullopt /*N*/,
std::move(default_value));
args.emplace_back(std::move(arg));
}
return args;
};
c10::FunctionSchema schema(
method->qn()->str(),
"" /*overload_name*/,
parseArgList(method->schema()->arguments()),
parseArgList(method->schema()->returns()),
false /*is_varargs*/,
false /*is_varret*/);
function->setSchema(std::move(schema));
}
return function;
}
at::Tensor FlatbufferLoader::parseTensor(
const mobile::serialization::TensorMetadata* tensor_md) {
at::ScalarType type = static_cast<at::ScalarType>(tensor_md->scalar_type());
auto options = at::CPU(type).options();
at::Tensor tensor;
if (tensor_md->quantized_schema() != nullptr) {
// is quantized
const auto* schema = tensor_md->quantized_schema();
auto qscheme_type = static_cast<at::QScheme>(schema->qscheme());
switch (qscheme_type) {
case at::kPerTensorAffine: {
tensor = at::_empty_affine_quantized(
{0}, options, schema->scale(), schema->zero_point());
} break;
case at::kPerChannelAffineFloatQParams:
case at::kPerChannelAffine: {
at::Tensor scales = parseTensor(schema->scales());
at::Tensor zero_points = parseTensor(schema->zero_points());
tensor = at::_empty_per_channel_affine_quantized(
{0}, scales, zero_points, schema->axis(), options);
} break;
default:
TORCH_CHECK(
false,
"Unsupported tensor quantization type in serialization ",
toString(qscheme_type));
break;
}
} else {
tensor = at::empty({0}, options);
}
at::TensorImpl* impl = tensor.unsafeGetTensorImpl();
c10::Storage storage;
storage = getStorage(tensor_md->storage_location_index());
impl->set_storage_keep_dtype(storage);
impl->set_storage_offset(tensor_md->storage_offset());
std::vector<int64_t> size{
tensor_md->sizes()->begin(), tensor_md->sizes()->end()};
std::vector<int64_t> stride{
tensor_md->strides()->begin(), tensor_md->strides()->end()};
impl->set_sizes_and_strides(size, stride);
tensor = autograd::make_variable(tensor, tensor_md->requires_grad());
return tensor;
}
IValue FlatbufferLoader::parseList(const mobile::serialization::List* list) {
auto res = c10::impl::GenericList(AnyType::get());
for (int i : *list->items()) {
res.emplace_back(getIValue(i));
}
auto type =
getOrCreateTypeAnnotations(list->annotation_str())->cast<ListType>();
res.unsafeSetElementType(type->getElementType());
return res;
}
IValue FlatbufferLoader::parseTuple(const mobile::serialization::Tuple* tuple) {
std::vector<IValue> res;
for (int i : *tuple->items()) {
res.emplace_back(getIValue(i));
}
return c10::ivalue::Tuple::create(res);
}
IValue FlatbufferLoader::parseDict(const mobile::serialization::Dict* dict) {
auto result = c10::impl::GenericDict(AnyType::get(), AnyType::get());
const auto* keys = dict->keys();
const auto* values = dict->values();
for (size_t i = 0; i < keys->size(); ++i) {
uint32_t key = keys->Get(i);
uint32_t val = values->Get(i);
result.insert_or_assign(getIValue(key), getIValue(val));
}
auto type =
getOrCreateTypeAnnotations(dict->annotation_str())->cast<DictType>();
result.unsafeSetKeyType(type->getKeyType());
result.unsafeSetValueType(type->getValueType());
return result;
}
IValue FlatbufferLoader::parseObject(
const mobile::serialization::Object* object) {
const mobile::serialization::ObjectType* obj_type =
module_->object_types()->Get(object->type_index());
auto cls = getType(object->type_index());
bool initialized = true;
if (cls == nullptr) {
c10::string_view qn_str(
obj_type->type_name()->c_str(), obj_type->type_name()->size());
if (qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) {
c10::QualifiedName qn(obj_type->type_name()->str());
cls = cu_->get_class(qn);
if (cls == nullptr) {
cls = ClassType::create(qn, cu_, true);
cu_->register_type(cls);
}
} else {
cls = c10::parseType(std::string(qn_str))->cast<ClassType>();
}
TORCH_CHECK(object->type_index() < all_ivalues_.size());
all_types_[object->type_index()] = cls;
initialized = false;
}
Stack stack;
switch (obj_type->type()) {
case mobile::serialization::TypeType::CLASS_WITH_FIELD: {
auto obj = c10::ivalue::Object::create(
at::StrongTypePtr(cu_, cls), object->attrs()->size());
if (!initialized) {
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
IValue val = getIValue(object->attrs()->Get(i));
cls->addAttribute(obj_type->attr_names()->Get(i)->str(), val.type());
obj->setSlot(i, std::move(val));
}
initialized_types_.insert(object->type_index());
} else {
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
IValue val = getIValue(object->attrs()->Get(i));
obj->setSlot(i, std::move(val));
}
}
return obj;
}
case mobile::serialization::TypeType::CLASS_WITH_SETSTATE: {
IValue input = getIValue(object->state());
mobile::Function* setstate = getFunction(object->setstate_func());
auto obj = c10::ivalue::Object::create(at::StrongTypePtr(cu_, cls), 0);
std::cerr << "here 2: " << cls.get() << std::endl;
stack.push_back(obj);
stack.emplace_back(std::move(input));
setstate->run(stack);
return obj;
}
case mobile::serialization::TypeType::CUSTOM_CLASS: {
auto custom_class_type =
torch::jit::getCustomClass(cls->name()->qualifiedName());
IValue input = getIValue(object->state());
auto obj = c10::ivalue::Object::create(
c10::StrongTypePtr(nullptr, custom_class_type), 1);
std::cerr << "here 3: " << cls.get() << std::endl;
stack.push_back(obj);
stack.emplace_back(std::move(input));
custom_class_type->getMethod("__setstate__").run(stack);
return obj;
}
default:
AT_ASSERT(false, "need to be object");
}
}
template <typename T, typename U>
std::vector<T> parseListNative(const U* list) {
return {list->items()->begin(), list->items()->end()};
}
IValue FlatbufferLoader::parseIValue(
const mobile::serialization::IValue* ivalue) {
switch (ivalue->val_type()) {
case mobile::serialization::IValueUnion::NONE:
return {};
case mobile::serialization::IValueUnion::Int:
return ivalue->val_as_Int()->int_val();
case mobile::serialization::IValueUnion::Bool:
return ivalue->val_as_Bool()->bool_val();
case mobile::serialization::IValueUnion::Double:
return ivalue->val_as_Double()->double_val();
case mobile::serialization::IValueUnion::ComplexDouble: {
const auto* comp = ivalue->val_as_ComplexDouble();
return c10::complex<double>(comp->real(), comp->imag());
}
case mobile::serialization::IValueUnion::TensorMetadata:
return parseTensor(ivalue->val_as_TensorMetadata());
case mobile::serialization::IValueUnion::String:
return ivalue->val_as_String()->data()->str();
case mobile::serialization::IValueUnion::List:
return parseList(ivalue->val_as_List());
case mobile::serialization::IValueUnion::IntList:
return parseListNative<int64_t>(ivalue->val_as_IntList());
case mobile::serialization::IValueUnion::DoubleList:
return parseListNative<double>(ivalue->val_as_DoubleList());
case mobile::serialization::IValueUnion::BoolList: {
std::vector<uint8_t> res =
parseListNative<uint8_t>(ivalue->val_as_BoolList());
c10::List<bool> boollist;
for (auto x : res) {
boollist.push_back(x);
}
return boollist;
}
case mobile::serialization::IValueUnion::Tuple:
return parseTuple(ivalue->val_as_Tuple());
case mobile::serialization::IValueUnion::Dict:
return parseDict(ivalue->val_as_Dict());
case mobile::serialization::IValueUnion::Object: {
auto val = parseObject(ivalue->val_as_Object());
return val;
}
case mobile::serialization::IValueUnion::Device: {
return c10::Device(ivalue->val_as_Device()->str()->str());
}
case mobile::serialization::IValueUnion::EnumValue: {
const auto* enum_val = ivalue->val_as_EnumValue();
auto enum_type = getOrCreateTypeAnnotations(enum_val->type_name())
->cast<c10::EnumType>();
AT_ASSERT(
enum_type,
"Enum with type: " + enum_val->type_name()->str() + " not found.");
IValue val = getIValue(enum_val->value());
for (const auto& p : enum_type->enumNamesValues()) {
if (p.second == val) {
auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
enum_type, p.first, p.second);
return IValue(std::move(enum_holder));
}
}
AT_ASSERT(
false,
"Enum with type: " + enum_val->type_name()->str() + " not found.");
}
default:
return {};
}
}
void deleteNothing2(void*);
void deleteNothing2(void*) {}
c10::Storage FlatbufferLoader::getStorage(uint32_t index) {
TORCH_CHECK(index < storage_loaded_.size());
TORCH_CHECK(index < storages_.size());
if (!storage_loaded_[index]) {
auto* storage = module_->storage_data()->GetMutableObject(index);
size_t size = storage->data()->size();
void* ptr = static_cast<void*>(storage->mutable_data()->data());
at::DataPtr data(ptr, ptr, deleteNothing2, DeviceType::CPU);
storages_[index] =
c10::Storage(c10::Storage::use_byte_size_t(), size, std::move(data));
storage_loaded_[index] = true;
}
return storages_[index];
}
TypePtr FlatbufferLoader::getOrCreateTypeAnnotations(
const flatbuffers::String* offset) {
auto iter = type_annotations_.find(offset);
if (iter != type_annotations_.end()) {
return iter->second;
}
TypePtr type;
c10::string_view qn_str(offset->c_str(), offset->size());
c10::QualifiedName qn(offset->str());
if (qn_str.starts_with(kCustomClassPrefix)) {
type = getCustomClass(qn.qualifiedName());
TORCH_CHECK(
type,
"The implementation of class ",
qn.qualifiedName(),
" cannot be found.");
} else if (
qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) {
if (cu_->get_class(qn) == nullptr) {
auto classtype = ClassType::create(qn, cu_, true);
cu_->register_type(classtype);
type = classtype;
} else {
type = cu_->get_class(qn);
}
} else {
type = c10::parseType(qn.qualifiedName());
}
type_annotations_[offset] = type;
return type;
}
} // namespace
mobile::Module parse_and_initialize_mobile_module(
std::shared_ptr<char> data,
size_t,
c10::optional<at::Device>) {
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
mobile::Module m = FlatbufferLoader().parseModule(flatbuffer_module);
m.set_delete_memory(std::move(data));
return m;
}
mobile::Module initialize_mobile_module(
mobile::serialization::Module* flatbuffer_module,
c10::optional<at::Device>) {
mobile::Module m = FlatbufferLoader().parseModule(flatbuffer_module);
return m;
}
mobile::Module load_mobile_module_from_file(
const std::string& filename,
c10::optional<c10::Device> device) {
#if defined(HAVE_MMAP)
int fd = open(filename.c_str(), O_RDONLY);
struct stat statbuf {};
fstat(fd, &statbuf);
int size = statbuf.st_size;
void* ptr = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
close(fd);
auto deleter = [statbuf](char* ptr) { munmap(ptr, statbuf.st_size); };
std::shared_ptr<char> data(reinterpret_cast<char*>(ptr), deleter);
#else
FILE* f = fopen(filename.c_str(), "rb");
fseek(f, 0, SEEK_END);
long size = ftell(f);
fseek(f, 0, SEEK_SET);
std::shared_ptr<char> data(static_cast<char*>(malloc(size)), free); // NOLINT
fread(data.get(), size, 1, f);
fclose(f);
#endif
return parse_and_initialize_mobile_module(std::move(data), size, device);
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,54 @@
#pragma once
#include <ATen/core/ivalue.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/mobile/function.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
#include <torch/custom_class.h>
#include <string>
#include <vector>
namespace torch {
namespace jit {
// On high level, to produce a Module from a file on disk, we need to go
// through the follow steps:
// 1. Read: Read the file from disk -> memory
// 2. Deserialize: Parse the bytes to produce some in memory manipulable
// structure
// 3. Module initialization: Produce mobile::Module out of the structure
// produced in 2.
// Under this context, the structure described in 2. is
// mobile::serialization::Module
// Parse a mobile::Module from flatbuffer's in-memory Module representation.
// The caller is assumed to manage the lifetimes of Module.
// This function does step 3 described above.
TORCH_API mobile::Module initialize_mobile_module(
mobile::serialization::Module* flatbuffer_module,
c10::optional<at::Device> device = c10::nullopt);
// Parse a mobile::Module from raw bytes.
// ownership of data is shared to the returned Module.
// (Feel free to pass in a unique_ptr too!)
// This function does steps 2+3 described above
TORCH_API mobile::Module parse_and_initialize_mobile_module(
std::shared_ptr<char> data,
size_t size,
c10::optional<at::Device> device = c10::nullopt);
// Load a mobile::Module from a filepath.
// This function does steps 1+2+3 described above.
// We need to have this as a convienience because Python
// API will need to wrap this. C++ clients should use one
// versions above.
TORCH_API mobile::Module load_mobile_module_from_file(
const std::string& filename,
c10::optional<at::Device> device = c10::nullopt);
} // namespace jit
} // namespace torch

View File

@ -19,7 +19,8 @@ void CompilationUnit::register_function(std::unique_ptr<Function> fn) {
methods_.emplace_back(std::move(fn));
}
Function* CompilationUnit::find_function(const c10::QualifiedName& qn) {
const Function* CompilationUnit::find_function(
const c10::QualifiedName& qn) const {
for (auto& fn : methods_) {
if (fn->qualname() == qn) {
return fn.get();
@ -28,6 +29,12 @@ Function* CompilationUnit::find_function(const c10::QualifiedName& qn) {
return nullptr;
}
Function* CompilationUnit::find_function(const c10::QualifiedName& qn) {
// NOLINTNEXTLINE
return const_cast<Function*>(
static_cast<const CompilationUnit*>(this)->find_function(qn));
}
Method Module::get_method(const std::string& name) const {
if (auto method = find_method(name)) {
return *method;

View File

@ -40,6 +40,7 @@ class CompilationUnit {
return methods_;
}
Function* find_function(const c10::QualifiedName& qn);
const Function* find_function(const c10::QualifiedName& qn) const;
private:
std::vector<std::unique_ptr<Function>> methods_;
@ -130,12 +131,19 @@ class TORCH_API Module {
return *cu_.get();
}
void set_delete_memory(std::shared_ptr<char> delete_mem) {
mem_to_delete_ = delete_mem;
}
private:
c10::intrusive_ptr<c10::ivalue::Object> object_;
std::unordered_map<std::string, std::string> metadata_;
std::shared_ptr<CompilationUnit> cu_;
MobileDebugTable debug_table_;
bool has_debug_handles_ = false;
// Extra handle for the module to delete when itself is deleted
std::shared_ptr<char> mem_to_delete_;
};
} // namespace mobile
} // namespace jit

View File

@ -76,6 +76,7 @@ std::ostream& operator<<(std::ostream& out, Instruction inst);
bool isOpSupportedInMobile(OpCode op);
char const* toString(OpCode op);
std::ostream& operator<<(std::ostream& out, Instruction inst);
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,681 @@
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#include <ATen/ATen.h>
#include <c10/core/CPUAllocator.h>
#include <flatbuffers/flatbuffers.h>
#include <torch/csrc/jit/mobile/code.h>
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/export.h>
#include <string>
namespace torch {
namespace jit {
using flatbuffers::FlatBufferBuilder;
using mobile::serialization::CreateArg;
using mobile::serialization::CreateDebugInfo;
using mobile::serialization::CreateDict;
using mobile::serialization::CreateFunctionDirect;
using mobile::serialization::CreateIValue;
using mobile::serialization::CreateList;
using mobile::serialization::CreateModule;
using mobile::serialization::CreateObject;
using mobile::serialization::CreateOperator;
using mobile::serialization::CreateTensorMetadataDirect;
using mobile::serialization::CreateTupleDirect;
namespace {
// We will store IValue NONE in index 0 in flatbuffer.
constexpr int kNoneIndex = 0;
class FlatbufferSerializer {
public:
FlatbufferSerializer() = default;
flatbuffers::DetachedBuffer serializeModule(
const mobile::Module& module,
bool include_tensor_data_in_flatbuffer);
private:
template <typename It>
std::vector<uint32_t> storeIValuesAndGetIndexes(
flatbuffers::FlatBufferBuilder& fbb,
It begin,
It end) {
std::vector<uint32_t> indexes;
for (; begin != end; ++begin) {
indexes.push_back(storeIValueAndGetIndex(fbb, *begin));
}
return indexes;
}
flatbuffers::Offset<mobile::serialization::Tuple> tupleToFB(
flatbuffers::FlatBufferBuilder& fbb,
const IValue& tuple);
flatbuffers::Offset<mobile::serialization::List> listToFB(
flatbuffers::FlatBufferBuilder& fbb,
const IValue& list);
flatbuffers::Offset<mobile::serialization::Dict> dictToFB(
flatbuffers::FlatBufferBuilder& fbb,
const IValue& list);
flatbuffers::Offset<mobile::serialization::Object> objectToFB(
flatbuffers::FlatBufferBuilder& fbb,
const IValue& ivalue);
flatbuffers::Offset<mobile::serialization::TensorMetadata> tensorToFB(
flatbuffers::FlatBufferBuilder& fbb,
const IValue& ivalue);
flatbuffers::Offset<mobile::serialization::Function> functionToFB(
flatbuffers::FlatBufferBuilder& fbb,
const std::string& qn,
const mobile::Function& func);
flatbuffers::Offset<mobile::serialization::IValue> iValueToFB(
flatbuffers::FlatBufferBuilder& fbb,
const IValue& ivalue);
flatbuffers::Offset<jit::mobile::serialization::Schema> CreateFBSchema(
flatbuffers::FlatBufferBuilder& fbb,
const std::vector<Argument>& args,
const std::vector<Argument>& returns,
c10::TypePrinter type_printer);
flatbuffers::Offset<mobile::serialization::ObjectType> classTypeToFB(
flatbuffers::FlatBufferBuilder& fbb,
ClassTypePtr class_ptr);
uint32_t storeIValueAndGetIndex(
flatbuffers::FlatBufferBuilder& fbb,
const IValue& ivalue);
uint32_t storeFunctionAndGetIndex(
flatbuffers::FlatBufferBuilder& fbb,
const std::string& qn,
const mobile::Function& function);
uint32_t storeClassTypeAndGetIndex(
flatbuffers::FlatBufferBuilder& fbb,
ClassTypePtr class_type);
uint32_t insertIValue(
flatbuffers::Offset<mobile::serialization::IValue> ivalue) {
uint32_t size = ivalue_offsets_.size();
ivalue_offsets_.push_back(ivalue);
return size;
}
std::vector<at::Tensor> tensor_data_;
std::unordered_map<const void*, uint32_t> memoized_storage_map_;
std::vector<flatbuffers::Offset<mobile::serialization::IValue>>
ivalue_offsets_;
std::vector<flatbuffers::Offset<mobile::serialization::ObjectType>>
obj_types_offset_;
// qualified name to serialized class, type or function
std::unordered_map<std::string, uint32_t> qn_to_serialized_values_;
// cache of some ivalues
struct IValueHash {
size_t operator()(const IValue& val) const {
return IValue::hash(val);
}
};
std::unordered_map<IValue, uint32_t, IValueHash> cached_ivalues_;
const mobile::CompilationUnit* mcu_ = nullptr;
};
flatbuffers::Offset<jit::mobile::serialization::Schema> FlatbufferSerializer::
CreateFBSchema(
flatbuffers::FlatBufferBuilder& fbb,
const std::vector<Argument>& args,
const std::vector<Argument>& returns,
c10::TypePrinter type_printer) {
std::vector<flatbuffers::Offset<jit::mobile::serialization::Arg>> arg_vec;
arg_vec.reserve(args.size());
std::vector<flatbuffers::Offset<jit::mobile::serialization::Arg>> return_vec;
return_vec.reserve(returns.size());
for (const auto& arg : args) {
int index = storeIValueAndGetIndex(fbb, arg.default_value());
arg_vec.emplace_back(CreateArg(
fbb,
fbb.CreateSharedString(arg.name()),
fbb.CreateSharedString(arg.type()->annotation_str(type_printer)),
index));
}
for (const auto& ret : returns) {
int index = storeIValueAndGetIndex(fbb, ret.default_value());
return_vec.emplace_back(CreateArg(
fbb,
fbb.CreateSharedString(ret.name()),
fbb.CreateSharedString(ret.type()->annotation_str(type_printer)),
index));
}
return CreateSchema(
fbb, fbb.CreateVector(arg_vec), fbb.CreateVector(return_vec));
}
flatbuffers::Offset<mobile::serialization::Function> FlatbufferSerializer::
functionToFB(
FlatBufferBuilder& fbb,
const std::string& qn,
const mobile::Function& func) {
const auto& code = func.get_code();
// instructions
std::vector<mobile::serialization::Instruction> instruction_vector;
for (const auto& inst : code.instructions_) {
instruction_vector.emplace_back(inst.op, inst.N, inst.X);
}
// operators
std::vector<flatbuffers::Offset<mobile::serialization::Operator>>
operator_vector;
operator_vector.reserve(code.op_names_.size());
for (int i = 0; i < code.op_names_.size(); ++i) {
const auto& opname = code.op_names_[i];
const int op_size = code.operator_input_sizes_[i];
operator_vector.push_back(CreateOperator(
fbb,
fbb.CreateSharedString(opname.name),
fbb.CreateSharedString(opname.overload_name),
op_size));
}
const auto& constants = code.constants_;
std::vector<uint32_t> constant_indexes;
constant_indexes.reserve(constants.size());
for (const auto& constant : constants) {
constant_indexes.push_back(storeIValueAndGetIndex(fbb, constant));
}
// types
static const std::string torch_prefix("__torch__");
static const std::string class_prefix("__torch__.torch.classes");
std::vector<flatbuffers::Offset<flatbuffers::String>> type_offsets;
for (const TypePtr& t : code.types_) {
auto type_str = t->annotation_str();
if (type_str.find(torch_prefix) == 0) {
TORCH_CHECK(
type_str.find(class_prefix) == 0,
"__torch__ types other than torchbind (__torch__.torch.classes)"
"are not supported in lite interpreter. ",
"Workaround: instead of using arbitrary class type (class Foo()), ",
"define a pytorch class (class Foo(torch.nn.Module)).");
}
type_offsets.push_back(fbb.CreateSharedString(type_str));
}
// since the register location is embedded into the bytecode, pass the
// register size
auto register_size = static_cast<int>(code.register_size_);
// schema
auto type_printer = [&](const c10::Type& t) -> c10::optional<std::string> {
auto namedType = t.cast<c10::NamedType>();
if (namedType && namedType->name()) {
return namedType->name().value().qualifiedName();
}
return c10::nullopt;
};
flatbuffers::Offset<mobile::serialization::Schema> schema_offset = 0;
if (func.hasSchema()) {
const auto& schema = func.getSchema();
TORCH_CHECK(
schema.overload_name().empty(), // @TODO: is this check correct?
"Overloads are not supported in mobile modules.");
TORCH_CHECK(
!schema.is_vararg(),
"Python *args are not supported in mobile modules.");
TORCH_CHECK(
!schema.is_varret(),
"A variable number of return values is not supported in mobile modules.");
schema_offset =
CreateFBSchema(fbb, schema.arguments(), schema.returns(), type_printer);
}
auto debug_info_offset =
CreateDebugInfo(fbb, fbb.CreateVector(code.debug_handles_));
// auto classtype = schema.arguments()[0].type()->cast<ClassType>();
// uint32_t class_type = storeClassTypeAndGetIndex(fbb, classtype);
auto function_offset = CreateFunctionDirect(
fbb,
qn.c_str(),
&instruction_vector,
&operator_vector,
&constant_indexes,
&type_offsets,
register_size,
schema_offset,
debug_info_offset,
0);
return function_offset;
}
flatbuffers::DetachedBuffer FlatbufferSerializer::serializeModule(
const mobile::Module& module,
bool include_tensor_data_in_flatbuffer) {
FlatBufferBuilder fbb;
mcu_ = &module.compilation_unit();
// first element is None.
insertIValue(CreateIValue(fbb, mobile::serialization::IValueUnion::NONE, 0));
auto methods = module.get_methods();
std::vector<uint32_t> functions_index;
functions_index.reserve(methods.size());
for (const auto& method : methods) {
auto func_offset = storeFunctionAndGetIndex(
fbb, method.function().qualname().qualifiedName(), method.function());
functions_index.push_back(func_offset);
}
auto functions_offset = fbb.CreateVector(functions_index);
uint32_t ivalue_index = storeIValueAndGetIndex(fbb, module._ivalue());
flatbuffers::Offset<flatbuffers::Vector<
flatbuffers::Offset<mobile::serialization::StorageData>>>
storage_data_offset = 0;
if (include_tensor_data_in_flatbuffer) {
std::vector<flatbuffers::Offset<mobile::serialization::StorageData>>
storage_data;
for (auto td : tensor_data_) {
if (td.storage().device_type() != DeviceType::CPU) {
td = at::empty({0}, td.options())
.set_(
td.storage(),
/* storage_offset = */ 0,
/* size = */
{static_cast<int64_t>(
td.storage().nbytes() / td.element_size())},
/* stride = */ {1})
.cpu();
}
fbb.ForceVectorAlignment(
td.storage().nbytes(), sizeof(uint8_t), FLATBUFFERS_MAX_ALIGNMENT);
auto storage_offset = mobile::serialization::CreateStorageData(
fbb,
fbb.CreateVector(
reinterpret_cast<const uint8_t*>(td.storage().data()),
td.storage().nbytes()));
storage_data.push_back(storage_offset);
}
storage_data_offset = fbb.CreateVector(storage_data);
}
auto mod = CreateModule(
fbb,
0, /* version */
0, /* extra_files */
functions_offset,
ivalue_index,
fbb.CreateVector(ivalue_offsets_),
tensor_data_.size(),
storage_data_offset,
fbb.CreateVector(obj_types_offset_));
fbb.Finish(mod);
return fbb.Release();
}
flatbuffers::Offset<mobile::serialization::Tuple> FlatbufferSerializer::
tupleToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& tuple) {
const auto& elements = tuple.toTuple()->elements();
std::vector<uint32_t> items =
storeIValuesAndGetIndexes(fbb, elements.begin(), elements.end());
return CreateTupleDirect(fbb, &items);
}
flatbuffers::Offset<mobile::serialization::List> FlatbufferSerializer::listToFB(
flatbuffers::FlatBufferBuilder& fbb,
const IValue& list) {
const auto& elements = list.toList();
std::vector<uint32_t> items =
storeIValuesAndGetIndexes(fbb, elements.begin(), elements.end());
return CreateList(
fbb,
fbb.CreateVector(items),
fbb.CreateSharedString(list.type()->annotation_str()));
}
flatbuffers::Offset<mobile::serialization::Dict> FlatbufferSerializer::dictToFB(
flatbuffers::FlatBufferBuilder& fbb,
const IValue& ivalue) {
const auto& dict = ivalue.toGenericDict();
std::vector<uint32_t> keys;
std::vector<uint32_t> values;
keys.reserve(dict.size());
values.reserve(dict.size());
for (const auto& entry : dict) {
int key_index = storeIValueAndGetIndex(fbb, entry.key());
keys.push_back(key_index);
int value_index = storeIValueAndGetIndex(fbb, entry.value());
values.push_back(value_index);
}
return CreateDict(
fbb,
fbb.CreateVector(keys),
fbb.CreateVector(values),
fbb.CreateSharedString(ivalue.type()->annotation_str()));
}
flatbuffers::Offset<mobile::serialization::ObjectType> FlatbufferSerializer::
classTypeToFB(FlatBufferBuilder& fbb, ClassTypePtr class_ptr) {
mobile::serialization::TypeType typetype =
mobile::serialization::TypeType::UNSET;
flatbuffers::Offset<
flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
names_offset = 0;
c10::QualifiedName setstate_name(*class_ptr->name(), "__setstate__");
const mobile::Function* setstate = mcu_->find_function(setstate_name);
if (setstate != nullptr) {
typetype = mobile::serialization::TypeType::CLASS_WITH_SETSTATE;
} else if (class_ptr->findMethod("__setstate__")) {
typetype = mobile::serialization::TypeType::CUSTOM_CLASS;
} else {
size_t num_attr = class_ptr->numAttributes();
std::vector<flatbuffers::Offset<flatbuffers::String>> names;
std::vector<uint32_t> type_index;
for (size_t i = 0; i < num_attr; ++i) {
names.push_back(fbb.CreateSharedString(class_ptr->getAttributeName(i)));
}
names_offset = fbb.CreateVector(names);
typetype = mobile::serialization::TypeType::CLASS_WITH_FIELD;
}
auto name_offset = fbb.CreateString(class_ptr->name()->qualifiedName());
return CreateObjectType(fbb, name_offset, typetype, names_offset);
}
uint32_t FlatbufferSerializer::storeFunctionAndGetIndex(
flatbuffers::FlatBufferBuilder& fbb,
const std::string& qn,
const mobile::Function& function) {
auto iter = qn_to_serialized_values_.find(qn);
if (iter != qn_to_serialized_values_.end()) {
return iter->second;
}
auto offset = CreateIValue(
fbb,
mobile::serialization::IValueUnion::Function,
functionToFB(fbb, qn, function).Union());
uint32_t index = insertIValue(offset);
qn_to_serialized_values_[qn] = index;
return index;
}
uint32_t FlatbufferSerializer::storeClassTypeAndGetIndex(
FlatBufferBuilder& fbb,
ClassTypePtr class_ptr) {
const auto& type_str = class_ptr->name()->qualifiedName();
auto iter = qn_to_serialized_values_.find(type_str);
if (iter != qn_to_serialized_values_.end()) {
return iter->second;
}
auto offset = classTypeToFB(fbb, class_ptr);
uint32_t res = obj_types_offset_.size();
obj_types_offset_.push_back(offset);
qn_to_serialized_values_[type_str] = res;
return res;
}
flatbuffers::Offset<mobile::serialization::Object> FlatbufferSerializer::
objectToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) {
auto obj = ivalue.toObject();
auto type = obj->type();
// rename type?
// check getstate
// save state as ivalue
flatbuffers::Offset<flatbuffers::Vector<uint32_t>> attrs = 0;
uint32_t state_index = 0;
uint32_t setstate_func_index = 0;
const auto qn = type->name()->qualifiedName() + ".__setstate__";
auto getstate = type->findMethod("__getstate__");
auto setstate = type->findMethod("__setstate__");
if (getstate && setstate) {
auto state = (*getstate)({obj});
state_index = storeIValueAndGetIndex(fbb, state);
auto func_index = qn_to_serialized_values_.find(qn);
if (func_index != qn_to_serialized_values_.end()) {
setstate_func_index = func_index->second;
}
} else {
size_t num_attr = type->numAttributes();
std::vector<uint32_t> tuple_index;
for (size_t i = 0; i < num_attr; ++i) {
tuple_index.push_back(storeIValueAndGetIndex(fbb, obj->getSlot(i)));
}
attrs = fbb.CreateVector(tuple_index);
}
uint32_t type_index = storeClassTypeAndGetIndex(fbb, type);
return CreateObject(fbb, type_index, state_index, attrs, setstate_func_index);
}
flatbuffers::Offset<mobile::serialization::TensorMetadata> FlatbufferSerializer::
FlatbufferSerializer::tensorToFB(
flatbuffers::FlatBufferBuilder& fbb,
const IValue& ivalue) {
auto& tensor = ivalue.toTensor();
bool quantized = tensor.is_quantized();
const at::Storage& storage = tensor.storage();
flatbuffers::Offset<mobile::serialization::QuantizedSchema> qschema_offset =
0;
if (quantized) {
double scale = 0;
int32_t zero_point = 0;
flatbuffers::Offset<mobile::serialization::TensorMetadata> scales = 0;
flatbuffers::Offset<mobile::serialization::TensorMetadata> zero_points = 0;
int32_t axis = 0;
switch (tensor.qscheme()) {
case at::kPerTensorAffine:
scale = tensor.q_scale();
zero_point = tensor.q_zero_point();
break;
case at::kPerChannelAffineFloatQParams:
case at::kPerChannelAffine: {
scales = tensorToFB(fbb, tensor.q_per_channel_scales());
zero_points = tensorToFB(fbb, tensor.q_per_channel_zero_points());
axis = tensor.q_per_channel_axis();
} break;
default:
TORCH_CHECK(
false,
"Unsupported tensor quantization type in serialization ",
toString(tensor.qscheme()));
break;
}
qschema_offset = mobile::serialization::CreateQuantizedSchema(
fbb,
static_cast<int8_t>(tensor.qscheme()),
scale,
zero_point,
scales,
zero_points,
axis);
}
void* addr = storage.unsafeGetStorageImpl();
uint32_t storage_index = 0;
auto it = memoized_storage_map_.find(addr);
if (it != memoized_storage_map_.end()) {
storage_index = it->second;
} else {
storage_index = tensor_data_.size();
memoized_storage_map_[addr] = storage_index;
tensor_data_.push_back(tensor);
}
std::vector<int> sizes{tensor.sizes().begin(), tensor.sizes().end()};
std::vector<int> strides{tensor.strides().begin(), tensor.strides().end()};
return CreateTensorMetadataDirect(
fbb,
/* storage_location_index */ storage_index,
/* scalar_type */ static_cast<int8_t>(tensor.scalar_type()),
/* int32_t storage_offset */ tensor.storage_offset(),
/* sizes */ &sizes,
/* strides */ &strides,
/* bool requires_grad */ tensor.requires_grad(),
/* qschema */ qschema_offset);
}
uint32_t FlatbufferSerializer::storeIValueAndGetIndex(
flatbuffers::FlatBufferBuilder& fbb,
const IValue& ivalue) {
if (ivalue.isNone()) {
return kNoneIndex;
}
try {
auto iter = cached_ivalues_.find(ivalue);
if (iter != cached_ivalues_.end()) {
return iter->second;
}
} catch (const std::runtime_error&) {
// Threw if ivalue is not hashable
} catch (const c10::Error&) {
// Threw if ivalue is don't have proper operator==
}
auto offset = iValueToFB(fbb, ivalue);
uint32_t index = insertIValue(offset);
try {
cached_ivalues_[ivalue] = index;
} catch (const std::runtime_error&) {
} catch (const c10::Error&) {
}
return index;
}
flatbuffers::Offset<mobile::serialization::IValue> FlatbufferSerializer::
iValueToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) {
using mobile::serialization::IValueUnion;
IValueUnion ivalue_type = IValueUnion::NONE;
flatbuffers::Offset<void> offset = 0;
if (ivalue.isTensor()) {
ivalue_type = IValueUnion::TensorMetadata;
offset = tensorToFB(fbb, ivalue).Union();
} else if (ivalue.isTuple()) {
ivalue_type = IValueUnion::Tuple;
offset = tupleToFB(fbb, ivalue).Union();
} else if (ivalue.isDouble()) {
ivalue_type = IValueUnion::Double;
offset = fbb.CreateStruct(mobile::serialization::Double(ivalue.toDouble()))
.Union();
} else if (ivalue.isComplexDouble()) {
auto comp = ivalue.toComplexDouble();
ivalue_type = IValueUnion::ComplexDouble;
offset = fbb.CreateStruct(mobile::serialization::ComplexDouble(
comp.real(), comp.imag()))
.Union();
} else if (ivalue.isInt()) {
ivalue_type = IValueUnion::Int;
offset =
fbb.CreateStruct(mobile::serialization::Int(ivalue.toInt())).Union();
} else if (ivalue.isBool()) {
ivalue_type = IValueUnion::Bool;
offset =
fbb.CreateStruct(mobile::serialization::Bool(ivalue.toBool())).Union();
} else if (ivalue.isString()) {
ivalue_type = IValueUnion::String;
offset = mobile::serialization::CreateString(
fbb, fbb.CreateSharedString(ivalue.toString()->string()))
.Union();
} else if (ivalue.isGenericDict()) {
ivalue_type = IValueUnion::Dict;
offset = dictToFB(fbb, ivalue).Union();
} else if (ivalue.isNone()) {
ivalue_type = IValueUnion::NONE;
offset = 0;
} else if (ivalue.isIntList()) {
ivalue_type = IValueUnion::IntList;
offset = mobile::serialization::CreateIntList(
fbb, fbb.CreateVector(ivalue.toIntVector()))
.Union();
} else if (ivalue.isDoubleList()) {
ivalue_type = IValueUnion::DoubleList;
offset = mobile::serialization::CreateDoubleList(
fbb, fbb.CreateVector(ivalue.toDoubleVector()))
.Union();
} else if (ivalue.isBoolList()) {
ivalue_type = IValueUnion::BoolList;
auto boollist = ivalue.toBoolList();
std::vector<uint8_t> bool_vec(boollist.begin(), boollist.end());
offset =
mobile::serialization::CreateBoolListDirect(fbb, &bool_vec).Union();
} else if (ivalue.isList()) {
ivalue_type = IValueUnion::List;
offset = listToFB(fbb, ivalue).Union();
} else if (ivalue.isObject()) {
ivalue_type = IValueUnion::Object;
offset = objectToFB(fbb, ivalue).Union();
} else if (ivalue.isDevice()) {
ivalue_type = IValueUnion::Device;
offset = mobile::serialization::CreateDevice(
fbb, fbb.CreateSharedString(ivalue.toDevice().str()))
.Union();
} else if (ivalue.isEnum()) {
const auto& enum_holder = ivalue.toEnumHolder();
const auto& qualified_class_name =
enum_holder->type()->qualifiedClassName();
uint32_t ival_pos = storeIValueAndGetIndex(fbb, enum_holder->value());
ivalue_type = IValueUnion::EnumValue;
offset = mobile::serialization::CreateEnumValue(
fbb,
fbb.CreateSharedString(qualified_class_name.qualifiedName()),
ival_pos)
.Union();
} else {
AT_ERROR("Invalid IValue type for serialization: ", ivalue.tagKind());
}
return CreateIValue(fbb, ivalue_type, offset);
}
} // namespace
void save_mobile_module(
const mobile::Module& module,
const std::string& filename) {
FlatbufferSerializer fb_serializer;
auto buffer = fb_serializer.serializeModule(module, true);
std::fstream ofile(filename, std::ios::binary | std::ios::out);
ofile.write(reinterpret_cast<char*>(buffer.data()), buffer.size());
ofile.close();
}
flatbuffers::DetachedBuffer save_mobile_module_to_bytes(
const mobile::Module& module) {
FlatbufferSerializer fb_serializer;
return fb_serializer.serializeModule(module, true);
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,26 @@
#pragma once
#include <ATen/core/qualified_name.h>
#include <flatbuffers/flatbuffers.h>
#include <string>
#include <vector>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/backends/backend_debug_handler.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
#include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
namespace torch {
namespace jit {
TORCH_API void save_mobile_module(
const mobile::Module& module,
const std::string& filename);
TORCH_API flatbuffers::DetachedBuffer save_mobile_module_to_bytes(
const mobile::Module& module);
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,197 @@
namespace torch.jit.mobile.serialization;
struct Int {
int_val:long;
}
struct Bool {
bool_val:bool;
}
struct Double{
double_val:double;
}
struct PerTensorAffineSchema {
q_scale:double;
q_zero_point:int;
}
table QuantizedSchema {
qscheme:byte;
scale:double;
zero_point:int;
scales:TensorMetadata;
zero_points:TensorMetadata;
axis:int;
}
table TensorMetadata {
// torch._utils _rebuild_tensor_v2
storage_location_index:uint;
// enum ScalarType
scalar_type:byte;
storage_offset:int;
sizes:[int];
strides:[int];
requires_grad:bool;
// only set for quantized tensors
quantized_schema:QuantizedSchema;
}
table String {
data:string;
}
table Device {
str:string;
}
table List {
items:[uint];
annotation_str:string; // to recover key/val type
}
table IntList {
items:[long];
}
table DoubleList {
items:[double];
}
table BoolList {
items:[bool];
}
table Tuple {
items:[uint];
}
table Dict {
keys:[uint];
values:[uint];
annotation_str:string; // to recover key/val type
}
enum TypeType :ubyte {
UNSET,
CLASS_WITH_FIELD,
CUSTOM_CLASS,
CLASS_WITH_SETSTATE,
NON_OBJ,
}
table ObjectType {
type_name:string;
type:TypeType;
// Below fields are optional
attr_names:[string];
}
table Object {
type_index:uint;
state:uint;
attrs:[uint];
setstate_func:uint;
}
struct ComplexDouble {
real:double;
imag:double;
}
table EnumValue {
type_name:string;
value:uint; // index to ivalues;
}
struct Instruction {
// Should op be enum instead?
op:byte;
n:ushort;
x:int;
}
table Operator {
name:string;
overload_name:string;
num_args_serialized:int = -1;
}
table Arg {
name:string;
// Why do we use string to represent types
// rather than index into Code.types?
type:string;
default_value:uint; // position into ivalues
}
table Schema {
arguments:[Arg];
returns:[Arg];
}
table DebugInfo {
debug_handle:[long];
}
table Function {
qn:string;
instructions:[Instruction];
operators:[Operator];
constants:[uint]; // index to ivalue
type_annotations:[string];
register_size:int;
schema:Schema;
debug_info:DebugInfo;
class_type:uint; // index into type table
}
table StorageData {
data:[ubyte] (force_align:16);
}
// Is it needed to represent other types?
union IValueUnion {
Int,
Bool,
Double,
ComplexDouble,
TensorMetadata,
String,
List,
Tuple,
Dict,
Object,
IntList,
DoubleList,
BoolList,
Device,
EnumValue,
Function,
}
table IValue {
val:IValueUnion;
}
table ExtraFile {
name:string;
content:string;
}
table Module {
version:int;
extra_files:[ExtraFile];
methods:[uint]; // index to ivalues
state_obj:uint; // index to ivalues
ivalues:[IValue];
storage_data_size:int; // number of storage data;
storage_data:[StorageData];
object_types:[ObjectType];
}
root_type Module;

File diff suppressed because it is too large Load Diff