Files
pytorch/torch/csrc/jit/passes/onnx/function_substitution.cpp
2025-02-15 03:36:59 +00:00

196 lines
7.3 KiB
C++

#include <torch/csrc/jit/passes/onnx/function_substitution.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/onnx/helper.h>
#include <torch/csrc/jit/passes/onnx/naming.h>
namespace torch::jit {
namespace {
const std::string kTopModuleVariableName;
std::string TidyClassNameFromTorchScript(
const std::optional<c10::QualifiedName>& class_name) {
if (!class_name) {
return "UNKNOWN_CLASS";
}
std::string out;
for (const auto& atom : class_name->atoms()) {
bool is_internal_torch_atom = (atom == "__torch__");
bool is_mangle_atom = (atom.find("__torch_mangle") != std::string::npos);
if (!is_internal_torch_atom && !is_mangle_atom) {
if (!out.empty()) {
out += ".";
}
out += atom;
}
}
return out;
}
std::string GetCallNodeVariableName(const Node* call_node) {
TORCH_INTERNAL_ASSERT(
call_node->kind() == prim::CallFunction ||
call_node->kind() == prim::CallMethod);
auto module_node = call_node->input(0)->node();
if (!module_node->hasAttribute(attr::name)) {
return "";
}
std::string module_name = module_node->s(attr::name);
if (module_node->inputs().empty()) {
return module_name;
}
// If module is from container, attr::name in module node only carries
// index info. Need to check parent node (container) for variable name.
auto parent_module_value = module_node->input(0);
while (parent_module_value) {
auto parent_module_type = parent_module_value->type()->cast<ClassType>();
if (parent_module_type &&
parent_module_type->name() ==
"__torch__.torch.nn.modules.container.ModuleList") {
auto parent_module_node = parent_module_value->node();
module_name = parent_module_node->s(attr::name) + "." + module_name;
parent_module_value = !parent_module_node->inputs().empty()
? parent_module_node->input(0)
: nullptr;
} else {
break;
}
}
return module_name;
}
ScopePtr ForwardCallScope(Graph& graph, Node* call_node) {
TORCH_INTERNAL_ASSERT(call_node->kind() == prim::CallMethod);
const std::string& method_name = call_node->s(attr::name);
if (method_name == "forward") {
const auto type = call_node->input(0)->type()->expect<c10::NamedType>();
const std::string class_name = TidyClassNameFromTorchScript(type->name());
const std::string variable_name = GetCallNodeVariableName(call_node);
const std::string scope_name =
onnx::ONNXScopeName::createFullScopeName(class_name, variable_name);
return graph.current_scope()->push(Symbol::scope(scope_name));
}
return graph.current_scope();
}
void functionCallSubstitution(Block* block) {
auto graph = block->owningGraph();
for (auto it = block->nodes().begin(), end = block->nodes().end();
it != end;) {
Node* cur = *it++;
switch (cur->kind()) {
case prim::CallFunction: {
TORCH_INTERNAL_ASSERT(cur->input(0)->node()->kind() == prim::Constant);
auto function_constant = cur->input(0)->node();
auto fun_type =
function_constant->output()->type()->expect<FunctionType>();
if ((fun_type->function()->qualname().qualifiedName().find(
"torch.nn.functional") != std::string::npos) &&
(fun_type->function()->qualname().qualifiedName().find(
"interpolate") != std::string::npos)) {
// Remove input[0] and the node that feeds into it
auto input_node_0 = cur->input(0)->node();
cur->removeInput(0);
if (!input_node_0->hasUses()) {
input_node_0->destroy();
}
Node* interpolate_node = block->owningGraph()->create(
Symbol::fromQualString("aten::__interpolate"),
{cur->inputs()},
cur->outputs().size());
interpolate_node->output()->copyMetadata(cur->output());
interpolate_node->insertAfter(cur);
interpolate_node->copyMetadata(cur);
cur->replaceAllUsesWith(interpolate_node);
cur->removeAllInputs();
cur->destroy();
GRAPH_UPDATE(
"ONNX function call substitution function: '",
fun_type->function()->name(),
"' to aten::__interpolate");
GRAPH_UPDATE(
"Function in ONNX function call substitution body: ",
toGraphFunction(*fun_type->function()).optimized_graph());
} else {
// Remove input[0] and the node that feeds into it
auto input_node_0 = cur->input(0)->node();
cur->removeInput(0);
if (!input_node_0->hasUses()) {
input_node_0->destroy();
}
auto& graphFunction = toGraphFunction(*fun_type->function());
functionCallSubstitution(graphFunction.graph()->block());
inlineCallTo(cur, &graphFunction, false);
}
} break;
case prim::CallMethod: {
const std::string& name = cur->s(attr::name);
if (auto class_type = cur->input(0)->type()->cast<ClassType>()) {
Function& function = class_type->getMethod(name);
ScopePtr call_scope = ForwardCallScope(*graph, cur);
WithCurrentScope scope_guard(*graph, call_scope);
GRAPH_DEBUG(
"Setting scope guard for forward call: ",
graph->current_scope()->namesFromRoot());
if (auto graphFunction = tryToGraphFunction(function)) {
GRAPH_DEBUG(
"Inner graph for method call ",
name,
": ",
*graphFunction->graph());
WithCurrentScope inner_graph_scope_guard(
*graphFunction->graph(), call_scope);
functionCallSubstitution(graphFunction->graph()->block());
inlineCallTo(cur, graphFunction, false);
}
}
} break;
default: {
if (!graph->current_scope()->isBlank()) {
cur->setScope(graph->current_scope());
}
for (auto b : cur->blocks()) {
functionCallSubstitution(b);
}
} break;
}
GRAPH_DEBUG(
"Graph current scope after node process: ",
graph->current_scope()->namesFromRoot());
}
}
ScopePtr ONNXGraphTopLevelScope(Graph& graph) {
if (graph.inputs().empty()) {
return graph.current_scope();
}
if (auto top_module_type = graph.inputs().at(0)->type()->cast<ClassType>()) {
auto scope_name = ::torch::jit::onnx::ONNXScopeName::createFullScopeName(
TidyClassNameFromTorchScript(top_module_type->name()),
kTopModuleVariableName);
return graph.current_scope()->push(Symbol::scope(scope_name));
}
return graph.current_scope();
}
} // namespace
// This pass is to be used for ONNX conversion only. The ONNX converter depends
// on a number of deprecated aten operators. These operators are removed from IR
// and replaced by the compiled python function code. However, in-order to
// maintain the behavior for ONNX conversion, we replace these function calls
// with the aten symbolic which can still be used by the ONNX converter.
void ONNXFunctionCallSubstitution(Graph& graph) {
GRAPH_DUMP("Before function call substitution calls: ", &graph);
WithCurrentScope top_level_scope_guard(graph, ONNXGraphTopLevelScope(graph));
functionCallSubstitution(graph.block());
GRAPH_DUMP("After function call substitution calls: ", &graph);
}
} // namespace torch::jit