mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor flatbuffer loader to allow overriding how IValues are parsed. (#71661)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71661 https://docs.google.com/document/d/1OoPKREoqNbOUIcbGzfk8TTIibeTgx3c6Lr3NthF7-PM/edit Test Plan: unittest Reviewed By: zhxchen17 Differential Revision: D33720630 fbshipit-source-id: da24993cf5568c689cb6fda64ba4943d77f8b5e6 (cherry picked from commit 327cf75d234ee2b1aea79dc909b890b96927f536)
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/dynamic_type.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
@ -32,7 +33,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace {
|
||||
|
||||
using caffe2::serialize::IStreamAdapter;
|
||||
using caffe2::serialize::PyTorchStreamReader;
|
||||
@ -43,56 +43,77 @@ static constexpr c10::string_view kCustomClassPrefix =
|
||||
static constexpr c10::string_view kTorchPrefix = "__torch__";
|
||||
static constexpr c10::string_view kJitPrefix = "torch.jit";
|
||||
|
||||
class FlatbufferLoader {
|
||||
public:
|
||||
FlatbufferLoader()
|
||||
: mcu_(std::make_shared<mobile::CompilationUnit>()),
|
||||
cu_(std::make_shared<CompilationUnit>()) {}
|
||||
template <typename T, typename U>
|
||||
std::vector<T> parseListNative(const U* list) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr);
|
||||
return {list->items()->begin(), list->items()->end()};
|
||||
}
|
||||
|
||||
mobile::Module parseModule(mobile::serialization::Module* module);
|
||||
IValue parseList(
|
||||
FlatbufferLoader&,
|
||||
const mobile::serialization::IValue& ivalue);
|
||||
IValue parseTensor(
|
||||
FlatbufferLoader&,
|
||||
const mobile::serialization::IValue& ivalue);
|
||||
IValue parseTuple(
|
||||
FlatbufferLoader&,
|
||||
const mobile::serialization::IValue& ivalue);
|
||||
IValue parseDict(
|
||||
FlatbufferLoader&,
|
||||
const mobile::serialization::IValue& ivalue);
|
||||
IValue parseObject(
|
||||
FlatbufferLoader&,
|
||||
const mobile::serialization::IValue& ivalue);
|
||||
IValue parseIntList(
|
||||
FlatbufferLoader&,
|
||||
const mobile::serialization::IValue& ivalue);
|
||||
IValue parseDoubleList(
|
||||
FlatbufferLoader&,
|
||||
const mobile::serialization::IValue& ivalue);
|
||||
IValue parseBoolList(
|
||||
FlatbufferLoader&,
|
||||
const mobile::serialization::IValue& ivalue);
|
||||
IValue parseBasic(
|
||||
FlatbufferLoader&,
|
||||
const mobile::serialization::IValue& ivalue);
|
||||
IValue parseEnum(
|
||||
FlatbufferLoader&,
|
||||
const mobile::serialization::IValue& ivalue);
|
||||
|
||||
private:
|
||||
IValue parseIValue(const mobile::serialization::IValue* ivalue);
|
||||
IValue parseList(const mobile::serialization::List* list);
|
||||
at::Tensor parseTensor(const mobile::serialization::TensorMetadata* tensor);
|
||||
IValue parseTuple(const mobile::serialization::Tuple* tuple);
|
||||
IValue parseDict(const mobile::serialization::Dict* dict);
|
||||
IValue parseObject(const mobile::serialization::Object* object);
|
||||
std::unique_ptr<mobile::Function> parseFunction(
|
||||
const mobile::serialization::Function* method);
|
||||
FlatbufferLoader::FlatbufferLoader()
|
||||
: mcu_(std::make_shared<mobile::CompilationUnit>()),
|
||||
cu_(std::make_shared<CompilationUnit>()),
|
||||
ivalue_parsers_{nullptr} {
|
||||
registerIValueParser(mobile::serialization::IValueUnion::NONE, &parseBasic);
|
||||
registerIValueParser(mobile::serialization::IValueUnion::Int, &parseBasic);
|
||||
registerIValueParser(mobile::serialization::IValueUnion::Bool, &parseBasic);
|
||||
registerIValueParser(mobile::serialization::IValueUnion::Double, &parseBasic);
|
||||
registerIValueParser(
|
||||
mobile::serialization::IValueUnion::ComplexDouble, &parseBasic);
|
||||
registerIValueParser(
|
||||
mobile::serialization::IValueUnion::TensorMetadata, &parseTensor);
|
||||
registerIValueParser(mobile::serialization::IValueUnion::String, &parseBasic);
|
||||
registerIValueParser(mobile::serialization::IValueUnion::List, &parseList);
|
||||
registerIValueParser(
|
||||
mobile::serialization::IValueUnion::IntList, &parseIntList);
|
||||
registerIValueParser(
|
||||
mobile::serialization::IValueUnion::DoubleList, &parseDoubleList);
|
||||
registerIValueParser(
|
||||
mobile::serialization::IValueUnion::BoolList, &parseBoolList);
|
||||
registerIValueParser(mobile::serialization::IValueUnion::Tuple, &parseTuple);
|
||||
registerIValueParser(mobile::serialization::IValueUnion::Dict, &parseDict);
|
||||
registerIValueParser(
|
||||
mobile::serialization::IValueUnion::Object, &parseObject);
|
||||
registerIValueParser(mobile::serialization::IValueUnion::Device, &parseBasic);
|
||||
registerIValueParser(
|
||||
mobile::serialization::IValueUnion::EnumValue, &parseEnum);
|
||||
}
|
||||
|
||||
IValue& getIValue(uint32_t pos) {
|
||||
TORCH_CHECK(pos < all_ivalues_.size());
|
||||
return all_ivalues_[pos];
|
||||
}
|
||||
|
||||
mobile::Function* getFunction(uint32_t pos) {
|
||||
return all_functions_[pos];
|
||||
}
|
||||
|
||||
ClassTypePtr getType(uint32_t pos) const {
|
||||
TORCH_CHECK(pos < all_ivalues_.size());
|
||||
return all_types_[pos];
|
||||
// auto iter = all_types_.find(pos);
|
||||
// AT_ASSERT(iter != all_types_.end(), "type not found at pos: ", pos);
|
||||
// return iter->second;
|
||||
}
|
||||
|
||||
c10::Storage getStorage(uint32_t index);
|
||||
TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
|
||||
|
||||
// fields
|
||||
std::unordered_map<uint32_t, mobile::Function*> all_functions_;
|
||||
std::vector<ClassTypePtr> all_types_;
|
||||
std::unordered_set<uint32_t> initialized_types_;
|
||||
std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
|
||||
std::vector<bool> storage_loaded_;
|
||||
std::vector<c10::Storage> storages_;
|
||||
std::vector<IValue> all_ivalues_;
|
||||
std::shared_ptr<mobile::CompilationUnit> mcu_;
|
||||
std::shared_ptr<CompilationUnit> cu_;
|
||||
mobile::serialization::Module* module_ = nullptr;
|
||||
};
|
||||
void FlatbufferLoader::registerIValueParser(
|
||||
mobile::serialization::IValueUnion ivalue_type,
|
||||
IValueParser parser) {
|
||||
ivalue_parsers_[static_cast<uint8_t>(ivalue_type)] = parser;
|
||||
}
|
||||
|
||||
mobile::Module FlatbufferLoader::parseModule(
|
||||
mobile::serialization::Module* module) {
|
||||
@ -120,12 +141,6 @@ mobile::Module FlatbufferLoader::parseModule(
|
||||
}
|
||||
|
||||
IValue& module_ivalue = getIValue(module->state_obj());
|
||||
// register function to class
|
||||
// for (const auto& func: all_functions_) {
|
||||
// const auto* fb_func = ivalues->Get(func.first)->val_as_Function();
|
||||
// auto class_type = getType(fb_func->class_type());
|
||||
// class_type->addMethod(func.second);
|
||||
// }
|
||||
return mobile::Module(module_ivalue.toObject(), mcu_);
|
||||
}
|
||||
|
||||
@ -196,7 +211,55 @@ std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
|
||||
return function;
|
||||
}
|
||||
|
||||
at::Tensor FlatbufferLoader::parseTensor(
|
||||
IValue parseEnum(
|
||||
FlatbufferLoader& loader,
|
||||
const mobile::serialization::IValue& ivalue) {
|
||||
const auto* enum_val = ivalue.val_as_EnumValue();
|
||||
auto enum_type = loader.getOrCreateTypeAnnotations(enum_val->type_name())
|
||||
->cast<c10::EnumType>();
|
||||
AT_ASSERT(
|
||||
enum_type,
|
||||
"Enum with type: " + enum_val->type_name()->str() + " not found.");
|
||||
IValue val = loader.getIValue(enum_val->value());
|
||||
for (const auto& p : enum_type->enumNamesValues()) {
|
||||
if (p.second == val) {
|
||||
auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
|
||||
enum_type, p.first, p.second);
|
||||
return IValue(std::move(enum_holder));
|
||||
}
|
||||
}
|
||||
AT_ASSERT(
|
||||
false, "Enum with type: " + enum_val->type_name()->str() + " not found.");
|
||||
}
|
||||
|
||||
IValue parseBasic(
|
||||
FlatbufferLoader&,
|
||||
const mobile::serialization::IValue& ivalue) {
|
||||
switch (ivalue.val_type()) {
|
||||
case mobile::serialization::IValueUnion::NONE:
|
||||
return {};
|
||||
case mobile::serialization::IValueUnion::Int:
|
||||
return ivalue.val_as_Int()->int_val();
|
||||
case mobile::serialization::IValueUnion::Bool:
|
||||
return ivalue.val_as_Bool()->bool_val();
|
||||
case mobile::serialization::IValueUnion::Double:
|
||||
return ivalue.val_as_Double()->double_val();
|
||||
case mobile::serialization::IValueUnion::ComplexDouble: {
|
||||
const auto* comp = ivalue.val_as_ComplexDouble();
|
||||
return c10::complex<double>(comp->real(), comp->imag());
|
||||
}
|
||||
case mobile::serialization::IValueUnion::String:
|
||||
return ivalue.val_as_String()->data()->str();
|
||||
case mobile::serialization::IValueUnion::Device: {
|
||||
return c10::Device(ivalue.val_as_Device()->str()->str());
|
||||
}
|
||||
default:
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor parseTensorFromMetadata(
|
||||
FlatbufferLoader* loader,
|
||||
const mobile::serialization::TensorMetadata* tensor_md) {
|
||||
at::ScalarType type = static_cast<at::ScalarType>(tensor_md->scalar_type());
|
||||
auto options = at::CPU(type).options();
|
||||
@ -212,8 +275,9 @@ at::Tensor FlatbufferLoader::parseTensor(
|
||||
} break;
|
||||
case at::kPerChannelAffineFloatQParams:
|
||||
case at::kPerChannelAffine: {
|
||||
at::Tensor scales = parseTensor(schema->scales());
|
||||
at::Tensor zero_points = parseTensor(schema->zero_points());
|
||||
at::Tensor scales = parseTensorFromMetadata(loader, schema->scales());
|
||||
at::Tensor zero_points =
|
||||
parseTensorFromMetadata(loader, schema->zero_points());
|
||||
tensor = at::_empty_per_channel_affine_quantized(
|
||||
{0}, scales, zero_points, schema->axis(), options);
|
||||
} break;
|
||||
@ -230,7 +294,7 @@ at::Tensor FlatbufferLoader::parseTensor(
|
||||
at::TensorImpl* impl = tensor.unsafeGetTensorImpl();
|
||||
|
||||
c10::Storage storage;
|
||||
storage = getStorage(tensor_md->storage_location_index());
|
||||
storage = loader->getStorage(tensor_md->storage_location_index());
|
||||
impl->set_storage_keep_dtype(storage);
|
||||
impl->set_storage_offset(tensor_md->storage_offset());
|
||||
|
||||
@ -239,48 +303,93 @@ at::Tensor FlatbufferLoader::parseTensor(
|
||||
std::vector<int64_t> stride{
|
||||
tensor_md->strides()->begin(), tensor_md->strides()->end()};
|
||||
impl->set_sizes_and_strides(size, stride);
|
||||
#ifndef MIN_EDGE_RUNTIME
|
||||
tensor = autograd::make_variable(tensor, tensor_md->requires_grad());
|
||||
#endif
|
||||
return tensor;
|
||||
}
|
||||
IValue FlatbufferLoader::parseList(const mobile::serialization::List* list) {
|
||||
|
||||
IValue parseTensor(
|
||||
FlatbufferLoader& loader,
|
||||
const mobile::serialization::IValue& ivalue) {
|
||||
const mobile::serialization::TensorMetadata* tensor_md =
|
||||
ivalue.val_as_TensorMetadata();
|
||||
return parseTensorFromMetadata(&loader, tensor_md);
|
||||
}
|
||||
|
||||
IValue parseList(
|
||||
FlatbufferLoader& loader,
|
||||
const mobile::serialization::IValue& ivalue) {
|
||||
const mobile::serialization::List* list = ivalue.val_as_List();
|
||||
auto res = c10::impl::GenericList(AnyType::get());
|
||||
for (int i : *list->items()) {
|
||||
res.emplace_back(getIValue(i));
|
||||
res.emplace_back(loader.getIValue(i));
|
||||
}
|
||||
auto type = getOrCreateTypeAnnotations(list->annotation_str());
|
||||
auto type = loader.getOrCreateTypeAnnotations(list->annotation_str());
|
||||
res.unsafeSetElementType(type->containedType(0));
|
||||
return res;
|
||||
}
|
||||
|
||||
IValue FlatbufferLoader::parseTuple(const mobile::serialization::Tuple* tuple) {
|
||||
IValue parseIntList(
|
||||
FlatbufferLoader&,
|
||||
const mobile::serialization::IValue& ivalue) {
|
||||
const auto& list = ivalue.val_as_IntList();
|
||||
return parseListNative<int64_t>(list);
|
||||
}
|
||||
|
||||
IValue parseBoolList(
|
||||
FlatbufferLoader&,
|
||||
const mobile::serialization::IValue& ivalue) {
|
||||
const auto& list = ivalue.val_as_DoubleList();
|
||||
return parseListNative<double>(list);
|
||||
}
|
||||
|
||||
IValue parseDoubleList(
|
||||
FlatbufferLoader&,
|
||||
const mobile::serialization::IValue& ivalue) {
|
||||
const auto& list = ivalue.val_as_BoolList();
|
||||
std::vector<uint8_t> res = parseListNative<uint8_t>(list);
|
||||
c10::List<bool> boollist;
|
||||
for (auto x : res) {
|
||||
boollist.push_back(x);
|
||||
}
|
||||
return boollist;
|
||||
}
|
||||
|
||||
IValue parseTuple(
|
||||
FlatbufferLoader& loader,
|
||||
const mobile::serialization::IValue& ivalue) {
|
||||
const auto& tuple = ivalue.val_as_Tuple();
|
||||
std::vector<IValue> res;
|
||||
for (int i : *tuple->items()) {
|
||||
res.emplace_back(getIValue(i));
|
||||
res.emplace_back(loader.getIValue(i));
|
||||
}
|
||||
return c10::ivalue::Tuple::create(res);
|
||||
}
|
||||
|
||||
IValue FlatbufferLoader::parseDict(const mobile::serialization::Dict* dict) {
|
||||
IValue parseDict(
|
||||
FlatbufferLoader& loader,
|
||||
const mobile::serialization::IValue& ivalue) {
|
||||
const auto* dict = ivalue.val_as_Dict();
|
||||
auto result = c10::impl::GenericDict(AnyType::get(), AnyType::get());
|
||||
const auto* keys = dict->keys();
|
||||
const auto* values = dict->values();
|
||||
for (size_t i = 0; i < keys->size(); ++i) {
|
||||
uint32_t key = keys->Get(i);
|
||||
uint32_t val = values->Get(i);
|
||||
result.insert_or_assign(getIValue(key), getIValue(val));
|
||||
result.insert_or_assign(loader.getIValue(key), loader.getIValue(val));
|
||||
}
|
||||
auto type = getOrCreateTypeAnnotations(dict->annotation_str());
|
||||
auto type = loader.getOrCreateTypeAnnotations(dict->annotation_str());
|
||||
result.unsafeSetKeyType(type->containedType(0));
|
||||
result.unsafeSetValueType(type->containedType(1));
|
||||
return result;
|
||||
}
|
||||
|
||||
IValue FlatbufferLoader::parseObject(
|
||||
ClassTypePtr FlatbufferLoader::getOrCreateClassTypeForObject(
|
||||
const mobile::serialization::Object* object) {
|
||||
auto cls = getType(object->type_index());
|
||||
const mobile::serialization::ObjectType* obj_type =
|
||||
module_->object_types()->Get(object->type_index());
|
||||
auto cls = getType(object->type_index());
|
||||
bool initialized = true;
|
||||
if (cls == nullptr) {
|
||||
c10::string_view qn_str(
|
||||
obj_type->type_name()->c_str(), obj_type->type_name()->size());
|
||||
@ -296,34 +405,46 @@ IValue FlatbufferLoader::parseObject(
|
||||
}
|
||||
TORCH_CHECK(object->type_index() < all_ivalues_.size());
|
||||
all_types_[object->type_index()] = cls;
|
||||
initialized = false;
|
||||
|
||||
if (obj_type->type() == mobile::serialization::TypeType::CLASS_WITH_FIELD) {
|
||||
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
|
||||
IValue val = getIValue(object->attrs()->Get(i));
|
||||
// Need to use concrete object's field's type to set type of field.
|
||||
cls->addAttribute(
|
||||
obj_type->attr_names()->Get(i)->str(),
|
||||
val.type<c10::DynamicType>());
|
||||
}
|
||||
}
|
||||
initialized_types_.insert(object->type_index());
|
||||
}
|
||||
return cls;
|
||||
}
|
||||
|
||||
IValue parseObject(
|
||||
FlatbufferLoader& loader,
|
||||
const mobile::serialization::IValue& ivalue) {
|
||||
const mobile::serialization::Object* object = ivalue.val_as_Object();
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(object != nullptr);
|
||||
const auto* cur_input = loader.getCurrentFlatbufferInput();
|
||||
const mobile::serialization::ObjectType* obj_type =
|
||||
cur_input->object_types()->Get(object->type_index());
|
||||
auto cls = loader.getOrCreateClassTypeForObject(object);
|
||||
Stack stack;
|
||||
switch (obj_type->type()) {
|
||||
case mobile::serialization::TypeType::CLASS_WITH_FIELD: {
|
||||
auto obj = c10::ivalue::Object::create(
|
||||
at::StrongTypePtr(cu_, cls), object->attrs()->size());
|
||||
if (!initialized) {
|
||||
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
|
||||
IValue val = getIValue(object->attrs()->Get(i));
|
||||
cls->addAttribute(
|
||||
obj_type->attr_names()->Get(i)->str(),
|
||||
val.type<c10::DynamicType>());
|
||||
obj->setSlot(i, std::move(val));
|
||||
}
|
||||
initialized_types_.insert(object->type_index());
|
||||
} else {
|
||||
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
|
||||
IValue val = getIValue(object->attrs()->Get(i));
|
||||
obj->setSlot(i, std::move(val));
|
||||
}
|
||||
at::StrongTypePtr(loader.cu_, cls), object->attrs()->size());
|
||||
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
|
||||
IValue val = loader.getIValue(object->attrs()->Get(i));
|
||||
obj->setSlot(i, std::move(val));
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
case mobile::serialization::TypeType::CLASS_WITH_SETSTATE: {
|
||||
IValue input = getIValue(object->state());
|
||||
mobile::Function* setstate = getFunction(object->setstate_func());
|
||||
auto obj = c10::ivalue::Object::create(at::StrongTypePtr(cu_, cls), 0);
|
||||
IValue input = loader.getIValue(object->state());
|
||||
mobile::Function* setstate = loader.getFunction(object->setstate_func());
|
||||
auto obj =
|
||||
c10::ivalue::Object::create(at::StrongTypePtr(loader.cu_, cls), 0);
|
||||
stack.push_back(obj);
|
||||
stack.emplace_back(std::move(input));
|
||||
setstate->run(stack);
|
||||
@ -332,7 +453,7 @@ IValue FlatbufferLoader::parseObject(
|
||||
case mobile::serialization::TypeType::CUSTOM_CLASS: {
|
||||
auto custom_class_type =
|
||||
torch::jit::getCustomClass(cls->name()->qualifiedName());
|
||||
IValue input = getIValue(object->state());
|
||||
IValue input = loader.getIValue(object->state());
|
||||
auto obj = c10::ivalue::Object::create(
|
||||
c10::StrongTypePtr(nullptr, custom_class_type), 1);
|
||||
stack.push_back(obj);
|
||||
@ -345,78 +466,10 @@ IValue FlatbufferLoader::parseObject(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
std::vector<T> parseListNative(const U* list) {
|
||||
return {list->items()->begin(), list->items()->end()};
|
||||
}
|
||||
|
||||
IValue FlatbufferLoader::parseIValue(
|
||||
const mobile::serialization::IValue* ivalue) {
|
||||
switch (ivalue->val_type()) {
|
||||
case mobile::serialization::IValueUnion::NONE:
|
||||
return {};
|
||||
case mobile::serialization::IValueUnion::Int:
|
||||
return ivalue->val_as_Int()->int_val();
|
||||
case mobile::serialization::IValueUnion::Bool:
|
||||
return ivalue->val_as_Bool()->bool_val();
|
||||
case mobile::serialization::IValueUnion::Double:
|
||||
return ivalue->val_as_Double()->double_val();
|
||||
case mobile::serialization::IValueUnion::ComplexDouble: {
|
||||
const auto* comp = ivalue->val_as_ComplexDouble();
|
||||
return c10::complex<double>(comp->real(), comp->imag());
|
||||
}
|
||||
case mobile::serialization::IValueUnion::TensorMetadata:
|
||||
return parseTensor(ivalue->val_as_TensorMetadata());
|
||||
case mobile::serialization::IValueUnion::String:
|
||||
return ivalue->val_as_String()->data()->str();
|
||||
case mobile::serialization::IValueUnion::List:
|
||||
return parseList(ivalue->val_as_List());
|
||||
case mobile::serialization::IValueUnion::IntList:
|
||||
return parseListNative<int64_t>(ivalue->val_as_IntList());
|
||||
case mobile::serialization::IValueUnion::DoubleList:
|
||||
return parseListNative<double>(ivalue->val_as_DoubleList());
|
||||
case mobile::serialization::IValueUnion::BoolList: {
|
||||
std::vector<uint8_t> res =
|
||||
parseListNative<uint8_t>(ivalue->val_as_BoolList());
|
||||
c10::List<bool> boollist;
|
||||
for (auto x : res) {
|
||||
boollist.push_back(x);
|
||||
}
|
||||
return boollist;
|
||||
}
|
||||
case mobile::serialization::IValueUnion::Tuple:
|
||||
return parseTuple(ivalue->val_as_Tuple());
|
||||
case mobile::serialization::IValueUnion::Dict:
|
||||
return parseDict(ivalue->val_as_Dict());
|
||||
case mobile::serialization::IValueUnion::Object: {
|
||||
auto val = parseObject(ivalue->val_as_Object());
|
||||
return val;
|
||||
}
|
||||
case mobile::serialization::IValueUnion::Device: {
|
||||
return c10::Device(ivalue->val_as_Device()->str()->str());
|
||||
}
|
||||
case mobile::serialization::IValueUnion::EnumValue: {
|
||||
const auto* enum_val = ivalue->val_as_EnumValue();
|
||||
auto enum_type = getOrCreateTypeAnnotations(enum_val->type_name())
|
||||
->cast<c10::EnumType>();
|
||||
AT_ASSERT(
|
||||
enum_type,
|
||||
"Enum with type: " + enum_val->type_name()->str() + " not found.");
|
||||
IValue val = getIValue(enum_val->value());
|
||||
for (const auto& p : enum_type->enumNamesValues()) {
|
||||
if (p.second == val) {
|
||||
auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
|
||||
enum_type, p.first, p.second);
|
||||
return IValue(std::move(enum_holder));
|
||||
}
|
||||
}
|
||||
AT_ASSERT(
|
||||
false,
|
||||
"Enum with type: " + enum_val->type_name()->str() + " not found.");
|
||||
}
|
||||
default:
|
||||
return {};
|
||||
}
|
||||
return ivalue_parsers_[static_cast<uint32_t>(ivalue->val_type())](
|
||||
*this, *ivalue);
|
||||
}
|
||||
|
||||
void deleteNothing2(void*);
|
||||
@ -469,8 +522,6 @@ TypePtr FlatbufferLoader::getOrCreateTypeAnnotations(
|
||||
return type;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
mobile::Module parse_and_initialize_mobile_module(
|
||||
std::shared_ptr<char> data,
|
||||
size_t,
|
||||
|
@ -50,5 +50,61 @@ TORCH_API mobile::Module load_mobile_module_from_file(
|
||||
const std::string& filename,
|
||||
c10::optional<at::Device> device = c10::nullopt);
|
||||
|
||||
class FlatbufferLoader {
|
||||
public:
|
||||
FlatbufferLoader();
|
||||
|
||||
typedef IValue (
|
||||
*IValueParser)(FlatbufferLoader&, const mobile::serialization::IValue&);
|
||||
void registerIValueParser(
|
||||
mobile::serialization::IValueUnion ivalue_type,
|
||||
IValueParser parser);
|
||||
mobile::Module parseModule(mobile::serialization::Module* module);
|
||||
|
||||
IValue& getIValue(uint32_t pos) {
|
||||
TORCH_CHECK(pos < all_ivalues_.size());
|
||||
return all_ivalues_[pos];
|
||||
}
|
||||
|
||||
mobile::Function* getFunction(uint32_t pos) {
|
||||
return all_functions_[pos];
|
||||
}
|
||||
|
||||
ClassTypePtr getType(uint32_t pos) {
|
||||
TORCH_CHECK(pos < all_ivalues_.size());
|
||||
return all_types_[pos];
|
||||
}
|
||||
|
||||
c10::Storage getStorage(uint32_t index);
|
||||
TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
|
||||
ClassTypePtr getOrCreateClassTypeForObject(
|
||||
const mobile::serialization::Object* object);
|
||||
|
||||
const mobile::serialization::Module* getCurrentFlatbufferInput() {
|
||||
return module_;
|
||||
}
|
||||
|
||||
std::shared_ptr<mobile::CompilationUnit> mcu_;
|
||||
std::shared_ptr<CompilationUnit> cu_;
|
||||
|
||||
private:
|
||||
IValue parseIValue(const mobile::serialization::IValue* ivalue);
|
||||
std::unique_ptr<mobile::Function> parseFunction(
|
||||
const mobile::serialization::Function* method);
|
||||
|
||||
std::unordered_map<uint32_t, mobile::Function*> all_functions_;
|
||||
std::vector<ClassTypePtr> all_types_;
|
||||
std::unordered_set<uint32_t> initialized_types_;
|
||||
std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
|
||||
std::vector<bool> storage_loaded_;
|
||||
std::vector<c10::Storage> storages_;
|
||||
std::vector<IValue> all_ivalues_;
|
||||
std::array<
|
||||
IValueParser,
|
||||
static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
|
||||
ivalue_parsers_;
|
||||
mobile::serialization::Module* module_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user