#include #include #include #include #include #include #include #include namespace torch { namespace jit { TEST(ConcatOptTest, SimpleCommonInputsEliminationPrefix) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): %5 : int = prim::Constant[value=0]() %concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5) %concat.3 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5) %res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3) return (%res) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ASSERT_TRUE(EliminateConcatCommonInputs(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // Graph after EliminateConcatCommonInputs: // graph(%0 : ..., // %1 : ..., // %2 : ...): // %3 : int = prim::Constant[value=0]() // %4 : Tensor = prim::VarConcat(%0, %1, %3) // %7 : Tensor = prim::VarConcat(%4, %2, %3) // UPDATED // %8 : Tensor[] = prim::ListConstruct(%4, %7) // return (%8) testing::FileCheck() .check_count("= prim::VarConcat(%0, %1, %3)", 1, /*exactly*/ true) ->check_count("= prim::VarConcat(%4, %2, %3)", 1, /*exactly*/ true) ->check_count("= prim::ListConstruct(%4, %7)", 1, /*exactly*/ true) ->check_count("= aten::cat(", 0, /*exactly*/ true) ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true) ->run(*graph); } TEST(ConcatOptTest, SimpleCommonInputsEliminationSuffix) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): %5 : int = prim::Constant[value=0]() %concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%1, %2, %5) %concat.3 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5) %res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3) return (%res) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ASSERT_TRUE(EliminateConcatCommonInputs(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // Graph after EliminateConcatCommonInputs: // graph(%0 : ..., // %1 : ..., // %2 : ...): // %3 : int = prim::Constant[value=0]() // %4 : Tensor = prim::VarConcat(%1, %2, %3) // %7 : Tensor = prim::VarConcat(%0, %4, %3) // UPDATED // %8 : Tensor[] = prim::ListConstruct(%4, %7) // return (%8) testing::FileCheck() .check_count("= prim::VarConcat(%1, %2, %3)", 1, /*exactly*/ true) ->check_count("= prim::VarConcat(%0, %4, %3)", 1, /*exactly*/ true) ->check_count("= prim::ListConstruct(%4, %7)", 1, /*exactly*/ true) ->check_count("= aten::cat(", 0, /*exactly*/ true) ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true) ->run(*graph); } TEST(ConcatOptTest, CommonInputsEliminationWithDifferentOrderInputs) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): %5 : int = prim::Constant[value=0]() #CHECK: prim::VarConcat %concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5) #CHECK: prim::VarConcat %concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%1, %0, %2, %5) #CHECK: prim::ListConstruct %res : Tensor[] = prim::ListConstruct(%concat.1, %concat.2) return (%res) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ASSERT_FALSE(EliminateConcatCommonInputs(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // No optimizations should have happened in this case since the inputs // to the `cat` are in different order. testing::FileCheck().run(input, *graph); } TEST(ConcatOptTest, MoreCommonInputsElimination) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): %5 : int = prim::Constant[value=0]() %concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5) %concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5) %concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %3, %5) %concat.4 : Float(192, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %3, %4, %5) %res : Tensor[] = prim::ListConstruct(%concat.1, %concat.2, %concat.3, %concat.4) return (%res) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ASSERT_TRUE(EliminateConcatCommonInputs(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); testing::FileCheck() .check_count("= prim::VarConcat(%0, %1, %5)", 1, /*exactly*/ true) ->check_count("= prim::VarConcat(%6, %2, %5)", 1, /*exactly*/ true) ->check_count("= prim::VarConcat(%11, %3, %5)", 1, /*exactly*/ true) ->check_count("= prim::VarConcat(%12, %4, %5)", 1, /*exactly*/ true) ->check_count("= aten::cat(", 0, /*exactly*/ true) ->run(*graph); } TEST(ConcatOptTest, ExpandConcat) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): %2 : int = prim::Constant[value=0]() %3 : float = prim::Constant[value=0.5]() %4 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%0, %3) %5 : Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%1, %3) %input : Tensor[] = prim::ListConstruct(%4, %5) %concat : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %2) return (%concat) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ExpandConcatAndEliminateRedundancy(graph); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // After full concat optimization we should have the following graph: // // graph(%0 : ..., // %1 : ...): // ... // %4 : Tensor = aten::clamp_max(...) // %5 : Tensor = aten::clamp_max(...) // %13 : int[] = prim::ListConstruct(...) // %14 : Tensor = aten::empty(%13, ...) // concat buffer // %17 : Tensor = aten::slice(%14, ...) // slice for %4 // %18 : Tensor = aten::copy_(%17, %4) // %20 : Tensor = aten::slice(%14, ...) // slice for %5 // %21 : Tensor = aten::copy_(%20, %5) // return (%14) testing::FileCheck() .check_count("= aten::cat(", 0, /*exactly*/ true) ->check_count("= aten::clamp_max(", 2, /*exactly*/ true) ->check_count("= aten::empty(", 1, /*exactly*/ true) ->check_count("= aten::slice(", 1, /*exactly*/ true) ->check_count("= aten::copy_(", 1, /*exactly*/ true) ->check_count("= aten::slice(", 1, /*exactly*/ true) ->check_count("= aten::copy_(", 1, /*exactly*/ true) ->check_count("= aten::cat(", 0, /*exactly*/ true) ->run(*graph); } TEST(ConcatOptTest, ConcatWithoutResultShape) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): %2 : int = prim::Constant[value=0]() %3 : float = prim::Constant[value=0.5]() # CHECK: clamp_max %4 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%0, %3) # CHECK: clamp_max %5 : Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%1, %3) # CHECK: prim::ListConstruct %6 : Tensor[] = prim::ListConstruct(%4, %5) # CHECK: aten::cat %7 : Tensor = aten::cat(%6, %2) return (%7) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ExpandConcatAndEliminateRedundancy(graph); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // No optimizations should have happened in this case since the output // shape of `aten::cat` is not known. testing::FileCheck().run(input, *graph); } TEST(ConcatOptTest, ConcatWithoutInputShape) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): %2 : int = prim::Constant[value=0]() %3 : float = prim::Constant[value=0.5]() # CHECK: clamp_max %4 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%0, %3) # CHECK: clamp_max %5 : Tensor = aten::clamp_max(%1, %3) # CHECK: prim::ListConstruct %6 : Tensor[] = prim::ListConstruct(%4, %5) # CHECK: aten::cat %7 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%6, %2) return (%7) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ExpandConcatAndEliminateRedundancy(graph); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // No optimizations should have happened in this case since the shape of %5, // which is an input to `aten::cat`, is not known. testing::FileCheck().run(input, *graph); } TEST(ConcatOptTest, UseVariadicCat) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %5: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): %10 : int = prim::Constant[value=0]() %input : Tensor[] = prim::ListConstruct(%0, %1, %2, %3, %4, %5) %concat : Float(224, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10) return (%concat) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ASSERT_TRUE(UseVariadicCat(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // After replacing `aten::cat` with `prim::VarConcat` we should have the // following graph: // // graph(%0 : ..., // %1 : ...): // %zero : int = prim:Constant[value=0]() // %varcat : Tensor = prim::VarConcat(%0, %1, %2, %3, %4, %5, %zero) // return (%varcat) testing::FileCheck() .check_count("= prim::VarConcat(", 1, /*exactly*/ true) ->check_count("= aten::cat(", 0, /*exactly*/ true) ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true) ->run(*graph); } TEST(OptimizeConcatTest, UseVariadicCatReplaceMultiple) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): %10 : int = prim::Constant[value=0]() %input1 : Tensor[] = prim::ListConstruct(%0, %1) %concat1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input1, %10) %input2 : Tensor[] = prim::ListConstruct(%2, %3) %concat2 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input2, %10) return (%concat1, %concat2) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ASSERT_TRUE(UseVariadicCat(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // After full concat optimization we should have the following graph: // // graph(%0 : ..., // %1 : ..., // %2 : ..., // %3 : ....): // %zero : int = prim:Constant[value=0]() // %varcat1 : Tensor = prim::VarConcat(%0, %1, %zero) // %varcat2 : Tensor = prim::VarConcat(%2, %3, %zero) // return (%varcat1, %varcat2) testing::FileCheck() .check_count("= prim::VarConcat(", 2, /*exactly*/ true) ->check_count("= aten::cat(", 0, /*exactly*/ true) ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true) ->run(*graph); } TEST(ConcatOptTest, UseVariadicCatWithMultipleListUses) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): %2 : int = prim::Constant[value=0]() %input : Tensor[] = prim::ListConstruct(%0, %1) %concat : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %2) return (%concat, %input) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ASSERT_TRUE(UseVariadicCat(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // After replacing `aten::cat` with `prim::VarConcat` we should have the // following graph: // // graph(%0 : ..., // %1 : ...): // %zero : int = prim:Constant[value=0]() // %input : Tensor[] = prim::ListConstruct(%0, %1) // %varcat : Tensor = prim::VarConcat(%0, %1, %zero) // return (%varcat, %input) testing::FileCheck() .check_count("= prim::ListConstruct(", 1, /*exactly*/ true) ->check_count("= prim::VarConcat(", 1, /*exactly*/ true) ->check_count("= aten::cat(", 0, /*exactly*/ true) ->run(*graph); } TEST(ConcatOptTest, UseVariadicCatWithListMutationAfterCat) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): %10 : int = prim::Constant[value=0]() %input : Tensor[] = prim::ListConstruct(%0, %1) %concat : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10) %11 : Tensor = aten::append(%input, %2) return (%concat, %input) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ASSERT_TRUE(UseVariadicCat(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // The input list to `aten::cat` is mutated only after `aten::cat` op. So, // it should have been replaced with `prim::VarConcat`. The transformed graph // should look like the following: // // graph(%0 : ..., // %1 : ..., // %2 : ...): // %3 : int = prim:Constant[value=0]() // %4 : Tensor[] = prim::ListConstruct(%0, %1) // %7 : Tensor = prim::VarConcat(%0, %1, %3) // %6 : Tensor = aten::append(%4, %2) // return (%7, %4) testing::FileCheck() .check_count("= prim::ListConstruct(", 1, /*exactly*/ true) ->check_count("= prim::VarConcat(", 1, /*exactly*/ true) ->check_count("= aten::cat(", 0, /*exactly*/ true) ->run(*graph); } TEST(ConcatOptTest, UseVariadicCatWithListMutationBeforeCat) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): %10 : int = prim::Constant[value=0]() %input : Tensor[] = prim::ListConstruct(%0, %1) %11 : Tensor = aten::append(%input, %2) %concat : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10) return (%concat) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); { ASSERT_FALSE(UseVariadicCat(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // No transformation should have happened since the `prim::ListConstruct` is // mutated before `aten::cat`. testing::FileCheck() .check_count("= prim::ListConstruct(", 1, /*exactly*/ true) ->check_count("= aten::cat(", 1, /*exactly*/ true) ->check_count("= prim::VarConcat(", 0, /*exactly*/ true) ->run(*graph); } { ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // The mutation of the list must be removed and the `aten::cat` op must // be replaced with the `prim::VarConcat` op in the graph. The transformed // graph should look like the following: // // graph(%0 : ..., // %1 : ..., // %2 : ...): // %3 : int = prim:Constant[value=0]() // %7 : Tensor = prim::VarConcat(%0, %1, %2, %3) // return (%7) testing::FileCheck() .check_count("= prim::VarConcat(", 1, /*exactly*/ true) ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true) ->check_count("= aten::cat(", 0, /*exactly*/ true) ->run(*graph); } } TEST(ConcatOptTest, UseVariadicCatWithMultipleListMutations) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): %10 : int = prim::Constant[value=0]() %input : Tensor[] = prim::ListConstruct(%0, %1) %concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10) %11 : Tensor = aten::append(%input, %2) %concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10) %12 : Tensor = aten::append(%input, %3) %concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10) %13 : Tensor = aten::append(%input, %4) %concat.4 : Float(192, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10) return (%concat.1, %concat.2, %concat.3, %concat.4) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // All the mutations of the list must be removed and the `aten::cat` ops must // be replaced with `prim::VarConcat` ops in the graph. The transformed graph // should look like the following: // // graph(%0 : ..., // %1 : ..., // %2 : ..., // %3 : ..., // %4 : ...): // %10 : int = prim:Constant[value=0]() // %5 : Tensor = prim::VarConcat(%0, %1, %10) // %6 : Tensor = prim::VarConcat(%0, %1, %2, %10) // %7 : Tensor = prim::VarConcat(%0, %1, %2, %3, %10) // %8 : Tensor = prim::VarConcat(%0, %1, %2, %3, %4, %10) // return (%5, %6, %7, %8) testing::FileCheck() .check_count("= prim::VarConcat(", 4, /*exactly*/ true) ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true) ->check_count("= aten::cat(", 0, /*exactly*/ true) ->run(*graph); } TEST( ConcatOptTest, RemoveListMutationUseVariadicCatAndCommonInputsElimination) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): %5 : int = prim::Constant[value=0]() %features.2 : Tensor[] = prim::ListConstruct(%0, %1) %6 : Tensor [] = aten::append(%features.2, %2) %concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%features.2, %5) %7 : Tensor [] = aten::append(%features.2, %0) %concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%features.2, %5) %res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3) return (%res) )IR"; parseIR(input, graph.get()); std::vector inputs = { at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; auto orig_outputs = runGraph(graph, inputs); ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph)); ASSERT_TRUE(EliminateConcatCommonInputs(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // After performing: // * Remove list mutation // * Use variadic cat // * Eliminate common inputs // we should have the following graph: // // graph(%0 : ..., // %1 : ..., // %2 : ...): // %3 : int = prim::Constant[value=0]() // %10 : Tensor = prim::VarConcat(%0, %1, %2, %3) // %12 : Tensor = prim::VarConcat(%10, %0, %3) // UPDATED // %8 : Tensor[] = prim::ListConstruct(%10, %12) // return (%8) testing::FileCheck() .check_count("= prim::VarConcat(%0, %1, %2, %3)", 1, /*exactly*/ true) ->check_count("= prim::VarConcat(%10, %0, %3)", 1, /*exactly*/ true) ->check_count("= prim::ListConstruct(%10, %12)", 1, /*exactly*/ true) ->check_count("= aten::cat(", 0, /*exactly*/ true) ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true) ->run(*graph); } TEST(ConcatOpt, CombineConcatsSimpleCase) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Tensor): %dim : int = prim::Constant[value=0]() %input.1 : Tensor[] = prim::ListConstruct(%0, %0) %concat.1 : Tensor = aten::cat(%input.1, %dim) %input.2 : Tensor[] = prim::ListConstruct(%concat.1, %0) %concat.2 : Tensor = aten::cat(%input.2, %dim) return (%concat.2) )IR"; parseIR(input, graph.get()); std::vector inputs = {at::rand({1})}; auto orig_outputs = runGraph(graph, inputs); ASSERT_TRUE(CombineConcats(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // After performing CombineConcats: // graph(%0 : Tensor): // %dim : int = prim::Constant[value=0]() // %input : Tensor[] = prim::ListConstruct(%0, %0, %0) // %concat : Tensor = aten::cat(%input, %dim) // return (%concat) testing::FileCheck() .check_count("prim::ListConstruct", 1, /*exactly*/ true) ->check_count("aten::cat", 1, /*exactly*/ true) ->run(*graph); } TEST(ConcatOpt, CombineConcatsLongChain) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Tensor, %1 : Tensor): %dim : int = prim::Constant[value=0]() %input.1 : Tensor[] = prim::ListConstruct(%0, %0) %concat.1 : Tensor = aten::cat(%input.1, %dim) %input.2 : Tensor[] = prim::ListConstruct(%1, %concat.1, %1) %concat.2 : Tensor = aten::cat(%input.2, %dim) %input.3 : Tensor[] = prim::ListConstruct(%0, %concat.2, %0) %concat.3 : Tensor = aten::cat(%input.3, %dim) return (%concat.3) )IR"; parseIR(input, graph.get()); std::vector inputs = {at::rand({1}), at::randn({1})}; auto orig_outputs = runGraph(graph, inputs); ASSERT_TRUE(CombineConcats(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // After performing CombineConcats: // graph(%0 : Tensor): // %dim : int = prim::Constant[value=0]() // %input : Tensor[] = prim::ListConstruct(%0, %1, %0, %0, %1, %0) // %concat : Tensor = aten::cat(%input, %dim) // return (%concat) testing::FileCheck() .check_count("prim::ListConstruct", 1, /*exactly*/ true) ->check_count("aten::cat", 1, /*exactly*/ true) ->run(*graph); } TEST(ConcatOpt, CombineConcatsMutation) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%0: Tensor, %1 : Tensor): %dim : int = prim::Constant[value=0]() %input.1 : Tensor[] = prim::ListConstruct(%0, %0) %concat.1 : Tensor = aten::cat(%input.1, %dim) %input.2 : Tensor[] = prim::ListConstruct(%1, %concat.1, %1) %input.3 : Tensor[] = aten::append(%input.2, %0) %concat.2 : Tensor = aten::cat(%input.2, %dim) return (%concat.2) )IR"; parseIR(input, graph.get()); std::vector inputs = {at::rand({1}), at::randn({1})}; // No modifications due to aten::append ASSERT_FALSE(CombineConcats(graph)); } } // namespace jit } // namespace torch