mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Further cleaning up of torch_cpu header inclusions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109964 Approved by: https://github.com/ezyang, https://github.com/Skylion007
310 lines
11 KiB
C++
310 lines
11 KiB
C++
#include <ATen/Functions.h>
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <test/cpp/jit/test_utils.h>
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include <torch/csrc/jit/passes/variadic_ops.h>
|
|
#include <torch/csrc/jit/runtime/interpreter.h>
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
TEST(StackOptTest, UseVariadicStack) {
|
|
auto graph = std::make_shared<Graph>();
|
|
|
|
const std::string input =
|
|
R"IR(
|
|
graph(%0: Float(56, 56, 56),
|
|
%1: Float(56, 56, 56),
|
|
%2: Float(56, 56, 56),
|
|
%3: Float(56, 56, 56),
|
|
%4: Float(56, 56, 56),
|
|
%5: Float(56, 56, 56)):
|
|
%10 : int = prim::Constant[value=0]()
|
|
%input : Tensor[] = prim::ListConstruct(%0, %1, %2, %3, %4, %5)
|
|
%stack : Float(5, 56, 56, 56) = aten::stack(%input, %10)
|
|
return (%stack)
|
|
)IR";
|
|
parseIR(input, graph.get());
|
|
std::vector<at::Tensor> inputs = {
|
|
at::rand({56, 56, 56}, at::kCPU),
|
|
at::rand({56, 56, 56}, at::kCPU),
|
|
at::rand({56, 56, 56}, at::kCPU),
|
|
at::rand({56, 56, 56}, at::kCPU),
|
|
at::rand({56, 56, 56}, at::kCPU),
|
|
at::rand({56, 56, 56}, at::kCPU)};
|
|
auto orig_outputs = runGraph(graph, inputs);
|
|
|
|
ASSERT_TRUE(UseVariadicStack(graph));
|
|
graph->lint();
|
|
auto opt_outputs = runGraph(graph, inputs);
|
|
|
|
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
|
|
|
|
// After replacing `aten::stack` with `prim::VarStack` we should have the
|
|
// following graph:
|
|
//
|
|
// graph(%0 : ...,
|
|
// %1 : ...):
|
|
// %zero : int = prim:Constant[value=0]()
|
|
// %varstack : Tensor = prim::VarStack(%0, %1, %2, %3, %4, %5, %zero)
|
|
// return (%varstack)
|
|
testing::FileCheck()
|
|
.check_count("= prim::VarStack(", 1, /*exactly*/ true)
|
|
->check_count("= aten::stack(", 0, /*exactly*/ true)
|
|
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
|
|
->run(*graph);
|
|
}
|
|
|
|
TEST(StackOptTest, UseVariadicStackReplaceMultiple) {
|
|
auto graph = std::make_shared<Graph>();
|
|
|
|
const std::string input =
|
|
R"IR(
|
|
graph(%0: Float(56, 56, 56),
|
|
%1: Float(56, 56, 56),
|
|
%2: Float(56, 56, 56),
|
|
%3: Float(56, 56, 56)):
|
|
%10 : int = prim::Constant[value=0]()
|
|
%input1 : Tensor[] = prim::ListConstruct(%0, %1)
|
|
%stack1 : Float(4, 56, 56, 56) = aten::stack(%input1, %10)
|
|
%input2 : Tensor[] = prim::ListConstruct(%2, %3)
|
|
%stack2 : Float(4, 56, 56, 56) = aten::stack(%input2, %10)
|
|
return (%stack1, %stack2)
|
|
)IR";
|
|
parseIR(input, graph.get());
|
|
std::vector<at::Tensor> inputs = {
|
|
at::rand({56, 56, 56}, at::kCPU),
|
|
at::rand({56, 56, 56}, at::kCPU),
|
|
at::rand({56, 56, 56}, at::kCPU),
|
|
at::rand({56, 56, 56}, at::kCPU)};
|
|
auto orig_outputs = runGraph(graph, inputs);
|
|
|
|
ASSERT_TRUE(UseVariadicStack(graph));
|
|
graph->lint();
|
|
auto opt_outputs = runGraph(graph, inputs);
|
|
|
|
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
|
|
|
|
// After full stack optimization we should have the following graph:
|
|
//
|
|
// graph(%0 : ...,
|
|
// %1 : ...,
|
|
// %2 : ...,
|
|
// %3 : ....):
|
|
// %zero : int = prim:Constant[value=0]()
|
|
// %varcat1 : Tensor = prim::VarStack(%0, %1, %zero)
|
|
// %varcat2 : Tensor = prim::VarStack(%2, %3, %zero)
|
|
// return (%varcat1, %varcat2)
|
|
testing::FileCheck()
|
|
.check_count("= prim::VarStack(", 2, /*exactly*/ true)
|
|
->check_count("= aten::stack(", 0, /*exactly*/ true)
|
|
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
|
|
->run(*graph);
|
|
}
|
|
|
|
TEST(StackOptTest, UseVariadicStackWithMultipleListUses) {
|
|
auto graph = std::make_shared<Graph>();
|
|
|
|
const std::string input =
|
|
R"IR(
|
|
graph(%0: Float(56, 56, 56),
|
|
%1: Float(56, 56, 56)):
|
|
%2 : int = prim::Constant[value=0]()
|
|
%input : Tensor[] = prim::ListConstruct(%0, %1)
|
|
%stack : Float(2, 56, 56, 56) = aten::stack(%input, %2)
|
|
return (%stack, %input)
|
|
)IR";
|
|
parseIR(input, graph.get());
|
|
std::vector<at::Tensor> inputs = {
|
|
at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU)};
|
|
auto orig_outputs = runGraph(graph, inputs);
|
|
|
|
ASSERT_TRUE(UseVariadicStack(graph));
|
|
graph->lint();
|
|
auto opt_outputs = runGraph(graph, inputs);
|
|
|
|
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
|
|
|
|
// After replacing `aten::stack` with `prim::VarStack` we should have the
|
|
// following graph:
|
|
//
|
|
// graph(%0 : ...,
|
|
// %1 : ...):
|
|
// %zero : int = prim:Constant[value=0]()
|
|
// %input : Tensor[] = prim::ListConstruct(%0, %1)
|
|
// %varcat : Tensor = prim::VarStack(%0, %1, %zero)
|
|
// return (%varcat, %input)
|
|
testing::FileCheck()
|
|
.check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
|
|
->check_count("= prim::VarStack(", 1, /*exactly*/ true)
|
|
->check_count("= aten::stack(", 0, /*exactly*/ true)
|
|
->run(*graph);
|
|
}
|
|
|
|
TEST(StackOptTest, UseVariadicStackWithListMutationAfterCat) {
|
|
auto graph = std::make_shared<Graph>();
|
|
|
|
const std::string input =
|
|
R"IR(
|
|
graph(%0: Float(56, 56, 56),
|
|
%1: Float(56, 56, 56),
|
|
%2: Float(56, 56, 56)):
|
|
%10 : int = prim::Constant[value=0]()
|
|
%input : Tensor[] = prim::ListConstruct(%0, %1)
|
|
%stack : Float(3, 56, 56, 56) = aten::stack(%input, %10)
|
|
%11 : Tensor = aten::append(%input, %2)
|
|
return (%stack, %input)
|
|
)IR";
|
|
parseIR(input, graph.get());
|
|
std::vector<at::Tensor> inputs = {
|
|
at::rand({56, 56, 56}, at::kCPU),
|
|
at::rand({56, 56, 56}, at::kCPU),
|
|
at::rand({56, 56, 56}, at::kCPU)};
|
|
auto orig_outputs = runGraph(graph, inputs);
|
|
|
|
ASSERT_TRUE(UseVariadicStack(graph));
|
|
graph->lint();
|
|
auto opt_outputs = runGraph(graph, inputs);
|
|
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
|
|
|
|
// The input list to `aten::stack` is mutated only after `aten::stack` op. So,
|
|
// it should have been replaced with `prim::VarStack`. The transformed graph
|
|
// should look like the following:
|
|
//
|
|
// graph(%0 : ...,
|
|
// %1 : ...,
|
|
// %2 : ...):
|
|
// %3 : int = prim:Constant[value=0]()
|
|
// %4 : Tensor[] = prim::ListConstruct(%0, %1)
|
|
// %7 : Tensor = prim::VarStack(%0, %1, %3)
|
|
// %6 : Tensor = aten::append(%4, %2)
|
|
// return (%7, %4)
|
|
testing::FileCheck()
|
|
.check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
|
|
->check_count("= prim::VarStack(", 1, /*exactly*/ true)
|
|
->check_count("= aten::stack(", 0, /*exactly*/ true)
|
|
->run(*graph);
|
|
}
|
|
|
|
TEST(StackOptTest, UseVariadicStackWithListMutationBeforeCat) {
|
|
auto graph = std::make_shared<Graph>();
|
|
|
|
const std::string input =
|
|
R"IR(
|
|
graph(%0: Float(56, 56, 56),
|
|
%1: Float(56, 56, 56),
|
|
%2: Float(56, 56, 56)):
|
|
%10 : int = prim::Constant[value=0]()
|
|
%input : Tensor[] = prim::ListConstruct(%0, %1)
|
|
%11 : Tensor = aten::append(%input, %2)
|
|
%stack : Float(3, 56, 56, 56) = aten::stack(%input, %10)
|
|
return (%stack)
|
|
)IR";
|
|
parseIR(input, graph.get());
|
|
std::vector<at::Tensor> inputs = {
|
|
at::rand({56, 56, 56}, at::kCPU),
|
|
at::rand({56, 56, 56}, at::kCPU),
|
|
at::rand({56, 56, 56}, at::kCPU)};
|
|
auto orig_outputs = runGraph(graph, inputs);
|
|
|
|
{
|
|
ASSERT_FALSE(UseVariadicStack(graph));
|
|
graph->lint();
|
|
auto opt_outputs = runGraph(graph, inputs);
|
|
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
|
|
|
|
// No transformation should have happened since the `prim::ListConstruct` is
|
|
// mutated before `aten::stack`.
|
|
testing::FileCheck()
|
|
.check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
|
|
->check_count("= aten::stack(", 1, /*exactly*/ true)
|
|
->check_count("= prim::VarStack(", 0, /*exactly*/ true)
|
|
->run(*graph);
|
|
}
|
|
|
|
{
|
|
ASSERT_TRUE(RemoveListMutationAndUseVariadicStack(graph));
|
|
graph->lint();
|
|
auto opt_outputs = runGraph(graph, inputs);
|
|
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
|
|
|
|
// The mutation of the list must be removed and the `aten::stack` op must
|
|
// be replaced with the `prim::VarStack` op in the graph. The transformed
|
|
// graph should look like the following:
|
|
//
|
|
// graph(%0 : ...,
|
|
// %1 : ...,
|
|
// %2 : ...):
|
|
// %3 : int = prim:Constant[value=0]()
|
|
// %7 : Tensor = prim::VarStack(%0, %1, %2, %3)
|
|
// return (%7)
|
|
testing::FileCheck()
|
|
.check_count("= prim::VarStack(", 1, /*exactly*/ true)
|
|
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
|
|
->check_count("= aten::stack(", 0, /*exactly*/ true)
|
|
->run(*graph);
|
|
}
|
|
}
|
|
|
|
TEST(StackOptTest, UseVariadicStackWithMultipleListMutations) {
|
|
auto graph = std::make_shared<Graph>();
|
|
|
|
const std::string input =
|
|
R"IR(
|
|
graph(%0: Float(56, 56, 56),
|
|
%1: Float(56, 56, 56),
|
|
%2: Float(56, 56, 56),
|
|
%3: Float(56, 56, 56),
|
|
%4: Float(56, 56, 56)):
|
|
%10 : int = prim::Constant[value=0]()
|
|
%input : Tensor[] = prim::ListConstruct(%0, %1)
|
|
%stack.1 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
|
|
%11 : Tensor = aten::append(%input, %2)
|
|
%stack.2 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
|
|
%12 : Tensor = aten::append(%input, %3)
|
|
%stack.3 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
|
|
%13 : Tensor = aten::append(%input, %4)
|
|
%stack.4 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
|
|
return (%stack.1, %stack.2, %stack.3, %stack.4)
|
|
)IR";
|
|
parseIR(input, graph.get());
|
|
std::vector<at::Tensor> inputs = {
|
|
at::rand({56, 56, 56}, at::kCPU),
|
|
at::rand({56, 56, 56}, at::kCPU),
|
|
at::rand({56, 56, 56}, at::kCPU),
|
|
at::rand({56, 56, 56}, at::kCPU),
|
|
at::rand({56, 56, 56}, at::kCPU)};
|
|
auto orig_outputs = runGraph(graph, inputs);
|
|
|
|
ASSERT_TRUE(RemoveListMutationAndUseVariadicStack(graph));
|
|
graph->lint();
|
|
auto opt_outputs = runGraph(graph, inputs);
|
|
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
|
|
|
|
// All the mutations of the list must be removed and the `aten::stack` ops
|
|
// must be replaced with `prim::VarStack` ops in the graph. The transformed
|
|
// graph should look like the following:
|
|
//
|
|
// graph(%0 : ...,
|
|
// %1 : ...,
|
|
// %2 : ...,
|
|
// %3 : ...,
|
|
// %4 : ...):
|
|
// %10 : int = prim:Constant[value=0]()
|
|
// %5 : Tensor = prim::VarStack(%0, %1, %10)
|
|
// %6 : Tensor = prim::VarStack(%0, %1, %2, %10)
|
|
// %7 : Tensor = prim::VarStack(%0, %1, %2, %3, %10)
|
|
// %8 : Tensor = prim::VarStack(%0, %1, %2, %3, %4, %10)
|
|
// return (%5, %6, %7, %8)
|
|
testing::FileCheck()
|
|
.check_count("= prim::VarStack(", 4, /*exactly*/ true)
|
|
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
|
|
->check_count("= aten::stack(", 0, /*exactly*/ true)
|
|
->run(*graph);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|