[TensorExpr] Extend arithmetic simplifier to work with multi variable expressions (#35127)

Summary:
A new version of the IR simplifier used by the jit/tensorexpr fuser. This is capable of simplifying expressions containing (shock) multiple variables, eg:

```(m * (1 * n_1) + (n  + 1)) - (m *  (1 * n_1) + n) => 1```

Similar to the previous IR Simplifier it uses a two stage approach:
1. Traverse the tree combining subtree's of commutable operations in to a flat structure. In this implementation we have two intermediate Exprs: Term (expressing products of sub expressions) and Polynomial (expressing sums of sub expressions).
2. Traverse the tree expanding Term's and Polynomials into their component operators.

Using the example above we execute with a process like this to simplify:
```
   (m * (1 * n_1) + (n  + 1)) - (m *  (1 * n_1) + n)
# Using PolynomialTransformer:
=> Sub(Add(Mul(m, Mul(1, n_1)), Add(n, 1)), Add(Mul(m, Mul(1, n_1)), n))
=> Sub(Polynomial(Term(m, n_1), n, 1), Polynomial(Term(m, n_1), n))
=> Polynomial(Term(m, n_1), Term(-1, m, n_1), n, -n, 1)
=> Polynomial(1)
# Using TermExpander
=> 1
```

The IRSimplifier supports arithmetic simplifications of operators Add, Sub and Mul and constant folding of all binary Exprs and Intrinsics, but does not attempt expansion of multiplication of Polynomials to the canonical form since that generally leads to less efficient representations. It will do scalar factorization if it results in removal of operators, and will merge chains of multilane primitives (such as Broadcast and Ramp) down into a single operator. The ir_simplifier unit tests are a short tour of its capabilities.

The existing simplifier has a bug where it will sometimes reorder operations on floating point types which are not associative. This causes (at least) the pyhpc equation_of_state benchmark to produce incorrect results. I have fixed that issue in this version and verified that that benchmark produces the same results with and without the simplifier.

Tests: all cpp & py tensorexpr tests, and pyphc benchmark:
```
benchmarks.equation_of_state
============================
Running on CPU

size          backend     calls     mean      stdev     min       25%       median    75%       max   Δ
------------------------------------------------------------------------------------------------------------------
   4,194,304  pytorch           10     0.246     0.002     0.243     0.245     0.246     0.248     0.250     1.000
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35127

Differential Revision: D20624571

Pulled By: nickgg

fbshipit-source-id: e49049377beee69e02dcf26eb922bef1447ae776
This commit is contained in:
Nick Gibson
2020-03-24 14:10:58 -07:00
committed by Facebook GitHub Bot
parent 2dc2933358
commit fce67800f4
19 changed files with 2450 additions and 715 deletions

View File

@ -481,6 +481,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_mutator.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_printer.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_simplifier.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_visitor.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/kernel.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/llvm_codegen.cpp

View File

@ -11,6 +11,40 @@ namespace jit {
using namespace torch::jit::tensorexpr;
using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
#define IS_NODE(T, node) \
{ \
auto* node_ = dynamic_cast<const T*>(node); \
EXPECT_NE(nullptr, node_) << "Expected node to be " #T; \
}
#define IS_NODE_WITH_NAME(T, node, name) \
auto* name = dynamic_cast<const T*>(node); \
EXPECT_NE(nullptr, name) << "Expected " #name " to be " #T;
#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \
const T* name = nullptr; \
{ \
auto* node_ = dynamic_cast<const Cast*>(node); \
EXPECT_NE(nullptr, node_); \
EXPECT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \
name = dynamic_cast<const T*>(node_->src_value()); \
} \
EXPECT_NE(nullptr, name) << "Expected " #name " to be " #T;
#define IS_IMM_WITH_VAL(T, node, val) \
{ \
auto* node_ = dynamic_cast<const T##Imm*>(node); \
EXPECT_NE(nullptr, node_) << "Expected node to be " #T "Imm"; \
EXPECT_EQ(node_->value(), val) << "Expected Imm to be " << val; \
}
#define IS_VAR_WITH_NAME(node, name) \
{ \
auto* node_ = dynamic_cast<const Var*>(node); \
EXPECT_NE(nullptr, node_) << "Expected node to be Var"; \
EXPECT_EQ(node_->name_hint(), name) << "Expected var to be " #name; \
}
void testConstantFoldSimple() {
KernelScope kernel_scope;
ExprHandle a(2.0f);
@ -133,17 +167,33 @@ void testConstantFoldIntrinsics() {
void testConstantFoldWithVar() {
KernelScope kernel_scope;
VarHandle x("x", kFloat);
ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f));
{
VarHandle x("x", kInt);
ExprHandle body = x * (ExprHandle(2) + ExprHandle(4));
ExprHandle newF = IRSimplifier::simplify(body);
const Mul* root = newF.AsNode<Mul>();
EXPECT_NE(root, nullptr);
EXPECT_NE(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
ExprHandle newF = IRSimplifier::simplify(body);
const Mul* root = newF.AsNode<Mul>();
EXPECT_NE(root, nullptr);
EXPECT_NE(dynamic_cast<const IntImm*>(root->lhs()), nullptr);
ExprHandle result = Let::make(x, ExprHandle(3.f), newF);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<float>(), 3 * (2 + 4));
ExprHandle result = Let::make(x, ExprHandle(3), newF);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<int>(), 3 * (2 + 4));
}
{
VarHandle x("x", kFloat);
ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f));
ExprHandle newF = IRSimplifier::simplify(body);
const Mul* root = newF.AsNode<Mul>();
EXPECT_NE(root, nullptr);
EXPECT_NE(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
ExprHandle result = Let::make(x, ExprHandle(3.f), newF);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<float>(), 3 * (2 + 4));
}
}
void testUnFoldableExpr() {
@ -228,34 +278,22 @@ void testHashEquivalenceAfterFolding() {
ExprHandle a(2.0f);
ExprHandle b(3.0f);
ExprHandle c(5.0f);
ExprHandle f = ((a + b) * x) * (c * x);
const Mul* root = f.AsNode<Mul>();
EXPECT_NE(root, nullptr);
ExprHandle f1 = ((a + b) * x);
ExprHandle f2 = (c * x);
HashProvider hasher;
auto hash_f = hasher.hash(f.node());
auto hash_l = hasher.hash(root->lhs());
auto hash_r = hasher.hash(root->rhs());
auto hash_l = hasher.hash(f1.node());
auto hash_r = hasher.hash(f2.node());
// Root not equal to either branch, and branches not equal.
EXPECT_NE(hash_f, hash_l);
EXPECT_NE(hash_f, hash_r);
EXPECT_NE(hash_l, hash_r);
ExprHandle newF = IRSimplifier::simplify(f);
ExprHandle ff1 = IRSimplifier::simplify(f1);
ExprHandle ff2 = IRSimplifier::simplify(f2);
const Mul* newRoot = newF.AsNode<Mul>();
EXPECT_NE(newRoot, nullptr);
auto hash_l_n = hasher.hash(ff1.node());
auto hash_r_n = hasher.hash(ff2.node());
auto hash_f_n = hasher.hash(newF.node());
auto hash_l_n = hasher.hash(newRoot->lhs());
auto hash_r_n = hasher.hash(newRoot->rhs());
// Root not equal to either branch.
EXPECT_NE(hash_f_n, hash_l_n);
EXPECT_NE(hash_f_n, hash_r_n);
// but branches are now equal.
// branches are now equal.
EXPECT_EQ(hash_l_n, hash_r_n);
}
@ -343,11 +381,16 @@ void testHashLargeExpression() {
EXPECT_NE(hash_t, hash_f);
}
/// (2.f + x) + 4.f => x + 6.f
/// (2 + x) + 4 => x + 6
void testSimplifyAdd() {
KernelScope kernel_scope;
VarHandle x("x", kFloat);
ExprHandle body = (ExprHandle(2.f) + x) + ExprHandle(4.f);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
VarHandle m("m", kInt);
VarHandle n("n", kInt);
VarHandle n_1("n_1", kInt);
ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4);
ExprHandle simplified = IRSimplifier::simplify(body);
const Add* root = simplified.AsNode<Add>();
@ -355,51 +398,43 @@ void testSimplifyAdd() {
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
EXPECT_NE(lhs, nullptr);
EXPECT_EQ(lhs->name_hint(), "x");
const FloatImm* rhs = dynamic_cast<const FloatImm*>(root->rhs());
const IntImm* rhs = dynamic_cast<const IntImm*>(root->rhs());
EXPECT_NE(rhs, nullptr);
EXPECT_EQ(rhs->value(), 6.f);
}
/// (2.f - x) - 4.f => -2.f - x
/// (2 - x) - 4 => -2 - x
void testSimplifySub() {
KernelScope kernel_scope;
VarHandle x("x", kFloat);
ExprHandle body = (ExprHandle(2.f) - x) - ExprHandle(4.f);
VarHandle x("x", kInt);
ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4);
ExprHandle simplified = IRSimplifier::simplify(body);
const Sub* root = simplified.AsNode<Sub>();
EXPECT_NE(root, nullptr);
const FloatImm* lhs = dynamic_cast<const FloatImm*>(root->lhs());
const IntImm* lhs = dynamic_cast<const IntImm*>(root->lhs());
EXPECT_NE(lhs, nullptr);
EXPECT_EQ(lhs->value(), -2.f);
EXPECT_EQ(lhs->value(), -2);
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
EXPECT_NE(rhs, nullptr);
EXPECT_EQ(rhs->name_hint(), "x");
}
/// 2.f * (1.f - x) - 4.f => -6.f - (x * 2.f)
/// 2 * (1 - x) - 4 => -2 * (x + 3)
void testSimplifyMultiLayer() {
KernelScope kernel_scope;
VarHandle x("x", kFloat);
ExprHandle body = ExprHandle(2.f) * ((ExprHandle(1.f) - x) - ExprHandle(4.f));
VarHandle x("x", kInt);
ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4));
ExprHandle simplified = IRSimplifier::simplify(body);
const Sub* root = simplified.AsNode<Sub>();
EXPECT_NE(root, nullptr);
const FloatImm* lhs = dynamic_cast<const FloatImm*>(root->lhs());
EXPECT_NE(lhs, nullptr);
EXPECT_EQ(lhs->value(), -6.f);
const Mul* rhs = dynamic_cast<const Mul*>(root->rhs());
EXPECT_NE(rhs, nullptr);
const Var* varX = dynamic_cast<const Var*>(rhs->lhs());
EXPECT_NE(varX, nullptr);
EXPECT_EQ(varX->name_hint(), "x");
const FloatImm* mulRhs = dynamic_cast<const FloatImm*>(rhs->rhs());
EXPECT_NE(mulRhs, nullptr);
EXPECT_EQ(mulRhs->value(), 2.f);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), -2);
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
IS_VAR_WITH_NAME(add->lhs(), "x");
IS_IMM_WITH_VAL(Int, add->rhs(), 3);
}
/// 2 * (3 * x) - (x * 4) => x * 2
/// 2 * (3 * x) - (x * 4) => 2 * x
void testSimplifyMultiTerm() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
@ -409,30 +444,30 @@ void testSimplifyMultiTerm() {
ExprHandle simplified = IRSimplifier::simplify(body);
const Mul* root = simplified.AsNode<Mul>();
EXPECT_NE(root, nullptr);
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
const IntImm* lhs = dynamic_cast<const IntImm*>(root->lhs());
EXPECT_NE(lhs, nullptr);
EXPECT_EQ(lhs->name_hint(), "x");
const IntImm* rhs = dynamic_cast<const IntImm*>(root->rhs());
EXPECT_EQ(lhs->value(), 2);
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
EXPECT_NE(rhs, nullptr);
EXPECT_EQ(rhs->value(), 2);
EXPECT_EQ(rhs->name_hint(), "x");
}
/// 2 * (3 * (f)x) - (x * 4) => x * 2.f
/// 2 * (3 * (long)x) - (x * 4) => 2 * x
void testSimplifyCasts() {
KernelScope kernel_scope;
VarHandle x("x", kFloat);
VarHandle x("x", kLong);
ExprHandle body =
(ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4)));
ExprHandle simplified = IRSimplifier::simplify(body);
const Mul* root = simplified.AsNode<Mul>();
EXPECT_NE(root, nullptr);
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
const LongImm* lhs = dynamic_cast<const LongImm*>(root->lhs());
EXPECT_NE(lhs, nullptr);
EXPECT_EQ(lhs->name_hint(), "x");
const FloatImm* rhs = dynamic_cast<const FloatImm*>(root->rhs());
EXPECT_EQ(lhs->value(), 2);
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
EXPECT_NE(rhs, nullptr);
EXPECT_EQ(rhs->value(), 2);
EXPECT_EQ(rhs->name_hint(), "x");
}
/// (x + 0) * 1 => x
@ -452,20 +487,39 @@ void testSimplifyMultiVar() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
ExprHandle body = y * 24 + x * 34;
ExprHandle body = x * 24 + y * 34;
ExprHandle simplified = IRSimplifier::simplify(body);
const Add* root = simplified.AsNode<Add>();
EXPECT_NE(root, nullptr);
const Mul* lhs = dynamic_cast<const Mul*>(root->lhs());
EXPECT_NE(lhs, nullptr);
const Var* varY = dynamic_cast<const Var*>(lhs->lhs());
EXPECT_EQ(varY->name_hint(), "y");
const Var* varX = dynamic_cast<const Var*>(lhs->rhs());
EXPECT_NE(varX, nullptr);
EXPECT_EQ(varX->name_hint(), "y");
const Mul* rhs = dynamic_cast<const Mul*>(root->rhs());
EXPECT_NE(rhs, nullptr);
const Var* varX = dynamic_cast<const Var*>(rhs->lhs());
EXPECT_NE(varX, nullptr);
EXPECT_EQ(varX->name_hint(), "x");
const Var* varY = dynamic_cast<const Var*>(rhs->rhs());
EXPECT_NE(varY, nullptr);
EXPECT_EQ(varY->name_hint(), "x");
}
// x + 2 + y => x + y + 2
void testSimplifyReorderings() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
ExprHandle body = x + 2 + y;
ExprHandle simplified = IRSimplifier::simplify(body);
const Add* root = simplified.AsNode<Add>();
EXPECT_NE(root, nullptr);
IS_NODE_WITH_NAME(Add, root->lhs(), rhs);
IS_VAR_WITH_NAME(rhs->lhs(), "x");
IS_VAR_WITH_NAME(rhs->rhs(), "y");
IS_IMM_WITH_VAL(Int, root->rhs(), 2);
}
/// y + x * 0 => y
@ -476,9 +530,621 @@ void testSimplifyEliminatesVar() {
ExprHandle body = y + x * ExprHandle(0);
ExprHandle simplified = IRSimplifier::simplify(body);
const Var* root = simplified.AsNode<Var>();
EXPECT_NE(root, nullptr);
EXPECT_EQ(root->name_hint(), "y");
IS_VAR_WITH_NAME(simplified.node(), "y");
}
void testSimplifyAdds() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
{
// (x + y) + (x + y) => 2 * (x + y)
ExprHandle body = (x + y) + (x + y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), root);
IS_IMM_WITH_VAL(Int, root->lhs(), 2);
IS_NODE_WITH_NAME(Add, root->rhs(), add);
IS_VAR_WITH_NAME(add->lhs(), "x");
IS_VAR_WITH_NAME(add->rhs(), "y");
}
{
// (x * y) + (x * y) => 2 * (x * y)
ExprHandle body = (x * y) + (x * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), root);
IS_IMM_WITH_VAL(Int, root->lhs(), 2);
IS_NODE_WITH_NAME(Mul, root->rhs(), mul);
IS_VAR_WITH_NAME(mul->lhs(), "x");
IS_VAR_WITH_NAME(mul->rhs(), "y");
}
{
// (x - y) + (x - y) => -2 * (y - x)
ExprHandle body = (x - y) + (x - y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), -2);
IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
IS_VAR_WITH_NAME(rhs->lhs(), "y");
IS_VAR_WITH_NAME(rhs->rhs(), "x");
}
}
void testSimplifyMuls() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
{
// (x + y) * (x + y) => (x + y) * (x + y)
// We don't attempt to simplify mulitplication of polynomials since the
// result is only very rarely more efficient.
ExprHandle body = (x + y) * (x + y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_NODE_WITH_NAME(Add, mul->lhs(), lhs);
IS_VAR_WITH_NAME(lhs->lhs(), "x");
IS_VAR_WITH_NAME(lhs->rhs(), "y");
IS_NODE_WITH_NAME(Add, mul->rhs(), rhs);
IS_VAR_WITH_NAME(rhs->lhs(), "x");
IS_VAR_WITH_NAME(rhs->rhs(), "y");
}
{
// x * y * x * y => x * x * y * y
// These get reordered only.
ExprHandle body = x * y * x * y;
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul1);
IS_NODE_WITH_NAME(Mul, mul1->lhs(), mul2);
IS_NODE_WITH_NAME(Mul, mul2->lhs(), mul3);
IS_VAR_WITH_NAME(mul1->rhs(), "y");
IS_VAR_WITH_NAME(mul2->rhs(), "y");
IS_VAR_WITH_NAME(mul3->lhs(), "x");
IS_VAR_WITH_NAME(mul3->rhs(), "x");
}
{
// (x - y) * (x - y) => (x - y) * (x - y)
// As with Add we don't attempt simplification of this.
ExprHandle body = (x - y) * (x - y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_NODE_WITH_NAME(Sub, mul->lhs(), lhs);
IS_VAR_WITH_NAME(lhs->lhs(), "x");
IS_VAR_WITH_NAME(lhs->rhs(), "y");
IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
IS_VAR_WITH_NAME(rhs->lhs(), "x");
IS_VAR_WITH_NAME(rhs->rhs(), "y");
}
{
// (x + y) * (x - y) => (x - y) * (x - y)
// Don't simplify with different ops on each side.
ExprHandle body = (x + y) * (x - y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_NODE_WITH_NAME(Add, mul->lhs(), lhs);
IS_VAR_WITH_NAME(lhs->lhs(), "x");
IS_VAR_WITH_NAME(lhs->rhs(), "y");
IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
IS_VAR_WITH_NAME(rhs->lhs(), "x");
IS_VAR_WITH_NAME(rhs->rhs(), "y");
}
}
// Sub an expr from itself will result in zero.
void testSimplifySubs() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
{
// (x + y) - (x + y) => 0
ExprHandle body = (x + y) - (x + y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
}
{
// (x * y) - (x * y) => 0
ExprHandle body = (x * y) - (x * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
}
{
// (x - y) - (x - y) => 0
ExprHandle body = (x - y) - (x - y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
}
{
// (x + y) - 2 * (x + y) => -1 * (x + y)
ExprHandle body = (x + y) - ExprHandle(2) * (x + y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), -1);
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
IS_VAR_WITH_NAME(add->lhs(), "x");
IS_VAR_WITH_NAME(add->rhs(), "y");
}
{
// (x + y) - y => x
ExprHandle body = (x + y) - y;
ExprHandle simplified = IRSimplifier::simplify(body);
IS_VAR_WITH_NAME(simplified.node(), "x");
}
{
// (x - y) - y => x - 2 * y
ExprHandle body = (x - y) - y;
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
IS_VAR_WITH_NAME(sub->lhs(), "x");
IS_NODE_WITH_NAME(Mul, sub->rhs(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
IS_VAR_WITH_NAME(mul->rhs(), "y");
}
{
// 2 * x - x => x
ExprHandle body = (ExprHandle(2) * x) - x;
ExprHandle simplified = IRSimplifier::simplify(body);
IS_VAR_WITH_NAME(simplified.node(), "x");
}
{
// x - 2 * x = -1 * x
// We don't have a unary negate, but this could be 0 -x I guess?
ExprHandle body = x - (ExprHandle(2) * x);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), -1);
IS_VAR_WITH_NAME(mul->rhs(), "x");
}
{
// (x + y + 5) * (x - x) => 0
// Cancelling out one side of Mul cancels both.
ExprHandle body = (x + y + 5) * (x - x);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
}
}
// Test that mixing ops together simplifies as expected.
void testSimplifyMultiOp() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
{
// (x * y) + (x - y) => (x * y) + x - y
//
ExprHandle body = (x * y) + (x - y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
IS_NODE_WITH_NAME(Add, sub->lhs(), add);
IS_NODE_WITH_NAME(Mul, add->lhs(), mul);
IS_VAR_WITH_NAME(mul->lhs(), "x");
IS_VAR_WITH_NAME(mul->rhs(), "y");
IS_VAR_WITH_NAME(add->rhs(), "x");
IS_VAR_WITH_NAME(sub->rhs(), "y");
}
{
// (x + y) - (x * y) => x + y - (x * y)
ExprHandle body = (x + y) - (x * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
IS_NODE_WITH_NAME(Add, sub->lhs(), add);
IS_NODE_WITH_NAME(Mul, sub->rhs(), mul);
IS_VAR_WITH_NAME(add->lhs(), "x");
IS_VAR_WITH_NAME(add->rhs(), "y");
IS_VAR_WITH_NAME(mul->lhs(), "x");
IS_VAR_WITH_NAME(mul->rhs(), "y");
}
{
// (x - y) - (x + y) => -2 * y
ExprHandle body = (x - y) - (x + y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), -2);
IS_VAR_WITH_NAME(mul->rhs(), "y");
}
}
// Test that chaining many ops together works as expected.
void testSimplifyManyOps() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
{
// x + y + x + x + y + y + x + y + x = 4 * y + 5 * x
ExprHandle body = x + y + x + x + y + y + x + y + x;
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Add, simplified.node(), add);
IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
IS_IMM_WITH_VAL(Int, lhs->lhs(), 4);
IS_VAR_WITH_NAME(lhs->rhs(), "y");
IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
IS_IMM_WITH_VAL(Int, rhs->lhs(), 5);
IS_VAR_WITH_NAME(rhs->rhs(), "x");
}
{
// x - y + x + x - y - y + x - y + x = 5 * x - 4 * y
ExprHandle body = x - y + x + x - y - y + x - y + x;
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Sub, simplified.node(), add);
IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
IS_IMM_WITH_VAL(Int, lhs->lhs(), 5);
IS_VAR_WITH_NAME(lhs->rhs(), "x");
IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
IS_IMM_WITH_VAL(Int, rhs->lhs(), 4);
IS_VAR_WITH_NAME(rhs->rhs(), "y");
}
{
// x + y + x - x - y - y + x + y + x = 3 * x
ExprHandle body = x + y + x - x - y - y + x + y + x;
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), 3);
IS_VAR_WITH_NAME(mul->rhs(), "x");
}
}
void testSimplifyFactorization() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
{
// (2 * x) + (2 * y) => 2 * (x + y)
ExprHandle body = (ExprHandle(2) * x + ExprHandle(2) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
IS_VAR_WITH_NAME(add->lhs(), "x");
IS_VAR_WITH_NAME(add->rhs(), "y");
}
{
// Factorization when scalars have common divider.
// (2 * x) + (4 * y) => 2 * (2 * y + x)
ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
IS_NODE_WITH_NAME(Mul, add->lhs(), mul2);
IS_IMM_WITH_VAL(Int, mul2->lhs(), 2);
IS_VAR_WITH_NAME(mul2->rhs(), "y");
IS_VAR_WITH_NAME(add->rhs(), "x");
}
{
// Factorization attempt without a common divider.
// (2 * x) + (5 * y) => (5 * y) + (2 * x)
ExprHandle body = (ExprHandle(2) * x + ExprHandle(5) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Add, simplified.node(), add);
IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
IS_IMM_WITH_VAL(Int, lhs->lhs(), 5);
IS_VAR_WITH_NAME(lhs->rhs(), "y");
IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
IS_IMM_WITH_VAL(Int, rhs->lhs(), 2);
IS_VAR_WITH_NAME(rhs->rhs(), "x");
}
{
// Factorization after merging.
// (2 * x) + (4 * y) + (8 * x + 6 * y) => 10 * (x + y)
ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y) +
(ExprHandle(8) * x + ExprHandle(6) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), 10);
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
IS_VAR_WITH_NAME(add->lhs(), "x");
IS_VAR_WITH_NAME(add->rhs(), "y");
}
{
// Factorization with common divider but different signs.
// (-2 * x) + (4 * y) => -2 * (x - 2 * y)
ExprHandle body = (ExprHandle(-2) * x + ExprHandle(4) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), -2);
IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
IS_VAR_WITH_NAME(sub->lhs(), "x");
IS_NODE_WITH_NAME(Mul, sub->rhs(), mul2);
IS_IMM_WITH_VAL(Int, mul2->lhs(), 2);
IS_VAR_WITH_NAME(mul2->rhs(), "y");
}
}
// (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (3 * z + y + 4 * x)
void testSimplifyFactorizeUneven() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
VarHandle z("z", kInt);
ExprHandle body =
(ExprHandle(4) * x + y + z * 2) + (ExprHandle(4) * x + y + z * 4);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), root);
IS_IMM_WITH_VAL(Int, root->lhs(), 2);
IS_NODE_WITH_NAME(Add, root->rhs(), add1);
IS_NODE_WITH_NAME(Add, add1->lhs(), add2);
IS_NODE_WITH_NAME(Mul, add1->rhs(), xmul);
IS_NODE_WITH_NAME(Mul, add2->lhs(), zmul);
IS_IMM_WITH_VAL(Int, zmul->lhs(), 3);
IS_VAR_WITH_NAME(zmul->rhs(), "z");
IS_VAR_WITH_NAME(add2->rhs(), "y");
IS_IMM_WITH_VAL(Int, xmul->lhs(), 4);
IS_VAR_WITH_NAME(xmul->rhs(), "x");
}
// (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y)
// This is kind of a placeholder test for variable factorization.
void testSimplifyDeeperTerms() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
ExprHandle body = (x * y) + (ExprHandle(2) * x) * (x + y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Add, simplified.node(), add);
IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm);
IS_VAR_WITH_NAME(xxTerm->lhs(), "x");
IS_VAR_WITH_NAME(xxTerm->rhs(), "x");
IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
IS_IMM_WITH_VAL(Int, rhs->lhs(), 3);
IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm);
IS_VAR_WITH_NAME(xyTerm->lhs(), "x");
IS_VAR_WITH_NAME(xyTerm->rhs(), "y");
}
// Tests the difference between two less trivial expressions.
// (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1
void testSimplifyDeeperDifference() {
KernelScope kernel_scope;
VarHandle n("n", kInt);
VarHandle n_1("n_1", kInt);
VarHandle m("m", kInt);
ExprHandle body =
(m * (ExprHandle(1) * n_1) + (n + 1)) - (m * (ExprHandle(1) * n_1) + n);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_IMM_WITH_VAL(Int, simplified.node(), 1);
}
// Test constant folding into the difference between expressions.
// 2 + char((m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n)) => 3
void testSimplifyFoldComplexDifference() {
KernelScope kernel_scope;
VarHandle n("n", kInt);
VarHandle n_1("n_1", kInt);
VarHandle m("m", kInt);
ExprHandle body =
(IntImm::make(2) +
(Cast::make(
kChar,
(m * (ExprHandle(1) * n_1) + (n + 1)) -
(m * (ExprHandle(1) * n_1) + n))));
ExprHandle simplified = IRSimplifier::simplify(body);
IS_IMM_WITH_VAL(Int, simplified.node(), 3);
}
void testSimplifyIfComponents() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
ExprHandle body = IfThenElse::make(
((ExprHandle(5) - ExprHandle(4)) * x) > y,
ExprHandle(2) * x - x,
ExprHandle(2) * y - y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(IfThenElse, simplified.node(), ifexpr);
IS_NODE_WITH_NAME(CompareSelect, ifexpr->condition(), cmp);
EXPECT_EQ(cmp->compare_select_op(), kGT);
IS_VAR_WITH_NAME(cmp->lhs(), "x");
IS_VAR_WITH_NAME(cmp->rhs(), "y");
IS_VAR_WITH_NAME(ifexpr->true_value(), "x");
IS_VAR_WITH_NAME(ifexpr->false_value(), "y");
}
void testSimplifyOpaqueTerms() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
{
// 2 * x/y * x - x/y * y => y * x/y
ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_VAR_WITH_NAME(mul->lhs(), "y");
IS_NODE_WITH_NAME(Div, mul->rhs(), div);
IS_VAR_WITH_NAME(div->lhs(), "x");
IS_VAR_WITH_NAME(div->rhs(), "y");
}
{
// x%y - (x%y - 1) => 1
ExprHandle body = (x % y) - ((x % y) - 1);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_IMM_WITH_VAL(Int, simplified.node(), 1);
}
}
void testSimplifyWontReorderFloat() {
KernelScope kernel_scope;
{
// 3 * (3 * x) - 3 * (3 * y) => -9 * (y - x)
// This is an expression we can simplify.
VarHandle x("x", kInt);
VarHandle y("y", kInt);
ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
ExprHandle(3) * (ExprHandle(3) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), -9);
IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
IS_VAR_WITH_NAME(sub->lhs(), "y");
IS_VAR_WITH_NAME(sub->rhs(), "x");
}
{
// 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - 3 * (3 * y).
// If the vars are floating point, ops are not associative and we can't
// reorder.
VarHandle x("x", kFloat);
VarHandle y("y", kFloat);
ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
ExprHandle(3) * (ExprHandle(3) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3);
IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul);
IS_IMM_WITH_VAL(Float, lhsVarMul->lhs(), 3);
IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul);
IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3);
IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul);
IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3);
IS_VAR_WITH_NAME(rhsVarMul->rhs(), "y");
}
{
// 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - (9 * y).
// We will simplify subexprs if they dont reorder floating point ops.
VarHandle x("x", kDouble);
VarHandle y("y", kInt);
ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
ExprHandle(3) * (ExprHandle(3) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
IS_IMM_WITH_VAL(Double, lhsMul->lhs(), 3);
IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul);
IS_IMM_WITH_VAL(Double, lhsVarMul->lhs(), 3);
IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
IS_NODE_WITH_NAME_AND_CAST(Mul, sub->rhs(), rhsMul, Double);
IS_IMM_WITH_VAL(Int, rhsMul->lhs(), 9);
IS_VAR_WITH_NAME(rhsMul->rhs(), "y");
}
{
// Prevent reordering if FP propagated from dtypes.
VarHandle x("x", kInt);
VarHandle y("y", kInt);
ExprHandle body = ExprHandle(3.f) * (ExprHandle(3) * x) -
ExprHandle(3) * (ExprHandle(3.f) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3);
IS_NODE_WITH_NAME_AND_CAST(Mul, lhsMul->rhs(), lhsVarMul, Float);
IS_IMM_WITH_VAL(Int, lhsVarMul->lhs(), 3);
IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul);
IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3);
IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul);
IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3);
IS_NODE_WITH_NAME(Cast, rhsVarMul->rhs(), yCast);
IS_VAR_WITH_NAME(yCast->src_value(), "y");
}
{
VarHandle x("x", kFloat);
VarHandle y("y", kFloat);
// x%y - (x%y - 1) => x%y - (x%y - 1).
// We wont reorder opaque ops if they are FP.
ExprHandle body = (x % y) - ((x % y) - 1);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
IS_NODE_WITH_NAME(Mod, sub->lhs(), lhsMod);
IS_VAR_WITH_NAME(lhsMod->lhs(), "x");
IS_VAR_WITH_NAME(lhsMod->rhs(), "y");
IS_NODE_WITH_NAME(Sub, sub->rhs(), rhsSub);
IS_NODE_WITH_NAME(Mod, rhsSub->lhs(), rhsMod);
IS_VAR_WITH_NAME(rhsMod->lhs(), "x");
IS_VAR_WITH_NAME(rhsMod->rhs(), "y");
IS_IMM_WITH_VAL(Float, rhsSub->rhs(), 1);
}
}
} // namespace jit

View File

@ -9,105 +9,119 @@
namespace torch {
namespace jit {
#define TH_FORALL_TESTS(_) \
_(ExprBasicValueTest) \
_(ExprBasicValueTest02) \
_(ExprLetTest01) \
_(ExprLetStmtTest01) \
_(ExprLetTest02) \
_(ExprIntTest) \
_(ExprFloatTest) \
_(ExprByteTest) \
_(ExprCharTest) \
_(ExprShortTest) \
_(ExprLongTest) \
_(ExprHalfTest) \
_(ExprDoubleTest) \
_(ExprVectorAdd01) \
_(ExprCompareSelectEQ) \
_(ExprSubstitute01) \
_(ExprMath01) \
_(ExprUnaryMath01) \
_(ExprBinaryMath01) \
_(ExprDynamicShapeAdd) \
_(ExprBitwiseOps) \
_(IRPrinterBasicValueTest) \
_(IRPrinterBasicValueTest02) \
_(IRPrinterLetTest01) \
_(IRPrinterLetTest02) \
_(IRPrinterCastTest) \
_(ExprSimple01) \
_(ExprLower01) \
_(ExprSimple02) \
_(ExprSplitWithTailNone) \
_(ExprSplitWithMask01) \
_(ScheduleBroadcastAddBuffer) \
_(ScheduleFunctionCall01) \
_(ScheduleInlineFunc01) \
_(ScheduleFuserStyle) \
_(ScheduleFuserThreeArg) \
_(ScheduleDynamicShape2D) \
_(TypeTest01) \
_(TypePropagation) \
_(Cond01) \
_(IfThenElse01) \
_(IfThenElse02) \
_(ATen_cast_Float) \
_(ATennegInt) \
_(ATennegFloat) \
_(ATenaddInt) \
_(ATenaddFloat) \
_(ATensubInt) \
_(ATensubFloat) \
_(ATenlerp) \
_(ATenaddcmulInt) \
_(ATenaddcmulFloat) \
_(ATenmulInt) \
_(ATenmulFloat) \
_(ATendivInt) \
_(ATendivFloat) \
_(ATenmaxInt) \
_(ATenmaxFloat) \
_(ATenminInt) \
_(ATenminFloat) \
_(ATen_sigmoid_backward) \
_(ATen_tanh_backward) \
_(ATenreciprocal) \
_(ATenreluInt) \
_(ATenreluFloat) \
_(ATenlogFloat) \
_(ATenlog10Float) \
_(ATenlog2Float) \
_(ATenexpFloat) \
_(ATenerfFloat) \
_(ATencosFloat) \
_(ATeneqInt) \
_(ATengeInt) \
_(ATengtInt) \
_(ATenleInt) \
_(ATenltInt) \
_(ConstantFoldSimple) \
_(ConstantFoldTwoLayer) \
_(ConstantFoldShifts) \
_(ConstantFoldBitwise) \
_(ConstantFoldMultiOp) \
_(ConstantFoldMinMax) \
_(ConstantFoldIntrinsics) \
_(ConstantFoldWithVar) \
_(UnFoldableExpr) \
_(HashSimple) \
_(HashEquivalence) \
_(HashEquivalenceAfterFolding) \
_(HashDifferenceTypes) \
_(HashLargeExpression) \
_(SimplifyAdd) \
_(SimplifySub) \
_(SimplifyMultiLayer) \
_(SimplifyMultiTerm) \
_(SimplifyCasts) \
_(SimplifyEliminatesNoOps) \
_(SimplifyMultiVar) \
_(SimplifyEliminatesVar) \
#define TH_FORALL_TESTS(_) \
_(ExprBasicValueTest) \
_(ExprBasicValueTest02) \
_(ExprLetTest01) \
_(ExprLetStmtTest01) \
_(ExprLetTest02) \
_(ExprIntTest) \
_(ExprFloatTest) \
_(ExprByteTest) \
_(ExprCharTest) \
_(ExprShortTest) \
_(ExprLongTest) \
_(ExprHalfTest) \
_(ExprDoubleTest) \
_(ExprVectorAdd01) \
_(ExprCompareSelectEQ) \
_(ExprSubstitute01) \
_(ExprMath01) \
_(ExprUnaryMath01) \
_(ExprBinaryMath01) \
_(ExprDynamicShapeAdd) \
_(ExprBitwiseOps) \
_(IRPrinterBasicValueTest) \
_(IRPrinterBasicValueTest02) \
_(IRPrinterLetTest01) \
_(IRPrinterLetTest02) \
_(IRPrinterCastTest) \
_(ExprSimple01) \
_(ExprLower01) \
_(ExprSimple02) \
_(ExprSplitWithTailNone) \
_(ExprSplitWithMask01) \
_(ScheduleBroadcastAddBuffer) \
_(ScheduleFunctionCall01) \
_(ScheduleInlineFunc01) \
_(ScheduleFuserStyle) \
_(ScheduleFuserThreeArg) \
_(ScheduleDynamicShape2D) \
_(TypeTest01) \
_(TypePropagation) \
_(Cond01) \
_(IfThenElse01) \
_(IfThenElse02) \
_(ATen_cast_Float) \
_(ATennegInt) \
_(ATennegFloat) \
_(ATenaddInt) \
_(ATenaddFloat) \
_(ATensubInt) \
_(ATensubFloat) \
_(ATenlerp) \
_(ATenaddcmulInt) \
_(ATenaddcmulFloat) \
_(ATenmulInt) \
_(ATenmulFloat) \
_(ATendivInt) \
_(ATendivFloat) \
_(ATenmaxInt) \
_(ATenmaxFloat) \
_(ATenminInt) \
_(ATenminFloat) \
_(ATen_sigmoid_backward) \
_(ATen_tanh_backward) \
_(ATenreciprocal) \
_(ATenreluInt) \
_(ATenreluFloat) \
_(ATenlogFloat) \
_(ATenlog10Float) \
_(ATenlog2Float) \
_(ATenexpFloat) \
_(ATenerfFloat) \
_(ATencosFloat) \
_(ATeneqInt) \
_(ATengeInt) \
_(ATengtInt) \
_(ATenleInt) \
_(ATenltInt) \
_(ConstantFoldSimple) \
_(ConstantFoldTwoLayer) \
_(ConstantFoldShifts) \
_(ConstantFoldBitwise) \
_(ConstantFoldMultiOp) \
_(ConstantFoldMinMax) \
_(ConstantFoldIntrinsics) \
_(ConstantFoldWithVar) \
_(UnFoldableExpr) \
_(HashSimple) \
_(HashEquivalence) \
_(HashEquivalenceAfterFolding) \
_(HashDifferenceTypes) \
_(HashLargeExpression) \
_(SimplifyAdd) \
_(SimplifySub) \
_(SimplifyMultiLayer) \
_(SimplifyMultiTerm) \
_(SimplifyCasts) \
_(SimplifyEliminatesNoOps) \
_(SimplifyMultiVar) \
_(SimplifyReorderings) \
_(SimplifyEliminatesVar) \
_(SimplifyAdds) \
_(SimplifyMuls) \
_(SimplifySubs) \
_(SimplifyMultiOp) \
_(SimplifyManyOps) \
_(SimplifyFactorization) \
_(SimplifyFactorizeUneven) \
_(SimplifyDeeperTerms) \
_(SimplifyDeeperDifference) \
_(SimplifyFoldComplexDifference) \
_(SimplifyIfComponents) \
_(SimplifyOpaqueTerms) \
_(SimplifyWontReorderFloat) \
_(StmtClone)
#define TH_FORALL_TESTS_LLVM(_) \

View File

@ -205,6 +205,7 @@ libtorch_sources = [
"torch/csrc/jit/tensorexpr/ir.cpp",
"torch/csrc/jit/tensorexpr/ir_mutator.cpp",
"torch/csrc/jit/tensorexpr/ir_printer.cpp",
"torch/csrc/jit/tensorexpr/ir_simplifier.cpp",
"torch/csrc/jit/tensorexpr/ir_visitor.cpp",
"torch/csrc/jit/tensorexpr/kernel.cpp",
"torch/csrc/jit/tensorexpr/llvm_codegen.cpp",

View File

@ -59,24 +59,24 @@ class Value {
void* ptr;
};
#define VALUE_AS_DISPATCH(Type, Name) \
template <> \
inline Type Value::as<Type>() const { \
if (dtype_ != k##Name) { \
throw unsupported_dtype(); \
} \
return Name##values[0]; \
#define VALUE_AS_DISPATCH(Type, Name) \
template <> \
inline Type Value::as<Type>() const { \
if (dtype_ != k##Name) { \
throw unsupported_dtype(); \
} \
return Name##values[0]; \
}
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_DISPATCH);
#undef VALUE_AS_DISPATCH
#define VALUE_AS_VEC_DISPATCH(Type, Name) \
template <> \
inline const std::vector<Type>& Value::as_vec<Type>() const { \
if (dtype_.scalar_type() != ScalarType::Name) { \
throw unsupported_dtype(); \
} \
return Name##values; \
#define VALUE_AS_VEC_DISPATCH(Type, Name) \
template <> \
inline const std::vector<Type>& Value::as_vec<Type>() const { \
if (dtype_.scalar_type() != ScalarType::Name) { \
throw unsupported_dtype(); \
} \
return Name##values; \
}
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_VEC_DISPATCH);
#undef VALUE_AS_VEC_DISPATCH
@ -479,7 +479,6 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {
throw malformed_input(v);
}
if (src_dtype != dst_dtype) {
switch (src_dtype.scalar_type()) {
#define SRC_TYPE_CASE(Type, Name) \
@ -912,6 +911,28 @@ inline Stmt* Substitute(Stmt* stmt, const VarMapping& var_mapping) {
return stmt->accept_mutator(&var_sub);
}
// Uses the evaluator to fold an Expression with constant terms.
// E.g. evaluateOp(Add(3, 4)) => 7.
// Expr v must not have any unbound Vars.
static Expr* evaluateOp(const Expr* v) {
ExprHandle handle(v);
ExprEval<SimpleIREvaluator> eval(handle);
switch (v->dtype().scalar_type()) {
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: { \
Type val = eval.value<Type>(); \
return getImmediateByType(v->dtype().scalar_type(), val); \
}
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
#undef TYPE_CASE
default:
LOG(FATAL) << "Unsupported datatype: " << v->dtype();
return nullptr;
}
return nullptr;
}
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -31,6 +31,8 @@ enum IRNodeType {
kCompareSelect,
kLet,
kCast,
kBroadcast,
kRamp,
kNone
};

View File

@ -1,3 +1,5 @@
#pragma once
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
@ -316,6 +318,13 @@ class HashProvider : public IRVisitor {
putHash(v, hash);
}
template <typename... Types>
SimplifierHashType hash_combine(const Types&... args) {
SimplifierHashType seed = 0;
_hash_combine(seed, args...);
return seed;
}
private:
SimplifierHashType hashOf(const Expr* e) {
auto it = exprToHash_.find(e);
@ -371,6 +380,10 @@ class HashProvider : public IRVisitor {
(seed << 7) + (seed >> 4);
}
void _hash_combine(SimplifierHashType& seed, const Expr* e) {
_hash_combine(seed, hash(e));
}
template <typename T, typename... Types>
void _hash_combine(
SimplifierHashType& seed,
@ -380,13 +393,6 @@ class HashProvider : public IRVisitor {
_hash_combine(seed, args...);
}
template <typename... Types>
SimplifierHashType hash_combine(const Types&... args) {
SimplifierHashType seed = 0;
_hash_combine(seed, args...);
return seed;
}
void putHash(const KernelScopedObject* e, SimplifierHashType h) {
auto res = exprToHash_.emplace(e, h);
if (res.second == false) {

View File

@ -283,24 +283,94 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE);
// Get immediate by ScalarType.
template <typename T>
ExprHandle getImmediateByType(ScalarType immType, T initialVal) {
Expr* getImmediateByType(ScalarType immType, T initialVal) {
switch (immType) {
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: \
return Name##Imm::make(initialVal);
return new Name##Imm(initialVal);
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
#undef TYPE_CASE
default:
throw unsupported_dtype();
}
return ExprHandle();
return nullptr;
}
template <typename T>
ExprHandle getImmediateByType(Dtype dtype, T initialVal) {
Expr* getImmediateByType(Dtype dtype, T initialVal) {
return getImmediateByType<T>(dtype.scalar_type(), initialVal);
}
template <typename T>
T immediateAs(const Expr* e) {
#define TYPE_CASE(Type, Name) \
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(e)) { \
return imm->value(); \
}
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
#undef TYPE_CASE
throw unsupported_dtype();
return 0;
}
template <typename T>
bool immediateEquals(const Expr* e, T val) {
#define TYPE_CASE(Type, Name) \
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(e)) { \
return imm->value() == val; \
}
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
#undef TYPE_CASE
throw unsupported_dtype();
return false;
}
template <typename T>
bool immediateIsNegative(const T* e) {
#define TYPE_CASE(Type, Name) \
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(e)) { \
return imm->value() < 0; \
}
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
#undef TYPE_CASE
return false;
}
// Creates a new Expr of the given type with the provided lhs and rhs.
static const Expr* newBinaryOpOfType(
IRNodeType expr_type,
const Expr* lhs,
const Expr* rhs,
bool option) {
switch (expr_type) {
case IRNodeType::kAdd:
return new Add(lhs, rhs);
case IRNodeType::kSub:
return new Sub(lhs, rhs);
case IRNodeType::kMul:
return new Mul(lhs, rhs);
case IRNodeType::kDiv:
return new Div(lhs, rhs);
case IRNodeType::kMod:
return new Mod(lhs, rhs);
case IRNodeType::kMax:
return new Max(lhs, rhs, option);
case IRNodeType::kMin:
return new Min(lhs, rhs, option);
case IRNodeType::kAnd:
return new And(lhs, rhs);
case IRNodeType::kXor:
return new Xor(lhs, rhs);
case IRNodeType::kLshift:
return new Lshift(lhs, rhs);
case IRNodeType::kRshift:
return new Rshift(lhs, rhs);
default:
LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
return nullptr;
}
}
// Bind the value to the var and evaluate the body.
class Let : public ExprNode<Let> {
public:
@ -354,7 +424,7 @@ class Ramp : public ExprNode<Ramp> {
}
Ramp(const Expr* base, const Expr* stride, int lanes)
: ExprNodeBase(Dtype(base->dtype(), lanes)),
: ExprNodeBase(Dtype(base->dtype(), lanes), kRamp),
base_(base),
stride_(stride),
lanes_(lanes) {
@ -420,7 +490,7 @@ class Broadcast : public ExprNode<Broadcast> {
return ExprHandle(new Broadcast(value.node(), lanes));
}
Broadcast(const Expr* value, int lanes)
: ExprNodeBase(Dtype(value->dtype(), lanes)),
: ExprNodeBase(Dtype(value->dtype(), lanes), kBroadcast),
value_(value),
lanes_(lanes) {}
@ -562,8 +632,7 @@ class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
const ExprHandle& ret_val1,
const ExprHandle& ret_val2,
CompareSelectOperation cmp_op) {
if (lhs.dtype() != rhs.dtype() ||
ret_val1.dtype() != ret_val2.dtype()) {
if (lhs.dtype() != rhs.dtype() || ret_val1.dtype() != ret_val2.dtype()) {
throw malformed_input();
}
return ExprHandle(new CompareSelect(
@ -790,48 +859,8 @@ class Intrinsics : public CallNode<Intrinsics> {
IntrinsicsOp op_type_;
};
/* An internal only Expr used in IR simplification.
* Encodes relationship y = Ax + B, where A and B are Immediates.
* Not required to be implemented by codegen. */
class LinearForm : public ExprNode<LinearForm> {
public:
LinearForm(const Expr* x, const Expr* A, const Expr* B)
: ExprNodeBase(dtypeFor(x, A, B)), x_(x), A_(A), B_(B) {}
LinearForm(const Expr* x)
: ExprNodeBase(x->dtype()),
x_(x),
A_(new CharImm(1)),
B_(new CharImm(0)) {}
const Expr* getX() const {
return x_;
}
const Expr* getA() const {
return A_;
}
const Expr* getB() const {
return B_;
}
void setA(const Expr* A) {
A_ = A;
}
void setB(const Expr* B) {
B_ = B;
}
static Dtype dtypeFor(const Expr* A, const Expr* B, const Expr* C) {
return ToDtype(promoteTypes(
A->dtype().scalar_type(), promoteTypes(B->dtype(), C->dtype())));
}
private:
const Expr* x_;
const Expr* A_;
const Expr* B_;
};
class Polynomial;
class Term;
class FunctionCall;

View File

@ -2,6 +2,7 @@
#include <torch/csrc/jit/tensorexpr/eval.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
namespace torch {
namespace jit {
@ -214,6 +215,7 @@ const Expr* IRMutator::mutate(const IfThenElse* v) {
const Expr* condition_new = condition->accept_mutator(this);
const Expr* true_value_new = true_value->accept_mutator(this);
const Expr* false_value_new = false_value->accept_mutator(this);
if (condition == condition_new && true_value == true_value_new &&
false_value == false_value_new) {
return v;
@ -232,11 +234,24 @@ const Expr* IRMutator::mutate(const FunctionCall* v) {
return this->mutate(base);
}
const Expr* IRMutator::mutate(const LinearForm* v) {
const Expr* new_x = v->getX()->accept_mutator(this);
const Expr* new_a = v->getA()->accept_mutator(this);
const Expr* new_b = v->getB()->accept_mutator(this);
return new LinearForm(new_x, new_a, new_b);
const Expr* IRMutator::mutate(const Term* v) {
const Expr* newScalar = v->scalar()->accept_mutator(this);
std::vector<const Expr*> variables;
for (const auto* t : v->variables()) {
variables.push_back(t->accept_mutator(this));
}
return new Term(v->hasher(), newScalar, variables);
}
const Expr* IRMutator::mutate(const Polynomial* v) {
const Expr* newScalar = v->scalar()->accept_mutator(this);
std::vector<const Term*> variables;
for (const auto* t : v->variables()) {
variables.push_back(static_cast<const Term*>(t->accept_mutator(this)));
}
return new Polynomial(v->hasher(), newScalar, variables);
}
const Expr* IRMutator::mutate(const BaseCallNode* v) {

View File

@ -45,7 +45,8 @@ class Allocate;
class Free;
class Cond;
class Stmt;
class LinearForm;
class Term;
class Polynomial;
class TORCH_API IRMutator {
public:
@ -86,7 +87,8 @@ class TORCH_API IRMutator {
virtual const Expr* mutate(const Intrinsics* v);
virtual const Expr* mutate(const FunctionCall* v);
virtual const Expr* mutate(const LinearForm* v);
virtual const Expr* mutate(const Term* v);
virtual const Expr* mutate(const Polynomial* v);
virtual Stmt* mutate(const For* v);
virtual Stmt* mutate(const Block* v);

View File

@ -1,5 +1,7 @@
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
namespace torch {
namespace jit {
namespace tensorexpr {
@ -198,7 +200,7 @@ template <
typename T,
std::enable_if_t<!std::is_floating_point<T>::value>* = nullptr>
static void formatImm(std::ostream& os, T v) {
os << v;
os << +v;
}
// NOLINTNEXTLINE
@ -368,9 +370,33 @@ void IRPrinter::visit(const Cond* v) {
}
}
void IRPrinter::visit(const LinearForm* v) {
os() << "(" << *v->getA() << ") * (" << *v->getX() << ") + (" << *v->getB()
<< ")" << std::endl;
void IRPrinter::visit(const Term* v) {
os() << "Term(";
v->scalar()->accept(this);
for (auto* t : v->variables()) {
os() << ",";
t->accept(this);
}
os() << ")";
}
void IRPrinter::visit(const Polynomial* v) {
bool first = true;
os() << "Polynomial(";
for (auto* t : v->variables()) {
emitIndent();
if (!first) {
os() << " + ";
}
first = false;
t->accept(this);
}
if (!first) {
os() << " + ";
}
v->scalar()->accept(this);
os() << ")";
}
void IRPrinter::emitIndent() {

View File

@ -48,7 +48,8 @@ class TORCH_API IRPrinter : public IRVisitor {
void visit(const Allocate* v) override;
void visit(const Free* v) override;
void visit(const Cond* v) override;
void visit(const LinearForm* v) override;
void visit(const Term* v) override;
void visit(const Polynomial* v) override;
std::ostream& os() {
return printer_os_;

File diff suppressed because it is too large Load Diff

View File

@ -1,368 +1,276 @@
#pragma once
#include "torch/csrc/jit/tensorexpr/eval.h"
#include "torch/csrc/jit/tensorexpr/ir_mutator.h"
#include "torch/csrc/jit/tensorexpr/ir_visitor.h"
#include "torch/csrc/jit/tensorexpr/types.h"
#include <torch/csrc/jit/tensorexpr/eval.h>
#include <torch/csrc/jit/tensorexpr/hash_server.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_mutator.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/types.h>
/* IR Simplification
*
* Simplfies expressions in two stages:
* 1. Recursively traverse the map combining similar operations into Terms
* (interacted via Multiplication) and Polynomials (interacted via Addition). We
* reorder the components of each Term or Polynomial into a consistent order to
* allow combination or cancelling of like terms.
* 2. Once the format of the tree is minimal, expand each Term into a sequence
* of Muls, and each Polynomial into a sequence of Ads.
*/
namespace torch {
namespace jit {
namespace tensorexpr {
// Uses the evaluator to fold an operation with constant terms.
// Expr v must be evaluatable without Vars.
static Expr* evaluateOp(const Expr* v) {
ExprHandle handle(v);
ExprEval<SimpleIREvaluator> eval(handle);
// A bunch of helpers for determine the Dtype of the output of a multi argument
// Term or Polynomial.
namespace {
template <class ExprType>
Dtype promoteTypesVec(const Expr* s, std::vector<const ExprType*>& v) {
Dtype t = s->dtype();
bool first = true;
switch (v->dtype().scalar_type()) {
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: { \
Type val = eval.value<Type>(); \
return getImmediateByType(v->dtype().scalar_type(), val).node(); \
}
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
#undef TYPE_CASE
default:
LOG(FATAL) << "Unsupported datatype: " << v->dtype();
return nullptr;
}
return nullptr;
} // namespace tensorexpr
static const Expr* newBinaryOpOfType(
IRNodeType expr_type,
const Expr* lhs,
const Expr* rhs,
bool option) {
switch (expr_type) {
case IRNodeType::kAdd:
return new Add(lhs, rhs);
case IRNodeType::kSub:
return new Sub(lhs, rhs);
case IRNodeType::kMul:
return new Mul(lhs, rhs);
case IRNodeType::kDiv:
return new Div(lhs, rhs);
case IRNodeType::kMod:
return new Mod(lhs, rhs);
case IRNodeType::kMax:
return new Max(lhs, rhs, option);
case IRNodeType::kMin:
return new Min(lhs, rhs, option);
case IRNodeType::kAnd:
return new And(lhs, rhs);
case IRNodeType::kXor:
return new Xor(lhs, rhs);
case IRNodeType::kLshift:
return new Lshift(lhs, rhs);
case IRNodeType::kRshift:
return new Rshift(lhs, rhs);
default:
LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
return nullptr;
}
}
/* Interprets expr as an Immediate and returns the value as type T. */
template <typename T>
T immediateAs(const Expr* expr) {
T val{0};
switch (expr->dtype().scalar_type()) {
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: \
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(expr)) { \
val = imm->value(); \
} else { \
LOG(FATAL) << "Bad expr: " << *expr << "\n"; \
} \
break;
AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE);
#undef TYPE_CASE
default:
LOG(FATAL) << "Unsupported datatype: " << expr->dtype();
}
return val;
}
/* Takes a LinearForm and converts it to Mul + (Add/Sub). */
const Expr* expandLinearForm(const LinearForm* v, IRMutator* mutator) {
const Expr* mul = nullptr;
const Expr* A = v->getA();
const Expr* B = v->getB();
const Expr* X = v->getX();
// we only really care about 0 and 1, so double should be fine.
double Aval = immediateAs<double>(A);
double Bval = immediateAs<double>(B);
// First handle A.
if (Aval == 0) {
if (Bval == 0) {
return getImmediateByType(X->dtype(), 0).node();
for (auto* e : v) {
if (first) {
t = Dtype(t.scalar_type(), e->dtype().lanes());
first = false;
}
return B;
} else if (Aval == 1) {
mul = X;
} else if (Aval == -1) {
return new Sub(B, X);
} else if (Aval < 0) {
// Negate A.
ExprHandle zero = getImmediateByType(A->dtype(), 0);
Sub* A_Sub = new Sub(zero.node(), A);
return new Sub(B, new Mul(X, evaluateOp(A_Sub)));
} else {
mul = new Mul(X, A);
t = promoteTypes(t, e->dtype());
}
if (Bval == 0) {
return mul;
}
return new Add(mul, B);
return t;
}
/* Expand any remaining LinearTerms into their component pieces */
class LinearFormExpander : public IRMutator {
public:
const Expr* mutate(const LinearForm* v) {
return expandLinearForm(v, this);
template <class ExprType>
Dtype promoteTypesVec(std::vector<const ExprType*>& v) {
if (v.empty()) {
throw malformed_input("empty list of types");
}
Dtype t = v[0]->dtype();
for (auto* e : v) {
t = promoteTypes(t, e->dtype());
}
return t;
}
template <class ExprType>
Dtype promoteTypesMap(
const Expr* s,
std::unordered_map<SimplifierHashType, const ExprType*>& m) {
Dtype t = s->dtype();
bool first = true;
for (auto& e : m) {
if (first) {
t = Dtype(t.scalar_type(), e.second->dtype().lanes());
first = false;
}
t = promoteTypes(t, e.second->dtype());
}
return t;
}
template <class ExprType>
Dtype promoteTypesVar(const ExprType* e) {
return e->dtype();
}
template <class ExprType, class... Args>
Dtype promoteTypesVar(const ExprType* e, Args... es) {
Dtype lhs = e->dtype();
Dtype rhs = promoteTypesVar(es...);
if (e->isConstant()) {
lhs = Dtype(lhs.scalar_type(), rhs.lanes());
}
return promoteTypes(lhs, rhs);
}
// Helper for determining if an Expr is a multi-lane primitive (e.g. Broadcast
// or Ramp).
bool isMultilanePrimitive(const Expr* e) {
return e->expr_type() == IRNodeType::kBroadcast ||
e->expr_type() == IRNodeType::kRamp;
}
} // namespace
// A Term represents a grouping of Exprs through multiplication.
// E.g. product(scalar, *variables).
class Term : public ExprNode<Term> {
public:
template <class... Args>
Term(HashProvider& hasher, const Expr* s, Args... ts)
: ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
CHECK(s->isConstant());
addComponent(ts...);
sort();
}
Term(HashProvider& hasher, const Expr* s, std::vector<const Expr*> v)
: ExprNodeBase(promoteTypesVec(s, v)),
variables_(std::move(v)),
scalar_(s),
hasher_(hasher) {
sort();
}
// Convenience constructor from a map of hash -> var, used when merging Terms.
Term(
HashProvider& hasher,
const Expr* s,
std::unordered_map<SimplifierHashType, const Expr*> varmap)
: ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
for (auto& p : varmap) {
addComponent(p.second);
}
sort();
}
const Expr* scalar() const {
return scalar_;
}
const std::vector<const Expr*>& variables() const {
return variables_;
}
HashProvider& hasher() const {
return hasher_;
}
// Produce a hash of just the variable components of this term, to determine
// if it can be combined with another term.
SimplifierHashType hashVars() const;
private:
std::vector<const Expr*> variables_;
const Expr* scalar_;
HashProvider& hasher_;
void addComponent() {}
void addComponent(const Expr* e) {
variables_.push_back(e);
}
template <class... Es>
void addComponent(const Expr* e, Es... es) {
addComponent(e);
addComponent(es...);
}
// Sort by hash to normalize order of components.
void sort();
};
/* Simplify the IR by combining arithmetic expressions over a common term.
*/
class IRSimplifier : public IRMutator {
// Polynomial represents a grouping of Exprs by addition.
// E.g. sum(*variables, scalar).
// This would better be called Expression, but, naming conflict...
class Polynomial : public ExprNode<Polynomial> {
public:
const Expr* mutate(const Add* v) override {
const Expr* lhs = v->lhs();
const Expr* rhs = v->rhs();
const Expr* lhs_new = lhs->accept_mutator(this);
const Expr* rhs_new = rhs->accept_mutator(this);
// Constant Folding.
if (lhs_new->isConstant() && rhs_new->isConstant()) {
const Expr* result = evaluateOp(v);
return result;
}
const LinearForm* lhsLinear = dynamic_cast<const LinearForm*>(lhs_new);
const LinearForm* rhsLinear = dynamic_cast<const LinearForm*>(rhs_new);
if (lhsLinear && rhsLinear) {
// Can add two LinearTerms if they reference the same Var.
if (lhsLinear->getX() == rhsLinear->getX()) {
Add* A_Add = new Add(lhsLinear->getA(), rhsLinear->getA());
Add* B_Add = new Add(lhsLinear->getB(), rhsLinear->getB());
LinearForm* linear = new LinearForm(
lhsLinear->getX(), evaluateOp(A_Add), evaluateOp(B_Add));
return linear;
}
// otherwise cannot simplify further.
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
}
// Can add a scalar into the B term of LinearTerm.
if (lhsLinear && rhs_new->isConstant()) {
Add* B_Add = new Add(lhsLinear->getB(), rhs_new);
LinearForm* linear = new LinearForm(
lhsLinear->getX(), lhsLinear->getA(), evaluateOp(B_Add));
return linear;
}
if (rhsLinear && lhs_new->isConstant()) {
Add* B_Add = new Add(rhsLinear->getB(), lhs_new);
LinearForm* linear = new LinearForm(
rhsLinear->getX(), rhsLinear->getA(), evaluateOp(B_Add));
return linear;
}
// Can create a LinearTerm over any sub expression.
if (lhs_new->isConstant()) {
LinearForm* linear = new LinearForm(rhs_new);
linear->setB(evaluateOp(lhs_new));
return linear;
}
if (rhs_new->isConstant()) {
LinearForm* linear = new LinearForm(lhs_new);
linear->setB(evaluateOp(rhs_new));
return linear;
}
/// Broadcasts are a bit more involved.
if (const Broadcast* bc = dynamic_cast<const Broadcast*>(lhs_new)) {
if (const Expr* ret = handleBroadcastAdd(bc, rhs_new)) {
return ret;
}
}
if (const Broadcast* bc = dynamic_cast<const Broadcast*>(rhs_new)) {
if (const Expr* ret = handleBroadcastAdd(bc, lhs_new)) {
return ret;
}
}
// No change.
if (lhs == lhs_new && rhs == rhs_new) {
return v;
}
// Cannot simplify.
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
template <class... Args>
Polynomial(HashProvider& hasher, const Expr* s, Args... ts)
: ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
CHECK(s->isConstant());
addTerm(ts...);
sort();
}
const Expr* mutate(const Sub* v) override {
const Expr* lhs = v->lhs();
const Expr* rhs = v->rhs();
const Expr* lhs_new = lhs->accept_mutator(this);
const Expr* rhs_new = rhs->accept_mutator(this);
// Constant Folding.
if (lhs_new->isConstant() && rhs_new->isConstant()) {
const Expr* result = evaluateOp(v);
return result;
}
const LinearForm* lhsLinear = dynamic_cast<const LinearForm*>(lhs_new);
const LinearForm* rhsLinear = dynamic_cast<const LinearForm*>(rhs_new);
if (lhsLinear && rhsLinear) {
// Can sub two LinearTerms if they reference the same Var.
if (lhsLinear->getX() == rhsLinear->getX()) {
Sub* A_Sub = new Sub(lhsLinear->getA(), rhsLinear->getA());
Sub* B_Sub = new Sub(lhsLinear->getB(), rhsLinear->getB());
LinearForm* linear = new LinearForm(
lhsLinear->getX(), evaluateOp(A_Sub), evaluateOp(B_Sub));
return linear;
}
// otherwise cannot simplify further.
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
}
// Can just sub from B term if LHS is a LinearTerm.
if (lhsLinear && rhs_new->isConstant()) {
Sub* B_Sub = new Sub(lhsLinear->getB(), rhs_new);
LinearForm* linear = new LinearForm(
lhsLinear->getX(), lhsLinear->getA(), evaluateOp(B_Sub));
return linear;
}
// Slightly more complicated if the RHS is LinearTerm.
if (rhsLinear && lhs_new->isConstant()) {
// The linear needs to be negated.
ExprHandle zero = getImmediateByType(rhsLinear->getA()->dtype(), 0);
Sub* A_Sub = new Sub(zero.node(), rhsLinear->getA());
Sub* B_Sub = new Sub(rhsLinear->getB(), lhs_new);
LinearForm* linear = new LinearForm(
rhsLinear->getX(), evaluateOp(A_Sub), evaluateOp(B_Sub));
return linear;
}
// Can create a new LinearTerm, but since the B term is defined as Add we
// must negate it.
if (rhs_new->isConstant()) {
LinearForm* linear = new LinearForm(lhs_new);
ExprHandle zero = getImmediateByType(linear->getA()->dtype(), 0);
Sub* B_Sub = new Sub(zero.node(), rhs_new);
linear->setB(evaluateOp(B_Sub));
return linear;
}
// Can create a new LinearTerm with the A term -1 to negate the Expr.
if (lhs_new->isConstant()) {
// Negate by using -1 as the first linear.
ExprHandle negOne = getImmediateByType(rhs_new->dtype(), -1);
LinearForm* linear =
new LinearForm(rhs_new, negOne.node(), evaluateOp(lhs_new));
return linear;
}
// Nothing to do.
if (lhs == lhs_new && rhs == rhs_new) {
return v;
}
// Cannot simplify.
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
Polynomial(HashProvider& hasher, const Expr* s, std::vector<const Term*> v)
: ExprNodeBase(promoteTypesVec(s, v)),
variables_(std::move(v)),
scalar_(s),
hasher_(hasher) {
sort();
}
const Expr* mutate(const Mul* v) override {
const Expr* lhs = v->lhs();
const Expr* rhs = v->rhs();
const Expr* lhs_new = lhs->accept_mutator(this);
const Expr* rhs_new = rhs->accept_mutator(this);
// Constant Folding.
if (lhs_new->isConstant() && rhs_new->isConstant()) {
return evaluateOp(v);
}
const LinearForm* lhsLinear = dynamic_cast<const LinearForm*>(lhs_new);
const LinearForm* rhsLinear = dynamic_cast<const LinearForm*>(rhs_new);
if (lhsLinear && rhsLinear) {
// Lets not get into higher order terms.
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
}
// Easy to simplify into an existing LinearTerm by multiplying A and B.
if (lhsLinear && rhs_new->isConstant()) {
Mul* A_Mul = new Mul(lhsLinear->getA(), rhs_new);
Mul* B_Mul = new Mul(lhsLinear->getB(), rhs_new);
LinearForm* linear = new LinearForm(
lhsLinear->getX(), evaluateOp(A_Mul), evaluateOp(B_Mul));
return linear;
}
if (rhsLinear && lhs_new->isConstant()) {
Mul* A_Mul = new Mul(rhsLinear->getA(), lhs_new);
Mul* B_Mul = new Mul(rhsLinear->getB(), lhs_new);
LinearForm* linear = new LinearForm(
rhsLinear->getX(), evaluateOp(A_Mul), evaluateOp(B_Mul));
return linear;
}
// Easy to create a new LinearTerm by setting term A.
if (lhs_new->isConstant()) {
LinearForm* linear = new LinearForm(rhs_new);
linear->setA(evaluateOp(lhs_new));
return linear;
}
if (rhs_new->isConstant()) {
LinearForm* linear = new LinearForm(lhs_new);
linear->setA(evaluateOp(rhs_new));
return linear;
}
// Broadcasts have special logic.
if (const Broadcast* bc = dynamic_cast<const Broadcast*>(lhs_new)) {
if (const Expr* ret = handleBroadcastMul(bc, rhs_new)) {
return ret;
}
}
if (const Broadcast* bc = dynamic_cast<const Broadcast*>(rhs_new)) {
if (const Expr* ret = handleBroadcastMul(bc, lhs_new)) {
return ret;
}
}
// Cannot be simplified, just exit.
if (lhs == lhs_new && rhs == rhs_new) {
return v;
}
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
// Helper constructor for list of terms with no scalar component.
Polynomial(HashProvider& hasher, std::vector<const Term*> terms)
: ExprNodeBase(promoteTypesVec(terms)),
variables_(std::move(terms)),
scalar_(getImmediateByType(dtype(), 0)),
hasher_(hasher) {
sort();
}
// Convenience constructor for map of hash -> var, used when merging
// Polynomials.
Polynomial(
HashProvider& hasher,
const Expr* s,
std::unordered_map<SimplifierHashType, const Term*> varmap)
: ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
for (auto& p : varmap) {
addTerm(p.second);
}
sort();
}
const Expr* scalar() const {
return scalar_;
}
const std::vector<const Term*>& variables() const {
return variables_;
}
HashProvider& hasher() const {
return hasher_;
}
SimplifierHashType hashVars() const;
private:
std::vector<const Term*> variables_;
const Expr* scalar_;
HashProvider& hasher_;
void addTerm(const Term* t) {
variables_.push_back(t);
}
template <class... Ts>
void addTerm(const Term* t, Ts... ts) {
addTerm(t);
addTerm(ts...);
}
// Sort by hash to normalize order of terms.
void sort();
};
// Simplify the IR by combining arithmetic expressions over common terms.
class TORCH_API PolynomialTransformer : public IRMutator {
public:
// Inserts term into the provided map, in the case of a hash collision
// combines the term with the existing and updates the map.
void addOrUpdateTerm(
std::unordered_map<SimplifierHashType, const Term*>& varmap,
const Term* term);
// Add Polynomial expressions, combining Terms representing the same
// variables.
const Expr* addPolynomials(const Polynomial* lhs, const Polynomial* rhs);
// Insert a new Term into the provided polynomial. If the new term has common
// variables to an existing term it is combined.
const Expr* insertTerm(const Polynomial* poly, const Term* term);
// Merge and simplify addition.
const Expr* mutate(const Add* v) override;
// Subtract one term from another, cancelling if necessary.
const Expr* subTerms(const Term* lhs, const Term* rhs, bool negated);
// Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where
// possible.
const Expr* subPolynomials(const Polynomial* lhs, const Polynomial* rhs);
// Merge and simplify subtraction.
const Expr* mutate(const Sub* v) override;
// Multiply two terms together, usually creating a new term with the variable
// lists concatenated.
const Term* mulTerms(const Term* lhs, const Term* rhs);
// Multiply a Polynomial by a Term.
const Expr* polyByTerm(const Polynomial* poly, const Term* term);
// Merge and simplify multiplication.
const Expr* mutate(const Mul* v) override;
const Expr* mutate(const Div* v) override {
// TODO div simplification will require a rational node.
return mutateBinaryOp(v, this);
@ -396,37 +304,9 @@ class IRSimplifier : public IRMutator {
return mutateBinaryOp(v, this, v->propagate_nans());
}
const Expr* mutate(const Intrinsics* v) override {
std::vector<const Expr*> new_params;
bool changed = false;
bool allConstant = true;
for (const auto* p : v->params()) {
const Expr* new_child = p->accept_mutator(this);
new_params.push_back(new_child);
const Expr* mutate(const Intrinsics* v) override;
changed |= p != new_child;
allConstant &= new_child->isConstant();
}
const Expr* node = v;
if (changed) {
node = new Intrinsics(v->op_type(), new_params);
}
if (!allConstant || !v->isPure()) {
return node;
}
return evaluateOp(node);
}
const Expr* mutate(const Cast* v) override {
if (v->src_value()->isConstant()) {
return evaluateOp(v);
}
return v;
}
const Expr* mutate(const Cast* v) override;
template <typename Op>
static const Expr* mutateBinaryOp(
@ -452,12 +332,40 @@ class IRSimplifier : public IRMutator {
return evaluateOp(node);
}
static const Expr* simplify(const Expr* e);
static ExprHandle simplify(const ExprHandle& e);
static Stmt* simplify(Stmt* e);
private:
HashProvider hasher_;
}; // namespace tensorexpr
// Expands Terms and Polynomial expressions into primitive operations.
// Does some simple factorization and reordering.
class TORCH_API TermExpander : public IRMutator {
PolynomialTransformer* simplifier_;
public:
TermExpander(PolynomialTransformer* simplifier) : simplifier_(simplifier) {}
// Expand Terms out to a series of Muls.
const Expr* mutate(const Term* v) override;
// Trivially factorize terms by GCD of scalar components.
const Expr* factorizePolynomial(const Polynomial* poly);
// Expand Polynomials out to a series of Adds.
const Expr* mutate(const Polynomial* v);
};
class TORCH_API IRSimplifier {
public:
static const Expr* simplify(const Expr* e) {
IRSimplifier simplifier;
PolynomialTransformer simplifier;
e = e->accept_mutator(&simplifier);
// There may be terms left in the IR, expand them.
LinearFormExpander expander;
TermExpander expander(&simplifier);
e = e->accept_mutator(&expander);
return e;
@ -468,74 +376,15 @@ class IRSimplifier : public IRMutator {
}
static Stmt* simplify(Stmt* s) {
IRSimplifier simplifier;
PolynomialTransformer simplifier;
s = s->accept_mutator(&simplifier);
// There may be terms left in the IR, expand them.
LinearFormExpander expander;
TermExpander expander(&simplifier);
s = s->accept_mutator(&expander);
return s;
}
private:
/* Expands lhs and rhs if they are LinearTerms, creating a new op to hold
* them. If either side expands to a constant term, attempt simplification of
* the new op. */
const Expr* expandAndRecurse(
IRNodeType expr_type,
const Expr* lhs,
const Expr* rhs) {
if (const LinearForm* lhsLinear = dynamic_cast<const LinearForm*>(lhs)) {
lhs = expandLinearForm(lhsLinear, this);
}
if (const LinearForm* rhsLinear = dynamic_cast<const LinearForm*>(rhs)) {
rhs = expandLinearForm(rhsLinear, this);
}
const Expr* result = newBinaryOpOfType(expr_type, lhs, rhs, false);
// lhs or rhs can become constant during expansion, if either is now
// constant we can keep merging into a linear term. Have another attempt to
// simplify the new op.
if (lhs->isConstant() || rhs->isConstant()) {
return result->accept_mutator(this);
}
return result;
}
/* Handles optimization cases for Broadcast() + Other */
const Expr* handleBroadcastAdd(const Broadcast* bc, const Expr* other) {
if (bc->value()->isConstant() && immediateAs<int>(bc->value()) == 0) {
return other;
}
if (const Ramp* r = dynamic_cast<const Ramp*>(other)) {
// Add the broadcast to the start of the Ramp.
const Expr* ret =
new Ramp(new Add(bc->value(), r->base()), r->stride(), r->lanes());
return ret->accept_mutator(this);
}
return nullptr;
}
/* Handles optimization cases for Broadcast() * Other */
const Expr* handleBroadcastMul(const Broadcast* bc, const Expr* other) {
if (bc->value()->isConstant() && immediateAs<int>(bc->value()) == 1) {
return other;
}
if (const Ramp* r = dynamic_cast<const Ramp*>(other)) {
// Multiply both start and stride by the broadcast value.
const Expr* ret = new Ramp(
new Mul(bc->value(), r->base()),
new Mul(bc->value(), r->stride()),
r->lanes());
return ret->accept_mutator(this);
}
return nullptr;
}
};
} // namespace tensorexpr

View File

@ -1,6 +1,7 @@
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
namespace torch {
@ -176,10 +177,18 @@ void IRVisitor::visit(const Cond* v) {
}
}
void IRVisitor::visit(const LinearForm* v) {
v->getA()->accept(this);
v->getX()->accept(this);
v->getB()->accept(this);
void IRVisitor::visit(const Term* v) {
v->scalar()->accept(this);
for (auto* t : v->variables()) {
t->accept(this);
}
}
void IRVisitor::visit(const Polynomial* v) {
v->scalar()->accept(this);
for (auto* t : v->variables()) {
t->accept(this);
}
}
} // namespace tensorexpr

View File

@ -42,7 +42,8 @@ class FunctionCall;
class Allocate;
class Free;
class Cond;
class LinearForm;
class Term;
class Polynomial;
class TORCH_API IRVisitor {
public:
@ -90,7 +91,8 @@ class TORCH_API IRVisitor {
virtual void visit(const Allocate* v);
virtual void visit(const Free* v);
virtual void visit(const Cond* v);
virtual void visit(const LinearForm* v);
virtual void visit(const Term* v);
virtual void visit(const Polynomial* v);
};
} // namespace tensorexpr

View File

@ -1114,6 +1114,7 @@ void TensorExprKernel::lowerToBackend(BackendType backendType) {
l.ApplyInlines();
Stmt* stmt = l.root_stmt();
// Arithmetic Simplification.
stmt = IRSimplifier::simplify(stmt);
// Set up formal params (inputs, then outputs) for kernel.

View File

@ -1,6 +1,6 @@
#include <torch/csrc/jit/tensorexpr/types.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/tensorexpr/exceptions.h>
#include <torch/csrc/jit/tensorexpr/types.h>
#include <c10/util/Logging.h>

View File

@ -50,7 +50,7 @@ class TORCH_API Dtype {
Dtype(Dtype type, int lanes)
: scalar_type_(type.scalar_type_), lanes_(lanes) {
if (type.lanes() != 1) {
throw malformed_input();
throw malformed_input("dtype lanes dont match");
}
}
int lanes() const {
@ -108,10 +108,15 @@ inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
return static_cast<ScalarType>(c10::promoteTypes(
static_cast<c10::ScalarType>(a), static_cast<c10::ScalarType>(b)));
}
inline ScalarType promoteTypes(Dtype a, Dtype b) {
return static_cast<ScalarType>(c10::promoteTypes(
static_cast<c10::ScalarType>(a.scalar_type()),
static_cast<c10::ScalarType>(b.scalar_type())));
inline Dtype promoteTypes(Dtype a, Dtype b) {
if (a.lanes() != b.lanes()) {
throw malformed_input("promoting types with different lanes");
}
return Dtype(
static_cast<ScalarType>(c10::promoteTypes(
static_cast<c10::ScalarType>(a.scalar_type()),
static_cast<c10::ScalarType>(b.scalar_type()))),
a.lanes());
}
inline Dtype BinaryOpDtype(
@ -127,22 +132,21 @@ inline Dtype BinaryOpDtype(
}
if (op1_dtype.lanes() != op2_dtype.lanes()) {
throw malformed_input();
throw malformed_input("lanes dont match");
}
int lanes = op1_dtype.lanes();
ScalarType resultType = promoteTypes(op1_dtype, op2_dtype);
if (resultType == ScalarType::Undefined) {
throw malformed_input();
Dtype resultType = promoteTypes(op1_dtype, op2_dtype);
if (resultType.scalar_type() == ScalarType::Undefined) {
throw malformed_input("scalar type doesn't match");
}
if (lanes == 1) {
// Use the fixed scalar Dtypes.
return ToDtype(resultType);
return ToDtype(resultType.scalar_type());
}
return Dtype(resultType, lanes);
return resultType;
}
} // namespace tensorexpr