[nnc][tests] Convert a bunch of FileCheck to checkIR

Summary:
I added a helper to convert a Stmt to string and FileCheck it, so
started using it in a bunch of places.  I replaced about half the current uses,
got tired, started to write a Perl script to automate it, realized that was
hard, and decided to give up for a bit.  But this cleans up some of the tests a
bit, so seems easy to review and worth landing.

Test Plan: test_tensorexpr --gtest_filter=LoopNest.*

Reviewed By: navahgar

Differential Revision: D27375866

fbshipit-source-id: 15894b9089dec5cf25f340fe17e6e54546a64257
This commit is contained in:
Bert Maher
2021-03-26 20:23:14 -07:00
committed by Facebook GitHub Bot
parent 24f589df44
commit e4d19798f3

View File

@ -22,6 +22,12 @@ namespace jit {
using namespace torch::jit::tensorexpr;
void checkIR(Stmt* s, const std::string& pattern) {
std::ostringstream oss;
oss << *s;
torch::jit::testing::FileCheck().run(pattern, oss.str());
}
TEST(LoopNest, ExprSimple01) {
KernelScope kernel_scope;
Tensor* tensor = Compute(
@ -627,20 +633,16 @@ TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) {
l.splitWithMask(outer, 4, &outer, &mid);
Stmt* stmt1 = IRSimplifier::simplify(l.root_stmt());
std::ostringstream oss;
oss << *stmt1;
// Two splits mean 3 loops, but should need no masks in this case.
const std::string& verification_pattern =
R"IR(
checkIR(stmt1, R"IR(
# CHECK: for (
# CHECK-NOT: if (
# CHECK: for (
# CHECK-NOT: if (
# CHECK: for (
# CHECK-NOT: if (
# CHECK: f[)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
# CHECK: f[)IR");
}
TEST(LoopNest, SplitWithTailWithLoopOptions) {
@ -1027,18 +1029,14 @@ TEST(LoopNest, ScheduleInlineRandom) {
// would normally compare results but Rand isn't implemented in the
// SimpleIREvaluator, even if we could seed it.
Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
std::ostringstream oss;
oss << *stmt1;
// Check the IR we produced
const std::string& verification_pattern =
R"IR(
checkIR(stmt1, R"IR(
# CHECK: for (int m2 = 0; m2 < 4; m2++)
# CHECK: for (int n2 = 0; n2 < 5; n2++)
# CHECK: for (int k2 = 0; k2 < 6; k2++)
# CHECK: int x = rand();
# CHECK: y[m2, n2, k2] = 2 * (x % 5);)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
# CHECK: y[m2, n2, k2] = 2 * (x % 5);)IR");
}
// Make sure we don't cache random vars that are not being inlined.
@ -1068,17 +1066,13 @@ TEST(LoopNest, ScheduleInlineRandomUnrelated) {
// would normally compare results but Rand isn't implemented in the
// SimpleIREvaluator, even if we could seed it.
Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
std::ostringstream oss;
oss << *stmt1;
// Check the IR we produced
const std::string& verification_pattern =
R"IR(
checkIR(stmt1, R"IR(
# CHECK: for (int m2 = 0; m2 < 4; m2++)
# CHECK: for (int n2 = 0; n2 < 5; n2++)
# CHECK: for (int k2 = 0; k2 < 6; k2++)
# CHECK: y[m2, n2, k2] = ((n2 * m2) * k2 + (rand())) + (rand());)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
# CHECK: y[m2, n2, k2] = ((n2 * m2) * k2 + (rand())) + (rand());)IR");
}
// Make sure we generate the right number of random values == the dimensionality
@ -1105,18 +1099,14 @@ TEST(LoopNest, ScheduleInlineRandomLowerDimensions) {
// would normally compare results but Rand isn't implemented in the
// SimpleIREvaluator, even if we could seed it.
Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
std::ostringstream oss;
oss << *stmt1;
// Check the IR we produced
const std::string& verification_pattern =
R"IR(
checkIR(stmt1, R"IR(
# CHECK: for (int m2 = 0; m2 < 4; m2++)
# CHECK: int x = rand();
# CHECK: for (int n2 = 0; n2 < 5; n2++)
# CHECK: for (int k2 = 0; k2 < 6; k2++)
# CHECK: y[m2, n2, k2] = 2 * (x % 5);)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
# CHECK: y[m2, n2, k2] = 2 * (x % 5);)IR");
}
// Make sure we don't screw up intrinsics thinking they're rand.
@ -1205,18 +1195,13 @@ TEST(LoopNest, ScheduleInlineRandWithIntrinsics) {
Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
std::ostringstream oss;
oss << *stmt1;
// Check the IR we produced
const std::string& verification_pattern =
R"IR(
checkIR(stmt1, R"IR(
# CHECK: for (int m2 = 0; m2 < 4; m2++)
# CHECK: for (int n2 = 0; n2 < 5; n2++)
# CHECK: for (int k2 = 0; k2 < 6; k2++)
# CHECK: float x = rand();
# CHECK: y[m2, n2, k2] = sqrt(x);)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
# CHECK: y[m2, n2, k2] = sqrt(x);)IR");
}
// Split a Compute then inline it into another compute.
@ -1503,12 +1488,9 @@ TEST(LoopNest, ScheduleInlineOutputTensors) {
// would normally compare results but Rand isn't implemented in the
// SimpleIREvaluator, even if we could seed it.
Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
std::ostringstream oss;
oss << *stmt1;
// Check the IR we produced
const std::string& verification_pattern =
R"IR(
checkIR(stmt1, R"IR(
# CHECK: for (int m1 = 0; m1 < 4; m1++)
# CHECK: for (int n1 = 0; n1 < 5; n1++)
# CHECK: for (int k1 = 0; k1 < 6; k1++)
@ -1516,8 +1498,7 @@ TEST(LoopNest, ScheduleInlineOutputTensors) {
# CHECK: for (int m2 = 0; m2 < 4; m2++)
# CHECK: for (int n2 = 0; n2 < 5; n2++)
# CHECK: for (int k2 = 0; k2 < 6; k2++)
# CHECK: y[m2, n2, k2] = (n2 * m2) * k2 + m2;)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
# CHECK: y[m2, n2, k2] = (n2 * m2) * k2 + m2;)IR");
}
TEST(LoopNest, ScheduleFuserStyle) {
@ -1644,19 +1625,13 @@ TEST(LoopNest, LoopNestComputeAt_1) {
l.prepareForCodegen();
Stmt* s = l.root_stmt();
std::ostringstream oss;
oss << *s;
const std::string& verification_pattern =
R"IR(
checkIR(s, R"IR(
# CHECK: for (int i_b = 0; i_b < N; i_b++)
# CHECK: Allocate(temp); // dtype=int, dims=[1]
# CHECK: temp[
# CHECK-NOT: A[
# CHECK: B[i_b] = temp[0]
# CHECK: Free(temp))IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
# CHECK: Free(temp))IR");
// Now check that the loop still produces the correct result.
std::vector<int> b_data(100, 0);
@ -1717,12 +1692,8 @@ TEST(LoopNest, LoopNestComputeAt_2) {
l.prepareForCodegen();
Stmt* s = l.root_stmt();
std::ostringstream oss;
oss << *s;
// Check the IR we produced
const std::string& verification_pattern =
R"IR(
checkIR(s, R"IR(
# CHECK: for (int cy = 0; cy < H; cy++)
# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1]
# CHECK: for
@ -1730,8 +1701,7 @@ TEST(LoopNest, LoopNestComputeAt_2) {
# CHECK: for (int cx = 0; cx < W; cx++)
# CHECK-NOT: prod[
# CHECK: cons[
# CHECK: Free(temp))IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
# CHECK: Free(temp))IR");
// Now check that the loop still produces the correct result.
std::vector<int> c_data(kW * kH, 0);
@ -1748,12 +1718,8 @@ TEST(LoopNest, LoopNestComputeAt_2) {
l.prepareForCodegen();
Stmt* s = l.root_stmt();
std::ostringstream oss;
oss << *s;
// Check the IR we produced
const std::string& verification_pattern =
R"IR(
checkIR(s, R"IR(
# CHECK: for (int cy = 0; cy < H; cy++)
# CHECK: for (int cx = 0; cx < W; cx++)
# CHECK: Allocate(temp); // dtype=int, dims=[2, 2]
@ -1761,8 +1727,7 @@ TEST(LoopNest, LoopNestComputeAt_2) {
# CHECK: for
# CHECK-NOT: prod[
# CHECK: cons[
# CHECK: Free(temp))IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
# CHECK: Free(temp))IR");
// Now check that the loop still produces the correct result.
std::vector<int> c_data(kW * kH, 0);
@ -1826,12 +1791,8 @@ TEST(LoopNest, LoopNestComputeAt_3) {
l.prepareForCodegen();
Stmt* s = l.root_stmt();
std::ostringstream oss;
oss << *s;
// Check the IR we produced
const std::string& verification_pattern =
R"IR(
checkIR(s, R"IR(
# CHECK: for (int ay = 0; ay < H + 1; ay++)
# CHECK: for (int ax = 0; ax < W + 1; ax++)
# CHECK: A[
@ -1844,8 +1805,7 @@ TEST(LoopNest, LoopNestComputeAt_3) {
# CHECK: for (int dy = 0; dy < H; dy++)
# CHECK: Allocate(temp); // dtype=int, dims=[1, W]
# CHECK: for (int dx = 0; dx < W; dx++)
# CHECK-NOT: A[)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
# CHECK-NOT: A[)IR");
// Now check that the loop still produces the correct result.
std::vector<int> c_data(kW * kH, 0);
@ -1862,12 +1822,8 @@ TEST(LoopNest, LoopNestComputeAt_3) {
l.prepareForCodegen();
Stmt* s = l.root_stmt();
std::ostringstream oss;
oss << *s;
// Check the IR we produced
const std::string& verification_pattern =
R"IR(
checkIR(s, R"IR(
# CHECK: for (int ay = 0; ay < H + 1; ay++)
# CHECK: for (int ax = 0; ax < W + 1; ax++)
# CHECK: A[
@ -1880,8 +1836,7 @@ TEST(LoopNest, LoopNestComputeAt_3) {
# CHECK: for (int dy = 0; dy < H; dy++)
# CHECK: for (int dx = 0; dx < W; dx++)
# CHECK: Allocate(temp); // dtype=int, dims=[1, 1]
# CHECK-NOT: A[)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
# CHECK-NOT: A[)IR");
// Now check that the loop still produces the correct result.
std::vector<int> c_data(kW * kH, 0);
@ -1894,12 +1849,6 @@ TEST(LoopNest, LoopNestComputeAt_3) {
using Axis = const VarHandle&;
void checkIR(Stmt* s, const std::string& pattern) {
std::ostringstream oss;
oss << *s;
torch::jit::testing::FileCheck().run(pattern, oss.str());
}
TEST(LoopNest, Reduce2dComputeAt) {
KernelScope kernel_scope;
@ -2365,12 +2314,8 @@ TEST(LoopNest, LoopNestReorderExtraStatements) {
l.reorderAxis(loops[1], loops[2]);
Stmt* stmt2 = Stmt::clone(l.root_stmt());
std::ostringstream oss;
oss << *l.root_stmt();
// Check the IR we produced
const std::string& verification_pattern1 =
R"IR(
checkIR(stmt2, R"IR(
# CHECK: for (int x
# CHECK: res[x, 0] = 1
# CHECK: for (int y
@ -2380,8 +2325,7 @@ TEST(LoopNest, LoopNestReorderExtraStatements) {
# CHECK: f[
# CHECK: for (int y
# CHECK: res[x, 2] = 4
)IR";
torch::jit::testing::FileCheck().run(verification_pattern1, oss.str());
)IR");
std::vector<int> extra2(6, 0);
std::vector<int> res2(24, 0);
@ -2418,12 +2362,8 @@ TEST(LoopNest, LoopNestReorderExtraStatements) {
l.reorderAxis(loops[0], loops[2]);
Stmt* stmt3 = Stmt::clone(l.root_stmt());
std::ostringstream oss2;
oss2 << *stmt3;
// Check the IR we produced
const std::string& verification_pattern2 =
R"IR(
checkIR(stmt3, R"IR(
# CHECK: for (int x
# CHECK: res[x, 0] = 1
# CHECK: for (int y
@ -2435,8 +2375,7 @@ TEST(LoopNest, LoopNestReorderExtraStatements) {
# CHECK: for (int x
# CHECK: for (int y
# CHECK: res[x, 2] = 4
)IR";
torch::jit::testing::FileCheck().run(verification_pattern2, oss2.str());
)IR");
std::vector<int> extra3(6, 0);
std::vector<int> res3(24, 0);
@ -2615,13 +2554,9 @@ TEST(LoopNest, LoopNestReorderInternalLoopNest) {
l.prepareForCodegen();
Stmt* stmt = IRSimplifier::simplify(l.root_stmt());
std::ostringstream oss;
oss << *stmt;
// Check the IR we produced has the 3 nests in the right order, but k and m
// swapped in the middle.
const std::string& verification_pattern =
R"IR(
checkIR(stmt, R"IR(
# CHECK: for (int m1
# CHECK: for (int n1
# CHECK: for (int k1
@ -2630,8 +2565,7 @@ TEST(LoopNest, LoopNestReorderInternalLoopNest) {
# CHECK: for (int m2
# CHECK: for (int m3
# CHECK: for (int n3
# CHECK: for (int k3)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
# CHECK: for (int k3)IR");
{
PaddedBuffer<float> a_v(M, N);
@ -2744,8 +2678,7 @@ TEST(LoopNest, UnrollOuter) {
std::vector<For*> loops = l.getLoopStmtsFor(A);
Stmt* unrolled = nullptr;
LoopNest::unroll(loops[0], &unrolled);
const std::string& verification_pattern =
R"IR(
checkIR(unrolled, R"IR(
# CHECK: for (int y = 0; y < 4; y++) {
# CHECK: A[0, y] = y;
# CHECK: }
@ -2754,11 +2687,7 @@ TEST(LoopNest, UnrollOuter) {
# CHECK: }
# CHECK: for (int y = 0; y < 4; y++) {
# CHECK: A[2, y] = y + 2;
# CHECK: })IR";
std::ostringstream oss;
oss << *unrolled;
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
# CHECK: })IR");
}
TEST(LoopNest, UnrollInner) {
@ -2774,18 +2703,13 @@ TEST(LoopNest, UnrollInner) {
Stmt* unrolled = nullptr;
LoopNest::unroll(
static_cast<For*>(loops[0]->body()->stmts().front()), &unrolled);
const std::string& verification_pattern =
R"IR(
checkIR(loops[0], R"IR(
# CHECK: for (int x = 0; x < 3; x++) {
# CHECK: A[x, 0] = x;
# CHECK: A[x, 1] = x + 1;
# CHECK: A[x, 2] = x + 2;
# CHECK: A[x, 3] = x + 3;
# CHECK: })IR";
std::ostringstream oss;
oss << *loops[0];
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
# CHECK: })IR");
}
TEST(LoopNest, UnrollMultipleStatements) {
@ -2805,18 +2729,13 @@ TEST(LoopNest, UnrollMultipleStatements) {
Block::make({f});
Stmt* unrolled = nullptr;
LoopNest::unroll(f, &unrolled);
std::ostringstream oss;
oss << *unrolled;
const std::string& verification_pattern =
R"IR(
checkIR(unrolled, R"IR(
# CHECK: A[0] = 0;
# CHECK: B[0] = A[0];
# CHECK: A[1] = 2;
# CHECK: B[1] = A[1];
# CHECK: A[2] = 4
# CHECK: B[2] = A[2];)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
# CHECK: B[2] = A[2];)IR");
}
TEST(LoopNest, UnrollNonLiteralConstantBounds) {
@ -2843,8 +2762,7 @@ TEST(LoopNest, UnrollNonLiteralConstantBounds) {
std::vector<For*> loops = {outer_for, inner_for};
Stmt* unrolled = nullptr;
LoopNest::unroll(loops[0], &unrolled);
const std::string& verification_pattern =
R"IR(
checkIR(unrolled, R"IR(
# CHECK: for (int j = 0; j < 4; j++) {
# CHECK: A[1, j] = j;
# CHECK: }
@ -2853,11 +2771,7 @@ TEST(LoopNest, UnrollNonLiteralConstantBounds) {
# CHECK: }
# CHECK: for (int j = 0; j < 4; j++) {
# CHECK: A[3, j] = 3 * j;
# CHECK: })IR";
std::ostringstream oss;
oss << *unrolled;
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
# CHECK: })IR");
}
TEST(LoopNest, UnrollEmpty) {
@ -3337,14 +3251,10 @@ TEST(LoopNest, FlattenLoopNestWithNonLiteralConstantBounds) {
ASSERT_TRUE(success);
auto result = IRSimplifier::simplify(flattened);
std::ostringstream oss;
oss << *result;
const std::string& expected_ir =
R"IR(
checkIR(result, R"IR(
# CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) {
# CHECK: A[i_flat / 5, i_flat % 5] =
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
)IR");
{
SimpleIREvaluator eval1(loops[0], {a_buf});
@ -3384,16 +3294,12 @@ TEST(LoopNest, FlattenImperfectLoopNest) {
ASSERT_FALSE(success);
auto result = IRSimplifier::simplify(flattened);
std::ostringstream oss;
oss << *result;
const std::string& expected_ir =
R"IR(
checkIR(result, R"IR(
# CHECK: for (int i = 0; i < 10; i++) {
# CHECK-NEXT: A[i, i] =
# CHECK-NEXT: for (int j = 0; j < 15; j++) {
# CHECK-NEXT: A[i, j] =
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
)IR");
}
TEST(LoopNest, FlattenReductionLoopNest) {
@ -3425,16 +3331,12 @@ TEST(LoopNest, FlattenReductionLoopNest) {
ASSERT_FALSE(success);
auto result = IRSimplifier::simplify(flattened);
std::ostringstream oss;
oss << *result;
const std::string& expected_ir =
R"IR(
checkIR(result, R"IR(
# CHECK: for (int i = 0; i < 10; i++) {
# CHECK-NEXT: S[i] =
# CHECK-NEXT: for (int j = 0; j < 15; j++) {
# CHECK-NEXT: S[i] = (S[i]) + (A[i, j])
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
)IR");
}
TEST(LoopNest, FlattenReductionLoopNestFromTensor) {
@ -3452,16 +3354,12 @@ TEST(LoopNest, FlattenReductionLoopNestFromTensor) {
ASSERT_FALSE(success);
auto result = IRSimplifier::simplify(flattened);
std::ostringstream oss;
oss << *result;
const std::string& expected_ir =
R"IR(
checkIR(result, R"IR(
# CHECK: for (int m = 0; m < 3; m++) {
# CHECK-NEXT: sum[m] =
# CHECK-NEXT: for (int n = 0; n < 7; n++) {
# CHECK-NEXT: sum[m] =
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
)IR");
}
TEST(LoopNest, FlattenIncorrectLoopsAsInput) {
@ -3500,15 +3398,11 @@ TEST(LoopNest, FlattenIncorrectLoopsAsInput) {
ASSERT_FALSE(success);
auto result = IRSimplifier::simplify(flattened);
std::ostringstream oss;
oss << *result;
const std::string& expected_ir =
R"IR(
checkIR(result, R"IR(
# CHECK: for (int i = 0; i < 10; i++) {
# CHECK-NEXT: for (int j = 0; j < 5; j++) {
# CHECK-NEXT: A[i, j] = i * j
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
)IR");
}
TEST(LoopNest, DetectInlineRankMismatch) {
@ -3552,12 +3446,8 @@ TEST(LoopNest, CacheReadsSimple) {
l.prepareForCodegen();
Stmt* result = IRSimplifier::simplify(l.root_stmt());
std::ostringstream oss;
oss << *result;
// just this once: verify the whole thing.
const std::string& expected_ir =
R"IR(
checkIR(result, R"IR(
#CHECK: Allocate(A); // dtype=int, dims=[64, 64]
#CHECK: for (int i
#CHECK: for (int j
@ -3580,8 +3470,7 @@ TEST(LoopNest, CacheReadsSimple) {
#CHECK: }
#CHECK: }
#CHECK: Free(A);
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
)IR");
std::vector<int> b_data(200, 0);
std::vector<int> c_data(200, 0);
@ -3625,15 +3514,11 @@ TEST(LoopNest, CacheReadsOuter) {
l.prepareForCodegen();
Stmt* result = IRSimplifier::simplify(l.root_stmt());
std::ostringstream oss;
oss << *result;
const std::string& expected_ir =
R"IR(
checkIR(result, R"IR(
#CHECK: Allocate(A_local); // dtype=int, dims=[21, 11]
#CHECK: A_local[j_1 + 11 * i_1] =
#CHECK: B[10 * i_2 + j_2] = (A_local[(j_2 + 11 * i_2) + 12]) + (A_local[j_2 + 11 * i_2]);
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
)IR");
std::vector<int> b_data(200, 0);
std::vector<int> c_data(200, 0);
@ -3676,15 +3561,11 @@ TEST(LoopNest, CacheReadsInternal) {
l.prepareForCodegen();
Stmt* result = IRSimplifier::simplify(l.root_stmt());
std::ostringstream oss;
oss << *result;
const std::string& expected_ir =
R"IR(
checkIR(result, R"IR(
#CHECK: Allocate(A_local); // dtype=int, dims=[2, 11]
#CHECK: A_local[j_1 + 11 * i_2] =
#CHECK: B[10 * i_1 + j_2] = (A_local[j_2]) + (A_local[j_2 + 12]);
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
)IR");
std::vector<int> b_data(200, 0);
std::vector<int> c_data(200, 0);
@ -3728,15 +3609,11 @@ TEST(LoopNest, CacheReadsInner) {
l.prepareForCodegen();
Stmt* result = IRSimplifier::simplify(l.root_stmt());
std::ostringstream oss;
oss << *result;
const std::string& expected_ir =
R"IR(
checkIR(result, R"IR(
#CHECK: Allocate(A_local); // dtype=int, dims=[5, 2]
#CHECK: A_local[2 * i_2 + j_2] =
#CHECK: B[10 * i_1 + j_1] = (A_local[8]) + (A_local[1]);
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
)IR");
std::vector<int> b_data(200, 0);
std::vector<int> c_data(200, 0);
@ -3780,10 +3657,7 @@ TEST(LoopNest, CacheWritesSimple) {
l.prepareForCodegen();
Stmt* result = IRSimplifier::simplify(l.root_stmt());
std::ostringstream oss;
oss << *result;
const std::string& expected_ir =
R"IR(
checkIR(result, R"IR(
#CHECK: Allocate(A_local); // dtype=int, dims=[1, 64]
#CHECK: for (int j = 0; j < 64
#CHECK: A_local[j] = i * j;
@ -3791,8 +3665,7 @@ TEST(LoopNest, CacheWritesSimple) {
#CHECK: A[64 * i + j_1] = A_local[
#CHECK: Free(A_local);
#CHECK-NOT: A_local
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
)IR");
std::vector<int> b_data(200, 0);
std::vector<int> c_data(200, 0);
@ -3839,29 +3712,19 @@ TEST(LoopNest, DeadStoreElimination) {
LoopNest loop(stmt, {f.node()});
loop.eliminateDeadStores();
std::ostringstream oss;
oss << *loop.root_stmt();
const std::string& expected_ir =
R"IR(
checkIR(loop.root_stmt(), R"IR(
#CHECK: f[x_tail + 5 * 4, y]
#CHECK-NOT: g[x_tail + 5 * 4, y]
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
)IR");
// But won't eliminate if used by different outputs.
LoopNest loop2(stmt, {f.node(), g.node()});
loop2.eliminateDeadStores();
oss.clear();
oss << *loop2.root_stmt();
const std::string& expected_ir2 =
R"IR(
checkIR(loop2.root_stmt(), R"IR(
#CHECK: f[x_tail + 5 * 4, y]
#CHECK: g[x_tail + 5 * 4, y]
)IR";
torch::jit::testing::FileCheck().run(expected_ir2, oss.str());
)IR");
}
TEST(LoopNest, DeadStoreEliminationWithIntermediates) {
@ -3894,31 +3757,21 @@ TEST(LoopNest, DeadStoreEliminationWithIntermediates) {
LoopNest loop(stmt, {h.node()});
loop.eliminateDeadStores();
std::ostringstream oss;
oss << *loop.root_stmt();
const std::string& expected_ir =
R"IR(
checkIR(loop.root_stmt(), R"IR(
#CHECK: f[x] = x;
#CHECK-NOT: g[z] =
#CHECK: h[x, y] = f[x * y];
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
)IR");
// Sanity check won't eliminate if g is an output.
LoopNest loop2(stmt, {h.node(), g.node()});
loop2.eliminateDeadStores();
oss.clear();
oss << *loop2.root_stmt();
const std::string& expected_ir2 =
R"IR(
checkIR(loop2.root_stmt(), R"IR(
#CHECK: f[x] = x;
#CHECK: g[z] = z + 1;
#CHECK: h[x, y] = f[x * y];
)IR";
torch::jit::testing::FileCheck().run(expected_ir2, oss.str());
)IR");
}
TEST(LoopNest, CompoundTensorSimple) {
@ -4109,7 +3962,6 @@ static Stmt* splitMaskReorder(Tensor* b) {
nest.splitWithMask(loops[0], kVectorWidth, &outer, &inner);
loops = nest.getLoopStmtsFor(b);
nest.reorderAxis(loops[1], loops[2]);
std::clog << *nest.root_stmt() << "\n";
nest.prepareForCodegen();
return nest.root_stmt();
}