mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This reverts commit d742a2896c571a535003d5928fe80397325575a5. Reverted https://github.com/pytorch/pytorch/pull/158928 on behalf of https://github.com/yangw-dev due to this breaks bunch of internal dependency since some tests are still using the deleted test files from this pr, the internal reviewer please help fix this using codev ([comment](https://github.com/pytorch/pytorch/pull/158928#issuecomment-3134378616))
3703 lines
87 KiB
C++
3703 lines
87 KiB
C++
#include <gtest/gtest.h>
|
|
#include "test/cpp/tensorexpr/test_base.h"
|
|
|
|
#include "test/cpp/tensorexpr/test_utils.h"
|
|
#include "torch/csrc/jit/tensorexpr/ir_simplifier.h"
|
|
#include "torch/csrc/jit/tensorexpr/registerizer.h"
|
|
|
|
#include <iostream>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
// Can replace a simple scalar access with a local variable.
|
|
TEST(Registerizer, RegisterizerSimple) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + A_1;
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Won't do replacement of a loop access.
|
|
TEST(Registerizer, RegisterizerLoop) {
|
|
BufHandle a("A", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK-NOT: int
|
|
# CHECK: A[0] = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A_
|
|
# CHECK: A[x] =
|
|
# CHECK-NOT: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Won't replace even if the load is a fixed scalar, since the store could
|
|
// invalidate it.
|
|
TEST(Registerizer, RegisterizerLoopFixedLoad) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {x}, Add::make(Load::make(a, {0}), x))}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK-NOT: int
|
|
# CHECK: A[0] = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A_
|
|
# CHECK: A[x] =
|
|
# CHECK-NOT: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// We can registerize accesses that occur entirely within inner scopes, even if
|
|
// they depend on the loop var.
|
|
TEST(Registerizer, RegisterizerLoopInternal) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {x}, Add::make(Load::make(a, {x}), x)),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[x]) + x;
|
|
* A[x] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
// TODO: the order of terms in addition changes and in general depends on
|
|
// some hash value. This results in unpredictable swaps of the operands from
|
|
// random changes, which is not great. Ideally, we should ensure some
|
|
// specific order (ideally, the original one).
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* int A_1 = A[x];
|
|
* A_1 = x + A_1;
|
|
* A_1 = x + A_1;
|
|
* A[x] = A_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: A[x] = A_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// An access can be overlapped by another read in the same Expr. In this case
|
|
// B[z] and B[y] overlap and prevent registerization of both accesses.
|
|
TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) {
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
VarHandle z("z", kInt);
|
|
StmtPtr stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Store::make(a, {x}, Add::make(Load::make(b, {y}), Load::make(b, {z}))))});
|
|
stmt = IRSimplifier::simplify(stmt);
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (B[y]) + (B[z]);
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerLoopInternalRepeated) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {1}), x)),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {1}), x)),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {1}), x))}))
|
|
|
|
});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = x + (A[1]);
|
|
* A[0] = x + (A[1]);
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = x + (A[1]);
|
|
* A[0] = x + (A[1]);
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[1];
|
|
* int A_2 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = A_1 + x;
|
|
* A_2 = A_1 + x;
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = A_1 + x;
|
|
* A_2 = A_1 + x;
|
|
* }
|
|
* A[0] = A_2;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[1];
|
|
# CHECK: int A_2 = A[0];
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: A_2 = A_1 + x;
|
|
# CHECK: A_2 = A_1 + x;
|
|
# CHECK: }
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: A_2 = A_1 + x;
|
|
# CHECK: A_2 = A_1 + x;
|
|
# CHECK: }
|
|
# CHECK-NOT: A[1]
|
|
# CHECK: A[0] = A_2;
|
|
# CHECK-NOT: A[1]
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {x}), x)),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {x}), x)),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {x}), x))}))
|
|
|
|
});
|
|
stmt = IRSimplifier::simplify(stmt);
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[x]) + x;
|
|
* A[0] = (A[x]) + x;
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[x]) + x;
|
|
* A[0] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = IRSimplifier::simplify(Block::make(
|
|
{For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(x, Load::make(a, {y}))),
|
|
Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(x, Load::make(a, {y}))),
|
|
Store::make(a, {0}, Add::make(x, Load::make(a, {y})))}))
|
|
|
|
}));
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[x]) + x;
|
|
* A[0] = (A[x]) + x;
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[x]) + x;
|
|
* A[0] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Will registerize multiple accesses of different items of the same buffer.
|
|
TEST(Registerizer, RegisterizerMultiVar) {
|
|
BufHandle a("A", {2}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({
|
|
Store::make(a, {0}, 0),
|
|
Store::make(a, {1}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), x)),
|
|
Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})),
|
|
});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* A[1] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* A[1] = (A[1]) - x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* int A_2 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = x + A_2;
|
|
* A_1 = A_1 - x;
|
|
* }
|
|
* A[1] = A_2;
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: int A_2 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A_2 =
|
|
# CHECK: A[1] = A_2
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Will registerize the valid accesses while skipping invalid replacements.
|
|
TEST(Registerizer, RegisterizerVariableLoad) {
|
|
BufHandle a("A", {1}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle x2("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(x, 0, 10, Store::make(b, {x}, x)),
|
|
For::make(
|
|
x2,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0}, Add::make(Load::make(a, {0}), Load::make(b, {x2})))}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = x;
|
|
* }
|
|
* for (int x_1 = 0; x_1 < 10; x_1++) {
|
|
* A[0] = (A[0]) + (B[x_1]);
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = x;
|
|
* }
|
|
* for (int x_1 = 0; x_1 < 10; x_1++) {
|
|
* A_1 = A_1 + (B[x_1]);
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: B[x] = x
|
|
# CHECK: for (int x_1 = 0; x_1 < 10; x_1++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize variable accesses so long as the variable does not change.
|
|
TEST(Registerizer, RegisterizerSymbolicIndices) {
|
|
VarHandle i("i", kInt);
|
|
VarHandle N("N", kInt);
|
|
BufHandle a("A", {N}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {i}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {i}, Add::make(Load::make(a, {i}), x))}))});
|
|
|
|
/*
|
|
* A[i] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[i] = (A[i]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + A_1;
|
|
* }
|
|
* A[i] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[i] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize accesses dependent on multiple loop vars.
|
|
TEST(Registerizer, RegisterizerMultiLoop) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a,
|
|
{0},
|
|
Mul::make(Add::make(Load::make(a, {0}), x), y))})))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A[0] = x * y + (A[0]) * y;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A_1 = x * y + y * A_1;
|
|
* }
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: for (int y = 0; y < 10; y++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize correctly if scalars already exist in the program.
|
|
TEST(Registerizer, RegisterizerRepeated) {
|
|
BufHandle a("A", {2}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({
|
|
Store::make(a, {0}, 0),
|
|
Store::make(a, {1}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), x)),
|
|
Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})),
|
|
});
|
|
|
|
// Registerize manually to make sure we only replace a single target.
|
|
{
|
|
registerizer::RegisterizerAnalysis analysis;
|
|
stmt->accept(&analysis);
|
|
auto candidates = analysis.getCandidates();
|
|
ASSERT_EQ(candidates.size(), 2);
|
|
|
|
candidates.pop_back();
|
|
registerizer::RegisterizerReplacer replacer(candidates);
|
|
stmt = stmt->accept_mutator(&replacer);
|
|
}
|
|
|
|
// Re-analyze and replace the second target.
|
|
{
|
|
registerizer::RegisterizerAnalysis analysis;
|
|
stmt->accept(&analysis);
|
|
auto candidates = analysis.getCandidates();
|
|
ASSERT_EQ(candidates.size(), 1);
|
|
|
|
registerizer::RegisterizerReplacer replacer(candidates);
|
|
stmt = stmt->accept_mutator(&replacer);
|
|
}
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: int A_1_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A_1_1 =
|
|
# CHECK: A[1] = A_1_1;
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize the load of A.
|
|
TEST(Registerizer, RegisterizerNoLoads) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1))}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = x + 1;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + 1;
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize the load of A but not the store of B.
|
|
TEST(Registerizer, RegisterizerNoRepeatedStores) {
|
|
BufHandle a("A", {1}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(b, {x}, Add::make(Load::make(a, {0}), x))}))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
// TODO: its unnecessary to reorder the initializer of A[0], but it's not
|
|
// actually worse so lets not worry for now.
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = x + A_1;
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A_
|
|
# CHECK: B[x] =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Won't registerize if there are multiple accesses which may overlap.
|
|
TEST(Registerizer, RegisterizerMultiVarOverlap) {
|
|
BufHandle a("A", {2}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({
|
|
Store::make(a, {0}, 0),
|
|
Store::make(a, {1}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {x}, Add::make(Load::make(a, {0}), x)),
|
|
Store::make(a, {x + 1}, Sub::make(Load::make(a, {1}), x))})),
|
|
});
|
|
stmt = IRSimplifier::simplify(stmt);
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerAllocs) {
|
|
BufHandle a("A", {2}, kInt);
|
|
BufHandle c("C", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
BufHandle b("B", {Load::make(c, {0})}, kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Allocate::make(b),
|
|
Store::make(a, {0}, Load::make(c, {0})),
|
|
Store::make(b, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(b, {0}, Add::make(Load::make(b, {0}), x)),
|
|
Store::make(a, {0}, Load::make(c, {0}))})),
|
|
Free::make(b)});
|
|
|
|
/*
|
|
* Allocate(B, int, {C[0]});
|
|
* A[0] = C[0];
|
|
* B[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[0] = (B[0]) + x;
|
|
* A[0] = C[0];
|
|
* }
|
|
* Free(B);
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int C_1 = C[0];
|
|
* Allocate(B, int, {C_});
|
|
* int A_1 = C_1;
|
|
* int B_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B_1 = B_1 + x;
|
|
* A_1 = C_1;
|
|
* }
|
|
* B[0] = B_1;
|
|
* A[0] = A_1;
|
|
* Free(B);
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int C_1 = C[0];
|
|
# CHECK: Allocate(B
|
|
# CHECK: int A_1 = C_1;
|
|
# CHECK: int B_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK: B_1 =
|
|
# CHECK: A_1 = C_
|
|
# CHECK: B[0] = B_1;
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: Free(B)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerNoInitializer) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + A_1;
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerNoInitializerLoopVar) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});
|
|
stmt = IRSimplifier::simplify(stmt);
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = (A[x]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerLoadThenStore) {
|
|
BufHandle a("A", {1}, kInt);
|
|
BufHandle b("B", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(b, {0}, Add::make(Load::make(a, {0}), x)),
|
|
Store::make(a, {0}, Load::make(b, {0}))}))});
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[0] = (A[0]) + x;
|
|
* A[0] = B[0];
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[0];
|
|
* int B_1 = B[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B_1 = x + A_1;
|
|
* A_1 = B_1;
|
|
* }
|
|
* B[0] = B_1;
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: int B_1 = B[0];
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: B[
|
|
# CHECK: B_1 =
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 = B_
|
|
# CHECK: B[0] = B_
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerParallelized) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
LoopOptions loopOpts;
|
|
loopOpts.set_gpu_block_index(0);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}),
|
|
loopOpts)});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
ASSERT_THROWS_WITH(
|
|
registerize(stmt),
|
|
"Registerization must occur after parallelism flattening");
|
|
}
|
|
|
|
// Should be able to registerize this since the scalar would exist before the
|
|
// branch.
|
|
TEST(Registerizer, RegisterizerConditionAfter) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {x}, Load::make(b, {x})),
|
|
Store::make(c, {x}, Load::make(a, {x})),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
nullptr)});
|
|
|
|
/*
|
|
* A[x] = B[x];
|
|
* C[x] = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = B[x];
|
|
* C[x] = A_1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = B[x];
|
|
# CHECK: C[x] = A_1;
|
|
# CHECK: if (
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Should be able to registerize this since the scalar exists in the same form
|
|
// after the branch and there is no overlap.
|
|
TEST(Registerizer, RegisterizerConditionBefore) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
nullptr),
|
|
Store::make(a, {x}, Load::make(b, {x})),
|
|
Store::make(c, {x}, Load::make(a, {x}))});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* A[x] = B[x];
|
|
* C[x] = A[x];
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_ 1 = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A_1 = B[x];
|
|
* C[x] = A_1;
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: if (
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A_1 = B[x];
|
|
# CHECK: C[x] = A_1;
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Should be able to registerize this as the combination of the two above rules.
|
|
TEST(Registerizer, RegisterizerConditionInside) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {x}, Load::make(b, {x})),
|
|
Store::make(c, {x}, Load::make(a, {x})),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
nullptr),
|
|
Store::make(b, {x}, Load::make(a, {x})),
|
|
Store::make(a, {x}, Load::make(c, {x}))});
|
|
|
|
/*
|
|
* A[x] = B[x];
|
|
* C[x] = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* B[x] = A[x];
|
|
* A[x] = C[x];
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = B[x];
|
|
* C[x] = A_1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* B[x] = A_1;
|
|
* A_1 = C[x];
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = B[x];
|
|
# CHECK: C[x] = A_1;
|
|
# CHECK: if (
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: B[x] = A_1;
|
|
# CHECK: A_1 = C[x];
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// An example where an access is cut by an overlapping access inside a
|
|
// condition, and both sides are large enough to be registerized but cannot be
|
|
// because there is no safe place to put the initializer or finalizer.
|
|
TEST(Registerizer, RegisterizerConditionInsideOverlap1) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
|
{Store::make(a, {x}, Load::make(b, {x})),
|
|
Store::make(c, {x}, Load::make(a, {x})),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
Store::make(a, {0}, 3),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
}),
|
|
nullptr),
|
|
Store::make(b, {x}, Load::make(a, {x})),
|
|
Store::make(a, {x}, Load::make(c, {x}))});
|
|
|
|
/*
|
|
* A[x] = B[x];
|
|
* C[x] = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* A[0] = 3;
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* B[x] = A[x];
|
|
* A[x] = C[x];
|
|
*/
|
|
|
|
// The A[0] store overlaps, A[x] cutting the region that can be registerized
|
|
// into two groups.
|
|
// Each group has 2 loads and 2 stores however, so we could registerize it,
|
|
// but the first group would need to be finalized inside the condition block,
|
|
// the second would need to be initialized inside the condition block. There's
|
|
// no safe place to put these that's visible to the other uses in the group
|
|
// and so neither registerization is possible.
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Same as the above, but the access group before the condition (and after the
|
|
// condition) are large enough to be registerized without needing the access
|
|
// from the loop. Registerization occurs but does not include any accesses in
|
|
// the condition, and the first group must be finalized before the Cond, the
|
|
// second initialized after it.
|
|
TEST(Registerizer, RegisterizerConditionInsideOverlap2) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
|
{Store::make(a, {x}, Load::make(b, {x})),
|
|
Store::make(a, {x}, Load::make(b, {x + 1})),
|
|
Store::make(c, {x}, Load::make(a, {x})),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
Store::make(a, {0}, 3),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
}),
|
|
nullptr),
|
|
Store::make(b, {x}, Load::make(a, {x})),
|
|
Store::make(b, {x + 1}, Load::make(a, {x})),
|
|
Store::make(a, {x}, Load::make(c, {x}))});
|
|
|
|
/*
|
|
* A[x] = B[x];
|
|
* A[x] = B[x + 1];
|
|
* C[x] = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* A[0] = 3;
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* B[x] = A[x];
|
|
* B[x + 1] = A[x];
|
|
* A[x] = C[x];
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = B[x]; // A_1 initializer
|
|
* A_1 = B[x + 1]; //
|
|
* C[x] = A_1; //
|
|
* A[x] = A_1; // A_1 finalizer
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* A[0] = 3;
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* int A_2 = A[x]; // A_2 initializer
|
|
* B[x] = A_2; //
|
|
* B[x + 1] = A_2; //
|
|
* A_2 = C[x]; //
|
|
* A[x] = A_2; // A_2 finalizer
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = B[x];
|
|
# CHECK: A_1 = B[x + 1];
|
|
# CHECK: C[x] = A_1;
|
|
# CHECK: A[x] = A_1;
|
|
# CHECK: if (
|
|
# CHECK-NOT: A_1 = A_1 + 1;
|
|
# CHECK: A[x] = (A[x]
|
|
# CHECK: A[0] =
|
|
# CHECK: A[x] = (A[x]
|
|
# CHECK: }
|
|
# CHECK: int A_2 = A[x];
|
|
# CHECK: B[x] = A_2;
|
|
# CHECK: B[x + 1] = A_2;
|
|
# CHECK: A_2 = C[x];
|
|
# CHECK: A[x] = A_2;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// When accesses are within conditional blocks they are not visible to the wider
|
|
// program, because we don't know if the branch would be taken and if it isn't
|
|
// the accesses in it don't need to be valid (think size checks on the index).
|
|
// In this case the accesses cannot be registerized.
|
|
TEST(Registerizer, RegisterizerConditionHidden) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
nullptr),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kGT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* if (x>5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// But... if the same access is found in a non conditional scope, that means
|
|
// that that access is valid in the higher scope (or at least if its not it's
|
|
// the user's fault). It "unhides" the conditional accesses, allowing
|
|
// registerization to occur.
|
|
TEST(Registerizer, RegisterizerConditionUnhidden) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
nullptr),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kGT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
* A[x] = (A[x]) + 1; <-- this is doing the unhiding.
|
|
* if (x>5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* if (x<5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A_1 = A_1 + 1;
|
|
* if (x>5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: if (x<5
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: if (x>5
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize a load that occurs in the condition of a Cond.
|
|
TEST(Registerizer, RegisterizerCondCondition) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {x}, Load::make(b, {x})),
|
|
Store::make(c, {x}, Load::make(a, {x})),
|
|
Cond::make(
|
|
CompareSelect::make(
|
|
Load::make(a, {x}), 5, CompareSelectOperation::kLT),
|
|
Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)),
|
|
nullptr)});
|
|
|
|
/*
|
|
* A[x] = B[x];
|
|
* C[x] = A[x];
|
|
* if ((A[x])<5 ? 1 : 0) {
|
|
* C[x] = (C[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = B[x];
|
|
* int C_1 = A_1;
|
|
* if (A_1<5 ? 1 : 0) {
|
|
* C_1 = C_1 + 1;
|
|
* }
|
|
* C[x] = C_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = B[x];
|
|
# CHECK: int C_1 = A_1;
|
|
# CHECK: if (A_1<5
|
|
# CHECK: C_1 = C_1 + 1;
|
|
# CHECK: C[x] = C_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Appearing in the condition of a Cond makes it visible to the enclosing scope,
|
|
// and so we can registerize internal usages.
|
|
TEST(Registerizer, RegisterizerCondConditionUnhidden) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make({Cond::make(
|
|
CompareSelect::make(Load::make(a, {x}), 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
|
|
Store::make(a, {x}, Add::make(Load::make(a, {x}), 10)))});
|
|
|
|
/*
|
|
* if ((A[x])<5 ? 1 : 0) {
|
|
* A[x] = (A[x]) + 1;
|
|
* } else {
|
|
* A[x] = (A[x]) + 10;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* if (A_1<5 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* } else {
|
|
* A_1 = A_1 + 10;
|
|
* }
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: if (A_1<5
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: } else {
|
|
# CHECK: A_1 = A_1 + 10;
|
|
# CHECK: }
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Conditional hiding also works for IfThenElse exprs.
|
|
TEST(Registerizer, RegisterizerIfThenElseHidden) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(
|
|
b,
|
|
{y},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}), 1),
|
|
Add::make(Load::make(a, {x + 1}), 2))),
|
|
Store::make(
|
|
b,
|
|
{y + 1},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}), 1),
|
|
Add::make(Load::make(a, {x + 1}), 2)))});
|
|
|
|
/*
|
|
* B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
|
|
* B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Conditional unhiding also works for IfThenElse exprs.
|
|
TEST(Registerizer, RegisterizerIfThenElseUnhidden) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
StmtPtr stmt = Block::make({
|
|
Store::make(a, {x}, 0),
|
|
Store::make(
|
|
b,
|
|
{y},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}), 1),
|
|
Add::make(Load::make(a, {x + 1}), 2))),
|
|
Store::make(
|
|
b,
|
|
{y + 1},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}), 1),
|
|
Add::make(Load::make(a, {x + 1}), 2))),
|
|
});
|
|
|
|
/*
|
|
* A[x] = 0;
|
|
* B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
|
|
* B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
|
|
* B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
|
|
# CHECK: B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Nested IfThenElse exprs can't promote to higher level scopes.
|
|
TEST(Registerizer, RegisterizerIfThenElseNested) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
BufHandle d("D", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make({Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kLT),
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Load::make(d, {x}),
|
|
Load::make(b, {x})),
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kEQ),
|
|
Load::make(c, {x}),
|
|
Load::make(d, {x}))))});
|
|
|
|
/*
|
|
* A[x] = IfThenElse(x<3 ? 1 : 0,
|
|
* IfThenElse(x==2 ? 1 : 0, D[x], B[x]),
|
|
* IfThenElse(x==5 ? 1 : 0, C[x], D[x]));
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Cannot registerize an access completely contained within an IfThenElse
|
|
// branch, since it is not a Stmt and cannot hold variable definitions. We need
|
|
// to check that we don't promote the initializer/finalizer to the enclosing
|
|
// Block.
|
|
TEST(Registerizer, RegisterizerIfThenElseInternal) {
|
|
// Making these floats so they don't get simplified to a single access.
|
|
BufHandle a("A", {5}, kFloat);
|
|
BufHandle b("B", {5}, kFloat);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make({Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(b, {x}), Load::make(b, {x})),
|
|
Load::make(b, {x})))});
|
|
|
|
/*
|
|
* A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]);
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
|
|
// If this was a Cond instead of an IfThenElse then we could registerize the
|
|
// two accesses to B[x] in the True branch.
|
|
|
|
// Actually lets verify that.
|
|
|
|
stmt = Block::make({Cond::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Add::make(Load::make(b, {x}), Load::make(b, {x}))),
|
|
Store::make(a, {x}, Load::make(b, {x})))});
|
|
|
|
/*
|
|
* if (x<3 ? 1 : 0) {
|
|
* A[x] = (B[x]) + (B[x]);
|
|
* } else {
|
|
* A[x] = B[x];
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x<3 ? 1 : 0) {
|
|
* float B_1 = B[x];
|
|
* A[x] = B_1 + B_1;
|
|
* } else {
|
|
* A[x] = B[x];
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK-NOT: int
|
|
# CHECK-NOT: float
|
|
# CHECK: if (x<3
|
|
# CHECK: float B_1 =
|
|
# CHECK: A[x] = B_1 + B_1
|
|
# CHECK: } else {
|
|
# CHECK: A[x] = B[x]
|
|
# CHECK: }
|
|
# CHECK-NOT: A[x]
|
|
# CHECK-NOT: B[x])IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize a load that occurs in the condition of an IfThenElse;
|
|
TEST(Registerizer, RegisterizerIfThenElseCondition) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {x}, Load::make(a, {x})),
|
|
Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(
|
|
Load::make(a, {x}), 5, CompareSelectOperation::kLT),
|
|
Load::make(b, {0}),
|
|
Load::make(c, {0})))});
|
|
|
|
/*
|
|
* A[x] = A[x]; <---- just here so there are enough accesses to combine.
|
|
* A[x] = IfThenElse((A[x])<5 ? 1 : 0, B[0], C[0]);
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* A_1 = A_1;
|
|
* A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]);
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]);
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Appearing in the condition of a Cond makes it visible to the enclosing scope,
|
|
// and so we can registerize internal usages.
|
|
TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make({Store::make(
|
|
b,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(
|
|
Load::make(a, {x}), 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}), 1),
|
|
Add::make(Load::make(a, {x}), 10)))});
|
|
|
|
/*
|
|
* B[x] = IfThenElse((A[x])<5 ? 1 : 0, (A[x]) + 1, (A[x]) + 10);
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Cannot promote accesses internal to IfThenElse branches even if the enclosing
|
|
// scope if conditional.
|
|
TEST(Registerizer, RegisterizerConditionBranchOnly) {
|
|
BufHandle a("A", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}), x),
|
|
Add::make(Load::make(a, {x - 5}), x))),
|
|
Store::make(
|
|
a,
|
|
{x - 5},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Add::make(Load::make(a, {x}), x),
|
|
Add::make(Load::make(a, {x - 5}), x)))),
|
|
}))});
|
|
stmt = IRSimplifier::simplify(stmt);
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
/* for (int x = 0; x < 10; x++) {
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x);
|
|
* } else {
|
|
* A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x);
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// We can registerize an IfThenElse that appears in the condition branch of a
|
|
// Cond. This is a weird but valid thing to do.
|
|
TEST(Registerizer, RegisterizerCondIfThenElse) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
BufHandle c("C", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
|
|
StmtPtr stmt = Block::make({Cond::make(
|
|
CompareSelect::make(
|
|
IfThenElse::make(
|
|
CompareSelect::make(
|
|
Load::make(a, {x}), 5, CompareSelectOperation::kLT),
|
|
Load::make(a, {x}),
|
|
Load::make(b, {x})),
|
|
x,
|
|
CompareSelectOperation::kEQ),
|
|
Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if ((IfThenElse((A[x])<5 ? 1 : 0, A[x], B[x]))==x ? 1 : 0) {
|
|
* C[x] = (C[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
// access to A can be registerized, but not B or C
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]))==x ? 1 : 0) {
|
|
* C[x] = (C[x]) + 1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]
|
|
# CHECK: C[x] = (C[x]) + 1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can registerize a conditional access in the RHS of a store unhidden by it's
|
|
// LHS, and hoist it out of a loop.
|
|
TEST(Registerizer, RegisterizerIfThenElseLoop) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
StmtPtr stmt = For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kLT),
|
|
Load::make(a, {x}),
|
|
Load::make(b, {y}))));
|
|
|
|
/*
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A[x] = IfThenElse(x<3 ? 1 : 0, A[x], B[y]);
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[x];
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]);
|
|
* }
|
|
* A[x] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: for (
|
|
# CHECK: A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]);
|
|
# CHECK: }
|
|
# CHECK: A[x] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Cannot registerize if the RHS overlaps the access creating visibility.
|
|
TEST(Registerizer, RegisterizerIfThenElseLoopCut) {
|
|
BufHandle a("A", {5}, kInt);
|
|
BufHandle b("B", {5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
|
|
StmtPtr stmt = Block::make({For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
a,
|
|
{x},
|
|
IfThenElse::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kLT),
|
|
Load::make(a, {x}),
|
|
Load::make(a, {y}))))});
|
|
|
|
/*
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A[x] = IfThenElse(x<3 ? 1 : 0, A[x], A[y]);
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Simple case where an access is cut by an overlapping access later in the
|
|
// program, we can registerize up until the overlap.
|
|
TEST(Registerizer, RegisterizerPartialAfter) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})),
|
|
For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})))});
|
|
|
|
/*
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + x;
|
|
* }
|
|
* A[0] = A_1;
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: for (
|
|
# CHECK: A[x] = A[x - 1];
|
|
# CHECK: }
|
|
# CHECK-NOT: A)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// We can registerize an access which overlaps a previous access, the
|
|
// initializer must be inserted after the previous access.
|
|
TEST(Registerizer, RegisterizerPartialBefore) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))),
|
|
Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});
|
|
|
|
/*
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + x;
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK-NOT: int
|
|
# CHECK: for (
|
|
# CHECK: A[x] = A[x - 1];
|
|
# CHECK: }
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// The combination of the previous two tests, an access is cut by an overlapping
|
|
// access in both directions.
|
|
TEST(Registerizer, RegisterizerPartialInside) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x1("x1", kInt);
|
|
VarHandle x2("x2", kInt);
|
|
VarHandle x3("x3", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 2),
|
|
For::make(
|
|
x1, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x1))),
|
|
For::make(x2, 1, 10, Store::make(a, {x2}, Load::make(a, {x2 - 1}))),
|
|
For::make(
|
|
x3, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x3)))});
|
|
|
|
/*
|
|
* A[0] = 2;
|
|
* for (int x1 = 0; x1 < 10; x1++) {
|
|
* A[0] = (A[0]) + x1;
|
|
* }
|
|
* for (int x2 = 1; x2 < 10; x2++) {
|
|
* A[x2] = A[x2 - 1];
|
|
* }
|
|
* for (int x3 = 0; x3 < 10; x3++) {
|
|
* A[0] = (A[0]) + x3;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 2;
|
|
* for (int x1 = 0; x1 < 10; x1++) {
|
|
* A_1 = A_1 + x1;
|
|
* }
|
|
* A[0] = A_1;
|
|
* for (int x2 = 1; x2 < 10; x2++) {
|
|
* A[x2] = A[x2 - 1];
|
|
* }
|
|
* int A_2 = A[0];
|
|
* for (int x3 = 0; x3 < 10; x3++) {
|
|
* A_2 = A_2 + x3;
|
|
* }
|
|
* A[0] = A_2;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 2;
|
|
# CHECK: for (
|
|
# CHECK: A_1 = A_1 + x1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: for (
|
|
# CHECK: A[x2] =
|
|
# CHECK: }
|
|
# CHECK: int A_2 = A[0];
|
|
# CHECK: for (
|
|
# CHECK: A_2 = A_2 + x3;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_2;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// An element could be registerized program wide but is cut by a conditional
|
|
// access, we should break this into two scalars and write back to the buffer
|
|
// before the condition.
|
|
TEST(Registerizer, RegisterizerPartialCondition) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 2),
|
|
For::make(
|
|
x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x))),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Store::make(a, {x}, Load::make(a, {x - 1})),
|
|
nullptr),
|
|
For::make(
|
|
x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x)))});
|
|
|
|
/*
|
|
* A[0] = 2;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 2;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + x;
|
|
* }
|
|
* A[0] = A_1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = A[x - 1];
|
|
* }
|
|
* int A_2 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = A_2 + x;
|
|
* }
|
|
* A[0] = A_2;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 2;
|
|
# CHECK: for (
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: if (
|
|
# CHECK: A[x] =
|
|
# CHECK: }
|
|
# CHECK: int A_2 = A[0];
|
|
# CHECK: for (
|
|
# CHECK: A_2 = A_2 + x;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_2;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Tests case where an access is cut by an internal conditional access which
|
|
// itself is registerized.
|
|
TEST(Registerizer, RegisterizerPartialConditionInternalCut) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 1),
|
|
Store::make(a, {0}, 3),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}),
|
|
nullptr),
|
|
Store::make(a, {0}, 4),
|
|
Store::make(a, {0}, 6)});
|
|
|
|
/*
|
|
* A[0] = 1;
|
|
* A[0] = 3;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = 1;
|
|
* A[x] = 3;
|
|
* }
|
|
* A[0] = 4;
|
|
* A[0] = 6;
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 1;
|
|
* A_1 = 3;
|
|
* A[0] = A_1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* int A_2 = 1;
|
|
* A_2 = 3;
|
|
* A[x] = A_2;
|
|
* }
|
|
* int A_3 = 4;
|
|
* A_3 = 6;
|
|
* A[0] = A_3;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 1;
|
|
# CHECK: A_1 = 3
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: if (
|
|
# CHECK: int A_2 = 1;
|
|
# CHECK: A_2 = 3;
|
|
# CHECK: A[x] = A_2;
|
|
# CHECK: }
|
|
# CHECK: int A_3 = 4;
|
|
# CHECK: A_3 = 6;
|
|
# CHECK: A[0] = A_3;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// First statement in condition closes outer access, but can be registerized
|
|
// with later statements.
|
|
TEST(Registerizer, RegisterizerPartialConditionInternalStart) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, 1),
|
|
Store::make(a, {0}, 3),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}),
|
|
nullptr),
|
|
Store::make(a, {x}, 4),
|
|
Store::make(a, {x}, 6)});
|
|
|
|
/*
|
|
* A[0] = 1;
|
|
* A[0] = 3;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = 1;
|
|
* A[x] = 3;
|
|
* }
|
|
* A[x] = 4;
|
|
* A[x] = 6;
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 1;
|
|
* A_1 = 3;
|
|
* A[0] = A_1;
|
|
* int A_2 = A[x]; <--- must read from the input here.
|
|
* if (x<5 ? 1 : 0) {
|
|
* A_2 = 1;
|
|
* A_2 = 3;
|
|
* }
|
|
* A_2 = 4;
|
|
* A_2 = 6;
|
|
* A[x] = A_2;
|
|
*/
|
|
|
|
// TODO: I suppose we could refactor with a conditional initializer?
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 1;
|
|
# CHECK: A_1 = 3
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: int A_2 = A[x];
|
|
# CHECK: if (
|
|
# CHECK: A_2 = 1;
|
|
# CHECK: A_2 = 3;
|
|
# CHECK: }
|
|
# CHECK: A_2 = 4;
|
|
# CHECK: A_2 = 6;
|
|
# CHECK: A[x] = A_2;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// An access cuts two open overlaps and creates four scalar variables.
|
|
TEST(Registerizer, RegisterizerPartialOverlapsTwo) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {1}, Load::make(a, {0})),
|
|
Store::make(a, {0}, Load::make(a, {1})),
|
|
Store::make(a, {0}, Load::make(a, {1})),
|
|
For::make(x, 1, 10, Store::make(a, {x}, x)),
|
|
Store::make(a, {1}, Load::make(a, {0})),
|
|
Store::make(a, {0}, Load::make(a, {1})),
|
|
Store::make(a, {0}, Load::make(a, {1}))});
|
|
|
|
/*
|
|
* A[1] = A[0];
|
|
* A[0] = A[1];
|
|
* A[0] = A[1];
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = x;
|
|
* }
|
|
* A[1] = A[0];
|
|
* A[0] = A[1];
|
|
* A[0] = A[1];
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[0];
|
|
* int A_2 = A_1;
|
|
* A_1 = A_2;
|
|
* A_1 = A_2;
|
|
* A[1] = A_2;
|
|
* A[0] = A_1;
|
|
* for (int x = 1; x < 10; x++) {
|
|
* A[x] = x;
|
|
* }
|
|
* int A_3 = A[0];
|
|
* int A_4 = A_3;
|
|
* A_3 = A_4;
|
|
* A_3 = A_4;
|
|
* A[1] = A_4;
|
|
* A[0] = A_3;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: int A_2 = A_1;
|
|
# CHECK: A_1 = A_2;
|
|
# CHECK: A_1 = A_2;
|
|
# CHECK: A[1] = A_2;
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: for (
|
|
# CHECK: A[x] = x;
|
|
# CHECK: }
|
|
# CHECK: int A_3 = A[0];
|
|
# CHECK: int A_4 = A_3;
|
|
# CHECK: A_3 = A_4;
|
|
# CHECK: A_3 = A_4;
|
|
# CHECK: A[1] = A_4;
|
|
# CHECK: A[0] = A_3;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Nested blocks will automatically be flattened and do not provent
|
|
// registerization of enclosed accesses.
|
|
TEST(Registerizer, RegisterizerNestedBlocks) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), 2))}),
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), 3)),
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), 4))})})});
|
|
|
|
/*
|
|
* A[0] = (A[0]) + 1;
|
|
* {
|
|
* A[0] = (A[0]) + 2;
|
|
* }
|
|
* {
|
|
* A[0] = (A[0]) + 3;
|
|
* {
|
|
* A[0] = (A[0]) + 4;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[0];
|
|
* A_1 = A_1 + 1;
|
|
* A_1 = A_1 + 2;
|
|
* A_1 = A_1 + 3;
|
|
* A_1 = A_1 + 4;
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: A_1 = A_1 + 2;
|
|
# CHECK: A_1 = A_1 + 3;
|
|
# CHECK: A_1 = A_1 + 4;
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// The access can be registerized internally to a condition, but must ensure
|
|
// that both initializer and finalizer are within the same condition.
|
|
TEST(Registerizer, RegisterizerNestedConditions) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make({Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr)}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* if (x==2 ? 1 : 0) {
|
|
*
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* int A_1 = A[0];
|
|
* A_1 = A_1 + 1;
|
|
* if (x==2 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x<5
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: if (x==2
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// If an access exists outside the scope of the condition then we can lift
|
|
// nested conditional usages into the same scalar.
|
|
TEST(Registerizer, RegisterizerNestedConditionsUnhidden) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make(
|
|
{Store::make(a, {1}, 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr)}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* A[0] = (A[0]) + 1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[1] = 1;
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = A[0];
|
|
* A_1 = A_1 + 1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[1] = 1;
|
|
* if (x==2 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* }
|
|
* A[0] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: if (x<5
|
|
# CHECK: A[1] = 1;
|
|
# CHECK: if (x==2
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: A[0] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr)}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* if (x<5 ? 1 : 0) {
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
stmt = registerize(stmt);
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make({Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr)}),
|
|
nullptr),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
stmt = registerize(stmt);
|
|
}
|
|
|
|
// If an access is cut by another access internal to a condition block, it still
|
|
// cuts the access.
|
|
TEST(Registerizer, RegisterizerNestedConditionsCut) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
Block::make(
|
|
{Store::make(a, {x}, 1),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr)}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* A[0] = (A[0]) + 1;
|
|
* if (x<5 ? 1 : 0) {
|
|
* A[x] = 1;
|
|
* if (x==2 ? 1 : 0) {
|
|
*
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
TEST(Registerizer, RegisterizerNestedConditionLoopHidden) {
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(b, {x}, 0),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
nullptr)}))});
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = 0; <-- this is only here to prevent Loop/Cond reordering.
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// Three loops and four element regions, three of which should be registerized
|
|
// at different levels of the IR.
|
|
TEST(Registerizer, RegisterizerNestedConditionThreeDeep) {
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {4}, 0),
|
|
Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kGT),
|
|
Cond::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kGT),
|
|
Block::make({
|
|
Cond::make(
|
|
CompareSelect::make(x, 4, CompareSelectOperation::kGT),
|
|
Block::make({
|
|
Store::make(
|
|
a, {1}, Add::make(Load::make(a, {1}), 1)),
|
|
Store::make(
|
|
a, {2}, Add::make(Load::make(a, {2}), 1)),
|
|
Store::make(
|
|
a, {3}, Add::make(Load::make(a, {3}), 1)),
|
|
Store::make(
|
|
a, {4}, Add::make(Load::make(a, {4}), 1)),
|
|
Store::make(
|
|
a, {1}, Add::make(Load::make(a, {1}), 1)),
|
|
}),
|
|
nullptr),
|
|
Store::make(a, {2}, Add::make(Load::make(a, {2}), 1)),
|
|
}),
|
|
nullptr),
|
|
nullptr)});
|
|
|
|
/*
|
|
* A[4] = 0;
|
|
* if (x>2 ? 1 : 0) {
|
|
* if (x>3 ? 1 : 0) {
|
|
* if (x>4 ? 1 : 0) {
|
|
* A[1] = (A[1]) + 1;
|
|
* A[2] = (A[2]) + 1;
|
|
* A[3] = (A[3]) + 1;
|
|
* A[4] = (A[4]) + 1;
|
|
* A[1] = (A[1]) + 1;
|
|
* }
|
|
* A[2] = (A[2]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* if (x>2 ? 1 : 0) {
|
|
* if (x>3 ? 1 : 0) {
|
|
* int A_3 = A[2];
|
|
* if (x>4 ? 1 : 0) {
|
|
* int A_2 = A[1];
|
|
* A_2 = A_2 + 1;
|
|
* A_3 = A_3 + 1;
|
|
* A[3] = (A[3]) + 1;
|
|
* A_1 = A_1 + 1;
|
|
* A_2 = A_2 + 1;
|
|
* A[1] = A_2;
|
|
* }
|
|
* A_3 = A_3 + 1;
|
|
* A[2] = A_3;
|
|
* }
|
|
* }
|
|
* A[4] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: if (x>2 ? 1 : 0) {
|
|
# CHECK: if (x>3 ? 1 : 0) {
|
|
# CHECK: int A_3 = A[2];
|
|
# CHECK: if (x>4 ? 1 : 0) {
|
|
# CHECK: int A_2 = A[1];
|
|
# CHECK: A_2 = A_2 + 1;
|
|
# CHECK: A_3 = A_3 + 1;
|
|
# CHECK: A[3] = (A[3]) + 1;
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: A_2 = A_2 + 1;
|
|
# CHECK: A[1] = A_2;
|
|
# CHECK: }
|
|
# CHECK: A_3 = A_3 + 1;
|
|
# CHECK: A[2] = A_3;
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: A[4] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Can replace a simple scalar access with a local variable even when that
|
|
// variable is an outer loop var.
|
|
TEST(Registerizer, RegisterizerNestedLoopSimple) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = Block::make({For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})))});
|
|
|
|
/*
|
|
* for (int y = 0; y < 10; y++) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[y] = (A[y]) + x;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* for (int y = 0; y < 10; y++) {
|
|
* int A_1 = A[y];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + x;
|
|
* }
|
|
* A[y] = A_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int y
|
|
# CHECK: int A_1 = A[y];
|
|
# CHECK: for (int x
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: }
|
|
# CHECK: A[y] = A_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Test the positive case of the hiddenAccess split, where an internal
|
|
// conditional access can be hoisted up through a loop to match an existing
|
|
// access in a higher scope and the two can be registerized.
|
|
TEST(Registerizer, RegisterizerHiddenAccessYes) {
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = Block::make({Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(b, {x}, 0),
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
|
Cond::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kEQ),
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
a, {0}, Add::make(Load::make(a, {0}), 1))),
|
|
nullptr)}))}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = 0;
|
|
* if (x==3 ? 1 : 0) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = 0;
|
|
* if (x==3 ? 1 : 0) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* }
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x==2
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x
|
|
# CHECK: B[x] = 0;
|
|
# CHECK: if (x==3
|
|
# CHECK: for (int y
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Test the negative case of the hiddenAccess split, where the hoisted access is
|
|
// never unhidden at a higher scope and registerization occurs at the lower
|
|
// scope.
|
|
TEST(Registerizer, RegisterizerHiddenAccessNo) {
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = Block::make({Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(b, {x}, 0),
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
|
Cond::make(
|
|
CompareSelect::make(x, 3, CompareSelectOperation::kEQ),
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
|
|
nullptr)}))}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = 0;
|
|
* if (x==3 ? 1 : 0) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* B[x] = 0;
|
|
* if (x==3 ? 1 : 0) {
|
|
* int A_1 = A[0];
|
|
* for (int y = 0; y < 10; y++) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x==2
|
|
# CHECK: for (int x
|
|
# CHECK: B[x] = 0;
|
|
# CHECK: if (x==3
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: for (int y
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// In this case the conditional access must be hoisted by two loops, there are
|
|
// two accesses here one is unhidden and the other isn't. A[0] can be
|
|
// registerized but B[0] cannot.
|
|
TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) {
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = Block::make({Cond::make(
|
|
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
|
|
Block::make(
|
|
{Store::make(a, {0}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
Block::make({Cond::make(
|
|
CompareSelect::make(y, 3, CompareSelectOperation::kEQ),
|
|
Block::make(
|
|
{Store::make(
|
|
a, {0}, Add::make(Load::make(a, {0}), 1)),
|
|
Store::make(
|
|
b, {0}, Add::make(Load::make(b, {0}), 1))}),
|
|
nullptr)})))}),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* A[0] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* if (y==3 ? 1 : 0) {
|
|
* A[0] = (A[0]) + 1;
|
|
* B[0] = (B[0]) + 1;
|
|
* }
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x==2 ? 1 : 0) {
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* if (y==3 ? 1 : 0) {
|
|
* A_1 = A_1 + 1;
|
|
* B[0] = (B[0]) + 1;
|
|
* }
|
|
* }
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x==2
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x
|
|
# CHECK: for (int y
|
|
# CHECK: if (y==3
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: B[0] = (B[0]) + 1;
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Accesses are registerized inside two conditions, but the immediate parent is
|
|
// not a condition.
|
|
TEST(Registerizer, RegisterizerTwoConditionalLoops) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
For::make(
|
|
x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
|
|
nullptr),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kGT),
|
|
For::make(
|
|
x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
* if (x>5 ? 1 : 0) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* int A_1 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
* if (x>5 ? 1 : 0) {
|
|
* int A_2 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = A_2 + 1;
|
|
* }
|
|
* A[0] = A_2;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x<5
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: for (int x
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: }
|
|
# CHECK: if (x>5
|
|
# CHECK: int A_2 = A[0];
|
|
# CHECK: for (int x
|
|
# CHECK: A_2 = A_2 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_2;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Accesses are registerized inside two conditions, cut in the middle.
|
|
TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) {
|
|
BufHandle a("A", {1}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
|
|
For::make(
|
|
x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
|
|
nullptr),
|
|
For::make(x, 0, 10, Store::make(a, {x}, 1)),
|
|
Cond::make(
|
|
CompareSelect::make(x, 5, CompareSelectOperation::kGT),
|
|
For::make(
|
|
x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
|
|
nullptr)});
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = 1;
|
|
* }
|
|
* if (x>5 ? 1 : 0) {
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0] = (A[0]) + 1;
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* if (x<5 ? 1 : 0) {
|
|
* int A_1 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + 1;
|
|
* }
|
|
* A[0] = A_1;
|
|
* }
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[x] = 1;
|
|
* }
|
|
* if (x>5 ? 1 : 0) {
|
|
* int A_2 = A[0];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = A_2 + 1;
|
|
* }
|
|
* A[0] = A_2;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: if (x<5
|
|
# CHECK: int A_1 = A[0];
|
|
# CHECK: for (int x
|
|
# CHECK: A_1 = A_1 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_1;
|
|
# CHECK: }
|
|
# CHECK: for (int x
|
|
# CHECK: A[x] = 1;
|
|
# CHECK: if (x>5
|
|
# CHECK: int A_2 = A[0];
|
|
# CHECK: for (int x
|
|
# CHECK: A_2 = A_2 + 1;
|
|
# CHECK: }
|
|
# CHECK: A[0] = A_2;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// references a Let var in a local scope which cannot be hoisted out of the
|
|
// loop.
|
|
TEST(Registerizer, RegisterizerLoopLetVar) {
|
|
BufHandle a("A", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = IRSimplifier::simplify(Block::make({For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Let::make(y, 30),
|
|
Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))}));
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* int y = 30;
|
|
* A[y] = x + (A[y]);
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// references a Let var in an outer scope that does not prevent hoisting the
|
|
// initializer.
|
|
TEST(Registerizer, RegisterizerLoopLetVarOuter) {
|
|
BufHandle a("A", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Let::make(y, 30),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make(
|
|
{Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))});
|
|
|
|
/*
|
|
* int y = 30;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[y] = x + (A[y]);
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int y = 30;
|
|
* int A_1 = A[y];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = A_1 + x;
|
|
* }
|
|
* A[y] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int y = 30;
|
|
# CHECK: int A_1 = A[y];
|
|
# CHECK: for (int x
|
|
# CHECK: A_1 = A_1 + x;
|
|
# CHECK: A[y] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Okay so the registerizer generally goes after index flattening, but just in
|
|
// case. Test multi index registerization.
|
|
TEST(Registerizer, RegisterizerMultiDim) {
|
|
BufHandle a("A", {3, 4, 5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0, 1, 2}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0, 1, 2}, Add::make(Load::make(a, {0, 1, 2}), x))}))});
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0, 1, 2] = (A[0, 1, 2]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* int A_1 = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_1 = x + A_1;
|
|
* }
|
|
* A[0, 1, 2] = A_1;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int A_1 = 0;
|
|
# CHECK: for (int x = 0; x < 10; x++)
|
|
# CHECK-NOT: A[
|
|
# CHECK: A_1 =
|
|
# CHECK: A[0, 1, 2] = A_1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// Won't registerize if only some dims match, but will still registerize
|
|
// distinct elements.
|
|
TEST(Registerizer, RegisterizerMultiDimPartial) {
|
|
BufHandle a("A", {3, 4, 5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0, 1, 2}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0, 2, 2}, Add::make(Load::make(a, {0, 1, 4}), x))}))});
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0, 2, 2] = (A[0, 1, 4]) + x;
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0;
|
|
* int A_1 = A[0, 1, 4];
|
|
* int A_2 = A[0, 2, 2];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A_2 = A_1 + x;
|
|
* }
|
|
* A[0, 2, 2] = A_2;
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: A[0, 1, 2] = 0;
|
|
# CHECK: int A_1 = A[0, 1, 4];
|
|
# CHECK: int A_2 = A[0, 2, 2];
|
|
# CHECK: for (
|
|
# CHECK: A_2 = A_1 + x;
|
|
# CHECK: A[0, 2, 2] = A_2;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// If they could overlap across all dimensions we cannot registerize.
|
|
TEST(Registerizer, RegisterizerMultiDimOverlap) {
|
|
BufHandle a("A", {3, 4, 5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0, 1, 2}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 2}), x))}))});
|
|
stmt = IRSimplifier::simplify(stmt);
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0;
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0, x, 2] = (A[y, 2, 2]) + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream before;
|
|
before << *stmt;
|
|
|
|
// No change.
|
|
stmt = registerize(stmt);
|
|
|
|
std::ostringstream after;
|
|
after << *stmt;
|
|
|
|
ASSERT_EQ(before.str(), after.str());
|
|
}
|
|
|
|
// But, if one dimension is known to be distinct they do not overlap.
|
|
TEST(Registerizer, RegisterizerMultiDimPartialOverlap) {
|
|
BufHandle a("A", {3, 4, 5}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
StmtPtr stmt = Block::make(
|
|
{Store::make(a, {0, 1, 2}, 0),
|
|
For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
Block::make({Store::make(
|
|
a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 4}), x))}))});
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0; <---- 2nd dim overlaps with store.
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0, x, 2] = (A[y, 2, 4]) + x; <---- 3rd dim has constant diff.
|
|
* }
|
|
*/
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* A[0, 1, 2] = 0;
|
|
* int A_1 = A[y, 2, 4];
|
|
* for (int x = 0; x < 10; x++) {
|
|
* A[0, x, 2] = A_1 + x;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: A[0, 1, 2] = 0;
|
|
# CHECK: int A_1 = A[y, 2, 4];
|
|
# CHECK: for (
|
|
# CHECK: A[0, x, 2] = A_1 + x;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// A 3D reduction with different input dimensionality.
|
|
TEST(Registerizer, RegisterizerMultiDim3DReduction1) {
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10, 10}, kInt);
|
|
BufHandle c("C", {10, 10, 10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
VarHandle z("z", kInt);
|
|
StmtPtr stmt = For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
For::make(
|
|
z,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
c,
|
|
{x, y, z},
|
|
Add::make(
|
|
Load::make(c, {x, y, z}),
|
|
Mul::make(Load::make(b, {x, y}), Load::make(a, {x})))))));
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* for (int z = 0; z < 10; z++) {
|
|
* C[x, y, z] = (C[x, y, z]) + (B[x, y]) * (A[x]);
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
// We can registerize the A and B access since they can be hoisted before
|
|
// hitting a dependent loop var.
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* int A_1 = A[x];
|
|
* for (int y = 0; y < 10; y++) {
|
|
* int B_1 = B[x, y];
|
|
* for (int z = 0; z < 10; z++) {
|
|
* C[x, y, z] = A_1 * B_1 + (C[x, y, z]);
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int x
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: for (int y
|
|
# CHECK: int B_1 = B[x, y];
|
|
# CHECK: for (int z
|
|
# CHECK: C[x, y, z] = A_1 * B_1 + (C[x, y, z]);
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
// A 3D reduction with the same smaller dimensionality using different loop
|
|
// vars.
|
|
TEST(Registerizer, RegisterizerMultiDim3DReduction2) {
|
|
BufHandle a("A", {10}, kInt);
|
|
BufHandle b("B", {10}, kInt);
|
|
BufHandle c("C", {10}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
VarHandle z("z", kInt);
|
|
StmtPtr stmt = For::make(
|
|
x,
|
|
0,
|
|
10,
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
|
For::make(
|
|
y,
|
|
0,
|
|
10,
|
|
For::make(
|
|
z,
|
|
0,
|
|
10,
|
|
Store::make(
|
|
c,
|
|
{x},
|
|
Add::make(
|
|
Load::make(c, {x}),
|
|
Mul::make(Load::make(b, {y}), Load::make(a, {x})))))));
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* for (int y = 0; y < 10; y++) {
|
|
* for (int z = 0; z < 10; z++) {
|
|
* C[x] = (C[x]) + (B[y]) * (A[x]);
|
|
* }
|
|
* }
|
|
* }
|
|
*/
|
|
|
|
// We can registerize all accesses, the A and C access can be hoisted to the
|
|
// outer loop since they depend only on it's loop var while the B can only be
|
|
// raised to the loop of y.
|
|
|
|
stmt = registerize(stmt);
|
|
|
|
/*
|
|
* for (int x = 0; x < 10; x++) {
|
|
* int A_1 = A[x];
|
|
* int C_1 = C[x];
|
|
* for (int y = 0; y < 10; y++) {
|
|
* int B_1 = B[y];
|
|
* for (int z = 0; z < 10; z++) {
|
|
* C_1 = A_1 * B_1 + C_1;
|
|
* }
|
|
* }
|
|
* C[x] = C_1;
|
|
* }
|
|
*/
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int x
|
|
# CHECK: int A_1 = A[x];
|
|
# CHECK: int C_1 = C[x];
|
|
# CHECK: for (int y
|
|
# CHECK: int B_1 = B[y];
|
|
# CHECK: for (int z
|
|
# CHECK: C_1 = A_1 * B_1 + C_1;
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: C[x] = C_1;
|
|
# CHECK: })IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|