Revert D20624571: [pytorch][PR] [TensorExpr] Extend arithmetic simplifier to work with multi variable expressions

Test Plan: revert-hammer

Differential Revision:
D20624571

Original commit changeset: e49049377bee

fbshipit-source-id: 7d8dda0c3b44be1c3236a0313bbfa128b7015de7
This commit is contained in:
Alban Desmaison
2020-03-24 16:55:57 -07:00
committed by Facebook GitHub Bot
parent ee7cd84fac
commit a7f8655314
19 changed files with 706 additions and 2441 deletions

View File

@ -481,7 +481,6 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_mutator.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_printer.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_simplifier.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_visitor.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/kernel.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/llvm_codegen.cpp

View File

@ -11,40 +11,6 @@ namespace jit {
using namespace torch::jit::tensorexpr;
using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
#define IS_NODE(T, node) \
{ \
auto* node_ = dynamic_cast<const T*>(node); \
EXPECT_NE(nullptr, node_) << "Expected node to be " #T; \
}
#define IS_NODE_WITH_NAME(T, node, name) \
auto* name = dynamic_cast<const T*>(node); \
EXPECT_NE(nullptr, name) << "Expected " #name " to be " #T;
#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \
const T* name = nullptr; \
{ \
auto* node_ = dynamic_cast<const Cast*>(node); \
EXPECT_NE(nullptr, node_); \
EXPECT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \
name = dynamic_cast<const T*>(node_->src_value()); \
} \
EXPECT_NE(nullptr, name) << "Expected " #name " to be " #T;
#define IS_IMM_WITH_VAL(T, node, val) \
{ \
auto* node_ = dynamic_cast<const T##Imm*>(node); \
EXPECT_NE(nullptr, node_) << "Expected node to be " #T "Imm"; \
EXPECT_EQ(node_->value(), val) << "Expected Imm to be " << val; \
}
#define IS_VAR_WITH_NAME(node, name) \
{ \
auto* node_ = dynamic_cast<const Var*>(node); \
EXPECT_NE(nullptr, node_) << "Expected node to be Var"; \
EXPECT_EQ(node_->name_hint(), name) << "Expected var to be " #name; \
}
void testConstantFoldSimple() {
KernelScope kernel_scope;
ExprHandle a(2.0f);
@ -167,33 +133,17 @@ void testConstantFoldIntrinsics() {
void testConstantFoldWithVar() {
KernelScope kernel_scope;
{
VarHandle x("x", kInt);
ExprHandle body = x * (ExprHandle(2) + ExprHandle(4));
VarHandle x("x", kFloat);
ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f));
ExprHandle newF = IRSimplifier::simplify(body);
const Mul* root = newF.AsNode<Mul>();
EXPECT_NE(root, nullptr);
EXPECT_NE(dynamic_cast<const IntImm*>(root->lhs()), nullptr);
ExprHandle newF = IRSimplifier::simplify(body);
const Mul* root = newF.AsNode<Mul>();
EXPECT_NE(root, nullptr);
EXPECT_NE(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
ExprHandle result = Let::make(x, ExprHandle(3), newF);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<int>(), 3 * (2 + 4));
}
{
VarHandle x("x", kFloat);
ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f));
ExprHandle newF = IRSimplifier::simplify(body);
const Mul* root = newF.AsNode<Mul>();
EXPECT_NE(root, nullptr);
EXPECT_NE(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
ExprHandle result = Let::make(x, ExprHandle(3.f), newF);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<float>(), 3 * (2 + 4));
}
ExprHandle result = Let::make(x, ExprHandle(3.f), newF);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<float>(), 3 * (2 + 4));
}
void testUnFoldableExpr() {
@ -278,22 +228,34 @@ void testHashEquivalenceAfterFolding() {
ExprHandle a(2.0f);
ExprHandle b(3.0f);
ExprHandle c(5.0f);
ExprHandle f1 = ((a + b) * x);
ExprHandle f2 = (c * x);
ExprHandle f = ((a + b) * x) * (c * x);
const Mul* root = f.AsNode<Mul>();
EXPECT_NE(root, nullptr);
HashProvider hasher;
auto hash_l = hasher.hash(f1.node());
auto hash_r = hasher.hash(f2.node());
auto hash_f = hasher.hash(f.node());
auto hash_l = hasher.hash(root->lhs());
auto hash_r = hasher.hash(root->rhs());
// Root not equal to either branch, and branches not equal.
EXPECT_NE(hash_f, hash_l);
EXPECT_NE(hash_f, hash_r);
EXPECT_NE(hash_l, hash_r);
ExprHandle ff1 = IRSimplifier::simplify(f1);
ExprHandle ff2 = IRSimplifier::simplify(f2);
ExprHandle newF = IRSimplifier::simplify(f);
auto hash_l_n = hasher.hash(ff1.node());
auto hash_r_n = hasher.hash(ff2.node());
const Mul* newRoot = newF.AsNode<Mul>();
EXPECT_NE(newRoot, nullptr);
// branches are now equal.
auto hash_f_n = hasher.hash(newF.node());
auto hash_l_n = hasher.hash(newRoot->lhs());
auto hash_r_n = hasher.hash(newRoot->rhs());
// Root not equal to either branch.
EXPECT_NE(hash_f_n, hash_l_n);
EXPECT_NE(hash_f_n, hash_r_n);
// but branches are now equal.
EXPECT_EQ(hash_l_n, hash_r_n);
}
@ -381,16 +343,11 @@ void testHashLargeExpression() {
EXPECT_NE(hash_t, hash_f);
}
/// (2 + x) + 4 => x + 6
/// (2.f + x) + 4.f => x + 6.f
void testSimplifyAdd() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
VarHandle m("m", kInt);
VarHandle n("n", kInt);
VarHandle n_1("n_1", kInt);
ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4);
VarHandle x("x", kFloat);
ExprHandle body = (ExprHandle(2.f) + x) + ExprHandle(4.f);
ExprHandle simplified = IRSimplifier::simplify(body);
const Add* root = simplified.AsNode<Add>();
@ -398,43 +355,51 @@ void testSimplifyAdd() {
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
EXPECT_NE(lhs, nullptr);
EXPECT_EQ(lhs->name_hint(), "x");
const IntImm* rhs = dynamic_cast<const IntImm*>(root->rhs());
const FloatImm* rhs = dynamic_cast<const FloatImm*>(root->rhs());
EXPECT_NE(rhs, nullptr);
EXPECT_EQ(rhs->value(), 6.f);
}
/// (2 - x) - 4 => -2 - x
/// (2.f - x) - 4.f => -2.f - x
void testSimplifySub() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4);
VarHandle x("x", kFloat);
ExprHandle body = (ExprHandle(2.f) - x) - ExprHandle(4.f);
ExprHandle simplified = IRSimplifier::simplify(body);
const Sub* root = simplified.AsNode<Sub>();
EXPECT_NE(root, nullptr);
const IntImm* lhs = dynamic_cast<const IntImm*>(root->lhs());
const FloatImm* lhs = dynamic_cast<const FloatImm*>(root->lhs());
EXPECT_NE(lhs, nullptr);
EXPECT_EQ(lhs->value(), -2);
EXPECT_EQ(lhs->value(), -2.f);
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
EXPECT_NE(rhs, nullptr);
EXPECT_EQ(rhs->name_hint(), "x");
}
/// 2 * (1 - x) - 4 => -2 * (x + 3)
/// 2.f * (1.f - x) - 4.f => -6.f - (x * 2.f)
void testSimplifyMultiLayer() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4));
ExprHandle simplified = IRSimplifier::simplify(body);
VarHandle x("x", kFloat);
ExprHandle body = ExprHandle(2.f) * ((ExprHandle(1.f) - x) - ExprHandle(4.f));
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), -2);
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
IS_VAR_WITH_NAME(add->lhs(), "x");
IS_IMM_WITH_VAL(Int, add->rhs(), 3);
ExprHandle simplified = IRSimplifier::simplify(body);
const Sub* root = simplified.AsNode<Sub>();
EXPECT_NE(root, nullptr);
const FloatImm* lhs = dynamic_cast<const FloatImm*>(root->lhs());
EXPECT_NE(lhs, nullptr);
EXPECT_EQ(lhs->value(), -6.f);
const Mul* rhs = dynamic_cast<const Mul*>(root->rhs());
EXPECT_NE(rhs, nullptr);
const Var* varX = dynamic_cast<const Var*>(rhs->lhs());
EXPECT_NE(varX, nullptr);
EXPECT_EQ(varX->name_hint(), "x");
const FloatImm* mulRhs = dynamic_cast<const FloatImm*>(rhs->rhs());
EXPECT_NE(mulRhs, nullptr);
EXPECT_EQ(mulRhs->value(), 2.f);
}
/// 2 * (3 * x) - (x * 4) => 2 * x
/// 2 * (3 * x) - (x * 4) => x * 2
void testSimplifyMultiTerm() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
@ -444,30 +409,30 @@ void testSimplifyMultiTerm() {
ExprHandle simplified = IRSimplifier::simplify(body);
const Mul* root = simplified.AsNode<Mul>();
EXPECT_NE(root, nullptr);
const IntImm* lhs = dynamic_cast<const IntImm*>(root->lhs());
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
EXPECT_NE(lhs, nullptr);
EXPECT_EQ(lhs->value(), 2);
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
EXPECT_EQ(lhs->name_hint(), "x");
const IntImm* rhs = dynamic_cast<const IntImm*>(root->rhs());
EXPECT_NE(rhs, nullptr);
EXPECT_EQ(rhs->name_hint(), "x");
EXPECT_EQ(rhs->value(), 2);
}
/// 2 * (3 * (long)x) - (x * 4) => 2 * x
/// 2 * (3 * (f)x) - (x * 4) => x * 2.f
void testSimplifyCasts() {
KernelScope kernel_scope;
VarHandle x("x", kLong);
VarHandle x("x", kFloat);
ExprHandle body =
(ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4)));
ExprHandle simplified = IRSimplifier::simplify(body);
const Mul* root = simplified.AsNode<Mul>();
EXPECT_NE(root, nullptr);
const LongImm* lhs = dynamic_cast<const LongImm*>(root->lhs());
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
EXPECT_NE(lhs, nullptr);
EXPECT_EQ(lhs->value(), 2);
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
EXPECT_EQ(lhs->name_hint(), "x");
const FloatImm* rhs = dynamic_cast<const FloatImm*>(root->rhs());
EXPECT_NE(rhs, nullptr);
EXPECT_EQ(rhs->name_hint(), "x");
EXPECT_EQ(rhs->value(), 2);
}
/// (x + 0) * 1 => x
@ -487,39 +452,20 @@ void testSimplifyMultiVar() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
ExprHandle body = x * 24 + y * 34;
ExprHandle body = y * 24 + x * 34;
ExprHandle simplified = IRSimplifier::simplify(body);
const Add* root = simplified.AsNode<Add>();
EXPECT_NE(root, nullptr);
const Mul* lhs = dynamic_cast<const Mul*>(root->lhs());
EXPECT_NE(lhs, nullptr);
const Var* varX = dynamic_cast<const Var*>(lhs->rhs());
EXPECT_NE(varX, nullptr);
EXPECT_EQ(varX->name_hint(), "y");
const Var* varY = dynamic_cast<const Var*>(lhs->lhs());
EXPECT_EQ(varY->name_hint(), "y");
const Mul* rhs = dynamic_cast<const Mul*>(root->rhs());
EXPECT_NE(rhs, nullptr);
const Var* varY = dynamic_cast<const Var*>(rhs->rhs());
EXPECT_NE(varY, nullptr);
EXPECT_EQ(varY->name_hint(), "x");
}
// x + 2 + y => x + y + 2
void testSimplifyReorderings() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
ExprHandle body = x + 2 + y;
ExprHandle simplified = IRSimplifier::simplify(body);
const Add* root = simplified.AsNode<Add>();
EXPECT_NE(root, nullptr);
IS_NODE_WITH_NAME(Add, root->lhs(), rhs);
IS_VAR_WITH_NAME(rhs->lhs(), "x");
IS_VAR_WITH_NAME(rhs->rhs(), "y");
IS_IMM_WITH_VAL(Int, root->rhs(), 2);
const Var* varX = dynamic_cast<const Var*>(rhs->lhs());
EXPECT_NE(varX, nullptr);
EXPECT_EQ(varX->name_hint(), "x");
}
/// y + x * 0 => y
@ -530,621 +476,9 @@ void testSimplifyEliminatesVar() {
ExprHandle body = y + x * ExprHandle(0);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_VAR_WITH_NAME(simplified.node(), "y");
}
void testSimplifyAdds() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
{
// (x + y) + (x + y) => 2 * (x + y)
ExprHandle body = (x + y) + (x + y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), root);
IS_IMM_WITH_VAL(Int, root->lhs(), 2);
IS_NODE_WITH_NAME(Add, root->rhs(), add);
IS_VAR_WITH_NAME(add->lhs(), "x");
IS_VAR_WITH_NAME(add->rhs(), "y");
}
{
// (x * y) + (x * y) => 2 * (x * y)
ExprHandle body = (x * y) + (x * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), root);
IS_IMM_WITH_VAL(Int, root->lhs(), 2);
IS_NODE_WITH_NAME(Mul, root->rhs(), mul);
IS_VAR_WITH_NAME(mul->lhs(), "x");
IS_VAR_WITH_NAME(mul->rhs(), "y");
}
{
// (x - y) + (x - y) => -2 * (y - x)
ExprHandle body = (x - y) + (x - y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), -2);
IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
IS_VAR_WITH_NAME(rhs->lhs(), "y");
IS_VAR_WITH_NAME(rhs->rhs(), "x");
}
}
void testSimplifyMuls() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
{
// (x + y) * (x + y) => (x + y) * (x + y)
// We don't attempt to simplify mulitplication of polynomials since the
// result is only very rarely more efficient.
ExprHandle body = (x + y) * (x + y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_NODE_WITH_NAME(Add, mul->lhs(), lhs);
IS_VAR_WITH_NAME(lhs->lhs(), "x");
IS_VAR_WITH_NAME(lhs->rhs(), "y");
IS_NODE_WITH_NAME(Add, mul->rhs(), rhs);
IS_VAR_WITH_NAME(rhs->lhs(), "x");
IS_VAR_WITH_NAME(rhs->rhs(), "y");
}
{
// x * y * x * y => x * x * y * y
// These get reordered only.
ExprHandle body = x * y * x * y;
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul1);
IS_NODE_WITH_NAME(Mul, mul1->lhs(), mul2);
IS_NODE_WITH_NAME(Mul, mul2->lhs(), mul3);
IS_VAR_WITH_NAME(mul1->rhs(), "y");
IS_VAR_WITH_NAME(mul2->rhs(), "y");
IS_VAR_WITH_NAME(mul3->lhs(), "x");
IS_VAR_WITH_NAME(mul3->rhs(), "x");
}
{
// (x - y) * (x - y) => (x - y) * (x - y)
// As with Add we don't attempt simplification of this.
ExprHandle body = (x - y) * (x - y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_NODE_WITH_NAME(Sub, mul->lhs(), lhs);
IS_VAR_WITH_NAME(lhs->lhs(), "x");
IS_VAR_WITH_NAME(lhs->rhs(), "y");
IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
IS_VAR_WITH_NAME(rhs->lhs(), "x");
IS_VAR_WITH_NAME(rhs->rhs(), "y");
}
{
// (x + y) * (x - y) => (x - y) * (x - y)
// Don't simplify with different ops on each side.
ExprHandle body = (x + y) * (x - y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_NODE_WITH_NAME(Add, mul->lhs(), lhs);
IS_VAR_WITH_NAME(lhs->lhs(), "x");
IS_VAR_WITH_NAME(lhs->rhs(), "y");
IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
IS_VAR_WITH_NAME(rhs->lhs(), "x");
IS_VAR_WITH_NAME(rhs->rhs(), "y");
}
}
// Sub an expr from itself will result in zero.
void testSimplifySubs() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
{
// (x + y) - (x + y) => 0
ExprHandle body = (x + y) - (x + y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
}
{
// (x * y) - (x * y) => 0
ExprHandle body = (x * y) - (x * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
}
{
// (x - y) - (x - y) => 0
ExprHandle body = (x - y) - (x - y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
}
{
// (x + y) - 2 * (x + y) => -1 * (x + y)
ExprHandle body = (x + y) - ExprHandle(2) * (x + y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), -1);
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
IS_VAR_WITH_NAME(add->lhs(), "x");
IS_VAR_WITH_NAME(add->rhs(), "y");
}
{
// (x + y) - y => x
ExprHandle body = (x + y) - y;
ExprHandle simplified = IRSimplifier::simplify(body);
IS_VAR_WITH_NAME(simplified.node(), "x");
}
{
// (x - y) - y => x - 2 * y
ExprHandle body = (x - y) - y;
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
IS_VAR_WITH_NAME(sub->lhs(), "x");
IS_NODE_WITH_NAME(Mul, sub->rhs(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
IS_VAR_WITH_NAME(mul->rhs(), "y");
}
{
// 2 * x - x => x
ExprHandle body = (ExprHandle(2) * x) - x;
ExprHandle simplified = IRSimplifier::simplify(body);
IS_VAR_WITH_NAME(simplified.node(), "x");
}
{
// x - 2 * x = -1 * x
// We don't have a unary negate, but this could be 0 -x I guess?
ExprHandle body = x - (ExprHandle(2) * x);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), -1);
IS_VAR_WITH_NAME(mul->rhs(), "x");
}
{
// (x + y + 5) * (x - x) => 0
// Cancelling out one side of Mul cancels both.
ExprHandle body = (x + y + 5) * (x - x);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
}
}
// Test that mixing ops together simplifies as expected.
void testSimplifyMultiOp() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
{
// (x * y) + (x - y) => (x * y) + x - y
//
ExprHandle body = (x * y) + (x - y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
IS_NODE_WITH_NAME(Add, sub->lhs(), add);
IS_NODE_WITH_NAME(Mul, add->lhs(), mul);
IS_VAR_WITH_NAME(mul->lhs(), "x");
IS_VAR_WITH_NAME(mul->rhs(), "y");
IS_VAR_WITH_NAME(add->rhs(), "x");
IS_VAR_WITH_NAME(sub->rhs(), "y");
}
{
// (x + y) - (x * y) => x + y - (x * y)
ExprHandle body = (x + y) - (x * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
IS_NODE_WITH_NAME(Add, sub->lhs(), add);
IS_NODE_WITH_NAME(Mul, sub->rhs(), mul);
IS_VAR_WITH_NAME(add->lhs(), "x");
IS_VAR_WITH_NAME(add->rhs(), "y");
IS_VAR_WITH_NAME(mul->lhs(), "x");
IS_VAR_WITH_NAME(mul->rhs(), "y");
}
{
// (x - y) - (x + y) => -2 * y
ExprHandle body = (x - y) - (x + y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), -2);
IS_VAR_WITH_NAME(mul->rhs(), "y");
}
}
// Test that chaining many ops together works as expected.
void testSimplifyManyOps() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
{
// x + y + x + x + y + y + x + y + x = 4 * y + 5 * x
ExprHandle body = x + y + x + x + y + y + x + y + x;
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Add, simplified.node(), add);
IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
IS_IMM_WITH_VAL(Int, lhs->lhs(), 4);
IS_VAR_WITH_NAME(lhs->rhs(), "y");
IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
IS_IMM_WITH_VAL(Int, rhs->lhs(), 5);
IS_VAR_WITH_NAME(rhs->rhs(), "x");
}
{
// x - y + x + x - y - y + x - y + x = 5 * x - 4 * y
ExprHandle body = x - y + x + x - y - y + x - y + x;
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Sub, simplified.node(), add);
IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
IS_IMM_WITH_VAL(Int, lhs->lhs(), 5);
IS_VAR_WITH_NAME(lhs->rhs(), "x");
IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
IS_IMM_WITH_VAL(Int, rhs->lhs(), 4);
IS_VAR_WITH_NAME(rhs->rhs(), "y");
}
{
// x + y + x - x - y - y + x + y + x = 3 * x
ExprHandle body = x + y + x - x - y - y + x + y + x;
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), 3);
IS_VAR_WITH_NAME(mul->rhs(), "x");
}
}
void testSimplifyFactorization() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
{
// (2 * x) + (2 * y) => 2 * (x + y)
ExprHandle body = (ExprHandle(2) * x + ExprHandle(2) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
IS_VAR_WITH_NAME(add->lhs(), "x");
IS_VAR_WITH_NAME(add->rhs(), "y");
}
{
// Factorization when scalars have common divider.
// (2 * x) + (4 * y) => 2 * (2 * y + x)
ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
IS_NODE_WITH_NAME(Mul, add->lhs(), mul2);
IS_IMM_WITH_VAL(Int, mul2->lhs(), 2);
IS_VAR_WITH_NAME(mul2->rhs(), "y");
IS_VAR_WITH_NAME(add->rhs(), "x");
}
{
// Factorization attempt without a common divider.
// (2 * x) + (5 * y) => (5 * y) + (2 * x)
ExprHandle body = (ExprHandle(2) * x + ExprHandle(5) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Add, simplified.node(), add);
IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
IS_IMM_WITH_VAL(Int, lhs->lhs(), 5);
IS_VAR_WITH_NAME(lhs->rhs(), "y");
IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
IS_IMM_WITH_VAL(Int, rhs->lhs(), 2);
IS_VAR_WITH_NAME(rhs->rhs(), "x");
}
{
// Factorization after merging.
// (2 * x) + (4 * y) + (8 * x + 6 * y) => 10 * (x + y)
ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y) +
(ExprHandle(8) * x + ExprHandle(6) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), 10);
IS_NODE_WITH_NAME(Add, mul->rhs(), add);
IS_VAR_WITH_NAME(add->lhs(), "x");
IS_VAR_WITH_NAME(add->rhs(), "y");
}
{
// Factorization with common divider but different signs.
// (-2 * x) + (4 * y) => -2 * (x - 2 * y)
ExprHandle body = (ExprHandle(-2) * x + ExprHandle(4) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), -2);
IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
IS_VAR_WITH_NAME(sub->lhs(), "x");
IS_NODE_WITH_NAME(Mul, sub->rhs(), mul2);
IS_IMM_WITH_VAL(Int, mul2->lhs(), 2);
IS_VAR_WITH_NAME(mul2->rhs(), "y");
}
}
// (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (3 * z + y + 4 * x)
void testSimplifyFactorizeUneven() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
VarHandle z("z", kInt);
ExprHandle body =
(ExprHandle(4) * x + y + z * 2) + (ExprHandle(4) * x + y + z * 4);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), root);
IS_IMM_WITH_VAL(Int, root->lhs(), 2);
IS_NODE_WITH_NAME(Add, root->rhs(), add1);
IS_NODE_WITH_NAME(Add, add1->lhs(), add2);
IS_NODE_WITH_NAME(Mul, add1->rhs(), xmul);
IS_NODE_WITH_NAME(Mul, add2->lhs(), zmul);
IS_IMM_WITH_VAL(Int, zmul->lhs(), 3);
IS_VAR_WITH_NAME(zmul->rhs(), "z");
IS_VAR_WITH_NAME(add2->rhs(), "y");
IS_IMM_WITH_VAL(Int, xmul->lhs(), 4);
IS_VAR_WITH_NAME(xmul->rhs(), "x");
}
// (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y)
// This is kind of a placeholder test for variable factorization.
void testSimplifyDeeperTerms() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
ExprHandle body = (x * y) + (ExprHandle(2) * x) * (x + y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Add, simplified.node(), add);
IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm);
IS_VAR_WITH_NAME(xxTerm->lhs(), "x");
IS_VAR_WITH_NAME(xxTerm->rhs(), "x");
IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
IS_IMM_WITH_VAL(Int, rhs->lhs(), 3);
IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm);
IS_VAR_WITH_NAME(xyTerm->lhs(), "x");
IS_VAR_WITH_NAME(xyTerm->rhs(), "y");
}
// Tests the difference between two less trivial expressions.
// (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1
void testSimplifyDeeperDifference() {
KernelScope kernel_scope;
VarHandle n("n", kInt);
VarHandle n_1("n_1", kInt);
VarHandle m("m", kInt);
ExprHandle body =
(m * (ExprHandle(1) * n_1) + (n + 1)) - (m * (ExprHandle(1) * n_1) + n);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_IMM_WITH_VAL(Int, simplified.node(), 1);
}
// Test constant folding into the difference between expressions.
// 2 + char((m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n)) => 3
void testSimplifyFoldComplexDifference() {
KernelScope kernel_scope;
VarHandle n("n", kInt);
VarHandle n_1("n_1", kInt);
VarHandle m("m", kInt);
ExprHandle body =
(IntImm::make(2) +
(Cast::make(
kChar,
(m * (ExprHandle(1) * n_1) + (n + 1)) -
(m * (ExprHandle(1) * n_1) + n))));
ExprHandle simplified = IRSimplifier::simplify(body);
IS_IMM_WITH_VAL(Int, simplified.node(), 3);
}
void testSimplifyIfComponents() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
ExprHandle body = IfThenElse::make(
((ExprHandle(5) - ExprHandle(4)) * x) > y,
ExprHandle(2) * x - x,
ExprHandle(2) * y - y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(IfThenElse, simplified.node(), ifexpr);
IS_NODE_WITH_NAME(CompareSelect, ifexpr->condition(), cmp);
EXPECT_EQ(cmp->compare_select_op(), kGT);
IS_VAR_WITH_NAME(cmp->lhs(), "x");
IS_VAR_WITH_NAME(cmp->rhs(), "y");
IS_VAR_WITH_NAME(ifexpr->true_value(), "x");
IS_VAR_WITH_NAME(ifexpr->false_value(), "y");
}
void testSimplifyOpaqueTerms() {
KernelScope kernel_scope;
VarHandle x("x", kInt);
VarHandle y("y", kInt);
{
// 2 * x/y * x - x/y * y => y * x/y
ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_VAR_WITH_NAME(mul->lhs(), "y");
IS_NODE_WITH_NAME(Div, mul->rhs(), div);
IS_VAR_WITH_NAME(div->lhs(), "x");
IS_VAR_WITH_NAME(div->rhs(), "y");
}
{
// x%y - (x%y - 1) => 1
ExprHandle body = (x % y) - ((x % y) - 1);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_IMM_WITH_VAL(Int, simplified.node(), 1);
}
}
void testSimplifyWontReorderFloat() {
KernelScope kernel_scope;
{
// 3 * (3 * x) - 3 * (3 * y) => -9 * (y - x)
// This is an expression we can simplify.
VarHandle x("x", kInt);
VarHandle y("y", kInt);
ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
ExprHandle(3) * (ExprHandle(3) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
IS_IMM_WITH_VAL(Int, mul->lhs(), -9);
IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
IS_VAR_WITH_NAME(sub->lhs(), "y");
IS_VAR_WITH_NAME(sub->rhs(), "x");
}
{
// 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - 3 * (3 * y).
// If the vars are floating point, ops are not associative and we can't
// reorder.
VarHandle x("x", kFloat);
VarHandle y("y", kFloat);
ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
ExprHandle(3) * (ExprHandle(3) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3);
IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul);
IS_IMM_WITH_VAL(Float, lhsVarMul->lhs(), 3);
IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul);
IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3);
IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul);
IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3);
IS_VAR_WITH_NAME(rhsVarMul->rhs(), "y");
}
{
// 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - (9 * y).
// We will simplify subexprs if they dont reorder floating point ops.
VarHandle x("x", kDouble);
VarHandle y("y", kInt);
ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
ExprHandle(3) * (ExprHandle(3) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
IS_IMM_WITH_VAL(Double, lhsMul->lhs(), 3);
IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul);
IS_IMM_WITH_VAL(Double, lhsVarMul->lhs(), 3);
IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
IS_NODE_WITH_NAME_AND_CAST(Mul, sub->rhs(), rhsMul, Double);
IS_IMM_WITH_VAL(Int, rhsMul->lhs(), 9);
IS_VAR_WITH_NAME(rhsMul->rhs(), "y");
}
{
// Prevent reordering if FP propagated from dtypes.
VarHandle x("x", kInt);
VarHandle y("y", kInt);
ExprHandle body = ExprHandle(3.f) * (ExprHandle(3) * x) -
ExprHandle(3) * (ExprHandle(3.f) * y);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3);
IS_NODE_WITH_NAME_AND_CAST(Mul, lhsMul->rhs(), lhsVarMul, Float);
IS_IMM_WITH_VAL(Int, lhsVarMul->lhs(), 3);
IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul);
IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3);
IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul);
IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3);
IS_NODE_WITH_NAME(Cast, rhsVarMul->rhs(), yCast);
IS_VAR_WITH_NAME(yCast->src_value(), "y");
}
{
VarHandle x("x", kFloat);
VarHandle y("y", kFloat);
// x%y - (x%y - 1) => x%y - (x%y - 1).
// We wont reorder opaque ops if they are FP.
ExprHandle body = (x % y) - ((x % y) - 1);
ExprHandle simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
IS_NODE_WITH_NAME(Mod, sub->lhs(), lhsMod);
IS_VAR_WITH_NAME(lhsMod->lhs(), "x");
IS_VAR_WITH_NAME(lhsMod->rhs(), "y");
IS_NODE_WITH_NAME(Sub, sub->rhs(), rhsSub);
IS_NODE_WITH_NAME(Mod, rhsSub->lhs(), rhsMod);
IS_VAR_WITH_NAME(rhsMod->lhs(), "x");
IS_VAR_WITH_NAME(rhsMod->rhs(), "y");
IS_IMM_WITH_VAL(Float, rhsSub->rhs(), 1);
}
const Var* root = simplified.AsNode<Var>();
EXPECT_NE(root, nullptr);
EXPECT_EQ(root->name_hint(), "y");
}
} // namespace jit

View File

@ -9,119 +9,105 @@
namespace torch {
namespace jit {
#define TH_FORALL_TESTS(_) \
_(ExprBasicValueTest) \
_(ExprBasicValueTest02) \
_(ExprLetTest01) \
_(ExprLetStmtTest01) \
_(ExprLetTest02) \
_(ExprIntTest) \
_(ExprFloatTest) \
_(ExprByteTest) \
_(ExprCharTest) \
_(ExprShortTest) \
_(ExprLongTest) \
_(ExprHalfTest) \
_(ExprDoubleTest) \
_(ExprVectorAdd01) \
_(ExprCompareSelectEQ) \
_(ExprSubstitute01) \
_(ExprMath01) \
_(ExprUnaryMath01) \
_(ExprBinaryMath01) \
_(ExprDynamicShapeAdd) \
_(ExprBitwiseOps) \
_(IRPrinterBasicValueTest) \
_(IRPrinterBasicValueTest02) \
_(IRPrinterLetTest01) \
_(IRPrinterLetTest02) \
_(IRPrinterCastTest) \
_(ExprSimple01) \
_(ExprLower01) \
_(ExprSimple02) \
_(ExprSplitWithTailNone) \
_(ExprSplitWithMask01) \
_(ScheduleBroadcastAddBuffer) \
_(ScheduleFunctionCall01) \
_(ScheduleInlineFunc01) \
_(ScheduleFuserStyle) \
_(ScheduleFuserThreeArg) \
_(ScheduleDynamicShape2D) \
_(TypeTest01) \
_(TypePropagation) \
_(Cond01) \
_(IfThenElse01) \
_(IfThenElse02) \
_(ATen_cast_Float) \
_(ATennegInt) \
_(ATennegFloat) \
_(ATenaddInt) \
_(ATenaddFloat) \
_(ATensubInt) \
_(ATensubFloat) \
_(ATenlerp) \
_(ATenaddcmulInt) \
_(ATenaddcmulFloat) \
_(ATenmulInt) \
_(ATenmulFloat) \
_(ATendivInt) \
_(ATendivFloat) \
_(ATenmaxInt) \
_(ATenmaxFloat) \
_(ATenminInt) \
_(ATenminFloat) \
_(ATen_sigmoid_backward) \
_(ATen_tanh_backward) \
_(ATenreciprocal) \
_(ATenreluInt) \
_(ATenreluFloat) \
_(ATenlogFloat) \
_(ATenlog10Float) \
_(ATenlog2Float) \
_(ATenexpFloat) \
_(ATenerfFloat) \
_(ATencosFloat) \
_(ATeneqInt) \
_(ATengeInt) \
_(ATengtInt) \
_(ATenleInt) \
_(ATenltInt) \
_(ConstantFoldSimple) \
_(ConstantFoldTwoLayer) \
_(ConstantFoldShifts) \
_(ConstantFoldBitwise) \
_(ConstantFoldMultiOp) \
_(ConstantFoldMinMax) \
_(ConstantFoldIntrinsics) \
_(ConstantFoldWithVar) \
_(UnFoldableExpr) \
_(HashSimple) \
_(HashEquivalence) \
_(HashEquivalenceAfterFolding) \
_(HashDifferenceTypes) \
_(HashLargeExpression) \
_(SimplifyAdd) \
_(SimplifySub) \
_(SimplifyMultiLayer) \
_(SimplifyMultiTerm) \
_(SimplifyCasts) \
_(SimplifyEliminatesNoOps) \
_(SimplifyMultiVar) \
_(SimplifyReorderings) \
_(SimplifyEliminatesVar) \
_(SimplifyAdds) \
_(SimplifyMuls) \
_(SimplifySubs) \
_(SimplifyMultiOp) \
_(SimplifyManyOps) \
_(SimplifyFactorization) \
_(SimplifyFactorizeUneven) \
_(SimplifyDeeperTerms) \
_(SimplifyDeeperDifference) \
_(SimplifyFoldComplexDifference) \
_(SimplifyIfComponents) \
_(SimplifyOpaqueTerms) \
_(SimplifyWontReorderFloat) \
#define TH_FORALL_TESTS(_) \
_(ExprBasicValueTest) \
_(ExprBasicValueTest02) \
_(ExprLetTest01) \
_(ExprLetStmtTest01) \
_(ExprLetTest02) \
_(ExprIntTest) \
_(ExprFloatTest) \
_(ExprByteTest) \
_(ExprCharTest) \
_(ExprShortTest) \
_(ExprLongTest) \
_(ExprHalfTest) \
_(ExprDoubleTest) \
_(ExprVectorAdd01) \
_(ExprCompareSelectEQ) \
_(ExprSubstitute01) \
_(ExprMath01) \
_(ExprUnaryMath01) \
_(ExprBinaryMath01) \
_(ExprDynamicShapeAdd) \
_(ExprBitwiseOps) \
_(IRPrinterBasicValueTest) \
_(IRPrinterBasicValueTest02) \
_(IRPrinterLetTest01) \
_(IRPrinterLetTest02) \
_(IRPrinterCastTest) \
_(ExprSimple01) \
_(ExprLower01) \
_(ExprSimple02) \
_(ExprSplitWithTailNone) \
_(ExprSplitWithMask01) \
_(ScheduleBroadcastAddBuffer) \
_(ScheduleFunctionCall01) \
_(ScheduleInlineFunc01) \
_(ScheduleFuserStyle) \
_(ScheduleFuserThreeArg) \
_(ScheduleDynamicShape2D) \
_(TypeTest01) \
_(TypePropagation) \
_(Cond01) \
_(IfThenElse01) \
_(IfThenElse02) \
_(ATen_cast_Float) \
_(ATennegInt) \
_(ATennegFloat) \
_(ATenaddInt) \
_(ATenaddFloat) \
_(ATensubInt) \
_(ATensubFloat) \
_(ATenlerp) \
_(ATenaddcmulInt) \
_(ATenaddcmulFloat) \
_(ATenmulInt) \
_(ATenmulFloat) \
_(ATendivInt) \
_(ATendivFloat) \
_(ATenmaxInt) \
_(ATenmaxFloat) \
_(ATenminInt) \
_(ATenminFloat) \
_(ATen_sigmoid_backward) \
_(ATen_tanh_backward) \
_(ATenreciprocal) \
_(ATenreluInt) \
_(ATenreluFloat) \
_(ATenlogFloat) \
_(ATenlog10Float) \
_(ATenlog2Float) \
_(ATenexpFloat) \
_(ATenerfFloat) \
_(ATencosFloat) \
_(ATeneqInt) \
_(ATengeInt) \
_(ATengtInt) \
_(ATenleInt) \
_(ATenltInt) \
_(ConstantFoldSimple) \
_(ConstantFoldTwoLayer) \
_(ConstantFoldShifts) \
_(ConstantFoldBitwise) \
_(ConstantFoldMultiOp) \
_(ConstantFoldMinMax) \
_(ConstantFoldIntrinsics) \
_(ConstantFoldWithVar) \
_(UnFoldableExpr) \
_(HashSimple) \
_(HashEquivalence) \
_(HashEquivalenceAfterFolding) \
_(HashDifferenceTypes) \
_(HashLargeExpression) \
_(SimplifyAdd) \
_(SimplifySub) \
_(SimplifyMultiLayer) \
_(SimplifyMultiTerm) \
_(SimplifyCasts) \
_(SimplifyEliminatesNoOps) \
_(SimplifyMultiVar) \
_(SimplifyEliminatesVar) \
_(StmtClone)
#define TH_FORALL_TESTS_LLVM(_) \

View File

@ -205,7 +205,6 @@ libtorch_sources = [
"torch/csrc/jit/tensorexpr/ir.cpp",
"torch/csrc/jit/tensorexpr/ir_mutator.cpp",
"torch/csrc/jit/tensorexpr/ir_printer.cpp",
"torch/csrc/jit/tensorexpr/ir_simplifier.cpp",
"torch/csrc/jit/tensorexpr/ir_visitor.cpp",
"torch/csrc/jit/tensorexpr/kernel.cpp",
"torch/csrc/jit/tensorexpr/llvm_codegen.cpp",

View File

@ -59,24 +59,24 @@ class Value {
void* ptr;
};
#define VALUE_AS_DISPATCH(Type, Name) \
template <> \
inline Type Value::as<Type>() const { \
if (dtype_ != k##Name) { \
throw unsupported_dtype(); \
} \
return Name##values[0]; \
#define VALUE_AS_DISPATCH(Type, Name) \
template <> \
inline Type Value::as<Type>() const { \
if (dtype_ != k##Name) { \
throw unsupported_dtype(); \
} \
return Name##values[0]; \
}
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_DISPATCH);
#undef VALUE_AS_DISPATCH
#define VALUE_AS_VEC_DISPATCH(Type, Name) \
template <> \
inline const std::vector<Type>& Value::as_vec<Type>() const { \
if (dtype_.scalar_type() != ScalarType::Name) { \
throw unsupported_dtype(); \
} \
return Name##values; \
#define VALUE_AS_VEC_DISPATCH(Type, Name) \
template <> \
inline const std::vector<Type>& Value::as_vec<Type>() const { \
if (dtype_.scalar_type() != ScalarType::Name) { \
throw unsupported_dtype(); \
} \
return Name##values; \
}
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_VEC_DISPATCH);
#undef VALUE_AS_VEC_DISPATCH
@ -479,6 +479,7 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {
throw malformed_input(v);
}
if (src_dtype != dst_dtype) {
switch (src_dtype.scalar_type()) {
#define SRC_TYPE_CASE(Type, Name) \
@ -911,28 +912,6 @@ inline Stmt* Substitute(Stmt* stmt, const VarMapping& var_mapping) {
return stmt->accept_mutator(&var_sub);
}
// Uses the evaluator to fold an Expression with constant terms.
// E.g. evaluateOp(Add(3, 4)) => 7.
// Expr v must not have any unbound Vars.
static Expr* evaluateOp(const Expr* v) {
ExprHandle handle(v);
ExprEval<SimpleIREvaluator> eval(handle);
switch (v->dtype().scalar_type()) {
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: { \
Type val = eval.value<Type>(); \
return getImmediateByType(v->dtype().scalar_type(), val); \
}
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
#undef TYPE_CASE
default:
LOG(FATAL) << "Unsupported datatype: " << v->dtype();
return nullptr;
}
return nullptr;
}
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -31,8 +31,6 @@ enum IRNodeType {
kCompareSelect,
kLet,
kCast,
kBroadcast,
kRamp,
kNone
};

View File

@ -1,5 +1,3 @@
#pragma once
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
@ -318,13 +316,6 @@ class HashProvider : public IRVisitor {
putHash(v, hash);
}
template <typename... Types>
SimplifierHashType hash_combine(const Types&... args) {
SimplifierHashType seed = 0;
_hash_combine(seed, args...);
return seed;
}
private:
SimplifierHashType hashOf(const Expr* e) {
auto it = exprToHash_.find(e);
@ -380,10 +371,6 @@ class HashProvider : public IRVisitor {
(seed << 7) + (seed >> 4);
}
void _hash_combine(SimplifierHashType& seed, const Expr* e) {
_hash_combine(seed, hash(e));
}
template <typename T, typename... Types>
void _hash_combine(
SimplifierHashType& seed,
@ -393,6 +380,13 @@ class HashProvider : public IRVisitor {
_hash_combine(seed, args...);
}
template <typename... Types>
SimplifierHashType hash_combine(const Types&... args) {
SimplifierHashType seed = 0;
_hash_combine(seed, args...);
return seed;
}
void putHash(const KernelScopedObject* e, SimplifierHashType h) {
auto res = exprToHash_.emplace(e, h);
if (res.second == false) {

View File

@ -283,94 +283,24 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE);
// Get immediate by ScalarType.
template <typename T>
Expr* getImmediateByType(ScalarType immType, T initialVal) {
ExprHandle getImmediateByType(ScalarType immType, T initialVal) {
switch (immType) {
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: \
return new Name##Imm(initialVal);
return Name##Imm::make(initialVal);
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
#undef TYPE_CASE
default:
throw unsupported_dtype();
}
return nullptr;
return ExprHandle();
}
template <typename T>
Expr* getImmediateByType(Dtype dtype, T initialVal) {
ExprHandle getImmediateByType(Dtype dtype, T initialVal) {
return getImmediateByType<T>(dtype.scalar_type(), initialVal);
}
template <typename T>
T immediateAs(const Expr* e) {
#define TYPE_CASE(Type, Name) \
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(e)) { \
return imm->value(); \
}
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
#undef TYPE_CASE
throw unsupported_dtype();
return 0;
}
template <typename T>
bool immediateEquals(const Expr* e, T val) {
#define TYPE_CASE(Type, Name) \
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(e)) { \
return imm->value() == val; \
}
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
#undef TYPE_CASE
throw unsupported_dtype();
return false;
}
template <typename T>
bool immediateIsNegative(const T* e) {
#define TYPE_CASE(Type, Name) \
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(e)) { \
return imm->value() < 0; \
}
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
#undef TYPE_CASE
return false;
}
// Creates a new Expr of the given type with the provided lhs and rhs.
static const Expr* newBinaryOpOfType(
IRNodeType expr_type,
const Expr* lhs,
const Expr* rhs,
bool option) {
switch (expr_type) {
case IRNodeType::kAdd:
return new Add(lhs, rhs);
case IRNodeType::kSub:
return new Sub(lhs, rhs);
case IRNodeType::kMul:
return new Mul(lhs, rhs);
case IRNodeType::kDiv:
return new Div(lhs, rhs);
case IRNodeType::kMod:
return new Mod(lhs, rhs);
case IRNodeType::kMax:
return new Max(lhs, rhs, option);
case IRNodeType::kMin:
return new Min(lhs, rhs, option);
case IRNodeType::kAnd:
return new And(lhs, rhs);
case IRNodeType::kXor:
return new Xor(lhs, rhs);
case IRNodeType::kLshift:
return new Lshift(lhs, rhs);
case IRNodeType::kRshift:
return new Rshift(lhs, rhs);
default:
LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
return nullptr;
}
}
// Bind the value to the var and evaluate the body.
class Let : public ExprNode<Let> {
public:
@ -424,7 +354,7 @@ class Ramp : public ExprNode<Ramp> {
}
Ramp(const Expr* base, const Expr* stride, int lanes)
: ExprNodeBase(Dtype(base->dtype(), lanes), kRamp),
: ExprNodeBase(Dtype(base->dtype(), lanes)),
base_(base),
stride_(stride),
lanes_(lanes) {
@ -490,7 +420,7 @@ class Broadcast : public ExprNode<Broadcast> {
return ExprHandle(new Broadcast(value.node(), lanes));
}
Broadcast(const Expr* value, int lanes)
: ExprNodeBase(Dtype(value->dtype(), lanes), kBroadcast),
: ExprNodeBase(Dtype(value->dtype(), lanes)),
value_(value),
lanes_(lanes) {}
@ -632,7 +562,8 @@ class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
const ExprHandle& ret_val1,
const ExprHandle& ret_val2,
CompareSelectOperation cmp_op) {
if (lhs.dtype() != rhs.dtype() || ret_val1.dtype() != ret_val2.dtype()) {
if (lhs.dtype() != rhs.dtype() ||
ret_val1.dtype() != ret_val2.dtype()) {
throw malformed_input();
}
return ExprHandle(new CompareSelect(
@ -859,8 +790,48 @@ class Intrinsics : public CallNode<Intrinsics> {
IntrinsicsOp op_type_;
};
class Polynomial;
class Term;
/* An internal only Expr used in IR simplification.
* Encodes relationship y = Ax + B, where A and B are Immediates.
* Not required to be implemented by codegen. */
class LinearForm : public ExprNode<LinearForm> {
public:
LinearForm(const Expr* x, const Expr* A, const Expr* B)
: ExprNodeBase(dtypeFor(x, A, B)), x_(x), A_(A), B_(B) {}
LinearForm(const Expr* x)
: ExprNodeBase(x->dtype()),
x_(x),
A_(new CharImm(1)),
B_(new CharImm(0)) {}
const Expr* getX() const {
return x_;
}
const Expr* getA() const {
return A_;
}
const Expr* getB() const {
return B_;
}
void setA(const Expr* A) {
A_ = A;
}
void setB(const Expr* B) {
B_ = B;
}
static Dtype dtypeFor(const Expr* A, const Expr* B, const Expr* C) {
return ToDtype(promoteTypes(
A->dtype().scalar_type(), promoteTypes(B->dtype(), C->dtype())));
}
private:
const Expr* x_;
const Expr* A_;
const Expr* B_;
};
class FunctionCall;

View File

@ -2,7 +2,6 @@
#include <torch/csrc/jit/tensorexpr/eval.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
namespace torch {
namespace jit {
@ -215,7 +214,6 @@ const Expr* IRMutator::mutate(const IfThenElse* v) {
const Expr* condition_new = condition->accept_mutator(this);
const Expr* true_value_new = true_value->accept_mutator(this);
const Expr* false_value_new = false_value->accept_mutator(this);
if (condition == condition_new && true_value == true_value_new &&
false_value == false_value_new) {
return v;
@ -234,24 +232,11 @@ const Expr* IRMutator::mutate(const FunctionCall* v) {
return this->mutate(base);
}
const Expr* IRMutator::mutate(const Term* v) {
const Expr* newScalar = v->scalar()->accept_mutator(this);
std::vector<const Expr*> variables;
for (const auto* t : v->variables()) {
variables.push_back(t->accept_mutator(this));
}
return new Term(v->hasher(), newScalar, variables);
}
const Expr* IRMutator::mutate(const Polynomial* v) {
const Expr* newScalar = v->scalar()->accept_mutator(this);
std::vector<const Term*> variables;
for (const auto* t : v->variables()) {
variables.push_back(static_cast<const Term*>(t->accept_mutator(this)));
}
return new Polynomial(v->hasher(), newScalar, variables);
const Expr* IRMutator::mutate(const LinearForm* v) {
const Expr* new_x = v->getX()->accept_mutator(this);
const Expr* new_a = v->getA()->accept_mutator(this);
const Expr* new_b = v->getB()->accept_mutator(this);
return new LinearForm(new_x, new_a, new_b);
}
const Expr* IRMutator::mutate(const BaseCallNode* v) {

View File

@ -45,8 +45,7 @@ class Allocate;
class Free;
class Cond;
class Stmt;
class Term;
class Polynomial;
class LinearForm;
class TORCH_API IRMutator {
public:
@ -87,8 +86,7 @@ class TORCH_API IRMutator {
virtual const Expr* mutate(const Intrinsics* v);
virtual const Expr* mutate(const FunctionCall* v);
virtual const Expr* mutate(const Term* v);
virtual const Expr* mutate(const Polynomial* v);
virtual const Expr* mutate(const LinearForm* v);
virtual Stmt* mutate(const For* v);
virtual Stmt* mutate(const Block* v);

View File

@ -1,7 +1,5 @@
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
namespace torch {
namespace jit {
namespace tensorexpr {
@ -200,7 +198,7 @@ template <
typename T,
std::enable_if_t<!std::is_floating_point<T>::value>* = nullptr>
static void formatImm(std::ostream& os, T v) {
os << +v;
os << v;
}
// NOLINTNEXTLINE
@ -370,33 +368,9 @@ void IRPrinter::visit(const Cond* v) {
}
}
void IRPrinter::visit(const Term* v) {
os() << "Term(";
v->scalar()->accept(this);
for (auto* t : v->variables()) {
os() << ",";
t->accept(this);
}
os() << ")";
}
void IRPrinter::visit(const Polynomial* v) {
bool first = true;
os() << "Polynomial(";
for (auto* t : v->variables()) {
emitIndent();
if (!first) {
os() << " + ";
}
first = false;
t->accept(this);
}
if (!first) {
os() << " + ";
}
v->scalar()->accept(this);
os() << ")";
void IRPrinter::visit(const LinearForm* v) {
os() << "(" << *v->getA() << ") * (" << *v->getX() << ") + (" << *v->getB()
<< ")" << std::endl;
}
void IRPrinter::emitIndent() {

View File

@ -48,8 +48,7 @@ class TORCH_API IRPrinter : public IRVisitor {
void visit(const Allocate* v) override;
void visit(const Free* v) override;
void visit(const Cond* v) override;
void visit(const Term* v) override;
void visit(const Polynomial* v) override;
void visit(const LinearForm* v) override;
std::ostream& os() {
return printer_os_;

File diff suppressed because it is too large Load Diff

View File

@ -1,275 +1,367 @@
#pragma once
#include <torch/csrc/jit/tensorexpr/eval.h>
#include <torch/csrc/jit/tensorexpr/hash_server.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_mutator.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/types.h>
/* IR Simplification
*
* Simplfies expressions in two stages:
* 1. Recursively traverse the map combining similar operations into Terms
* (interacted via Multiplication) and Polynomials (interacted via Addition). We
* reorder the components of each Term or Polynomial into a consistent order to
* allow combination or cancelling of like terms.
* 2. Once the format of the tree is minimal, expand each Term into a sequence
* of Muls, and each Polynomial into a sequence of Ads.
*/
#include "torch/csrc/jit/tensorexpr/eval.h"
#include "torch/csrc/jit/tensorexpr/ir_mutator.h"
#include "torch/csrc/jit/tensorexpr/ir_visitor.h"
#include "torch/csrc/jit/tensorexpr/types.h"
namespace torch {
namespace jit {
namespace tensorexpr {
// A bunch of helpers for determine the Dtype of the output of a multi argument
// Term or Polynomial.
namespace {
template <class ExprType>
Dtype promoteTypesVec(const Expr* s, std::vector<const ExprType*>& v) {
Dtype t = s->dtype();
bool first = true;
// Uses the evaluator to fold an operation with constant terms.
// Expr v must be evaluatable without Vars.
static Expr* evaluateOp(const Expr* v) {
ExprHandle handle(v);
ExprEval<SimpleIREvaluator> eval(handle);
for (auto* e : v) {
if (first) {
t = Dtype(t.scalar_type(), e->dtype().lanes());
first = false;
switch (v->dtype().scalar_type()) {
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: { \
Type val = eval.value<Type>(); \
return getImmediateByType(v->dtype().scalar_type(), val).node(); \
}
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
#undef TYPE_CASE
default:
LOG(FATAL) << "Unsupported datatype: " << v->dtype();
return nullptr;
}
return nullptr;
} // namespace tensorexpr
static const Expr* newBinaryOpOfType(
IRNodeType expr_type,
const Expr* lhs,
const Expr* rhs,
bool option) {
switch (expr_type) {
case IRNodeType::kAdd:
return new Add(lhs, rhs);
case IRNodeType::kSub:
return new Sub(lhs, rhs);
case IRNodeType::kMul:
return new Mul(lhs, rhs);
case IRNodeType::kDiv:
return new Div(lhs, rhs);
case IRNodeType::kMod:
return new Mod(lhs, rhs);
case IRNodeType::kMax:
return new Max(lhs, rhs, option);
case IRNodeType::kMin:
return new Min(lhs, rhs, option);
case IRNodeType::kAnd:
return new And(lhs, rhs);
case IRNodeType::kXor:
return new Xor(lhs, rhs);
case IRNodeType::kLshift:
return new Lshift(lhs, rhs);
case IRNodeType::kRshift:
return new Rshift(lhs, rhs);
default:
LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
return nullptr;
}
}
/* Interprets expr as an Immediate and returns the value as type T. */
template <typename T>
T immediateAs(const Expr* expr) {
T val{0};
switch (expr->dtype().scalar_type()) {
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: \
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(expr)) { \
val = imm->value(); \
} else { \
LOG(FATAL) << "Bad expr: " << *expr << "\n"; \
} \
break;
AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE);
#undef TYPE_CASE
default:
LOG(FATAL) << "Unsupported datatype: " << expr->dtype();
}
return val;
}
/* Takes a LinearForm and converts it to Mul + (Add/Sub). */
const Expr* expandLinearForm(const LinearForm* v, IRMutator* mutator) {
const Expr* mul = nullptr;
const Expr* A = v->getA();
const Expr* B = v->getB();
const Expr* X = v->getX();
// we only really care about 0 and 1, so double should be fine.
double Aval = immediateAs<double>(A);
double Bval = immediateAs<double>(B);
// First handle A.
if (Aval == 0) {
if (Bval == 0) {
return getImmediateByType(X->dtype(), 0).node();
}
t = promoteTypes(t, e->dtype());
}
return t;
}
return B;
} else if (Aval == 1) {
mul = X;
} else if (Aval == -1) {
return new Sub(B, X);
} else if (Aval < 0) {
// Negate A.
ExprHandle zero = getImmediateByType(A->dtype(), 0);
Sub* A_Sub = new Sub(zero.node(), A);
template <class ExprType>
Dtype promoteTypesVec(std::vector<const ExprType*>& v) {
if (v.empty()) {
throw malformed_input("empty list of types");
return new Sub(B, new Mul(X, evaluateOp(A_Sub)));
} else {
mul = new Mul(X, A);
}
Dtype t = v[0]->dtype();
for (auto* e : v) {
t = promoteTypes(t, e->dtype());
}
return t;
}
template <class ExprType>
Dtype promoteTypesMap(
const Expr* s,
std::unordered_map<SimplifierHashType, const ExprType*>& m) {
Dtype t = s->dtype();
bool first = true;
for (auto& e : m) {
if (first) {
t = Dtype(t.scalar_type(), e.second->dtype().lanes());
first = false;
}
t = promoteTypes(t, e.second->dtype());
}
return t;
}
template <class ExprType>
Dtype promoteTypesVar(const ExprType* e) {
return e->dtype();
}
template <class ExprType, class... Args>
Dtype promoteTypesVar(const ExprType* e, Args... es) {
Dtype lhs = e->dtype();
Dtype rhs = promoteTypesVar(es...);
if (e->isConstant()) {
lhs = Dtype(lhs.scalar_type(), rhs.lanes());
if (Bval == 0) {
return mul;
}
return promoteTypes(lhs, rhs);
return new Add(mul, B);
}
// Helper for determining if an Expr is a multi-lane primitive (e.g. Broadcast
// or Ramp).
bool isMultilanePrimitive(const Expr* e) {
return e->expr_type() == IRNodeType::kBroadcast ||
e->expr_type() == IRNodeType::kRamp;
}
} // namespace
// A Term represents a grouping of Exprs through multiplication.
// E.g. product(scalar, *variables).
class Term : public ExprNode<Term> {
/* Expand any remaining LinearTerms into their component pieces */
class LinearFormExpander : public IRMutator {
public:
template <class... Args>
Term(HashProvider& hasher, const Expr* s, Args... ts)
: ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
CHECK(s->isConstant());
addComponent(ts...);
sort();
const Expr* mutate(const LinearForm* v) {
return expandLinearForm(v, this);
}
Term(HashProvider& hasher, const Expr* s, std::vector<const Expr*> v)
: ExprNodeBase(promoteTypesVec(s, v)),
variables_(std::move(v)),
scalar_(s),
hasher_(hasher) {
sort();
}
// Convenience constructor from a map of hash -> var, used when merging Terms.
Term(
HashProvider& hasher,
const Expr* s,
std::unordered_map<SimplifierHashType, const Expr*> varmap)
: ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
for (auto& p : varmap) {
addComponent(p.second);
}
sort();
}
const Expr* scalar() const {
return scalar_;
}
const std::vector<const Expr*>& variables() const {
return variables_;
}
HashProvider& hasher() const {
return hasher_;
}
// Produce a hash of just the variable components of this term, to determine
// if it can be combined with another term.
SimplifierHashType hashVars() const;
private:
std::vector<const Expr*> variables_;
const Expr* scalar_;
HashProvider& hasher_;
void addComponent() {}
void addComponent(const Expr* e) {
variables_.push_back(e);
}
template <class... Es>
void addComponent(const Expr* e, Es... es) {
addComponent(e);
addComponent(es...);
}
// Sort by hash to normalize order of components.
void sort();
};
// Polynomial represents a grouping of Exprs by addition.
// E.g. sum(*variables, scalar).
// This would better be called Expression, but, naming conflict...
class Polynomial : public ExprNode<Polynomial> {
/* Simplify the IR by combining arithmetic expressions over a common term.
*/
class IRSimplifier : public IRMutator {
public:
template <class... Args>
Polynomial(HashProvider& hasher, const Expr* s, Args... ts)
: ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
CHECK(s->isConstant());
addTerm(ts...);
sort();
}
const Expr* mutate(const Add* v) override {
const Expr* lhs = v->lhs();
const Expr* rhs = v->rhs();
const Expr* lhs_new = lhs->accept_mutator(this);
const Expr* rhs_new = rhs->accept_mutator(this);
Polynomial(HashProvider& hasher, const Expr* s, std::vector<const Term*> v)
: ExprNodeBase(promoteTypesVec(s, v)),
variables_(std::move(v)),
scalar_(s),
hasher_(hasher) {
sort();
}
// Helper constructor for list of terms with no scalar component.
Polynomial(HashProvider& hasher, std::vector<const Term*> terms)
: ExprNodeBase(promoteTypesVec(terms)),
variables_(std::move(terms)),
scalar_(getImmediateByType(dtype(), 0)),
hasher_(hasher) {
sort();
}
// Convenience constructor for map of hash -> var, used when merging
// Polynomials.
Polynomial(
HashProvider& hasher,
const Expr* s,
std::unordered_map<SimplifierHashType, const Term*> varmap)
: ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
for (auto& p : varmap) {
addTerm(p.second);
// Constant Folding.
if (lhs_new->isConstant() && rhs_new->isConstant()) {
const Expr* result = evaluateOp(v);
return result;
}
sort();
const LinearForm* lhsLinear = dynamic_cast<const LinearForm*>(lhs_new);
const LinearForm* rhsLinear = dynamic_cast<const LinearForm*>(rhs_new);
if (lhsLinear && rhsLinear) {
// Can add two LinearTerms if they reference the same Var.
if (lhsLinear->getX() == rhsLinear->getX()) {
Add* A_Add = new Add(lhsLinear->getA(), rhsLinear->getA());
Add* B_Add = new Add(lhsLinear->getB(), rhsLinear->getB());
LinearForm* linear = new LinearForm(
lhsLinear->getX(), evaluateOp(A_Add), evaluateOp(B_Add));
return linear;
}
// otherwise cannot simplify further.
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
}
// Can add a scalar into the B term of LinearTerm.
if (lhsLinear && rhs_new->isConstant()) {
Add* B_Add = new Add(lhsLinear->getB(), rhs_new);
LinearForm* linear = new LinearForm(
lhsLinear->getX(), lhsLinear->getA(), evaluateOp(B_Add));
return linear;
}
if (rhsLinear && lhs_new->isConstant()) {
Add* B_Add = new Add(rhsLinear->getB(), lhs_new);
LinearForm* linear = new LinearForm(
rhsLinear->getX(), rhsLinear->getA(), evaluateOp(B_Add));
return linear;
}
// Can create a LinearTerm over any sub expression.
if (lhs_new->isConstant()) {
LinearForm* linear = new LinearForm(rhs_new);
linear->setB(evaluateOp(lhs_new));
return linear;
}
if (rhs_new->isConstant()) {
LinearForm* linear = new LinearForm(lhs_new);
linear->setB(evaluateOp(rhs_new));
return linear;
}
/// Broadcasts are a bit more involved.
if (const Broadcast* bc = dynamic_cast<const Broadcast*>(lhs_new)) {
if (const Expr* ret = handleBroadcastAdd(bc, rhs_new)) {
return ret;
}
}
if (const Broadcast* bc = dynamic_cast<const Broadcast*>(rhs_new)) {
if (const Expr* ret = handleBroadcastAdd(bc, lhs_new)) {
return ret;
}
}
// No change.
if (lhs == lhs_new && rhs == rhs_new) {
return v;
}
// Cannot simplify.
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
}
const Expr* scalar() const {
return scalar_;
}
const std::vector<const Term*>& variables() const {
return variables_;
}
HashProvider& hasher() const {
return hasher_;
const Expr* mutate(const Sub* v) override {
const Expr* lhs = v->lhs();
const Expr* rhs = v->rhs();
const Expr* lhs_new = lhs->accept_mutator(this);
const Expr* rhs_new = rhs->accept_mutator(this);
// Constant Folding.
if (lhs_new->isConstant() && rhs_new->isConstant()) {
const Expr* result = evaluateOp(v);
return result;
}
const LinearForm* lhsLinear = dynamic_cast<const LinearForm*>(lhs_new);
const LinearForm* rhsLinear = dynamic_cast<const LinearForm*>(rhs_new);
if (lhsLinear && rhsLinear) {
// Can sub two LinearTerms if they reference the same Var.
if (lhsLinear->getX() == rhsLinear->getX()) {
Sub* A_Sub = new Sub(lhsLinear->getA(), rhsLinear->getA());
Sub* B_Sub = new Sub(lhsLinear->getB(), rhsLinear->getB());
LinearForm* linear = new LinearForm(
lhsLinear->getX(), evaluateOp(A_Sub), evaluateOp(B_Sub));
return linear;
}
// otherwise cannot simplify further.
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
}
// Can just sub from B term if LHS is a LinearTerm.
if (lhsLinear && rhs_new->isConstant()) {
Sub* B_Sub = new Sub(lhsLinear->getB(), rhs_new);
LinearForm* linear = new LinearForm(
lhsLinear->getX(), lhsLinear->getA(), evaluateOp(B_Sub));
return linear;
}
// Slightly more complicated if the RHS is LinearTerm.
if (rhsLinear && lhs_new->isConstant()) {
// The linear needs to be negated.
ExprHandle zero = getImmediateByType(rhsLinear->getA()->dtype(), 0);
Sub* A_Sub = new Sub(zero.node(), rhsLinear->getA());
Sub* B_Sub = new Sub(rhsLinear->getB(), lhs_new);
LinearForm* linear = new LinearForm(
rhsLinear->getX(), evaluateOp(A_Sub), evaluateOp(B_Sub));
return linear;
}
// Can create a new LinearTerm, but since the B term is defined as Add we
// must negate it.
if (rhs_new->isConstant()) {
LinearForm* linear = new LinearForm(lhs_new);
ExprHandle zero = getImmediateByType(linear->getA()->dtype(), 0);
Sub* B_Sub = new Sub(zero.node(), rhs_new);
linear->setB(evaluateOp(B_Sub));
return linear;
}
// Can create a new LinearTerm with the A term -1 to negate the Expr.
if (lhs_new->isConstant()) {
// Negate by using -1 as the first linear.
ExprHandle negOne = getImmediateByType(rhs_new->dtype(), -1);
LinearForm* linear =
new LinearForm(rhs_new, negOne.node(), evaluateOp(lhs_new));
return linear;
}
// Nothing to do.
if (lhs == lhs_new && rhs == rhs_new) {
return v;
}
// Cannot simplify.
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
}
SimplifierHashType hashVars() const;
const Expr* mutate(const Mul* v) override {
const Expr* lhs = v->lhs();
const Expr* rhs = v->rhs();
const Expr* lhs_new = lhs->accept_mutator(this);
const Expr* rhs_new = rhs->accept_mutator(this);
private:
std::vector<const Term*> variables_;
const Expr* scalar_;
HashProvider& hasher_;
// Constant Folding.
if (lhs_new->isConstant() && rhs_new->isConstant()) {
return evaluateOp(v);
}
void addTerm(const Term* t) {
variables_.push_back(t);
const LinearForm* lhsLinear = dynamic_cast<const LinearForm*>(lhs_new);
const LinearForm* rhsLinear = dynamic_cast<const LinearForm*>(rhs_new);
if (lhsLinear && rhsLinear) {
// Lets not get into higher order terms.
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
}
// Easy to simplify into an existing LinearTerm by multiplying A and B.
if (lhsLinear && rhs_new->isConstant()) {
Mul* A_Mul = new Mul(lhsLinear->getA(), rhs_new);
Mul* B_Mul = new Mul(lhsLinear->getB(), rhs_new);
LinearForm* linear = new LinearForm(
lhsLinear->getX(), evaluateOp(A_Mul), evaluateOp(B_Mul));
return linear;
}
if (rhsLinear && lhs_new->isConstant()) {
Mul* A_Mul = new Mul(rhsLinear->getA(), lhs_new);
Mul* B_Mul = new Mul(rhsLinear->getB(), lhs_new);
LinearForm* linear = new LinearForm(
rhsLinear->getX(), evaluateOp(A_Mul), evaluateOp(B_Mul));
return linear;
}
// Easy to create a new LinearTerm by setting term A.
if (lhs_new->isConstant()) {
LinearForm* linear = new LinearForm(rhs_new);
linear->setA(evaluateOp(lhs_new));
return linear;
}
if (rhs_new->isConstant()) {
LinearForm* linear = new LinearForm(lhs_new);
linear->setA(evaluateOp(rhs_new));
return linear;
}
// Broadcasts have special logic.
if (const Broadcast* bc = dynamic_cast<const Broadcast*>(lhs_new)) {
if (const Expr* ret = handleBroadcastMul(bc, rhs_new)) {
return ret;
}
}
if (const Broadcast* bc = dynamic_cast<const Broadcast*>(rhs_new)) {
if (const Expr* ret = handleBroadcastMul(bc, lhs_new)) {
return ret;
}
}
// Cannot be simplified, just exit.
if (lhs == lhs_new && rhs == rhs_new) {
return v;
}
return expandAndRecurse(v->expr_type(), lhs_new, rhs_new);
}
template <class... Ts>
void addTerm(const Term* t, Ts... ts) {
addTerm(t);
addTerm(ts...);
}
// Sort by hash to normalize order of terms.
void sort();
};
// Simplify the IR by combining arithmetic expressions over common terms.
class TORCH_API PolynomialTransformer : public IRMutator {
public:
// Inserts term into the provided map, in the case of a hash collision
// combines the term with the existing and updates the map.
void addOrUpdateTerm(
std::unordered_map<SimplifierHashType, const Term*>& varmap,
const Term* term);
// Add Polynomial expressions, combining Terms representing the same
// variables.
const Expr* addPolynomials(const Polynomial* lhs, const Polynomial* rhs);
// Insert a new Term into the provided polynomial. If the new term has common
// variables to an existing term it is combined.
const Expr* insertTerm(const Polynomial* poly, const Term* term);
// Merge and simplify addition.
const Expr* mutate(const Add* v) override;
// Subtract one term from another, cancelling if necessary.
const Expr* subTerms(const Term* lhs, const Term* rhs, bool negated);
// Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where
// possible.
const Expr* subPolynomials(const Polynomial* lhs, const Polynomial* rhs);
// Merge and simplify subtraction.
const Expr* mutate(const Sub* v) override;
// Multiply two terms together, usually creating a new term with the variable
// lists concatenated.
const Term* mulTerms(const Term* lhs, const Term* rhs);
// Multiply a Polynomial by a Term.
const Expr* polyByTerm(const Polynomial* poly, const Term* term);
// Merge and simplify multiplication.
const Expr* mutate(const Mul* v) override;
const Expr* mutate(const Div* v) override {
// TODO div simplification will require a rational node.
@ -304,9 +396,37 @@ class TORCH_API PolynomialTransformer : public IRMutator {
return mutateBinaryOp(v, this, v->propagate_nans());
}
const Expr* mutate(const Intrinsics* v) override;
const Expr* mutate(const Intrinsics* v) override {
std::vector<const Expr*> new_params;
bool changed = false;
bool allConstant = true;
for (const auto* p : v->params()) {
const Expr* new_child = p->accept_mutator(this);
new_params.push_back(new_child);
const Expr* mutate(const Cast* v) override;
changed |= p != new_child;
allConstant &= new_child->isConstant();
}
const Expr* node = v;
if (changed) {
node = new Intrinsics(v->op_type(), new_params);
}
if (!allConstant || !v->isPure()) {
return node;
}
return evaluateOp(node);
}
const Expr* mutate(const Cast* v) override {
if (v->src_value()->isConstant()) {
return evaluateOp(v);
}
return v;
}
template <typename Op>
static const Expr* mutateBinaryOp(
@ -332,40 +452,12 @@ class TORCH_API PolynomialTransformer : public IRMutator {
return evaluateOp(node);
}
static const Expr* simplify(const Expr* e);
static ExprHandle simplify(const ExprHandle& e);
static Stmt* simplify(Stmt* e);
private:
HashProvider hasher_;
}; // namespace tensorexpr
// Expands Terms and Polynomial expressions into primitive operations.
// Does some simple factorization and reordering.
class TORCH_API TermExpander : public IRMutator {
PolynomialTransformer* simplifier_;
public:
TermExpander(PolynomialTransformer* simplifier) : simplifier_(simplifier) {}
// Expand Terms out to a series of Muls.
const Expr* mutate(const Term* v) override;
// Trivially factorize terms by GCD of scalar components.
const Expr* factorizePolynomial(const Polynomial* poly);
// Expand Polynomials out to a series of Adds.
const Expr* mutate(const Polynomial* v);
};
class TORCH_API IRSimplifier {
public:
static const Expr* simplify(const Expr* e) {
PolynomialTransformer simplifier;
IRSimplifier simplifier;
e = e->accept_mutator(&simplifier);
// There may be terms left in the IR, expand them.
TermExpander expander(&simplifier);
LinearFormExpander expander;
e = e->accept_mutator(&expander);
return e;
@ -376,15 +468,74 @@ class TORCH_API IRSimplifier {
}
static Stmt* simplify(Stmt* s) {
PolynomialTransformer simplifier;
IRSimplifier simplifier;
s = s->accept_mutator(&simplifier);
// There may be terms left in the IR, expand them.
TermExpander expander(&simplifier);
LinearFormExpander expander;
s = s->accept_mutator(&expander);
return s;
}
private:
/* Expands lhs and rhs if they are LinearTerms, creating a new op to hold
* them. If either side expands to a constant term, attempt simplification of
* the new op. */
const Expr* expandAndRecurse(
IRNodeType expr_type,
const Expr* lhs,
const Expr* rhs) {
if (const LinearForm* lhsLinear = dynamic_cast<const LinearForm*>(lhs)) {
lhs = expandLinearForm(lhsLinear, this);
}
if (const LinearForm* rhsLinear = dynamic_cast<const LinearForm*>(rhs)) {
rhs = expandLinearForm(rhsLinear, this);
}
const Expr* result = newBinaryOpOfType(expr_type, lhs, rhs, false);
// lhs or rhs can become constant during expansion, if either is now
// constant we can keep merging into a linear term. Have another attempt to
// simplify the new op.
if (lhs->isConstant() || rhs->isConstant()) {
return result->accept_mutator(this);
}
return result;
}
/* Handles optimization cases for Broadcast() + Other */
const Expr* handleBroadcastAdd(const Broadcast* bc, const Expr* other) {
if (bc->value()->isConstant() && immediateAs<int>(bc->value()) == 0) {
return other;
}
if (const Ramp* r = dynamic_cast<const Ramp*>(other)) {
// Add the broadcast to the start of the Ramp.
const Expr* ret =
new Ramp(new Add(bc->value(), r->base()), r->stride(), r->lanes());
return ret->accept_mutator(this);
}
return nullptr;
}
/* Handles optimization cases for Broadcast() * Other */
const Expr* handleBroadcastMul(const Broadcast* bc, const Expr* other) {
if (bc->value()->isConstant() && immediateAs<int>(bc->value()) == 1) {
return other;
}
if (const Ramp* r = dynamic_cast<const Ramp*>(other)) {
// Multiply both start and stride by the broadcast value.
const Expr* ret = new Ramp(
new Mul(bc->value(), r->base()),
new Mul(bc->value(), r->stride()),
r->lanes());
return ret->accept_mutator(this);
}
return nullptr;
}
};
} // namespace tensorexpr

View File

@ -1,7 +1,6 @@
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
namespace torch {
@ -177,18 +176,10 @@ void IRVisitor::visit(const Cond* v) {
}
}
void IRVisitor::visit(const Term* v) {
v->scalar()->accept(this);
for (auto* t : v->variables()) {
t->accept(this);
}
}
void IRVisitor::visit(const Polynomial* v) {
v->scalar()->accept(this);
for (auto* t : v->variables()) {
t->accept(this);
}
void IRVisitor::visit(const LinearForm* v) {
v->getA()->accept(this);
v->getX()->accept(this);
v->getB()->accept(this);
}
} // namespace tensorexpr

View File

@ -42,8 +42,7 @@ class FunctionCall;
class Allocate;
class Free;
class Cond;
class Term;
class Polynomial;
class LinearForm;
class TORCH_API IRVisitor {
public:
@ -91,8 +90,7 @@ class TORCH_API IRVisitor {
virtual void visit(const Allocate* v);
virtual void visit(const Free* v);
virtual void visit(const Cond* v);
virtual void visit(const Term* v);
virtual void visit(const Polynomial* v);
virtual void visit(const LinearForm* v);
};
} // namespace tensorexpr

View File

@ -1114,7 +1114,6 @@ void TensorExprKernel::lowerToBackend(BackendType backendType) {
l.ApplyInlines();
Stmt* stmt = l.root_stmt();
// Arithmetic Simplification.
stmt = IRSimplifier::simplify(stmt);
// Set up formal params (inputs, then outputs) for kernel.

View File

@ -1,6 +1,6 @@
#include <torch/csrc/jit/tensorexpr/types.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/tensorexpr/exceptions.h>
#include <torch/csrc/jit/tensorexpr/types.h>
#include <c10/util/Logging.h>

View File

@ -50,7 +50,7 @@ class TORCH_API Dtype {
Dtype(Dtype type, int lanes)
: scalar_type_(type.scalar_type_), lanes_(lanes) {
if (type.lanes() != 1) {
throw malformed_input("dtype lanes dont match");
throw malformed_input();
}
}
int lanes() const {
@ -108,15 +108,10 @@ inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
return static_cast<ScalarType>(c10::promoteTypes(
static_cast<c10::ScalarType>(a), static_cast<c10::ScalarType>(b)));
}
inline Dtype promoteTypes(Dtype a, Dtype b) {
if (a.lanes() != b.lanes()) {
throw malformed_input("promoting types with different lanes");
}
return Dtype(
static_cast<ScalarType>(c10::promoteTypes(
static_cast<c10::ScalarType>(a.scalar_type()),
static_cast<c10::ScalarType>(b.scalar_type()))),
a.lanes());
inline ScalarType promoteTypes(Dtype a, Dtype b) {
return static_cast<ScalarType>(c10::promoteTypes(
static_cast<c10::ScalarType>(a.scalar_type()),
static_cast<c10::ScalarType>(b.scalar_type())));
}
inline Dtype BinaryOpDtype(
@ -132,21 +127,22 @@ inline Dtype BinaryOpDtype(
}
if (op1_dtype.lanes() != op2_dtype.lanes()) {
throw malformed_input("lanes dont match");
throw malformed_input();
}
int lanes = op1_dtype.lanes();
Dtype resultType = promoteTypes(op1_dtype, op2_dtype);
if (resultType.scalar_type() == ScalarType::Undefined) {
throw malformed_input("scalar type doesn't match");
ScalarType resultType = promoteTypes(op1_dtype, op2_dtype);
if (resultType == ScalarType::Undefined) {
throw malformed_input();
}
if (lanes == 1) {
// Use the fixed scalar Dtypes.
return ToDtype(resultType.scalar_type());
return ToDtype(resultType);
}
return resultType;
return Dtype(resultType, lanes);
}
} // namespace tensorexpr