diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 9ad5de242426..b9dec0b87c9c 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -481,7 +481,6 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir.cpp ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_mutator.cpp ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_printer.cpp - ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_simplifier.cpp ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_visitor.cpp ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/kernel.cpp ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/llvm_codegen.cpp diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp index c2b16fd6dd56..f97f6da343ab 100644 --- a/test/cpp/tensorexpr/test_simplify.cpp +++ b/test/cpp/tensorexpr/test_simplify.cpp @@ -11,40 +11,6 @@ namespace jit { using namespace torch::jit::tensorexpr; using SimpleIRExprEval = ExprEval; -#define IS_NODE(T, node) \ - { \ - auto* node_ = dynamic_cast(node); \ - EXPECT_NE(nullptr, node_) << "Expected node to be " #T; \ - } - -#define IS_NODE_WITH_NAME(T, node, name) \ - auto* name = dynamic_cast(node); \ - EXPECT_NE(nullptr, name) << "Expected " #name " to be " #T; - -#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \ - const T* name = nullptr; \ - { \ - auto* node_ = dynamic_cast(node); \ - EXPECT_NE(nullptr, node_); \ - EXPECT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \ - name = dynamic_cast(node_->src_value()); \ - } \ - EXPECT_NE(nullptr, name) << "Expected " #name " to be " #T; - -#define IS_IMM_WITH_VAL(T, node, val) \ - { \ - auto* node_ = dynamic_cast(node); \ - EXPECT_NE(nullptr, node_) << "Expected node to be " #T "Imm"; \ - EXPECT_EQ(node_->value(), val) << "Expected Imm to be " << val; \ - } - -#define IS_VAR_WITH_NAME(node, name) \ - { \ - auto* node_ = dynamic_cast(node); \ - EXPECT_NE(nullptr, node_) << "Expected node to be Var"; \ - EXPECT_EQ(node_->name_hint(), name) << "Expected var to be " #name; \ - } - void testConstantFoldSimple() { KernelScope kernel_scope; ExprHandle a(2.0f); @@ -167,33 +133,17 @@ void testConstantFoldIntrinsics() { void testConstantFoldWithVar() { KernelScope kernel_scope; - { - VarHandle x("x", kInt); - ExprHandle body = x * (ExprHandle(2) + ExprHandle(4)); + VarHandle x("x", kFloat); + ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f)); - ExprHandle newF = IRSimplifier::simplify(body); - const Mul* root = newF.AsNode(); - EXPECT_NE(root, nullptr); - EXPECT_NE(dynamic_cast(root->lhs()), nullptr); + ExprHandle newF = IRSimplifier::simplify(body); + const Mul* root = newF.AsNode(); + EXPECT_NE(root, nullptr); + EXPECT_NE(dynamic_cast(root->rhs()), nullptr); - ExprHandle result = Let::make(x, ExprHandle(3), newF); - SimpleIRExprEval eval(result); - EXPECT_EQ(eval.value(), 3 * (2 + 4)); - } - - { - VarHandle x("x", kFloat); - ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f)); - - ExprHandle newF = IRSimplifier::simplify(body); - const Mul* root = newF.AsNode(); - EXPECT_NE(root, nullptr); - EXPECT_NE(dynamic_cast(root->rhs()), nullptr); - - ExprHandle result = Let::make(x, ExprHandle(3.f), newF); - SimpleIRExprEval eval(result); - EXPECT_EQ(eval.value(), 3 * (2 + 4)); - } + ExprHandle result = Let::make(x, ExprHandle(3.f), newF); + SimpleIRExprEval eval(result); + EXPECT_EQ(eval.value(), 3 * (2 + 4)); } void testUnFoldableExpr() { @@ -278,22 +228,34 @@ void testHashEquivalenceAfterFolding() { ExprHandle a(2.0f); ExprHandle b(3.0f); ExprHandle c(5.0f); - ExprHandle f1 = ((a + b) * x); - ExprHandle f2 = (c * x); + ExprHandle f = ((a + b) * x) * (c * x); + + const Mul* root = f.AsNode(); + EXPECT_NE(root, nullptr); HashProvider hasher; - auto hash_l = hasher.hash(f1.node()); - auto hash_r = hasher.hash(f2.node()); + auto hash_f = hasher.hash(f.node()); + auto hash_l = hasher.hash(root->lhs()); + auto hash_r = hasher.hash(root->rhs()); + // Root not equal to either branch, and branches not equal. + EXPECT_NE(hash_f, hash_l); + EXPECT_NE(hash_f, hash_r); EXPECT_NE(hash_l, hash_r); - ExprHandle ff1 = IRSimplifier::simplify(f1); - ExprHandle ff2 = IRSimplifier::simplify(f2); + ExprHandle newF = IRSimplifier::simplify(f); - auto hash_l_n = hasher.hash(ff1.node()); - auto hash_r_n = hasher.hash(ff2.node()); + const Mul* newRoot = newF.AsNode(); + EXPECT_NE(newRoot, nullptr); - // branches are now equal. + auto hash_f_n = hasher.hash(newF.node()); + auto hash_l_n = hasher.hash(newRoot->lhs()); + auto hash_r_n = hasher.hash(newRoot->rhs()); + + // Root not equal to either branch. + EXPECT_NE(hash_f_n, hash_l_n); + EXPECT_NE(hash_f_n, hash_r_n); + // but branches are now equal. EXPECT_EQ(hash_l_n, hash_r_n); } @@ -381,16 +343,11 @@ void testHashLargeExpression() { EXPECT_NE(hash_t, hash_f); } -/// (2 + x) + 4 => x + 6 +/// (2.f + x) + 4.f => x + 6.f void testSimplifyAdd() { KernelScope kernel_scope; - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle n_1("n_1", kInt); - ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4); + VarHandle x("x", kFloat); + ExprHandle body = (ExprHandle(2.f) + x) + ExprHandle(4.f); ExprHandle simplified = IRSimplifier::simplify(body); const Add* root = simplified.AsNode(); @@ -398,43 +355,51 @@ void testSimplifyAdd() { const Var* lhs = dynamic_cast(root->lhs()); EXPECT_NE(lhs, nullptr); EXPECT_EQ(lhs->name_hint(), "x"); - const IntImm* rhs = dynamic_cast(root->rhs()); + const FloatImm* rhs = dynamic_cast(root->rhs()); EXPECT_NE(rhs, nullptr); EXPECT_EQ(rhs->value(), 6.f); } -/// (2 - x) - 4 => -2 - x +/// (2.f - x) - 4.f => -2.f - x void testSimplifySub() { KernelScope kernel_scope; - VarHandle x("x", kInt); - ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4); + VarHandle x("x", kFloat); + ExprHandle body = (ExprHandle(2.f) - x) - ExprHandle(4.f); ExprHandle simplified = IRSimplifier::simplify(body); const Sub* root = simplified.AsNode(); EXPECT_NE(root, nullptr); - const IntImm* lhs = dynamic_cast(root->lhs()); + const FloatImm* lhs = dynamic_cast(root->lhs()); EXPECT_NE(lhs, nullptr); - EXPECT_EQ(lhs->value(), -2); + EXPECT_EQ(lhs->value(), -2.f); const Var* rhs = dynamic_cast(root->rhs()); EXPECT_NE(rhs, nullptr); EXPECT_EQ(rhs->name_hint(), "x"); } -/// 2 * (1 - x) - 4 => -2 * (x + 3) +/// 2.f * (1.f - x) - 4.f => -6.f - (x * 2.f) void testSimplifyMultiLayer() { KernelScope kernel_scope; - VarHandle x("x", kInt); - ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4)); - ExprHandle simplified = IRSimplifier::simplify(body); + VarHandle x("x", kFloat); + ExprHandle body = ExprHandle(2.f) * ((ExprHandle(1.f) - x) - ExprHandle(4.f)); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), -2); - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_IMM_WITH_VAL(Int, add->rhs(), 3); + ExprHandle simplified = IRSimplifier::simplify(body); + const Sub* root = simplified.AsNode(); + EXPECT_NE(root, nullptr); + const FloatImm* lhs = dynamic_cast(root->lhs()); + EXPECT_NE(lhs, nullptr); + EXPECT_EQ(lhs->value(), -6.f); + const Mul* rhs = dynamic_cast(root->rhs()); + EXPECT_NE(rhs, nullptr); + const Var* varX = dynamic_cast(rhs->lhs()); + EXPECT_NE(varX, nullptr); + EXPECT_EQ(varX->name_hint(), "x"); + const FloatImm* mulRhs = dynamic_cast(rhs->rhs()); + EXPECT_NE(mulRhs, nullptr); + EXPECT_EQ(mulRhs->value(), 2.f); } -/// 2 * (3 * x) - (x * 4) => 2 * x +/// 2 * (3 * x) - (x * 4) => x * 2 void testSimplifyMultiTerm() { KernelScope kernel_scope; VarHandle x("x", kInt); @@ -444,30 +409,30 @@ void testSimplifyMultiTerm() { ExprHandle simplified = IRSimplifier::simplify(body); const Mul* root = simplified.AsNode(); EXPECT_NE(root, nullptr); - const IntImm* lhs = dynamic_cast(root->lhs()); + const Var* lhs = dynamic_cast(root->lhs()); EXPECT_NE(lhs, nullptr); - EXPECT_EQ(lhs->value(), 2); - const Var* rhs = dynamic_cast(root->rhs()); + EXPECT_EQ(lhs->name_hint(), "x"); + const IntImm* rhs = dynamic_cast(root->rhs()); EXPECT_NE(rhs, nullptr); - EXPECT_EQ(rhs->name_hint(), "x"); + EXPECT_EQ(rhs->value(), 2); } -/// 2 * (3 * (long)x) - (x * 4) => 2 * x +/// 2 * (3 * (f)x) - (x * 4) => x * 2.f void testSimplifyCasts() { KernelScope kernel_scope; - VarHandle x("x", kLong); + VarHandle x("x", kFloat); ExprHandle body = (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); ExprHandle simplified = IRSimplifier::simplify(body); const Mul* root = simplified.AsNode(); EXPECT_NE(root, nullptr); - const LongImm* lhs = dynamic_cast(root->lhs()); + const Var* lhs = dynamic_cast(root->lhs()); EXPECT_NE(lhs, nullptr); - EXPECT_EQ(lhs->value(), 2); - const Var* rhs = dynamic_cast(root->rhs()); + EXPECT_EQ(lhs->name_hint(), "x"); + const FloatImm* rhs = dynamic_cast(root->rhs()); EXPECT_NE(rhs, nullptr); - EXPECT_EQ(rhs->name_hint(), "x"); + EXPECT_EQ(rhs->value(), 2); } /// (x + 0) * 1 => x @@ -487,39 +452,20 @@ void testSimplifyMultiVar() { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); - ExprHandle body = x * 24 + y * 34; + ExprHandle body = y * 24 + x * 34; ExprHandle simplified = IRSimplifier::simplify(body); - const Add* root = simplified.AsNode(); EXPECT_NE(root, nullptr); const Mul* lhs = dynamic_cast(root->lhs()); EXPECT_NE(lhs, nullptr); - const Var* varX = dynamic_cast(lhs->rhs()); - EXPECT_NE(varX, nullptr); - EXPECT_EQ(varX->name_hint(), "y"); + const Var* varY = dynamic_cast(lhs->lhs()); + EXPECT_EQ(varY->name_hint(), "y"); const Mul* rhs = dynamic_cast(root->rhs()); EXPECT_NE(rhs, nullptr); - const Var* varY = dynamic_cast(rhs->rhs()); - EXPECT_NE(varY, nullptr); - EXPECT_EQ(varY->name_hint(), "x"); -} - -// x + 2 + y => x + y + 2 -void testSimplifyReorderings() { - KernelScope kernel_scope; - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = x + 2 + y; - ExprHandle simplified = IRSimplifier::simplify(body); - - const Add* root = simplified.AsNode(); - EXPECT_NE(root, nullptr); - - IS_NODE_WITH_NAME(Add, root->lhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - IS_IMM_WITH_VAL(Int, root->rhs(), 2); + const Var* varX = dynamic_cast(rhs->lhs()); + EXPECT_NE(varX, nullptr); + EXPECT_EQ(varX->name_hint(), "x"); } /// y + x * 0 => y @@ -530,621 +476,9 @@ void testSimplifyEliminatesVar() { ExprHandle body = y + x * ExprHandle(0); ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "y"); -} - -void testSimplifyAdds() { - KernelScope kernel_scope; - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (x + y) + (x + y) => 2 * (x + y) - ExprHandle body = (x + y) + (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), root); - IS_IMM_WITH_VAL(Int, root->lhs(), 2); - IS_NODE_WITH_NAME(Add, root->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - } - - { - // (x * y) + (x * y) => 2 * (x * y) - ExprHandle body = (x * y) + (x * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), root); - IS_IMM_WITH_VAL(Int, root->lhs(), 2); - IS_NODE_WITH_NAME(Mul, root->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // (x - y) + (x - y) => -2 * (y - x) - ExprHandle body = (x - y) + (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), -2); - - IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "y"); - IS_VAR_WITH_NAME(rhs->rhs(), "x"); - } -} - -void testSimplifyMuls() { - KernelScope kernel_scope; - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (x + y) * (x + y) => (x + y) * (x + y) - // We don't attempt to simplify mulitplication of polynomials since the - // result is only very rarely more efficient. - ExprHandle body = (x + y) * (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "x"); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - IS_NODE_WITH_NAME(Add, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // x * y * x * y => x * x * y * y - // These get reordered only. - ExprHandle body = x * y * x * y; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul1); - IS_NODE_WITH_NAME(Mul, mul1->lhs(), mul2); - IS_NODE_WITH_NAME(Mul, mul2->lhs(), mul3); - IS_VAR_WITH_NAME(mul1->rhs(), "y"); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - IS_VAR_WITH_NAME(mul3->lhs(), "x"); - IS_VAR_WITH_NAME(mul3->rhs(), "x"); - } - - { - // (x - y) * (x - y) => (x - y) * (x - y) - // As with Add we don't attempt simplification of this. - ExprHandle body = (x - y) * (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Sub, mul->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "x"); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // (x + y) * (x - y) => (x - y) * (x - y) - // Don't simplify with different ops on each side. - ExprHandle body = (x + y) * (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "x"); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } -} - -// Sub an expr from itself will result in zero. -void testSimplifySubs() { - KernelScope kernel_scope; - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (x + y) - (x + y) => 0 - ExprHandle body = (x + y) - (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // (x * y) - (x * y) => 0 - ExprHandle body = (x * y) - (x * y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // (x - y) - (x - y) => 0 - ExprHandle body = (x - y) - (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // (x + y) - 2 * (x + y) => -1 * (x + y) - ExprHandle body = (x + y) - ExprHandle(2) * (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), -1); - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - } - - { - // (x + y) - y => x - ExprHandle body = (x + y) - y; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // (x - y) - y => x - 2 * y - ExprHandle body = (x - y) - y; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_VAR_WITH_NAME(sub->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // 2 * x - x => x - ExprHandle body = (ExprHandle(2) * x) - x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // x - 2 * x = -1 * x - // We don't have a unary negate, but this could be 0 -x I guess? - ExprHandle body = x - (ExprHandle(2) * x); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - - IS_IMM_WITH_VAL(Int, mul->lhs(), -1); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // (x + y + 5) * (x - x) => 0 - // Cancelling out one side of Mul cancels both. - ExprHandle body = (x + y + 5) * (x - x); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } -} - -// Test that mixing ops together simplifies as expected. -void testSimplifyMultiOp() { - KernelScope kernel_scope; - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (x * y) + (x - y) => (x * y) + x - y - // - ExprHandle body = (x * y) + (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Add, sub->lhs(), add); - IS_NODE_WITH_NAME(Mul, add->lhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - IS_VAR_WITH_NAME(add->rhs(), "x"); - IS_VAR_WITH_NAME(sub->rhs(), "y"); - } - - { - // (x + y) - (x * y) => x + y - (x * y) - ExprHandle body = (x + y) - (x * y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Add, sub->lhs(), add); - IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // (x - y) - (x + y) => -2 * y - ExprHandle body = (x - y) - (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), -2); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } -} - -// Test that chaining many ops together works as expected. -void testSimplifyManyOps() { - KernelScope kernel_scope; - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // x + y + x + x + y + y + x + y + x = 4 * y + 5 * x - ExprHandle body = x + y + x + x + y + y + x + y + x; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - - IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 4); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - - IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); - IS_VAR_WITH_NAME(rhs->rhs(), "x"); - } - - { - // x - y + x + x - y - y + x - y + x = 5 * x - 4 * y - ExprHandle body = x - y + x + x - y - y + x - y + x; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), add); - - IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 5); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // x + y + x - x - y - y + x + y + x = 3 * x - ExprHandle body = x + y + x - x - y - y + x + y + x; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 3); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } -} - -void testSimplifyFactorization() { - KernelScope kernel_scope; - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (2 * x) + (2 * y) => 2 * (x + y) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(2) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - } - - { - // Factorization when scalars have common divider. - // (2 * x) + (4 * y) => 2 * (2 * y + x) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_NODE_WITH_NAME(Mul, add->lhs(), mul2); - IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - IS_VAR_WITH_NAME(add->rhs(), "x"); - } - - { - // Factorization attempt without a common divider. - // (2 * x) + (5 * y) => (5 * y) + (2 * x) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(5) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - - IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 5); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - - IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 2); - IS_VAR_WITH_NAME(rhs->rhs(), "x"); - } - - { - // Factorization after merging. - // (2 * x) + (4 * y) + (8 * x + 6 * y) => 10 * (x + y) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y) + - (ExprHandle(8) * x + ExprHandle(6) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 10); - - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - } - - { - // Factorization with common divider but different signs. - // (-2 * x) + (4 * y) => -2 * (x - 2 * y) - ExprHandle body = (ExprHandle(-2) * x + ExprHandle(4) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), -2); - - IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); - IS_VAR_WITH_NAME(sub->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, sub->rhs(), mul2); - IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - } -} - -// (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (3 * z + y + 4 * x) -void testSimplifyFactorizeUneven() { - KernelScope kernel_scope; - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - ExprHandle body = - (ExprHandle(4) * x + y + z * 2) + (ExprHandle(4) * x + y + z * 4); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), root); - IS_IMM_WITH_VAL(Int, root->lhs(), 2); - IS_NODE_WITH_NAME(Add, root->rhs(), add1); - IS_NODE_WITH_NAME(Add, add1->lhs(), add2); - - IS_NODE_WITH_NAME(Mul, add1->rhs(), xmul); - IS_NODE_WITH_NAME(Mul, add2->lhs(), zmul); - - IS_IMM_WITH_VAL(Int, zmul->lhs(), 3); - IS_VAR_WITH_NAME(zmul->rhs(), "z"); - - IS_VAR_WITH_NAME(add2->rhs(), "y"); - - IS_IMM_WITH_VAL(Int, xmul->lhs(), 4); - IS_VAR_WITH_NAME(xmul->rhs(), "x"); -} - -// (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y) -// This is kind of a placeholder test for variable factorization. -void testSimplifyDeeperTerms() { - KernelScope kernel_scope; - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x * y) + (ExprHandle(2) * x) * (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - - IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); - IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm); - IS_VAR_WITH_NAME(xxTerm->lhs(), "x"); - IS_VAR_WITH_NAME(xxTerm->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 3); - IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm); - IS_VAR_WITH_NAME(xyTerm->lhs(), "x"); - IS_VAR_WITH_NAME(xyTerm->rhs(), "y"); -} - -// Tests the difference between two less trivial expressions. -// (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1 -void testSimplifyDeeperDifference() { - KernelScope kernel_scope; - VarHandle n("n", kInt); - VarHandle n_1("n_1", kInt); - VarHandle m("m", kInt); - ExprHandle body = - (m * (ExprHandle(1) * n_1) + (n + 1)) - (m * (ExprHandle(1) * n_1) + n); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 1); -} - -// Test constant folding into the difference between expressions. -// 2 + char((m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n)) => 3 -void testSimplifyFoldComplexDifference() { - KernelScope kernel_scope; - VarHandle n("n", kInt); - VarHandle n_1("n_1", kInt); - VarHandle m("m", kInt); - ExprHandle body = - (IntImm::make(2) + - (Cast::make( - kChar, - (m * (ExprHandle(1) * n_1) + (n + 1)) - - (m * (ExprHandle(1) * n_1) + n)))); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 3); -} - -void testSimplifyIfComponents() { - KernelScope kernel_scope; - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = IfThenElse::make( - ((ExprHandle(5) - ExprHandle(4)) * x) > y, - ExprHandle(2) * x - x, - ExprHandle(2) * y - y); - - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(IfThenElse, simplified.node(), ifexpr); - - IS_NODE_WITH_NAME(CompareSelect, ifexpr->condition(), cmp); - EXPECT_EQ(cmp->compare_select_op(), kGT); - IS_VAR_WITH_NAME(cmp->lhs(), "x"); - IS_VAR_WITH_NAME(cmp->rhs(), "y"); - - IS_VAR_WITH_NAME(ifexpr->true_value(), "x"); - IS_VAR_WITH_NAME(ifexpr->false_value(), "y"); -} - -void testSimplifyOpaqueTerms() { - KernelScope kernel_scope; - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // 2 * x/y * x - x/y * y => y * x/y - ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "y"); - IS_NODE_WITH_NAME(Div, mul->rhs(), div); - IS_VAR_WITH_NAME(div->lhs(), "x"); - IS_VAR_WITH_NAME(div->rhs(), "y"); - } - - { - // x%y - (x%y - 1) => 1 - ExprHandle body = (x % y) - ((x % y) - 1); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 1); - } -} - -void testSimplifyWontReorderFloat() { - KernelScope kernel_scope; - - { - // 3 * (3 * x) - 3 * (3 * y) => -9 * (y - x) - // This is an expression we can simplify. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - - ExprHandle(3) * (ExprHandle(3) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), -9); - IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); - IS_VAR_WITH_NAME(sub->lhs(), "y"); - IS_VAR_WITH_NAME(sub->rhs(), "x"); - } - - { - // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - 3 * (3 * y). - // If the vars are floating point, ops are not associative and we can't - // reorder. - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - - ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - - ExprHandle(3) * (ExprHandle(3) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); - IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); - IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); - IS_IMM_WITH_VAL(Float, lhsVarMul->lhs(), 3); - IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); - IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); - IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); - IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); - IS_VAR_WITH_NAME(rhsVarMul->rhs(), "y"); - } - - { - // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - (9 * y). - // We will simplify subexprs if they dont reorder floating point ops. - VarHandle x("x", kDouble); - VarHandle y("y", kInt); - - ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - - ExprHandle(3) * (ExprHandle(3) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); - IS_IMM_WITH_VAL(Double, lhsMul->lhs(), 3); - IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); - IS_IMM_WITH_VAL(Double, lhsVarMul->lhs(), 3); - IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); - - IS_NODE_WITH_NAME_AND_CAST(Mul, sub->rhs(), rhsMul, Double); - IS_IMM_WITH_VAL(Int, rhsMul->lhs(), 9); - IS_VAR_WITH_NAME(rhsMul->rhs(), "y"); - } - - { - // Prevent reordering if FP propagated from dtypes. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - ExprHandle body = ExprHandle(3.f) * (ExprHandle(3) * x) - - ExprHandle(3) * (ExprHandle(3.f) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); - IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); - IS_NODE_WITH_NAME_AND_CAST(Mul, lhsMul->rhs(), lhsVarMul, Float); - IS_IMM_WITH_VAL(Int, lhsVarMul->lhs(), 3); - IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); - IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); - IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); - IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); - IS_NODE_WITH_NAME(Cast, rhsVarMul->rhs(), yCast); - IS_VAR_WITH_NAME(yCast->src_value(), "y"); - } - - { - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - // x%y - (x%y - 1) => x%y - (x%y - 1). - // We wont reorder opaque ops if they are FP. - ExprHandle body = (x % y) - ((x % y) - 1); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mod, sub->lhs(), lhsMod); - IS_VAR_WITH_NAME(lhsMod->lhs(), "x"); - IS_VAR_WITH_NAME(lhsMod->rhs(), "y"); - - IS_NODE_WITH_NAME(Sub, sub->rhs(), rhsSub); - IS_NODE_WITH_NAME(Mod, rhsSub->lhs(), rhsMod); - IS_VAR_WITH_NAME(rhsMod->lhs(), "x"); - IS_VAR_WITH_NAME(rhsMod->rhs(), "y"); - IS_IMM_WITH_VAL(Float, rhsSub->rhs(), 1); - } + const Var* root = simplified.AsNode(); + EXPECT_NE(root, nullptr); + EXPECT_EQ(root->name_hint(), "y"); } } // namespace jit diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h index 4afb905e2032..4e43ac16d4ed 100644 --- a/test/cpp/tensorexpr/tests.h +++ b/test/cpp/tensorexpr/tests.h @@ -9,119 +9,105 @@ namespace torch { namespace jit { -#define TH_FORALL_TESTS(_) \ - _(ExprBasicValueTest) \ - _(ExprBasicValueTest02) \ - _(ExprLetTest01) \ - _(ExprLetStmtTest01) \ - _(ExprLetTest02) \ - _(ExprIntTest) \ - _(ExprFloatTest) \ - _(ExprByteTest) \ - _(ExprCharTest) \ - _(ExprShortTest) \ - _(ExprLongTest) \ - _(ExprHalfTest) \ - _(ExprDoubleTest) \ - _(ExprVectorAdd01) \ - _(ExprCompareSelectEQ) \ - _(ExprSubstitute01) \ - _(ExprMath01) \ - _(ExprUnaryMath01) \ - _(ExprBinaryMath01) \ - _(ExprDynamicShapeAdd) \ - _(ExprBitwiseOps) \ - _(IRPrinterBasicValueTest) \ - _(IRPrinterBasicValueTest02) \ - _(IRPrinterLetTest01) \ - _(IRPrinterLetTest02) \ - _(IRPrinterCastTest) \ - _(ExprSimple01) \ - _(ExprLower01) \ - _(ExprSimple02) \ - _(ExprSplitWithTailNone) \ - _(ExprSplitWithMask01) \ - _(ScheduleBroadcastAddBuffer) \ - _(ScheduleFunctionCall01) \ - _(ScheduleInlineFunc01) \ - _(ScheduleFuserStyle) \ - _(ScheduleFuserThreeArg) \ - _(ScheduleDynamicShape2D) \ - _(TypeTest01) \ - _(TypePropagation) \ - _(Cond01) \ - _(IfThenElse01) \ - _(IfThenElse02) \ - _(ATen_cast_Float) \ - _(ATennegInt) \ - _(ATennegFloat) \ - _(ATenaddInt) \ - _(ATenaddFloat) \ - _(ATensubInt) \ - _(ATensubFloat) \ - _(ATenlerp) \ - _(ATenaddcmulInt) \ - _(ATenaddcmulFloat) \ - _(ATenmulInt) \ - _(ATenmulFloat) \ - _(ATendivInt) \ - _(ATendivFloat) \ - _(ATenmaxInt) \ - _(ATenmaxFloat) \ - _(ATenminInt) \ - _(ATenminFloat) \ - _(ATen_sigmoid_backward) \ - _(ATen_tanh_backward) \ - _(ATenreciprocal) \ - _(ATenreluInt) \ - _(ATenreluFloat) \ - _(ATenlogFloat) \ - _(ATenlog10Float) \ - _(ATenlog2Float) \ - _(ATenexpFloat) \ - _(ATenerfFloat) \ - _(ATencosFloat) \ - _(ATeneqInt) \ - _(ATengeInt) \ - _(ATengtInt) \ - _(ATenleInt) \ - _(ATenltInt) \ - _(ConstantFoldSimple) \ - _(ConstantFoldTwoLayer) \ - _(ConstantFoldShifts) \ - _(ConstantFoldBitwise) \ - _(ConstantFoldMultiOp) \ - _(ConstantFoldMinMax) \ - _(ConstantFoldIntrinsics) \ - _(ConstantFoldWithVar) \ - _(UnFoldableExpr) \ - _(HashSimple) \ - _(HashEquivalence) \ - _(HashEquivalenceAfterFolding) \ - _(HashDifferenceTypes) \ - _(HashLargeExpression) \ - _(SimplifyAdd) \ - _(SimplifySub) \ - _(SimplifyMultiLayer) \ - _(SimplifyMultiTerm) \ - _(SimplifyCasts) \ - _(SimplifyEliminatesNoOps) \ - _(SimplifyMultiVar) \ - _(SimplifyReorderings) \ - _(SimplifyEliminatesVar) \ - _(SimplifyAdds) \ - _(SimplifyMuls) \ - _(SimplifySubs) \ - _(SimplifyMultiOp) \ - _(SimplifyManyOps) \ - _(SimplifyFactorization) \ - _(SimplifyFactorizeUneven) \ - _(SimplifyDeeperTerms) \ - _(SimplifyDeeperDifference) \ - _(SimplifyFoldComplexDifference) \ - _(SimplifyIfComponents) \ - _(SimplifyOpaqueTerms) \ - _(SimplifyWontReorderFloat) \ +#define TH_FORALL_TESTS(_) \ + _(ExprBasicValueTest) \ + _(ExprBasicValueTest02) \ + _(ExprLetTest01) \ + _(ExprLetStmtTest01) \ + _(ExprLetTest02) \ + _(ExprIntTest) \ + _(ExprFloatTest) \ + _(ExprByteTest) \ + _(ExprCharTest) \ + _(ExprShortTest) \ + _(ExprLongTest) \ + _(ExprHalfTest) \ + _(ExprDoubleTest) \ + _(ExprVectorAdd01) \ + _(ExprCompareSelectEQ) \ + _(ExprSubstitute01) \ + _(ExprMath01) \ + _(ExprUnaryMath01) \ + _(ExprBinaryMath01) \ + _(ExprDynamicShapeAdd) \ + _(ExprBitwiseOps) \ + _(IRPrinterBasicValueTest) \ + _(IRPrinterBasicValueTest02) \ + _(IRPrinterLetTest01) \ + _(IRPrinterLetTest02) \ + _(IRPrinterCastTest) \ + _(ExprSimple01) \ + _(ExprLower01) \ + _(ExprSimple02) \ + _(ExprSplitWithTailNone) \ + _(ExprSplitWithMask01) \ + _(ScheduleBroadcastAddBuffer) \ + _(ScheduleFunctionCall01) \ + _(ScheduleInlineFunc01) \ + _(ScheduleFuserStyle) \ + _(ScheduleFuserThreeArg) \ + _(ScheduleDynamicShape2D) \ + _(TypeTest01) \ + _(TypePropagation) \ + _(Cond01) \ + _(IfThenElse01) \ + _(IfThenElse02) \ + _(ATen_cast_Float) \ + _(ATennegInt) \ + _(ATennegFloat) \ + _(ATenaddInt) \ + _(ATenaddFloat) \ + _(ATensubInt) \ + _(ATensubFloat) \ + _(ATenlerp) \ + _(ATenaddcmulInt) \ + _(ATenaddcmulFloat) \ + _(ATenmulInt) \ + _(ATenmulFloat) \ + _(ATendivInt) \ + _(ATendivFloat) \ + _(ATenmaxInt) \ + _(ATenmaxFloat) \ + _(ATenminInt) \ + _(ATenminFloat) \ + _(ATen_sigmoid_backward) \ + _(ATen_tanh_backward) \ + _(ATenreciprocal) \ + _(ATenreluInt) \ + _(ATenreluFloat) \ + _(ATenlogFloat) \ + _(ATenlog10Float) \ + _(ATenlog2Float) \ + _(ATenexpFloat) \ + _(ATenerfFloat) \ + _(ATencosFloat) \ + _(ATeneqInt) \ + _(ATengeInt) \ + _(ATengtInt) \ + _(ATenleInt) \ + _(ATenltInt) \ + _(ConstantFoldSimple) \ + _(ConstantFoldTwoLayer) \ + _(ConstantFoldShifts) \ + _(ConstantFoldBitwise) \ + _(ConstantFoldMultiOp) \ + _(ConstantFoldMinMax) \ + _(ConstantFoldIntrinsics) \ + _(ConstantFoldWithVar) \ + _(UnFoldableExpr) \ + _(HashSimple) \ + _(HashEquivalence) \ + _(HashEquivalenceAfterFolding) \ + _(HashDifferenceTypes) \ + _(HashLargeExpression) \ + _(SimplifyAdd) \ + _(SimplifySub) \ + _(SimplifyMultiLayer) \ + _(SimplifyMultiTerm) \ + _(SimplifyCasts) \ + _(SimplifyEliminatesNoOps) \ + _(SimplifyMultiVar) \ + _(SimplifyEliminatesVar) \ _(StmtClone) #define TH_FORALL_TESTS_LLVM(_) \ diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index b4e56c0e83c4..de4a15339ae1 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -205,7 +205,6 @@ libtorch_sources = [ "torch/csrc/jit/tensorexpr/ir.cpp", "torch/csrc/jit/tensorexpr/ir_mutator.cpp", "torch/csrc/jit/tensorexpr/ir_printer.cpp", - "torch/csrc/jit/tensorexpr/ir_simplifier.cpp", "torch/csrc/jit/tensorexpr/ir_visitor.cpp", "torch/csrc/jit/tensorexpr/kernel.cpp", "torch/csrc/jit/tensorexpr/llvm_codegen.cpp", diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h index 58360b766655..68c8ad17534e 100644 --- a/torch/csrc/jit/tensorexpr/eval.h +++ b/torch/csrc/jit/tensorexpr/eval.h @@ -59,24 +59,24 @@ class Value { void* ptr; }; -#define VALUE_AS_DISPATCH(Type, Name) \ - template <> \ - inline Type Value::as() const { \ - if (dtype_ != k##Name) { \ - throw unsupported_dtype(); \ - } \ - return Name##values[0]; \ +#define VALUE_AS_DISPATCH(Type, Name) \ + template <> \ + inline Type Value::as() const { \ + if (dtype_ != k##Name) { \ + throw unsupported_dtype(); \ + } \ + return Name##values[0]; \ } AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_DISPATCH); #undef VALUE_AS_DISPATCH -#define VALUE_AS_VEC_DISPATCH(Type, Name) \ - template <> \ - inline const std::vector& Value::as_vec() const { \ - if (dtype_.scalar_type() != ScalarType::Name) { \ - throw unsupported_dtype(); \ - } \ - return Name##values; \ +#define VALUE_AS_VEC_DISPATCH(Type, Name) \ + template <> \ + inline const std::vector& Value::as_vec() const { \ + if (dtype_.scalar_type() != ScalarType::Name) { \ + throw unsupported_dtype(); \ + } \ + return Name##values; \ } AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_VEC_DISPATCH); #undef VALUE_AS_VEC_DISPATCH @@ -479,6 +479,7 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor { throw malformed_input(v); } + if (src_dtype != dst_dtype) { switch (src_dtype.scalar_type()) { #define SRC_TYPE_CASE(Type, Name) \ @@ -911,28 +912,6 @@ inline Stmt* Substitute(Stmt* stmt, const VarMapping& var_mapping) { return stmt->accept_mutator(&var_sub); } -// Uses the evaluator to fold an Expression with constant terms. -// E.g. evaluateOp(Add(3, 4)) => 7. -// Expr v must not have any unbound Vars. -static Expr* evaluateOp(const Expr* v) { - ExprHandle handle(v); - ExprEval eval(handle); - - switch (v->dtype().scalar_type()) { -#define TYPE_CASE(Type, Name) \ - case ScalarType::Name: { \ - Type val = eval.value(); \ - return getImmediateByType(v->dtype().scalar_type(), val); \ - } - AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); -#undef TYPE_CASE - default: - LOG(FATAL) << "Unsupported datatype: " << v->dtype(); - return nullptr; - } - return nullptr; -} - } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h index 00ffdaca6a74..ccda96cc1d81 100644 --- a/torch/csrc/jit/tensorexpr/expr.h +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -31,8 +31,6 @@ enum IRNodeType { kCompareSelect, kLet, kCast, - kBroadcast, - kRamp, kNone }; diff --git a/torch/csrc/jit/tensorexpr/hash_server.h b/torch/csrc/jit/tensorexpr/hash_server.h index 5f4b7c00c924..7f5b5097489f 100644 --- a/torch/csrc/jit/tensorexpr/hash_server.h +++ b/torch/csrc/jit/tensorexpr/hash_server.h @@ -1,5 +1,3 @@ -#pragma once - #include #include #include @@ -318,13 +316,6 @@ class HashProvider : public IRVisitor { putHash(v, hash); } - template - SimplifierHashType hash_combine(const Types&... args) { - SimplifierHashType seed = 0; - _hash_combine(seed, args...); - return seed; - } - private: SimplifierHashType hashOf(const Expr* e) { auto it = exprToHash_.find(e); @@ -380,10 +371,6 @@ class HashProvider : public IRVisitor { (seed << 7) + (seed >> 4); } - void _hash_combine(SimplifierHashType& seed, const Expr* e) { - _hash_combine(seed, hash(e)); - } - template void _hash_combine( SimplifierHashType& seed, @@ -393,6 +380,13 @@ class HashProvider : public IRVisitor { _hash_combine(seed, args...); } + template + SimplifierHashType hash_combine(const Types&... args) { + SimplifierHashType seed = 0; + _hash_combine(seed, args...); + return seed; + } + void putHash(const KernelScopedObject* e, SimplifierHashType h) { auto res = exprToHash_.emplace(e, h); if (res.second == false) { diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h index 136339f32e8f..a71c5d30b204 100644 --- a/torch/csrc/jit/tensorexpr/ir.h +++ b/torch/csrc/jit/tensorexpr/ir.h @@ -283,94 +283,24 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE); // Get immediate by ScalarType. template -Expr* getImmediateByType(ScalarType immType, T initialVal) { +ExprHandle getImmediateByType(ScalarType immType, T initialVal) { switch (immType) { #define TYPE_CASE(Type, Name) \ case ScalarType::Name: \ - return new Name##Imm(initialVal); + return Name##Imm::make(initialVal); AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); #undef TYPE_CASE default: throw unsupported_dtype(); } - return nullptr; + return ExprHandle(); } template -Expr* getImmediateByType(Dtype dtype, T initialVal) { +ExprHandle getImmediateByType(Dtype dtype, T initialVal) { return getImmediateByType(dtype.scalar_type(), initialVal); } -template -T immediateAs(const Expr* e) { -#define TYPE_CASE(Type, Name) \ - if (const Name##Imm* imm = dynamic_cast(e)) { \ - return imm->value(); \ - } - AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); -#undef TYPE_CASE - throw unsupported_dtype(); - return 0; -} - -template -bool immediateEquals(const Expr* e, T val) { -#define TYPE_CASE(Type, Name) \ - if (const Name##Imm* imm = dynamic_cast(e)) { \ - return imm->value() == val; \ - } - AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); -#undef TYPE_CASE - throw unsupported_dtype(); - return false; -} - -template -bool immediateIsNegative(const T* e) { -#define TYPE_CASE(Type, Name) \ - if (const Name##Imm* imm = dynamic_cast(e)) { \ - return imm->value() < 0; \ - } - AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); -#undef TYPE_CASE - return false; -} - -// Creates a new Expr of the given type with the provided lhs and rhs. -static const Expr* newBinaryOpOfType( - IRNodeType expr_type, - const Expr* lhs, - const Expr* rhs, - bool option) { - switch (expr_type) { - case IRNodeType::kAdd: - return new Add(lhs, rhs); - case IRNodeType::kSub: - return new Sub(lhs, rhs); - case IRNodeType::kMul: - return new Mul(lhs, rhs); - case IRNodeType::kDiv: - return new Div(lhs, rhs); - case IRNodeType::kMod: - return new Mod(lhs, rhs); - case IRNodeType::kMax: - return new Max(lhs, rhs, option); - case IRNodeType::kMin: - return new Min(lhs, rhs, option); - case IRNodeType::kAnd: - return new And(lhs, rhs); - case IRNodeType::kXor: - return new Xor(lhs, rhs); - case IRNodeType::kLshift: - return new Lshift(lhs, rhs); - case IRNodeType::kRshift: - return new Rshift(lhs, rhs); - default: - LOG(FATAL) << "unsupported expr_type: " << static_cast(expr_type); - return nullptr; - } -} - // Bind the value to the var and evaluate the body. class Let : public ExprNode { public: @@ -424,7 +354,7 @@ class Ramp : public ExprNode { } Ramp(const Expr* base, const Expr* stride, int lanes) - : ExprNodeBase(Dtype(base->dtype(), lanes), kRamp), + : ExprNodeBase(Dtype(base->dtype(), lanes)), base_(base), stride_(stride), lanes_(lanes) { @@ -490,7 +420,7 @@ class Broadcast : public ExprNode { return ExprHandle(new Broadcast(value.node(), lanes)); } Broadcast(const Expr* value, int lanes) - : ExprNodeBase(Dtype(value->dtype(), lanes), kBroadcast), + : ExprNodeBase(Dtype(value->dtype(), lanes)), value_(value), lanes_(lanes) {} @@ -632,7 +562,8 @@ class TORCH_API CompareSelect : public ExprNode { const ExprHandle& ret_val1, const ExprHandle& ret_val2, CompareSelectOperation cmp_op) { - if (lhs.dtype() != rhs.dtype() || ret_val1.dtype() != ret_val2.dtype()) { + if (lhs.dtype() != rhs.dtype() || + ret_val1.dtype() != ret_val2.dtype()) { throw malformed_input(); } return ExprHandle(new CompareSelect( @@ -859,8 +790,48 @@ class Intrinsics : public CallNode { IntrinsicsOp op_type_; }; -class Polynomial; -class Term; +/* An internal only Expr used in IR simplification. + * Encodes relationship y = Ax + B, where A and B are Immediates. + * Not required to be implemented by codegen. */ +class LinearForm : public ExprNode { + public: + LinearForm(const Expr* x, const Expr* A, const Expr* B) + : ExprNodeBase(dtypeFor(x, A, B)), x_(x), A_(A), B_(B) {} + + LinearForm(const Expr* x) + : ExprNodeBase(x->dtype()), + x_(x), + A_(new CharImm(1)), + B_(new CharImm(0)) {} + + const Expr* getX() const { + return x_; + } + const Expr* getA() const { + return A_; + } + const Expr* getB() const { + return B_; + } + + void setA(const Expr* A) { + A_ = A; + } + + void setB(const Expr* B) { + B_ = B; + } + + static Dtype dtypeFor(const Expr* A, const Expr* B, const Expr* C) { + return ToDtype(promoteTypes( + A->dtype().scalar_type(), promoteTypes(B->dtype(), C->dtype()))); + } + + private: + const Expr* x_; + const Expr* A_; + const Expr* B_; +}; class FunctionCall; diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.cpp b/torch/csrc/jit/tensorexpr/ir_mutator.cpp index 041e714c170e..716658e33c2c 100644 --- a/torch/csrc/jit/tensorexpr/ir_mutator.cpp +++ b/torch/csrc/jit/tensorexpr/ir_mutator.cpp @@ -2,7 +2,6 @@ #include #include -#include namespace torch { namespace jit { @@ -215,7 +214,6 @@ const Expr* IRMutator::mutate(const IfThenElse* v) { const Expr* condition_new = condition->accept_mutator(this); const Expr* true_value_new = true_value->accept_mutator(this); const Expr* false_value_new = false_value->accept_mutator(this); - if (condition == condition_new && true_value == true_value_new && false_value == false_value_new) { return v; @@ -234,24 +232,11 @@ const Expr* IRMutator::mutate(const FunctionCall* v) { return this->mutate(base); } -const Expr* IRMutator::mutate(const Term* v) { - const Expr* newScalar = v->scalar()->accept_mutator(this); - - std::vector variables; - for (const auto* t : v->variables()) { - variables.push_back(t->accept_mutator(this)); - } - return new Term(v->hasher(), newScalar, variables); -} - -const Expr* IRMutator::mutate(const Polynomial* v) { - const Expr* newScalar = v->scalar()->accept_mutator(this); - - std::vector variables; - for (const auto* t : v->variables()) { - variables.push_back(static_cast(t->accept_mutator(this))); - } - return new Polynomial(v->hasher(), newScalar, variables); +const Expr* IRMutator::mutate(const LinearForm* v) { + const Expr* new_x = v->getX()->accept_mutator(this); + const Expr* new_a = v->getA()->accept_mutator(this); + const Expr* new_b = v->getB()->accept_mutator(this); + return new LinearForm(new_x, new_a, new_b); } const Expr* IRMutator::mutate(const BaseCallNode* v) { diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.h b/torch/csrc/jit/tensorexpr/ir_mutator.h index 29e0a1f5dd20..a9aff366ad95 100644 --- a/torch/csrc/jit/tensorexpr/ir_mutator.h +++ b/torch/csrc/jit/tensorexpr/ir_mutator.h @@ -45,8 +45,7 @@ class Allocate; class Free; class Cond; class Stmt; -class Term; -class Polynomial; +class LinearForm; class TORCH_API IRMutator { public: @@ -87,8 +86,7 @@ class TORCH_API IRMutator { virtual const Expr* mutate(const Intrinsics* v); virtual const Expr* mutate(const FunctionCall* v); - virtual const Expr* mutate(const Term* v); - virtual const Expr* mutate(const Polynomial* v); + virtual const Expr* mutate(const LinearForm* v); virtual Stmt* mutate(const For* v); virtual Stmt* mutate(const Block* v); diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp index 7c377ebb1430..1260629c6a88 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.cpp +++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp @@ -1,7 +1,5 @@ #include -#include - namespace torch { namespace jit { namespace tensorexpr { @@ -200,7 +198,7 @@ template < typename T, std::enable_if_t::value>* = nullptr> static void formatImm(std::ostream& os, T v) { - os << +v; + os << v; } // NOLINTNEXTLINE @@ -370,33 +368,9 @@ void IRPrinter::visit(const Cond* v) { } } -void IRPrinter::visit(const Term* v) { - os() << "Term("; - v->scalar()->accept(this); - for (auto* t : v->variables()) { - os() << ","; - t->accept(this); - } - os() << ")"; -} - -void IRPrinter::visit(const Polynomial* v) { - bool first = true; - os() << "Polynomial("; - for (auto* t : v->variables()) { - emitIndent(); - if (!first) { - os() << " + "; - } - first = false; - t->accept(this); - } - - if (!first) { - os() << " + "; - } - v->scalar()->accept(this); - os() << ")"; +void IRPrinter::visit(const LinearForm* v) { + os() << "(" << *v->getA() << ") * (" << *v->getX() << ") + (" << *v->getB() + << ")" << std::endl; } void IRPrinter::emitIndent() { diff --git a/torch/csrc/jit/tensorexpr/ir_printer.h b/torch/csrc/jit/tensorexpr/ir_printer.h index f0af7b6dc45f..82ccf086258b 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.h +++ b/torch/csrc/jit/tensorexpr/ir_printer.h @@ -48,8 +48,7 @@ class TORCH_API IRPrinter : public IRVisitor { void visit(const Allocate* v) override; void visit(const Free* v) override; void visit(const Cond* v) override; - void visit(const Term* v) override; - void visit(const Polynomial* v) override; + void visit(const LinearForm* v) override; std::ostream& os() { return printer_os_; diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp deleted file mode 100644 index aa308204aa53..000000000000 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp +++ /dev/null @@ -1,1086 +0,0 @@ -#include - -namespace torch { -namespace jit { -namespace tensorexpr { - -SimplifierHashType Term::hashVars() const { - SimplifierHashType hash = 0; - for (auto* v : variables_) { - hash = hasher_.hash_combine(hash, hasher_.hash(v)); - } - - return hash; -} - -void Term::sort() { - // order of ops important for float - if (dtype().is_floating_point()) { - throw std::logic_error("reordering FP ops"); - } - std::sort( - variables_.begin(), variables_.end(), [&](const Expr* a, const Expr* b) { - return hasher_.hash(a) < hasher_.hash(b); - }); -} - -SimplifierHashType Polynomial::hashVars() const { - SimplifierHashType hash = 0; - for (auto* v : variables_) { - hash = hasher_.hash_combine(hash, hasher_.hash(v)); - } - return hash; -} - -void Polynomial::sort() { - if (dtype().is_floating_point()) { - throw std::logic_error("reordering FP ops"); - } - std::sort( - variables_.begin(), variables_.end(), [&](const Expr* a, const Expr* b) { - return hasher_.hash(a) < hasher_.hash(b); - }); -} - -// Handles optimization cases for Broadcast/Ramp +/- Broadcast/Ramp -template -const Expr* combineMultilane(const Expr* lhs, const Expr* rhs) { - if (const Broadcast* bc = dynamic_cast(lhs)) { - if (const Broadcast* bcother = dynamic_cast(rhs)) { - if (bc->lanes() != bcother->lanes()) { - throw malformed_input("multilane lane mismatch"); - } - - const Expr* ret = - new Broadcast(new Op(bc->value(), bcother->value()), bc->lanes()); - return ret; - } - - if (const Ramp* r = dynamic_cast(rhs)) { - if (bc->lanes() != r->lanes()) { - throw malformed_input("multilane lane mismatch"); - } - - const Expr* ret = - new Ramp(new Op(bc->value(), r->base()), r->stride(), r->lanes()); - return ret; - } - } else if (const Ramp* ramp = dynamic_cast(lhs)) { - if (const Ramp* rother = dynamic_cast(rhs)) { - if (ramp->lanes() != rother->lanes()) { - throw malformed_input("multilane lane mismatch"); - } - - const Expr* ret = new Ramp( - new Op(ramp->base(), rother->base()), - new Op(ramp->stride(), rother->stride()), - ramp->lanes()); - return ret; - } - - if (const Broadcast* bc = dynamic_cast(rhs)) { - if (ramp->lanes() != bc->lanes()) { - throw malformed_input("multilane lane mismatch"); - } - const Expr* ret = new Ramp( - new Op(bc->value(), ramp->base()), ramp->stride(), ramp->lanes()); - return ret; - } - } - - return nullptr; -} - -// Handles optimization cases for Broadcast/Ramp * Broadcast/Ramp -const Expr* mulMultilane(const Expr* lhs, const Expr* rhs) { - if (const Broadcast* bc = dynamic_cast(lhs)) { - if (const Broadcast* bcother = dynamic_cast(rhs)) { - if (bc->lanes() != bcother->lanes()) { - throw malformed_input("multilane lane mismatch"); - } - - const Expr* ret = - new Broadcast(new Mul(bc->value(), bcother->value()), bc->lanes()); - return ret; - } - - if (const Ramp* r = dynamic_cast(rhs)) { - if (bc->lanes() != r->lanes()) { - throw malformed_input("multilane lane mismatch"); - } - - const Expr* ret = new Ramp( - new Mul(bc->value(), r->base()), - new Mul(bc->value(), r->stride()), - r->lanes()); - return ret; - } - } else if (const Ramp* ramp = dynamic_cast(lhs)) { - if (const Ramp* r = dynamic_cast(rhs)) { - if (ramp->lanes() != r->lanes()) { - throw malformed_input("multilane lane mismatch"); - } - - const Expr* ret = new Ramp( - new Mul(ramp->base(), r->base()), - new Mul(ramp->stride(), r->stride()), - r->lanes()); - return ret; - } - - if (const Broadcast* bc = dynamic_cast(rhs)) { - if (ramp->lanes() != bc->lanes()) { - throw malformed_input("multilane lane mismatch"); - } - - const Expr* ret = new Ramp( - new Mul(bc->value(), ramp->base()), - new Mul(bc->value(), ramp->stride()), - ramp->lanes()); - return ret; - } - } - - return nullptr; -} - -void PolynomialTransformer::addOrUpdateTerm( - std::unordered_map& varmap, - const Term* term) { - SimplifierHashType hash = term->hashVars(); - auto insertRes = varmap.emplace(hash, term); - if (insertRes.second == false) { - const Term* lt = insertRes.first->second; - const Expr* termScalar = evaluateOp(new Add(lt->scalar(), term->scalar())); - - // If the term is canceled out, remove from the map. - if (immediateEquals(termScalar, 0)) { - varmap.erase(hash); - return; - } - - varmap[hash] = new Term(hasher_, termScalar, lt->variables()); - } -} - -const Expr* PolynomialTransformer::addPolynomials( - const Polynomial* lhs, - const Polynomial* rhs) { - // simplify common components - // The key here is the variable hash, not the term's hash since we do want - // to combine terms that have the same vars but different scalar components. - std::unordered_map varmap; - - for (auto* lt : lhs->variables()) { - addOrUpdateTerm(varmap, lt); - } - for (auto* rt : rhs->variables()) { - addOrUpdateTerm(varmap, rt); - } - - const Expr* newScalar = evaluateOp(new Add(lhs->scalar(), rhs->scalar())); - return new Polynomial(hasher_, newScalar, varmap); -} - -// Insert a new Term into the provided polynomial. If the new term has common -// variables to an existing term it is combined. -const Expr* PolynomialTransformer::insertTerm( - const Polynomial* poly, - const Term* term) { - SimplifierHashType tHash = term->hashVars(); - std::vector newVars; - - bool found = false; - for (auto* v : poly->variables()) { - if (v->hashVars() == tHash) { - const Expr* newScalar = evaluateOp(new Add(term->scalar(), v->scalar())); - found = true; - // Skip this term if we cancelled it out. - if (immediateEquals(newScalar, 0)) { - continue; - } - auto* term = new Term(hasher_, newScalar, v->variables()); - newVars.push_back(term); - } else { - newVars.push_back(v); - } - } - - if (!found) { - newVars.push_back(term); - } - - if (newVars.empty()) { - return poly->scalar(); - } - - auto* Poly = new Polynomial(hasher_, poly->scalar(), newVars); - return Poly; -} - -const Expr* PolynomialTransformer::mutate(const Add* v) { - const Expr* lhs_new = v->lhs()->accept_mutator(this); - const Expr* rhs_new = v->rhs()->accept_mutator(this); - - // Constant Folding. - if (lhs_new->isConstant() && rhs_new->isConstant()) { - const Expr* result = evaluateOp(new Add(lhs_new, rhs_new)); - return result; - } - - // Multilane folding. - if (isMultilanePrimitive(lhs_new)) { - if (auto* ret = combineMultilane(lhs_new, rhs_new)) { - return ret->accept_mutator(this); - } - } - - // If this is a floating point Add then order of operations is important, we - // dont want to combine ops. - if (lhs_new->dtype().is_floating_point() || - rhs_new->dtype().is_floating_point()) { - return new Add(lhs_new, rhs_new); - } - - const Polynomial* lhsPoly = dynamic_cast(lhs_new); - const Polynomial* rhsPoly = dynamic_cast(rhs_new); - - if (lhsPoly && rhsPoly) { - return addPolynomials(lhsPoly, rhsPoly); - } - - const Term* lhsTerm = dynamic_cast(lhs_new); - const Term* rhsTerm = dynamic_cast(rhs_new); - - if (lhsPoly && rhsTerm) { - return insertTerm(lhsPoly, rhsTerm); - } - - if (rhsPoly && lhsTerm) { - return insertTerm(rhsPoly, lhsTerm); - } - - if (lhsTerm && rhsTerm) { - // If the terms refer to the same variables: combine them. - if (lhsTerm->hashVars() == rhsTerm->hashVars()) { - const Expr* newScalar = - evaluateOp(new Add(lhsTerm->scalar(), rhsTerm->scalar())); - - // If the terms cancelled out, return zero. - if (immediateEquals(newScalar, 0)) { - return newScalar->accept_mutator(this); - } - - return new Term(hasher_, newScalar, lhsTerm->variables()); - } - - // Otherwise this is a new polynomial with no scalar and two variable - // terms. - return new Polynomial( - hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm); - } - - const Expr* scalar = nullptr; - const Expr* variable = nullptr; - if (lhs_new->isConstant()) { - scalar = evaluateOp(lhs_new); - variable = rhs_new; - } else if (rhs_new->isConstant()) { - scalar = evaluateOp(rhs_new); - variable = lhs_new; - } - - // If there is a scalar, and it's zero: short circuit and return the other - // side. - if (scalar && immediateEquals(scalar, 0)) { - return variable; - } - - // Adds are commutative. - const Polynomial* poly = lhsPoly ? lhsPoly : rhsPoly; - - // Add to Polynomial->scalar(). - if (scalar && poly) { - const Expr* newScalar = evaluateOp(new Add(scalar, poly->scalar())); - return new Polynomial(hasher_, newScalar, poly->variables()); - } - - // Simple Polynomial with a scalar and Term. - const Term* term = lhsTerm ? lhsTerm : rhsTerm; - if (scalar && term) { - return new Polynomial(hasher_, scalar, term); - } - - // Simple Term with a scalar and variable type. - if (scalar) { - return new Polynomial( - hasher_, - scalar, - new Term(hasher_, getImmediateByType(v->dtype(), 1), variable)); - } - - // If LHS is neither Term not Polynomial, wrap it in a Term. - if (!lhsTerm && !lhsPoly) { - lhsTerm = new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new); - } - - // Same for RHS. - if (!rhsTerm && !rhsPoly) { - rhsTerm = new Term(hasher_, getImmediateByType(v->dtype(), 1), rhs_new); - } - - // If we now have a poly and a term, we can insert. - if (poly) { - return insertTerm(poly, lhsTerm ? lhsTerm : rhsTerm); - } - - // If all else fails we have a new Polynomial with two new variable Terms. - return new Polynomial( - hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm); -} - -const Expr* PolynomialTransformer::subTerms( - const Term* lhs, - const Term* rhs, - bool negated) { - // If RHS not already negated, negate it. - if (!negated) { - const Expr* minusOne = getImmediateByType(rhs->dtype(), -1); - const Expr* negateScalar = evaluateOp(new Mul(minusOne, rhs->scalar())); - rhs = new Term(hasher_, negateScalar, rhs->variables()); - } - - if (lhs->hashVars() == rhs->hashVars()) { - const Expr* newScalar = evaluateOp(new Add(lhs->scalar(), rhs->scalar())); - - // If the terms cancel out, return zero. - if (immediateEquals(newScalar, 0)) { - return newScalar; - } - - return new Term(hasher_, newScalar, lhs->variables()); - } - - return new Polynomial( - hasher_, - getImmediateByType(promoteTypes(lhs->dtype(), rhs->dtype()), 0), - lhs, - rhs); -} - -// Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where -// possible. -const Expr* PolynomialTransformer::subPolynomials( - const Polynomial* lhs, - const Polynomial* rhs) { - // simplify common components - // The key here is the variable hash, not the term's hash since we do want - // to combine terms that have the same vars but different scalar components. - std::unordered_map varmap; - - for (auto* lt : lhs->variables()) { - addOrUpdateTerm(varmap, lt); - } - - for (auto* rt : rhs->variables()) { - // Polynomials add their terms, so negate the RHS's Terms. - const Expr* negated = - evaluateOp(new Mul(getImmediateByType(rt->dtype(), -1), rt->scalar())); - Term* newRHS = new Term(hasher_, negated, rt->variables()); - addOrUpdateTerm(varmap, newRHS); - } - - const Expr* newScalar = evaluateOp(new Sub(lhs->scalar(), rhs->scalar())); - - // No vars means this cancelled out to a scalar, return it unwrapped. - if (varmap.empty()) { - return newScalar; - } - - // If there is no scalar and zero or one terms, don't wrap. - if (immediateEquals(newScalar, 0)) { - if (varmap.empty()) { - return nullptr; - } - if (varmap.size() == 1) { - return varmap.begin()->second; - } - } - - // Wrap new variables in a Polynomial. - return new Polynomial(hasher_, newScalar, varmap); -} - -const Expr* PolynomialTransformer::mutate(const Sub* v) { - const Expr* lhs_new = v->lhs()->accept_mutator(this); - const Expr* rhs_new = v->rhs()->accept_mutator(this); - - // Constant Folding. - if (lhs_new->isConstant() && rhs_new->isConstant()) { - const Expr* result = evaluateOp(new Sub(lhs_new, rhs_new)); - return result; - } - - // Multilane folding. - if (isMultilanePrimitive(lhs_new)) { - if (auto* ret = combineMultilane(lhs_new, rhs_new)) { - return ret->accept_mutator(this); - } - } - - // If this is a floating point Sub then order of operations is important, we - // dont want to combine ops. - if (lhs_new->dtype().is_floating_point() || - rhs_new->dtype().is_floating_point()) { - return new Sub(lhs_new, rhs_new); - } - - const Polynomial* lhsPoly = dynamic_cast(lhs_new); - const Polynomial* rhsPoly = dynamic_cast(rhs_new); - - if (lhsPoly && rhsPoly) { - auto* ret = subPolynomials(lhsPoly, rhsPoly); - if (!ret) { - // Cancelled out completely. - return getImmediateByType(v->dtype(), 0); - } - return ret; - } - - const Term* lhsTerm = dynamic_cast(lhs_new); - const Term* rhsTerm = dynamic_cast(rhs_new); - - // Polynomial - Term. - if (lhsPoly && rhsTerm) { - // Negate the term. - const Expr* negate = evaluateOp( - new Mul(getImmediateByType(rhsTerm->dtype(), -1), rhsTerm->scalar())); - const Term* newTerm = new Term(hasher_, negate, rhsTerm->variables()); - return insertTerm(lhsPoly, newTerm); - } - - // Term - Polynomial. - if (rhsPoly && lhsTerm) { - // Negate every part of the Polynomial. - const Expr* minusOne = getImmediateByType(lhsTerm->dtype(), -1); - const Expr* negateScalar = evaluateOp(new Mul(minusOne, lhsTerm->scalar())); - - std::vector variables; - for (auto* t : lhsPoly->variables()) { - const Expr* negate = evaluateOp(new Mul(minusOne, t->scalar())); - variables.push_back(new Term(hasher_, negate, t->variables())); - } - - Polynomial* newPoly = new Polynomial(hasher_, negateScalar, variables); - return insertTerm(newPoly, lhsTerm); - } - - if (lhsTerm && rhsTerm) { - return subTerms(lhsTerm, rhsTerm, false); - } - - bool lhsScalar = lhs_new->isConstant(); - bool rhsScalar = rhs_new->isConstant(); - - if (lhsPoly && rhsScalar) { - // Easy path, just sub the scalar component. - const Expr* newScalar = evaluateOp(new Sub(lhsPoly->scalar(), rhs_new)); - return new Polynomial(hasher_, newScalar, lhsPoly->variables()); - } - - if (lhsScalar && rhsPoly) { - // Sub the scalar component. - const Expr* newScalar = evaluateOp(new Sub(lhs_new, rhsPoly->scalar())); - - // Negate each term in the Polynomial RHS. - const Expr* minusOne = getImmediateByType(rhsPoly->dtype(), -1); - std::vector variables; - for (auto* t : rhsPoly->variables()) { - const Expr* negate = evaluateOp(new Mul(minusOne, t->scalar())); - variables.push_back(new Term(hasher_, negate, t->variables())); - } - - return new Polynomial(hasher_, newScalar, variables); - } - - if (lhsTerm && rhsScalar) { - // Negate the constant. - const Expr* negate = - evaluateOp(new Mul(getImmediateByType(rhs_new->dtype(), -1), rhs_new)); - return new Polynomial(hasher_, negate, lhsTerm); - } - - if (lhsScalar && rhsTerm) { - // Negate the RHS Term. - const Expr* negate = evaluateOp(new Mul( - getImmediateByType(rhsTerm->scalar()->dtype(), -1), rhsTerm->scalar())); - - return new Polynomial( - hasher_, lhs_new, new Term(hasher_, negate, rhsTerm->variables())); - } - - // simple term with a scalar and variable type. - if (lhsScalar) { - // Create a negated term. - return new Polynomial( - hasher_, - lhs_new, - new Term(hasher_, getImmediateByType(v->dtype(), -1), rhs_new)); - } - - if (rhsScalar) { - // Negate the scalar. - const Expr* negate = - evaluateOp(new Mul(getImmediateByType(rhs_new->dtype(), -1), rhs_new)); - return new Polynomial( - hasher_, - negate, - new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new)); - } - - // no scalar... - if (!lhsTerm && !lhsPoly) { - lhsTerm = new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new); - } - - bool createdRHSnegated = false; - if (!rhsTerm && !rhsPoly) { - rhsTerm = new Term(hasher_, getImmediateByType(v->dtype(), -1), rhs_new); - createdRHSnegated = true; - } - - if (lhsTerm && rhsTerm) { - return subTerms(lhsTerm, rhsTerm, createdRHSnegated); - } - - // Insert wrapped Term into LHS Polynomial. - if (lhsPoly) { - CHECK(rhsTerm); - return insertTerm(lhsPoly, rhsTerm); - } - - // Insert wrapper Term into negated RHS Poly. - if (rhsPoly) { - CHECK(lhsTerm); - const Expr* minusOne = getImmediateByType(rhsPoly->dtype(), -1); - const Expr* newScalar = evaluateOp(new Mul(minusOne, rhsPoly->scalar())); - - // Negate each term in the Polynomial RHS. - std::vector variables; - for (auto* t : rhsPoly->variables()) { - const Expr* negate = evaluateOp(new Mul(minusOne, t->scalar())); - variables.push_back(new Term(hasher_, negate, t->variables())); - } - - auto* poly = new Polynomial(hasher_, newScalar, variables); - return insertTerm(poly, lhsTerm); - } - - return new Polynomial( - hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm); -} - -// Multiply two terms together, usually creating a new term with the variable -// lists concatenated. -const Term* PolynomialTransformer::mulTerms(const Term* lhs, const Term* rhs) { - const Expr* scalar = evaluateOp(new Mul(lhs->scalar(), rhs->scalar())); - if (immediateEquals(scalar, 0)) { - return nullptr; - } - - // Can reorder here since floating point ops don't get put into Terms. - std::vector variables; - std::vector multilaneVariables; - // For now don't handle exponents. - for (auto* c : lhs->variables()) { - if (isMultilanePrimitive(c)) { - multilaneVariables.push_back(c); - } else { - variables.push_back(c); - } - } - for (auto* c : rhs->variables()) { - if (isMultilanePrimitive(c)) { - multilaneVariables.push_back(c); - } else { - variables.push_back(c); - } - } - - // Merge all the multilane vars: - const Expr* lastNode{nullptr}; - for (auto* node : multilaneVariables) { - if (lastNode == nullptr) { - lastNode = node; - } else { - if (auto* next = mulMultilane(lastNode, node)) { - lastNode = next->accept_mutator(this); - } else { - variables.push_back(lastNode); - lastNode = node; - } - } - } - if (lastNode) { - variables.push_back(lastNode); - } - - return new Term(hasher_, scalar, variables); -} - -// Multiply a Polynomial by a Term. -const Expr* PolynomialTransformer::polyByTerm( - const Polynomial* poly, - const Term* term) { - std::vector newTerms; - - // scalar Term - const Expr* scalar = evaluateOp(new Mul(poly->scalar(), term->scalar())); - - for (auto* var : poly->variables()) { - const Term* newTerm = mulTerms(var, term); - if (newTerm) { - newTerms.push_back(newTerm); - } - } - - if (newTerms.empty()) { - return scalar; - } - - return new Polynomial(hasher_, scalar, std::move(newTerms)); -} - -const Expr* PolynomialTransformer::mutate(const Mul* v) { - const Expr* lhs_new = v->lhs()->accept_mutator(this); - const Expr* rhs_new = v->rhs()->accept_mutator(this); - - // Constant Folding. - if (lhs_new->isConstant() && rhs_new->isConstant()) { - return evaluateOp(new Mul(lhs_new, rhs_new)); - } - - // Multilane folding. - if (isMultilanePrimitive(lhs_new)) { - if (auto* ret = mulMultilane(lhs_new, rhs_new)) { - return ret->accept_mutator(this); - } - } - - // If this is a floating point Mul then order of operations is important, we - // dont want to combine ops. - if (lhs_new->dtype().is_floating_point() || - rhs_new->dtype().is_floating_point()) { - return new Mul(lhs_new, rhs_new); - } - - const Polynomial* lhsPoly = dynamic_cast(lhs_new); - const Polynomial* rhsPoly = dynamic_cast(rhs_new); - - if (lhsPoly && rhsPoly) { - // This expands to more terms that we can't generally fix without variable - // factorization, it's more efficient to just leave these as Muls. - return new Mul(lhsPoly, rhsPoly); - } - - const Term* lhsTerm = dynamic_cast(lhs_new); - const Term* rhsTerm = dynamic_cast(rhs_new); - - if (lhsPoly && rhsTerm) { - return polyByTerm(lhsPoly, rhsTerm); - } - - if (rhsPoly && lhsTerm) { - return polyByTerm(rhsPoly, lhsTerm); - } - - if (lhsTerm && rhsTerm) { - return mulTerms(lhsTerm, rhsTerm); - } - - const Expr* scalar = nullptr; - const Expr* variable = nullptr; - if (lhs_new->isConstant()) { - scalar = lhs_new; - variable = rhs_new; - } else if (rhs_new->isConstant()) { - scalar = rhs_new; - variable = lhs_new; - } - - if (scalar && lhsTerm) { - const Expr* newScalar = evaluateOp(new Mul(scalar, lhsTerm->scalar())); - if (immediateEquals(newScalar, 0)) { - return newScalar; - } - return new Term(hasher_, newScalar, lhsTerm->variables()); - } - - if (scalar && rhsTerm) { - const Expr* newScalar = evaluateOp(new Mul(scalar, rhsTerm->scalar())); - - if (immediateEquals(newScalar, 0)) { - return newScalar; - } - return new Term(hasher_, newScalar, rhsTerm->variables()); - } - - // If this is a scalar * a Polynomial, push the scalar term down. - // We can wrap the scalar with a Term and use polyByTerm. - if (scalar && lhsPoly) { - return polyByTerm(lhsPoly, new Term(hasher_, scalar)); - } - if (scalar && rhsPoly) { - return polyByTerm(rhsPoly, new Term(hasher_, scalar)); - } - - // simple term with a scalar and variable type. - if (scalar) { - return new Term(hasher_, scalar, variable); - } - - // Multiplying Polynomial by variable can be wrapped in a term and handled - // by polyByTerm also. - if (lhsPoly) { - auto* term = - new Term(hasher_, getImmediateByType(rhs_new->dtype(), 1), rhs_new); - return polyByTerm(lhsPoly, term); - } - if (rhsPoly) { - auto* term = - new Term(hasher_, getImmediateByType(lhs_new->dtype(), 1), lhs_new); - return polyByTerm(rhsPoly, term); - } - - // Multiplying Term by a variable is equivalent to adding the variable to - // the term's list of vars. - if (lhsTerm) { - std::vector vars = lhsTerm->variables(); - vars.push_back(rhs_new); - return new Term(hasher_, lhsTerm->scalar(), vars); - } - if (rhsTerm) { - std::vector vars = rhsTerm->variables(); - vars.push_back(lhs_new); - return new Term(hasher_, rhsTerm->scalar(), vars); - } - - // Two variables, create a new Term. - return new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new, rhs_new); -} - -const Expr* PolynomialTransformer::mutate(const Intrinsics* v) { - std::vector new_params; - bool changed = false; - bool allConstant = true; - for (const auto* p : v->params()) { - const Expr* new_child = p->accept_mutator(this); - new_params.push_back(new_child); - - changed |= p != new_child; - allConstant &= new_child->isConstant(); - } - - const Expr* node = v; - if (changed) { - node = new Intrinsics(v->op_type(), new_params); - } - - if (!allConstant || !v->isPure()) { - return node; - } - - // we're evaluating, but the evaluator only supports float intrinsics. - std::vector const_params; - changed = false; - for (const auto* p : new_params) { - if (p->dtype().scalar_type() == ScalarType::Float) { - const_params.push_back(p); - } else { - const_params.push_back( - new Cast(Dtype(ScalarType::Float, p->dtype().lanes()), p)); - changed = true; - } - } - - if (changed) { - node = new Intrinsics(v->op_type(), const_params); - } - return evaluateOp(node); -} - -const Expr* PolynomialTransformer::mutate(const Cast* v) { - const Expr* node = v->src_value()->accept_mutator(this); - if (node->isConstant()) { - return evaluateOp(new Cast(v->dtype(), node)); - } - - return new Cast(v->dtype(), node); -} - -// TermExpander - -const Expr* TermExpander::mutate(const Term* v) { - const Expr* newScalar = v->scalar()->accept_mutator(this); - if (immediateEquals(newScalar, 0)) { - return newScalar; - } - - std::vector vars; - std::vector multilaneVars; - - // Assume we can reorder here because we wont merge floating terms. - const Expr* lastNode{nullptr}; - for (auto* var : v->variables()) { - const Expr* node = var->accept_mutator(this); - if (const Mul* mul = dynamic_cast(node)) { - // If the sub-Expr resolved to a multiplication, lift it into this - // term. - if (isMultilanePrimitive(mul->lhs())) { - multilaneVars.push_back(mul->lhs()); - } else { - vars.push_back(mul->lhs()); - } - - if (isMultilanePrimitive(mul->rhs())) { - multilaneVars.push_back(mul->rhs()); - } else { - vars.push_back(mul->lhs()); - } - } else { - if (isMultilanePrimitive(node)) { - multilaneVars.push_back(node); - } else { - vars.push_back(node); - } - } - } - - for (auto* node : multilaneVars) { - if (lastNode == nullptr) { - lastNode = node; - } else { - lastNode = mulMultilane(lastNode, node); - // simplify first, then re-expand. - lastNode = lastNode->accept_mutator(simplifier_); - lastNode = lastNode->accept_mutator(this); - } - } - - for (auto* node : vars) { - if (lastNode == nullptr) { - lastNode = node; - } else { - lastNode = new Mul(lastNode, node); - } - } - - if (!immediateEquals(newScalar, 1)) { - if (lastNode) { - // We want to avoid a leaving a CastNode on the scalar, so handle that - // now. - if (v->scalar()->dtype() != lastNode->dtype()) { - lastNode = new Mul( - evaluateOp(new Cast(lastNode->dtype(), v->scalar())), lastNode); - } else { - lastNode = new Mul(v->scalar(), lastNode); - } - } else { - lastNode = v->scalar(); - } - } - - return lastNode; -} - -// Simple recursive GCD. -template -T gcd(T a, T b) { - if (b == 0) { - return a; - } - return gcd(b, a % b); -} - -// Returns an immediate containing the greatest common divisor of all terms -// (inc. the scalar term) in the polynomial. If the GCD is uninteresting -// (e.g. 1) then returns nullptr. -const Expr* polyGCD(const Polynomial* poly) { - const Expr* scalar = poly->scalar(); - const std::vector& variables = poly->variables(); - - // We ony want to factorize if we're saving complete operations, i.e. no - // value in factorizing 6x + 4y into 2 * (3x + 2y) since we don't save work. - int opsSaved = 1; // default to saving the scalar. - long GCD = immediateAs(scalar); - for (auto* t : variables) { - long termScalar = immediateAs(t->scalar()); - long newGCD = gcd(std::max(GCD, termScalar), std::min(GCD, termScalar)); - if (newGCD == 1) { - return nullptr; - } - - if (GCD != newGCD) { - opsSaved = 0; - GCD = newGCD; - } - - if (GCD == termScalar) { - opsSaved++; - } - } - - if (opsSaved == 0) { - return nullptr; - } - - // Not worth, can be a Sub. - if (GCD == -1 && opsSaved == 1) { - return nullptr; - } - - return getImmediateByType(poly->dtype(), GCD); -} - -// Trivially factorize terms by GCD of scalar components. -const Expr* TermExpander::factorizePolynomial(const Polynomial* poly) { - const Expr* scalar = poly->scalar(); - const std::vector& variables = poly->variables(); - bool floatScalars = false; - - // Check types. - for (auto& p : variables) { - if (is_floating_point(p->dtype().scalar_type()) || - is_floating_point(p->scalar()->dtype().scalar_type())) { - floatScalars = true; - } - } - if (is_floating_point(scalar->dtype().scalar_type())) { - floatScalars = true; - } - - // floating point isn't generally distributive. - if (floatScalars) { - return nullptr; - } - - // Compute the GCD of terms. - const Expr* GCD = polyGCD(poly); - - // No GCD means 0 or 1 and can't be factored. - if (!GCD) { - return nullptr; - } - - // Create new struture. - std::vector newPolyTerms; - for (auto* t : variables) { - // New term with the scalar divided by the GCD. - newPolyTerms.push_back(new Term( - poly->hasher(), evaluateOp(new Div(t->scalar(), GCD)), t->variables())); - } - - Polynomial* newPoly = new Polynomial( - poly->hasher(), evaluateOp(new Div(scalar, GCD)), newPolyTerms); - - return new Term(poly->hasher(), GCD, newPoly); -} - -const Expr* TermExpander::mutate(const Polynomial* v) { - if (v->variables().empty()) { - return v->scalar(); - } - - // If this Polynomial can be factorized: do it, then expand the result. - if (const Expr* factorized = factorizePolynomial(v)) { - return factorized->accept_mutator(this); - } - - std::vector addTerms; - std::vector subTerms; - - // partition the terms into a list to add and list to subtract. - for (auto* node : v->variables()) { - if (immediateIsNegative(node->scalar())) { - subTerms.push_back(node); - } else if (!immediateEquals(node->scalar(), 0)) { - addTerms.push_back(node); - } - // Skip terms with a scalar of zero. - } - - // The last node constructed. - const Expr* lastNode{nullptr}; - - for (auto* node : addTerms) { - const Expr* simpleNode = node->accept_mutator(this); - - if (lastNode == nullptr) { - lastNode = simpleNode; - continue; - } - - if (isMultilanePrimitive(simpleNode)) { - auto* ret = combineMultilane(lastNode, simpleNode); - if (ret) { - // simplify result first, then expand. - lastNode = ret->accept_mutator(simplifier_); - lastNode = lastNode->accept_mutator(this); - continue; - } - } - - lastNode = new Add(lastNode, simpleNode); - } - - // If we have no add terms the scalar should go first. - // E.g. 1 - x. - bool scalarWritten = false; - if (lastNode == nullptr) { - auto* scalarNode = v->scalar()->accept_mutator(simplifier_); - - if (!immediateEquals(scalarNode, 0)) { - lastNode = scalarNode; - scalarWritten = true; - } - } - - for (auto* node : subTerms) { - // Can still be first node if scalarVal is 0. - if (lastNode == nullptr) { - lastNode = node->accept_mutator(this); - continue; - } - - // Negate the term back to positive since we'll be subtracting it. - const Expr* negated = evaluateOp(new Mul( - getImmediateByType(node->scalar()->dtype(), -1), node->scalar())); - Term* newRHS = new Term(node->hasher(), negated, node->variables()); - lastNode = new Sub(lastNode, newRHS->accept_mutator(this)); - } - - if (scalarWritten || immediateEquals(v->scalar(), 0)) { - return lastNode; - } - - if (immediateIsNegative(v->scalar())) { - // Negate the scalar and subtract. - const Expr* negated = evaluateOp( - new Mul(getImmediateByType(lastNode->dtype(), -1), v->scalar())); - lastNode = new Sub(lastNode, evaluateOp(negated)); - } else { - // we want to avoid a cast to the scalar if it would happen. - if (v->scalar()->dtype() != lastNode->dtype()) { - lastNode = new Add( - lastNode, evaluateOp(new Cast(lastNode->dtype(), v->scalar()))); - } else { - lastNode = new Add(lastNode, v->scalar()); - } - } - - return lastNode; -} - -} // namespace tensorexpr -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.h b/torch/csrc/jit/tensorexpr/ir_simplifier.h index 5364db846e27..1fa1757161de 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.h +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.h @@ -1,275 +1,367 @@ #pragma once -#include -#include -#include -#include -#include -#include - -/* IR Simplification - * - * Simplfies expressions in two stages: - * 1. Recursively traverse the map combining similar operations into Terms - * (interacted via Multiplication) and Polynomials (interacted via Addition). We - * reorder the components of each Term or Polynomial into a consistent order to - * allow combination or cancelling of like terms. - * 2. Once the format of the tree is minimal, expand each Term into a sequence - * of Muls, and each Polynomial into a sequence of Ads. - */ +#include "torch/csrc/jit/tensorexpr/eval.h" +#include "torch/csrc/jit/tensorexpr/ir_mutator.h" +#include "torch/csrc/jit/tensorexpr/ir_visitor.h" +#include "torch/csrc/jit/tensorexpr/types.h" namespace torch { namespace jit { namespace tensorexpr { -// A bunch of helpers for determine the Dtype of the output of a multi argument -// Term or Polynomial. -namespace { -template -Dtype promoteTypesVec(const Expr* s, std::vector& v) { - Dtype t = s->dtype(); - bool first = true; +// Uses the evaluator to fold an operation with constant terms. +// Expr v must be evaluatable without Vars. +static Expr* evaluateOp(const Expr* v) { + ExprHandle handle(v); + ExprEval eval(handle); - for (auto* e : v) { - if (first) { - t = Dtype(t.scalar_type(), e->dtype().lanes()); - first = false; + switch (v->dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: { \ + Type val = eval.value(); \ + return getImmediateByType(v->dtype().scalar_type(), val).node(); \ + } + AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); +#undef TYPE_CASE + default: + LOG(FATAL) << "Unsupported datatype: " << v->dtype(); + return nullptr; + } + return nullptr; +} // namespace tensorexpr + +static const Expr* newBinaryOpOfType( + IRNodeType expr_type, + const Expr* lhs, + const Expr* rhs, + bool option) { + switch (expr_type) { + case IRNodeType::kAdd: + return new Add(lhs, rhs); + case IRNodeType::kSub: + return new Sub(lhs, rhs); + case IRNodeType::kMul: + return new Mul(lhs, rhs); + case IRNodeType::kDiv: + return new Div(lhs, rhs); + case IRNodeType::kMod: + return new Mod(lhs, rhs); + case IRNodeType::kMax: + return new Max(lhs, rhs, option); + case IRNodeType::kMin: + return new Min(lhs, rhs, option); + case IRNodeType::kAnd: + return new And(lhs, rhs); + case IRNodeType::kXor: + return new Xor(lhs, rhs); + case IRNodeType::kLshift: + return new Lshift(lhs, rhs); + case IRNodeType::kRshift: + return new Rshift(lhs, rhs); + default: + LOG(FATAL) << "unsupported expr_type: " << static_cast(expr_type); + return nullptr; + } +} + +/* Interprets expr as an Immediate and returns the value as type T. */ +template +T immediateAs(const Expr* expr) { + T val{0}; + switch (expr->dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + if (const Name##Imm* imm = dynamic_cast(expr)) { \ + val = imm->value(); \ + } else { \ + LOG(FATAL) << "Bad expr: " << *expr << "\n"; \ + } \ + break; + AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE); +#undef TYPE_CASE + default: + LOG(FATAL) << "Unsupported datatype: " << expr->dtype(); + } + + return val; +} + +/* Takes a LinearForm and converts it to Mul + (Add/Sub). */ +const Expr* expandLinearForm(const LinearForm* v, IRMutator* mutator) { + const Expr* mul = nullptr; + const Expr* A = v->getA(); + const Expr* B = v->getB(); + const Expr* X = v->getX(); + // we only really care about 0 and 1, so double should be fine. + double Aval = immediateAs(A); + double Bval = immediateAs(B); + + // First handle A. + if (Aval == 0) { + if (Bval == 0) { + return getImmediateByType(X->dtype(), 0).node(); } - t = promoteTypes(t, e->dtype()); - } - return t; -} + return B; + } else if (Aval == 1) { + mul = X; + } else if (Aval == -1) { + return new Sub(B, X); + } else if (Aval < 0) { + // Negate A. + ExprHandle zero = getImmediateByType(A->dtype(), 0); + Sub* A_Sub = new Sub(zero.node(), A); -template -Dtype promoteTypesVec(std::vector& v) { - if (v.empty()) { - throw malformed_input("empty list of types"); + return new Sub(B, new Mul(X, evaluateOp(A_Sub))); + } else { + mul = new Mul(X, A); } - Dtype t = v[0]->dtype(); - for (auto* e : v) { - t = promoteTypes(t, e->dtype()); - } - return t; -} - -template -Dtype promoteTypesMap( - const Expr* s, - std::unordered_map& m) { - Dtype t = s->dtype(); - bool first = true; - for (auto& e : m) { - if (first) { - t = Dtype(t.scalar_type(), e.second->dtype().lanes()); - first = false; - } - t = promoteTypes(t, e.second->dtype()); - } - return t; -} - -template -Dtype promoteTypesVar(const ExprType* e) { - return e->dtype(); -} - -template -Dtype promoteTypesVar(const ExprType* e, Args... es) { - Dtype lhs = e->dtype(); - Dtype rhs = promoteTypesVar(es...); - if (e->isConstant()) { - lhs = Dtype(lhs.scalar_type(), rhs.lanes()); + if (Bval == 0) { + return mul; } - return promoteTypes(lhs, rhs); + return new Add(mul, B); } -// Helper for determining if an Expr is a multi-lane primitive (e.g. Broadcast -// or Ramp). -bool isMultilanePrimitive(const Expr* e) { - return e->expr_type() == IRNodeType::kBroadcast || - e->expr_type() == IRNodeType::kRamp; -} -} // namespace - -// A Term represents a grouping of Exprs through multiplication. -// E.g. product(scalar, *variables). -class Term : public ExprNode { +/* Expand any remaining LinearTerms into their component pieces */ +class LinearFormExpander : public IRMutator { public: - template - Term(HashProvider& hasher, const Expr* s, Args... ts) - : ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) { - CHECK(s->isConstant()); - addComponent(ts...); - sort(); + const Expr* mutate(const LinearForm* v) { + return expandLinearForm(v, this); } - - Term(HashProvider& hasher, const Expr* s, std::vector v) - : ExprNodeBase(promoteTypesVec(s, v)), - variables_(std::move(v)), - scalar_(s), - hasher_(hasher) { - sort(); - } - - // Convenience constructor from a map of hash -> var, used when merging Terms. - Term( - HashProvider& hasher, - const Expr* s, - std::unordered_map varmap) - : ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) { - for (auto& p : varmap) { - addComponent(p.second); - } - sort(); - } - - const Expr* scalar() const { - return scalar_; - } - const std::vector& variables() const { - return variables_; - } - HashProvider& hasher() const { - return hasher_; - } - - // Produce a hash of just the variable components of this term, to determine - // if it can be combined with another term. - SimplifierHashType hashVars() const; - - private: - std::vector variables_; - const Expr* scalar_; - HashProvider& hasher_; - - void addComponent() {} - void addComponent(const Expr* e) { - variables_.push_back(e); - } - template - void addComponent(const Expr* e, Es... es) { - addComponent(e); - addComponent(es...); - } - - // Sort by hash to normalize order of components. - void sort(); }; -// Polynomial represents a grouping of Exprs by addition. -// E.g. sum(*variables, scalar). -// This would better be called Expression, but, naming conflict... -class Polynomial : public ExprNode { +/* Simplify the IR by combining arithmetic expressions over a common term. + */ +class IRSimplifier : public IRMutator { public: - template - Polynomial(HashProvider& hasher, const Expr* s, Args... ts) - : ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) { - CHECK(s->isConstant()); - addTerm(ts...); - sort(); - } + const Expr* mutate(const Add* v) override { + const Expr* lhs = v->lhs(); + const Expr* rhs = v->rhs(); + const Expr* lhs_new = lhs->accept_mutator(this); + const Expr* rhs_new = rhs->accept_mutator(this); - Polynomial(HashProvider& hasher, const Expr* s, std::vector v) - : ExprNodeBase(promoteTypesVec(s, v)), - variables_(std::move(v)), - scalar_(s), - hasher_(hasher) { - sort(); - } - - // Helper constructor for list of terms with no scalar component. - Polynomial(HashProvider& hasher, std::vector terms) - : ExprNodeBase(promoteTypesVec(terms)), - variables_(std::move(terms)), - scalar_(getImmediateByType(dtype(), 0)), - hasher_(hasher) { - sort(); - } - - // Convenience constructor for map of hash -> var, used when merging - // Polynomials. - Polynomial( - HashProvider& hasher, - const Expr* s, - std::unordered_map varmap) - : ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) { - for (auto& p : varmap) { - addTerm(p.second); + // Constant Folding. + if (lhs_new->isConstant() && rhs_new->isConstant()) { + const Expr* result = evaluateOp(v); + return result; } - sort(); + + const LinearForm* lhsLinear = dynamic_cast(lhs_new); + const LinearForm* rhsLinear = dynamic_cast(rhs_new); + + if (lhsLinear && rhsLinear) { + // Can add two LinearTerms if they reference the same Var. + if (lhsLinear->getX() == rhsLinear->getX()) { + Add* A_Add = new Add(lhsLinear->getA(), rhsLinear->getA()); + Add* B_Add = new Add(lhsLinear->getB(), rhsLinear->getB()); + + LinearForm* linear = new LinearForm( + lhsLinear->getX(), evaluateOp(A_Add), evaluateOp(B_Add)); + return linear; + } + + // otherwise cannot simplify further. + return expandAndRecurse(v->expr_type(), lhs_new, rhs_new); + } + + // Can add a scalar into the B term of LinearTerm. + if (lhsLinear && rhs_new->isConstant()) { + Add* B_Add = new Add(lhsLinear->getB(), rhs_new); + LinearForm* linear = new LinearForm( + lhsLinear->getX(), lhsLinear->getA(), evaluateOp(B_Add)); + return linear; + } + + if (rhsLinear && lhs_new->isConstant()) { + Add* B_Add = new Add(rhsLinear->getB(), lhs_new); + LinearForm* linear = new LinearForm( + rhsLinear->getX(), rhsLinear->getA(), evaluateOp(B_Add)); + return linear; + } + + // Can create a LinearTerm over any sub expression. + if (lhs_new->isConstant()) { + LinearForm* linear = new LinearForm(rhs_new); + linear->setB(evaluateOp(lhs_new)); + return linear; + } + + if (rhs_new->isConstant()) { + LinearForm* linear = new LinearForm(lhs_new); + linear->setB(evaluateOp(rhs_new)); + return linear; + } + + /// Broadcasts are a bit more involved. + if (const Broadcast* bc = dynamic_cast(lhs_new)) { + if (const Expr* ret = handleBroadcastAdd(bc, rhs_new)) { + return ret; + } + } + + if (const Broadcast* bc = dynamic_cast(rhs_new)) { + if (const Expr* ret = handleBroadcastAdd(bc, lhs_new)) { + return ret; + } + } + + // No change. + if (lhs == lhs_new && rhs == rhs_new) { + return v; + } + + // Cannot simplify. + return expandAndRecurse(v->expr_type(), lhs_new, rhs_new); } - const Expr* scalar() const { - return scalar_; - } - const std::vector& variables() const { - return variables_; - } - HashProvider& hasher() const { - return hasher_; + const Expr* mutate(const Sub* v) override { + const Expr* lhs = v->lhs(); + const Expr* rhs = v->rhs(); + const Expr* lhs_new = lhs->accept_mutator(this); + const Expr* rhs_new = rhs->accept_mutator(this); + + // Constant Folding. + if (lhs_new->isConstant() && rhs_new->isConstant()) { + const Expr* result = evaluateOp(v); + return result; + } + + const LinearForm* lhsLinear = dynamic_cast(lhs_new); + const LinearForm* rhsLinear = dynamic_cast(rhs_new); + + if (lhsLinear && rhsLinear) { + // Can sub two LinearTerms if they reference the same Var. + if (lhsLinear->getX() == rhsLinear->getX()) { + Sub* A_Sub = new Sub(lhsLinear->getA(), rhsLinear->getA()); + Sub* B_Sub = new Sub(lhsLinear->getB(), rhsLinear->getB()); + + LinearForm* linear = new LinearForm( + lhsLinear->getX(), evaluateOp(A_Sub), evaluateOp(B_Sub)); + return linear; + } + + // otherwise cannot simplify further. + return expandAndRecurse(v->expr_type(), lhs_new, rhs_new); + } + + // Can just sub from B term if LHS is a LinearTerm. + if (lhsLinear && rhs_new->isConstant()) { + Sub* B_Sub = new Sub(lhsLinear->getB(), rhs_new); + LinearForm* linear = new LinearForm( + lhsLinear->getX(), lhsLinear->getA(), evaluateOp(B_Sub)); + return linear; + } + + // Slightly more complicated if the RHS is LinearTerm. + if (rhsLinear && lhs_new->isConstant()) { + // The linear needs to be negated. + ExprHandle zero = getImmediateByType(rhsLinear->getA()->dtype(), 0); + Sub* A_Sub = new Sub(zero.node(), rhsLinear->getA()); + Sub* B_Sub = new Sub(rhsLinear->getB(), lhs_new); + LinearForm* linear = new LinearForm( + rhsLinear->getX(), evaluateOp(A_Sub), evaluateOp(B_Sub)); + return linear; + } + + // Can create a new LinearTerm, but since the B term is defined as Add we + // must negate it. + if (rhs_new->isConstant()) { + LinearForm* linear = new LinearForm(lhs_new); + + ExprHandle zero = getImmediateByType(linear->getA()->dtype(), 0); + Sub* B_Sub = new Sub(zero.node(), rhs_new); + linear->setB(evaluateOp(B_Sub)); + return linear; + } + + // Can create a new LinearTerm with the A term -1 to negate the Expr. + if (lhs_new->isConstant()) { + // Negate by using -1 as the first linear. + ExprHandle negOne = getImmediateByType(rhs_new->dtype(), -1); + LinearForm* linear = + new LinearForm(rhs_new, negOne.node(), evaluateOp(lhs_new)); + return linear; + } + + // Nothing to do. + if (lhs == lhs_new && rhs == rhs_new) { + return v; + } + + // Cannot simplify. + return expandAndRecurse(v->expr_type(), lhs_new, rhs_new); } - SimplifierHashType hashVars() const; + const Expr* mutate(const Mul* v) override { + const Expr* lhs = v->lhs(); + const Expr* rhs = v->rhs(); + const Expr* lhs_new = lhs->accept_mutator(this); + const Expr* rhs_new = rhs->accept_mutator(this); - private: - std::vector variables_; - const Expr* scalar_; - HashProvider& hasher_; + // Constant Folding. + if (lhs_new->isConstant() && rhs_new->isConstant()) { + return evaluateOp(v); + } - void addTerm(const Term* t) { - variables_.push_back(t); + const LinearForm* lhsLinear = dynamic_cast(lhs_new); + const LinearForm* rhsLinear = dynamic_cast(rhs_new); + + if (lhsLinear && rhsLinear) { + // Lets not get into higher order terms. + return expandAndRecurse(v->expr_type(), lhs_new, rhs_new); + } + + // Easy to simplify into an existing LinearTerm by multiplying A and B. + if (lhsLinear && rhs_new->isConstant()) { + Mul* A_Mul = new Mul(lhsLinear->getA(), rhs_new); + Mul* B_Mul = new Mul(lhsLinear->getB(), rhs_new); + LinearForm* linear = new LinearForm( + lhsLinear->getX(), evaluateOp(A_Mul), evaluateOp(B_Mul)); + return linear; + } + + if (rhsLinear && lhs_new->isConstant()) { + Mul* A_Mul = new Mul(rhsLinear->getA(), lhs_new); + Mul* B_Mul = new Mul(rhsLinear->getB(), lhs_new); + LinearForm* linear = new LinearForm( + rhsLinear->getX(), evaluateOp(A_Mul), evaluateOp(B_Mul)); + return linear; + } + + // Easy to create a new LinearTerm by setting term A. + if (lhs_new->isConstant()) { + LinearForm* linear = new LinearForm(rhs_new); + linear->setA(evaluateOp(lhs_new)); + return linear; + } + + if (rhs_new->isConstant()) { + LinearForm* linear = new LinearForm(lhs_new); + linear->setA(evaluateOp(rhs_new)); + return linear; + } + + // Broadcasts have special logic. + if (const Broadcast* bc = dynamic_cast(lhs_new)) { + if (const Expr* ret = handleBroadcastMul(bc, rhs_new)) { + return ret; + } + } + + if (const Broadcast* bc = dynamic_cast(rhs_new)) { + if (const Expr* ret = handleBroadcastMul(bc, lhs_new)) { + return ret; + } + } + + // Cannot be simplified, just exit. + if (lhs == lhs_new && rhs == rhs_new) { + return v; + } + + return expandAndRecurse(v->expr_type(), lhs_new, rhs_new); } - template - void addTerm(const Term* t, Ts... ts) { - addTerm(t); - addTerm(ts...); - } - - // Sort by hash to normalize order of terms. - void sort(); -}; - -// Simplify the IR by combining arithmetic expressions over common terms. -class TORCH_API PolynomialTransformer : public IRMutator { - public: - // Inserts term into the provided map, in the case of a hash collision - // combines the term with the existing and updates the map. - void addOrUpdateTerm( - std::unordered_map& varmap, - const Term* term); - - // Add Polynomial expressions, combining Terms representing the same - // variables. - const Expr* addPolynomials(const Polynomial* lhs, const Polynomial* rhs); - - // Insert a new Term into the provided polynomial. If the new term has common - // variables to an existing term it is combined. - const Expr* insertTerm(const Polynomial* poly, const Term* term); - - // Merge and simplify addition. - const Expr* mutate(const Add* v) override; - - // Subtract one term from another, cancelling if necessary. - const Expr* subTerms(const Term* lhs, const Term* rhs, bool negated); - - // Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where - // possible. - const Expr* subPolynomials(const Polynomial* lhs, const Polynomial* rhs); - - // Merge and simplify subtraction. - const Expr* mutate(const Sub* v) override; - - // Multiply two terms together, usually creating a new term with the variable - // lists concatenated. - const Term* mulTerms(const Term* lhs, const Term* rhs); - - // Multiply a Polynomial by a Term. - const Expr* polyByTerm(const Polynomial* poly, const Term* term); - - // Merge and simplify multiplication. - const Expr* mutate(const Mul* v) override; const Expr* mutate(const Div* v) override { // TODO div simplification will require a rational node. @@ -304,9 +396,37 @@ class TORCH_API PolynomialTransformer : public IRMutator { return mutateBinaryOp(v, this, v->propagate_nans()); } - const Expr* mutate(const Intrinsics* v) override; + const Expr* mutate(const Intrinsics* v) override { + std::vector new_params; + bool changed = false; + bool allConstant = true; + for (const auto* p : v->params()) { + const Expr* new_child = p->accept_mutator(this); + new_params.push_back(new_child); - const Expr* mutate(const Cast* v) override; + changed |= p != new_child; + allConstant &= new_child->isConstant(); + } + + const Expr* node = v; + if (changed) { + node = new Intrinsics(v->op_type(), new_params); + } + + if (!allConstant || !v->isPure()) { + return node; + } + + return evaluateOp(node); + } + + const Expr* mutate(const Cast* v) override { + if (v->src_value()->isConstant()) { + return evaluateOp(v); + } + + return v; + } template static const Expr* mutateBinaryOp( @@ -332,40 +452,12 @@ class TORCH_API PolynomialTransformer : public IRMutator { return evaluateOp(node); } - static const Expr* simplify(const Expr* e); - static ExprHandle simplify(const ExprHandle& e); - static Stmt* simplify(Stmt* e); - - private: - HashProvider hasher_; -}; // namespace tensorexpr - -// Expands Terms and Polynomial expressions into primitive operations. -// Does some simple factorization and reordering. -class TORCH_API TermExpander : public IRMutator { - PolynomialTransformer* simplifier_; - - public: - TermExpander(PolynomialTransformer* simplifier) : simplifier_(simplifier) {} - - // Expand Terms out to a series of Muls. - const Expr* mutate(const Term* v) override; - - // Trivially factorize terms by GCD of scalar components. - const Expr* factorizePolynomial(const Polynomial* poly); - - // Expand Polynomials out to a series of Adds. - const Expr* mutate(const Polynomial* v); -}; - -class TORCH_API IRSimplifier { - public: static const Expr* simplify(const Expr* e) { - PolynomialTransformer simplifier; + IRSimplifier simplifier; e = e->accept_mutator(&simplifier); // There may be terms left in the IR, expand them. - TermExpander expander(&simplifier); + LinearFormExpander expander; e = e->accept_mutator(&expander); return e; @@ -376,15 +468,74 @@ class TORCH_API IRSimplifier { } static Stmt* simplify(Stmt* s) { - PolynomialTransformer simplifier; + IRSimplifier simplifier; s = s->accept_mutator(&simplifier); // There may be terms left in the IR, expand them. - TermExpander expander(&simplifier); + LinearFormExpander expander; s = s->accept_mutator(&expander); - return s; } + + private: + /* Expands lhs and rhs if they are LinearTerms, creating a new op to hold + * them. If either side expands to a constant term, attempt simplification of + * the new op. */ + const Expr* expandAndRecurse( + IRNodeType expr_type, + const Expr* lhs, + const Expr* rhs) { + if (const LinearForm* lhsLinear = dynamic_cast(lhs)) { + lhs = expandLinearForm(lhsLinear, this); + } + if (const LinearForm* rhsLinear = dynamic_cast(rhs)) { + rhs = expandLinearForm(rhsLinear, this); + } + const Expr* result = newBinaryOpOfType(expr_type, lhs, rhs, false); + + // lhs or rhs can become constant during expansion, if either is now + // constant we can keep merging into a linear term. Have another attempt to + // simplify the new op. + if (lhs->isConstant() || rhs->isConstant()) { + return result->accept_mutator(this); + } + + return result; + } + + /* Handles optimization cases for Broadcast() + Other */ + const Expr* handleBroadcastAdd(const Broadcast* bc, const Expr* other) { + if (bc->value()->isConstant() && immediateAs(bc->value()) == 0) { + return other; + } + + if (const Ramp* r = dynamic_cast(other)) { + // Add the broadcast to the start of the Ramp. + const Expr* ret = + new Ramp(new Add(bc->value(), r->base()), r->stride(), r->lanes()); + return ret->accept_mutator(this); + } + + return nullptr; + } + + /* Handles optimization cases for Broadcast() * Other */ + const Expr* handleBroadcastMul(const Broadcast* bc, const Expr* other) { + if (bc->value()->isConstant() && immediateAs(bc->value()) == 1) { + return other; + } + + if (const Ramp* r = dynamic_cast(other)) { + // Multiply both start and stride by the broadcast value. + const Expr* ret = new Ramp( + new Mul(bc->value(), r->base()), + new Mul(bc->value(), r->stride()), + r->lanes()); + return ret->accept_mutator(this); + } + + return nullptr; + } }; } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.cpp b/torch/csrc/jit/tensorexpr/ir_visitor.cpp index dd989a33ab89..43a94706212d 100644 --- a/torch/csrc/jit/tensorexpr/ir_visitor.cpp +++ b/torch/csrc/jit/tensorexpr/ir_visitor.cpp @@ -1,7 +1,6 @@ #include #include -#include #include namespace torch { @@ -177,18 +176,10 @@ void IRVisitor::visit(const Cond* v) { } } -void IRVisitor::visit(const Term* v) { - v->scalar()->accept(this); - for (auto* t : v->variables()) { - t->accept(this); - } -} - -void IRVisitor::visit(const Polynomial* v) { - v->scalar()->accept(this); - for (auto* t : v->variables()) { - t->accept(this); - } +void IRVisitor::visit(const LinearForm* v) { + v->getA()->accept(this); + v->getX()->accept(this); + v->getB()->accept(this); } } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.h b/torch/csrc/jit/tensorexpr/ir_visitor.h index 35797dc9c786..b83b405757ca 100644 --- a/torch/csrc/jit/tensorexpr/ir_visitor.h +++ b/torch/csrc/jit/tensorexpr/ir_visitor.h @@ -42,8 +42,7 @@ class FunctionCall; class Allocate; class Free; class Cond; -class Term; -class Polynomial; +class LinearForm; class TORCH_API IRVisitor { public: @@ -91,8 +90,7 @@ class TORCH_API IRVisitor { virtual void visit(const Allocate* v); virtual void visit(const Free* v); virtual void visit(const Cond* v); - virtual void visit(const Term* v); - virtual void visit(const Polynomial* v); + virtual void visit(const LinearForm* v); }; } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 05720f5d88d5..1ff09be225b7 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1114,7 +1114,6 @@ void TensorExprKernel::lowerToBackend(BackendType backendType) { l.ApplyInlines(); Stmt* stmt = l.root_stmt(); - // Arithmetic Simplification. stmt = IRSimplifier::simplify(stmt); // Set up formal params (inputs, then outputs) for kernel. diff --git a/torch/csrc/jit/tensorexpr/types.cpp b/torch/csrc/jit/tensorexpr/types.cpp index 2651240061c3..1738f7534bde 100644 --- a/torch/csrc/jit/tensorexpr/types.cpp +++ b/torch/csrc/jit/tensorexpr/types.cpp @@ -1,6 +1,6 @@ +#include #include #include -#include #include diff --git a/torch/csrc/jit/tensorexpr/types.h b/torch/csrc/jit/tensorexpr/types.h index 6a916ab3ef24..cb6b9f712262 100644 --- a/torch/csrc/jit/tensorexpr/types.h +++ b/torch/csrc/jit/tensorexpr/types.h @@ -50,7 +50,7 @@ class TORCH_API Dtype { Dtype(Dtype type, int lanes) : scalar_type_(type.scalar_type_), lanes_(lanes) { if (type.lanes() != 1) { - throw malformed_input("dtype lanes dont match"); + throw malformed_input(); } } int lanes() const { @@ -108,15 +108,10 @@ inline ScalarType promoteTypes(ScalarType a, ScalarType b) { return static_cast(c10::promoteTypes( static_cast(a), static_cast(b))); } -inline Dtype promoteTypes(Dtype a, Dtype b) { - if (a.lanes() != b.lanes()) { - throw malformed_input("promoting types with different lanes"); - } - return Dtype( - static_cast(c10::promoteTypes( - static_cast(a.scalar_type()), - static_cast(b.scalar_type()))), - a.lanes()); +inline ScalarType promoteTypes(Dtype a, Dtype b) { + return static_cast(c10::promoteTypes( + static_cast(a.scalar_type()), + static_cast(b.scalar_type()))); } inline Dtype BinaryOpDtype( @@ -132,21 +127,22 @@ inline Dtype BinaryOpDtype( } if (op1_dtype.lanes() != op2_dtype.lanes()) { - throw malformed_input("lanes dont match"); + throw malformed_input(); } int lanes = op1_dtype.lanes(); - Dtype resultType = promoteTypes(op1_dtype, op2_dtype); - if (resultType.scalar_type() == ScalarType::Undefined) { - throw malformed_input("scalar type doesn't match"); + ScalarType resultType = promoteTypes(op1_dtype, op2_dtype); + if (resultType == ScalarType::Undefined) { + throw malformed_input(); } + if (lanes == 1) { // Use the fixed scalar Dtypes. - return ToDtype(resultType.scalar_type()); + return ToDtype(resultType); } - return resultType; + return Dtype(resultType, lanes); } } // namespace tensorexpr