Files
pytorch/torch/csrc/jit/backends/backend_detail.cpp

410 lines
16 KiB
C++

#include <torch/csrc/jit/backends/backend_detail.h>
#include <ATen/code_template.h>
#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/backends/backend.h>
#include <torch/csrc/jit/backends/backend_debug_handler.h>
#include <torch/csrc/jit/backends/backend_debug_info.h>
#include <torch/csrc/jit/backends/backend_resolver.h>
#include <memory>
#include <stack>
#include <unordered_map>
namespace torch::jit::detail {
namespace {
/*
* This is the API via which backend's preprocess function will obtain debug
* handles corresponding to the nodes of the graph for the lowered methods of
* the module.
* Implementation: Given graph
* For each node of the graph, request debug handle via debug_info_recorder.
* debug_info_recorder returns the next debug handle and record node with
* corresponding debug info, such as source range and inlined callstack.
*
* Backend code for lowering module, preprocess, calls
* generate_debug_handles(graph)) which will return debug handles corresponding
* to the Node* of the said graph.
*
* In to_backend, after lowering, stopRecording is called on
* BackendModuleDebugInfoRecorder: It will extract debug map. This map gets
* stored as part of the lowered module.
* During serialization, specifically for bytecode serialization, check is made
* to see if the model being serialized has any lowered modules. If so
* corresponding debug map is extracted and serialized.
*/
NodeToDebugHandle generate_debug_handles(
BackendDebugInfoRecorder& debug_info_recorder,
const std::shared_ptr<Graph>& graph) {
NodeToDebugHandle node_to_debug_handles;
std::stack<Block*> blocks_to_visit;
// TODO: Look into using DepthFirstGraphNodeIterator
// At the moment it takes non-const graph but maybe we can make it
// general such that it can work with both.
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
DebugHandleType debug_handle = debug_info_recorder.getNextDebugHandle(n);
node_to_debug_handles.emplace(n, debug_handle);
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
return node_to_debug_handles;
}
std::unordered_map<std::string, BackendPreprocessFunction>&
backendPreprocessFunctions() {
static std::unordered_map<std::string, BackendPreprocessFunction>
preprocess_functions;
return preprocess_functions;
}
} // namespace
bool hasBackendPreprocessFunction(const std::string& name) {
return backendPreprocessFunctions().count(name);
}
void registerBackendPreprocessFunction(
const std::string& name,
const BackendPreprocessFunction& preprocess) {
TORCH_CHECK(
!detail::hasBackendPreprocessFunction(name),
"Preprocessing function for backend ",
name,
" is already registered. Ensure that registration is only called once.");
detail::backendPreprocessFunctions()[name] = preprocess;
}
BackendPreprocessFunction getBackendPreprocessFunction(
const std::string& name) {
TORCH_CHECK(
hasBackendPreprocessFunction(name),
"Preprocessing function for backend ",
name,
" is not registered.");
return backendPreprocessFunctions()[name];
}
Module codegen_backend_module(
const std::string& backend_name,
const Module& orig_module,
const c10::Dict<IValue, IValue>& method_compile_spec,
const c10::DictTypePtr& any_dict_ty) {
const c10::QualifiedName qual_backend_name(
{"__torch__", "torch", "classes", kBackendsNamespace, backend_name});
// TODO: Validate method_compile_spec.
// Clone orig_module to make sure backend transformation is
// functional.
auto cloned_module = orig_module.clone();
auto module_name = orig_module.type()->name()->qualifiedName();
// Generate LoweredModule.
Module loweredModule(
"torch.jit.LoweredModule." + backend_name + "." + module_name,
std::make_shared<CompilationUnit>(),
/*shouldMangle=*/true);
// Generate WrapperModule.
Module wrapper(
"torch.jit.LoweredWrapper." + backend_name + "." + module_name,
std::make_shared<CompilationUnit>(),
/*shouldMangle=*/true);
// 1. Initialized debug info recorder.
// 2. Later call debug_info_recorder.stopRecording() to gather
// recorded debug info and save it in __backend_debug_info.
BackendDebugInfoRecorder debug_info_recorder;
// Generate attributes.
// This is the preprocessed module.
// For backwards compatibility, for backends that implement preprocessing in
// the backend interface rather than as a separate function, we just pass
// the cloned original Module.
BackendDebugHandleGenerator debug_handle_generator =
[&](const std::shared_ptr<Graph>& g) {
return generate_debug_handles(debug_info_recorder, g);
};
loweredModule.register_attribute(
"__processed_module",
AnyType::get(),
detail::getBackendPreprocessFunction(backend_name)(
cloned_module, method_compile_spec, debug_handle_generator),
/*is_param=*/false);
// This is for the method_compile_spec passed in to to_<backend> or
// loaded from an exported model.
loweredModule.register_attribute(
"__method_compile_spec",
any_dict_ty,
method_compile_spec,
/*is_param=*/false);
// This is a pointer to a backend instance that is used to access
// compile and execute functions.
auto cls = getCustomClass(qual_backend_name.qualifiedName());
TORCH_INTERNAL_ASSERT(cls);
c10::intrusive_ptr<torch::CustomClassHolder> backend;
loweredModule.register_attribute(
"__backend", cls, IValue::make_capsule(backend));
// This is the list of opaque backend handles returned by
// backend.compile.
loweredModule.register_attribute(
"__handles",
any_dict_ty,
c10::impl::GenericDict(
any_dict_ty->getKeyType(), any_dict_ty->getValueType()),
/*is_param=*/false);
// Methods.
// This is a helper function for creating a new instance of the
// backend class.
static const auto create_backend_ct = at::jit::CodeTemplate(R"(
def __create_backend(self):
self.__backend = $name()
)");
at::jit::TemplateEnv create_backend_te;
create_backend_te.s("name", qual_backend_name.qualifiedName());
loweredModule.define(
create_backend_ct.format(create_backend_te), loweredModuleResolver());
// Helper function to expose backend.is_available() to Module generation code.
// Assumes self.__backend exists (i.e. __create_backend() has already been
// invoked).
loweredModule.define(
R"(
def __is_available(self):
return self.__backend.is_available()
)",
loweredModuleResolver());
// backend_debug_info_class is an instance of BackendDebugInfo that
// stores debug information.
// The purpose of this class is to make the debug information available
// at model saving time for serializing it outside of the lowered module,
// while still tying it to the module's lifetime (so it gets destroyed along
// with it).
// Whereas this information is not serialized as part of the lowered
// module, we still need to provide a valid instance of the
// BackendDebugInfo class when the lowered module is deserialized.
// Since the deserialized modules does not need this information,
// we create a "dummy" instance with no extra code dependencies (to avoid
// overhead) when the backend is created in __setstate__.
c10::intrusive_ptr<torch::CustomClassHolder> backend_debug_info_class;
const c10::QualifiedName backend_debug_info_class_name(
{"__torch__",
"torch",
"classes",
kBackendUtilsNamespace,
kBackendDebugInfoClass});
auto debug_info_cls =
getCustomClass(backend_debug_info_class_name.qualifiedName());
TORCH_CHECK(debug_info_cls, "BackendDebugInfo class must be available.");
loweredModule.register_attribute(
"__backend_debug_info",
OptionalType::create(debug_info_cls),
IValue::make_capsule(backend_debug_info_class));
static const auto create_backend_debug_info_ct = at::jit::CodeTemplate(R"(
def __create_backend_debug_info(self):
self.__backend_debug_info = $backend_debug_info()
)");
at::jit::TemplateEnv create_backend_debug_info_te;
create_backend_debug_info_te.s(
"backend_debug_info", backend_debug_info_class_name.qualifiedName());
loweredModule.define(
create_backend_debug_info_ct.format(create_backend_debug_info_te),
loweredModuleResolver());
// getstate and setstate are for serialization/deserialization of
// the LoweredModule.
// setstate is in charge of initializing self.__backend by invoking
// __create_backend().
loweredModule.define(
R"(
def __getstate__(self):
# The third parameter indicates whether __setstate__ must create
# the backend instance. It's hardcoded to True since the only
# case it can be false is when __setstate__ is called from
# outside the module (at module creation time), because
# __create_backed has been called already (also directly).
return self.__method_compile_spec, self.__processed_module, True
)",
loweredModuleResolver());
loweredModule.define(
R"(
def __setstate__(self, state):
self.__method_compile_spec = state[0]
self.__processed_module = state[1]
# state[2] indicates whether to create the backend instance.
if state[2]:
self.__create_backend()
self.__create_backend_debug_info()
if self.__backend.is_available() :
self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec)
else:
raise Exception("Backend is not available.")
)",
loweredModuleResolver());
// This loop generates one method on the LoweredModule for every key
// in method_compile_spec.
std::vector<std::string> wrapper_methods;
for (auto& e : method_compile_spec) {
std::string method_name = e.key().toStringRef();
static const auto method_ct = at::jit::CodeTemplate(R"(
def $method(self${,def_inputs}):
typed_inputs: List[Any] = [${fwd_inputs,}]
if self.__backend.is_available() :
$unpack, = self.__backend.execute(self.__handles["$method"], typed_inputs)
${refine,}
return $ret
else:
raise Exception("Backend is not available.")
)");
static const auto wrapper_method_ct = at::jit::CodeTemplate(R"(
def $method(self${,def_inputs}):
return self.__loweredModule__.$method(${fwd_inputs})
)");
at::jit::TemplateEnv method_te, wrapper_method_te;
method_te.s("method", method_name);
wrapper_method_te.s("method", method_name);
auto method = orig_module.get_method(method_name);
auto& function = method.function();
auto& schema = function.getSchema();
// Generate the inputs for the function signature (def_inputs) and
// for passing to backend.execute (fwd_inputs).
std::vector<std::string> def_inputs, fwd_inputs;
for (const auto& arg : schema.arguments()) {
auto name = arg.name();
// Skip self since that is only and always present in the
// signature.
if (name == "self") {
continue;
}
auto default_value = arg.default_value();
if (arg.kwarg_only()) {
// If this is a kwarg, it needs to be emitted as keyword=value
// in the definition and keyword=keyword in the call to
// backend_execute.
TORCH_INTERNAL_ASSERT(default_value.has_value());
std::stringstream def_ss, fwd_ss;
// Annotate type of the arg
def_ss << name << ": " << arg.type()->annotation_str(nullptr) << "=";
fwd_ss << name << "=" << name;
default_value->repr(
def_ss, [](std::ostream&, const IValue&) -> bool { return false; });
def_inputs.emplace_back(def_ss.str());
fwd_inputs.emplace_back(fwd_ss.str());
} else {
// If this is not a kwarg, it should be emitted as is in the
// signature and the call to backend_execute.
std::stringstream def_ss;
// Annotate type of the arg
def_ss << name << ": " << arg.type()->annotation_str(nullptr);
def_inputs.emplace_back(def_ss.str());
fwd_inputs.emplace_back(name);
}
}
// Generate a comma-delimited list of identifiers to unpack
// outputs, as well as a list of isinstance checks to make sure
// the backend returned the types it was supposed to.
std::stringstream out_ss, type_check_ss;
std::vector<std::string> type_checks;
TORCH_INTERNAL_ASSERT(schema.returns().size() == 1);
auto out_ty = schema.returns().at(0).type();
out_ss << "_0";
type_check_ss << "assert isinstance(_0, ";
auto out_tuple_ty = out_ty->cast<TupleType>();
if (out_tuple_ty) {
auto tuple_elements = out_tuple_ty->elements();
type_check_ss << tuple_elements[0]->annotation_str() << ")";
type_checks.emplace_back(type_check_ss.str());
for (unsigned i = 1, e = tuple_elements.size(); i < e; ++i) {
type_check_ss.str(std::string());
type_check_ss.clear();
out_ss << ", _" << i;
type_check_ss << "assert isinstance(_" << i << ", "
<< tuple_elements[i]->annotation_str() << ")";
type_checks.emplace_back(type_check_ss.str());
}
} else {
type_check_ss << out_ty->annotation_str() << ")";
type_checks.emplace_back(type_check_ss.str());
}
method_te.v("def_inputs", def_inputs);
method_te.v("fwd_inputs", fwd_inputs);
method_te.v("refine", type_checks);
method_te.s("unpack", out_ss.str());
wrapper_method_te.v("def_inputs", def_inputs);
wrapper_method_te.v("fwd_inputs", fwd_inputs);
wrapper_methods.emplace_back(wrapper_method_ct.format(wrapper_method_te));
// If the output type is a single element tuple then add an extra comma
// to ensure the final output maintains this type.
if (out_tuple_ty && out_tuple_ty->elements().size() == 1) {
out_ss << ",";
}
method_te.s("ret", out_ss.str());
loweredModule.define(method_ct.format(method_te), loweredModuleResolver());
}
// If backend is available, call __setstate__ to ensure that the returned
// Module is ready to run.
// Otherwise throw a warning indicating that the resulting Module is not
// ready for execution until is loaded to a device with the backend.
loweredModule.run_method("__create_backend");
if (loweredModule.run_method("__is_available").toBool()) {
auto state = at::ivalue::Tuple::create(
method_compile_spec,
loweredModule.attr("__processed_module"),
/*create_backend*/ false);
loweredModule.run_method("__setstate__", state);
} else {
TORCH_WARN(
"Backend [",
backend_name,
"] is not available. Execution of this Module is still possible by "
"saving and loading on a device where the backend is available.");
}
// stop debug info recording and get debug_info_map
auto debug_info_map = debug_info_recorder.stopRecording();
loweredModule.run_method("__create_backend_debug_info");
auto backend_debug_info = loweredModule.attr("__backend_debug_info")
.toCustomClass<PyTorchBackendDebugInfo>();
backend_debug_info->setDebugInfoMap(std::move(debug_info_map));
// Wrap lowered module to obfuscate custom serialization logic
wrapper.register_module("__loweredModule__", loweredModule);
for (auto& method : wrapper_methods) {
wrapper.define(method);
}
return wrapper;
}
} // namespace torch::jit::detail