mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Drop caffe2 nomnigraph (#127086)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/127086 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
f6ef832e87
commit
1be7e4086a
12
c2_defs.bzl
12
c2_defs.bzl
@ -41,16 +41,6 @@ def get_c2_mpscnn_test():
|
||||
|
||||
return bool(int(c2_mpscnn_test))
|
||||
|
||||
def get_c2_nomnigraph():
|
||||
c2_nomnigraph = native.read_config("caffe2", "enable_nomnigraph", "1")
|
||||
|
||||
expect(
|
||||
c2_nomnigraph in ("0", "1"),
|
||||
c2_nomnigraph,
|
||||
)
|
||||
|
||||
return bool(int(c2_nomnigraph))
|
||||
|
||||
def get_c2_qpl():
|
||||
c2_qpl = native.read_config("caffe2", "enable_qpl", "1")
|
||||
|
||||
@ -125,8 +115,6 @@ C2_XPLAT_HPTT_PREPROCESSOR_FLAGS = [
|
||||
|
||||
def get_c2_xplat_preprocessor_flags():
|
||||
flags = get_c2_xplat_no_hptt_preprocessor_flags() + C2_XPLAT_HPTT_PREPROCESSOR_FLAGS
|
||||
if get_c2_nomnigraph():
|
||||
flags.append("-DCAFFE2_OPTIMIZER")
|
||||
return flags
|
||||
|
||||
def get_c2_xplat_no_hptt_compiler_flags():
|
||||
|
@ -1,22 +0,0 @@
|
||||
# ---[ CPU files.
|
||||
file(GLOB_RECURSE NOMNI_SRCS *.cc)
|
||||
file(GLOB_RECURSE NOMNI_TEST_SRCS *Test.cc)
|
||||
exclude(NOMNI_SRCS "${NOMNI_SRCS}" "${NOMNI_TEST_SRCS}")
|
||||
set(NOMNI_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||
|
||||
list(APPEND Caffe2_CPU_SRCS ${NOMNI_SRCS})
|
||||
list(APPEND Caffe2_CPU_INCLUDE ${NOMNI_INCLUDE_DIR})
|
||||
list(APPEND Caffe2_GPU_INCLUDE ${NOMNI_INCLUDE_DIR})
|
||||
list(APPEND Caffe2_CPU_TEST_SRCS ${NOMNI_TEST_SRCS})
|
||||
|
||||
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||
DESTINATION ${CMAKE_INSTALL_PREFIX}
|
||||
FILES_MATCHING PATTERN "*.h")
|
||||
|
||||
# ---[ Send the lists to the parent scope.
|
||||
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
|
||||
set(Caffe2_CPU_INCLUDE ${Caffe2_CPU_INCLUDE} PARENT_SCOPE)
|
||||
set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE)
|
||||
if(USE_TENSORRT)
|
||||
set(Caffe2_GPU_INCLUDE ${Caffe2_GPU_INCLUDE} PARENT_SCOPE)
|
||||
endif()
|
@ -1,154 +0,0 @@
|
||||
# nomnigraph
|
||||
|
||||
nomnigraph is caffe2's graph transformation subsystem
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
The output of `caffe2::convertToNNModule(caffe2::NetDef)` (found in `caffe2/opt`) is an `NNModule`.
|
||||
The output of `caffe2::convertToCaffe2Proto(nom::repr::NNModule*, caffe2::NetDef)` is a `NetDef`.
|
||||
`convertToCaffe2Proto(convertToNNModule(n), n)` should basically return an unchanged network.
|
||||
|
||||
An `NNModule` is composed of both `dataFlow` and `controlFlow` graphs.
|
||||
|
||||
Creating a new operator is straightforward.
|
||||
```cpp
|
||||
auto reluNode = nn.dataFlow.createNode(make_unique<nom::repr::Relu>());
|
||||
```
|
||||
The line above does a few things worth talking about.
|
||||
|
||||
1) It creates a new node using the graph API (both dataFlow and controlFlow are `Graph`s).
|
||||
2) It instantiates the node with data, specifically a `unique_ptr` to a neural network operator.
|
||||
3) This `unique_ptr` contains a type that inherits from `NeuralNetOperator` and forms the fundamental representation described in the IR section below.
|
||||
|
||||
Inserting this operator into the graph would look something like this:
|
||||
|
||||
```cpp
|
||||
auto edge = nn.dataFlow.createEdge(convOutputTensorNode, reluNode);
|
||||
```
|
||||
|
||||
Some notes here:
|
||||
1) Again the graph API is used to insert the node into the graph with an edge.
|
||||
2) Operators are strictly connected to Tensors, not other operators.
|
||||
|
||||
## IR
|
||||
|
||||
nomnigraph has a *parallel* representation that can contain annotations with caffe2's OperatorDef.
|
||||
|
||||
If you call `caffe2::convertToNNModule(caffe2::NetDef)`, every operator in the `NNModule` will be annotated with a reference to the original operator in the net.
|
||||
|
||||
This means you should not delete the original protobuf.
|
||||
|
||||
```cpp
|
||||
auto conv = repr::nn::get<repr::Conv>(convNode);
|
||||
if (conv->getAnnotation()) {
|
||||
auto annotation = dyn_cast<caffe2::Caffe2Annotation>(conv->getMutableAnnotation());
|
||||
OperatorDef* op = annotation->getMutableOperatorDef();
|
||||
// Do stuff with the caffe2 protobuf
|
||||
}
|
||||
```
|
||||
|
||||
If you create a new op, as shown in the example above and copied here:
|
||||
```cpp
|
||||
auto reluNode = nn.dataFlow.createNode(make_unique<nom::repr::Relu>());
|
||||
```
|
||||
it will not have a caffe2 annotation.
|
||||
|
||||
How does `caffe2::convertToCaffe2Proto(nom::repr::NNModule*, caffe2::NetDef)` deal with this?
|
||||
|
||||
Operators are either generated manually (see the implementation in `caffe2/opt/converter.cc`) or automatically.
|
||||
The automatic generation is done by simply setting the operator `type` to the name of the operator.
|
||||
If you'd like to add your own operator to a net and need it to be generated (i.e. are writing a transform that inserts
|
||||
new nodes which have attributes likes args) you will need to add your own code to `caffe2/opt/converter.cc`.
|
||||
|
||||
Do not create `OperatorDef`s in the transformation itself! This is an anti-pattern as the logic becomes less portable.
|
||||
|
||||
## API
|
||||
|
||||
Below is a subset of selected API calls that are quite useful. Lower level manipulation calls are omitted.
|
||||
|
||||
### Graph transformation API
|
||||
Nomnigraph provides a ReplaceSubgraph API to perform graph transformation operations without having to write custom subgraph matching logic. The main header file is [SubgraphMatcher.h](include/nomnigraph/Transformations/SubgraphMatcher.h).
|
||||
|
||||
ReplaceSubgraph API takes in
|
||||
- A subgraph pattern to be matched
|
||||
- A graph to be scanned for matching patterns
|
||||
- A ReplaceGraph lambda function that takes in a matched subgraph; callers should implement specific graph transformation operation in the lambda.
|
||||
|
||||
The ReplaceSubgraph implementation takes care of the pattern matching part and also provides tools for callers to implement graph transformation logic with less effort.
|
||||
|
||||
Example usage of the API can be found in [subgraph_matcher_test.cc](tests/subgraph_matcher_test.cc)
|
||||
|
||||
Example usage of the API for NNGraph can be found in [neural_net_test.cc](tests/neural_net_test.cc)
|
||||
|
||||
### Graph API
|
||||
Nomnigraph's core graph APIs provide a generic graph data structure and basic graph manipulation abilities. The main header file is [Graph.h](include/nomnigraph/Graph/Graph.h).
|
||||
|
||||
```cpp
|
||||
auto g = Graph<T>(); // Constructor
|
||||
|
||||
Graph<T>::NodeRef n = g.createNode(T t); // Returns reference to the node
|
||||
|
||||
Graph<T>::EdgeRef e = g.createEdge(n1, n2); // Returns reference to the edge
|
||||
|
||||
g.deleteNode(n); // Deletes the node and all of its in/out edges from the graph
|
||||
// Use g.deleteNode(n, false); to keep the edges around.
|
||||
|
||||
g.deleteEdge(e); // Deletes the edge between two nodes.
|
||||
|
||||
auto e = g.getEdge(n1, n2); // Gets the first edge that has n1 as a tail and n2 as the head.
|
||||
|
||||
auto ns = g.getMutableNodes(); // Returns a vector of Graph<T>::NodeRef
|
||||
|
||||
auto es = g.getMutableEdges(); // Returns a vector of Graph<T>::EdgeRef
|
||||
|
||||
T d = n->data(); // Get the data stored at the node
|
||||
```
|
||||
|
||||
### NN API
|
||||
NN (NeuralNet) extends core Graph with functionalities specific to neural network computation graph. The main header file is [NeuralNet.h](include/nomnigraph/Representations/NeuralNet.h).
|
||||
|
||||
Type checking & data accessing
|
||||
|
||||
```cpp
|
||||
repr::NNModule nn = ...;
|
||||
using namespace nom;
|
||||
|
||||
repr::NNGraph::NodeRef n; // Canonical node of the neural network
|
||||
|
||||
bool b = repr::nn::is<repr::Tensor>(n); // Checks the type stored on the node. (Works with parent types.)
|
||||
|
||||
repr::Conv* c = repr::nn::get<repr::Conv>(n); // Returns a pointer to the NeuralNetOperator or NeuralNetData in the node
|
||||
```
|
||||
|
||||
Iterate through nodes in a NNGraph.
|
||||
```cpp
|
||||
auto pairs = dataIterator(nn); // A useful paradigm for iterating through nodes and corresponding data in no particular order.
|
||||
auto nodeRefs = nodeIterator(nn); // Iterate through nodes in no particular order.
|
||||
// See https://github.com/pytorch/pytorch/blob/main/caffe2/opt/mobile.cc#L106-L109
|
||||
```
|
||||
|
||||
|
||||
These functions make it easy to check attributes on nodes.
|
||||
```cpp
|
||||
// -- Tensor node functions --
|
||||
bool b = hasProducer(tensorNode); // Checks for producers.
|
||||
auto n = getProducer(tensorNode); // Returns the producer of the tensor
|
||||
bool b = hasConsumer(tensorNode); // Checks for consumers.
|
||||
std::vector<NNGraph::NodeRef> consumers = getConsumers(tensorNode); // Returns a vector of all consumers of the tensor.
|
||||
|
||||
// -- Operator node functions --
|
||||
bool b = hasInputs(n); // Checks if there are any input tensors.
|
||||
std::vector<NNGraph::NodeRef> getInputs(n); // Returns a vector of all the input tensor nodes.
|
||||
std::vector<NNGraph::NodeRef> getOutputs(n); // Returns a vector of all the output tensor nodes.
|
||||
```
|
||||
|
||||
These functions are less commonly useful
|
||||
```cpp
|
||||
coalesceInsertedDataDependencies(&nn); // Fixes up all the inserted dependencies in the dataflow graph.
|
||||
|
||||
insertOp<repr::Relu>(nn.dataFlow, n1, n2); // Inserts an operator into the dataflow graph and creates a new blob to do so.
|
||||
// n1 or n2 must be a tensor and the inserted blob inherits the name from that, appending an underscore.
|
||||
|
||||
convertNode<repr::ConvRelu>(nn.dataFlow, n); // Converts the data at the node to a new node by calling the passed in type with the old node's data as the constructor argument.
|
||||
```
|
@ -1,245 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
import argparse
|
||||
from textwrap import dedent
|
||||
from subprocess import call
|
||||
|
||||
|
||||
def parse_lines(lines):
|
||||
# States
|
||||
EMPTY = 0
|
||||
OP = 1
|
||||
MACRO = 2
|
||||
parse_state = EMPTY
|
||||
|
||||
# Preprocess the macros
|
||||
curr_macro = ""
|
||||
macros = {}
|
||||
|
||||
index = 0
|
||||
while index < len(lines):
|
||||
line = lines[index]
|
||||
if line.lower().startswith("macro"):
|
||||
assert parse_state == EMPTY
|
||||
macro_line = line.split(" ")
|
||||
# Support macros that look like attributes
|
||||
# e.g. macro - CONV_LIKE
|
||||
curr_macro = " ".join(macro_line[1:])
|
||||
assert curr_macro not in macros, 'Macro "{}" defined twice.'.format(
|
||||
curr_macro
|
||||
)
|
||||
macros[curr_macro] = []
|
||||
parse_state = MACRO
|
||||
lines = lines[:index] + lines[index + 1 :]
|
||||
continue
|
||||
elif line.lower().startswith("endmacro"):
|
||||
assert parse_state == MACRO
|
||||
parse_state = EMPTY
|
||||
lines = lines[:index] + lines[index + 1 :]
|
||||
continue
|
||||
elif parse_state == MACRO:
|
||||
macros[curr_macro].append(line)
|
||||
lines = lines[:index] + lines[index + 1 :]
|
||||
continue
|
||||
index += 1
|
||||
|
||||
index = 0
|
||||
while index < len(lines):
|
||||
line = lines[index]
|
||||
if line in macros:
|
||||
lines = lines[:index] + macros[line] + lines[index + 1 :]
|
||||
index += len(macros[line]) - 1
|
||||
index += 1
|
||||
|
||||
# Now parse the file
|
||||
curr_op = ""
|
||||
# dict of the form
|
||||
# opName : { attributes: [], ... }
|
||||
ops = {}
|
||||
# To preserve parsing order for dependencies (for things like init_from)
|
||||
op_list = []
|
||||
|
||||
for line in lines:
|
||||
if not len(line):
|
||||
continue
|
||||
if line[0] == "-":
|
||||
assert parse_state is OP
|
||||
attr = [_.strip() for _ in line[1:].split(":")]
|
||||
assert attr[0][0].isupper()
|
||||
if len(attr) == 2: # attribute : type
|
||||
ops[curr_op]["attributes"].append((attr[0], attr[1]))
|
||||
elif len(attr) == 3: # attribute : type
|
||||
ops[curr_op]["attributes"].append((attr[0], attr[1], attr[2]))
|
||||
else:
|
||||
op = [l.strip() for l in line.split(":")]
|
||||
assert len(op[0].split(" ")) == 1
|
||||
parse_state = OP
|
||||
curr_op = op[0]
|
||||
assert curr_op not in ops
|
||||
ops[curr_op] = {}
|
||||
op_list.append(curr_op)
|
||||
if len(op) > 1:
|
||||
ops[curr_op]["init_from"] = [op[1]]
|
||||
ops[curr_op]["attributes"] = []
|
||||
return ops, op_list
|
||||
|
||||
|
||||
def gen_class(op, op_def):
|
||||
attributes = op_def["attributes"]
|
||||
attribute_args = []
|
||||
default_init = "NeuralNetOperator(NNKind::{op})".format(op=op)
|
||||
attribute_init = [default_init]
|
||||
attribute_declarations = []
|
||||
attribute_getters = []
|
||||
attribute_setters = []
|
||||
for attr in attributes:
|
||||
lower_name = attr[0][0].lower() + attr[0][1:]
|
||||
private_name = lower_name + "_"
|
||||
default_arg = "" if len(attr) < 3 else " = {}".format(attr[2])
|
||||
name = attr[0]
|
||||
t = attr[1]
|
||||
attr_arg = "{type} {lower_name}".format(
|
||||
type=t, lower_name=lower_name + default_arg
|
||||
)
|
||||
attr_init = "{private_name}({lower_name})".format(
|
||||
private_name=private_name, lower_name=lower_name)
|
||||
attr_declare = "{type} {private_name};".format(
|
||||
type=t, private_name=private_name)
|
||||
attr_get = dedent(
|
||||
"""
|
||||
{type} get{name}() const {{
|
||||
return {private_name};
|
||||
}}
|
||||
""".format(
|
||||
type=t, name=name, private_name=private_name
|
||||
)
|
||||
)
|
||||
attr_set = dedent(
|
||||
"""
|
||||
void set{name}({type} {lower_name}) {{
|
||||
{private_name} = {lower_name};
|
||||
}}
|
||||
""".format(
|
||||
type=t, name=name, private_name=private_name, lower_name=lower_name
|
||||
)
|
||||
)
|
||||
attribute_args.append(attr_arg)
|
||||
attribute_init.append(attr_init)
|
||||
attribute_declarations.append(attr_declare)
|
||||
attribute_getters.append(attr_get)
|
||||
attribute_setters.append(attr_set)
|
||||
|
||||
extra_init = ""
|
||||
if "init_from" in op_def:
|
||||
for other_op in op_def["init_from"]:
|
||||
lower_other_op = other_op[0].lower() + other_op[1:]
|
||||
other_init = [default_init]
|
||||
for attr in attributes:
|
||||
lower_name = attr[0][0].lower() + attr[0][1:]
|
||||
private_name = lower_name + "_"
|
||||
other_init.append(
|
||||
"{private_name}({other_op}.get{name}())".format(
|
||||
name=attr[0], private_name=private_name, other_op=lower_other_op
|
||||
)
|
||||
)
|
||||
init = dedent(
|
||||
"""
|
||||
{op}(const {other_op}& {lower_other_op}) :
|
||||
{other_init} {{}}
|
||||
""".format(
|
||||
op=op,
|
||||
other_op=other_op,
|
||||
lower_other_op=lower_other_op,
|
||||
other_init=",\n ".join(other_init),
|
||||
)
|
||||
)
|
||||
extra_init += init
|
||||
|
||||
return dedent(
|
||||
"""
|
||||
class {op} : public NeuralNetOperator {{
|
||||
public:
|
||||
{op}({attribute_args}) :
|
||||
{attribute_init} {{}}
|
||||
{extra_init}
|
||||
~{op}() {{}}
|
||||
|
||||
NOMNIGRAPH_DEFINE_NN_RTTI({op});
|
||||
{getters}{setters}
|
||||
private:
|
||||
{attribute_declarations}
|
||||
}};
|
||||
|
||||
""".format(
|
||||
op=op,
|
||||
extra_init=extra_init,
|
||||
getters="".join(attribute_getters),
|
||||
setters="".join(attribute_setters),
|
||||
attribute_args=",\n".join(attribute_args),
|
||||
attribute_init=",\n".join(attribute_init),
|
||||
attribute_declarations="\n".join(attribute_declarations),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def gen_classes(ops, op_list):
|
||||
f = ""
|
||||
for op in op_list:
|
||||
f += gen_class(op, ops[op])
|
||||
return f
|
||||
|
||||
|
||||
def gen_enum(op_list):
|
||||
return ",\n".join([op for op in op_list]) + "\n"
|
||||
|
||||
|
||||
def gen_names(op_list):
|
||||
f = ""
|
||||
for op in op_list:
|
||||
f += dedent(
|
||||
"""
|
||||
case NNKind::{name}:
|
||||
return \"{name}\";
|
||||
""".format(
|
||||
name=op
|
||||
)
|
||||
)
|
||||
return f
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Generate op files.")
|
||||
parser.add_argument("--install_dir", help="installation directory")
|
||||
parser.add_argument("--source_def", help="ops.def", action="append")
|
||||
args = parser.parse_args()
|
||||
install_dir = args.install_dir
|
||||
sources = args.source_def
|
||||
|
||||
lines = []
|
||||
for source in sources:
|
||||
with open(source, "rb") as f:
|
||||
lines_tmp = f.readlines()
|
||||
lines += [l.strip().decode("utf-8") for l in lines_tmp]
|
||||
ops, op_list = parse_lines(lines)
|
||||
|
||||
with open(install_dir + "/OpClasses.h", "wb") as f:
|
||||
f.write(gen_classes(ops, op_list).encode("utf-8"))
|
||||
with open(install_dir + "/OpNames.h", "wb") as f:
|
||||
f.write(gen_names(op_list).encode("utf-8"))
|
||||
with open(install_dir + "/OpEnum.h", "wb") as f:
|
||||
f.write(gen_enum(op_list).encode("utf-8"))
|
||||
|
||||
try:
|
||||
cmd = ["clang-format", "-i", install_dir + "/OpClasses.h"]
|
||||
call(cmd)
|
||||
cmd = ["clang-format", "-i", install_dir + "/OpNames.h"]
|
||||
call(cmd)
|
||||
cmd = ["clang-format", "-i", install_dir + "/OpEnum.h"]
|
||||
call(cmd)
|
||||
except Exception:
|
||||
pass
|
@ -1,80 +0,0 @@
|
||||
macro - CONV_ATTRS
|
||||
- KernelShape : vector<int>
|
||||
- Pads : vector<int> : {0, 0}
|
||||
- Strides : vector<int> : {1, 1}
|
||||
- Group : int : 1
|
||||
- Dilations: vector<int> : {1, 1}
|
||||
endmacro
|
||||
|
||||
macro - POOL_ATTRS
|
||||
- KernelShape : vector<int>
|
||||
- Pads : vector<int> : {0, 0}
|
||||
- Strides : vector<int> : {1, 1}
|
||||
endmacro
|
||||
|
||||
Relu
|
||||
|
||||
Conv
|
||||
- CONV_ATTRS
|
||||
|
||||
ConvRelu : Conv
|
||||
- CONV_ATTRS
|
||||
|
||||
ConvTranspose
|
||||
- CONV_ATTRS
|
||||
|
||||
AveragePool
|
||||
- POOL_ATTRS
|
||||
|
||||
AveragePoolRelu : AveragePool
|
||||
- POOL_ATTRS
|
||||
|
||||
MaxPool
|
||||
- POOL_ATTRS
|
||||
|
||||
MaxPoolRelu : MaxPool
|
||||
- POOL_ATTRS
|
||||
|
||||
Sum
|
||||
|
||||
SumRelu : Sum
|
||||
|
||||
Send
|
||||
- Destination : string
|
||||
|
||||
Receive
|
||||
- Source : string
|
||||
|
||||
BatchNormalization
|
||||
- Epsilon : float : 1e-5f
|
||||
- Momentum : float : 0.9f
|
||||
- Spatial : bool : true
|
||||
- IsTest : bool : false
|
||||
|
||||
Clip
|
||||
- Min : float
|
||||
- Max : float
|
||||
|
||||
FC
|
||||
- Axis : int : 1
|
||||
- AxisW : int : 1
|
||||
|
||||
GivenTensorFill
|
||||
Concat
|
||||
- Axis : int : -1
|
||||
- AddAxis : bool : false
|
||||
Softmax
|
||||
ChannelShuffle
|
||||
Add
|
||||
- Broadcast : int : 0
|
||||
Reshape
|
||||
Flatten
|
||||
|
||||
CopyToOpenCL
|
||||
CopyFromOpenCL
|
||||
|
||||
NCHW2NHWC
|
||||
NHWC2NCHW
|
||||
|
||||
Declare
|
||||
Export
|
@ -1,126 +0,0 @@
|
||||
#include "test_util.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
TEST(DominatorTree, Test1) {
|
||||
nom::Graph<std::string> graph;
|
||||
auto r = graph.createNode(std::string("r"));
|
||||
auto a = graph.createNode(std::string("a"));
|
||||
auto b = graph.createNode(std::string("b"));
|
||||
auto c = graph.createNode(std::string("c"));
|
||||
auto d = graph.createNode(std::string("d"));
|
||||
auto e = graph.createNode(std::string("e"));
|
||||
auto f = graph.createNode(std::string("f"));
|
||||
auto g = graph.createNode(std::string("g"));
|
||||
auto l = graph.createNode(std::string("l"));
|
||||
auto h = graph.createNode(std::string("h"));
|
||||
auto i = graph.createNode(std::string("i"));
|
||||
auto j = graph.createNode(std::string("j"));
|
||||
auto k = graph.createNode(std::string("k"));
|
||||
graph.createEdge(r, a);
|
||||
graph.createEdge(r, b);
|
||||
graph.createEdge(r, c);
|
||||
graph.createEdge(c, f);
|
||||
graph.createEdge(c, g);
|
||||
graph.createEdge(g, j);
|
||||
graph.createEdge(g, i);
|
||||
graph.createEdge(f, i);
|
||||
graph.createEdge(i, k);
|
||||
graph.createEdge(k, i);
|
||||
graph.createEdge(k, r);
|
||||
graph.createEdge(a, d);
|
||||
graph.createEdge(b, d);
|
||||
graph.createEdge(b, a);
|
||||
graph.createEdge(b, e);
|
||||
graph.createEdge(d, l);
|
||||
graph.createEdge(l, h);
|
||||
graph.createEdge(h, k);
|
||||
graph.createEdge(h, e);
|
||||
graph.createEdge(e, h);
|
||||
|
||||
auto tree = nom::algorithm::dominatorTree(&graph, r);
|
||||
auto map = nom::algorithm::immediateDominatorMap(&graph, r);
|
||||
|
||||
EXPECT_EQ(map[j], g);
|
||||
EXPECT_EQ(map[g], c);
|
||||
EXPECT_EQ(map[f], c);
|
||||
EXPECT_EQ(map[l], d);
|
||||
EXPECT_EQ(map[a], r);
|
||||
EXPECT_EQ(map[b], r);
|
||||
EXPECT_EQ(map[c], r);
|
||||
EXPECT_EQ(map[d], r);
|
||||
EXPECT_EQ(map[e], r);
|
||||
EXPECT_EQ(map[h], r);
|
||||
EXPECT_EQ(map[i], r);
|
||||
EXPECT_EQ(map[k], r);
|
||||
auto domFrontMap = nom::algorithm::dominanceFrontierMap(&graph, r);
|
||||
}
|
||||
|
||||
// https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec04-SSA.pdf
|
||||
// using example on page 24
|
||||
TEST(DominatorTree, Test2) {
|
||||
nom::Graph<std::string> graph;
|
||||
auto entry = graph.createNode(std::string("entry"));
|
||||
auto n1 = graph.createNode(std::string("1"));
|
||||
auto n2 = graph.createNode(std::string("2"));
|
||||
auto n3 = graph.createNode(std::string("3"));
|
||||
auto n4 = graph.createNode(std::string("4"));
|
||||
auto n5 = graph.createNode(std::string("5"));
|
||||
auto n6 = graph.createNode(std::string("6"));
|
||||
auto n7 = graph.createNode(std::string("7"));
|
||||
auto exit = graph.createNode(std::string("exit"));
|
||||
graph.createEdge(entry, n1);
|
||||
graph.createEdge(n1, n2);
|
||||
graph.createEdge(n1, n5);
|
||||
graph.createEdge(n5, n1);
|
||||
graph.createEdge(n2, n3);
|
||||
graph.createEdge(n2, n4);
|
||||
graph.createEdge(n3, n6);
|
||||
graph.createEdge(n4, n6);
|
||||
graph.createEdge(n6, n7);
|
||||
graph.createEdge(n5, n7);
|
||||
graph.createEdge(n7, exit);
|
||||
|
||||
auto domFrontMap = nom::algorithm::dominanceFrontierMap(&graph, entry);
|
||||
using noderef = nom::Graph<std::string>::NodeRef;
|
||||
std::unordered_map<noderef, std::unordered_set<noderef>> checkMap = {
|
||||
{n1, {n1}},
|
||||
{n2, {n7}},
|
||||
{n3, {n6}},
|
||||
{n4, {n6}},
|
||||
{n5, {n1, n7}},
|
||||
{n6, {n7}}
|
||||
};
|
||||
// NOLINTNEXTLINE(performance-for-range-copy)
|
||||
for (auto pair : domFrontMap) {
|
||||
EXPECT_EQ(pair.second, checkMap[pair.first]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Subgraph, InduceEdges) {
|
||||
auto g = createGraph();
|
||||
auto sg = decltype(g)::SubgraphType();
|
||||
for (const auto& node : g.getMutableNodes()) {
|
||||
sg.addNode(node);
|
||||
}
|
||||
|
||||
nom::algorithm::induceEdges(&sg);
|
||||
|
||||
for (const auto& edge : g.getMutableEdges()) {
|
||||
EXPECT_TRUE(sg.hasEdge(edge));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Subgraph, InduceEdgesCycle) {
|
||||
auto g = createGraphWithCycle();
|
||||
auto sg = decltype(g)::SubgraphType();
|
||||
for (const auto& node : g.getMutableNodes()) {
|
||||
sg.addNode(node);
|
||||
}
|
||||
|
||||
nom::algorithm::induceEdges(&sg);
|
||||
|
||||
for (const auto& edge : g.getMutableEdges()) {
|
||||
EXPECT_TRUE(sg.hasEdge(edge));
|
||||
}
|
||||
}
|
@ -1,123 +0,0 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <set>
|
||||
|
||||
#include "test_util.h"
|
||||
|
||||
#include "nomnigraph/Converters/Dot.h"
|
||||
#include "nomnigraph/Graph/Algorithms.h"
|
||||
#include "nomnigraph/Graph/Graph.h"
|
||||
|
||||
TEST(BinaryMatch, NoMatch) {
|
||||
auto graph = createGraph();
|
||||
auto matches = nom::algorithm::binaryMatch(
|
||||
&graph, [](decltype(graph)::NodeRef n) { return false; });
|
||||
EXPECT_EQ(matches.size(), 0);
|
||||
}
|
||||
|
||||
TEST(BinaryMatch, AllMatch) {
|
||||
auto graph = createGraph();
|
||||
auto matches = nom::algorithm::binaryMatch(
|
||||
&graph, [](decltype(graph)::NodeRef n) { return true; });
|
||||
EXPECT_EQ(matches.size(), 1);
|
||||
EXPECT_EQ(matches.front().getNodesCount(), graph.getMutableNodes().size());
|
||||
}
|
||||
|
||||
TEST(BinaryMatch, EmptyGraph) {
|
||||
nom::Graph<std::string> graph;
|
||||
auto matches = nom::algorithm::binaryMatch(
|
||||
&graph, [](decltype(graph)::NodeRef n) { return true; });
|
||||
EXPECT_EQ(matches.size(), 0);
|
||||
}
|
||||
|
||||
// We should get this back:
|
||||
// +---+ +-------+
|
||||
// | 4 | <-- | 2 |
|
||||
// +---+ +-------+
|
||||
// | |
|
||||
// | |
|
||||
// | v
|
||||
// | +-------+
|
||||
// | | 3 |
|
||||
// | +-------+
|
||||
// | |
|
||||
// | |
|
||||
// | v
|
||||
// | +-------+
|
||||
// +-----> | 6 |
|
||||
// +-------+
|
||||
TEST(BinaryMatch, Basic) {
|
||||
auto graph = createGraph();
|
||||
auto matches =
|
||||
nom::algorithm::binaryMatch(&graph, [](decltype(graph)::NodeRef n) {
|
||||
if (n->data() == "2" || n->data() == "3" || n->data() == "4" ||
|
||||
n->data() == "6") {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
EXPECT_EQ(matches.size(), 1);
|
||||
auto match = matches.front();
|
||||
EXPECT_EQ(match.getNodesCount(), 4);
|
||||
std::set<std::string> exp{"2", "3", "4", "6"};
|
||||
for (auto n : match.getNodes()) {
|
||||
EXPECT_EQ(exp.count(n->data()), 1);
|
||||
exp.erase(n->data());
|
||||
}
|
||||
|
||||
// We found all the those nodes.
|
||||
EXPECT_EQ(exp.size(), 0);
|
||||
}
|
||||
|
||||
// The interesting bit about this test case is that
|
||||
// the predicate does not match on 3.
|
||||
//
|
||||
// As such, this part of the graph
|
||||
// +---+ +-------+
|
||||
// | 4 | <-- | 2 |
|
||||
// +---+ +-------+
|
||||
// | |
|
||||
// | |
|
||||
// | v
|
||||
// | +-------+
|
||||
// | | 3 |
|
||||
// | +-------+
|
||||
// | |
|
||||
// | |
|
||||
// | v
|
||||
// | +-------+
|
||||
// +-----> | 6 |
|
||||
// +-------+
|
||||
//
|
||||
// should match as { 4, 2 }, { 6 } not { 4, 2, 6 }
|
||||
TEST(BinaryMatch, RemovedMiddleNode) {
|
||||
auto graph = createGraph();
|
||||
auto matches =
|
||||
nom::algorithm::binaryMatch(&graph, [](decltype(graph)::NodeRef n) {
|
||||
if (n->data() == "2" || n->data() == "4" || n->data() == "6") {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
EXPECT_EQ(matches.size(), 2);
|
||||
auto match1 = matches.front();
|
||||
auto match2 = matches.back();
|
||||
|
||||
EXPECT_EQ(match1.getNodesCount(), 2);
|
||||
EXPECT_EQ(match2.getNodesCount(), 1);
|
||||
|
||||
std::set<std::string> exp1{"2", "4"};
|
||||
std::set<std::string> exp2{"6"};
|
||||
for (auto n : match1.getNodes()) {
|
||||
EXPECT_EQ(exp1.count(n->data()), 1);
|
||||
exp1.erase(n->data());
|
||||
}
|
||||
for (auto n : match2.getNodes()) {
|
||||
EXPECT_EQ(exp2.count(n->data()), 1);
|
||||
exp2.erase(n->data());
|
||||
}
|
||||
|
||||
EXPECT_EQ(exp1.size(), 0);
|
||||
EXPECT_EQ(exp2.size(), 0);
|
||||
}
|
@ -1,219 +0,0 @@
|
||||
#include "test_util.h"
|
||||
|
||||
#include "nomnigraph/Converters/Dot.h"
|
||||
#include "nomnigraph/Graph/Algorithms.h"
|
||||
#include "nomnigraph/Graph/Graph.h"
|
||||
#include "nomnigraph/Support/Casting.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using TestGraph = nom::Graph<TestClass>;
|
||||
|
||||
TEST(Basic, CreateNodeAndEdge) {
|
||||
TestGraph g;
|
||||
auto n1 = createTestNode(g);
|
||||
auto n2 = createTestNode(g);
|
||||
g.createEdge(n1, n2);
|
||||
EXPECT_TRUE(g.hasNode(n1));
|
||||
EXPECT_TRUE(g.hasNode(n2));
|
||||
}
|
||||
|
||||
TEST(Basic, DeleteNode) {
|
||||
TestGraph g;
|
||||
auto n1 = createTestNode(g);
|
||||
auto n2 = createTestNode(g);
|
||||
g.createEdge(n1, n2);
|
||||
EXPECT_TRUE(g.hasNode(n1));
|
||||
g.deleteNode(n1);
|
||||
EXPECT_FALSE(g.hasNode(n1));
|
||||
}
|
||||
|
||||
TEST(Basic, DeleteEdge) {
|
||||
TestGraph g;
|
||||
auto n1 = createTestNode(g);
|
||||
auto n2 = createTestNode(g);
|
||||
auto e = g.createEdge(n1, n2);
|
||||
g.deleteEdge(e);
|
||||
}
|
||||
|
||||
TEST(Basic, ReplaceEdges) {
|
||||
TestGraph g;
|
||||
auto n1 = createTestNode(g);
|
||||
auto n2 = createTestNode(g);
|
||||
auto n3 = createTestNode(g);
|
||||
auto n4 = createTestNode(g);
|
||||
auto n5 = createTestNode(g);
|
||||
|
||||
g.createEdge(n1, n3);
|
||||
g.createEdge(n2, n3);
|
||||
g.createEdge(n3, n4);
|
||||
/*
|
||||
1 2 5
|
||||
|
|
||||
3
|
||||
|
|
||||
4
|
||||
*/
|
||||
|
||||
EXPECT_FALSE(g.hasEdge(n1, n5));
|
||||
EXPECT_FALSE(g.hasEdge(n2, n5));
|
||||
g.replaceInEdges(n3, n5);
|
||||
/*
|
||||
1 2 3
|
||||
| |
|
||||
5 4
|
||||
*/
|
||||
EXPECT_TRUE(g.hasEdge(n1, n5));
|
||||
EXPECT_TRUE(g.hasEdge(n2, n5));
|
||||
|
||||
EXPECT_FALSE(g.hasEdge(n5, n4));
|
||||
g.replaceOutEdges(n3, n5);
|
||||
/*
|
||||
1 2 3
|
||||
|
|
||||
5
|
||||
|
|
||||
4
|
||||
*/
|
||||
EXPECT_TRUE(g.hasEdge(n5, n4));
|
||||
|
||||
g.replaceNode(n5, n3);
|
||||
// Back to the original graph.
|
||||
/*
|
||||
1 2 5
|
||||
|
|
||||
3
|
||||
|
|
||||
4
|
||||
*/
|
||||
EXPECT_TRUE(g.hasEdge(n1, n3));
|
||||
EXPECT_TRUE(g.hasEdge(n2, n3));
|
||||
}
|
||||
|
||||
TEST(Basic, HasNode) {
|
||||
TestGraph g;
|
||||
auto n1 = createTestNode(g);
|
||||
auto n2 = createTestNode(g);
|
||||
auto n3 = createTestNode(g);
|
||||
g.createEdge(n1, n2);
|
||||
g.createEdge(n1, n3);
|
||||
// Current graph: 1 -> 2 -> 3
|
||||
EXPECT_TRUE(g.hasNode(n1));
|
||||
EXPECT_TRUE(g.hasNode(n2));
|
||||
EXPECT_TRUE(g.hasNode(n3));
|
||||
g.swapNodes(n1, n3);
|
||||
// Current graph: 3 -> 2 -> 1
|
||||
EXPECT_TRUE(g.hasNode(n1));
|
||||
EXPECT_TRUE(g.hasNode(n3));
|
||||
g.deleteNode(n1);
|
||||
// Current graph: 3 -> 2
|
||||
EXPECT_FALSE(g.hasNode(n1));
|
||||
auto n4 = createTestNode(g);
|
||||
EXPECT_TRUE(g.hasNode(n4));
|
||||
g.replaceNode(n2, n4);
|
||||
// Current graph: 3 -> 4 , 2
|
||||
// replaceNode doesn't delete n2.
|
||||
EXPECT_TRUE(g.hasNode(n2));
|
||||
|
||||
// Create a second graph g2, and move the nodes from g2 to g.
|
||||
TestClass t5;
|
||||
nom::Graph<TestClass> g2;
|
||||
nom::Graph<TestClass>::NodeRef n5 = g2.createNode(std::move(t5));
|
||||
EXPECT_TRUE(g2.hasNode(n5));
|
||||
|
||||
EXPECT_FALSE(g.hasNode(n5));
|
||||
g2.moveNode(n5, &g);
|
||||
// Current graph (g1): 3 -> 4, 2, 5
|
||||
EXPECT_TRUE(g.hasNode(n5));
|
||||
}
|
||||
|
||||
TEST(Basic, Moves) {
|
||||
TestGraph g;
|
||||
auto n1 = createTestNode(g);
|
||||
auto n2 = createTestNode(g);
|
||||
auto n3 = createTestNode(g);
|
||||
auto e1 = g.createEdge(n1, n2);
|
||||
auto e2 = g.createEdge(n1, n3);
|
||||
// Current graph: 1 -> 2 -> 3
|
||||
|
||||
TestGraph g2;
|
||||
g.deleteEdge(e2);
|
||||
g.moveNode(n1, &g2);
|
||||
g.moveNode(n2, &g2);
|
||||
g.moveEdge(e1, &g2);
|
||||
EXPECT_TRUE(g.isValid());
|
||||
EXPECT_TRUE(g2.isValid());
|
||||
EXPECT_EQ(g.getMutableNodes().size(), 1);
|
||||
EXPECT_EQ(g2.getMutableNodes().size(), 2);
|
||||
EXPECT_EQ(g.getMutableEdges().size(), 0);
|
||||
EXPECT_EQ(g2.getMutableEdges().size(), 1);
|
||||
}
|
||||
|
||||
TEST(Basic, MoveSubgraph) {
|
||||
TestGraph g;
|
||||
auto n1 = createTestNode(g);
|
||||
auto n2 = createTestNode(g);
|
||||
auto n3 = createTestNode(g);
|
||||
auto e1 = g.createEdge(n1, n2);
|
||||
auto e2 = g.createEdge(n1, n3);
|
||||
// Current graph: 1 -> 2 -> 3
|
||||
|
||||
TestGraph g2;
|
||||
|
||||
g.deleteEdge(e2);
|
||||
|
||||
TestGraph::SubgraphType sg;
|
||||
sg.addNode(n1);
|
||||
sg.addNode(n2);
|
||||
sg.addEdge(e1);
|
||||
|
||||
g.moveSubgraph(sg, &g2);
|
||||
|
||||
EXPECT_TRUE(g.isValid());
|
||||
EXPECT_TRUE(g2.isValid());
|
||||
EXPECT_EQ(g.getMutableNodes().size(), 1);
|
||||
EXPECT_EQ(g2.getMutableNodes().size(), 2);
|
||||
EXPECT_EQ(g.getMutableEdges().size(), 0);
|
||||
EXPECT_EQ(g2.getMutableEdges().size(), 1);
|
||||
}
|
||||
|
||||
TEST(Basic, DotGenerator) {
|
||||
TestGraph g;
|
||||
auto n1 = createTestNode(g);
|
||||
auto n2 = createTestNode(g);
|
||||
auto n3 = createTestNode(g);
|
||||
auto e12 = g.createEdge(n1, n2);
|
||||
g.createEdge(n1, n3);
|
||||
|
||||
std::string dot = nom::converters::convertToDotString(&g, TestNodePrinter);
|
||||
|
||||
// sanity check
|
||||
std::string prefix = "digraph G";
|
||||
// Full string comparison of the output is not stable because the dot
|
||||
// string includes node pointer address as node id. We should switch to
|
||||
// comparing full output once dot generator no longer uses addresses.
|
||||
EXPECT_TRUE(dot.compare(0, prefix.length(), prefix) == 0);
|
||||
|
||||
TestGraph::SubgraphType sg;
|
||||
sg.addNode(n1);
|
||||
sg.addNode(n2);
|
||||
sg.addEdge(e12);
|
||||
|
||||
// Convert to dot with subgraph clusters.
|
||||
dot = nom::converters::convertToDotString<TestGraph>(
|
||||
&g, {&sg}, TestNodePrinter);
|
||||
|
||||
// sanity check
|
||||
EXPECT_TRUE(dot.compare(0, prefix.length(), prefix) == 0);
|
||||
|
||||
// Convert a single subgraph to dot.
|
||||
dot = nom::converters::convertToDotString<TestGraph>(sg, TestNodePrinter);
|
||||
|
||||
// sanity check
|
||||
EXPECT_TRUE(dot.compare(0, prefix.length(), prefix) == 0);
|
||||
|
||||
dot =
|
||||
nom::converters::convertToDotRecordString<TestGraph>(&g, TestNodePrinter);
|
||||
// sanity check
|
||||
EXPECT_TRUE(dot.compare(0, prefix.length(), prefix) == 0);
|
||||
}
|
@ -1,37 +0,0 @@
|
||||
#include "test_util.h"
|
||||
|
||||
#include "nomnigraph/Transformations/Match.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
TEST(Match, Basic) {
|
||||
nom::Graph<std::string> graph;
|
||||
auto entry = graph.createNode(std::string("entry"));
|
||||
auto n1 = graph.createNode(std::string("1"));
|
||||
auto n2 = graph.createNode(std::string("2"));
|
||||
auto n3 = graph.createNode(std::string("3"));
|
||||
auto n4 = graph.createNode(std::string("4"));
|
||||
auto n5 = graph.createNode(std::string("5"));
|
||||
auto n6 = graph.createNode(std::string("6"));
|
||||
auto n7 = graph.createNode(std::string("7"));
|
||||
auto exit = graph.createNode(std::string("exit"));
|
||||
graph.createEdge(entry, n1);
|
||||
graph.createEdge(n1, n2);
|
||||
graph.createEdge(n1, n5);
|
||||
graph.createEdge(n5, n1);
|
||||
graph.createEdge(n2, n3);
|
||||
graph.createEdge(n2, n4);
|
||||
graph.createEdge(n3, n6);
|
||||
graph.createEdge(n4, n6);
|
||||
graph.createEdge(n6, n7);
|
||||
graph.createEdge(n5, n7);
|
||||
graph.createEdge(n7, exit);
|
||||
|
||||
nom::Graph<std::string> match_graph;
|
||||
auto m1 = match_graph.createNode(std::string("1"));
|
||||
auto m2 = match_graph.createNode(std::string("2"));
|
||||
match_graph.createEdge(m1, m2);
|
||||
|
||||
nom::Match<decltype(graph)> m(match_graph);
|
||||
EXPECT_EQ(m.match(graph).size(), 1);
|
||||
}
|
@ -1,97 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "test_util.h"
|
||||
|
||||
#include "nomnigraph/Representations/NeuralNet.h"
|
||||
#include "nomnigraph/Transformations/SubgraphMatcher.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using namespace nom;
|
||||
using namespace nom::repr;
|
||||
using namespace nom::repr::nn;
|
||||
|
||||
// Test for the NNGraph subgraph matching APIs.
|
||||
TEST(NeuralNetGraph, ReplaceGraph) {
|
||||
NNGraph graph;
|
||||
|
||||
auto input1 = graph.createNode(std::make_unique<Tensor>("input1"));
|
||||
auto input2 = graph.createNode(std::make_unique<Tensor>("input2"));
|
||||
// Test renaming blob
|
||||
nn::get<Tensor>(input2)->setName("input2_renamed");
|
||||
auto sum = graph.createNode(std::make_unique<Sum>());
|
||||
auto sumOutput = graph.createNode(std::make_unique<Tensor>("sumOutput"));
|
||||
auto relu = graph.createNode(std::make_unique<Relu>());
|
||||
auto reluOutput = graph.createNode(std::make_unique<Tensor>("reluOutput"));
|
||||
|
||||
graph.createEdge(input1, sum);
|
||||
graph.createEdge(input2, sum);
|
||||
graph.createEdge(sum, sumOutput);
|
||||
graph.createEdge(sumOutput, relu);
|
||||
graph.createEdge(relu, reluOutput);
|
||||
|
||||
/* input1 input2
|
||||
\ /
|
||||
\ /
|
||||
sum
|
||||
|
|
||||
|
|
||||
sumOutput
|
||||
|
|
||||
relu
|
||||
|
|
||||
reluOutput
|
||||
*/
|
||||
|
||||
auto mg = NNMatchGraph();
|
||||
auto matchSumInput =
|
||||
mg.createNode(std::move(matchExternalTensorNode().count(2)));
|
||||
auto matchSum = mg.createNode(nn::is<Sum>);
|
||||
mg.createEdge(matchSumInput, matchSum);
|
||||
|
||||
auto matchSumOutput = mg.createNode(nn::is<Tensor>);
|
||||
mg.createEdge(matchSum, matchSumOutput);
|
||||
|
||||
auto matchRelu = mg.createNode(nn::is<Relu>);
|
||||
mg.createEdge(matchSumOutput, matchRelu);
|
||||
|
||||
auto matchRoot = matchRelu;
|
||||
EXPECT_FALSE(mg.isSubgraphMatch(sum, matchRoot).isMatch());
|
||||
EXPECT_FALSE(mg.isSubgraphMatch(reluOutput, matchRoot).isMatch());
|
||||
EXPECT_FALSE(mg.isSubgraphMatch(input1, matchRoot).isMatch());
|
||||
|
||||
EXPECT_TRUE(mg.isSubgraphMatch(relu, matchRoot).isMatch());
|
||||
|
||||
mg.replaceSubgraph(
|
||||
graph,
|
||||
matchRoot,
|
||||
[&matchSumOutput](
|
||||
NNGraph& g,
|
||||
NNGraph::NodeRef relu,
|
||||
const NNMatchGraph::SubgraphMatchResultType& matchResult) {
|
||||
auto fusedNode = g.createNode(std::make_unique<SumRelu>());
|
||||
auto sumNode =
|
||||
getProducer(matchResult.getMatchNodeMap()->at(matchSumOutput));
|
||||
g.replaceOutEdges(relu, fusedNode);
|
||||
g.replaceInEdges(sumNode, fusedNode);
|
||||
g.deleteNodes(matchResult.getMatchedSubgraph()->getNodes());
|
||||
return true;
|
||||
});
|
||||
|
||||
/*
|
||||
Fused graph:
|
||||
|
||||
input1 input2
|
||||
\ /
|
||||
\ /
|
||||
sumRelu
|
||||
|
|
||||
|
|
||||
output
|
||||
*/
|
||||
EXPECT_EQ(graph.getNodesCount(), 4);
|
||||
auto fusedNode = getProducer(reluOutput);
|
||||
EXPECT_TRUE(is<SumRelu>(fusedNode));
|
||||
EXPECT_EQ(getInputs(fusedNode).size(), 2);
|
||||
}
|
@ -1,646 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
#include "test_util.h"
|
||||
|
||||
#include "nomnigraph/Transformations/SubgraphMatcher.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace nom {
|
||||
|
||||
namespace matcher {
|
||||
|
||||
using NodeType = std::string;
|
||||
using Criteria = std::string;
|
||||
using TestGraph = Graph<NodeType>;
|
||||
using TestMatchGraph = MatchGraph<TestGraph>;
|
||||
using TestMatchPredicate = MatchPredicate<TestGraph>;
|
||||
|
||||
// Have just one TestMatchGraph in the tests to make it less verbose to create
|
||||
// the match graphs.
|
||||
TestMatchGraph graph;
|
||||
// Call reset before creating a new TestMatchGraph.
|
||||
void reset() {
|
||||
graph = TestMatchGraph();
|
||||
}
|
||||
|
||||
// Node matches a criteria (string) if the data string is the same as the
|
||||
// criteria. Special case: "*" will match any thing.
|
||||
TestMatchPredicate testMatchPredicate(const Criteria& criteria) {
|
||||
return TestMatchPredicate([criteria](TestGraph::NodeRef node) {
|
||||
return criteria == "*" || criteria == node->data();
|
||||
});
|
||||
}
|
||||
|
||||
Criteria any() {
|
||||
return Criteria("*");
|
||||
}
|
||||
|
||||
// Helper methods to make it less verbose to create match graphs.
|
||||
TestMatchGraph::NodeRef Tree(
|
||||
const Criteria& root,
|
||||
const std::vector<TestMatchGraph::NodeRef>& children = {},
|
||||
int count = 1) {
|
||||
auto result =
|
||||
graph.createNode(std::move(testMatchPredicate(root).count(count)));
|
||||
for (auto& child : children) {
|
||||
graph.createEdge(result, child);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
TestMatchGraph::NodeRef NonTerminal(const Criteria& root, int count = 1) {
|
||||
return graph.createNode(
|
||||
std::move(testMatchPredicate(root).count(count).nonTerminal()));
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> TestGraphNodePrinter(
|
||||
TestGraph::NodeRef node) {
|
||||
std::map<std::string, std::string> labelMap;
|
||||
labelMap["label"] = node->data();
|
||||
return labelMap;
|
||||
};
|
||||
|
||||
// Attempts to create a realistic dataflow graph that shows a fuse procedure.
|
||||
struct DataFlowTestGraph {
|
||||
const int numInputs = 4;
|
||||
|
||||
TestGraph graph;
|
||||
|
||||
TestGraph::NodeRef opB;
|
||||
TestGraph::NodeRef opF;
|
||||
TestGraph::NodeRef opC;
|
||||
TestGraph::NodeRef opG;
|
||||
TestGraph::NodeRef dataOut;
|
||||
|
||||
// Realistic data flow test graph.
|
||||
/*
|
||||
|
||||
|
||||
+---------------+
|
||||
| |
|
||||
| +---------+ | +---------+
|
||||
+---------------------+ | input_A | | | input_B |
|
||||
| +---------+ | +---------+
|
||||
| | | |
|
||||
| | | |
|
||||
| v v v
|
||||
+---------++---------+ +-------------------------+ +--------+
|
||||
| input_C || input_D | --> | opC | --> | dataC2 |
|
||||
+---------++---------+ +-------------------------+ +--------+
|
||||
|
|
||||
|
|
||||
v
|
||||
+---------+
|
||||
| dataC | -+
|
||||
+---------+ |
|
||||
| |
|
||||
| |
|
||||
v |
|
||||
+---------+ |
|
||||
| opB | <+
|
||||
+---------+
|
||||
|
|
||||
|
|
||||
v
|
||||
+---------+
|
||||
| dataB |
|
||||
+---------+
|
||||
|
|
||||
|
|
||||
v
|
||||
+---------+
|
||||
| opF |
|
||||
+---------+
|
||||
|
|
||||
|
|
||||
v
|
||||
+---------+
|
||||
| dataF |
|
||||
+---------+
|
||||
|
|
||||
|
|
||||
v
|
||||
+---------+ +---------+
|
||||
| dataI | --> | opG |
|
||||
+---------+ +---------+
|
||||
|
|
||||
|
|
||||
v
|
||||
+---------+
|
||||
| dataOut |
|
||||
+---------+
|
||||
*/
|
||||
DataFlowTestGraph() {
|
||||
opC = graph.createNode("opC");
|
||||
|
||||
for (int i = 0; i < numInputs; i++) {
|
||||
auto dataInput = graph.createNode("input");
|
||||
graph.createEdge(dataInput, opC);
|
||||
}
|
||||
|
||||
auto dataC = graph.createNode("dataC");
|
||||
auto dataC2 = graph.createNode("dataC2");
|
||||
graph.createEdge(opC, dataC);
|
||||
graph.createEdge(opC, dataC2);
|
||||
|
||||
opB = graph.createNode("opB");
|
||||
// There are 2 edges
|
||||
graph.createEdge(dataC, opB);
|
||||
graph.createEdge(dataC, opB);
|
||||
|
||||
auto dataB = graph.createNode("dataB");
|
||||
graph.createEdge(opB, dataB);
|
||||
|
||||
opF = graph.createNode("opF");
|
||||
graph.createEdge(dataB, opF);
|
||||
|
||||
auto dataF = graph.createNode("dataF");
|
||||
graph.createEdge(opF, dataF);
|
||||
|
||||
auto dataI = graph.createNode("dataI");
|
||||
|
||||
opG = graph.createNode("opG");
|
||||
graph.createEdge(dataF, opG);
|
||||
graph.createEdge(dataI, opG);
|
||||
|
||||
dataOut = graph.createNode("dataOut");
|
||||
graph.createEdge(opG, dataOut);
|
||||
|
||||
// Use nom::converters::convertToDotString(&graph, TestGraphNodePrinter)
|
||||
// to visualize the graph.
|
||||
}
|
||||
};
|
||||
|
||||
struct DataFlowTestGraphCriteria {
|
||||
TestMatchGraph::NodeRef matchOpCOutput;
|
||||
TestMatchGraph::NodeRef matchOpG;
|
||||
|
||||
DataFlowTestGraphCriteria() {
|
||||
auto matchOpCInputs =
|
||||
graph.createNode(std::move(testMatchPredicate(Criteria("input"))
|
||||
.starCount()
|
||||
.nonTerminal()
|
||||
.excludeFromSubgraph()));
|
||||
auto matchOpC = graph.createNode(testMatchPredicate("opC"));
|
||||
graph.createEdge(matchOpCInputs, matchOpC);
|
||||
|
||||
matchOpCOutput = graph.createNode(testMatchPredicate(any()));
|
||||
graph.createEdge(matchOpC, matchOpCOutput);
|
||||
|
||||
auto matchOpB = graph.createNode(testMatchPredicate("opB"));
|
||||
graph.createEdge(matchOpCOutput, matchOpB);
|
||||
graph.createEdge(matchOpCOutput, matchOpB);
|
||||
|
||||
auto matchOpBOutput = graph.createNode(testMatchPredicate(any()));
|
||||
graph.createEdge(matchOpB, matchOpBOutput);
|
||||
|
||||
auto matchOpF = graph.createNode(testMatchPredicate("opF"));
|
||||
graph.createEdge(matchOpBOutput, matchOpF);
|
||||
|
||||
auto matchOpFOutput = graph.createNode(testMatchPredicate(any()));
|
||||
graph.createEdge(matchOpF, matchOpFOutput);
|
||||
|
||||
matchOpG = graph.createNode(testMatchPredicate("opG"));
|
||||
auto matchDataI = graph.createNode(std::move(
|
||||
testMatchPredicate(any()).nonTerminal().excludeFromSubgraph()));
|
||||
graph.createEdge(matchOpFOutput, matchOpG);
|
||||
graph.createEdge(matchDataI, matchOpG);
|
||||
}
|
||||
};
|
||||
|
||||
TestGraph::NodeRef getInNode(TestGraph::NodeRef node, int index) {
|
||||
return node->getInEdges()[index]->tail();
|
||||
}
|
||||
|
||||
bool isSubgraphMatch(
|
||||
TestGraph::NodeRef nodeRef,
|
||||
const TestMatchGraph::NodeRef& criteria,
|
||||
bool invertGraphTraversal = true) {
|
||||
return graph.isSubgraphMatch(nodeRef, criteria, invertGraphTraversal)
|
||||
.isMatch();
|
||||
}
|
||||
} // namespace matcher
|
||||
|
||||
} // namespace nom
|
||||
|
||||
using namespace nom::matcher;
|
||||
|
||||
// Simple test cases for node matching criteria.
|
||||
TEST(SubgraphMatcher, IsNodeMatch) {
|
||||
TestGraph g;
|
||||
auto n1 = g.createNode("Hello");
|
||||
auto n2 = g.createNode("Le");
|
||||
g.createEdge(n1, n2);
|
||||
|
||||
EXPECT_TRUE(graph.isNodeMatch(n1, testMatchPredicate("Hello")));
|
||||
EXPECT_FALSE(graph.isNodeMatch(n1, testMatchPredicate("G")));
|
||||
EXPECT_TRUE(graph.isNodeMatch(n2, testMatchPredicate("Le")));
|
||||
EXPECT_FALSE(graph.isNodeMatch(n2, testMatchPredicate("le")));
|
||||
}
|
||||
|
||||
// Test subtree matching with a simple tree graph.
|
||||
TEST(SubgraphMatcher, IsSubtreeMatch) {
|
||||
TestGraph graph;
|
||||
auto n1 = graph.createNode("1");
|
||||
auto n2 = graph.createNode("2");
|
||||
auto n3 = graph.createNode("3");
|
||||
auto n4 = graph.createNode("4");
|
||||
auto n5 = graph.createNode("5");
|
||||
auto n6 = graph.createNode("6");
|
||||
auto n7 = graph.createNode("7");
|
||||
|
||||
graph.createEdge(n1, n2);
|
||||
graph.createEdge(n2, n3);
|
||||
graph.createEdge(n2, n4);
|
||||
graph.createEdge(n1, n5);
|
||||
graph.createEdge(n5, n6);
|
||||
graph.createEdge(n5, n7);
|
||||
/* N1
|
||||
/ \
|
||||
N2 N5
|
||||
/ \ / \
|
||||
N3 N4 N6 N7
|
||||
*/
|
||||
|
||||
reset();
|
||||
auto subtree = Tree(any(), {Tree(any()), Tree(any())});
|
||||
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
|
||||
EXPECT_FALSE(isSubgraphMatch(n4, subtree, false));
|
||||
|
||||
EXPECT_TRUE(isSubgraphMatch(n2, subtree, false));
|
||||
EXPECT_TRUE(isSubgraphMatch(n5, subtree, false));
|
||||
|
||||
reset();
|
||||
subtree = Tree(Criteria("5"), {Tree(any()), Tree(any())});
|
||||
EXPECT_FALSE(isSubgraphMatch(n2, subtree, false));
|
||||
EXPECT_TRUE(isSubgraphMatch(n5, subtree, false));
|
||||
|
||||
reset();
|
||||
subtree = Tree(any(), {Tree(any()), Tree(Criteria("4"))});
|
||||
EXPECT_TRUE(isSubgraphMatch(n2, subtree, false));
|
||||
EXPECT_FALSE(isSubgraphMatch(n5, subtree, false));
|
||||
|
||||
reset();
|
||||
// Accepts non terminal node
|
||||
subtree = Tree(any(), {NonTerminal(any()), NonTerminal(any())});
|
||||
EXPECT_TRUE(isSubgraphMatch(n1, subtree, false));
|
||||
EXPECT_TRUE(isSubgraphMatch(n2, subtree, false));
|
||||
EXPECT_TRUE(isSubgraphMatch(n5, subtree, false));
|
||||
EXPECT_FALSE(isSubgraphMatch(n3, subtree, false));
|
||||
EXPECT_FALSE(isSubgraphMatch(n4, subtree, false));
|
||||
EXPECT_FALSE(isSubgraphMatch(n6, subtree, false));
|
||||
EXPECT_FALSE(isSubgraphMatch(n7, subtree, false));
|
||||
}
|
||||
|
||||
// Test subtree matching in which * (repeated) matching of children is allowed.
|
||||
TEST(SubgraphMatcher, IsSubtreeMatchRepeated) {
|
||||
TestGraph graph;
|
||||
auto n1 = graph.createNode("1");
|
||||
auto n2 = graph.createNode("2");
|
||||
auto n3A = graph.createNode("3");
|
||||
auto n3B = graph.createNode("3");
|
||||
auto n4 = graph.createNode("4");
|
||||
auto n5A = graph.createNode("5");
|
||||
auto n5B = graph.createNode("5");
|
||||
auto n5C = graph.createNode("5");
|
||||
graph.createEdge(n1, n2);
|
||||
graph.createEdge(n1, n3A);
|
||||
graph.createEdge(n1, n3B);
|
||||
graph.createEdge(n1, n4);
|
||||
graph.createEdge(n1, n4);
|
||||
graph.createEdge(n1, n5A);
|
||||
graph.createEdge(n1, n5B);
|
||||
graph.createEdge(n1, n5C);
|
||||
|
||||
reset();
|
||||
auto subtree = Tree(any(), {Tree(Criteria("2"))});
|
||||
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
|
||||
|
||||
reset();
|
||||
subtree =
|
||||
Tree(any(), {Tree(Criteria("2"), {}, TestMatchPredicate::kStarCount)});
|
||||
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
|
||||
|
||||
reset();
|
||||
// clang-format off
|
||||
subtree = Tree(any(), {
|
||||
Tree(Criteria("2")),
|
||||
Tree(Criteria("3"), {}, 2),
|
||||
Tree(Criteria("4"), {}, 2),
|
||||
Tree(Criteria("5"), {}, 3)
|
||||
});
|
||||
EXPECT_TRUE(isSubgraphMatch(n1, subtree, false));
|
||||
|
||||
reset();
|
||||
subtree = Tree(any(), {
|
||||
Tree(Criteria("2")),
|
||||
Tree(Criteria("3"), {}, 2),
|
||||
Tree(Criteria("4"), {}, 2),
|
||||
Tree(Criteria("5"), {}, 4)
|
||||
});
|
||||
// Failes because exepected 4 matches of n5 but found 3.
|
||||
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
|
||||
|
||||
reset();
|
||||
subtree = Tree(any(), {
|
||||
Tree(Criteria("2")),
|
||||
Tree(Criteria("3"), {}, 2),
|
||||
Tree(Criteria("4"), {}, 2),
|
||||
Tree(Criteria("5"), {}, TestMatchPredicate::kStarCount)
|
||||
});
|
||||
EXPECT_TRUE(isSubgraphMatch(n1, subtree, false));
|
||||
|
||||
reset();
|
||||
subtree = Tree(any(), {
|
||||
Tree(Criteria("2")),
|
||||
Tree(Criteria("3"), {}, TestMatchPredicate::kStarCount),
|
||||
Tree(Criteria("4"), {}, 2),
|
||||
Tree(Criteria("5"), {}, TestMatchPredicate::kStarCount)
|
||||
});
|
||||
EXPECT_TRUE(isSubgraphMatch(n1, subtree, false));
|
||||
|
||||
reset();
|
||||
subtree = Tree(any(), {
|
||||
Tree(Criteria("2")),
|
||||
Tree(Criteria("3"), {}, TestMatchPredicate::kStarCount),
|
||||
});
|
||||
// Fails because there are unmatched edges.
|
||||
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
|
||||
|
||||
reset();
|
||||
subtree = Tree(any(), {
|
||||
Tree(Criteria("2")),
|
||||
Tree(Criteria("3"), {}, 2),
|
||||
Tree(Criteria("4")),
|
||||
Tree(Criteria("5"), {}, 3)
|
||||
});
|
||||
// Fails because the count is wrong; we have 2 edges to node N4 while
|
||||
// the pattern expects only 1.
|
||||
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
TEST(SubgraphMatcher, DagMatching) {
|
||||
reset();
|
||||
|
||||
// clang-format off
|
||||
auto n4match = Tree(Criteria("4"), {
|
||||
Tree(Criteria("5"))
|
||||
});
|
||||
auto subgraph = Tree(Criteria("1"), {
|
||||
Tree(Criteria("2"), {
|
||||
n4match
|
||||
}),
|
||||
Tree(Criteria("3"), {
|
||||
n4match
|
||||
}),
|
||||
});
|
||||
// clang-format on
|
||||
|
||||
{
|
||||
TestGraph graph;
|
||||
auto n1 = graph.createNode("1");
|
||||
auto n2 = graph.createNode("2");
|
||||
auto n3 = graph.createNode("3");
|
||||
auto n4 = graph.createNode("4");
|
||||
auto n5 = graph.createNode("5");
|
||||
|
||||
graph.createEdge(n1, n2);
|
||||
graph.createEdge(n1, n3);
|
||||
graph.createEdge(n2, n4);
|
||||
graph.createEdge(n3, n4);
|
||||
graph.createEdge(n4, n5);
|
||||
|
||||
/* N1
|
||||
/ \
|
||||
N2 N3
|
||||
\ /
|
||||
N4
|
||||
|
|
||||
N5
|
||||
*/
|
||||
|
||||
EXPECT_TRUE(isSubgraphMatch(n1, subgraph, false));
|
||||
}
|
||||
|
||||
{
|
||||
TestGraph graph;
|
||||
auto n1 = graph.createNode("1");
|
||||
auto n2 = graph.createNode("2");
|
||||
auto n3 = graph.createNode("3");
|
||||
auto n4A = graph.createNode("4");
|
||||
auto n4B = graph.createNode("4");
|
||||
auto n5 = graph.createNode("5");
|
||||
|
||||
graph.createEdge(n1, n2);
|
||||
graph.createEdge(n1, n3);
|
||||
graph.createEdge(n2, n4A);
|
||||
graph.createEdge(n3, n4B);
|
||||
graph.createEdge(n4A, n5);
|
||||
graph.createEdge(n4B, n5);
|
||||
|
||||
/* N1
|
||||
/ \
|
||||
N2 N3
|
||||
/ \
|
||||
N4A N4B
|
||||
\ /
|
||||
N5
|
||||
*/
|
||||
|
||||
// This should fail because n4A and n4B are not the same node.
|
||||
EXPECT_FALSE(isSubgraphMatch(n1, subgraph, false));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SubgraphMatcher, DagMatchingMultiEdges) {
|
||||
reset();
|
||||
|
||||
// clang-format off
|
||||
auto n2match = Tree(Criteria("2"));
|
||||
auto subgraph = Tree(Criteria("1"), {
|
||||
n2match,
|
||||
n2match
|
||||
});
|
||||
// clang-format on
|
||||
|
||||
{
|
||||
TestGraph graph;
|
||||
auto n1 = graph.createNode("1");
|
||||
auto n2 = graph.createNode("2");
|
||||
|
||||
graph.createEdge(n1, n2);
|
||||
graph.createEdge(n1, n2);
|
||||
|
||||
EXPECT_TRUE(isSubgraphMatch(n1, subgraph, false));
|
||||
}
|
||||
|
||||
{
|
||||
TestGraph graph;
|
||||
auto n1 = graph.createNode("1");
|
||||
auto n2A = graph.createNode("2");
|
||||
auto n2B = graph.createNode("2");
|
||||
|
||||
graph.createEdge(n1, n2A);
|
||||
graph.createEdge(n1, n2B);
|
||||
|
||||
EXPECT_FALSE(isSubgraphMatch(n1, subgraph, false));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SubgraphMatcher, DagMatchingRandomLargeGraph) {
|
||||
reset();
|
||||
// clang-format off
|
||||
auto n4match = Tree(any(), {
|
||||
NonTerminal(any(), 1)
|
||||
});
|
||||
auto subtree = Tree(any(), {
|
||||
Tree(any(), {
|
||||
n4match
|
||||
}),
|
||||
Tree(any(), {
|
||||
n4match
|
||||
}),
|
||||
});
|
||||
// clang-format on
|
||||
/* N1
|
||||
/ \
|
||||
N2 N3
|
||||
\ /
|
||||
N4
|
||||
|
|
||||
N5
|
||||
*/
|
||||
|
||||
// Look for the diamond pattern in a random large graph.
|
||||
TestGraph graph;
|
||||
std::vector<nom::Graph<std::string>::NodeRef> nodes;
|
||||
|
||||
// Here we create a test graph and then randomly embed the above
|
||||
// pattern into the graph repeatedly (numPatterns times).
|
||||
// The actual number of match will be less than numPatterns because the
|
||||
// embedded patterns can overlap which become unmatched subgraphs.
|
||||
const int numNodes = 50000;
|
||||
const int numPatterns = 5000;
|
||||
|
||||
for (int i = 0; i < numNodes; i++) {
|
||||
auto node = graph.createNode("Node");
|
||||
nodes.emplace_back(node);
|
||||
}
|
||||
|
||||
TestRandom random(517);
|
||||
for (int i = 0; i < numPatterns; i++) {
|
||||
std::vector<int> nodeIdx;
|
||||
for (int k = 0; k < 5; k++) {
|
||||
// NOLINTNEXTLINE(performance-inefficient-vector-operation)
|
||||
nodeIdx.emplace_back(random.nextInt() % numNodes);
|
||||
}
|
||||
graph.createEdge(nodes[nodeIdx[0]], nodes[nodeIdx[1]]);
|
||||
graph.createEdge(nodes[nodeIdx[0]], nodes[nodeIdx[2]]);
|
||||
graph.createEdge(nodes[nodeIdx[1]], nodes[nodeIdx[3]]);
|
||||
graph.createEdge(nodes[nodeIdx[2]], nodes[nodeIdx[3]]);
|
||||
graph.createEdge(nodes[nodeIdx[3]], nodes[nodeIdx[4]]);
|
||||
}
|
||||
EXPECT_EQ(graph.getEdgesCount(), 5 * numPatterns);
|
||||
|
||||
int countMatch = 0;
|
||||
for (auto node : graph.getMutableNodes()) {
|
||||
if (isSubgraphMatch(node, subtree, false)) {
|
||||
countMatch++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(countMatch, 1072);
|
||||
}
|
||||
|
||||
TEST(SubgraphMatcher, IsSubtreeMatchRealistic) {
|
||||
reset();
|
||||
auto graph = DataFlowTestGraph();
|
||||
auto subtree = DataFlowTestGraphCriteria().matchOpG;
|
||||
|
||||
EXPECT_FALSE(isSubgraphMatch(graph.opF, subtree));
|
||||
EXPECT_FALSE(isSubgraphMatch(graph.opC, subtree));
|
||||
EXPECT_FALSE(isSubgraphMatch(graph.opB, subtree));
|
||||
EXPECT_FALSE(isSubgraphMatch(graph.dataOut, subtree));
|
||||
|
||||
EXPECT_TRUE(isSubgraphMatch(graph.opG, subtree));
|
||||
}
|
||||
|
||||
TEST(SubgraphMatcher, ReplaceGraphRealistic) {
|
||||
reset();
|
||||
auto testGraph = DataFlowTestGraph();
|
||||
auto subtree = DataFlowTestGraphCriteria();
|
||||
|
||||
graph.replaceSubgraph(
|
||||
testGraph.graph,
|
||||
subtree.matchOpG,
|
||||
[subtree](
|
||||
TestGraph& g,
|
||||
TestGraph::NodeRef opG,
|
||||
const TestMatchGraph::SubgraphMatchResultType& matchResult) {
|
||||
auto fusedNode = g.createNode("opFused");
|
||||
auto opC = getInNode(
|
||||
matchResult.getMatchNodeMap()->at(subtree.matchOpCOutput), 0);
|
||||
g.replaceOutEdges(opG, fusedNode);
|
||||
g.replaceInEdges(opG, fusedNode);
|
||||
g.replaceInEdges(opC, fusedNode);
|
||||
g.deleteNodes(matchResult.getMatchedSubgraph()->getNodes());
|
||||
return true;
|
||||
});
|
||||
|
||||
// Now the nodes are:
|
||||
// - NumInputs input nodes
|
||||
// - dataI node
|
||||
// - fused node
|
||||
// - output node
|
||||
// - dataC2 node
|
||||
auto nodes = testGraph.graph.getMutableNodes();
|
||||
|
||||
// Test that the graph is transformed as expected.
|
||||
EXPECT_EQ(nodes.size(), testGraph.numInputs + 4);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
TestGraph::NodeRef opFused;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
TestGraph::NodeRef dataI;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
TestGraph::NodeRef dataOut;
|
||||
for (auto node : nodes) {
|
||||
if (node->data() == "opFused") {
|
||||
opFused = node;
|
||||
} else if (node->data() == "dataOut") {
|
||||
dataOut = node;
|
||||
} else if (node->data() == "dataI") {
|
||||
dataI = node;
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_EQ(getInNode(dataOut, 0), opFused);
|
||||
|
||||
EXPECT_EQ(opFused->getInEdges().size(), testGraph.numInputs + 1);
|
||||
EXPECT_EQ(getInNode(opFused, 0), dataI);
|
||||
for (int i = 1; i <= testGraph.numInputs; i++) {
|
||||
EXPECT_EQ(getInNode(opFused, i)->data(), "input");
|
||||
}
|
||||
|
||||
// Use nom::converters::convertToDotString(&graph.graph, TestGraphNodePrinter)
|
||||
// to visualize. The transformed graph looks like This
|
||||
/*
|
||||
|
||||
+---------++---------+
|
||||
| input_A || input_D |
|
||||
+---------++---------+
|
||||
| |
|
||||
| |
|
||||
v v
|
||||
+---------+ +--------------------+ +---------+
|
||||
| input_B | --> | opFused | <-- | input_C |
|
||||
+---------+ +--------------------+ +---------+
|
||||
| ^
|
||||
| |
|
||||
v |
|
||||
+---------++---------+
|
||||
| dataOut || dataI |
|
||||
+---------++---------+
|
||||
*/
|
||||
}
|
@ -1,60 +0,0 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "test_util.h"
|
||||
|
||||
#include "nomnigraph/Graph/Graph.h"
|
||||
|
||||
TEST(Tarjans, Simple) {
|
||||
TestClass t1;
|
||||
TestClass t2;
|
||||
nom::Graph<TestClass> g;
|
||||
nom::Graph<TestClass>::NodeRef n1 = g.createNode(std::move(t1));
|
||||
nom::Graph<TestClass>::NodeRef n2 = g.createNode(std::move(t2));
|
||||
g.createEdge(n1, n2);
|
||||
g.createEdge(n2, n1);
|
||||
auto sccs = nom::algorithm::tarjans(&g);
|
||||
EXPECT_EQ(sccs.size(), 1);
|
||||
}
|
||||
|
||||
TEST(Tarjans, WithEdgeStorage) {
|
||||
TestClass t1;
|
||||
TestClass t2;
|
||||
nom::Graph<TestClass, TestClass> g;
|
||||
nom::Graph<TestClass, TestClass>::NodeRef n1 = g.createNode(std::move(t1));
|
||||
nom::Graph<TestClass, TestClass>::NodeRef n2 = g.createNode(std::move(t2));
|
||||
g.createEdge(n1, n2, TestClass());
|
||||
g.createEdge(n2, n1, TestClass());
|
||||
auto sccs = nom::algorithm::tarjans(&g);
|
||||
EXPECT_EQ(sccs.size(), 1);
|
||||
}
|
||||
|
||||
TEST(Tarjans, DAG) {
|
||||
auto graph = createGraph();
|
||||
auto sccs = nom::algorithm::tarjans(&graph);
|
||||
EXPECT_EQ(sccs.size(), 9);
|
||||
}
|
||||
|
||||
TEST(Tarjans, Cycle) {
|
||||
auto graph = createGraphWithCycle();
|
||||
auto sccs = nom::algorithm::tarjans(&graph);
|
||||
EXPECT_EQ(sccs.size(), 8);
|
||||
}
|
||||
|
||||
TEST(Tarjans, Random) {
|
||||
nom::Graph<TestClass> g;
|
||||
std::vector<nom::Graph<TestClass>::NodeRef> nodes;
|
||||
for (auto i = 0; i < 10; ++i) {
|
||||
TestClass t;
|
||||
nodes.emplace_back(g.createNode(std::move(t)));
|
||||
}
|
||||
for (auto i = 0; i < 30; ++i) {
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,clang-analyzer-security.insecureAPI.rand)
|
||||
int ri1 = rand() % nodes.size();
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,clang-analyzer-security.insecureAPI.rand)
|
||||
int ri2 = rand() % nodes.size();
|
||||
g.createEdge(nodes[ri1], nodes[ri2]);
|
||||
}
|
||||
|
||||
auto sccs = nom::algorithm::tarjans(&g);
|
||||
EXPECT_GE(sccs.size(), 1);
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "test_util.h"
|
||||
|
||||
#include "nomnigraph/Graph/Graph.h"
|
||||
|
||||
using GraphT = nom::Graph<TestClass>;
|
||||
using TopoSortT = nom::algorithm::TopoSort<GraphT>;
|
||||
|
||||
TEST(TopoSort, Simple) {
|
||||
GraphT g;
|
||||
auto n1 = createTestNode(g);
|
||||
auto n2 = createTestNode(g);
|
||||
g.createEdge(n1, n2);
|
||||
auto res = nom::algorithm::topoSort(&g);
|
||||
EXPECT_EQ(res.status, TopoSortT::Result::OK);
|
||||
EXPECT_EQ(res.nodes.size(), 2);
|
||||
EXPECT_EQ(res.nodes[0], n1);
|
||||
EXPECT_EQ(res.nodes[1], n2);
|
||||
}
|
||||
|
||||
TEST(TopoSort, DAG) {
|
||||
GraphT g;
|
||||
auto n1 = createTestNode(g);
|
||||
auto n2 = createTestNode(g);
|
||||
auto n3 = createTestNode(g);
|
||||
auto n4 = createTestNode(g);
|
||||
g.createEdge(n1, n2);
|
||||
g.createEdge(n1, n3);
|
||||
g.createEdge(n2, n4);
|
||||
g.createEdge(n3, n4);
|
||||
auto res = nom::algorithm::topoSort(&g);
|
||||
EXPECT_EQ(res.status, TopoSortT::Result::OK);
|
||||
EXPECT_EQ(res.nodes.size(), 4);
|
||||
auto i1 = std::find(res.nodes.begin(), res.nodes.end(), n1);
|
||||
auto i2 = std::find(res.nodes.begin(), res.nodes.end(), n2);
|
||||
auto i3 = std::find(res.nodes.begin(), res.nodes.end(), n3);
|
||||
auto i4 = std::find(res.nodes.begin(), res.nodes.end(), n4);
|
||||
ASSERT_TRUE(i1 != res.nodes.end());
|
||||
ASSERT_TRUE(i2 != res.nodes.end());
|
||||
ASSERT_TRUE(i3 != res.nodes.end());
|
||||
ASSERT_TRUE(i4 != res.nodes.end());
|
||||
ASSERT_LT(i1, i2);
|
||||
ASSERT_LT(i1, i3);
|
||||
ASSERT_LT(i2, i4);
|
||||
ASSERT_LT(i3, i4);
|
||||
}
|
||||
|
||||
TEST(TopoSort, Cycle1) {
|
||||
GraphT g;
|
||||
auto n1 = createTestNode(g);
|
||||
auto n2 = createTestNode(g);
|
||||
g.createEdge(n1, n2);
|
||||
g.createEdge(n2, n1);
|
||||
auto res = nom::algorithm::topoSort(&g);
|
||||
EXPECT_EQ(res.status, TopoSortT::Result::CYCLE);
|
||||
EXPECT_EQ(res.nodes.size(), 0);
|
||||
}
|
||||
|
||||
TEST(TopoSort, Cycle2) {
|
||||
GraphT g;
|
||||
auto n1 = createTestNode(g);
|
||||
auto n2 = createTestNode(g);
|
||||
auto n3 = createTestNode(g);
|
||||
auto n4 = createTestNode(g);
|
||||
g.createEdge(n1, n2);
|
||||
g.createEdge(n2, n3);
|
||||
g.createEdge(n3, n4);
|
||||
g.createEdge(n4, n2);
|
||||
auto res = nom::algorithm::topoSort(&g);
|
||||
EXPECT_EQ(res.status, TopoSortT::Result::CYCLE);
|
||||
EXPECT_EQ(res.nodes.size(), 0);
|
||||
}
|
Reference in New Issue
Block a user