mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nnc][tests] Convert a bunch of FileCheck to checkIR
Summary: I added a helper to convert a Stmt to string and FileCheck it, so started using it in a bunch of places. I replaced about half the current uses, got tired, started to write a Perl script to automate it, realized that was hard, and decided to give up for a bit. But this cleans up some of the tests a bit, so seems easy to review and worth landing. Test Plan: test_tensorexpr --gtest_filter=LoopNest.* Reviewed By: navahgar Differential Revision: D27375866 fbshipit-source-id: 15894b9089dec5cf25f340fe17e6e54546a64257
This commit is contained in:
committed by
Facebook GitHub Bot
parent
24f589df44
commit
e4d19798f3
@ -22,6 +22,12 @@ namespace jit {
|
||||
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
void checkIR(Stmt* s, const std::string& pattern) {
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
torch::jit::testing::FileCheck().run(pattern, oss.str());
|
||||
}
|
||||
|
||||
TEST(LoopNest, ExprSimple01) {
|
||||
KernelScope kernel_scope;
|
||||
Tensor* tensor = Compute(
|
||||
@ -627,20 +633,16 @@ TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) {
|
||||
l.splitWithMask(outer, 4, &outer, &mid);
|
||||
|
||||
Stmt* stmt1 = IRSimplifier::simplify(l.root_stmt());
|
||||
std::ostringstream oss;
|
||||
oss << *stmt1;
|
||||
|
||||
// Two splits mean 3 loops, but should need no masks in this case.
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
checkIR(stmt1, R"IR(
|
||||
# CHECK: for (
|
||||
# CHECK-NOT: if (
|
||||
# CHECK: for (
|
||||
# CHECK-NOT: if (
|
||||
# CHECK: for (
|
||||
# CHECK-NOT: if (
|
||||
# CHECK: f[)IR";
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
# CHECK: f[)IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, SplitWithTailWithLoopOptions) {
|
||||
@ -1027,18 +1029,14 @@ TEST(LoopNest, ScheduleInlineRandom) {
|
||||
// would normally compare results but Rand isn't implemented in the
|
||||
// SimpleIREvaluator, even if we could seed it.
|
||||
Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
|
||||
std::ostringstream oss;
|
||||
oss << *stmt1;
|
||||
|
||||
// Check the IR we produced
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
checkIR(stmt1, R"IR(
|
||||
# CHECK: for (int m2 = 0; m2 < 4; m2++)
|
||||
# CHECK: for (int n2 = 0; n2 < 5; n2++)
|
||||
# CHECK: for (int k2 = 0; k2 < 6; k2++)
|
||||
# CHECK: int x = rand();
|
||||
# CHECK: y[m2, n2, k2] = 2 * (x % 5);)IR";
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
# CHECK: y[m2, n2, k2] = 2 * (x % 5);)IR");
|
||||
}
|
||||
|
||||
// Make sure we don't cache random vars that are not being inlined.
|
||||
@ -1068,17 +1066,13 @@ TEST(LoopNest, ScheduleInlineRandomUnrelated) {
|
||||
// would normally compare results but Rand isn't implemented in the
|
||||
// SimpleIREvaluator, even if we could seed it.
|
||||
Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
|
||||
std::ostringstream oss;
|
||||
oss << *stmt1;
|
||||
|
||||
// Check the IR we produced
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
checkIR(stmt1, R"IR(
|
||||
# CHECK: for (int m2 = 0; m2 < 4; m2++)
|
||||
# CHECK: for (int n2 = 0; n2 < 5; n2++)
|
||||
# CHECK: for (int k2 = 0; k2 < 6; k2++)
|
||||
# CHECK: y[m2, n2, k2] = ((n2 * m2) * k2 + (rand())) + (rand());)IR";
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
# CHECK: y[m2, n2, k2] = ((n2 * m2) * k2 + (rand())) + (rand());)IR");
|
||||
}
|
||||
|
||||
// Make sure we generate the right number of random values == the dimensionality
|
||||
@ -1105,18 +1099,14 @@ TEST(LoopNest, ScheduleInlineRandomLowerDimensions) {
|
||||
// would normally compare results but Rand isn't implemented in the
|
||||
// SimpleIREvaluator, even if we could seed it.
|
||||
Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
|
||||
std::ostringstream oss;
|
||||
oss << *stmt1;
|
||||
|
||||
// Check the IR we produced
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
checkIR(stmt1, R"IR(
|
||||
# CHECK: for (int m2 = 0; m2 < 4; m2++)
|
||||
# CHECK: int x = rand();
|
||||
# CHECK: for (int n2 = 0; n2 < 5; n2++)
|
||||
# CHECK: for (int k2 = 0; k2 < 6; k2++)
|
||||
# CHECK: y[m2, n2, k2] = 2 * (x % 5);)IR";
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
# CHECK: y[m2, n2, k2] = 2 * (x % 5);)IR");
|
||||
}
|
||||
|
||||
// Make sure we don't screw up intrinsics thinking they're rand.
|
||||
@ -1205,18 +1195,13 @@ TEST(LoopNest, ScheduleInlineRandWithIntrinsics) {
|
||||
|
||||
Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *stmt1;
|
||||
|
||||
// Check the IR we produced
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
checkIR(stmt1, R"IR(
|
||||
# CHECK: for (int m2 = 0; m2 < 4; m2++)
|
||||
# CHECK: for (int n2 = 0; n2 < 5; n2++)
|
||||
# CHECK: for (int k2 = 0; k2 < 6; k2++)
|
||||
# CHECK: float x = rand();
|
||||
# CHECK: y[m2, n2, k2] = sqrt(x);)IR";
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
# CHECK: y[m2, n2, k2] = sqrt(x);)IR");
|
||||
}
|
||||
|
||||
// Split a Compute then inline it into another compute.
|
||||
@ -1503,12 +1488,9 @@ TEST(LoopNest, ScheduleInlineOutputTensors) {
|
||||
// would normally compare results but Rand isn't implemented in the
|
||||
// SimpleIREvaluator, even if we could seed it.
|
||||
Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
|
||||
std::ostringstream oss;
|
||||
oss << *stmt1;
|
||||
|
||||
// Check the IR we produced
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
checkIR(stmt1, R"IR(
|
||||
# CHECK: for (int m1 = 0; m1 < 4; m1++)
|
||||
# CHECK: for (int n1 = 0; n1 < 5; n1++)
|
||||
# CHECK: for (int k1 = 0; k1 < 6; k1++)
|
||||
@ -1516,8 +1498,7 @@ TEST(LoopNest, ScheduleInlineOutputTensors) {
|
||||
# CHECK: for (int m2 = 0; m2 < 4; m2++)
|
||||
# CHECK: for (int n2 = 0; n2 < 5; n2++)
|
||||
# CHECK: for (int k2 = 0; k2 < 6; k2++)
|
||||
# CHECK: y[m2, n2, k2] = (n2 * m2) * k2 + m2;)IR";
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
# CHECK: y[m2, n2, k2] = (n2 * m2) * k2 + m2;)IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, ScheduleFuserStyle) {
|
||||
@ -1644,19 +1625,13 @@ TEST(LoopNest, LoopNestComputeAt_1) {
|
||||
l.prepareForCodegen();
|
||||
Stmt* s = l.root_stmt();
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
checkIR(s, R"IR(
|
||||
# CHECK: for (int i_b = 0; i_b < N; i_b++)
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[1]
|
||||
# CHECK: temp[
|
||||
# CHECK-NOT: A[
|
||||
# CHECK: B[i_b] = temp[0]
|
||||
# CHECK: Free(temp))IR";
|
||||
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
# CHECK: Free(temp))IR");
|
||||
|
||||
// Now check that the loop still produces the correct result.
|
||||
std::vector<int> b_data(100, 0);
|
||||
@ -1717,12 +1692,8 @@ TEST(LoopNest, LoopNestComputeAt_2) {
|
||||
l.prepareForCodegen();
|
||||
Stmt* s = l.root_stmt();
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
|
||||
// Check the IR we produced
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
checkIR(s, R"IR(
|
||||
# CHECK: for (int cy = 0; cy < H; cy++)
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1]
|
||||
# CHECK: for
|
||||
@ -1730,8 +1701,7 @@ TEST(LoopNest, LoopNestComputeAt_2) {
|
||||
# CHECK: for (int cx = 0; cx < W; cx++)
|
||||
# CHECK-NOT: prod[
|
||||
# CHECK: cons[
|
||||
# CHECK: Free(temp))IR";
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
# CHECK: Free(temp))IR");
|
||||
|
||||
// Now check that the loop still produces the correct result.
|
||||
std::vector<int> c_data(kW * kH, 0);
|
||||
@ -1748,12 +1718,8 @@ TEST(LoopNest, LoopNestComputeAt_2) {
|
||||
l.prepareForCodegen();
|
||||
Stmt* s = l.root_stmt();
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
|
||||
// Check the IR we produced
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
checkIR(s, R"IR(
|
||||
# CHECK: for (int cy = 0; cy < H; cy++)
|
||||
# CHECK: for (int cx = 0; cx < W; cx++)
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[2, 2]
|
||||
@ -1761,8 +1727,7 @@ TEST(LoopNest, LoopNestComputeAt_2) {
|
||||
# CHECK: for
|
||||
# CHECK-NOT: prod[
|
||||
# CHECK: cons[
|
||||
# CHECK: Free(temp))IR";
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
# CHECK: Free(temp))IR");
|
||||
|
||||
// Now check that the loop still produces the correct result.
|
||||
std::vector<int> c_data(kW * kH, 0);
|
||||
@ -1826,12 +1791,8 @@ TEST(LoopNest, LoopNestComputeAt_3) {
|
||||
l.prepareForCodegen();
|
||||
Stmt* s = l.root_stmt();
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
|
||||
// Check the IR we produced
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
checkIR(s, R"IR(
|
||||
# CHECK: for (int ay = 0; ay < H + 1; ay++)
|
||||
# CHECK: for (int ax = 0; ax < W + 1; ax++)
|
||||
# CHECK: A[
|
||||
@ -1844,8 +1805,7 @@ TEST(LoopNest, LoopNestComputeAt_3) {
|
||||
# CHECK: for (int dy = 0; dy < H; dy++)
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[1, W]
|
||||
# CHECK: for (int dx = 0; dx < W; dx++)
|
||||
# CHECK-NOT: A[)IR";
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
# CHECK-NOT: A[)IR");
|
||||
|
||||
// Now check that the loop still produces the correct result.
|
||||
std::vector<int> c_data(kW * kH, 0);
|
||||
@ -1862,12 +1822,8 @@ TEST(LoopNest, LoopNestComputeAt_3) {
|
||||
l.prepareForCodegen();
|
||||
Stmt* s = l.root_stmt();
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
|
||||
// Check the IR we produced
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
checkIR(s, R"IR(
|
||||
# CHECK: for (int ay = 0; ay < H + 1; ay++)
|
||||
# CHECK: for (int ax = 0; ax < W + 1; ax++)
|
||||
# CHECK: A[
|
||||
@ -1880,8 +1836,7 @@ TEST(LoopNest, LoopNestComputeAt_3) {
|
||||
# CHECK: for (int dy = 0; dy < H; dy++)
|
||||
# CHECK: for (int dx = 0; dx < W; dx++)
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[1, 1]
|
||||
# CHECK-NOT: A[)IR";
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
# CHECK-NOT: A[)IR");
|
||||
|
||||
// Now check that the loop still produces the correct result.
|
||||
std::vector<int> c_data(kW * kH, 0);
|
||||
@ -1894,12 +1849,6 @@ TEST(LoopNest, LoopNestComputeAt_3) {
|
||||
|
||||
using Axis = const VarHandle&;
|
||||
|
||||
void checkIR(Stmt* s, const std::string& pattern) {
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
torch::jit::testing::FileCheck().run(pattern, oss.str());
|
||||
}
|
||||
|
||||
TEST(LoopNest, Reduce2dComputeAt) {
|
||||
KernelScope kernel_scope;
|
||||
|
||||
@ -2365,12 +2314,8 @@ TEST(LoopNest, LoopNestReorderExtraStatements) {
|
||||
l.reorderAxis(loops[1], loops[2]);
|
||||
Stmt* stmt2 = Stmt::clone(l.root_stmt());
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *l.root_stmt();
|
||||
|
||||
// Check the IR we produced
|
||||
const std::string& verification_pattern1 =
|
||||
R"IR(
|
||||
checkIR(stmt2, R"IR(
|
||||
# CHECK: for (int x
|
||||
# CHECK: res[x, 0] = 1
|
||||
# CHECK: for (int y
|
||||
@ -2380,8 +2325,7 @@ TEST(LoopNest, LoopNestReorderExtraStatements) {
|
||||
# CHECK: f[
|
||||
# CHECK: for (int y
|
||||
# CHECK: res[x, 2] = 4
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(verification_pattern1, oss.str());
|
||||
)IR");
|
||||
|
||||
std::vector<int> extra2(6, 0);
|
||||
std::vector<int> res2(24, 0);
|
||||
@ -2418,12 +2362,8 @@ TEST(LoopNest, LoopNestReorderExtraStatements) {
|
||||
l.reorderAxis(loops[0], loops[2]);
|
||||
Stmt* stmt3 = Stmt::clone(l.root_stmt());
|
||||
|
||||
std::ostringstream oss2;
|
||||
oss2 << *stmt3;
|
||||
|
||||
// Check the IR we produced
|
||||
const std::string& verification_pattern2 =
|
||||
R"IR(
|
||||
checkIR(stmt3, R"IR(
|
||||
# CHECK: for (int x
|
||||
# CHECK: res[x, 0] = 1
|
||||
# CHECK: for (int y
|
||||
@ -2435,8 +2375,7 @@ TEST(LoopNest, LoopNestReorderExtraStatements) {
|
||||
# CHECK: for (int x
|
||||
# CHECK: for (int y
|
||||
# CHECK: res[x, 2] = 4
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(verification_pattern2, oss2.str());
|
||||
)IR");
|
||||
|
||||
std::vector<int> extra3(6, 0);
|
||||
std::vector<int> res3(24, 0);
|
||||
@ -2615,13 +2554,9 @@ TEST(LoopNest, LoopNestReorderInternalLoopNest) {
|
||||
l.prepareForCodegen();
|
||||
Stmt* stmt = IRSimplifier::simplify(l.root_stmt());
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *stmt;
|
||||
|
||||
// Check the IR we produced has the 3 nests in the right order, but k and m
|
||||
// swapped in the middle.
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
checkIR(stmt, R"IR(
|
||||
# CHECK: for (int m1
|
||||
# CHECK: for (int n1
|
||||
# CHECK: for (int k1
|
||||
@ -2630,8 +2565,7 @@ TEST(LoopNest, LoopNestReorderInternalLoopNest) {
|
||||
# CHECK: for (int m2
|
||||
# CHECK: for (int m3
|
||||
# CHECK: for (int n3
|
||||
# CHECK: for (int k3)IR";
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
# CHECK: for (int k3)IR");
|
||||
|
||||
{
|
||||
PaddedBuffer<float> a_v(M, N);
|
||||
@ -2744,8 +2678,7 @@ TEST(LoopNest, UnrollOuter) {
|
||||
std::vector<For*> loops = l.getLoopStmtsFor(A);
|
||||
Stmt* unrolled = nullptr;
|
||||
LoopNest::unroll(loops[0], &unrolled);
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
checkIR(unrolled, R"IR(
|
||||
# CHECK: for (int y = 0; y < 4; y++) {
|
||||
# CHECK: A[0, y] = y;
|
||||
# CHECK: }
|
||||
@ -2754,11 +2687,7 @@ TEST(LoopNest, UnrollOuter) {
|
||||
# CHECK: }
|
||||
# CHECK: for (int y = 0; y < 4; y++) {
|
||||
# CHECK: A[2, y] = y + 2;
|
||||
# CHECK: })IR";
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *unrolled;
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
# CHECK: })IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, UnrollInner) {
|
||||
@ -2774,18 +2703,13 @@ TEST(LoopNest, UnrollInner) {
|
||||
Stmt* unrolled = nullptr;
|
||||
LoopNest::unroll(
|
||||
static_cast<For*>(loops[0]->body()->stmts().front()), &unrolled);
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
checkIR(loops[0], R"IR(
|
||||
# CHECK: for (int x = 0; x < 3; x++) {
|
||||
# CHECK: A[x, 0] = x;
|
||||
# CHECK: A[x, 1] = x + 1;
|
||||
# CHECK: A[x, 2] = x + 2;
|
||||
# CHECK: A[x, 3] = x + 3;
|
||||
# CHECK: })IR";
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *loops[0];
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
# CHECK: })IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, UnrollMultipleStatements) {
|
||||
@ -2805,18 +2729,13 @@ TEST(LoopNest, UnrollMultipleStatements) {
|
||||
Block::make({f});
|
||||
Stmt* unrolled = nullptr;
|
||||
LoopNest::unroll(f, &unrolled);
|
||||
std::ostringstream oss;
|
||||
oss << *unrolled;
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
checkIR(unrolled, R"IR(
|
||||
# CHECK: A[0] = 0;
|
||||
# CHECK: B[0] = A[0];
|
||||
# CHECK: A[1] = 2;
|
||||
# CHECK: B[1] = A[1];
|
||||
# CHECK: A[2] = 4
|
||||
# CHECK: B[2] = A[2];)IR";
|
||||
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
# CHECK: B[2] = A[2];)IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, UnrollNonLiteralConstantBounds) {
|
||||
@ -2843,8 +2762,7 @@ TEST(LoopNest, UnrollNonLiteralConstantBounds) {
|
||||
std::vector<For*> loops = {outer_for, inner_for};
|
||||
Stmt* unrolled = nullptr;
|
||||
LoopNest::unroll(loops[0], &unrolled);
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
checkIR(unrolled, R"IR(
|
||||
# CHECK: for (int j = 0; j < 4; j++) {
|
||||
# CHECK: A[1, j] = j;
|
||||
# CHECK: }
|
||||
@ -2853,11 +2771,7 @@ TEST(LoopNest, UnrollNonLiteralConstantBounds) {
|
||||
# CHECK: }
|
||||
# CHECK: for (int j = 0; j < 4; j++) {
|
||||
# CHECK: A[3, j] = 3 * j;
|
||||
# CHECK: })IR";
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *unrolled;
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
# CHECK: })IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, UnrollEmpty) {
|
||||
@ -3337,14 +3251,10 @@ TEST(LoopNest, FlattenLoopNestWithNonLiteralConstantBounds) {
|
||||
ASSERT_TRUE(success);
|
||||
|
||||
auto result = IRSimplifier::simplify(flattened);
|
||||
std::ostringstream oss;
|
||||
oss << *result;
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
checkIR(result, R"IR(
|
||||
# CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) {
|
||||
# CHECK: A[i_flat / 5, i_flat % 5] =
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
)IR");
|
||||
|
||||
{
|
||||
SimpleIREvaluator eval1(loops[0], {a_buf});
|
||||
@ -3384,16 +3294,12 @@ TEST(LoopNest, FlattenImperfectLoopNest) {
|
||||
ASSERT_FALSE(success);
|
||||
|
||||
auto result = IRSimplifier::simplify(flattened);
|
||||
std::ostringstream oss;
|
||||
oss << *result;
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
checkIR(result, R"IR(
|
||||
# CHECK: for (int i = 0; i < 10; i++) {
|
||||
# CHECK-NEXT: A[i, i] =
|
||||
# CHECK-NEXT: for (int j = 0; j < 15; j++) {
|
||||
# CHECK-NEXT: A[i, j] =
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
)IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, FlattenReductionLoopNest) {
|
||||
@ -3425,16 +3331,12 @@ TEST(LoopNest, FlattenReductionLoopNest) {
|
||||
ASSERT_FALSE(success);
|
||||
|
||||
auto result = IRSimplifier::simplify(flattened);
|
||||
std::ostringstream oss;
|
||||
oss << *result;
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
checkIR(result, R"IR(
|
||||
# CHECK: for (int i = 0; i < 10; i++) {
|
||||
# CHECK-NEXT: S[i] =
|
||||
# CHECK-NEXT: for (int j = 0; j < 15; j++) {
|
||||
# CHECK-NEXT: S[i] = (S[i]) + (A[i, j])
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
)IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, FlattenReductionLoopNestFromTensor) {
|
||||
@ -3452,16 +3354,12 @@ TEST(LoopNest, FlattenReductionLoopNestFromTensor) {
|
||||
ASSERT_FALSE(success);
|
||||
|
||||
auto result = IRSimplifier::simplify(flattened);
|
||||
std::ostringstream oss;
|
||||
oss << *result;
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
checkIR(result, R"IR(
|
||||
# CHECK: for (int m = 0; m < 3; m++) {
|
||||
# CHECK-NEXT: sum[m] =
|
||||
# CHECK-NEXT: for (int n = 0; n < 7; n++) {
|
||||
# CHECK-NEXT: sum[m] =
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
)IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, FlattenIncorrectLoopsAsInput) {
|
||||
@ -3500,15 +3398,11 @@ TEST(LoopNest, FlattenIncorrectLoopsAsInput) {
|
||||
ASSERT_FALSE(success);
|
||||
|
||||
auto result = IRSimplifier::simplify(flattened);
|
||||
std::ostringstream oss;
|
||||
oss << *result;
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
checkIR(result, R"IR(
|
||||
# CHECK: for (int i = 0; i < 10; i++) {
|
||||
# CHECK-NEXT: for (int j = 0; j < 5; j++) {
|
||||
# CHECK-NEXT: A[i, j] = i * j
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
)IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, DetectInlineRankMismatch) {
|
||||
@ -3552,12 +3446,8 @@ TEST(LoopNest, CacheReadsSimple) {
|
||||
l.prepareForCodegen();
|
||||
Stmt* result = IRSimplifier::simplify(l.root_stmt());
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *result;
|
||||
|
||||
// just this once: verify the whole thing.
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
checkIR(result, R"IR(
|
||||
#CHECK: Allocate(A); // dtype=int, dims=[64, 64]
|
||||
#CHECK: for (int i
|
||||
#CHECK: for (int j
|
||||
@ -3580,8 +3470,7 @@ TEST(LoopNest, CacheReadsSimple) {
|
||||
#CHECK: }
|
||||
#CHECK: }
|
||||
#CHECK: Free(A);
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
)IR");
|
||||
|
||||
std::vector<int> b_data(200, 0);
|
||||
std::vector<int> c_data(200, 0);
|
||||
@ -3625,15 +3514,11 @@ TEST(LoopNest, CacheReadsOuter) {
|
||||
l.prepareForCodegen();
|
||||
Stmt* result = IRSimplifier::simplify(l.root_stmt());
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *result;
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
checkIR(result, R"IR(
|
||||
#CHECK: Allocate(A_local); // dtype=int, dims=[21, 11]
|
||||
#CHECK: A_local[j_1 + 11 * i_1] =
|
||||
#CHECK: B[10 * i_2 + j_2] = (A_local[(j_2 + 11 * i_2) + 12]) + (A_local[j_2 + 11 * i_2]);
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
)IR");
|
||||
|
||||
std::vector<int> b_data(200, 0);
|
||||
std::vector<int> c_data(200, 0);
|
||||
@ -3676,15 +3561,11 @@ TEST(LoopNest, CacheReadsInternal) {
|
||||
l.prepareForCodegen();
|
||||
Stmt* result = IRSimplifier::simplify(l.root_stmt());
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *result;
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
checkIR(result, R"IR(
|
||||
#CHECK: Allocate(A_local); // dtype=int, dims=[2, 11]
|
||||
#CHECK: A_local[j_1 + 11 * i_2] =
|
||||
#CHECK: B[10 * i_1 + j_2] = (A_local[j_2]) + (A_local[j_2 + 12]);
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
)IR");
|
||||
|
||||
std::vector<int> b_data(200, 0);
|
||||
std::vector<int> c_data(200, 0);
|
||||
@ -3728,15 +3609,11 @@ TEST(LoopNest, CacheReadsInner) {
|
||||
l.prepareForCodegen();
|
||||
Stmt* result = IRSimplifier::simplify(l.root_stmt());
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *result;
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
checkIR(result, R"IR(
|
||||
#CHECK: Allocate(A_local); // dtype=int, dims=[5, 2]
|
||||
#CHECK: A_local[2 * i_2 + j_2] =
|
||||
#CHECK: B[10 * i_1 + j_1] = (A_local[8]) + (A_local[1]);
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
)IR");
|
||||
|
||||
std::vector<int> b_data(200, 0);
|
||||
std::vector<int> c_data(200, 0);
|
||||
@ -3780,10 +3657,7 @@ TEST(LoopNest, CacheWritesSimple) {
|
||||
l.prepareForCodegen();
|
||||
Stmt* result = IRSimplifier::simplify(l.root_stmt());
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *result;
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
checkIR(result, R"IR(
|
||||
#CHECK: Allocate(A_local); // dtype=int, dims=[1, 64]
|
||||
#CHECK: for (int j = 0; j < 64
|
||||
#CHECK: A_local[j] = i * j;
|
||||
@ -3791,8 +3665,7 @@ TEST(LoopNest, CacheWritesSimple) {
|
||||
#CHECK: A[64 * i + j_1] = A_local[
|
||||
#CHECK: Free(A_local);
|
||||
#CHECK-NOT: A_local
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
)IR");
|
||||
|
||||
std::vector<int> b_data(200, 0);
|
||||
std::vector<int> c_data(200, 0);
|
||||
@ -3839,29 +3712,19 @@ TEST(LoopNest, DeadStoreElimination) {
|
||||
LoopNest loop(stmt, {f.node()});
|
||||
loop.eliminateDeadStores();
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *loop.root_stmt();
|
||||
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
checkIR(loop.root_stmt(), R"IR(
|
||||
#CHECK: f[x_tail + 5 * 4, y]
|
||||
#CHECK-NOT: g[x_tail + 5 * 4, y]
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
)IR");
|
||||
|
||||
// But won't eliminate if used by different outputs.
|
||||
LoopNest loop2(stmt, {f.node(), g.node()});
|
||||
loop2.eliminateDeadStores();
|
||||
|
||||
oss.clear();
|
||||
oss << *loop2.root_stmt();
|
||||
|
||||
const std::string& expected_ir2 =
|
||||
R"IR(
|
||||
checkIR(loop2.root_stmt(), R"IR(
|
||||
#CHECK: f[x_tail + 5 * 4, y]
|
||||
#CHECK: g[x_tail + 5 * 4, y]
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir2, oss.str());
|
||||
)IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, DeadStoreEliminationWithIntermediates) {
|
||||
@ -3894,31 +3757,21 @@ TEST(LoopNest, DeadStoreEliminationWithIntermediates) {
|
||||
LoopNest loop(stmt, {h.node()});
|
||||
loop.eliminateDeadStores();
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *loop.root_stmt();
|
||||
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
checkIR(loop.root_stmt(), R"IR(
|
||||
#CHECK: f[x] = x;
|
||||
#CHECK-NOT: g[z] =
|
||||
#CHECK: h[x, y] = f[x * y];
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
)IR");
|
||||
|
||||
// Sanity check won't eliminate if g is an output.
|
||||
LoopNest loop2(stmt, {h.node(), g.node()});
|
||||
loop2.eliminateDeadStores();
|
||||
|
||||
oss.clear();
|
||||
oss << *loop2.root_stmt();
|
||||
|
||||
const std::string& expected_ir2 =
|
||||
R"IR(
|
||||
checkIR(loop2.root_stmt(), R"IR(
|
||||
#CHECK: f[x] = x;
|
||||
#CHECK: g[z] = z + 1;
|
||||
#CHECK: h[x, y] = f[x * y];
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir2, oss.str());
|
||||
)IR");
|
||||
}
|
||||
|
||||
TEST(LoopNest, CompoundTensorSimple) {
|
||||
@ -4109,7 +3962,6 @@ static Stmt* splitMaskReorder(Tensor* b) {
|
||||
nest.splitWithMask(loops[0], kVectorWidth, &outer, &inner);
|
||||
loops = nest.getLoopStmtsFor(b);
|
||||
nest.reorderAxis(loops[1], loops[2]);
|
||||
std::clog << *nest.root_stmt() << "\n";
|
||||
nest.prepareForCodegen();
|
||||
return nest.root_stmt();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user