#include #include #include #include #include #include #include namespace torch { namespace jit { TEST(StackOptTest, UseVariadicStack) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(56, 56, 56), %1: Float(56, 56, 56), %2: Float(56, 56, 56), %3: Float(56, 56, 56), %4: Float(56, 56, 56), %5: Float(56, 56, 56)): %10 : int = prim::Constant[value=0]() %input : Tensor[] = prim::ListConstruct(%0, %1, %2, %3, %4, %5) %stack : Float(5, 56, 56, 56) = aten::stack(%input, %10) return (%stack) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ASSERT_TRUE(UseVariadicStack(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // After replacing `aten::stack` with `prim::VarStack` we should have the // following graph: // // graph(%0 : ..., // %1 : ...): // %zero : int = prim:Constant[value=0]() // %varstack : Tensor = prim::VarStack(%0, %1, %2, %3, %4, %5, %zero) // return (%varstack) testing::FileCheck() .check_count("= prim::VarStack(", 1, /*exactly*/ true) ->check_count("= aten::stack(", 0, /*exactly*/ true) ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true) ->run(*graph); } TEST(StackOptTest, UseVariadicStackReplaceMultiple) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(56, 56, 56), %1: Float(56, 56, 56), %2: Float(56, 56, 56), %3: Float(56, 56, 56)): %10 : int = prim::Constant[value=0]() %input1 : Tensor[] = prim::ListConstruct(%0, %1) %stack1 : Float(4, 56, 56, 56) = aten::stack(%input1, %10) %input2 : Tensor[] = prim::ListConstruct(%2, %3) %stack2 : Float(4, 56, 56, 56) = aten::stack(%input2, %10) return (%stack1, %stack2) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ASSERT_TRUE(UseVariadicStack(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // After full stack optimization we should have the following graph: // // graph(%0 : ..., // %1 : ..., // %2 : ..., // %3 : ....): // %zero : int = prim:Constant[value=0]() // %varcat1 : Tensor = prim::VarStack(%0, %1, %zero) // %varcat2 : Tensor = prim::VarStack(%2, %3, %zero) // return (%varcat1, %varcat2) testing::FileCheck() .check_count("= prim::VarStack(", 2, /*exactly*/ true) ->check_count("= aten::stack(", 0, /*exactly*/ true) ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true) ->run(*graph); } TEST(StackOptTest, UseVariadicStackWithMultipleListUses) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(56, 56, 56), %1: Float(56, 56, 56)): %2 : int = prim::Constant[value=0]() %input : Tensor[] = prim::ListConstruct(%0, %1) %stack : Float(2, 56, 56, 56) = aten::stack(%input, %2) return (%stack, %input) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ASSERT_TRUE(UseVariadicStack(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // After replacing `aten::stack` with `prim::VarStack` we should have the // following graph: // // graph(%0 : ..., // %1 : ...): // %zero : int = prim:Constant[value=0]() // %input : Tensor[] = prim::ListConstruct(%0, %1) // %varcat : Tensor = prim::VarStack(%0, %1, %zero) // return (%varcat, %input) testing::FileCheck() .check_count("= prim::ListConstruct(", 1, /*exactly*/ true) ->check_count("= prim::VarStack(", 1, /*exactly*/ true) ->check_count("= aten::stack(", 0, /*exactly*/ true) ->run(*graph); } TEST(StackOptTest, UseVariadicStackWithListMutationAfterCat) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(56, 56, 56), %1: Float(56, 56, 56), %2: Float(56, 56, 56)): %10 : int = prim::Constant[value=0]() %input : Tensor[] = prim::ListConstruct(%0, %1) %stack : Float(3, 56, 56, 56) = aten::stack(%input, %10) %11 : Tensor = aten::append(%input, %2) return (%stack, %input) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ASSERT_TRUE(UseVariadicStack(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // The input list to `aten::stack` is mutated only after `aten::stack` op. So, // it should have been replaced with `prim::VarStack`. The transformed graph // should look like the following: // // graph(%0 : ..., // %1 : ..., // %2 : ...): // %3 : int = prim:Constant[value=0]() // %4 : Tensor[] = prim::ListConstruct(%0, %1) // %7 : Tensor = prim::VarStack(%0, %1, %3) // %6 : Tensor = aten::append(%4, %2) // return (%7, %4) testing::FileCheck() .check_count("= prim::ListConstruct(", 1, /*exactly*/ true) ->check_count("= prim::VarStack(", 1, /*exactly*/ true) ->check_count("= aten::stack(", 0, /*exactly*/ true) ->run(*graph); } TEST(StackOptTest, UseVariadicStackWithListMutationBeforeCat) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(56, 56, 56), %1: Float(56, 56, 56), %2: Float(56, 56, 56)): %10 : int = prim::Constant[value=0]() %input : Tensor[] = prim::ListConstruct(%0, %1) %11 : Tensor = aten::append(%input, %2) %stack : Float(3, 56, 56, 56) = aten::stack(%input, %10) return (%stack) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); { ASSERT_FALSE(UseVariadicStack(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // No transformation should have happened since the `prim::ListConstruct` is // mutated before `aten::stack`. testing::FileCheck() .check_count("= prim::ListConstruct(", 1, /*exactly*/ true) ->check_count("= aten::stack(", 1, /*exactly*/ true) ->check_count("= prim::VarStack(", 0, /*exactly*/ true) ->run(*graph); } { ASSERT_TRUE(RemoveListMutationAndUseVariadicStack(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // The mutation of the list must be removed and the `aten::stack` op must // be replaced with the `prim::VarStack` op in the graph. The transformed // graph should look like the following: // // graph(%0 : ..., // %1 : ..., // %2 : ...): // %3 : int = prim:Constant[value=0]() // %7 : Tensor = prim::VarStack(%0, %1, %2, %3) // return (%7) testing::FileCheck() .check_count("= prim::VarStack(", 1, /*exactly*/ true) ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true) ->check_count("= aten::stack(", 0, /*exactly*/ true) ->run(*graph); } } TEST(StackOptTest, UseVariadicStackWithMultipleListMutations) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(56, 56, 56), %1: Float(56, 56, 56), %2: Float(56, 56, 56), %3: Float(56, 56, 56), %4: Float(56, 56, 56)): %10 : int = prim::Constant[value=0]() %input : Tensor[] = prim::ListConstruct(%0, %1) %stack.1 : Float(5, 56, 56, 56) = aten::stack(%input, %10) %11 : Tensor = aten::append(%input, %2) %stack.2 : Float(5, 56, 56, 56) = aten::stack(%input, %10) %12 : Tensor = aten::append(%input, %3) %stack.3 : Float(5, 56, 56, 56) = aten::stack(%input, %10) %13 : Tensor = aten::append(%input, %4) %stack.4 : Float(5, 56, 56, 56) = aten::stack(%input, %10) return (%stack.1, %stack.2, %stack.3, %stack.4) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ASSERT_TRUE(RemoveListMutationAndUseVariadicStack(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // All the mutations of the list must be removed and the `aten::stack` ops // must be replaced with `prim::VarStack` ops in the graph. The transformed // graph should look like the following: // // graph(%0 : ..., // %1 : ..., // %2 : ..., // %3 : ..., // %4 : ...): // %10 : int = prim:Constant[value=0]() // %5 : Tensor = prim::VarStack(%0, %1, %10) // %6 : Tensor = prim::VarStack(%0, %1, %2, %10) // %7 : Tensor = prim::VarStack(%0, %1, %2, %3, %10) // %8 : Tensor = prim::VarStack(%0, %1, %2, %3, %4, %10) // return (%5, %6, %7, %8) testing::FileCheck() .check_count("= prim::VarStack(", 4, /*exactly*/ true) ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true) ->check_count("= aten::stack(", 0, /*exactly*/ true) ->run(*graph); } } // namespace jit } // namespace torch