#include #include #include #include #include #include namespace torch { namespace jit { namespace SubgraphUtils { namespace { bool hasSubgraph(Node* n) { return n->hasAttribute(attr::Subgraph); } std::vector> gatherLastUses( at::ArrayRef values) { return fmap(values, [&](Value* v) -> c10::optional { return firstOrLastUse(v, /*find_first*/ false); }); } // When merging a node into a subgraph, we wish to preserve all of the // aliasing properties of the node's outputs. It is difficult to track // the node or its contained nodes through all of the ir manipulation // involved in merging; it is pretty easy to uniquely identify the value // based on its uses. We can identify the value by its last use in the graph. // Values which do not have uses or which do not have a last use // outside of the subgraph to be merged into we do not need to track. struct ValueMapper { // `to_merge` is the node we're merginginto a subgraph, `existing_subgraph` is // the subgraph node that we're merging into if it exists ValueMapper( Node* to_merge, AliasDb& db, c10::optional existing_subgraph) { last_uses_ = gatherLastUses(to_merge->outputs()); if (existing_subgraph) { existing_last_uses_ = gatherLastUses((*existing_subgraph)->outputs()); } WithInsertPoint guard(to_merge); auto g = to_merge->owningGraph(); // temporary node to put the aliasing properties of the node before its // merged and destroyed placeholder_node_ = g->insertNode(g->create(prim::Uninitialized, 0)); for (size_t i = 0; i < to_merge->outputs().size(); ++i) { Value* existing = to_merge->outputs().at(i); Value* new_value = placeholder_node_->insertOutput(i)->copyMetadata( to_merge->outputs().at(i)); db.replaceWithNewValue(existing, new_value); } } bool usesEqual(const Use& a, const Use& b) { return a.user == b.user && a.offset == b.offset; } void copyAliasing(Node* merged_node, AliasDb& db) { auto new_outputs = merged_node->outputs(); for (Value* v : new_outputs) { auto maybe_last_use = firstOrLastUse(v, /*find_first*/ false); // if it doesnt have a use it shouldnt have been added as output TORCH_INTERNAL_ASSERT(maybe_last_use); const Use last_use = *maybe_last_use; // existing outputs of the subgraph do not need to have alias db mappings // updated bool is_existing_value = false; for (size_t i = 0; i < existing_last_uses_.size() && !is_existing_value; ++i) { is_existing_value = existing_last_uses_[i].has_value() && usesEqual(*existing_last_uses_[i], last_use); } if (is_existing_value) { continue; } size_t i = 0; while (i < last_uses_.size() && last_uses_.at(i).has_value() && !usesEqual(*last_uses_.at(i), last_use)) { ++i; } TORCH_INTERNAL_ASSERT(i != last_uses_.size()); db.replaceWithNewValue(placeholder_node_->outputs().at(i), v); } placeholder_node_->destroy(); } std::vector> last_uses_; std::vector> existing_last_uses_; Node* placeholder_node_; }; Node* executeSubgraphMergeAndUpdateAliasing( Node* to_merge, c10::optional existing, AliasDb& db, const std::function& merge_fn) { // When we merge a node into a subgraph, the new subgraph outputs // have the same aliasing properties as the original node's outputs. // Here we create a placeholder node, transfer the aliasing properties // to the placeholder, execute the merge, and transfer the aliasing // properties to the appropriate fusion group outputs ValueMapper vm(to_merge, db, existing); Node* fusion_group = merge_fn(); vm.copyAliasing(fusion_group, db); return fusion_group; } // Combine the nodes in two subgraph together. The nodes will end up in // `mergeTo`, and `mergeFrom` is destroyed. void mergeSubgraph(Node* mergeTo, Node* mergeFrom) { bool merge_from_is_after = mergeFrom->isAfter(mergeTo); Node* nodeBeforeMergeFrom = mergeFrom->prev(); Node* nodeAfterMergeFrom = mergeFrom->next(); unmergeSubgraph(mergeFrom); graph_node_list_iterator end_it; graph_node_list_iterator it; if (merge_from_is_after) { it = nodeBeforeMergeFrom->iterator(); end_it = nodeAfterMergeFrom->iterator(); } else { end_it = nodeBeforeMergeFrom->reverseIterator(); it = nodeAfterMergeFrom->reverseIterator(); } ++it; std::vector merged_nodes; while (it != end_it) { Node* node = *it; ++it; mergeNodeIntoSubgraph(node, mergeTo); } } struct topo_cmp_value { bool operator()(Value* a, Value* b) const { if (a->node() == b->node()) { return a->unique() < b->unique(); } return a->node()->isBefore(b->node()); } }; struct topo_cmp_node { bool operator()(Node* a, Node* b) const { return a->isBefore(b); } }; void collectNodesToUnfuse(Node* start, std::set& s) { if (start->kind() == prim::Return || start->kind() == prim::Param) { GRAPH_DEBUG("reached the param or return node", getHeader(start)); return; } if (s.count(start) != 0) { // already visited, no need to visit descendants return; } GRAPH_DEBUG("collectNodesToUnfuse: inserting node ", getHeader(start)); s.insert(start); for (auto o : start->outputs()) { for (auto use : o->uses()) { collectNodesToUnfuse(use.user, s); } } } std::vector> buildAliasedSets( std::shared_ptr subgraph) { auto outputs = subgraph->outputs(); AliasDb alias_db(subgraph); TORCH_INTERNAL_ASSERT(outputs.size() > 1); std::vector> res; for (auto o : outputs) { auto grouped = false; for (auto& s : res) { auto os = *s.begin(); auto aliased = alias_db.mayContainAlias(os, o); GRAPH_DEBUG( "comparing %", o->debugName(), " with %", os->debugName(), " result ", aliased); if (aliased) { s.insert(o); GRAPH_DEBUG("Grouping %", o->debugName(), " with %", os->debugName()); grouped = true; } } if (!grouped) { res.push_back({o}); } } return res; } } // namespace std::shared_ptr getSubgraph(Node* n) { return n->g(attr::Subgraph); } void unmergeSubgraph(Node* subgraphNode) { // Inline the graph, replace uses of node outputs and destroy the node auto outerGraph = subgraphNode->owningGraph(); WithInsertPoint guard(subgraphNode); const auto subgraphOutputs = insertGraph( *outerGraph, *getSubgraph(subgraphNode), subgraphNode->inputs()); AT_ASSERT(subgraphOutputs.size() >= subgraphNode->outputs().size()); for (size_t i = 0; i < subgraphNode->outputs().size(); ++i) { subgraphNode->outputs()[i]->replaceAllUsesWith(subgraphOutputs[i]); } subgraphNode->destroy(); } void collectNestedUses( std::unordered_set& closed_over_values, std::unordered_set& new_values, std::unordered_map& externalValuesMap, Node* input_node) { for (auto input : input_node->inputs()) { if (externalValuesMap.count(input) == 0 && new_values.count(input) == 0) { closed_over_values.insert(input); } } if (input_node->kind() == prim::If) { for (Block* block : input_node->blocks()) { for (Node* node : block->nodes()) { collectNestedUses( closed_over_values, new_values, externalValuesMap, node); } for (Value* v : block->outputs()) { if (externalValuesMap.count(v) == 0 && new_values.count(v) == 0) { closed_over_values.insert(v); } } } } else if (input_node->kind() == prim::Loop) { for (Value* v : input_node->inputs()) { if (externalValuesMap.count(v) == 0 && new_values.count(v) == 0) { closed_over_values.insert(v); } } Block* block = input_node->blocks().at(0); for (Value* v : block->inputs()) { new_values.insert(v); } for (Node* node : block->nodes()) { collectNestedUses( closed_over_values, new_values, externalValuesMap, node); } } else if (input_node->blocks().size() != 0) { TORCH_INTERNAL_ASSERT(false, input_node, " kind not handled yet"); } for (Value* output : input_node->outputs()) { new_values.insert(output); } } std::unordered_set closedOverValues( Node* toMerge, std::unordered_map& externalValuesMap) { std::unordered_set closed_over_values; std::unordered_set new_values; collectNestedUses(closed_over_values, new_values, externalValuesMap, toMerge); return closed_over_values; } void mergeNodeIntoSubgraph( Node* toMerge, Node* subgraphNode, bool destroyNode) { AT_ASSERT(hasSubgraph(subgraphNode) && toMerge != subgraphNode); if (hasSubgraph(toMerge)) { return mergeSubgraph(subgraphNode, toMerge); } auto subgraph = getSubgraph(subgraphNode); // Map from values in the surrounding graph to inputs/outputs in the subgraph std::unordered_map externalValuesMap; AT_ASSERT(subgraphNode->inputs().size() == subgraph->inputs().size()); size_t idx = 0; for (auto input : subgraphNode->inputs()) { externalValuesMap[input] = subgraph->inputs()[idx]; idx++; } for (size_t i = 0; i < subgraphNode->outputs().size(); ++i) { externalValuesMap[subgraphNode->outputs().at(i)] = subgraph->outputs().at(i); } // Add n's inputs to the group's input list if we don't already have them bool merging_node_after_subgraph = toMerge->isAfter(subgraphNode); Node* guard_node = merging_node_after_subgraph ? *subgraph->nodes().end() : *subgraph->nodes().begin(); WithInsertPoint guard(guard_node); std::unordered_set closedValues = closedOverValues(toMerge, externalValuesMap); // There are currently downstream usage that relies on a fixed ordering // of graph inputs. TODO: remove std::vector orderedClosedValues; std::unordered_set orderedSeenValues; for (Value* input : toMerge->inputs()) { orderedClosedValues.push_back(input); orderedSeenValues.insert(input); } for (Value* closedValue : closedValues) { if (!orderedSeenValues.count(closedValue)) { orderedClosedValues.push_back(closedValue); orderedSeenValues.insert(closedValue); } } for (auto input : orderedClosedValues) { if (externalValuesMap.count(input) == 0) { // Clone constants inside the subgraph instead of referencing them, to // enable more optimizations if (auto value = toIValue(input)) { auto nv = subgraph->insertConstant(*value); nv->copyMetadata(input); externalValuesMap[input] = nv; } else { // The common case: this is a regular input, so just register it with // the group node and inner subgraph subgraphNode->addInput(input); auto inputToGraph = subgraph->addInput(); inputToGraph->copyMetadata(input); externalValuesMap[input] = inputToGraph; } } } // Merge the node into the graph auto mergedNode = subgraph->insertNode(subgraph->createClone( toMerge, [&](Value* v) { return externalValuesMap[v]; })); if (!merging_node_after_subgraph) { // If n's outputs were inputs to `group`, remove them since we just merged // n in. // // i.e., // x = f(w); group(x, y, z) becomes group(w, y, z). // x, y, z = f(w); group(x, y, z) becomes group(w). auto inputs = subgraphNode->inputs(); for (size_t i = 0; i < toMerge->outputs().size(); ++i) { auto it = std::find(inputs.begin(), inputs.end(), toMerge->outputs()[i]); if (it != inputs.end()) { size_t p = it - inputs.begin(); subgraphNode->removeInput(p); subgraph->inputs()[p]->replaceAllUsesWith(mergedNode->outputs()[i]); subgraph->eraseInput(p); } } } // Add n's outputs to the group node and inner subgraph outputs. for (const auto i : c10::irange(toMerge->outputs().size())) { auto oldOutput = toMerge->outputs()[i]; auto newOutput = mergedNode->outputs()[i]; subgraph->registerOutput(newOutput); auto groupOutput = subgraphNode->addOutput(); groupOutput->copyMetadata(oldOutput); oldOutput->replaceAllUsesWith(groupOutput); } // Remove the original node now that the merge is complete if (destroyNode) { toMerge->destroy(); } // We wait till destroying `toMerge` before pruning subgraph outputs, // since destroying `toMerge` could cause a subgraph output to no longer // have any uses const auto hasUsesOutsideSubgraph = [&](Value* v) { return std::any_of( v->uses().cbegin(), v->uses().cend(), [&](const Use& use) { return use.user->isAfter(subgraphNode); }); }; for (int64_t i = subgraphNode->outputs().size() - 1; i >= 0; i--) { if (!hasUsesOutsideSubgraph(subgraphNode->outputs().at(i))) { subgraphNode->eraseOutput(i); subgraph->eraseOutput(i); } } } Node* createSingletonSubgraph(Node* n, Symbol subgraphKind) { auto graph = n->owningGraph(); auto subgraph = graph->create(subgraphKind, 0); subgraph->g_(attr::Subgraph, std::make_shared(graph->current_scope())); subgraph->insertBefore(n); mergeNodeIntoSubgraph(n, subgraph); return subgraph; } void mergeNodeIntoSubgraphAndUpdateAliasing( Node* to_merge, Node* subgraphNode, AliasDb& db) { executeSubgraphMergeAndUpdateAliasing(to_merge, subgraphNode, db, [&]() { mergeNodeIntoSubgraph(to_merge, subgraphNode); return subgraphNode; }); } Node* createSingletonSubgraphAndUpdateAliasing( Node* to_merge, Symbol subgraphKind, AliasDb& db) { return executeSubgraphMergeAndUpdateAliasing( to_merge, c10::nullopt, db, [&]() { return createSingletonSubgraph(to_merge, subgraphKind); }); } bool unmergeOutputsAlisingInputs(Node* subgraphNode) { GRAPH_DEBUG("unfuseOutputsAlisingInputs on ", getHeader(subgraphNode)); auto subgraph = subgraphNode->g(attr::Subgraph); AliasDb alias_db(subgraph); std::set nodes; for (auto o : subgraph->outputs()) { if (alias_db.mayContainAlias({o}, subgraph->inputs())) { collectNodesToUnfuse(o->node(), nodes); } } // unfuse in the reverse topo order for (auto it = nodes.rbegin(); it != nodes.rend(); it++) { SubgraphUtils::unmergeNode(*it, subgraphNode); } return !nodes.empty(); } bool unmergeAliasedOutputs(Node* subgraphNode) { GRAPH_DEBUG("unfuseAliasedOutputs on ", getHeader(subgraphNode)); if (subgraphNode->outputs().size() < 2) { return false; } auto subgraph = subgraphNode->g(attr::Subgraph); GRAPH_DUMP("unfuseAliasedOutputs Subgraph ", subgraph); auto sets = buildAliasedSets(subgraph); GRAPH_DEBUG("buildAliasedSets sets.size() = ", sets.size()); std::set nodes; for (auto i : c10::irange(sets.size())) { if (sets[i].size() <= 1) { GRAPH_DEBUG( "Set ", i, " with leader ", (*(sets[i].begin()))->debugName(), " size = ", sets[i].size()); continue; } // we have at least two aliased outputs // we skip the earliest node w.r.t. the topo order // NB. after some nodes are unfused, the outputs of some other nodes // may become the outputs of the subgraph and alias the remaining ones // so we have to re-run this function until there are no more changes auto it = ++sets[i].begin(); while (it != sets[i].end()) { GRAPH_DEBUG( "root aliased value ", (*it)->debugName(), " node ", *(*it)->node()); collectNodesToUnfuse((*it)->node(), nodes); it++; } } // unfuse in the reverse topo order for (auto it = nodes.rbegin(); it != nodes.rend(); it++) { unmergeNode(*it, subgraphNode); } return !nodes.empty(); } void unmergeNode(Node* n, Node* subgraphNode) { // collect output indices GRAPH_DEBUG("unfuseNode node ", getHeader(n)); auto subgraph = n->owningGraph(); std::set node_outputs(n->outputs().begin(), n->outputs().end()); std::set output_indices; std::set node_inputs(n->inputs().begin(), n->inputs().end()); std::unordered_map local_map; auto env = [&](Value* v) { auto it = local_map.find(v); if (it != local_map.end()) { return it->second; } TORCH_INTERNAL_ASSERT( false, "all inputs should've been mapped. Couldn't map %", v->debugName()); return v; }; for (auto i : c10::irange(subgraph->outputs().size())) { if (node_outputs.count(subgraph->outputs().at(i)) != 0) { output_indices.insert(i); } if (node_inputs.count(subgraph->outputs().at(i)) != 0) { GRAPH_DEBUG( "output %", subgraph->outputs().at(i)->debugName(), " is already subgraph's output"); GRAPH_DEBUG( "Mapping %", subgraph->outputs().at(i)->debugName(), " to %", subgraphNode->outputs().at(i)->debugName()); local_map[subgraph->outputs().at(i)] = subgraphNode->outputs().at(i); node_inputs.erase(subgraph->outputs().at(i)); } } WithInsertPoint wip(subgraphNode->next()); // these node inputs need to be added to subgraph's outputs // put them in vmap for (auto ni : node_inputs) { if (local_map.count(ni) != 0) { // this could happen if `n` uses two or more outputs // of a constant node and we already cloned the constant // into the outer graph and mapped its outputs continue; } Value* sno = nullptr; if (ni->node()->kind() == prim::Constant) { auto copy = subgraphNode->owningGraph()->createClone(ni->node(), env); subgraphNode->owningGraph()->insertNode(copy); // in case we have a multi-output const, map the rest of the outputs // so when we get to clone `n`, `n`'s clone will use the outputs of this // constant clone for (auto i : c10::irange(n->outputs().size())) { GRAPH_DEBUG( "Mapping %", ni->node()->output(i)->debugName(), " to %", copy->output(i)->debugName()); local_map[ni->node()->output(i)] = copy->output(i); } } else { subgraph->registerOutput(ni); sno = subgraphNode->addOutput(); sno->setType(ni->type()); GRAPH_DEBUG("Mapping %", ni->debugName(), " to %", sno->debugName()); local_map[ni] = sno; } } auto copy = subgraphNode->owningGraph()->createClone(n, env); GRAPH_DEBUG("copy ", *copy); for (auto i : c10::irange(n->outputs().size())) { auto oo = n->outputs()[i]; auto no = copy->outputs()[i]; no->copyMetadata(oo); GRAPH_DEBUG("Mapping %", oo->debugName(), " to %", no->debugName()); local_map[oo] = no; } subgraphNode->owningGraph()->insertNode(copy); for (auto it = output_indices.rbegin(); it != output_indices.rend(); it++) { auto replace_val = local_map[subgraph->outputs().at(*it)]; subgraphNode->outputs().at(*it)->replaceAllUsesWith(replace_val); subgraphNode->eraseOutput(*it); subgraph->eraseOutput(*it); } n->destroy(); } std::string truncateStrWithHash(const std::string& s, size_t maxlen) { if (s.size() <= maxlen) { return s; } std::string hash_str = c10::to_string(c10::hash{}(s)); // If hash-string plus '_' can fit into maxlen, then truncate the original // string correspondingly so that the final string with the hash included fits // into maxlen. If that's not possible, at least truncate the original string // to maxlen (and appen the hash to it). size_t trunc_len = (maxlen > hash_str.size() + 1) ? (maxlen - hash_str.size() - 1) : maxlen; std::stringstream truncated; truncated << s.substr(0, trunc_len); truncated << "_" << hash_str; return truncated.str(); } std::string generateNameForGraph( const std::shared_ptr& graph, size_t maxlen, const std::string& prefix) { std::stringstream graph_name; graph_name << prefix; for (Node* node : graph->nodes()) { if (!node->kind().is_aten()) { continue; } graph_name << "_" << node->kind().toUnqualString(); } return truncateStrWithHash(graph_name.str(), maxlen); } } // namespace SubgraphUtils } // namespace jit } // namespace torch