Refactor saving jit::Module to mobile .pt in 2 steps: (#66494)

Summary:
1. is to convert Function -> mobile::Function
2. is to serialize mobile::Function

This also opens opportunity to create mobile::Module without saving/reloading

Fixes #{issue number}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66494

Reviewed By: zhxchen17

Differential Revision: D32293022

Pulled By: qihqi

fbshipit-source-id: 29b43d47ff86071d5e2f9d6ca4dba4445711ce3d
This commit is contained in:
Han Qi
2021-11-17 11:57:56 -08:00
committed by Facebook GitHub Bot
parent e2aeb4a7af
commit 4eb772fde6
13 changed files with 1449 additions and 378 deletions

View File

@ -38,6 +38,18 @@
namespace torch {
namespace jit {
CompilationOptions getOptionsFromGlobal() {
CompilationOptions compilation_options;
compilation_options.enable_default_args_before_out_args =
BytecodeEmitMode::is_default_args_before_out_args_enabled();
compilation_options.enable_default_value_for_unspecified_arg =
BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled();
compilation_options.incl_interface_call = getMobileInterfaceCallExport();
compilation_options.model_version =
caffe2::serialize::kProducedBytecodeVersion;
return compilation_options;
}
IValue to_tuple(std::initializer_list<IValue> ivalues) {
return c10::ivalue::Tuple::create(ivalues);
}
@ -63,138 +75,49 @@ ExportModuleExtraFilesHook& GetExtraFilesHook() {
}
std::pair<IValue, IValue> getFunctionTuple(
const Module& module,
const Function& func,
std::unique_ptr<Graph> optimizedGraph,
const CompilationUnit& compilation_unit,
const mobile::Function& func,
BackendDebugInfoRecorder& debug_info_recorder,
const std::string& qn,
TypeNameUniquer& type_name_uniquer_) {
TORCH_INTERNAL_ASSERT(optimizedGraph);
std::shared_ptr<MobileCode> code;
code = std::make_shared<MobileCode>(
std::move(optimizedGraph), func.name(), BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled() /* emit_default_input_instructions */, BytecodeEmitMode::is_default_args_before_out_args_enabled() /* enable_defaults_args_with_out_args */);
auto instructions_copy = code->instructions();
// operator names
std::vector<c10::OperatorName> opnames;
std::vector<std::string> method_names;
std::vector<int64_t> op_debug_handles;
int next_new_op_index = 0;
for (size_t i = 0; i < instructions_copy.size(); ++i) {
Instruction ins = instructions_copy[i];
if ((ins.op == OP || ins.op == OPN) && ins.X == next_new_op_index) {
// Found a new op (assumes new operators ordered by ascending ins.X)
auto node = code->instructions_source()[i];
opnames.emplace_back(node->schema().operator_name());
next_new_op_index++;
}
// CALL nodes at this point represent built-in (i.e. non-Graph)
// functions that were not inlined. Here we convert the CALL
// instructions for these functions into INTERFACE_CALL instructions
// s.t. at runtime, we will look up the Function* on the Type of the
// 0th argument in the stack and call that directly.
if (ins.op == CALL) {
auto node = code->instructions_source()[i];
if (node->kind() == prim::CallMethod) {
// NB: replacing instruction
auto method_name_idx =
code->constant_table().size() + method_names.size();
method_names.emplace_back(node->s(attr::name));
Instruction new_instr{
INTERFACE_CALL,
static_cast<int32_t>(method_name_idx),
static_cast<uint16_t>(node->inputs().size())};
instructions_copy[i] = new_instr;
} else {
TORCH_INTERNAL_ASSERT(
false, "Unsupported node kind on CALL opcode for mobile");
}
} else if (ins.op == RET) {
auto node = code->instructions_source()[i];
for (const auto& input : node->inputs()) {
const auto& input_type = input->type();
if (input_type->kind() == TypeKind::ListType ||
input_type->kind() == TypeKind::DictType) {
for (const TypePtr& element_type : input_type->containedTypes()) {
TORCH_CHECK(
element_type->kind() != TypeKind::ClassType,
"Returining a list or dictionary with pytorch class type ",
"is not supported in mobile module "
"(List[Foo] or Dict[int, Foo] for class Foo(torch.nn.Module)). "
"Workaround: instead of using pytorch class as their element type, ",
"use a combination of list, dictionary, and single types.");
}
}
}
} else {
TORCH_CHECK(
isOpSupportedInMobile(ins.op),
toString(ins.op),
" is not supported in mobile module.");
}
auto node = code->instructions_source()[i];
int64_t debug_handle = debug_info_recorder.getNextDebugHandle(node);
// Note 1-to-1 correspondence between instructions and debug handles
op_debug_handles.emplace_back(debug_handle);
}
const std::shared_ptr<mobile::Code> mobile_code_ptr = func.get_code();
// instructions
std::vector<IValue> instructions;
instructions.reserve(instructions_copy.size());
for (Instruction ins : instructions_copy) {
instructions.reserve(mobile_code_ptr->instructions_.size());
for (Instruction ins : mobile_code_ptr->instructions_) {
instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N}));
}
// operators
std::vector<IValue> operators;
auto op_to_specified_args = code->op_to_num_specified_args();
operators.reserve(opnames.size());
for (const auto& opname : opnames) {
auto unique_name = c10::toString(opname);
// For operator with vararg, adding default arguments would be confusing and
// is not allowed. For an operator with num_args = -1, it means the number
// of arguments is not available for this operator, we don't do any backward
// compatibility adaptation at runtime.
int num_args = -1;
auto it = op_to_specified_args.find(unique_name);
if (it != op_to_specified_args.end()) {
num_args = it->second;
}
operators.reserve(mobile_code_ptr->op_names_.size());
for (int i = 0; i < mobile_code_ptr->op_names_.size(); ++i) {
const auto& opname = mobile_code_ptr->op_names_[i];
const int size = mobile_code_ptr->operator_input_sizes_[i];
if (BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()) {
operators.emplace_back(to_tuple({opname.name, opname.overload_name}));
} else {
operators.emplace_back(
to_tuple({opname.name, opname.overload_name, num_args}));
to_tuple({opname.name, opname.overload_name, size}));
}
}
// constants
//
// Make a copy of the constants and append the method names
// that we emitted for the converted INTERFACE_CALL nodes above.
auto constants = code->constant_table();
for (auto& method_name : method_names) {
constants.emplace_back(std::move(method_name));
}
// types
std::vector<IValue> types;
types.reserve(code->type_table().size());
types.reserve(mobile_code_ptr->types_.size());
static const std::string torch_prefix("__torch__");
static const std::string class_prefix("__torch__.torch.classes");
std::shared_ptr<torch::jit::CompilationUnit> cu =
module._ivalue()->compilation_unit();
for (const TypePtr& t : code->type_table()) {
for (const TypePtr& t : mobile_code_ptr->types_) {
std::string type_str = t->annotation_str();
if (t->kind() == TypeKind::TupleType) {
TORCH_CHECK(
cu->get_named_tuple(t->str()),
compilation_unit.get_named_tuple(t->str()),
"Can't find definition for the qualified name: ",
t->str(),
"(TypeKind::TupleType) in compilation unit.",
"Please report a bug to PyTorch.");
auto named_tuple_type = cu->get_named_tuple(t->str());
auto named_tuple_type = compilation_unit.get_named_tuple(t->str());
if (named_tuple_type != nullptr) {
std::string named_tuple_str = t->str();
named_tuple_str.append("[NamedTuple, [");
@ -254,12 +177,12 @@ std::pair<IValue, IValue> getFunctionTuple(
// since the register location is embedded into the bytecode, pass the
// register size
auto register_size = static_cast<int>(code->register_size());
auto register_size = static_cast<int>(mobile_code_ptr->register_size_);
auto codeTable = Table(
{{"instructions", to_tuple(instructions)},
{"operators", to_tuple(operators)},
{"constants", to_tuple(constants)},
{"constants", to_tuple(mobile_code_ptr->constants_)},
{"types", to_tuple(types)},
{"register_size", register_size}});
@ -273,14 +196,7 @@ std::pair<IValue, IValue> getFunctionTuple(
}
return c10::nullopt;
};
TORCH_CHECK(
schema.overload_name().empty(), // @TODO: is this check correct?
"Overloads are not supported in mobile modules.");
TORCH_CHECK(
!schema.is_vararg(), "Python *args are not supported in mobile modules.");
TORCH_CHECK(
!schema.is_varret(),
"A variable number of return values is not supported in mobile modules.");
auto makeArgTuple = [&](const std::vector<Argument>& args) {
std::vector<IValue> argTables;
for (auto&& arg : args) {
@ -315,6 +231,17 @@ std::pair<IValue, IValue> getFunctionTuple(
});
// function tuple
std::string qn;
if (func.name() == "__setstate__" || func.name() == "__getstate__") {
auto classtype = func.getSchema().arguments()[0].type()->cast<ClassType>();
TORCH_INTERNAL_ASSERT(
classtype, "class is null ", func.qualname().qualifiedName());
qn = c10::QualifiedName(
type_name_uniquer_.getUniqueName(classtype), func.name())
.qualifiedName();
} else {
qn = func.qualname().qualifiedName();
}
auto bytecode_vals = to_tuple({qn, codeTable, schemaTable});
c10::optional<IValue> debug_info_vals;
@ -324,41 +251,27 @@ std::pair<IValue, IValue> getFunctionTuple(
// debug handles generated by debug_handle_manager
// will correspond to {source_range, inlinedCallStackPtr} which we will
// serialize separately.
IValue module_debug_tuple = c10::ivalue::Tuple::create(op_debug_handles);
IValue module_debug_tuple =
c10::ivalue::Tuple::create(mobile_code_ptr->debug_handles_);
auto function_debug_info =
Table({{"function_debug_handles", module_debug_tuple}});
debug_info_vals = to_tuple({qn, function_debug_info});
return std::make_pair(bytecode_vals, debug_info_vals);
}
void pushFunctionToIValues(
BytecodeExportSet exportSet,
void pushMobileFunctionsToIValues(
const CompilationUnit& compilation_unit,
const mobile::Module& module,
std::vector<c10::IValue>& elements,
std::vector<c10::IValue>& debugInfoElements,
BackendDebugInfoRecorder& recorder,
TypeNameUniquer& uniquer) {
exportSet.visit(
[&](const c10::QualifiedName& qn, ExportedFunction& exported) {
auto tuple = getFunctionTuple(
exported.mod,
exported.function,
std::move(exported.optimizedGraph),
recorder,
qn.qualifiedName(),
uniquer);
elements.push_back(std::move(tuple.first));
debugInfoElements.push_back(std::move(tuple.second));
});
}
void pushFunctionToIValues(
BytecodeExportSet exportSet,
std::vector<c10::IValue>& elements,
BackendDebugInfoRecorder& recorder,
TypeNameUniquer& uniquer) {
std::vector<c10::IValue> debugInfoElements;
pushFunctionToIValues(
std::move(exportSet), elements, debugInfoElements, recorder, uniquer);
for (const auto& method : module.get_methods()) {
auto tuple = getFunctionTuple(
compilation_unit, method.function(), recorder, uniquer);
elements.push_back(std::move(tuple.first));
debugInfoElements.push_back(std::move(tuple.second));
}
}
std::unordered_set<const FunctionSchema*> getInterfaceCalls(Graph& graph) {
@ -402,61 +315,6 @@ std::vector<ModuleMethod> getModuleInterfaceExports(
return ret;
}
void exportFunction(
BytecodeExportSet& exportSet,
const ModuleMethod& method,
bool toplevel = false) {
const auto& func = method.function;
const auto& qn = method.exportName;
if (exportSet.contains(qn)) {
if (toplevel) {
exportSet.update(qn, toplevel);
}
return;
}
auto graph = func.graph()->copyUnique();
Inline(*graph);
auto interfaceCalls = getInterfaceCalls(*graph);
exportSet.add(
qn, ExportedFunction{method.module, func, std::move(graph), toplevel});
if (!getMobileInterfaceCallExport()) {
return;
}
auto interfaces = getModuleInterfaceExports(method.module, interfaceCalls);
for (const auto& interface : interfaces) {
exportFunction(exportSet, interface);
}
}
void setstateTuple(
BytecodeExportSet& exportSet,
const Module& module,
const IValue& ivalue,
TypeNameUniquer& type_name_uniquer_,
bool toplevel = false) {
if (!ivalue.isObject())
return;
auto obj = ivalue.toObject();
auto type = obj->type();
if (checkHasValidSetGetState(type)) {
Function& setstate = type->getMethod("__setstate__");
auto qn = type_name_uniquer_.getUniqueName(obj->type()).qualifiedName() +
"." + setstate.name();
if (exportSet.contains(qn)) {
return;
}
if (auto f = tryToGraphFunction(setstate)) {
exportFunction(exportSet, ModuleMethod{module, *f, qn}, toplevel);
}
} else {
for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
setstateTuple(exportSet, module, obj->getSlot(i), type_name_uniquer_);
}
}
}
bool isLoweredModule(const Module& m) {
c10::QualifiedName type_name;
if (m.type()->name()) {
@ -544,24 +402,6 @@ bool getMobileInterfaceCallExport() {
return mobileInterfaceCallExport().load(std::memory_order_relaxed);
}
BytecodeExportSet moduleMethodsTuple(
const Module& module,
TypeNameUniquer& type_name_uniquer_) {
BytecodeExportSet exportSet;
auto methods = module.get_methods();
// top level methods
for (const auto& method : methods) {
const auto& f = toGraphFunction(method.function());
exportFunction(
exportSet, ModuleMethod{module, f, f.qualname()}, /* toplevel */ true);
}
// __setstate__ of all components
setstateTuple(exportSet, module, module._ivalue(), type_name_uniquer_, true);
return exportSet;
}
void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) {
GetExtraFilesHook() = std::move(hook);
}
@ -774,9 +614,12 @@ void ScriptModuleSerializer::writeByteCode(
// Always save debug handles
debug_info_elements.emplace_back(static_cast<int64_t>(version_to_write));
BytecodeExportSet exportSet = moduleMethodsTuple(module, type_name_uniquer_);
pushFunctionToIValues(
std::move(exportSet),
mobile::Module mobile_module =
jitModuleToMobile(module, getOptionsFromGlobal());
pushMobileFunctionsToIValues(
*module._ivalue()->compilation_unit(),
mobile_module,
elements,
debug_info_elements,
debug_info_recorder,
@ -840,9 +683,9 @@ void ScriptModuleSerializer::writeByteCode(
getBackendDebugInfoMap(module, backend_debug_info_map);
// Now get the debug-handles-to-inlined-cs-ptr-map
// And serialize that in a separate archive
auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording();
debug_handle_cs_ptr_map.insert(
backend_debug_info_map.begin(), backend_debug_info_map.end());
const auto& debug_info = mobile_module.getDebugTable().getCallStackPtrMap();
BackendDebugInfoMapType debug_handle_cs_ptr_map(
debug_info.begin(), debug_info.end());
CallStackDebugInfoPickler cs_debug_info_pickler;
auto cs_data = cs_debug_info_pickler.pickle(
debug_handle_cs_ptr_map, source_range_tags_);
@ -962,31 +805,13 @@ void ExportModule(
namespace {
void export_opnames(const script::Module& m, std::set<std::string>& opnames) {
std::vector<c10::IValue> elements;
BackendDebugInfoRecorder dummy;
TypeNameUniquer dummy_uniquer = TypeNameUniquer();
BytecodeExportSet exportSet = moduleMethodsTuple(m, dummy_uniquer);
pushFunctionToIValues(std::move(exportSet), elements, dummy, dummy_uniquer);
for (const auto& element : elements) {
auto table = element.toTupleRef().elements()[1];
auto row =
table.toTupleRef().elements().at(BYTECODE_INDEX_OPERATOR).toTuple();
TORCH_INTERNAL_ASSERT(
row->elements().at(0).toStringRef() == "operators",
"Expected operators but found ",
row->elements().at(0).toStringRef());
const auto& ops_list = row->elements().at(1).toTupleRef().elements();
for (const auto& op : ops_list) {
const auto& op_item = op.toTupleRef().elements();
TORCH_CHECK(
op_item.size() >= 2,
"There should be either two parts (name and overload name), ",
"or three parts (name, overload name and number of specified args) ",
"for an operator.");
auto opname = op_item[0].toString()->string();
auto overload = op_item[1].toString()->string();
mobile::Module mobile_m = jitModuleToMobile(m, getOptionsFromGlobal());
for (const auto& method : mobile_m.get_methods()) {
for (const auto& op : method.function().get_code()->op_names_) {
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
opnames.emplace(overload.empty() ? opname : opname + "." + overload);
opnames.emplace(
op.overload_name.empty() ? op.name
: op.name + "." + op.overload_name);
}
}
}