#include #include #include #include #include #include #include #include namespace torch { namespace jit { using namespace torch::jit::tensorexpr; using SimpleIRExprEval = ExprEval; TEST(Simplify, ConstantFoldSimple) { ExprHandle a(2.0f); ExprHandle b(3.0f); ExprHandle f = (a + b); ExprHandle newF = IRSimplifier::simplify(f); ASSERT_NE(newF.AsNode(), nullptr); ASSERT_EQ(newF.AsNode()->value(), 5); SimpleIRExprEval eval(newF); ASSERT_EQ(eval.value(), 5.f); } TEST(Simplify, ConstantFoldTwoLayer) { ExprHandle a(2.0f); ExprHandle b(3.0f); ExprHandle c(4.0f); ExprHandle d(5.0f); ExprHandle f = (a + b) - (c + d); ExprHandle newF = IRSimplifier::simplify(f); ASSERT_NE(newF.AsNode(), nullptr); ASSERT_EQ(newF.AsNode()->value(), -4); SimpleIRExprEval eval(newF); ASSERT_EQ(eval.value(), -4.f); } TEST(Simplify, ConstantFoldShifts) { ExprHandle a(7); ExprHandle b(2); ExprHandle c(3); ExprHandle f = ((a << b) << b) >> c; ExprHandle newF = IRSimplifier::simplify(f); ASSERT_NE(newF.AsNode(), nullptr); ASSERT_EQ(newF.AsNode()->value(), 14); SimpleIRExprEval eval(newF); ASSERT_EQ(eval.value(), 7 << (4 - 3)); } TEST(Simplify, ConstantFoldBitwise) { ExprHandle a(59); ExprHandle b(22); ExprHandle c(101); ExprHandle f = (a ^ b) & c; ExprHandle newF = IRSimplifier::simplify(f); ASSERT_NE(newF.AsNode(), nullptr); ASSERT_EQ(newF.AsNode()->value(), 37); SimpleIRExprEval eval(newF); ASSERT_EQ(eval.value(), (59 ^ 22) & 101); } TEST(Simplify, ConstantFoldMultiOp) { ExprHandle a(2.0f); ExprHandle b(3.0f); ExprHandle c(4.0f); ExprHandle d(5.0f); ExprHandle e(6.0f); ExprHandle f(7.0f); ExprHandle fn = ((a / e) - (c + d)) * (f / b); ExprHandle newF = IRSimplifier::simplify(fn); ASSERT_NE(newF.AsNode(), nullptr); SimpleIRExprEval eval(newF); SimpleIRExprEval ref(fn); ASSERT_EQ(eval.value(), ref.value()); } TEST(Simplify, ConstantFoldMinMax) { ExprHandle a(12.0f); ExprHandle b(15.0f); ExprHandle c(17.0f); // x = max(12, min(15, 17)). ExprHandle minHandle = Min::make(b, c, true); ExprHandle fn = Max::make(a, minHandle, false); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) ASSERT_EQ(fn.dtype().scalar_type(), ScalarType::Float); ExprHandle newF = IRSimplifier::simplify(fn); ASSERT_NE(newF.AsNode(), nullptr); SimpleIRExprEval eval(newF); ASSERT_EQ(eval.value(), 15.f); } TEST(Simplify, ConstantFoldIntrinsics) { ExprHandle a(2.0f); ExprHandle b(3.0f); ExprHandle c(4.0f); ExprHandle powHandle = Intrinsics::make(kPow, a, b); ExprHandle sinHandle = Intrinsics::make(kSin, powHandle); ExprHandle modHandle = Intrinsics::make(kFmod, c, sinHandle); ExprHandle logHandle = Intrinsics::make(kLog10, modHandle); ExprHandle rndHandle = Intrinsics::make(kRound, logHandle); ExprHandle fn = Intrinsics::make(kAbs, rndHandle); ExprHandle newF = IRSimplifier::simplify(fn); ASSERT_NE(newF.AsNode(), nullptr); ASSERT_EQ(newF.AsNode()->value(), 1); SimpleIRExprEval eval(newF); SimpleIRExprEval ref(fn); ASSERT_EQ(eval.value(), ref.value()); } TEST(Simplify, ConstantFoldCastToBool) { ExprHandle f = Cast::make(kBool, IntImm::make(0)); ExprHandle newF = IRSimplifier::simplify(f); SimpleIRExprEval eval(newF); ASSERT_EQ(eval.value(), false); } TEST(Simplify, ConstantFoldWithVar) { { VarHandle x("x", kInt); ExprHandle body = x * (ExprHandle(2) + ExprHandle(4)); ExprHandle newF = IRSimplifier::simplify(body); MulPtr root = newF.AsNode(); ASSERT_NE(root, nullptr); ASSERT_NE(to(root->lhs()), nullptr); SimpleIRExprEval eval(newF); eval.bindVar(x, ExprHandle(3)); ASSERT_EQ(eval.value(), 3 * (2 + 4)); } { VarHandle x("x", kFloat); ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f)); ExprHandle newF = IRSimplifier::simplify(body); MulPtr root = newF.AsNode(); ASSERT_NE(root, nullptr); ASSERT_NE(to(root->rhs()), nullptr); SimpleIRExprEval eval(newF); eval.bindVar(x, ExprHandle(3.f)); ASSERT_EQ(eval.value(), 3 * (2 + 4)); } } TEST(Simplify, ConditionalSelectFoldSimple) { ExprHandle a(3.0f); ExprHandle b(4.0f); ExprHandle c(3.0f); { ExprHandle f = (a > b); ExprHandle newF = IRSimplifier::simplify(f); ASSERT_NE(newF.AsNode(), nullptr); ASSERT_EQ(newF.AsNode()->value(), 0); SimpleIRExprEval eval(newF); ASSERT_EQ(eval.value(), 0); } { ExprHandle f = (a < b); ExprHandle newF = IRSimplifier::simplify(f); ASSERT_NE(newF.AsNode(), nullptr); ASSERT_EQ(newF.AsNode()->value(), 1); SimpleIRExprEval eval(newF); ASSERT_EQ(eval.value(), 1); } { ExprHandle f = (a == c); ExprHandle newF = IRSimplifier::simplify(f); ASSERT_NE(newF.AsNode(), nullptr); ASSERT_EQ(newF.AsNode()->value(), 1); SimpleIRExprEval eval(newF); ASSERT_EQ(eval.value(), 1); } { ExprHandle f = (a != c); ExprHandle newF = IRSimplifier::simplify(f); ASSERT_NE(newF.AsNode(), nullptr); ASSERT_EQ(newF.AsNode()->value(), 0); SimpleIRExprEval eval(newF); ASSERT_EQ(eval.value(), 0); } } TEST(Simplify, ConditionalSelectFoldTwoLayer) { ExprHandle a(3.0f); ExprHandle b(2.0f); ExprHandle c(2.0f); ExprHandle d(1.0f); { ExprHandle f = (a + b < c + d); ExprHandle newF = IRSimplifier::simplify(f); ASSERT_NE(newF.AsNode(), nullptr); ASSERT_EQ(newF.AsNode()->value(), 0); SimpleIRExprEval eval(newF); ASSERT_EQ(eval.value(), 0); } { ExprHandle f = (a + b > c + d); ExprHandle newF = IRSimplifier::simplify(f); ASSERT_NE(newF.AsNode(), nullptr); ASSERT_EQ(newF.AsNode()->value(), 1); SimpleIRExprEval eval(newF); ASSERT_EQ(eval.value(), 1); } { ExprHandle f = (a + d == b + c); ExprHandle newF = IRSimplifier::simplify(f); ASSERT_NE(newF.AsNode(), nullptr); ASSERT_EQ(newF.AsNode()->value(), 1); SimpleIRExprEval eval(newF); ASSERT_EQ(eval.value(), 1); } { ExprHandle f = (a + d != b + c); ExprHandle newF = IRSimplifier::simplify(f); ASSERT_NE(newF.AsNode(), nullptr); ASSERT_EQ(newF.AsNode()->value(), 0); SimpleIRExprEval eval(newF); ASSERT_EQ(eval.value(), 0); } } TEST(Simplify, ConditionalSelectFoldWithVar) { VarHandle x("x", kFloat); ExprHandle f = x < 4.f; ExprHandle newF = IRSimplifier::simplify(f); IntImmPtr folded = newF.AsNode(); ASSERT_EQ(folded, nullptr); { SimpleIRExprEval eval(newF); eval.bindVar(x, ExprHandle(3.f)); ASSERT_EQ(eval.value(), 1); } { SimpleIRExprEval eval(newF); eval.bindVar(x, ExprHandle(5.f)); ASSERT_EQ(eval.value(), 0); } } TEST(Simplify, UnFoldableExpr) { VarHandle x("x", kFloat); VarHandle y("y", kFloat); ExprHandle body = (ExprHandle(3) * x) + (ExprHandle(5) * y); ExprHandle newF = IRSimplifier::simplify(body); AddPtr root = newF.AsNode(); ASSERT_NE(root, nullptr); ASSERT_EQ(to(root->lhs()), nullptr); ASSERT_EQ(to(root->rhs()), nullptr); SimpleIRExprEval eval(newF); eval.bindVar(x, ExprHandle(3.f)); eval.bindVar(y, ExprHandle(2.f)); ASSERT_EQ(eval.value(), 9 + 10); } TEST(Simplify, HashSimple) { VarHandle x("x", kFloat); ExprHandle a(2.0f); ExprHandle b(3.0f); ExprHandle f = a + b * x; HashProvider hasher; auto hash_x = hasher.hash(x.node()); auto hash_a = hasher.hash(a.node()); auto hash_f = hasher.hash(f.node()); ASSERT_NE(hash_x, (size_t)0); ASSERT_NE(hash_a, (size_t)0); ASSERT_NE(hash_f, (size_t)0); ASSERT_NE(hash_x, hash_a); ASSERT_NE(hash_x, hash_f); ASSERT_NE(hash_a, hash_f); } TEST(Simplify, HashEquivalence) { VarHandle x("x", kFloat); VarHandle y("y", kFloat); ExprHandle f = (x * y) + (x * y); AddPtr root = f.AsNode(); ASSERT_NE(root, nullptr); HashProvider hasher; auto hash_f = hasher.hash(f.node()); auto hash_l = hasher.hash(root->lhs()); auto hash_r = hasher.hash(root->rhs()); // Root not equal to either branch. ASSERT_NE(hash_f, hash_l); ASSERT_NE(hash_f, hash_r); // but branches are equal. ASSERT_EQ(hash_l, hash_r); // Still equivalent if separate. ExprHandle a(2); ExprHandle f2 = x + a / y; ExprHandle b(2); ExprHandle f3 = x + b / y; ASSERT_EQ(hasher.hash(f2.node()), hasher.hash(f3.node())); // Not equivalent if different vars (even with same name). VarHandle z("x", kFloat); ExprHandle f4 = z + b / y; ASSERT_NE(hasher.hash(f2.node()), hasher.hash(f4.node())); // Intrinsics sanity check. ExprHandle f5 = Intrinsics::make(kSin, x) * Intrinsics::make(kCos, x); ASSERT_NE(hasher.hash(f5.node()), (size_t)0); } TEST(Simplify, HashEquivalenceRand) { ExprHandle f = Intrinsics::make(kRand, kFloat) + Intrinsics::make(kRand, kInt); AddPtr root = f.AsNode(); ASSERT_NE(root, nullptr); HashProvider hasher; auto hash_f = hasher.hash(f.node()); auto hash_l = hasher.hash(root->lhs()); auto hash_r = hasher.hash(root->rhs()); // Root not equal to either branch. ASSERT_NE(hash_f, hash_l); ASSERT_NE(hash_f, hash_r); // and branches are NOT equal. ASSERT_NE(hash_l, hash_r); } TEST(Simplify, HashEquivalenceAfterFolding) { VarHandle x("x", kFloat); ExprHandle a(2.0f); ExprHandle b(3.0f); ExprHandle c(5.0f); ExprHandle f1 = ((a + b) * x); ExprHandle f2 = (c * x); HashProvider hasher; auto hash_l = hasher.hash(f1.node()); auto hash_r = hasher.hash(f2.node()); // Root not equal to either branch, and branches not equal. ASSERT_NE(hash_l, hash_r); ExprHandle ff1 = IRSimplifier::simplify(f1); ExprHandle ff2 = IRSimplifier::simplify(f2); auto hash_l_n = hasher.hash(ff1.node()); auto hash_r_n = hasher.hash(ff2.node()); // but branches are now equal. ASSERT_EQ(hash_l_n, hash_r_n); } TEST(Simplify, HashDifferenceTypes) { HashProvider hasher; std::vector immediates; immediates.push_back(alloc(1)); immediates.push_back(alloc(1)); immediates.push_back(alloc(1)); // NOLINTNEXTLINE(modernize-use-bool-literals) immediates.push_back(alloc(1)); immediates.push_back(alloc(1)); immediates.push_back(alloc(1)); immediates.push_back(alloc(1)); immediates.push_back(alloc(1)); immediates.push_back(alloc(1)); // Immediates of different types are not equal. for (unsigned int i = 0; i < immediates.size(); ++i) { for (unsigned int j = i + 1; j < immediates.size(); ++j) { ASSERT_NE(hasher.hash(immediates[i]), hasher.hash(immediates[j])); } } // But coerced immediates are if they are the same type: ExprHandle f1 = ExprHandle(2.f) + CharImm::make(1); ExprHandle f2 = Cast::make(kFloat, IntImm::make(3)); ExprHandle ff1 = IRSimplifier::simplify(f1); ExprHandle ff2 = IRSimplifier::simplify(f2); ASSERT_EQ(hasher.hash(ff1.node()), hasher.hash(ff2.node())); } TEST(Simplify, HashLargeExpression) { constexpr int N = 1024; BufHandle a("A", {N}, kInt); BufHandle b("B", {N}, kInt); BufHandle c("C", {N}, kInt); VarHandle i("i", kInt); auto memcpy_stmt = For::make( i, 0, N, Store::make( c, {i}, CompareSelect::make( Load::make(a, {i}), Load::make(b, {i}), CompareSelectOperation::kEQ))); BufHandle d("D", {1}, kInt); BufHandle e("E", {1}, kInt); auto store_ramp_stmt = Store::make( e, {Ramp::make(0, 1, 4)}, Load::make(d, {Ramp::make(0, 1, 4)})); auto if_stmt = Cond::make( CompareSelect::make( Load::make(a, {i}), Load::make(b, {i}), CompareSelectOperation::kGE), memcpy_stmt, store_ramp_stmt); HashProvider hasher; auto hash_r = hasher.hash(if_stmt); // We should not have to do any more work. ASSERT_TRUE(hasher.cachedHash(memcpy_stmt)); auto hash_t = hasher.hash(memcpy_stmt); ASSERT_TRUE(hasher.cachedHash(store_ramp_stmt)); auto hash_f = hasher.hash(store_ramp_stmt); // Root not equal to either branch, and branches not equal. ASSERT_NE(hash_r, hash_t); ASSERT_NE(hash_r, hash_f); ASSERT_NE(hash_t, hash_f); } TEST(Simplify, HashForLoopOptions) { constexpr int N = 1024; BufHandle a("A", {N}, kInt); BufHandle b("B", {N}, kInt); BufHandle c("C", {N}, kInt); VarHandle i("i", kInt); auto for_stmt = For::make( i, 0, N, Store::make( c, {i}, CompareSelect::make( Load::make(a, {i}), Load::make(b, {i}), CompareSelectOperation::kEQ))); HashProvider hasher; auto hash_before = hasher.hash(for_stmt); hasher.clearCache(); for_stmt->set_gpu_block_index(LoopOptions::IDX_X); auto hash_block_idx = hasher.hash(for_stmt); hasher.clearCache(); ASSERT_NE(hash_before, hash_block_idx); for_stmt->set_gpu_block_index(LoopOptions::IDX_UNSET); auto hash_reset = hasher.hash(for_stmt); hasher.clearCache(); ASSERT_EQ(hash_before, hash_reset); for_stmt->set_gpu_thread_index(LoopOptions::IDX_X); auto hash_thread_idx = hasher.hash(for_stmt); ASSERT_NE(hash_before, hash_thread_idx); ASSERT_NE(hash_block_idx, hash_thread_idx); } /// (2 + x) + 4 => x + 6 TEST(Simplify, SimplifyAdd) { VarHandle x("x", kInt); VarHandle y("y", kInt); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) VarHandle m("m", kInt); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) VarHandle n("n", kInt); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) VarHandle n_1("n_1", kInt); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4); ExprHandle simplified = IRSimplifier::simplify(body); AddPtr root = simplified.AsNode(); ASSERT_NE(root, nullptr); VarPtr lhs = to(root->lhs()); ASSERT_NE(lhs, nullptr); ASSERT_EQ(lhs->name_hint(), "x"); IntImmPtr rhs = to(root->rhs()); ASSERT_NE(rhs, nullptr); ASSERT_EQ(rhs->value(), 6.f); } /// (2 - x) - 4 => -2 - x TEST(Simplify, SimplifySub) { VarHandle x("x", kInt); ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4); ExprHandle simplified = IRSimplifier::simplify(body); SubPtr root = simplified.AsNode(); ASSERT_NE(root, nullptr); IntImmPtr lhs = to(root->lhs()); ASSERT_NE(lhs, nullptr); ASSERT_EQ(lhs->value(), -2.f); VarPtr rhs = to(root->rhs()); ASSERT_NE(rhs, nullptr); ASSERT_EQ(rhs->name_hint(), "x"); } /// 2 * (1 - x) - 4 => 2 * (-3 - x) TEST(Simplify, SimplifyMultiLayer) { VarHandle x("x", kInt); ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4)); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); IS_IMM_WITH_VAL(Int, sub->lhs(), -3); IS_VAR_WITH_NAME(sub->rhs(), "x"); } /// 2 * (3 * x) - (x * 4) => 2 * x TEST(Simplify, SimplifyMultiTerm) { VarHandle x("x", kInt); ExprHandle body = (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); ExprHandle simplified = IRSimplifier::simplify(body); MulPtr root = simplified.AsNode(); ASSERT_NE(root, nullptr); IntImmPtr lhs = to(root->lhs()); ASSERT_NE(lhs, nullptr); ASSERT_EQ(lhs->value(), 2); VarPtr rhs = to(root->rhs()); ASSERT_NE(rhs, nullptr); ASSERT_EQ(rhs->name_hint(), "x"); } /// 2 * (3 * (long)x) - (x * 4) => 2 * x TEST(Simplify, SimplifyCasts) { VarHandle x("x", kLong); ExprHandle body = (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); ExprHandle simplified = IRSimplifier::simplify(body); MulPtr root = simplified.AsNode(); ASSERT_NE(root, nullptr); LongImmPtr lhs = to(root->lhs()); ASSERT_NE(lhs, nullptr); ASSERT_EQ(lhs->value(), 2); VarPtr rhs = to(root->rhs()); ASSERT_NE(rhs, nullptr); ASSERT_EQ(rhs->name_hint(), "x"); } /// (x + 0) * 1 => x TEST(Simplify, SimplifyEliminatesNoOps) { VarHandle x("x", kInt); ExprHandle body = (x + ExprHandle(0)) * 1; ExprHandle simplified = IRSimplifier::simplify(body); VarPtr root = simplified.AsNode(); ASSERT_NE(root, nullptr); ASSERT_EQ(root->name_hint(), "x"); } /// Cannot simplify this. TEST(Simplify, SimplifyMultiVar) { VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = x * 24 + y * 34; ExprHandle simplified = IRSimplifier::simplify(body); AddPtr root = simplified.AsNode(); ASSERT_NE(root, nullptr); MulPtr lhs = to(root->lhs()); ASSERT_NE(lhs, nullptr); VarPtr varX = to(lhs->rhs()); ASSERT_NE(varX, nullptr); ASSERT_EQ(varX->name_hint(), "x"); MulPtr rhs = to(root->rhs()); ASSERT_NE(rhs, nullptr); VarPtr varY = to(rhs->rhs()); ASSERT_NE(varY, nullptr); ASSERT_EQ(varY->name_hint(), "y"); } // x + 2 + y => x + y + 2 TEST(Simplify, DISABLED_SimplifyReorderings) { VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = x + 2 + y; ExprHandle simplified = IRSimplifier::simplify(body); AddPtr root = simplified.AsNode(); ASSERT_NE(root, nullptr); IS_NODE_WITH_NAME(Add, root->lhs(), rhs); IS_VAR_WITH_NAME(rhs->lhs(), "x"); IS_VAR_WITH_NAME(rhs->rhs(), "y"); IS_IMM_WITH_VAL(Int, root->rhs(), 2); } /// y + x * 0 => y TEST(Simplify, SimplifyEliminatesVar) { VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = y + x * ExprHandle(0); ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "y"); } TEST(Simplify, SimplifyAdds) { VarHandle x("x", kInt); VarHandle y("y", kInt); { // (x + y) + (x + y) => 2 * (x + y) ExprHandle body = (x + y) + (x + y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), root); IS_IMM_WITH_VAL(Int, root->lhs(), 2); IS_NODE_WITH_NAME(Add, root->rhs(), add); IS_VAR_WITH_NAME(add->lhs(), "x"); IS_VAR_WITH_NAME(add->rhs(), "y"); } { // (x * y) + (x * y) => 2 * (x * y) ExprHandle body = (x * y) + (x * y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), root); IS_IMM_WITH_VAL(Int, root->lhs(), 2); IS_NODE_WITH_NAME(Mul, root->rhs(), mul); IS_VAR_WITH_NAME(mul->lhs(), "x"); IS_VAR_WITH_NAME(mul->rhs(), "y"); } { // (x - y) + (x - y) => 2 * (x - y) ExprHandle body = (x - y) + (x - y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); IS_VAR_WITH_NAME(rhs->lhs(), "x"); IS_VAR_WITH_NAME(rhs->rhs(), "y"); } { // (x + x + x + x) => 4 * x ExprHandle body = (x + x + x + x); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), root); IS_IMM_WITH_VAL(Int, root->lhs(), 4); IS_VAR_WITH_NAME(root->rhs(), "x"); } { // (x + 0) => x. ExprHandle body = x + 0; ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "x"); } { // (x + 0.f) => float(x). ExprHandle body = x + 0.f; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cast, simplified.node(), cast); ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); IS_VAR_WITH_NAME(cast->src_value(), "x"); } } TEST(Simplify, SimplifyMuls) { VarHandle x("x", kInt); VarHandle y("y", kInt); { // (x + y) * (x + y) => (x + y) * (x + y) // We don't attempt to simplify multiplication of polynomials since the // result is only very rarely more efficient. ExprHandle body = (x + y) * (x + y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); IS_VAR_WITH_NAME(lhs->lhs(), "x"); IS_VAR_WITH_NAME(lhs->rhs(), "y"); IS_NODE_WITH_NAME(Add, mul->rhs(), rhs); IS_VAR_WITH_NAME(rhs->lhs(), "x"); IS_VAR_WITH_NAME(rhs->rhs(), "y"); } { // x * y * x * y => x * x * y * y // These get reordered only. ExprHandle body = x * y * x * y; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul1); IS_NODE_WITH_NAME(Mul, mul1->lhs(), mul2); IS_NODE_WITH_NAME(Mul, mul2->lhs(), mul3); IS_VAR_WITH_NAME(mul1->rhs(), "y"); IS_VAR_WITH_NAME(mul2->rhs(), "y"); IS_VAR_WITH_NAME(mul3->lhs(), "x"); IS_VAR_WITH_NAME(mul3->rhs(), "x"); } { // 1 * (x * 1) => x // Ones cancel cleanly. ExprHandle body = ExprHandle(1) * (x * ExprHandle(1)); ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "x"); } { // 1.f * (x * 1.f) => x // Even float ones cancel cleanly, but carry their type. ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(1.f)); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cast, simplified.node(), cast); ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); IS_VAR_WITH_NAME(cast->src_value(), "x"); } { // 1 * (x * 1.f) => x // One float is enough to cast the expr. ExprHandle body = ExprHandle(1) * (x * ExprHandle(1.f)); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cast, simplified.node(), cast); ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); IS_VAR_WITH_NAME(cast->src_value(), "x"); } { // 1 * (x * 0) => 0 // Zeroes are eliminated. ExprHandle body = ExprHandle(1) * (x * ExprHandle(0)); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 0); } { // 1 * (x * 0) => 0 // But not for Float since nan * 0 = nan. ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(0.f)); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_NODE_WITH_NAME(Cast, mul->lhs(), cast); ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); IS_VAR_WITH_NAME(cast->src_value(), "x"); IS_IMM_WITH_VAL(Float, mul->rhs(), 0.0); } { // (x - y) * (x - y) => (x - y) * (x - y) // As with Add we don't attempt simplification of this. ExprHandle body = (x - y) * (x - y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_NODE_WITH_NAME(Sub, mul->lhs(), lhs); IS_VAR_WITH_NAME(lhs->lhs(), "x"); IS_VAR_WITH_NAME(lhs->rhs(), "y"); IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); IS_VAR_WITH_NAME(rhs->lhs(), "x"); IS_VAR_WITH_NAME(rhs->rhs(), "y"); } { // (x + y) * (x - y) => (x + y) * (x - y) // Don't simplify with different ops on each side. ExprHandle body = (x + y) * (x - y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); IS_VAR_WITH_NAME(lhs->lhs(), "x"); IS_VAR_WITH_NAME(lhs->rhs(), "y"); IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); IS_VAR_WITH_NAME(rhs->lhs(), "x"); IS_VAR_WITH_NAME(rhs->rhs(), "y"); } { // Multiply a polynomial by a term. // - term with no scalar, poly with non-identity scalar. // x * (y + 1) => x + x * y ExprHandle body = x * (y + ExprHandle(1)); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_VAR_WITH_NAME(add->lhs(), "x"); IS_NODE_WITH_NAME(Mul, add->rhs(), mul); IS_VAR_WITH_NAME(mul->lhs(), "x"); IS_VAR_WITH_NAME(mul->rhs(), "y"); } { // Multiply a polynomial by a term. // - term with identity scalar, poly with non-identity scalar. // (x * 1) * (y + 1) => x + x * y ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(1)); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_VAR_WITH_NAME(add->lhs(), "x"); IS_NODE_WITH_NAME(Mul, add->rhs(), mul); IS_VAR_WITH_NAME(mul->lhs(), "x"); IS_VAR_WITH_NAME(mul->rhs(), "y"); } { // Multiply a polynomial by a term. // - term with non-identity scalar, poly with non-identity scalar. // (x * 2) * (y + 1) => 2 * (x + x * y) ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(1)); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_NODE_WITH_NAME(Add, mul->rhs(), add); IS_VAR_WITH_NAME(add->lhs(), "x"); IS_NODE_WITH_NAME(Mul, add->rhs(), mul2); IS_VAR_WITH_NAME(mul2->lhs(), "x"); IS_VAR_WITH_NAME(mul2->rhs(), "y"); } { // Multiply a polynomial by a term. // - term with non-identity scalar, poly with identity scalar. // (x * 2) * (y + 0) => 2 * (x * y) ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(0)); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_NODE_WITH_NAME(Mul, mul->rhs(), mul2); IS_VAR_WITH_NAME(mul2->lhs(), "x"); IS_VAR_WITH_NAME(mul2->rhs(), "y"); } { // Multiply a polynomial by a term. // - term with identity scalar, poly with identity scalar. // (x * 1) * (y + 0) => x * y ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(0)); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_VAR_WITH_NAME(mul->lhs(), "x"); IS_VAR_WITH_NAME(mul->rhs(), "y"); } { // Multiply a polynomial by a term. // - term with no scalar, poly with identity scalar. // x * (y + 0) => x * y ExprHandle body = x * (y + ExprHandle(0)); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_VAR_WITH_NAME(mul->lhs(), "x"); IS_VAR_WITH_NAME(mul->rhs(), "y"); } } // Sub an expr from itself will result in zero. TEST(Simplify, SimplifySubs) { VarHandle x("x", kInt); VarHandle y("y", kInt); { // (x + y) - (x + y) => 0 ExprHandle body = (x + y) - (x + y); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 0); } { // (x * y) - (x * y) => 0 ExprHandle body = (x * y) - (x * y); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 0); } { // (x - y) - (x - y) => 0 ExprHandle body = (x - y) - (x - y); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 0); } { // (x + y) - 2 * (x + y) => -1 * x - y ExprHandle body = (x + y) - ExprHandle(2) * (x + y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Sub, simplified.node(), sub); IS_NODE_WITH_NAME(Mul, sub->lhs(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), -1); IS_VAR_WITH_NAME(mul->rhs(), "x"); IS_VAR_WITH_NAME(sub->rhs(), "y"); } { // (x + y) - y => x ExprHandle body = (x + y) - y; ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "x"); } { // (x - 0) => x. ExprHandle body = x - 0; ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "x"); } { // (x - 0.f) => x. // Simple enough to cancel in float. ExprHandle body = x - ExprHandle(0.f); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cast, simplified.node(), cast); ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); IS_VAR_WITH_NAME(cast->src_value(), "x"); } { // (x - (float)(y - y)) => x. ExprHandle body = x - Cast::make(kFloat, y - y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cast, simplified.node(), cast); ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); IS_VAR_WITH_NAME(cast->src_value(), "x"); } { // (x - y) - y => x - 2 * y ExprHandle body = (x - y) - y; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Sub, simplified.node(), sub); IS_VAR_WITH_NAME(sub->lhs(), "x"); IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_VAR_WITH_NAME(mul->rhs(), "y"); } { // 2 * x - x => x ExprHandle body = (ExprHandle(2) * x) - x; ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "x"); } { // x - 2 * x = -1 * x // We don't have a unary negate, but this could be 0 -x I guess? ExprHandle body = x - (ExprHandle(2) * x); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), -1); IS_VAR_WITH_NAME(mul->rhs(), "x"); } { // (x + y + 5) * (x - x) => 0 // Cancelling out one side of Mul cancels both. ExprHandle body = (x + y + 5) * (x - x); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 0); } { // Cancel out opaque modulus. ExprHandle body = (x % y + 2) - (x % y); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 2); } { // Cancel out opaque modulus with a bit more going on. ExprHandle body = (x % y + (x * 2 - x - y * 0) - x + 2) - (x % y); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 2); } { // Sub where result is negative. ExprHandle body = x - (x + 1); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), -1); } { // Sub where result is positive due to negative scalar on RHS. ExprHandle body = x - (x - 1); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 1); } { // Term - Polynomial sub where RHS must be negated. ExprHandle body = (x * 2) - (x * 2 + 1); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), -1); } { // Term - Polynomial sub where the result is a Term. ExprHandle body = (y * x * 2) - (x * y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_VAR_WITH_NAME(mul->lhs(), "x"); IS_VAR_WITH_NAME(mul->rhs(), "y"); } { // Term - Polynomial sub where the result is a Polynomial. ExprHandle body = (x * 2) - (x + 1); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Sub, simplified.node(), sub); IS_VAR_WITH_NAME(sub->lhs(), "x"); IS_IMM_WITH_VAL(Int, sub->rhs(), 1); } } TEST(Simplify, SimplifyDiv) { VarHandle x("x", kInt); { ExprHandle body = ExprHandle(0) / x; ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 0); } { ExprHandle body = x / 1; ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "x"); } } TEST(Simplify, SimplifyDivWithLoopContext0) { // Stmt to simplify: // for (int i = 0; i < 100; i++) { // A[i] = i / 100; //} VarHandle i("i", kInt); BufHandle a_buf("A", {100}, kInt); auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i / 100))); const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); const std::string& verification_pattern = R"IR( # CHECK: for (int i # CHECK-NEXT: A[i] = 0; )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } TEST(Simplify, SimplifyDivWithLoopContext1) { // Stmt to simplify: // for (const auto i : c10::irange(6)) { // A[i] = (i + 24) / 6; //} VarHandle i("i", kInt); BufHandle a_buf("A", {6}, kInt); auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / 6)); const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); const std::string& verification_pattern = R"IR( # CHECK: for (int i # CHECK-NEXT: A[i] = 4; )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } TEST(Simplify, SimplifyDivWithLoopContext2) { // Stmt to simplify: // for (const auto i : c10::irange(5)) { // A[i] = (i + 25) / 6; //} VarHandle i("i", kInt); BufHandle a_buf("A", {5}, kInt); auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) / 6)); const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); const std::string& verification_pattern = R"IR( # CHECK: for (int i # CHECK-NEXT: A[i] = 4; )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } TEST(Simplify, SimplifyDivWithLoopContext3) { // Stmt to simplify: // for (const auto i : c10::irange(6)) { // A[i] = (i + 24) / (-6); //} VarHandle i("i", kInt); BufHandle a_buf("A", {6}, kInt); auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / (-6))); const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); const std::string& verification_pattern = R"IR( # CHECK: for (int i # CHECK-NOT: A[i] = -4; )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } TEST(Simplify, SimplifyDivWithLoopContext4) { // Stmt to simplify: // for (const auto i : c10::irange(5)) { // A[i] = (i - 5) / 6; //} VarHandle i("i", kInt); BufHandle a_buf("A", {5}, kInt); auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) / 6)); const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); const std::string& verification_pattern = R"IR( # CHECK: for (int i # CHECK-NOT: A[i] = 0; )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } TEST(Simplify, SimplifyDivWithLoopContext5) { // Stmt to simplify: // for (const auto i : c10::irange(6)) { // for (const auto j : c10::irange(10)) { // A[i, j] = (i + 6*j) / 6; // } //} VarHandle i("i", kInt); VarHandle j("j", kInt); BufHandle a_buf("A", {6, 10}, kInt); auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / 6)); auto for_i = For::make(i, 0, 6, for_j); const StmtPtr simplified = IRSimplifier::simplify(for_i); std::ostringstream oss; oss << *(simplified); const std::string& verification_pattern = R"IR( # CHECK: for (int i # CHECK: for (int j # CHECK-NEXT: A[i, j] = j; )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } TEST(Simplify, SimplifyDivWithLoopContext6) { // Stmt to simplify: // for (const auto i : c10::irange(6)) { // for (int j = -1; j < 9; j++) { // A[i, j+1] = (i + 6*j) / 6; // } //} VarHandle i("i", kInt); VarHandle j("j", kInt); BufHandle a_buf("A", {6, 10}, kInt); auto for_j = For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) / 6)); auto for_i = For::make(i, 0, 6, for_j); const StmtPtr simplified = IRSimplifier::simplify(for_i); std::ostringstream oss; oss << *(simplified); const std::string& verification_pattern = R"IR( # CHECK: for (int i # CHECK: for (int j # CHECK-NOT: A[i, j] = j; )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } TEST(Simplify, SimplifyDivWithLoopContext7) { // Stmt to simplify: // for (const auto i : c10::irange(6)) { // for (const auto j : c10::irange(10)) { // A[i, j] = (i + 6*j) / (-6); // } //} VarHandle i("i", kInt); VarHandle j("j", kInt); BufHandle a_buf("A", {6, 10}, kInt); auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / (-6))); auto for_i = For::make(i, 0, 6, for_j); const StmtPtr simplified = IRSimplifier::simplify(for_i); std::ostringstream oss; oss << *(simplified); const std::string& verification_pattern = R"IR( # CHECK: for (int i # CHECK: for (int j # CHECK-NOT: A[i, j] = -j; )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } TEST(Simplify, SimplifyModWithLoopContext0) { // Stmt to simplify: // for (const auto i : c10::irange(100)) { // A[i] = i % 100; //} VarHandle i("i", kInt); BufHandle a_buf("A", {100}, kInt); auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i % 100))); const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); const std::string& verification_pattern = R"IR( # CHECK: for (int i # CHECK-NEXT: A[i] = i; )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } TEST(Simplify, SimplifyModWithLoopContext1) { // Stmt to simplify: // for (const auto i : c10::irange(6)) { // A[i] = (i + 24) % 6; //} VarHandle i("i", kInt); BufHandle a_buf("A", {6}, kInt); auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % 6)); const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); const std::string& verification_pattern = R"IR( # CHECK: for (int i # CHECK-NEXT: A[i] = i; )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } TEST(Simplify, SimplifyModWithLoopContext2) { // Stmt to simplify: // for (const auto i : c10::irange(5)) { // A[i] = (i + 25) % 6; //} VarHandle i("i", kInt); BufHandle a_buf("A", {5}, kInt); auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) % 6)); const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); const std::string& verification_pattern = R"IR( # CHECK: for (int i # CHECK-NEXT: A[i] = i + 1; )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } TEST(Simplify, SimplifyModWithLoopContext3) { // Stmt to simplify: // for (const auto i : c10::irange(6)) { // A[i] = (i + 24) % (-6); //} VarHandle i("i", kInt); BufHandle a_buf("A", {6}, kInt); auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % (-6))); const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); const std::string& verification_pattern = R"IR( # CHECK: for (int i # CHECK-NOT: A[i] = i; )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } TEST(Simplify, SimplifyModWithLoopContext4) { // Stmt to simplify: // for (const auto i : c10::irange(5)) { // A[i] = (i - 5) % 6; //} VarHandle i("i", kInt); BufHandle a_buf("A", {5}, kInt); auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) % 6)); const StmtPtr simplified = IRSimplifier::simplify(for_stmt); std::ostringstream oss; oss << *(simplified); const std::string& verification_pattern = R"IR( # CHECK: for (int i # CHECK-NOT: A[i] = i - 5; )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } TEST(Simplify, SimplifyModWithLoopContext5) { // Stmt to simplify: // for (const auto i : c10::irange(6)) { // for (const auto j : c10::irange(10)) { // A[i, j] = (i + 6*j) % 6; // } //} VarHandle i("i", kInt); VarHandle j("j", kInt); BufHandle a_buf("A", {6, 10}, kInt); auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % 6)); auto for_i = For::make(i, 0, 6, for_j); const StmtPtr simplified = IRSimplifier::simplify(for_i); std::ostringstream oss; oss << *(simplified); const std::string& verification_pattern = R"IR( # CHECK: for (int i # CHECK: for (int j # CHECK-NEXT: A[i, j] = i; )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } TEST(Simplify, SimplifyModWithLoopContext6) { // Stmt to simplify: // for (const auto i : c10::irange(6)) { // for (int j = -1; j < 9; j++) { // A[i, j+1] = (i + 6*j) % 6; // } //} VarHandle i("i", kInt); VarHandle j("j", kInt); BufHandle a_buf("A", {6, 10}, kInt); auto for_j = For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) % 6)); auto for_i = For::make(i, 0, 6, for_j); const StmtPtr simplified = IRSimplifier::simplify(for_i); std::ostringstream oss; oss << *(simplified); const std::string& verification_pattern = R"IR( # CHECK: for (int i # CHECK: for (int j # CHECK-NOT: A[i, j] = i; )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } TEST(Simplify, SimplifyModWithLoopContext7) { // Stmt to simplify: // for (const auto i : c10::irange(6)) { // for (const auto j : c10::irange(10)) { // A[i, j] = (i + 6*j) % (-6); // } //} VarHandle i("i", kInt); VarHandle j("j", kInt); BufHandle a_buf("A", {6, 10}, kInt); auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % (-6))); auto for_i = For::make(i, 0, 6, for_j); const StmtPtr simplified = IRSimplifier::simplify(for_i); std::ostringstream oss; oss << *(simplified); const std::string& verification_pattern = R"IR( # CHECK: for (int i # CHECK: for (int j # CHECK-NOT: A[i, j] = i; )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } TEST(Simplify, SimplifyMod) { VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); { // Constant folding works. ExprHandle body = ExprHandle(10) % 8; ExprHandle simplified = IRSimplifier::simplify(body); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) IS_IMM_WITH_VAL(Int, simplified.node(), 2); } { // x % x => 0 ExprHandle body = x % x; ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 0); } { // 0 % x => 0 ExprHandle body = ExprHandle(0) % x; ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 0); } { // x % 1 => 0 ExprHandle body = x % 1; ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 0); } { // Doesn't change unknown mods. // x % y => x % y ExprHandle body = x % y; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mod, simplified.node(), mod); IS_VAR_WITH_NAME(mod->lhs(), "x"); IS_VAR_WITH_NAME(mod->rhs(), "y"); } { // don't touch if RHS is unknown. // 4 % x => 4 % x ExprHandle body = ExprHandle(4) % x; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mod, simplified.node(), mod); IS_IMM_WITH_VAL(Int, mod->lhs(), 4); IS_VAR_WITH_NAME(mod->rhs(), "x"); } { // don't touch if LHS is unknown. // x % 4 => x % 4 ExprHandle body = x % 4; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mod, simplified.node(), mod); IS_VAR_WITH_NAME(mod->lhs(), "x"); IS_IMM_WITH_VAL(Int, mod->rhs(), 4); } { // if LHS is a multiple of RHS, mod is zero. // 2 * x % x => 0 ExprHandle body = (x * 2) % x; ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 0); } { // true even if the multiple is not constant. // x * y % x => 0 ExprHandle body = (x * y) % x; ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 0); } { // true with multiple unknown values in LHS. // x * y * z % x => 0 ExprHandle body = (x * y * z) % x; ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 0); } { // true if the denom is compound. // x * y * z % y * z => 0 ExprHandle body = (x * y * z) % (y * z); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 0); } { // Sanity check true with scalars that are multiples. // 12 * x % 4 => 0 ExprHandle body = (x * 12) % 4; ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 0); } { // Sanity check not true if the smaller scalar is on LHS. // 4 * x % 12 => 4 * x % 12 ExprHandle body = (x * 4) % 12; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mod, simplified.node(), mod); IS_NODE_WITH_NAME(Mul, mod->lhs(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 4); IS_VAR_WITH_NAME(mul->rhs(), "x"); IS_IMM_WITH_VAL(Int, mod->rhs(), 12); } { // Both scalar and symbolic in multiple. // (6 * x * y) % (3 * x * y) => 0 ExprHandle body = (ExprHandle(6) * x * y) % (x * y * 3); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 0); } } // Test that mixing ops together simplifies as expected. TEST(Simplify, SimplifyMultiOp) { VarHandle x("x", kInt); VarHandle y("y", kInt); { // (x * y) + (x - y) => (x + x * y) - y ExprHandle body = (x * y) + (x - y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Sub, simplified.node(), sub); IS_NODE_WITH_NAME(Add, sub->lhs(), add); IS_VAR_WITH_NAME(add->lhs(), "x"); IS_NODE_WITH_NAME(Mul, add->rhs(), mul); IS_VAR_WITH_NAME(mul->lhs(), "x"); IS_VAR_WITH_NAME(mul->rhs(), "y"); IS_VAR_WITH_NAME(sub->rhs(), "y"); } { // (x + y) - x * y => (x + y) - x * y ExprHandle body = (x + y) - x * y; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Sub, simplified.node(), sub); IS_NODE_WITH_NAME(Add, sub->lhs(), add); IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); IS_VAR_WITH_NAME(add->lhs(), "x"); IS_VAR_WITH_NAME(add->rhs(), "y"); IS_VAR_WITH_NAME(mul->lhs(), "x"); IS_VAR_WITH_NAME(mul->rhs(), "y"); } { // (x - y) - (x + y) => -2 * y ExprHandle body = (x - y) - (x + y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), -2); IS_VAR_WITH_NAME(mul->rhs(), "y"); } { // (x - 0) + (x * 1) - (x + 0) => x ExprHandle body = (x - 0) + (x * 1) - (x + 0); ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "x"); } { // (x - 0.f) + (x * 1.f) - (x + 0.f) => float(x) + float(x) - float(x) // Even in Float simple terms cancel out, but the variable ones cannot. ExprHandle body = (x - ExprHandle(0.f)) + (x * ExprHandle(1.f)) - (x + ExprHandle(0.f)); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Sub, simplified.node(), sub); IS_NODE_WITH_NAME(Add, sub->lhs(), add); IS_NODE_WITH_NAME(Cast, add->lhs(), cast1); IS_VAR_WITH_NAME(cast1->src_value(), "x"); IS_NODE_WITH_NAME(Cast, add->rhs(), cast2); IS_VAR_WITH_NAME(cast2->src_value(), "x"); IS_NODE_WITH_NAME(Cast, sub->rhs(), cast3); IS_VAR_WITH_NAME(cast3->src_value(), "x"); } } // Test that chaining many ops together works as expected. TEST(Simplify, SimplifyManyOps) { VarHandle x("x", kInt); VarHandle y("y", kInt); { // x + y + x + x + y + y + x + y + x = 4 * y + 5 * x ExprHandle body = x + y + x + x + y + y + x + y + x; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); IS_IMM_WITH_VAL(Int, lhs->lhs(), 4); IS_VAR_WITH_NAME(lhs->rhs(), "y"); IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); IS_VAR_WITH_NAME(rhs->rhs(), "x"); } { // x - y + x + x - y - y + x - y + x = 5 * x - 4 * y ExprHandle body = x - y + x + x - y - y + x - y + x; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Sub, simplified.node(), add); IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); IS_IMM_WITH_VAL(Int, lhs->lhs(), 5); IS_VAR_WITH_NAME(lhs->rhs(), "x"); IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); IS_VAR_WITH_NAME(rhs->rhs(), "y"); } { // x + y + x - x - y - y + x + y + x = 3 * x ExprHandle body = x + y + x - x - y - y + x + y + x; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 3); IS_VAR_WITH_NAME(mul->rhs(), "x"); } } TEST(Simplify, SimplifyFactorization) { VarHandle x("x", kInt); VarHandle y("y", kInt); { // (2 * x) + (2 * y) => 2 * (x + y) ExprHandle body = (ExprHandle(2) * x + ExprHandle(2) * y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_NODE_WITH_NAME(Add, mul->rhs(), add); IS_VAR_WITH_NAME(add->lhs(), "x"); IS_VAR_WITH_NAME(add->rhs(), "y"); } { // Factorization when scalars have common divider. // (2 * x) + (4 * y) => 2 * (2 * y + x) ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_NODE_WITH_NAME(Add, mul->rhs(), add); IS_VAR_WITH_NAME(add->lhs(), "x"); IS_NODE_WITH_NAME(Mul, add->rhs(), mul2); IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); IS_VAR_WITH_NAME(mul2->rhs(), "y"); } { // Factorization attempt without a common divider. // (2 * x) + (5 * y) => (5 * y) + (2 * x) ExprHandle body = (ExprHandle(2) * x + ExprHandle(5) * y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); IS_VAR_WITH_NAME(lhs->rhs(), "x"); IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); IS_VAR_WITH_NAME(rhs->rhs(), "y"); } { // Factorization after merging. // (2 * x) + (4 * y) + (8 * x + 6 * y) => 10 * (x + y) ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y) + (ExprHandle(8) * x + ExprHandle(6) * y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 10); IS_NODE_WITH_NAME(Add, mul->rhs(), add); IS_VAR_WITH_NAME(add->lhs(), "x"); IS_VAR_WITH_NAME(add->rhs(), "y"); } { // Factorization with common divider but different signs. // (2 * x) + (-4 * y) => 2 * (x - 2 * y) ExprHandle body = (ExprHandle(2) * x + ExprHandle(-4) * y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); IS_VAR_WITH_NAME(sub->lhs(), "x"); IS_NODE_WITH_NAME(Mul, sub->rhs(), mul2); IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); IS_VAR_WITH_NAME(mul2->rhs(), "y"); } { // Factorization with all negative numbers. // (-2 * x) + (-4 * y) => 2 * (-1 * x - 2 * y) ExprHandle body = ExprHandle(-2) * x + ExprHandle(-4) * y; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); IS_NODE_WITH_NAME(Mul, sub->lhs(), mul2); IS_IMM_WITH_VAL(Int, mul2->lhs(), -1); IS_VAR_WITH_NAME(mul2->rhs(), "x"); IS_NODE_WITH_NAME(Mul, sub->rhs(), mul3); IS_IMM_WITH_VAL(Int, mul3->lhs(), 2); IS_VAR_WITH_NAME(mul3->rhs(), "y"); } { // The following test ensures that there in no infinite recursion during // factorization when negative numbers are involved. VarHandle a("a", kInt); VarHandle b("b", kInt); VarHandle c("c", kInt); VarHandle d("d", kInt); VarHandle e("e", kInt); VarHandle f("f", kInt); VarHandle g("g", kInt); VarHandle h("h", kInt); ExprHandle body = a * 1024 + 0 + b * (-1) + c * (-1) + d * 1 + e * 1 + f * 32 + g * (-1024) + h * (-32); ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR( simplified, "((((((d + e) + 1024 * a) + 32 * f) - b) - c) - 1024 * g) - 32 * h"); } } // (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (y + 3 * z + 4 * x) TEST(Simplify, SimplifyFactorizeUneven) { VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); ExprHandle body = (ExprHandle(4) * x + y + z * 2) + (ExprHandle(4) * x + y + z * 4); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), root); IS_IMM_WITH_VAL(Int, root->lhs(), 2); IS_NODE_WITH_NAME(Add, root->rhs(), add1); IS_NODE_WITH_NAME(Add, add1->lhs(), add2); IS_VAR_WITH_NAME(add2->lhs(), "y"); IS_NODE_WITH_NAME(Mul, add2->rhs(), zmul); IS_NODE_WITH_NAME(Mul, add1->rhs(), xmul); IS_IMM_WITH_VAL(Int, xmul->lhs(), 4); IS_VAR_WITH_NAME(xmul->rhs(), "x"); IS_IMM_WITH_VAL(Int, zmul->lhs(), 3); IS_VAR_WITH_NAME(zmul->rhs(), "z"); } // (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y) // This is kind of a placeholder test for variable factorization. TEST(Simplify, SimplifyDeeperTerms) { VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = (x * y) + (ExprHandle(2) * x) * (x + y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm); IS_VAR_WITH_NAME(xxTerm->lhs(), "x"); IS_VAR_WITH_NAME(xxTerm->rhs(), "x"); IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); IS_IMM_WITH_VAL(Int, rhs->lhs(), 3); IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm); IS_VAR_WITH_NAME(xyTerm->lhs(), "x"); IS_VAR_WITH_NAME(xyTerm->rhs(), "y"); } // Tests the difference between two less trivial expressions. // (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1 TEST(Simplify, SimplifyDeeperDifference) { VarHandle n("n", kInt); VarHandle n_1("n_1", kInt); VarHandle m("m", kInt); ExprHandle body = (m * (ExprHandle(1) * n_1) + (n + 1)) - (m * (ExprHandle(1) * n_1) + n); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 1); } // Test constant folding into the difference between expressions. // 2 + char((m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n)) => 3 TEST(Simplify, SimplifyFoldComplexDifference) { VarHandle n("n", kInt); VarHandle n_1("n_1", kInt); VarHandle m("m", kInt); ExprHandle body = (IntImm::make(2) + (Cast::make( kChar, (m * (ExprHandle(1) * n_1) + (n + 1)) - (m * (ExprHandle(1) * n_1) + n)))); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 3); } TEST(Simplify, SimplifyIfComponents) { VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = IfThenElse::make( ((ExprHandle(5) - ExprHandle(4)) * x) > y, ExprHandle(2) * x - x, ExprHandle(2) * y - y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(IfThenElse, simplified.node(), ifexpr); IS_NODE_WITH_NAME(CompareSelect, ifexpr->condition(), cmp); ASSERT_EQ(cmp->compare_select_op(), kGT); IS_VAR_WITH_NAME(cmp->lhs(), "x"); IS_VAR_WITH_NAME(cmp->rhs(), "y"); IS_VAR_WITH_NAME(ifexpr->true_value(), "x"); IS_VAR_WITH_NAME(ifexpr->false_value(), "y"); } TEST(Simplify, SimplifyOpaqueTerms) { VarHandle x("x", kInt); VarHandle y("y", kInt); { // 2 * x/y * y - x/y * y => x/y * y ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_NODE_WITH_NAME(Div, mul->lhs(), div); IS_VAR_WITH_NAME(div->lhs(), "x"); IS_VAR_WITH_NAME(div->rhs(), "y"); IS_VAR_WITH_NAME(mul->rhs(), "y"); } { // x%y - (x%y - 1) => 1 ExprHandle body = (x % y) - ((x % y) - 1); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 1); } } TEST(Simplify, SimplifySymbolicMinMax) { { // Minimum with constant difference between terms. VarHandle x("x", kInt); ExprHandle body = Min::make(x + 3, x + 7, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_VAR_WITH_NAME(add->lhs(), "x"); IS_IMM_WITH_VAL(Int, add->rhs(), 3); } { // Maximum with constant difference between terms. VarHandle x("x", kInt); ExprHandle body = Max::make(x + 3, x + 7, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_VAR_WITH_NAME(add->lhs(), "x"); IS_IMM_WITH_VAL(Int, add->rhs(), 7); } { // Can't simplify multiples because of signedness of variable component. // TODO: maybe we could for unsigned types? VarHandle x("x", kInt); ExprHandle body = Max::make(x * 3, x * 7, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE(Max, simplified.node()); } } TEST(Simplify, SimplifyNestedMax) { VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); { // Max(x + y, x + y) => x + y ExprHandle body = Max::make(x + y, x + y, true); ExprHandle simplified = IRSimplifier::simplify(body); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y"); } { // Max(x + y, Max(x + y, z)) => Max(x + y, z) ExprHandle body = Max::make(x + y, Max::make(x + y, z, true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max); IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); IS_VAR_WITH_NAME(max->rhs(), "z"); } { // Max(x + y, Max(z, x + y)) => Max(x + y, z) ExprHandle body = Max::make(x + y, Max::make(z, x + y, true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max); IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); IS_VAR_WITH_NAME(max->rhs(), "z"); } { // Max(Max(x + y, z), x + y) => Max(x + y, z) ExprHandle body = Max::make(Max::make(x + y, z, true), x + y, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max); IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); IS_VAR_WITH_NAME(max->rhs(), "z"); } { // Max(Max(z, x + y), x + y) => Max(x + y, z) ExprHandle body = Max::make(Max::make(z, x + y, true), x + y, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max); IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); IS_VAR_WITH_NAME(max->rhs(), "z"); } { // Max(Max(x, y), x) => Max(Max(x, y), x) // Nested Max ops with different propagate_nans should not be simplified. ExprHandle body = Max::make(Max::make(x, y, true), x, false); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max); IS_BINOP_W_VARS(Max, max->lhs(), max1, "x", "y"); ASSERT_TRUE(max1->propagate_nans()); IS_VAR_WITH_NAME(max->rhs(), "x"); ASSERT_FALSE(max->propagate_nans()); } { // Max(Min(x, y), Min(x, z)) => Min(Max(y, z), x) ExprHandle body = Max::make(Min::make(x, y, true), Min::make(x, z, true), true); ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); } { // Max(Min(x, y), Min(z, x)) => Min(Max(y, z), x) ExprHandle body = Max::make(Min::make(x, y, true), Min::make(z, x, true), true); ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); } { // Max(Min(y, x), Min(x, z)) => Min(Max(y, z), x) ExprHandle body = Max::make(Min::make(y, x, true), Min::make(x, z, true), true); ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); } { // Max(Min(y, x), Min(z, x)) => Min(Max(y, z), x) ExprHandle body = Max::make(Min::make(y, x, true), Min::make(z, x, true), true); ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); } { // Max(Min(y, x), Min(z, x)) => Max(Min(x, y), Min(x, z)) // When all the ops in the pattern do not have the same propagate_nans, // it should not be simplified. ExprHandle body = Max::make(Min::make(y, x, true), Min::make(z, x, false), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max); IS_BINOP_W_VARS(Min, max->lhs(), min1, "x", "y"); ASSERT_TRUE(min1->propagate_nans()); IS_BINOP_W_VARS(Min, max->rhs(), min2, "x", "z"); ASSERT_FALSE(min2->propagate_nans()); ASSERT_TRUE(max->propagate_nans()); } { // Max(5, Max(x, 8)) => Max(x, 8) ExprHandle body = Max::make(5, Max::make(x, 8, true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); ASSERT_TRUE(max->propagate_nans()); } { // Max(8, Max(x, 5)) => Max(x, 8) ExprHandle body = Max::make(8, Max::make(x, 5, true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); ASSERT_TRUE(max->propagate_nans()); } { // Max(Max(x, 8), 5) => Max(x, 8) ExprHandle body = Max::make(Max::make(x, 8, true), 5, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); ASSERT_TRUE(max->propagate_nans()); } { // Max(Max(x, 5), 8) => Max(x, 8) ExprHandle body = Max::make(Max::make(x, 5, true), 8, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); ASSERT_TRUE(max->propagate_nans()); } { // Max(5, Max(x, Max(y, Max(z, 8)))) => Max(Max(Max(x, 8), y), z) ExprHandle body = Max::make( 5, Max::make(x, Max::make(y, Max::make(z, 8, true), true), true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max1); IS_NODE_WITH_NAME(Max, max1->lhs(), max2); IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); ASSERT_TRUE(max3->propagate_nans()); IS_VAR_WITH_NAME(max2->rhs(), "y"); IS_VAR_WITH_NAME(max1->rhs(), "z"); } { // Max(8, Max(Max(y, Max(z, 5)), x)) => Max(Max(Max(x, 8), y), z) ExprHandle body = Max::make( 8, Max::make(Max::make(y, Max::make(z, 5, true), true), x, true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max1); IS_NODE_WITH_NAME(Max, max1->lhs(), max2); IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); ASSERT_TRUE(max3->propagate_nans()); IS_VAR_WITH_NAME(max2->rhs(), "y"); IS_VAR_WITH_NAME(max1->rhs(), "z"); } { // Max(5, Max(Max(Max(z, 8), y), x)) => Max(Max(Max(x, 8), y), z) ExprHandle body = Max::make( 5, Max::make(Max::make(Max::make(z, 8, true), y, true), x, true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max1); IS_NODE_WITH_NAME(Max, max1->lhs(), max2); IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); ASSERT_TRUE(max3->propagate_nans()); IS_VAR_WITH_NAME(max2->rhs(), "y"); IS_VAR_WITH_NAME(max1->rhs(), "z"); } { // Max(Max(x, Max(y, Max(5, z))), 8) => Max(Max(Max(x, 8), y), z) ExprHandle body = Max::make( Max::make(x, Max::make(y, Max::make(5, z, true), true), true), 8, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max1); IS_NODE_WITH_NAME(Max, max1->lhs(), max2); IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); ASSERT_TRUE(max3->propagate_nans()); IS_VAR_WITH_NAME(max2->rhs(), "y"); IS_VAR_WITH_NAME(max1->rhs(), "z"); } { // Max(Max(Max(y, Max(8, z)), x), 5) => Max(Max(Max(x, 8), y), z) ExprHandle body = Max::make( Max::make(Max::make(y, Max::make(z, 8, true), true), x, true), 5, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max1); IS_NODE_WITH_NAME(Max, max1->lhs(), max2); IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); ASSERT_TRUE(max3->propagate_nans()); IS_VAR_WITH_NAME(max2->rhs(), "y"); IS_VAR_WITH_NAME(max1->rhs(), "z"); } { // Max(Max(Max(Max(5, z), y), x), 8) => Max(Max(Max(x, 8), y), z) ExprHandle body = Max::make( Max::make(Max::make(Max::make(z, 5, true), y, true), x, true), 8, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max1); IS_NODE_WITH_NAME(Max, max1->lhs(), max2); IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); ASSERT_TRUE(max3->propagate_nans()); IS_VAR_WITH_NAME(max2->rhs(), "y"); IS_VAR_WITH_NAME(max1->rhs(), "z"); } { // Max(Max(Max(Max(z, 5), y), x), 8) => Max(Max(x, Max(Max(z, 5), y)), 8) // Do not simplify when all the Max ops do not have the same // propagate_nans. ExprHandle body = Max::make( Max::make(Max::make(Max::make(z, 5, true), y, false), x, true), 8, false); ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "Max(Max(Max(Max(z, 5, 1), y, 0), x, 1), 8, 0)"); } { // Max(8, Max(Max(x, 5), Max(y, z))) => Max(Max(Max(x, 8), y), z) ExprHandle body = Max::make( 8, Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max1); IS_NODE_WITH_NAME(Max, max1->lhs(), max2); IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); ASSERT_TRUE(max3->propagate_nans()); IS_VAR_WITH_NAME(max2->rhs(), "y"); IS_VAR_WITH_NAME(max1->rhs(), "z"); } { // Max(Max(Max(x, 5), Max(y, z)), 8) => Max(Max(Max(x, 8), y), z) ExprHandle body = Max::make( Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), 8, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max1); IS_NODE_WITH_NAME(Max, max1->lhs(), max2); IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); ASSERT_TRUE(max3->propagate_nans()); IS_VAR_WITH_NAME(max2->rhs(), "y"); IS_VAR_WITH_NAME(max1->rhs(), "z"); } } TEST(Simplify, SimplifyNestedMin) { VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); { // Min(x + y, x + y) => x + y ExprHandle body = Min::make(x + y, x + y, true); ExprHandle simplified = IRSimplifier::simplify(body); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y"); } { // Min(x + y, Min(x + y, z)) => Min(x + y, z) ExprHandle body = Min::make(x + y, Min::make(x + y, z, true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min); IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); IS_VAR_WITH_NAME(min->rhs(), "z"); } { // Min(x + y, Min(z, x + y)) => Min(x + y, z) ExprHandle body = Min::make(x + y, Min::make(z, x + y, true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min); IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); IS_VAR_WITH_NAME(min->rhs(), "z"); } { // Min(Min(x + y, z), x + y) => Min(x + y, z) ExprHandle body = Min::make(Min::make(x + y, z, true), x + y, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min); IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); IS_VAR_WITH_NAME(min->rhs(), "z"); } { // Min(Min(z, x + y), x + y) => Min(x + y, z) ExprHandle body = Min::make(Min::make(z, x + y, true), x + y, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min); IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); IS_VAR_WITH_NAME(min->rhs(), "z"); } { // Min(Min(x, y), x) => Min(Min(x, y), x) // Nested Min ops with different propagate_nans should not be simplified. ExprHandle body = Min::make(Min::make(x, y, true), x, false); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min1); IS_BINOP_W_VARS(Min, min1->lhs(), min2, "x", "y"); ASSERT_TRUE(min2->propagate_nans()); IS_VAR_WITH_NAME(min1->rhs(), "x"); ASSERT_FALSE(min1->propagate_nans()); } { // Min(Max(x, y), Max(x, z)) => Max(Min(y, z), x) ExprHandle body = Min::make(Max::make(x, y, true), Max::make(x, z, true), true); ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); } { // Min(Max(x, y), Max(z, x)) => Max(Min(y, z), x) ExprHandle body = Min::make(Max::make(x, y, true), Max::make(z, x, true), true); ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); } { // Min(Max(y, x), Max(x, z)) => Max(Min(y, z), x) ExprHandle body = Min::make(Max::make(y, x, true), Max::make(x, z, true), true); ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); } { // Min(Max(y, x), Max(z, x)) => Max(Min(y, z), x) ExprHandle body = Min::make(Max::make(y, x, true), Max::make(z, x, true), true); ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); } { // Min(Max(y, x), Max(z, x)) => Min(Max(x, y), Max(x, z)) // When all the ops in the pattern do not have the same propagate_nans, // it should not be simplified. ExprHandle body = Min::make(Max::make(y, x, true), Max::make(z, x, false), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min); IS_BINOP_W_VARS(Max, min->lhs(), max1, "x", "y"); ASSERT_TRUE(max1->propagate_nans()); IS_BINOP_W_VARS(Max, min->rhs(), max2, "x", "z"); ASSERT_FALSE(max2->propagate_nans()); ASSERT_TRUE(min->propagate_nans()); } { // Min(5, Min(x, 8)) => Min(x, 8) ExprHandle body = Min::make(5, Min::make(x, 8, true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); ASSERT_TRUE(min->propagate_nans()); } { // Min(8, Min(x, 5)) => Min(x, 8) ExprHandle body = Min::make(8, Min::make(x, 5, true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); ASSERT_TRUE(min->propagate_nans()); } { // Min(Min(x, 8), 5) => Min(x, 8) ExprHandle body = Min::make(Min::make(x, 8, true), 5, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); ASSERT_TRUE(min->propagate_nans()); } { // Min(Min(x, 5), 8) => Min(x, 8) ExprHandle body = Min::make(Min::make(x, 5, true), 8, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); ASSERT_TRUE(min->propagate_nans()); } { // Min(5, Min(x, Min(y, Min(z, 8)))) => Min(Min(Min(x, 5), y), z) ExprHandle body = Min::make( 5, Min::make(x, Min::make(y, Min::make(z, 8, true), true), true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min1); IS_NODE_WITH_NAME(Min, min1->lhs(), min2); IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); ASSERT_TRUE(min3->propagate_nans()); IS_VAR_WITH_NAME(min2->rhs(), "y"); IS_VAR_WITH_NAME(min1->rhs(), "z"); } { // Min(5, Min(Min(y, Min(z, 8)), x)) => Min(Min(Min(x, 5), y), z) ExprHandle body = Min::make( 5, Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min1); IS_NODE_WITH_NAME(Min, min1->lhs(), min2); IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); ASSERT_TRUE(min3->propagate_nans()); IS_VAR_WITH_NAME(min2->rhs(), "y"); IS_VAR_WITH_NAME(min1->rhs(), "z"); } { // Min(5, Min(Min(Min(z, 8), y), x)) => Min(Min(Min(x, 5), y), z) ExprHandle body = Min::make( 5, Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min1); IS_NODE_WITH_NAME(Min, min1->lhs(), min2); IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); ASSERT_TRUE(min3->propagate_nans()); IS_VAR_WITH_NAME(min2->rhs(), "y"); IS_VAR_WITH_NAME(min1->rhs(), "z"); } { // Min(Min(x, Min(y, Min(8, z))), 5) => Min(Min(Min(x, 5), y), z) ExprHandle body = Min::make( Min::make(x, Min::make(y, Min::make(8, z, true), true), true), 5, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min1); IS_NODE_WITH_NAME(Min, min1->lhs(), min2); IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); ASSERT_TRUE(min3->propagate_nans()); IS_VAR_WITH_NAME(min2->rhs(), "y"); IS_VAR_WITH_NAME(min1->rhs(), "z"); } { // Min(Min(Min(y, Min(8, z)), x), 5) => Min(Min(Min(x, 5), y), z) ExprHandle body = Min::make( Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), 5, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min1); IS_NODE_WITH_NAME(Min, min1->lhs(), min2); IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); ASSERT_TRUE(min3->propagate_nans()); IS_VAR_WITH_NAME(min2->rhs(), "y"); IS_VAR_WITH_NAME(min1->rhs(), "z"); } { // Min(Min(Min(Min(8, z), y), x), 5) => Min(Min(Min(x, 5), y), z) ExprHandle body = Min::make( Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), 5, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min1); IS_NODE_WITH_NAME(Min, min1->lhs(), min2); IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); ASSERT_TRUE(min3->propagate_nans()); IS_VAR_WITH_NAME(min2->rhs(), "y"); IS_VAR_WITH_NAME(min1->rhs(), "z"); } { // Min(Min(Min(Min(z, 5), y), x), 8) => Min(Min(Min(Min(z, 5), y), x), 8) // Do not simplify when all the Min ops do not have the same // propagate_nans. ExprHandle body = Min::make( Min::make(Min::make(Min::make(z, 5, true), y, false), x, true), 8, false); ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "Min(Min(Min(Min(z, 5, 1), y, 0), x, 1), 8, 0)"); } { // Min(8, Min(Min(x, 5), Min(y, z))) => Min(Min(Min(x, 5), y), z) ExprHandle body = Min::make( 8, Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min1); IS_NODE_WITH_NAME(Min, min1->lhs(), min2); IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); ASSERT_TRUE(min3->propagate_nans()); IS_VAR_WITH_NAME(min2->rhs(), "y"); IS_VAR_WITH_NAME(min1->rhs(), "z"); } { // Min(Min(Min(x, 5), Min(y, z)), 8) => Min(Min(Min(x, 5), y), z) ExprHandle body = Min::make( Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), 8, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min1); IS_NODE_WITH_NAME(Min, min1->lhs(), min2); IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); ASSERT_TRUE(min3->propagate_nans()); IS_VAR_WITH_NAME(min2->rhs(), "y"); IS_VAR_WITH_NAME(min1->rhs(), "z"); } } TEST(Simplify, SimplifyWontReorderFloat) { { // 3 * (3 * x) - 3 * (3 * y) => 9 * (x - y) // This is an expression we can simplify. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - ExprHandle(3) * (ExprHandle(3) * y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 9); IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); IS_VAR_WITH_NAME(sub->lhs(), "x"); IS_VAR_WITH_NAME(sub->rhs(), "y"); } { // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - 3 * (3 * y). // If the vars are floating point, ops are not associative and we can't // reorder. VarHandle x("x", kFloat); VarHandle y("y", kFloat); ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - ExprHandle(3) * (ExprHandle(3) * y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Sub, simplified.node(), sub); IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); IS_IMM_WITH_VAL(Float, lhsVarMul->lhs(), 3); IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); IS_VAR_WITH_NAME(rhsVarMul->rhs(), "y"); } { // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - (9 * y). // We will simplify subexprs if they dont reorder floating point ops. VarHandle x("x", kDouble); VarHandle y("y", kInt); ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - ExprHandle(3) * (ExprHandle(3) * y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Sub, simplified.node(), sub); IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); IS_IMM_WITH_VAL(Double, lhsMul->lhs(), 3); IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); IS_IMM_WITH_VAL(Double, lhsVarMul->lhs(), 3); IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); IS_NODE_WITH_NAME_AND_CAST(Mul, sub->rhs(), rhsMul, Double); IS_IMM_WITH_VAL(Int, rhsMul->lhs(), 9); IS_VAR_WITH_NAME(rhsMul->rhs(), "y"); } { // Prevent reordering if FP propagated from dtypes. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = ExprHandle(3.f) * (ExprHandle(3) * x) - ExprHandle(3) * (ExprHandle(3.f) * y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Sub, simplified.node(), sub); IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); IS_NODE_WITH_NAME_AND_CAST(Mul, lhsMul->rhs(), lhsVarMul, Float); IS_IMM_WITH_VAL(Int, lhsVarMul->lhs(), 3); IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); IS_NODE_WITH_NAME(Cast, rhsVarMul->rhs(), yCast); IS_VAR_WITH_NAME(yCast->src_value(), "y"); } { VarHandle x("x", kFloat); VarHandle y("y", kFloat); // x%y - (x%y - 1) => x%y - (x%y - 1). // We wont reorder opaque ops if they are FP. ExprHandle body = (x % y) - ((x % y) - 1); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Sub, simplified.node(), sub); IS_NODE_WITH_NAME(Mod, sub->lhs(), lhsMod); IS_VAR_WITH_NAME(lhsMod->lhs(), "x"); IS_VAR_WITH_NAME(lhsMod->rhs(), "y"); IS_NODE_WITH_NAME(Sub, sub->rhs(), rhsSub); IS_NODE_WITH_NAME(Mod, rhsSub->lhs(), rhsMod); IS_VAR_WITH_NAME(rhsMod->lhs(), "x"); IS_VAR_WITH_NAME(rhsMod->rhs(), "y"); IS_IMM_WITH_VAL(Float, rhsSub->rhs(), 1); } } TEST(Simplify, SimplifyRoundModPattern) { { // (x/y)*y + x%y => x. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = ((x / y) * y) + (x % y); ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "x"); } { // Reverse order. // x%y + (x/y)*y => x. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = (x % y) + ((x / y) * y); ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "x"); } { // Non opaque denominator. // (x / (4+y)) * (4+y)) + (x % (y + 4)) => x. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)) + (x % (y + ExprHandle(4))); ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "x"); } { // Reverse order. // (x % (y + 4)) + (x / (4+y)) * (4+y)) => x. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = (x % (y + ExprHandle(4))) + ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)); ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "x"); } { // Opaque denominator. // (x / (2/y)) * (2/y)) + (x % (2/y)) => x. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = ((x / (ExprHandle(2) / y)) * (ExprHandle(2) / y)) + (x % (ExprHandle(2) / y)); ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "x"); } { // Non opaque numerator // ((2*x)/y * y) + ((2*x) % y) => 2 * x. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = (((ExprHandle(2) * x) / y) * y) + ((ExprHandle(2) * x) % y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_VAR_WITH_NAME(mul->rhs(), "x"); } { // Opaque numerator. // ((x/2) / y * y) + (x/2 % y) => x / 2. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = (((x / ExprHandle(2)) / y) * y) + ((x / ExprHandle(2)) % y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Div, simplified.node(), div); IS_VAR_WITH_NAME(div->lhs(), "x"); IS_IMM_WITH_VAL(Int, div->rhs(), 2); } { // Numerator and denominator. // ((2*x)/(2*y) * (2*y)) + ((2*x) % (2*y)) => 2 * x. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)) + ((ExprHandle(2) * x) % (ExprHandle(2) * y)); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_VAR_WITH_NAME(mul->rhs(), "x"); } { // Reverse order. // ((2*x) % (2*y)) + ((2*x)/(2*y) * (2*y)) => 2 * x. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = ((ExprHandle(2) * x) % (ExprHandle(2) * y)) + (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_VAR_WITH_NAME(mul->rhs(), "x"); } { // Negated Subtraction of Round Mod. // (x/y) * y - (0 - x%y) => x. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = ((x / y) * y) - (ExprHandle(0) - (x % y)); ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "x"); } { // Other terms are preserved. // (x/y)*y + x%y + (y * x) => x + (y * x). VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = ((x / y) * y) + (x % y) + (y * x); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_VAR_WITH_NAME(add->lhs(), "x"); IS_NODE_WITH_NAME(Mul, add->rhs(), mul); IS_VAR_WITH_NAME(mul->lhs(), "x"); IS_VAR_WITH_NAME(mul->rhs(), "y"); } { // Sanity checking we wont do the optimization on floats. VarHandle x("x", kFloat); VarHandle y("y", kFloat); ExprHandle body = ((x / y) * y) + (x % y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_NODE_WITH_NAME(Mul, add->lhs(), roundMul); IS_NODE_WITH_NAME(Div, roundMul->lhs(), roundDiv); IS_VAR_WITH_NAME(roundDiv->lhs(), "x"); IS_VAR_WITH_NAME(roundDiv->rhs(), "y"); IS_VAR_WITH_NAME(roundMul->rhs(), "y"); IS_NODE_WITH_NAME(Mod, add->rhs(), mod); IS_VAR_WITH_NAME(mod->lhs(), "x"); IS_VAR_WITH_NAME(mod->rhs(), "y"); } { // Sanity check we wont do it if the mod term doesn't match. VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); ExprHandle body = ((x / y) * y) + (x % z); ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "(x / y) * y + x % z"); } { // Sanity check we wont do it if the div term doesn't match. VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); ExprHandle body = (y * (x / z)) + (x % y); ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "x % y + (x / z) * y"); } { // Sanity check we wont do it if the mul term doesn't match. VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); ExprHandle body = ((x / y) * z) + (x % y); ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "x % y + (x / y) * z"); } } TEST(Simplify, SimplifyRoundModPatternFactorization) { { // Full factorization. // 2 * (x/y * y) + 2 * (x%y) => 2 * x. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = ExprHandle(2) * ((x / y) * y) + ExprHandle(2) * (x % y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_VAR_WITH_NAME(mul->rhs(), "x"); } { // Partial Factorization. // 32 * (x/8) + 4 * (x % 8) => 4 * x. VarHandle x("x", kInt); VarHandle y("y", kInt); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) ExprHandle body = ExprHandle(32) * (x / 8) + ExprHandle(4) * (x % 8); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 4); IS_VAR_WITH_NAME(mul->rhs(), "x"); } { // Factorization requiring constant folding. // 20 * (x / (16 / 2)) * 2 + (11 % 6) * (x % (7+1)) => 5 * x. VarHandle x("x", kInt); ExprHandle body = ExprHandle(40) * (x / (ExprHandle(16) / 2)) + (ExprHandle(11) % 6) * (x % (ExprHandle(7) + 1)); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 5); IS_VAR_WITH_NAME(mul->rhs(), "x"); } { VarHandle x("x", kInt); ExprHandle body = (x / 5) * 10 + ExprHandle(2) * (x % 5); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_VAR_WITH_NAME(mul->rhs(), "x"); } { VarHandle x("x", kInt); ExprHandle body = (x / 10) * 0 + x % 5; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mod, simplified.node(), mod); IS_VAR_WITH_NAME(mod->lhs(), "x"); IS_IMM_WITH_VAL(Int, mod->rhs(), 5); } } TEST(Simplify, SimplifyRoundModPatternMultivar) { { // Multivar. // (x/8) * 8 + (y/5)*5 + x%8 + y%5 => x + y. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = (x / ExprHandle(8) * ExprHandle(8)) + (y / ExprHandle(5) * ExprHandle(5)) + (x % 8) + (y % 5); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_VAR_WITH_NAME(add->lhs(), "x"); IS_VAR_WITH_NAME(add->rhs(), "y"); } { // Find the right var. // (y/8) * 8 x%8 + y%8 + z%8 => x%8 + y + z%8 VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); ExprHandle body = (y / ExprHandle(8) * ExprHandle(8)) + (x % 8) + (y % 8) + (z % 8); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_NODE_WITH_NAME(Add, add->lhs(), add2); IS_NODE_WITH_NAME(Mod, add2->lhs(), xMod); IS_VAR_WITH_NAME(xMod->lhs(), "x"); IS_IMM_WITH_VAL(Int, xMod->rhs(), 8); IS_VAR_WITH_NAME(add2->rhs(), "y"); IS_NODE_WITH_NAME(Mod, add->rhs(), zMod); IS_VAR_WITH_NAME(zMod->lhs(), "z"); IS_IMM_WITH_VAL(Int, zMod->rhs(), 8); } { // Compound. // (x + (z + 512 * y) % 16) + 16 * ((z + 512 * y) / 16) // => (z + 512 * y) + x VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); ExprHandle body = x + (z + y * 512) % 16 + ((z + y * 512) / 16 * 16); ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "x + (z + 512 * y)"); } } TEST(Simplify, SimplifyModRoundModPattern) { { // t/7 % 9 * 7 + t % 7 => t%63 VarHandle t("t", kInt); ExprHandle body = (t / 7 % 9) * 7 + t % 7; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mod, simplified.node(), mod); IS_VAR_WITH_NAME(mod->lhs(), "t"); IS_IMM_WITH_VAL(Int, mod->rhs(), 63); } { // 2*t/7 % 9 * 7 + 2*t % 7 => 2*t % 63 VarHandle t("t", kInt); ExprHandle body = (ExprHandle(2) * t / 7 % 9) * 7 + ExprHandle(2) * t % 7; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mod, simplified.node(), mod); IS_NODE_WITH_NAME(Mul, mod->lhs(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_VAR_WITH_NAME(mul->rhs(), "t"); IS_IMM_WITH_VAL(Int, mod->rhs(), 63); } { // t/x % y * x + t % x => t%(x*y) VarHandle t("t", kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = (t / x % y) * x + t % x; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mod, simplified.node(), mod); IS_VAR_WITH_NAME(mod->lhs(), "t"); IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); IS_VAR_WITH_NAME(mul->lhs(), "x"); IS_VAR_WITH_NAME(mul->rhs(), "y"); } { // k*t/x % y * x + k*t % x => k*t%(x*y) VarHandle t("t", kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle k("k", kInt); ExprHandle body = (k * t / x % y) * x + k * t % x; ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "(k * t) % (x * y)"); } { // t/k/x % y * x + t/k % x => t/k%(x*y) VarHandle t("t", kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle k("k", kInt); ExprHandle body = (t / k / x % y) * x + t / k % x; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mod, simplified.node(), mod); IS_NODE_WITH_NAME(Div, mod->lhs(), div); IS_VAR_WITH_NAME(div->lhs(), "t"); IS_VAR_WITH_NAME(div->rhs(), "k"); IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); IS_VAR_WITH_NAME(mul->lhs(), "x"); IS_VAR_WITH_NAME(mul->rhs(), "y"); } { // Sanity checking we wont do the optimization on floats. VarHandle x("x", kFloat); VarHandle y("y", kFloat); VarHandle z("z", kFloat); ExprHandle body = ((x / y % z) * y) + (x % y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_NODE_WITH_NAME(Mul, add->lhs(), mul); IS_NODE_WITH_NAME(Mod, mul->lhs(), mod); IS_NODE_WITH_NAME(Div, mod->lhs(), div); IS_VAR_WITH_NAME(div->lhs(), "x"); IS_VAR_WITH_NAME(div->rhs(), "y"); IS_VAR_WITH_NAME(mod->rhs(), "z"); IS_VAR_WITH_NAME(mul->rhs(), "y"); IS_NODE_WITH_NAME(Mod, add->rhs(), mod2); IS_VAR_WITH_NAME(mod2->lhs(), "x"); IS_VAR_WITH_NAME(mod2->rhs(), "y"); } } TEST(Simplify, SimplifyModRoundModPatternFactorization) { { // 2 * (t /7 % 9 * 7) + 2 * (t % 7) => 2 * (t % 63) VarHandle t("t", kInt); ExprHandle body = ExprHandle(2) * ((t / 7 % 9) * 7) + ExprHandle(2) * (t % 7); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); IS_VAR_WITH_NAME(mod->lhs(), "t"); IS_IMM_WITH_VAL(Int, mod->rhs(), 63); } { // t /7 % 9 * 14 + 2* (t % 7) => 2* (t % 63) VarHandle t("t", kInt); ExprHandle body = (t / 7 % 9) * 14 + ExprHandle(2) * (t % 7); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); IS_VAR_WITH_NAME(mod->lhs(), "t"); IS_IMM_WITH_VAL(Int, mod->rhs(), 63); } { // t/14 % 9 * 7 + t/2 % 7 => t/2 % 63 VarHandle t("t", kInt); ExprHandle body = (t / 14 % 9) * 7 + t / 2 % 7; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mod, simplified.node(), mod); IS_NODE_WITH_NAME(Div, mod->lhs(), div); IS_VAR_WITH_NAME(div->lhs(), "t"); IS_IMM_WITH_VAL(Int, div->rhs(), 2); IS_IMM_WITH_VAL(Int, mod->rhs(), 63); } { // t/(7*3) % 9 * 7*3 + t % (7*3) => t % 189 VarHandle t("t", kInt); ExprHandle body = (t / (ExprHandle(7) * ExprHandle(3)) % 9) * 7 * 3 + t % (ExprHandle(7) * ExprHandle(3)); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mod, simplified.node(), mod); IS_VAR_WITH_NAME(mod->lhs(), "t"); IS_IMM_WITH_VAL(Int, mod->rhs(), 189); } { // 2*(t/x % y * x) + 2*(t % x) => 2*(t%(x*y)) VarHandle t("t", kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = ExprHandle(2) * ((t / x % y) * x) + ExprHandle(2) * (t % x); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); IS_VAR_WITH_NAME(mod->lhs(), "t"); IS_NODE_WITH_NAME(Mul, mod->rhs(), mul2); IS_VAR_WITH_NAME(mul2->lhs(), "x"); IS_VAR_WITH_NAME(mul2->rhs(), "y"); } } TEST(Simplify, SimplifyModRoundModPatternMultivar) { { // t/7 % 9 * 7 + t % 7 + t => t % 63 + t VarHandle t("t", kInt); ExprHandle body = (t / 7 % 9) * 7 + t % 7 + t; ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "t % 63 + t"); } { // t/7 % 9 * 7 + t/8 % 9 * 8 + t % 7 + t % 8 => t % 63 + t % 72 VarHandle t("t", kInt); ExprHandle body = (t / 7 % 9) * 7 + (t / 8 % 9) * 8 + t % 7 + t % 8; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_NODE_WITH_NAME(Mod, add->lhs(), mod1); IS_VAR_WITH_NAME(mod1->lhs(), "t"); IS_IMM_WITH_VAL(Int, mod1->rhs(), 63); IS_NODE_WITH_NAME(Mod, add->rhs(), mod2); IS_VAR_WITH_NAME(mod2->lhs(), "t"); IS_IMM_WITH_VAL(Int, mod2->rhs(), 72); } { // k + t/x % y * x + t % x => k + t%(x*y) VarHandle t("t", kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle k("k", kInt); ExprHandle body = k + (t / x % y) * x + t % x; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_VAR_WITH_NAME(add->lhs(), "k"); IS_NODE_WITH_NAME(Mod, add->rhs(), mod); IS_VAR_WITH_NAME(mod->lhs(), "t"); IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); IS_VAR_WITH_NAME(mul->lhs(), "x"); IS_VAR_WITH_NAME(mul->rhs(), "y"); } { // t/x % y * x + t % x + (t/k / x % y) * x + t/k % x // => t%(x*y) + t/k % (x*y) VarHandle t("t", kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle k("k", kInt); ExprHandle body = (t / x % y) * x + t % x + (t / k / x % y) * x + t / k % x; ExprHandle simplified = IRSimplifier::simplify(body); checkExprIR(simplified, "(t / k) % (x * y) + t % (x * y)"); } { // 3D: (7 * ((i0_flat / 7) % 9) + i0_flat % 7) + 63 * (i0_flat / 63) // => io_flat VarHandle t("io_flat", kInt); ExprHandle body = ExprHandle(7) * (t / 7 % 9) + t % 7 + ExprHandle(63) * (t / 63); ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "io_flat"); } { // 5D: i0_flat / (11 * 10 * 9 * 7) * (7 * 9 * 10 * 11) + // (i0_flat / (10 * 9 * 7) % 11) * 7 * 9 * 10 + // (i0_flat / (9 * 7) % 10) * 7 * 9 + // (i0_flat / 7 % 9) * 7 + // i0_flat % 7 => io_flat VarHandle t("io_flat", kInt); ExprHandle body = (t / (ExprHandle(11) * 10 * 9 * 7)) * (7 * 9 * 10 * 11) + (t / (ExprHandle(10) * 9 * 7) % 11) * 7 * 9 * 10 + (t / (ExprHandle(9) * 7) % 10) * 7 * 9 + (t / 7 % 9) * 7 + t % 7; ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "io_flat"); } { // 3D: (m * ((i0_flat / m) % n) + i0_flat % m) + (m * n) * // (i0_flat / (m * n)) => io_flat VarHandle t("io_flat", kInt); VarHandle m("m", kInt); VarHandle n("n", kInt); ExprHandle body = m * (t / m % n) + t % m + (m * n) * (t / (m * n)); ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "io_flat"); } { // 5D: i0_flat / (k * l * n * m) * (m * n * l * k) + // (i0_flat / (l * n * m) % k) * m * n * l + // (i0_flat / (n * m) % l) * m * n + // (i0_flat / m % n) * m + // i0_flat % m => io_flat VarHandle t("io_flat", kInt); VarHandle m("m", kInt); VarHandle n("n", kInt); VarHandle l("l", kInt); VarHandle k("k", kInt); ExprHandle body = (t / (k * l * n * m)) * (m * n * l * k) + (t / (l * n * m) % k) * m * n * l + (t / (n * m) % l) * m * n + (t / m % n) * m + t % m; ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "io_flat"); } } TEST(Simplify, SimplifyDivisionScalarFactorization) { { // Simple factorization of numerator and denominator. // 8x / 4y => 2x / y. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = (x * 8) / (y * 4); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Div, simplified.node(), div); IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); IS_VAR_WITH_NAME(lhs->rhs(), "x"); IS_VAR_WITH_NAME(div->rhs(), "y"); } { // Don't change anything if we can't factorize. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = (x * 7) / (y * 4); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Div, simplified.node(), div); IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); IS_IMM_WITH_VAL(Int, lhs->lhs(), 7); IS_VAR_WITH_NAME(lhs->rhs(), "x"); IS_NODE_WITH_NAME(Mul, div->rhs(), rhs); IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); IS_VAR_WITH_NAME(rhs->rhs(), "y"); } { // Don't reorder floats. VarHandle x("x", kFloat); VarHandle y("y", kFloat); ExprHandle body = (x * 8) / (y * 4); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Div, simplified.node(), div); IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); IS_VAR_WITH_NAME(lhs->lhs(), "x"); IS_IMM_WITH_VAL(Float, lhs->rhs(), 8.f); IS_NODE_WITH_NAME(Mul, div->rhs(), rhs); IS_VAR_WITH_NAME(rhs->lhs(), "y"); IS_IMM_WITH_VAL(Float, rhs->rhs(), 4.f); } { // Sanity check we do nothing if there are only scalar parts. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = (x * 1) / (y * 1); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Div, simplified.node(), div); IS_VAR_WITH_NAME(div->lhs(), "x"); IS_VAR_WITH_NAME(div->rhs(), "y"); } { // Can factorize amounts of variables. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = (x + x + x + x) / (y + y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Div, simplified.node(), div); IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); IS_VAR_WITH_NAME(lhs->rhs(), "x"); IS_VAR_WITH_NAME(div->rhs(), "y"); } } TEST(Simplify, SimplifyConstantBranches) { { // If the condition is constant true then take the true_value. // 1 ? x : y => x VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle t(1); ExprHandle body = IfThenElse::make(t, x, y); ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "x"); } { // If the condition is constant false then take the false_value. // 0 ? x : y => y VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle t(0); ExprHandle body = IfThenElse::make(t, x, y); ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "y"); } { // condition is simplified before checking. // (x-x) ? x : y => y VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = IfThenElse::make(x - x, x, y); ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "y"); } { // If both branches are the same then don't do the condition. // y ? x : x => x VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = IfThenElse::make(y, x, x); ExprHandle simplified = IRSimplifier::simplify(body); IS_VAR_WITH_NAME(simplified.node(), "x"); } { // If both branches simplify to the same thing it still works. // y ? (x + x) : (2 * x) => x VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = IfThenElse::make(y, x + x, ExprHandle(2) * x); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_VAR_WITH_NAME(mul->rhs(), "x"); } } TEST(Simplify, SimplifyConstantCond) { { // If the condition is constant true then take the true_value. // 1 ? A[0] = 1 : B[0] = 1 => A[0] = 1 BufHandle a("A", {1}, kInt); BufHandle b("B", {1}, kInt); ExprHandle condition(1); StmtPtr true_val = Store::make(a, {0}, 1); StmtPtr false_val = Store::make(b, {0}, 1); CondPtr body = alloc(condition.node(), true_val, false_val); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "A"); } { // If the condition is constant false then take the false_value. // 0 ? A[0] = 1 : B[0] = 1 => B[0] = 1 BufHandle a("A", {1}, kInt); BufHandle b("B", {1}, kInt); ExprHandle condition(0); StmtPtr true_val = Store::make(a, {0}, 1); StmtPtr false_val = Store::make(b, {0}, 1); StmtPtr body = alloc(condition.node(), true_val, false_val); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "B"); } { // condition is simplified before checking. // (x-x) ? A[0] = 1 : B[0] = 1 => B[0] = 1 VarHandle x("x", kInt); BufHandle a("A", {1}, kInt); BufHandle b("B", {1}, kInt); ExprHandle condition(x - x); StmtPtr true_val = Store::make(a, {0}, 1); StmtPtr false_val = Store::make(b, {0}, 1); StmtPtr body = alloc(condition.node(), true_val, false_val); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "B"); } { // If both branches are the same then don't do the condition. // x ? A[0] = x : A[0] = x => A[0] = x VarHandle x("x", kInt); BufHandle a("A", {1}, kInt); ExprHandle condition(x - x); StmtPtr true_val = Store::make(a, {0}, x); StmtPtr false_val = Store::make(a, {0}, x); StmtPtr body = alloc(condition.node(), true_val, false_val); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "A"); } { // If both branches simplify to the same thing it still works. // x ? (x + x) : (2 * x) => x VarHandle x("x", kInt); BufHandle a("A", {1}, kInt); ExprHandle condition(x - x); StmtPtr true_val = Store::make(a, {0}, ExprHandle(2) * x); StmtPtr false_val = Store::make(a, {0}, x + x); StmtPtr body = alloc(condition.node(), true_val, false_val); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "A"); } { // But not if they dont // x ? x : (2 * x) => x ? x : (2 * x) VarHandle x("x", kInt); BufHandle a("A", {1}, kInt); ExprHandle condition(x); StmtPtr true_val = Store::make(a, {0}, x); StmtPtr false_val = Store::make(a, {0}, ExprHandle(2) * x); StmtPtr body = alloc(condition.node(), true_val, false_val); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); ASSERT_EQ(block, nullptr); } { StmtPtr cond = alloc( ExprHandle(false).node(), alloc(std::vector({})), nullptr); StmtPtr simplified = IRSimplifier::simplify(cond); ASSERT_EQ(simplified, nullptr); } { StmtPtr cond = alloc( ExprHandle(true).node(), nullptr, alloc(std::vector({}))); StmtPtr simplified = IRSimplifier::simplify(cond); ASSERT_EQ(simplified, nullptr); } } TEST(Simplify, SimplifyEliminateEmptyCond) { // If the branches are empty in different ways, eliminate. { VarHandle x("x", kInt); ExprHandle condition(x); StmtPtr true_val = alloc(std::vector({})); StmtPtr body = alloc(condition.node(), true_val, nullptr); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); ASSERT_NE(block, nullptr); ASSERT_EQ(block->nstmts(), 0); } { VarHandle x("x", kInt); ExprHandle condition(x); StmtPtr false_val = alloc(std::vector({})); StmtPtr body = alloc(condition.node(), nullptr, false_val); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); ASSERT_NE(block, nullptr); ASSERT_EQ(block->nstmts(), 0); } } TEST(Simplify, SimplifyConstantComparisons) { auto ComparisonTest = [](ExprHandle a, ExprHandle b, CompareSelectOperation op, int result) { ExprHandle body = CompareSelect::make(a, b, op); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), result); }; // Equals. ComparisonTest(2, 2, kEQ, 1); ComparisonTest(1, 2, kEQ, 0); ComparisonTest(2, 1, kEQ, 0); // Greater than. ComparisonTest(2, 2, kGT, 0); ComparisonTest(1, 2, kGT, 0); ComparisonTest(2, 1, kGT, 1); // Greater or Equal. ComparisonTest(2, 2, kGE, 1); ComparisonTest(1, 2, kGE, 0); ComparisonTest(2, 1, kGE, 1); // Less Than. ComparisonTest(2, 2, kLT, 0); ComparisonTest(1, 2, kLT, 1); ComparisonTest(2, 1, kLT, 0); // Less or Equal. ComparisonTest(2, 2, kLE, 1); ComparisonTest(1, 2, kLE, 1); ComparisonTest(2, 1, kLE, 0); // Not equal. ComparisonTest(2, 2, kNE, 0); ComparisonTest(1, 2, kNE, 1); ComparisonTest(2, 1, kNE, 1); // With specified results: ExprHandle body = CompareSelect::make(2, 2, 5, 42, kNE); ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 42); } TEST(Simplify, SimplifySymbolicComparisons) { VarHandle x("x", kInt); VarHandle y("y", kInt); auto TookTrueBranch = [](ExprHandle a) { IS_IMM_WITH_VAL(Int, a.node(), 1); }; auto TookFalseBranch = [](ExprHandle a) { IS_IMM_WITH_VAL(Int, a.node(), 0); }; // EQ // x == x => 1 ExprHandle body = CompareSelect::make(x, x, kEQ); TookTrueBranch(IRSimplifier::simplify(body)); // x == x+1 => 0 body = CompareSelect::make(x, x + 1, kEQ); TookFalseBranch(IRSimplifier::simplify(body)); // x == x * 2 cannot simplify since we don't know x is nonzero. body = CompareSelect::make(x, x * 2, kEQ); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); // x == x * 1 => 1 body = CompareSelect::make(x, x * 1, kEQ); TookTrueBranch(IRSimplifier::simplify(body)); { // x == y => x == y body = CompareSelect::make(x, y, kEQ); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp); ASSERT_EQ(cmp->compare_select_op(), kEQ); IS_VAR_WITH_NAME(cmp->lhs(), "x"); IS_VAR_WITH_NAME(cmp->rhs(), "y"); } { // x == 5 => x == 5 body = CompareSelect::make(x, 5, kEQ); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp); ASSERT_EQ(cmp->compare_select_op(), kEQ); IS_VAR_WITH_NAME(cmp->lhs(), "x"); IS_IMM_WITH_VAL(Int, cmp->rhs(), 5); } // GT // x+1 > x => 1 body = CompareSelect::make(x + 1, x, kGT); TookTrueBranch(IRSimplifier::simplify(body)); // x > x + 1 => 0 body = CompareSelect::make(x, x + 1, kGT); TookFalseBranch(IRSimplifier::simplify(body)); // x > x - 1 => 1 body = CompareSelect::make(x, x - 1, kGT); TookTrueBranch(IRSimplifier::simplify(body)); // x - 1 > x => 0 body = CompareSelect::make(x - 1, x, kGT); TookFalseBranch(IRSimplifier::simplify(body)); // x > x => 0 body = CompareSelect::make(x, x, kGT); TookFalseBranch(IRSimplifier::simplify(body)); // x * 2 > x => x * 2 > x // since we don't know the sign of x. body = CompareSelect::make(x * 2, x, kGT); IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); // GE // x+1 >= x => 1 body = CompareSelect::make(x + 1, x, kGE); TookTrueBranch(IRSimplifier::simplify(body)); // x >= x + 1 => 0 body = CompareSelect::make(x, x + 1, kGE); TookFalseBranch(IRSimplifier::simplify(body)); // x >= x => 1 body = CompareSelect::make(x, x, kGE); TookTrueBranch(IRSimplifier::simplify(body)); // x * 2 >= x => x * 2 >= x // since we don't know the sign of x. body = CompareSelect::make(x * 2, x, kGE); IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); // LT // x+1 < x => 0 body = CompareSelect::make(x + 1, x, kLT); TookFalseBranch(IRSimplifier::simplify(body)); // x < x + 1 => 1 body = CompareSelect::make(x, x + 1, kLT); TookTrueBranch(IRSimplifier::simplify(body)); // x < x => 0 body = CompareSelect::make(x, x, kLT); TookFalseBranch(IRSimplifier::simplify(body)); // LE // x+1 <= x => 0 body = CompareSelect::make(x + 1, x, kLE); TookFalseBranch(IRSimplifier::simplify(body)); // x <= x + 1 => 1 body = CompareSelect::make(x, x + 1, kLE); TookTrueBranch(IRSimplifier::simplify(body)); // x <= x => 1 body = CompareSelect::make(x, x, kLE); TookTrueBranch(IRSimplifier::simplify(body)); // NE // x+1 != x => 1 body = CompareSelect::make(x + 1, x, kNE); TookTrueBranch(IRSimplifier::simplify(body)); // x != x + 1 => 1 body = CompareSelect::make(x, x + 1, kNE); TookTrueBranch(IRSimplifier::simplify(body)); // x != x => 0 body = CompareSelect::make(x, x, kNE); TookFalseBranch(IRSimplifier::simplify(body)); } TEST(Simplify, SimplifyEliminateZeroLengthFor) { { // Will eliminate zero loop For. BufHandle a("A", {4}, kInt); BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, 0, 0, Store::make(c, {i}, Load::make(a, {i}))); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); ASSERT_EQ(block->nstmts(), 0); } { // still works if start is not zero. BufHandle a("A", {4}, kInt); BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, 2, 2, Store::make(c, {i}, Load::make(a, {i}))); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); ASSERT_EQ(block->nstmts(), 0); } { // works if both terms are variable. VarHandle x("x", kInt); BufHandle a("A", {4}, kInt); BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, x, x, Store::make(c, {i}, Load::make(a, {i}))); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); ASSERT_EQ(block->nstmts(), 0); } { // works if one term simplifies down. VarHandle x("x", kInt); BufHandle a("A", {4}, kInt); BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, 0, x - x, Store::make(c, {i}, Load::make(a, {i}))); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); ASSERT_EQ(block->nstmts(), 0); } { // Sanity check does nothing if the condition is not met. BufHandle a("A", {4}, kInt); BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i}))); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE(For, simplified); } } TEST(Simplify, SimplifyOneLoopFor) { { // Will remove the loop if the body is run once. BufHandle a("A", {4}, kInt); BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "C"); IS_IMM_WITH_VAL(Int, store->flat_index(), 0); } { // still works if start is not zero. BufHandle a("A", {4}, kInt); BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, 2, 3, Store::make(c, {i}, Load::make(a, {i}))); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "C"); IS_IMM_WITH_VAL(Int, store->flat_index(), 2); } { // works if both terms are variable. VarHandle x("x", kInt); BufHandle a("A", {4}, kInt); BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, x, x + 1, Store::make(c, {i}, Load::make(a, {i}))); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "C"); IS_VAR_WITH_NAME(store->flat_index(), "x"); } { // works if one term simplifies down. VarHandle x("x", kInt); BufHandle a("A", {4}, kInt); BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, 0, x - x + 1, Store::make(c, {i}, Load::make(a, {i}))); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "C"); IS_IMM_WITH_VAL(Int, store->flat_index(), 0); } { // Sanity check does nothing if the condition is not met. BufHandle a("A", {4}, kInt); BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i}))); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE(For, simplified); } } TEST(Simplify, SimplifyForWontLoseLoopOptions) { { // Sanity check does nothing if the condition is not met. BufHandle a("A", {4}, kInt); BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); LoopOptions options; options.set_gpu_block_index(LoopOptions::IDX_W); auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})), options); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(For, simplified, for_); LoopOptions options2 = for_->loop_options(); ASSERT_EQ(options.gpu_block_index(), options2.gpu_block_index()); } } TEST(Simplify, SimplifyMultilevelFor) { { // Multiple layers of For will be simplified out. BufHandle a("A", {4}, kInt); BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); VarHandle j("j", kInt); auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); auto outer = For::make(j, 0, 1, body); StmtPtr simplified = IRSimplifier::simplify(outer); BlockPtr block = to(simplified); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "C"); IS_IMM_WITH_VAL(Int, store->flat_index(), 0); } { // Will maintain an outer loop if the inner loop is eliminated. BufHandle a("A", {4}, kInt); BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); VarHandle j("j", kInt); auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); auto outer = For::make(j, 0, 2, body); StmtPtr simplified = IRSimplifier::simplify(outer); ForPtr for__ = static_to(simplified); IS_NODE_WITH_NAME(For, for__, for_); IS_VAR_WITH_NAME(for_->var(), "j"); IS_IMM_WITH_VAL(Int, for_->start(), 0); IS_IMM_WITH_VAL(Int, for_->stop(), 2); BlockPtr block = to(for_->body()); ASSERT_NE(block, nullptr); IS_NODE_WITH_NAME(Store, block->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "C"); IS_IMM_WITH_VAL(Int, store->flat_index(), 0); } { // Will maintain inner loop if outer loops is eliminated. BufHandle a("A", {4}, kInt); BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); VarHandle j("j", kInt); auto body = For::make(i, 0, 2, Store::make(c, {i}, Load::make(a, {i}))); auto outer = For::make(j, 0, 1, body); StmtPtr simplified = IRSimplifier::simplify(outer); BlockPtr block = to(simplified); IS_NODE_WITH_NAME(For, block->front(), for_); IS_VAR_WITH_NAME(for_->var(), "i"); IS_IMM_WITH_VAL(Int, for_->start(), 0); IS_IMM_WITH_VAL(Int, for_->stop(), 2); IS_NODE_WITH_NAME(Store, for_->body()->front(), store); IS_VAR_WITH_NAME(store->base_handle(), "C"); IS_VAR_WITH_NAME(store->flat_index(), "i"); } } TEST(Simplify, SimplifyForCleansUp) { { BufHandle a("a", {1, 12, 1}, kFloat); VarHandle x("x", kInt); Tensor b = Compute( "x", {1, 12, 1}, [](const VarHandle& i, const VarHandle& m, const VarHandle& n) { return i + m + n; }); LoopNest l({b}); l.prepareForCodegen(); StmtPtr body = LoopNest::sanitizeNames(l.root_stmt()); StmtPtr simplified = IRSimplifier::simplify(body); BlockPtr block = to(simplified); IS_NODE_WITH_NAME(For, block->front(), for_); // for is over "m". IS_VAR_WITH_NAME(for_->var(), "j"); // x[m] = m; IS_NODE_WITH_NAME(Store, for_->body()->front(), store); IS_VAR_WITH_NAME(store->flat_index(), "j"); IS_VAR_WITH_NAME(store->value(), "j"); } } TEST(Simplify, SimplifyEliminateEmptyFor) { { // Flatten many layers around an empty block to an empty block. StmtPtr last = alloc(std::vector({})); for ([[maybe_unused]] const auto i : c10::irange(11)) { VarHandle loopVar("loopVar", kInt); last = For::make(loopVar, 0, 10, last); } StmtPtr simplified = IRSimplifier::simplify(last); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 0); } } TEST(Simplify, SimplifyFlattenBlock) { { // Flatten multiple blocks down to one. // { { { stmt1, stmt2 } } } => { stmt1, stmt2 } BufHandle a("A", {1}, kInt); StorePtr store1 = Store::make(a, {0}, 1); StorePtr store2 = Store::make(a, {0}, 0); BlockPtr block1 = alloc(std::vector({store1, store2})); BlockPtr block2 = alloc(std::vector({block1})); BlockPtr enclosing = alloc(std::vector({block2})); StmtPtr simplified = IRSimplifier::simplify(enclosing); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); IS_NODE_WITH_NAME(Store, block->front(), store1_); IS_NODE_WITH_NAME(Store, block->back(), store2_); ASSERT_EQ(store1->value(), store1_->value()); ASSERT_EQ(store2->value(), store2_->value()); } { // Flatten multiple sub blocks containing statements. // { { stmt1 }, { stmt2 } } => { stmt1, stmt2 } BufHandle a("A", {1}, kInt); StorePtr store1 = Store::make(a, {0}, 1); StorePtr store2 = Store::make(a, {0}, 0); BlockPtr block1 = alloc(std::vector({store1})); BlockPtr block2 = alloc(std::vector({store2})); BlockPtr enclosing = alloc(std::vector({block1, block2})); StmtPtr simplified = IRSimplifier::simplify(enclosing); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); IS_NODE_WITH_NAME(Store, block->front(), store1_); IS_NODE_WITH_NAME(Store, block->back(), store2_); ASSERT_EQ(store1->value(), store1_->value()); ASSERT_EQ(store2->value(), store2_->value()); } { // Flatten sub blocks with different depths. // { stmt1 , { { stmt2 } } } => { stmt1, stmt2 } BufHandle a("A", {1}, kInt); StorePtr store1 = Store::make(a, {0}, 1); StorePtr store2 = Store::make(a, {0}, 0); BlockPtr block1 = alloc(std::vector({store2})); BlockPtr block2 = alloc(std::vector({block1})); BlockPtr enclosing = alloc(std::vector({store1, block2})); StmtPtr simplified = IRSimplifier::simplify(enclosing); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); IS_NODE_WITH_NAME(Store, block->front(), store1_); IS_NODE_WITH_NAME(Store, block->back(), store2_); ASSERT_EQ(store1->value(), store1_->value()); ASSERT_EQ(store2->value(), store2_->value()); } { // Flatten many layers around an empty block to an empty block. StmtPtr last = alloc(std::vector({})); for ([[maybe_unused]] const auto i : c10::irange(11)) { last = alloc(std::vector({last})); } StmtPtr simplified = IRSimplifier::simplify(last); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 0); } } TEST(Simplify, SimplifyEliminateZeroLengthAlloc) { { // Simple positive case. BufHandle b("x", {0}, kInt); AllocatePtr alloc_ = Allocate::make(b); FreePtr free_ = Free::make(b); BlockPtr block1 = alloc(std::vector({alloc_, free_})); ASSERT_EQ(block1->nstmts(), 2); StmtPtr simplified = IRSimplifier::simplify(block1); IS_NODE_WITH_NAME(Block, simplified, block2); ASSERT_EQ(block2->nstmts(), 0); } { // Simple negative case. BufHandle b("x", {2}, kInt); AllocatePtr alloc_ = Allocate::make(b); FreePtr free_ = Free::make(b); BlockPtr block1 = alloc(std::vector({alloc_, free_})); ASSERT_EQ(block1->nstmts(), 2); StmtPtr simplified = IRSimplifier::simplify(block1); IS_NODE_WITH_NAME(Block, simplified, block2); ASSERT_EQ(block2->nstmts(), 2); } { // Finds right Alloc/Free. BufHandle b1("x", {0}, kInt); BufHandle b2("y", {2}, kInt); AllocatePtr alloc1 = Allocate::make(b1); AllocatePtr alloc2 = Allocate::make(b2); FreePtr free2_ = Free::make(b2); FreePtr free1_ = Free::make(b1); BlockPtr block1 = alloc(std::vector({alloc1, alloc2, free2_, free1_})); ASSERT_EQ(block1->nstmts(), 4); StmtPtr simplified = IRSimplifier::simplify(block1); IS_NODE_WITH_NAME(Block, simplified, block2); ASSERT_EQ(block2->nstmts(), 2); IS_NODE_WITH_NAME(Allocate, block2->stmts().front(), simplified_alloc); IS_VAR_WITH_NAME(simplified_alloc->buffer_var(), "y"); IS_NODE_WITH_NAME(Free, block2->stmts().back(), simplified_free); ASSERT_EQ(simplified_alloc->buffer_var(), simplified_free->buffer_var()); } { // Dynamic shape. VarHandle z("z", kInt); BufHandle b1("x", {0}, kInt); BufHandle b2("y", {z}, kInt); AllocatePtr alloc1 = Allocate::make(b1); AllocatePtr alloc2 = Allocate::make(b2); FreePtr free2_ = Free::make(b2); FreePtr free1_ = Free::make(b1); BlockPtr block1 = alloc(std::vector({alloc1, alloc2, free2_, free1_})); ASSERT_EQ(block1->nstmts(), 4); StmtPtr simplified = IRSimplifier::simplify(block1); IS_NODE_WITH_NAME(Block, simplified, block2); ASSERT_EQ(block2->nstmts(), 2); } } TEST(Simplify, DontSimplifyRand) { { // rand() + rand() = rand() + rand() NOT 2 * rand(). ExprHandle body = Intrinsics::make(kRand, kInt) + Intrinsics::make(kRand, kInt); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_RAND(add->lhs()); IS_RAND(add->rhs()); } { // rand() - rand() = rand() - rand() NOT 0. ExprHandle body = Intrinsics::make(kRand, kFloat) - Intrinsics::make(kRand, kFloat); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Sub, simplified.node(), sub); IS_RAND(sub->lhs()); IS_RAND(sub->rhs()); } { // rand() * rand() = rand() * rand(). ExprHandle body = Intrinsics::make(kRand, kInt) * Intrinsics::make(kRand, kInt); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_RAND(mul->lhs()); IS_RAND(mul->rhs()); } } TEST(Simplify, SimplifyReorderForCond) { BufHandle a("A", {4}, kInt); BufHandle b("B", {1}, kInt); BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); VarHandle j("j", kInt); { // for ( if ( ... ) ) => if ( for ( ... ) ). auto body = For::make( i, 0, 4, Cond::make( CompareSelect::make(j, 10, CompareSelectOperation::kLT), Store::make(c, {i}, Load::make(a, {i})), nullptr)); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cond, simplified, cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); IS_NODE_WITH_NAME(For, true_block->front(), loop); } { // Can't reorder if condition is dependent on the loop var. auto body = For::make( i, 0, 4, Cond::make( CompareSelect::make(i, 2, CompareSelectOperation::kEQ), Store::make(c, {i}, Load::make(a, {i})), nullptr)); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(For, simplified, loop); IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); } { // Can't reorder if condition is dependent on a var that is modified inside // the loop. auto body = For::make( i, 0, 4, Cond::make( CompareSelect::make( Load::make(c, {0}), 10, CompareSelectOperation::kLT), Store::make(c, {0}, Load::make(a, {i})), nullptr)); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(For, simplified, loop); IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); } { // Condition based on buffer not referenced in body. Can reorder here. auto body = For::make( i, 0, 4, Cond::make( CompareSelect::make( Load::make(b, {0}), 10, CompareSelectOperation::kLT), Store::make(c, {0}, Load::make(a, {i})), nullptr)); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cond, simplified, cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); IS_NODE_WITH_NAME(For, true_block->front(), loop); } { // Condition based on buffer read only in body. Can reorder here. auto body = For::make( i, 0, 4, Cond::make( CompareSelect::make( Load::make(a, {0}), 10, CompareSelectOperation::kLT), Store::make(c, {0}, Load::make(a, {i})), nullptr)); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cond, simplified, cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); IS_NODE_WITH_NAME(For, true_block->front(), loop); } { // Condition depends on Let in the loop. Cannot reorder. auto body = For::make( i, 0, 4, Block::make( {Let::make(j, 3), Cond::make( CompareSelect::make(j, 10, CompareSelectOperation::kLT), Store::make(c, {0}, Load::make(a, {i})), nullptr)})); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(For, simplified, loop); IS_NODE_WITH_NAME(Let, loop->body()->front(), let); IS_NODE_WITH_NAME(Cond, loop->body()->back(), cond); } { // Multi level Ifs where all conditions are distinct. Move BOTH Cond // statements outside the loop. auto body = For::make( i, 0, 4, Cond::make( CompareSelect::make( Load::make(a, {0}), 10, CompareSelectOperation::kLT), Cond::make( CompareSelect::make(j, 10, CompareSelectOperation::kEQ), Store::make(c, {0}, Load::make(a, {i})), nullptr), nullptr)); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cond, simplified, cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); IS_NODE_WITH_NAME(Cond, true_block->front(), cond2); IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_block2); IS_NODE_WITH_NAME(For, true_block2->front(), loop); } { // Multi level Ifs where the inner condition does depend on a loop var, // reorder only the first Cond. auto body = For::make( i, 0, 4, Cond::make( CompareSelect::make( Load::make(a, {0}), 10, CompareSelectOperation::kLT), Cond::make( CompareSelect::make(i, 3, CompareSelectOperation::kEQ), Store::make(c, {0}, Load::make(a, {i})), nullptr), nullptr)); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cond, simplified, cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); IS_NODE_WITH_NAME(For, true_block->front(), loop); IS_NODE_WITH_NAME(Block, loop->body(), loop_body); IS_NODE_WITH_NAME(Cond, loop_body->front(), cond2); } { // Don't reorder if there's an else block of the Cond. // We could, but is it much better? auto body = For::make( i, 0, 4, Cond::make( CompareSelect::make(j, 10, CompareSelectOperation::kLT), Store::make(c, {0}, Load::make(a, {i})), Store::make(c, {0}, 0))); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(For, simplified, loop); IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); } { // Condition uses distinct region of Tensor. // We could reorder here wih better analysis, but we don't. Included for // completeness. auto body = For::make( i, 0, 4, Cond::make( CompareSelect::make( Load::make(c, {0}), 10, CompareSelectOperation::kLT), Store::make(c, {1}, Load::make(a, {i})), nullptr)); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(For, simplified, loop); IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); } } TEST(Simplify, SimplifyFuseConditions) { BufHandle a("A", {2}, kInt); BufHandle b("B", {2}, kInt); VarHandle i("i", kInt); VarHandle j("j", kInt); { // Can fuse since the conditions are identical. // if (A) { X }; if (A) { Y }; => if (A) { X; Y } auto body = Block::make( {Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {0}, i), nullptr), Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {1}, i), nullptr)}); StmtPtr simplified = IRSimplifier::simplify(body); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); IS_NODE_WITH_NAME(Cond, block->front(), cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); ASSERT_EQ(true_stmt->nstmts(), 2); ASSERT_EQ(cond->false_stmt(), nullptr); } { // Can't fuse, conditions are not identical in lhs (i != j). auto body = Block::make( {Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {0}, i), nullptr), Cond::make( CompareSelect::make(j, 10, CompareSelectOperation::kLT), Store::make(a, {1}, i), nullptr)}); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); IS_NODE_WITH_NAME(Cond, block->front(), cond1); IS_NODE_WITH_NAME(Cond, block->back(), cond2); IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); ASSERT_EQ(true_stmt1->nstmts(), 1); ASSERT_EQ(true_stmt2->nstmts(), 1); ASSERT_EQ(cond1->false_stmt(), nullptr); ASSERT_EQ(cond2->false_stmt(), nullptr); } { // Can't fuse, conditions are not identical in rhs (10 != 11). auto body = Block::make( {Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {0}, i), nullptr), Cond::make( CompareSelect::make(i, 11, CompareSelectOperation::kLT), Store::make(a, {1}, i), nullptr)}); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); IS_NODE_WITH_NAME(Cond, block->front(), cond1); IS_NODE_WITH_NAME(Cond, block->back(), cond2); IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); ASSERT_EQ(true_stmt1->nstmts(), 1); ASSERT_EQ(true_stmt2->nstmts(), 1); ASSERT_EQ(cond1->false_stmt(), nullptr); ASSERT_EQ(cond2->false_stmt(), nullptr); } { // Can't fuse, conditions are not identical in operation (LT vs GT). auto body = Block::make( {Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {0}, i), nullptr), Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kGT), Store::make(a, {1}, i), nullptr)}); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); IS_NODE_WITH_NAME(Cond, block->front(), cond1); IS_NODE_WITH_NAME(Cond, block->back(), cond2); IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); ASSERT_EQ(true_stmt1->nstmts(), 1); ASSERT_EQ(true_stmt2->nstmts(), 1); ASSERT_EQ(cond1->false_stmt(), nullptr); ASSERT_EQ(cond2->false_stmt(), nullptr); } { // Can't fuse, CompareSelect results are different. // Actually we totally could if we normalized CompareSelect results, but // TODO for later. auto body = Block::make( {Cond::make( CompareSelect::make(i, 10, 1, 0, CompareSelectOperation::kLT), Store::make(a, {0}, i), nullptr), Cond::make( CompareSelect::make(j, 10, 2, 0, CompareSelectOperation::kLT), Store::make(a, {1}, i), nullptr)}); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); IS_NODE_WITH_NAME(Cond, block->front(), cond1); IS_NODE_WITH_NAME(Cond, block->back(), cond2); IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); ASSERT_EQ(true_stmt1->nstmts(), 1); ASSERT_EQ(true_stmt2->nstmts(), 1); ASSERT_EQ(cond1->false_stmt(), nullptr); ASSERT_EQ(cond2->false_stmt(), nullptr); } { // Can fuse with false stmt only. auto body = Block::make( {Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), nullptr, Store::make(a, {0}, i)), Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), nullptr, Store::make(a, {1}, i))}); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); IS_NODE_WITH_NAME(Cond, block->front(), cond); IS_NODE_WITH_NAME(Block, cond->false_stmt(), false_stmt); ASSERT_EQ(false_stmt->nstmts(), 2); ASSERT_EQ(cond->true_stmt(), nullptr); } { // Can fuse with both true and false stmt. auto body = Block::make( {Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {0}, i), Store::make(b, {0}, i)), Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {1}, i), Store::make(b, {1}, i))}); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); IS_NODE_WITH_NAME(Cond, block->front(), cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); ASSERT_EQ(true_stmt->nstmts(), 2); IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt); ASSERT_EQ(false_stmt->nstmts(), 2); } { // Can fuse with mismatched true / false stmt existing auto body = Block::make( {Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {0}, i), nullptr), Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), nullptr, Store::make(b, {1}, i))}); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); IS_NODE_WITH_NAME(Cond, block->front(), cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); ASSERT_EQ(true_stmt->nstmts(), 1); IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt); ASSERT_EQ(false_stmt->nstmts(), 1); } { // Can fuse partial block contents, ie when there are non fused stmts before // and after. // before: // if (j < 10) { A[0] = j; } // if (i < 10) { A[0] = i; } // if (i < 10) { A[1] = i; } // if (i < 11) { A[1] = j; } // // after: // // if (j < 10) { A[0] = j; } // if (i < 10) { // A[0] = i; // A[1] = i; // } // if (i < 11) { A[1] = j; } auto body = Block::make({ Cond::make( CompareSelect::make(j, 10, CompareSelectOperation::kLT), Store::make(a, {0}, j), nullptr), Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {0}, i), nullptr), Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {1}, i), nullptr), Cond::make( CompareSelect::make(i, 11, CompareSelectOperation::kLT), Store::make(a, {1}, j), nullptr), }); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 3); auto it = block->begin(); it++; IS_NODE_WITH_NAME(Cond, *it, cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); ASSERT_EQ(true_stmt->nstmts(), 2); ASSERT_EQ(cond->false_stmt(), nullptr); } { // Can fuse longer sequences of identical conditions. auto body = Block::make({ Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {0}, j), nullptr), Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {0}, i), nullptr), Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {1}, i), nullptr), Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {1}, j), nullptr), }); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); IS_NODE_WITH_NAME(Cond, block->front(), cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); ASSERT_EQ(true_stmt->nstmts(), 4); ASSERT_EQ(cond->false_stmt(), nullptr); } { // Can't fuse through a non condition. auto body = Block::make({ Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {0}, j), nullptr), Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {0}, i), nullptr), Store::make(b, {1}, i + j), Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {1}, i), nullptr), Cond::make( CompareSelect::make(i, 10, CompareSelectOperation::kLT), Store::make(a, {1}, j), nullptr), }); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 3); IS_NODE_WITH_NAME(Cond, block->front(), cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); ASSERT_EQ(true_stmt->nstmts(), 2); ASSERT_EQ(cond->false_stmt(), nullptr); IS_NODE_WITH_NAME(Cond, block->back(), cond2); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt2); ASSERT_EQ(true_stmt2->nstmts(), 2); ASSERT_EQ(cond2->false_stmt(), nullptr); auto it = block->begin(); it++; IS_NODE_WITH_NAME(Store, *it, middle); } { // Can fuse if the conditions simplify to the same thing. auto body = Block::make( {Cond::make( CompareSelect::make( i * 2, ExprHandle(87) % ExprHandle(11), CompareSelectOperation::kLT), Store::make(a, {0}, i), nullptr), Cond::make( CompareSelect::make( i * 2, ExprHandle(300) / ExprHandle(30), CompareSelectOperation::kLT), Store::make(a, {1}, i), nullptr)}); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); IS_NODE_WITH_NAME(Cond, block->front(), cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); ASSERT_EQ(true_stmt->nstmts(), 2); ASSERT_EQ(cond->false_stmt(), nullptr); } { // Can fuse non-CompareSelects. // if (i) { X } if (i) { Y } => if (i) { X; Y } auto body = Block::make( {Cond::make(i, Store::make(a, {0}, i), nullptr), Cond::make(i, Store::make(a, {1}, i), nullptr)}); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); IS_NODE_WITH_NAME(Cond, block->front(), cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); ASSERT_EQ(true_stmt->nstmts(), 2); ASSERT_EQ(cond->false_stmt(), nullptr); } { // Sanity check wont fuse different non-CompareSelects. auto body = Block::make( {Cond::make(i, Store::make(a, {0}, i), nullptr), Cond::make(j, Store::make(a, {1}, i), nullptr)}); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); IS_NODE_WITH_NAME(Cond, block->front(), cond1); IS_NODE_WITH_NAME(Cond, block->back(), cond2); } { // Sanity check constant condition elimination still occurs when merging is // possible. auto body = Block::make( {Cond::make(1, Store::make(a, {0}, i), nullptr), Cond::make(1, Store::make(a, {1}, i), nullptr)}); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); IS_NODE_WITH_NAME(Store, block->front(), store1); IS_NODE_WITH_NAME(Store, block->back(), store2); } { // Sanity check for-cond reordering occurs after fusing. auto body = For::make( i, 0, 4, Block::make( {Cond::make( CompareSelect::make(j, 10, CompareSelectOperation::kLT), Store::make(a, {1}, Load::make(b, {0})), nullptr), Cond::make( CompareSelect::make(j, 10, CompareSelectOperation::kLT), Store::make(a, {2}, Load::make(b, {0})), nullptr)})); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Cond, simplified, cond); IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); IS_NODE_WITH_NAME(For, true_block->front(), loop); } } TEST(Simplify, SimplifySyncThreads) { BufHandle a("A", {4}, kInt); VarHandle i("i", kInt); { // Merge two inner SyncThreads. auto body = Block::make( // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) {Store::make(a, {0}, 1), alloc(), alloc(), Store::make(a, {1}, 0)}); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 3); auto it = block->begin(); IS_NODE(Store, *it++); IS_NODE(SyncThreads, *it++); IS_NODE(Store, *it++); } { // Eliminate outer SyncThreads. auto body = Block::make( {alloc(), Store::make(a, {1}, 0), alloc()}); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); auto it = block->begin(); IS_NODE(Store, *it); } { // Merge many inner SyncThreads. auto body = Block::make( {Store::make(a, {0}, 1), alloc(), alloc(), alloc(), alloc(), alloc(), Store::make(a, {1}, 0)}); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 3); auto it = block->begin(); IS_NODE(Store, *it++); IS_NODE(SyncThreads, *it++); IS_NODE(Store, *it++); } { // Merge multiple outer SyncThreads. auto body = Block::make( {alloc(), alloc(), Store::make(a, {1}, 0), alloc(), alloc(), alloc(), alloc()}); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); auto it = block->begin(); IS_NODE(Store, *it); } { // Merge multiple sections; auto body = Block::make( {Store::make(a, {0}, 1), alloc(), alloc(), Store::make(a, {1}, 0), Store::make(a, {2}, 0), alloc(), alloc(), alloc(), Store::make(a, {3}, 0)}); StmtPtr simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 6); auto it = block->begin(); IS_NODE(Store, *it++); IS_NODE(SyncThreads, *it++); IS_NODE(Store, *it++); IS_NODE(Store, *it++); IS_NODE(SyncThreads, *it++); IS_NODE(Store, *it++); } } TEST(Simplify, SimplifyRampSubBroadcast) { int num_lanes = 4; ExprHandle ramp = Ramp::make(ExprHandle(0), ExprHandle(6), num_lanes); ExprHandle broadcast = Broadcast::make(ExprHandle(-5), num_lanes); ExprHandle simplified = IRSimplifier::simplify(ramp - broadcast); RampPtr newRamp = simplified.AsNode(); IS_NODE_WITH_NAME(IntImm, newRamp->base(), base); ASSERT_EQ(base->value(), 5); IS_NODE_WITH_NAME(IntImm, newRamp->stride(), stride); ASSERT_EQ(stride->value(), 6); ASSERT_EQ(newRamp->lanes(), num_lanes); } TEST(Simplify, SimplifyBroadcastTermExpander) { int num_lanes = 8; ExprHandle bc0 = Broadcast::make(ExprHandle(0), num_lanes); ExprHandle bc1 = Broadcast::make(ExprHandle(1), num_lanes); ExprHandle bc2 = Broadcast::make(ExprHandle(2), num_lanes); // NB: We need a term in the middle which isn't simplified to trigger the // relevant path in TermExpander::mutate. The two bc1 terms are brought // together and simplified to 2 * bc1, which then needs to make 2 multi-lane. ExprHandle simplified = IRSimplifier::simplify(bc1 + (bc0 / bc2) + bc1); BufHandle buf("buf", {num_lanes}, kInt); // The result isn't fully simplified currently and thus would be brittle to // match. Observe its value instead. auto store = Store::make(buf, {Ramp::make(0, 1, num_lanes)}, simplified); SimpleIREvaluator eval(store, {buf}); std::vector output(num_lanes); eval(output); for (const auto i : c10::irange(num_lanes)) { ASSERT_EQ(output[i], 2); } } TEST(Simplify, CompareSelectLoopBounds) { constexpr int N = 8; BufHandle b("b", {N}, kFloat); VarHandle n("n", kInt); VarHandle m("m", kInt); VarHandle var_N("var_N", kInt); VarHandle var_M("var_M", kInt); auto test_case_fn = [](const VarHandle& n, const BufHandle& b, const ExprHandle& start, const ExprHandle& stop, const int& cmp_val, const CompareSelectOperation& cmp_op, const std::string& check_string) { StmtPtr s = For::make( n, start, stop, b.store({n}, CompareSelect::make(n, cmp_val, 0.f, 1.0f, cmp_op))); s = IRSimplifier::simplify(s); std::ostringstream oss; oss << *s; std::string target_string = "# CHECK: "; target_string += check_string; torch::jit::testing::FileCheck().run(target_string, oss.str()); }; auto test_case_nest_loops_fn = [](const VarHandle& n, const VarHandle& m, const BufHandle& b, const ExprHandle& n_start, const ExprHandle& n_stop, const ExprHandle& m_start, const ExprHandle& m_stop, const CompareSelectOperation& cmp_op, const std::string& check_string) { StmtPtr s = For::make( m, m_start, m_stop, b.store({n, m}, CompareSelect::make(n, m, 0.f, 1.0f, cmp_op))); StmtPtr root_s = For::make(n, n_start, n_stop, s); root_s = IRSimplifier::simplify(root_s); std::ostringstream oss; oss << *root_s; std::string target_string = "# CHECK: "; target_string += check_string; torch::jit::testing::FileCheck().run(target_string, oss.str()); }; // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n < 1 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 1.f; // } test_case_fn(n, b, 1, N, 1, kLT, "b[n] = 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n <= 1 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = n <= 1 ? 0.f : 1.f; // } test_case_fn(n, b, 1, N, 1, kLE, "b[n] = n<=1 ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n <= 0 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 1.f; // } test_case_fn(n, b, 1, N, 0, kLE, "b[n] = 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n < 0 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 1.f; // } test_case_fn(n, b, 1, N, 0, kLT, "b[n] = 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n < 8 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 0.f; // } test_case_fn(n, b, 1, N, N, kLT, "b[n] = 0.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n <= 7 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 0.f; // } test_case_fn(n, b, 1, N, N - 1, kLE, "b[n] = 0.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n <= 8 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 0.f; // } test_case_fn(n, b, 1, N, N, kLE, "b[n] = 0.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n < 7 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = n < 7 ? 0.f : 1.f; // } test_case_fn(n, b, 1, N, N - 1, kLT, "b[n] = n<7 ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n > 0 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 0.f; // } test_case_fn(n, b, 1, N, 0, kGT, "b[n] = 0.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n > 1 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = n > 1 ? 0.f : 1.f; // } test_case_fn(n, b, 1, N, 1, kGT, "b[n] = n>1 ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n >= 1 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 0.f; // } test_case_fn(n, b, 1, N, 1, kGE, "b[n] = 0.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n > 7 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 1.f; // } test_case_fn(n, b, 1, N, N - 1, kGT, "b[n] = 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n >= 7 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = n >= 7 ? 0.f : 1.f; // } test_case_fn(n, b, 1, N, N - 1, kGE, "b[n] = n>=7 ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n > 5 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = n > 5 ? 0.f : 1.f; // } test_case_fn(n, b, 1, N, 5, kGT, "b[n] = n>5 ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n >= 5 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = n >= 5 ? 0.f : 1.f; // } test_case_fn(n, b, 1, N, 5, kGE, "b[n] = n>=5 ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n > 8 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 1.f; // } test_case_fn(n, b, 1, N, N, kGT, "b[n] = 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n >= 8 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 1.f; // } test_case_fn(n, b, 1, N, N, kGE, "b[n] = 1.f;"); // Before: // for (const auto n : c10::irange(1, 2)) { // b[n] = n == 1 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, 2)) { // b[1] = 0.f; // } test_case_fn(n, b, 1, 2, 1, kEQ, "b[1] = 0.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n == 1 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = n == 1 ? 0.f : 1.f; // } test_case_fn(n, b, 1, N, 1, kEQ, "b[n] = n==1 ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n == 0 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 1.f; // } test_case_fn(n, b, 1, N, 0, kEQ, "b[n] = 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n == 7 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = n == 7 ? 0.f : 1.f; // } test_case_fn(n, b, 1, N, N - 1, kEQ, "b[n] = n==7 ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n == 8 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 1.f; // } test_case_fn(n, b, 1, N, N, kEQ, "b[n] = 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n != 1 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = n != 1 ? 0.f : 1.f; // } test_case_fn(n, b, 1, N, 1, kNE, "b[n] = n!=1 ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n != 7 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = n != 7 ? 0.f : 1.f; // } test_case_fn(n, b, 1, N, N - 1, kNE, "b[n] = n!=7 ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n != 5 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = n != 5 ? 0.f : 1.f; // } test_case_fn(n, b, 1, N, 5, kNE, "b[n] = n!=5 ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n != 0 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 0.f; // } test_case_fn(n, b, 1, N, 0, kNE, "b[n] = 0.f;"); // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n != 8 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 0.f; // } test_case_fn(n, b, 1, N, N, kNE, "b[n] = 0.f;"); // Before: // for (const auto n : c10::irange(10, 20)) { // for(const auto m : c10::irange(30, 40)) { // b[n, m] = (n != m) ? 0.f : 1.f; // } // } // After: // for (const auto n : c10::irange(10, 20)) { // for(const auto m : c10::irange(30, 40)) { // b[n, m] = 0.f; // } // } test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kNE, "b[n, m] = 0.f;"); test_case_nest_loops_fn( n, m, b, var_N + 10, var_N + 20, var_N + 30, var_N + 40, kNE, "b[n, m] = 0.f;"); test_case_nest_loops_fn( n, m, b, var_N + 10, var_N + 20, var_M + 30, var_M + 40, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(30, 40)) { // for(const auto m : c10::irange(10, 20)) { // b[n, m] = (n != m) ? 0.f : 1.f; // } // } // After: // for (const auto n : c10::irange(30, 40)) { // for(const auto m : c10::irange(10, 20)) { // b[n, m] = 0.f; // } // } test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kNE, "b[n, m] = 0.f;"); test_case_nest_loops_fn( n, m, b, var_N + 30, var_N + 40, var_N + 10, var_N + 20, kNE, "b[n, m] = 0.f;"); test_case_nest_loops_fn( n, m, b, var_N + 30, var_N + 40, var_M + 10, var_M + 20, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(30, 40)) { // for(const auto m : c10::irange(10, 31)) { // b[n, m] = (n != m) ? 0.f : 1.f; // } // } // After: // for (const auto n : c10::irange(30, 40)) { // for(const auto m : c10::irange(10, 31)) { // b[n, m] = (n != m) ? 0.f : 1.f; // } // } test_case_nest_loops_fn( n, m, b, 30, 40, 10, 31, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); test_case_nest_loops_fn( n, m, b, var_N + 30, var_N + 40, var_N + 10, var_N + 31, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); test_case_nest_loops_fn( n, m, b, var_N + 30, var_N + 40, var_M + 10, var_M + 31, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(10, 31)) { // for(const auto m : c10::irange(30, 40)) { // b[n, m] = (n != m) ? 0.f : 1.f; // } // } // After: // for (const auto n : c10::irange(10, 31)) { // for(const auto m : c10::irange(30, 40)) { // b[n, m] = (n != m) ? 0.f : 1.f; // } // } test_case_nest_loops_fn( n, m, b, 10, 31, 30, 40, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); test_case_nest_loops_fn( n, m, b, var_N + 10, var_N + 31, var_N + 30, var_N + 40, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); test_case_nest_loops_fn( n, m, b, var_N + 10, var_N + 31, var_M + 30, var_M + 40, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(10, 20)) { // for(const auto m : c10::irange(30, 40)) { // b[n, m] = (n < m) ? 0.f : 1.f; // } // } // After: // for (const auto n : c10::irange(10, 20)) { // for(const auto m : c10::irange(30, 40)) { // b[n, m] = 0.f; // } // } test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kLT, "b[n, m] = 0.f;"); test_case_nest_loops_fn( n, m, b, var_N + 10, var_N + 20, var_N + 30, var_N + 40, kLT, "b[n, m] = 0.f;"); test_case_nest_loops_fn( n, m, b, var_N + 10, var_N + 20, var_M + 30, var_M + 40, kLT, "b[n, m] = n m) ? 0.f : 1.f; // } // } // After: // for (const auto n : c10::irange(30, 40)) { // for(const auto m : c10::irange(10, 20)) { // b[n, m] = 0.f; // } // } test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kGT, "b[n, m] = 0.f;"); test_case_nest_loops_fn( n, m, b, var_N + 30, var_N + 40, var_N + 10, var_N + 20, kGT, "b[n, m] = 0.f;"); test_case_nest_loops_fn( n, m, b, var_N + 30, var_N + 40, var_M + 10, var_M + 20, kGT, "b[n, m] = n>m ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(10, 31)) { // for(const auto m : c10::irange(30, 40)) { // b[n, m] = (n > m) ? 0.f : 1.f; // } // } // After: // for (const auto n : c10::irange(10, 31)) { // for(const auto m : c10::irange(30, 40)) { // b[n, m] = 1.f; // } // } test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kGT, "b[n, m] = 1.f;"); test_case_nest_loops_fn( n, m, b, var_N + 10, var_N + 31, var_N + 30, var_N + 40, kGT, "b[n, m] = 1.f;"); test_case_nest_loops_fn( n, m, b, var_N + 10, var_N + 31, var_M + 30, var_M + 40, kGT, "b[n, m] = n>m ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(30, 40)) { // for(const auto m : c10::irange(10, 31)) { // b[n, m] = (n >= m) ? 0.f : 1.f; // } // } // After: // for (const auto n : c10::irange(30, 40)) { // for(const auto m : c10::irange(10, 31)) { // b[n, m] = 0.f; // } // } test_case_nest_loops_fn(n, m, b, 30, 40, 10, 31, kGE, "b[n, m] = 0.f;"); test_case_nest_loops_fn( n, m, b, var_N + 30, var_N + 40, var_N + 10, var_N + 31, kGE, "b[n, m] = 0.f;"); test_case_nest_loops_fn( n, m, b, var_N + 30, var_N + 40, var_M + 10, var_M + 31, kGE, "b[n, m] = n>=m ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(10, 20)) { // for(const auto m : c10::irange(30, 40)) { // b[n, m] = (n >= m) ? 0.f : 1.f; // } // } // After: // for (const auto n : c10::irange(10, 20)) { // for(const auto m : c10::irange(30, 40)) { // b[n, m] = 1.f; // } // } test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kGE, "b[n, m] = 1.f;"); test_case_nest_loops_fn( n, m, b, var_N + 10, var_N + 20, var_N + 30, var_N + 40, kGE, "b[n, m] = 1.f;"); test_case_nest_loops_fn( n, m, b, var_N + 10, var_N + 20, var_M + 30, var_M + 40, kGE, "b[n, m] = n>=m ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(10, 31)) { // for(const auto m : c10::irange(30, 40)) { // b[n, m] = (n <= m) ? 0.f : 1.f; // } // } // After: // for (const auto n : c10::irange(10, 31)) { // for(const auto m : c10::irange(30, 40)) { // b[n, m] = 0.f; // } // } test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kLE, "b[n, m] = 0.f;"); test_case_nest_loops_fn( n, m, b, var_N + 10, var_N + 31, var_N + 30, var_N + 40, kLE, "b[n, m] = 0.f;"); test_case_nest_loops_fn( n, m, b, var_N + 10, var_N + 31, var_M + 30, var_M + 40, kLE, "b[n, m] = n<=m ? 0.f : 1.f;"); // Before: // for (const auto n : c10::irange(30, 40)) { // for(const auto m : c10::irange(10, 20)) { // b[n, m] = (n <= m) ? 0.f : 1.f; // } // } // After: // for (const auto n : c10::irange(30, 40)) { // for(const auto m : c10::irange(10, 20)) { // b[n, m] = 0.f; // } // } test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kLE, "b[n, m] = 1.f;"); test_case_nest_loops_fn( n, m, b, var_N + 30, var_N + 40, var_N + 10, var_N + 20, kLE, "b[n, m] = 1.f;"); test_case_nest_loops_fn( n, m, b, var_N + 30, var_N + 40, var_M + 10, var_M + 20, kLE, "b[n, m] = n<=m ? 0.f : 1.f;"); } TEST(Simplify, CompareSelectCondAlwaysInLoopBounds) { // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = n < 1 ? 0.f : 1.f; // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 1.f; // } constexpr int N = 8; BufHandle b("b", {N}, kFloat); VarHandle n("n", kInt); StmtPtr s = For::make( n, 1, N, b.store({n}, CompareSelect::make(n, 1, 0.f, 1.0f, kLT))); s = IRSimplifier::simplify(s); std::ostringstream oss; oss << *s; torch::jit::testing::FileCheck().run( R"IR( # CHECK: b[n] = 1.f; )IR", oss.str()); } TEST(Simplify, IfThenCondAlwaysInLoopBounds) { // Before: // for (const auto n : c10::irange(1, N)) { // b[n] = IfThenElse(n < 1 ? 1 : 0, 0.f, 1.f); // } // After: // for (const auto n : c10::irange(1, N)) { // b[n] = 1.f; // } constexpr int N = 8; BufHandle b("b", {N}, kFloat); VarHandle n("n", kInt); StmtPtr s = For::make(n, 1, N, b.store({n}, IfThenElse::make(n < 1, 0.f, 1.0f))); s = IRSimplifier::simplify(s); std::ostringstream oss; oss << *s; torch::jit::testing::FileCheck().run( R"IR( # CHECK: b[n] = 1.f; )IR", oss.str()); } TEST(Simplify, MultiClauseCondAlwaysInLoopBounds) { // This test mimics the unpadded region of a conv2d. We want to remove any // conditional that is provably satisfied (or unsatisfied) by the entire loop // range. // Before: // for (const auto i : c10::irange(1, 7)) { // for (const auto j : c10::irange(1, 7)) { // b[i, j] = IfThenElse( // j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, 1.f); // After: // for (const auto i : c10::irange(1, 7)) { // for (const auto j : c10::irange(1, 7)) { // b[i, j] = 1.f; constexpr int N = 8; BufHandle b("b", {N, N}, kFloat); VarHandle i("i", kInt); VarHandle j("j", kInt); auto csel = CompareSelect::make(i, 1, kLT); csel = CompareSelect::make(j, 1, 1, csel, kLT); csel = CompareSelect::make(i, N - 1, 1, csel, kGE); csel = CompareSelect::make(j, N - 1, 1, csel, kGE); StmtPtr s = b.store({i, j}, IfThenElse::make(csel, 0.f, 1.0f)); s = For::make(j, 1, N - 1, s); s = For::make(i, 1, N - 1, s); s = IRSimplifier::simplify(s); std::ostringstream oss; oss << *s; torch::jit::testing::FileCheck().run( R"IR( # CHECK: b[i, j] = 1.f; )IR", oss.str()); } TEST(Simplify, DISABLED_SimplifyLoopBounds) { // This test mimics the padded region of a conv2d. We want to adjust the // loop bounds such that the condition will be always met. Note that this // could be solved by peeling, and applying the range-based conditional // simplification in the previous tests. // Before: // for (const auto i : c10::irange(3)) { // for (const auto j : c10::irange(3)) { // b[i, j] = (b[i, j]) + (IfThenElse( // j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, a[i, j])); // After: // for (const auto i : c10::irange(1, 3)) { // for (const auto j : c10::irange(1, 3)) { // b[i, j] = (b[i, j]) + 1.f; constexpr int N = 8; constexpr int K = 3; BufHandle a("a", {N, N}, kFloat); BufHandle b("b", {N, N}, kFloat); VarHandle i("i", kInt); VarHandle j("j", kInt); auto csel = CompareSelect::make(i, 1, kLT); csel = CompareSelect::make(j, 1, 1, csel, kLT); csel = CompareSelect::make(i, N - 1, 1, csel, kGE); csel = CompareSelect::make(j, N - 1, 1, csel, kGE); StmtPtr s = b.store( {i, j}, b.load({i, j}) + IfThenElse::make(csel, 0.f, a.load({i, j}))); s = For::make(j, 0, K, s); s = For::make(i, 0, K, s); s = IRSimplifier::simplify(s); std::ostringstream oss; oss << *s; torch::jit::testing::FileCheck().run( R"IR( # CHECK: for (const auto i : c10::irange(1, 3)) { # CHECK: for (const auto j : c10::irange(1, 3)) { # CHECK-NOT: IfThenElse )IR", oss.str()); } } // namespace jit } // namespace torch