mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor saving jit::Module to mobile .pt in 2 steps: (#66494)
Summary:
1. is to convert Function -> mobile::Function
2. is to serialize mobile::Function
This also opens opportunity to create mobile::Module without saving/reloading
Fixes #{issue number}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66494
Reviewed By: zhxchen17
Differential Revision: D32293022
Pulled By: qihqi
fbshipit-source-id: 29b43d47ff86071d5e2f9d6ca4dba4445711ce3d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e2aeb4a7af
commit
4eb772fde6
@ -38,6 +38,18 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
CompilationOptions getOptionsFromGlobal() {
|
||||
CompilationOptions compilation_options;
|
||||
compilation_options.enable_default_args_before_out_args =
|
||||
BytecodeEmitMode::is_default_args_before_out_args_enabled();
|
||||
compilation_options.enable_default_value_for_unspecified_arg =
|
||||
BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled();
|
||||
compilation_options.incl_interface_call = getMobileInterfaceCallExport();
|
||||
compilation_options.model_version =
|
||||
caffe2::serialize::kProducedBytecodeVersion;
|
||||
return compilation_options;
|
||||
}
|
||||
|
||||
IValue to_tuple(std::initializer_list<IValue> ivalues) {
|
||||
return c10::ivalue::Tuple::create(ivalues);
|
||||
}
|
||||
@ -63,138 +75,49 @@ ExportModuleExtraFilesHook& GetExtraFilesHook() {
|
||||
}
|
||||
|
||||
std::pair<IValue, IValue> getFunctionTuple(
|
||||
const Module& module,
|
||||
const Function& func,
|
||||
std::unique_ptr<Graph> optimizedGraph,
|
||||
const CompilationUnit& compilation_unit,
|
||||
const mobile::Function& func,
|
||||
BackendDebugInfoRecorder& debug_info_recorder,
|
||||
const std::string& qn,
|
||||
TypeNameUniquer& type_name_uniquer_) {
|
||||
TORCH_INTERNAL_ASSERT(optimizedGraph);
|
||||
std::shared_ptr<MobileCode> code;
|
||||
code = std::make_shared<MobileCode>(
|
||||
std::move(optimizedGraph), func.name(), BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled() /* emit_default_input_instructions */, BytecodeEmitMode::is_default_args_before_out_args_enabled() /* enable_defaults_args_with_out_args */);
|
||||
auto instructions_copy = code->instructions();
|
||||
|
||||
// operator names
|
||||
std::vector<c10::OperatorName> opnames;
|
||||
std::vector<std::string> method_names;
|
||||
std::vector<int64_t> op_debug_handles;
|
||||
int next_new_op_index = 0;
|
||||
for (size_t i = 0; i < instructions_copy.size(); ++i) {
|
||||
Instruction ins = instructions_copy[i];
|
||||
if ((ins.op == OP || ins.op == OPN) && ins.X == next_new_op_index) {
|
||||
// Found a new op (assumes new operators ordered by ascending ins.X)
|
||||
auto node = code->instructions_source()[i];
|
||||
opnames.emplace_back(node->schema().operator_name());
|
||||
next_new_op_index++;
|
||||
}
|
||||
// CALL nodes at this point represent built-in (i.e. non-Graph)
|
||||
// functions that were not inlined. Here we convert the CALL
|
||||
// instructions for these functions into INTERFACE_CALL instructions
|
||||
// s.t. at runtime, we will look up the Function* on the Type of the
|
||||
// 0th argument in the stack and call that directly.
|
||||
if (ins.op == CALL) {
|
||||
auto node = code->instructions_source()[i];
|
||||
if (node->kind() == prim::CallMethod) {
|
||||
// NB: replacing instruction
|
||||
auto method_name_idx =
|
||||
code->constant_table().size() + method_names.size();
|
||||
method_names.emplace_back(node->s(attr::name));
|
||||
Instruction new_instr{
|
||||
INTERFACE_CALL,
|
||||
static_cast<int32_t>(method_name_idx),
|
||||
static_cast<uint16_t>(node->inputs().size())};
|
||||
instructions_copy[i] = new_instr;
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "Unsupported node kind on CALL opcode for mobile");
|
||||
}
|
||||
} else if (ins.op == RET) {
|
||||
auto node = code->instructions_source()[i];
|
||||
for (const auto& input : node->inputs()) {
|
||||
const auto& input_type = input->type();
|
||||
if (input_type->kind() == TypeKind::ListType ||
|
||||
input_type->kind() == TypeKind::DictType) {
|
||||
for (const TypePtr& element_type : input_type->containedTypes()) {
|
||||
TORCH_CHECK(
|
||||
element_type->kind() != TypeKind::ClassType,
|
||||
"Returining a list or dictionary with pytorch class type ",
|
||||
"is not supported in mobile module "
|
||||
"(List[Foo] or Dict[int, Foo] for class Foo(torch.nn.Module)). "
|
||||
"Workaround: instead of using pytorch class as their element type, ",
|
||||
"use a combination of list, dictionary, and single types.");
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
isOpSupportedInMobile(ins.op),
|
||||
toString(ins.op),
|
||||
" is not supported in mobile module.");
|
||||
}
|
||||
auto node = code->instructions_source()[i];
|
||||
int64_t debug_handle = debug_info_recorder.getNextDebugHandle(node);
|
||||
// Note 1-to-1 correspondence between instructions and debug handles
|
||||
op_debug_handles.emplace_back(debug_handle);
|
||||
}
|
||||
const std::shared_ptr<mobile::Code> mobile_code_ptr = func.get_code();
|
||||
|
||||
// instructions
|
||||
std::vector<IValue> instructions;
|
||||
instructions.reserve(instructions_copy.size());
|
||||
for (Instruction ins : instructions_copy) {
|
||||
instructions.reserve(mobile_code_ptr->instructions_.size());
|
||||
for (Instruction ins : mobile_code_ptr->instructions_) {
|
||||
instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N}));
|
||||
}
|
||||
|
||||
// operators
|
||||
std::vector<IValue> operators;
|
||||
auto op_to_specified_args = code->op_to_num_specified_args();
|
||||
operators.reserve(opnames.size());
|
||||
for (const auto& opname : opnames) {
|
||||
auto unique_name = c10::toString(opname);
|
||||
// For operator with vararg, adding default arguments would be confusing and
|
||||
// is not allowed. For an operator with num_args = -1, it means the number
|
||||
// of arguments is not available for this operator, we don't do any backward
|
||||
// compatibility adaptation at runtime.
|
||||
int num_args = -1;
|
||||
auto it = op_to_specified_args.find(unique_name);
|
||||
if (it != op_to_specified_args.end()) {
|
||||
num_args = it->second;
|
||||
}
|
||||
operators.reserve(mobile_code_ptr->op_names_.size());
|
||||
for (int i = 0; i < mobile_code_ptr->op_names_.size(); ++i) {
|
||||
const auto& opname = mobile_code_ptr->op_names_[i];
|
||||
const int size = mobile_code_ptr->operator_input_sizes_[i];
|
||||
if (BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()) {
|
||||
operators.emplace_back(to_tuple({opname.name, opname.overload_name}));
|
||||
} else {
|
||||
operators.emplace_back(
|
||||
to_tuple({opname.name, opname.overload_name, num_args}));
|
||||
to_tuple({opname.name, opname.overload_name, size}));
|
||||
}
|
||||
}
|
||||
|
||||
// constants
|
||||
//
|
||||
// Make a copy of the constants and append the method names
|
||||
// that we emitted for the converted INTERFACE_CALL nodes above.
|
||||
auto constants = code->constant_table();
|
||||
for (auto& method_name : method_names) {
|
||||
constants.emplace_back(std::move(method_name));
|
||||
}
|
||||
|
||||
// types
|
||||
std::vector<IValue> types;
|
||||
types.reserve(code->type_table().size());
|
||||
types.reserve(mobile_code_ptr->types_.size());
|
||||
static const std::string torch_prefix("__torch__");
|
||||
static const std::string class_prefix("__torch__.torch.classes");
|
||||
std::shared_ptr<torch::jit::CompilationUnit> cu =
|
||||
module._ivalue()->compilation_unit();
|
||||
|
||||
for (const TypePtr& t : code->type_table()) {
|
||||
for (const TypePtr& t : mobile_code_ptr->types_) {
|
||||
std::string type_str = t->annotation_str();
|
||||
if (t->kind() == TypeKind::TupleType) {
|
||||
TORCH_CHECK(
|
||||
cu->get_named_tuple(t->str()),
|
||||
compilation_unit.get_named_tuple(t->str()),
|
||||
"Can't find definition for the qualified name: ",
|
||||
t->str(),
|
||||
"(TypeKind::TupleType) in compilation unit.",
|
||||
"Please report a bug to PyTorch.");
|
||||
auto named_tuple_type = cu->get_named_tuple(t->str());
|
||||
auto named_tuple_type = compilation_unit.get_named_tuple(t->str());
|
||||
if (named_tuple_type != nullptr) {
|
||||
std::string named_tuple_str = t->str();
|
||||
named_tuple_str.append("[NamedTuple, [");
|
||||
@ -254,12 +177,12 @@ std::pair<IValue, IValue> getFunctionTuple(
|
||||
|
||||
// since the register location is embedded into the bytecode, pass the
|
||||
// register size
|
||||
auto register_size = static_cast<int>(code->register_size());
|
||||
auto register_size = static_cast<int>(mobile_code_ptr->register_size_);
|
||||
|
||||
auto codeTable = Table(
|
||||
{{"instructions", to_tuple(instructions)},
|
||||
{"operators", to_tuple(operators)},
|
||||
{"constants", to_tuple(constants)},
|
||||
{"constants", to_tuple(mobile_code_ptr->constants_)},
|
||||
{"types", to_tuple(types)},
|
||||
{"register_size", register_size}});
|
||||
|
||||
@ -273,14 +196,7 @@ std::pair<IValue, IValue> getFunctionTuple(
|
||||
}
|
||||
return c10::nullopt;
|
||||
};
|
||||
TORCH_CHECK(
|
||||
schema.overload_name().empty(), // @TODO: is this check correct?
|
||||
"Overloads are not supported in mobile modules.");
|
||||
TORCH_CHECK(
|
||||
!schema.is_vararg(), "Python *args are not supported in mobile modules.");
|
||||
TORCH_CHECK(
|
||||
!schema.is_varret(),
|
||||
"A variable number of return values is not supported in mobile modules.");
|
||||
|
||||
auto makeArgTuple = [&](const std::vector<Argument>& args) {
|
||||
std::vector<IValue> argTables;
|
||||
for (auto&& arg : args) {
|
||||
@ -315,6 +231,17 @@ std::pair<IValue, IValue> getFunctionTuple(
|
||||
});
|
||||
|
||||
// function tuple
|
||||
std::string qn;
|
||||
if (func.name() == "__setstate__" || func.name() == "__getstate__") {
|
||||
auto classtype = func.getSchema().arguments()[0].type()->cast<ClassType>();
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
classtype, "class is null ", func.qualname().qualifiedName());
|
||||
qn = c10::QualifiedName(
|
||||
type_name_uniquer_.getUniqueName(classtype), func.name())
|
||||
.qualifiedName();
|
||||
} else {
|
||||
qn = func.qualname().qualifiedName();
|
||||
}
|
||||
auto bytecode_vals = to_tuple({qn, codeTable, schemaTable});
|
||||
|
||||
c10::optional<IValue> debug_info_vals;
|
||||
@ -324,41 +251,27 @@ std::pair<IValue, IValue> getFunctionTuple(
|
||||
// debug handles generated by debug_handle_manager
|
||||
// will correspond to {source_range, inlinedCallStackPtr} which we will
|
||||
// serialize separately.
|
||||
IValue module_debug_tuple = c10::ivalue::Tuple::create(op_debug_handles);
|
||||
IValue module_debug_tuple =
|
||||
c10::ivalue::Tuple::create(mobile_code_ptr->debug_handles_);
|
||||
auto function_debug_info =
|
||||
Table({{"function_debug_handles", module_debug_tuple}});
|
||||
debug_info_vals = to_tuple({qn, function_debug_info});
|
||||
return std::make_pair(bytecode_vals, debug_info_vals);
|
||||
}
|
||||
|
||||
void pushFunctionToIValues(
|
||||
BytecodeExportSet exportSet,
|
||||
void pushMobileFunctionsToIValues(
|
||||
const CompilationUnit& compilation_unit,
|
||||
const mobile::Module& module,
|
||||
std::vector<c10::IValue>& elements,
|
||||
std::vector<c10::IValue>& debugInfoElements,
|
||||
BackendDebugInfoRecorder& recorder,
|
||||
TypeNameUniquer& uniquer) {
|
||||
exportSet.visit(
|
||||
[&](const c10::QualifiedName& qn, ExportedFunction& exported) {
|
||||
auto tuple = getFunctionTuple(
|
||||
exported.mod,
|
||||
exported.function,
|
||||
std::move(exported.optimizedGraph),
|
||||
recorder,
|
||||
qn.qualifiedName(),
|
||||
uniquer);
|
||||
elements.push_back(std::move(tuple.first));
|
||||
debugInfoElements.push_back(std::move(tuple.second));
|
||||
});
|
||||
}
|
||||
|
||||
void pushFunctionToIValues(
|
||||
BytecodeExportSet exportSet,
|
||||
std::vector<c10::IValue>& elements,
|
||||
BackendDebugInfoRecorder& recorder,
|
||||
TypeNameUniquer& uniquer) {
|
||||
std::vector<c10::IValue> debugInfoElements;
|
||||
pushFunctionToIValues(
|
||||
std::move(exportSet), elements, debugInfoElements, recorder, uniquer);
|
||||
for (const auto& method : module.get_methods()) {
|
||||
auto tuple = getFunctionTuple(
|
||||
compilation_unit, method.function(), recorder, uniquer);
|
||||
elements.push_back(std::move(tuple.first));
|
||||
debugInfoElements.push_back(std::move(tuple.second));
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_set<const FunctionSchema*> getInterfaceCalls(Graph& graph) {
|
||||
@ -402,61 +315,6 @@ std::vector<ModuleMethod> getModuleInterfaceExports(
|
||||
return ret;
|
||||
}
|
||||
|
||||
void exportFunction(
|
||||
BytecodeExportSet& exportSet,
|
||||
const ModuleMethod& method,
|
||||
bool toplevel = false) {
|
||||
const auto& func = method.function;
|
||||
const auto& qn = method.exportName;
|
||||
if (exportSet.contains(qn)) {
|
||||
if (toplevel) {
|
||||
exportSet.update(qn, toplevel);
|
||||
}
|
||||
return;
|
||||
}
|
||||
auto graph = func.graph()->copyUnique();
|
||||
Inline(*graph);
|
||||
auto interfaceCalls = getInterfaceCalls(*graph);
|
||||
exportSet.add(
|
||||
qn, ExportedFunction{method.module, func, std::move(graph), toplevel});
|
||||
|
||||
if (!getMobileInterfaceCallExport()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto interfaces = getModuleInterfaceExports(method.module, interfaceCalls);
|
||||
for (const auto& interface : interfaces) {
|
||||
exportFunction(exportSet, interface);
|
||||
}
|
||||
}
|
||||
|
||||
void setstateTuple(
|
||||
BytecodeExportSet& exportSet,
|
||||
const Module& module,
|
||||
const IValue& ivalue,
|
||||
TypeNameUniquer& type_name_uniquer_,
|
||||
bool toplevel = false) {
|
||||
if (!ivalue.isObject())
|
||||
return;
|
||||
auto obj = ivalue.toObject();
|
||||
auto type = obj->type();
|
||||
if (checkHasValidSetGetState(type)) {
|
||||
Function& setstate = type->getMethod("__setstate__");
|
||||
auto qn = type_name_uniquer_.getUniqueName(obj->type()).qualifiedName() +
|
||||
"." + setstate.name();
|
||||
if (exportSet.contains(qn)) {
|
||||
return;
|
||||
}
|
||||
if (auto f = tryToGraphFunction(setstate)) {
|
||||
exportFunction(exportSet, ModuleMethod{module, *f, qn}, toplevel);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
|
||||
setstateTuple(exportSet, module, obj->getSlot(i), type_name_uniquer_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool isLoweredModule(const Module& m) {
|
||||
c10::QualifiedName type_name;
|
||||
if (m.type()->name()) {
|
||||
@ -544,24 +402,6 @@ bool getMobileInterfaceCallExport() {
|
||||
return mobileInterfaceCallExport().load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
BytecodeExportSet moduleMethodsTuple(
|
||||
const Module& module,
|
||||
TypeNameUniquer& type_name_uniquer_) {
|
||||
BytecodeExportSet exportSet;
|
||||
auto methods = module.get_methods();
|
||||
// top level methods
|
||||
for (const auto& method : methods) {
|
||||
const auto& f = toGraphFunction(method.function());
|
||||
exportFunction(
|
||||
exportSet, ModuleMethod{module, f, f.qualname()}, /* toplevel */ true);
|
||||
}
|
||||
|
||||
// __setstate__ of all components
|
||||
setstateTuple(exportSet, module, module._ivalue(), type_name_uniquer_, true);
|
||||
|
||||
return exportSet;
|
||||
}
|
||||
|
||||
void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) {
|
||||
GetExtraFilesHook() = std::move(hook);
|
||||
}
|
||||
@ -774,9 +614,12 @@ void ScriptModuleSerializer::writeByteCode(
|
||||
// Always save debug handles
|
||||
debug_info_elements.emplace_back(static_cast<int64_t>(version_to_write));
|
||||
|
||||
BytecodeExportSet exportSet = moduleMethodsTuple(module, type_name_uniquer_);
|
||||
pushFunctionToIValues(
|
||||
std::move(exportSet),
|
||||
mobile::Module mobile_module =
|
||||
jitModuleToMobile(module, getOptionsFromGlobal());
|
||||
|
||||
pushMobileFunctionsToIValues(
|
||||
*module._ivalue()->compilation_unit(),
|
||||
mobile_module,
|
||||
elements,
|
||||
debug_info_elements,
|
||||
debug_info_recorder,
|
||||
@ -840,9 +683,9 @@ void ScriptModuleSerializer::writeByteCode(
|
||||
getBackendDebugInfoMap(module, backend_debug_info_map);
|
||||
// Now get the debug-handles-to-inlined-cs-ptr-map
|
||||
// And serialize that in a separate archive
|
||||
auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording();
|
||||
debug_handle_cs_ptr_map.insert(
|
||||
backend_debug_info_map.begin(), backend_debug_info_map.end());
|
||||
const auto& debug_info = mobile_module.getDebugTable().getCallStackPtrMap();
|
||||
BackendDebugInfoMapType debug_handle_cs_ptr_map(
|
||||
debug_info.begin(), debug_info.end());
|
||||
CallStackDebugInfoPickler cs_debug_info_pickler;
|
||||
auto cs_data = cs_debug_info_pickler.pickle(
|
||||
debug_handle_cs_ptr_map, source_range_tags_);
|
||||
@ -962,31 +805,13 @@ void ExportModule(
|
||||
|
||||
namespace {
|
||||
void export_opnames(const script::Module& m, std::set<std::string>& opnames) {
|
||||
std::vector<c10::IValue> elements;
|
||||
BackendDebugInfoRecorder dummy;
|
||||
TypeNameUniquer dummy_uniquer = TypeNameUniquer();
|
||||
BytecodeExportSet exportSet = moduleMethodsTuple(m, dummy_uniquer);
|
||||
pushFunctionToIValues(std::move(exportSet), elements, dummy, dummy_uniquer);
|
||||
for (const auto& element : elements) {
|
||||
auto table = element.toTupleRef().elements()[1];
|
||||
auto row =
|
||||
table.toTupleRef().elements().at(BYTECODE_INDEX_OPERATOR).toTuple();
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
row->elements().at(0).toStringRef() == "operators",
|
||||
"Expected operators but found ",
|
||||
row->elements().at(0).toStringRef());
|
||||
const auto& ops_list = row->elements().at(1).toTupleRef().elements();
|
||||
for (const auto& op : ops_list) {
|
||||
const auto& op_item = op.toTupleRef().elements();
|
||||
TORCH_CHECK(
|
||||
op_item.size() >= 2,
|
||||
"There should be either two parts (name and overload name), ",
|
||||
"or three parts (name, overload name and number of specified args) ",
|
||||
"for an operator.");
|
||||
auto opname = op_item[0].toString()->string();
|
||||
auto overload = op_item[1].toString()->string();
|
||||
mobile::Module mobile_m = jitModuleToMobile(m, getOptionsFromGlobal());
|
||||
for (const auto& method : mobile_m.get_methods()) {
|
||||
for (const auto& op : method.function().get_code()->op_names_) {
|
||||
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
|
||||
opnames.emplace(overload.empty() ? opname : opname + "." + overload);
|
||||
opnames.emplace(
|
||||
op.overload_name.empty() ? op.name
|
||||
: op.name + "." + op.overload_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user