Files
pytorch/test/cpp/tensorexpr/test_graph_opt.cpp
mikey dagitses 322e4b4c8a set -Wsuggest-override for builds (#89852)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/89852).
* __->__ #89852
* #89851

set -Wsuggest-override for builds

Summary: This was flagged by a Meta internal build.

Test Plan: Rely on CI.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89852
Approved by: https://github.com/malfet
2022-12-19 22:08:47 +00:00

320 lines
10 KiB
C++

#include <gtest/gtest.h>
#include <test/cpp/tensorexpr/test_base.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/torch.h>
#include <limits>
namespace torch {
namespace jit {
using namespace torch::jit::tensorexpr;
class GraphOpt : public ::testing::Test {
public:
void SetUp() override {
old_cat_wo_conditionals_ = getCatWoConditionals();
getCatWoConditionals() = true;
}
void TearDown() override {
getCatWoConditionals() = old_cat_wo_conditionals_;
}
private:
bool old_cat_wo_conditionals_;
};
TEST_F(GraphOpt, OptimizeCat) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
return (%5))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// The `aten::log` op must be moved to the inputs of `aten::cat`.
testing::FileCheck()
.check("aten::log")
->check("aten::log")
->check("aten::log")
->check("aten::cat")
->check_not("aten::log")
->run(*kernel.graph());
auto x = at::rand({10}, at::kFloat);
auto y = at::rand({20}, at::kFloat);
auto z = at::rand({30}, at::kFloat);
auto ref = at::log(at::cat({x, y, z}, 0));
std::vector<at::Tensor> inputs = {x, y, z};
std::vector<IValue> stack = fmap<IValue>(inputs);
kernel.run(stack);
auto out = stack[0].toTensor();
ASSERT_EQ(out.sizes(), ref.sizes());
ASSERT_EQ(out.dtype(), ref.dtype());
ASSERT_TRUE(at::allclose(out, ref));
#endif
}
TEST_F(GraphOpt, OptimizeCat2) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
%6 : Float(60, strides=[1], device=cpu) = aten::tanh(%5)
return (%6))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// The `aten::log` and `aten::tanh` ops must be moved to the inputs of
// `aten::cat`.
testing::FileCheck()
.check("aten::log")
->check("aten::log")
->check("aten::log")
->check("aten::tanh")
->check("aten::tanh")
->check("aten::tanh")
->check("aten::cat")
->check_not("aten::log")
->check_not("aten::tanh")
->run(*kernel.graph());
auto x = at::rand({10}, at::kFloat);
auto y = at::rand({20}, at::kFloat);
auto z = at::rand({30}, at::kFloat);
auto ref = at::tanh(at::log(at::cat({x, y, z}, 0)));
std::vector<at::Tensor> inputs = {x, y, z};
std::vector<IValue> stack = fmap<IValue>(inputs);
kernel.run(stack);
auto out = stack[0].toTensor();
ASSERT_EQ(out.sizes(), ref.sizes());
ASSERT_EQ(out.dtype(), ref.dtype());
ASSERT_TRUE(at::allclose(out, ref));
#endif
}
TEST_F(GraphOpt, OptimizeCat3) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%a : Float(60, strides=[1], device=cpu),
%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat)
%6 : Float(60, strides=[1], device=cpu) = aten::mul(%a, %5)
return (%6))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// The `aten::tanh` op must be moved to the inputs of `aten::cat`.
// But the `aten::mul` op must not be moved since it is not a single-tensor
// op (it has 2 tensor inputs).
testing::FileCheck()
.check("aten::tanh")
->check("aten::tanh")
->check("aten::tanh")
->check("aten::cat")
->check("aten::mul")
->check_not("aten::tanh")
->run(*kernel.graph());
auto a = at::rand({60}, at::kFloat);
auto x = at::rand({10}, at::kFloat);
auto y = at::rand({20}, at::kFloat);
auto z = at::rand({30}, at::kFloat);
auto ref = at::tanh(at::cat({x, y, z}, 0)) * a;
std::vector<at::Tensor> inputs = {a, x, y, z};
std::vector<IValue> stack = fmap<IValue>(inputs);
kernel.run(stack);
auto out = stack[0].toTensor();
ASSERT_EQ(out.sizes(), ref.sizes());
ASSERT_EQ(out.dtype(), ref.dtype());
ASSERT_TRUE(at::allclose(out, ref));
#endif
}
TEST_F(GraphOpt, OptimizeCatWithTypePromotionInUser) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%x : Int(10, strides=[1], device=cpu),
%y : Int(20, strides=[1], device=cpu),
%z : Int(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Int(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat)
return (%5))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// The `aten::tanh` op must be moved to the inputs of `aten::cat`.
// The scalar type of the inputs to `cat` should now be `Float` since they
// are the result of `tanh` which does the type promotion.
testing::FileCheck()
.check("aten::tanh")
->check("aten::tanh")
->check("aten::tanh")
->check("aten::cat")
->check_not("aten::tanh")
->run(*kernel.graph());
auto x = at::randint(std::numeric_limits<int>::max(), {10}, at::kInt);
auto y = at::randint(std::numeric_limits<int>::max(), {20}, at::kInt);
auto z = at::randint(std::numeric_limits<int>::max(), {30}, at::kInt);
auto ref = at::tanh(at::cat({x, y, z}, 0));
std::vector<at::Tensor> inputs = {x, y, z};
std::vector<IValue> stack = fmap<IValue>(inputs);
kernel.run(stack);
auto out = stack[0].toTensor();
ASSERT_EQ(out.sizes(), ref.sizes());
ASSERT_EQ(out.dtype(), ref.dtype());
ASSERT_TRUE(at::allclose(out, ref));
#endif
}
TEST_F(GraphOpt, OptimizeCatWithTypePromotionInCat) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Double(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Double(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Double(60, strides=[1], device=cpu) = aten::log(%cat)
return (%5))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// No transformation should have happened because the `aten::cat` op performs
// type promotion. This case is currently not handled.
testing::FileCheck()
.check("aten::cat")
->check("aten::log")
->check_not("aten::cat")
->check_not("aten::log")
->run(*kernel.graph());
#endif
}
TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%0 : Float(60, strides=[1], device=cpu),
%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat)
return (%5))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// No transformation is expected since the consumers of cat are not
// single-tensor element-wise ops.
testing::FileCheck()
.check("aten::cat")
->check("aten::mul")
->check_not("aten::cat")
->check_not("aten::mul")
->run(*kernel.graph());
#endif
}
TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp2) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%0 : Float(60, strides=[1], device=cpu),
%1 : Float(60, strides=[1], device=cpu),
%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(30, strides=[1], device=cpu)):
%one : int = prim::Constant[value=1]()
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat)
%6 : Float(60, strides=[1], device=cpu) = aten::add(%5, %1, %one)
return (%6))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// No transformation is expected since the consumers of cat are not
// single-tensor element-wise ops.
testing::FileCheck()
.check("aten::cat")
->check("aten::mul")
->check("aten::add")
->check_not("aten::cat")
->check_not("aten::mul")
->check_not("aten::add")
->run(*kernel.graph());
#endif
}
TEST_F(GraphOpt, AOTGraphPrepPasses) {
const auto graph_string = R"IR(
graph(%x, %y, %z, %i : int):
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
return (%xyz_list, %i))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
removeGraphOutput(g, 1);
replaceListOutputWithTuple(g);
LowerAllTuples(g);
testing::FileCheck().check("return (%x, %y, %z)")->run(*g);
}
} // namespace jit
} // namespace torch