Files
pytorch/caffe2/core/graph_test.cc
Orion Reblitz-Richardson 1d5780d42c Remove Apache headers from source.
* LICENSE file contains details, so removing from individual source files.
2018-03-27 13:10:18 -07:00

201 lines
6.0 KiB
C++

#include <gtest/gtest.h>
#include "caffe2/core/graph.h"
#include "caffe2/core/net.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
namespace {
using transform::Graph;
static std::atomic<int> counter;
class GraphDummyOp final : public OperatorBase {
public:
using OperatorBase::OperatorBase;
bool Run(int /* unused */) override {
counter.fetch_add(1);
return true;
}
};
REGISTER_CPU_OPERATOR(GraphDummyOp1, GraphDummyOp);
OPERATOR_SCHEMA(GraphDummyOp1)
.NumInputs(0, INT_MAX)
.NumOutputs(0, INT_MAX)
.AllowInplace({{0, 0}, {1, 1}});
REGISTER_CPU_OPERATOR(GraphDummyOp2, GraphDummyOp);
OPERATOR_SCHEMA(GraphDummyOp2)
.NumInputs(0, INT_MAX)
.NumOutputs(0, INT_MAX)
.AllowInplace({{0, 0}, {1, 1}});
REGISTER_CPU_OPERATOR(GraphDummyOp3, GraphDummyOp);
OPERATOR_SCHEMA(GraphDummyOp3)
.NumInputs(0, INT_MAX)
.NumOutputs(0, INT_MAX)
.AllowInplace({{0, 0}, {1, 1}});
// Checks if two netdefs are in terms of type, input, and output.
void compare_netdefs(const NetDef& net_a, const NetDef& net_b) {
EXPECT_EQ(net_a.op_size(), net_b.op_size());
for (int i = 0; i < net_a.op_size(); i++) {
EXPECT_EQ(net_a.op(i).type(), net_b.op(i).type());
EXPECT_EQ(net_a.op(i).input_size(), net_b.op(i).input_size());
for (int j = 0; j < net_a.op(i).input_size(); j++) {
EXPECT_EQ(net_a.op(i).input(j), net_b.op(i).input(j));
}
EXPECT_EQ(net_a.op(i).output_size(), net_b.op(i).output_size());
for (int j = 0; j < net_a.op(i).output_size(); j++) {
EXPECT_EQ(net_a.op(i).output(j), net_b.op(i).output(j));
}
}
}
TEST(GraphTest, TestGenerateGraphChain) {
Workspace ws;
ws.CreateBlob("in");
NetDef netdef;
AddOp(&netdef, "GraphDummyOp1", {"in"}, {"mid1"});
AddOp(&netdef, "GraphDummyOp2", {"mid1"}, {"mid2"});
AddOp(&netdef, "GraphDummyOp1", {"mid2"}, {"mid3"});
AddOp(&netdef, "GraphDummyOp2", {"mid3"}, {"out"});
Graph g(netdef);
EXPECT_EQ(g.size(), 4);
for (int i = 0; i < 4; i++) {
if (i < 3) {
EXPECT_EQ(g.node(i).children.size(), 1);
EXPECT_TRUE(g.node(i).children.count(i + 1));
}
if (i > 0) {
EXPECT_EQ(g.node(i).parents.size(), 1);
EXPECT_TRUE(g.node(i).parents.count(i - 1));
}
}
NetDef retrieved_net = g.GetNetDef();
compare_netdefs(retrieved_net, netdef);
}
TEST(GraphTest, TestGenerateGraphChainInPlace) {
Workspace ws;
ws.CreateBlob("in");
NetDef netdef;
AddOp(&netdef, "GraphDummyOp1", {"in"}, {"out"});
AddOp(&netdef, "GraphDummyOp2", {"out"}, {"out"});
AddOp(&netdef, "GraphDummyOp1", {"out"}, {"out"});
AddOp(&netdef, "GraphDummyOp2", {"out"}, {"out"});
Graph g(netdef);
EXPECT_EQ(g.size(), 4);
for (int i = 0; i < 4; i++) {
if (i < 3) {
EXPECT_EQ(g.node(i).children.size(), 1);
EXPECT_TRUE(g.node(i).children.count(i + 1));
}
if (i > 0) {
EXPECT_EQ(g.node(i).parents.size(), 1);
EXPECT_TRUE(g.node(i).parents.count(i - 1));
}
}
NetDef retrieved_net = g.GetNetDef();
compare_netdefs(retrieved_net, netdef);
}
// Diamond Graph
TEST(GraphTest, TestGenerateGraphBranch) {
Workspace ws;
ws.CreateBlob("in");
NetDef netdef;
AddOp(&netdef, "GraphDummyOp1", {"in"}, {"mid1"});
AddOp(&netdef, "GraphDummyOp2", {"mid1"}, {"mid2"});
AddOp(&netdef, "GraphDummyOp2", {"mid1"}, {"mid3"});
AddOp(&netdef, "GraphDummyOp3", {"mid2", "mid3"}, {"out"});
Graph g(netdef);
EXPECT_EQ(g.size(), 4);
EXPECT_EQ(g.node(0).parents.size(), 0);
EXPECT_EQ(g.node(0).children.size(), 2);
EXPECT_EQ(g.node(1).parents.size(), 1);
EXPECT_EQ(g.node(1).children.size(), 1);
EXPECT_EQ(g.node(2).parents.size(), 1);
EXPECT_EQ(g.node(2).children.size(), 1);
EXPECT_EQ(g.node(3).parents.size(), 2);
EXPECT_EQ(g.node(3).children.size(), 0);
NetDef retrieved_net = g.GetNetDef();
compare_netdefs(retrieved_net, netdef);
}
// Double Diamond Graph, reused names
TEST(GraphTest, TestReusedInputs) {
Workspace ws;
ws.CreateBlob("in");
NetDef netdef;
AddOp(&netdef, "GraphDummyOp1", {"in"}, {"in"});
AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid1"});
AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid2"});
AddOp(&netdef, "GraphDummyOp3", {"mid1", "mid2"}, {"in"});
AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid1"});
AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid2"});
AddOp(&netdef, "GraphDummyOp3", {"mid1", "mid2"}, {"in"});
Graph g(netdef);
EXPECT_EQ(g.size(), 7);
EXPECT_EQ(g.node(0).parents.size(), 0);
EXPECT_EQ(g.node(0).children.size(), 2);
EXPECT_EQ(g.node(1).parents.size(), 1);
EXPECT_EQ(g.node(1).children.size(), 1);
EXPECT_EQ(g.node(2).parents.size(), 1);
EXPECT_EQ(g.node(2).children.size(), 1);
EXPECT_EQ(g.node(3).parents.size(), 2);
EXPECT_EQ(g.node(3).children.size(), 2);
EXPECT_EQ(g.node(4).parents.size(), 1);
EXPECT_EQ(g.node(4).children.size(), 1);
EXPECT_EQ(g.node(5).parents.size(), 1);
EXPECT_EQ(g.node(5).children.size(), 1);
EXPECT_EQ(g.node(6).parents.size(), 2);
EXPECT_EQ(g.node(6).children.size(), 0);
NetDef retrieved_net = g.GetNetDef();
compare_netdefs(retrieved_net, netdef);
}
TEST(GraphTest, TestGetPerimeter) {
Workspace ws;
ws.CreateBlob("in");
NetDef netdef;
AddOp(&netdef, "GraphDummyOp1", {"in"}, {"in"});
AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid1"});
AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid2"});
AddOp(&netdef, "GraphDummyOp3", {"mid1", "mid2"}, {"in"});
AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid1"});
AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid2"});
AddOp(&netdef, "GraphDummyOp1", {"mid1", "mid2"}, {"in"});
Graph g(netdef);
std::vector<int> subgraph = {3};
auto subgraph_input = g.GetSubgraphInput(subgraph);
EXPECT_EQ(subgraph_input.size(), 2);
EXPECT_EQ(subgraph_input[0], std::make_pair(string("mid1"), 1));
EXPECT_EQ(subgraph_input[1], std::make_pair(string("mid2"), 2));
auto subgraph_output = g.GetSubgraphOutput(subgraph);
EXPECT_EQ(subgraph_output.size(), 2);
EXPECT_EQ(subgraph_output[0], std::make_pair(string("in"), 4));
EXPECT_EQ(subgraph_output[1], std::make_pair(string("in"), 5));
}
} // namespace
} // namespace caffe2