From fce67800f4beba8fa90e9189847547f1b13c59d3 Mon Sep 17 00:00:00 2001 From: Nick Gibson Date: Tue, 24 Mar 2020 14:10:58 -0700 Subject: [PATCH] [TensorExpr] Extend arithmetic simplifier to work with multi variable expressions (#35127) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: A new version of the IR simplifier used by the jit/tensorexpr fuser. This is capable of simplifying expressions containing (shock) multiple variables, eg: ```(m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1``` Similar to the previous IR Simplifier it uses a two stage approach: 1. Traverse the tree combining subtree's of commutable operations in to a flat structure. In this implementation we have two intermediate Exprs: Term (expressing products of sub expressions) and Polynomial (expressing sums of sub expressions). 2. Traverse the tree expanding Term's and Polynomials into their component operators. Using the example above we execute with a process like this to simplify: ``` (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) # Using PolynomialTransformer: => Sub(Add(Mul(m, Mul(1, n_1)), Add(n, 1)), Add(Mul(m, Mul(1, n_1)), n)) => Sub(Polynomial(Term(m, n_1), n, 1), Polynomial(Term(m, n_1), n)) => Polynomial(Term(m, n_1), Term(-1, m, n_1), n, -n, 1) => Polynomial(1) # Using TermExpander => 1 ``` The IRSimplifier supports arithmetic simplifications of operators Add, Sub and Mul and constant folding of all binary Exprs and Intrinsics, but does not attempt expansion of multiplication of Polynomials to the canonical form since that generally leads to less efficient representations. It will do scalar factorization if it results in removal of operators, and will merge chains of multilane primitives (such as Broadcast and Ramp) down into a single operator. The ir_simplifier unit tests are a short tour of its capabilities. The existing simplifier has a bug where it will sometimes reorder operations on floating point types which are not associative. This causes (at least) the pyhpc equation_of_state benchmark to produce incorrect results. I have fixed that issue in this version and verified that that benchmark produces the same results with and without the simplifier. Tests: all cpp & py tensorexpr tests, and pyphc benchmark: ``` benchmarks.equation_of_state ============================ Running on CPU size backend calls mean stdev min 25% median 75% max Δ ------------------------------------------------------------------------------------------------------------------ 4,194,304 pytorch 10 0.246 0.002 0.243 0.245 0.246 0.248 0.250 1.000 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/35127 Differential Revision: D20624571 Pulled By: nickgg fbshipit-source-id: e49049377beee69e02dcf26eb922bef1447ae776 --- caffe2/CMakeLists.txt | 1 + test/cpp/tensorexpr/test_simplify.cpp | 818 ++++++++++++-- test/cpp/tensorexpr/tests.h | 212 ++-- tools/build_variables.bzl | 1 + torch/csrc/jit/tensorexpr/eval.h | 51 +- torch/csrc/jit/tensorexpr/expr.h | 2 + torch/csrc/jit/tensorexpr/hash_server.h | 20 +- torch/csrc/jit/tensorexpr/ir.h | 129 ++- torch/csrc/jit/tensorexpr/ir_mutator.cpp | 25 +- torch/csrc/jit/tensorexpr/ir_mutator.h | 6 +- torch/csrc/jit/tensorexpr/ir_printer.cpp | 34 +- torch/csrc/jit/tensorexpr/ir_printer.h | 3 +- torch/csrc/jit/tensorexpr/ir_simplifier.cpp | 1086 +++++++++++++++++++ torch/csrc/jit/tensorexpr/ir_simplifier.h | 723 +++++------- torch/csrc/jit/tensorexpr/ir_visitor.cpp | 17 +- torch/csrc/jit/tensorexpr/ir_visitor.h | 6 +- torch/csrc/jit/tensorexpr/kernel.cpp | 1 + torch/csrc/jit/tensorexpr/types.cpp | 2 +- torch/csrc/jit/tensorexpr/types.h | 28 +- 19 files changed, 2450 insertions(+), 715 deletions(-) create mode 100644 torch/csrc/jit/tensorexpr/ir_simplifier.cpp diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index b9dec0b87c9c..9ad5de242426 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -481,6 +481,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir.cpp ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_mutator.cpp ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_printer.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_simplifier.cpp ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_visitor.cpp ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/kernel.cpp ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/llvm_codegen.cpp diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp index f97f6da343ab..c2b16fd6dd56 100644 --- a/test/cpp/tensorexpr/test_simplify.cpp +++ b/test/cpp/tensorexpr/test_simplify.cpp @@ -11,6 +11,40 @@ namespace jit { using namespace torch::jit::tensorexpr; using SimpleIRExprEval = ExprEval; +#define IS_NODE(T, node) \ + { \ + auto* node_ = dynamic_cast(node); \ + EXPECT_NE(nullptr, node_) << "Expected node to be " #T; \ + } + +#define IS_NODE_WITH_NAME(T, node, name) \ + auto* name = dynamic_cast(node); \ + EXPECT_NE(nullptr, name) << "Expected " #name " to be " #T; + +#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \ + const T* name = nullptr; \ + { \ + auto* node_ = dynamic_cast(node); \ + EXPECT_NE(nullptr, node_); \ + EXPECT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \ + name = dynamic_cast(node_->src_value()); \ + } \ + EXPECT_NE(nullptr, name) << "Expected " #name " to be " #T; + +#define IS_IMM_WITH_VAL(T, node, val) \ + { \ + auto* node_ = dynamic_cast(node); \ + EXPECT_NE(nullptr, node_) << "Expected node to be " #T "Imm"; \ + EXPECT_EQ(node_->value(), val) << "Expected Imm to be " << val; \ + } + +#define IS_VAR_WITH_NAME(node, name) \ + { \ + auto* node_ = dynamic_cast(node); \ + EXPECT_NE(nullptr, node_) << "Expected node to be Var"; \ + EXPECT_EQ(node_->name_hint(), name) << "Expected var to be " #name; \ + } + void testConstantFoldSimple() { KernelScope kernel_scope; ExprHandle a(2.0f); @@ -133,17 +167,33 @@ void testConstantFoldIntrinsics() { void testConstantFoldWithVar() { KernelScope kernel_scope; - VarHandle x("x", kFloat); - ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f)); + { + VarHandle x("x", kInt); + ExprHandle body = x * (ExprHandle(2) + ExprHandle(4)); - ExprHandle newF = IRSimplifier::simplify(body); - const Mul* root = newF.AsNode(); - EXPECT_NE(root, nullptr); - EXPECT_NE(dynamic_cast(root->rhs()), nullptr); + ExprHandle newF = IRSimplifier::simplify(body); + const Mul* root = newF.AsNode(); + EXPECT_NE(root, nullptr); + EXPECT_NE(dynamic_cast(root->lhs()), nullptr); - ExprHandle result = Let::make(x, ExprHandle(3.f), newF); - SimpleIRExprEval eval(result); - EXPECT_EQ(eval.value(), 3 * (2 + 4)); + ExprHandle result = Let::make(x, ExprHandle(3), newF); + SimpleIRExprEval eval(result); + EXPECT_EQ(eval.value(), 3 * (2 + 4)); + } + + { + VarHandle x("x", kFloat); + ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f)); + + ExprHandle newF = IRSimplifier::simplify(body); + const Mul* root = newF.AsNode(); + EXPECT_NE(root, nullptr); + EXPECT_NE(dynamic_cast(root->rhs()), nullptr); + + ExprHandle result = Let::make(x, ExprHandle(3.f), newF); + SimpleIRExprEval eval(result); + EXPECT_EQ(eval.value(), 3 * (2 + 4)); + } } void testUnFoldableExpr() { @@ -228,34 +278,22 @@ void testHashEquivalenceAfterFolding() { ExprHandle a(2.0f); ExprHandle b(3.0f); ExprHandle c(5.0f); - ExprHandle f = ((a + b) * x) * (c * x); - - const Mul* root = f.AsNode(); - EXPECT_NE(root, nullptr); + ExprHandle f1 = ((a + b) * x); + ExprHandle f2 = (c * x); HashProvider hasher; - auto hash_f = hasher.hash(f.node()); - auto hash_l = hasher.hash(root->lhs()); - auto hash_r = hasher.hash(root->rhs()); + auto hash_l = hasher.hash(f1.node()); + auto hash_r = hasher.hash(f2.node()); - // Root not equal to either branch, and branches not equal. - EXPECT_NE(hash_f, hash_l); - EXPECT_NE(hash_f, hash_r); EXPECT_NE(hash_l, hash_r); - ExprHandle newF = IRSimplifier::simplify(f); + ExprHandle ff1 = IRSimplifier::simplify(f1); + ExprHandle ff2 = IRSimplifier::simplify(f2); - const Mul* newRoot = newF.AsNode(); - EXPECT_NE(newRoot, nullptr); + auto hash_l_n = hasher.hash(ff1.node()); + auto hash_r_n = hasher.hash(ff2.node()); - auto hash_f_n = hasher.hash(newF.node()); - auto hash_l_n = hasher.hash(newRoot->lhs()); - auto hash_r_n = hasher.hash(newRoot->rhs()); - - // Root not equal to either branch. - EXPECT_NE(hash_f_n, hash_l_n); - EXPECT_NE(hash_f_n, hash_r_n); - // but branches are now equal. + // branches are now equal. EXPECT_EQ(hash_l_n, hash_r_n); } @@ -343,11 +381,16 @@ void testHashLargeExpression() { EXPECT_NE(hash_t, hash_f); } -/// (2.f + x) + 4.f => x + 6.f +/// (2 + x) + 4 => x + 6 void testSimplifyAdd() { KernelScope kernel_scope; - VarHandle x("x", kFloat); - ExprHandle body = (ExprHandle(2.f) + x) + ExprHandle(4.f); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle n_1("n_1", kInt); + ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4); ExprHandle simplified = IRSimplifier::simplify(body); const Add* root = simplified.AsNode(); @@ -355,51 +398,43 @@ void testSimplifyAdd() { const Var* lhs = dynamic_cast(root->lhs()); EXPECT_NE(lhs, nullptr); EXPECT_EQ(lhs->name_hint(), "x"); - const FloatImm* rhs = dynamic_cast(root->rhs()); + const IntImm* rhs = dynamic_cast(root->rhs()); EXPECT_NE(rhs, nullptr); EXPECT_EQ(rhs->value(), 6.f); } -/// (2.f - x) - 4.f => -2.f - x +/// (2 - x) - 4 => -2 - x void testSimplifySub() { KernelScope kernel_scope; - VarHandle x("x", kFloat); - ExprHandle body = (ExprHandle(2.f) - x) - ExprHandle(4.f); + VarHandle x("x", kInt); + ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4); ExprHandle simplified = IRSimplifier::simplify(body); const Sub* root = simplified.AsNode(); EXPECT_NE(root, nullptr); - const FloatImm* lhs = dynamic_cast(root->lhs()); + const IntImm* lhs = dynamic_cast(root->lhs()); EXPECT_NE(lhs, nullptr); - EXPECT_EQ(lhs->value(), -2.f); + EXPECT_EQ(lhs->value(), -2); const Var* rhs = dynamic_cast(root->rhs()); EXPECT_NE(rhs, nullptr); EXPECT_EQ(rhs->name_hint(), "x"); } -/// 2.f * (1.f - x) - 4.f => -6.f - (x * 2.f) +/// 2 * (1 - x) - 4 => -2 * (x + 3) void testSimplifyMultiLayer() { KernelScope kernel_scope; - VarHandle x("x", kFloat); - ExprHandle body = ExprHandle(2.f) * ((ExprHandle(1.f) - x) - ExprHandle(4.f)); - + VarHandle x("x", kInt); + ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4)); ExprHandle simplified = IRSimplifier::simplify(body); - const Sub* root = simplified.AsNode(); - EXPECT_NE(root, nullptr); - const FloatImm* lhs = dynamic_cast(root->lhs()); - EXPECT_NE(lhs, nullptr); - EXPECT_EQ(lhs->value(), -6.f); - const Mul* rhs = dynamic_cast(root->rhs()); - EXPECT_NE(rhs, nullptr); - const Var* varX = dynamic_cast(rhs->lhs()); - EXPECT_NE(varX, nullptr); - EXPECT_EQ(varX->name_hint(), "x"); - const FloatImm* mulRhs = dynamic_cast(rhs->rhs()); - EXPECT_NE(mulRhs, nullptr); - EXPECT_EQ(mulRhs->value(), 2.f); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), -2); + IS_NODE_WITH_NAME(Add, mul->rhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_IMM_WITH_VAL(Int, add->rhs(), 3); } -/// 2 * (3 * x) - (x * 4) => x * 2 +/// 2 * (3 * x) - (x * 4) => 2 * x void testSimplifyMultiTerm() { KernelScope kernel_scope; VarHandle x("x", kInt); @@ -409,30 +444,30 @@ void testSimplifyMultiTerm() { ExprHandle simplified = IRSimplifier::simplify(body); const Mul* root = simplified.AsNode(); EXPECT_NE(root, nullptr); - const Var* lhs = dynamic_cast(root->lhs()); + const IntImm* lhs = dynamic_cast(root->lhs()); EXPECT_NE(lhs, nullptr); - EXPECT_EQ(lhs->name_hint(), "x"); - const IntImm* rhs = dynamic_cast(root->rhs()); + EXPECT_EQ(lhs->value(), 2); + const Var* rhs = dynamic_cast(root->rhs()); EXPECT_NE(rhs, nullptr); - EXPECT_EQ(rhs->value(), 2); + EXPECT_EQ(rhs->name_hint(), "x"); } -/// 2 * (3 * (f)x) - (x * 4) => x * 2.f +/// 2 * (3 * (long)x) - (x * 4) => 2 * x void testSimplifyCasts() { KernelScope kernel_scope; - VarHandle x("x", kFloat); + VarHandle x("x", kLong); ExprHandle body = (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); ExprHandle simplified = IRSimplifier::simplify(body); const Mul* root = simplified.AsNode(); EXPECT_NE(root, nullptr); - const Var* lhs = dynamic_cast(root->lhs()); + const LongImm* lhs = dynamic_cast(root->lhs()); EXPECT_NE(lhs, nullptr); - EXPECT_EQ(lhs->name_hint(), "x"); - const FloatImm* rhs = dynamic_cast(root->rhs()); + EXPECT_EQ(lhs->value(), 2); + const Var* rhs = dynamic_cast(root->rhs()); EXPECT_NE(rhs, nullptr); - EXPECT_EQ(rhs->value(), 2); + EXPECT_EQ(rhs->name_hint(), "x"); } /// (x + 0) * 1 => x @@ -452,20 +487,39 @@ void testSimplifyMultiVar() { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); - ExprHandle body = y * 24 + x * 34; + ExprHandle body = x * 24 + y * 34; ExprHandle simplified = IRSimplifier::simplify(body); + const Add* root = simplified.AsNode(); EXPECT_NE(root, nullptr); const Mul* lhs = dynamic_cast(root->lhs()); EXPECT_NE(lhs, nullptr); - const Var* varY = dynamic_cast(lhs->lhs()); - EXPECT_EQ(varY->name_hint(), "y"); + const Var* varX = dynamic_cast(lhs->rhs()); + EXPECT_NE(varX, nullptr); + EXPECT_EQ(varX->name_hint(), "y"); const Mul* rhs = dynamic_cast(root->rhs()); EXPECT_NE(rhs, nullptr); - const Var* varX = dynamic_cast(rhs->lhs()); - EXPECT_NE(varX, nullptr); - EXPECT_EQ(varX->name_hint(), "x"); + const Var* varY = dynamic_cast(rhs->rhs()); + EXPECT_NE(varY, nullptr); + EXPECT_EQ(varY->name_hint(), "x"); +} + +// x + 2 + y => x + y + 2 +void testSimplifyReorderings() { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = x + 2 + y; + ExprHandle simplified = IRSimplifier::simplify(body); + + const Add* root = simplified.AsNode(); + EXPECT_NE(root, nullptr); + + IS_NODE_WITH_NAME(Add, root->lhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "x"); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + IS_IMM_WITH_VAL(Int, root->rhs(), 2); } /// y + x * 0 => y @@ -476,9 +530,621 @@ void testSimplifyEliminatesVar() { ExprHandle body = y + x * ExprHandle(0); ExprHandle simplified = IRSimplifier::simplify(body); - const Var* root = simplified.AsNode(); - EXPECT_NE(root, nullptr); - EXPECT_EQ(root->name_hint(), "y"); + IS_VAR_WITH_NAME(simplified.node(), "y"); +} + +void testSimplifyAdds() { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // (x + y) + (x + y) => 2 * (x + y) + ExprHandle body = (x + y) + (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), root); + IS_IMM_WITH_VAL(Int, root->lhs(), 2); + IS_NODE_WITH_NAME(Add, root->rhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); + } + + { + // (x * y) + (x * y) => 2 * (x * y) + ExprHandle body = (x * y) + (x * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), root); + IS_IMM_WITH_VAL(Int, root->lhs(), 2); + IS_NODE_WITH_NAME(Mul, root->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // (x - y) + (x - y) => -2 * (y - x) + ExprHandle body = (x - y) + (x - y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), -2); + + IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "y"); + IS_VAR_WITH_NAME(rhs->rhs(), "x"); + } +} + +void testSimplifyMuls() { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // (x + y) * (x + y) => (x + y) * (x + y) + // We don't attempt to simplify mulitplication of polynomials since the + // result is only very rarely more efficient. + ExprHandle body = (x + y) * (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); + IS_VAR_WITH_NAME(lhs->lhs(), "x"); + IS_VAR_WITH_NAME(lhs->rhs(), "y"); + IS_NODE_WITH_NAME(Add, mul->rhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "x"); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // x * y * x * y => x * x * y * y + // These get reordered only. + ExprHandle body = x * y * x * y; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul1); + IS_NODE_WITH_NAME(Mul, mul1->lhs(), mul2); + IS_NODE_WITH_NAME(Mul, mul2->lhs(), mul3); + IS_VAR_WITH_NAME(mul1->rhs(), "y"); + IS_VAR_WITH_NAME(mul2->rhs(), "y"); + IS_VAR_WITH_NAME(mul3->lhs(), "x"); + IS_VAR_WITH_NAME(mul3->rhs(), "x"); + } + + { + // (x - y) * (x - y) => (x - y) * (x - y) + // As with Add we don't attempt simplification of this. + ExprHandle body = (x - y) * (x - y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_NODE_WITH_NAME(Sub, mul->lhs(), lhs); + IS_VAR_WITH_NAME(lhs->lhs(), "x"); + IS_VAR_WITH_NAME(lhs->rhs(), "y"); + IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "x"); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // (x + y) * (x - y) => (x - y) * (x - y) + // Don't simplify with different ops on each side. + ExprHandle body = (x + y) * (x - y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); + IS_VAR_WITH_NAME(lhs->lhs(), "x"); + IS_VAR_WITH_NAME(lhs->rhs(), "y"); + IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "x"); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } +} + +// Sub an expr from itself will result in zero. +void testSimplifySubs() { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // (x + y) - (x + y) => 0 + ExprHandle body = (x + y) - (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // (x * y) - (x * y) => 0 + ExprHandle body = (x * y) - (x * y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // (x - y) - (x - y) => 0 + ExprHandle body = (x - y) - (x - y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // (x + y) - 2 * (x + y) => -1 * (x + y) + ExprHandle body = (x + y) - ExprHandle(2) * (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), -1); + IS_NODE_WITH_NAME(Add, mul->rhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); + } + + { + // (x + y) - y => x + ExprHandle body = (x + y) - y; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // (x - y) - y => x - 2 * y + ExprHandle body = (x - y) - y; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_VAR_WITH_NAME(sub->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // 2 * x - x => x + ExprHandle body = (ExprHandle(2) * x) - x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // x - 2 * x = -1 * x + // We don't have a unary negate, but this could be 0 -x I guess? + ExprHandle body = x - (ExprHandle(2) * x); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + + IS_IMM_WITH_VAL(Int, mul->lhs(), -1); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + // (x + y + 5) * (x - x) => 0 + // Cancelling out one side of Mul cancels both. + ExprHandle body = (x + y + 5) * (x - x); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } +} + +// Test that mixing ops together simplifies as expected. +void testSimplifyMultiOp() { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // (x * y) + (x - y) => (x * y) + x - y + // + ExprHandle body = (x * y) + (x - y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Add, sub->lhs(), add); + IS_NODE_WITH_NAME(Mul, add->lhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + IS_VAR_WITH_NAME(add->rhs(), "x"); + IS_VAR_WITH_NAME(sub->rhs(), "y"); + } + + { + // (x + y) - (x * y) => x + y - (x * y) + ExprHandle body = (x + y) - (x * y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Add, sub->lhs(), add); + IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // (x - y) - (x + y) => -2 * y + ExprHandle body = (x - y) - (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), -2); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } +} + +// Test that chaining many ops together works as expected. +void testSimplifyManyOps() { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // x + y + x + x + y + y + x + y + x = 4 * y + 5 * x + ExprHandle body = x + y + x + x + y + y + x + y + x; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + + IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 4); + IS_VAR_WITH_NAME(lhs->rhs(), "y"); + + IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); + IS_VAR_WITH_NAME(rhs->rhs(), "x"); + } + + { + // x - y + x + x - y - y + x - y + x = 5 * x - 4 * y + ExprHandle body = x - y + x + x - y - y + x - y + x; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), add); + + IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 5); + IS_VAR_WITH_NAME(lhs->rhs(), "x"); + + IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // x + y + x - x - y - y + x + y + x = 3 * x + ExprHandle body = x + y + x - x - y - y + x + y + x; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 3); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } +} + +void testSimplifyFactorization() { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // (2 * x) + (2 * y) => 2 * (x + y) + ExprHandle body = (ExprHandle(2) * x + ExprHandle(2) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + + IS_NODE_WITH_NAME(Add, mul->rhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); + } + + { + // Factorization when scalars have common divider. + // (2 * x) + (4 * y) => 2 * (2 * y + x) + ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + + IS_NODE_WITH_NAME(Add, mul->rhs(), add); + IS_NODE_WITH_NAME(Mul, add->lhs(), mul2); + IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); + IS_VAR_WITH_NAME(mul2->rhs(), "y"); + IS_VAR_WITH_NAME(add->rhs(), "x"); + } + + { + // Factorization attempt without a common divider. + // (2 * x) + (5 * y) => (5 * y) + (2 * x) + ExprHandle body = (ExprHandle(2) * x + ExprHandle(5) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + + IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 5); + IS_VAR_WITH_NAME(lhs->rhs(), "y"); + + IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 2); + IS_VAR_WITH_NAME(rhs->rhs(), "x"); + } + + { + // Factorization after merging. + // (2 * x) + (4 * y) + (8 * x + 6 * y) => 10 * (x + y) + ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y) + + (ExprHandle(8) * x + ExprHandle(6) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 10); + + IS_NODE_WITH_NAME(Add, mul->rhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); + } + + { + // Factorization with common divider but different signs. + // (-2 * x) + (4 * y) => -2 * (x - 2 * y) + ExprHandle body = (ExprHandle(-2) * x + ExprHandle(4) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), -2); + + IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); + IS_VAR_WITH_NAME(sub->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, sub->rhs(), mul2); + IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); + IS_VAR_WITH_NAME(mul2->rhs(), "y"); + } +} + +// (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (3 * z + y + 4 * x) +void testSimplifyFactorizeUneven() { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + ExprHandle body = + (ExprHandle(4) * x + y + z * 2) + (ExprHandle(4) * x + y + z * 4); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), root); + IS_IMM_WITH_VAL(Int, root->lhs(), 2); + IS_NODE_WITH_NAME(Add, root->rhs(), add1); + IS_NODE_WITH_NAME(Add, add1->lhs(), add2); + + IS_NODE_WITH_NAME(Mul, add1->rhs(), xmul); + IS_NODE_WITH_NAME(Mul, add2->lhs(), zmul); + + IS_IMM_WITH_VAL(Int, zmul->lhs(), 3); + IS_VAR_WITH_NAME(zmul->rhs(), "z"); + + IS_VAR_WITH_NAME(add2->rhs(), "y"); + + IS_IMM_WITH_VAL(Int, xmul->lhs(), 4); + IS_VAR_WITH_NAME(xmul->rhs(), "x"); +} + +// (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y) +// This is kind of a placeholder test for variable factorization. +void testSimplifyDeeperTerms() { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x * y) + (ExprHandle(2) * x) * (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + + IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); + IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm); + IS_VAR_WITH_NAME(xxTerm->lhs(), "x"); + IS_VAR_WITH_NAME(xxTerm->rhs(), "x"); + + IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 3); + IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm); + IS_VAR_WITH_NAME(xyTerm->lhs(), "x"); + IS_VAR_WITH_NAME(xyTerm->rhs(), "y"); +} + +// Tests the difference between two less trivial expressions. +// (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1 +void testSimplifyDeeperDifference() { + KernelScope kernel_scope; + VarHandle n("n", kInt); + VarHandle n_1("n_1", kInt); + VarHandle m("m", kInt); + ExprHandle body = + (m * (ExprHandle(1) * n_1) + (n + 1)) - (m * (ExprHandle(1) * n_1) + n); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_IMM_WITH_VAL(Int, simplified.node(), 1); +} + +// Test constant folding into the difference between expressions. +// 2 + char((m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n)) => 3 +void testSimplifyFoldComplexDifference() { + KernelScope kernel_scope; + VarHandle n("n", kInt); + VarHandle n_1("n_1", kInt); + VarHandle m("m", kInt); + ExprHandle body = + (IntImm::make(2) + + (Cast::make( + kChar, + (m * (ExprHandle(1) * n_1) + (n + 1)) - + (m * (ExprHandle(1) * n_1) + n)))); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 3); +} + +void testSimplifyIfComponents() { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = IfThenElse::make( + ((ExprHandle(5) - ExprHandle(4)) * x) > y, + ExprHandle(2) * x - x, + ExprHandle(2) * y - y); + + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(IfThenElse, simplified.node(), ifexpr); + + IS_NODE_WITH_NAME(CompareSelect, ifexpr->condition(), cmp); + EXPECT_EQ(cmp->compare_select_op(), kGT); + IS_VAR_WITH_NAME(cmp->lhs(), "x"); + IS_VAR_WITH_NAME(cmp->rhs(), "y"); + + IS_VAR_WITH_NAME(ifexpr->true_value(), "x"); + IS_VAR_WITH_NAME(ifexpr->false_value(), "y"); +} + +void testSimplifyOpaqueTerms() { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // 2 * x/y * x - x/y * y => y * x/y + ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "y"); + IS_NODE_WITH_NAME(Div, mul->rhs(), div); + IS_VAR_WITH_NAME(div->lhs(), "x"); + IS_VAR_WITH_NAME(div->rhs(), "y"); + } + + { + // x%y - (x%y - 1) => 1 + ExprHandle body = (x % y) - ((x % y) - 1); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_IMM_WITH_VAL(Int, simplified.node(), 1); + } +} + +void testSimplifyWontReorderFloat() { + KernelScope kernel_scope; + + { + // 3 * (3 * x) - 3 * (3 * y) => -9 * (y - x) + // This is an expression we can simplify. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - + ExprHandle(3) * (ExprHandle(3) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), -9); + IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); + IS_VAR_WITH_NAME(sub->lhs(), "y"); + IS_VAR_WITH_NAME(sub->rhs(), "x"); + } + + { + // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - 3 * (3 * y). + // If the vars are floating point, ops are not associative and we can't + // reorder. + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + + ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - + ExprHandle(3) * (ExprHandle(3) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); + IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); + IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); + IS_IMM_WITH_VAL(Float, lhsVarMul->lhs(), 3); + IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); + + IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); + IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); + IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); + IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); + IS_VAR_WITH_NAME(rhsVarMul->rhs(), "y"); + } + + { + // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - (9 * y). + // We will simplify subexprs if they dont reorder floating point ops. + VarHandle x("x", kDouble); + VarHandle y("y", kInt); + + ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - + ExprHandle(3) * (ExprHandle(3) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); + IS_IMM_WITH_VAL(Double, lhsMul->lhs(), 3); + IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); + IS_IMM_WITH_VAL(Double, lhsVarMul->lhs(), 3); + IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); + + IS_NODE_WITH_NAME_AND_CAST(Mul, sub->rhs(), rhsMul, Double); + IS_IMM_WITH_VAL(Int, rhsMul->lhs(), 9); + IS_VAR_WITH_NAME(rhsMul->rhs(), "y"); + } + + { + // Prevent reordering if FP propagated from dtypes. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + ExprHandle body = ExprHandle(3.f) * (ExprHandle(3) * x) - + ExprHandle(3) * (ExprHandle(3.f) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); + IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); + IS_NODE_WITH_NAME_AND_CAST(Mul, lhsMul->rhs(), lhsVarMul, Float); + IS_IMM_WITH_VAL(Int, lhsVarMul->lhs(), 3); + IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); + + IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); + IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); + IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); + IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); + IS_NODE_WITH_NAME(Cast, rhsVarMul->rhs(), yCast); + IS_VAR_WITH_NAME(yCast->src_value(), "y"); + } + + { + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + // x%y - (x%y - 1) => x%y - (x%y - 1). + // We wont reorder opaque ops if they are FP. + ExprHandle body = (x % y) - ((x % y) - 1); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Mod, sub->lhs(), lhsMod); + IS_VAR_WITH_NAME(lhsMod->lhs(), "x"); + IS_VAR_WITH_NAME(lhsMod->rhs(), "y"); + + IS_NODE_WITH_NAME(Sub, sub->rhs(), rhsSub); + IS_NODE_WITH_NAME(Mod, rhsSub->lhs(), rhsMod); + IS_VAR_WITH_NAME(rhsMod->lhs(), "x"); + IS_VAR_WITH_NAME(rhsMod->rhs(), "y"); + IS_IMM_WITH_VAL(Float, rhsSub->rhs(), 1); + } } } // namespace jit diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h index 4e43ac16d4ed..4afb905e2032 100644 --- a/test/cpp/tensorexpr/tests.h +++ b/test/cpp/tensorexpr/tests.h @@ -9,105 +9,119 @@ namespace torch { namespace jit { -#define TH_FORALL_TESTS(_) \ - _(ExprBasicValueTest) \ - _(ExprBasicValueTest02) \ - _(ExprLetTest01) \ - _(ExprLetStmtTest01) \ - _(ExprLetTest02) \ - _(ExprIntTest) \ - _(ExprFloatTest) \ - _(ExprByteTest) \ - _(ExprCharTest) \ - _(ExprShortTest) \ - _(ExprLongTest) \ - _(ExprHalfTest) \ - _(ExprDoubleTest) \ - _(ExprVectorAdd01) \ - _(ExprCompareSelectEQ) \ - _(ExprSubstitute01) \ - _(ExprMath01) \ - _(ExprUnaryMath01) \ - _(ExprBinaryMath01) \ - _(ExprDynamicShapeAdd) \ - _(ExprBitwiseOps) \ - _(IRPrinterBasicValueTest) \ - _(IRPrinterBasicValueTest02) \ - _(IRPrinterLetTest01) \ - _(IRPrinterLetTest02) \ - _(IRPrinterCastTest) \ - _(ExprSimple01) \ - _(ExprLower01) \ - _(ExprSimple02) \ - _(ExprSplitWithTailNone) \ - _(ExprSplitWithMask01) \ - _(ScheduleBroadcastAddBuffer) \ - _(ScheduleFunctionCall01) \ - _(ScheduleInlineFunc01) \ - _(ScheduleFuserStyle) \ - _(ScheduleFuserThreeArg) \ - _(ScheduleDynamicShape2D) \ - _(TypeTest01) \ - _(TypePropagation) \ - _(Cond01) \ - _(IfThenElse01) \ - _(IfThenElse02) \ - _(ATen_cast_Float) \ - _(ATennegInt) \ - _(ATennegFloat) \ - _(ATenaddInt) \ - _(ATenaddFloat) \ - _(ATensubInt) \ - _(ATensubFloat) \ - _(ATenlerp) \ - _(ATenaddcmulInt) \ - _(ATenaddcmulFloat) \ - _(ATenmulInt) \ - _(ATenmulFloat) \ - _(ATendivInt) \ - _(ATendivFloat) \ - _(ATenmaxInt) \ - _(ATenmaxFloat) \ - _(ATenminInt) \ - _(ATenminFloat) \ - _(ATen_sigmoid_backward) \ - _(ATen_tanh_backward) \ - _(ATenreciprocal) \ - _(ATenreluInt) \ - _(ATenreluFloat) \ - _(ATenlogFloat) \ - _(ATenlog10Float) \ - _(ATenlog2Float) \ - _(ATenexpFloat) \ - _(ATenerfFloat) \ - _(ATencosFloat) \ - _(ATeneqInt) \ - _(ATengeInt) \ - _(ATengtInt) \ - _(ATenleInt) \ - _(ATenltInt) \ - _(ConstantFoldSimple) \ - _(ConstantFoldTwoLayer) \ - _(ConstantFoldShifts) \ - _(ConstantFoldBitwise) \ - _(ConstantFoldMultiOp) \ - _(ConstantFoldMinMax) \ - _(ConstantFoldIntrinsics) \ - _(ConstantFoldWithVar) \ - _(UnFoldableExpr) \ - _(HashSimple) \ - _(HashEquivalence) \ - _(HashEquivalenceAfterFolding) \ - _(HashDifferenceTypes) \ - _(HashLargeExpression) \ - _(SimplifyAdd) \ - _(SimplifySub) \ - _(SimplifyMultiLayer) \ - _(SimplifyMultiTerm) \ - _(SimplifyCasts) \ - _(SimplifyEliminatesNoOps) \ - _(SimplifyMultiVar) \ - _(SimplifyEliminatesVar) \ +#define TH_FORALL_TESTS(_) \ + _(ExprBasicValueTest) \ + _(ExprBasicValueTest02) \ + _(ExprLetTest01) \ + _(ExprLetStmtTest01) \ + _(ExprLetTest02) \ + _(ExprIntTest) \ + _(ExprFloatTest) \ + _(ExprByteTest) \ + _(ExprCharTest) \ + _(ExprShortTest) \ + _(ExprLongTest) \ + _(ExprHalfTest) \ + _(ExprDoubleTest) \ + _(ExprVectorAdd01) \ + _(ExprCompareSelectEQ) \ + _(ExprSubstitute01) \ + _(ExprMath01) \ + _(ExprUnaryMath01) \ + _(ExprBinaryMath01) \ + _(ExprDynamicShapeAdd) \ + _(ExprBitwiseOps) \ + _(IRPrinterBasicValueTest) \ + _(IRPrinterBasicValueTest02) \ + _(IRPrinterLetTest01) \ + _(IRPrinterLetTest02) \ + _(IRPrinterCastTest) \ + _(ExprSimple01) \ + _(ExprLower01) \ + _(ExprSimple02) \ + _(ExprSplitWithTailNone) \ + _(ExprSplitWithMask01) \ + _(ScheduleBroadcastAddBuffer) \ + _(ScheduleFunctionCall01) \ + _(ScheduleInlineFunc01) \ + _(ScheduleFuserStyle) \ + _(ScheduleFuserThreeArg) \ + _(ScheduleDynamicShape2D) \ + _(TypeTest01) \ + _(TypePropagation) \ + _(Cond01) \ + _(IfThenElse01) \ + _(IfThenElse02) \ + _(ATen_cast_Float) \ + _(ATennegInt) \ + _(ATennegFloat) \ + _(ATenaddInt) \ + _(ATenaddFloat) \ + _(ATensubInt) \ + _(ATensubFloat) \ + _(ATenlerp) \ + _(ATenaddcmulInt) \ + _(ATenaddcmulFloat) \ + _(ATenmulInt) \ + _(ATenmulFloat) \ + _(ATendivInt) \ + _(ATendivFloat) \ + _(ATenmaxInt) \ + _(ATenmaxFloat) \ + _(ATenminInt) \ + _(ATenminFloat) \ + _(ATen_sigmoid_backward) \ + _(ATen_tanh_backward) \ + _(ATenreciprocal) \ + _(ATenreluInt) \ + _(ATenreluFloat) \ + _(ATenlogFloat) \ + _(ATenlog10Float) \ + _(ATenlog2Float) \ + _(ATenexpFloat) \ + _(ATenerfFloat) \ + _(ATencosFloat) \ + _(ATeneqInt) \ + _(ATengeInt) \ + _(ATengtInt) \ + _(ATenleInt) \ + _(ATenltInt) \ + _(ConstantFoldSimple) \ + _(ConstantFoldTwoLayer) \ + _(ConstantFoldShifts) \ + _(ConstantFoldBitwise) \ + _(ConstantFoldMultiOp) \ + _(ConstantFoldMinMax) \ + _(ConstantFoldIntrinsics) \ + _(ConstantFoldWithVar) \ + _(UnFoldableExpr) \ + _(HashSimple) \ + _(HashEquivalence) \ + _(HashEquivalenceAfterFolding) \ + _(HashDifferenceTypes) \ + _(HashLargeExpression) \ + _(SimplifyAdd) \ + _(SimplifySub) \ + _(SimplifyMultiLayer) \ + _(SimplifyMultiTerm) \ + _(SimplifyCasts) \ + _(SimplifyEliminatesNoOps) \ + _(SimplifyMultiVar) \ + _(SimplifyReorderings) \ + _(SimplifyEliminatesVar) \ + _(SimplifyAdds) \ + _(SimplifyMuls) \ + _(SimplifySubs) \ + _(SimplifyMultiOp) \ + _(SimplifyManyOps) \ + _(SimplifyFactorization) \ + _(SimplifyFactorizeUneven) \ + _(SimplifyDeeperTerms) \ + _(SimplifyDeeperDifference) \ + _(SimplifyFoldComplexDifference) \ + _(SimplifyIfComponents) \ + _(SimplifyOpaqueTerms) \ + _(SimplifyWontReorderFloat) \ _(StmtClone) #define TH_FORALL_TESTS_LLVM(_) \ diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index de4a15339ae1..b4e56c0e83c4 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -205,6 +205,7 @@ libtorch_sources = [ "torch/csrc/jit/tensorexpr/ir.cpp", "torch/csrc/jit/tensorexpr/ir_mutator.cpp", "torch/csrc/jit/tensorexpr/ir_printer.cpp", + "torch/csrc/jit/tensorexpr/ir_simplifier.cpp", "torch/csrc/jit/tensorexpr/ir_visitor.cpp", "torch/csrc/jit/tensorexpr/kernel.cpp", "torch/csrc/jit/tensorexpr/llvm_codegen.cpp", diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h index 68c8ad17534e..58360b766655 100644 --- a/torch/csrc/jit/tensorexpr/eval.h +++ b/torch/csrc/jit/tensorexpr/eval.h @@ -59,24 +59,24 @@ class Value { void* ptr; }; -#define VALUE_AS_DISPATCH(Type, Name) \ - template <> \ - inline Type Value::as() const { \ - if (dtype_ != k##Name) { \ - throw unsupported_dtype(); \ - } \ - return Name##values[0]; \ +#define VALUE_AS_DISPATCH(Type, Name) \ + template <> \ + inline Type Value::as() const { \ + if (dtype_ != k##Name) { \ + throw unsupported_dtype(); \ + } \ + return Name##values[0]; \ } AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_DISPATCH); #undef VALUE_AS_DISPATCH -#define VALUE_AS_VEC_DISPATCH(Type, Name) \ - template <> \ - inline const std::vector& Value::as_vec() const { \ - if (dtype_.scalar_type() != ScalarType::Name) { \ - throw unsupported_dtype(); \ - } \ - return Name##values; \ +#define VALUE_AS_VEC_DISPATCH(Type, Name) \ + template <> \ + inline const std::vector& Value::as_vec() const { \ + if (dtype_.scalar_type() != ScalarType::Name) { \ + throw unsupported_dtype(); \ + } \ + return Name##values; \ } AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_VEC_DISPATCH); #undef VALUE_AS_VEC_DISPATCH @@ -479,7 +479,6 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor { throw malformed_input(v); } - if (src_dtype != dst_dtype) { switch (src_dtype.scalar_type()) { #define SRC_TYPE_CASE(Type, Name) \ @@ -912,6 +911,28 @@ inline Stmt* Substitute(Stmt* stmt, const VarMapping& var_mapping) { return stmt->accept_mutator(&var_sub); } +// Uses the evaluator to fold an Expression with constant terms. +// E.g. evaluateOp(Add(3, 4)) => 7. +// Expr v must not have any unbound Vars. +static Expr* evaluateOp(const Expr* v) { + ExprHandle handle(v); + ExprEval eval(handle); + + switch (v->dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: { \ + Type val = eval.value(); \ + return getImmediateByType(v->dtype().scalar_type(), val); \ + } + AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); +#undef TYPE_CASE + default: + LOG(FATAL) << "Unsupported datatype: " << v->dtype(); + return nullptr; + } + return nullptr; +} + } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h index ccda96cc1d81..00ffdaca6a74 100644 --- a/torch/csrc/jit/tensorexpr/expr.h +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -31,6 +31,8 @@ enum IRNodeType { kCompareSelect, kLet, kCast, + kBroadcast, + kRamp, kNone }; diff --git a/torch/csrc/jit/tensorexpr/hash_server.h b/torch/csrc/jit/tensorexpr/hash_server.h index 7f5b5097489f..5f4b7c00c924 100644 --- a/torch/csrc/jit/tensorexpr/hash_server.h +++ b/torch/csrc/jit/tensorexpr/hash_server.h @@ -1,3 +1,5 @@ +#pragma once + #include #include #include @@ -316,6 +318,13 @@ class HashProvider : public IRVisitor { putHash(v, hash); } + template + SimplifierHashType hash_combine(const Types&... args) { + SimplifierHashType seed = 0; + _hash_combine(seed, args...); + return seed; + } + private: SimplifierHashType hashOf(const Expr* e) { auto it = exprToHash_.find(e); @@ -371,6 +380,10 @@ class HashProvider : public IRVisitor { (seed << 7) + (seed >> 4); } + void _hash_combine(SimplifierHashType& seed, const Expr* e) { + _hash_combine(seed, hash(e)); + } + template void _hash_combine( SimplifierHashType& seed, @@ -380,13 +393,6 @@ class HashProvider : public IRVisitor { _hash_combine(seed, args...); } - template - SimplifierHashType hash_combine(const Types&... args) { - SimplifierHashType seed = 0; - _hash_combine(seed, args...); - return seed; - } - void putHash(const KernelScopedObject* e, SimplifierHashType h) { auto res = exprToHash_.emplace(e, h); if (res.second == false) { diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h index a71c5d30b204..136339f32e8f 100644 --- a/torch/csrc/jit/tensorexpr/ir.h +++ b/torch/csrc/jit/tensorexpr/ir.h @@ -283,24 +283,94 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE); // Get immediate by ScalarType. template -ExprHandle getImmediateByType(ScalarType immType, T initialVal) { +Expr* getImmediateByType(ScalarType immType, T initialVal) { switch (immType) { #define TYPE_CASE(Type, Name) \ case ScalarType::Name: \ - return Name##Imm::make(initialVal); + return new Name##Imm(initialVal); AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); #undef TYPE_CASE default: throw unsupported_dtype(); } - return ExprHandle(); + return nullptr; } template -ExprHandle getImmediateByType(Dtype dtype, T initialVal) { +Expr* getImmediateByType(Dtype dtype, T initialVal) { return getImmediateByType(dtype.scalar_type(), initialVal); } +template +T immediateAs(const Expr* e) { +#define TYPE_CASE(Type, Name) \ + if (const Name##Imm* imm = dynamic_cast(e)) { \ + return imm->value(); \ + } + AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); +#undef TYPE_CASE + throw unsupported_dtype(); + return 0; +} + +template +bool immediateEquals(const Expr* e, T val) { +#define TYPE_CASE(Type, Name) \ + if (const Name##Imm* imm = dynamic_cast(e)) { \ + return imm->value() == val; \ + } + AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); +#undef TYPE_CASE + throw unsupported_dtype(); + return false; +} + +template +bool immediateIsNegative(const T* e) { +#define TYPE_CASE(Type, Name) \ + if (const Name##Imm* imm = dynamic_cast(e)) { \ + return imm->value() < 0; \ + } + AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); +#undef TYPE_CASE + return false; +} + +// Creates a new Expr of the given type with the provided lhs and rhs. +static const Expr* newBinaryOpOfType( + IRNodeType expr_type, + const Expr* lhs, + const Expr* rhs, + bool option) { + switch (expr_type) { + case IRNodeType::kAdd: + return new Add(lhs, rhs); + case IRNodeType::kSub: + return new Sub(lhs, rhs); + case IRNodeType::kMul: + return new Mul(lhs, rhs); + case IRNodeType::kDiv: + return new Div(lhs, rhs); + case IRNodeType::kMod: + return new Mod(lhs, rhs); + case IRNodeType::kMax: + return new Max(lhs, rhs, option); + case IRNodeType::kMin: + return new Min(lhs, rhs, option); + case IRNodeType::kAnd: + return new And(lhs, rhs); + case IRNodeType::kXor: + return new Xor(lhs, rhs); + case IRNodeType::kLshift: + return new Lshift(lhs, rhs); + case IRNodeType::kRshift: + return new Rshift(lhs, rhs); + default: + LOG(FATAL) << "unsupported expr_type: " << static_cast(expr_type); + return nullptr; + } +} + // Bind the value to the var and evaluate the body. class Let : public ExprNode { public: @@ -354,7 +424,7 @@ class Ramp : public ExprNode { } Ramp(const Expr* base, const Expr* stride, int lanes) - : ExprNodeBase(Dtype(base->dtype(), lanes)), + : ExprNodeBase(Dtype(base->dtype(), lanes), kRamp), base_(base), stride_(stride), lanes_(lanes) { @@ -420,7 +490,7 @@ class Broadcast : public ExprNode { return ExprHandle(new Broadcast(value.node(), lanes)); } Broadcast(const Expr* value, int lanes) - : ExprNodeBase(Dtype(value->dtype(), lanes)), + : ExprNodeBase(Dtype(value->dtype(), lanes), kBroadcast), value_(value), lanes_(lanes) {} @@ -562,8 +632,7 @@ class TORCH_API CompareSelect : public ExprNode { const ExprHandle& ret_val1, const ExprHandle& ret_val2, CompareSelectOperation cmp_op) { - if (lhs.dtype() != rhs.dtype() || - ret_val1.dtype() != ret_val2.dtype()) { + if (lhs.dtype() != rhs.dtype() || ret_val1.dtype() != ret_val2.dtype()) { throw malformed_input(); } return ExprHandle(new CompareSelect( @@ -790,48 +859,8 @@ class Intrinsics : public CallNode { IntrinsicsOp op_type_; }; -/* An internal only Expr used in IR simplification. - * Encodes relationship y = Ax + B, where A and B are Immediates. - * Not required to be implemented by codegen. */ -class LinearForm : public ExprNode { - public: - LinearForm(const Expr* x, const Expr* A, const Expr* B) - : ExprNodeBase(dtypeFor(x, A, B)), x_(x), A_(A), B_(B) {} - - LinearForm(const Expr* x) - : ExprNodeBase(x->dtype()), - x_(x), - A_(new CharImm(1)), - B_(new CharImm(0)) {} - - const Expr* getX() const { - return x_; - } - const Expr* getA() const { - return A_; - } - const Expr* getB() const { - return B_; - } - - void setA(const Expr* A) { - A_ = A; - } - - void setB(const Expr* B) { - B_ = B; - } - - static Dtype dtypeFor(const Expr* A, const Expr* B, const Expr* C) { - return ToDtype(promoteTypes( - A->dtype().scalar_type(), promoteTypes(B->dtype(), C->dtype()))); - } - - private: - const Expr* x_; - const Expr* A_; - const Expr* B_; -}; +class Polynomial; +class Term; class FunctionCall; diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.cpp b/torch/csrc/jit/tensorexpr/ir_mutator.cpp index 716658e33c2c..041e714c170e 100644 --- a/torch/csrc/jit/tensorexpr/ir_mutator.cpp +++ b/torch/csrc/jit/tensorexpr/ir_mutator.cpp @@ -2,6 +2,7 @@ #include #include +#include namespace torch { namespace jit { @@ -214,6 +215,7 @@ const Expr* IRMutator::mutate(const IfThenElse* v) { const Expr* condition_new = condition->accept_mutator(this); const Expr* true_value_new = true_value->accept_mutator(this); const Expr* false_value_new = false_value->accept_mutator(this); + if (condition == condition_new && true_value == true_value_new && false_value == false_value_new) { return v; @@ -232,11 +234,24 @@ const Expr* IRMutator::mutate(const FunctionCall* v) { return this->mutate(base); } -const Expr* IRMutator::mutate(const LinearForm* v) { - const Expr* new_x = v->getX()->accept_mutator(this); - const Expr* new_a = v->getA()->accept_mutator(this); - const Expr* new_b = v->getB()->accept_mutator(this); - return new LinearForm(new_x, new_a, new_b); +const Expr* IRMutator::mutate(const Term* v) { + const Expr* newScalar = v->scalar()->accept_mutator(this); + + std::vector variables; + for (const auto* t : v->variables()) { + variables.push_back(t->accept_mutator(this)); + } + return new Term(v->hasher(), newScalar, variables); +} + +const Expr* IRMutator::mutate(const Polynomial* v) { + const Expr* newScalar = v->scalar()->accept_mutator(this); + + std::vector variables; + for (const auto* t : v->variables()) { + variables.push_back(static_cast(t->accept_mutator(this))); + } + return new Polynomial(v->hasher(), newScalar, variables); } const Expr* IRMutator::mutate(const BaseCallNode* v) { diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.h b/torch/csrc/jit/tensorexpr/ir_mutator.h index a9aff366ad95..29e0a1f5dd20 100644 --- a/torch/csrc/jit/tensorexpr/ir_mutator.h +++ b/torch/csrc/jit/tensorexpr/ir_mutator.h @@ -45,7 +45,8 @@ class Allocate; class Free; class Cond; class Stmt; -class LinearForm; +class Term; +class Polynomial; class TORCH_API IRMutator { public: @@ -86,7 +87,8 @@ class TORCH_API IRMutator { virtual const Expr* mutate(const Intrinsics* v); virtual const Expr* mutate(const FunctionCall* v); - virtual const Expr* mutate(const LinearForm* v); + virtual const Expr* mutate(const Term* v); + virtual const Expr* mutate(const Polynomial* v); virtual Stmt* mutate(const For* v); virtual Stmt* mutate(const Block* v); diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp index 1260629c6a88..7c377ebb1430 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.cpp +++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp @@ -1,5 +1,7 @@ #include +#include + namespace torch { namespace jit { namespace tensorexpr { @@ -198,7 +200,7 @@ template < typename T, std::enable_if_t::value>* = nullptr> static void formatImm(std::ostream& os, T v) { - os << v; + os << +v; } // NOLINTNEXTLINE @@ -368,9 +370,33 @@ void IRPrinter::visit(const Cond* v) { } } -void IRPrinter::visit(const LinearForm* v) { - os() << "(" << *v->getA() << ") * (" << *v->getX() << ") + (" << *v->getB() - << ")" << std::endl; +void IRPrinter::visit(const Term* v) { + os() << "Term("; + v->scalar()->accept(this); + for (auto* t : v->variables()) { + os() << ","; + t->accept(this); + } + os() << ")"; +} + +void IRPrinter::visit(const Polynomial* v) { + bool first = true; + os() << "Polynomial("; + for (auto* t : v->variables()) { + emitIndent(); + if (!first) { + os() << " + "; + } + first = false; + t->accept(this); + } + + if (!first) { + os() << " + "; + } + v->scalar()->accept(this); + os() << ")"; } void IRPrinter::emitIndent() { diff --git a/torch/csrc/jit/tensorexpr/ir_printer.h b/torch/csrc/jit/tensorexpr/ir_printer.h index 82ccf086258b..f0af7b6dc45f 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.h +++ b/torch/csrc/jit/tensorexpr/ir_printer.h @@ -48,7 +48,8 @@ class TORCH_API IRPrinter : public IRVisitor { void visit(const Allocate* v) override; void visit(const Free* v) override; void visit(const Cond* v) override; - void visit(const LinearForm* v) override; + void visit(const Term* v) override; + void visit(const Polynomial* v) override; std::ostream& os() { return printer_os_; diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp new file mode 100644 index 000000000000..aa308204aa53 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp @@ -0,0 +1,1086 @@ +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +SimplifierHashType Term::hashVars() const { + SimplifierHashType hash = 0; + for (auto* v : variables_) { + hash = hasher_.hash_combine(hash, hasher_.hash(v)); + } + + return hash; +} + +void Term::sort() { + // order of ops important for float + if (dtype().is_floating_point()) { + throw std::logic_error("reordering FP ops"); + } + std::sort( + variables_.begin(), variables_.end(), [&](const Expr* a, const Expr* b) { + return hasher_.hash(a) < hasher_.hash(b); + }); +} + +SimplifierHashType Polynomial::hashVars() const { + SimplifierHashType hash = 0; + for (auto* v : variables_) { + hash = hasher_.hash_combine(hash, hasher_.hash(v)); + } + return hash; +} + +void Polynomial::sort() { + if (dtype().is_floating_point()) { + throw std::logic_error("reordering FP ops"); + } + std::sort( + variables_.begin(), variables_.end(), [&](const Expr* a, const Expr* b) { + return hasher_.hash(a) < hasher_.hash(b); + }); +} + +// Handles optimization cases for Broadcast/Ramp +/- Broadcast/Ramp +template +const Expr* combineMultilane(const Expr* lhs, const Expr* rhs) { + if (const Broadcast* bc = dynamic_cast(lhs)) { + if (const Broadcast* bcother = dynamic_cast(rhs)) { + if (bc->lanes() != bcother->lanes()) { + throw malformed_input("multilane lane mismatch"); + } + + const Expr* ret = + new Broadcast(new Op(bc->value(), bcother->value()), bc->lanes()); + return ret; + } + + if (const Ramp* r = dynamic_cast(rhs)) { + if (bc->lanes() != r->lanes()) { + throw malformed_input("multilane lane mismatch"); + } + + const Expr* ret = + new Ramp(new Op(bc->value(), r->base()), r->stride(), r->lanes()); + return ret; + } + } else if (const Ramp* ramp = dynamic_cast(lhs)) { + if (const Ramp* rother = dynamic_cast(rhs)) { + if (ramp->lanes() != rother->lanes()) { + throw malformed_input("multilane lane mismatch"); + } + + const Expr* ret = new Ramp( + new Op(ramp->base(), rother->base()), + new Op(ramp->stride(), rother->stride()), + ramp->lanes()); + return ret; + } + + if (const Broadcast* bc = dynamic_cast(rhs)) { + if (ramp->lanes() != bc->lanes()) { + throw malformed_input("multilane lane mismatch"); + } + const Expr* ret = new Ramp( + new Op(bc->value(), ramp->base()), ramp->stride(), ramp->lanes()); + return ret; + } + } + + return nullptr; +} + +// Handles optimization cases for Broadcast/Ramp * Broadcast/Ramp +const Expr* mulMultilane(const Expr* lhs, const Expr* rhs) { + if (const Broadcast* bc = dynamic_cast(lhs)) { + if (const Broadcast* bcother = dynamic_cast(rhs)) { + if (bc->lanes() != bcother->lanes()) { + throw malformed_input("multilane lane mismatch"); + } + + const Expr* ret = + new Broadcast(new Mul(bc->value(), bcother->value()), bc->lanes()); + return ret; + } + + if (const Ramp* r = dynamic_cast(rhs)) { + if (bc->lanes() != r->lanes()) { + throw malformed_input("multilane lane mismatch"); + } + + const Expr* ret = new Ramp( + new Mul(bc->value(), r->base()), + new Mul(bc->value(), r->stride()), + r->lanes()); + return ret; + } + } else if (const Ramp* ramp = dynamic_cast(lhs)) { + if (const Ramp* r = dynamic_cast(rhs)) { + if (ramp->lanes() != r->lanes()) { + throw malformed_input("multilane lane mismatch"); + } + + const Expr* ret = new Ramp( + new Mul(ramp->base(), r->base()), + new Mul(ramp->stride(), r->stride()), + r->lanes()); + return ret; + } + + if (const Broadcast* bc = dynamic_cast(rhs)) { + if (ramp->lanes() != bc->lanes()) { + throw malformed_input("multilane lane mismatch"); + } + + const Expr* ret = new Ramp( + new Mul(bc->value(), ramp->base()), + new Mul(bc->value(), ramp->stride()), + ramp->lanes()); + return ret; + } + } + + return nullptr; +} + +void PolynomialTransformer::addOrUpdateTerm( + std::unordered_map& varmap, + const Term* term) { + SimplifierHashType hash = term->hashVars(); + auto insertRes = varmap.emplace(hash, term); + if (insertRes.second == false) { + const Term* lt = insertRes.first->second; + const Expr* termScalar = evaluateOp(new Add(lt->scalar(), term->scalar())); + + // If the term is canceled out, remove from the map. + if (immediateEquals(termScalar, 0)) { + varmap.erase(hash); + return; + } + + varmap[hash] = new Term(hasher_, termScalar, lt->variables()); + } +} + +const Expr* PolynomialTransformer::addPolynomials( + const Polynomial* lhs, + const Polynomial* rhs) { + // simplify common components + // The key here is the variable hash, not the term's hash since we do want + // to combine terms that have the same vars but different scalar components. + std::unordered_map varmap; + + for (auto* lt : lhs->variables()) { + addOrUpdateTerm(varmap, lt); + } + for (auto* rt : rhs->variables()) { + addOrUpdateTerm(varmap, rt); + } + + const Expr* newScalar = evaluateOp(new Add(lhs->scalar(), rhs->scalar())); + return new Polynomial(hasher_, newScalar, varmap); +} + +// Insert a new Term into the provided polynomial. If the new term has common +// variables to an existing term it is combined. +const Expr* PolynomialTransformer::insertTerm( + const Polynomial* poly, + const Term* term) { + SimplifierHashType tHash = term->hashVars(); + std::vector newVars; + + bool found = false; + for (auto* v : poly->variables()) { + if (v->hashVars() == tHash) { + const Expr* newScalar = evaluateOp(new Add(term->scalar(), v->scalar())); + found = true; + // Skip this term if we cancelled it out. + if (immediateEquals(newScalar, 0)) { + continue; + } + auto* term = new Term(hasher_, newScalar, v->variables()); + newVars.push_back(term); + } else { + newVars.push_back(v); + } + } + + if (!found) { + newVars.push_back(term); + } + + if (newVars.empty()) { + return poly->scalar(); + } + + auto* Poly = new Polynomial(hasher_, poly->scalar(), newVars); + return Poly; +} + +const Expr* PolynomialTransformer::mutate(const Add* v) { + const Expr* lhs_new = v->lhs()->accept_mutator(this); + const Expr* rhs_new = v->rhs()->accept_mutator(this); + + // Constant Folding. + if (lhs_new->isConstant() && rhs_new->isConstant()) { + const Expr* result = evaluateOp(new Add(lhs_new, rhs_new)); + return result; + } + + // Multilane folding. + if (isMultilanePrimitive(lhs_new)) { + if (auto* ret = combineMultilane(lhs_new, rhs_new)) { + return ret->accept_mutator(this); + } + } + + // If this is a floating point Add then order of operations is important, we + // dont want to combine ops. + if (lhs_new->dtype().is_floating_point() || + rhs_new->dtype().is_floating_point()) { + return new Add(lhs_new, rhs_new); + } + + const Polynomial* lhsPoly = dynamic_cast(lhs_new); + const Polynomial* rhsPoly = dynamic_cast(rhs_new); + + if (lhsPoly && rhsPoly) { + return addPolynomials(lhsPoly, rhsPoly); + } + + const Term* lhsTerm = dynamic_cast(lhs_new); + const Term* rhsTerm = dynamic_cast(rhs_new); + + if (lhsPoly && rhsTerm) { + return insertTerm(lhsPoly, rhsTerm); + } + + if (rhsPoly && lhsTerm) { + return insertTerm(rhsPoly, lhsTerm); + } + + if (lhsTerm && rhsTerm) { + // If the terms refer to the same variables: combine them. + if (lhsTerm->hashVars() == rhsTerm->hashVars()) { + const Expr* newScalar = + evaluateOp(new Add(lhsTerm->scalar(), rhsTerm->scalar())); + + // If the terms cancelled out, return zero. + if (immediateEquals(newScalar, 0)) { + return newScalar->accept_mutator(this); + } + + return new Term(hasher_, newScalar, lhsTerm->variables()); + } + + // Otherwise this is a new polynomial with no scalar and two variable + // terms. + return new Polynomial( + hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm); + } + + const Expr* scalar = nullptr; + const Expr* variable = nullptr; + if (lhs_new->isConstant()) { + scalar = evaluateOp(lhs_new); + variable = rhs_new; + } else if (rhs_new->isConstant()) { + scalar = evaluateOp(rhs_new); + variable = lhs_new; + } + + // If there is a scalar, and it's zero: short circuit and return the other + // side. + if (scalar && immediateEquals(scalar, 0)) { + return variable; + } + + // Adds are commutative. + const Polynomial* poly = lhsPoly ? lhsPoly : rhsPoly; + + // Add to Polynomial->scalar(). + if (scalar && poly) { + const Expr* newScalar = evaluateOp(new Add(scalar, poly->scalar())); + return new Polynomial(hasher_, newScalar, poly->variables()); + } + + // Simple Polynomial with a scalar and Term. + const Term* term = lhsTerm ? lhsTerm : rhsTerm; + if (scalar && term) { + return new Polynomial(hasher_, scalar, term); + } + + // Simple Term with a scalar and variable type. + if (scalar) { + return new Polynomial( + hasher_, + scalar, + new Term(hasher_, getImmediateByType(v->dtype(), 1), variable)); + } + + // If LHS is neither Term not Polynomial, wrap it in a Term. + if (!lhsTerm && !lhsPoly) { + lhsTerm = new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new); + } + + // Same for RHS. + if (!rhsTerm && !rhsPoly) { + rhsTerm = new Term(hasher_, getImmediateByType(v->dtype(), 1), rhs_new); + } + + // If we now have a poly and a term, we can insert. + if (poly) { + return insertTerm(poly, lhsTerm ? lhsTerm : rhsTerm); + } + + // If all else fails we have a new Polynomial with two new variable Terms. + return new Polynomial( + hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm); +} + +const Expr* PolynomialTransformer::subTerms( + const Term* lhs, + const Term* rhs, + bool negated) { + // If RHS not already negated, negate it. + if (!negated) { + const Expr* minusOne = getImmediateByType(rhs->dtype(), -1); + const Expr* negateScalar = evaluateOp(new Mul(minusOne, rhs->scalar())); + rhs = new Term(hasher_, negateScalar, rhs->variables()); + } + + if (lhs->hashVars() == rhs->hashVars()) { + const Expr* newScalar = evaluateOp(new Add(lhs->scalar(), rhs->scalar())); + + // If the terms cancel out, return zero. + if (immediateEquals(newScalar, 0)) { + return newScalar; + } + + return new Term(hasher_, newScalar, lhs->variables()); + } + + return new Polynomial( + hasher_, + getImmediateByType(promoteTypes(lhs->dtype(), rhs->dtype()), 0), + lhs, + rhs); +} + +// Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where +// possible. +const Expr* PolynomialTransformer::subPolynomials( + const Polynomial* lhs, + const Polynomial* rhs) { + // simplify common components + // The key here is the variable hash, not the term's hash since we do want + // to combine terms that have the same vars but different scalar components. + std::unordered_map varmap; + + for (auto* lt : lhs->variables()) { + addOrUpdateTerm(varmap, lt); + } + + for (auto* rt : rhs->variables()) { + // Polynomials add their terms, so negate the RHS's Terms. + const Expr* negated = + evaluateOp(new Mul(getImmediateByType(rt->dtype(), -1), rt->scalar())); + Term* newRHS = new Term(hasher_, negated, rt->variables()); + addOrUpdateTerm(varmap, newRHS); + } + + const Expr* newScalar = evaluateOp(new Sub(lhs->scalar(), rhs->scalar())); + + // No vars means this cancelled out to a scalar, return it unwrapped. + if (varmap.empty()) { + return newScalar; + } + + // If there is no scalar and zero or one terms, don't wrap. + if (immediateEquals(newScalar, 0)) { + if (varmap.empty()) { + return nullptr; + } + if (varmap.size() == 1) { + return varmap.begin()->second; + } + } + + // Wrap new variables in a Polynomial. + return new Polynomial(hasher_, newScalar, varmap); +} + +const Expr* PolynomialTransformer::mutate(const Sub* v) { + const Expr* lhs_new = v->lhs()->accept_mutator(this); + const Expr* rhs_new = v->rhs()->accept_mutator(this); + + // Constant Folding. + if (lhs_new->isConstant() && rhs_new->isConstant()) { + const Expr* result = evaluateOp(new Sub(lhs_new, rhs_new)); + return result; + } + + // Multilane folding. + if (isMultilanePrimitive(lhs_new)) { + if (auto* ret = combineMultilane(lhs_new, rhs_new)) { + return ret->accept_mutator(this); + } + } + + // If this is a floating point Sub then order of operations is important, we + // dont want to combine ops. + if (lhs_new->dtype().is_floating_point() || + rhs_new->dtype().is_floating_point()) { + return new Sub(lhs_new, rhs_new); + } + + const Polynomial* lhsPoly = dynamic_cast(lhs_new); + const Polynomial* rhsPoly = dynamic_cast(rhs_new); + + if (lhsPoly && rhsPoly) { + auto* ret = subPolynomials(lhsPoly, rhsPoly); + if (!ret) { + // Cancelled out completely. + return getImmediateByType(v->dtype(), 0); + } + return ret; + } + + const Term* lhsTerm = dynamic_cast(lhs_new); + const Term* rhsTerm = dynamic_cast(rhs_new); + + // Polynomial - Term. + if (lhsPoly && rhsTerm) { + // Negate the term. + const Expr* negate = evaluateOp( + new Mul(getImmediateByType(rhsTerm->dtype(), -1), rhsTerm->scalar())); + const Term* newTerm = new Term(hasher_, negate, rhsTerm->variables()); + return insertTerm(lhsPoly, newTerm); + } + + // Term - Polynomial. + if (rhsPoly && lhsTerm) { + // Negate every part of the Polynomial. + const Expr* minusOne = getImmediateByType(lhsTerm->dtype(), -1); + const Expr* negateScalar = evaluateOp(new Mul(minusOne, lhsTerm->scalar())); + + std::vector variables; + for (auto* t : lhsPoly->variables()) { + const Expr* negate = evaluateOp(new Mul(minusOne, t->scalar())); + variables.push_back(new Term(hasher_, negate, t->variables())); + } + + Polynomial* newPoly = new Polynomial(hasher_, negateScalar, variables); + return insertTerm(newPoly, lhsTerm); + } + + if (lhsTerm && rhsTerm) { + return subTerms(lhsTerm, rhsTerm, false); + } + + bool lhsScalar = lhs_new->isConstant(); + bool rhsScalar = rhs_new->isConstant(); + + if (lhsPoly && rhsScalar) { + // Easy path, just sub the scalar component. + const Expr* newScalar = evaluateOp(new Sub(lhsPoly->scalar(), rhs_new)); + return new Polynomial(hasher_, newScalar, lhsPoly->variables()); + } + + if (lhsScalar && rhsPoly) { + // Sub the scalar component. + const Expr* newScalar = evaluateOp(new Sub(lhs_new, rhsPoly->scalar())); + + // Negate each term in the Polynomial RHS. + const Expr* minusOne = getImmediateByType(rhsPoly->dtype(), -1); + std::vector variables; + for (auto* t : rhsPoly->variables()) { + const Expr* negate = evaluateOp(new Mul(minusOne, t->scalar())); + variables.push_back(new Term(hasher_, negate, t->variables())); + } + + return new Polynomial(hasher_, newScalar, variables); + } + + if (lhsTerm && rhsScalar) { + // Negate the constant. + const Expr* negate = + evaluateOp(new Mul(getImmediateByType(rhs_new->dtype(), -1), rhs_new)); + return new Polynomial(hasher_, negate, lhsTerm); + } + + if (lhsScalar && rhsTerm) { + // Negate the RHS Term. + const Expr* negate = evaluateOp(new Mul( + getImmediateByType(rhsTerm->scalar()->dtype(), -1), rhsTerm->scalar())); + + return new Polynomial( + hasher_, lhs_new, new Term(hasher_, negate, rhsTerm->variables())); + } + + // simple term with a scalar and variable type. + if (lhsScalar) { + // Create a negated term. + return new Polynomial( + hasher_, + lhs_new, + new Term(hasher_, getImmediateByType(v->dtype(), -1), rhs_new)); + } + + if (rhsScalar) { + // Negate the scalar. + const Expr* negate = + evaluateOp(new Mul(getImmediateByType(rhs_new->dtype(), -1), rhs_new)); + return new Polynomial( + hasher_, + negate, + new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new)); + } + + // no scalar... + if (!lhsTerm && !lhsPoly) { + lhsTerm = new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new); + } + + bool createdRHSnegated = false; + if (!rhsTerm && !rhsPoly) { + rhsTerm = new Term(hasher_, getImmediateByType(v->dtype(), -1), rhs_new); + createdRHSnegated = true; + } + + if (lhsTerm && rhsTerm) { + return subTerms(lhsTerm, rhsTerm, createdRHSnegated); + } + + // Insert wrapped Term into LHS Polynomial. + if (lhsPoly) { + CHECK(rhsTerm); + return insertTerm(lhsPoly, rhsTerm); + } + + // Insert wrapper Term into negated RHS Poly. + if (rhsPoly) { + CHECK(lhsTerm); + const Expr* minusOne = getImmediateByType(rhsPoly->dtype(), -1); + const Expr* newScalar = evaluateOp(new Mul(minusOne, rhsPoly->scalar())); + + // Negate each term in the Polynomial RHS. + std::vector variables; + for (auto* t : rhsPoly->variables()) { + const Expr* negate = evaluateOp(new Mul(minusOne, t->scalar())); + variables.push_back(new Term(hasher_, negate, t->variables())); + } + + auto* poly = new Polynomial(hasher_, newScalar, variables); + return insertTerm(poly, lhsTerm); + } + + return new Polynomial( + hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm); +} + +// Multiply two terms together, usually creating a new term with the variable +// lists concatenated. +const Term* PolynomialTransformer::mulTerms(const Term* lhs, const Term* rhs) { + const Expr* scalar = evaluateOp(new Mul(lhs->scalar(), rhs->scalar())); + if (immediateEquals(scalar, 0)) { + return nullptr; + } + + // Can reorder here since floating point ops don't get put into Terms. + std::vector variables; + std::vector multilaneVariables; + // For now don't handle exponents. + for (auto* c : lhs->variables()) { + if (isMultilanePrimitive(c)) { + multilaneVariables.push_back(c); + } else { + variables.push_back(c); + } + } + for (auto* c : rhs->variables()) { + if (isMultilanePrimitive(c)) { + multilaneVariables.push_back(c); + } else { + variables.push_back(c); + } + } + + // Merge all the multilane vars: + const Expr* lastNode{nullptr}; + for (auto* node : multilaneVariables) { + if (lastNode == nullptr) { + lastNode = node; + } else { + if (auto* next = mulMultilane(lastNode, node)) { + lastNode = next->accept_mutator(this); + } else { + variables.push_back(lastNode); + lastNode = node; + } + } + } + if (lastNode) { + variables.push_back(lastNode); + } + + return new Term(hasher_, scalar, variables); +} + +// Multiply a Polynomial by a Term. +const Expr* PolynomialTransformer::polyByTerm( + const Polynomial* poly, + const Term* term) { + std::vector newTerms; + + // scalar Term + const Expr* scalar = evaluateOp(new Mul(poly->scalar(), term->scalar())); + + for (auto* var : poly->variables()) { + const Term* newTerm = mulTerms(var, term); + if (newTerm) { + newTerms.push_back(newTerm); + } + } + + if (newTerms.empty()) { + return scalar; + } + + return new Polynomial(hasher_, scalar, std::move(newTerms)); +} + +const Expr* PolynomialTransformer::mutate(const Mul* v) { + const Expr* lhs_new = v->lhs()->accept_mutator(this); + const Expr* rhs_new = v->rhs()->accept_mutator(this); + + // Constant Folding. + if (lhs_new->isConstant() && rhs_new->isConstant()) { + return evaluateOp(new Mul(lhs_new, rhs_new)); + } + + // Multilane folding. + if (isMultilanePrimitive(lhs_new)) { + if (auto* ret = mulMultilane(lhs_new, rhs_new)) { + return ret->accept_mutator(this); + } + } + + // If this is a floating point Mul then order of operations is important, we + // dont want to combine ops. + if (lhs_new->dtype().is_floating_point() || + rhs_new->dtype().is_floating_point()) { + return new Mul(lhs_new, rhs_new); + } + + const Polynomial* lhsPoly = dynamic_cast(lhs_new); + const Polynomial* rhsPoly = dynamic_cast(rhs_new); + + if (lhsPoly && rhsPoly) { + // This expands to more terms that we can't generally fix without variable + // factorization, it's more efficient to just leave these as Muls. + return new Mul(lhsPoly, rhsPoly); + } + + const Term* lhsTerm = dynamic_cast(lhs_new); + const Term* rhsTerm = dynamic_cast(rhs_new); + + if (lhsPoly && rhsTerm) { + return polyByTerm(lhsPoly, rhsTerm); + } + + if (rhsPoly && lhsTerm) { + return polyByTerm(rhsPoly, lhsTerm); + } + + if (lhsTerm && rhsTerm) { + return mulTerms(lhsTerm, rhsTerm); + } + + const Expr* scalar = nullptr; + const Expr* variable = nullptr; + if (lhs_new->isConstant()) { + scalar = lhs_new; + variable = rhs_new; + } else if (rhs_new->isConstant()) { + scalar = rhs_new; + variable = lhs_new; + } + + if (scalar && lhsTerm) { + const Expr* newScalar = evaluateOp(new Mul(scalar, lhsTerm->scalar())); + if (immediateEquals(newScalar, 0)) { + return newScalar; + } + return new Term(hasher_, newScalar, lhsTerm->variables()); + } + + if (scalar && rhsTerm) { + const Expr* newScalar = evaluateOp(new Mul(scalar, rhsTerm->scalar())); + + if (immediateEquals(newScalar, 0)) { + return newScalar; + } + return new Term(hasher_, newScalar, rhsTerm->variables()); + } + + // If this is a scalar * a Polynomial, push the scalar term down. + // We can wrap the scalar with a Term and use polyByTerm. + if (scalar && lhsPoly) { + return polyByTerm(lhsPoly, new Term(hasher_, scalar)); + } + if (scalar && rhsPoly) { + return polyByTerm(rhsPoly, new Term(hasher_, scalar)); + } + + // simple term with a scalar and variable type. + if (scalar) { + return new Term(hasher_, scalar, variable); + } + + // Multiplying Polynomial by variable can be wrapped in a term and handled + // by polyByTerm also. + if (lhsPoly) { + auto* term = + new Term(hasher_, getImmediateByType(rhs_new->dtype(), 1), rhs_new); + return polyByTerm(lhsPoly, term); + } + if (rhsPoly) { + auto* term = + new Term(hasher_, getImmediateByType(lhs_new->dtype(), 1), lhs_new); + return polyByTerm(rhsPoly, term); + } + + // Multiplying Term by a variable is equivalent to adding the variable to + // the term's list of vars. + if (lhsTerm) { + std::vector vars = lhsTerm->variables(); + vars.push_back(rhs_new); + return new Term(hasher_, lhsTerm->scalar(), vars); + } + if (rhsTerm) { + std::vector vars = rhsTerm->variables(); + vars.push_back(lhs_new); + return new Term(hasher_, rhsTerm->scalar(), vars); + } + + // Two variables, create a new Term. + return new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new, rhs_new); +} + +const Expr* PolynomialTransformer::mutate(const Intrinsics* v) { + std::vector new_params; + bool changed = false; + bool allConstant = true; + for (const auto* p : v->params()) { + const Expr* new_child = p->accept_mutator(this); + new_params.push_back(new_child); + + changed |= p != new_child; + allConstant &= new_child->isConstant(); + } + + const Expr* node = v; + if (changed) { + node = new Intrinsics(v->op_type(), new_params); + } + + if (!allConstant || !v->isPure()) { + return node; + } + + // we're evaluating, but the evaluator only supports float intrinsics. + std::vector const_params; + changed = false; + for (const auto* p : new_params) { + if (p->dtype().scalar_type() == ScalarType::Float) { + const_params.push_back(p); + } else { + const_params.push_back( + new Cast(Dtype(ScalarType::Float, p->dtype().lanes()), p)); + changed = true; + } + } + + if (changed) { + node = new Intrinsics(v->op_type(), const_params); + } + return evaluateOp(node); +} + +const Expr* PolynomialTransformer::mutate(const Cast* v) { + const Expr* node = v->src_value()->accept_mutator(this); + if (node->isConstant()) { + return evaluateOp(new Cast(v->dtype(), node)); + } + + return new Cast(v->dtype(), node); +} + +// TermExpander + +const Expr* TermExpander::mutate(const Term* v) { + const Expr* newScalar = v->scalar()->accept_mutator(this); + if (immediateEquals(newScalar, 0)) { + return newScalar; + } + + std::vector vars; + std::vector multilaneVars; + + // Assume we can reorder here because we wont merge floating terms. + const Expr* lastNode{nullptr}; + for (auto* var : v->variables()) { + const Expr* node = var->accept_mutator(this); + if (const Mul* mul = dynamic_cast(node)) { + // If the sub-Expr resolved to a multiplication, lift it into this + // term. + if (isMultilanePrimitive(mul->lhs())) { + multilaneVars.push_back(mul->lhs()); + } else { + vars.push_back(mul->lhs()); + } + + if (isMultilanePrimitive(mul->rhs())) { + multilaneVars.push_back(mul->rhs()); + } else { + vars.push_back(mul->lhs()); + } + } else { + if (isMultilanePrimitive(node)) { + multilaneVars.push_back(node); + } else { + vars.push_back(node); + } + } + } + + for (auto* node : multilaneVars) { + if (lastNode == nullptr) { + lastNode = node; + } else { + lastNode = mulMultilane(lastNode, node); + // simplify first, then re-expand. + lastNode = lastNode->accept_mutator(simplifier_); + lastNode = lastNode->accept_mutator(this); + } + } + + for (auto* node : vars) { + if (lastNode == nullptr) { + lastNode = node; + } else { + lastNode = new Mul(lastNode, node); + } + } + + if (!immediateEquals(newScalar, 1)) { + if (lastNode) { + // We want to avoid a leaving a CastNode on the scalar, so handle that + // now. + if (v->scalar()->dtype() != lastNode->dtype()) { + lastNode = new Mul( + evaluateOp(new Cast(lastNode->dtype(), v->scalar())), lastNode); + } else { + lastNode = new Mul(v->scalar(), lastNode); + } + } else { + lastNode = v->scalar(); + } + } + + return lastNode; +} + +// Simple recursive GCD. +template +T gcd(T a, T b) { + if (b == 0) { + return a; + } + return gcd(b, a % b); +} + +// Returns an immediate containing the greatest common divisor of all terms +// (inc. the scalar term) in the polynomial. If the GCD is uninteresting +// (e.g. 1) then returns nullptr. +const Expr* polyGCD(const Polynomial* poly) { + const Expr* scalar = poly->scalar(); + const std::vector& variables = poly->variables(); + + // We ony want to factorize if we're saving complete operations, i.e. no + // value in factorizing 6x + 4y into 2 * (3x + 2y) since we don't save work. + int opsSaved = 1; // default to saving the scalar. + long GCD = immediateAs(scalar); + for (auto* t : variables) { + long termScalar = immediateAs(t->scalar()); + long newGCD = gcd(std::max(GCD, termScalar), std::min(GCD, termScalar)); + if (newGCD == 1) { + return nullptr; + } + + if (GCD != newGCD) { + opsSaved = 0; + GCD = newGCD; + } + + if (GCD == termScalar) { + opsSaved++; + } + } + + if (opsSaved == 0) { + return nullptr; + } + + // Not worth, can be a Sub. + if (GCD == -1 && opsSaved == 1) { + return nullptr; + } + + return getImmediateByType(poly->dtype(), GCD); +} + +// Trivially factorize terms by GCD of scalar components. +const Expr* TermExpander::factorizePolynomial(const Polynomial* poly) { + const Expr* scalar = poly->scalar(); + const std::vector& variables = poly->variables(); + bool floatScalars = false; + + // Check types. + for (auto& p : variables) { + if (is_floating_point(p->dtype().scalar_type()) || + is_floating_point(p->scalar()->dtype().scalar_type())) { + floatScalars = true; + } + } + if (is_floating_point(scalar->dtype().scalar_type())) { + floatScalars = true; + } + + // floating point isn't generally distributive. + if (floatScalars) { + return nullptr; + } + + // Compute the GCD of terms. + const Expr* GCD = polyGCD(poly); + + // No GCD means 0 or 1 and can't be factored. + if (!GCD) { + return nullptr; + } + + // Create new struture. + std::vector newPolyTerms; + for (auto* t : variables) { + // New term with the scalar divided by the GCD. + newPolyTerms.push_back(new Term( + poly->hasher(), evaluateOp(new Div(t->scalar(), GCD)), t->variables())); + } + + Polynomial* newPoly = new Polynomial( + poly->hasher(), evaluateOp(new Div(scalar, GCD)), newPolyTerms); + + return new Term(poly->hasher(), GCD, newPoly); +} + +const Expr* TermExpander::mutate(const Polynomial* v) { + if (v->variables().empty()) { + return v->scalar(); + } + + // If this Polynomial can be factorized: do it, then expand the result. + if (const Expr* factorized = factorizePolynomial(v)) { + return factorized->accept_mutator(this); + } + + std::vector addTerms; + std::vector subTerms; + + // partition the terms into a list to add and list to subtract. + for (auto* node : v->variables()) { + if (immediateIsNegative(node->scalar())) { + subTerms.push_back(node); + } else if (!immediateEquals(node->scalar(), 0)) { + addTerms.push_back(node); + } + // Skip terms with a scalar of zero. + } + + // The last node constructed. + const Expr* lastNode{nullptr}; + + for (auto* node : addTerms) { + const Expr* simpleNode = node->accept_mutator(this); + + if (lastNode == nullptr) { + lastNode = simpleNode; + continue; + } + + if (isMultilanePrimitive(simpleNode)) { + auto* ret = combineMultilane(lastNode, simpleNode); + if (ret) { + // simplify result first, then expand. + lastNode = ret->accept_mutator(simplifier_); + lastNode = lastNode->accept_mutator(this); + continue; + } + } + + lastNode = new Add(lastNode, simpleNode); + } + + // If we have no add terms the scalar should go first. + // E.g. 1 - x. + bool scalarWritten = false; + if (lastNode == nullptr) { + auto* scalarNode = v->scalar()->accept_mutator(simplifier_); + + if (!immediateEquals(scalarNode, 0)) { + lastNode = scalarNode; + scalarWritten = true; + } + } + + for (auto* node : subTerms) { + // Can still be first node if scalarVal is 0. + if (lastNode == nullptr) { + lastNode = node->accept_mutator(this); + continue; + } + + // Negate the term back to positive since we'll be subtracting it. + const Expr* negated = evaluateOp(new Mul( + getImmediateByType(node->scalar()->dtype(), -1), node->scalar())); + Term* newRHS = new Term(node->hasher(), negated, node->variables()); + lastNode = new Sub(lastNode, newRHS->accept_mutator(this)); + } + + if (scalarWritten || immediateEquals(v->scalar(), 0)) { + return lastNode; + } + + if (immediateIsNegative(v->scalar())) { + // Negate the scalar and subtract. + const Expr* negated = evaluateOp( + new Mul(getImmediateByType(lastNode->dtype(), -1), v->scalar())); + lastNode = new Sub(lastNode, evaluateOp(negated)); + } else { + // we want to avoid a cast to the scalar if it would happen. + if (v->scalar()->dtype() != lastNode->dtype()) { + lastNode = new Add( + lastNode, evaluateOp(new Cast(lastNode->dtype(), v->scalar()))); + } else { + lastNode = new Add(lastNode, v->scalar()); + } + } + + return lastNode; +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.h b/torch/csrc/jit/tensorexpr/ir_simplifier.h index 1fa1757161de..5364db846e27 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.h +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.h @@ -1,368 +1,276 @@ #pragma once -#include "torch/csrc/jit/tensorexpr/eval.h" -#include "torch/csrc/jit/tensorexpr/ir_mutator.h" -#include "torch/csrc/jit/tensorexpr/ir_visitor.h" -#include "torch/csrc/jit/tensorexpr/types.h" +#include +#include +#include +#include +#include +#include + +/* IR Simplification + * + * Simplfies expressions in two stages: + * 1. Recursively traverse the map combining similar operations into Terms + * (interacted via Multiplication) and Polynomials (interacted via Addition). We + * reorder the components of each Term or Polynomial into a consistent order to + * allow combination or cancelling of like terms. + * 2. Once the format of the tree is minimal, expand each Term into a sequence + * of Muls, and each Polynomial into a sequence of Ads. + */ namespace torch { namespace jit { namespace tensorexpr { -// Uses the evaluator to fold an operation with constant terms. -// Expr v must be evaluatable without Vars. -static Expr* evaluateOp(const Expr* v) { - ExprHandle handle(v); - ExprEval eval(handle); +// A bunch of helpers for determine the Dtype of the output of a multi argument +// Term or Polynomial. +namespace { +template +Dtype promoteTypesVec(const Expr* s, std::vector& v) { + Dtype t = s->dtype(); + bool first = true; - switch (v->dtype().scalar_type()) { -#define TYPE_CASE(Type, Name) \ - case ScalarType::Name: { \ - Type val = eval.value(); \ - return getImmediateByType(v->dtype().scalar_type(), val).node(); \ - } - AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); -#undef TYPE_CASE - default: - LOG(FATAL) << "Unsupported datatype: " << v->dtype(); - return nullptr; - } - return nullptr; -} // namespace tensorexpr - -static const Expr* newBinaryOpOfType( - IRNodeType expr_type, - const Expr* lhs, - const Expr* rhs, - bool option) { - switch (expr_type) { - case IRNodeType::kAdd: - return new Add(lhs, rhs); - case IRNodeType::kSub: - return new Sub(lhs, rhs); - case IRNodeType::kMul: - return new Mul(lhs, rhs); - case IRNodeType::kDiv: - return new Div(lhs, rhs); - case IRNodeType::kMod: - return new Mod(lhs, rhs); - case IRNodeType::kMax: - return new Max(lhs, rhs, option); - case IRNodeType::kMin: - return new Min(lhs, rhs, option); - case IRNodeType::kAnd: - return new And(lhs, rhs); - case IRNodeType::kXor: - return new Xor(lhs, rhs); - case IRNodeType::kLshift: - return new Lshift(lhs, rhs); - case IRNodeType::kRshift: - return new Rshift(lhs, rhs); - default: - LOG(FATAL) << "unsupported expr_type: " << static_cast(expr_type); - return nullptr; - } -} - -/* Interprets expr as an Immediate and returns the value as type T. */ -template -T immediateAs(const Expr* expr) { - T val{0}; - switch (expr->dtype().scalar_type()) { -#define TYPE_CASE(Type, Name) \ - case ScalarType::Name: \ - if (const Name##Imm* imm = dynamic_cast(expr)) { \ - val = imm->value(); \ - } else { \ - LOG(FATAL) << "Bad expr: " << *expr << "\n"; \ - } \ - break; - AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE); -#undef TYPE_CASE - default: - LOG(FATAL) << "Unsupported datatype: " << expr->dtype(); - } - - return val; -} - -/* Takes a LinearForm and converts it to Mul + (Add/Sub). */ -const Expr* expandLinearForm(const LinearForm* v, IRMutator* mutator) { - const Expr* mul = nullptr; - const Expr* A = v->getA(); - const Expr* B = v->getB(); - const Expr* X = v->getX(); - // we only really care about 0 and 1, so double should be fine. - double Aval = immediateAs(A); - double Bval = immediateAs(B); - - // First handle A. - if (Aval == 0) { - if (Bval == 0) { - return getImmediateByType(X->dtype(), 0).node(); + for (auto* e : v) { + if (first) { + t = Dtype(t.scalar_type(), e->dtype().lanes()); + first = false; } - return B; - } else if (Aval == 1) { - mul = X; - } else if (Aval == -1) { - return new Sub(B, X); - } else if (Aval < 0) { - // Negate A. - ExprHandle zero = getImmediateByType(A->dtype(), 0); - Sub* A_Sub = new Sub(zero.node(), A); - - return new Sub(B, new Mul(X, evaluateOp(A_Sub))); - } else { - mul = new Mul(X, A); + t = promoteTypes(t, e->dtype()); } - - if (Bval == 0) { - return mul; - } - - return new Add(mul, B); + return t; } -/* Expand any remaining LinearTerms into their component pieces */ -class LinearFormExpander : public IRMutator { - public: - const Expr* mutate(const LinearForm* v) { - return expandLinearForm(v, this); +template +Dtype promoteTypesVec(std::vector& v) { + if (v.empty()) { + throw malformed_input("empty list of types"); } + + Dtype t = v[0]->dtype(); + for (auto* e : v) { + t = promoteTypes(t, e->dtype()); + } + return t; +} + +template +Dtype promoteTypesMap( + const Expr* s, + std::unordered_map& m) { + Dtype t = s->dtype(); + bool first = true; + for (auto& e : m) { + if (first) { + t = Dtype(t.scalar_type(), e.second->dtype().lanes()); + first = false; + } + t = promoteTypes(t, e.second->dtype()); + } + return t; +} + +template +Dtype promoteTypesVar(const ExprType* e) { + return e->dtype(); +} + +template +Dtype promoteTypesVar(const ExprType* e, Args... es) { + Dtype lhs = e->dtype(); + Dtype rhs = promoteTypesVar(es...); + if (e->isConstant()) { + lhs = Dtype(lhs.scalar_type(), rhs.lanes()); + } + + return promoteTypes(lhs, rhs); +} + +// Helper for determining if an Expr is a multi-lane primitive (e.g. Broadcast +// or Ramp). +bool isMultilanePrimitive(const Expr* e) { + return e->expr_type() == IRNodeType::kBroadcast || + e->expr_type() == IRNodeType::kRamp; +} +} // namespace + +// A Term represents a grouping of Exprs through multiplication. +// E.g. product(scalar, *variables). +class Term : public ExprNode { + public: + template + Term(HashProvider& hasher, const Expr* s, Args... ts) + : ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) { + CHECK(s->isConstant()); + addComponent(ts...); + sort(); + } + + Term(HashProvider& hasher, const Expr* s, std::vector v) + : ExprNodeBase(promoteTypesVec(s, v)), + variables_(std::move(v)), + scalar_(s), + hasher_(hasher) { + sort(); + } + + // Convenience constructor from a map of hash -> var, used when merging Terms. + Term( + HashProvider& hasher, + const Expr* s, + std::unordered_map varmap) + : ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) { + for (auto& p : varmap) { + addComponent(p.second); + } + sort(); + } + + const Expr* scalar() const { + return scalar_; + } + const std::vector& variables() const { + return variables_; + } + HashProvider& hasher() const { + return hasher_; + } + + // Produce a hash of just the variable components of this term, to determine + // if it can be combined with another term. + SimplifierHashType hashVars() const; + + private: + std::vector variables_; + const Expr* scalar_; + HashProvider& hasher_; + + void addComponent() {} + void addComponent(const Expr* e) { + variables_.push_back(e); + } + template + void addComponent(const Expr* e, Es... es) { + addComponent(e); + addComponent(es...); + } + + // Sort by hash to normalize order of components. + void sort(); }; -/* Simplify the IR by combining arithmetic expressions over a common term. - */ -class IRSimplifier : public IRMutator { +// Polynomial represents a grouping of Exprs by addition. +// E.g. sum(*variables, scalar). +// This would better be called Expression, but, naming conflict... +class Polynomial : public ExprNode { public: - const Expr* mutate(const Add* v) override { - const Expr* lhs = v->lhs(); - const Expr* rhs = v->rhs(); - const Expr* lhs_new = lhs->accept_mutator(this); - const Expr* rhs_new = rhs->accept_mutator(this); - - // Constant Folding. - if (lhs_new->isConstant() && rhs_new->isConstant()) { - const Expr* result = evaluateOp(v); - return result; - } - - const LinearForm* lhsLinear = dynamic_cast(lhs_new); - const LinearForm* rhsLinear = dynamic_cast(rhs_new); - - if (lhsLinear && rhsLinear) { - // Can add two LinearTerms if they reference the same Var. - if (lhsLinear->getX() == rhsLinear->getX()) { - Add* A_Add = new Add(lhsLinear->getA(), rhsLinear->getA()); - Add* B_Add = new Add(lhsLinear->getB(), rhsLinear->getB()); - - LinearForm* linear = new LinearForm( - lhsLinear->getX(), evaluateOp(A_Add), evaluateOp(B_Add)); - return linear; - } - - // otherwise cannot simplify further. - return expandAndRecurse(v->expr_type(), lhs_new, rhs_new); - } - - // Can add a scalar into the B term of LinearTerm. - if (lhsLinear && rhs_new->isConstant()) { - Add* B_Add = new Add(lhsLinear->getB(), rhs_new); - LinearForm* linear = new LinearForm( - lhsLinear->getX(), lhsLinear->getA(), evaluateOp(B_Add)); - return linear; - } - - if (rhsLinear && lhs_new->isConstant()) { - Add* B_Add = new Add(rhsLinear->getB(), lhs_new); - LinearForm* linear = new LinearForm( - rhsLinear->getX(), rhsLinear->getA(), evaluateOp(B_Add)); - return linear; - } - - // Can create a LinearTerm over any sub expression. - if (lhs_new->isConstant()) { - LinearForm* linear = new LinearForm(rhs_new); - linear->setB(evaluateOp(lhs_new)); - return linear; - } - - if (rhs_new->isConstant()) { - LinearForm* linear = new LinearForm(lhs_new); - linear->setB(evaluateOp(rhs_new)); - return linear; - } - - /// Broadcasts are a bit more involved. - if (const Broadcast* bc = dynamic_cast(lhs_new)) { - if (const Expr* ret = handleBroadcastAdd(bc, rhs_new)) { - return ret; - } - } - - if (const Broadcast* bc = dynamic_cast(rhs_new)) { - if (const Expr* ret = handleBroadcastAdd(bc, lhs_new)) { - return ret; - } - } - - // No change. - if (lhs == lhs_new && rhs == rhs_new) { - return v; - } - - // Cannot simplify. - return expandAndRecurse(v->expr_type(), lhs_new, rhs_new); + template + Polynomial(HashProvider& hasher, const Expr* s, Args... ts) + : ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) { + CHECK(s->isConstant()); + addTerm(ts...); + sort(); } - const Expr* mutate(const Sub* v) override { - const Expr* lhs = v->lhs(); - const Expr* rhs = v->rhs(); - const Expr* lhs_new = lhs->accept_mutator(this); - const Expr* rhs_new = rhs->accept_mutator(this); - - // Constant Folding. - if (lhs_new->isConstant() && rhs_new->isConstant()) { - const Expr* result = evaluateOp(v); - return result; - } - - const LinearForm* lhsLinear = dynamic_cast(lhs_new); - const LinearForm* rhsLinear = dynamic_cast(rhs_new); - - if (lhsLinear && rhsLinear) { - // Can sub two LinearTerms if they reference the same Var. - if (lhsLinear->getX() == rhsLinear->getX()) { - Sub* A_Sub = new Sub(lhsLinear->getA(), rhsLinear->getA()); - Sub* B_Sub = new Sub(lhsLinear->getB(), rhsLinear->getB()); - - LinearForm* linear = new LinearForm( - lhsLinear->getX(), evaluateOp(A_Sub), evaluateOp(B_Sub)); - return linear; - } - - // otherwise cannot simplify further. - return expandAndRecurse(v->expr_type(), lhs_new, rhs_new); - } - - // Can just sub from B term if LHS is a LinearTerm. - if (lhsLinear && rhs_new->isConstant()) { - Sub* B_Sub = new Sub(lhsLinear->getB(), rhs_new); - LinearForm* linear = new LinearForm( - lhsLinear->getX(), lhsLinear->getA(), evaluateOp(B_Sub)); - return linear; - } - - // Slightly more complicated if the RHS is LinearTerm. - if (rhsLinear && lhs_new->isConstant()) { - // The linear needs to be negated. - ExprHandle zero = getImmediateByType(rhsLinear->getA()->dtype(), 0); - Sub* A_Sub = new Sub(zero.node(), rhsLinear->getA()); - Sub* B_Sub = new Sub(rhsLinear->getB(), lhs_new); - LinearForm* linear = new LinearForm( - rhsLinear->getX(), evaluateOp(A_Sub), evaluateOp(B_Sub)); - return linear; - } - - // Can create a new LinearTerm, but since the B term is defined as Add we - // must negate it. - if (rhs_new->isConstant()) { - LinearForm* linear = new LinearForm(lhs_new); - - ExprHandle zero = getImmediateByType(linear->getA()->dtype(), 0); - Sub* B_Sub = new Sub(zero.node(), rhs_new); - linear->setB(evaluateOp(B_Sub)); - return linear; - } - - // Can create a new LinearTerm with the A term -1 to negate the Expr. - if (lhs_new->isConstant()) { - // Negate by using -1 as the first linear. - ExprHandle negOne = getImmediateByType(rhs_new->dtype(), -1); - LinearForm* linear = - new LinearForm(rhs_new, negOne.node(), evaluateOp(lhs_new)); - return linear; - } - - // Nothing to do. - if (lhs == lhs_new && rhs == rhs_new) { - return v; - } - - // Cannot simplify. - return expandAndRecurse(v->expr_type(), lhs_new, rhs_new); + Polynomial(HashProvider& hasher, const Expr* s, std::vector v) + : ExprNodeBase(promoteTypesVec(s, v)), + variables_(std::move(v)), + scalar_(s), + hasher_(hasher) { + sort(); } - const Expr* mutate(const Mul* v) override { - const Expr* lhs = v->lhs(); - const Expr* rhs = v->rhs(); - const Expr* lhs_new = lhs->accept_mutator(this); - const Expr* rhs_new = rhs->accept_mutator(this); - - // Constant Folding. - if (lhs_new->isConstant() && rhs_new->isConstant()) { - return evaluateOp(v); - } - - const LinearForm* lhsLinear = dynamic_cast(lhs_new); - const LinearForm* rhsLinear = dynamic_cast(rhs_new); - - if (lhsLinear && rhsLinear) { - // Lets not get into higher order terms. - return expandAndRecurse(v->expr_type(), lhs_new, rhs_new); - } - - // Easy to simplify into an existing LinearTerm by multiplying A and B. - if (lhsLinear && rhs_new->isConstant()) { - Mul* A_Mul = new Mul(lhsLinear->getA(), rhs_new); - Mul* B_Mul = new Mul(lhsLinear->getB(), rhs_new); - LinearForm* linear = new LinearForm( - lhsLinear->getX(), evaluateOp(A_Mul), evaluateOp(B_Mul)); - return linear; - } - - if (rhsLinear && lhs_new->isConstant()) { - Mul* A_Mul = new Mul(rhsLinear->getA(), lhs_new); - Mul* B_Mul = new Mul(rhsLinear->getB(), lhs_new); - LinearForm* linear = new LinearForm( - rhsLinear->getX(), evaluateOp(A_Mul), evaluateOp(B_Mul)); - return linear; - } - - // Easy to create a new LinearTerm by setting term A. - if (lhs_new->isConstant()) { - LinearForm* linear = new LinearForm(rhs_new); - linear->setA(evaluateOp(lhs_new)); - return linear; - } - - if (rhs_new->isConstant()) { - LinearForm* linear = new LinearForm(lhs_new); - linear->setA(evaluateOp(rhs_new)); - return linear; - } - - // Broadcasts have special logic. - if (const Broadcast* bc = dynamic_cast(lhs_new)) { - if (const Expr* ret = handleBroadcastMul(bc, rhs_new)) { - return ret; - } - } - - if (const Broadcast* bc = dynamic_cast(rhs_new)) { - if (const Expr* ret = handleBroadcastMul(bc, lhs_new)) { - return ret; - } - } - - // Cannot be simplified, just exit. - if (lhs == lhs_new && rhs == rhs_new) { - return v; - } - - return expandAndRecurse(v->expr_type(), lhs_new, rhs_new); + // Helper constructor for list of terms with no scalar component. + Polynomial(HashProvider& hasher, std::vector terms) + : ExprNodeBase(promoteTypesVec(terms)), + variables_(std::move(terms)), + scalar_(getImmediateByType(dtype(), 0)), + hasher_(hasher) { + sort(); } + // Convenience constructor for map of hash -> var, used when merging + // Polynomials. + Polynomial( + HashProvider& hasher, + const Expr* s, + std::unordered_map varmap) + : ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) { + for (auto& p : varmap) { + addTerm(p.second); + } + sort(); + } + + const Expr* scalar() const { + return scalar_; + } + const std::vector& variables() const { + return variables_; + } + HashProvider& hasher() const { + return hasher_; + } + + SimplifierHashType hashVars() const; + + private: + std::vector variables_; + const Expr* scalar_; + HashProvider& hasher_; + + void addTerm(const Term* t) { + variables_.push_back(t); + } + template + void addTerm(const Term* t, Ts... ts) { + addTerm(t); + addTerm(ts...); + } + + // Sort by hash to normalize order of terms. + void sort(); +}; + +// Simplify the IR by combining arithmetic expressions over common terms. +class TORCH_API PolynomialTransformer : public IRMutator { + public: + // Inserts term into the provided map, in the case of a hash collision + // combines the term with the existing and updates the map. + void addOrUpdateTerm( + std::unordered_map& varmap, + const Term* term); + + // Add Polynomial expressions, combining Terms representing the same + // variables. + const Expr* addPolynomials(const Polynomial* lhs, const Polynomial* rhs); + + // Insert a new Term into the provided polynomial. If the new term has common + // variables to an existing term it is combined. + const Expr* insertTerm(const Polynomial* poly, const Term* term); + + // Merge and simplify addition. + const Expr* mutate(const Add* v) override; + + // Subtract one term from another, cancelling if necessary. + const Expr* subTerms(const Term* lhs, const Term* rhs, bool negated); + + // Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where + // possible. + const Expr* subPolynomials(const Polynomial* lhs, const Polynomial* rhs); + + // Merge and simplify subtraction. + const Expr* mutate(const Sub* v) override; + + // Multiply two terms together, usually creating a new term with the variable + // lists concatenated. + const Term* mulTerms(const Term* lhs, const Term* rhs); + + // Multiply a Polynomial by a Term. + const Expr* polyByTerm(const Polynomial* poly, const Term* term); + + // Merge and simplify multiplication. + const Expr* mutate(const Mul* v) override; + const Expr* mutate(const Div* v) override { // TODO div simplification will require a rational node. return mutateBinaryOp(v, this); @@ -396,37 +304,9 @@ class IRSimplifier : public IRMutator { return mutateBinaryOp(v, this, v->propagate_nans()); } - const Expr* mutate(const Intrinsics* v) override { - std::vector new_params; - bool changed = false; - bool allConstant = true; - for (const auto* p : v->params()) { - const Expr* new_child = p->accept_mutator(this); - new_params.push_back(new_child); + const Expr* mutate(const Intrinsics* v) override; - changed |= p != new_child; - allConstant &= new_child->isConstant(); - } - - const Expr* node = v; - if (changed) { - node = new Intrinsics(v->op_type(), new_params); - } - - if (!allConstant || !v->isPure()) { - return node; - } - - return evaluateOp(node); - } - - const Expr* mutate(const Cast* v) override { - if (v->src_value()->isConstant()) { - return evaluateOp(v); - } - - return v; - } + const Expr* mutate(const Cast* v) override; template static const Expr* mutateBinaryOp( @@ -452,12 +332,40 @@ class IRSimplifier : public IRMutator { return evaluateOp(node); } + static const Expr* simplify(const Expr* e); + static ExprHandle simplify(const ExprHandle& e); + static Stmt* simplify(Stmt* e); + + private: + HashProvider hasher_; +}; // namespace tensorexpr + +// Expands Terms and Polynomial expressions into primitive operations. +// Does some simple factorization and reordering. +class TORCH_API TermExpander : public IRMutator { + PolynomialTransformer* simplifier_; + + public: + TermExpander(PolynomialTransformer* simplifier) : simplifier_(simplifier) {} + + // Expand Terms out to a series of Muls. + const Expr* mutate(const Term* v) override; + + // Trivially factorize terms by GCD of scalar components. + const Expr* factorizePolynomial(const Polynomial* poly); + + // Expand Polynomials out to a series of Adds. + const Expr* mutate(const Polynomial* v); +}; + +class TORCH_API IRSimplifier { + public: static const Expr* simplify(const Expr* e) { - IRSimplifier simplifier; + PolynomialTransformer simplifier; e = e->accept_mutator(&simplifier); // There may be terms left in the IR, expand them. - LinearFormExpander expander; + TermExpander expander(&simplifier); e = e->accept_mutator(&expander); return e; @@ -468,74 +376,15 @@ class IRSimplifier : public IRMutator { } static Stmt* simplify(Stmt* s) { - IRSimplifier simplifier; + PolynomialTransformer simplifier; s = s->accept_mutator(&simplifier); // There may be terms left in the IR, expand them. - LinearFormExpander expander; + TermExpander expander(&simplifier); s = s->accept_mutator(&expander); + return s; } - - private: - /* Expands lhs and rhs if they are LinearTerms, creating a new op to hold - * them. If either side expands to a constant term, attempt simplification of - * the new op. */ - const Expr* expandAndRecurse( - IRNodeType expr_type, - const Expr* lhs, - const Expr* rhs) { - if (const LinearForm* lhsLinear = dynamic_cast(lhs)) { - lhs = expandLinearForm(lhsLinear, this); - } - if (const LinearForm* rhsLinear = dynamic_cast(rhs)) { - rhs = expandLinearForm(rhsLinear, this); - } - const Expr* result = newBinaryOpOfType(expr_type, lhs, rhs, false); - - // lhs or rhs can become constant during expansion, if either is now - // constant we can keep merging into a linear term. Have another attempt to - // simplify the new op. - if (lhs->isConstant() || rhs->isConstant()) { - return result->accept_mutator(this); - } - - return result; - } - - /* Handles optimization cases for Broadcast() + Other */ - const Expr* handleBroadcastAdd(const Broadcast* bc, const Expr* other) { - if (bc->value()->isConstant() && immediateAs(bc->value()) == 0) { - return other; - } - - if (const Ramp* r = dynamic_cast(other)) { - // Add the broadcast to the start of the Ramp. - const Expr* ret = - new Ramp(new Add(bc->value(), r->base()), r->stride(), r->lanes()); - return ret->accept_mutator(this); - } - - return nullptr; - } - - /* Handles optimization cases for Broadcast() * Other */ - const Expr* handleBroadcastMul(const Broadcast* bc, const Expr* other) { - if (bc->value()->isConstant() && immediateAs(bc->value()) == 1) { - return other; - } - - if (const Ramp* r = dynamic_cast(other)) { - // Multiply both start and stride by the broadcast value. - const Expr* ret = new Ramp( - new Mul(bc->value(), r->base()), - new Mul(bc->value(), r->stride()), - r->lanes()); - return ret->accept_mutator(this); - } - - return nullptr; - } }; } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.cpp b/torch/csrc/jit/tensorexpr/ir_visitor.cpp index 43a94706212d..dd989a33ab89 100644 --- a/torch/csrc/jit/tensorexpr/ir_visitor.cpp +++ b/torch/csrc/jit/tensorexpr/ir_visitor.cpp @@ -1,6 +1,7 @@ #include #include +#include #include namespace torch { @@ -176,10 +177,18 @@ void IRVisitor::visit(const Cond* v) { } } -void IRVisitor::visit(const LinearForm* v) { - v->getA()->accept(this); - v->getX()->accept(this); - v->getB()->accept(this); +void IRVisitor::visit(const Term* v) { + v->scalar()->accept(this); + for (auto* t : v->variables()) { + t->accept(this); + } +} + +void IRVisitor::visit(const Polynomial* v) { + v->scalar()->accept(this); + for (auto* t : v->variables()) { + t->accept(this); + } } } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.h b/torch/csrc/jit/tensorexpr/ir_visitor.h index b83b405757ca..35797dc9c786 100644 --- a/torch/csrc/jit/tensorexpr/ir_visitor.h +++ b/torch/csrc/jit/tensorexpr/ir_visitor.h @@ -42,7 +42,8 @@ class FunctionCall; class Allocate; class Free; class Cond; -class LinearForm; +class Term; +class Polynomial; class TORCH_API IRVisitor { public: @@ -90,7 +91,8 @@ class TORCH_API IRVisitor { virtual void visit(const Allocate* v); virtual void visit(const Free* v); virtual void visit(const Cond* v); - virtual void visit(const LinearForm* v); + virtual void visit(const Term* v); + virtual void visit(const Polynomial* v); }; } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 1ff09be225b7..05720f5d88d5 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1114,6 +1114,7 @@ void TensorExprKernel::lowerToBackend(BackendType backendType) { l.ApplyInlines(); Stmt* stmt = l.root_stmt(); + // Arithmetic Simplification. stmt = IRSimplifier::simplify(stmt); // Set up formal params (inputs, then outputs) for kernel. diff --git a/torch/csrc/jit/tensorexpr/types.cpp b/torch/csrc/jit/tensorexpr/types.cpp index 1738f7534bde..2651240061c3 100644 --- a/torch/csrc/jit/tensorexpr/types.cpp +++ b/torch/csrc/jit/tensorexpr/types.cpp @@ -1,6 +1,6 @@ -#include #include #include +#include #include diff --git a/torch/csrc/jit/tensorexpr/types.h b/torch/csrc/jit/tensorexpr/types.h index cb6b9f712262..6a916ab3ef24 100644 --- a/torch/csrc/jit/tensorexpr/types.h +++ b/torch/csrc/jit/tensorexpr/types.h @@ -50,7 +50,7 @@ class TORCH_API Dtype { Dtype(Dtype type, int lanes) : scalar_type_(type.scalar_type_), lanes_(lanes) { if (type.lanes() != 1) { - throw malformed_input(); + throw malformed_input("dtype lanes dont match"); } } int lanes() const { @@ -108,10 +108,15 @@ inline ScalarType promoteTypes(ScalarType a, ScalarType b) { return static_cast(c10::promoteTypes( static_cast(a), static_cast(b))); } -inline ScalarType promoteTypes(Dtype a, Dtype b) { - return static_cast(c10::promoteTypes( - static_cast(a.scalar_type()), - static_cast(b.scalar_type()))); +inline Dtype promoteTypes(Dtype a, Dtype b) { + if (a.lanes() != b.lanes()) { + throw malformed_input("promoting types with different lanes"); + } + return Dtype( + static_cast(c10::promoteTypes( + static_cast(a.scalar_type()), + static_cast(b.scalar_type()))), + a.lanes()); } inline Dtype BinaryOpDtype( @@ -127,22 +132,21 @@ inline Dtype BinaryOpDtype( } if (op1_dtype.lanes() != op2_dtype.lanes()) { - throw malformed_input(); + throw malformed_input("lanes dont match"); } int lanes = op1_dtype.lanes(); - ScalarType resultType = promoteTypes(op1_dtype, op2_dtype); - if (resultType == ScalarType::Undefined) { - throw malformed_input(); + Dtype resultType = promoteTypes(op1_dtype, op2_dtype); + if (resultType.scalar_type() == ScalarType::Undefined) { + throw malformed_input("scalar type doesn't match"); } - if (lanes == 1) { // Use the fixed scalar Dtypes. - return ToDtype(resultType); + return ToDtype(resultType.scalar_type()); } - return Dtype(resultType, lanes); + return resultType; } } // namespace tensorexpr