mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Turn more functions and variables into static if they are not used outside the cpp files. Unused functions are removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150930 Approved by: https://github.com/Skylion007 Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
192 lines
7.3 KiB
C++
192 lines
7.3 KiB
C++
#include <torch/csrc/jit/backends/backend_init.h>
|
|
|
|
#include <pybind11/iostream.h>
|
|
#include <torch/csrc/jit/backends/backend_detail.h>
|
|
#include <torch/csrc/jit/backends/backend_resolver.h>
|
|
#include <torch/csrc/jit/python/module_python.h>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
namespace torch::jit {
|
|
|
|
// Get all types that are shared in the module hierarchy rooted at \p mod.
|
|
static std::unordered_set<TypePtr> getSharedModuleTypes(Module& mod) {
|
|
// Maintain a set of all TypePtrs.
|
|
std::unordered_set<TypePtr> types;
|
|
// Maintain another set of TypePtrs that have been encountered more than once.
|
|
std::unordered_set<TypePtr> duplicate_types;
|
|
|
|
// Iterate over all modules in the hierarchy, including the root.
|
|
for (auto module : mod.modules()) {
|
|
auto module_type = module.type();
|
|
if (types.count(module_type) > 0) {
|
|
duplicate_types.insert(module_type);
|
|
}
|
|
|
|
types.insert(module_type);
|
|
}
|
|
|
|
return duplicate_types;
|
|
}
|
|
|
|
// Selectively lower \p mod to a backend. \p to_backend
|
|
// is called to lower modules. \p modules_to_lower contains
|
|
// qualified names of submodules of \p mod that should be lowered.
|
|
static void toBackendSelectiveImpl(
|
|
Module& mod,
|
|
const py::function& to_backend,
|
|
const std::vector<std::string>& modules_to_lower,
|
|
const std::unordered_set<TypePtr>& duplicate_types) {
|
|
// This map will be used later to remap types in ancestor module graphs for
|
|
// all lowered submodules.
|
|
std::unordered_map<TypePtr, TypePtr> type_remap;
|
|
|
|
// For each module that should be lowered:
|
|
for (const auto& module_to_lower : modules_to_lower) {
|
|
// Use QualifiedName to parse the qualified module names.
|
|
c10::QualifiedName qual_module_name(module_to_lower);
|
|
auto& atoms = qual_module_name.atoms();
|
|
|
|
// Search through the module hierarchy using the atoms of
|
|
// qual_module_name until current points to the module to
|
|
// be lowered and parent points to its parent.
|
|
Module current = mod;
|
|
Module parent;
|
|
|
|
for (size_t i = 0, e = atoms.size(); i < e; ++i) {
|
|
IValue submodule = current.attr(atoms[i]);
|
|
if (submodule.isModule()) {
|
|
if (i == e - 1) {
|
|
parent = current;
|
|
}
|
|
current = submodule.toModule();
|
|
} else {
|
|
std::stringstream err;
|
|
err << "Attribute named " << atoms[i] << " is not a Module";
|
|
throw std::runtime_error(err.str());
|
|
}
|
|
}
|
|
|
|
// Check that the parent type is not shared and therefore can be edited.
|
|
if (duplicate_types.count(parent.type()) > 0) {
|
|
throw py::cast_error(c10::str(
|
|
"Selective lowering is only supported for module hierarchies with unique types for selected modules; ",
|
|
parent.type()->repr_str(),
|
|
" is shared"));
|
|
}
|
|
|
|
// Call to_backend on the module that needs to be lowered. It needs to be
|
|
// wrapped before doing so because _to_jit_backend accepts wrapped modules.
|
|
// The result needs to be unwrapped in order to access its type below.
|
|
auto lowered_submodule =
|
|
py::cast<Module>(to_backend(py::module::import("torch.jit._recursive")
|
|
.attr("wrap_cpp_module")(current))
|
|
.attr("_c"));
|
|
|
|
// Adjust the parent's type so that the type of the submodule matches
|
|
// the type of lowered_submodule.
|
|
auto parent_type = parent.type();
|
|
|
|
parent_type->unsafeChangeAttributeType(
|
|
atoms.back(), lowered_submodule.type());
|
|
parent.setattr(atoms.back(), lowered_submodule._ivalue());
|
|
|
|
// Record the type mapping from old type -> lowered type.
|
|
type_remap[current.type()] = lowered_submodule.type();
|
|
}
|
|
|
|
// Having lowered all of the modules that needed to be lowered, remap types in
|
|
// all graphs in the hierarchy so that the graphs all use the new lowered
|
|
// type.
|
|
auto type_remap_fn = [&type_remap](TypePtr in) {
|
|
auto it = type_remap.find(in);
|
|
if (it == type_remap.end())
|
|
return in;
|
|
return it->second;
|
|
};
|
|
|
|
// modules() iterates over all modules in the hierarchy including the root.
|
|
for (auto module : mod.modules()) {
|
|
auto module_type = module.type();
|
|
for (auto& fn : module_type->methods()) {
|
|
auto method = module.get_method(fn->name());
|
|
auto graph = method.graph();
|
|
graph->remapTypes(type_remap_fn);
|
|
auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn);
|
|
fn->setSchema(new_schema);
|
|
}
|
|
}
|
|
}
|
|
|
|
static Module codegen_func(
|
|
const std::string& backend_name,
|
|
const Module& orig_module,
|
|
const py::dict& method_compile_spec) {
|
|
// Represents of a Type of Dict[str, Any].
|
|
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
|
|
return detail::codegen_backend_module(
|
|
backend_name,
|
|
orig_module,
|
|
toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
|
|
any_dict_ty);
|
|
}
|
|
|
|
void initJitBackendBindings(PyObject* module) {
|
|
// Bind a function for lowering to each JIT backend. The name of the backend
|
|
// must be the first argument. For example, to lower a Module to
|
|
// "example_backend", declared as
|
|
//
|
|
// static auto cls = torch::jit::backend<ExampleBackend>("example_backend");
|
|
//
|
|
// this function must be called like
|
|
//
|
|
// torch._C._jit_to_backend("example_backend", module, spec)
|
|
auto m = py::handle(module).cast<py::module>();
|
|
m.def(
|
|
"_jit_to_backend",
|
|
[=](const std::string& backend_name,
|
|
py::handle orig_module,
|
|
const py::dict& method_compile_spec) {
|
|
py::scoped_ostream_redirect cerr(
|
|
std::cerr, py::module_::import("sys").attr("stderr"));
|
|
py::scoped_ostream_redirect cout(
|
|
std::cout, py::module_::import("sys").attr("stdout"));
|
|
return py::module::import("torch.jit._recursive")
|
|
.attr("wrap_cpp_module")(codegen_func(
|
|
backend_name,
|
|
py::cast<Module>(orig_module.attr("_c")),
|
|
method_compile_spec));
|
|
});
|
|
|
|
m.def(
|
|
"_jit_to_backend_selective",
|
|
[=](py::handle orig_module,
|
|
const py::function& to_backend,
|
|
const std::vector<std::string>& modules_to_lower) {
|
|
py::scoped_ostream_redirect cerr(
|
|
std::cerr, py::module_::import("sys").attr("stderr"));
|
|
py::scoped_ostream_redirect cout(
|
|
std::cout, py::module_::import("sys").attr("stdout"));
|
|
if (auto original_module =
|
|
as_module(py::cast<py::object>(orig_module))) {
|
|
// Clone the Module to avoid editing types that are shared with
|
|
// Modules in other instances outside this hierarchy.
|
|
Module& mod = original_module.value();
|
|
auto cloned_mod = mod.clone();
|
|
// Get all shared module types. Type sharing is only a problem if the
|
|
// parent modules of the ones to lower are in this set.
|
|
auto shared_types = getSharedModuleTypes(cloned_mod);
|
|
toBackendSelectiveImpl(
|
|
cloned_mod, to_backend, modules_to_lower, shared_types);
|
|
// Wrap the result in a RecursiveScriptModule because that's what
|
|
// the caller passed in.
|
|
return py::module::import("torch.jit._recursive")
|
|
.attr("wrap_cpp_module")(cloned_mod);
|
|
}
|
|
|
|
throw py::cast_error(c10::str(
|
|
"Object ", py::str(orig_module), " is not a ScriptModule"));
|
|
});
|
|
}
|
|
} // namespace torch::jit
|