Files
pytorch/test/cpp/tensorexpr/test_te_fuser_pass.cpp
Hui Guo 90c3699cc8 [tensorexpr] Enabled aten::stack in the fuser pass with static shapes (#74077)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74077

Test Plan: Imported from OSS

Reviewed By: gchanan

Differential Revision: D34808051

Pulled By: huiguoo

fbshipit-source-id: 213e2ffdf87fb1a74104037cea7ef25e4bfd4307
(cherry picked from commit ad9e84842e5b47eda845827d325b08ba361a8286)
2022-03-31 04:25:43 +00:00

422 lines
14 KiB
C++

#include <gtest/gtest.h>
#include <test/cpp/tensorexpr/test_base.h>
#include <torch/csrc/jit/codegen/fuser/interface.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <sstream>
namespace torch {
namespace jit {
using namespace torch::jit::tensorexpr;
struct WithCPUFuser {
WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
overrideCanFuseOnCPU(val);
}
~WithCPUFuser() {
overrideCanFuseOnCPU(cpuFuserEnabled);
}
bool cpuFuserEnabled;
};
TEST(TEFuserPass, FuserPass_1) {
WithCPUFuser cf;
const auto graph_string = R"IR(
graph(%0 : Float(128, strides=[1], device=cpu),
%1 : Float(128, strides=[1], device=cpu)):
%12 : int = prim::Constant[value=1]()
%2.1 : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1)
%2 : Float(128, strides=[1], device=cpu) = aten::mul(%2.1, %1)
%3 : Float(128, strides=[1], device=cpu) = aten::add_(%2, %1, %12)
%4 : Float(128, strides=[1], device=cpu) = aten::mul(%2, %1)
%5 : Float(128, strides=[1], device=cpu) = aten::add(%2, %4, %12)
return (%5))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g);
// We should not be able to fuse across the in-place operation here.
testing::FileCheck()
.check("prim::TensorExprGroup_")
->check("aten::add_")
->check("prim::TensorExprGroup_")
->run(*g);
}
TEST(TEFuserPass, FuserPass_2) {
WithCPUFuser cf;
const auto graph_string = R"IR(
graph(%0 : Float(128, strides=[1], device=cpu),
%1 : Float(128, strides=[1], device=cpu)):
%12 : int = prim::Constant[value=1]()
%a : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1)
%b : Float(128, strides=[1], device=cpu) = aten::add(%0, %1, %12)
%c : Float(128, strides=[1], device=cpu) = aten::add_(%b, %1, %12)
%d : Float(128, strides=[1], device=cpu) = aten::mul(%c, %a)
return (%d))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g);
// We should not be able to fuse across the in-place operation here.
testing::FileCheck()
.check("aten::add_")
->check("prim::TensorExprGroup_0")
->run(*g);
}
TEST(TEFuserPass, FuserPass_3) {
WithCPUFuser cf;
const auto graph_string = R"IR(
graph(%x : Float(128, strides=[1], device=cpu),
%y : Float(128, strides=[1], device=cpu)):
%r : Float(128, strides=[1], device=cpu) = aten::mul(%x, %y)
return (%r))IR";
{
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g, /* min_group_size= */ 2);
// We should not create a fusion group since its size would be too small
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
}
{
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g, /* min_group_size= */ 1);
// We should create a fusion group since its size is above the threshold
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
}
}
TEST(TEFuserPass, FuserPass_0DimInput) {
WithCPUFuser cf;
const auto graph_string = R"IR(
graph(%x : Float(device=cpu),
%y : Float(device=cpu)):
%one : int = prim::Constant[value=1]()
%a : Float(device=cpu) = aten::mul(%x, %y)
%b : Float(device=cpu) = aten::add(%x, %a, %one)
return (%b))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g);
// We should fuse 0-dim tensors too
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
}
TEST(TEFuserPass, FuserPass_UnfusibleDevice) {
WithCPUFuser cf(false);
const auto graph_string = R"IR(
graph(%x : Float(10, strides=[1], device=cpu),
%y : Float(10, strides=[1], device=cpu)):
%a : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y)
return (%a))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g, /* min_group_size= */ 1);
// Test that we're not starting fusion groups from nodes with unfusible device
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
}
TEST(TEFuserPass, FuserPass_UnknownShapes) {
WithCPUFuser cf;
const auto graph_string = R"IR(
graph(%x : Tensor,
%y : Tensor):
%a : Tensor = aten::mul(%x, %y)
%b : Tensor = aten::mul(%x, %a)
return (%b))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g);
// Test that we're not generating fusion groups when shapes are not known
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
}
TEST(TEFuserPass, FuserPass_Multidevice) {
{
WithCPUFuser cf;
const auto graph_string = R"IR(
graph(%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
return (%cat))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g, /* min_group_size= */ 1);
// We should be able to fuse this
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
}
{
WithCPUFuser cf;
const auto graph_string = R"IR(
graph(%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cuda:0),
%z : Float(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
return (%cat))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g, /* min_group_size= */ 1);
// We should not fuse this aten::cat since its inputs are from different
// devices
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
}
{
WithCPUFuser cf;
const auto graph_string = R"IR(
graph(%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(10, strides=[1], device=cuda:0)):
%dim : int = prim::Constant[value=0]()
%xy_list : Tensor[] = prim::ListConstruct(%x, %y)
%xy_cat : Float(30, strides=[1], device=cpu) = aten::cat(%xy_list, %dim)
%r : Float(30, strides=[1], device=cpu) = aten::mul(%xy_cat, %z)
return (%r))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g, /* min_group_size= */ 2);
// Test that we check device before merging one node (cat) into another
// (mul)
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
}
{
WithCPUFuser cf;
const auto graph_string = R"IR(
graph(%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(10, strides=[1], device=cuda:0)):
%z2 : Tensor = aten::mul(%z, %z)
%dim : int = prim::Constant[value=0]()
%xy_list : Tensor[] = prim::ListConstruct(%x, %y, %z2)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim)
return (%cat))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g, /* min_group_size= */ 2);
// Test that we check device before merging one node (mul) into another
// (cat)
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
}
{
WithCPUFuser cf;
const auto graph_string = R"IR(
graph(%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cuda:0)):
%r : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y)
return (%r))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g, /* min_group_size= */ 1);
// We should not fuse this graph since its inputs are from different devices
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
}
{
WithCPUFuser cf;
const auto graph_string = R"IR(
graph(%x : Float(10, strides=[1], device=cuda:0),
%y : Float(20, strides=[1], device=cuda:1),
%z : Float(20, strides=[1], device=cpu)):
%x2 : Float(10, strides=[1], device=cpu) = aten::mul(%x, %x)
%y2 : Float(10, strides=[1], device=cpu) = aten::mul(%y, %y)
%z2 : Float(10, strides=[1], device=cpu) = aten::mul(%z, %z)
return (%x2, %y2, %z2))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g, /* min_group_size= */ 2);
// We should not fuse these two computations since they use different
// devices
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
}
}
TEST(TEFuserPass, FuserPass_MergeGroups) {
WithCPUFuser cf;
const auto graph_string = R"IR(
graph(%a : Float(128, strides=[1], device=cpu),
%b : Float(128, strides=[1], device=cpu)):
%x : Float(128, strides=[1], device=cpu) = aten::mul(%a, %a)
%y : Float(128, strides=[1], device=cpu) = aten::mul(%b, %b)
return (%x, %y))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g, /* min_group_size= */ 1);
// The %x and %y computations are completely independent and yet we should put
// them into a single fusion group rather than having two separate ones.
testing::FileCheck()
.check("= prim::TensorExprGroup_")
->check_not("= prim::TensorExprGroup_")
->run(*g);
}
TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) {
WithCPUFuser cf;
const auto graph_string = R"IR(
graph(%x : Bool(8, strides=[1], device=cpu),
%y : Bool(8, strides=[1], device=cpu)):
%a : Bool(8, strides=[1], device=cpu) = aten::__and__(%x, %y)
%b : Tensor = aten::__or__(%a, %y)
return (%b)
)IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g, /* min_group_size= */ 2);
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
}
TEST(TEFuserPass, FuserPass_Stack) {
WithCPUFuser cf;
const auto graph_string =
R"IR(graph(%y.1 : Float(5, 3, 3, 6, strides=[54, 18, 6, 1], requires_grad=0, device=cpu),
%x.1 : Float(5, 3, 3, 6, strides=[54, 18, 6, 1], requires_grad=0, device=cpu)):
%1 : int = prim::Constant[value=2]()
%9 : Float(5, 3, 3, 6, strides=[54, 18, 6, 1], requires_grad=0, device=cpu) = aten::tanh(%x.1)
%7 : Float(5, 3, 3, 6, strides=[54, 18, 6, 1], requires_grad=0, device=cpu) = aten::tanh(%y.1)
%5 : Tensor[] = prim::ListConstruct(%9, %7)
%z.2 : Float(5, 3, 2, 3, 6, strides=[108, 36, 18, 6, 1], requires_grad=0, device=cpu) = aten::stack(%5, %1)
return (%z.2)
)IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g, /* min_group_size= */ 2);
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
}
TEST(TEFuserPass, FuserPass_Where) {
WithCPUFuser cf;
const auto graph_string = R"IR(
graph(%x : Float(8, strides=[1], device=cpu),
%y : Float(8, strides=[1], device=cpu),
%z : Float(8, strides=[1], device=cpu)):
%cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y)
%b : Float(8, strides=[1], device=cpu) = aten::where(%cond, %y, %z)
return (%b)
)IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g, /* min_group_size= */ 2);
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
}
TEST(TEFuserPass, FuserPass_WhereList) {
WithCPUFuser cf;
const auto graph_string = R"IR(
graph(%x : Float(8, strides=[1], device=cpu),
%y : Float(8, strides=[1], device=cpu),
%z : Float(8, strides=[1], device=cpu)):
%cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y)
%b : Tensor[] = aten::where(%cond)
return (%b)
)IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g, /* min_group_size= */ 2);
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
}
TEST(TEFuserPass, DynamicShapeFusion) {
WithCPUFuser cf;
const auto graph_string = R"IR(
graph(%0 : Float(10, 5, strides=[5, 1], device=cpu),
%1 : Float(10, 5, strides=[5, 1], device=cpu)):
%2 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%0, %1)
%3 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%2, %1)
return (%3))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(
g,
/* min_group_size = */ 2,
/* add_composed_op = */ true,
/* fuse_to_dynamic_shapes = */ true);
Code code(g, "");
testing::FileCheck()
.check("prim::TensorExprDynamicGroup_")
->check("prim::TensorExprDynamicGuard")
->check("prim::TensorExprGroup_")
->run(*g);
auto run_and_compare = [&](const std::vector<at::Tensor>& inputs) {
TORCH_INTERNAL_ASSERT(inputs.size() == 2);
auto ref = at::mul(at::mul(inputs[0], inputs[1]), inputs[1]);
InterpreterState interp(code);
Stack stack(inputs.begin(), inputs.end());
interp.run(stack);
at::Tensor out = pop(stack).toTensor();
ASSERT_TRUE(at::allclose(out, ref));
};
std::vector<at::Tensor> inputs = {at::rand({10, 5}), at::rand({10, 5})};
run_and_compare(inputs);
std::vector<at::Tensor> inputs2 = {at::rand({20, 5}), at::rand({20, 5})};
run_and_compare(inputs2);
std::vector<at::Tensor> inputs3 = {at::rand({25, 60}), at::rand({25, 60})};
run_and_compare(inputs3);
}
} // namespace jit
} // namespace torch