mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/89852). * __->__ #89852 * #89851 set -Wsuggest-override for builds Summary: This was flagged by a Meta internal build. Test Plan: Rely on CI. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89852 Approved by: https://github.com/malfet
320 lines
10 KiB
C++
320 lines
10 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <test/cpp/tensorexpr/test_base.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include <torch/csrc/jit/passes/lower_tuples.h>
|
|
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
|
|
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
#include <torch/torch.h>
|
|
|
|
#include <limits>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
class GraphOpt : public ::testing::Test {
|
|
public:
|
|
void SetUp() override {
|
|
old_cat_wo_conditionals_ = getCatWoConditionals();
|
|
getCatWoConditionals() = true;
|
|
}
|
|
|
|
void TearDown() override {
|
|
getCatWoConditionals() = old_cat_wo_conditionals_;
|
|
}
|
|
|
|
private:
|
|
bool old_cat_wo_conditionals_;
|
|
};
|
|
|
|
TEST_F(GraphOpt, OptimizeCat) {
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cpu),
|
|
%z : Float(30, strides=[1], device=cpu)):
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
|
%5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
|
|
return (%5))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
|
|
TensorExprKernel kernel(g);
|
|
|
|
// The `aten::log` op must be moved to the inputs of `aten::cat`.
|
|
testing::FileCheck()
|
|
.check("aten::log")
|
|
->check("aten::log")
|
|
->check("aten::log")
|
|
->check("aten::cat")
|
|
->check_not("aten::log")
|
|
->run(*kernel.graph());
|
|
|
|
auto x = at::rand({10}, at::kFloat);
|
|
auto y = at::rand({20}, at::kFloat);
|
|
auto z = at::rand({30}, at::kFloat);
|
|
auto ref = at::log(at::cat({x, y, z}, 0));
|
|
|
|
std::vector<at::Tensor> inputs = {x, y, z};
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
kernel.run(stack);
|
|
auto out = stack[0].toTensor();
|
|
ASSERT_EQ(out.sizes(), ref.sizes());
|
|
ASSERT_EQ(out.dtype(), ref.dtype());
|
|
ASSERT_TRUE(at::allclose(out, ref));
|
|
#endif
|
|
}
|
|
|
|
TEST_F(GraphOpt, OptimizeCat2) {
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cpu),
|
|
%z : Float(30, strides=[1], device=cpu)):
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
|
%5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
|
|
%6 : Float(60, strides=[1], device=cpu) = aten::tanh(%5)
|
|
return (%6))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
|
|
TensorExprKernel kernel(g);
|
|
|
|
// The `aten::log` and `aten::tanh` ops must be moved to the inputs of
|
|
// `aten::cat`.
|
|
testing::FileCheck()
|
|
.check("aten::log")
|
|
->check("aten::log")
|
|
->check("aten::log")
|
|
->check("aten::tanh")
|
|
->check("aten::tanh")
|
|
->check("aten::tanh")
|
|
->check("aten::cat")
|
|
->check_not("aten::log")
|
|
->check_not("aten::tanh")
|
|
->run(*kernel.graph());
|
|
|
|
auto x = at::rand({10}, at::kFloat);
|
|
auto y = at::rand({20}, at::kFloat);
|
|
auto z = at::rand({30}, at::kFloat);
|
|
auto ref = at::tanh(at::log(at::cat({x, y, z}, 0)));
|
|
|
|
std::vector<at::Tensor> inputs = {x, y, z};
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
kernel.run(stack);
|
|
auto out = stack[0].toTensor();
|
|
ASSERT_EQ(out.sizes(), ref.sizes());
|
|
ASSERT_EQ(out.dtype(), ref.dtype());
|
|
ASSERT_TRUE(at::allclose(out, ref));
|
|
#endif
|
|
}
|
|
|
|
TEST_F(GraphOpt, OptimizeCat3) {
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
const auto graph_string = R"IR(
|
|
graph(%a : Float(60, strides=[1], device=cpu),
|
|
%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cpu),
|
|
%z : Float(30, strides=[1], device=cpu)):
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
|
%5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat)
|
|
%6 : Float(60, strides=[1], device=cpu) = aten::mul(%a, %5)
|
|
return (%6))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
|
|
TensorExprKernel kernel(g);
|
|
|
|
// The `aten::tanh` op must be moved to the inputs of `aten::cat`.
|
|
// But the `aten::mul` op must not be moved since it is not a single-tensor
|
|
// op (it has 2 tensor inputs).
|
|
testing::FileCheck()
|
|
.check("aten::tanh")
|
|
->check("aten::tanh")
|
|
->check("aten::tanh")
|
|
->check("aten::cat")
|
|
->check("aten::mul")
|
|
->check_not("aten::tanh")
|
|
->run(*kernel.graph());
|
|
|
|
auto a = at::rand({60}, at::kFloat);
|
|
auto x = at::rand({10}, at::kFloat);
|
|
auto y = at::rand({20}, at::kFloat);
|
|
auto z = at::rand({30}, at::kFloat);
|
|
auto ref = at::tanh(at::cat({x, y, z}, 0)) * a;
|
|
|
|
std::vector<at::Tensor> inputs = {a, x, y, z};
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
kernel.run(stack);
|
|
auto out = stack[0].toTensor();
|
|
ASSERT_EQ(out.sizes(), ref.sizes());
|
|
ASSERT_EQ(out.dtype(), ref.dtype());
|
|
ASSERT_TRUE(at::allclose(out, ref));
|
|
#endif
|
|
}
|
|
|
|
TEST_F(GraphOpt, OptimizeCatWithTypePromotionInUser) {
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Int(10, strides=[1], device=cpu),
|
|
%y : Int(20, strides=[1], device=cpu),
|
|
%z : Int(30, strides=[1], device=cpu)):
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
%cat : Int(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
|
%5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat)
|
|
return (%5))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
|
|
TensorExprKernel kernel(g);
|
|
|
|
// The `aten::tanh` op must be moved to the inputs of `aten::cat`.
|
|
// The scalar type of the inputs to `cat` should now be `Float` since they
|
|
// are the result of `tanh` which does the type promotion.
|
|
testing::FileCheck()
|
|
.check("aten::tanh")
|
|
->check("aten::tanh")
|
|
->check("aten::tanh")
|
|
->check("aten::cat")
|
|
->check_not("aten::tanh")
|
|
->run(*kernel.graph());
|
|
|
|
auto x = at::randint(std::numeric_limits<int>::max(), {10}, at::kInt);
|
|
auto y = at::randint(std::numeric_limits<int>::max(), {20}, at::kInt);
|
|
auto z = at::randint(std::numeric_limits<int>::max(), {30}, at::kInt);
|
|
auto ref = at::tanh(at::cat({x, y, z}, 0));
|
|
|
|
std::vector<at::Tensor> inputs = {x, y, z};
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
kernel.run(stack);
|
|
auto out = stack[0].toTensor();
|
|
ASSERT_EQ(out.sizes(), ref.sizes());
|
|
ASSERT_EQ(out.dtype(), ref.dtype());
|
|
ASSERT_TRUE(at::allclose(out, ref));
|
|
#endif
|
|
}
|
|
|
|
TEST_F(GraphOpt, OptimizeCatWithTypePromotionInCat) {
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cpu),
|
|
%z : Double(30, strides=[1], device=cpu)):
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
%cat : Double(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
|
%5 : Double(60, strides=[1], device=cpu) = aten::log(%cat)
|
|
return (%5))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
|
|
TensorExprKernel kernel(g);
|
|
|
|
// No transformation should have happened because the `aten::cat` op performs
|
|
// type promotion. This case is currently not handled.
|
|
testing::FileCheck()
|
|
.check("aten::cat")
|
|
->check("aten::log")
|
|
->check_not("aten::cat")
|
|
->check_not("aten::log")
|
|
->run(*kernel.graph());
|
|
#endif
|
|
}
|
|
|
|
TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp) {
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
const auto graph_string = R"IR(
|
|
graph(%0 : Float(60, strides=[1], device=cpu),
|
|
%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cpu),
|
|
%z : Float(30, strides=[1], device=cpu)):
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
|
%5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat)
|
|
return (%5))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
|
|
TensorExprKernel kernel(g);
|
|
|
|
// No transformation is expected since the consumers of cat are not
|
|
// single-tensor element-wise ops.
|
|
testing::FileCheck()
|
|
.check("aten::cat")
|
|
->check("aten::mul")
|
|
->check_not("aten::cat")
|
|
->check_not("aten::mul")
|
|
->run(*kernel.graph());
|
|
#endif
|
|
}
|
|
|
|
TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp2) {
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
const auto graph_string = R"IR(
|
|
graph(%0 : Float(60, strides=[1], device=cpu),
|
|
%1 : Float(60, strides=[1], device=cpu),
|
|
%x : Float(10, strides=[1], device=cpu),
|
|
%y : Float(20, strides=[1], device=cpu),
|
|
%z : Float(30, strides=[1], device=cpu)):
|
|
%one : int = prim::Constant[value=1]()
|
|
%dim : int = prim::Constant[value=0]()
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
|
%5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat)
|
|
%6 : Float(60, strides=[1], device=cpu) = aten::add(%5, %1, %one)
|
|
return (%6))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
g->lint();
|
|
|
|
TensorExprKernel kernel(g);
|
|
|
|
// No transformation is expected since the consumers of cat are not
|
|
// single-tensor element-wise ops.
|
|
testing::FileCheck()
|
|
.check("aten::cat")
|
|
->check("aten::mul")
|
|
->check("aten::add")
|
|
->check_not("aten::cat")
|
|
->check_not("aten::mul")
|
|
->check_not("aten::add")
|
|
->run(*kernel.graph());
|
|
#endif
|
|
}
|
|
|
|
TEST_F(GraphOpt, AOTGraphPrepPasses) {
|
|
const auto graph_string = R"IR(
|
|
graph(%x, %y, %z, %i : int):
|
|
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
|
return (%xyz_list, %i))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
removeGraphOutput(g, 1);
|
|
replaceListOutputWithTuple(g);
|
|
LowerAllTuples(g);
|
|
|
|
testing::FileCheck().check("return (%x, %y, %z)")->run(*g);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|