#include #include "test/cpp/jit/test_utils.h" #include "torch/csrc/jit/ir/subgraph_matcher.h" namespace torch { namespace jit { TEST(SubgraphMatcherTest, Trivial1) { Graph graph, pattern; parseIR( R"IR( graph(%0): %a = a::aaa(%0) return (%a))IR", &graph); parseIR( R"IR( graph(%0): %x = a::aaa(%0) return (%x))IR", &pattern); AT_ASSERT(!findPatternMatches(pattern, graph).empty()); } TEST(SubgraphMatcherTest, Trivial2) { Graph graph; auto* g_in = graph.addInput(); auto* g_tanh = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1)); g_tanh->addInput(g_in); graph.registerOutput(g_tanh->output()); Graph pattern; auto* p_in = pattern.addInput(); auto* p_tanh = pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/1)); p_tanh->addInput(p_in); pattern.registerOutput(p_tanh->output()); auto matches = findPatternMatches(pattern, graph); AT_ASSERT(matches.size() == 1); for (const Match& m : matches) { AT_ASSERT(m.values_map.at(p_in) == g_in); AT_ASSERT(m.values_map.at(p_tanh->output()) == g_tanh->output()); AT_ASSERT(m.nodes_map.at(p_tanh) == g_tanh); } } TEST(SubgraphMatcherTest, Trivial3) { Graph graph, pattern; parseIR( R"IR( graph(%0): %a = a::a(%0) %b = a::b(%0) %c = a::c(%a, %b) return (%c))IR", &graph); parseIR( R"IR( graph(%a, %b): %c = a::c(%a, %b) return (%c))IR", &pattern); AT_ASSERT(!findPatternMatches(pattern, graph).empty()); } TEST(SubgraphMatcherTest, Trivial4) { Graph graph; auto* g_in0 = graph.addInput(); auto* g_in1 = graph.addInput(); auto* g_mul = graph.insertNode(graph.create(aten::mul, /*num_outputs =*/1)); g_mul->addInput(g_in0); g_mul->addInput(g_in1); graph.registerOutput(g_mul->output()); Graph pattern; auto* p_in0 = pattern.addInput(); auto* p_in1 = pattern.addInput(); auto* p_mul = pattern.insertNode(pattern.create(aten::mul, /*num_outputs =*/1)); p_mul->addInput(p_in0); p_mul->addInput(p_in1); pattern.registerOutput(p_mul->output()); auto matches = findPatternMatches(pattern, graph); AT_ASSERT(matches.size() == 1); for (const Match& m : matches) { AT_ASSERT(m.values_map.at(p_in0) == g_in0); AT_ASSERT(m.values_map.at(p_in1) == g_in1); AT_ASSERT(m.values_map.at(p_mul->output()) == g_mul->output()); AT_ASSERT(m.nodes_map.at(p_mul) == g_mul); } } TEST(SubgraphMatcherTest, Linear1) { Graph graph, pattern; parseIR( R"IR( graph(%0): %a = a::aaa(%0) %b = b::bbb(%a) %c = c::ccc(%b) %d = d::ddd(%c) %a = a::aaa(%0) return (%d))IR", &graph); parseIR( R"IR( graph(%0): %x = b::bbb(%0) %y = c::ccc(%x) return (%y))IR", &pattern); AT_ASSERT(!findPatternMatches(pattern, graph).empty()); } TEST(SubgraphMatcherTest, Linear2) { Graph graph; auto* g_in = graph.addInput(); auto* g_tanh = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1)); g_tanh->addInput(g_in); auto* g_tanh2 = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1)); g_tanh2->addInput(g_tanh->output()); graph.registerOutput(g_tanh2->output()); Graph pattern; auto* p_in = pattern.addInput(); auto* p_tanh = pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/1)); p_tanh->addInput(p_in); auto* p_tanh2 = pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/1)); p_tanh2->addInput(p_tanh->output()); pattern.registerOutput(p_tanh2->output()); auto matches = findPatternMatches(pattern, graph); AT_ASSERT(matches.size() == 1); for (const Match& m : matches) { AT_ASSERT(m.values_map.at(p_in) == g_in); AT_ASSERT(m.values_map.at(p_tanh->output()) == g_tanh->output()); AT_ASSERT(m.values_map.at(p_tanh2->output()) == g_tanh2->output()); AT_ASSERT(m.nodes_map.at(p_tanh) == g_tanh); AT_ASSERT(m.nodes_map.at(p_tanh2) == g_tanh2); } } /** * Test diamond pattern: * * ooo * | * aaa * / \ * bbb ccc * \ / * ddd * | * eee */ TEST(SubgraphMatcherTest, Diamond1) { Graph graph, pattern1, pattern2; parseIR( R"IR( graph(%0): %o = o::ooo(%0) %a = a::aaa(%o) %b = b::bbb(%a) %c = c::ccc(%a) %d = d::ddd(%b, %c) %e = e::eee(%d) return (%e))IR", &graph); parseIR( R"IR( graph(%0): %a = a::aaa(%0) %b = b::bbb(%a) %c = c::ccc(%a) %d = d::ddd(%b, %c) return (%d))IR", &pattern1); AT_ASSERT(!findPatternMatches(pattern1, graph).empty()); // Check that order of nodes inside the diamond does not affect the result parseIR( R"IR( graph(%0): %a = a::aaa(%0) %c = c::ccc(%a) %b = b::bbb(%a) %d = d::ddd(%b, %c) return (%d))IR", &pattern2); AT_ASSERT(!findPatternMatches(pattern2, graph).empty()); } /** * Test diamond pattern: * * i0 * | * chunk * / \ * os[0] os[1] * \ / * * * | * o1 */ TEST(SubgraphMatcherTest, Diamond2) { Graph graph; auto* g_in = graph.addInput(); auto* g_chunk = graph.insertNode(graph.create(prim::ConstantChunk, /*num_outputs =*/2)); g_chunk->i_(attr::chunks, 2)->i_(attr::dim, 0); g_chunk->addInput(g_in); auto* g_mul = graph.insertNode(graph.create(aten::mul, /*num_outputs =*/1)); g_mul->addInput(g_chunk->outputs()[0]); g_mul->addInput(g_chunk->outputs()[1]); graph.registerOutput(g_mul->output()); Graph pattern; auto* p_in = pattern.addInput(); auto* p_chunk = pattern.insertNode( pattern.create(prim::ConstantChunk, /*num_outputs =*/2)); p_chunk->i_(attr::chunks, 2)->i_(attr::dim, 0); p_chunk->addInput(p_in); auto* p_mul = pattern.insertNode(pattern.create(aten::mul, /*num_outputs =*/1)); p_mul->addInput(p_chunk->outputs()[0]); p_mul->addInput(p_chunk->outputs()[1]); pattern.registerOutput(p_mul->output()); auto matches = findPatternMatches(pattern, graph); AT_ASSERT(matches.size() == 1); for (const Match& m : matches) { AT_ASSERT(m.values_map.at(p_in) == g_in); AT_ASSERT(m.values_map.at(p_chunk->outputs()[0]) == g_chunk->outputs()[0]); AT_ASSERT(m.values_map.at(p_chunk->outputs()[1]) == g_chunk->outputs()[1]); AT_ASSERT(m.values_map.at(p_mul->output()) == g_mul->output()); AT_ASSERT(m.nodes_map.at(p_mul) == g_mul); } } TEST(SubgraphMatcherTest, XPattern) { Graph graph, pattern; parseIR( R"IR( graph(%0, %1): %b = b::bbb(%0) %c = c::ccc(%1) %x = x::xxx(%b, %c) %e = e::eee(%x) %f = f::fff(%x) %g = g::ggg(%e, %f) return (%g))IR", &graph); parseIR( R"IR( graph(%0, %1): %b = b::bbb(%0) %c = c::ccc(%1) %x = x::xxx(%b, %c) %e = e::eee(%x) %f = f::fff(%x) %g = g::ggg(%e, %f) return (%g))IR", &pattern); AT_ASSERT(!findPatternMatches(pattern, graph).empty()); } TEST(SubgraphMatcherTest, MultipleMatches) { Graph graph, pattern; parseIR( R"IR( graph(%t0): %t1 = a::aaa(%t0) %t2 = a::aaa(%t1) %t3 = a::aaa(%t2) %t4 = a::aaa(%t3) return (%t4))IR", &graph); parseIR( R"IR( graph(%t0): %t1 = a::aaa(%t0) return (%t1))IR", &pattern); auto matches = findPatternMatches(pattern, graph); AT_ASSERT(matches.size() == 4); } TEST(SubgraphMatcherTest, OverlappingMatches) { Graph graph, pattern; parseIR( R"IR( graph(%t0): %t1 = a::aaa(%t0) %t2 = a::aaa(%t1) %t3 = a::aaa(%t2) %t4 = a::aaa(%t3) return (%t4))IR", &graph); parseIR( R"IR( graph(%t0): %t1 = a::aaa(%t0) %t2 = a::aaa(%t1) return (%t2))IR", &pattern); auto matches = findPatternMatches(pattern, graph); AT_ASSERT(matches.size() == 3); } TEST(SubgraphMatcherTest, MatchInBasicBlocks1) { Graph graph; parseIR( R"IR( graph(%a, %b, %c): %d = aten::mul(%a, %b) %x = prim::If(%c) block0(): %x1 = aten::mul(%a, %d) -> (%x1) block1(): %x2 = aten::mul(%b, %d) -> (%x2) return (%x))IR", &graph); // Ensure the matches don't cross basic block boundaries Graph pattern0; parseIR( R"IR( graph(%x, %y): %z = aten::mul(%x, %y) return (%z))IR", &pattern0); AT_ASSERT(findPatternMatches(pattern0, graph).size() == 3); Graph pattern1; parseIR( R"IR( graph(%x, %y): %z1 = aten::mul(%x, %y) %z2 = aten::mul(%y, %z1) return (%z2))IR", &pattern1); AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0); } TEST(SubgraphMatcherTest, MatchInBasicBlocks2) { Graph graph; parseIR( R"IR( graph(%a, %b): %x = my::mul(%a, %b) %y = my::node_with_subblock() block0(): %z = my::mul(%b, %x) -> (%z) return (%y))IR", &graph); // Check that we can match both mul ops Graph pattern0; parseIR( R"IR( graph(%x, %y): %z = my::mul(%x, %y) return (%z))IR", &pattern0); AT_ASSERT(findPatternMatches(pattern0, graph).size() == 2); // Ensure the matches don't cross basic block boundaries Graph pattern1; parseIR( R"IR( graph(%x, %y): %u = my::mul(%x, %y) %v = my::mul(%y, %u) return (%v))IR", &pattern1); AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0); } TEST(SubgraphMatcherTest, MatchesAttributes) { Graph graph; parseIR( R"IR( graph(%0): %a = a::a[isattr=[1,2]](%0) %b = a::b[intattr=10, floatattr=3.14, complexattr=-3.14j](%0) %c = a::c[myattr="qqq"](%a, %b) return (%c))IR", &graph); { Graph pattern; parseIR( R"IR( graph(%a, %b): %c = a::c[myattr="qqq"](%a, %b) return (%c))IR", &pattern); AT_ASSERT(!findPatternMatches(pattern, graph).empty()); } { Graph pattern; parseIR( R"IR( graph(%a, %b): %c = a::c[myattr="zzz"](%a, %b) return (%c))IR", &pattern); AT_ASSERT(findPatternMatches(pattern, graph).empty()); } { Graph pattern; parseIR( R"IR( graph(%0): %b = a::b[extraattr=10](%0) return (%b))IR", &pattern); AT_ASSERT(findPatternMatches(pattern, graph).empty()); } { Graph pattern; parseIR( R"IR( graph(%0): %b = a::b[intattr=10, floatattr=3.14, complexattr=-3.14j](%0) return (%b))IR", &pattern); AT_ASSERT(!findPatternMatches(pattern, graph).empty()); } { Graph pattern; parseIR( R"IR( graph(%0): %b = a::b[intattr=10, floatattr=3.14, complexattr=-3.14j, strattr="rrr"](%0) return (%b))IR", &pattern); AT_ASSERT(findPatternMatches(pattern, graph).empty()); } { Graph pattern; parseIR( R"IR( graph(%0): %a = a::a[isattr=[1,2]](%0) return (%a))IR", &pattern); // Lists are not supported yet, thus we shouldn't match for now. AT_ASSERT(findPatternMatches(pattern, graph).empty()); } { Graph pattern; parseIR( R"IR( graph(%a, %b): %c = a::c[myattr="q.*"](%a, %b) return (%c))IR", &pattern); AT_ASSERT(!findPatternMatches(pattern, graph).empty()); } } TEST(SubgraphMatcherTest, BadPattern) { Graph graph, pattern1, pattern2; parseIR( R"IR( graph(%x): %y = my::op1(%x) %z = my::op2(%x) return (%y, %z))IR", &graph); parseIR( R"IR( graph(%x): %y = my::node_with_subblock() block0(): %z = my::op(%x) -> (%z) return (%y))IR", &pattern1); // No support for patterns with subblocks // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) ASSERT_ANY_THROW(findPatternMatches(pattern1, graph)); parseIR( R"IR( graph(%x): %y = my::op1(%x) %z = my::op2(%x) return (%y, %z))IR", &pattern2); // Not supported multi-output pattern, because not the whole pattern is // covered by a traversal up from the first output (`%z = ...` is not // visited). See the note "Multi-output Patterns" in subgraph_matcher.h. // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) ASSERT_ANY_THROW(findPatternMatches(pattern2, graph)); } TEST(SubgraphMatcherTest, MultiOutput) { { Graph graph, pattern; parseIR( R"IR( graph(%0): %a = a::aaa(%0) %b = b::bbb(%a) %c = c::ccc(%a, %b) %x = a::aaa(%c) %y = b::bbb(%x) %z = d::ddd(%x, %y) return (%y))IR", &graph); parseIR( R"IR( graph(%0): %a = a::aaa(%0) %b = b::bbb(%a) return (%b, %a))IR", &pattern); AT_ASSERT(findPatternMatches(pattern, graph).size() == 2); } { Graph graph, pattern; parseIR( R"IR( graph(%0, %1): %a1, %a2 = a::aaa(%0, %1) %b = b::bbb(%a1) %c = c::ccc(%b) %x1, %x2 = a::aaa(%c, %a2) %y = b::bbb(%x1) %z = d::ddd(%y) return (%z))IR", &graph); parseIR( R"IR( graph(%0, %1): %a1, %a2 = a::aaa(%0, %1) %b = b::bbb(%a1) return (%b, %a2))IR", &pattern); AT_ASSERT(findPatternMatches(pattern, graph).size() == 2); } } } // namespace jit } // namespace torch