#include #include #include namespace torch::jit::onnx { namespace { using scope_list = std::vector; // Annotated attributes retrieved from module by inspecting module annotations. // These attributes are not used inside the subgraph of ONNX local function // because they are not created by PyTorch JIT tracing, but they may be used by // consumers to determine whether or not to replace the function with a // particular fused kernel. static std::unordered_map scope_attr_map_; static std::shared_ptr scope_attr_graph_ = std::make_shared(); static bool HasSameAttribute( const Node* a, const Node* b, const c10::Symbol& attr); struct FunctionExtractor { public: FunctionExtractor( std::shared_ptr& graph, const std::unordered_set& module_names, const std::vector& param_names) : graph_(graph), module_names_(module_names.begin(), module_names.end()), param_names_(param_names.begin(), param_names.end()) {} NodeAttrNameMap run(); private: struct ScopeContext { std::unordered_set children_; ScopePtr scope_; node_list nlist_; value_list inputs_; value_list outputs_; std::unordered_map env_to_subgraph_; void PopulateInputsOutputs( const std::unordered_set& param_names); bool IsIdenticalFuncion(const ScopeContext& other_ctx) const; }; using ScopeCtxPtr = ScopeContext*; using scope_ctx_map = std::unordered_map; struct FunctionContext { FunctionContext( ScopePtr key, const scope_list& scopes, scope_ctx_map& scope_ctxs); void DebugPrint() const; void SetAttrName(Node* ref_n, Symbol attr, const std::string& name); std::optional FindAttrName(Node* ref_n, Symbol attr); std::optional FindAttrName(Node* ref_const_n); ScopePtr scope_key_; scope_ctx_map scope_ctxs_; std::unordered_map< Node*, std::unordered_map>> attribute_map_; // Passed later to serialization. NodeAttrNameMap node_attr_to_name_; }; using FunctionCtxPtr = FunctionContext*; using func_ctx_map = std::unordered_map; static bool IsValidScope(const ScopePtr& s); static std::optional InferScope(Node* n); static bool IsAncestor(const ScopePtr& parent, ScopePtr child); static std::optional FindCommonAncestor(ScopePtr a, ScopePtr b); static std::optional FindCommonAncestor(const scope_list& scopes); std::shared_ptr ConstructFuncGraph(FunctionContext& ctx); void ConvertScopeToFunction( const ScopePtr& scope_key, const scope_list& scope_list, scope_ctx_map& scope_ctxs, const std::shared_ptr& graph); static void HandleNoScopeNodes( scope_ctx_map& /*scope_ctxs*/, const node_list& no_scope_nlist); std::tuple PartitionNodesByScope(Block* b); scope_ctx_map PartitionNodesByScope(const std::shared_ptr& graph); static std::unordered_map PartitionIdenticalScopes( scope_ctx_map& scope_ctxs); static scope_list SortScopesByMaxDepth( std::unordered_map& /*identical_scope_map*/); Node* CreateFunctionDefNode( FunctionContext& func_ctx, const std::shared_ptr& graph, const std::string& domain_name, const std::string& func_name); Node* CreateFunctionNode( FunctionContext& func_ctx, ScopeContext& scope_ctx, const std::shared_ptr& graph, const std::string& domain_name, const std::string& func_name); static void DebugPrintScopeContexts(const scope_ctx_map& /*scope_ctxs*/); static void DebugPrintGraphWithFunction(const std::shared_ptr& g); static void DebugPrintConstantDiff(const FunctionContext&); std::shared_ptr graph_; std::unordered_set module_names_; std::unordered_set param_names_; // Track modules with same module name that are exported as different onnx // local functions. std::unordered_map module_variant_count_; func_ctx_map func_ctxs_; }; FunctionExtractor::FunctionContext::FunctionContext( ScopePtr key, const scope_list& scopes, scope_ctx_map& scope_ctxs) : scope_key_(std::move(key)) { GRAPH_UPDATE( "Process function context for scope ", scope_key_->name().toDisplayString()); TORCH_INTERNAL_ASSERT(!scopes.empty()); const auto& ref_ctx = scope_ctxs[scope_key_]; // NOTE: Function scopes must have same number and order of nodes. GRAPH_DEBUG( "Initialized function context for scope ", scope_key_->name().toDisplayString()); for (const auto& scope : scopes) { GRAPH_DEBUG( "Process function context for scope ", scope->name().toDisplayString()); TORCH_INTERNAL_ASSERT(scope_ctxs.find(scope) != scope_ctxs.end()); scope_ctxs_[scope] = scope_ctxs[scope]; if (scope_key_ == scope) { continue; } auto& scope_ctx = scope_ctxs[scope]; const auto& ns_a = ref_ctx->nlist_; const auto& ns_b = scope_ctx->nlist_; TORCH_INTERNAL_ASSERT(ns_a.size() == ns_b.size()); GRAPH_DEBUG("Process nodes of scope ", scope->name().toDisplayString()); for (const auto i : c10::irange(ns_a.size())) { TORCH_INTERNAL_ASSERT(ns_a[i]->kind() == ns_b[i]->kind()); auto n_a = ns_a[i]; auto n_b = ns_b[i]; std::vector diff_attrs; std::vector same_attrs; auto n_a_attr_names = n_a->attributeNames(); auto n_b_attr_names = n_b->attributeNames(); std::sort(n_a_attr_names.begin(), n_a_attr_names.end()); std::sort(n_b_attr_names.begin(), n_b_attr_names.end()); std::set_difference( n_a_attr_names.begin(), n_a_attr_names.end(), n_b_attr_names.begin(), n_b_attr_names.end(), std::inserter(diff_attrs, diff_attrs.begin())); std::set_intersection( n_a_attr_names.begin(), n_a_attr_names.end(), n_b_attr_names.begin(), n_b_attr_names.end(), std::inserter(same_attrs, same_attrs.begin())); for (auto attr_name : diff_attrs) { attribute_map_[n_a][attr_name].insert(n_b); } for (auto attr_name : same_attrs) { if (!HasSameAttribute(n_a, n_b, attr_name)) { attribute_map_[n_a][attr_name].insert(n_b); } } } GRAPH_DEBUG("Process scope complete. ", scope->name().toDisplayString()); } GRAPH_DEBUG( "Process function context complete. ", scope_key_->name().toDisplayString()); DebugPrint(); } void FunctionExtractor::FunctionContext::DebugPrint() const { GRAPH_DEBUG("Scope name: ", scope_key_->name().toDisplayString()); for (const auto& it : attribute_map_) { for (const auto& attr_it : it.second) { GRAPH_DEBUG( "Attribute value difference for attribute ", attr_it.first.toDisplayString()); GRAPH_DEBUG(*it.first); for (auto n : attr_it.second) { GRAPH_DEBUG(*n); } } } } void FunctionExtractor::FunctionContext::SetAttrName( Node* ref_n, Symbol attr, const std::string& name) { auto v_it = scope_ctxs_[scope_key_]->env_to_subgraph_.find(ref_n->outputs().at(0)); TORCH_INTERNAL_ASSERT( v_it != scope_ctxs_[scope_key_]->env_to_subgraph_.end()); auto* n_in_def = v_it->second->node(); node_attr_to_name_[n_in_def][attr.toUnqualString()] = name; } std::optional FunctionExtractor::FunctionContext::FindAttrName( Node* ref_n, Symbol attr) { auto v_it = scope_ctxs_[scope_key_]->env_to_subgraph_.find(ref_n->outputs().at(0)); if (v_it == scope_ctxs_[scope_key_]->env_to_subgraph_.end()) { return std::nullopt; } auto* n_in_def = v_it->second->node(); auto n_attr_it = node_attr_to_name_.find(n_in_def); if (n_attr_it == node_attr_to_name_.end()) { return std::nullopt; } auto name_it = n_attr_it->second.find(attr.toUnqualString()); if (name_it == n_attr_it->second.end()) { return std::nullopt; } return name_it->second; } void FunctionExtractor::DebugPrintScopeContexts( const scope_ctx_map& scope_ctxs) { for (auto& it : scope_ctxs) { GRAPH_UPDATE( "Scope name: ", it.first->namesFromRoot(), " ", it.first->name().toDisplayString()); GRAPH_UPDATE("Children scopes: ", [&]() { std::stringstream ss; for (const auto& child_scope : it.second->children_) { ss << child_scope->name().toDisplayString() << " "; } return ss.str(); }()); GRAPH_UPDATE("Node types: \n", [&]() { std::stringstream ss; for (auto n : it.second->nlist_) { ss << " " << *n; } return ss.str(); }()); GRAPH_UPDATE("Node count: ", it.second->nlist_.size()); } } void FunctionExtractor::DebugPrintGraphWithFunction( const std::shared_ptr& g) { GRAPH_UPDATE("Local function definitions:"); for (auto* n : g->nodes()) { if (n->kind() == Symbol::onnx("LocalFunctionDef")) { GRAPH_UPDATE( n->s(attr::name), " graph: ", n->g(Symbol::attr("graph"))->toString()); } } GRAPH_UPDATE("Main graph: ", g->toString()); } bool FunctionExtractor::IsValidScope(const ScopePtr& s) { return !s->isRoot() && !s->isBlank(); } bool FunctionExtractor::IsAncestor(const ScopePtr& parent, ScopePtr child) { if (!IsValidScope(parent) || !IsValidScope(child) || parent->getDepth() >= child->getDepth()) { return false; } do { child = child->parent(); if (parent == child) { return true; } } while (IsValidScope(child)); return false; } std::optional FunctionExtractor::FindCommonAncestor( ScopePtr a, ScopePtr b) { if (!IsValidScope(a) || !IsValidScope(b)) { return std::nullopt; } auto diff = static_cast(a->getDepth()) - static_cast(b->getDepth()); if (diff != 0) { auto deeper_scope = diff > 0 ? a : b; auto other_scope = diff > 0 ? b : a; diff = std::abs(diff); while (diff > 0) { deeper_scope = deeper_scope->parent(); diff--; } a = deeper_scope; b = other_scope; } while (IsValidScope(a) && IsValidScope(b)) { if (a == b) { return a; } else { a = a->parent(); b = b->parent(); } } return std::nullopt; } std::optional FunctionExtractor::FindCommonAncestor( const scope_list& scopes) { if (scopes.empty()) { return std::nullopt; } std::optional common_ancestor = scopes.at(0); for (const auto& scope : scopes) { common_ancestor = FindCommonAncestor(common_ancestor.value(), scope); if (!common_ancestor.has_value()) { return std::nullopt; } } return common_ancestor; } std::optional FunctionExtractor::InferScope(Node* n) { // The scope of node n is assigned based on the following rules. // 1. If all uses of outputs of n belongs to the same scope, // assign that scope, otherwise // 2. If all nodes of inputs of n belongs to the same scope, // assign that scope, otherwise // 3. Find common ancestor of the scopes of uses of outputs of n, // and the scopes of nodes of inputs of n. scope_list input_scopes; scope_list output_scopes; for (auto input : n->inputs()) { input_scopes.emplace_back(input->node()->scope()); } for (auto output : n->outputs()) { for (auto use : output->uses()) { if (!IsValidScope(use.user->scope())) { auto inferred_output_scope = InferScope(use.user); if (inferred_output_scope.has_value() && IsValidScope(inferred_output_scope.value())) { use.user->setScope(inferred_output_scope.value()); } } output_scopes.emplace_back(use.user->scope()); } } if (!output_scopes.empty() && std::all_of( output_scopes.begin(), output_scopes.end(), [&output_scopes](const ScopePtr& scope) -> bool { return IsValidScope(scope) && scope == output_scopes.at(0); })) { return output_scopes.at(0); } else if ( !input_scopes.empty() && std::all_of( input_scopes.begin(), input_scopes.end(), [&input_scopes](const ScopePtr& scope) -> bool { return IsValidScope(scope) && scope == input_scopes.at(0); })) { return input_scopes.at(0); } else { scope_list scopes; std::copy_if( input_scopes.begin(), input_scopes.end(), std::back_inserter(scopes), IsValidScope); std::copy_if( output_scopes.begin(), output_scopes.end(), std::back_inserter(scopes), IsValidScope); if (!scopes.empty()) { auto common_ancestor = FindCommonAncestor(scopes); if (common_ancestor.has_value() && IsValidScope(common_ancestor.value())) { return common_ancestor; } } } return std::nullopt; } std::shared_ptr FunctionExtractor::ConstructFuncGraph( FunctionContext& func_ctx) { auto& ctx = *func_ctx.scope_ctxs_[func_ctx.scope_key_]; const auto& nlist = ctx.nlist_; const auto& scope = ctx.scope_; auto& env = ctx.env_to_subgraph_; auto g = std::make_shared(); GRAPH_DEBUG("Constructing graph for ", scope->namesFromRoot()); // TODO: Update input names of function to match those in Module source code // signature. // This requires mapping between function node inputs and Module inputs. // Due to the lack of such mapping, currently debugName is used as input // names. ctx.PopulateInputsOutputs(param_names_); for (auto* v : ctx.inputs_) { env[v] = g->addInput()->copyMetadata(v); GRAPH_DEBUG( "Add input value ", env[v]->debugName(), " for outer scope value ", v->debugName(), " from ", *v->node()); } for (auto* n : nlist) { auto clone_n = g->createClone(n, [&](Value* v) { TORCH_INTERNAL_ASSERT(env.find(v) != env.end()); return env[v]; }); for (const auto i : c10::irange(clone_n->outputs().size())) { env[n->output(i)] = clone_n->output(i); } g->insertNode(clone_n); } // If values are used outside of this graph, set as graph output. for (auto* v : ctx.outputs_) { TORCH_INTERNAL_ASSERT(env.find(v) != env.end()); g->registerOutput(env[v]); } GRAPH_DEBUG(g->toString()); return g; } Node* FunctionExtractor::CreateFunctionDefNode( FunctionContext& func_ctx, const std::shared_ptr& graph, const std::string& domain_name, const std::string& func_name) { const auto func_def_nk = Symbol::onnx("LocalFunctionDef"); const auto func_g_attr = Symbol::attr("graph"); const auto func_name_attr = attr::name; const auto func_domain_attr = Symbol::attr("domain"); auto func_graph = ConstructFuncGraph(func_ctx); // create and insert local function definition node auto func_def_n = graph->create(func_def_nk, 0); func_def_n->g_(func_g_attr, func_graph); func_def_n->s_(func_name_attr, func_name); func_def_n->s_(func_domain_attr, domain_name); graph->prependNode(func_def_n); // set constants and attributes of different values as function attributes. std::unordered_map base_attr_name_count; std::vector final_attr_names; auto adjust_attr_name = [&](std::string attr_name) { if (base_attr_name_count.find(attr_name) != base_attr_name_count.end()) { attr_name = attr_name + "." + std::to_string(base_attr_name_count[attr_name]++); } else { base_attr_name_count[attr_name] = 1; } return attr_name; }; for (const auto& n_it : func_ctx.attribute_map_) { auto* n = n_it.first; for (const auto& attr_it : n_it.second) { const auto& attr = attr_it.first; // Add prefix "inferred::" to name of inferred attribute. // This is to differentiate from annotated attributes picked up // from python module annotation. auto attr_name = "inferred::" + std::string(n->kind().toUnqualString()) + '_' + attr.toUnqualString(); auto final_attr_name = adjust_attr_name(attr_name); final_attr_names.emplace_back(final_attr_name); func_ctx.SetAttrName(n, attr, final_attr_name); } } // Set annotated attributes std::unordered_set annotated_attr_names; bool first_iteration = true; for (const auto& it : func_ctx.scope_ctxs_) { auto scope = it.first; auto annotated_attr_node = scope_attr_map_.find(scope); if (annotated_attr_node != scope_attr_map_.end()) { auto names = annotated_attr_node->second->attributeNames(); if (first_iteration) { std::copy( names.begin(), names.end(), std::inserter(annotated_attr_names, annotated_attr_names.end())); first_iteration = false; } else { auto unseen_attr_name = std::find_if( names.begin(), names.end(), [&annotated_attr_names](const Symbol& name) { return annotated_attr_names.find(name) == annotated_attr_names.end(); }); TORCH_CHECK( unseen_attr_name == names.end(), "Found outstanding annotated attribute ", *unseen_attr_name, " from module ", scope->name(), ". Please ensure module instances of the same class have the same set of annotated attributes."); } } } for (auto attr_name : annotated_attr_names) { final_attr_names.emplace_back(attr_name.toUnqualString()); } func_def_n->ss_(Symbol::attr("attributes"), final_attr_names); return func_def_n; } Node* FunctionExtractor::CreateFunctionNode( FunctionContext& func_ctx, ScopeContext& scope_ctx, const std::shared_ptr& graph, const std::string& domain_name, const std::string& func_name) { const auto& func_scope = func_ctx.scope_key_; GRAPH_DEBUG( "Create and insert local function for scope: ", func_scope->namesFromRoot()); scope_ctx.PopulateInputsOutputs(param_names_); auto last_n = *scope_ctx.nlist_.rbegin(); auto func_n = graph->create( Symbol::fromQualString(domain_name + "::" + func_name), scope_ctx.outputs_.size()); func_n->copyMetadata(last_n); for (auto* v : scope_ctx.inputs_) { func_n->addInput(v); } for (const auto i : c10::irange(scope_ctx.outputs_.size())) { func_n->output(i)->setType(scope_ctx.outputs_[i]->type()); scope_ctx.outputs_[i]->replaceAllUsesWith(func_n->output(i)); } // set attributes of different values as function attributes. auto copy_attr = [](Node* a, Node* b, Symbol attr, const std::string& new_name) { #define COPY_ATTR(kind) \ case AttributeKind::kind: { \ b->kind##_(Symbol::attr(new_name), a->kind(attr)); \ break; \ } switch (a->kindOf(attr)) { COPY_ATTR(f) COPY_ATTR(fs) COPY_ATTR(i) COPY_ATTR(is) COPY_ATTR(s) COPY_ATTR(ss) COPY_ATTR(t) COPY_ATTR(ts) #undef COPY_ATTR case AttributeKind::ival: case AttributeKind::g: case AttributeKind::gs: case AttributeKind::ty: case AttributeKind::tys: case AttributeKind::c: default: TORCH_INTERNAL_ASSERT( false, "Unexpected attribute type ", static_cast(a->kindOf(attr)), " from node ", *a); break; } }; for (const auto& it : func_ctx.attribute_map_) { auto* ref_n = it.first; for (const auto& attr_it : it.second) { const auto& attr = attr_it.first; auto attr_name = func_ctx.FindAttrName(ref_n, attr).value(); copy_attr(ref_n, func_n, attr, attr_name); for (auto* n : scope_ctx.nlist_) { if (attr_it.second.find(n) != attr_it.second.end()) { copy_attr(n, func_n, attr, attr_name); break; } } } } // annotated attributes auto scope = scope_ctx.scope_; auto annotated_attr_node = scope_attr_map_.find(scope); if (annotated_attr_node != scope_attr_map_.end()) { auto node = annotated_attr_node->second; for (auto attr : node->attributeNames()) { copy_attr(node, func_n, attr, attr.toUnqualString()); } } func_n->insertAfter(last_n); return func_n; } void FunctionExtractor::ConvertScopeToFunction( const ScopePtr& scope_key, const scope_list& scope_list, scope_ctx_map& scope_ctxs, const std::shared_ptr& graph) { // This function needs to be called always on inner most scopes. // 1. Generate function context, this identifies different constants and // attributes. // 2. Create function definition node, and insert to main graph. // 3. Create function node for each call, and replace subgraph nodes in parent // functions. func_ctxs_.insert(std::make_pair( scope_key, new FunctionContext(scope_key, scope_list, scope_ctxs))); auto& func_ctx = *func_ctxs_[scope_key]; const std::string module_class_name( ONNXScopeName::className(func_ctx.scope_key_)); auto pos = module_class_name.rfind('.'); TORCH_INTERNAL_ASSERT(pos != std::string::npos); auto construct_unique_module_name = [&](std::string module_name) { auto module_name_variant = module_variant_count_.find(module_name); if (module_name_variant != module_variant_count_.end()) { module_variant_count_[module_name]++; module_name += ("." + std::to_string(module_name_variant->second)); } else { module_variant_count_[module_name] = 0; } return module_name; }; const auto domain_name = module_class_name.substr(0, pos); const auto func_name = construct_unique_module_name(module_class_name.substr(pos + 1)); CreateFunctionDefNode(func_ctx, graph, domain_name, func_name); // create and insert local function node to graph. for (const auto& it : func_ctx.scope_ctxs_) { auto scope = it.first; auto& scope_ctx = *it.second; auto func_n = CreateFunctionNode(func_ctx, scope_ctx, graph, domain_name, func_name); std::unordered_set old_nodes( scope_ctx.nlist_.begin(), scope_ctx.nlist_.end()); auto last_n = *scope_ctx.nlist_.rbegin(); // replace function body nodes in parent scopes with local function node. for (auto& it : scope_ctxs) { const auto& parent_scope = it.first; auto& parent_ctx = *it.second; if (!IsAncestor(parent_scope, scope)) { continue; } auto& ctx_nlist = parent_ctx.nlist_; GRAPH_DEBUG( "Replace local function node in parent scope: ", it.first->namesFromRoot(), " nodes to remove: ", old_nodes.size(), " parent total nodes: ", ctx_nlist.size()); // insert local function node auto last_n_it = std::find(ctx_nlist.begin(), ctx_nlist.end(), last_n); ctx_nlist.insert(last_n_it, func_n); // remove replaced nodes from list ctx_nlist.erase( std::remove_if( ctx_nlist.begin(), ctx_nlist.end(), [&old_nodes](Node* n) { return old_nodes.find(n) != old_nodes.end(); }), ctx_nlist.end()); GRAPH_DEBUG("Parent total nodes after remove: ", ctx_nlist.size()); // refresh inputs/outputs. parent_ctx.PopulateInputsOutputs(param_names_); } } for (const auto& it : func_ctx.scope_ctxs_) { auto& scope_ctx = *it.second; // delete replaced nodes in graph. for (auto it = scope_ctx.nlist_.rbegin(); it != scope_ctx.nlist_.rend();) { auto* n = *it; it++; GRAPH_DEBUG("Destroying node ", *n); n->destroy(); } } } bool FunctionExtractor::ScopeContext::IsIdenticalFuncion( const ScopeContext& other_ctx) const { // Differentiate same function under different inputs. // When constants are passed in place of inputs, it leads to different // input count and node count. Likewise, due to different uses, output // count can be different as well. // For now export them as different functions. // Covered by `test_local_function_overloads` in // `test/onnx/test_utility_funs.py`. if (&other_ctx == this) { return true; } if (ONNXScopeName::className(this->scope_) != ONNXScopeName::className(other_ctx.scope_)) { return false; } if (this->inputs_.size() != other_ctx.inputs_.size() || this->outputs_.size() != other_ctx.outputs_.size()) { return false; } const auto& ns_a = this->nlist_; const auto& ns_b = other_ctx.nlist_; if (ns_a.size() != ns_b.size()) { return false; } for (const auto i : c10::irange(ns_a.size())) { if (ns_a[i]->kind() != ns_b[i]->kind()) { return false; } } return true; } void FunctionExtractor::ScopeContext::PopulateInputsOutputs( const std::unordered_set& param_names) { inputs_.clear(); outputs_.clear(); const auto& nlist = this->nlist_; std::unordered_set v_set; std::unordered_set n_set; value_list input_list; value_list initializer_list; // Add initializers after inputs. for (auto* n : nlist) { for (auto* v : n->inputs()) { if (v_set.find(v) == v_set.end()) { if (param_names.find(v->debugName()) != param_names.end()) { initializer_list.emplace_back(v); } else { input_list.emplace_back(v); } v_set.insert(v); } } for (auto* v : n->outputs()) { v_set.insert(v); } n_set.insert(n); } for (auto* v : input_list) { inputs_.emplace_back(v); } for (auto* v : initializer_list) { inputs_.emplace_back(v); } for (auto* n : nlist) { for (auto* v : n->outputs()) { bool used_outside = false; for (auto use : v->uses()) { used_outside |= (n_set.find(use.user) == n_set.end()); } if (used_outside) { outputs_.emplace_back(v); } } } } void FunctionExtractor::HandleNoScopeNodes( scope_ctx_map& scope_ctxs, const node_list& no_scope_nlist) { GRAPH_UPDATE("No scope node count: ", no_scope_nlist.size()); for (auto n : no_scope_nlist) { TORCH_WARN( "ONNX function extraction cannot determine the scope for node: ", *n); } TORCH_INTERNAL_ASSERT( no_scope_nlist.empty(), "ONNX function extraction cannot determine the scope for the above nodes."); } std::tuple FunctionExtractor:: PartitionNodesByScope(Block* b) { scope_ctx_map scope_ctxs = {}; node_list no_scope_nlist; auto find_or_create_scope_ctx = [](scope_ctx_map& scope_ctxs, const ScopePtr& scope) { if (scope_ctxs.find(scope) == scope_ctxs.end()) { scope_ctxs.insert(std::make_pair(scope, new ScopeContext())); } return scope_ctxs[scope]; }; auto record_node_scope = [&scope_ctxs, &find_or_create_scope_ctx](Node* n) { const auto& scope = n->scope(); find_or_create_scope_ctx(scope_ctxs, scope)->scope_ = scope; auto tmp_scope = scope; while (IsValidScope(tmp_scope)) { find_or_create_scope_ctx(scope_ctxs, tmp_scope)->nlist_.emplace_back(n); if (IsValidScope(tmp_scope->parent())) { find_or_create_scope_ctx(scope_ctxs, tmp_scope->parent()) ->children_.insert(tmp_scope); } tmp_scope = tmp_scope->parent(); } }; for (auto* n : b->nodes()) { auto scope = n->scope(); if (scope && IsValidScope(scope)) { record_node_scope(n); } else { auto inferred_scope = InferScope(n); if (inferred_scope.has_value() && IsValidScope(inferred_scope.value())) { n->setScope(inferred_scope.value()); record_node_scope(n); } else { GRAPH_UPDATE("Cannot infer proper scope for node: ", *n); no_scope_nlist.emplace_back(n); } } for (auto* sub_b : n->blocks()) { auto [subblock_scope_ctxs, subblock_no_scope_nlist] = PartitionNodesByScope(sub_b); for (auto& it : subblock_scope_ctxs) { if (scope_ctxs.find(it.first) == scope_ctxs.end()) { scope_ctxs.insert(std::make_pair(it.first, it.second)); } else { for (auto* s_n : it.second->nlist_) { scope_ctxs[it.first]->nlist_.emplace_back(s_n); } for (const auto& s_child_scope : it.second->children_) { scope_ctxs[it.first]->children_.insert(s_child_scope); } } } no_scope_nlist.insert( no_scope_nlist.end(), subblock_no_scope_nlist.begin(), subblock_no_scope_nlist.end()); } } for (auto& it : scope_ctxs) { it.second->scope_ = it.first; it.second->PopulateInputsOutputs(param_names_); } return std::tie(scope_ctxs, no_scope_nlist); } FunctionExtractor::scope_ctx_map FunctionExtractor::PartitionNodesByScope( const std::shared_ptr& graph) { scope_ctx_map scope_ctxs; node_list no_scope_nlist; std::tie(scope_ctxs, no_scope_nlist) = PartitionNodesByScope(graph->block()); HandleNoScopeNodes(scope_ctxs, no_scope_nlist); return scope_ctxs; } std::unordered_map FunctionExtractor:: PartitionIdenticalScopes(FunctionExtractor::scope_ctx_map& scope_ctxs) { std::unordered_map identical_scope_map; for (auto& it : scope_ctxs) { auto scope = it.first; const auto& scope_ctx = it.second; bool unique = true; for (auto& kv_it : identical_scope_map) { auto key_scope = kv_it.first; const auto& key_scope_ctx = scope_ctxs[key_scope]; auto& key_scope_vec = kv_it.second; if (key_scope_ctx->IsIdenticalFuncion(*scope_ctx)) { key_scope_vec.emplace_back(scope); unique = false; break; } } if (unique) { identical_scope_map[scope].emplace_back(scope); } } return identical_scope_map; } static bool HasSameAttribute( const Node* a, const Node* b, const c10::Symbol& attr) { if (!a->hasAttribute(attr) && !b->hasAttribute(attr)) { return true; } if (!a->hasAttribute(attr) || !b->hasAttribute(attr)) { return false; } auto a_kind = a->kindOf(attr); auto b_kind = b->kindOf(attr); if (a_kind != b_kind) { return false; } #define COMP_ATTR(kind) \ case AttributeKind::kind: { \ const auto& a_v = a->kind(attr); \ const auto& b_v = b->kind(attr); \ return a_v == b_v; \ } switch (a_kind) { COMP_ATTR(f) COMP_ATTR(fs) COMP_ATTR(i) COMP_ATTR(is) COMP_ATTR(s) COMP_ATTR(ss) #undef COMP_ATTR case AttributeKind::t: { const auto& a_v = a->t(attr); const auto& b_v = b->t(attr); return a_v.equal(b_v); } case AttributeKind::ts: { const auto& a_v = a->ts(attr); const auto& b_v = b->ts(attr); return std::equal( a_v.begin(), a_v.end(), b_v.begin(), b_v.end(), [](const at::Tensor& a_t, const at::Tensor& b_t) { return a_t.equal(b_t); }); } case AttributeKind::ival: case AttributeKind::g: case AttributeKind::gs: case AttributeKind::ty: case AttributeKind::tys: case AttributeKind::c: default: TORCH_INTERNAL_ASSERT( false, "Unexpected attribute type ", static_cast(a_kind), " from node ", *a); break; } return true; } scope_list FunctionExtractor::SortScopesByMaxDepth( std::unordered_map& identical_scope_map) { std::unordered_map scope_max_depth; for (const auto& it : identical_scope_map) { const auto& scopes = it.second; size_t max_depth = 0; for (const auto& scope : scopes) { if (scope->getDepth() > max_depth) { max_depth = scope->getDepth(); } } scope_max_depth[it.first] = max_depth; } scope_list sorted_scopes; sorted_scopes.reserve(scope_max_depth.size()); for (const auto& it : scope_max_depth) { sorted_scopes.emplace_back(it.first); } std::sort( sorted_scopes.begin(), sorted_scopes.end(), [&scope_max_depth](const ScopePtr& a, const ScopePtr& b) -> bool { return scope_max_depth[a] >= scope_max_depth[b]; }); return sorted_scopes; } NodeAttrNameMap FunctionExtractor::run() { auto scope_ctxs = PartitionNodesByScope(graph_); DebugPrintScopeContexts(scope_ctxs); auto identical_scope_map = PartitionIdenticalScopes(scope_ctxs); // Deepest scope comes first, guaranteeing no other scope can be its child. auto sorted_scope_keys = SortScopesByMaxDepth(identical_scope_map); for (const auto& scope_key : sorted_scope_keys) { if (module_names_.find(ONNXScopeName::className(scope_key)) != module_names_.end()) { ConvertScopeToFunction( scope_key, identical_scope_map[scope_key], scope_ctxs, graph_); } GRAPH_DEBUG("Main graph afterwards: ", graph_->toString()); } DebugPrintGraphWithFunction(graph_); // Construct return mappings NodeAttrNameMap node_attr_to_name; for (const auto& it : func_ctxs_) { auto func_ref_map = it.second->node_attr_to_name_; node_attr_to_name.insert(func_ref_map.begin(), func_ref_map.end()); } // Clear for (auto& it : scope_ctxs) { delete it.second; } scope_ctxs.clear(); for (auto& it : func_ctxs_) { delete it.second; } func_ctxs_.clear(); return node_attr_to_name; } // Retrieves the node representing the most recent // ScopePtr. This function should only be invoked from module forward hook. At // this point, module forward call is completed, and the most recent ScopePtr // is popped from TracingState. // This function inspects the node, and its subblock, to find // the node associated with the most recent ScopePtr. Node* NodeOfMostRecentScope(Node* forward_node) { TORCH_INTERNAL_ASSERT( forward_node->kind() == prim::TracedModuleForward, "forward_node got kind: ", forward_node->kind().toDisplayString()); auto* block = forward_node->blocks()[0]; for (auto* node : block->nodes().reverse()) { if (node->kind() == prim::TracedModuleForward) { Node* target_node = NodeOfMostRecentScope(node); if (scope_attr_map_.find(node->scope()) == scope_attr_map_.end()) { return target_node; } } } return forward_node; } } // namespace // FunctionExtractor runs in the following steps. Updates are made inplace to // the graph argument. // 1. Partition nodes into groups based on their scope information. // Each scope represents an individual nn.Module call. A ScopeContext object // is created for each group. // 2. Compare and find groups with the same subgraph pattern from step 1. // 3. Scopes are nested. Starting from the deepest scope, extract the // subgraph pattern, and define as local function node. Replace subgraph // pattern with a single node of the new local function node type. A // FunctionContext object is created for each function. // 4. Construct NodeAttrNameMap tracking mapping from attribute name of // IR Node inside function subgraph, to function attribute name. NodeAttrNameMap ONNXFunctionExtraction( std::shared_ptr& graph, const std::unordered_set& module_names, const std::vector& param_names) { GRAPH_UPDATE( "Export these module forward calls as functions: ", std::vector{module_names.begin(), module_names.end()}); FunctionExtractor fe(graph, module_names, param_names); return fe.run(); } void ONNXClearScopeRecords() { scope_attr_map_.clear(); scope_attr_graph_ = std::make_shared(); } void ONNXTrackScopeAttributes( std::shared_ptr& graph, std::map& attributes) { // Skip the "real" last node which is `return_node`. auto* last_node = graph->nodes().back()->prev(); auto* scope_node = NodeOfMostRecentScope(last_node); auto* attr_node = scope_attr_graph_->create(prim::TracedModuleForward); attr_node->setScope(scope_node->scope()); TORCH_INTERNAL_ASSERT( scope_attr_map_.find(scope_node->scope()) == scope_attr_map_.end()); scope_attr_map_[scope_node->scope()] = attr_node; for (const auto& it : attributes) { auto k = Symbol::attr(it.first); auto v = it.second; if (v.isTensor()) { attr_node->t_(k, v.toTensor()); } else if (v.isInt()) { attr_node->i_(k, v.toInt()); } else if (v.isDouble()) { attr_node->f_(k, v.toDouble()); } else if (v.isBool()) { attr_node->i_(k, v.toBool()); } else if (v.isString()) { attr_node->s_(k, v.toStringRef()); } else if (v.isIntList()) { attr_node->is_(k, v.toIntList().vec()); } else if (v.isBoolList()) { auto bool_list = v.toBoolList(); attr_node->is_( k, std::vector(bool_list.begin(), bool_list.end())); } else if (v.isDoubleList()) { attr_node->fs_(k, v.toDoubleList().vec()); } } } } // namespace torch::jit::onnx