mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Remove const_cast's from subgraph matcher. (#20303)
Summary: The trick here is that creating a mapping from const values to const values means that downstream clients that want to mutate the output of the mapping are stuck. However, a mapping from const values to non-const values is just fine and doesn't put constraints on downstream clients. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20303 Differential Revision: D15284076 fbshipit-source-id: 16206fd910dd5f83218525ca301b1889df0586cb
This commit is contained in:
committed by
Facebook Github Bot
parent
e47b210075
commit
02df1ccd9c
@ -70,10 +70,10 @@ void SubgraphRewriter::rewriteSinglePatternOnGraph(
|
||||
// we matched.
|
||||
std::vector<Value*> inputs, outputs;
|
||||
for (Value* v : pattern_graph.inputs()) {
|
||||
inputs.push_back(const_cast<Value*>(match.values_map.at(v)));
|
||||
inputs.push_back(match.values_map.at(v));
|
||||
}
|
||||
for (Value* v : pattern_graph.outputs()) {
|
||||
outputs.push_back(const_cast<Value*>(match.values_map.at(v)));
|
||||
outputs.push_back(match.values_map.at(v));
|
||||
}
|
||||
|
||||
// Insert a clone of replacement subgraph after the matched subgraph.
|
||||
@ -81,7 +81,7 @@ void SubgraphRewriter::rewriteSinglePatternOnGraph(
|
||||
// new subgraph, and we will get `new_outputs` vector containing values
|
||||
// produced by this new subgraph - we will then rewrite old outputs with the
|
||||
// new ones.
|
||||
WithInsertPoint insert_point(const_cast<Node*>(match.anchor));
|
||||
WithInsertPoint insert_point(match.anchor);
|
||||
std::vector<Value*> new_outputs =
|
||||
inlineCallTo(*graph, replacement_graph, inputs);
|
||||
|
||||
@ -94,7 +94,7 @@ void SubgraphRewriter::rewriteSinglePatternOnGraph(
|
||||
// Record all planned deletions
|
||||
for (Node* pattern_n : pattern_graph.nodes()) {
|
||||
if (match.nodes_map.count(pattern_n)) {
|
||||
Node* n = const_cast<Node*>(match.nodes_map.at(pattern_n));
|
||||
Node* n = match.nodes_map.at(pattern_n);
|
||||
nodes_to_delete_.insert(n);
|
||||
}
|
||||
}
|
||||
@ -116,7 +116,7 @@ void SubgraphRewriter::rewriteSinglePatternOnGraph(
|
||||
|
||||
bool SubgraphRewriter::overlapsWithPreviousMatches(const Match* match) {
|
||||
for (auto n : match->nodes_map) {
|
||||
if (nodes_to_delete_.count(const_cast<Node*>(n.second))) {
|
||||
if (nodes_to_delete_.count(n.second)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -21,24 +21,24 @@ class SubgraphMatcher {
|
||||
* is the same as in the corresponding matchGraph node, its type is the same,
|
||||
* and all nodes producing input-values also match.
|
||||
*/
|
||||
bool matchesSubgraphFromAnchorNode(const Node* anchor);
|
||||
bool matchesSubgraphFromAnchorNode(Node* anchor);
|
||||
|
||||
/** \brief Return match map for nodes. */
|
||||
std::unordered_map<const Node*, const Node*> nodes_map() const {
|
||||
std::unordered_map<const Node*, Node*> nodes_map() const {
|
||||
return nodes_map_;
|
||||
}
|
||||
|
||||
/** \brief Return match map for values. */
|
||||
std::unordered_map<const Value*, const Value*> values_map() const {
|
||||
std::unordered_map<const Value*, Value*> values_map() const {
|
||||
return values_map_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool matchValues(const Value* v1, const Value* v2);
|
||||
bool matchNodes(const Node* n1, const Node* n2);
|
||||
bool matchValues(const Value* v1, Value* v2);
|
||||
bool matchNodes(const Node* n1, Node* n2);
|
||||
|
||||
std::unordered_map<const Node*, const Node*> nodes_map_;
|
||||
std::unordered_map<const Value*, const Value*> values_map_;
|
||||
std::unordered_map<const Node*, Node*> nodes_map_;
|
||||
std::unordered_map<const Value*, Value*> values_map_;
|
||||
|
||||
const Graph& pattern_;
|
||||
const Node* anchor_ = nullptr;
|
||||
@ -73,7 +73,7 @@ bool patternGraphIsValid(const Graph& pattern) {
|
||||
* 1) the nodes defining them match
|
||||
* 2) they have the same number of uses, except they are entry or exit nodes.
|
||||
*/
|
||||
bool SubgraphMatcher::matchValues(const Value* v1, const Value* v2) {
|
||||
bool SubgraphMatcher::matchValues(const Value* v1, Value* v2) {
|
||||
// Check if we've already visited these values.
|
||||
if (values_map_.count(v1)) {
|
||||
return values_map_.at(v1) == v2;
|
||||
@ -104,7 +104,7 @@ bool SubgraphMatcher::matchValues(const Value* v1, const Value* v2) {
|
||||
* A special case is when N1 is PARAM - this is considered outside the pattern,
|
||||
* so it matches everything.
|
||||
*/
|
||||
bool SubgraphMatcher::matchNodes(const Node* n1, const Node* n2) {
|
||||
bool SubgraphMatcher::matchNodes(const Node* n1, Node* n2) {
|
||||
// Check if we've already visited these nodes.
|
||||
if (nodes_map_.count(n1)) {
|
||||
return nodes_map_.at(n1) == n2;
|
||||
@ -148,7 +148,7 @@ bool SubgraphMatcher::matchNodes(const Node* n1, const Node* n2) {
|
||||
* Recursively try to match pattern with the actual graph starting from the
|
||||
* exiting node in the pattern and anchor node in the actual graph.
|
||||
*/
|
||||
bool SubgraphMatcher::matchesSubgraphFromAnchorNode(const Node* anchor) {
|
||||
bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) {
|
||||
nodes_map_.clear();
|
||||
values_map_.clear();
|
||||
anchor_ = anchor;
|
||||
@ -169,24 +169,24 @@ bool SubgraphMatcher::matchesSubgraphFromAnchorNode(const Node* anchor) {
|
||||
// Main entry point for the subgraph matching.
|
||||
std::vector<Match> findPatternMatches(
|
||||
const Graph& pattern,
|
||||
const Graph& graph) {
|
||||
Graph& graph) {
|
||||
AT_ASSERT(patternGraphIsValid(pattern));
|
||||
|
||||
SubgraphMatcher m(pattern);
|
||||
std::vector<Match> matches;
|
||||
std::stack<const Block*> blocks_to_visit;
|
||||
std::stack<Block*> blocks_to_visit;
|
||||
|
||||
// Iterate over all nodes in the graph (including nodes in subblocks) trying
|
||||
// to match the pattern each node.
|
||||
blocks_to_visit.push(graph.block());
|
||||
while (!blocks_to_visit.empty()) {
|
||||
const Block* block = blocks_to_visit.top();
|
||||
Block* block = blocks_to_visit.top();
|
||||
blocks_to_visit.pop();
|
||||
for (const Node* n : block->nodes()) {
|
||||
for (Node* n : block->nodes()) {
|
||||
if (m.matchesSubgraphFromAnchorNode(n)) {
|
||||
matches.push_back({n, m.nodes_map(), m.values_map()});
|
||||
}
|
||||
for (const Block* subblock : n->blocks()) {
|
||||
for (Block* subblock : n->blocks()) {
|
||||
blocks_to_visit.push(subblock);
|
||||
}
|
||||
}
|
||||
|
@ -17,9 +17,9 @@ namespace jit {
|
||||
* (match-map values). We keep such maps for both nodes and values.
|
||||
*/
|
||||
struct Match {
|
||||
const Node* anchor;
|
||||
std::unordered_map<const Node*, const Node*> nodes_map;
|
||||
std::unordered_map<const Value*, const Value*> values_map;
|
||||
Node* anchor;
|
||||
std::unordered_map<const Node*, Node*> nodes_map;
|
||||
std::unordered_map<const Value*, Value*> values_map;
|
||||
};
|
||||
|
||||
/**
|
||||
@ -42,9 +42,12 @@ struct Match {
|
||||
* - Aliasing nodes in the graph can not consitute a match (i.e. in all found
|
||||
* matches no nodes in the subgraph alias with each other). TODO: the check not
|
||||
* implemented yet.
|
||||
* - The matcher will not mutate either the pattern graph or the matched graph,
|
||||
* but the latter is taken as non-const so that Match may contain non-const
|
||||
* pointers. This enables clients of this API to use Match to drive mutations.
|
||||
*/
|
||||
std::vector<Match> TORCH_API
|
||||
findPatternMatches(const Graph& pattern, const Graph& graph);
|
||||
findPatternMatches(const Graph& pattern, Graph& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user