Remove const_cast's from subgraph matcher. (#20303)

Summary:
The trick here is that creating a mapping from const values to
const values means that downstream clients that want to mutate
the output of the mapping are stuck.  However, a mapping from
const values to non-const values is just fine and doesn't put
constraints on downstream clients.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20303

Differential Revision: D15284076

fbshipit-source-id: 16206fd910dd5f83218525ca301b1889df0586cb
This commit is contained in:
Owen Anderson
2019-05-09 18:03:41 -07:00
committed by Facebook Github Bot
parent e47b210075
commit 02df1ccd9c
3 changed files with 27 additions and 24 deletions

View File

@ -70,10 +70,10 @@ void SubgraphRewriter::rewriteSinglePatternOnGraph(
// we matched.
std::vector<Value*> inputs, outputs;
for (Value* v : pattern_graph.inputs()) {
inputs.push_back(const_cast<Value*>(match.values_map.at(v)));
inputs.push_back(match.values_map.at(v));
}
for (Value* v : pattern_graph.outputs()) {
outputs.push_back(const_cast<Value*>(match.values_map.at(v)));
outputs.push_back(match.values_map.at(v));
}
// Insert a clone of replacement subgraph after the matched subgraph.
@ -81,7 +81,7 @@ void SubgraphRewriter::rewriteSinglePatternOnGraph(
// new subgraph, and we will get `new_outputs` vector containing values
// produced by this new subgraph - we will then rewrite old outputs with the
// new ones.
WithInsertPoint insert_point(const_cast<Node*>(match.anchor));
WithInsertPoint insert_point(match.anchor);
std::vector<Value*> new_outputs =
inlineCallTo(*graph, replacement_graph, inputs);
@ -94,7 +94,7 @@ void SubgraphRewriter::rewriteSinglePatternOnGraph(
// Record all planned deletions
for (Node* pattern_n : pattern_graph.nodes()) {
if (match.nodes_map.count(pattern_n)) {
Node* n = const_cast<Node*>(match.nodes_map.at(pattern_n));
Node* n = match.nodes_map.at(pattern_n);
nodes_to_delete_.insert(n);
}
}
@ -116,7 +116,7 @@ void SubgraphRewriter::rewriteSinglePatternOnGraph(
bool SubgraphRewriter::overlapsWithPreviousMatches(const Match* match) {
for (auto n : match->nodes_map) {
if (nodes_to_delete_.count(const_cast<Node*>(n.second))) {
if (nodes_to_delete_.count(n.second)) {
return true;
}
}

View File

@ -21,24 +21,24 @@ class SubgraphMatcher {
* is the same as in the corresponding matchGraph node, its type is the same,
* and all nodes producing input-values also match.
*/
bool matchesSubgraphFromAnchorNode(const Node* anchor);
bool matchesSubgraphFromAnchorNode(Node* anchor);
/** \brief Return match map for nodes. */
std::unordered_map<const Node*, const Node*> nodes_map() const {
std::unordered_map<const Node*, Node*> nodes_map() const {
return nodes_map_;
}
/** \brief Return match map for values. */
std::unordered_map<const Value*, const Value*> values_map() const {
std::unordered_map<const Value*, Value*> values_map() const {
return values_map_;
}
private:
bool matchValues(const Value* v1, const Value* v2);
bool matchNodes(const Node* n1, const Node* n2);
bool matchValues(const Value* v1, Value* v2);
bool matchNodes(const Node* n1, Node* n2);
std::unordered_map<const Node*, const Node*> nodes_map_;
std::unordered_map<const Value*, const Value*> values_map_;
std::unordered_map<const Node*, Node*> nodes_map_;
std::unordered_map<const Value*, Value*> values_map_;
const Graph& pattern_;
const Node* anchor_ = nullptr;
@ -73,7 +73,7 @@ bool patternGraphIsValid(const Graph& pattern) {
* 1) the nodes defining them match
* 2) they have the same number of uses, except they are entry or exit nodes.
*/
bool SubgraphMatcher::matchValues(const Value* v1, const Value* v2) {
bool SubgraphMatcher::matchValues(const Value* v1, Value* v2) {
// Check if we've already visited these values.
if (values_map_.count(v1)) {
return values_map_.at(v1) == v2;
@ -104,7 +104,7 @@ bool SubgraphMatcher::matchValues(const Value* v1, const Value* v2) {
* A special case is when N1 is PARAM - this is considered outside the pattern,
* so it matches everything.
*/
bool SubgraphMatcher::matchNodes(const Node* n1, const Node* n2) {
bool SubgraphMatcher::matchNodes(const Node* n1, Node* n2) {
// Check if we've already visited these nodes.
if (nodes_map_.count(n1)) {
return nodes_map_.at(n1) == n2;
@ -148,7 +148,7 @@ bool SubgraphMatcher::matchNodes(const Node* n1, const Node* n2) {
* Recursively try to match pattern with the actual graph starting from the
* exiting node in the pattern and anchor node in the actual graph.
*/
bool SubgraphMatcher::matchesSubgraphFromAnchorNode(const Node* anchor) {
bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) {
nodes_map_.clear();
values_map_.clear();
anchor_ = anchor;
@ -169,24 +169,24 @@ bool SubgraphMatcher::matchesSubgraphFromAnchorNode(const Node* anchor) {
// Main entry point for the subgraph matching.
std::vector<Match> findPatternMatches(
const Graph& pattern,
const Graph& graph) {
Graph& graph) {
AT_ASSERT(patternGraphIsValid(pattern));
SubgraphMatcher m(pattern);
std::vector<Match> matches;
std::stack<const Block*> blocks_to_visit;
std::stack<Block*> blocks_to_visit;
// Iterate over all nodes in the graph (including nodes in subblocks) trying
// to match the pattern each node.
blocks_to_visit.push(graph.block());
while (!blocks_to_visit.empty()) {
const Block* block = blocks_to_visit.top();
Block* block = blocks_to_visit.top();
blocks_to_visit.pop();
for (const Node* n : block->nodes()) {
for (Node* n : block->nodes()) {
if (m.matchesSubgraphFromAnchorNode(n)) {
matches.push_back({n, m.nodes_map(), m.values_map()});
}
for (const Block* subblock : n->blocks()) {
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}

View File

@ -17,9 +17,9 @@ namespace jit {
* (match-map values). We keep such maps for both nodes and values.
*/
struct Match {
const Node* anchor;
std::unordered_map<const Node*, const Node*> nodes_map;
std::unordered_map<const Value*, const Value*> values_map;
Node* anchor;
std::unordered_map<const Node*, Node*> nodes_map;
std::unordered_map<const Value*, Value*> values_map;
};
/**
@ -42,9 +42,12 @@ struct Match {
* - Aliasing nodes in the graph can not consitute a match (i.e. in all found
* matches no nodes in the subgraph alias with each other). TODO: the check not
* implemented yet.
* - The matcher will not mutate either the pattern graph or the matched graph,
* but the latter is taken as non-const so that Match may contain non-const
* pointers. This enables clients of this API to use Match to drive mutations.
*/
std::vector<Match> TORCH_API
findPatternMatches(const Graph& pattern, const Graph& graph);
findPatternMatches(const Graph& pattern, Graph& graph);
} // namespace jit
} // namespace torch