mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove some clang-tidy suppression in torch/csrc/jit by applying fixes or refactoring. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152271 Approved by: https://github.com/Skylion007, https://github.com/malfet Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
861 lines
30 KiB
C++
861 lines
30 KiB
C++
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
|
|
|
#ifdef FLATBUFFERS_VERSION_MAJOR
|
|
#error "flatbuffer_serializer.h must not include any flatbuffers headers"
|
|
#endif // FLATBUFFERS_VERSION_MAJOR
|
|
|
|
#include <fstream>
|
|
#include <functional>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <c10/core/CPUAllocator.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <caffe2/serialize/versions.h>
|
|
#include <torch/csrc/jit/mobile/code.h>
|
|
#include <torch/csrc/jit/mobile/train/export_data.h>
|
|
#include <torch/csrc/jit/passes/inliner.h>
|
|
#include <torch/csrc/jit/runtime/instruction.h>
|
|
|
|
#if defined(FB_XPLAT_BUILD) || defined(FBCODE_CAFFE2)
|
|
#include <torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h> // NOLINT
|
|
namespace flatbuffers = flatbuffers_fbsource;
|
|
#define FLATBUFFERS_MAX_ALIGNMENT FLATBUFFERS_FBSOURCE_MAX_ALIGNMENT
|
|
#else
|
|
#include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
|
|
#endif
|
|
|
|
namespace torch::jit {
|
|
|
|
using flatbuffers::FlatBufferBuilder;
|
|
using mobile::serialization::CreateArg;
|
|
using mobile::serialization::CreateDebugInfo;
|
|
using mobile::serialization::CreateDict;
|
|
using mobile::serialization::CreateFunctionDirect;
|
|
using mobile::serialization::CreateIValue;
|
|
using mobile::serialization::CreateList;
|
|
using mobile::serialization::CreateModule;
|
|
using mobile::serialization::CreateObject;
|
|
using mobile::serialization::CreateOperator;
|
|
using mobile::serialization::CreateTensorMetadataDirect;
|
|
using mobile::serialization::CreateTupleDirect;
|
|
|
|
namespace {
|
|
|
|
// TODO: remove once caffe2::kProducedBytecodeVersion is >= 9 and flatbuffer is
|
|
// launched.
|
|
constexpr uint32_t kMinVersion = 9;
|
|
|
|
// We will store IValue NONE in index 0 in flatbuffer.
|
|
constexpr int kNoneIndex = 0;
|
|
|
|
static TypePtr realType(TypePtr type) {
|
|
if (auto dyn = type->castRaw<c10::DynamicType>()) {
|
|
return dyn->fallback();
|
|
} else {
|
|
return type;
|
|
}
|
|
}
|
|
|
|
auto print_type(const c10::Type& t) -> std::optional<std::string> {
|
|
auto namedType = t.cast<c10::NamedType>();
|
|
if (namedType && namedType->name()) {
|
|
return namedType->name().value().qualifiedName();
|
|
}
|
|
if (auto dyn = t.castRaw<c10::DynamicType>()) {
|
|
return dyn->fallback()->annotation_str();
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
class FlatbufferSerializer {
|
|
public:
|
|
FlatbufferSerializer() = default;
|
|
|
|
flatbuffers::DetachedBuffer serializeModule(
|
|
const mobile::Module& module,
|
|
bool include_tensor_data_in_flatbuffer,
|
|
const ExtraFilesMap& extra_files = ExtraFilesMap(),
|
|
const ExtraFilesMap& jit_sources = ExtraFilesMap(),
|
|
const std::vector<IValue>& jit_constants = {});
|
|
|
|
private:
|
|
template <typename It>
|
|
std::vector<uint32_t> storeIValuesAndGetIndexes(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
It begin,
|
|
It end) {
|
|
std::vector<uint32_t> indexes;
|
|
for (; begin != end; ++begin) {
|
|
indexes.push_back(storeIValueAndGetIndex(fbb, *begin));
|
|
}
|
|
return indexes;
|
|
}
|
|
|
|
flatbuffers::Offset<mobile::serialization::Tuple> tupleToFB(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const IValue& tuple);
|
|
|
|
flatbuffers::Offset<mobile::serialization::List> listToFB(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const IValue& list);
|
|
|
|
flatbuffers::Offset<mobile::serialization::Dict> dictToFB(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const IValue& list);
|
|
|
|
flatbuffers::Offset<mobile::serialization::Object> objectToFB(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const IValue& ivalue);
|
|
|
|
flatbuffers::Offset<mobile::serialization::TensorMetadata> tensorToFB(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const IValue& ivalue);
|
|
|
|
flatbuffers::Offset<mobile::serialization::Function> functionToFB(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const std::string& qn,
|
|
const mobile::Function& func);
|
|
|
|
flatbuffers::Offset<mobile::serialization::IValue> iValueToFB(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const IValue& ivalue);
|
|
|
|
flatbuffers::Offset<jit::mobile::serialization::Schema> CreateFBSchema(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const std::vector<Argument>& args,
|
|
const std::vector<Argument>& returns,
|
|
const c10::TypePrinter& type_printer);
|
|
|
|
flatbuffers::Offset<mobile::serialization::ObjectType> classTypeToFB(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const ClassTypePtr& class_ptr);
|
|
|
|
uint32_t storeIValueAndGetIndex(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const IValue& ivalue);
|
|
uint32_t storeFunctionAndGetIndex(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const std::string& qn,
|
|
const mobile::Function& function);
|
|
|
|
uint32_t storeClassTypeAndGetIndex(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const ClassTypePtr& class_type);
|
|
|
|
flatbuffers::Offset<flatbuffers::Vector<
|
|
flatbuffers::Offset<mobile::serialization::ExtraFile>>>
|
|
storeExtraFilesAndGetOffset(
|
|
FlatBufferBuilder& fbb,
|
|
const ExtraFilesMap& extra_files);
|
|
|
|
uint32_t insertIValue(
|
|
flatbuffers::Offset<mobile::serialization::IValue> ivalue) {
|
|
uint32_t size = ivalue_offsets_.size();
|
|
ivalue_offsets_.push_back(ivalue);
|
|
return size;
|
|
}
|
|
|
|
std::vector<at::Tensor> tensor_data_;
|
|
|
|
std::unordered_map<const void*, uint32_t> memoized_storage_map_;
|
|
|
|
std::vector<flatbuffers::Offset<mobile::serialization::IValue>>
|
|
ivalue_offsets_;
|
|
std::vector<flatbuffers::Offset<mobile::serialization::ObjectType>>
|
|
obj_types_offset_;
|
|
|
|
// qualified name to serialized class, type or function
|
|
std::unordered_map<std::string, uint32_t> qn_to_serialized_values_;
|
|
|
|
// cache of some ivalues
|
|
struct IValueHash {
|
|
size_t operator()(const IValue& val) const {
|
|
return IValue::hash(val);
|
|
}
|
|
};
|
|
|
|
struct IValueEqual {
|
|
// Copy of this
|
|
// https://www.internalfb.com/code/aros/[3b875bce7ffa2adacdcea9b3e0cb6d304737a193]/xros/third-party/caffe2/caffe2/aten/src/ATen/core/ivalue.cpp?lines=266
|
|
// but without relying on aten::nonzero operator being present in the
|
|
// binary.
|
|
bool operator()(const IValue& lhs, const IValue& rhs) const {
|
|
// The only case we don't return bool is for tensor comparison. Lets do
|
|
// pointer comparison here.
|
|
if (lhs.isTensor() || rhs.isTensor()) {
|
|
if (lhs.isTensor() && rhs.isTensor()) {
|
|
return (&lhs.toTensor()) == (&rhs.toTensor());
|
|
}
|
|
return false;
|
|
}
|
|
IValue eq = lhs.equals(rhs);
|
|
if (eq.isBool()) {
|
|
return eq.toBool();
|
|
}
|
|
return false;
|
|
}
|
|
};
|
|
|
|
std::unordered_map<IValue, uint32_t, IValueHash, IValueEqual> cached_ivalues_;
|
|
const mobile::CompilationUnit* mcu_ = nullptr;
|
|
};
|
|
|
|
flatbuffers::Offset<jit::mobile::serialization::Schema> FlatbufferSerializer::
|
|
CreateFBSchema(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const std::vector<Argument>& args,
|
|
const std::vector<Argument>& returns,
|
|
const c10::TypePrinter& type_printer) {
|
|
std::vector<flatbuffers::Offset<jit::mobile::serialization::Arg>> arg_vec;
|
|
arg_vec.reserve(args.size());
|
|
std::vector<flatbuffers::Offset<jit::mobile::serialization::Arg>> return_vec;
|
|
return_vec.reserve(returns.size());
|
|
for (const auto& arg : args) {
|
|
auto index = storeIValueAndGetIndex(fbb, arg.default_value());
|
|
arg_vec.emplace_back(CreateArg(
|
|
fbb,
|
|
fbb.CreateSharedString(arg.name()),
|
|
fbb.CreateSharedString(
|
|
realType(arg.type())->annotation_str(type_printer)),
|
|
index));
|
|
}
|
|
|
|
for (const auto& ret : returns) {
|
|
auto index = storeIValueAndGetIndex(fbb, ret.default_value());
|
|
return_vec.emplace_back(CreateArg(
|
|
fbb,
|
|
fbb.CreateSharedString(ret.name()),
|
|
fbb.CreateSharedString(
|
|
realType(ret.type())->annotation_str(type_printer)),
|
|
index));
|
|
}
|
|
return CreateSchema(
|
|
fbb, fbb.CreateVector(arg_vec), fbb.CreateVector(return_vec));
|
|
}
|
|
|
|
flatbuffers::Offset<mobile::serialization::Function> FlatbufferSerializer::
|
|
functionToFB(
|
|
FlatBufferBuilder& fbb,
|
|
const std::string& qn,
|
|
const mobile::Function& func) {
|
|
const auto& code = func.get_code();
|
|
|
|
// instructions
|
|
std::vector<mobile::serialization::Instruction> instruction_vector;
|
|
instruction_vector.reserve(code.instructions_.size());
|
|
for (const auto& inst : code.instructions_) {
|
|
instruction_vector.emplace_back(inst.op, inst.N, inst.X);
|
|
}
|
|
|
|
// operators
|
|
std::vector<flatbuffers::Offset<mobile::serialization::Operator>>
|
|
operator_vector;
|
|
operator_vector.reserve(code.op_names_.size());
|
|
for (const auto i : c10::irange(code.op_names_.size())) {
|
|
const auto& opname = code.op_names_[i];
|
|
const int op_size = code.operator_input_sizes_[i];
|
|
operator_vector.push_back(CreateOperator(
|
|
fbb,
|
|
fbb.CreateSharedString(opname.name),
|
|
fbb.CreateSharedString(opname.overload_name),
|
|
op_size));
|
|
}
|
|
|
|
const auto& constants = code.constants_;
|
|
|
|
std::vector<uint32_t> constant_indexes;
|
|
constant_indexes.reserve(constants.size());
|
|
for (const auto& constant : constants) {
|
|
constant_indexes.push_back(storeIValueAndGetIndex(fbb, constant));
|
|
}
|
|
|
|
// types
|
|
static const std::string torch_prefix("__torch__");
|
|
static const std::string class_prefix("__torch__.torch.classes");
|
|
std::vector<flatbuffers::Offset<flatbuffers::String>> type_offsets;
|
|
|
|
for (const TypePtr& t : code.types_) {
|
|
auto type_str = realType(t)->annotation_str();
|
|
if (type_str.find(torch_prefix) == 0) {
|
|
TORCH_CHECK(
|
|
type_str.find(class_prefix) == 0,
|
|
"__torch__ types other than custom c++ classes (__torch__.torch.classes)"
|
|
"are not supported in lite interpreter. ",
|
|
"Workaround: instead of using arbitrary class type (class Foo()), ",
|
|
"define a pytorch class (class Foo(torch.nn.Module)).");
|
|
}
|
|
|
|
type_offsets.push_back(fbb.CreateSharedString(type_str));
|
|
}
|
|
|
|
// since the register location is embedded into the bytecode, pass the
|
|
// register size
|
|
auto register_size = static_cast<int>(code.register_size_);
|
|
|
|
// schema
|
|
auto type_printer = [&](const c10::Type& t) -> std::optional<std::string> {
|
|
auto namedType = t.cast<c10::NamedType>();
|
|
if (namedType && namedType->name()) {
|
|
return namedType->name().value().qualifiedName();
|
|
}
|
|
if (auto dyn = t.castRaw<c10::DynamicType>()) {
|
|
return dyn->fallback()->annotation_str();
|
|
}
|
|
return std::nullopt;
|
|
};
|
|
|
|
flatbuffers::Offset<mobile::serialization::Schema> schema_offset = 0;
|
|
uint32_t class_index = 0;
|
|
if (func.hasSchema()) {
|
|
const auto& schema = func.getSchema();
|
|
TORCH_CHECK(
|
|
schema.overload_name().empty(), // @TODO: is this check correct?
|
|
"Overloads are not supported in mobile modules.");
|
|
TORCH_CHECK(
|
|
!schema.is_vararg(),
|
|
"Python *args are not supported in mobile modules.");
|
|
TORCH_CHECK(
|
|
!schema.is_varret(),
|
|
"A variable number of return values is not supported in mobile modules.");
|
|
schema_offset =
|
|
CreateFBSchema(fbb, schema.arguments(), schema.returns(), type_printer);
|
|
auto classtype = schema.arguments()[0].type()->cast<ClassType>();
|
|
class_index = storeClassTypeAndGetIndex(fbb, classtype);
|
|
}
|
|
|
|
auto debug_info_offset =
|
|
CreateDebugInfo(fbb, fbb.CreateVector(code.debug_handles_));
|
|
|
|
auto function_offset = CreateFunctionDirect(
|
|
fbb,
|
|
qn.c_str(),
|
|
&instruction_vector,
|
|
&operator_vector,
|
|
&constant_indexes,
|
|
&type_offsets,
|
|
register_size,
|
|
schema_offset,
|
|
debug_info_offset,
|
|
class_index);
|
|
return function_offset;
|
|
}
|
|
|
|
flatbuffers::Offset<
|
|
flatbuffers::Vector<flatbuffers::Offset<mobile::serialization::ExtraFile>>>
|
|
FlatbufferSerializer::storeExtraFilesAndGetOffset(
|
|
FlatBufferBuilder& fbb,
|
|
const ExtraFilesMap& extra_files) {
|
|
std::vector<flatbuffers::Offset<mobile::serialization::ExtraFile>>
|
|
extra_file_offsets;
|
|
|
|
for (const auto& extra_file : extra_files) {
|
|
flatbuffers::Offset<mobile::serialization::ExtraFile> extra_file_offset =
|
|
mobile::serialization::CreateExtraFile(
|
|
fbb,
|
|
fbb.CreateSharedString(extra_file.first),
|
|
fbb.CreateString(extra_file.second));
|
|
extra_file_offsets.emplace_back(extra_file_offset);
|
|
}
|
|
return fbb.CreateVector(extra_file_offsets);
|
|
}
|
|
|
|
flatbuffers::DetachedBuffer FlatbufferSerializer::serializeModule(
|
|
const mobile::Module& module,
|
|
bool include_tensor_data_in_flatbuffer,
|
|
const ExtraFilesMap& extra_files,
|
|
const ExtraFilesMap& jit_sources,
|
|
const std::vector<IValue>& jit_constants) {
|
|
FlatBufferBuilder fbb;
|
|
|
|
mcu_ = &module.compilation_unit();
|
|
|
|
// first element is None.
|
|
insertIValue(CreateIValue(fbb, mobile::serialization::IValueUnion::NONE, 0));
|
|
|
|
auto methods = module.get_methods();
|
|
std::vector<uint32_t> functions_index;
|
|
functions_index.reserve(methods.size());
|
|
for (const auto& method : methods) {
|
|
auto func_offset = storeFunctionAndGetIndex(
|
|
fbb, method.function().qualname().qualifiedName(), method.function());
|
|
functions_index.push_back(func_offset);
|
|
}
|
|
|
|
auto functions_offset = fbb.CreateVector(functions_index);
|
|
uint32_t ivalue_index = storeIValueAndGetIndex(fbb, module._ivalue());
|
|
|
|
flatbuffers::Offset<flatbuffers::Vector<
|
|
flatbuffers::Offset<mobile::serialization::StorageData>>>
|
|
storage_data_offset = 0;
|
|
auto extra_files_offset = storeExtraFilesAndGetOffset(fbb, extra_files);
|
|
|
|
auto jit_source_offset = storeExtraFilesAndGetOffset(fbb, jit_sources);
|
|
std::vector<uint32_t> jit_constants_indexes;
|
|
jit_constants_indexes.reserve(jit_constants.size());
|
|
const uint32_t mobile_ivalue_size = ivalue_offsets_.size();
|
|
for (const auto& ival : jit_constants) {
|
|
jit_constants_indexes.emplace_back(storeIValueAndGetIndex(fbb, ival));
|
|
}
|
|
const uint32_t operator_version =
|
|
static_cast<uint32_t>(module.min_operator_version());
|
|
uint32_t bytecode_version = static_cast<uint32_t>(module.bytecode_version());
|
|
if (bytecode_version < kMinVersion) {
|
|
bytecode_version = kMinVersion;
|
|
}
|
|
|
|
// NOTE: saving of storage has to be the last thing to do.
|
|
if (include_tensor_data_in_flatbuffer) {
|
|
std::vector<flatbuffers::Offset<mobile::serialization::StorageData>>
|
|
storage_data;
|
|
for (auto td : tensor_data_) {
|
|
if (td.storage().device_type() != DeviceType::CPU) {
|
|
td = at::empty({0}, td.options())
|
|
.set_(
|
|
td.storage(),
|
|
/* storage_offset = */ 0,
|
|
/* size = */
|
|
{static_cast<int64_t>(
|
|
td.storage().nbytes() / td.element_size())},
|
|
/* stride = */ {1})
|
|
.cpu();
|
|
}
|
|
fbb.ForceVectorAlignment(
|
|
td.storage().nbytes(), sizeof(uint8_t), FLATBUFFERS_MAX_ALIGNMENT);
|
|
auto storage_offset = mobile::serialization::CreateStorageData(
|
|
fbb,
|
|
fbb.CreateVector(
|
|
reinterpret_cast<const uint8_t*>(td.storage().data()),
|
|
td.storage().nbytes()));
|
|
storage_data.push_back(storage_offset);
|
|
}
|
|
storage_data_offset = fbb.CreateVector(storage_data);
|
|
}
|
|
|
|
auto mod = CreateModule(
|
|
fbb,
|
|
/*bytecode_version=*/bytecode_version,
|
|
extra_files_offset, /* extra_files */
|
|
functions_offset,
|
|
ivalue_index,
|
|
fbb.CreateVector(ivalue_offsets_),
|
|
static_cast<int32_t>(tensor_data_.size()),
|
|
storage_data_offset,
|
|
fbb.CreateVector(obj_types_offset_),
|
|
jit_source_offset,
|
|
fbb.CreateVector(jit_constants_indexes),
|
|
operator_version,
|
|
mobile_ivalue_size);
|
|
FinishModuleBuffer(fbb, mod);
|
|
return fbb.Release();
|
|
}
|
|
|
|
flatbuffers::Offset<mobile::serialization::Tuple> FlatbufferSerializer::
|
|
tupleToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& tuple) {
|
|
const auto& elements = tuple.toTuple()->elements();
|
|
std::vector<uint32_t> items =
|
|
storeIValuesAndGetIndexes(fbb, elements.begin(), elements.end());
|
|
return CreateTupleDirect(fbb, &items);
|
|
}
|
|
|
|
flatbuffers::Offset<mobile::serialization::List> FlatbufferSerializer::listToFB(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const IValue& list) {
|
|
const auto& elements = list.toList();
|
|
std::vector<uint32_t> items =
|
|
storeIValuesAndGetIndexes(fbb, elements.begin(), elements.end());
|
|
return CreateList(
|
|
fbb,
|
|
fbb.CreateVector(items),
|
|
fbb.CreateSharedString(
|
|
realType(list.type<c10::Type>())->annotation_str(print_type)));
|
|
}
|
|
|
|
flatbuffers::Offset<mobile::serialization::Dict> FlatbufferSerializer::dictToFB(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const IValue& ivalue) {
|
|
const auto& dict = ivalue.toGenericDict();
|
|
std::vector<uint32_t> keys;
|
|
std::vector<uint32_t> values;
|
|
keys.reserve(dict.size());
|
|
values.reserve(dict.size());
|
|
for (const auto& entry : dict) {
|
|
auto key_index = storeIValueAndGetIndex(fbb, entry.key());
|
|
keys.push_back(key_index);
|
|
auto value_index = storeIValueAndGetIndex(fbb, entry.value());
|
|
values.push_back(value_index);
|
|
}
|
|
|
|
return CreateDict(
|
|
fbb,
|
|
fbb.CreateVector(keys),
|
|
fbb.CreateVector(values),
|
|
fbb.CreateSharedString(
|
|
realType(ivalue.type<c10::Type>())->annotation_str(print_type)));
|
|
}
|
|
|
|
flatbuffers::Offset<mobile::serialization::ObjectType> FlatbufferSerializer::
|
|
classTypeToFB(FlatBufferBuilder& fbb, const ClassTypePtr& class_ptr) {
|
|
mobile::serialization::TypeType typetype =
|
|
mobile::serialization::TypeType::UNSET;
|
|
|
|
flatbuffers::Offset<
|
|
flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
|
|
names_offset = 0;
|
|
c10::QualifiedName setstate_name(*class_ptr->name(), "__setstate__");
|
|
c10::QualifiedName getstate_name(*class_ptr->name(), "__getstate__");
|
|
const mobile::Function* setstate = mcu_->find_function(setstate_name);
|
|
const mobile::Function* getstate = mcu_->find_function(getstate_name);
|
|
if (setstate != nullptr && getstate != nullptr) {
|
|
typetype = mobile::serialization::TypeType::CLASS_WITH_SETSTATE;
|
|
} else if (
|
|
class_ptr->findMethod("__setstate__") &&
|
|
class_ptr->findMethod("__getstate__")) {
|
|
typetype = mobile::serialization::TypeType::CUSTOM_CLASS;
|
|
} else {
|
|
size_t num_attr = class_ptr->numAttributes();
|
|
std::vector<flatbuffers::Offset<flatbuffers::String>> names;
|
|
names.reserve(num_attr);
|
|
for (size_t i = 0; i < num_attr; ++i) {
|
|
names.push_back(fbb.CreateSharedString(class_ptr->getAttributeName(i)));
|
|
}
|
|
names_offset = fbb.CreateVector(names);
|
|
typetype = mobile::serialization::TypeType::CLASS_WITH_FIELD;
|
|
}
|
|
|
|
auto name_offset = fbb.CreateString(class_ptr->name()->qualifiedName());
|
|
return CreateObjectType(fbb, name_offset, typetype, names_offset);
|
|
}
|
|
|
|
uint32_t FlatbufferSerializer::storeFunctionAndGetIndex(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const std::string& qn,
|
|
const mobile::Function& function) {
|
|
auto iter = qn_to_serialized_values_.find(qn);
|
|
if (iter != qn_to_serialized_values_.end()) {
|
|
return iter->second;
|
|
}
|
|
|
|
auto offset = CreateIValue(
|
|
fbb,
|
|
mobile::serialization::IValueUnion::Function,
|
|
functionToFB(fbb, qn, function).Union());
|
|
|
|
uint32_t index = insertIValue(offset);
|
|
qn_to_serialized_values_[qn] = index;
|
|
return index;
|
|
}
|
|
|
|
uint32_t FlatbufferSerializer::storeClassTypeAndGetIndex(
|
|
FlatBufferBuilder& fbb,
|
|
const ClassTypePtr& class_ptr) {
|
|
const auto& type_str = class_ptr->name()->qualifiedName();
|
|
auto iter = qn_to_serialized_values_.find(type_str);
|
|
if (iter != qn_to_serialized_values_.end()) {
|
|
return iter->second;
|
|
}
|
|
|
|
auto offset = classTypeToFB(fbb, class_ptr);
|
|
uint32_t res = obj_types_offset_.size();
|
|
obj_types_offset_.push_back(offset);
|
|
qn_to_serialized_values_[type_str] = res;
|
|
return res;
|
|
}
|
|
|
|
flatbuffers::Offset<mobile::serialization::Object> FlatbufferSerializer::
|
|
objectToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) {
|
|
auto obj = ivalue.toObject();
|
|
auto type = obj->type();
|
|
// rename type?
|
|
// check getstate
|
|
|
|
// save state as ivalue
|
|
flatbuffers::Offset<flatbuffers::Vector<uint32_t>> attrs = 0;
|
|
uint32_t state_index = 0;
|
|
uint32_t setstate_func_index = 0;
|
|
const auto qn = type->name()->qualifiedName() + ".__setstate__";
|
|
auto getstate = type->findMethod("__getstate__");
|
|
auto setstate = type->findMethod("__setstate__");
|
|
if (getstate && setstate) {
|
|
auto state = (*getstate)({obj});
|
|
state_index = storeIValueAndGetIndex(fbb, state);
|
|
auto func_index = qn_to_serialized_values_.find(qn);
|
|
if (func_index != qn_to_serialized_values_.end()) {
|
|
setstate_func_index = func_index->second;
|
|
}
|
|
} else {
|
|
size_t num_attr = type->numAttributes();
|
|
std::vector<uint32_t> tuple_index;
|
|
tuple_index.reserve(num_attr);
|
|
for (size_t i = 0; i < num_attr; ++i) {
|
|
tuple_index.push_back(storeIValueAndGetIndex(fbb, obj->getSlot(i)));
|
|
}
|
|
attrs = fbb.CreateVector(tuple_index);
|
|
}
|
|
|
|
uint32_t type_index = storeClassTypeAndGetIndex(fbb, type);
|
|
return CreateObject(fbb, type_index, state_index, attrs, setstate_func_index);
|
|
}
|
|
|
|
flatbuffers::Offset<mobile::serialization::TensorMetadata> FlatbufferSerializer::
|
|
FlatbufferSerializer::tensorToFB(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const IValue& ivalue) {
|
|
auto& tensor = ivalue.toTensor();
|
|
bool quantized = tensor.is_quantized();
|
|
const at::Storage& storage = tensor.storage();
|
|
|
|
flatbuffers::Offset<mobile::serialization::QuantizedSchema> qschema_offset =
|
|
0;
|
|
if (quantized) {
|
|
double scale = 0;
|
|
int64_t zero_point = 0;
|
|
flatbuffers::Offset<mobile::serialization::TensorMetadata> scales = 0;
|
|
flatbuffers::Offset<mobile::serialization::TensorMetadata> zero_points = 0;
|
|
int64_t axis = 0;
|
|
|
|
switch (tensor.qscheme()) {
|
|
case at::kPerTensorAffine:
|
|
scale = tensor.q_scale();
|
|
zero_point = tensor.q_zero_point();
|
|
break;
|
|
case at::kPerChannelAffineFloatQParams:
|
|
case at::kPerChannelAffine: {
|
|
scales = tensorToFB(fbb, tensor.q_per_channel_scales());
|
|
zero_points = tensorToFB(fbb, tensor.q_per_channel_zero_points());
|
|
axis = tensor.q_per_channel_axis();
|
|
} break;
|
|
default:
|
|
TORCH_CHECK(
|
|
false,
|
|
"Unsupported tensor quantization type in serialization ",
|
|
toString(tensor.qscheme()));
|
|
break;
|
|
}
|
|
|
|
qschema_offset = mobile::serialization::CreateQuantizedSchema(
|
|
fbb,
|
|
static_cast<int8_t>(tensor.qscheme()),
|
|
scale,
|
|
static_cast<int32_t>(zero_point),
|
|
scales,
|
|
zero_points,
|
|
static_cast<int32_t>(axis));
|
|
}
|
|
|
|
void* addr = storage.unsafeGetStorageImpl();
|
|
uint32_t storage_index = 0;
|
|
auto it = memoized_storage_map_.find(addr);
|
|
if (it != memoized_storage_map_.end()) {
|
|
storage_index = it->second;
|
|
} else {
|
|
storage_index = tensor_data_.size();
|
|
memoized_storage_map_[addr] = storage_index;
|
|
tensor_data_.push_back(tensor);
|
|
}
|
|
|
|
std::vector<int> sizes{tensor.sizes().begin(), tensor.sizes().end()};
|
|
std::vector<int> strides{tensor.strides().begin(), tensor.strides().end()};
|
|
|
|
return CreateTensorMetadataDirect(
|
|
fbb,
|
|
/* storage_location_index */ storage_index,
|
|
/* scalar_type */ static_cast<int8_t>(tensor.scalar_type()),
|
|
/* int32_t storage_offset */
|
|
static_cast<int32_t>(tensor.storage_offset()),
|
|
/* sizes */ &sizes,
|
|
/* strides */ &strides,
|
|
/* bool requires_grad */ tensor.requires_grad(),
|
|
/* qschema */ qschema_offset);
|
|
}
|
|
|
|
uint32_t FlatbufferSerializer::storeIValueAndGetIndex(
|
|
flatbuffers::FlatBufferBuilder& fbb,
|
|
const IValue& ivalue) {
|
|
if (ivalue.isNone()) {
|
|
return kNoneIndex;
|
|
}
|
|
|
|
try {
|
|
auto iter = cached_ivalues_.find(ivalue);
|
|
if (iter != cached_ivalues_.end()) {
|
|
return iter->second;
|
|
}
|
|
// NOLINTNEXTLINE(bugprone-empty-catch)
|
|
} catch (...) {
|
|
// Threw if ivalue is not hashable or
|
|
// if ivalue is don't have proper operator==
|
|
// we don't care catchall because either case we want to skip hashing
|
|
}
|
|
|
|
auto offset = iValueToFB(fbb, ivalue);
|
|
uint32_t index = insertIValue(offset);
|
|
try {
|
|
cached_ivalues_[ivalue] = index;
|
|
// NOLINTNEXTLINE(bugprone-empty-catch)
|
|
} catch (...) {
|
|
// Threw if ivalue is not hashable or
|
|
// if ivalue is don't have proper operator==
|
|
// we don't care catchall because either case we want to skip hashing
|
|
}
|
|
|
|
return index;
|
|
}
|
|
|
|
flatbuffers::Offset<mobile::serialization::IValue> FlatbufferSerializer::
|
|
iValueToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) {
|
|
using mobile::serialization::IValueUnion;
|
|
|
|
IValueUnion ivalue_type = IValueUnion::NONE;
|
|
flatbuffers::Offset<void> offset = 0;
|
|
|
|
if (ivalue.isTensor()) {
|
|
ivalue_type = IValueUnion::TensorMetadata;
|
|
offset = tensorToFB(fbb, ivalue).Union();
|
|
} else if (ivalue.isTuple()) {
|
|
ivalue_type = IValueUnion::Tuple;
|
|
offset = tupleToFB(fbb, ivalue).Union();
|
|
} else if (ivalue.isDouble()) {
|
|
ivalue_type = IValueUnion::Double;
|
|
offset = fbb.CreateStruct(mobile::serialization::Double(ivalue.toDouble()))
|
|
.Union();
|
|
} else if (ivalue.isComplexDouble()) {
|
|
auto comp = ivalue.toComplexDouble();
|
|
ivalue_type = IValueUnion::ComplexDouble;
|
|
offset = fbb.CreateStruct(mobile::serialization::ComplexDouble(
|
|
comp.real(), comp.imag()))
|
|
.Union();
|
|
} else if (ivalue.isInt()) {
|
|
ivalue_type = IValueUnion::Int;
|
|
offset =
|
|
fbb.CreateStruct(mobile::serialization::Int(ivalue.toInt())).Union();
|
|
} else if (ivalue.isBool()) {
|
|
ivalue_type = IValueUnion::Bool;
|
|
offset =
|
|
fbb.CreateStruct(mobile::serialization::Bool(ivalue.toBool())).Union();
|
|
} else if (ivalue.isString()) {
|
|
ivalue_type = IValueUnion::String;
|
|
offset = mobile::serialization::CreateString(
|
|
fbb, fbb.CreateSharedString(ivalue.toStringRef()))
|
|
.Union();
|
|
} else if (ivalue.isGenericDict()) {
|
|
ivalue_type = IValueUnion::Dict;
|
|
offset = dictToFB(fbb, ivalue).Union();
|
|
} else if (ivalue.isNone()) {
|
|
ivalue_type = IValueUnion::NONE;
|
|
offset = 0;
|
|
} else if (ivalue.isIntList()) {
|
|
ivalue_type = IValueUnion::IntList;
|
|
offset = mobile::serialization::CreateIntList(
|
|
fbb, fbb.CreateVector(ivalue.toIntVector()))
|
|
.Union();
|
|
} else if (ivalue.isDoubleList()) {
|
|
ivalue_type = IValueUnion::DoubleList;
|
|
offset = mobile::serialization::CreateDoubleList(
|
|
fbb, fbb.CreateVector(ivalue.toDoubleVector()))
|
|
.Union();
|
|
} else if (ivalue.isBoolList()) {
|
|
ivalue_type = IValueUnion::BoolList;
|
|
auto boollist = ivalue.toBoolList();
|
|
std::vector<uint8_t> bool_vec(boollist.begin(), boollist.end());
|
|
offset =
|
|
mobile::serialization::CreateBoolListDirect(fbb, &bool_vec).Union();
|
|
} else if (ivalue.isList()) {
|
|
ivalue_type = IValueUnion::List;
|
|
offset = listToFB(fbb, ivalue).Union();
|
|
} else if (ivalue.isObject()) {
|
|
ivalue_type = IValueUnion::Object;
|
|
offset = objectToFB(fbb, ivalue).Union();
|
|
} else if (ivalue.isDevice()) {
|
|
ivalue_type = IValueUnion::Device;
|
|
offset = mobile::serialization::CreateDevice(
|
|
fbb, fbb.CreateSharedString(ivalue.toDevice().str()))
|
|
.Union();
|
|
} else if (ivalue.isEnum()) {
|
|
const auto& enum_holder = ivalue.toEnumHolder();
|
|
const auto& qualified_class_name =
|
|
enum_holder->type()->qualifiedClassName();
|
|
uint32_t ival_pos = storeIValueAndGetIndex(fbb, enum_holder->value());
|
|
ivalue_type = IValueUnion::EnumValue;
|
|
offset = mobile::serialization::CreateEnumValue(
|
|
fbb,
|
|
fbb.CreateSharedString(qualified_class_name.qualifiedName()),
|
|
ival_pos)
|
|
.Union();
|
|
} else {
|
|
TORCH_CHECK(
|
|
false, "Invalid IValue type for serialization: ", ivalue.tagKind());
|
|
}
|
|
return CreateIValue(fbb, ivalue_type, offset);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void save_mobile_module(
|
|
const mobile::Module& module,
|
|
const std::string& filename,
|
|
const ExtraFilesMap& extra_files,
|
|
const ExtraFilesMap& jit_sources,
|
|
const std::vector<IValue>& jit_constants) {
|
|
auto buffer = save_mobile_module_to_bytes(
|
|
module, extra_files, jit_sources, jit_constants);
|
|
std::fstream ofile(filename, std::ios::binary | std::ios::out);
|
|
ofile.write(
|
|
reinterpret_cast<char*>(buffer->data()),
|
|
static_cast<std::streamsize>(buffer->size()));
|
|
ofile.close();
|
|
}
|
|
|
|
/// Deletes a DetachedBuffer, along with the internal
|
|
/// flatbuffers::DetachedBuffer if present. Used as a custom deleter for
|
|
/// std::unique_ptr; see UniqueDetachedBuffer and make_unique_detached_buffer.
|
|
void DetachedBuffer::destroy(DetachedBuffer* buf) {
|
|
// May be null.
|
|
delete static_cast<flatbuffers::DetachedBuffer*>(buf->data_owner_);
|
|
delete buf;
|
|
}
|
|
|
|
/// Provides access to DetachedBuffer::destroy().
|
|
struct DetachedBufferFriend {
|
|
/// Returns a UniqueDetachedBuffer that wraps the provided DetachedBuffer.
|
|
static DetachedBuffer::UniqueDetachedBuffer make_unique_detached_buffer(
|
|
DetachedBuffer* buf) {
|
|
return DetachedBuffer::UniqueDetachedBuffer(buf, DetachedBuffer::destroy);
|
|
}
|
|
};
|
|
|
|
DetachedBuffer::UniqueDetachedBuffer save_mobile_module_to_bytes(
|
|
const mobile::Module& module,
|
|
const ExtraFilesMap& extra_files,
|
|
const ExtraFilesMap& jit_sources,
|
|
const std::vector<IValue>& jit_constants) {
|
|
FlatbufferSerializer fb_serializer;
|
|
flatbuffers::DetachedBuffer buf = fb_serializer.serializeModule(
|
|
module,
|
|
/*include_tensor_data_in_flatbuffer=*/true,
|
|
extra_files,
|
|
jit_sources,
|
|
jit_constants);
|
|
flatbuffers::DetachedBuffer* buf_ptr =
|
|
new flatbuffers::DetachedBuffer(std::move(buf));
|
|
DetachedBuffer* ret =
|
|
new DetachedBuffer(buf_ptr->data(), buf_ptr->size(), buf_ptr);
|
|
return DetachedBufferFriend::make_unique_detached_buffer(ret);
|
|
}
|
|
|
|
void save_mobile_module_to_func(
|
|
const mobile::Module& module,
|
|
const std::function<size_t(const void*, size_t)>& writer_func) {
|
|
auto buffer = save_mobile_module_to_bytes(module);
|
|
writer_func(buffer->data(), buffer->size());
|
|
}
|
|
|
|
bool register_flatbuffer_serializer() {
|
|
return true;
|
|
}
|
|
|
|
} // namespace torch::jit
|