Drop caffe2 nomnigraph (#127086)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127086
Approved by: https://github.com/Skylion007
This commit is contained in:
Richard Barnes
2024-05-28 23:20:43 +00:00
committed by PyTorch MergeBot
parent f6ef832e87
commit 1be7e4086a
15 changed files with 0 additions and 1897 deletions

View File

@ -41,16 +41,6 @@ def get_c2_mpscnn_test():
return bool(int(c2_mpscnn_test)) return bool(int(c2_mpscnn_test))
def get_c2_nomnigraph():
c2_nomnigraph = native.read_config("caffe2", "enable_nomnigraph", "1")
expect(
c2_nomnigraph in ("0", "1"),
c2_nomnigraph,
)
return bool(int(c2_nomnigraph))
def get_c2_qpl(): def get_c2_qpl():
c2_qpl = native.read_config("caffe2", "enable_qpl", "1") c2_qpl = native.read_config("caffe2", "enable_qpl", "1")
@ -125,8 +115,6 @@ C2_XPLAT_HPTT_PREPROCESSOR_FLAGS = [
def get_c2_xplat_preprocessor_flags(): def get_c2_xplat_preprocessor_flags():
flags = get_c2_xplat_no_hptt_preprocessor_flags() + C2_XPLAT_HPTT_PREPROCESSOR_FLAGS flags = get_c2_xplat_no_hptt_preprocessor_flags() + C2_XPLAT_HPTT_PREPROCESSOR_FLAGS
if get_c2_nomnigraph():
flags.append("-DCAFFE2_OPTIMIZER")
return flags return flags
def get_c2_xplat_no_hptt_compiler_flags(): def get_c2_xplat_no_hptt_compiler_flags():

View File

@ -1,22 +0,0 @@
# ---[ CPU files.
file(GLOB_RECURSE NOMNI_SRCS *.cc)
file(GLOB_RECURSE NOMNI_TEST_SRCS *Test.cc)
exclude(NOMNI_SRCS "${NOMNI_SRCS}" "${NOMNI_TEST_SRCS}")
set(NOMNI_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include)
list(APPEND Caffe2_CPU_SRCS ${NOMNI_SRCS})
list(APPEND Caffe2_CPU_INCLUDE ${NOMNI_INCLUDE_DIR})
list(APPEND Caffe2_GPU_INCLUDE ${NOMNI_INCLUDE_DIR})
list(APPEND Caffe2_CPU_TEST_SRCS ${NOMNI_TEST_SRCS})
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include
DESTINATION ${CMAKE_INSTALL_PREFIX}
FILES_MATCHING PATTERN "*.h")
# ---[ Send the lists to the parent scope.
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
set(Caffe2_CPU_INCLUDE ${Caffe2_CPU_INCLUDE} PARENT_SCOPE)
set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE)
if(USE_TENSORRT)
set(Caffe2_GPU_INCLUDE ${Caffe2_GPU_INCLUDE} PARENT_SCOPE)
endif()

View File

@ -1,154 +0,0 @@
# nomnigraph
nomnigraph is caffe2's graph transformation subsystem
## Usage
The output of `caffe2::convertToNNModule(caffe2::NetDef)` (found in `caffe2/opt`) is an `NNModule`.
The output of `caffe2::convertToCaffe2Proto(nom::repr::NNModule*, caffe2::NetDef)` is a `NetDef`.
`convertToCaffe2Proto(convertToNNModule(n), n)` should basically return an unchanged network.
An `NNModule` is composed of both `dataFlow` and `controlFlow` graphs.
Creating a new operator is straightforward.
```cpp
auto reluNode = nn.dataFlow.createNode(make_unique<nom::repr::Relu>());
```
The line above does a few things worth talking about.
1) It creates a new node using the graph API (both dataFlow and controlFlow are `Graph`s).
2) It instantiates the node with data, specifically a `unique_ptr` to a neural network operator.
3) This `unique_ptr` contains a type that inherits from `NeuralNetOperator` and forms the fundamental representation described in the IR section below.
Inserting this operator into the graph would look something like this:
```cpp
auto edge = nn.dataFlow.createEdge(convOutputTensorNode, reluNode);
```
Some notes here:
1) Again the graph API is used to insert the node into the graph with an edge.
2) Operators are strictly connected to Tensors, not other operators.
## IR
nomnigraph has a *parallel* representation that can contain annotations with caffe2's OperatorDef.
If you call `caffe2::convertToNNModule(caffe2::NetDef)`, every operator in the `NNModule` will be annotated with a reference to the original operator in the net.
This means you should not delete the original protobuf.
```cpp
auto conv = repr::nn::get<repr::Conv>(convNode);
if (conv->getAnnotation()) {
auto annotation = dyn_cast<caffe2::Caffe2Annotation>(conv->getMutableAnnotation());
OperatorDef* op = annotation->getMutableOperatorDef();
// Do stuff with the caffe2 protobuf
}
```
If you create a new op, as shown in the example above and copied here:
```cpp
auto reluNode = nn.dataFlow.createNode(make_unique<nom::repr::Relu>());
```
it will not have a caffe2 annotation.
How does `caffe2::convertToCaffe2Proto(nom::repr::NNModule*, caffe2::NetDef)` deal with this?
Operators are either generated manually (see the implementation in `caffe2/opt/converter.cc`) or automatically.
The automatic generation is done by simply setting the operator `type` to the name of the operator.
If you'd like to add your own operator to a net and need it to be generated (i.e. are writing a transform that inserts
new nodes which have attributes likes args) you will need to add your own code to `caffe2/opt/converter.cc`.
Do not create `OperatorDef`s in the transformation itself! This is an anti-pattern as the logic becomes less portable.
## API
Below is a subset of selected API calls that are quite useful. Lower level manipulation calls are omitted.
### Graph transformation API
Nomnigraph provides a ReplaceSubgraph API to perform graph transformation operations without having to write custom subgraph matching logic. The main header file is [SubgraphMatcher.h](include/nomnigraph/Transformations/SubgraphMatcher.h).
ReplaceSubgraph API takes in
- A subgraph pattern to be matched
- A graph to be scanned for matching patterns
- A ReplaceGraph lambda function that takes in a matched subgraph; callers should implement specific graph transformation operation in the lambda.
The ReplaceSubgraph implementation takes care of the pattern matching part and also provides tools for callers to implement graph transformation logic with less effort.
Example usage of the API can be found in [subgraph_matcher_test.cc](tests/subgraph_matcher_test.cc)
Example usage of the API for NNGraph can be found in [neural_net_test.cc](tests/neural_net_test.cc)
### Graph API
Nomnigraph's core graph APIs provide a generic graph data structure and basic graph manipulation abilities. The main header file is [Graph.h](include/nomnigraph/Graph/Graph.h).
```cpp
auto g = Graph<T>(); // Constructor
Graph<T>::NodeRef n = g.createNode(T t); // Returns reference to the node
Graph<T>::EdgeRef e = g.createEdge(n1, n2); // Returns reference to the edge
g.deleteNode(n); // Deletes the node and all of its in/out edges from the graph
// Use g.deleteNode(n, false); to keep the edges around.
g.deleteEdge(e); // Deletes the edge between two nodes.
auto e = g.getEdge(n1, n2); // Gets the first edge that has n1 as a tail and n2 as the head.
auto ns = g.getMutableNodes(); // Returns a vector of Graph<T>::NodeRef
auto es = g.getMutableEdges(); // Returns a vector of Graph<T>::EdgeRef
T d = n->data(); // Get the data stored at the node
```
### NN API
NN (NeuralNet) extends core Graph with functionalities specific to neural network computation graph. The main header file is [NeuralNet.h](include/nomnigraph/Representations/NeuralNet.h).
Type checking & data accessing
```cpp
repr::NNModule nn = ...;
using namespace nom;
repr::NNGraph::NodeRef n; // Canonical node of the neural network
bool b = repr::nn::is<repr::Tensor>(n); // Checks the type stored on the node. (Works with parent types.)
repr::Conv* c = repr::nn::get<repr::Conv>(n); // Returns a pointer to the NeuralNetOperator or NeuralNetData in the node
```
Iterate through nodes in a NNGraph.
```cpp
auto pairs = dataIterator(nn); // A useful paradigm for iterating through nodes and corresponding data in no particular order.
auto nodeRefs = nodeIterator(nn); // Iterate through nodes in no particular order.
// See https://github.com/pytorch/pytorch/blob/main/caffe2/opt/mobile.cc#L106-L109
```
These functions make it easy to check attributes on nodes.
```cpp
// -- Tensor node functions --
bool b = hasProducer(tensorNode); // Checks for producers.
auto n = getProducer(tensorNode); // Returns the producer of the tensor
bool b = hasConsumer(tensorNode); // Checks for consumers.
std::vector<NNGraph::NodeRef> consumers = getConsumers(tensorNode); // Returns a vector of all consumers of the tensor.
// -- Operator node functions --
bool b = hasInputs(n); // Checks if there are any input tensors.
std::vector<NNGraph::NodeRef> getInputs(n); // Returns a vector of all the input tensor nodes.
std::vector<NNGraph::NodeRef> getOutputs(n); // Returns a vector of all the output tensor nodes.
```
These functions are less commonly useful
```cpp
coalesceInsertedDataDependencies(&nn); // Fixes up all the inserted dependencies in the dataflow graph.
insertOp<repr::Relu>(nn.dataFlow, n1, n2); // Inserts an operator into the dataflow graph and creates a new blob to do so.
// n1 or n2 must be a tensor and the inserted blob inherits the name from that, appending an underscore.
convertNode<repr::ConvRelu>(nn.dataFlow, n); // Converts the data at the node to a new node by calling the passed in type with the old node's data as the constructor argument.
```

View File

@ -1,245 +0,0 @@
#!/usr/bin/env python3
import argparse
from textwrap import dedent
from subprocess import call
def parse_lines(lines):
# States
EMPTY = 0
OP = 1
MACRO = 2
parse_state = EMPTY
# Preprocess the macros
curr_macro = ""
macros = {}
index = 0
while index < len(lines):
line = lines[index]
if line.lower().startswith("macro"):
assert parse_state == EMPTY
macro_line = line.split(" ")
# Support macros that look like attributes
# e.g. macro - CONV_LIKE
curr_macro = " ".join(macro_line[1:])
assert curr_macro not in macros, 'Macro "{}" defined twice.'.format(
curr_macro
)
macros[curr_macro] = []
parse_state = MACRO
lines = lines[:index] + lines[index + 1 :]
continue
elif line.lower().startswith("endmacro"):
assert parse_state == MACRO
parse_state = EMPTY
lines = lines[:index] + lines[index + 1 :]
continue
elif parse_state == MACRO:
macros[curr_macro].append(line)
lines = lines[:index] + lines[index + 1 :]
continue
index += 1
index = 0
while index < len(lines):
line = lines[index]
if line in macros:
lines = lines[:index] + macros[line] + lines[index + 1 :]
index += len(macros[line]) - 1
index += 1
# Now parse the file
curr_op = ""
# dict of the form
# opName : { attributes: [], ... }
ops = {}
# To preserve parsing order for dependencies (for things like init_from)
op_list = []
for line in lines:
if not len(line):
continue
if line[0] == "-":
assert parse_state is OP
attr = [_.strip() for _ in line[1:].split(":")]
assert attr[0][0].isupper()
if len(attr) == 2: # attribute : type
ops[curr_op]["attributes"].append((attr[0], attr[1]))
elif len(attr) == 3: # attribute : type
ops[curr_op]["attributes"].append((attr[0], attr[1], attr[2]))
else:
op = [l.strip() for l in line.split(":")]
assert len(op[0].split(" ")) == 1
parse_state = OP
curr_op = op[0]
assert curr_op not in ops
ops[curr_op] = {}
op_list.append(curr_op)
if len(op) > 1:
ops[curr_op]["init_from"] = [op[1]]
ops[curr_op]["attributes"] = []
return ops, op_list
def gen_class(op, op_def):
attributes = op_def["attributes"]
attribute_args = []
default_init = "NeuralNetOperator(NNKind::{op})".format(op=op)
attribute_init = [default_init]
attribute_declarations = []
attribute_getters = []
attribute_setters = []
for attr in attributes:
lower_name = attr[0][0].lower() + attr[0][1:]
private_name = lower_name + "_"
default_arg = "" if len(attr) < 3 else " = {}".format(attr[2])
name = attr[0]
t = attr[1]
attr_arg = "{type} {lower_name}".format(
type=t, lower_name=lower_name + default_arg
)
attr_init = "{private_name}({lower_name})".format(
private_name=private_name, lower_name=lower_name)
attr_declare = "{type} {private_name};".format(
type=t, private_name=private_name)
attr_get = dedent(
"""
{type} get{name}() const {{
return {private_name};
}}
""".format(
type=t, name=name, private_name=private_name
)
)
attr_set = dedent(
"""
void set{name}({type} {lower_name}) {{
{private_name} = {lower_name};
}}
""".format(
type=t, name=name, private_name=private_name, lower_name=lower_name
)
)
attribute_args.append(attr_arg)
attribute_init.append(attr_init)
attribute_declarations.append(attr_declare)
attribute_getters.append(attr_get)
attribute_setters.append(attr_set)
extra_init = ""
if "init_from" in op_def:
for other_op in op_def["init_from"]:
lower_other_op = other_op[0].lower() + other_op[1:]
other_init = [default_init]
for attr in attributes:
lower_name = attr[0][0].lower() + attr[0][1:]
private_name = lower_name + "_"
other_init.append(
"{private_name}({other_op}.get{name}())".format(
name=attr[0], private_name=private_name, other_op=lower_other_op
)
)
init = dedent(
"""
{op}(const {other_op}& {lower_other_op}) :
{other_init} {{}}
""".format(
op=op,
other_op=other_op,
lower_other_op=lower_other_op,
other_init=",\n ".join(other_init),
)
)
extra_init += init
return dedent(
"""
class {op} : public NeuralNetOperator {{
public:
{op}({attribute_args}) :
{attribute_init} {{}}
{extra_init}
~{op}() {{}}
NOMNIGRAPH_DEFINE_NN_RTTI({op});
{getters}{setters}
private:
{attribute_declarations}
}};
""".format(
op=op,
extra_init=extra_init,
getters="".join(attribute_getters),
setters="".join(attribute_setters),
attribute_args=",\n".join(attribute_args),
attribute_init=",\n".join(attribute_init),
attribute_declarations="\n".join(attribute_declarations),
)
)
def gen_classes(ops, op_list):
f = ""
for op in op_list:
f += gen_class(op, ops[op])
return f
def gen_enum(op_list):
return ",\n".join([op for op in op_list]) + "\n"
def gen_names(op_list):
f = ""
for op in op_list:
f += dedent(
"""
case NNKind::{name}:
return \"{name}\";
""".format(
name=op
)
)
return f
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate op files.")
parser.add_argument("--install_dir", help="installation directory")
parser.add_argument("--source_def", help="ops.def", action="append")
args = parser.parse_args()
install_dir = args.install_dir
sources = args.source_def
lines = []
for source in sources:
with open(source, "rb") as f:
lines_tmp = f.readlines()
lines += [l.strip().decode("utf-8") for l in lines_tmp]
ops, op_list = parse_lines(lines)
with open(install_dir + "/OpClasses.h", "wb") as f:
f.write(gen_classes(ops, op_list).encode("utf-8"))
with open(install_dir + "/OpNames.h", "wb") as f:
f.write(gen_names(op_list).encode("utf-8"))
with open(install_dir + "/OpEnum.h", "wb") as f:
f.write(gen_enum(op_list).encode("utf-8"))
try:
cmd = ["clang-format", "-i", install_dir + "/OpClasses.h"]
call(cmd)
cmd = ["clang-format", "-i", install_dir + "/OpNames.h"]
call(cmd)
cmd = ["clang-format", "-i", install_dir + "/OpEnum.h"]
call(cmd)
except Exception:
pass

View File

@ -1,80 +0,0 @@
macro - CONV_ATTRS
- KernelShape : vector<int>
- Pads : vector<int> : {0, 0}
- Strides : vector<int> : {1, 1}
- Group : int : 1
- Dilations: vector<int> : {1, 1}
endmacro
macro - POOL_ATTRS
- KernelShape : vector<int>
- Pads : vector<int> : {0, 0}
- Strides : vector<int> : {1, 1}
endmacro
Relu
Conv
- CONV_ATTRS
ConvRelu : Conv
- CONV_ATTRS
ConvTranspose
- CONV_ATTRS
AveragePool
- POOL_ATTRS
AveragePoolRelu : AveragePool
- POOL_ATTRS
MaxPool
- POOL_ATTRS
MaxPoolRelu : MaxPool
- POOL_ATTRS
Sum
SumRelu : Sum
Send
- Destination : string
Receive
- Source : string
BatchNormalization
- Epsilon : float : 1e-5f
- Momentum : float : 0.9f
- Spatial : bool : true
- IsTest : bool : false
Clip
- Min : float
- Max : float
FC
- Axis : int : 1
- AxisW : int : 1
GivenTensorFill
Concat
- Axis : int : -1
- AddAxis : bool : false
Softmax
ChannelShuffle
Add
- Broadcast : int : 0
Reshape
Flatten
CopyToOpenCL
CopyFromOpenCL
NCHW2NHWC
NHWC2NCHW
Declare
Export

View File

@ -1,126 +0,0 @@
#include "test_util.h"
#include <gtest/gtest.h>
TEST(DominatorTree, Test1) {
nom::Graph<std::string> graph;
auto r = graph.createNode(std::string("r"));
auto a = graph.createNode(std::string("a"));
auto b = graph.createNode(std::string("b"));
auto c = graph.createNode(std::string("c"));
auto d = graph.createNode(std::string("d"));
auto e = graph.createNode(std::string("e"));
auto f = graph.createNode(std::string("f"));
auto g = graph.createNode(std::string("g"));
auto l = graph.createNode(std::string("l"));
auto h = graph.createNode(std::string("h"));
auto i = graph.createNode(std::string("i"));
auto j = graph.createNode(std::string("j"));
auto k = graph.createNode(std::string("k"));
graph.createEdge(r, a);
graph.createEdge(r, b);
graph.createEdge(r, c);
graph.createEdge(c, f);
graph.createEdge(c, g);
graph.createEdge(g, j);
graph.createEdge(g, i);
graph.createEdge(f, i);
graph.createEdge(i, k);
graph.createEdge(k, i);
graph.createEdge(k, r);
graph.createEdge(a, d);
graph.createEdge(b, d);
graph.createEdge(b, a);
graph.createEdge(b, e);
graph.createEdge(d, l);
graph.createEdge(l, h);
graph.createEdge(h, k);
graph.createEdge(h, e);
graph.createEdge(e, h);
auto tree = nom::algorithm::dominatorTree(&graph, r);
auto map = nom::algorithm::immediateDominatorMap(&graph, r);
EXPECT_EQ(map[j], g);
EXPECT_EQ(map[g], c);
EXPECT_EQ(map[f], c);
EXPECT_EQ(map[l], d);
EXPECT_EQ(map[a], r);
EXPECT_EQ(map[b], r);
EXPECT_EQ(map[c], r);
EXPECT_EQ(map[d], r);
EXPECT_EQ(map[e], r);
EXPECT_EQ(map[h], r);
EXPECT_EQ(map[i], r);
EXPECT_EQ(map[k], r);
auto domFrontMap = nom::algorithm::dominanceFrontierMap(&graph, r);
}
// https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec04-SSA.pdf
// using example on page 24
TEST(DominatorTree, Test2) {
nom::Graph<std::string> graph;
auto entry = graph.createNode(std::string("entry"));
auto n1 = graph.createNode(std::string("1"));
auto n2 = graph.createNode(std::string("2"));
auto n3 = graph.createNode(std::string("3"));
auto n4 = graph.createNode(std::string("4"));
auto n5 = graph.createNode(std::string("5"));
auto n6 = graph.createNode(std::string("6"));
auto n7 = graph.createNode(std::string("7"));
auto exit = graph.createNode(std::string("exit"));
graph.createEdge(entry, n1);
graph.createEdge(n1, n2);
graph.createEdge(n1, n5);
graph.createEdge(n5, n1);
graph.createEdge(n2, n3);
graph.createEdge(n2, n4);
graph.createEdge(n3, n6);
graph.createEdge(n4, n6);
graph.createEdge(n6, n7);
graph.createEdge(n5, n7);
graph.createEdge(n7, exit);
auto domFrontMap = nom::algorithm::dominanceFrontierMap(&graph, entry);
using noderef = nom::Graph<std::string>::NodeRef;
std::unordered_map<noderef, std::unordered_set<noderef>> checkMap = {
{n1, {n1}},
{n2, {n7}},
{n3, {n6}},
{n4, {n6}},
{n5, {n1, n7}},
{n6, {n7}}
};
// NOLINTNEXTLINE(performance-for-range-copy)
for (auto pair : domFrontMap) {
EXPECT_EQ(pair.second, checkMap[pair.first]);
}
}
TEST(Subgraph, InduceEdges) {
auto g = createGraph();
auto sg = decltype(g)::SubgraphType();
for (const auto& node : g.getMutableNodes()) {
sg.addNode(node);
}
nom::algorithm::induceEdges(&sg);
for (const auto& edge : g.getMutableEdges()) {
EXPECT_TRUE(sg.hasEdge(edge));
}
}
TEST(Subgraph, InduceEdgesCycle) {
auto g = createGraphWithCycle();
auto sg = decltype(g)::SubgraphType();
for (const auto& node : g.getMutableNodes()) {
sg.addNode(node);
}
nom::algorithm::induceEdges(&sg);
for (const auto& edge : g.getMutableEdges()) {
EXPECT_TRUE(sg.hasEdge(edge));
}
}

View File

@ -1,123 +0,0 @@
#include <gtest/gtest.h>
#include <set>
#include "test_util.h"
#include "nomnigraph/Converters/Dot.h"
#include "nomnigraph/Graph/Algorithms.h"
#include "nomnigraph/Graph/Graph.h"
TEST(BinaryMatch, NoMatch) {
auto graph = createGraph();
auto matches = nom::algorithm::binaryMatch(
&graph, [](decltype(graph)::NodeRef n) { return false; });
EXPECT_EQ(matches.size(), 0);
}
TEST(BinaryMatch, AllMatch) {
auto graph = createGraph();
auto matches = nom::algorithm::binaryMatch(
&graph, [](decltype(graph)::NodeRef n) { return true; });
EXPECT_EQ(matches.size(), 1);
EXPECT_EQ(matches.front().getNodesCount(), graph.getMutableNodes().size());
}
TEST(BinaryMatch, EmptyGraph) {
nom::Graph<std::string> graph;
auto matches = nom::algorithm::binaryMatch(
&graph, [](decltype(graph)::NodeRef n) { return true; });
EXPECT_EQ(matches.size(), 0);
}
// We should get this back:
// +---+ +-------+
// | 4 | <-- | 2 |
// +---+ +-------+
// | |
// | |
// | v
// | +-------+
// | | 3 |
// | +-------+
// | |
// | |
// | v
// | +-------+
// +-----> | 6 |
// +-------+
TEST(BinaryMatch, Basic) {
auto graph = createGraph();
auto matches =
nom::algorithm::binaryMatch(&graph, [](decltype(graph)::NodeRef n) {
if (n->data() == "2" || n->data() == "3" || n->data() == "4" ||
n->data() == "6") {
return true;
}
return false;
});
EXPECT_EQ(matches.size(), 1);
auto match = matches.front();
EXPECT_EQ(match.getNodesCount(), 4);
std::set<std::string> exp{"2", "3", "4", "6"};
for (auto n : match.getNodes()) {
EXPECT_EQ(exp.count(n->data()), 1);
exp.erase(n->data());
}
// We found all the those nodes.
EXPECT_EQ(exp.size(), 0);
}
// The interesting bit about this test case is that
// the predicate does not match on 3.
//
// As such, this part of the graph
// +---+ +-------+
// | 4 | <-- | 2 |
// +---+ +-------+
// | |
// | |
// | v
// | +-------+
// | | 3 |
// | +-------+
// | |
// | |
// | v
// | +-------+
// +-----> | 6 |
// +-------+
//
// should match as { 4, 2 }, { 6 } not { 4, 2, 6 }
TEST(BinaryMatch, RemovedMiddleNode) {
auto graph = createGraph();
auto matches =
nom::algorithm::binaryMatch(&graph, [](decltype(graph)::NodeRef n) {
if (n->data() == "2" || n->data() == "4" || n->data() == "6") {
return true;
}
return false;
});
EXPECT_EQ(matches.size(), 2);
auto match1 = matches.front();
auto match2 = matches.back();
EXPECT_EQ(match1.getNodesCount(), 2);
EXPECT_EQ(match2.getNodesCount(), 1);
std::set<std::string> exp1{"2", "4"};
std::set<std::string> exp2{"6"};
for (auto n : match1.getNodes()) {
EXPECT_EQ(exp1.count(n->data()), 1);
exp1.erase(n->data());
}
for (auto n : match2.getNodes()) {
EXPECT_EQ(exp2.count(n->data()), 1);
exp2.erase(n->data());
}
EXPECT_EQ(exp1.size(), 0);
EXPECT_EQ(exp2.size(), 0);
}

View File

@ -1,219 +0,0 @@
#include "test_util.h"
#include "nomnigraph/Converters/Dot.h"
#include "nomnigraph/Graph/Algorithms.h"
#include "nomnigraph/Graph/Graph.h"
#include "nomnigraph/Support/Casting.h"
#include <gtest/gtest.h>
using TestGraph = nom::Graph<TestClass>;
TEST(Basic, CreateNodeAndEdge) {
TestGraph g;
auto n1 = createTestNode(g);
auto n2 = createTestNode(g);
g.createEdge(n1, n2);
EXPECT_TRUE(g.hasNode(n1));
EXPECT_TRUE(g.hasNode(n2));
}
TEST(Basic, DeleteNode) {
TestGraph g;
auto n1 = createTestNode(g);
auto n2 = createTestNode(g);
g.createEdge(n1, n2);
EXPECT_TRUE(g.hasNode(n1));
g.deleteNode(n1);
EXPECT_FALSE(g.hasNode(n1));
}
TEST(Basic, DeleteEdge) {
TestGraph g;
auto n1 = createTestNode(g);
auto n2 = createTestNode(g);
auto e = g.createEdge(n1, n2);
g.deleteEdge(e);
}
TEST(Basic, ReplaceEdges) {
TestGraph g;
auto n1 = createTestNode(g);
auto n2 = createTestNode(g);
auto n3 = createTestNode(g);
auto n4 = createTestNode(g);
auto n5 = createTestNode(g);
g.createEdge(n1, n3);
g.createEdge(n2, n3);
g.createEdge(n3, n4);
/*
1 2 5
|
3
|
4
*/
EXPECT_FALSE(g.hasEdge(n1, n5));
EXPECT_FALSE(g.hasEdge(n2, n5));
g.replaceInEdges(n3, n5);
/*
1 2 3
| |
5 4
*/
EXPECT_TRUE(g.hasEdge(n1, n5));
EXPECT_TRUE(g.hasEdge(n2, n5));
EXPECT_FALSE(g.hasEdge(n5, n4));
g.replaceOutEdges(n3, n5);
/*
1 2 3
|
5
|
4
*/
EXPECT_TRUE(g.hasEdge(n5, n4));
g.replaceNode(n5, n3);
// Back to the original graph.
/*
1 2 5
|
3
|
4
*/
EXPECT_TRUE(g.hasEdge(n1, n3));
EXPECT_TRUE(g.hasEdge(n2, n3));
}
TEST(Basic, HasNode) {
TestGraph g;
auto n1 = createTestNode(g);
auto n2 = createTestNode(g);
auto n3 = createTestNode(g);
g.createEdge(n1, n2);
g.createEdge(n1, n3);
// Current graph: 1 -> 2 -> 3
EXPECT_TRUE(g.hasNode(n1));
EXPECT_TRUE(g.hasNode(n2));
EXPECT_TRUE(g.hasNode(n3));
g.swapNodes(n1, n3);
// Current graph: 3 -> 2 -> 1
EXPECT_TRUE(g.hasNode(n1));
EXPECT_TRUE(g.hasNode(n3));
g.deleteNode(n1);
// Current graph: 3 -> 2
EXPECT_FALSE(g.hasNode(n1));
auto n4 = createTestNode(g);
EXPECT_TRUE(g.hasNode(n4));
g.replaceNode(n2, n4);
// Current graph: 3 -> 4 , 2
// replaceNode doesn't delete n2.
EXPECT_TRUE(g.hasNode(n2));
// Create a second graph g2, and move the nodes from g2 to g.
TestClass t5;
nom::Graph<TestClass> g2;
nom::Graph<TestClass>::NodeRef n5 = g2.createNode(std::move(t5));
EXPECT_TRUE(g2.hasNode(n5));
EXPECT_FALSE(g.hasNode(n5));
g2.moveNode(n5, &g);
// Current graph (g1): 3 -> 4, 2, 5
EXPECT_TRUE(g.hasNode(n5));
}
TEST(Basic, Moves) {
TestGraph g;
auto n1 = createTestNode(g);
auto n2 = createTestNode(g);
auto n3 = createTestNode(g);
auto e1 = g.createEdge(n1, n2);
auto e2 = g.createEdge(n1, n3);
// Current graph: 1 -> 2 -> 3
TestGraph g2;
g.deleteEdge(e2);
g.moveNode(n1, &g2);
g.moveNode(n2, &g2);
g.moveEdge(e1, &g2);
EXPECT_TRUE(g.isValid());
EXPECT_TRUE(g2.isValid());
EXPECT_EQ(g.getMutableNodes().size(), 1);
EXPECT_EQ(g2.getMutableNodes().size(), 2);
EXPECT_EQ(g.getMutableEdges().size(), 0);
EXPECT_EQ(g2.getMutableEdges().size(), 1);
}
TEST(Basic, MoveSubgraph) {
TestGraph g;
auto n1 = createTestNode(g);
auto n2 = createTestNode(g);
auto n3 = createTestNode(g);
auto e1 = g.createEdge(n1, n2);
auto e2 = g.createEdge(n1, n3);
// Current graph: 1 -> 2 -> 3
TestGraph g2;
g.deleteEdge(e2);
TestGraph::SubgraphType sg;
sg.addNode(n1);
sg.addNode(n2);
sg.addEdge(e1);
g.moveSubgraph(sg, &g2);
EXPECT_TRUE(g.isValid());
EXPECT_TRUE(g2.isValid());
EXPECT_EQ(g.getMutableNodes().size(), 1);
EXPECT_EQ(g2.getMutableNodes().size(), 2);
EXPECT_EQ(g.getMutableEdges().size(), 0);
EXPECT_EQ(g2.getMutableEdges().size(), 1);
}
TEST(Basic, DotGenerator) {
TestGraph g;
auto n1 = createTestNode(g);
auto n2 = createTestNode(g);
auto n3 = createTestNode(g);
auto e12 = g.createEdge(n1, n2);
g.createEdge(n1, n3);
std::string dot = nom::converters::convertToDotString(&g, TestNodePrinter);
// sanity check
std::string prefix = "digraph G";
// Full string comparison of the output is not stable because the dot
// string includes node pointer address as node id. We should switch to
// comparing full output once dot generator no longer uses addresses.
EXPECT_TRUE(dot.compare(0, prefix.length(), prefix) == 0);
TestGraph::SubgraphType sg;
sg.addNode(n1);
sg.addNode(n2);
sg.addEdge(e12);
// Convert to dot with subgraph clusters.
dot = nom::converters::convertToDotString<TestGraph>(
&g, {&sg}, TestNodePrinter);
// sanity check
EXPECT_TRUE(dot.compare(0, prefix.length(), prefix) == 0);
// Convert a single subgraph to dot.
dot = nom::converters::convertToDotString<TestGraph>(sg, TestNodePrinter);
// sanity check
EXPECT_TRUE(dot.compare(0, prefix.length(), prefix) == 0);
dot =
nom::converters::convertToDotRecordString<TestGraph>(&g, TestNodePrinter);
// sanity check
EXPECT_TRUE(dot.compare(0, prefix.length(), prefix) == 0);
}

View File

@ -1,37 +0,0 @@
#include "test_util.h"
#include "nomnigraph/Transformations/Match.h"
#include <gtest/gtest.h>
TEST(Match, Basic) {
nom::Graph<std::string> graph;
auto entry = graph.createNode(std::string("entry"));
auto n1 = graph.createNode(std::string("1"));
auto n2 = graph.createNode(std::string("2"));
auto n3 = graph.createNode(std::string("3"));
auto n4 = graph.createNode(std::string("4"));
auto n5 = graph.createNode(std::string("5"));
auto n6 = graph.createNode(std::string("6"));
auto n7 = graph.createNode(std::string("7"));
auto exit = graph.createNode(std::string("exit"));
graph.createEdge(entry, n1);
graph.createEdge(n1, n2);
graph.createEdge(n1, n5);
graph.createEdge(n5, n1);
graph.createEdge(n2, n3);
graph.createEdge(n2, n4);
graph.createEdge(n3, n6);
graph.createEdge(n4, n6);
graph.createEdge(n6, n7);
graph.createEdge(n5, n7);
graph.createEdge(n7, exit);
nom::Graph<std::string> match_graph;
auto m1 = match_graph.createNode(std::string("1"));
auto m2 = match_graph.createNode(std::string("2"));
match_graph.createEdge(m1, m2);
nom::Match<decltype(graph)> m(match_graph);
EXPECT_EQ(m.match(graph).size(), 1);
}

View File

@ -1,97 +0,0 @@
#include <algorithm>
#include <memory>
#include "test_util.h"
#include "nomnigraph/Representations/NeuralNet.h"
#include "nomnigraph/Transformations/SubgraphMatcher.h"
#include <gtest/gtest.h>
using namespace nom;
using namespace nom::repr;
using namespace nom::repr::nn;
// Test for the NNGraph subgraph matching APIs.
TEST(NeuralNetGraph, ReplaceGraph) {
NNGraph graph;
auto input1 = graph.createNode(std::make_unique<Tensor>("input1"));
auto input2 = graph.createNode(std::make_unique<Tensor>("input2"));
// Test renaming blob
nn::get<Tensor>(input2)->setName("input2_renamed");
auto sum = graph.createNode(std::make_unique<Sum>());
auto sumOutput = graph.createNode(std::make_unique<Tensor>("sumOutput"));
auto relu = graph.createNode(std::make_unique<Relu>());
auto reluOutput = graph.createNode(std::make_unique<Tensor>("reluOutput"));
graph.createEdge(input1, sum);
graph.createEdge(input2, sum);
graph.createEdge(sum, sumOutput);
graph.createEdge(sumOutput, relu);
graph.createEdge(relu, reluOutput);
/* input1 input2
\ /
\ /
sum
|
|
sumOutput
|
relu
|
reluOutput
*/
auto mg = NNMatchGraph();
auto matchSumInput =
mg.createNode(std::move(matchExternalTensorNode().count(2)));
auto matchSum = mg.createNode(nn::is<Sum>);
mg.createEdge(matchSumInput, matchSum);
auto matchSumOutput = mg.createNode(nn::is<Tensor>);
mg.createEdge(matchSum, matchSumOutput);
auto matchRelu = mg.createNode(nn::is<Relu>);
mg.createEdge(matchSumOutput, matchRelu);
auto matchRoot = matchRelu;
EXPECT_FALSE(mg.isSubgraphMatch(sum, matchRoot).isMatch());
EXPECT_FALSE(mg.isSubgraphMatch(reluOutput, matchRoot).isMatch());
EXPECT_FALSE(mg.isSubgraphMatch(input1, matchRoot).isMatch());
EXPECT_TRUE(mg.isSubgraphMatch(relu, matchRoot).isMatch());
mg.replaceSubgraph(
graph,
matchRoot,
[&matchSumOutput](
NNGraph& g,
NNGraph::NodeRef relu,
const NNMatchGraph::SubgraphMatchResultType& matchResult) {
auto fusedNode = g.createNode(std::make_unique<SumRelu>());
auto sumNode =
getProducer(matchResult.getMatchNodeMap()->at(matchSumOutput));
g.replaceOutEdges(relu, fusedNode);
g.replaceInEdges(sumNode, fusedNode);
g.deleteNodes(matchResult.getMatchedSubgraph()->getNodes());
return true;
});
/*
Fused graph:
input1 input2
\ /
\ /
sumRelu
|
|
output
*/
EXPECT_EQ(graph.getNodesCount(), 4);
auto fusedNode = getProducer(reluOutput);
EXPECT_TRUE(is<SumRelu>(fusedNode));
EXPECT_EQ(getInputs(fusedNode).size(), 2);
}

View File

@ -1,646 +0,0 @@
#include <algorithm>
#include <functional>
#include "test_util.h"
#include "nomnigraph/Transformations/SubgraphMatcher.h"
#include <gtest/gtest.h>
namespace nom {
namespace matcher {
using NodeType = std::string;
using Criteria = std::string;
using TestGraph = Graph<NodeType>;
using TestMatchGraph = MatchGraph<TestGraph>;
using TestMatchPredicate = MatchPredicate<TestGraph>;
// Have just one TestMatchGraph in the tests to make it less verbose to create
// the match graphs.
TestMatchGraph graph;
// Call reset before creating a new TestMatchGraph.
void reset() {
graph = TestMatchGraph();
}
// Node matches a criteria (string) if the data string is the same as the
// criteria. Special case: "*" will match any thing.
TestMatchPredicate testMatchPredicate(const Criteria& criteria) {
return TestMatchPredicate([criteria](TestGraph::NodeRef node) {
return criteria == "*" || criteria == node->data();
});
}
Criteria any() {
return Criteria("*");
}
// Helper methods to make it less verbose to create match graphs.
TestMatchGraph::NodeRef Tree(
const Criteria& root,
const std::vector<TestMatchGraph::NodeRef>& children = {},
int count = 1) {
auto result =
graph.createNode(std::move(testMatchPredicate(root).count(count)));
for (auto& child : children) {
graph.createEdge(result, child);
}
return result;
}
TestMatchGraph::NodeRef NonTerminal(const Criteria& root, int count = 1) {
return graph.createNode(
std::move(testMatchPredicate(root).count(count).nonTerminal()));
}
std::map<std::string, std::string> TestGraphNodePrinter(
TestGraph::NodeRef node) {
std::map<std::string, std::string> labelMap;
labelMap["label"] = node->data();
return labelMap;
};
// Attempts to create a realistic dataflow graph that shows a fuse procedure.
struct DataFlowTestGraph {
const int numInputs = 4;
TestGraph graph;
TestGraph::NodeRef opB;
TestGraph::NodeRef opF;
TestGraph::NodeRef opC;
TestGraph::NodeRef opG;
TestGraph::NodeRef dataOut;
// Realistic data flow test graph.
/*
+---------------+
| |
| +---------+ | +---------+
+---------------------+ | input_A | | | input_B |
| +---------+ | +---------+
| | | |
| | | |
| v v v
+---------++---------+ +-------------------------+ +--------+
| input_C || input_D | --> | opC | --> | dataC2 |
+---------++---------+ +-------------------------+ +--------+
|
|
v
+---------+
| dataC | -+
+---------+ |
| |
| |
v |
+---------+ |
| opB | <+
+---------+
|
|
v
+---------+
| dataB |
+---------+
|
|
v
+---------+
| opF |
+---------+
|
|
v
+---------+
| dataF |
+---------+
|
|
v
+---------+ +---------+
| dataI | --> | opG |
+---------+ +---------+
|
|
v
+---------+
| dataOut |
+---------+
*/
DataFlowTestGraph() {
opC = graph.createNode("opC");
for (int i = 0; i < numInputs; i++) {
auto dataInput = graph.createNode("input");
graph.createEdge(dataInput, opC);
}
auto dataC = graph.createNode("dataC");
auto dataC2 = graph.createNode("dataC2");
graph.createEdge(opC, dataC);
graph.createEdge(opC, dataC2);
opB = graph.createNode("opB");
// There are 2 edges
graph.createEdge(dataC, opB);
graph.createEdge(dataC, opB);
auto dataB = graph.createNode("dataB");
graph.createEdge(opB, dataB);
opF = graph.createNode("opF");
graph.createEdge(dataB, opF);
auto dataF = graph.createNode("dataF");
graph.createEdge(opF, dataF);
auto dataI = graph.createNode("dataI");
opG = graph.createNode("opG");
graph.createEdge(dataF, opG);
graph.createEdge(dataI, opG);
dataOut = graph.createNode("dataOut");
graph.createEdge(opG, dataOut);
// Use nom::converters::convertToDotString(&graph, TestGraphNodePrinter)
// to visualize the graph.
}
};
struct DataFlowTestGraphCriteria {
TestMatchGraph::NodeRef matchOpCOutput;
TestMatchGraph::NodeRef matchOpG;
DataFlowTestGraphCriteria() {
auto matchOpCInputs =
graph.createNode(std::move(testMatchPredicate(Criteria("input"))
.starCount()
.nonTerminal()
.excludeFromSubgraph()));
auto matchOpC = graph.createNode(testMatchPredicate("opC"));
graph.createEdge(matchOpCInputs, matchOpC);
matchOpCOutput = graph.createNode(testMatchPredicate(any()));
graph.createEdge(matchOpC, matchOpCOutput);
auto matchOpB = graph.createNode(testMatchPredicate("opB"));
graph.createEdge(matchOpCOutput, matchOpB);
graph.createEdge(matchOpCOutput, matchOpB);
auto matchOpBOutput = graph.createNode(testMatchPredicate(any()));
graph.createEdge(matchOpB, matchOpBOutput);
auto matchOpF = graph.createNode(testMatchPredicate("opF"));
graph.createEdge(matchOpBOutput, matchOpF);
auto matchOpFOutput = graph.createNode(testMatchPredicate(any()));
graph.createEdge(matchOpF, matchOpFOutput);
matchOpG = graph.createNode(testMatchPredicate("opG"));
auto matchDataI = graph.createNode(std::move(
testMatchPredicate(any()).nonTerminal().excludeFromSubgraph()));
graph.createEdge(matchOpFOutput, matchOpG);
graph.createEdge(matchDataI, matchOpG);
}
};
TestGraph::NodeRef getInNode(TestGraph::NodeRef node, int index) {
return node->getInEdges()[index]->tail();
}
bool isSubgraphMatch(
TestGraph::NodeRef nodeRef,
const TestMatchGraph::NodeRef& criteria,
bool invertGraphTraversal = true) {
return graph.isSubgraphMatch(nodeRef, criteria, invertGraphTraversal)
.isMatch();
}
} // namespace matcher
} // namespace nom
using namespace nom::matcher;
// Simple test cases for node matching criteria.
TEST(SubgraphMatcher, IsNodeMatch) {
TestGraph g;
auto n1 = g.createNode("Hello");
auto n2 = g.createNode("Le");
g.createEdge(n1, n2);
EXPECT_TRUE(graph.isNodeMatch(n1, testMatchPredicate("Hello")));
EXPECT_FALSE(graph.isNodeMatch(n1, testMatchPredicate("G")));
EXPECT_TRUE(graph.isNodeMatch(n2, testMatchPredicate("Le")));
EXPECT_FALSE(graph.isNodeMatch(n2, testMatchPredicate("le")));
}
// Test subtree matching with a simple tree graph.
TEST(SubgraphMatcher, IsSubtreeMatch) {
TestGraph graph;
auto n1 = graph.createNode("1");
auto n2 = graph.createNode("2");
auto n3 = graph.createNode("3");
auto n4 = graph.createNode("4");
auto n5 = graph.createNode("5");
auto n6 = graph.createNode("6");
auto n7 = graph.createNode("7");
graph.createEdge(n1, n2);
graph.createEdge(n2, n3);
graph.createEdge(n2, n4);
graph.createEdge(n1, n5);
graph.createEdge(n5, n6);
graph.createEdge(n5, n7);
/* N1
/ \
N2 N5
/ \ / \
N3 N4 N6 N7
*/
reset();
auto subtree = Tree(any(), {Tree(any()), Tree(any())});
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n4, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n2, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n5, subtree, false));
reset();
subtree = Tree(Criteria("5"), {Tree(any()), Tree(any())});
EXPECT_FALSE(isSubgraphMatch(n2, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n5, subtree, false));
reset();
subtree = Tree(any(), {Tree(any()), Tree(Criteria("4"))});
EXPECT_TRUE(isSubgraphMatch(n2, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n5, subtree, false));
reset();
// Accepts non terminal node
subtree = Tree(any(), {NonTerminal(any()), NonTerminal(any())});
EXPECT_TRUE(isSubgraphMatch(n1, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n2, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n5, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n3, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n4, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n6, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n7, subtree, false));
}
// Test subtree matching in which * (repeated) matching of children is allowed.
TEST(SubgraphMatcher, IsSubtreeMatchRepeated) {
TestGraph graph;
auto n1 = graph.createNode("1");
auto n2 = graph.createNode("2");
auto n3A = graph.createNode("3");
auto n3B = graph.createNode("3");
auto n4 = graph.createNode("4");
auto n5A = graph.createNode("5");
auto n5B = graph.createNode("5");
auto n5C = graph.createNode("5");
graph.createEdge(n1, n2);
graph.createEdge(n1, n3A);
graph.createEdge(n1, n3B);
graph.createEdge(n1, n4);
graph.createEdge(n1, n4);
graph.createEdge(n1, n5A);
graph.createEdge(n1, n5B);
graph.createEdge(n1, n5C);
reset();
auto subtree = Tree(any(), {Tree(Criteria("2"))});
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
reset();
subtree =
Tree(any(), {Tree(Criteria("2"), {}, TestMatchPredicate::kStarCount)});
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
reset();
// clang-format off
subtree = Tree(any(), {
Tree(Criteria("2")),
Tree(Criteria("3"), {}, 2),
Tree(Criteria("4"), {}, 2),
Tree(Criteria("5"), {}, 3)
});
EXPECT_TRUE(isSubgraphMatch(n1, subtree, false));
reset();
subtree = Tree(any(), {
Tree(Criteria("2")),
Tree(Criteria("3"), {}, 2),
Tree(Criteria("4"), {}, 2),
Tree(Criteria("5"), {}, 4)
});
// Failes because exepected 4 matches of n5 but found 3.
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
reset();
subtree = Tree(any(), {
Tree(Criteria("2")),
Tree(Criteria("3"), {}, 2),
Tree(Criteria("4"), {}, 2),
Tree(Criteria("5"), {}, TestMatchPredicate::kStarCount)
});
EXPECT_TRUE(isSubgraphMatch(n1, subtree, false));
reset();
subtree = Tree(any(), {
Tree(Criteria("2")),
Tree(Criteria("3"), {}, TestMatchPredicate::kStarCount),
Tree(Criteria("4"), {}, 2),
Tree(Criteria("5"), {}, TestMatchPredicate::kStarCount)
});
EXPECT_TRUE(isSubgraphMatch(n1, subtree, false));
reset();
subtree = Tree(any(), {
Tree(Criteria("2")),
Tree(Criteria("3"), {}, TestMatchPredicate::kStarCount),
});
// Fails because there are unmatched edges.
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
reset();
subtree = Tree(any(), {
Tree(Criteria("2")),
Tree(Criteria("3"), {}, 2),
Tree(Criteria("4")),
Tree(Criteria("5"), {}, 3)
});
// Fails because the count is wrong; we have 2 edges to node N4 while
// the pattern expects only 1.
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
// clang-format on
}
TEST(SubgraphMatcher, DagMatching) {
reset();
// clang-format off
auto n4match = Tree(Criteria("4"), {
Tree(Criteria("5"))
});
auto subgraph = Tree(Criteria("1"), {
Tree(Criteria("2"), {
n4match
}),
Tree(Criteria("3"), {
n4match
}),
});
// clang-format on
{
TestGraph graph;
auto n1 = graph.createNode("1");
auto n2 = graph.createNode("2");
auto n3 = graph.createNode("3");
auto n4 = graph.createNode("4");
auto n5 = graph.createNode("5");
graph.createEdge(n1, n2);
graph.createEdge(n1, n3);
graph.createEdge(n2, n4);
graph.createEdge(n3, n4);
graph.createEdge(n4, n5);
/* N1
/ \
N2 N3
\ /
N4
|
N5
*/
EXPECT_TRUE(isSubgraphMatch(n1, subgraph, false));
}
{
TestGraph graph;
auto n1 = graph.createNode("1");
auto n2 = graph.createNode("2");
auto n3 = graph.createNode("3");
auto n4A = graph.createNode("4");
auto n4B = graph.createNode("4");
auto n5 = graph.createNode("5");
graph.createEdge(n1, n2);
graph.createEdge(n1, n3);
graph.createEdge(n2, n4A);
graph.createEdge(n3, n4B);
graph.createEdge(n4A, n5);
graph.createEdge(n4B, n5);
/* N1
/ \
N2 N3
/ \
N4A N4B
\ /
N5
*/
// This should fail because n4A and n4B are not the same node.
EXPECT_FALSE(isSubgraphMatch(n1, subgraph, false));
}
}
TEST(SubgraphMatcher, DagMatchingMultiEdges) {
reset();
// clang-format off
auto n2match = Tree(Criteria("2"));
auto subgraph = Tree(Criteria("1"), {
n2match,
n2match
});
// clang-format on
{
TestGraph graph;
auto n1 = graph.createNode("1");
auto n2 = graph.createNode("2");
graph.createEdge(n1, n2);
graph.createEdge(n1, n2);
EXPECT_TRUE(isSubgraphMatch(n1, subgraph, false));
}
{
TestGraph graph;
auto n1 = graph.createNode("1");
auto n2A = graph.createNode("2");
auto n2B = graph.createNode("2");
graph.createEdge(n1, n2A);
graph.createEdge(n1, n2B);
EXPECT_FALSE(isSubgraphMatch(n1, subgraph, false));
}
}
TEST(SubgraphMatcher, DagMatchingRandomLargeGraph) {
reset();
// clang-format off
auto n4match = Tree(any(), {
NonTerminal(any(), 1)
});
auto subtree = Tree(any(), {
Tree(any(), {
n4match
}),
Tree(any(), {
n4match
}),
});
// clang-format on
/* N1
/ \
N2 N3
\ /
N4
|
N5
*/
// Look for the diamond pattern in a random large graph.
TestGraph graph;
std::vector<nom::Graph<std::string>::NodeRef> nodes;
// Here we create a test graph and then randomly embed the above
// pattern into the graph repeatedly (numPatterns times).
// The actual number of match will be less than numPatterns because the
// embedded patterns can overlap which become unmatched subgraphs.
const int numNodes = 50000;
const int numPatterns = 5000;
for (int i = 0; i < numNodes; i++) {
auto node = graph.createNode("Node");
nodes.emplace_back(node);
}
TestRandom random(517);
for (int i = 0; i < numPatterns; i++) {
std::vector<int> nodeIdx;
for (int k = 0; k < 5; k++) {
// NOLINTNEXTLINE(performance-inefficient-vector-operation)
nodeIdx.emplace_back(random.nextInt() % numNodes);
}
graph.createEdge(nodes[nodeIdx[0]], nodes[nodeIdx[1]]);
graph.createEdge(nodes[nodeIdx[0]], nodes[nodeIdx[2]]);
graph.createEdge(nodes[nodeIdx[1]], nodes[nodeIdx[3]]);
graph.createEdge(nodes[nodeIdx[2]], nodes[nodeIdx[3]]);
graph.createEdge(nodes[nodeIdx[3]], nodes[nodeIdx[4]]);
}
EXPECT_EQ(graph.getEdgesCount(), 5 * numPatterns);
int countMatch = 0;
for (auto node : graph.getMutableNodes()) {
if (isSubgraphMatch(node, subtree, false)) {
countMatch++;
}
}
EXPECT_EQ(countMatch, 1072);
}
TEST(SubgraphMatcher, IsSubtreeMatchRealistic) {
reset();
auto graph = DataFlowTestGraph();
auto subtree = DataFlowTestGraphCriteria().matchOpG;
EXPECT_FALSE(isSubgraphMatch(graph.opF, subtree));
EXPECT_FALSE(isSubgraphMatch(graph.opC, subtree));
EXPECT_FALSE(isSubgraphMatch(graph.opB, subtree));
EXPECT_FALSE(isSubgraphMatch(graph.dataOut, subtree));
EXPECT_TRUE(isSubgraphMatch(graph.opG, subtree));
}
TEST(SubgraphMatcher, ReplaceGraphRealistic) {
reset();
auto testGraph = DataFlowTestGraph();
auto subtree = DataFlowTestGraphCriteria();
graph.replaceSubgraph(
testGraph.graph,
subtree.matchOpG,
[subtree](
TestGraph& g,
TestGraph::NodeRef opG,
const TestMatchGraph::SubgraphMatchResultType& matchResult) {
auto fusedNode = g.createNode("opFused");
auto opC = getInNode(
matchResult.getMatchNodeMap()->at(subtree.matchOpCOutput), 0);
g.replaceOutEdges(opG, fusedNode);
g.replaceInEdges(opG, fusedNode);
g.replaceInEdges(opC, fusedNode);
g.deleteNodes(matchResult.getMatchedSubgraph()->getNodes());
return true;
});
// Now the nodes are:
// - NumInputs input nodes
// - dataI node
// - fused node
// - output node
// - dataC2 node
auto nodes = testGraph.graph.getMutableNodes();
// Test that the graph is transformed as expected.
EXPECT_EQ(nodes.size(), testGraph.numInputs + 4);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
TestGraph::NodeRef opFused;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
TestGraph::NodeRef dataI;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
TestGraph::NodeRef dataOut;
for (auto node : nodes) {
if (node->data() == "opFused") {
opFused = node;
} else if (node->data() == "dataOut") {
dataOut = node;
} else if (node->data() == "dataI") {
dataI = node;
}
}
EXPECT_EQ(getInNode(dataOut, 0), opFused);
EXPECT_EQ(opFused->getInEdges().size(), testGraph.numInputs + 1);
EXPECT_EQ(getInNode(opFused, 0), dataI);
for (int i = 1; i <= testGraph.numInputs; i++) {
EXPECT_EQ(getInNode(opFused, i)->data(), "input");
}
// Use nom::converters::convertToDotString(&graph.graph, TestGraphNodePrinter)
// to visualize. The transformed graph looks like This
/*
+---------++---------+
| input_A || input_D |
+---------++---------+
| |
| |
v v
+---------+ +--------------------+ +---------+
| input_B | --> | opFused | <-- | input_C |
+---------+ +--------------------+ +---------+
| ^
| |
v |
+---------++---------+
| dataOut || dataI |
+---------++---------+
*/
}

View File

@ -1,60 +0,0 @@
#include <gtest/gtest.h>
#include "test_util.h"
#include "nomnigraph/Graph/Graph.h"
TEST(Tarjans, Simple) {
TestClass t1;
TestClass t2;
nom::Graph<TestClass> g;
nom::Graph<TestClass>::NodeRef n1 = g.createNode(std::move(t1));
nom::Graph<TestClass>::NodeRef n2 = g.createNode(std::move(t2));
g.createEdge(n1, n2);
g.createEdge(n2, n1);
auto sccs = nom::algorithm::tarjans(&g);
EXPECT_EQ(sccs.size(), 1);
}
TEST(Tarjans, WithEdgeStorage) {
TestClass t1;
TestClass t2;
nom::Graph<TestClass, TestClass> g;
nom::Graph<TestClass, TestClass>::NodeRef n1 = g.createNode(std::move(t1));
nom::Graph<TestClass, TestClass>::NodeRef n2 = g.createNode(std::move(t2));
g.createEdge(n1, n2, TestClass());
g.createEdge(n2, n1, TestClass());
auto sccs = nom::algorithm::tarjans(&g);
EXPECT_EQ(sccs.size(), 1);
}
TEST(Tarjans, DAG) {
auto graph = createGraph();
auto sccs = nom::algorithm::tarjans(&graph);
EXPECT_EQ(sccs.size(), 9);
}
TEST(Tarjans, Cycle) {
auto graph = createGraphWithCycle();
auto sccs = nom::algorithm::tarjans(&graph);
EXPECT_EQ(sccs.size(), 8);
}
TEST(Tarjans, Random) {
nom::Graph<TestClass> g;
std::vector<nom::Graph<TestClass>::NodeRef> nodes;
for (auto i = 0; i < 10; ++i) {
TestClass t;
nodes.emplace_back(g.createNode(std::move(t)));
}
for (auto i = 0; i < 30; ++i) {
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,clang-analyzer-security.insecureAPI.rand)
int ri1 = rand() % nodes.size();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,clang-analyzer-security.insecureAPI.rand)
int ri2 = rand() % nodes.size();
g.createEdge(nodes[ri1], nodes[ri2]);
}
auto sccs = nom::algorithm::tarjans(&g);
EXPECT_GE(sccs.size(), 1);
}

View File

@ -1,73 +0,0 @@
#include <gtest/gtest.h>
#include "test_util.h"
#include "nomnigraph/Graph/Graph.h"
using GraphT = nom::Graph<TestClass>;
using TopoSortT = nom::algorithm::TopoSort<GraphT>;
TEST(TopoSort, Simple) {
GraphT g;
auto n1 = createTestNode(g);
auto n2 = createTestNode(g);
g.createEdge(n1, n2);
auto res = nom::algorithm::topoSort(&g);
EXPECT_EQ(res.status, TopoSortT::Result::OK);
EXPECT_EQ(res.nodes.size(), 2);
EXPECT_EQ(res.nodes[0], n1);
EXPECT_EQ(res.nodes[1], n2);
}
TEST(TopoSort, DAG) {
GraphT g;
auto n1 = createTestNode(g);
auto n2 = createTestNode(g);
auto n3 = createTestNode(g);
auto n4 = createTestNode(g);
g.createEdge(n1, n2);
g.createEdge(n1, n3);
g.createEdge(n2, n4);
g.createEdge(n3, n4);
auto res = nom::algorithm::topoSort(&g);
EXPECT_EQ(res.status, TopoSortT::Result::OK);
EXPECT_EQ(res.nodes.size(), 4);
auto i1 = std::find(res.nodes.begin(), res.nodes.end(), n1);
auto i2 = std::find(res.nodes.begin(), res.nodes.end(), n2);
auto i3 = std::find(res.nodes.begin(), res.nodes.end(), n3);
auto i4 = std::find(res.nodes.begin(), res.nodes.end(), n4);
ASSERT_TRUE(i1 != res.nodes.end());
ASSERT_TRUE(i2 != res.nodes.end());
ASSERT_TRUE(i3 != res.nodes.end());
ASSERT_TRUE(i4 != res.nodes.end());
ASSERT_LT(i1, i2);
ASSERT_LT(i1, i3);
ASSERT_LT(i2, i4);
ASSERT_LT(i3, i4);
}
TEST(TopoSort, Cycle1) {
GraphT g;
auto n1 = createTestNode(g);
auto n2 = createTestNode(g);
g.createEdge(n1, n2);
g.createEdge(n2, n1);
auto res = nom::algorithm::topoSort(&g);
EXPECT_EQ(res.status, TopoSortT::Result::CYCLE);
EXPECT_EQ(res.nodes.size(), 0);
}
TEST(TopoSort, Cycle2) {
GraphT g;
auto n1 = createTestNode(g);
auto n2 = createTestNode(g);
auto n3 = createTestNode(g);
auto n4 = createTestNode(g);
g.createEdge(n1, n2);
g.createEdge(n2, n3);
g.createEdge(n3, n4);
g.createEdge(n4, n2);
auto res = nom::algorithm::topoSort(&g);
EXPECT_EQ(res.status, TopoSortT::Result::CYCLE);
EXPECT_EQ(res.nodes.size(), 0);
}

View File

@ -138,9 +138,6 @@ ignore_errors = True
[mypy-caffe2.proto.*] [mypy-caffe2.proto.*]
ignore_errors = True ignore_errors = True
[mypy-caffe2.core.nomnigraph.op_gen]
ignore_errors = True
[mypy-caffe2.distributed.store_ops_test_util] [mypy-caffe2.distributed.store_ops_test_util]
ignore_errors = True ignore_errors = True