mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[JIT] Combine concat nodes where possible (#67000)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67000 See the [related issue](https://github.com/pytorch/pytorch/issues/66654) for context. This new JIT optimization transforms patterns like this: ``` %inputs.1 : Tensor[] = prim::ListConstruct(%a, %b, %c) %concat.1 : Tensor = aten::cat(%inputs, %dim) %inputs.2 : Tensor[] = prim::ListConstruct(%x, %concat.1, %y) %concat.2 : Tensor = aten::cat(%inputs.2, %dim) ``` into this: ``` %inputs.2 : Tensor[] = prim::ListConstruct(%x, %a, %b, %c, %y) %concat.2 : Tensor = aten::cat(%inputs.2, %dim) ``` (it can do this for chains of `aten::cat` longer than 2 as well) A few conditions have to hold: 1. The `dim`s have to match. 2. `inputs.1` and `inputs.2` cannot be mutated Test Plan: `buck test caffe2/test/cpp/jit:jit -- ConcatOpt` Reviewed By: d1jang Differential Revision: D31819491 fbshipit-source-id: 9f1a501d52099eb1a630b5dd906df4c38c3817ba
This commit is contained in:
committed by
Facebook GitHub Bot
parent
30cda0b28c
commit
c697eeba72
@ -654,5 +654,92 @@ TEST(
|
||||
->run(*graph);
|
||||
}
|
||||
|
||||
TEST(ConcatOpt, CombineConcatsSimpleCase) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
const std::string input =
|
||||
R"IR(
|
||||
graph(%0: Tensor):
|
||||
%dim : int = prim::Constant[value=0]()
|
||||
%input.1 : Tensor[] = prim::ListConstruct(%0, %0)
|
||||
%concat.1 : Tensor = aten::cat(%input.1, %dim)
|
||||
%input.2 : Tensor[] = prim::ListConstruct(%concat.1, %0)
|
||||
%concat.2 : Tensor = aten::cat(%input.2, %dim)
|
||||
return (%concat.2)
|
||||
)IR";
|
||||
parseIR(input, graph.get());
|
||||
std::vector<at::Tensor> inputs = {at::rand({1})};
|
||||
auto orig_outputs = runGraph(graph, inputs);
|
||||
|
||||
ASSERT_TRUE(CombineConcats(graph));
|
||||
graph->lint();
|
||||
auto opt_outputs = runGraph(graph, inputs);
|
||||
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
|
||||
|
||||
// After performing CombineConcats:
|
||||
// graph(%0 : Tensor):
|
||||
// %dim : int = prim::Constant[value=0]()
|
||||
// %input : Tensor[] = prim::ListConstruct(%0, %0, %0)
|
||||
// %concat : Tensor = aten::cat(%input, %dim)
|
||||
// return (%concat)
|
||||
testing::FileCheck()
|
||||
.check_count("prim::ListConstruct", 1, /*exactly*/ true)
|
||||
->check_count("aten::cat", 1, /*exactly*/ true)
|
||||
->run(*graph);
|
||||
}
|
||||
|
||||
TEST(ConcatOpt, CombineConcatsLongChain) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
const std::string input =
|
||||
R"IR(
|
||||
graph(%0: Tensor, %1 : Tensor):
|
||||
%dim : int = prim::Constant[value=0]()
|
||||
%input.1 : Tensor[] = prim::ListConstruct(%0, %0)
|
||||
%concat.1 : Tensor = aten::cat(%input.1, %dim)
|
||||
%input.2 : Tensor[] = prim::ListConstruct(%1, %concat.1, %1)
|
||||
%concat.2 : Tensor = aten::cat(%input.2, %dim)
|
||||
%input.3 : Tensor[] = prim::ListConstruct(%0, %concat.2, %0)
|
||||
%concat.3 : Tensor = aten::cat(%input.3, %dim)
|
||||
return (%concat.3)
|
||||
)IR";
|
||||
parseIR(input, graph.get());
|
||||
std::vector<at::Tensor> inputs = {at::rand({1}), at::randn({1})};
|
||||
auto orig_outputs = runGraph(graph, inputs);
|
||||
|
||||
ASSERT_TRUE(CombineConcats(graph));
|
||||
graph->lint();
|
||||
auto opt_outputs = runGraph(graph, inputs);
|
||||
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
|
||||
|
||||
// After performing CombineConcats:
|
||||
// graph(%0 : Tensor):
|
||||
// %dim : int = prim::Constant[value=0]()
|
||||
// %input : Tensor[] = prim::ListConstruct(%0, %1, %0, %0, %1, %0)
|
||||
// %concat : Tensor = aten::cat(%input, %dim)
|
||||
// return (%concat)
|
||||
testing::FileCheck()
|
||||
.check_count("prim::ListConstruct", 1, /*exactly*/ true)
|
||||
->check_count("aten::cat", 1, /*exactly*/ true)
|
||||
->run(*graph);
|
||||
}
|
||||
|
||||
TEST(ConcatOpt, CombineConcatsMutation) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
const std::string input =
|
||||
R"IR(
|
||||
graph(%0: Tensor, %1 : Tensor):
|
||||
%dim : int = prim::Constant[value=0]()
|
||||
%input.1 : Tensor[] = prim::ListConstruct(%0, %0)
|
||||
%concat.1 : Tensor = aten::cat(%input.1, %dim)
|
||||
%input.2 : Tensor[] = prim::ListConstruct(%1, %concat.1, %1)
|
||||
%input.3 : Tensor[] = aten::append(%input.2, %0)
|
||||
%concat.2 : Tensor = aten::cat(%input.2, %dim)
|
||||
return (%concat.2)
|
||||
)IR";
|
||||
parseIR(input, graph.get());
|
||||
std::vector<at::Tensor> inputs = {at::rand({1}), at::randn({1})};
|
||||
// No modifications due to aten::append
|
||||
ASSERT_FALSE(CombineConcats(graph));
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user