mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nativert] oss subgraph rewriter (#160780)
Summary: att Test Plan: ci Rollback Plan: Differential Revision: D80367765 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160780 Approved by: https://github.com/SherlockNoMad, https://github.com/georgiaphillips
This commit is contained in:
@ -631,6 +631,7 @@ libtorch_nativert_sources = [
|
||||
"torch/nativert/kernels/NativeKernels.cpp",
|
||||
"torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp",
|
||||
"torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp",
|
||||
"torch/nativert/graph/passes/SubgraphRewriter.cpp",
|
||||
]
|
||||
|
||||
torch_mobile_tracer_sources = [
|
||||
|
@ -36,6 +36,7 @@ set(NATIVERT_TEST_SRCS
|
||||
${TORCH_ROOT}/torch/nativert/kernels/AutoFunctionalizeKernel.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/CallTorchBindKernel.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/HigherOrderKernel.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/passes/SubgraphRewriter.cpp
|
||||
)
|
||||
|
||||
add_executable(test_nativert
|
||||
|
447
torch/nativert/graph/passes/SubgraphRewriter.cpp
Normal file
447
torch/nativert/graph/passes/SubgraphRewriter.cpp
Normal file
@ -0,0 +1,447 @@
|
||||
#include <variant>
|
||||
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
#include <torch/nativert/graph/passes/SubgraphRewriter.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
const std::string kDummyTarget = "dummy";
|
||||
|
||||
//-------------------------
|
||||
// SubgraphMatcher
|
||||
//-------------------------
|
||||
|
||||
SubgraphMatcher::SubgraphMatcher(const Graph* pattern)
|
||||
: pattern_(pattern), pattern_root_(findRootNode(pattern_)) {}
|
||||
|
||||
const Node* SubgraphMatcher::findRootNode(const Graph* g) {
|
||||
return g->outputNode()->inputs()[0].value->producer();
|
||||
}
|
||||
|
||||
std::optional<Match> SubgraphMatcher::match(Node* target_node) {
|
||||
if (!pattern_root_) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
Match current_match;
|
||||
if (tryMatchNode(pattern_root_, target_node, current_match)) {
|
||||
for (const Value* output : pattern_->outputs()) {
|
||||
TORCH_CHECK(
|
||||
current_match.value_map.find(output) != current_match.value_map.end(),
|
||||
"Not all outputs were matched to the pattern. ",
|
||||
"Please check that the first output node suffices ",
|
||||
"to traverse all output values in the pattern.");
|
||||
}
|
||||
return current_match;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::vector<Match> SubgraphMatcher::matchAll(Graph* graph) {
|
||||
std::vector<Match> matches;
|
||||
|
||||
for (auto& node : graph->nodes()) {
|
||||
auto maybeMatch = match(&node);
|
||||
if (maybeMatch.has_value()) {
|
||||
matches.push_back(*maybeMatch);
|
||||
}
|
||||
}
|
||||
return matches;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool compareConstants(const Constant& a, const Constant& b) {
|
||||
return std::visit(
|
||||
[](const auto& lhs, const auto& rhs) -> bool {
|
||||
using LType = std::decay_t<decltype(lhs)>;
|
||||
using RType = std::decay_t<decltype(rhs)>;
|
||||
|
||||
// Handle directly comparable types
|
||||
if constexpr (
|
||||
std::is_same_v<LType, RType> &&
|
||||
!std::is_same_v<LType, std::unique_ptr<Graph>>) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
// Unsupported types (Graph)
|
||||
LOG(ERROR) << "Unsupported Constant types for pattern matching: "
|
||||
<< typeid(lhs).name() << " vs " << typeid(rhs).name();
|
||||
throw std::runtime_error("Unsupported Constant types.");
|
||||
},
|
||||
a,
|
||||
b);
|
||||
}
|
||||
|
||||
auto findMatchingAttribute(const Node* target_node, const Attribute& attr) {
|
||||
return std::find_if(
|
||||
target_node->attributes().begin(),
|
||||
target_node->attributes().end(),
|
||||
[&](const Attribute& otherAttr) {
|
||||
return attr.name == otherAttr.name &&
|
||||
compareConstants(attr.value, otherAttr.value);
|
||||
});
|
||||
}
|
||||
|
||||
auto findInputByName(const Node* pattern_node, const std::string& inputName) {
|
||||
return std::find_if(
|
||||
pattern_node->inputs().begin(),
|
||||
pattern_node->inputs().end(),
|
||||
[&](const NamedArgument& patternInput) {
|
||||
return inputName == patternInput.name;
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool SubgraphMatcher::tryMatchNodeInputs(
|
||||
const Node* pattern_node,
|
||||
Node* target_node,
|
||||
Match& match) {
|
||||
TORCH_CHECK(
|
||||
pattern_node->numInputs() + pattern_node->attributes().size() ==
|
||||
target_node->numInputs() + target_node->attributes().size());
|
||||
TORCH_CHECK(target_node->numInputs() <= pattern_node->numInputs());
|
||||
TORCH_CHECK(pattern_node->attributes().size() <= target_node->numInputs());
|
||||
|
||||
// Target node inputs should match pattern node inputs
|
||||
for (const auto i : c10::irange(target_node->numInputs())) {
|
||||
// Compare input values
|
||||
// Current target node input should match a pattern node input
|
||||
const auto& inputMatch =
|
||||
findInputByName(pattern_node, target_node->inputs()[i].name);
|
||||
if (inputMatch == pattern_node->inputs().end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const Value* pval = inputMatch->value;
|
||||
Value* tval = target_node->inputs()[i].value;
|
||||
if (!tryMatchValue(pval, tval, match)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern node attributes should match target node attributes
|
||||
std::unordered_set<std::string> matched_attributes;
|
||||
for (const auto i : c10::irange(pattern_node->attributes().size())) {
|
||||
// Compare attributes
|
||||
const auto& attr = pattern_node->attributes()[i];
|
||||
auto it = findMatchingAttribute(target_node, attr);
|
||||
if (it == target_node->attributes().end()) {
|
||||
return false; // Attribute not found or values differ
|
||||
}
|
||||
matched_attributes.insert(it->name);
|
||||
}
|
||||
|
||||
// Target node attributes that do not match pattern node attributes should
|
||||
// match pattern node inputs
|
||||
for (const auto i : c10::irange(target_node->attributes().size())) {
|
||||
const auto& it = target_node->attributes()[i];
|
||||
if (matched_attributes.find(it.name) != matched_attributes.end()) {
|
||||
continue; // Skip attributes already matched
|
||||
}
|
||||
const auto& patternInput = findInputByName(pattern_node, it.name);
|
||||
if (patternInput == pattern_node->inputs().end()) {
|
||||
return false;
|
||||
}
|
||||
if (patternInput->value->producer()->target() != "prim.Input" ||
|
||||
patternInput->value->users().size() > 1) {
|
||||
return false; // Only a pattern graph input should match a constant attr
|
||||
}
|
||||
|
||||
// Insert a dummy node to match the pattern input value
|
||||
// Record the attribute that should be used to replace the dummy node
|
||||
auto* targetGraph = target_node->owningGraph();
|
||||
Node* dummyNode = targetGraph->createNode(kDummyTarget);
|
||||
Value* dummyOutput = dummyNode->addOutput(
|
||||
targetGraph->getUniqueValueName(), Type::Kind::None);
|
||||
targetGraph->insertBefore(dummyNode, target_node);
|
||||
if (match.value_map.find(patternInput->value) != match.value_map.end()) {
|
||||
return match.value_map[patternInput->value]->producer()->target() ==
|
||||
kDummyTarget;
|
||||
}
|
||||
match.value_map[patternInput->value] = dummyOutput;
|
||||
match.dummy_input_to_attribute_map[dummyOutput] = &it.value;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SubgraphMatcher::tryMatchNode(
|
||||
const Node* pattern_node,
|
||||
Node* target_node,
|
||||
Match& match) {
|
||||
if (match.node_map.find(pattern_node) != match.node_map.end()) {
|
||||
return match.node_map[pattern_node] == target_node;
|
||||
}
|
||||
|
||||
// If the pattern node is an input, it should match every node
|
||||
if (pattern_node->target() == "prim.Input") {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (pattern_node->target() != target_node->target() ||
|
||||
pattern_node->numOutputs() != target_node->numOutputs()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t deltaInputCount = static_cast<int64_t>(pattern_node->numInputs()) -
|
||||
static_cast<int64_t>(target_node->numInputs());
|
||||
int64_t deltaAttributesCount =
|
||||
static_cast<int64_t>(pattern_node->attributes().size()) -
|
||||
static_cast<int64_t>(target_node->attributes().size());
|
||||
// Number of inputs and attributes should match exactly
|
||||
// and the pattern should always have >= input count of the target node
|
||||
// and the pattern should always have <= attribute count of the target node
|
||||
if (deltaInputCount + deltaAttributesCount != 0 ||
|
||||
(deltaInputCount < 0 && deltaAttributesCount > 0)) {
|
||||
return false;
|
||||
}
|
||||
match.node_map[pattern_node] = target_node;
|
||||
|
||||
for (const auto i : c10::irange(pattern_node->numOutputs())) {
|
||||
const Value* pval = pattern_node->outputs()[i];
|
||||
Value* tval = target_node->outputs()[i];
|
||||
if (!tryMatchValue(pval, tval, match)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return tryMatchNodeInputs(pattern_node, target_node, match);
|
||||
}
|
||||
|
||||
bool SubgraphMatcher::isOutputValue(const Value* val) {
|
||||
for (const auto& output : pattern_->outputs()) {
|
||||
if (val == output) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool SubgraphMatcher::tryMatchValue(
|
||||
const Value* pval,
|
||||
Value* tval,
|
||||
Match& match) {
|
||||
if (match.value_map.find(pval) != match.value_map.end()) {
|
||||
return match.value_map[pval] == tval;
|
||||
}
|
||||
|
||||
const Node* pProducer = pval->producer();
|
||||
Node* tProducer = tval->producer();
|
||||
// If the value in the pattern is an input, then it could have other uses
|
||||
// outside of the subgraph. Similarly, output values can also have uses
|
||||
// outside of the matching subgraph.
|
||||
if (pval->users().size() != tval->users().size() &&
|
||||
pProducer->target() != "prim.Input" && !isOutputValue(pval)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (pval->type().kind() != tval->type().kind()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
match.value_map[pval] = tval;
|
||||
|
||||
return tryMatchNode(pProducer, tProducer, match);
|
||||
}
|
||||
|
||||
//-------------------------
|
||||
// SubgraphRewriter
|
||||
//-------------------------
|
||||
|
||||
void SubgraphRewriter::registerRewritePattern(
|
||||
const std::string& pattern,
|
||||
const std::string& replacement) {
|
||||
patterns_.emplace_back(RewriteRule{pattern, replacement});
|
||||
}
|
||||
|
||||
bool SubgraphRewriter::run(
|
||||
Graph* graph,
|
||||
const std::vector<MatchFilter>& filters) {
|
||||
bool mutated = false;
|
||||
for (const auto& [pattern, replacement] : patterns_) {
|
||||
const auto& pattern_graph = stringToGraph(pattern);
|
||||
const auto& replacement_graph = stringToGraph(replacement);
|
||||
mutated |= runForPattern(
|
||||
graph, *pattern_graph.get(), *replacement_graph.get(), filters);
|
||||
}
|
||||
return mutated;
|
||||
}
|
||||
|
||||
bool SubgraphRewriter::runForPattern(
|
||||
Graph* graph,
|
||||
const Graph& pattern,
|
||||
const Graph& replacement,
|
||||
const std::vector<MatchFilter>& filters) {
|
||||
SubgraphMatcher matcher(&pattern);
|
||||
std::vector<Match> matches = matcher.matchAll(graph);
|
||||
|
||||
VLOG(1) << "[GraphPasses] Found " << matches.size()
|
||||
<< " matches for : " << name_;
|
||||
|
||||
for (auto& m : matches) {
|
||||
if (!std::all_of(filters.begin(), filters.end(), [&](const MatchFilter& f) {
|
||||
return f(m, getVmap(pattern));
|
||||
})) {
|
||||
continue;
|
||||
}
|
||||
if (!overlapsWithUsedNodes(m, replacedNodes_)) {
|
||||
rewriteMatch(graph, m, pattern, replacement);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* v : valuesToRewrite_) {
|
||||
graph->replaceAllUses(v, valueRewrites_.at(v));
|
||||
}
|
||||
|
||||
for (auto* n : replacedNodes_) {
|
||||
for (const auto& input : n->inputs()) {
|
||||
input.value->eraseUser(n);
|
||||
}
|
||||
n->inputs().clear();
|
||||
}
|
||||
|
||||
for (auto* n : replacedNodes_) {
|
||||
n->destroy();
|
||||
}
|
||||
|
||||
bool mutated = (valuesToRewrite_.size() + valueRewrites_.size() +
|
||||
replacedNodes_.size()) > 0;
|
||||
|
||||
valuesToRewrite_.clear();
|
||||
valueRewrites_.clear();
|
||||
replacedNodes_.clear();
|
||||
|
||||
graph->cleanupDeadNodes();
|
||||
graph->finalize();
|
||||
graph->lint();
|
||||
|
||||
return mutated;
|
||||
}
|
||||
|
||||
bool SubgraphRewriter::overlapsWithUsedNodes(
|
||||
const Match& match,
|
||||
const std::unordered_set<Node*>& usedNodes) {
|
||||
// If any node or value used by this match is already in usedNodes/usedValues,
|
||||
// then this match overlaps with a previously selected match.
|
||||
for (auto& kv : match.node_map) {
|
||||
Node* target_node = kv.second;
|
||||
if (usedNodes.find(target_node) != usedNodes.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void SubgraphRewriter::rewriteMatch(
|
||||
Graph* graph,
|
||||
const Match& match,
|
||||
const Graph& pattern,
|
||||
const Graph& replacement) {
|
||||
// TODO: Preserve original node metadata with python source traceback
|
||||
std::unordered_map<const Value*, Value*> valueMap;
|
||||
|
||||
// Find the point at which to insert the new subgraph
|
||||
// and get pointers to input/output values to insert at
|
||||
Node* insertionPoint = nullptr;
|
||||
std::vector<Value*> inputs, outputs;
|
||||
for (Value* v : pattern.inputs()) {
|
||||
if (match.value_map.find(v) == match.value_map.end()) {
|
||||
continue;
|
||||
}
|
||||
Value* input = match.value_map.at(v);
|
||||
// We want to insert after latest producer of any input that is not a dummy
|
||||
// node
|
||||
if (!insertionPoint ||
|
||||
(insertionPoint->isBefore(input->producer()) &&
|
||||
input->producer()->target() != kDummyTarget)) {
|
||||
insertionPoint = input->producer();
|
||||
}
|
||||
inputs.push_back(input);
|
||||
}
|
||||
TORCH_CHECK(insertionPoint, "No insertion point found");
|
||||
|
||||
// Check we're not inserting after any of the outputs
|
||||
bool insertionPointValid = true;
|
||||
for (const auto* v : pattern.outputs()) {
|
||||
Value* output = match.value_map.at(v);
|
||||
outputs.push_back(match.value_map.at(v));
|
||||
for (const auto* user : output->users()) {
|
||||
if (user->isBefore(insertionPoint)) {
|
||||
insertionPointValid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!insertionPointValid) {
|
||||
return;
|
||||
}
|
||||
std::vector<Value*> newOutputs;
|
||||
{
|
||||
InsertingAfter guard(insertionPoint);
|
||||
|
||||
newOutputs = graph->insertGraph(replacement, inputs, valueMap);
|
||||
}
|
||||
TORCH_CHECK(outputs.size() == newOutputs.size());
|
||||
|
||||
for (auto i : c10::irange(outputs.size())) {
|
||||
valuesToRewrite_.push_back(outputs[i]);
|
||||
valueRewrites_[outputs[i]] = newOutputs[i];
|
||||
}
|
||||
|
||||
for (auto& patternNode : pattern.nodes()) {
|
||||
if (match.node_map.find(&patternNode) != match.node_map.end()) {
|
||||
Node* n = match.node_map.at(&patternNode);
|
||||
replacedNodes_.insert(n);
|
||||
}
|
||||
}
|
||||
|
||||
// Replace dummy values with constant attributes
|
||||
for (const auto& inputToAttr : match.dummy_input_to_attribute_map) {
|
||||
auto* dummy = inputToAttr.first;
|
||||
// dummy might not be used in rewritten graph
|
||||
// e.g., casted_batch_one_hot_lengths
|
||||
if (dummy->users().empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto& userNode : dummy->users()) {
|
||||
auto& userInputs = userNode->inputs();
|
||||
replacedNodes_.insert(dummy->producer());
|
||||
for (auto it = userInputs.begin(); it != userInputs.end(); ++it) {
|
||||
if (it->value == dummy) {
|
||||
Attribute newAttr;
|
||||
std::visit(
|
||||
[&](auto&& val) -> void {
|
||||
using T = std::decay_t<decltype(val)>;
|
||||
if constexpr (std::is_same_v<T, std::unique_ptr<Graph>>) {
|
||||
LOG(ERROR)
|
||||
<< "Graph attributes are not supported yet. Skipping attribute";
|
||||
} else {
|
||||
newAttr.value = val;
|
||||
}
|
||||
},
|
||||
*inputToAttr.second);
|
||||
newAttr.name = it->name;
|
||||
userNode->addAttribute(std::move(newAttr));
|
||||
dummy->eraseUser(userNode);
|
||||
userInputs.erase(it);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c10::FastMap<std::string, const Value*> SubgraphRewriter::getVmap(
|
||||
const Graph& pattern) {
|
||||
c10::FastMap<std::string, const Value*> vmap;
|
||||
for (const auto& v : pattern.inputs()) {
|
||||
vmap[std::string(v->name())] = v;
|
||||
}
|
||||
for (const auto& n : pattern.nodes()) {
|
||||
for (const Value* v : n.outputs()) {
|
||||
vmap[std::string(v->name())] = v;
|
||||
}
|
||||
}
|
||||
return vmap;
|
||||
}
|
||||
} // namespace torch::nativert
|
198
torch/nativert/graph/passes/SubgraphRewriter.h
Normal file
198
torch/nativert/graph/passes/SubgraphRewriter.h
Normal file
@ -0,0 +1,198 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/FbcodeMaps.h>
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
/*
|
||||
* node_map: A map from nodes in the pattern to nodes in the actual graph.
|
||||
* value_map : A map between values in the pattern to values in the actual
|
||||
* graph.
|
||||
* dummy_input_to_attribute_map: A map between the actual dummy input values to
|
||||
* constant attributes in the actual graph that should replace the dummy nodes
|
||||
*/
|
||||
struct Match {
|
||||
std::unordered_map<const Node*, Node*> node_map;
|
||||
std::unordered_map<const Value*, Value*> value_map;
|
||||
std::unordered_map<Value*, const Constant*>
|
||||
dummy_input_to_attribute_map; // For constant attrs matching graph inputs
|
||||
};
|
||||
|
||||
using MatchFilter = std::function<
|
||||
bool(const Match&, const c10::FastMap<std::string, const Value*>&)>;
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const Match& match) {
|
||||
out << "\nNode mapping:\n";
|
||||
for (const auto& kv : match.node_map) {
|
||||
const Node* patternNode = kv.first;
|
||||
Node* targetNode = kv.second;
|
||||
out << " Pattern Node: " << *patternNode
|
||||
<< " -> Target Node: " << *targetNode << "\n";
|
||||
}
|
||||
|
||||
out << "Value mapping:\n";
|
||||
for (const auto& kv : match.value_map) {
|
||||
const Value* patternValue = kv.first;
|
||||
Value* targetValue = kv.second;
|
||||
out << " Pattern Value: " << *patternValue
|
||||
<< " -> Target Value: " << *targetValue << "\n";
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* A helper class for matching a subgraph pattern within a larger graph.
|
||||
* It attempts to match a given `pattern` graph inside a target `graph`,
|
||||
* starting from a single "root" output node in the pattern graph. The
|
||||
* matching process works backward through the graph, comparing each node
|
||||
* in the pattern to corresponding nodes in the candidate graph.
|
||||
*
|
||||
* Note: This implementation currently only supports deterministic matching
|
||||
* for patterns with one output node. It also only matches nodes connecting to
|
||||
* output nodes
|
||||
*
|
||||
* Constraints for Patterns with Multiple Output Nodes:
|
||||
* To avoid an exponential increase in the search space, this implementation
|
||||
* starts searching from the first output node as an anchor as an heuristic. It
|
||||
* assumes that all other output nodes in the pattern are interconnected through
|
||||
* the graph from this anchor node, allowing the matcher to traverse from the
|
||||
* anchor to other outputs.
|
||||
*
|
||||
* Important: The order of output nodes in the pattern matters. For example:
|
||||
*
|
||||
* graph(%x):
|
||||
* %a = a.aaa(input=%x)
|
||||
* %b = b.bbb(input=%a)
|
||||
* return (%a, %b)
|
||||
*
|
||||
* If the search starts from %a, it will not explore the portion of the graph
|
||||
* connected to %b. However, if the order is switched:
|
||||
*
|
||||
* graph(%x):
|
||||
* %a = a.aaa(input=%x)
|
||||
* %b = b.bbb(input=%a)
|
||||
* return (%b, %a)
|
||||
*
|
||||
* The search will start from %b and successfully explore both %b and %a.
|
||||
*/
|
||||
class SubgraphMatcher {
|
||||
public:
|
||||
explicit SubgraphMatcher(const Graph* pattern);
|
||||
|
||||
/// Attempt to match the pattern at a given node in the target graph.
|
||||
/// If successful, returns a Match, otherwise std::nullopt.
|
||||
std::optional<Match> match(Node* target_node);
|
||||
|
||||
std::vector<Match> matchAll(Graph* target_graph);
|
||||
|
||||
private:
|
||||
const Graph* pattern_;
|
||||
const Node* pattern_root_;
|
||||
|
||||
/**
|
||||
* Finds the root output node of a Graph g to start a match from
|
||||
* Note that graphs with multiple output nodes, this will pick the first
|
||||
* output node in the order provided.
|
||||
**/
|
||||
const Node* findRootNode(const Graph* g);
|
||||
|
||||
/**
|
||||
* Tries to match nodes in the pattern_ graph with the target graph, starting
|
||||
* from pattern_node and target_node. Nodes are considered to match if they
|
||||
* have the same target type, and all input and output values to the nodes
|
||||
* match. Matching nodes are stored to `match`
|
||||
**/
|
||||
bool tryMatchNode(const Node* pattern_node, Node* target_node, Match& match);
|
||||
|
||||
/**
|
||||
* Match inputs of pattern_node w/ target_node. Store matching values to
|
||||
*`match`
|
||||
**/
|
||||
bool tryMatchNodeInputs(
|
||||
const Node* pattern_node,
|
||||
Node* target_node,
|
||||
Match& match);
|
||||
|
||||
/**
|
||||
* Tries to match values in the pattern_ graph with the target graph, starting
|
||||
* from pval and tval. Matching values are stored to `match`.
|
||||
**/
|
||||
bool tryMatchValue(const Value* pval, Value* tval, Match& match);
|
||||
|
||||
/**
|
||||
* Returns true of val is an output of its graph, and false otherwise
|
||||
**/
|
||||
bool isOutputValue(const Value* val);
|
||||
};
|
||||
|
||||
struct RewriteRule {
|
||||
std::string pattern;
|
||||
std::string replacement;
|
||||
};
|
||||
|
||||
/**
|
||||
* Rewrite subgraphs in a given graph.
|
||||
* TODO: Write more detailed documentation
|
||||
**/
|
||||
class SubgraphRewriter {
|
||||
public:
|
||||
SubgraphRewriter(const std::string& name) : name_(name) {}
|
||||
|
||||
/**
|
||||
* Registers the rewrite pattern.
|
||||
* @param patternA The subgraph str to match.
|
||||
* @param patternB The subgraph str to replace with.
|
||||
*/
|
||||
void registerRewritePattern(
|
||||
const std::string& pattern,
|
||||
const std::string& replacement);
|
||||
|
||||
/**
|
||||
* Runs the subgraph rewrite process on a graph.
|
||||
* @param graph The graph on which the rewrite is applied.
|
||||
* @param pattern The subgraph to match.
|
||||
* @param replacement The subgraph to replace with.
|
||||
* @param filters A list of filters to apply to the match. If any filter
|
||||
* predicate returns true, the match will not be considered.
|
||||
*/
|
||||
bool /* mutated? */ runForPattern(
|
||||
Graph* graph,
|
||||
const Graph& pattern,
|
||||
const Graph& replacement,
|
||||
const std::vector<MatchFilter>& filters);
|
||||
|
||||
bool /* mutated? */ run(
|
||||
Graph* graph,
|
||||
const MatchFilter& filter =
|
||||
[](const Match&, const c10::FastMap<std::string, const Value*>&) {
|
||||
return true;
|
||||
}) {
|
||||
return run(graph, std::vector<MatchFilter>({filter}));
|
||||
}
|
||||
|
||||
bool /* mutated? */ run(
|
||||
Graph* graph,
|
||||
const std::vector<MatchFilter>& filters);
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
std::vector<RewriteRule> patterns_; // The subgraph pattern to match
|
||||
std::unordered_set<Node*> replacedNodes_;
|
||||
std::vector<Value*> valuesToRewrite_;
|
||||
std::unordered_map<const Value*, Value*> valueRewrites_;
|
||||
|
||||
// Helper methods
|
||||
bool overlapsWithUsedNodes(
|
||||
const Match& match,
|
||||
const std::unordered_set<Node*>& replacedNodes);
|
||||
void rewriteMatch(
|
||||
Graph* graph,
|
||||
const Match& match,
|
||||
const Graph& pattern,
|
||||
const Graph& replacement);
|
||||
|
||||
c10::FastMap<std::string, const Value*> getVmap(const Graph& pattern);
|
||||
};
|
||||
} // namespace torch::nativert
|
Reference in New Issue
Block a user