mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert D20624571: [pytorch][PR] [TensorExpr] Extend arithmetic simplifier to work with multi variable expressions
Test Plan: revert-hammer Differential Revision: D20624571 Original commit changeset: e49049377bee fbshipit-source-id: 7d8dda0c3b44be1c3236a0313bbfa128b7015de7
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ee7cd84fac
commit
a7f8655314
@ -481,7 +481,6 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_mutator.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_printer.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_simplifier.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_visitor.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/kernel.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/llvm_codegen.cpp
|
||||
|
@ -11,40 +11,6 @@ namespace jit {
|
||||
using namespace torch::jit::tensorexpr;
|
||||
using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
|
||||
|
||||
#define IS_NODE(T, node) \
|
||||
{ \
|
||||
auto* node_ = dynamic_cast<const T*>(node); \
|
||||
EXPECT_NE(nullptr, node_) << "Expected node to be " #T; \
|
||||
}
|
||||
|
||||
#define IS_NODE_WITH_NAME(T, node, name) \
|
||||
auto* name = dynamic_cast<const T*>(node); \
|
||||
EXPECT_NE(nullptr, name) << "Expected " #name " to be " #T;
|
||||
|
||||
#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \
|
||||
const T* name = nullptr; \
|
||||
{ \
|
||||
auto* node_ = dynamic_cast<const Cast*>(node); \
|
||||
EXPECT_NE(nullptr, node_); \
|
||||
EXPECT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \
|
||||
name = dynamic_cast<const T*>(node_->src_value()); \
|
||||
} \
|
||||
EXPECT_NE(nullptr, name) << "Expected " #name " to be " #T;
|
||||
|
||||
#define IS_IMM_WITH_VAL(T, node, val) \
|
||||
{ \
|
||||
auto* node_ = dynamic_cast<const T##Imm*>(node); \
|
||||
EXPECT_NE(nullptr, node_) << "Expected node to be " #T "Imm"; \
|
||||
EXPECT_EQ(node_->value(), val) << "Expected Imm to be " << val; \
|
||||
}
|
||||
|
||||
#define IS_VAR_WITH_NAME(node, name) \
|
||||
{ \
|
||||
auto* node_ = dynamic_cast<const Var*>(node); \
|
||||
EXPECT_NE(nullptr, node_) << "Expected node to be Var"; \
|
||||
EXPECT_EQ(node_->name_hint(), name) << "Expected var to be " #name; \
|
||||
}
|
||||
|
||||
void testConstantFoldSimple() {
|
||||
KernelScope kernel_scope;
|
||||
ExprHandle a(2.0f);
|
||||
@ -167,33 +133,17 @@ void testConstantFoldIntrinsics() {
|
||||
|
||||
void testConstantFoldWithVar() {
|
||||
KernelScope kernel_scope;
|
||||
{
|
||||
VarHandle x("x", kInt);
|
||||
ExprHandle body = x * (ExprHandle(2) + ExprHandle(4));
|
||||
VarHandle x("x", kFloat);
|
||||
ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f));
|
||||
|
||||
ExprHandle newF = IRSimplifier::simplify(body);
|
||||
const Mul* root = newF.AsNode<Mul>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
EXPECT_NE(dynamic_cast<const IntImm*>(root->lhs()), nullptr);
|
||||
ExprHandle newF = IRSimplifier::simplify(body);
|
||||
const Mul* root = newF.AsNode<Mul>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
EXPECT_NE(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
|
||||
|
||||
ExprHandle result = Let::make(x, ExprHandle(3), newF);
|
||||
SimpleIRExprEval eval(result);
|
||||
EXPECT_EQ(eval.value<int>(), 3 * (2 + 4));
|
||||
}
|
||||
|
||||
{
|
||||
VarHandle x("x", kFloat);
|
||||
ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f));
|
||||
|
||||
ExprHandle newF = IRSimplifier::simplify(body);
|
||||
const Mul* root = newF.AsNode<Mul>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
EXPECT_NE(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
|
||||
|
||||
ExprHandle result = Let::make(x, ExprHandle(3.f), newF);
|
||||
SimpleIRExprEval eval(result);
|
||||
EXPECT_EQ(eval.value<float>(), 3 * (2 + 4));
|
||||
}
|
||||
ExprHandle result = Let::make(x, ExprHandle(3.f), newF);
|
||||
SimpleIRExprEval eval(result);
|
||||
EXPECT_EQ(eval.value<float>(), 3 * (2 + 4));
|
||||
}
|
||||
|
||||
void testUnFoldableExpr() {
|
||||
@ -278,22 +228,34 @@ void testHashEquivalenceAfterFolding() {
|
||||
ExprHandle a(2.0f);
|
||||
ExprHandle b(3.0f);
|
||||
ExprHandle c(5.0f);
|
||||
ExprHandle f1 = ((a + b) * x);
|
||||
ExprHandle f2 = (c * x);
|
||||
ExprHandle f = ((a + b) * x) * (c * x);
|
||||
|
||||
const Mul* root = f.AsNode<Mul>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
|
||||
HashProvider hasher;
|
||||
auto hash_l = hasher.hash(f1.node());
|
||||
auto hash_r = hasher.hash(f2.node());
|
||||
auto hash_f = hasher.hash(f.node());
|
||||
auto hash_l = hasher.hash(root->lhs());
|
||||
auto hash_r = hasher.hash(root->rhs());
|
||||
|
||||
// Root not equal to either branch, and branches not equal.
|
||||
EXPECT_NE(hash_f, hash_l);
|
||||
EXPECT_NE(hash_f, hash_r);
|
||||
EXPECT_NE(hash_l, hash_r);
|
||||
|
||||
ExprHandle ff1 = IRSimplifier::simplify(f1);
|
||||
ExprHandle ff2 = IRSimplifier::simplify(f2);
|
||||
ExprHandle newF = IRSimplifier::simplify(f);
|
||||
|
||||
auto hash_l_n = hasher.hash(ff1.node());
|
||||
auto hash_r_n = hasher.hash(ff2.node());
|
||||
const Mul* newRoot = newF.AsNode<Mul>();
|
||||
EXPECT_NE(newRoot, nullptr);
|
||||
|
||||
// branches are now equal.
|
||||
auto hash_f_n = hasher.hash(newF.node());
|
||||
auto hash_l_n = hasher.hash(newRoot->lhs());
|
||||
auto hash_r_n = hasher.hash(newRoot->rhs());
|
||||
|
||||
// Root not equal to either branch.
|
||||
EXPECT_NE(hash_f_n, hash_l_n);
|
||||
EXPECT_NE(hash_f_n, hash_r_n);
|
||||
// but branches are now equal.
|
||||
EXPECT_EQ(hash_l_n, hash_r_n);
|
||||
}
|
||||
|
||||
@ -381,16 +343,11 @@ void testHashLargeExpression() {
|
||||
EXPECT_NE(hash_t, hash_f);
|
||||
}
|
||||
|
||||
/// (2 + x) + 4 => x + 6
|
||||
/// (2.f + x) + 4.f => x + 6.f
|
||||
void testSimplifyAdd() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
|
||||
VarHandle m("m", kInt);
|
||||
VarHandle n("n", kInt);
|
||||
VarHandle n_1("n_1", kInt);
|
||||
ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4);
|
||||
VarHandle x("x", kFloat);
|
||||
ExprHandle body = (ExprHandle(2.f) + x) + ExprHandle(4.f);
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Add* root = simplified.AsNode<Add>();
|
||||
@ -398,43 +355,51 @@ void testSimplifyAdd() {
|
||||
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
|
||||
EXPECT_NE(lhs, nullptr);
|
||||
EXPECT_EQ(lhs->name_hint(), "x");
|
||||
const IntImm* rhs = dynamic_cast<const IntImm*>(root->rhs());
|
||||
const FloatImm* rhs = dynamic_cast<const FloatImm*>(root->rhs());
|
||||
EXPECT_NE(rhs, nullptr);
|
||||
EXPECT_EQ(rhs->value(), 6.f);
|
||||
}
|
||||
|
||||
/// (2 - x) - 4 => -2 - x
|
||||
/// (2.f - x) - 4.f => -2.f - x
|
||||
void testSimplifySub() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kInt);
|
||||
ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4);
|
||||
VarHandle x("x", kFloat);
|
||||
ExprHandle body = (ExprHandle(2.f) - x) - ExprHandle(4.f);
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Sub* root = simplified.AsNode<Sub>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
const IntImm* lhs = dynamic_cast<const IntImm*>(root->lhs());
|
||||
const FloatImm* lhs = dynamic_cast<const FloatImm*>(root->lhs());
|
||||
EXPECT_NE(lhs, nullptr);
|
||||
EXPECT_EQ(lhs->value(), -2);
|
||||
EXPECT_EQ(lhs->value(), -2.f);
|
||||
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
|
||||
EXPECT_NE(rhs, nullptr);
|
||||
EXPECT_EQ(rhs->name_hint(), "x");
|
||||
}
|
||||
|
||||
/// 2 * (1 - x) - 4 => -2 * (x + 3)
|
||||
/// 2.f * (1.f - x) - 4.f => -6.f - (x * 2.f)
|
||||
void testSimplifyMultiLayer() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kInt);
|
||||
ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4));
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
VarHandle x("x", kFloat);
|
||||
ExprHandle body = ExprHandle(2.f) * ((ExprHandle(1.f) - x) - ExprHandle(4.f));
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_IMM_WITH_VAL(Int, mul->lhs(), -2);
|
||||
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
|
||||
IS_VAR_WITH_NAME(add->lhs(), "x");
|
||||
IS_IMM_WITH_VAL(Int, add->rhs(), 3);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Sub* root = simplified.AsNode<Sub>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
const FloatImm* lhs = dynamic_cast<const FloatImm*>(root->lhs());
|
||||
EXPECT_NE(lhs, nullptr);
|
||||
EXPECT_EQ(lhs->value(), -6.f);
|
||||
const Mul* rhs = dynamic_cast<const Mul*>(root->rhs());
|
||||
EXPECT_NE(rhs, nullptr);
|
||||
const Var* varX = dynamic_cast<const Var*>(rhs->lhs());
|
||||
EXPECT_NE(varX, nullptr);
|
||||
EXPECT_EQ(varX->name_hint(), "x");
|
||||
const FloatImm* mulRhs = dynamic_cast<const FloatImm*>(rhs->rhs());
|
||||
EXPECT_NE(mulRhs, nullptr);
|
||||
EXPECT_EQ(mulRhs->value(), 2.f);
|
||||
}
|
||||
|
||||
/// 2 * (3 * x) - (x * 4) => 2 * x
|
||||
/// 2 * (3 * x) - (x * 4) => x * 2
|
||||
void testSimplifyMultiTerm() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kInt);
|
||||
@ -444,30 +409,30 @@ void testSimplifyMultiTerm() {
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Mul* root = simplified.AsNode<Mul>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
const IntImm* lhs = dynamic_cast<const IntImm*>(root->lhs());
|
||||
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
|
||||
EXPECT_NE(lhs, nullptr);
|
||||
EXPECT_EQ(lhs->value(), 2);
|
||||
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
|
||||
EXPECT_EQ(lhs->name_hint(), "x");
|
||||
const IntImm* rhs = dynamic_cast<const IntImm*>(root->rhs());
|
||||
EXPECT_NE(rhs, nullptr);
|
||||
EXPECT_EQ(rhs->name_hint(), "x");
|
||||
EXPECT_EQ(rhs->value(), 2);
|
||||
}
|
||||
|
||||
/// 2 * (3 * (long)x) - (x * 4) => 2 * x
|
||||
/// 2 * (3 * (f)x) - (x * 4) => x * 2.f
|
||||
void testSimplifyCasts() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kLong);
|
||||
VarHandle x("x", kFloat);
|
||||
ExprHandle body =
|
||||
(ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4)));
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Mul* root = simplified.AsNode<Mul>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
const LongImm* lhs = dynamic_cast<const LongImm*>(root->lhs());
|
||||
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
|
||||
EXPECT_NE(lhs, nullptr);
|
||||
EXPECT_EQ(lhs->value(), 2);
|
||||
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
|
||||
EXPECT_EQ(lhs->name_hint(), "x");
|
||||
const FloatImm* rhs = dynamic_cast<const FloatImm*>(root->rhs());
|
||||
EXPECT_NE(rhs, nullptr);
|
||||
EXPECT_EQ(rhs->name_hint(), "x");
|
||||
EXPECT_EQ(rhs->value(), 2);
|
||||
}
|
||||
|
||||
/// (x + 0) * 1 => x
|
||||
@ -487,39 +452,20 @@ void testSimplifyMultiVar() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
ExprHandle body = x * 24 + y * 34;
|
||||
ExprHandle body = y * 24 + x * 34;
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
const Add* root = simplified.AsNode<Add>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
const Mul* lhs = dynamic_cast<const Mul*>(root->lhs());
|
||||
EXPECT_NE(lhs, nullptr);
|
||||
const Var* varX = dynamic_cast<const Var*>(lhs->rhs());
|
||||
EXPECT_NE(varX, nullptr);
|
||||
EXPECT_EQ(varX->name_hint(), "y");
|
||||
const Var* varY = dynamic_cast<const Var*>(lhs->lhs());
|
||||
EXPECT_EQ(varY->name_hint(), "y");
|
||||
const Mul* rhs = dynamic_cast<const Mul*>(root->rhs());
|
||||
EXPECT_NE(rhs, nullptr);
|
||||
const Var* varY = dynamic_cast<const Var*>(rhs->rhs());
|
||||
EXPECT_NE(varY, nullptr);
|
||||
EXPECT_EQ(varY->name_hint(), "x");
|
||||
}
|
||||
|
||||
// x + 2 + y => x + y + 2
|
||||
void testSimplifyReorderings() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
ExprHandle body = x + 2 + y;
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
const Add* root = simplified.AsNode<Add>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
|
||||
IS_NODE_WITH_NAME(Add, root->lhs(), rhs);
|
||||
IS_VAR_WITH_NAME(rhs->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(rhs->rhs(), "y");
|
||||
IS_IMM_WITH_VAL(Int, root->rhs(), 2);
|
||||
const Var* varX = dynamic_cast<const Var*>(rhs->lhs());
|
||||
EXPECT_NE(varX, nullptr);
|
||||
EXPECT_EQ(varX->name_hint(), "x");
|
||||
}
|
||||
|
||||
/// y + x * 0 => y
|
||||
@ -530,621 +476,9 @@ void testSimplifyEliminatesVar() {
|
||||
ExprHandle body = y + x * ExprHandle(0);
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_VAR_WITH_NAME(simplified.node(), "y");
|
||||
}
|
||||
|
||||
void testSimplifyAdds() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
|
||||
{
|
||||
// (x + y) + (x + y) => 2 * (x + y)
|
||||
ExprHandle body = (x + y) + (x + y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), root);
|
||||
IS_IMM_WITH_VAL(Int, root->lhs(), 2);
|
||||
IS_NODE_WITH_NAME(Add, root->rhs(), add);
|
||||
IS_VAR_WITH_NAME(add->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
// (x * y) + (x * y) => 2 * (x * y)
|
||||
ExprHandle body = (x * y) + (x * y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), root);
|
||||
IS_IMM_WITH_VAL(Int, root->lhs(), 2);
|
||||
IS_NODE_WITH_NAME(Mul, root->rhs(), mul);
|
||||
IS_VAR_WITH_NAME(mul->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(mul->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
// (x - y) + (x - y) => -2 * (y - x)
|
||||
ExprHandle body = (x - y) + (x - y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_IMM_WITH_VAL(Int, mul->lhs(), -2);
|
||||
|
||||
IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
|
||||
IS_VAR_WITH_NAME(rhs->lhs(), "y");
|
||||
IS_VAR_WITH_NAME(rhs->rhs(), "x");
|
||||
}
|
||||
}
|
||||
|
||||
void testSimplifyMuls() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
|
||||
{
|
||||
// (x + y) * (x + y) => (x + y) * (x + y)
|
||||
// We don't attempt to simplify mulitplication of polynomials since the
|
||||
// result is only very rarely more efficient.
|
||||
ExprHandle body = (x + y) * (x + y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_NODE_WITH_NAME(Add, mul->lhs(), lhs);
|
||||
IS_VAR_WITH_NAME(lhs->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(lhs->rhs(), "y");
|
||||
IS_NODE_WITH_NAME(Add, mul->rhs(), rhs);
|
||||
IS_VAR_WITH_NAME(rhs->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(rhs->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
// x * y * x * y => x * x * y * y
|
||||
// These get reordered only.
|
||||
ExprHandle body = x * y * x * y;
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul1);
|
||||
IS_NODE_WITH_NAME(Mul, mul1->lhs(), mul2);
|
||||
IS_NODE_WITH_NAME(Mul, mul2->lhs(), mul3);
|
||||
IS_VAR_WITH_NAME(mul1->rhs(), "y");
|
||||
IS_VAR_WITH_NAME(mul2->rhs(), "y");
|
||||
IS_VAR_WITH_NAME(mul3->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(mul3->rhs(), "x");
|
||||
}
|
||||
|
||||
{
|
||||
// (x - y) * (x - y) => (x - y) * (x - y)
|
||||
// As with Add we don't attempt simplification of this.
|
||||
ExprHandle body = (x - y) * (x - y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_NODE_WITH_NAME(Sub, mul->lhs(), lhs);
|
||||
IS_VAR_WITH_NAME(lhs->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(lhs->rhs(), "y");
|
||||
IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
|
||||
IS_VAR_WITH_NAME(rhs->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(rhs->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
// (x + y) * (x - y) => (x - y) * (x - y)
|
||||
// Don't simplify with different ops on each side.
|
||||
ExprHandle body = (x + y) * (x - y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_NODE_WITH_NAME(Add, mul->lhs(), lhs);
|
||||
IS_VAR_WITH_NAME(lhs->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(lhs->rhs(), "y");
|
||||
IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
|
||||
IS_VAR_WITH_NAME(rhs->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(rhs->rhs(), "y");
|
||||
}
|
||||
}
|
||||
|
||||
// Sub an expr from itself will result in zero.
|
||||
void testSimplifySubs() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
|
||||
{
|
||||
// (x + y) - (x + y) => 0
|
||||
ExprHandle body = (x + y) - (x + y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
|
||||
}
|
||||
|
||||
{
|
||||
// (x * y) - (x * y) => 0
|
||||
ExprHandle body = (x * y) - (x * y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
|
||||
}
|
||||
|
||||
{
|
||||
// (x - y) - (x - y) => 0
|
||||
ExprHandle body = (x - y) - (x - y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
|
||||
}
|
||||
|
||||
{
|
||||
// (x + y) - 2 * (x + y) => -1 * (x + y)
|
||||
ExprHandle body = (x + y) - ExprHandle(2) * (x + y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_IMM_WITH_VAL(Int, mul->lhs(), -1);
|
||||
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
|
||||
IS_VAR_WITH_NAME(add->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
// (x + y) - y => x
|
||||
ExprHandle body = (x + y) - y;
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_VAR_WITH_NAME(simplified.node(), "x");
|
||||
}
|
||||
|
||||
{
|
||||
// (x - y) - y => x - 2 * y
|
||||
ExprHandle body = (x - y) - y;
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
|
||||
IS_VAR_WITH_NAME(sub->lhs(), "x");
|
||||
IS_NODE_WITH_NAME(Mul, sub->rhs(), mul);
|
||||
IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
|
||||
IS_VAR_WITH_NAME(mul->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
// 2 * x - x => x
|
||||
ExprHandle body = (ExprHandle(2) * x) - x;
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_VAR_WITH_NAME(simplified.node(), "x");
|
||||
}
|
||||
|
||||
{
|
||||
// x - 2 * x = -1 * x
|
||||
// We don't have a unary negate, but this could be 0 -x I guess?
|
||||
ExprHandle body = x - (ExprHandle(2) * x);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
|
||||
IS_IMM_WITH_VAL(Int, mul->lhs(), -1);
|
||||
IS_VAR_WITH_NAME(mul->rhs(), "x");
|
||||
}
|
||||
|
||||
{
|
||||
// (x + y + 5) * (x - x) => 0
|
||||
// Cancelling out one side of Mul cancels both.
|
||||
ExprHandle body = (x + y + 5) * (x - x);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Test that mixing ops together simplifies as expected.
|
||||
void testSimplifyMultiOp() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
|
||||
{
|
||||
// (x * y) + (x - y) => (x * y) + x - y
|
||||
//
|
||||
ExprHandle body = (x * y) + (x - y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
|
||||
IS_NODE_WITH_NAME(Add, sub->lhs(), add);
|
||||
IS_NODE_WITH_NAME(Mul, add->lhs(), mul);
|
||||
IS_VAR_WITH_NAME(mul->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(mul->rhs(), "y");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "x");
|
||||
IS_VAR_WITH_NAME(sub->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
// (x + y) - (x * y) => x + y - (x * y)
|
||||
ExprHandle body = (x + y) - (x * y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
|
||||
IS_NODE_WITH_NAME(Add, sub->lhs(), add);
|
||||
IS_NODE_WITH_NAME(Mul, sub->rhs(), mul);
|
||||
IS_VAR_WITH_NAME(add->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "y");
|
||||
IS_VAR_WITH_NAME(mul->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(mul->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
// (x - y) - (x + y) => -2 * y
|
||||
ExprHandle body = (x - y) - (x + y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_IMM_WITH_VAL(Int, mul->lhs(), -2);
|
||||
IS_VAR_WITH_NAME(mul->rhs(), "y");
|
||||
}
|
||||
}
|
||||
|
||||
// Test that chaining many ops together works as expected.
|
||||
void testSimplifyManyOps() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
|
||||
{
|
||||
// x + y + x + x + y + y + x + y + x = 4 * y + 5 * x
|
||||
ExprHandle body = x + y + x + x + y + y + x + y + x;
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Add, simplified.node(), add);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
|
||||
IS_IMM_WITH_VAL(Int, lhs->lhs(), 4);
|
||||
IS_VAR_WITH_NAME(lhs->rhs(), "y");
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
|
||||
IS_IMM_WITH_VAL(Int, rhs->lhs(), 5);
|
||||
IS_VAR_WITH_NAME(rhs->rhs(), "x");
|
||||
}
|
||||
|
||||
{
|
||||
// x - y + x + x - y - y + x - y + x = 5 * x - 4 * y
|
||||
ExprHandle body = x - y + x + x - y - y + x - y + x;
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Sub, simplified.node(), add);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
|
||||
IS_IMM_WITH_VAL(Int, lhs->lhs(), 5);
|
||||
IS_VAR_WITH_NAME(lhs->rhs(), "x");
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
|
||||
IS_IMM_WITH_VAL(Int, rhs->lhs(), 4);
|
||||
IS_VAR_WITH_NAME(rhs->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
// x + y + x - x - y - y + x + y + x = 3 * x
|
||||
ExprHandle body = x + y + x - x - y - y + x + y + x;
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_IMM_WITH_VAL(Int, mul->lhs(), 3);
|
||||
IS_VAR_WITH_NAME(mul->rhs(), "x");
|
||||
}
|
||||
}
|
||||
|
||||
void testSimplifyFactorization() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
|
||||
{
|
||||
// (2 * x) + (2 * y) => 2 * (x + y)
|
||||
ExprHandle body = (ExprHandle(2) * x + ExprHandle(2) * y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
|
||||
|
||||
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
|
||||
IS_VAR_WITH_NAME(add->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
// Factorization when scalars have common divider.
|
||||
// (2 * x) + (4 * y) => 2 * (2 * y + x)
|
||||
ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
|
||||
|
||||
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
|
||||
IS_NODE_WITH_NAME(Mul, add->lhs(), mul2);
|
||||
IS_IMM_WITH_VAL(Int, mul2->lhs(), 2);
|
||||
IS_VAR_WITH_NAME(mul2->rhs(), "y");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "x");
|
||||
}
|
||||
|
||||
{
|
||||
// Factorization attempt without a common divider.
|
||||
// (2 * x) + (5 * y) => (5 * y) + (2 * x)
|
||||
ExprHandle body = (ExprHandle(2) * x + ExprHandle(5) * y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Add, simplified.node(), add);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
|
||||
IS_IMM_WITH_VAL(Int, lhs->lhs(), 5);
|
||||
IS_VAR_WITH_NAME(lhs->rhs(), "y");
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
|
||||
IS_IMM_WITH_VAL(Int, rhs->lhs(), 2);
|
||||
IS_VAR_WITH_NAME(rhs->rhs(), "x");
|
||||
}
|
||||
|
||||
{
|
||||
// Factorization after merging.
|
||||
// (2 * x) + (4 * y) + (8 * x + 6 * y) => 10 * (x + y)
|
||||
ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y) +
|
||||
(ExprHandle(8) * x + ExprHandle(6) * y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_IMM_WITH_VAL(Int, mul->lhs(), 10);
|
||||
|
||||
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
|
||||
IS_VAR_WITH_NAME(add->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(add->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
// Factorization with common divider but different signs.
|
||||
// (-2 * x) + (4 * y) => -2 * (x - 2 * y)
|
||||
ExprHandle body = (ExprHandle(-2) * x + ExprHandle(4) * y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_IMM_WITH_VAL(Int, mul->lhs(), -2);
|
||||
|
||||
IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
|
||||
IS_VAR_WITH_NAME(sub->lhs(), "x");
|
||||
IS_NODE_WITH_NAME(Mul, sub->rhs(), mul2);
|
||||
IS_IMM_WITH_VAL(Int, mul2->lhs(), 2);
|
||||
IS_VAR_WITH_NAME(mul2->rhs(), "y");
|
||||
}
|
||||
}
|
||||
|
||||
// (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (3 * z + y + 4 * x)
|
||||
void testSimplifyFactorizeUneven() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
VarHandle z("z", kInt);
|
||||
ExprHandle body =
|
||||
(ExprHandle(4) * x + y + z * 2) + (ExprHandle(4) * x + y + z * 4);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), root);
|
||||
IS_IMM_WITH_VAL(Int, root->lhs(), 2);
|
||||
IS_NODE_WITH_NAME(Add, root->rhs(), add1);
|
||||
IS_NODE_WITH_NAME(Add, add1->lhs(), add2);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, add1->rhs(), xmul);
|
||||
IS_NODE_WITH_NAME(Mul, add2->lhs(), zmul);
|
||||
|
||||
IS_IMM_WITH_VAL(Int, zmul->lhs(), 3);
|
||||
IS_VAR_WITH_NAME(zmul->rhs(), "z");
|
||||
|
||||
IS_VAR_WITH_NAME(add2->rhs(), "y");
|
||||
|
||||
IS_IMM_WITH_VAL(Int, xmul->lhs(), 4);
|
||||
IS_VAR_WITH_NAME(xmul->rhs(), "x");
|
||||
}
|
||||
|
||||
// (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y)
|
||||
// This is kind of a placeholder test for variable factorization.
|
||||
void testSimplifyDeeperTerms() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
ExprHandle body = (x * y) + (ExprHandle(2) * x) * (x + y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Add, simplified.node(), add);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
|
||||
IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
|
||||
IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm);
|
||||
IS_VAR_WITH_NAME(xxTerm->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(xxTerm->rhs(), "x");
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
|
||||
IS_IMM_WITH_VAL(Int, rhs->lhs(), 3);
|
||||
IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm);
|
||||
IS_VAR_WITH_NAME(xyTerm->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(xyTerm->rhs(), "y");
|
||||
}
|
||||
|
||||
// Tests the difference between two less trivial expressions.
|
||||
// (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1
|
||||
void testSimplifyDeeperDifference() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle n("n", kInt);
|
||||
VarHandle n_1("n_1", kInt);
|
||||
VarHandle m("m", kInt);
|
||||
ExprHandle body =
|
||||
(m * (ExprHandle(1) * n_1) + (n + 1)) - (m * (ExprHandle(1) * n_1) + n);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_IMM_WITH_VAL(Int, simplified.node(), 1);
|
||||
}
|
||||
|
||||
// Test constant folding into the difference between expressions.
|
||||
// 2 + char((m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n)) => 3
|
||||
void testSimplifyFoldComplexDifference() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle n("n", kInt);
|
||||
VarHandle n_1("n_1", kInt);
|
||||
VarHandle m("m", kInt);
|
||||
ExprHandle body =
|
||||
(IntImm::make(2) +
|
||||
(Cast::make(
|
||||
kChar,
|
||||
(m * (ExprHandle(1) * n_1) + (n + 1)) -
|
||||
(m * (ExprHandle(1) * n_1) + n))));
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
IS_IMM_WITH_VAL(Int, simplified.node(), 3);
|
||||
}
|
||||
|
||||
void testSimplifyIfComponents() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
ExprHandle body = IfThenElse::make(
|
||||
((ExprHandle(5) - ExprHandle(4)) * x) > y,
|
||||
ExprHandle(2) * x - x,
|
||||
ExprHandle(2) * y - y);
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(IfThenElse, simplified.node(), ifexpr);
|
||||
|
||||
IS_NODE_WITH_NAME(CompareSelect, ifexpr->condition(), cmp);
|
||||
EXPECT_EQ(cmp->compare_select_op(), kGT);
|
||||
IS_VAR_WITH_NAME(cmp->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(cmp->rhs(), "y");
|
||||
|
||||
IS_VAR_WITH_NAME(ifexpr->true_value(), "x");
|
||||
IS_VAR_WITH_NAME(ifexpr->false_value(), "y");
|
||||
}
|
||||
|
||||
void testSimplifyOpaqueTerms() {
|
||||
KernelScope kernel_scope;
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
|
||||
{
|
||||
// 2 * x/y * x - x/y * y => y * x/y
|
||||
ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_VAR_WITH_NAME(mul->lhs(), "y");
|
||||
IS_NODE_WITH_NAME(Div, mul->rhs(), div);
|
||||
IS_VAR_WITH_NAME(div->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(div->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
// x%y - (x%y - 1) => 1
|
||||
ExprHandle body = (x % y) - ((x % y) - 1);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_IMM_WITH_VAL(Int, simplified.node(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
void testSimplifyWontReorderFloat() {
|
||||
KernelScope kernel_scope;
|
||||
|
||||
{
|
||||
// 3 * (3 * x) - 3 * (3 * y) => -9 * (y - x)
|
||||
// This is an expression we can simplify.
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
|
||||
ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
|
||||
ExprHandle(3) * (ExprHandle(3) * y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
|
||||
IS_IMM_WITH_VAL(Int, mul->lhs(), -9);
|
||||
IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
|
||||
IS_VAR_WITH_NAME(sub->lhs(), "y");
|
||||
IS_VAR_WITH_NAME(sub->rhs(), "x");
|
||||
}
|
||||
|
||||
{
|
||||
// 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - 3 * (3 * y).
|
||||
// If the vars are floating point, ops are not associative and we can't
|
||||
// reorder.
|
||||
VarHandle x("x", kFloat);
|
||||
VarHandle y("y", kFloat);
|
||||
|
||||
ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
|
||||
ExprHandle(3) * (ExprHandle(3) * y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
|
||||
IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
|
||||
IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3);
|
||||
IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul);
|
||||
IS_IMM_WITH_VAL(Float, lhsVarMul->lhs(), 3);
|
||||
IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul);
|
||||
IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3);
|
||||
IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul);
|
||||
IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3);
|
||||
IS_VAR_WITH_NAME(rhsVarMul->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
// 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - (9 * y).
|
||||
// We will simplify subexprs if they dont reorder floating point ops.
|
||||
VarHandle x("x", kDouble);
|
||||
VarHandle y("y", kInt);
|
||||
|
||||
ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
|
||||
ExprHandle(3) * (ExprHandle(3) * y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
|
||||
IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
|
||||
IS_IMM_WITH_VAL(Double, lhsMul->lhs(), 3);
|
||||
IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul);
|
||||
IS_IMM_WITH_VAL(Double, lhsVarMul->lhs(), 3);
|
||||
IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
|
||||
|
||||
IS_NODE_WITH_NAME_AND_CAST(Mul, sub->rhs(), rhsMul, Double);
|
||||
IS_IMM_WITH_VAL(Int, rhsMul->lhs(), 9);
|
||||
IS_VAR_WITH_NAME(rhsMul->rhs(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
// Prevent reordering if FP propagated from dtypes.
|
||||
VarHandle x("x", kInt);
|
||||
VarHandle y("y", kInt);
|
||||
|
||||
ExprHandle body = ExprHandle(3.f) * (ExprHandle(3) * x) -
|
||||
ExprHandle(3) * (ExprHandle(3.f) * y);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
|
||||
IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
|
||||
IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3);
|
||||
IS_NODE_WITH_NAME_AND_CAST(Mul, lhsMul->rhs(), lhsVarMul, Float);
|
||||
IS_IMM_WITH_VAL(Int, lhsVarMul->lhs(), 3);
|
||||
IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
|
||||
|
||||
IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul);
|
||||
IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3);
|
||||
IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul);
|
||||
IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3);
|
||||
IS_NODE_WITH_NAME(Cast, rhsVarMul->rhs(), yCast);
|
||||
IS_VAR_WITH_NAME(yCast->src_value(), "y");
|
||||
}
|
||||
|
||||
{
|
||||
VarHandle x("x", kFloat);
|
||||
VarHandle y("y", kFloat);
|
||||
// x%y - (x%y - 1) => x%y - (x%y - 1).
|
||||
// We wont reorder opaque ops if they are FP.
|
||||
ExprHandle body = (x % y) - ((x % y) - 1);
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
|
||||
IS_NODE_WITH_NAME(Mod, sub->lhs(), lhsMod);
|
||||
IS_VAR_WITH_NAME(lhsMod->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(lhsMod->rhs(), "y");
|
||||
|
||||
IS_NODE_WITH_NAME(Sub, sub->rhs(), rhsSub);
|
||||
IS_NODE_WITH_NAME(Mod, rhsSub->lhs(), rhsMod);
|
||||
IS_VAR_WITH_NAME(rhsMod->lhs(), "x");
|
||||
IS_VAR_WITH_NAME(rhsMod->rhs(), "y");
|
||||
IS_IMM_WITH_VAL(Float, rhsSub->rhs(), 1);
|
||||
}
|
||||
const Var* root = simplified.AsNode<Var>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
EXPECT_EQ(root->name_hint(), "y");
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
@ -9,119 +9,105 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
#define TH_FORALL_TESTS(_) \
|
||||
_(ExprBasicValueTest) \
|
||||
_(ExprBasicValueTest02) \
|
||||
_(ExprLetTest01) \
|
||||
_(ExprLetStmtTest01) \
|
||||
_(ExprLetTest02) \
|
||||
_(ExprIntTest) \
|
||||
_(ExprFloatTest) \
|
||||
_(ExprByteTest) \
|
||||
_(ExprCharTest) \
|
||||
_(ExprShortTest) \
|
||||
_(ExprLongTest) \
|
||||
_(ExprHalfTest) \
|
||||
_(ExprDoubleTest) \
|
||||
_(ExprVectorAdd01) \
|
||||
_(ExprCompareSelectEQ) \
|
||||
_(ExprSubstitute01) \
|
||||
_(ExprMath01) \
|
||||
_(ExprUnaryMath01) \
|
||||
_(ExprBinaryMath01) \
|
||||
_(ExprDynamicShapeAdd) \
|
||||
_(ExprBitwiseOps) \
|
||||
_(IRPrinterBasicValueTest) \
|
||||
_(IRPrinterBasicValueTest02) \
|
||||
_(IRPrinterLetTest01) \
|
||||
_(IRPrinterLetTest02) \
|
||||
_(IRPrinterCastTest) \
|
||||
_(ExprSimple01) \
|
||||
_(ExprLower01) \
|
||||
_(ExprSimple02) \
|
||||
_(ExprSplitWithTailNone) \
|
||||
_(ExprSplitWithMask01) \
|
||||
_(ScheduleBroadcastAddBuffer) \
|
||||
_(ScheduleFunctionCall01) \
|
||||
_(ScheduleInlineFunc01) \
|
||||
_(ScheduleFuserStyle) \
|
||||
_(ScheduleFuserThreeArg) \
|
||||
_(ScheduleDynamicShape2D) \
|
||||
_(TypeTest01) \
|
||||
_(TypePropagation) \
|
||||
_(Cond01) \
|
||||
_(IfThenElse01) \
|
||||
_(IfThenElse02) \
|
||||
_(ATen_cast_Float) \
|
||||
_(ATennegInt) \
|
||||
_(ATennegFloat) \
|
||||
_(ATenaddInt) \
|
||||
_(ATenaddFloat) \
|
||||
_(ATensubInt) \
|
||||
_(ATensubFloat) \
|
||||
_(ATenlerp) \
|
||||
_(ATenaddcmulInt) \
|
||||
_(ATenaddcmulFloat) \
|
||||
_(ATenmulInt) \
|
||||
_(ATenmulFloat) \
|
||||
_(ATendivInt) \
|
||||
_(ATendivFloat) \
|
||||
_(ATenmaxInt) \
|
||||
_(ATenmaxFloat) \
|
||||
_(ATenminInt) \
|
||||
_(ATenminFloat) \
|
||||
_(ATen_sigmoid_backward) \
|
||||
_(ATen_tanh_backward) \
|
||||
_(ATenreciprocal) \
|
||||
_(ATenreluInt) \
|
||||
_(ATenreluFloat) \
|
||||
_(ATenlogFloat) \
|
||||
_(ATenlog10Float) \
|
||||
_(ATenlog2Float) \
|
||||
_(ATenexpFloat) \
|
||||
_(ATenerfFloat) \
|
||||
_(ATencosFloat) \
|
||||
_(ATeneqInt) \
|
||||
_(ATengeInt) \
|
||||
_(ATengtInt) \
|
||||
_(ATenleInt) \
|
||||
_(ATenltInt) \
|
||||
_(ConstantFoldSimple) \
|
||||
_(ConstantFoldTwoLayer) \
|
||||
_(ConstantFoldShifts) \
|
||||
_(ConstantFoldBitwise) \
|
||||
_(ConstantFoldMultiOp) \
|
||||
_(ConstantFoldMinMax) \
|
||||
_(ConstantFoldIntrinsics) \
|
||||
_(ConstantFoldWithVar) \
|
||||
_(UnFoldableExpr) \
|
||||
_(HashSimple) \
|
||||
_(HashEquivalence) \
|
||||
_(HashEquivalenceAfterFolding) \
|
||||
_(HashDifferenceTypes) \
|
||||
_(HashLargeExpression) \
|
||||
_(SimplifyAdd) \
|
||||
_(SimplifySub) \
|
||||
_(SimplifyMultiLayer) \
|
||||
_(SimplifyMultiTerm) \
|
||||
_(SimplifyCasts) \
|
||||
_(SimplifyEliminatesNoOps) \
|
||||
_(SimplifyMultiVar) \
|
||||
_(SimplifyReorderings) \
|
||||
_(SimplifyEliminatesVar) \
|
||||
_(SimplifyAdds) \
|
||||
_(SimplifyMuls) \
|
||||
_(SimplifySubs) \
|
||||
_(SimplifyMultiOp) \
|
||||
_(SimplifyManyOps) \
|
||||
_(SimplifyFactorization) \
|
||||
_(SimplifyFactorizeUneven) \
|
||||
_(SimplifyDeeperTerms) \
|
||||
_(SimplifyDeeperDifference) \
|
||||
_(SimplifyFoldComplexDifference) \
|
||||
_(SimplifyIfComponents) \
|
||||
_(SimplifyOpaqueTerms) \
|
||||
_(SimplifyWontReorderFloat) \
|
||||
#define TH_FORALL_TESTS(_) \
|
||||
_(ExprBasicValueTest) \
|
||||
_(ExprBasicValueTest02) \
|
||||
_(ExprLetTest01) \
|
||||
_(ExprLetStmtTest01) \
|
||||
_(ExprLetTest02) \
|
||||
_(ExprIntTest) \
|
||||
_(ExprFloatTest) \
|
||||
_(ExprByteTest) \
|
||||
_(ExprCharTest) \
|
||||
_(ExprShortTest) \
|
||||
_(ExprLongTest) \
|
||||
_(ExprHalfTest) \
|
||||
_(ExprDoubleTest) \
|
||||
_(ExprVectorAdd01) \
|
||||
_(ExprCompareSelectEQ) \
|
||||
_(ExprSubstitute01) \
|
||||
_(ExprMath01) \
|
||||
_(ExprUnaryMath01) \
|
||||
_(ExprBinaryMath01) \
|
||||
_(ExprDynamicShapeAdd) \
|
||||
_(ExprBitwiseOps) \
|
||||
_(IRPrinterBasicValueTest) \
|
||||
_(IRPrinterBasicValueTest02) \
|
||||
_(IRPrinterLetTest01) \
|
||||
_(IRPrinterLetTest02) \
|
||||
_(IRPrinterCastTest) \
|
||||
_(ExprSimple01) \
|
||||
_(ExprLower01) \
|
||||
_(ExprSimple02) \
|
||||
_(ExprSplitWithTailNone) \
|
||||
_(ExprSplitWithMask01) \
|
||||
_(ScheduleBroadcastAddBuffer) \
|
||||
_(ScheduleFunctionCall01) \
|
||||
_(ScheduleInlineFunc01) \
|
||||
_(ScheduleFuserStyle) \
|
||||
_(ScheduleFuserThreeArg) \
|
||||
_(ScheduleDynamicShape2D) \
|
||||
_(TypeTest01) \
|
||||
_(TypePropagation) \
|
||||
_(Cond01) \
|
||||
_(IfThenElse01) \
|
||||
_(IfThenElse02) \
|
||||
_(ATen_cast_Float) \
|
||||
_(ATennegInt) \
|
||||
_(ATennegFloat) \
|
||||
_(ATenaddInt) \
|
||||
_(ATenaddFloat) \
|
||||
_(ATensubInt) \
|
||||
_(ATensubFloat) \
|
||||
_(ATenlerp) \
|
||||
_(ATenaddcmulInt) \
|
||||
_(ATenaddcmulFloat) \
|
||||
_(ATenmulInt) \
|
||||
_(ATenmulFloat) \
|
||||
_(ATendivInt) \
|
||||
_(ATendivFloat) \
|
||||
_(ATenmaxInt) \
|
||||
_(ATenmaxFloat) \
|
||||
_(ATenminInt) \
|
||||
_(ATenminFloat) \
|
||||
_(ATen_sigmoid_backward) \
|
||||
_(ATen_tanh_backward) \
|
||||
_(ATenreciprocal) \
|
||||
_(ATenreluInt) \
|
||||
_(ATenreluFloat) \
|
||||
_(ATenlogFloat) \
|
||||
_(ATenlog10Float) \
|
||||
_(ATenlog2Float) \
|
||||
_(ATenexpFloat) \
|
||||
_(ATenerfFloat) \
|
||||
_(ATencosFloat) \
|
||||
_(ATeneqInt) \
|
||||
_(ATengeInt) \
|
||||
_(ATengtInt) \
|
||||
_(ATenleInt) \
|
||||
_(ATenltInt) \
|
||||
_(ConstantFoldSimple) \
|
||||
_(ConstantFoldTwoLayer) \
|
||||
_(ConstantFoldShifts) \
|
||||
_(ConstantFoldBitwise) \
|
||||
_(ConstantFoldMultiOp) \
|
||||
_(ConstantFoldMinMax) \
|
||||
_(ConstantFoldIntrinsics) \
|
||||
_(ConstantFoldWithVar) \
|
||||
_(UnFoldableExpr) \
|
||||
_(HashSimple) \
|
||||
_(HashEquivalence) \
|
||||
_(HashEquivalenceAfterFolding) \
|
||||
_(HashDifferenceTypes) \
|
||||
_(HashLargeExpression) \
|
||||
_(SimplifyAdd) \
|
||||
_(SimplifySub) \
|
||||
_(SimplifyMultiLayer) \
|
||||
_(SimplifyMultiTerm) \
|
||||
_(SimplifyCasts) \
|
||||
_(SimplifyEliminatesNoOps) \
|
||||
_(SimplifyMultiVar) \
|
||||
_(SimplifyEliminatesVar) \
|
||||
_(StmtClone)
|
||||
|
||||
#define TH_FORALL_TESTS_LLVM(_) \
|
||||
|
@ -205,7 +205,6 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/tensorexpr/ir.cpp",
|
||||
"torch/csrc/jit/tensorexpr/ir_mutator.cpp",
|
||||
"torch/csrc/jit/tensorexpr/ir_printer.cpp",
|
||||
"torch/csrc/jit/tensorexpr/ir_simplifier.cpp",
|
||||
"torch/csrc/jit/tensorexpr/ir_visitor.cpp",
|
||||
"torch/csrc/jit/tensorexpr/kernel.cpp",
|
||||
"torch/csrc/jit/tensorexpr/llvm_codegen.cpp",
|
||||
|
@ -59,24 +59,24 @@ class Value {
|
||||
void* ptr;
|
||||
};
|
||||
|
||||
#define VALUE_AS_DISPATCH(Type, Name) \
|
||||
template <> \
|
||||
inline Type Value::as<Type>() const { \
|
||||
if (dtype_ != k##Name) { \
|
||||
throw unsupported_dtype(); \
|
||||
} \
|
||||
return Name##values[0]; \
|
||||
#define VALUE_AS_DISPATCH(Type, Name) \
|
||||
template <> \
|
||||
inline Type Value::as<Type>() const { \
|
||||
if (dtype_ != k##Name) { \
|
||||
throw unsupported_dtype(); \
|
||||
} \
|
||||
return Name##values[0]; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_DISPATCH);
|
||||
#undef VALUE_AS_DISPATCH
|
||||
|
||||
#define VALUE_AS_VEC_DISPATCH(Type, Name) \
|
||||
template <> \
|
||||
inline const std::vector<Type>& Value::as_vec<Type>() const { \
|
||||
if (dtype_.scalar_type() != ScalarType::Name) { \
|
||||
throw unsupported_dtype(); \
|
||||
} \
|
||||
return Name##values; \
|
||||
#define VALUE_AS_VEC_DISPATCH(Type, Name) \
|
||||
template <> \
|
||||
inline const std::vector<Type>& Value::as_vec<Type>() const { \
|
||||
if (dtype_.scalar_type() != ScalarType::Name) { \
|
||||
throw unsupported_dtype(); \
|
||||
} \
|
||||
return Name##values; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_VEC_DISPATCH);
|
||||
#undef VALUE_AS_VEC_DISPATCH
|
||||
@ -479,6 +479,7 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {
|
||||
throw malformed_input(v);
|
||||
}
|
||||
|
||||
|
||||
if (src_dtype != dst_dtype) {
|
||||
switch (src_dtype.scalar_type()) {
|
||||
#define SRC_TYPE_CASE(Type, Name) \
|
||||
@ -911,28 +912,6 @@ inline Stmt* Substitute(Stmt* stmt, const VarMapping& var_mapping) {
|
||||
return stmt->accept_mutator(&var_sub);
|
||||
}
|
||||
|
||||
// Uses the evaluator to fold an Expression with constant terms.
|
||||
// E.g. evaluateOp(Add(3, 4)) => 7.
|
||||
// Expr v must not have any unbound Vars.
|
||||
static Expr* evaluateOp(const Expr* v) {
|
||||
ExprHandle handle(v);
|
||||
ExprEval<SimpleIREvaluator> eval(handle);
|
||||
|
||||
switch (v->dtype().scalar_type()) {
|
||||
#define TYPE_CASE(Type, Name) \
|
||||
case ScalarType::Name: { \
|
||||
Type val = eval.value<Type>(); \
|
||||
return getImmediateByType(v->dtype().scalar_type(), val); \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
|
||||
#undef TYPE_CASE
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported datatype: " << v->dtype();
|
||||
return nullptr;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -31,8 +31,6 @@ enum IRNodeType {
|
||||
kCompareSelect,
|
||||
kLet,
|
||||
kCast,
|
||||
kBroadcast,
|
||||
kRamp,
|
||||
kNone
|
||||
};
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/tensorexpr/ir.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
|
||||
@ -318,13 +316,6 @@ class HashProvider : public IRVisitor {
|
||||
putHash(v, hash);
|
||||
}
|
||||
|
||||
template <typename... Types>
|
||||
SimplifierHashType hash_combine(const Types&... args) {
|
||||
SimplifierHashType seed = 0;
|
||||
_hash_combine(seed, args...);
|
||||
return seed;
|
||||
}
|
||||
|
||||
private:
|
||||
SimplifierHashType hashOf(const Expr* e) {
|
||||
auto it = exprToHash_.find(e);
|
||||
@ -380,10 +371,6 @@ class HashProvider : public IRVisitor {
|
||||
(seed << 7) + (seed >> 4);
|
||||
}
|
||||
|
||||
void _hash_combine(SimplifierHashType& seed, const Expr* e) {
|
||||
_hash_combine(seed, hash(e));
|
||||
}
|
||||
|
||||
template <typename T, typename... Types>
|
||||
void _hash_combine(
|
||||
SimplifierHashType& seed,
|
||||
@ -393,6 +380,13 @@ class HashProvider : public IRVisitor {
|
||||
_hash_combine(seed, args...);
|
||||
}
|
||||
|
||||
template <typename... Types>
|
||||
SimplifierHashType hash_combine(const Types&... args) {
|
||||
SimplifierHashType seed = 0;
|
||||
_hash_combine(seed, args...);
|
||||
return seed;
|
||||
}
|
||||
|
||||
void putHash(const KernelScopedObject* e, SimplifierHashType h) {
|
||||
auto res = exprToHash_.emplace(e, h);
|
||||
if (res.second == false) {
|
||||
|
@ -283,94 +283,24 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE);
|
||||
|
||||
// Get immediate by ScalarType.
|
||||
template <typename T>
|
||||
Expr* getImmediateByType(ScalarType immType, T initialVal) {
|
||||
ExprHandle getImmediateByType(ScalarType immType, T initialVal) {
|
||||
switch (immType) {
|
||||
#define TYPE_CASE(Type, Name) \
|
||||
case ScalarType::Name: \
|
||||
return new Name##Imm(initialVal);
|
||||
return Name##Imm::make(initialVal);
|
||||
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
|
||||
#undef TYPE_CASE
|
||||
default:
|
||||
throw unsupported_dtype();
|
||||
}
|
||||
return nullptr;
|
||||
return ExprHandle();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Expr* getImmediateByType(Dtype dtype, T initialVal) {
|
||||
ExprHandle getImmediateByType(Dtype dtype, T initialVal) {
|
||||
return getImmediateByType<T>(dtype.scalar_type(), initialVal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T immediateAs(const Expr* e) {
|
||||
#define TYPE_CASE(Type, Name) \
|
||||
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(e)) { \
|
||||
return imm->value(); \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
|
||||
#undef TYPE_CASE
|
||||
throw unsupported_dtype();
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool immediateEquals(const Expr* e, T val) {
|
||||
#define TYPE_CASE(Type, Name) \
|
||||
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(e)) { \
|
||||
return imm->value() == val; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
|
||||
#undef TYPE_CASE
|
||||
throw unsupported_dtype();
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool immediateIsNegative(const T* e) {
|
||||
#define TYPE_CASE(Type, Name) \
|
||||
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(e)) { \
|
||||
return imm->value() < 0; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
|
||||
#undef TYPE_CASE
|
||||
return false;
|
||||
}
|
||||
|
||||
// Creates a new Expr of the given type with the provided lhs and rhs.
|
||||
static const Expr* newBinaryOpOfType(
|
||||
IRNodeType expr_type,
|
||||
const Expr* lhs,
|
||||
const Expr* rhs,
|
||||
bool option) {
|
||||
switch (expr_type) {
|
||||
case IRNodeType::kAdd:
|
||||
return new Add(lhs, rhs);
|
||||
case IRNodeType::kSub:
|
||||
return new Sub(lhs, rhs);
|
||||
case IRNodeType::kMul:
|
||||
return new Mul(lhs, rhs);
|
||||
case IRNodeType::kDiv:
|
||||
return new Div(lhs, rhs);
|
||||
case IRNodeType::kMod:
|
||||
return new Mod(lhs, rhs);
|
||||
case IRNodeType::kMax:
|
||||
return new Max(lhs, rhs, option);
|
||||
case IRNodeType::kMin:
|
||||
return new Min(lhs, rhs, option);
|
||||
case IRNodeType::kAnd:
|
||||
return new And(lhs, rhs);
|
||||
case IRNodeType::kXor:
|
||||
return new Xor(lhs, rhs);
|
||||
case IRNodeType::kLshift:
|
||||
return new Lshift(lhs, rhs);
|
||||
case IRNodeType::kRshift:
|
||||
return new Rshift(lhs, rhs);
|
||||
default:
|
||||
LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// Bind the value to the var and evaluate the body.
|
||||
class Let : public ExprNode<Let> {
|
||||
public:
|
||||
@ -424,7 +354,7 @@ class Ramp : public ExprNode<Ramp> {
|
||||
}
|
||||
|
||||
Ramp(const Expr* base, const Expr* stride, int lanes)
|
||||
: ExprNodeBase(Dtype(base->dtype(), lanes), kRamp),
|
||||
: ExprNodeBase(Dtype(base->dtype(), lanes)),
|
||||
base_(base),
|
||||
stride_(stride),
|
||||
lanes_(lanes) {
|
||||
@ -490,7 +420,7 @@ class Broadcast : public ExprNode<Broadcast> {
|
||||
return ExprHandle(new Broadcast(value.node(), lanes));
|
||||
}
|
||||
Broadcast(const Expr* value, int lanes)
|
||||
: ExprNodeBase(Dtype(value->dtype(), lanes), kBroadcast),
|
||||
: ExprNodeBase(Dtype(value->dtype(), lanes)),
|
||||
value_(value),
|
||||
lanes_(lanes) {}
|
||||
|
||||
@ -632,7 +562,8 @@ class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
|
||||
const ExprHandle& ret_val1,
|
||||
const ExprHandle& ret_val2,
|
||||
CompareSelectOperation cmp_op) {
|
||||
if (lhs.dtype() != rhs.dtype() || ret_val1.dtype() != ret_val2.dtype()) {
|
||||
if (lhs.dtype() != rhs.dtype() ||
|
||||
ret_val1.dtype() != ret_val2.dtype()) {
|
||||
throw malformed_input();
|
||||
}
|
||||
return ExprHandle(new CompareSelect(
|
||||
@ -859,8 +790,48 @@ class Intrinsics : public CallNode<Intrinsics> {
|
||||
IntrinsicsOp op_type_;
|
||||
};
|
||||
|
||||
class Polynomial;
|
||||
class Term;
|
||||
/* An internal only Expr used in IR simplification.
|
||||
* Encodes relationship y = Ax + B, where A and B are Immediates.
|
||||
* Not required to be implemented by codegen. */
|
||||
class LinearForm : public ExprNode<LinearForm> {
|
||||
public:
|
||||
LinearForm(const Expr* x, const Expr* A, const Expr* B)
|
||||
: ExprNodeBase(dtypeFor(x, A, B)), x_(x), A_(A), B_(B) {}
|
||||
|
||||
LinearForm(const Expr* x)
|
||||
: ExprNodeBase(x->dtype()),
|
||||
x_(x),
|
||||
A_(new CharImm(1)),
|
||||
B_(new CharImm(0)) {}
|
||||
|
||||
const Expr* getX() const {
|
||||
return x_;
|
||||
}
|
||||
const Expr* getA() const {
|
||||
return A_;
|
||||
}
|
||||
const Expr* getB() const {
|
||||
return B_;
|
||||
}
|
||||
|
||||
void setA(const Expr* A) {
|
||||
A_ = A;
|
||||
}
|
||||
|
||||
void setB(const Expr* B) {
|
||||
B_ = B;
|
||||
}
|
||||
|
||||
static Dtype dtypeFor(const Expr* A, const Expr* B, const Expr* C) {
|
||||
return ToDtype(promoteTypes(
|
||||
A->dtype().scalar_type(), promoteTypes(B->dtype(), C->dtype())));
|
||||
}
|
||||
|
||||
private:
|
||||
const Expr* x_;
|
||||
const Expr* A_;
|
||||
const Expr* B_;
|
||||
};
|
||||
|
||||
class FunctionCall;
|
||||
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include <torch/csrc/jit/tensorexpr/eval.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -215,7 +214,6 @@ const Expr* IRMutator::mutate(const IfThenElse* v) {
|
||||
const Expr* condition_new = condition->accept_mutator(this);
|
||||
const Expr* true_value_new = true_value->accept_mutator(this);
|
||||
const Expr* false_value_new = false_value->accept_mutator(this);
|
||||
|
||||
if (condition == condition_new && true_value == true_value_new &&
|
||||
false_value == false_value_new) {
|
||||
return v;
|
||||
@ -234,24 +232,11 @@ const Expr* IRMutator::mutate(const FunctionCall* v) {
|
||||
return this->mutate(base);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Term* v) {
|
||||
const Expr* newScalar = v->scalar()->accept_mutator(this);
|
||||
|
||||
std::vector<const Expr*> variables;
|
||||
for (const auto* t : v->variables()) {
|
||||
variables.push_back(t->accept_mutator(this));
|
||||
}
|
||||
return new Term(v->hasher(), newScalar, variables);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Polynomial* v) {
|
||||
const Expr* newScalar = v->scalar()->accept_mutator(this);
|
||||
|
||||
std::vector<const Term*> variables;
|
||||
for (const auto* t : v->variables()) {
|
||||
variables.push_back(static_cast<const Term*>(t->accept_mutator(this)));
|
||||
}
|
||||
return new Polynomial(v->hasher(), newScalar, variables);
|
||||
const Expr* IRMutator::mutate(const LinearForm* v) {
|
||||
const Expr* new_x = v->getX()->accept_mutator(this);
|
||||
const Expr* new_a = v->getA()->accept_mutator(this);
|
||||
const Expr* new_b = v->getB()->accept_mutator(this);
|
||||
return new LinearForm(new_x, new_a, new_b);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const BaseCallNode* v) {
|
||||
|
@ -45,8 +45,7 @@ class Allocate;
|
||||
class Free;
|
||||
class Cond;
|
||||
class Stmt;
|
||||
class Term;
|
||||
class Polynomial;
|
||||
class LinearForm;
|
||||
|
||||
class TORCH_API IRMutator {
|
||||
public:
|
||||
@ -87,8 +86,7 @@ class TORCH_API IRMutator {
|
||||
virtual const Expr* mutate(const Intrinsics* v);
|
||||
virtual const Expr* mutate(const FunctionCall* v);
|
||||
|
||||
virtual const Expr* mutate(const Term* v);
|
||||
virtual const Expr* mutate(const Polynomial* v);
|
||||
virtual const Expr* mutate(const LinearForm* v);
|
||||
|
||||
virtual Stmt* mutate(const For* v);
|
||||
virtual Stmt* mutate(const Block* v);
|
||||
|
@ -1,7 +1,5 @@
|
||||
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
||||
|
||||
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
@ -200,7 +198,7 @@ template <
|
||||
typename T,
|
||||
std::enable_if_t<!std::is_floating_point<T>::value>* = nullptr>
|
||||
static void formatImm(std::ostream& os, T v) {
|
||||
os << +v;
|
||||
os << v;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
@ -370,33 +368,9 @@ void IRPrinter::visit(const Cond* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Term* v) {
|
||||
os() << "Term(";
|
||||
v->scalar()->accept(this);
|
||||
for (auto* t : v->variables()) {
|
||||
os() << ",";
|
||||
t->accept(this);
|
||||
}
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Polynomial* v) {
|
||||
bool first = true;
|
||||
os() << "Polynomial(";
|
||||
for (auto* t : v->variables()) {
|
||||
emitIndent();
|
||||
if (!first) {
|
||||
os() << " + ";
|
||||
}
|
||||
first = false;
|
||||
t->accept(this);
|
||||
}
|
||||
|
||||
if (!first) {
|
||||
os() << " + ";
|
||||
}
|
||||
v->scalar()->accept(this);
|
||||
os() << ")";
|
||||
void IRPrinter::visit(const LinearForm* v) {
|
||||
os() << "(" << *v->getA() << ") * (" << *v->getX() << ") + (" << *v->getB()
|
||||
<< ")" << std::endl;
|
||||
}
|
||||
|
||||
void IRPrinter::emitIndent() {
|
||||
|
@ -48,8 +48,7 @@ class TORCH_API IRPrinter : public IRVisitor {
|
||||
void visit(const Allocate* v) override;
|
||||
void visit(const Free* v) override;
|
||||
void visit(const Cond* v) override;
|
||||
void visit(const Term* v) override;
|
||||
void visit(const Polynomial* v) override;
|
||||
void visit(const LinearForm* v) override;
|
||||
|
||||
std::ostream& os() {
|
||||
return printer_os_;
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,275 +1,367 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/tensorexpr/eval.h>
|
||||
#include <torch/csrc/jit/tensorexpr/hash_server.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_mutator.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
|
||||
#include <torch/csrc/jit/tensorexpr/types.h>
|
||||
|
||||
/* IR Simplification
|
||||
*
|
||||
* Simplfies expressions in two stages:
|
||||
* 1. Recursively traverse the map combining similar operations into Terms
|
||||
* (interacted via Multiplication) and Polynomials (interacted via Addition). We
|
||||
* reorder the components of each Term or Polynomial into a consistent order to
|
||||
* allow combination or cancelling of like terms.
|
||||
* 2. Once the format of the tree is minimal, expand each Term into a sequence
|
||||
* of Muls, and each Polynomial into a sequence of Ads.
|
||||
*/
|
||||
#include "torch/csrc/jit/tensorexpr/eval.h"
|
||||
#include "torch/csrc/jit/tensorexpr/ir_mutator.h"
|
||||
#include "torch/csrc/jit/tensorexpr/ir_visitor.h"
|
||||
#include "torch/csrc/jit/tensorexpr/types.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
// A bunch of helpers for determine the Dtype of the output of a multi argument
|
||||
// Term or Polynomial.
|
||||
namespace {
|
||||
template <class ExprType>
|
||||
Dtype promoteTypesVec(const Expr* s, std::vector<const ExprType*>& v) {
|
||||
Dtype t = s->dtype();
|
||||
bool first = true;
|
||||
// Uses the evaluator to fold an operation with constant terms.
|
||||
// Expr v must be evaluatable without Vars.
|
||||
static Expr* evaluateOp(const Expr* v) {
|
||||
ExprHandle handle(v);
|
||||
ExprEval<SimpleIREvaluator> eval(handle);
|
||||
|
||||
for (auto* e : v) {
|
||||
if (first) {
|
||||
t = Dtype(t.scalar_type(), e->dtype().lanes());
|
||||
first = false;
|
||||
switch (v->dtype().scalar_type()) {
|
||||
#define TYPE_CASE(Type, Name) \
|
||||
case ScalarType::Name: { \
|
||||
Type val = eval.value<Type>(); \
|
||||
return getImmediateByType(v->dtype().scalar_type(), val).node(); \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
|
||||
#undef TYPE_CASE
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported datatype: " << v->dtype();
|
||||
return nullptr;
|
||||
}
|
||||
return nullptr;
|
||||
} // namespace tensorexpr
|
||||
|
||||
static const Expr* newBinaryOpOfType(
|
||||
IRNodeType expr_type,
|
||||
const Expr* lhs,
|
||||
const Expr* rhs,
|
||||
bool option) {
|
||||
switch (expr_type) {
|
||||
case IRNodeType::kAdd:
|
||||
return new Add(lhs, rhs);
|
||||
case IRNodeType::kSub:
|
||||
return new Sub(lhs, rhs);
|
||||
case IRNodeType::kMul:
|
||||
return new Mul(lhs, rhs);
|
||||
case IRNodeType::kDiv:
|
||||
return new Div(lhs, rhs);
|
||||
case IRNodeType::kMod:
|
||||
return new Mod(lhs, rhs);
|
||||
case IRNodeType::kMax:
|
||||
return new Max(lhs, rhs, option);
|
||||
case IRNodeType::kMin:
|
||||
return new Min(lhs, rhs, option);
|
||||
case IRNodeType::kAnd:
|
||||
return new And(lhs, rhs);
|
||||
case IRNodeType::kXor:
|
||||
return new Xor(lhs, rhs);
|
||||
case IRNodeType::kLshift:
|
||||
return new Lshift(lhs, rhs);
|
||||
case IRNodeType::kRshift:
|
||||
return new Rshift(lhs, rhs);
|
||||
default:
|
||||
LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
/* Interprets expr as an Immediate and returns the value as type T. */
|
||||
template <typename T>
|
||||
T immediateAs(const Expr* expr) {
|
||||
T val{0};
|
||||
switch (expr->dtype().scalar_type()) {
|
||||
#define TYPE_CASE(Type, Name) \
|
||||
case ScalarType::Name: \
|
||||
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(expr)) { \
|
||||
val = imm->value(); \
|
||||
} else { \
|
||||
LOG(FATAL) << "Bad expr: " << *expr << "\n"; \
|
||||
} \
|
||||
break;
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE);
|
||||
#undef TYPE_CASE
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported datatype: " << expr->dtype();
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Takes a LinearForm and converts it to Mul + (Add/Sub). */
|
||||
const Expr* expandLinearForm(const LinearForm* v, IRMutator* mutator) {
|
||||
const Expr* mul = nullptr;
|
||||
const Expr* A = v->getA();
|
||||
const Expr* B = v->getB();
|
||||
const Expr* X = v->getX();
|
||||
// we only really care about 0 and 1, so double should be fine.
|
||||
double Aval = immediateAs<double>(A);
|
||||
double Bval = immediateAs<double>(B);
|
||||
|
||||
// First handle A.
|
||||
if (Aval == 0) {
|
||||
if (Bval == 0) {
|
||||
return getImmediateByType(X->dtype(), 0).node();
|
||||
}
|
||||
t = promoteTypes(t, e->dtype());
|
||||
}
|
||||
return t;
|
||||
}
|
||||
return B;
|
||||
} else if (Aval == 1) {
|
||||
mul = X;
|
||||
} else if (Aval == -1) {
|
||||
return new Sub(B, X);
|
||||
} else if (Aval < 0) {
|
||||
// Negate A.
|
||||
ExprHandle zero = getImmediateByType(A->dtype(), 0);
|
||||
Sub* A_Sub = new Sub(zero.node(), A);
|
||||
|
||||
template <class ExprType>
|
||||
Dtype promoteTypesVec(std::vector<const ExprType*>& v) {
|
||||
if (v.empty()) {
|
||||
throw malformed_input("empty list of types");
|
||||
return new Sub(B, new Mul(X, evaluateOp(A_Sub)));
|
||||
} else {
|
||||
mul = new Mul(X, A);
|
||||
}
|
||||
|
||||
Dtype t = v[0]->dtype();
|
||||
for (auto* e : v) {
|
||||
t = promoteTypes(t, e->dtype());
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
template <class ExprType>
|
||||
Dtype promoteTypesMap(
|
||||
const Expr* s,
|
||||
std::unordered_map<SimplifierHashType, const ExprType*>& m) {
|
||||
Dtype t = s->dtype();
|
||||
bool first = true;
|
||||
for (auto& e : m) {
|
||||
if (first) {
|
||||
t = Dtype(t.scalar_type(), e.second->dtype().lanes());
|
||||
first = false;
|
||||
}
|
||||
t = promoteTypes(t, e.second->dtype());
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
template <class ExprType>
|
||||
Dtype promoteTypesVar(const ExprType* e) {
|
||||
return e->dtype();
|
||||
}
|
||||
|
||||
template <class ExprType, class... Args>
|
||||
Dtype promoteTypesVar(const ExprType* e, Args... es) {
|
||||
Dtype lhs = e->dtype();
|
||||
Dtype rhs = promoteTypesVar(es...);
|
||||
if (e->isConstant()) {
|
||||
lhs = Dtype(lhs.scalar_type(), rhs.lanes());
|
||||
if (Bval == 0) {
|
||||
return mul;
|
||||
}
|
||||
|
||||
return promoteTypes(lhs, rhs);
|
||||
return new Add(mul, B);
|
||||
}
|
||||
|
||||
// Helper for determining if an Expr is a multi-lane primitive (e.g. Broadcast
|
||||
// or Ramp).
|
||||
bool isMultilanePrimitive(const Expr* e) {
|
||||
return e->expr_type() == IRNodeType::kBroadcast ||
|
||||
e->expr_type() == IRNodeType::kRamp;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// A Term represents a grouping of Exprs through multiplication.
|
||||
// E.g. product(scalar, *variables).
|
||||
class Term : public ExprNode<Term> {
|
||||
/* Expand any remaining LinearTerms into their component pieces */
|
||||
class LinearFormExpander : public IRMutator {
|
||||
public:
|
||||
template <class... Args>
|
||||
Term(HashProvider& hasher, const Expr* s, Args... ts)
|
||||
: ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
|
||||
CHECK(s->isConstant());
|
||||
addComponent(ts...);
|
||||
sort();
|
||||
const Expr* mutate(const LinearForm* v) {
|
||||
return expandLinearForm(v, this);
|
||||
}
|
||||
|
||||
Term(HashProvider& hasher, const Expr* s, std::vector<const Expr*> v)
|
||||
: ExprNodeBase(promoteTypesVec(s, v)),
|
||||
variables_(std::move(v)),
|
||||
scalar_(s),
|
||||
hasher_(hasher) {
|
||||
sort();
|
||||
}
|
||||
|
||||
// Convenience constructor from a map of hash -> var, used when merging Terms.
|
||||
Term(
|
||||
HashProvider& hasher,
|
||||
const Expr* s,
|
||||
std::unordered_map<SimplifierHashType, const Expr*> varmap)
|
||||
: ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
|
||||
for (auto& p : varmap) {
|
||||
addComponent(p.second);
|
||||
}
|
||||
sort();
|
||||
}
|
||||
|
||||
const Expr* scalar() const {
|
||||
return scalar_;
|
||||
}
|
||||
const std::vector<const Expr*>& variables() const {
|
||||
return variables_;
|
||||
}
|
||||
HashProvider& hasher() const {
|
||||
return hasher_;
|
||||
}
|
||||
|
||||
// Produce a hash of just the variable components of this term, to determine
|
||||
// if it can be combined with another term.
|
||||
SimplifierHashType hashVars() const;
|
||||
|
||||
private:
|
||||
std::vector<const Expr*> variables_;
|
||||
const Expr* scalar_;
|
||||
HashProvider& hasher_;
|
||||
|
||||
void addComponent() {}
|
||||
void addComponent(const Expr* e) {
|
||||
variables_.push_back(e);
|
||||
}
|
||||
template <class... Es>
|
||||
void addComponent(const Expr* e, Es... es) {
|
||||
addComponent(e);
|
||||
addComponent(es...);
|
||||
}
|
||||
|
||||
// Sort by hash to normalize order of components.
|
||||
void sort();
|
||||
};
|
||||
|
||||
// Polynomial represents a grouping of Exprs by addition.
|
||||
// E.g. sum(*variables, scalar).
|
||||
// This would better be called Expression, but, naming conflict...
|
||||
class Polynomial : public ExprNode<Polynomial> {
|
||||
/* Simplify the IR by combining arithmetic expressions over a common term.
|
||||
*/
|
||||
class IRSimplifier : public IRMutator {
|
||||
public:
|
||||
template <class... Args>
|
||||
Polynomial(HashProvider& hasher, const Expr* s, Args... ts)
|
||||
: ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
|
||||
CHECK(s->isConstant());
|
||||
addTerm(ts...);
|
||||
sort();
|
||||
}
|
||||
const Expr* mutate(const Add* v) override {
|
||||
const Expr* lhs = v->lhs();
|
||||
const Expr* rhs = v->rhs();
|
||||
const Expr* lhs_new = lhs->accept_mutator(this);
|
||||
const Expr* rhs_new = rhs->accept_mutator(this);
|
||||
|
||||
Polynomial(HashProvider& hasher, const Expr* s, std::vector<const Term*> v)
|
||||
: ExprNodeBase(promoteTypesVec(s, v)),
|
||||
variables_(std::move(v)),
|
||||
scalar_(s),
|
||||
hasher_(hasher) {
|
||||
sort();
|
||||
}
|
||||
|
||||
// Helper constructor for list of terms with no scalar component.
|
||||
Polynomial(HashProvider& hasher, std::vector<const Term*> terms)
|
||||
: ExprNodeBase(promoteTypesVec(terms)),
|
||||
variables_(std::move(terms)),
|
||||
scalar_(getImmediateByType(dtype(), 0)),
|
||||
hasher_(hasher) {
|
||||
sort();
|
||||
}
|
||||
|
||||
// Convenience constructor for map of hash -> var, used when merging
|
||||
// Polynomials.
|
||||
Polynomial(
|
||||
HashProvider& hasher,
|
||||
const Expr* s,
|
||||
std::unordered_map<SimplifierHashType, const Term*> varmap)
|
||||
: ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
|
||||
for (auto& p : varmap) {
|
||||
addTerm(p.second);
|
||||
// Constant Folding.
|
||||
if (lhs_new->isConstant() && rhs_new->isConstant()) {
|
||||
const Expr* result = evaluateOp(v);
|
||||
return result;
|
||||
}
|
||||
sort();
|
||||
|
||||
const LinearForm* lhsLinear = dynamic_cast<const LinearForm*>(lhs_new);
|
||||
const LinearForm* rhsLinear = dynamic_cast<const LinearForm*>(rhs_new);
|
||||
|
||||
if (lhsLinear && rhsLinear) {
|
||||
// Can add two LinearTerms if they reference the same Var.
|
||||
if (lhsLinear->getX() == rhsLinear->getX()) {
|
||||
Add* A_Add = new Add(lhsLinear->getA(), rhsLinear->getA());
|
||||
Add* B_Add = new Add(lhsLinear->getB(), rhsLinear->getB());
|
||||
|
||||
LinearForm* linear = new LinearForm(
|
||||
lhsLinear->getX(), evaluateOp(A_Add), evaluateOp(B_Add));
|
||||
return linear;
|
||||
}
|
||||
|
||||
// otherwise cannot simplify further.
|
||||
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
|
||||
}
|
||||
|
||||
// Can add a scalar into the B term of LinearTerm.
|
||||
if (lhsLinear && rhs_new->isConstant()) {
|
||||
Add* B_Add = new Add(lhsLinear->getB(), rhs_new);
|
||||
LinearForm* linear = new LinearForm(
|
||||
lhsLinear->getX(), lhsLinear->getA(), evaluateOp(B_Add));
|
||||
return linear;
|
||||
}
|
||||
|
||||
if (rhsLinear && lhs_new->isConstant()) {
|
||||
Add* B_Add = new Add(rhsLinear->getB(), lhs_new);
|
||||
LinearForm* linear = new LinearForm(
|
||||
rhsLinear->getX(), rhsLinear->getA(), evaluateOp(B_Add));
|
||||
return linear;
|
||||
}
|
||||
|
||||
// Can create a LinearTerm over any sub expression.
|
||||
if (lhs_new->isConstant()) {
|
||||
LinearForm* linear = new LinearForm(rhs_new);
|
||||
linear->setB(evaluateOp(lhs_new));
|
||||
return linear;
|
||||
}
|
||||
|
||||
if (rhs_new->isConstant()) {
|
||||
LinearForm* linear = new LinearForm(lhs_new);
|
||||
linear->setB(evaluateOp(rhs_new));
|
||||
return linear;
|
||||
}
|
||||
|
||||
/// Broadcasts are a bit more involved.
|
||||
if (const Broadcast* bc = dynamic_cast<const Broadcast*>(lhs_new)) {
|
||||
if (const Expr* ret = handleBroadcastAdd(bc, rhs_new)) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
if (const Broadcast* bc = dynamic_cast<const Broadcast*>(rhs_new)) {
|
||||
if (const Expr* ret = handleBroadcastAdd(bc, lhs_new)) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
// No change.
|
||||
if (lhs == lhs_new && rhs == rhs_new) {
|
||||
return v;
|
||||
}
|
||||
|
||||
// Cannot simplify.
|
||||
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
|
||||
}
|
||||
|
||||
const Expr* scalar() const {
|
||||
return scalar_;
|
||||
}
|
||||
const std::vector<const Term*>& variables() const {
|
||||
return variables_;
|
||||
}
|
||||
HashProvider& hasher() const {
|
||||
return hasher_;
|
||||
const Expr* mutate(const Sub* v) override {
|
||||
const Expr* lhs = v->lhs();
|
||||
const Expr* rhs = v->rhs();
|
||||
const Expr* lhs_new = lhs->accept_mutator(this);
|
||||
const Expr* rhs_new = rhs->accept_mutator(this);
|
||||
|
||||
// Constant Folding.
|
||||
if (lhs_new->isConstant() && rhs_new->isConstant()) {
|
||||
const Expr* result = evaluateOp(v);
|
||||
return result;
|
||||
}
|
||||
|
||||
const LinearForm* lhsLinear = dynamic_cast<const LinearForm*>(lhs_new);
|
||||
const LinearForm* rhsLinear = dynamic_cast<const LinearForm*>(rhs_new);
|
||||
|
||||
if (lhsLinear && rhsLinear) {
|
||||
// Can sub two LinearTerms if they reference the same Var.
|
||||
if (lhsLinear->getX() == rhsLinear->getX()) {
|
||||
Sub* A_Sub = new Sub(lhsLinear->getA(), rhsLinear->getA());
|
||||
Sub* B_Sub = new Sub(lhsLinear->getB(), rhsLinear->getB());
|
||||
|
||||
LinearForm* linear = new LinearForm(
|
||||
lhsLinear->getX(), evaluateOp(A_Sub), evaluateOp(B_Sub));
|
||||
return linear;
|
||||
}
|
||||
|
||||
// otherwise cannot simplify further.
|
||||
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
|
||||
}
|
||||
|
||||
// Can just sub from B term if LHS is a LinearTerm.
|
||||
if (lhsLinear && rhs_new->isConstant()) {
|
||||
Sub* B_Sub = new Sub(lhsLinear->getB(), rhs_new);
|
||||
LinearForm* linear = new LinearForm(
|
||||
lhsLinear->getX(), lhsLinear->getA(), evaluateOp(B_Sub));
|
||||
return linear;
|
||||
}
|
||||
|
||||
// Slightly more complicated if the RHS is LinearTerm.
|
||||
if (rhsLinear && lhs_new->isConstant()) {
|
||||
// The linear needs to be negated.
|
||||
ExprHandle zero = getImmediateByType(rhsLinear->getA()->dtype(), 0);
|
||||
Sub* A_Sub = new Sub(zero.node(), rhsLinear->getA());
|
||||
Sub* B_Sub = new Sub(rhsLinear->getB(), lhs_new);
|
||||
LinearForm* linear = new LinearForm(
|
||||
rhsLinear->getX(), evaluateOp(A_Sub), evaluateOp(B_Sub));
|
||||
return linear;
|
||||
}
|
||||
|
||||
// Can create a new LinearTerm, but since the B term is defined as Add we
|
||||
// must negate it.
|
||||
if (rhs_new->isConstant()) {
|
||||
LinearForm* linear = new LinearForm(lhs_new);
|
||||
|
||||
ExprHandle zero = getImmediateByType(linear->getA()->dtype(), 0);
|
||||
Sub* B_Sub = new Sub(zero.node(), rhs_new);
|
||||
linear->setB(evaluateOp(B_Sub));
|
||||
return linear;
|
||||
}
|
||||
|
||||
// Can create a new LinearTerm with the A term -1 to negate the Expr.
|
||||
if (lhs_new->isConstant()) {
|
||||
// Negate by using -1 as the first linear.
|
||||
ExprHandle negOne = getImmediateByType(rhs_new->dtype(), -1);
|
||||
LinearForm* linear =
|
||||
new LinearForm(rhs_new, negOne.node(), evaluateOp(lhs_new));
|
||||
return linear;
|
||||
}
|
||||
|
||||
// Nothing to do.
|
||||
if (lhs == lhs_new && rhs == rhs_new) {
|
||||
return v;
|
||||
}
|
||||
|
||||
// Cannot simplify.
|
||||
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
|
||||
}
|
||||
|
||||
SimplifierHashType hashVars() const;
|
||||
const Expr* mutate(const Mul* v) override {
|
||||
const Expr* lhs = v->lhs();
|
||||
const Expr* rhs = v->rhs();
|
||||
const Expr* lhs_new = lhs->accept_mutator(this);
|
||||
const Expr* rhs_new = rhs->accept_mutator(this);
|
||||
|
||||
private:
|
||||
std::vector<const Term*> variables_;
|
||||
const Expr* scalar_;
|
||||
HashProvider& hasher_;
|
||||
// Constant Folding.
|
||||
if (lhs_new->isConstant() && rhs_new->isConstant()) {
|
||||
return evaluateOp(v);
|
||||
}
|
||||
|
||||
void addTerm(const Term* t) {
|
||||
variables_.push_back(t);
|
||||
const LinearForm* lhsLinear = dynamic_cast<const LinearForm*>(lhs_new);
|
||||
const LinearForm* rhsLinear = dynamic_cast<const LinearForm*>(rhs_new);
|
||||
|
||||
if (lhsLinear && rhsLinear) {
|
||||
// Lets not get into higher order terms.
|
||||
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
|
||||
}
|
||||
|
||||
// Easy to simplify into an existing LinearTerm by multiplying A and B.
|
||||
if (lhsLinear && rhs_new->isConstant()) {
|
||||
Mul* A_Mul = new Mul(lhsLinear->getA(), rhs_new);
|
||||
Mul* B_Mul = new Mul(lhsLinear->getB(), rhs_new);
|
||||
LinearForm* linear = new LinearForm(
|
||||
lhsLinear->getX(), evaluateOp(A_Mul), evaluateOp(B_Mul));
|
||||
return linear;
|
||||
}
|
||||
|
||||
if (rhsLinear && lhs_new->isConstant()) {
|
||||
Mul* A_Mul = new Mul(rhsLinear->getA(), lhs_new);
|
||||
Mul* B_Mul = new Mul(rhsLinear->getB(), lhs_new);
|
||||
LinearForm* linear = new LinearForm(
|
||||
rhsLinear->getX(), evaluateOp(A_Mul), evaluateOp(B_Mul));
|
||||
return linear;
|
||||
}
|
||||
|
||||
// Easy to create a new LinearTerm by setting term A.
|
||||
if (lhs_new->isConstant()) {
|
||||
LinearForm* linear = new LinearForm(rhs_new);
|
||||
linear->setA(evaluateOp(lhs_new));
|
||||
return linear;
|
||||
}
|
||||
|
||||
if (rhs_new->isConstant()) {
|
||||
LinearForm* linear = new LinearForm(lhs_new);
|
||||
linear->setA(evaluateOp(rhs_new));
|
||||
return linear;
|
||||
}
|
||||
|
||||
// Broadcasts have special logic.
|
||||
if (const Broadcast* bc = dynamic_cast<const Broadcast*>(lhs_new)) {
|
||||
if (const Expr* ret = handleBroadcastMul(bc, rhs_new)) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
if (const Broadcast* bc = dynamic_cast<const Broadcast*>(rhs_new)) {
|
||||
if (const Expr* ret = handleBroadcastMul(bc, lhs_new)) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
// Cannot be simplified, just exit.
|
||||
if (lhs == lhs_new && rhs == rhs_new) {
|
||||
return v;
|
||||
}
|
||||
|
||||
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
|
||||
}
|
||||
template <class... Ts>
|
||||
void addTerm(const Term* t, Ts... ts) {
|
||||
addTerm(t);
|
||||
addTerm(ts...);
|
||||
}
|
||||
|
||||
// Sort by hash to normalize order of terms.
|
||||
void sort();
|
||||
};
|
||||
|
||||
// Simplify the IR by combining arithmetic expressions over common terms.
|
||||
class TORCH_API PolynomialTransformer : public IRMutator {
|
||||
public:
|
||||
// Inserts term into the provided map, in the case of a hash collision
|
||||
// combines the term with the existing and updates the map.
|
||||
void addOrUpdateTerm(
|
||||
std::unordered_map<SimplifierHashType, const Term*>& varmap,
|
||||
const Term* term);
|
||||
|
||||
// Add Polynomial expressions, combining Terms representing the same
|
||||
// variables.
|
||||
const Expr* addPolynomials(const Polynomial* lhs, const Polynomial* rhs);
|
||||
|
||||
// Insert a new Term into the provided polynomial. If the new term has common
|
||||
// variables to an existing term it is combined.
|
||||
const Expr* insertTerm(const Polynomial* poly, const Term* term);
|
||||
|
||||
// Merge and simplify addition.
|
||||
const Expr* mutate(const Add* v) override;
|
||||
|
||||
// Subtract one term from another, cancelling if necessary.
|
||||
const Expr* subTerms(const Term* lhs, const Term* rhs, bool negated);
|
||||
|
||||
// Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where
|
||||
// possible.
|
||||
const Expr* subPolynomials(const Polynomial* lhs, const Polynomial* rhs);
|
||||
|
||||
// Merge and simplify subtraction.
|
||||
const Expr* mutate(const Sub* v) override;
|
||||
|
||||
// Multiply two terms together, usually creating a new term with the variable
|
||||
// lists concatenated.
|
||||
const Term* mulTerms(const Term* lhs, const Term* rhs);
|
||||
|
||||
// Multiply a Polynomial by a Term.
|
||||
const Expr* polyByTerm(const Polynomial* poly, const Term* term);
|
||||
|
||||
// Merge and simplify multiplication.
|
||||
const Expr* mutate(const Mul* v) override;
|
||||
|
||||
const Expr* mutate(const Div* v) override {
|
||||
// TODO div simplification will require a rational node.
|
||||
@ -304,9 +396,37 @@ class TORCH_API PolynomialTransformer : public IRMutator {
|
||||
return mutateBinaryOp(v, this, v->propagate_nans());
|
||||
}
|
||||
|
||||
const Expr* mutate(const Intrinsics* v) override;
|
||||
const Expr* mutate(const Intrinsics* v) override {
|
||||
std::vector<const Expr*> new_params;
|
||||
bool changed = false;
|
||||
bool allConstant = true;
|
||||
for (const auto* p : v->params()) {
|
||||
const Expr* new_child = p->accept_mutator(this);
|
||||
new_params.push_back(new_child);
|
||||
|
||||
const Expr* mutate(const Cast* v) override;
|
||||
changed |= p != new_child;
|
||||
allConstant &= new_child->isConstant();
|
||||
}
|
||||
|
||||
const Expr* node = v;
|
||||
if (changed) {
|
||||
node = new Intrinsics(v->op_type(), new_params);
|
||||
}
|
||||
|
||||
if (!allConstant || !v->isPure()) {
|
||||
return node;
|
||||
}
|
||||
|
||||
return evaluateOp(node);
|
||||
}
|
||||
|
||||
const Expr* mutate(const Cast* v) override {
|
||||
if (v->src_value()->isConstant()) {
|
||||
return evaluateOp(v);
|
||||
}
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
static const Expr* mutateBinaryOp(
|
||||
@ -332,40 +452,12 @@ class TORCH_API PolynomialTransformer : public IRMutator {
|
||||
return evaluateOp(node);
|
||||
}
|
||||
|
||||
static const Expr* simplify(const Expr* e);
|
||||
static ExprHandle simplify(const ExprHandle& e);
|
||||
static Stmt* simplify(Stmt* e);
|
||||
|
||||
private:
|
||||
HashProvider hasher_;
|
||||
}; // namespace tensorexpr
|
||||
|
||||
// Expands Terms and Polynomial expressions into primitive operations.
|
||||
// Does some simple factorization and reordering.
|
||||
class TORCH_API TermExpander : public IRMutator {
|
||||
PolynomialTransformer* simplifier_;
|
||||
|
||||
public:
|
||||
TermExpander(PolynomialTransformer* simplifier) : simplifier_(simplifier) {}
|
||||
|
||||
// Expand Terms out to a series of Muls.
|
||||
const Expr* mutate(const Term* v) override;
|
||||
|
||||
// Trivially factorize terms by GCD of scalar components.
|
||||
const Expr* factorizePolynomial(const Polynomial* poly);
|
||||
|
||||
// Expand Polynomials out to a series of Adds.
|
||||
const Expr* mutate(const Polynomial* v);
|
||||
};
|
||||
|
||||
class TORCH_API IRSimplifier {
|
||||
public:
|
||||
static const Expr* simplify(const Expr* e) {
|
||||
PolynomialTransformer simplifier;
|
||||
IRSimplifier simplifier;
|
||||
e = e->accept_mutator(&simplifier);
|
||||
|
||||
// There may be terms left in the IR, expand them.
|
||||
TermExpander expander(&simplifier);
|
||||
LinearFormExpander expander;
|
||||
e = e->accept_mutator(&expander);
|
||||
|
||||
return e;
|
||||
@ -376,15 +468,74 @@ class TORCH_API IRSimplifier {
|
||||
}
|
||||
|
||||
static Stmt* simplify(Stmt* s) {
|
||||
PolynomialTransformer simplifier;
|
||||
IRSimplifier simplifier;
|
||||
s = s->accept_mutator(&simplifier);
|
||||
|
||||
// There may be terms left in the IR, expand them.
|
||||
TermExpander expander(&simplifier);
|
||||
LinearFormExpander expander;
|
||||
s = s->accept_mutator(&expander);
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
private:
|
||||
/* Expands lhs and rhs if they are LinearTerms, creating a new op to hold
|
||||
* them. If either side expands to a constant term, attempt simplification of
|
||||
* the new op. */
|
||||
const Expr* expandAndRecurse(
|
||||
IRNodeType expr_type,
|
||||
const Expr* lhs,
|
||||
const Expr* rhs) {
|
||||
if (const LinearForm* lhsLinear = dynamic_cast<const LinearForm*>(lhs)) {
|
||||
lhs = expandLinearForm(lhsLinear, this);
|
||||
}
|
||||
if (const LinearForm* rhsLinear = dynamic_cast<const LinearForm*>(rhs)) {
|
||||
rhs = expandLinearForm(rhsLinear, this);
|
||||
}
|
||||
const Expr* result = newBinaryOpOfType(expr_type, lhs, rhs, false);
|
||||
|
||||
// lhs or rhs can become constant during expansion, if either is now
|
||||
// constant we can keep merging into a linear term. Have another attempt to
|
||||
// simplify the new op.
|
||||
if (lhs->isConstant() || rhs->isConstant()) {
|
||||
return result->accept_mutator(this);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/* Handles optimization cases for Broadcast() + Other */
|
||||
const Expr* handleBroadcastAdd(const Broadcast* bc, const Expr* other) {
|
||||
if (bc->value()->isConstant() && immediateAs<int>(bc->value()) == 0) {
|
||||
return other;
|
||||
}
|
||||
|
||||
if (const Ramp* r = dynamic_cast<const Ramp*>(other)) {
|
||||
// Add the broadcast to the start of the Ramp.
|
||||
const Expr* ret =
|
||||
new Ramp(new Add(bc->value(), r->base()), r->stride(), r->lanes());
|
||||
return ret->accept_mutator(this);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/* Handles optimization cases for Broadcast() * Other */
|
||||
const Expr* handleBroadcastMul(const Broadcast* bc, const Expr* other) {
|
||||
if (bc->value()->isConstant() && immediateAs<int>(bc->value()) == 1) {
|
||||
return other;
|
||||
}
|
||||
|
||||
if (const Ramp* r = dynamic_cast<const Ramp*>(other)) {
|
||||
// Multiply both start and stride by the broadcast value.
|
||||
const Expr* ret = new Ramp(
|
||||
new Mul(bc->value(), r->base()),
|
||||
new Mul(bc->value(), r->stride()),
|
||||
r->lanes());
|
||||
return ret->accept_mutator(this);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
|
@ -1,7 +1,6 @@
|
||||
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
|
||||
|
||||
#include <torch/csrc/jit/tensorexpr/ir.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
||||
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
||||
|
||||
namespace torch {
|
||||
@ -177,18 +176,10 @@ void IRVisitor::visit(const Cond* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Term* v) {
|
||||
v->scalar()->accept(this);
|
||||
for (auto* t : v->variables()) {
|
||||
t->accept(this);
|
||||
}
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Polynomial* v) {
|
||||
v->scalar()->accept(this);
|
||||
for (auto* t : v->variables()) {
|
||||
t->accept(this);
|
||||
}
|
||||
void IRVisitor::visit(const LinearForm* v) {
|
||||
v->getA()->accept(this);
|
||||
v->getX()->accept(this);
|
||||
v->getB()->accept(this);
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
|
@ -42,8 +42,7 @@ class FunctionCall;
|
||||
class Allocate;
|
||||
class Free;
|
||||
class Cond;
|
||||
class Term;
|
||||
class Polynomial;
|
||||
class LinearForm;
|
||||
|
||||
class TORCH_API IRVisitor {
|
||||
public:
|
||||
@ -91,8 +90,7 @@ class TORCH_API IRVisitor {
|
||||
virtual void visit(const Allocate* v);
|
||||
virtual void visit(const Free* v);
|
||||
virtual void visit(const Cond* v);
|
||||
virtual void visit(const Term* v);
|
||||
virtual void visit(const Polynomial* v);
|
||||
virtual void visit(const LinearForm* v);
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
|
@ -1114,7 +1114,6 @@ void TensorExprKernel::lowerToBackend(BackendType backendType) {
|
||||
|
||||
l.ApplyInlines();
|
||||
Stmt* stmt = l.root_stmt();
|
||||
// Arithmetic Simplification.
|
||||
stmt = IRSimplifier::simplify(stmt);
|
||||
|
||||
// Set up formal params (inputs, then outputs) for kernel.
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <torch/csrc/jit/tensorexpr/types.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/csrc/jit/tensorexpr/exceptions.h>
|
||||
#include <torch/csrc/jit/tensorexpr/types.h>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
|
@ -50,7 +50,7 @@ class TORCH_API Dtype {
|
||||
Dtype(Dtype type, int lanes)
|
||||
: scalar_type_(type.scalar_type_), lanes_(lanes) {
|
||||
if (type.lanes() != 1) {
|
||||
throw malformed_input("dtype lanes dont match");
|
||||
throw malformed_input();
|
||||
}
|
||||
}
|
||||
int lanes() const {
|
||||
@ -108,15 +108,10 @@ inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
||||
return static_cast<ScalarType>(c10::promoteTypes(
|
||||
static_cast<c10::ScalarType>(a), static_cast<c10::ScalarType>(b)));
|
||||
}
|
||||
inline Dtype promoteTypes(Dtype a, Dtype b) {
|
||||
if (a.lanes() != b.lanes()) {
|
||||
throw malformed_input("promoting types with different lanes");
|
||||
}
|
||||
return Dtype(
|
||||
static_cast<ScalarType>(c10::promoteTypes(
|
||||
static_cast<c10::ScalarType>(a.scalar_type()),
|
||||
static_cast<c10::ScalarType>(b.scalar_type()))),
|
||||
a.lanes());
|
||||
inline ScalarType promoteTypes(Dtype a, Dtype b) {
|
||||
return static_cast<ScalarType>(c10::promoteTypes(
|
||||
static_cast<c10::ScalarType>(a.scalar_type()),
|
||||
static_cast<c10::ScalarType>(b.scalar_type())));
|
||||
}
|
||||
|
||||
inline Dtype BinaryOpDtype(
|
||||
@ -132,21 +127,22 @@ inline Dtype BinaryOpDtype(
|
||||
}
|
||||
|
||||
if (op1_dtype.lanes() != op2_dtype.lanes()) {
|
||||
throw malformed_input("lanes dont match");
|
||||
throw malformed_input();
|
||||
}
|
||||
int lanes = op1_dtype.lanes();
|
||||
|
||||
Dtype resultType = promoteTypes(op1_dtype, op2_dtype);
|
||||
if (resultType.scalar_type() == ScalarType::Undefined) {
|
||||
throw malformed_input("scalar type doesn't match");
|
||||
ScalarType resultType = promoteTypes(op1_dtype, op2_dtype);
|
||||
if (resultType == ScalarType::Undefined) {
|
||||
throw malformed_input();
|
||||
}
|
||||
|
||||
|
||||
if (lanes == 1) {
|
||||
// Use the fixed scalar Dtypes.
|
||||
return ToDtype(resultType.scalar_type());
|
||||
return ToDtype(resultType);
|
||||
}
|
||||
|
||||
return resultType;
|
||||
return Dtype(resultType, lanes);
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
|
Reference in New Issue
Block a user