[JIT] Combine concat nodes where possible (#67000)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67000

See the [related issue](https://github.com/pytorch/pytorch/issues/66654) for context.

This new JIT optimization transforms patterns like this:
```
%inputs.1 : Tensor[] = prim::ListConstruct(%a, %b, %c)
%concat.1 : Tensor = aten::cat(%inputs, %dim)
%inputs.2 : Tensor[] = prim::ListConstruct(%x, %concat.1, %y)
%concat.2 : Tensor = aten::cat(%inputs.2, %dim)
```
into this:
```
%inputs.2 : Tensor[] = prim::ListConstruct(%x, %a, %b, %c, %y)
%concat.2 : Tensor = aten::cat(%inputs.2, %dim)
```
(it can do this for chains of `aten::cat` longer than 2 as well)

A few conditions have to hold:
1.  The `dim`s have to match.
2. `inputs.1` and `inputs.2` cannot be mutated

Test Plan: `buck test caffe2/test/cpp/jit:jit -- ConcatOpt`

Reviewed By: d1jang

Differential Revision: D31819491

fbshipit-source-id: 9f1a501d52099eb1a630b5dd906df4c38c3817ba
This commit is contained in:
Mike Iovine
2021-11-15 11:57:50 -08:00
committed by Facebook GitHub Bot
parent 30cda0b28c
commit c697eeba72
4 changed files with 289 additions and 0 deletions

View File

@ -654,5 +654,92 @@ TEST(
->run(*graph);
}
TEST(ConcatOpt, CombineConcatsSimpleCase) {
auto graph = std::make_shared<Graph>();
const std::string input =
R"IR(
graph(%0: Tensor):
%dim : int = prim::Constant[value=0]()
%input.1 : Tensor[] = prim::ListConstruct(%0, %0)
%concat.1 : Tensor = aten::cat(%input.1, %dim)
%input.2 : Tensor[] = prim::ListConstruct(%concat.1, %0)
%concat.2 : Tensor = aten::cat(%input.2, %dim)
return (%concat.2)
)IR";
parseIR(input, graph.get());
std::vector<at::Tensor> inputs = {at::rand({1})};
auto orig_outputs = runGraph(graph, inputs);
ASSERT_TRUE(CombineConcats(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// After performing CombineConcats:
// graph(%0 : Tensor):
// %dim : int = prim::Constant[value=0]()
// %input : Tensor[] = prim::ListConstruct(%0, %0, %0)
// %concat : Tensor = aten::cat(%input, %dim)
// return (%concat)
testing::FileCheck()
.check_count("prim::ListConstruct", 1, /*exactly*/ true)
->check_count("aten::cat", 1, /*exactly*/ true)
->run(*graph);
}
TEST(ConcatOpt, CombineConcatsLongChain) {
auto graph = std::make_shared<Graph>();
const std::string input =
R"IR(
graph(%0: Tensor, %1 : Tensor):
%dim : int = prim::Constant[value=0]()
%input.1 : Tensor[] = prim::ListConstruct(%0, %0)
%concat.1 : Tensor = aten::cat(%input.1, %dim)
%input.2 : Tensor[] = prim::ListConstruct(%1, %concat.1, %1)
%concat.2 : Tensor = aten::cat(%input.2, %dim)
%input.3 : Tensor[] = prim::ListConstruct(%0, %concat.2, %0)
%concat.3 : Tensor = aten::cat(%input.3, %dim)
return (%concat.3)
)IR";
parseIR(input, graph.get());
std::vector<at::Tensor> inputs = {at::rand({1}), at::randn({1})};
auto orig_outputs = runGraph(graph, inputs);
ASSERT_TRUE(CombineConcats(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// After performing CombineConcats:
// graph(%0 : Tensor):
// %dim : int = prim::Constant[value=0]()
// %input : Tensor[] = prim::ListConstruct(%0, %1, %0, %0, %1, %0)
// %concat : Tensor = aten::cat(%input, %dim)
// return (%concat)
testing::FileCheck()
.check_count("prim::ListConstruct", 1, /*exactly*/ true)
->check_count("aten::cat", 1, /*exactly*/ true)
->run(*graph);
}
TEST(ConcatOpt, CombineConcatsMutation) {
auto graph = std::make_shared<Graph>();
const std::string input =
R"IR(
graph(%0: Tensor, %1 : Tensor):
%dim : int = prim::Constant[value=0]()
%input.1 : Tensor[] = prim::ListConstruct(%0, %0)
%concat.1 : Tensor = aten::cat(%input.1, %dim)
%input.2 : Tensor[] = prim::ListConstruct(%1, %concat.1, %1)
%input.3 : Tensor[] = aten::append(%input.2, %0)
%concat.2 : Tensor = aten::cat(%input.2, %dim)
return (%concat.2)
)IR";
parseIR(input, graph.get());
std::vector<at::Tensor> inputs = {at::rand({1}), at::randn({1})};
// No modifications due to aten::append
ASSERT_FALSE(CombineConcats(graph));
}
} // namespace jit
} // namespace torch