mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/6357 Pull Request resolved: https://github.com/pytorch/pytorch/pull/138364 Approved by: https://github.com/Skylion007, https://github.com/eqy
		
			
				
	
	
		
			5681 lines
		
	
	
		
			171 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			5681 lines
		
	
	
		
			171 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
#include <gtest/gtest.h>
 | 
						|
#include <test/cpp/tensorexpr/test_base.h>
 | 
						|
 | 
						|
#include <c10/util/irange.h>
 | 
						|
#include <test/cpp/tensorexpr/test_utils.h>
 | 
						|
#include <torch/csrc/jit/tensorexpr/hash_provider.h>
 | 
						|
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
 | 
						|
#include <torch/csrc/jit/tensorexpr/loopnest.h>
 | 
						|
 | 
						|
#include <cmath>
 | 
						|
 | 
						|
namespace torch {
 | 
						|
namespace jit {
 | 
						|
using namespace torch::jit::tensorexpr;
 | 
						|
using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
 | 
						|
 | 
						|
TEST(Simplify, ConstantFoldSimple) {
 | 
						|
  ExprHandle a(2.0f);
 | 
						|
  ExprHandle b(3.0f);
 | 
						|
  ExprHandle f = (a + b);
 | 
						|
 | 
						|
  ExprHandle newF = IRSimplifier::simplify(f);
 | 
						|
  ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
 | 
						|
  ASSERT_EQ(newF.AsNode<FloatImm>()->value(), 5);
 | 
						|
 | 
						|
  SimpleIRExprEval eval(newF);
 | 
						|
  ASSERT_EQ(eval.value<float>(), 5.f);
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, ConstantFoldTwoLayer) {
 | 
						|
  ExprHandle a(2.0f);
 | 
						|
  ExprHandle b(3.0f);
 | 
						|
  ExprHandle c(4.0f);
 | 
						|
  ExprHandle d(5.0f);
 | 
						|
  ExprHandle f = (a + b) - (c + d);
 | 
						|
 | 
						|
  ExprHandle newF = IRSimplifier::simplify(f);
 | 
						|
  ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
 | 
						|
  ASSERT_EQ(newF.AsNode<FloatImm>()->value(), -4);
 | 
						|
 | 
						|
  SimpleIRExprEval eval(newF);
 | 
						|
  ASSERT_EQ(eval.value<float>(), -4.f);
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, ConstantFoldShifts) {
 | 
						|
  ExprHandle a(7);
 | 
						|
  ExprHandle b(2);
 | 
						|
  ExprHandle c(3);
 | 
						|
  ExprHandle f = ((a << b) << b) >> c;
 | 
						|
 | 
						|
  ExprHandle newF = IRSimplifier::simplify(f);
 | 
						|
  ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
 | 
						|
  ASSERT_EQ(newF.AsNode<IntImm>()->value(), 14);
 | 
						|
 | 
						|
  SimpleIRExprEval eval(newF);
 | 
						|
  ASSERT_EQ(eval.value<int>(), 7 << (4 - 3));
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, ConstantFoldBitwise) {
 | 
						|
  ExprHandle a(59);
 | 
						|
  ExprHandle b(22);
 | 
						|
  ExprHandle c(101);
 | 
						|
  ExprHandle f = (a ^ b) & c;
 | 
						|
 | 
						|
  ExprHandle newF = IRSimplifier::simplify(f);
 | 
						|
  ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
 | 
						|
  ASSERT_EQ(newF.AsNode<IntImm>()->value(), 37);
 | 
						|
 | 
						|
  SimpleIRExprEval eval(newF);
 | 
						|
  ASSERT_EQ(eval.value<int>(), (59 ^ 22) & 101);
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, ConstantFoldMultiOp) {
 | 
						|
  ExprHandle a(2.0f);
 | 
						|
  ExprHandle b(3.0f);
 | 
						|
  ExprHandle c(4.0f);
 | 
						|
  ExprHandle d(5.0f);
 | 
						|
  ExprHandle e(6.0f);
 | 
						|
  ExprHandle f(7.0f);
 | 
						|
  ExprHandle fn = ((a / e) - (c + d)) * (f / b);
 | 
						|
 | 
						|
  ExprHandle newF = IRSimplifier::simplify(fn);
 | 
						|
  ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
 | 
						|
 | 
						|
  SimpleIRExprEval eval(newF);
 | 
						|
  SimpleIRExprEval ref(fn);
 | 
						|
 | 
						|
  ASSERT_EQ(eval.value<float>(), ref.value<float>());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, ConstantFoldMinMax) {
 | 
						|
  ExprHandle a(12.0f);
 | 
						|
  ExprHandle b(15.0f);
 | 
						|
  ExprHandle c(17.0f);
 | 
						|
 | 
						|
  // x = max(12, min(15, 17)).
 | 
						|
  ExprHandle minHandle = Min::make(b, c, true);
 | 
						|
  ExprHandle fn = Max::make(a, minHandle, false);
 | 
						|
 | 
						|
  // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
 | 
						|
  ASSERT_EQ(fn.dtype().scalar_type(), ScalarType::Float);
 | 
						|
 | 
						|
  ExprHandle newF = IRSimplifier::simplify(fn);
 | 
						|
  ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
 | 
						|
 | 
						|
  SimpleIRExprEval eval(newF);
 | 
						|
  ASSERT_EQ(eval.value<float>(), 15.f);
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, ConstantFoldIntrinsics) {
 | 
						|
  ExprHandle a(2.0f);
 | 
						|
  ExprHandle b(3.0f);
 | 
						|
  ExprHandle c(4.0f);
 | 
						|
  ExprHandle powHandle = Intrinsics::make(kPow, a, b);
 | 
						|
  ExprHandle sinHandle = Intrinsics::make(kSin, powHandle);
 | 
						|
  ExprHandle modHandle = Intrinsics::make(kFmod, c, sinHandle);
 | 
						|
  ExprHandle logHandle = Intrinsics::make(kLog10, modHandle);
 | 
						|
  ExprHandle rndHandle = Intrinsics::make(kRound, logHandle);
 | 
						|
  ExprHandle fn = Intrinsics::make(kAbs, rndHandle);
 | 
						|
 | 
						|
  ExprHandle newF = IRSimplifier::simplify(fn);
 | 
						|
  ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
 | 
						|
  ASSERT_EQ(newF.AsNode<FloatImm>()->value(), 1);
 | 
						|
 | 
						|
  SimpleIRExprEval eval(newF);
 | 
						|
  SimpleIRExprEval ref(fn);
 | 
						|
 | 
						|
  ASSERT_EQ(eval.value<float>(), ref.value<float>());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, ConstantFoldCastToBool) {
 | 
						|
  ExprHandle f = Cast::make(kBool, IntImm::make(0));
 | 
						|
  ExprHandle newF = IRSimplifier::simplify(f);
 | 
						|
  SimpleIRExprEval eval(newF);
 | 
						|
  ASSERT_EQ(eval.value<bool>(), false);
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, ConstantFoldWithVar) {
 | 
						|
  {
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    ExprHandle body = x * (ExprHandle(2) + ExprHandle(4));
 | 
						|
 | 
						|
    ExprHandle newF = IRSimplifier::simplify(body);
 | 
						|
    MulPtr root = newF.AsNode<Mul>();
 | 
						|
    ASSERT_NE(root, nullptr);
 | 
						|
    ASSERT_NE(to<IntImm>(root->lhs()), nullptr);
 | 
						|
 | 
						|
    SimpleIRExprEval eval(newF);
 | 
						|
    eval.bindVar(x, ExprHandle(3));
 | 
						|
    ASSERT_EQ(eval.value<int>(), 3 * (2 + 4));
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    VarHandle x("x", kFloat);
 | 
						|
    ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f));
 | 
						|
 | 
						|
    ExprHandle newF = IRSimplifier::simplify(body);
 | 
						|
    MulPtr root = newF.AsNode<Mul>();
 | 
						|
    ASSERT_NE(root, nullptr);
 | 
						|
    ASSERT_NE(to<FloatImm>(root->rhs()), nullptr);
 | 
						|
 | 
						|
    SimpleIRExprEval eval(newF);
 | 
						|
    eval.bindVar(x, ExprHandle(3.f));
 | 
						|
    ASSERT_EQ(eval.value<float>(), 3 * (2 + 4));
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, ConditionalSelectFoldSimple) {
 | 
						|
  ExprHandle a(3.0f);
 | 
						|
  ExprHandle b(4.0f);
 | 
						|
  ExprHandle c(3.0f);
 | 
						|
  {
 | 
						|
    ExprHandle f = (a > b);
 | 
						|
 | 
						|
    ExprHandle newF = IRSimplifier::simplify(f);
 | 
						|
    ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
 | 
						|
    ASSERT_EQ(newF.AsNode<IntImm>()->value(), 0);
 | 
						|
 | 
						|
    SimpleIRExprEval eval(newF);
 | 
						|
    ASSERT_EQ(eval.value<int>(), 0);
 | 
						|
  }
 | 
						|
  {
 | 
						|
    ExprHandle f = (a < b);
 | 
						|
 | 
						|
    ExprHandle newF = IRSimplifier::simplify(f);
 | 
						|
    ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
 | 
						|
    ASSERT_EQ(newF.AsNode<IntImm>()->value(), 1);
 | 
						|
 | 
						|
    SimpleIRExprEval eval(newF);
 | 
						|
    ASSERT_EQ(eval.value<int>(), 1);
 | 
						|
  }
 | 
						|
  {
 | 
						|
    ExprHandle f = (a == c);
 | 
						|
 | 
						|
    ExprHandle newF = IRSimplifier::simplify(f);
 | 
						|
    ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
 | 
						|
    ASSERT_EQ(newF.AsNode<IntImm>()->value(), 1);
 | 
						|
 | 
						|
    SimpleIRExprEval eval(newF);
 | 
						|
    ASSERT_EQ(eval.value<int>(), 1);
 | 
						|
  }
 | 
						|
  {
 | 
						|
    ExprHandle f = (a != c);
 | 
						|
 | 
						|
    ExprHandle newF = IRSimplifier::simplify(f);
 | 
						|
    ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
 | 
						|
    ASSERT_EQ(newF.AsNode<IntImm>()->value(), 0);
 | 
						|
 | 
						|
    SimpleIRExprEval eval(newF);
 | 
						|
    ASSERT_EQ(eval.value<int>(), 0);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, ConditionalSelectFoldTwoLayer) {
 | 
						|
  ExprHandle a(3.0f);
 | 
						|
  ExprHandle b(2.0f);
 | 
						|
  ExprHandle c(2.0f);
 | 
						|
  ExprHandle d(1.0f);
 | 
						|
  {
 | 
						|
    ExprHandle f = (a + b < c + d);
 | 
						|
 | 
						|
    ExprHandle newF = IRSimplifier::simplify(f);
 | 
						|
    ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
 | 
						|
    ASSERT_EQ(newF.AsNode<IntImm>()->value(), 0);
 | 
						|
 | 
						|
    SimpleIRExprEval eval(newF);
 | 
						|
    ASSERT_EQ(eval.value<int>(), 0);
 | 
						|
  }
 | 
						|
  {
 | 
						|
    ExprHandle f = (a + b > c + d);
 | 
						|
 | 
						|
    ExprHandle newF = IRSimplifier::simplify(f);
 | 
						|
    ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
 | 
						|
    ASSERT_EQ(newF.AsNode<IntImm>()->value(), 1);
 | 
						|
 | 
						|
    SimpleIRExprEval eval(newF);
 | 
						|
    ASSERT_EQ(eval.value<int>(), 1);
 | 
						|
  }
 | 
						|
  {
 | 
						|
    ExprHandle f = (a + d == b + c);
 | 
						|
 | 
						|
    ExprHandle newF = IRSimplifier::simplify(f);
 | 
						|
    ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
 | 
						|
    ASSERT_EQ(newF.AsNode<IntImm>()->value(), 1);
 | 
						|
 | 
						|
    SimpleIRExprEval eval(newF);
 | 
						|
    ASSERT_EQ(eval.value<int>(), 1);
 | 
						|
  }
 | 
						|
  {
 | 
						|
    ExprHandle f = (a + d != b + c);
 | 
						|
 | 
						|
    ExprHandle newF = IRSimplifier::simplify(f);
 | 
						|
    ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
 | 
						|
    ASSERT_EQ(newF.AsNode<IntImm>()->value(), 0);
 | 
						|
 | 
						|
    SimpleIRExprEval eval(newF);
 | 
						|
    ASSERT_EQ(eval.value<int>(), 0);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, ConditionalSelectFoldWithVar) {
 | 
						|
  VarHandle x("x", kFloat);
 | 
						|
  ExprHandle f = x < 4.f;
 | 
						|
 | 
						|
  ExprHandle newF = IRSimplifier::simplify(f);
 | 
						|
  IntImmPtr folded = newF.AsNode<IntImm>();
 | 
						|
  ASSERT_EQ(folded, nullptr);
 | 
						|
 | 
						|
  {
 | 
						|
    SimpleIRExprEval eval(newF);
 | 
						|
    eval.bindVar(x, ExprHandle(3.f));
 | 
						|
    ASSERT_EQ(eval.value<int>(), 1);
 | 
						|
  }
 | 
						|
  {
 | 
						|
    SimpleIRExprEval eval(newF);
 | 
						|
    eval.bindVar(x, ExprHandle(5.f));
 | 
						|
    ASSERT_EQ(eval.value<int>(), 0);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, UnFoldableExpr) {
 | 
						|
  VarHandle x("x", kFloat);
 | 
						|
  VarHandle y("y", kFloat);
 | 
						|
  ExprHandle body = (ExprHandle(3) * x) + (ExprHandle(5) * y);
 | 
						|
 | 
						|
  ExprHandle newF = IRSimplifier::simplify(body);
 | 
						|
  AddPtr root = newF.AsNode<Add>();
 | 
						|
  ASSERT_NE(root, nullptr);
 | 
						|
  ASSERT_EQ(to<FloatImm>(root->lhs()), nullptr);
 | 
						|
  ASSERT_EQ(to<FloatImm>(root->rhs()), nullptr);
 | 
						|
 | 
						|
  SimpleIRExprEval eval(newF);
 | 
						|
  eval.bindVar(x, ExprHandle(3.f));
 | 
						|
  eval.bindVar(y, ExprHandle(2.f));
 | 
						|
  ASSERT_EQ(eval.value<float>(), 9 + 10);
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, HashSimple) {
 | 
						|
  VarHandle x("x", kFloat);
 | 
						|
  ExprHandle a(2.0f);
 | 
						|
  ExprHandle b(3.0f);
 | 
						|
  ExprHandle f = a + b * x;
 | 
						|
 | 
						|
  HashProvider hasher;
 | 
						|
 | 
						|
  auto hash_x = hasher.hash(x.node());
 | 
						|
  auto hash_a = hasher.hash(a.node());
 | 
						|
  auto hash_f = hasher.hash(f.node());
 | 
						|
 | 
						|
  ASSERT_NE(hash_x, (size_t)0);
 | 
						|
  ASSERT_NE(hash_a, (size_t)0);
 | 
						|
  ASSERT_NE(hash_f, (size_t)0);
 | 
						|
  ASSERT_NE(hash_x, hash_a);
 | 
						|
  ASSERT_NE(hash_x, hash_f);
 | 
						|
  ASSERT_NE(hash_a, hash_f);
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, HashEquivalence) {
 | 
						|
  VarHandle x("x", kFloat);
 | 
						|
  VarHandle y("y", kFloat);
 | 
						|
  ExprHandle f = (x * y) + (x * y);
 | 
						|
 | 
						|
  AddPtr root = f.AsNode<Add>();
 | 
						|
  ASSERT_NE(root, nullptr);
 | 
						|
 | 
						|
  HashProvider hasher;
 | 
						|
  auto hash_f = hasher.hash(f.node());
 | 
						|
  auto hash_l = hasher.hash(root->lhs());
 | 
						|
  auto hash_r = hasher.hash(root->rhs());
 | 
						|
 | 
						|
  // Root not equal to either branch.
 | 
						|
  ASSERT_NE(hash_f, hash_l);
 | 
						|
  ASSERT_NE(hash_f, hash_r);
 | 
						|
  // but branches are equal.
 | 
						|
  ASSERT_EQ(hash_l, hash_r);
 | 
						|
 | 
						|
  // Still equivalent if separate.
 | 
						|
  ExprHandle a(2);
 | 
						|
  ExprHandle f2 = x + a / y;
 | 
						|
  ExprHandle b(2);
 | 
						|
  ExprHandle f3 = x + b / y;
 | 
						|
  ASSERT_EQ(hasher.hash(f2.node()), hasher.hash(f3.node()));
 | 
						|
 | 
						|
  // Not equivalent if different vars (even with same name).
 | 
						|
  VarHandle z("x", kFloat);
 | 
						|
  ExprHandle f4 = z + b / y;
 | 
						|
  ASSERT_NE(hasher.hash(f2.node()), hasher.hash(f4.node()));
 | 
						|
 | 
						|
  // Intrinsics sanity check.
 | 
						|
  ExprHandle f5 = Intrinsics::make(kSin, x) * Intrinsics::make(kCos, x);
 | 
						|
  ASSERT_NE(hasher.hash(f5.node()), (size_t)0);
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, HashEquivalenceRand) {
 | 
						|
  ExprHandle f =
 | 
						|
      Intrinsics::make(kRand, kFloat) + Intrinsics::make(kRand, kInt);
 | 
						|
 | 
						|
  AddPtr root = f.AsNode<Add>();
 | 
						|
  ASSERT_NE(root, nullptr);
 | 
						|
 | 
						|
  HashProvider hasher;
 | 
						|
  auto hash_f = hasher.hash(f.node());
 | 
						|
  auto hash_l = hasher.hash(root->lhs());
 | 
						|
  auto hash_r = hasher.hash(root->rhs());
 | 
						|
 | 
						|
  // Root not equal to either branch.
 | 
						|
  ASSERT_NE(hash_f, hash_l);
 | 
						|
  ASSERT_NE(hash_f, hash_r);
 | 
						|
  // and branches are NOT equal.
 | 
						|
  ASSERT_NE(hash_l, hash_r);
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, HashEquivalenceAfterFolding) {
 | 
						|
  VarHandle x("x", kFloat);
 | 
						|
  ExprHandle a(2.0f);
 | 
						|
  ExprHandle b(3.0f);
 | 
						|
  ExprHandle c(5.0f);
 | 
						|
 | 
						|
  ExprHandle f1 = ((a + b) * x);
 | 
						|
  ExprHandle f2 = (c * x);
 | 
						|
 | 
						|
  HashProvider hasher;
 | 
						|
  auto hash_l = hasher.hash(f1.node());
 | 
						|
  auto hash_r = hasher.hash(f2.node());
 | 
						|
 | 
						|
  // Root not equal to either branch, and branches not equal.
 | 
						|
  ASSERT_NE(hash_l, hash_r);
 | 
						|
 | 
						|
  ExprHandle ff1 = IRSimplifier::simplify(f1);
 | 
						|
  ExprHandle ff2 = IRSimplifier::simplify(f2);
 | 
						|
 | 
						|
  auto hash_l_n = hasher.hash(ff1.node());
 | 
						|
  auto hash_r_n = hasher.hash(ff2.node());
 | 
						|
  // but branches are now equal.
 | 
						|
  ASSERT_EQ(hash_l_n, hash_r_n);
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, HashDifferenceTypes) {
 | 
						|
  HashProvider hasher;
 | 
						|
  std::vector<ExprPtr> immediates;
 | 
						|
 | 
						|
  immediates.push_back(alloc<DoubleImm>(1));
 | 
						|
  immediates.push_back(alloc<FloatImm>(1));
 | 
						|
  immediates.push_back(alloc<HalfImm>(1));
 | 
						|
  // NOLINTNEXTLINE(modernize-use-bool-literals)
 | 
						|
  immediates.push_back(alloc<BoolImm>(1));
 | 
						|
  immediates.push_back(alloc<CharImm>(1));
 | 
						|
  immediates.push_back(alloc<ByteImm>(1));
 | 
						|
  immediates.push_back(alloc<ShortImm>(1));
 | 
						|
  immediates.push_back(alloc<IntImm>(1));
 | 
						|
  immediates.push_back(alloc<LongImm>(1));
 | 
						|
 | 
						|
  // Immediates of different types are not equal.
 | 
						|
  for (unsigned int i = 0; i < immediates.size(); ++i) {
 | 
						|
    for (unsigned int j = i + 1; j < immediates.size(); ++j) {
 | 
						|
      ASSERT_NE(hasher.hash(immediates[i]), hasher.hash(immediates[j]));
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  // But coerced immediates are if they are the same type:
 | 
						|
  ExprHandle f1 = ExprHandle(2.f) + CharImm::make(1);
 | 
						|
  ExprHandle f2 = Cast::make(kFloat, IntImm::make(3));
 | 
						|
 | 
						|
  ExprHandle ff1 = IRSimplifier::simplify(f1);
 | 
						|
  ExprHandle ff2 = IRSimplifier::simplify(f2);
 | 
						|
 | 
						|
  ASSERT_EQ(hasher.hash(ff1.node()), hasher.hash(ff2.node()));
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, HashLargeExpression) {
 | 
						|
  constexpr int N = 1024;
 | 
						|
  BufHandle a("A", {N}, kInt);
 | 
						|
  BufHandle b("B", {N}, kInt);
 | 
						|
  BufHandle c("C", {N}, kInt);
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  auto memcpy_stmt = For::make(
 | 
						|
      i,
 | 
						|
      0,
 | 
						|
      N,
 | 
						|
      Store::make(
 | 
						|
          c,
 | 
						|
          {i},
 | 
						|
          CompareSelect::make(
 | 
						|
              Load::make(a, {i}),
 | 
						|
              Load::make(b, {i}),
 | 
						|
              CompareSelectOperation::kEQ)));
 | 
						|
 | 
						|
  BufHandle d("D", {1}, kInt);
 | 
						|
  BufHandle e("E", {1}, kInt);
 | 
						|
  auto store_ramp_stmt = Store::make(
 | 
						|
      e, {Ramp::make(0, 1, 4)}, Load::make(d, {Ramp::make(0, 1, 4)}));
 | 
						|
 | 
						|
  auto if_stmt = Cond::make(
 | 
						|
      CompareSelect::make(
 | 
						|
          Load::make(a, {i}), Load::make(b, {i}), CompareSelectOperation::kGE),
 | 
						|
      memcpy_stmt,
 | 
						|
      store_ramp_stmt);
 | 
						|
 | 
						|
  HashProvider hasher;
 | 
						|
  auto hash_r = hasher.hash(if_stmt);
 | 
						|
  // We should not have to do any more work.
 | 
						|
  ASSERT_TRUE(hasher.cachedHash(memcpy_stmt));
 | 
						|
  auto hash_t = hasher.hash(memcpy_stmt);
 | 
						|
  ASSERT_TRUE(hasher.cachedHash(store_ramp_stmt));
 | 
						|
  auto hash_f = hasher.hash(store_ramp_stmt);
 | 
						|
 | 
						|
  // Root not equal to either branch, and branches not equal.
 | 
						|
  ASSERT_NE(hash_r, hash_t);
 | 
						|
  ASSERT_NE(hash_r, hash_f);
 | 
						|
  ASSERT_NE(hash_t, hash_f);
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, HashForLoopOptions) {
 | 
						|
  constexpr int N = 1024;
 | 
						|
  BufHandle a("A", {N}, kInt);
 | 
						|
  BufHandle b("B", {N}, kInt);
 | 
						|
  BufHandle c("C", {N}, kInt);
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  auto for_stmt = For::make(
 | 
						|
      i,
 | 
						|
      0,
 | 
						|
      N,
 | 
						|
      Store::make(
 | 
						|
          c,
 | 
						|
          {i},
 | 
						|
          CompareSelect::make(
 | 
						|
              Load::make(a, {i}),
 | 
						|
              Load::make(b, {i}),
 | 
						|
              CompareSelectOperation::kEQ)));
 | 
						|
 | 
						|
  HashProvider hasher;
 | 
						|
  auto hash_before = hasher.hash(for_stmt);
 | 
						|
  hasher.clearCache();
 | 
						|
 | 
						|
  for_stmt->set_gpu_block_index(LoopOptions::IDX_X);
 | 
						|
  auto hash_block_idx = hasher.hash(for_stmt);
 | 
						|
  hasher.clearCache();
 | 
						|
 | 
						|
  ASSERT_NE(hash_before, hash_block_idx);
 | 
						|
 | 
						|
  for_stmt->set_gpu_block_index(LoopOptions::IDX_UNSET);
 | 
						|
  auto hash_reset = hasher.hash(for_stmt);
 | 
						|
  hasher.clearCache();
 | 
						|
 | 
						|
  ASSERT_EQ(hash_before, hash_reset);
 | 
						|
  for_stmt->set_gpu_thread_index(LoopOptions::IDX_X);
 | 
						|
  auto hash_thread_idx = hasher.hash(for_stmt);
 | 
						|
 | 
						|
  ASSERT_NE(hash_before, hash_thread_idx);
 | 
						|
  ASSERT_NE(hash_block_idx, hash_thread_idx);
 | 
						|
}
 | 
						|
 | 
						|
/// (2 + x) + 4 => x + 6
 | 
						|
TEST(Simplify, SimplifyAdd) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
 | 
						|
  // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
 | 
						|
  VarHandle m("m", kInt);
 | 
						|
  // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
 | 
						|
  VarHandle n("n", kInt);
 | 
						|
  // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
 | 
						|
  VarHandle n_1("n_1", kInt);
 | 
						|
  // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
 | 
						|
  ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4);
 | 
						|
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
  AddPtr root = simplified.AsNode<Add>();
 | 
						|
  ASSERT_NE(root, nullptr);
 | 
						|
  VarPtr lhs = to<Var>(root->lhs());
 | 
						|
  ASSERT_NE(lhs, nullptr);
 | 
						|
  ASSERT_EQ(lhs->name_hint(), "x");
 | 
						|
  IntImmPtr rhs = to<IntImm>(root->rhs());
 | 
						|
  ASSERT_NE(rhs, nullptr);
 | 
						|
  ASSERT_EQ(rhs->value(), 6.f);
 | 
						|
}
 | 
						|
 | 
						|
/// (2 - x) - 4 => -2 - x
 | 
						|
TEST(Simplify, SimplifySub) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4);
 | 
						|
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
  SubPtr root = simplified.AsNode<Sub>();
 | 
						|
  ASSERT_NE(root, nullptr);
 | 
						|
  IntImmPtr lhs = to<IntImm>(root->lhs());
 | 
						|
  ASSERT_NE(lhs, nullptr);
 | 
						|
  ASSERT_EQ(lhs->value(), -2.f);
 | 
						|
  VarPtr rhs = to<Var>(root->rhs());
 | 
						|
  ASSERT_NE(rhs, nullptr);
 | 
						|
  ASSERT_EQ(rhs->name_hint(), "x");
 | 
						|
}
 | 
						|
 | 
						|
/// 2 * (1 - x) - 4 => 2 * (-3 - x)
 | 
						|
TEST(Simplify, SimplifyMultiLayer) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4));
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
  IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
  IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
  IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
 | 
						|
  IS_IMM_WITH_VAL(Int, sub->lhs(), -3);
 | 
						|
  IS_VAR_WITH_NAME(sub->rhs(), "x");
 | 
						|
}
 | 
						|
 | 
						|
/// 2 * (3 * x) - (x * 4) => 2 * x
 | 
						|
TEST(Simplify, SimplifyMultiTerm) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  ExprHandle body =
 | 
						|
      (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4)));
 | 
						|
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
  MulPtr root = simplified.AsNode<Mul>();
 | 
						|
  ASSERT_NE(root, nullptr);
 | 
						|
  IntImmPtr lhs = to<IntImm>(root->lhs());
 | 
						|
  ASSERT_NE(lhs, nullptr);
 | 
						|
  ASSERT_EQ(lhs->value(), 2);
 | 
						|
  VarPtr rhs = to<Var>(root->rhs());
 | 
						|
  ASSERT_NE(rhs, nullptr);
 | 
						|
  ASSERT_EQ(rhs->name_hint(), "x");
 | 
						|
}
 | 
						|
 | 
						|
/// 2 * (3 * (long)x) - (x * 4) => 2 * x
 | 
						|
TEST(Simplify, SimplifyCasts) {
 | 
						|
  VarHandle x("x", kLong);
 | 
						|
  ExprHandle body =
 | 
						|
      (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4)));
 | 
						|
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
  MulPtr root = simplified.AsNode<Mul>();
 | 
						|
  ASSERT_NE(root, nullptr);
 | 
						|
  LongImmPtr lhs = to<LongImm>(root->lhs());
 | 
						|
  ASSERT_NE(lhs, nullptr);
 | 
						|
  ASSERT_EQ(lhs->value(), 2);
 | 
						|
  VarPtr rhs = to<Var>(root->rhs());
 | 
						|
  ASSERT_NE(rhs, nullptr);
 | 
						|
  ASSERT_EQ(rhs->name_hint(), "x");
 | 
						|
}
 | 
						|
 | 
						|
/// (x + 0) * 1 => x
 | 
						|
TEST(Simplify, SimplifyEliminatesNoOps) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  ExprHandle body = (x + ExprHandle(0)) * 1;
 | 
						|
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
  VarPtr root = simplified.AsNode<Var>();
 | 
						|
  ASSERT_NE(root, nullptr);
 | 
						|
  ASSERT_EQ(root->name_hint(), "x");
 | 
						|
}
 | 
						|
 | 
						|
/// Cannot simplify this.
 | 
						|
TEST(Simplify, SimplifyMultiVar) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
  ExprHandle body = x * 24 + y * 34;
 | 
						|
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
  AddPtr root = simplified.AsNode<Add>();
 | 
						|
  ASSERT_NE(root, nullptr);
 | 
						|
  MulPtr lhs = to<Mul>(root->lhs());
 | 
						|
  ASSERT_NE(lhs, nullptr);
 | 
						|
  VarPtr varX = to<Var>(lhs->rhs());
 | 
						|
  ASSERT_NE(varX, nullptr);
 | 
						|
  ASSERT_EQ(varX->name_hint(), "x");
 | 
						|
  MulPtr rhs = to<Mul>(root->rhs());
 | 
						|
  ASSERT_NE(rhs, nullptr);
 | 
						|
  VarPtr varY = to<Var>(rhs->rhs());
 | 
						|
  ASSERT_NE(varY, nullptr);
 | 
						|
  ASSERT_EQ(varY->name_hint(), "y");
 | 
						|
}
 | 
						|
 | 
						|
// x + 2 + y => x + y + 2
 | 
						|
TEST(Simplify, DISABLED_SimplifyReorderings) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
  ExprHandle body = x + 2 + y;
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
  AddPtr root = simplified.AsNode<Add>();
 | 
						|
  ASSERT_NE(root, nullptr);
 | 
						|
 | 
						|
  IS_NODE_WITH_NAME(Add, root->lhs(), rhs);
 | 
						|
  IS_VAR_WITH_NAME(rhs->lhs(), "x");
 | 
						|
  IS_VAR_WITH_NAME(rhs->rhs(), "y");
 | 
						|
  IS_IMM_WITH_VAL(Int, root->rhs(), 2);
 | 
						|
}
 | 
						|
 | 
						|
/// y + x * 0 => y
 | 
						|
TEST(Simplify, SimplifyEliminatesVar) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
  ExprHandle body = y + x * ExprHandle(0);
 | 
						|
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
  IS_VAR_WITH_NAME(simplified.node(), "y");
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyAdds) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
 | 
						|
  {
 | 
						|
    // (x + y) + (x + y) => 2 * (x + y)
 | 
						|
    ExprHandle body = (x + y) + (x + y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), root);
 | 
						|
    IS_IMM_WITH_VAL(Int, root->lhs(), 2);
 | 
						|
    IS_NODE_WITH_NAME(Add, root->rhs(), add);
 | 
						|
    IS_VAR_WITH_NAME(add->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(add->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x * y) + (x * y) => 2 * (x * y)
 | 
						|
    ExprHandle body = (x * y) + (x * y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), root);
 | 
						|
    IS_IMM_WITH_VAL(Int, root->lhs(), 2);
 | 
						|
    IS_NODE_WITH_NAME(Mul, root->rhs(), mul);
 | 
						|
    IS_VAR_WITH_NAME(mul->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x - y) + (x - y) => 2 * (x - y)
 | 
						|
    ExprHandle body = (x - y) + (x - y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
 | 
						|
    IS_VAR_WITH_NAME(rhs->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(rhs->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x + x + x + x) => 4 * x
 | 
						|
    ExprHandle body = (x + x + x + x);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), root);
 | 
						|
    IS_IMM_WITH_VAL(Int, root->lhs(), 4);
 | 
						|
    IS_VAR_WITH_NAME(root->rhs(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x + 0) => x.
 | 
						|
    ExprHandle body = x + 0;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x + 0.f) => float(x).
 | 
						|
    ExprHandle body = x + 0.f;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Cast, simplified.node(), cast);
 | 
						|
    ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
 | 
						|
    IS_VAR_WITH_NAME(cast->src_value(), "x");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyMuls) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
 | 
						|
  {
 | 
						|
    // (x + y) * (x + y) => (x + y) * (x + y)
 | 
						|
    // We don't attempt to simplify multiplication of polynomials since the
 | 
						|
    // result is only very rarely more efficient.
 | 
						|
    ExprHandle body = (x + y) * (x + y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_NODE_WITH_NAME(Add, mul->lhs(), lhs);
 | 
						|
    IS_VAR_WITH_NAME(lhs->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(lhs->rhs(), "y");
 | 
						|
    IS_NODE_WITH_NAME(Add, mul->rhs(), rhs);
 | 
						|
    IS_VAR_WITH_NAME(rhs->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(rhs->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // x * y * x * y => x * x * y * y
 | 
						|
    // These get reordered only.
 | 
						|
    ExprHandle body = x * y * x * y;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul1);
 | 
						|
    IS_NODE_WITH_NAME(Mul, mul1->lhs(), mul2);
 | 
						|
    IS_NODE_WITH_NAME(Mul, mul2->lhs(), mul3);
 | 
						|
    IS_VAR_WITH_NAME(mul1->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(mul2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(mul3->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mul3->rhs(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // 1 * (x * 1) => x
 | 
						|
    // Ones cancel cleanly.
 | 
						|
    ExprHandle body = ExprHandle(1) * (x * ExprHandle(1));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // 1.f * (x * 1.f) => x
 | 
						|
    // Even float ones cancel cleanly, but carry their type.
 | 
						|
    ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(1.f));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Cast, simplified.node(), cast);
 | 
						|
    ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
 | 
						|
    IS_VAR_WITH_NAME(cast->src_value(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // 1 * (x * 1.f) => x
 | 
						|
    // One float is enough to cast the expr.
 | 
						|
    ExprHandle body = ExprHandle(1) * (x * ExprHandle(1.f));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Cast, simplified.node(), cast);
 | 
						|
    ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
 | 
						|
    IS_VAR_WITH_NAME(cast->src_value(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // 1 * (x * 0) => 0
 | 
						|
    // Zeroes are eliminated.
 | 
						|
    ExprHandle body = ExprHandle(1) * (x * ExprHandle(0));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // 1 * (x * 0) => 0
 | 
						|
    // But not for Float since nan * 0 = nan.
 | 
						|
    ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(0.f));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_NODE_WITH_NAME(Cast, mul->lhs(), cast);
 | 
						|
    ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
 | 
						|
    IS_VAR_WITH_NAME(cast->src_value(), "x");
 | 
						|
    IS_IMM_WITH_VAL(Float, mul->rhs(), 0.0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x - y) * (x - y) => (x - y) * (x - y)
 | 
						|
    // As with Add we don't attempt simplification of this.
 | 
						|
    ExprHandle body = (x - y) * (x - y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_NODE_WITH_NAME(Sub, mul->lhs(), lhs);
 | 
						|
    IS_VAR_WITH_NAME(lhs->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(lhs->rhs(), "y");
 | 
						|
    IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
 | 
						|
    IS_VAR_WITH_NAME(rhs->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(rhs->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x + y) * (x - y) => (x + y) * (x - y)
 | 
						|
    // Don't simplify with different ops on each side.
 | 
						|
    ExprHandle body = (x + y) * (x - y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_NODE_WITH_NAME(Add, mul->lhs(), lhs);
 | 
						|
    IS_VAR_WITH_NAME(lhs->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(lhs->rhs(), "y");
 | 
						|
    IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
 | 
						|
    IS_VAR_WITH_NAME(rhs->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(rhs->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Multiply a polynomial by a term.
 | 
						|
    //   - term with no scalar, poly with non-identity scalar.
 | 
						|
    // x * (y + 1) => x + x * y
 | 
						|
    ExprHandle body = x * (y + ExprHandle(1));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Add, simplified.node(), add);
 | 
						|
    IS_VAR_WITH_NAME(add->lhs(), "x");
 | 
						|
    IS_NODE_WITH_NAME(Mul, add->rhs(), mul);
 | 
						|
    IS_VAR_WITH_NAME(mul->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Multiply a polynomial by a term.
 | 
						|
    //   - term with identity scalar, poly with non-identity scalar.
 | 
						|
    // (x * 1) * (y + 1) => x + x * y
 | 
						|
    ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(1));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Add, simplified.node(), add);
 | 
						|
    IS_VAR_WITH_NAME(add->lhs(), "x");
 | 
						|
    IS_NODE_WITH_NAME(Mul, add->rhs(), mul);
 | 
						|
    IS_VAR_WITH_NAME(mul->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Multiply a polynomial by a term.
 | 
						|
    //   - term with non-identity scalar, poly with non-identity scalar.
 | 
						|
    // (x * 2) * (y + 1) => 2 * (x + x * y)
 | 
						|
    ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(1));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
    IS_NODE_WITH_NAME(Add, mul->rhs(), add);
 | 
						|
    IS_VAR_WITH_NAME(add->lhs(), "x");
 | 
						|
    IS_NODE_WITH_NAME(Mul, add->rhs(), mul2);
 | 
						|
    IS_VAR_WITH_NAME(mul2->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mul2->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Multiply a polynomial by a term.
 | 
						|
    //   - term with non-identity scalar, poly with identity scalar.
 | 
						|
    // (x * 2) * (y + 0) => 2 * (x * y)
 | 
						|
    ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(0));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
    IS_NODE_WITH_NAME(Mul, mul->rhs(), mul2);
 | 
						|
    IS_VAR_WITH_NAME(mul2->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mul2->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Multiply a polynomial by a term.
 | 
						|
    //   - term with identity scalar, poly with identity scalar.
 | 
						|
    // (x * 1) * (y + 0) => x * y
 | 
						|
    ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(0));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_VAR_WITH_NAME(mul->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Multiply a polynomial by a term.
 | 
						|
    //   - term with no scalar, poly with identity scalar.
 | 
						|
    // x * (y + 0) => x * y
 | 
						|
    ExprHandle body = x * (y + ExprHandle(0));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_VAR_WITH_NAME(mul->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "y");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
// Sub an expr from itself will result in zero.
 | 
						|
TEST(Simplify, SimplifySubs) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
 | 
						|
  {
 | 
						|
    // (x + y) - (x + y) => 0
 | 
						|
    ExprHandle body = (x + y) - (x + y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x * y) - (x * y) => 0
 | 
						|
    ExprHandle body = (x * y) - (x * y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x - y) - (x - y) => 0
 | 
						|
    ExprHandle body = (x - y) - (x - y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x + y) - 2 * (x + y) => -1 * x - y
 | 
						|
    ExprHandle body = (x + y) - ExprHandle(2) * (x + y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
 | 
						|
    IS_NODE_WITH_NAME(Mul, sub->lhs(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), -1);
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(sub->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x + y) - y => x
 | 
						|
    ExprHandle body = (x + y) - y;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x - 0) => x.
 | 
						|
    ExprHandle body = x - 0;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x - 0.f) => x.
 | 
						|
    // Simple enough to cancel in float.
 | 
						|
    ExprHandle body = x - ExprHandle(0.f);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Cast, simplified.node(), cast);
 | 
						|
    ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
 | 
						|
    IS_VAR_WITH_NAME(cast->src_value(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x - (float)(y - y)) => x.
 | 
						|
    ExprHandle body = x - Cast::make(kFloat, y - y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Cast, simplified.node(), cast);
 | 
						|
    ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
 | 
						|
    IS_VAR_WITH_NAME(cast->src_value(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x - y) - y => x - 2 * y
 | 
						|
    ExprHandle body = (x - y) - y;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
 | 
						|
    IS_VAR_WITH_NAME(sub->lhs(), "x");
 | 
						|
    IS_NODE_WITH_NAME(Mul, sub->rhs(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // 2 * x - x => x
 | 
						|
    ExprHandle body = (ExprHandle(2) * x) - x;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // x - 2 * x = -1 * x
 | 
						|
    // We don't have a unary negate, but this could be 0 -x I guess?
 | 
						|
    ExprHandle body = x - (ExprHandle(2) * x);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), -1);
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x + y + 5) * (x - x) => 0
 | 
						|
    // Cancelling out one side of Mul cancels both.
 | 
						|
    ExprHandle body = (x + y + 5) * (x - x);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Cancel out opaque modulus.
 | 
						|
    ExprHandle body = (x % y + 2) - (x % y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 2);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Cancel out opaque modulus with a bit more going on.
 | 
						|
    ExprHandle body = (x % y + (x * 2 - x - y * 0) - x + 2) - (x % y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 2);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Sub where result is negative.
 | 
						|
    ExprHandle body = x - (x + 1);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), -1);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Sub where result is positive due to negative scalar on RHS.
 | 
						|
    ExprHandle body = x - (x - 1);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 1);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Term - Polynomial sub where RHS must be negated.
 | 
						|
    ExprHandle body = (x * 2) - (x * 2 + 1);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), -1);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Term - Polynomial sub where the result is a Term.
 | 
						|
    ExprHandle body = (y * x * 2) - (x * y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
 | 
						|
    IS_VAR_WITH_NAME(mul->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Term - Polynomial sub where the result is a Polynomial.
 | 
						|
    ExprHandle body = (x * 2) - (x + 1);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
 | 
						|
 | 
						|
    IS_VAR_WITH_NAME(sub->lhs(), "x");
 | 
						|
    IS_IMM_WITH_VAL(Int, sub->rhs(), 1);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyDiv) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
 | 
						|
  {
 | 
						|
    ExprHandle body = ExprHandle(0) / x;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    ExprHandle body = x / 1;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "x");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyDivWithLoopContext0) {
 | 
						|
  // Stmt to simplify:
 | 
						|
  // for (int i = 0; i < 100; i++) {
 | 
						|
  //  A[i] = i / 100;
 | 
						|
  //}
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  BufHandle a_buf("A", {100}, kInt);
 | 
						|
  auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i / 100)));
 | 
						|
 | 
						|
  const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
 | 
						|
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *(simplified);
 | 
						|
  const std::string& verification_pattern =
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (int i
 | 
						|
# CHECK-NEXT:   A[i] = 0;
 | 
						|
      )IR";
 | 
						|
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyDivWithLoopContext1) {
 | 
						|
  // Stmt to simplify:
 | 
						|
  // for (const auto i : c10::irange(6)) {
 | 
						|
  //  A[i] = (i + 24) / 6;
 | 
						|
  //}
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  BufHandle a_buf("A", {6}, kInt);
 | 
						|
  auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / 6));
 | 
						|
 | 
						|
  const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
 | 
						|
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *(simplified);
 | 
						|
  const std::string& verification_pattern =
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (int i
 | 
						|
# CHECK-NEXT:   A[i] = 4;
 | 
						|
      )IR";
 | 
						|
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyDivWithLoopContext2) {
 | 
						|
  // Stmt to simplify:
 | 
						|
  // for (const auto i : c10::irange(5)) {
 | 
						|
  //  A[i] = (i + 25) / 6;
 | 
						|
  //}
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  BufHandle a_buf("A", {5}, kInt);
 | 
						|
  auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) / 6));
 | 
						|
 | 
						|
  const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
 | 
						|
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *(simplified);
 | 
						|
  const std::string& verification_pattern =
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (int i
 | 
						|
# CHECK-NEXT:   A[i] = 4;
 | 
						|
      )IR";
 | 
						|
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyDivWithLoopContext3) {
 | 
						|
  // Stmt to simplify:
 | 
						|
  // for (const auto i : c10::irange(6)) {
 | 
						|
  //  A[i] = (i + 24) / (-6);
 | 
						|
  //}
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  BufHandle a_buf("A", {6}, kInt);
 | 
						|
  auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / (-6)));
 | 
						|
 | 
						|
  const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
 | 
						|
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *(simplified);
 | 
						|
  const std::string& verification_pattern =
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (int i
 | 
						|
# CHECK-NOT:   A[i] = -4;
 | 
						|
      )IR";
 | 
						|
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyDivWithLoopContext4) {
 | 
						|
  // Stmt to simplify:
 | 
						|
  // for (const auto i : c10::irange(5)) {
 | 
						|
  //  A[i] = (i - 5) / 6;
 | 
						|
  //}
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  BufHandle a_buf("A", {5}, kInt);
 | 
						|
  auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) / 6));
 | 
						|
 | 
						|
  const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
 | 
						|
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *(simplified);
 | 
						|
  const std::string& verification_pattern =
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (int i
 | 
						|
# CHECK-NOT:   A[i] = 0;
 | 
						|
      )IR";
 | 
						|
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyDivWithLoopContext5) {
 | 
						|
  // Stmt to simplify:
 | 
						|
  // for (const auto i : c10::irange(6)) {
 | 
						|
  //  for (const auto j : c10::irange(10)) {
 | 
						|
  //    A[i, j] = (i + 6*j) / 6;
 | 
						|
  //  }
 | 
						|
  //}
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  VarHandle j("j", kInt);
 | 
						|
  BufHandle a_buf("A", {6, 10}, kInt);
 | 
						|
  auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / 6));
 | 
						|
  auto for_i = For::make(i, 0, 6, for_j);
 | 
						|
 | 
						|
  const StmtPtr simplified = IRSimplifier::simplify(for_i);
 | 
						|
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *(simplified);
 | 
						|
  const std::string& verification_pattern =
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (int i
 | 
						|
# CHECK:   for (int j
 | 
						|
# CHECK-NEXT:   A[i, j] = j;
 | 
						|
      )IR";
 | 
						|
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyDivWithLoopContext6) {
 | 
						|
  // Stmt to simplify:
 | 
						|
  // for (const auto i : c10::irange(6)) {
 | 
						|
  //  for (int j = -1; j < 9; j++) {
 | 
						|
  //    A[i, j+1] = (i + 6*j) / 6;
 | 
						|
  //  }
 | 
						|
  //}
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  VarHandle j("j", kInt);
 | 
						|
  BufHandle a_buf("A", {6, 10}, kInt);
 | 
						|
  auto for_j =
 | 
						|
      For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) / 6));
 | 
						|
  auto for_i = For::make(i, 0, 6, for_j);
 | 
						|
 | 
						|
  const StmtPtr simplified = IRSimplifier::simplify(for_i);
 | 
						|
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *(simplified);
 | 
						|
  const std::string& verification_pattern =
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (int i
 | 
						|
# CHECK:   for (int j
 | 
						|
# CHECK-NOT:   A[i, j] = j;
 | 
						|
      )IR";
 | 
						|
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyDivWithLoopContext7) {
 | 
						|
  // Stmt to simplify:
 | 
						|
  // for (const auto i : c10::irange(6)) {
 | 
						|
  //  for (const auto j : c10::irange(10)) {
 | 
						|
  //    A[i, j] = (i + 6*j) / (-6);
 | 
						|
  //  }
 | 
						|
  //}
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  VarHandle j("j", kInt);
 | 
						|
  BufHandle a_buf("A", {6, 10}, kInt);
 | 
						|
  auto for_j =
 | 
						|
      For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / (-6)));
 | 
						|
  auto for_i = For::make(i, 0, 6, for_j);
 | 
						|
 | 
						|
  const StmtPtr simplified = IRSimplifier::simplify(for_i);
 | 
						|
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *(simplified);
 | 
						|
  const std::string& verification_pattern =
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (int i
 | 
						|
# CHECK:   for (int j
 | 
						|
# CHECK-NOT:   A[i, j] = -j;
 | 
						|
      )IR";
 | 
						|
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyModWithLoopContext0) {
 | 
						|
  // Stmt to simplify:
 | 
						|
  // for (const auto i : c10::irange(100)) {
 | 
						|
  //  A[i] = i % 100;
 | 
						|
  //}
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  BufHandle a_buf("A", {100}, kInt);
 | 
						|
  auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i % 100)));
 | 
						|
 | 
						|
  const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
 | 
						|
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *(simplified);
 | 
						|
  const std::string& verification_pattern =
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (int i
 | 
						|
# CHECK-NEXT:   A[i] = i;
 | 
						|
      )IR";
 | 
						|
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyModWithLoopContext1) {
 | 
						|
  // Stmt to simplify:
 | 
						|
  // for (const auto i : c10::irange(6)) {
 | 
						|
  //  A[i] = (i + 24) % 6;
 | 
						|
  //}
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  BufHandle a_buf("A", {6}, kInt);
 | 
						|
  auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % 6));
 | 
						|
 | 
						|
  const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
 | 
						|
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *(simplified);
 | 
						|
  const std::string& verification_pattern =
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (int i
 | 
						|
# CHECK-NEXT:   A[i] = i;
 | 
						|
      )IR";
 | 
						|
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyModWithLoopContext2) {
 | 
						|
  // Stmt to simplify:
 | 
						|
  // for (const auto i : c10::irange(5)) {
 | 
						|
  //  A[i] = (i + 25) % 6;
 | 
						|
  //}
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  BufHandle a_buf("A", {5}, kInt);
 | 
						|
  auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) % 6));
 | 
						|
 | 
						|
  const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
 | 
						|
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *(simplified);
 | 
						|
  const std::string& verification_pattern =
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (int i
 | 
						|
# CHECK-NEXT:   A[i] = i + 1;
 | 
						|
      )IR";
 | 
						|
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyModWithLoopContext3) {
 | 
						|
  // Stmt to simplify:
 | 
						|
  // for (const auto i : c10::irange(6)) {
 | 
						|
  //  A[i] = (i + 24) % (-6);
 | 
						|
  //}
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  BufHandle a_buf("A", {6}, kInt);
 | 
						|
  auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % (-6)));
 | 
						|
 | 
						|
  const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
 | 
						|
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *(simplified);
 | 
						|
  const std::string& verification_pattern =
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (int i
 | 
						|
# CHECK-NOT:   A[i] = i;
 | 
						|
      )IR";
 | 
						|
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyModWithLoopContext4) {
 | 
						|
  // Stmt to simplify:
 | 
						|
  // for (const auto i : c10::irange(5)) {
 | 
						|
  //  A[i] = (i - 5) % 6;
 | 
						|
  //}
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  BufHandle a_buf("A", {5}, kInt);
 | 
						|
  auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) % 6));
 | 
						|
 | 
						|
  const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
 | 
						|
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *(simplified);
 | 
						|
  const std::string& verification_pattern =
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (int i
 | 
						|
# CHECK-NOT:   A[i] = i - 5;
 | 
						|
      )IR";
 | 
						|
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyModWithLoopContext5) {
 | 
						|
  // Stmt to simplify:
 | 
						|
  // for (const auto i : c10::irange(6)) {
 | 
						|
  //  for (const auto j : c10::irange(10)) {
 | 
						|
  //    A[i, j] = (i + 6*j) % 6;
 | 
						|
  //  }
 | 
						|
  //}
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  VarHandle j("j", kInt);
 | 
						|
  BufHandle a_buf("A", {6, 10}, kInt);
 | 
						|
  auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % 6));
 | 
						|
  auto for_i = For::make(i, 0, 6, for_j);
 | 
						|
 | 
						|
  const StmtPtr simplified = IRSimplifier::simplify(for_i);
 | 
						|
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *(simplified);
 | 
						|
  const std::string& verification_pattern =
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (int i
 | 
						|
# CHECK:   for (int j
 | 
						|
# CHECK-NEXT:   A[i, j] = i;
 | 
						|
      )IR";
 | 
						|
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyModWithLoopContext6) {
 | 
						|
  // Stmt to simplify:
 | 
						|
  // for (const auto i : c10::irange(6)) {
 | 
						|
  //  for (int j = -1; j < 9; j++) {
 | 
						|
  //    A[i, j+1] = (i + 6*j) % 6;
 | 
						|
  //  }
 | 
						|
  //}
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  VarHandle j("j", kInt);
 | 
						|
  BufHandle a_buf("A", {6, 10}, kInt);
 | 
						|
  auto for_j =
 | 
						|
      For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) % 6));
 | 
						|
  auto for_i = For::make(i, 0, 6, for_j);
 | 
						|
 | 
						|
  const StmtPtr simplified = IRSimplifier::simplify(for_i);
 | 
						|
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *(simplified);
 | 
						|
  const std::string& verification_pattern =
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (int i
 | 
						|
# CHECK:   for (int j
 | 
						|
# CHECK-NOT:   A[i, j] = i;
 | 
						|
      )IR";
 | 
						|
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyModWithLoopContext7) {
 | 
						|
  // Stmt to simplify:
 | 
						|
  // for (const auto i : c10::irange(6)) {
 | 
						|
  //  for (const auto j : c10::irange(10)) {
 | 
						|
  //    A[i, j] = (i + 6*j) % (-6);
 | 
						|
  //  }
 | 
						|
  //}
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  VarHandle j("j", kInt);
 | 
						|
  BufHandle a_buf("A", {6, 10}, kInt);
 | 
						|
  auto for_j =
 | 
						|
      For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % (-6)));
 | 
						|
  auto for_i = For::make(i, 0, 6, for_j);
 | 
						|
 | 
						|
  const StmtPtr simplified = IRSimplifier::simplify(for_i);
 | 
						|
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *(simplified);
 | 
						|
  const std::string& verification_pattern =
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (int i
 | 
						|
# CHECK:   for (int j
 | 
						|
# CHECK-NOT:   A[i, j] = i;
 | 
						|
      )IR";
 | 
						|
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyMod) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
  VarHandle z("z", kInt);
 | 
						|
 | 
						|
  {
 | 
						|
    // Constant folding works.
 | 
						|
    ExprHandle body = ExprHandle(10) % 8;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 2);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // x % x => 0
 | 
						|
    ExprHandle body = x % x;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // 0 % x => 0
 | 
						|
    ExprHandle body = ExprHandle(0) % x;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // x % 1 => 0
 | 
						|
    ExprHandle body = x % 1;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Doesn't change unknown mods.
 | 
						|
    // x % y => x % y
 | 
						|
    ExprHandle body = x % y;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
 | 
						|
    IS_VAR_WITH_NAME(mod->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mod->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // don't touch if RHS is unknown.
 | 
						|
    // 4 % x => 4 % x
 | 
						|
    ExprHandle body = ExprHandle(4) % x;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
 | 
						|
    IS_IMM_WITH_VAL(Int, mod->lhs(), 4);
 | 
						|
    IS_VAR_WITH_NAME(mod->rhs(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // don't touch if LHS is unknown.
 | 
						|
    // x % 4 => x % 4
 | 
						|
    ExprHandle body = x % 4;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
 | 
						|
    IS_VAR_WITH_NAME(mod->lhs(), "x");
 | 
						|
    IS_IMM_WITH_VAL(Int, mod->rhs(), 4);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // if LHS is a multiple of RHS, mod is zero.
 | 
						|
    // 2 * x % x => 0
 | 
						|
    ExprHandle body = (x * 2) % x;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // true even if the multiple is not constant.
 | 
						|
    // x * y % x => 0
 | 
						|
    ExprHandle body = (x * y) % x;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // true with multiple unknown values in LHS.
 | 
						|
    // x * y * z % x => 0
 | 
						|
    ExprHandle body = (x * y * z) % x;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // true if the denom is compound.
 | 
						|
    // x * y * z % y * z => 0
 | 
						|
    ExprHandle body = (x * y * z) % (y * z);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Sanity check true with scalars that are multiples.
 | 
						|
    // 12 * x % 4 => 0
 | 
						|
    ExprHandle body = (x * 12) % 4;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Sanity check not true if the smaller scalar is on LHS.
 | 
						|
    // 4 * x % 12 => 4 * x % 12
 | 
						|
    ExprHandle body = (x * 4) % 12;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
 | 
						|
    IS_NODE_WITH_NAME(Mul, mod->lhs(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 4);
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "x");
 | 
						|
    IS_IMM_WITH_VAL(Int, mod->rhs(), 12);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Both scalar and symbolic in multiple.
 | 
						|
    // (6 * x * y) % (3 * x * y) => 0
 | 
						|
    ExprHandle body = (ExprHandle(6) * x * y) % (x * y * 3);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 0);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
// Test that mixing ops together simplifies as expected.
 | 
						|
TEST(Simplify, SimplifyMultiOp) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
 | 
						|
  {
 | 
						|
    // (x * y) + (x - y) => (x + x * y) - y
 | 
						|
    ExprHandle body = (x * y) + (x - y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
 | 
						|
    IS_NODE_WITH_NAME(Add, sub->lhs(), add);
 | 
						|
    IS_VAR_WITH_NAME(add->lhs(), "x");
 | 
						|
    IS_NODE_WITH_NAME(Mul, add->rhs(), mul);
 | 
						|
    IS_VAR_WITH_NAME(mul->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(sub->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x + y) - x * y => (x + y) - x * y
 | 
						|
    ExprHandle body = (x + y) - x * y;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
 | 
						|
    IS_NODE_WITH_NAME(Add, sub->lhs(), add);
 | 
						|
    IS_NODE_WITH_NAME(Mul, sub->rhs(), mul);
 | 
						|
    IS_VAR_WITH_NAME(add->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(add->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(mul->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x - y) - (x + y) => -2 * y
 | 
						|
    ExprHandle body = (x - y) - (x + y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), -2);
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x - 0) + (x * 1) - (x + 0) => x
 | 
						|
    ExprHandle body = (x - 0) + (x * 1) - (x + 0);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // (x - 0.f) + (x * 1.f) - (x + 0.f) => float(x) + float(x) - float(x)
 | 
						|
    // Even in Float simple terms cancel out, but the variable ones cannot.
 | 
						|
    ExprHandle body =
 | 
						|
        (x - ExprHandle(0.f)) + (x * ExprHandle(1.f)) - (x + ExprHandle(0.f));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
 | 
						|
    IS_NODE_WITH_NAME(Add, sub->lhs(), add);
 | 
						|
    IS_NODE_WITH_NAME(Cast, add->lhs(), cast1);
 | 
						|
    IS_VAR_WITH_NAME(cast1->src_value(), "x");
 | 
						|
    IS_NODE_WITH_NAME(Cast, add->rhs(), cast2);
 | 
						|
    IS_VAR_WITH_NAME(cast2->src_value(), "x");
 | 
						|
    IS_NODE_WITH_NAME(Cast, sub->rhs(), cast3);
 | 
						|
    IS_VAR_WITH_NAME(cast3->src_value(), "x");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
// Test that chaining many ops together works as expected.
 | 
						|
TEST(Simplify, SimplifyManyOps) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
 | 
						|
  {
 | 
						|
    // x + y + x + x + y + y + x + y + x = 4 * y + 5 * x
 | 
						|
    ExprHandle body = x + y + x + x + y + y + x + y + x;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Add, simplified.node(), add);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
 | 
						|
    IS_IMM_WITH_VAL(Int, lhs->lhs(), 4);
 | 
						|
    IS_VAR_WITH_NAME(lhs->rhs(), "y");
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
 | 
						|
    IS_IMM_WITH_VAL(Int, rhs->lhs(), 5);
 | 
						|
    IS_VAR_WITH_NAME(rhs->rhs(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // x - y + x + x - y - y + x - y + x = 5 * x - 4 * y
 | 
						|
    ExprHandle body = x - y + x + x - y - y + x - y + x;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Sub, simplified.node(), add);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
 | 
						|
    IS_IMM_WITH_VAL(Int, lhs->lhs(), 5);
 | 
						|
    IS_VAR_WITH_NAME(lhs->rhs(), "x");
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
 | 
						|
    IS_IMM_WITH_VAL(Int, rhs->lhs(), 4);
 | 
						|
    IS_VAR_WITH_NAME(rhs->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // x + y + x - x - y - y + x + y + x = 3 * x
 | 
						|
    ExprHandle body = x + y + x - x - y - y + x + y + x;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 3);
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "x");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyFactorization) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
 | 
						|
  {
 | 
						|
    // (2 * x) + (2 * y) => 2 * (x + y)
 | 
						|
    ExprHandle body = (ExprHandle(2) * x + ExprHandle(2) * y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Add, mul->rhs(), add);
 | 
						|
    IS_VAR_WITH_NAME(add->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(add->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Factorization when scalars have common divider.
 | 
						|
    // (2 * x) + (4 * y) => 2 * (2 * y + x)
 | 
						|
    ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Add, mul->rhs(), add);
 | 
						|
    IS_VAR_WITH_NAME(add->lhs(), "x");
 | 
						|
    IS_NODE_WITH_NAME(Mul, add->rhs(), mul2);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul2->lhs(), 2);
 | 
						|
    IS_VAR_WITH_NAME(mul2->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Factorization attempt without a common divider.
 | 
						|
    // (2 * x) + (5 * y) =>  (5 * y) + (2 * x)
 | 
						|
    ExprHandle body = (ExprHandle(2) * x + ExprHandle(5) * y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Add, simplified.node(), add);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
 | 
						|
    IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
 | 
						|
    IS_VAR_WITH_NAME(lhs->rhs(), "x");
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
 | 
						|
    IS_IMM_WITH_VAL(Int, rhs->lhs(), 5);
 | 
						|
    IS_VAR_WITH_NAME(rhs->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Factorization after merging.
 | 
						|
    // (2 * x) + (4 * y) + (8 * x + 6 * y) => 10 * (x + y)
 | 
						|
    ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y) +
 | 
						|
        (ExprHandle(8) * x + ExprHandle(6) * y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 10);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Add, mul->rhs(), add);
 | 
						|
    IS_VAR_WITH_NAME(add->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(add->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Factorization with common divider but different signs.
 | 
						|
    // (2 * x) + (-4 * y) => 2 * (x - 2 * y)
 | 
						|
    ExprHandle body = (ExprHandle(2) * x + ExprHandle(-4) * y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
 | 
						|
    IS_VAR_WITH_NAME(sub->lhs(), "x");
 | 
						|
    IS_NODE_WITH_NAME(Mul, sub->rhs(), mul2);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul2->lhs(), 2);
 | 
						|
    IS_VAR_WITH_NAME(mul2->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Factorization with all negative numbers.
 | 
						|
    // (-2 * x) + (-4 * y) => 2 * (-1 * x - 2 * y)
 | 
						|
    ExprHandle body = ExprHandle(-2) * x + ExprHandle(-4) * y;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
 | 
						|
    IS_NODE_WITH_NAME(Mul, sub->lhs(), mul2);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul2->lhs(), -1);
 | 
						|
    IS_VAR_WITH_NAME(mul2->rhs(), "x");
 | 
						|
    IS_NODE_WITH_NAME(Mul, sub->rhs(), mul3);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul3->lhs(), 2);
 | 
						|
    IS_VAR_WITH_NAME(mul3->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // The following test ensures that there in no infinite recursion during
 | 
						|
    // factorization when negative numbers are involved.
 | 
						|
    VarHandle a("a", kInt);
 | 
						|
    VarHandle b("b", kInt);
 | 
						|
    VarHandle c("c", kInt);
 | 
						|
    VarHandle d("d", kInt);
 | 
						|
    VarHandle e("e", kInt);
 | 
						|
    VarHandle f("f", kInt);
 | 
						|
    VarHandle g("g", kInt);
 | 
						|
    VarHandle h("h", kInt);
 | 
						|
 | 
						|
    ExprHandle body = a * 1024 + 0 + b * (-1) + c * (-1) + d * 1 + e * 1 +
 | 
						|
        f * 32 + g * (-1024) + h * (-32);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(
 | 
						|
        simplified,
 | 
						|
        "((((((d + e) + 1024 * a) + 32 * f) - b) - c) - 1024 * g) - 32 * h");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
// (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (y + 3 * z + 4 * x)
 | 
						|
TEST(Simplify, SimplifyFactorizeUneven) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
  VarHandle z("z", kInt);
 | 
						|
  ExprHandle body =
 | 
						|
      (ExprHandle(4) * x + y + z * 2) + (ExprHandle(4) * x + y + z * 4);
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
  IS_NODE_WITH_NAME(Mul, simplified.node(), root);
 | 
						|
  IS_IMM_WITH_VAL(Int, root->lhs(), 2);
 | 
						|
  IS_NODE_WITH_NAME(Add, root->rhs(), add1);
 | 
						|
  IS_NODE_WITH_NAME(Add, add1->lhs(), add2);
 | 
						|
 | 
						|
  IS_VAR_WITH_NAME(add2->lhs(), "y");
 | 
						|
  IS_NODE_WITH_NAME(Mul, add2->rhs(), zmul);
 | 
						|
  IS_NODE_WITH_NAME(Mul, add1->rhs(), xmul);
 | 
						|
 | 
						|
  IS_IMM_WITH_VAL(Int, xmul->lhs(), 4);
 | 
						|
  IS_VAR_WITH_NAME(xmul->rhs(), "x");
 | 
						|
 | 
						|
  IS_IMM_WITH_VAL(Int, zmul->lhs(), 3);
 | 
						|
  IS_VAR_WITH_NAME(zmul->rhs(), "z");
 | 
						|
}
 | 
						|
 | 
						|
// (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y)
 | 
						|
// This is kind of a placeholder test for variable factorization.
 | 
						|
TEST(Simplify, SimplifyDeeperTerms) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
  ExprHandle body = (x * y) + (ExprHandle(2) * x) * (x + y);
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
  IS_NODE_WITH_NAME(Add, simplified.node(), add);
 | 
						|
 | 
						|
  IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
 | 
						|
  IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
 | 
						|
  IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm);
 | 
						|
  IS_VAR_WITH_NAME(xxTerm->lhs(), "x");
 | 
						|
  IS_VAR_WITH_NAME(xxTerm->rhs(), "x");
 | 
						|
 | 
						|
  IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
 | 
						|
  IS_IMM_WITH_VAL(Int, rhs->lhs(), 3);
 | 
						|
  IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm);
 | 
						|
  IS_VAR_WITH_NAME(xyTerm->lhs(), "x");
 | 
						|
  IS_VAR_WITH_NAME(xyTerm->rhs(), "y");
 | 
						|
}
 | 
						|
 | 
						|
// Tests the difference between two less trivial expressions.
 | 
						|
// (m * (1 * n_1) + (n  + 1)) - (m *  (1 * n_1) + n) => 1
 | 
						|
TEST(Simplify, SimplifyDeeperDifference) {
 | 
						|
  VarHandle n("n", kInt);
 | 
						|
  VarHandle n_1("n_1", kInt);
 | 
						|
  VarHandle m("m", kInt);
 | 
						|
  ExprHandle body =
 | 
						|
      (m * (ExprHandle(1) * n_1) + (n + 1)) - (m * (ExprHandle(1) * n_1) + n);
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
  IS_IMM_WITH_VAL(Int, simplified.node(), 1);
 | 
						|
}
 | 
						|
 | 
						|
// Test constant folding into the difference between expressions.
 | 
						|
// 2 + char((m * (1 * n_1) + (n  + 1)) - (m *  (1 * n_1) + n)) => 3
 | 
						|
TEST(Simplify, SimplifyFoldComplexDifference) {
 | 
						|
  VarHandle n("n", kInt);
 | 
						|
  VarHandle n_1("n_1", kInt);
 | 
						|
  VarHandle m("m", kInt);
 | 
						|
  ExprHandle body =
 | 
						|
      (IntImm::make(2) +
 | 
						|
       (Cast::make(
 | 
						|
           kChar,
 | 
						|
           (m * (ExprHandle(1) * n_1) + (n + 1)) -
 | 
						|
               (m * (ExprHandle(1) * n_1) + n))));
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
  IS_IMM_WITH_VAL(Int, simplified.node(), 3);
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyIfComponents) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
  ExprHandle body = IfThenElse::make(
 | 
						|
      ((ExprHandle(5) - ExprHandle(4)) * x) > y,
 | 
						|
      ExprHandle(2) * x - x,
 | 
						|
      ExprHandle(2) * y - y);
 | 
						|
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
  IS_NODE_WITH_NAME(IfThenElse, simplified.node(), ifexpr);
 | 
						|
 | 
						|
  IS_NODE_WITH_NAME(CompareSelect, ifexpr->condition(), cmp);
 | 
						|
  ASSERT_EQ(cmp->compare_select_op(), kGT);
 | 
						|
  IS_VAR_WITH_NAME(cmp->lhs(), "x");
 | 
						|
  IS_VAR_WITH_NAME(cmp->rhs(), "y");
 | 
						|
 | 
						|
  IS_VAR_WITH_NAME(ifexpr->true_value(), "x");
 | 
						|
  IS_VAR_WITH_NAME(ifexpr->false_value(), "y");
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyOpaqueTerms) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
 | 
						|
  {
 | 
						|
    // 2 * x/y * y - x/y * y => x/y * y
 | 
						|
    ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_NODE_WITH_NAME(Div, mul->lhs(), div);
 | 
						|
    IS_VAR_WITH_NAME(div->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(div->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // x%y - (x%y - 1) => 1
 | 
						|
    ExprHandle body = (x % y) - ((x % y) - 1);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_IMM_WITH_VAL(Int, simplified.node(), 1);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifySymbolicMinMax) {
 | 
						|
  {
 | 
						|
    // Minimum with constant difference between terms.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    ExprHandle body = Min::make(x + 3, x + 7, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Add, simplified.node(), add);
 | 
						|
    IS_VAR_WITH_NAME(add->lhs(), "x");
 | 
						|
    IS_IMM_WITH_VAL(Int, add->rhs(), 3);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Maximum with constant difference between terms.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    ExprHandle body = Max::make(x + 3, x + 7, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Add, simplified.node(), add);
 | 
						|
    IS_VAR_WITH_NAME(add->lhs(), "x");
 | 
						|
    IS_IMM_WITH_VAL(Int, add->rhs(), 7);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Can't simplify multiples because of signedness of variable component.
 | 
						|
    // TODO: maybe we could for unsigned types?
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    ExprHandle body = Max::make(x * 3, x * 7, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE(Max, simplified.node());
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyNestedMax) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
  VarHandle z("z", kInt);
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(x + y, x + y) => x + y
 | 
						|
    ExprHandle body = Max::make(x + y, x + y, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
 | 
						|
    IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(x + y, Max(x + y, z)) => Max(x + y, z)
 | 
						|
    ExprHandle body = Max::make(x + y, Max::make(x + y, z, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Max, simplified.node(), max);
 | 
						|
    IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y");
 | 
						|
    IS_VAR_WITH_NAME(max->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(x + y, Max(z, x + y)) => Max(x + y, z)
 | 
						|
    ExprHandle body = Max::make(x + y, Max::make(z, x + y, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Max, simplified.node(), max);
 | 
						|
    IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y");
 | 
						|
    IS_VAR_WITH_NAME(max->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(Max(x + y, z), x + y) => Max(x + y, z)
 | 
						|
    ExprHandle body = Max::make(Max::make(x + y, z, true), x + y, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Max, simplified.node(), max);
 | 
						|
    IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y");
 | 
						|
    IS_VAR_WITH_NAME(max->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(Max(z, x + y), x + y) => Max(x + y, z)
 | 
						|
    ExprHandle body = Max::make(Max::make(z, x + y, true), x + y, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Max, simplified.node(), max);
 | 
						|
    IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y");
 | 
						|
    IS_VAR_WITH_NAME(max->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(Max(x, y), x) => Max(Max(x, y), x)
 | 
						|
    // Nested Max ops with different propagate_nans should not be simplified.
 | 
						|
    ExprHandle body = Max::make(Max::make(x, y, true), x, false);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Max, simplified.node(), max);
 | 
						|
    IS_BINOP_W_VARS(Max, max->lhs(), max1, "x", "y");
 | 
						|
    ASSERT_TRUE(max1->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(max->rhs(), "x");
 | 
						|
    ASSERT_FALSE(max->propagate_nans());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(Min(x, y), Min(x, z)) => Min(Max(y, z), x)
 | 
						|
    ExprHandle body =
 | 
						|
        Max::make(Min::make(x, y, true), Min::make(x, z, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(Min(x, y), Min(z, x)) => Min(Max(y, z), x)
 | 
						|
    ExprHandle body =
 | 
						|
        Max::make(Min::make(x, y, true), Min::make(z, x, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(Min(y, x), Min(x, z)) => Min(Max(y, z), x)
 | 
						|
    ExprHandle body =
 | 
						|
        Max::make(Min::make(y, x, true), Min::make(x, z, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(Min(y, x), Min(z, x)) => Min(Max(y, z), x)
 | 
						|
    ExprHandle body =
 | 
						|
        Max::make(Min::make(y, x, true), Min::make(z, x, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(Min(y, x), Min(z, x)) => Max(Min(x, y), Min(x, z))
 | 
						|
    // When all the ops in the pattern do not have the same propagate_nans,
 | 
						|
    // it should not be simplified.
 | 
						|
    ExprHandle body =
 | 
						|
        Max::make(Min::make(y, x, true), Min::make(z, x, false), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Max, simplified.node(), max);
 | 
						|
    IS_BINOP_W_VARS(Min, max->lhs(), min1, "x", "y");
 | 
						|
    ASSERT_TRUE(min1->propagate_nans());
 | 
						|
    IS_BINOP_W_VARS(Min, max->rhs(), min2, "x", "z");
 | 
						|
    ASSERT_FALSE(min2->propagate_nans());
 | 
						|
    ASSERT_TRUE(max->propagate_nans());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(5, Max(x, 8)) => Max(x, 8)
 | 
						|
    ExprHandle body = Max::make(5, Max::make(x, 8, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8);
 | 
						|
    ASSERT_TRUE(max->propagate_nans());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(8, Max(x, 5)) => Max(x, 8)
 | 
						|
    ExprHandle body = Max::make(8, Max::make(x, 5, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8);
 | 
						|
    ASSERT_TRUE(max->propagate_nans());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(Max(x, 8), 5) => Max(x, 8)
 | 
						|
    ExprHandle body = Max::make(Max::make(x, 8, true), 5, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8);
 | 
						|
    ASSERT_TRUE(max->propagate_nans());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(Max(x, 5), 8) => Max(x, 8)
 | 
						|
    ExprHandle body = Max::make(Max::make(x, 5, true), 8, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8);
 | 
						|
    ASSERT_TRUE(max->propagate_nans());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(5, Max(x, Max(y, Max(z, 8)))) => Max(Max(Max(x, 8), y), z)
 | 
						|
    ExprHandle body = Max::make(
 | 
						|
        5, Max::make(x, Max::make(y, Max::make(z, 8, true), true), true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Max, simplified.node(), max1);
 | 
						|
    IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
 | 
						|
    IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
 | 
						|
    ASSERT_TRUE(max3->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(max2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(max1->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(8, Max(Max(y, Max(z, 5)), x)) => Max(Max(Max(x, 8), y), z)
 | 
						|
    ExprHandle body = Max::make(
 | 
						|
        8, Max::make(Max::make(y, Max::make(z, 5, true), true), x, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Max, simplified.node(), max1);
 | 
						|
    IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
 | 
						|
    IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
 | 
						|
    ASSERT_TRUE(max3->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(max2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(max1->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(5, Max(Max(Max(z, 8), y), x)) => Max(Max(Max(x, 8), y), z)
 | 
						|
    ExprHandle body = Max::make(
 | 
						|
        5, Max::make(Max::make(Max::make(z, 8, true), y, true), x, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Max, simplified.node(), max1);
 | 
						|
    IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
 | 
						|
    IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
 | 
						|
    ASSERT_TRUE(max3->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(max2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(max1->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(Max(x, Max(y, Max(5, z))), 8) => Max(Max(Max(x, 8), y), z)
 | 
						|
    ExprHandle body = Max::make(
 | 
						|
        Max::make(x, Max::make(y, Max::make(5, z, true), true), true), 8, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Max, simplified.node(), max1);
 | 
						|
    IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
 | 
						|
    IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
 | 
						|
    ASSERT_TRUE(max3->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(max2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(max1->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(Max(Max(y, Max(8, z)), x), 5) => Max(Max(Max(x, 8), y), z)
 | 
						|
    ExprHandle body = Max::make(
 | 
						|
        Max::make(Max::make(y, Max::make(z, 8, true), true), x, true), 5, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Max, simplified.node(), max1);
 | 
						|
    IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
 | 
						|
    IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
 | 
						|
    ASSERT_TRUE(max3->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(max2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(max1->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(Max(Max(Max(5, z), y), x), 8) => Max(Max(Max(x, 8), y), z)
 | 
						|
    ExprHandle body = Max::make(
 | 
						|
        Max::make(Max::make(Max::make(z, 5, true), y, true), x, true), 8, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Max, simplified.node(), max1);
 | 
						|
    IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
 | 
						|
    IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
 | 
						|
    ASSERT_TRUE(max3->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(max2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(max1->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(Max(Max(Max(z, 5), y), x), 8) => Max(Max(x, Max(Max(z, 5), y)), 8)
 | 
						|
    // Do not simplify when all the Max ops do not have the same
 | 
						|
    // propagate_nans.
 | 
						|
    ExprHandle body = Max::make(
 | 
						|
        Max::make(Max::make(Max::make(z, 5, true), y, false), x, true),
 | 
						|
        8,
 | 
						|
        false);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "Max(Max(Max(Max(z, 5, 1), y, 0), x, 1), 8, 0)");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(8, Max(Max(x, 5), Max(y, z))) => Max(Max(Max(x, 8), y), z)
 | 
						|
    ExprHandle body = Max::make(
 | 
						|
        8, Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Max, simplified.node(), max1);
 | 
						|
    IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
 | 
						|
    IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
 | 
						|
    ASSERT_TRUE(max3->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(max2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(max1->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Max(Max(Max(x, 5), Max(y, z)), 8) => Max(Max(Max(x, 8), y), z)
 | 
						|
    ExprHandle body = Max::make(
 | 
						|
        Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), 8, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Max, simplified.node(), max1);
 | 
						|
    IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
 | 
						|
    IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
 | 
						|
    ASSERT_TRUE(max3->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(max2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(max1->rhs(), "z");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyNestedMin) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
  VarHandle z("z", kInt);
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(x + y, x + y) => x + y
 | 
						|
    ExprHandle body = Min::make(x + y, x + y, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
 | 
						|
    IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(x + y, Min(x + y, z)) => Min(x + y, z)
 | 
						|
    ExprHandle body = Min::make(x + y, Min::make(x + y, z, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Min, simplified.node(), min);
 | 
						|
    IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y");
 | 
						|
    IS_VAR_WITH_NAME(min->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(x + y, Min(z, x + y)) => Min(x + y, z)
 | 
						|
    ExprHandle body = Min::make(x + y, Min::make(z, x + y, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Min, simplified.node(), min);
 | 
						|
    IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y");
 | 
						|
    IS_VAR_WITH_NAME(min->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(Min(x + y, z), x + y) => Min(x + y, z)
 | 
						|
    ExprHandle body = Min::make(Min::make(x + y, z, true), x + y, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Min, simplified.node(), min);
 | 
						|
    IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y");
 | 
						|
    IS_VAR_WITH_NAME(min->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(Min(z, x + y), x + y) => Min(x + y, z)
 | 
						|
    ExprHandle body = Min::make(Min::make(z, x + y, true), x + y, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Min, simplified.node(), min);
 | 
						|
    IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y");
 | 
						|
    IS_VAR_WITH_NAME(min->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(Min(x, y), x) => Min(Min(x, y), x)
 | 
						|
    // Nested Min ops with different propagate_nans should not be simplified.
 | 
						|
    ExprHandle body = Min::make(Min::make(x, y, true), x, false);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Min, simplified.node(), min1);
 | 
						|
    IS_BINOP_W_VARS(Min, min1->lhs(), min2, "x", "y");
 | 
						|
    ASSERT_TRUE(min2->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(min1->rhs(), "x");
 | 
						|
    ASSERT_FALSE(min1->propagate_nans());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(Max(x, y), Max(x, z)) => Max(Min(y, z), x)
 | 
						|
    ExprHandle body =
 | 
						|
        Min::make(Max::make(x, y, true), Max::make(x, z, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(Max(x, y), Max(z, x)) => Max(Min(y, z), x)
 | 
						|
    ExprHandle body =
 | 
						|
        Min::make(Max::make(x, y, true), Max::make(z, x, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(Max(y, x), Max(x, z)) => Max(Min(y, z), x)
 | 
						|
    ExprHandle body =
 | 
						|
        Min::make(Max::make(y, x, true), Max::make(x, z, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(Max(y, x), Max(z, x)) => Max(Min(y, z), x)
 | 
						|
    ExprHandle body =
 | 
						|
        Min::make(Max::make(y, x, true), Max::make(z, x, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(Max(y, x), Max(z, x)) => Min(Max(x, y), Max(x, z))
 | 
						|
    // When all the ops in the pattern do not have the same propagate_nans,
 | 
						|
    // it should not be simplified.
 | 
						|
    ExprHandle body =
 | 
						|
        Min::make(Max::make(y, x, true), Max::make(z, x, false), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Min, simplified.node(), min);
 | 
						|
    IS_BINOP_W_VARS(Max, min->lhs(), max1, "x", "y");
 | 
						|
    ASSERT_TRUE(max1->propagate_nans());
 | 
						|
    IS_BINOP_W_VARS(Max, min->rhs(), max2, "x", "z");
 | 
						|
    ASSERT_FALSE(max2->propagate_nans());
 | 
						|
    ASSERT_TRUE(min->propagate_nans());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(5, Min(x, 8)) => Min(x, 8)
 | 
						|
    ExprHandle body = Min::make(5, Min::make(x, 8, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5);
 | 
						|
    ASSERT_TRUE(min->propagate_nans());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(8, Min(x, 5)) => Min(x, 8)
 | 
						|
    ExprHandle body = Min::make(8, Min::make(x, 5, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5);
 | 
						|
    ASSERT_TRUE(min->propagate_nans());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(Min(x, 8), 5) => Min(x, 8)
 | 
						|
    ExprHandle body = Min::make(Min::make(x, 8, true), 5, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5);
 | 
						|
    ASSERT_TRUE(min->propagate_nans());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(Min(x, 5), 8) => Min(x, 8)
 | 
						|
    ExprHandle body = Min::make(Min::make(x, 5, true), 8, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5);
 | 
						|
    ASSERT_TRUE(min->propagate_nans());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(5, Min(x, Min(y, Min(z, 8)))) => Min(Min(Min(x, 5), y), z)
 | 
						|
    ExprHandle body = Min::make(
 | 
						|
        5, Min::make(x, Min::make(y, Min::make(z, 8, true), true), true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Min, simplified.node(), min1);
 | 
						|
    IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
 | 
						|
    IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
 | 
						|
    ASSERT_TRUE(min3->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(min2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(min1->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(5, Min(Min(y, Min(z, 8)), x)) => Min(Min(Min(x, 5), y), z)
 | 
						|
    ExprHandle body = Min::make(
 | 
						|
        5, Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Min, simplified.node(), min1);
 | 
						|
    IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
 | 
						|
    IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
 | 
						|
    ASSERT_TRUE(min3->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(min2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(min1->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(5, Min(Min(Min(z, 8), y), x)) => Min(Min(Min(x, 5), y), z)
 | 
						|
    ExprHandle body = Min::make(
 | 
						|
        5, Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Min, simplified.node(), min1);
 | 
						|
    IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
 | 
						|
    IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
 | 
						|
    ASSERT_TRUE(min3->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(min2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(min1->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(Min(x, Min(y, Min(8, z))), 5) => Min(Min(Min(x, 5), y), z)
 | 
						|
    ExprHandle body = Min::make(
 | 
						|
        Min::make(x, Min::make(y, Min::make(8, z, true), true), true), 5, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Min, simplified.node(), min1);
 | 
						|
    IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
 | 
						|
    IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
 | 
						|
    ASSERT_TRUE(min3->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(min2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(min1->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(Min(Min(y, Min(8, z)), x), 5) => Min(Min(Min(x, 5), y), z)
 | 
						|
    ExprHandle body = Min::make(
 | 
						|
        Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), 5, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Min, simplified.node(), min1);
 | 
						|
    IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
 | 
						|
    IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
 | 
						|
    ASSERT_TRUE(min3->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(min2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(min1->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(Min(Min(Min(8, z), y), x), 5) => Min(Min(Min(x, 5), y), z)
 | 
						|
    ExprHandle body = Min::make(
 | 
						|
        Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), 5, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Min, simplified.node(), min1);
 | 
						|
    IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
 | 
						|
    IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
 | 
						|
    ASSERT_TRUE(min3->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(min2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(min1->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(Min(Min(Min(z, 5), y), x), 8) => Min(Min(Min(Min(z, 5), y), x), 8)
 | 
						|
    // Do not simplify when all the Min ops do not have the same
 | 
						|
    // propagate_nans.
 | 
						|
    ExprHandle body = Min::make(
 | 
						|
        Min::make(Min::make(Min::make(z, 5, true), y, false), x, true),
 | 
						|
        8,
 | 
						|
        false);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "Min(Min(Min(Min(z, 5, 1), y, 0), x, 1), 8, 0)");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(8, Min(Min(x, 5), Min(y, z))) => Min(Min(Min(x, 5), y), z)
 | 
						|
    ExprHandle body = Min::make(
 | 
						|
        8, Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Min, simplified.node(), min1);
 | 
						|
    IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
 | 
						|
    IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
 | 
						|
    ASSERT_TRUE(min3->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(min2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(min1->rhs(), "z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Min(Min(Min(x, 5), Min(y, z)), 8) => Min(Min(Min(x, 5), y), z)
 | 
						|
    ExprHandle body = Min::make(
 | 
						|
        Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), 8, true);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Min, simplified.node(), min1);
 | 
						|
    IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
 | 
						|
    IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
 | 
						|
    ASSERT_TRUE(min3->propagate_nans());
 | 
						|
    IS_VAR_WITH_NAME(min2->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(min1->rhs(), "z");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyWontReorderFloat) {
 | 
						|
  {
 | 
						|
    // 3 * (3 * x) - 3 * (3 * y) => 9 * (x - y)
 | 
						|
    // This is an expression we can simplify.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
 | 
						|
    ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
 | 
						|
        ExprHandle(3) * (ExprHandle(3) * y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 9);
 | 
						|
    IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
 | 
						|
    IS_VAR_WITH_NAME(sub->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(sub->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - 3 * (3 * y).
 | 
						|
    // If the vars are floating point, ops are not associative and we can't
 | 
						|
    // reorder.
 | 
						|
    VarHandle x("x", kFloat);
 | 
						|
    VarHandle y("y", kFloat);
 | 
						|
 | 
						|
    ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
 | 
						|
        ExprHandle(3) * (ExprHandle(3) * y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
 | 
						|
    IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
 | 
						|
    IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3);
 | 
						|
    IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul);
 | 
						|
    IS_IMM_WITH_VAL(Float, lhsVarMul->lhs(), 3);
 | 
						|
    IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul);
 | 
						|
    IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3);
 | 
						|
    IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul);
 | 
						|
    IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3);
 | 
						|
    IS_VAR_WITH_NAME(rhsVarMul->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - (9 * y).
 | 
						|
    // We will simplify subexprs if they dont reorder floating point ops.
 | 
						|
    VarHandle x("x", kDouble);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
 | 
						|
    ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
 | 
						|
        ExprHandle(3) * (ExprHandle(3) * y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
 | 
						|
    IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
 | 
						|
    IS_IMM_WITH_VAL(Double, lhsMul->lhs(), 3);
 | 
						|
    IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul);
 | 
						|
    IS_IMM_WITH_VAL(Double, lhsVarMul->lhs(), 3);
 | 
						|
    IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME_AND_CAST(Mul, sub->rhs(), rhsMul, Double);
 | 
						|
    IS_IMM_WITH_VAL(Int, rhsMul->lhs(), 9);
 | 
						|
    IS_VAR_WITH_NAME(rhsMul->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Prevent reordering if FP propagated from dtypes.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
 | 
						|
    ExprHandle body = ExprHandle(3.f) * (ExprHandle(3) * x) -
 | 
						|
        ExprHandle(3) * (ExprHandle(3.f) * y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
 | 
						|
    IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
 | 
						|
    IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3);
 | 
						|
    IS_NODE_WITH_NAME_AND_CAST(Mul, lhsMul->rhs(), lhsVarMul, Float);
 | 
						|
    IS_IMM_WITH_VAL(Int, lhsVarMul->lhs(), 3);
 | 
						|
    IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul);
 | 
						|
    IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3);
 | 
						|
    IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul);
 | 
						|
    IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3);
 | 
						|
    IS_NODE_WITH_NAME(Cast, rhsVarMul->rhs(), yCast);
 | 
						|
    IS_VAR_WITH_NAME(yCast->src_value(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    VarHandle x("x", kFloat);
 | 
						|
    VarHandle y("y", kFloat);
 | 
						|
    // x%y - (x%y - 1) => x%y - (x%y - 1).
 | 
						|
    // We wont reorder opaque ops if they are FP.
 | 
						|
    ExprHandle body = (x % y) - ((x % y) - 1);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
 | 
						|
    IS_NODE_WITH_NAME(Mod, sub->lhs(), lhsMod);
 | 
						|
    IS_VAR_WITH_NAME(lhsMod->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(lhsMod->rhs(), "y");
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Sub, sub->rhs(), rhsSub);
 | 
						|
    IS_NODE_WITH_NAME(Mod, rhsSub->lhs(), rhsMod);
 | 
						|
    IS_VAR_WITH_NAME(rhsMod->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(rhsMod->rhs(), "y");
 | 
						|
    IS_IMM_WITH_VAL(Float, rhsSub->rhs(), 1);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyRoundModPattern) {
 | 
						|
  {
 | 
						|
    // (x/y)*y + x%y => x.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = ((x / y) * y) + (x % y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Reverse order.
 | 
						|
    // x%y + (x/y)*y => x.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = (x % y) + ((x / y) * y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Non opaque denominator.
 | 
						|
    // (x / (4+y)) * (4+y)) + (x % (y + 4)) => x.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)) +
 | 
						|
        (x % (y + ExprHandle(4)));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Reverse order.
 | 
						|
    // (x % (y + 4)) + (x / (4+y)) * (4+y)) => x.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = (x % (y + ExprHandle(4))) +
 | 
						|
        ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Opaque denominator.
 | 
						|
    // (x / (2/y)) * (2/y)) + (x % (2/y)) => x.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = ((x / (ExprHandle(2) / y)) * (ExprHandle(2) / y)) +
 | 
						|
        (x % (ExprHandle(2) / y));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Non opaque numerator
 | 
						|
    // ((2*x)/y * y) + ((2*x) % y) => 2 * x.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body =
 | 
						|
        (((ExprHandle(2) * x) / y) * y) + ((ExprHandle(2) * x) % y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Opaque numerator.
 | 
						|
    // ((x/2) / y * y) + (x/2 % y) => x / 2.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body =
 | 
						|
        (((x / ExprHandle(2)) / y) * y) + ((x / ExprHandle(2)) % y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Div, simplified.node(), div);
 | 
						|
    IS_VAR_WITH_NAME(div->lhs(), "x");
 | 
						|
    IS_IMM_WITH_VAL(Int, div->rhs(), 2);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Numerator and denominator.
 | 
						|
    // ((2*x)/(2*y) * (2*y)) + ((2*x) % (2*y)) => 2 * x.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body =
 | 
						|
        (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)) +
 | 
						|
        ((ExprHandle(2) * x) % (ExprHandle(2) * y));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Reverse order.
 | 
						|
    // ((2*x) % (2*y)) + ((2*x)/(2*y) * (2*y)) => 2 * x.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = ((ExprHandle(2) * x) % (ExprHandle(2) * y)) +
 | 
						|
        (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Negated Subtraction of Round Mod.
 | 
						|
    // (x/y) * y - (0 - x%y) => x.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = ((x / y) * y) - (ExprHandle(0) - (x % y));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Other terms are preserved.
 | 
						|
    // (x/y)*y + x%y + (y * x) => x + (y * x).
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = ((x / y) * y) + (x % y) + (y * x);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Add, simplified.node(), add);
 | 
						|
    IS_VAR_WITH_NAME(add->lhs(), "x");
 | 
						|
    IS_NODE_WITH_NAME(Mul, add->rhs(), mul);
 | 
						|
    IS_VAR_WITH_NAME(mul->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Sanity checking we wont do the optimization on floats.
 | 
						|
    VarHandle x("x", kFloat);
 | 
						|
    VarHandle y("y", kFloat);
 | 
						|
    ExprHandle body = ((x / y) * y) + (x % y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Add, simplified.node(), add);
 | 
						|
    IS_NODE_WITH_NAME(Mul, add->lhs(), roundMul);
 | 
						|
    IS_NODE_WITH_NAME(Div, roundMul->lhs(), roundDiv);
 | 
						|
    IS_VAR_WITH_NAME(roundDiv->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(roundDiv->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(roundMul->rhs(), "y");
 | 
						|
    IS_NODE_WITH_NAME(Mod, add->rhs(), mod);
 | 
						|
    IS_VAR_WITH_NAME(mod->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mod->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Sanity check we wont do it if the mod term doesn't match.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    VarHandle z("z", kInt);
 | 
						|
    ExprHandle body = ((x / y) * y) + (x % z);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "(x / y) * y + x % z");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Sanity check we wont do it if the div term doesn't match.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    VarHandle z("z", kInt);
 | 
						|
    ExprHandle body = (y * (x / z)) + (x % y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "x % y + (x / z) * y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Sanity check we wont do it if the mul term doesn't match.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    VarHandle z("z", kInt);
 | 
						|
    ExprHandle body = ((x / y) * z) + (x % y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "x % y + (x / y) * z");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyRoundModPatternFactorization) {
 | 
						|
  {
 | 
						|
    // Full factorization.
 | 
						|
    // 2 * (x/y * y) + 2 * (x%y) => 2 * x.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = ExprHandle(2) * ((x / y) * y) + ExprHandle(2) * (x % y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Partial Factorization.
 | 
						|
    // 32 * (x/8) + 4 * (x % 8) => 4 * x.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers)
 | 
						|
    ExprHandle body = ExprHandle(32) * (x / 8) + ExprHandle(4) * (x % 8);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 4);
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Factorization requiring constant folding.
 | 
						|
    // 20 * (x  / (16 / 2)) * 2 + (11 % 6) * (x % (7+1)) => 5 * x.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    ExprHandle body = ExprHandle(40) * (x / (ExprHandle(16) / 2)) +
 | 
						|
        (ExprHandle(11) % 6) * (x % (ExprHandle(7) + 1));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 5);
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    ExprHandle body = (x / 5) * 10 + ExprHandle(2) * (x % 5);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    ExprHandle body = (x / 10) * 0 + x % 5;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
 | 
						|
    IS_VAR_WITH_NAME(mod->lhs(), "x");
 | 
						|
    IS_IMM_WITH_VAL(Int, mod->rhs(), 5);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyRoundModPatternMultivar) {
 | 
						|
  {
 | 
						|
    // Multivar.
 | 
						|
    // (x/8) * 8 + (y/5)*5 + x%8 + y%5 => x + y.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = (x / ExprHandle(8) * ExprHandle(8)) +
 | 
						|
        (y / ExprHandle(5) * ExprHandle(5)) + (x % 8) + (y % 5);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Add, simplified.node(), add);
 | 
						|
    IS_VAR_WITH_NAME(add->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(add->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Find the right var.
 | 
						|
    // (y/8) * 8  x%8 + y%8 + z%8 => x%8 + y + z%8
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    VarHandle z("z", kInt);
 | 
						|
    ExprHandle body =
 | 
						|
        (y / ExprHandle(8) * ExprHandle(8)) + (x % 8) + (y % 8) + (z % 8);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Add, simplified.node(), add);
 | 
						|
    IS_NODE_WITH_NAME(Add, add->lhs(), add2);
 | 
						|
    IS_NODE_WITH_NAME(Mod, add2->lhs(), xMod);
 | 
						|
    IS_VAR_WITH_NAME(xMod->lhs(), "x");
 | 
						|
    IS_IMM_WITH_VAL(Int, xMod->rhs(), 8);
 | 
						|
    IS_VAR_WITH_NAME(add2->rhs(), "y");
 | 
						|
    IS_NODE_WITH_NAME(Mod, add->rhs(), zMod);
 | 
						|
    IS_VAR_WITH_NAME(zMod->lhs(), "z");
 | 
						|
    IS_IMM_WITH_VAL(Int, zMod->rhs(), 8);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Compound.
 | 
						|
    // (x + (z + 512 * y) % 16) + 16 * ((z + 512 * y) / 16)
 | 
						|
    // => (z + 512 * y) + x
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    VarHandle z("z", kInt);
 | 
						|
 | 
						|
    ExprHandle body = x + (z + y * 512) % 16 + ((z + y * 512) / 16 * 16);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "x + (z + 512 * y)");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyModRoundModPattern) {
 | 
						|
  {
 | 
						|
    // t/7 % 9 * 7 + t % 7 => t%63
 | 
						|
    VarHandle t("t", kInt);
 | 
						|
    ExprHandle body = (t / 7 % 9) * 7 + t % 7;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
 | 
						|
    IS_VAR_WITH_NAME(mod->lhs(), "t");
 | 
						|
    IS_IMM_WITH_VAL(Int, mod->rhs(), 63);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // 2*t/7 % 9 * 7 + 2*t % 7 => 2*t % 63
 | 
						|
    VarHandle t("t", kInt);
 | 
						|
    ExprHandle body = (ExprHandle(2) * t / 7 % 9) * 7 + ExprHandle(2) * t % 7;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
 | 
						|
    IS_NODE_WITH_NAME(Mul, mod->lhs(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "t");
 | 
						|
    IS_IMM_WITH_VAL(Int, mod->rhs(), 63);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // t/x % y * x + t % x => t%(x*y)
 | 
						|
    VarHandle t("t", kInt);
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = (t / x % y) * x + t % x;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
 | 
						|
    IS_VAR_WITH_NAME(mod->lhs(), "t");
 | 
						|
    IS_NODE_WITH_NAME(Mul, mod->rhs(), mul);
 | 
						|
    IS_VAR_WITH_NAME(mul->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // k*t/x % y * x + k*t % x => k*t%(x*y)
 | 
						|
    VarHandle t("t", kInt);
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    VarHandle k("k", kInt);
 | 
						|
    ExprHandle body = (k * t / x % y) * x + k * t % x;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "(k * t) % (x * y)");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // t/k/x % y * x + t/k % x => t/k%(x*y)
 | 
						|
    VarHandle t("t", kInt);
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    VarHandle k("k", kInt);
 | 
						|
    ExprHandle body = (t / k / x % y) * x + t / k % x;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
 | 
						|
    IS_NODE_WITH_NAME(Div, mod->lhs(), div);
 | 
						|
    IS_VAR_WITH_NAME(div->lhs(), "t");
 | 
						|
    IS_VAR_WITH_NAME(div->rhs(), "k");
 | 
						|
    IS_NODE_WITH_NAME(Mul, mod->rhs(), mul);
 | 
						|
    IS_VAR_WITH_NAME(mul->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Sanity checking we wont do the optimization on floats.
 | 
						|
    VarHandle x("x", kFloat);
 | 
						|
    VarHandle y("y", kFloat);
 | 
						|
    VarHandle z("z", kFloat);
 | 
						|
    ExprHandle body = ((x / y % z) * y) + (x % y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Add, simplified.node(), add);
 | 
						|
    IS_NODE_WITH_NAME(Mul, add->lhs(), mul);
 | 
						|
    IS_NODE_WITH_NAME(Mod, mul->lhs(), mod);
 | 
						|
    IS_NODE_WITH_NAME(Div, mod->lhs(), div);
 | 
						|
    IS_VAR_WITH_NAME(div->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(div->rhs(), "y");
 | 
						|
    IS_VAR_WITH_NAME(mod->rhs(), "z");
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "y");
 | 
						|
    IS_NODE_WITH_NAME(Mod, add->rhs(), mod2);
 | 
						|
    IS_VAR_WITH_NAME(mod2->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mod2->rhs(), "y");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyModRoundModPatternFactorization) {
 | 
						|
  {
 | 
						|
    // 2 * (t /7 % 9 * 7) + 2 * (t % 7) => 2 * (t % 63)
 | 
						|
    VarHandle t("t", kInt);
 | 
						|
    ExprHandle body =
 | 
						|
        ExprHandle(2) * ((t / 7 % 9) * 7) + ExprHandle(2) * (t % 7);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
    IS_NODE_WITH_NAME(Mod, mul->rhs(), mod);
 | 
						|
    IS_VAR_WITH_NAME(mod->lhs(), "t");
 | 
						|
    IS_IMM_WITH_VAL(Int, mod->rhs(), 63);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // t /7 % 9 * 14 + 2* (t % 7) => 2* (t % 63)
 | 
						|
    VarHandle t("t", kInt);
 | 
						|
    ExprHandle body = (t / 7 % 9) * 14 + ExprHandle(2) * (t % 7);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
    IS_NODE_WITH_NAME(Mod, mul->rhs(), mod);
 | 
						|
    IS_VAR_WITH_NAME(mod->lhs(), "t");
 | 
						|
    IS_IMM_WITH_VAL(Int, mod->rhs(), 63);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // t/14 % 9 * 7 + t/2 % 7 => t/2 % 63
 | 
						|
    VarHandle t("t", kInt);
 | 
						|
    ExprHandle body = (t / 14 % 9) * 7 + t / 2 % 7;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
 | 
						|
    IS_NODE_WITH_NAME(Div, mod->lhs(), div);
 | 
						|
    IS_VAR_WITH_NAME(div->lhs(), "t");
 | 
						|
    IS_IMM_WITH_VAL(Int, div->rhs(), 2);
 | 
						|
    IS_IMM_WITH_VAL(Int, mod->rhs(), 63);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // t/(7*3) % 9 * 7*3 + t % (7*3) => t % 189
 | 
						|
    VarHandle t("t", kInt);
 | 
						|
    ExprHandle body = (t / (ExprHandle(7) * ExprHandle(3)) % 9) * 7 * 3 +
 | 
						|
        t % (ExprHandle(7) * ExprHandle(3));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
 | 
						|
    IS_VAR_WITH_NAME(mod->lhs(), "t");
 | 
						|
    IS_IMM_WITH_VAL(Int, mod->rhs(), 189);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // 2*(t/x % y * x) + 2*(t % x) => 2*(t%(x*y))
 | 
						|
    VarHandle t("t", kInt);
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body =
 | 
						|
        ExprHandle(2) * ((t / x % y) * x) + ExprHandle(2) * (t % x);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
    IS_NODE_WITH_NAME(Mod, mul->rhs(), mod);
 | 
						|
    IS_VAR_WITH_NAME(mod->lhs(), "t");
 | 
						|
    IS_NODE_WITH_NAME(Mul, mod->rhs(), mul2);
 | 
						|
    IS_VAR_WITH_NAME(mul2->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mul2->rhs(), "y");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyModRoundModPatternMultivar) {
 | 
						|
  {
 | 
						|
    // t/7 % 9 * 7 + t % 7 + t => t % 63 + t
 | 
						|
    VarHandle t("t", kInt);
 | 
						|
    ExprHandle body = (t / 7 % 9) * 7 + t % 7 + t;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "t % 63 + t");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // t/7 % 9 * 7 + t/8 % 9 * 8 + t % 7 + t % 8  => t % 63 + t % 72
 | 
						|
    VarHandle t("t", kInt);
 | 
						|
    ExprHandle body = (t / 7 % 9) * 7 + (t / 8 % 9) * 8 + t % 7 + t % 8;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Add, simplified.node(), add);
 | 
						|
    IS_NODE_WITH_NAME(Mod, add->lhs(), mod1);
 | 
						|
    IS_VAR_WITH_NAME(mod1->lhs(), "t");
 | 
						|
    IS_IMM_WITH_VAL(Int, mod1->rhs(), 63);
 | 
						|
    IS_NODE_WITH_NAME(Mod, add->rhs(), mod2);
 | 
						|
    IS_VAR_WITH_NAME(mod2->lhs(), "t");
 | 
						|
    IS_IMM_WITH_VAL(Int, mod2->rhs(), 72);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // k + t/x % y * x + t % x => k + t%(x*y)
 | 
						|
    VarHandle t("t", kInt);
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    VarHandle k("k", kInt);
 | 
						|
    ExprHandle body = k + (t / x % y) * x + t % x;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Add, simplified.node(), add);
 | 
						|
    IS_VAR_WITH_NAME(add->lhs(), "k");
 | 
						|
    IS_NODE_WITH_NAME(Mod, add->rhs(), mod);
 | 
						|
    IS_VAR_WITH_NAME(mod->lhs(), "t");
 | 
						|
    IS_NODE_WITH_NAME(Mul, mod->rhs(), mul);
 | 
						|
    IS_VAR_WITH_NAME(mul->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // t/x % y * x + t % x + (t/k / x % y) * x + t/k % x
 | 
						|
    // => t%(x*y) + t/k % (x*y)
 | 
						|
    VarHandle t("t", kInt);
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    VarHandle k("k", kInt);
 | 
						|
    ExprHandle body = (t / x % y) * x + t % x + (t / k / x % y) * x + t / k % x;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    checkExprIR(simplified, "(t / k) % (x * y) + t % (x * y)");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // 3D: (7 * ((i0_flat / 7) % 9) + i0_flat % 7) + 63 * (i0_flat / 63)
 | 
						|
    // => io_flat
 | 
						|
    VarHandle t("io_flat", kInt);
 | 
						|
    ExprHandle body =
 | 
						|
        ExprHandle(7) * (t / 7 % 9) + t % 7 + ExprHandle(63) * (t / 63);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "io_flat");
 | 
						|
  }
 | 
						|
 | 
						|
  { // 5D: i0_flat / (11 * 10 * 9 * 7)  * (7 * 9 * 10 * 11) +
 | 
						|
    // (i0_flat / (10 * 9 * 7) % 11)  * 7 * 9 * 10 +
 | 
						|
    // (i0_flat / (9 * 7) % 10) * 7 * 9 +
 | 
						|
    // (i0_flat / 7 % 9)  * 7 +
 | 
						|
    // i0_flat % 7 => io_flat
 | 
						|
    VarHandle t("io_flat", kInt);
 | 
						|
    ExprHandle body = (t / (ExprHandle(11) * 10 * 9 * 7)) * (7 * 9 * 10 * 11) +
 | 
						|
        (t / (ExprHandle(10) * 9 * 7) % 11) * 7 * 9 * 10 +
 | 
						|
        (t / (ExprHandle(9) * 7) % 10) * 7 * 9 + (t / 7 % 9) * 7 + t % 7;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "io_flat");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // 3D: (m * ((i0_flat / m) % n) + i0_flat % m) + (m * n) *
 | 
						|
    // (i0_flat / (m * n)) => io_flat
 | 
						|
    VarHandle t("io_flat", kInt);
 | 
						|
    VarHandle m("m", kInt);
 | 
						|
    VarHandle n("n", kInt);
 | 
						|
    ExprHandle body = m * (t / m % n) + t % m + (m * n) * (t / (m * n));
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "io_flat");
 | 
						|
  }
 | 
						|
 | 
						|
  { // 5D: i0_flat / (k * l * n * m)  * (m * n * l * k) +
 | 
						|
    // (i0_flat / (l * n * m) % k)  * m * n * l +
 | 
						|
    // (i0_flat / (n * m) % l) * m * n +
 | 
						|
    // (i0_flat / m % n)  * m +
 | 
						|
    // i0_flat % m => io_flat
 | 
						|
    VarHandle t("io_flat", kInt);
 | 
						|
    VarHandle m("m", kInt);
 | 
						|
    VarHandle n("n", kInt);
 | 
						|
    VarHandle l("l", kInt);
 | 
						|
    VarHandle k("k", kInt);
 | 
						|
    ExprHandle body = (t / (k * l * n * m)) * (m * n * l * k) +
 | 
						|
        (t / (l * n * m) % k) * m * n * l + (t / (n * m) % l) * m * n +
 | 
						|
        (t / m % n) * m + t % m;
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "io_flat");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyDivisionScalarFactorization) {
 | 
						|
  {
 | 
						|
    // Simple factorization of numerator and denominator.
 | 
						|
    // 8x / 4y => 2x / y.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = (x * 8) / (y * 4);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Div, simplified.node(), div);
 | 
						|
    IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
 | 
						|
    IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
 | 
						|
    IS_VAR_WITH_NAME(lhs->rhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(div->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Don't change anything if we can't factorize.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = (x * 7) / (y * 4);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Div, simplified.node(), div);
 | 
						|
    IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
 | 
						|
    IS_IMM_WITH_VAL(Int, lhs->lhs(), 7);
 | 
						|
    IS_VAR_WITH_NAME(lhs->rhs(), "x");
 | 
						|
    IS_NODE_WITH_NAME(Mul, div->rhs(), rhs);
 | 
						|
    IS_IMM_WITH_VAL(Int, rhs->lhs(), 4);
 | 
						|
    IS_VAR_WITH_NAME(rhs->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Don't reorder floats.
 | 
						|
    VarHandle x("x", kFloat);
 | 
						|
    VarHandle y("y", kFloat);
 | 
						|
    ExprHandle body = (x * 8) / (y * 4);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Div, simplified.node(), div);
 | 
						|
    IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
 | 
						|
    IS_VAR_WITH_NAME(lhs->lhs(), "x");
 | 
						|
    IS_IMM_WITH_VAL(Float, lhs->rhs(), 8.f);
 | 
						|
    IS_NODE_WITH_NAME(Mul, div->rhs(), rhs);
 | 
						|
    IS_VAR_WITH_NAME(rhs->lhs(), "y");
 | 
						|
    IS_IMM_WITH_VAL(Float, rhs->rhs(), 4.f);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Sanity check we do nothing if there are only scalar parts.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = (x * 1) / (y * 1);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Div, simplified.node(), div);
 | 
						|
    IS_VAR_WITH_NAME(div->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(div->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Can factorize amounts of variables.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = (x + x + x + x) / (y + y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Div, simplified.node(), div);
 | 
						|
    IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
 | 
						|
    IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
 | 
						|
    IS_VAR_WITH_NAME(lhs->rhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(div->rhs(), "y");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyConstantBranches) {
 | 
						|
  {
 | 
						|
    // If the condition is constant true then take the true_value.
 | 
						|
    // 1 ? x : y => x
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle t(1);
 | 
						|
    ExprHandle body = IfThenElse::make(t, x, y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // If the condition is constant false then take the false_value.
 | 
						|
    // 0 ? x : y => y
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle t(0);
 | 
						|
    ExprHandle body = IfThenElse::make(t, x, y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // condition is simplified before checking.
 | 
						|
    // (x-x) ? x : y => y
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = IfThenElse::make(x - x, x, y);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // If both branches are the same then don't do the condition.
 | 
						|
    // y ? x : x => x
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = IfThenElse::make(y, x, x);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_VAR_WITH_NAME(simplified.node(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // If both branches simplify to the same thing it still works.
 | 
						|
    // y ? (x + x) : (2 * x) => x
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    VarHandle y("y", kInt);
 | 
						|
    ExprHandle body = IfThenElse::make(y, x + x, ExprHandle(2) * x);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
 | 
						|
    IS_VAR_WITH_NAME(mul->rhs(), "x");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyConstantCond) {
 | 
						|
  {
 | 
						|
    // If the condition is constant true then take the true_value.
 | 
						|
    // 1 ? A[0] = 1 : B[0] = 1 => A[0] = 1
 | 
						|
    BufHandle a("A", {1}, kInt);
 | 
						|
    BufHandle b("B", {1}, kInt);
 | 
						|
    ExprHandle condition(1);
 | 
						|
    StmtPtr true_val = Store::make(a, {0}, 1);
 | 
						|
    StmtPtr false_val = Store::make(b, {0}, 1);
 | 
						|
 | 
						|
    CondPtr body = alloc<Cond>(condition.node(), true_val, false_val);
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    IS_NODE_WITH_NAME(Store, block->front(), store);
 | 
						|
    IS_VAR_WITH_NAME(store->base_handle(), "A");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // If the condition is constant false then take the false_value.
 | 
						|
    // 0 ? A[0] = 1 : B[0] = 1 => B[0] = 1
 | 
						|
    BufHandle a("A", {1}, kInt);
 | 
						|
    BufHandle b("B", {1}, kInt);
 | 
						|
    ExprHandle condition(0);
 | 
						|
    StmtPtr true_val = Store::make(a, {0}, 1);
 | 
						|
    StmtPtr false_val = Store::make(b, {0}, 1);
 | 
						|
 | 
						|
    StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    IS_NODE_WITH_NAME(Store, block->front(), store);
 | 
						|
    IS_VAR_WITH_NAME(store->base_handle(), "B");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // condition is simplified before checking.
 | 
						|
    // (x-x) ? A[0] = 1 : B[0] = 1 => B[0] = 1
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    BufHandle a("A", {1}, kInt);
 | 
						|
    BufHandle b("B", {1}, kInt);
 | 
						|
    ExprHandle condition(x - x);
 | 
						|
    StmtPtr true_val = Store::make(a, {0}, 1);
 | 
						|
    StmtPtr false_val = Store::make(b, {0}, 1);
 | 
						|
 | 
						|
    StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    IS_NODE_WITH_NAME(Store, block->front(), store);
 | 
						|
    IS_VAR_WITH_NAME(store->base_handle(), "B");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // If both branches are the same then don't do the condition.
 | 
						|
    // x ? A[0] = x : A[0] = x => A[0] = x
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    BufHandle a("A", {1}, kInt);
 | 
						|
    ExprHandle condition(x - x);
 | 
						|
    StmtPtr true_val = Store::make(a, {0}, x);
 | 
						|
    StmtPtr false_val = Store::make(a, {0}, x);
 | 
						|
 | 
						|
    StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    IS_NODE_WITH_NAME(Store, block->front(), store);
 | 
						|
    IS_VAR_WITH_NAME(store->base_handle(), "A");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // If both branches simplify to the same thing it still works.
 | 
						|
    // x ? (x + x) : (2 * x) => x
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    BufHandle a("A", {1}, kInt);
 | 
						|
    ExprHandle condition(x - x);
 | 
						|
    StmtPtr true_val = Store::make(a, {0}, ExprHandle(2) * x);
 | 
						|
    StmtPtr false_val = Store::make(a, {0}, x + x);
 | 
						|
 | 
						|
    StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    IS_NODE_WITH_NAME(Store, block->front(), store);
 | 
						|
    IS_VAR_WITH_NAME(store->base_handle(), "A");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // But not if they dont
 | 
						|
    // x ? x : (2 * x) => x ? x : (2 * x)
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    BufHandle a("A", {1}, kInt);
 | 
						|
    ExprHandle condition(x);
 | 
						|
    StmtPtr true_val = Store::make(a, {0}, x);
 | 
						|
    StmtPtr false_val = Store::make(a, {0}, ExprHandle(2) * x);
 | 
						|
 | 
						|
    StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    ASSERT_EQ(block, nullptr);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    StmtPtr cond = alloc<Cond>(
 | 
						|
        ExprHandle(false).node(),
 | 
						|
        alloc<Block>(std::vector<StmtPtr>({})),
 | 
						|
        nullptr);
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(cond);
 | 
						|
    ASSERT_EQ(simplified, nullptr);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    StmtPtr cond = alloc<Cond>(
 | 
						|
        ExprHandle(true).node(),
 | 
						|
        nullptr,
 | 
						|
        alloc<Block>(std::vector<StmtPtr>({})));
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(cond);
 | 
						|
    ASSERT_EQ(simplified, nullptr);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyEliminateEmptyCond) {
 | 
						|
  // If the branches are empty in different ways, eliminate.
 | 
						|
  {
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    ExprHandle condition(x);
 | 
						|
    StmtPtr true_val = alloc<Block>(std::vector<StmtPtr>({}));
 | 
						|
 | 
						|
    StmtPtr body = alloc<Cond>(condition.node(), true_val, nullptr);
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    ASSERT_NE(block, nullptr);
 | 
						|
    ASSERT_EQ(block->nstmts(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    ExprHandle condition(x);
 | 
						|
    StmtPtr false_val = alloc<Block>(std::vector<StmtPtr>({}));
 | 
						|
 | 
						|
    StmtPtr body = alloc<Cond>(condition.node(), nullptr, false_val);
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    ASSERT_NE(block, nullptr);
 | 
						|
    ASSERT_EQ(block->nstmts(), 0);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyConstantComparisons) {
 | 
						|
  auto ComparisonTest =
 | 
						|
      [](ExprHandle a, ExprHandle b, CompareSelectOperation op, int result) {
 | 
						|
        ExprHandle body = CompareSelect::make(a, b, op);
 | 
						|
        ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
        IS_IMM_WITH_VAL(Int, simplified.node(), result);
 | 
						|
      };
 | 
						|
 | 
						|
  // Equals.
 | 
						|
  ComparisonTest(2, 2, kEQ, 1);
 | 
						|
  ComparisonTest(1, 2, kEQ, 0);
 | 
						|
  ComparisonTest(2, 1, kEQ, 0);
 | 
						|
 | 
						|
  // Greater than.
 | 
						|
  ComparisonTest(2, 2, kGT, 0);
 | 
						|
  ComparisonTest(1, 2, kGT, 0);
 | 
						|
  ComparisonTest(2, 1, kGT, 1);
 | 
						|
 | 
						|
  // Greater or Equal.
 | 
						|
  ComparisonTest(2, 2, kGE, 1);
 | 
						|
  ComparisonTest(1, 2, kGE, 0);
 | 
						|
  ComparisonTest(2, 1, kGE, 1);
 | 
						|
 | 
						|
  // Less Than.
 | 
						|
  ComparisonTest(2, 2, kLT, 0);
 | 
						|
  ComparisonTest(1, 2, kLT, 1);
 | 
						|
  ComparisonTest(2, 1, kLT, 0);
 | 
						|
 | 
						|
  // Less or Equal.
 | 
						|
  ComparisonTest(2, 2, kLE, 1);
 | 
						|
  ComparisonTest(1, 2, kLE, 1);
 | 
						|
  ComparisonTest(2, 1, kLE, 0);
 | 
						|
 | 
						|
  // Not equal.
 | 
						|
  ComparisonTest(2, 2, kNE, 0);
 | 
						|
  ComparisonTest(1, 2, kNE, 1);
 | 
						|
  ComparisonTest(2, 1, kNE, 1);
 | 
						|
 | 
						|
  // With specified results:
 | 
						|
  ExprHandle body = CompareSelect::make(2, 2, 5, 42, kNE);
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
  IS_IMM_WITH_VAL(Int, simplified.node(), 42);
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifySymbolicComparisons) {
 | 
						|
  VarHandle x("x", kInt);
 | 
						|
  VarHandle y("y", kInt);
 | 
						|
 | 
						|
  auto TookTrueBranch = [](ExprHandle a) { IS_IMM_WITH_VAL(Int, a.node(), 1); };
 | 
						|
  auto TookFalseBranch = [](ExprHandle a) {
 | 
						|
    IS_IMM_WITH_VAL(Int, a.node(), 0);
 | 
						|
  };
 | 
						|
 | 
						|
  // EQ
 | 
						|
 | 
						|
  // x == x => 1
 | 
						|
  ExprHandle body = CompareSelect::make(x, x, kEQ);
 | 
						|
  TookTrueBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // x == x+1 => 0
 | 
						|
  body = CompareSelect::make(x, x + 1, kEQ);
 | 
						|
  TookFalseBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // x == x * 2 cannot simplify since we don't know x is nonzero.
 | 
						|
  body = CompareSelect::make(x, x * 2, kEQ);
 | 
						|
  // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
 | 
						|
  IS_NODE(CompareSelect, IRSimplifier::simplify(body).node());
 | 
						|
 | 
						|
  // x == x * 1 => 1
 | 
						|
  body = CompareSelect::make(x, x * 1, kEQ);
 | 
						|
  TookTrueBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  {
 | 
						|
    // x == y => x == y
 | 
						|
    body = CompareSelect::make(x, y, kEQ);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp);
 | 
						|
    ASSERT_EQ(cmp->compare_select_op(), kEQ);
 | 
						|
    IS_VAR_WITH_NAME(cmp->lhs(), "x");
 | 
						|
    IS_VAR_WITH_NAME(cmp->rhs(), "y");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // x == 5 => x == 5
 | 
						|
    body = CompareSelect::make(x, 5, kEQ);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp);
 | 
						|
    ASSERT_EQ(cmp->compare_select_op(), kEQ);
 | 
						|
    IS_VAR_WITH_NAME(cmp->lhs(), "x");
 | 
						|
    IS_IMM_WITH_VAL(Int, cmp->rhs(), 5);
 | 
						|
  }
 | 
						|
 | 
						|
  // GT
 | 
						|
 | 
						|
  // x+1 > x => 1
 | 
						|
  body = CompareSelect::make(x + 1, x, kGT);
 | 
						|
  TookTrueBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // x > x + 1 => 0
 | 
						|
  body = CompareSelect::make(x, x + 1, kGT);
 | 
						|
  TookFalseBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // x > x - 1 => 1
 | 
						|
  body = CompareSelect::make(x, x - 1, kGT);
 | 
						|
  TookTrueBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // x - 1 > x => 0
 | 
						|
  body = CompareSelect::make(x - 1, x, kGT);
 | 
						|
  TookFalseBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // x > x => 0
 | 
						|
  body = CompareSelect::make(x, x, kGT);
 | 
						|
  TookFalseBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // x * 2 > x => x * 2 > x
 | 
						|
  // since we don't know the sign of x.
 | 
						|
  body = CompareSelect::make(x * 2, x, kGT);
 | 
						|
  IS_NODE(CompareSelect, IRSimplifier::simplify(body).node());
 | 
						|
 | 
						|
  // GE
 | 
						|
 | 
						|
  // x+1 >= x => 1
 | 
						|
  body = CompareSelect::make(x + 1, x, kGE);
 | 
						|
  TookTrueBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // x >= x + 1 => 0
 | 
						|
  body = CompareSelect::make(x, x + 1, kGE);
 | 
						|
  TookFalseBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // x >= x => 1
 | 
						|
  body = CompareSelect::make(x, x, kGE);
 | 
						|
  TookTrueBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // x * 2 >= x => x * 2 >= x
 | 
						|
  // since we don't know the sign of x.
 | 
						|
  body = CompareSelect::make(x * 2, x, kGE);
 | 
						|
  IS_NODE(CompareSelect, IRSimplifier::simplify(body).node());
 | 
						|
 | 
						|
  // LT
 | 
						|
 | 
						|
  // x+1 < x => 0
 | 
						|
  body = CompareSelect::make(x + 1, x, kLT);
 | 
						|
  TookFalseBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // x < x + 1 => 1
 | 
						|
  body = CompareSelect::make(x, x + 1, kLT);
 | 
						|
  TookTrueBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // x < x => 0
 | 
						|
  body = CompareSelect::make(x, x, kLT);
 | 
						|
  TookFalseBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // LE
 | 
						|
 | 
						|
  // x+1 <= x => 0
 | 
						|
  body = CompareSelect::make(x + 1, x, kLE);
 | 
						|
  TookFalseBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // x <= x + 1 => 1
 | 
						|
  body = CompareSelect::make(x, x + 1, kLE);
 | 
						|
  TookTrueBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // x <= x => 1
 | 
						|
  body = CompareSelect::make(x, x, kLE);
 | 
						|
  TookTrueBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // NE
 | 
						|
 | 
						|
  // x+1 != x => 1
 | 
						|
  body = CompareSelect::make(x + 1, x, kNE);
 | 
						|
  TookTrueBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // x != x + 1 => 1
 | 
						|
  body = CompareSelect::make(x, x + 1, kNE);
 | 
						|
  TookTrueBranch(IRSimplifier::simplify(body));
 | 
						|
 | 
						|
  // x != x => 0
 | 
						|
  body = CompareSelect::make(x, x, kNE);
 | 
						|
  TookFalseBranch(IRSimplifier::simplify(body));
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyEliminateZeroLengthFor) {
 | 
						|
  {
 | 
						|
    // Will eliminate zero loop For.
 | 
						|
    BufHandle a("A", {4}, kInt);
 | 
						|
    BufHandle c("C", {4}, kInt);
 | 
						|
    VarHandle i("i", kInt);
 | 
						|
    auto body = For::make(i, 0, 0, Store::make(c, {i}, Load::make(a, {i})));
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    ASSERT_EQ(block->nstmts(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // still works if start is not zero.
 | 
						|
    BufHandle a("A", {4}, kInt);
 | 
						|
    BufHandle c("C", {4}, kInt);
 | 
						|
    VarHandle i("i", kInt);
 | 
						|
    auto body = For::make(i, 2, 2, Store::make(c, {i}, Load::make(a, {i})));
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    ASSERT_EQ(block->nstmts(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // works if both terms are variable.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    BufHandle a("A", {4}, kInt);
 | 
						|
    BufHandle c("C", {4}, kInt);
 | 
						|
    VarHandle i("i", kInt);
 | 
						|
    auto body = For::make(i, x, x, Store::make(c, {i}, Load::make(a, {i})));
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    ASSERT_EQ(block->nstmts(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // works if one term simplifies down.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    BufHandle a("A", {4}, kInt);
 | 
						|
    BufHandle c("C", {4}, kInt);
 | 
						|
    VarHandle i("i", kInt);
 | 
						|
    auto body = For::make(i, 0, x - x, Store::make(c, {i}, Load::make(a, {i})));
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    ASSERT_EQ(block->nstmts(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Sanity check does nothing if the condition is not met.
 | 
						|
    BufHandle a("A", {4}, kInt);
 | 
						|
    BufHandle c("C", {4}, kInt);
 | 
						|
    VarHandle i("i", kInt);
 | 
						|
    auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i})));
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE(For, simplified);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyOneLoopFor) {
 | 
						|
  {
 | 
						|
    // Will remove the loop if the body is run once.
 | 
						|
    BufHandle a("A", {4}, kInt);
 | 
						|
    BufHandle c("C", {4}, kInt);
 | 
						|
    VarHandle i("i", kInt);
 | 
						|
    auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})));
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    IS_NODE_WITH_NAME(Store, block->front(), store);
 | 
						|
    IS_VAR_WITH_NAME(store->base_handle(), "C");
 | 
						|
    IS_IMM_WITH_VAL(Int, store->flat_index(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // still works if start is not zero.
 | 
						|
    BufHandle a("A", {4}, kInt);
 | 
						|
    BufHandle c("C", {4}, kInt);
 | 
						|
    VarHandle i("i", kInt);
 | 
						|
    auto body = For::make(i, 2, 3, Store::make(c, {i}, Load::make(a, {i})));
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    IS_NODE_WITH_NAME(Store, block->front(), store);
 | 
						|
    IS_VAR_WITH_NAME(store->base_handle(), "C");
 | 
						|
    IS_IMM_WITH_VAL(Int, store->flat_index(), 2);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // works if both terms are variable.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    BufHandle a("A", {4}, kInt);
 | 
						|
    BufHandle c("C", {4}, kInt);
 | 
						|
    VarHandle i("i", kInt);
 | 
						|
    auto body = For::make(i, x, x + 1, Store::make(c, {i}, Load::make(a, {i})));
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    IS_NODE_WITH_NAME(Store, block->front(), store);
 | 
						|
    IS_VAR_WITH_NAME(store->base_handle(), "C");
 | 
						|
    IS_VAR_WITH_NAME(store->flat_index(), "x");
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // works if one term simplifies down.
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    BufHandle a("A", {4}, kInt);
 | 
						|
    BufHandle c("C", {4}, kInt);
 | 
						|
    VarHandle i("i", kInt);
 | 
						|
    auto body =
 | 
						|
        For::make(i, 0, x - x + 1, Store::make(c, {i}, Load::make(a, {i})));
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    IS_NODE_WITH_NAME(Store, block->front(), store);
 | 
						|
    IS_VAR_WITH_NAME(store->base_handle(), "C");
 | 
						|
    IS_IMM_WITH_VAL(Int, store->flat_index(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Sanity check does nothing if the condition is not met.
 | 
						|
    BufHandle a("A", {4}, kInt);
 | 
						|
    BufHandle c("C", {4}, kInt);
 | 
						|
    VarHandle i("i", kInt);
 | 
						|
    auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i})));
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE(For, simplified);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyForWontLoseLoopOptions) {
 | 
						|
  {
 | 
						|
    // Sanity check does nothing if the condition is not met.
 | 
						|
    BufHandle a("A", {4}, kInt);
 | 
						|
    BufHandle c("C", {4}, kInt);
 | 
						|
    VarHandle i("i", kInt);
 | 
						|
    LoopOptions options;
 | 
						|
    options.set_gpu_block_index(LoopOptions::IDX_W);
 | 
						|
    auto body =
 | 
						|
        For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})), options);
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(For, simplified, for_);
 | 
						|
    LoopOptions options2 = for_->loop_options();
 | 
						|
    ASSERT_EQ(options.gpu_block_index(), options2.gpu_block_index());
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyMultilevelFor) {
 | 
						|
  {
 | 
						|
    // Multiple layers of For will be simplified out.
 | 
						|
    BufHandle a("A", {4}, kInt);
 | 
						|
    BufHandle c("C", {4}, kInt);
 | 
						|
    VarHandle i("i", kInt);
 | 
						|
    VarHandle j("j", kInt);
 | 
						|
    auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})));
 | 
						|
    auto outer = For::make(j, 0, 1, body);
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(outer);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    IS_NODE_WITH_NAME(Store, block->front(), store);
 | 
						|
    IS_VAR_WITH_NAME(store->base_handle(), "C");
 | 
						|
    IS_IMM_WITH_VAL(Int, store->flat_index(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Will maintain an outer loop if the inner loop is eliminated.
 | 
						|
    BufHandle a("A", {4}, kInt);
 | 
						|
    BufHandle c("C", {4}, kInt);
 | 
						|
    VarHandle i("i", kInt);
 | 
						|
    VarHandle j("j", kInt);
 | 
						|
    auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})));
 | 
						|
    auto outer = For::make(j, 0, 2, body);
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(outer);
 | 
						|
    ForPtr for__ = static_to<For>(simplified);
 | 
						|
    IS_NODE_WITH_NAME(For, for__, for_);
 | 
						|
    IS_VAR_WITH_NAME(for_->var(), "j");
 | 
						|
    IS_IMM_WITH_VAL(Int, for_->start(), 0);
 | 
						|
    IS_IMM_WITH_VAL(Int, for_->stop(), 2);
 | 
						|
    BlockPtr block = to<Block>(for_->body());
 | 
						|
    ASSERT_NE(block, nullptr);
 | 
						|
    IS_NODE_WITH_NAME(Store, block->front(), store);
 | 
						|
    IS_VAR_WITH_NAME(store->base_handle(), "C");
 | 
						|
    IS_IMM_WITH_VAL(Int, store->flat_index(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Will maintain inner loop if outer loops is eliminated.
 | 
						|
    BufHandle a("A", {4}, kInt);
 | 
						|
    BufHandle c("C", {4}, kInt);
 | 
						|
    VarHandle i("i", kInt);
 | 
						|
    VarHandle j("j", kInt);
 | 
						|
    auto body = For::make(i, 0, 2, Store::make(c, {i}, Load::make(a, {i})));
 | 
						|
    auto outer = For::make(j, 0, 1, body);
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(outer);
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    IS_NODE_WITH_NAME(For, block->front(), for_);
 | 
						|
    IS_VAR_WITH_NAME(for_->var(), "i");
 | 
						|
    IS_IMM_WITH_VAL(Int, for_->start(), 0);
 | 
						|
    IS_IMM_WITH_VAL(Int, for_->stop(), 2);
 | 
						|
    IS_NODE_WITH_NAME(Store, for_->body()->front(), store);
 | 
						|
    IS_VAR_WITH_NAME(store->base_handle(), "C");
 | 
						|
    IS_VAR_WITH_NAME(store->flat_index(), "i");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyForCleansUp) {
 | 
						|
  {
 | 
						|
    BufHandle a("a", {1, 12, 1}, kFloat);
 | 
						|
    VarHandle x("x", kInt);
 | 
						|
    Tensor b = Compute(
 | 
						|
        "x",
 | 
						|
        {1, 12, 1},
 | 
						|
        [](const VarHandle& i, const VarHandle& m, const VarHandle& n) {
 | 
						|
          return i + m + n;
 | 
						|
        });
 | 
						|
    LoopNest l({b});
 | 
						|
    l.prepareForCodegen();
 | 
						|
 | 
						|
    StmtPtr body = LoopNest::sanitizeNames(l.root_stmt());
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
 | 
						|
    BlockPtr block = to<Block>(simplified);
 | 
						|
    IS_NODE_WITH_NAME(For, block->front(), for_);
 | 
						|
    // for is over "m".
 | 
						|
    IS_VAR_WITH_NAME(for_->var(), "j");
 | 
						|
    // x[m] = m;
 | 
						|
    IS_NODE_WITH_NAME(Store, for_->body()->front(), store);
 | 
						|
    IS_VAR_WITH_NAME(store->flat_index(), "j");
 | 
						|
    IS_VAR_WITH_NAME(store->value(), "j");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyEliminateEmptyFor) {
 | 
						|
  {
 | 
						|
    // Flatten many layers around an empty block to an empty block.
 | 
						|
    StmtPtr last = alloc<Block>(std::vector<StmtPtr>({}));
 | 
						|
    for ([[maybe_unused]] const auto i : c10::irange(11)) {
 | 
						|
      VarHandle loopVar("loopVar", kInt);
 | 
						|
      last = For::make(loopVar, 0, 10, last);
 | 
						|
    }
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(last);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 0);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyFlattenBlock) {
 | 
						|
  {
 | 
						|
    // Flatten multiple blocks down to one.
 | 
						|
    // { { { stmt1, stmt2 } } } =>  { stmt1, stmt2 }
 | 
						|
    BufHandle a("A", {1}, kInt);
 | 
						|
    StorePtr store1 = Store::make(a, {0}, 1);
 | 
						|
    StorePtr store2 = Store::make(a, {0}, 0);
 | 
						|
 | 
						|
    BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({store1, store2}));
 | 
						|
    BlockPtr block2 = alloc<Block>(std::vector<StmtPtr>({block1}));
 | 
						|
 | 
						|
    BlockPtr enclosing = alloc<Block>(std::vector<StmtPtr>({block2}));
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(enclosing);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 2);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Store, block->front(), store1_);
 | 
						|
    IS_NODE_WITH_NAME(Store, block->back(), store2_);
 | 
						|
 | 
						|
    ASSERT_EQ(store1->value(), store1_->value());
 | 
						|
    ASSERT_EQ(store2->value(), store2_->value());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Flatten multiple sub blocks containing statements.
 | 
						|
    // { { stmt1 }, { stmt2 } } =>  { stmt1, stmt2 }
 | 
						|
    BufHandle a("A", {1}, kInt);
 | 
						|
    StorePtr store1 = Store::make(a, {0}, 1);
 | 
						|
    StorePtr store2 = Store::make(a, {0}, 0);
 | 
						|
 | 
						|
    BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({store1}));
 | 
						|
    BlockPtr block2 = alloc<Block>(std::vector<StmtPtr>({store2}));
 | 
						|
 | 
						|
    BlockPtr enclosing = alloc<Block>(std::vector<StmtPtr>({block1, block2}));
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(enclosing);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 2);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Store, block->front(), store1_);
 | 
						|
    IS_NODE_WITH_NAME(Store, block->back(), store2_);
 | 
						|
 | 
						|
    ASSERT_EQ(store1->value(), store1_->value());
 | 
						|
    ASSERT_EQ(store2->value(), store2_->value());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Flatten sub blocks with different depths.
 | 
						|
    // { stmt1 , { { stmt2 } } } =>  { stmt1, stmt2 }
 | 
						|
    BufHandle a("A", {1}, kInt);
 | 
						|
    StorePtr store1 = Store::make(a, {0}, 1);
 | 
						|
    StorePtr store2 = Store::make(a, {0}, 0);
 | 
						|
 | 
						|
    BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({store2}));
 | 
						|
    BlockPtr block2 = alloc<Block>(std::vector<StmtPtr>({block1}));
 | 
						|
 | 
						|
    BlockPtr enclosing = alloc<Block>(std::vector<StmtPtr>({store1, block2}));
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(enclosing);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 2);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Store, block->front(), store1_);
 | 
						|
    IS_NODE_WITH_NAME(Store, block->back(), store2_);
 | 
						|
 | 
						|
    ASSERT_EQ(store1->value(), store1_->value());
 | 
						|
    ASSERT_EQ(store2->value(), store2_->value());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Flatten many layers around an empty block to an empty block.
 | 
						|
    StmtPtr last = alloc<Block>(std::vector<StmtPtr>({}));
 | 
						|
    for ([[maybe_unused]] const auto i : c10::irange(11)) {
 | 
						|
      last = alloc<Block>(std::vector<StmtPtr>({last}));
 | 
						|
    }
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(last);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 0);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyEliminateZeroLengthAlloc) {
 | 
						|
  {
 | 
						|
    // Simple positive case.
 | 
						|
    BufHandle b("x", {0}, kInt);
 | 
						|
 | 
						|
    AllocatePtr alloc_ = Allocate::make(b);
 | 
						|
    FreePtr free_ = Free::make(b);
 | 
						|
 | 
						|
    BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({alloc_, free_}));
 | 
						|
    ASSERT_EQ(block1->nstmts(), 2);
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(block1);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block2);
 | 
						|
    ASSERT_EQ(block2->nstmts(), 0);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Simple negative case.
 | 
						|
    BufHandle b("x", {2}, kInt);
 | 
						|
 | 
						|
    AllocatePtr alloc_ = Allocate::make(b);
 | 
						|
    FreePtr free_ = Free::make(b);
 | 
						|
 | 
						|
    BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({alloc_, free_}));
 | 
						|
    ASSERT_EQ(block1->nstmts(), 2);
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(block1);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block2);
 | 
						|
    ASSERT_EQ(block2->nstmts(), 2);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Finds right Alloc/Free.
 | 
						|
    BufHandle b1("x", {0}, kInt);
 | 
						|
    BufHandle b2("y", {2}, kInt);
 | 
						|
 | 
						|
    AllocatePtr alloc1 = Allocate::make(b1);
 | 
						|
    AllocatePtr alloc2 = Allocate::make(b2);
 | 
						|
    FreePtr free2_ = Free::make(b2);
 | 
						|
    FreePtr free1_ = Free::make(b1);
 | 
						|
 | 
						|
    BlockPtr block1 =
 | 
						|
        alloc<Block>(std::vector<StmtPtr>({alloc1, alloc2, free2_, free1_}));
 | 
						|
    ASSERT_EQ(block1->nstmts(), 4);
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(block1);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block2);
 | 
						|
    ASSERT_EQ(block2->nstmts(), 2);
 | 
						|
    IS_NODE_WITH_NAME(Allocate, block2->stmts().front(), simplified_alloc);
 | 
						|
    IS_VAR_WITH_NAME(simplified_alloc->buffer_var(), "y");
 | 
						|
    IS_NODE_WITH_NAME(Free, block2->stmts().back(), simplified_free);
 | 
						|
    ASSERT_EQ(simplified_alloc->buffer_var(), simplified_free->buffer_var());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Dynamic shape.
 | 
						|
    VarHandle z("z", kInt);
 | 
						|
    BufHandle b1("x", {0}, kInt);
 | 
						|
    BufHandle b2("y", {z}, kInt);
 | 
						|
 | 
						|
    AllocatePtr alloc1 = Allocate::make(b1);
 | 
						|
    AllocatePtr alloc2 = Allocate::make(b2);
 | 
						|
    FreePtr free2_ = Free::make(b2);
 | 
						|
    FreePtr free1_ = Free::make(b1);
 | 
						|
 | 
						|
    BlockPtr block1 =
 | 
						|
        alloc<Block>(std::vector<StmtPtr>({alloc1, alloc2, free2_, free1_}));
 | 
						|
    ASSERT_EQ(block1->nstmts(), 4);
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(block1);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block2);
 | 
						|
    ASSERT_EQ(block2->nstmts(), 2);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, DontSimplifyRand) {
 | 
						|
  {
 | 
						|
    // rand() + rand() = rand() + rand() NOT 2 * rand().
 | 
						|
    ExprHandle body =
 | 
						|
        Intrinsics::make(kRand, kInt) + Intrinsics::make(kRand, kInt);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Add, simplified.node(), add);
 | 
						|
    IS_RAND(add->lhs());
 | 
						|
    IS_RAND(add->rhs());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // rand() - rand() = rand() - rand() NOT 0.
 | 
						|
    ExprHandle body =
 | 
						|
        Intrinsics::make(kRand, kFloat) - Intrinsics::make(kRand, kFloat);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
 | 
						|
    IS_RAND(sub->lhs());
 | 
						|
    IS_RAND(sub->rhs());
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // rand() * rand() = rand() * rand().
 | 
						|
    ExprHandle body =
 | 
						|
        Intrinsics::make(kRand, kInt) * Intrinsics::make(kRand, kInt);
 | 
						|
    ExprHandle simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
 | 
						|
    IS_RAND(mul->lhs());
 | 
						|
    IS_RAND(mul->rhs());
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyReorderForCond) {
 | 
						|
  BufHandle a("A", {4}, kInt);
 | 
						|
  BufHandle b("B", {1}, kInt);
 | 
						|
  BufHandle c("C", {4}, kInt);
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  VarHandle j("j", kInt);
 | 
						|
 | 
						|
  {
 | 
						|
    // for ( if ( ... ) ) => if ( for ( ... ) ).
 | 
						|
    auto body = For::make(
 | 
						|
        i,
 | 
						|
        0,
 | 
						|
        4,
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(j, 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(c, {i}, Load::make(a, {i})),
 | 
						|
            nullptr));
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Cond, simplified, cond);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
 | 
						|
    IS_NODE_WITH_NAME(For, true_block->front(), loop);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Can't reorder if condition is dependent on the loop var.
 | 
						|
    auto body = For::make(
 | 
						|
        i,
 | 
						|
        0,
 | 
						|
        4,
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(i, 2, CompareSelectOperation::kEQ),
 | 
						|
            Store::make(c, {i}, Load::make(a, {i})),
 | 
						|
            nullptr));
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(For, simplified, loop);
 | 
						|
    IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Can't reorder if condition is dependent on a var that is modified inside
 | 
						|
    // the loop.
 | 
						|
    auto body = For::make(
 | 
						|
        i,
 | 
						|
        0,
 | 
						|
        4,
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(
 | 
						|
                Load::make(c, {0}), 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(c, {0}, Load::make(a, {i})),
 | 
						|
            nullptr));
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(For, simplified, loop);
 | 
						|
    IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Condition based on buffer not referenced in body. Can reorder here.
 | 
						|
    auto body = For::make(
 | 
						|
        i,
 | 
						|
        0,
 | 
						|
        4,
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(
 | 
						|
                Load::make(b, {0}), 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(c, {0}, Load::make(a, {i})),
 | 
						|
            nullptr));
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Cond, simplified, cond);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
 | 
						|
    IS_NODE_WITH_NAME(For, true_block->front(), loop);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Condition based on buffer read only in body. Can reorder here.
 | 
						|
    auto body = For::make(
 | 
						|
        i,
 | 
						|
        0,
 | 
						|
        4,
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(
 | 
						|
                Load::make(a, {0}), 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(c, {0}, Load::make(a, {i})),
 | 
						|
            nullptr));
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Cond, simplified, cond);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
 | 
						|
    IS_NODE_WITH_NAME(For, true_block->front(), loop);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Condition depends on Let in the loop. Cannot reorder.
 | 
						|
    auto body = For::make(
 | 
						|
        i,
 | 
						|
        0,
 | 
						|
        4,
 | 
						|
        Block::make(
 | 
						|
            {Let::make(j, 3),
 | 
						|
             Cond::make(
 | 
						|
                 CompareSelect::make(j, 10, CompareSelectOperation::kLT),
 | 
						|
                 Store::make(c, {0}, Load::make(a, {i})),
 | 
						|
                 nullptr)}));
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(For, simplified, loop);
 | 
						|
    IS_NODE_WITH_NAME(Let, loop->body()->front(), let);
 | 
						|
    IS_NODE_WITH_NAME(Cond, loop->body()->back(), cond);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Multi level Ifs where all conditions are distinct. Move BOTH Cond
 | 
						|
    // statements outside the loop.
 | 
						|
    auto body = For::make(
 | 
						|
        i,
 | 
						|
        0,
 | 
						|
        4,
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(
 | 
						|
                Load::make(a, {0}), 10, CompareSelectOperation::kLT),
 | 
						|
            Cond::make(
 | 
						|
                CompareSelect::make(j, 10, CompareSelectOperation::kEQ),
 | 
						|
                Store::make(c, {0}, Load::make(a, {i})),
 | 
						|
                nullptr),
 | 
						|
            nullptr));
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Cond, simplified, cond);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
 | 
						|
    IS_NODE_WITH_NAME(Cond, true_block->front(), cond2);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_block2);
 | 
						|
    IS_NODE_WITH_NAME(For, true_block2->front(), loop);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Multi level Ifs where the inner condition does depend on a loop var,
 | 
						|
    // reorder only the first Cond.
 | 
						|
    auto body = For::make(
 | 
						|
        i,
 | 
						|
        0,
 | 
						|
        4,
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(
 | 
						|
                Load::make(a, {0}), 10, CompareSelectOperation::kLT),
 | 
						|
            Cond::make(
 | 
						|
                CompareSelect::make(i, 3, CompareSelectOperation::kEQ),
 | 
						|
                Store::make(c, {0}, Load::make(a, {i})),
 | 
						|
                nullptr),
 | 
						|
            nullptr));
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Cond, simplified, cond);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
 | 
						|
    IS_NODE_WITH_NAME(For, true_block->front(), loop);
 | 
						|
    IS_NODE_WITH_NAME(Block, loop->body(), loop_body);
 | 
						|
    IS_NODE_WITH_NAME(Cond, loop_body->front(), cond2);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Don't reorder if there's an else block of the Cond.
 | 
						|
    // We could, but is it much better?
 | 
						|
    auto body = For::make(
 | 
						|
        i,
 | 
						|
        0,
 | 
						|
        4,
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(j, 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(c, {0}, Load::make(a, {i})),
 | 
						|
            Store::make(c, {0}, 0)));
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(For, simplified, loop);
 | 
						|
    IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Condition uses distinct region of Tensor.
 | 
						|
    // We could reorder here wih better analysis, but we don't. Included for
 | 
						|
    // completeness.
 | 
						|
    auto body = For::make(
 | 
						|
        i,
 | 
						|
        0,
 | 
						|
        4,
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(
 | 
						|
                Load::make(c, {0}), 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(c, {1}, Load::make(a, {i})),
 | 
						|
            nullptr));
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(For, simplified, loop);
 | 
						|
    IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyFuseConditions) {
 | 
						|
  BufHandle a("A", {2}, kInt);
 | 
						|
  BufHandle b("B", {2}, kInt);
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  VarHandle j("j", kInt);
 | 
						|
 | 
						|
  {
 | 
						|
    // Can fuse since the conditions are identical.
 | 
						|
    // if (A) { X }; if (A) { Y }; => if (A) { X; Y }
 | 
						|
    auto body = Block::make(
 | 
						|
        {Cond::make(
 | 
						|
             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
             Store::make(a, {0}, i),
 | 
						|
             nullptr),
 | 
						|
         Cond::make(
 | 
						|
             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
             Store::make(a, {1}, i),
 | 
						|
             nullptr)});
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 1);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->front(), cond);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
 | 
						|
    ASSERT_EQ(true_stmt->nstmts(), 2);
 | 
						|
    ASSERT_EQ(cond->false_stmt(), nullptr);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Can't fuse, conditions are not identical in lhs (i != j).
 | 
						|
    auto body = Block::make(
 | 
						|
        {Cond::make(
 | 
						|
             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
             Store::make(a, {0}, i),
 | 
						|
             nullptr),
 | 
						|
         Cond::make(
 | 
						|
             CompareSelect::make(j, 10, CompareSelectOperation::kLT),
 | 
						|
             Store::make(a, {1}, i),
 | 
						|
             nullptr)});
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 2);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->front(), cond1);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->back(), cond2);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2);
 | 
						|
    ASSERT_EQ(true_stmt1->nstmts(), 1);
 | 
						|
    ASSERT_EQ(true_stmt2->nstmts(), 1);
 | 
						|
 | 
						|
    ASSERT_EQ(cond1->false_stmt(), nullptr);
 | 
						|
    ASSERT_EQ(cond2->false_stmt(), nullptr);
 | 
						|
  }
 | 
						|
  {
 | 
						|
    // Can't fuse, conditions are not identical in rhs (10 != 11).
 | 
						|
    auto body = Block::make(
 | 
						|
        {Cond::make(
 | 
						|
             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
             Store::make(a, {0}, i),
 | 
						|
             nullptr),
 | 
						|
         Cond::make(
 | 
						|
             CompareSelect::make(i, 11, CompareSelectOperation::kLT),
 | 
						|
             Store::make(a, {1}, i),
 | 
						|
             nullptr)});
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 2);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->front(), cond1);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->back(), cond2);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2);
 | 
						|
    ASSERT_EQ(true_stmt1->nstmts(), 1);
 | 
						|
    ASSERT_EQ(true_stmt2->nstmts(), 1);
 | 
						|
 | 
						|
    ASSERT_EQ(cond1->false_stmt(), nullptr);
 | 
						|
    ASSERT_EQ(cond2->false_stmt(), nullptr);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Can't fuse, conditions are not identical in operation (LT vs GT).
 | 
						|
    auto body = Block::make(
 | 
						|
        {Cond::make(
 | 
						|
             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
             Store::make(a, {0}, i),
 | 
						|
             nullptr),
 | 
						|
         Cond::make(
 | 
						|
             CompareSelect::make(i, 10, CompareSelectOperation::kGT),
 | 
						|
             Store::make(a, {1}, i),
 | 
						|
             nullptr)});
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 2);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->front(), cond1);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->back(), cond2);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2);
 | 
						|
    ASSERT_EQ(true_stmt1->nstmts(), 1);
 | 
						|
    ASSERT_EQ(true_stmt2->nstmts(), 1);
 | 
						|
 | 
						|
    ASSERT_EQ(cond1->false_stmt(), nullptr);
 | 
						|
    ASSERT_EQ(cond2->false_stmt(), nullptr);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Can't fuse, CompareSelect results are different.
 | 
						|
    // Actually we totally could if we normalized CompareSelect results, but
 | 
						|
    // TODO for later.
 | 
						|
    auto body = Block::make(
 | 
						|
        {Cond::make(
 | 
						|
             CompareSelect::make(i, 10, 1, 0, CompareSelectOperation::kLT),
 | 
						|
             Store::make(a, {0}, i),
 | 
						|
             nullptr),
 | 
						|
         Cond::make(
 | 
						|
             CompareSelect::make(j, 10, 2, 0, CompareSelectOperation::kLT),
 | 
						|
             Store::make(a, {1}, i),
 | 
						|
             nullptr)});
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 2);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->front(), cond1);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->back(), cond2);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2);
 | 
						|
    ASSERT_EQ(true_stmt1->nstmts(), 1);
 | 
						|
    ASSERT_EQ(true_stmt2->nstmts(), 1);
 | 
						|
 | 
						|
    ASSERT_EQ(cond1->false_stmt(), nullptr);
 | 
						|
    ASSERT_EQ(cond2->false_stmt(), nullptr);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Can fuse with false stmt only.
 | 
						|
    auto body = Block::make(
 | 
						|
        {Cond::make(
 | 
						|
             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
             nullptr,
 | 
						|
             Store::make(a, {0}, i)),
 | 
						|
         Cond::make(
 | 
						|
             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
             nullptr,
 | 
						|
             Store::make(a, {1}, i))});
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 1);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->front(), cond);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->false_stmt(), false_stmt);
 | 
						|
    ASSERT_EQ(false_stmt->nstmts(), 2);
 | 
						|
    ASSERT_EQ(cond->true_stmt(), nullptr);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Can fuse with both true and false stmt.
 | 
						|
    auto body = Block::make(
 | 
						|
        {Cond::make(
 | 
						|
             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
             Store::make(a, {0}, i),
 | 
						|
             Store::make(b, {0}, i)),
 | 
						|
         Cond::make(
 | 
						|
             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
             Store::make(a, {1}, i),
 | 
						|
             Store::make(b, {1}, i))});
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 1);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->front(), cond);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
 | 
						|
    ASSERT_EQ(true_stmt->nstmts(), 2);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt);
 | 
						|
    ASSERT_EQ(false_stmt->nstmts(), 2);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Can fuse with mismatched true / false stmt existing
 | 
						|
    auto body = Block::make(
 | 
						|
        {Cond::make(
 | 
						|
             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
             Store::make(a, {0}, i),
 | 
						|
             nullptr),
 | 
						|
         Cond::make(
 | 
						|
             CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
             nullptr,
 | 
						|
             Store::make(b, {1}, i))});
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 1);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->front(), cond);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
 | 
						|
    ASSERT_EQ(true_stmt->nstmts(), 1);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt);
 | 
						|
    ASSERT_EQ(false_stmt->nstmts(), 1);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Can fuse partial block contents, ie when there are non fused stmts before
 | 
						|
    // and after.
 | 
						|
    // before:
 | 
						|
    // if (j < 10) { A[0] = j; }
 | 
						|
    // if (i < 10) { A[0] = i; }
 | 
						|
    // if (i < 10) { A[1] = i; }
 | 
						|
    // if (i < 11) { A[1] = j; }
 | 
						|
    //
 | 
						|
    // after:
 | 
						|
    //
 | 
						|
    // if (j < 10) { A[0] = j; }
 | 
						|
    // if (i < 10) {
 | 
						|
    //   A[0] = i;
 | 
						|
    //   A[1] = i;
 | 
						|
    // }
 | 
						|
    // if (i < 11) { A[1] = j; }
 | 
						|
 | 
						|
    auto body = Block::make({
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(j, 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(a, {0}, j),
 | 
						|
            nullptr),
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(a, {0}, i),
 | 
						|
            nullptr),
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(a, {1}, i),
 | 
						|
            nullptr),
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(i, 11, CompareSelectOperation::kLT),
 | 
						|
            Store::make(a, {1}, j),
 | 
						|
            nullptr),
 | 
						|
    });
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 3);
 | 
						|
    auto it = block->begin();
 | 
						|
    it++;
 | 
						|
    IS_NODE_WITH_NAME(Cond, *it, cond);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
 | 
						|
    ASSERT_EQ(true_stmt->nstmts(), 2);
 | 
						|
    ASSERT_EQ(cond->false_stmt(), nullptr);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Can fuse longer sequences of identical conditions.
 | 
						|
    auto body = Block::make({
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(a, {0}, j),
 | 
						|
            nullptr),
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(a, {0}, i),
 | 
						|
            nullptr),
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(a, {1}, i),
 | 
						|
            nullptr),
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(a, {1}, j),
 | 
						|
            nullptr),
 | 
						|
    });
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 1);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->front(), cond);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
 | 
						|
    ASSERT_EQ(true_stmt->nstmts(), 4);
 | 
						|
    ASSERT_EQ(cond->false_stmt(), nullptr);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Can't fuse through a non condition.
 | 
						|
    auto body = Block::make({
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(a, {0}, j),
 | 
						|
            nullptr),
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(a, {0}, i),
 | 
						|
            nullptr),
 | 
						|
        Store::make(b, {1}, i + j),
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(a, {1}, i),
 | 
						|
            nullptr),
 | 
						|
        Cond::make(
 | 
						|
            CompareSelect::make(i, 10, CompareSelectOperation::kLT),
 | 
						|
            Store::make(a, {1}, j),
 | 
						|
            nullptr),
 | 
						|
    });
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 3);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->front(), cond);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
 | 
						|
    ASSERT_EQ(true_stmt->nstmts(), 2);
 | 
						|
    ASSERT_EQ(cond->false_stmt(), nullptr);
 | 
						|
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->back(), cond2);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt2);
 | 
						|
    ASSERT_EQ(true_stmt2->nstmts(), 2);
 | 
						|
    ASSERT_EQ(cond2->false_stmt(), nullptr);
 | 
						|
 | 
						|
    auto it = block->begin();
 | 
						|
    it++;
 | 
						|
    IS_NODE_WITH_NAME(Store, *it, middle);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Can fuse if the conditions simplify to the same thing.
 | 
						|
    auto body = Block::make(
 | 
						|
        {Cond::make(
 | 
						|
             CompareSelect::make(
 | 
						|
                 i * 2,
 | 
						|
                 ExprHandle(87) % ExprHandle(11),
 | 
						|
                 CompareSelectOperation::kLT),
 | 
						|
             Store::make(a, {0}, i),
 | 
						|
             nullptr),
 | 
						|
         Cond::make(
 | 
						|
             CompareSelect::make(
 | 
						|
                 i * 2,
 | 
						|
                 ExprHandle(300) / ExprHandle(30),
 | 
						|
                 CompareSelectOperation::kLT),
 | 
						|
             Store::make(a, {1}, i),
 | 
						|
             nullptr)});
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 1);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->front(), cond);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
 | 
						|
    ASSERT_EQ(true_stmt->nstmts(), 2);
 | 
						|
    ASSERT_EQ(cond->false_stmt(), nullptr);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Can fuse non-CompareSelects.
 | 
						|
    // if (i) { X } if (i) { Y } => if (i) { X; Y }
 | 
						|
    auto body = Block::make(
 | 
						|
        {Cond::make(i, Store::make(a, {0}, i), nullptr),
 | 
						|
         Cond::make(i, Store::make(a, {1}, i), nullptr)});
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 1);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->front(), cond);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
 | 
						|
    ASSERT_EQ(true_stmt->nstmts(), 2);
 | 
						|
    ASSERT_EQ(cond->false_stmt(), nullptr);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Sanity check wont fuse different non-CompareSelects.
 | 
						|
    auto body = Block::make(
 | 
						|
        {Cond::make(i, Store::make(a, {0}, i), nullptr),
 | 
						|
         Cond::make(j, Store::make(a, {1}, i), nullptr)});
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 2);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->front(), cond1);
 | 
						|
    IS_NODE_WITH_NAME(Cond, block->back(), cond2);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Sanity check constant condition elimination still occurs when merging is
 | 
						|
    // possible.
 | 
						|
    auto body = Block::make(
 | 
						|
        {Cond::make(1, Store::make(a, {0}, i), nullptr),
 | 
						|
         Cond::make(1, Store::make(a, {1}, i), nullptr)});
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 2);
 | 
						|
    IS_NODE_WITH_NAME(Store, block->front(), store1);
 | 
						|
    IS_NODE_WITH_NAME(Store, block->back(), store2);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Sanity check for-cond reordering occurs after fusing.
 | 
						|
    auto body = For::make(
 | 
						|
        i,
 | 
						|
        0,
 | 
						|
        4,
 | 
						|
        Block::make(
 | 
						|
            {Cond::make(
 | 
						|
                 CompareSelect::make(j, 10, CompareSelectOperation::kLT),
 | 
						|
                 Store::make(a, {1}, Load::make(b, {0})),
 | 
						|
                 nullptr),
 | 
						|
             Cond::make(
 | 
						|
                 CompareSelect::make(j, 10, CompareSelectOperation::kLT),
 | 
						|
                 Store::make(a, {2}, Load::make(b, {0})),
 | 
						|
                 nullptr)}));
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Cond, simplified, cond);
 | 
						|
    IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
 | 
						|
    IS_NODE_WITH_NAME(For, true_block->front(), loop);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifySyncThreads) {
 | 
						|
  BufHandle a("A", {4}, kInt);
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
 | 
						|
  {
 | 
						|
    // Merge two inner SyncThreads.
 | 
						|
    auto body = Block::make(
 | 
						|
        // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
 | 
						|
        {Store::make(a, {0}, 1),
 | 
						|
         alloc<SyncThreads>(),
 | 
						|
         alloc<SyncThreads>(),
 | 
						|
         Store::make(a, {1}, 0)});
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 3);
 | 
						|
    auto it = block->begin();
 | 
						|
    IS_NODE(Store, *it++);
 | 
						|
    IS_NODE(SyncThreads, *it++);
 | 
						|
    IS_NODE(Store, *it++);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Eliminate outer SyncThreads.
 | 
						|
    auto body = Block::make(
 | 
						|
        {alloc<SyncThreads>(), Store::make(a, {1}, 0), alloc<SyncThreads>()});
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 1);
 | 
						|
    auto it = block->begin();
 | 
						|
    IS_NODE(Store, *it);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Merge many inner SyncThreads.
 | 
						|
    auto body = Block::make(
 | 
						|
        {Store::make(a, {0}, 1),
 | 
						|
         alloc<SyncThreads>(),
 | 
						|
         alloc<SyncThreads>(),
 | 
						|
         alloc<SyncThreads>(),
 | 
						|
         alloc<SyncThreads>(),
 | 
						|
         alloc<SyncThreads>(),
 | 
						|
         Store::make(a, {1}, 0)});
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 3);
 | 
						|
    auto it = block->begin();
 | 
						|
    IS_NODE(Store, *it++);
 | 
						|
    IS_NODE(SyncThreads, *it++);
 | 
						|
    IS_NODE(Store, *it++);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Merge multiple outer SyncThreads.
 | 
						|
    auto body = Block::make(
 | 
						|
        {alloc<SyncThreads>(),
 | 
						|
         alloc<SyncThreads>(),
 | 
						|
         Store::make(a, {1}, 0),
 | 
						|
         alloc<SyncThreads>(),
 | 
						|
         alloc<SyncThreads>(),
 | 
						|
         alloc<SyncThreads>(),
 | 
						|
         alloc<SyncThreads>()});
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 1);
 | 
						|
    auto it = block->begin();
 | 
						|
    IS_NODE(Store, *it);
 | 
						|
  }
 | 
						|
 | 
						|
  {
 | 
						|
    // Merge multiple sections;
 | 
						|
    auto body = Block::make(
 | 
						|
        {Store::make(a, {0}, 1),
 | 
						|
         alloc<SyncThreads>(),
 | 
						|
         alloc<SyncThreads>(),
 | 
						|
         Store::make(a, {1}, 0),
 | 
						|
         Store::make(a, {2}, 0),
 | 
						|
         alloc<SyncThreads>(),
 | 
						|
         alloc<SyncThreads>(),
 | 
						|
         alloc<SyncThreads>(),
 | 
						|
         Store::make(a, {3}, 0)});
 | 
						|
 | 
						|
    StmtPtr simplified = IRSimplifier::simplify(body);
 | 
						|
    IS_NODE_WITH_NAME(Block, simplified, block);
 | 
						|
    ASSERT_EQ(block->nstmts(), 6);
 | 
						|
    auto it = block->begin();
 | 
						|
    IS_NODE(Store, *it++);
 | 
						|
    IS_NODE(SyncThreads, *it++);
 | 
						|
    IS_NODE(Store, *it++);
 | 
						|
    IS_NODE(Store, *it++);
 | 
						|
    IS_NODE(SyncThreads, *it++);
 | 
						|
    IS_NODE(Store, *it++);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyRampSubBroadcast) {
 | 
						|
  int num_lanes = 4;
 | 
						|
  ExprHandle ramp = Ramp::make(ExprHandle(0), ExprHandle(6), num_lanes);
 | 
						|
  ExprHandle broadcast = Broadcast::make(ExprHandle(-5), num_lanes);
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(ramp - broadcast);
 | 
						|
  RampPtr newRamp = simplified.AsNode<Ramp>();
 | 
						|
  IS_NODE_WITH_NAME(IntImm, newRamp->base(), base);
 | 
						|
  ASSERT_EQ(base->value(), 5);
 | 
						|
  IS_NODE_WITH_NAME(IntImm, newRamp->stride(), stride);
 | 
						|
  ASSERT_EQ(stride->value(), 6);
 | 
						|
  ASSERT_EQ(newRamp->lanes(), num_lanes);
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, SimplifyBroadcastTermExpander) {
 | 
						|
  int num_lanes = 8;
 | 
						|
  ExprHandle bc0 = Broadcast::make(ExprHandle(0), num_lanes);
 | 
						|
  ExprHandle bc1 = Broadcast::make(ExprHandle(1), num_lanes);
 | 
						|
  ExprHandle bc2 = Broadcast::make(ExprHandle(2), num_lanes);
 | 
						|
  // NB: We need a term in the middle which isn't simplified to trigger the
 | 
						|
  // relevant path in TermExpander::mutate. The two bc1 terms are brought
 | 
						|
  // together and simplified to 2 * bc1, which then needs to make 2 multi-lane.
 | 
						|
  ExprHandle simplified = IRSimplifier::simplify(bc1 + (bc0 / bc2) + bc1);
 | 
						|
  BufHandle buf("buf", {num_lanes}, kInt);
 | 
						|
  // The result isn't fully simplified currently and thus would be brittle to
 | 
						|
  // match. Observe its value instead.
 | 
						|
  auto store = Store::make(buf, {Ramp::make(0, 1, num_lanes)}, simplified);
 | 
						|
  SimpleIREvaluator eval(store, {buf});
 | 
						|
  std::vector<int> output(num_lanes);
 | 
						|
  eval(output);
 | 
						|
  for (const auto i : c10::irange(num_lanes)) {
 | 
						|
    ASSERT_EQ(output[i], 2);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, CompareSelectLoopBounds) {
 | 
						|
  constexpr int N = 8;
 | 
						|
  BufHandle b("b", {N}, kFloat);
 | 
						|
  VarHandle n("n", kInt);
 | 
						|
  VarHandle m("m", kInt);
 | 
						|
  VarHandle var_N("var_N", kInt);
 | 
						|
  VarHandle var_M("var_M", kInt);
 | 
						|
 | 
						|
  auto test_case_fn = [](const VarHandle& n,
 | 
						|
                         const BufHandle& b,
 | 
						|
                         const ExprHandle& start,
 | 
						|
                         const ExprHandle& stop,
 | 
						|
                         const int& cmp_val,
 | 
						|
                         const CompareSelectOperation& cmp_op,
 | 
						|
                         const std::string& check_string) {
 | 
						|
    StmtPtr s = For::make(
 | 
						|
        n,
 | 
						|
        start,
 | 
						|
        stop,
 | 
						|
        b.store({n}, CompareSelect::make(n, cmp_val, 0.f, 1.0f, cmp_op)));
 | 
						|
    s = IRSimplifier::simplify(s);
 | 
						|
    std::ostringstream oss;
 | 
						|
    oss << *s;
 | 
						|
    std::string target_string = "# CHECK: ";
 | 
						|
    target_string += check_string;
 | 
						|
    torch::jit::testing::FileCheck().run(target_string, oss.str());
 | 
						|
  };
 | 
						|
 | 
						|
  auto test_case_nest_loops_fn = [](const VarHandle& n,
 | 
						|
                                    const VarHandle& m,
 | 
						|
                                    const BufHandle& b,
 | 
						|
                                    const ExprHandle& n_start,
 | 
						|
                                    const ExprHandle& n_stop,
 | 
						|
                                    const ExprHandle& m_start,
 | 
						|
                                    const ExprHandle& m_stop,
 | 
						|
                                    const CompareSelectOperation& cmp_op,
 | 
						|
                                    const std::string& check_string) {
 | 
						|
    StmtPtr s = For::make(
 | 
						|
        m,
 | 
						|
        m_start,
 | 
						|
        m_stop,
 | 
						|
        b.store({n, m}, CompareSelect::make(n, m, 0.f, 1.0f, cmp_op)));
 | 
						|
    StmtPtr root_s = For::make(n, n_start, n_stop, s);
 | 
						|
    root_s = IRSimplifier::simplify(root_s);
 | 
						|
    std::ostringstream oss;
 | 
						|
    oss << *root_s;
 | 
						|
    std::string target_string = "# CHECK: ";
 | 
						|
    target_string += check_string;
 | 
						|
    torch::jit::testing::FileCheck().run(target_string, oss.str());
 | 
						|
  };
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n < 1 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, 1, kLT, "b[n] = 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n <= 1 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n <= 1 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, 1, kLE, "b[n] = n<=1 ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n <= 0 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, 0, kLE, "b[n] = 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n < 0 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, 0, kLT, "b[n] = 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n < 8 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 0.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, N, kLT, "b[n] = 0.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n <= 7 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 0.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, N - 1, kLE, "b[n] = 0.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n <= 8 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 0.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, N, kLE, "b[n] = 0.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n < 7 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n < 7 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, N - 1, kLT, "b[n] = n<7 ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n > 0 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 0.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, 0, kGT, "b[n] = 0.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n > 1 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n > 1 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, 1, kGT, "b[n] = n>1 ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n >= 1 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 0.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, 1, kGE, "b[n] = 0.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n > 7 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, N - 1, kGT, "b[n] = 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n >= 7 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n >= 7 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, N - 1, kGE, "b[n] = n>=7 ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n > 5 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n > 5 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, 5, kGT, "b[n] = n>5 ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n >= 5 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n >= 5 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, 5, kGE, "b[n] = n>=5 ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n > 8 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, N, kGT, "b[n] = 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n >= 8 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, N, kGE, "b[n] = 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, 2)) {
 | 
						|
  //     b[n] = n == 1 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, 2)) {
 | 
						|
  //     b[1] = 0.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, 2, 1, kEQ, "b[1] = 0.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n == 1 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n == 1 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, 1, kEQ, "b[n] = n==1 ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n == 0 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, 0, kEQ, "b[n] = 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n == 7 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n == 7 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, N - 1, kEQ, "b[n] = n==7 ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n == 8 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, N, kEQ, "b[n] = 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n != 1 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n != 1 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, 1, kNE, "b[n] = n!=1 ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n != 7 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n != 7 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, N - 1, kNE, "b[n] = n!=7 ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n != 5 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n != 5 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, 5, kNE, "b[n] = n!=5 ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n != 0 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 0.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, 0, kNE, "b[n] = 0.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n != 8 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 0.f;
 | 
						|
  //   }
 | 
						|
  test_case_fn(n, b, 1, N, N, kNE, "b[n] = 0.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(10, 20)) {
 | 
						|
  //     for(const auto m : c10::irange(30, 40)) {
 | 
						|
  //       b[n, m] = (n != m) ? 0.f : 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(10, 20)) {
 | 
						|
  //     for(const auto m : c10::irange(30, 40)) {
 | 
						|
  //       b[n, m] = 0.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kNE, "b[n, m] = 0.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 20,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      kNE,
 | 
						|
      "b[n, m] = 0.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 20,
 | 
						|
      var_M + 30,
 | 
						|
      var_M + 40,
 | 
						|
      kNE,
 | 
						|
      "b[n, m] = n!=m ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(30, 40)) {
 | 
						|
  //     for(const auto m : c10::irange(10, 20)) {
 | 
						|
  //       b[n, m] = (n != m) ? 0.f : 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(30, 40)) {
 | 
						|
  //     for(const auto m : c10::irange(10, 20)) {
 | 
						|
  //       b[n, m] = 0.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kNE, "b[n, m] = 0.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 20,
 | 
						|
      kNE,
 | 
						|
      "b[n, m] = 0.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      var_M + 10,
 | 
						|
      var_M + 20,
 | 
						|
      kNE,
 | 
						|
      "b[n, m] = n!=m ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(30, 40)) {
 | 
						|
  //     for(const auto m : c10::irange(10, 31)) {
 | 
						|
  //       b[n, m] = (n != m) ? 0.f : 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(30, 40)) {
 | 
						|
  //     for(const auto m : c10::irange(10, 31)) {
 | 
						|
  //       b[n, m] = (n != m) ? 0.f : 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n, m, b, 30, 40, 10, 31, kNE, "b[n, m] = n!=m ? 0.f : 1.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 31,
 | 
						|
      kNE,
 | 
						|
      "b[n, m] = n!=m ? 0.f : 1.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      var_M + 10,
 | 
						|
      var_M + 31,
 | 
						|
      kNE,
 | 
						|
      "b[n, m] = n!=m ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(10, 31)) {
 | 
						|
  //     for(const auto m : c10::irange(30, 40)) {
 | 
						|
  //       b[n, m] = (n != m) ? 0.f : 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(10, 31)) {
 | 
						|
  //     for(const auto m : c10::irange(30, 40)) {
 | 
						|
  //       b[n, m] = (n != m) ? 0.f : 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n, m, b, 10, 31, 30, 40, kNE, "b[n, m] = n!=m ? 0.f : 1.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 31,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      kNE,
 | 
						|
      "b[n, m] = n!=m ? 0.f : 1.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 31,
 | 
						|
      var_M + 30,
 | 
						|
      var_M + 40,
 | 
						|
      kNE,
 | 
						|
      "b[n, m] = n!=m ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(10, 20)) {
 | 
						|
  //     for(const auto m : c10::irange(30, 40)) {
 | 
						|
  //       b[n, m] = (n < m) ? 0.f : 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(10, 20)) {
 | 
						|
  //     for(const auto m : c10::irange(30, 40)) {
 | 
						|
  //       b[n, m] = 0.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kLT, "b[n, m] = 0.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 20,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      kLT,
 | 
						|
      "b[n, m] = 0.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 20,
 | 
						|
      var_M + 30,
 | 
						|
      var_M + 40,
 | 
						|
      kLT,
 | 
						|
      "b[n, m] = n<m ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(30, 40)) {
 | 
						|
  //     for(const auto m : c10::irange(10, 31)) {
 | 
						|
  //       b[n, m] = (n < m) ? 0.f : 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(30, 40)) {
 | 
						|
  //     for(const auto m : c10::irange(10, 31)) {
 | 
						|
  //       b[n, m] = 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  test_case_nest_loops_fn(n, m, b, 30, 40, 10, 31, kLT, "b[n, m] = 1.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 31,
 | 
						|
      kLT,
 | 
						|
      "b[n, m] = 1.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      var_M + 10,
 | 
						|
      var_M + 31,
 | 
						|
      kLT,
 | 
						|
      "b[n, m] = n<m ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(30, 40)) {
 | 
						|
  //     for(const auto m : c10::irange(10, 20)) {
 | 
						|
  //       b[n, m] = (n > m) ? 0.f : 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(30, 40)) {
 | 
						|
  //     for(const auto m : c10::irange(10, 20)) {
 | 
						|
  //       b[n, m] = 0.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kGT, "b[n, m] = 0.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 20,
 | 
						|
      kGT,
 | 
						|
      "b[n, m] = 0.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      var_M + 10,
 | 
						|
      var_M + 20,
 | 
						|
      kGT,
 | 
						|
      "b[n, m] = n>m ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(10, 31)) {
 | 
						|
  //     for(const auto m : c10::irange(30, 40)) {
 | 
						|
  //       b[n, m] = (n > m) ? 0.f : 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(10, 31)) {
 | 
						|
  //     for(const auto m : c10::irange(30, 40)) {
 | 
						|
  //       b[n, m] = 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kGT, "b[n, m] = 1.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 31,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      kGT,
 | 
						|
      "b[n, m] = 1.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 31,
 | 
						|
      var_M + 30,
 | 
						|
      var_M + 40,
 | 
						|
      kGT,
 | 
						|
      "b[n, m] = n>m ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(30, 40)) {
 | 
						|
  //     for(const auto m : c10::irange(10, 31)) {
 | 
						|
  //       b[n, m] = (n >= m) ? 0.f : 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(30, 40)) {
 | 
						|
  //     for(const auto m : c10::irange(10, 31)) {
 | 
						|
  //       b[n, m] = 0.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  test_case_nest_loops_fn(n, m, b, 30, 40, 10, 31, kGE, "b[n, m] = 0.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 31,
 | 
						|
      kGE,
 | 
						|
      "b[n, m] = 0.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      var_M + 10,
 | 
						|
      var_M + 31,
 | 
						|
      kGE,
 | 
						|
      "b[n, m] = n>=m ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(10, 20)) {
 | 
						|
  //     for(const auto m : c10::irange(30, 40)) {
 | 
						|
  //       b[n, m] = (n >= m) ? 0.f : 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(10, 20)) {
 | 
						|
  //     for(const auto m : c10::irange(30, 40)) {
 | 
						|
  //       b[n, m] = 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kGE, "b[n, m] = 1.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 20,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      kGE,
 | 
						|
      "b[n, m] = 1.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 20,
 | 
						|
      var_M + 30,
 | 
						|
      var_M + 40,
 | 
						|
      kGE,
 | 
						|
      "b[n, m] = n>=m ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(10, 31)) {
 | 
						|
  //     for(const auto m : c10::irange(30, 40)) {
 | 
						|
  //       b[n, m] = (n <= m) ? 0.f : 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(10, 31)) {
 | 
						|
  //     for(const auto m : c10::irange(30, 40)) {
 | 
						|
  //       b[n, m] = 0.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kLE, "b[n, m] = 0.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 31,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      kLE,
 | 
						|
      "b[n, m] = 0.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 31,
 | 
						|
      var_M + 30,
 | 
						|
      var_M + 40,
 | 
						|
      kLE,
 | 
						|
      "b[n, m] = n<=m ? 0.f : 1.f;");
 | 
						|
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(30, 40)) {
 | 
						|
  //     for(const auto m : c10::irange(10, 20)) {
 | 
						|
  //       b[n, m] = (n <= m) ? 0.f : 1.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(30, 40)) {
 | 
						|
  //     for(const auto m : c10::irange(10, 20)) {
 | 
						|
  //       b[n, m] = 0.f;
 | 
						|
  //     }
 | 
						|
  //   }
 | 
						|
  test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kLE, "b[n, m] = 1.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      var_N + 10,
 | 
						|
      var_N + 20,
 | 
						|
      kLE,
 | 
						|
      "b[n, m] = 1.f;");
 | 
						|
  test_case_nest_loops_fn(
 | 
						|
      n,
 | 
						|
      m,
 | 
						|
      b,
 | 
						|
      var_N + 30,
 | 
						|
      var_N + 40,
 | 
						|
      var_M + 10,
 | 
						|
      var_M + 20,
 | 
						|
      kLE,
 | 
						|
      "b[n, m] = n<=m ? 0.f : 1.f;");
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, CompareSelectCondAlwaysInLoopBounds) {
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = n < 1 ? 0.f : 1.f;
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 1.f;
 | 
						|
  //   }
 | 
						|
  constexpr int N = 8;
 | 
						|
  BufHandle b("b", {N}, kFloat);
 | 
						|
  VarHandle n("n", kInt);
 | 
						|
  StmtPtr s = For::make(
 | 
						|
      n, 1, N, b.store({n}, CompareSelect::make(n, 1, 0.f, 1.0f, kLT)));
 | 
						|
  s = IRSimplifier::simplify(s);
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *s;
 | 
						|
  torch::jit::testing::FileCheck().run(
 | 
						|
      R"IR(
 | 
						|
# CHECK: b[n] = 1.f;
 | 
						|
)IR",
 | 
						|
      oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, IfThenCondAlwaysInLoopBounds) {
 | 
						|
  // Before:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = IfThenElse(n < 1 ? 1 : 0, 0.f, 1.f);
 | 
						|
  //   }
 | 
						|
  // After:
 | 
						|
  //   for (const auto n : c10::irange(1, N)) {
 | 
						|
  //     b[n] = 1.f;
 | 
						|
  //   }
 | 
						|
  constexpr int N = 8;
 | 
						|
  BufHandle b("b", {N}, kFloat);
 | 
						|
  VarHandle n("n", kInt);
 | 
						|
  StmtPtr s =
 | 
						|
      For::make(n, 1, N, b.store({n}, IfThenElse::make(n < 1, 0.f, 1.0f)));
 | 
						|
  s = IRSimplifier::simplify(s);
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *s;
 | 
						|
  torch::jit::testing::FileCheck().run(
 | 
						|
      R"IR(
 | 
						|
# CHECK: b[n] = 1.f;
 | 
						|
)IR",
 | 
						|
      oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, MultiClauseCondAlwaysInLoopBounds) {
 | 
						|
  // This test mimics the unpadded region of a conv2d.  We want to remove any
 | 
						|
  // conditional that is provably satisfied (or unsatisfied) by the entire loop
 | 
						|
  // range.
 | 
						|
  // Before:
 | 
						|
  //   for (const auto i : c10::irange(1, 7)) {
 | 
						|
  //     for (const auto j : c10::irange(1, 7)) {
 | 
						|
  //       b[i, j] = IfThenElse(
 | 
						|
  //         j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, 1.f);
 | 
						|
  // After:
 | 
						|
  //   for (const auto i : c10::irange(1, 7)) {
 | 
						|
  //     for (const auto j : c10::irange(1, 7)) {
 | 
						|
  //       b[i, j] = 1.f;
 | 
						|
  constexpr int N = 8;
 | 
						|
  BufHandle b("b", {N, N}, kFloat);
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  VarHandle j("j", kInt);
 | 
						|
  auto csel = CompareSelect::make(i, 1, kLT);
 | 
						|
  csel = CompareSelect::make(j, 1, 1, csel, kLT);
 | 
						|
  csel = CompareSelect::make(i, N - 1, 1, csel, kGE);
 | 
						|
  csel = CompareSelect::make(j, N - 1, 1, csel, kGE);
 | 
						|
  StmtPtr s = b.store({i, j}, IfThenElse::make(csel, 0.f, 1.0f));
 | 
						|
  s = For::make(j, 1, N - 1, s);
 | 
						|
  s = For::make(i, 1, N - 1, s);
 | 
						|
  s = IRSimplifier::simplify(s);
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *s;
 | 
						|
  torch::jit::testing::FileCheck().run(
 | 
						|
      R"IR(
 | 
						|
# CHECK: b[i, j] = 1.f;
 | 
						|
)IR",
 | 
						|
      oss.str());
 | 
						|
}
 | 
						|
 | 
						|
TEST(Simplify, DISABLED_SimplifyLoopBounds) {
 | 
						|
  // This test mimics the padded region of a conv2d.  We want to adjust the
 | 
						|
  // loop bounds such that the condition will be always met.  Note that this
 | 
						|
  // could be solved by peeling, and applying the range-based conditional
 | 
						|
  // simplification in the previous tests.
 | 
						|
  // Before:
 | 
						|
  //   for (const auto i : c10::irange(3)) {
 | 
						|
  //     for (const auto j : c10::irange(3)) {
 | 
						|
  //       b[i, j] = (b[i, j]) + (IfThenElse(
 | 
						|
  //         j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, a[i, j]));
 | 
						|
  // After:
 | 
						|
  //   for (const auto i : c10::irange(1, 3)) {
 | 
						|
  //     for (const auto j : c10::irange(1, 3)) {
 | 
						|
  //       b[i, j] = (b[i, j]) + 1.f;
 | 
						|
  constexpr int N = 8;
 | 
						|
  constexpr int K = 3;
 | 
						|
  BufHandle a("a", {N, N}, kFloat);
 | 
						|
  BufHandle b("b", {N, N}, kFloat);
 | 
						|
  VarHandle i("i", kInt);
 | 
						|
  VarHandle j("j", kInt);
 | 
						|
  auto csel = CompareSelect::make(i, 1, kLT);
 | 
						|
  csel = CompareSelect::make(j, 1, 1, csel, kLT);
 | 
						|
  csel = CompareSelect::make(i, N - 1, 1, csel, kGE);
 | 
						|
  csel = CompareSelect::make(j, N - 1, 1, csel, kGE);
 | 
						|
  StmtPtr s = b.store(
 | 
						|
      {i, j}, b.load({i, j}) + IfThenElse::make(csel, 0.f, a.load({i, j})));
 | 
						|
  s = For::make(j, 0, K, s);
 | 
						|
  s = For::make(i, 0, K, s);
 | 
						|
  s = IRSimplifier::simplify(s);
 | 
						|
  std::ostringstream oss;
 | 
						|
  oss << *s;
 | 
						|
  torch::jit::testing::FileCheck().run(
 | 
						|
      R"IR(
 | 
						|
# CHECK: for (const auto i : c10::irange(1, 3)) {
 | 
						|
# CHECK: for (const auto j : c10::irange(1, 3)) {
 | 
						|
# CHECK-NOT: IfThenElse
 | 
						|
)IR",
 | 
						|
      oss.str());
 | 
						|
}
 | 
						|
 | 
						|
} // namespace jit
 | 
						|
} // namespace torch
 |