diff --git a/benchmarks/cpp/tensorexpr/bench_approx.cpp b/benchmarks/cpp/tensorexpr/bench_approx.cpp index 208b333ce57b..07e2eeefc550 100644 --- a/benchmarks/cpp/tensorexpr/bench_approx.cpp +++ b/benchmarks/cpp/tensorexpr/bench_approx.cpp @@ -25,7 +25,7 @@ void optimizePointwise(tensorexpr::LoopNest* ln, tensorexpr::Tensor target) { ln->vectorize(inner); ln->splitWithTail(outer, 8, &inner, &tail); StmtPtr unrolled; - LoopNest::unroll(inner, &unrolled); + LoopNest::fullUnroll(inner, &unrolled); } static void relu_nnc(benchmark::State& state) { diff --git a/benchmarks/cpp/tensorexpr/bench_gemm.cpp b/benchmarks/cpp/tensorexpr/bench_gemm.cpp index a860c10d01c7..6d452368fc7a 100644 --- a/benchmarks/cpp/tensorexpr/bench_gemm.cpp +++ b/benchmarks/cpp/tensorexpr/bench_gemm.cpp @@ -230,7 +230,7 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16VecUnroll)(benchmark::State& state) { te::ForPtr ni = loops[4]; te::StmtPtr unrolled; loop.vectorize(ni); - loop.unroll(mi, &unrolled); + loop.fullUnroll(mi, &unrolled); } loop.prepareForCodegen(); diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index dd8950e8efa1..becf3bdffbac 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -2929,7 +2929,7 @@ std::string constantUpperBoundLoopIR(int upper_bound_val) { LoopNest l({A}); std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; StmtPtr unrolled = nullptr; - LoopNest::unroll(loops[0], &unrolled); + LoopNest::fullUnroll(loops[0], &unrolled); std::ostringstream oss; oss << *unrolled; return oss.str(); @@ -2958,7 +2958,7 @@ TEST(LoopNest, UnrollOuter) { LoopNest l({A}); std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; StmtPtr unrolled = nullptr; - LoopNest::unroll(loops[0], &unrolled); + LoopNest::fullUnroll(loops[0], &unrolled); checkIR(unrolled, R"IR( # CHECK: for (int y = 0; y < 4; y++) { # CHECK: A[0, y] = y; @@ -2981,7 +2981,7 @@ TEST(LoopNest, UnrollInner) { LoopNest l({A}); std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; StmtPtr unrolled = nullptr; - LoopNest::unroll( + LoopNest::fullUnroll( static_to(loops[0]->body()->stmts().front()), &unrolled); checkIR(loops[0], R"IR( # CHECK: for (int x = 0; x < 3; x++) { @@ -3007,7 +3007,7 @@ TEST(LoopNest, UnrollMultipleStatements) { Store::make(b_buf, {x}, Load::make(a_buf, {x}))})); auto parent_block = Block::make({f}); StmtPtr unrolled = nullptr; - LoopNest::unroll(f, &unrolled); + LoopNest::fullUnroll(f, &unrolled); checkIR(unrolled, R"IR( # CHECK: A[0] = 0; # CHECK: B[0] = A[0]; @@ -3039,7 +3039,7 @@ TEST(LoopNest, UnrollNonLiteralConstantBounds) { std::vector loops = {outer_for, inner_for}; StmtPtr unrolled = nullptr; - LoopNest::unroll(loops[0], &unrolled); + LoopNest::fullUnroll(loops[0], &unrolled); checkIR(unrolled, R"IR( # CHECK: for (int j = 0; j < 4; j++) { # CHECK: A[1, j] = j; @@ -3052,6 +3052,117 @@ TEST(LoopNest, UnrollNonLiteralConstantBounds) { # CHECK: })IR"); } +TEST(LoopNest, UnrollNonConstantBounds) { + // Input IR: + // for (int i = 0; i < M; i++) { + // for (int j = 0; j < N; j++) { + // A[i, j] = i * j; + // } + // } + VarHandle M("M", kInt); + VarHandle N("N", kInt); + BufHandle a_buf("A", {M, N}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for = For::make(j, 0, N, for_body); + auto outer_for = For::make(i, 0, M, inner_for); + auto block = Block::make({outer_for}); + LoopNest l(block, {a_buf.node()}); + + LoopNest::unroll(inner_for, 8); + l.simplify(); + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int i = 0; i < M; i++) { + # CHECK: for (int j_outer = 0; j_outer < N / 8; j_outer++) { + # CHECK: A[i, 8 * j_outer] = + # CHECK: A[i, 8 * j_outer + 1] = + # CHECK: A[i, 2 * (4 * j_outer + 1)] = + # CHECK: A[i, 8 * j_outer + 3] = + # CHECK: A[i, 4 * (2 * j_outer + 1)] = + # CHECK: A[i, 8 * j_outer + 5] = + # CHECK: A[i, 8 * j_outer + 6] = + # CHECK: A[i, 8 * j_outer + 7] = + # CHECK: } + # CHECK: for (int j_tail = 0; j_tail < N % 8; j_tail++) { + # CHECK: A[i, 8 * (N / 8) + j_tail] = + # CHECK: } + # CHECK: } + )IR"); +} + +TEST(LoopNest, UnrollByFactorsLessThan2) { + // Input IR: + // for (int i = 0; i < M; i++) { + // for (int j = 0; j < N; j++) { + // A[i, j] = i * j; + // } + // } + VarHandle M("M", kInt); + VarHandle N("N", kInt); + BufHandle a_buf("A", {M, N}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for = For::make(j, 0, N, for_body); + auto outer_for = For::make(i, 0, M, inner_for); + auto block = Block::make({outer_for}); + LoopNest l(block, {a_buf.node()}); + + // Unrolling by factor = 1 should do nothing. + LoopNest::unroll(inner_for, 1); + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int i = 0; i < M; i++) { + # CHECK: for (int j = 0; j < N; j++) { + # CHECK: A[i, j] = + # CHECK: } + # CHECK: } + )IR"); + + // Unrolling by factor = 0 should do nothing. + LoopNest::unroll(inner_for, 0); + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int i = 0; i < M; i++) { + # CHECK: for (int j = 0; j < N; j++) { + # CHECK: A[i, j] = + # CHECK: } + # CHECK: } + )IR"); + + // Unrolling by negative factor should do nothing. + LoopNest::unroll(inner_for, -2); + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int i = 0; i < M; i++) { + # CHECK: for (int j = 0; j < N; j++) { + # CHECK: A[i, j] = + # CHECK: } + # CHECK: } + )IR"); +} + +TEST(LoopNest, UnrollByFactorEqualToIters) { + // Input IR: + // for (int i = 0; i < 5; i++) { + // A[i] = i * i; + // } + BufHandle a_buf("A", {5}, kInt); + VarHandle i("i", kInt); + auto for_body = Block::make({Store::make(a_buf, {i}, i * i)}); + auto for_loop = For::make(i, 0, 5, for_body); + auto block = Block::make({for_loop}); + LoopNest l(block, {a_buf.node()}); + + LoopNest::unroll(for_loop, 5); + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int i_outer = 0; i_outer < (5 - 0) / 5; i_outer++) + # CHECK: A[5 * i_outer] + # CHECK: A[5 * i_outer + 1] + # CHECK: A[5 * i_outer + 2] + # CHECK: A[5 * i_outer + 3] + # CHECK: A[5 * i_outer + 4] + )IR"); +} + TEST(LoopNest, UnrollEmpty) { const std::string actual = constantUpperBoundLoopIR(0); const std::string& verification_pattern = R"IR( @@ -3069,7 +3180,7 @@ TEST(LoopNest, NoUnroll) { std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; StmtPtr unrolled = nullptr; ASSERT_THROWS_WITH( - LoopNest::unroll(loops[0], &unrolled), "non-constant loop"); + LoopNest::fullUnroll(loops[0], &unrolled), "non-constant loop"); } TEST(LoopNest, UnrollWithLet) { @@ -3089,7 +3200,7 @@ TEST(LoopNest, UnrollWithLet) { Store::make(b_buf, {x}, e + 1)})); auto parent_block = Block::make({f}); StmtPtr unrolled = nullptr; - LoopNest::unroll(f, &unrolled); + LoopNest::fullUnroll(f, &unrolled); std::ostringstream oss; oss << *unrolled; const std::string& verification_pattern = diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index a3c2b196f1f5..cc49ef3ef82b 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -2309,7 +2309,7 @@ bool LoopNest::areLoopsPerfectlyNested(const std::vector& loops) { return true; } -void LoopNest::unroll(ForPtr f, StmtPtr* unrolled) { +void LoopNest::fullUnroll(ForPtr f, StmtPtr* unrolled) { BlockPtr p = to(f->get_parent()); if (!f) { throw malformed_input("unroll attempted on null loop"); @@ -2341,10 +2341,26 @@ void LoopNest::unroll(ForPtr f, StmtPtr* unrolled) { p->replace_stmt(f, *unrolled); } -void LoopNest::unroll(ForPtr f) { +void LoopNest::fullUnroll(ForPtr f) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) StmtPtr unrolled; - unroll(f, &unrolled); + fullUnroll(f, &unrolled); +} + +void LoopNest::unroll(ForPtr f, int factor, ForPtr* tail) { + if (factor < 2) { + return; + } + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + ForPtr inner; + splitWithTail(f, factor, &inner, tail); + fullUnroll(inner); +} + +void LoopNest::unroll(ForPtr f, int factor) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + ForPtr tail; + unroll(f, factor, &tail); } bool LoopNest::isNormalized(ForPtr f) { diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h index 8d9432cd0fc2..25b6f29e695e 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.h +++ b/torch/csrc/jit/tensorexpr/loopnest.h @@ -418,8 +418,15 @@ class TORCH_API LoopNest { // Returns true if the given loop has a loop-carried dependence. static bool hasLoopCarriedDependence(ForPtr loop); - static void unroll(ForPtr f, StmtPtr* unrolled); - static void unroll(ForPtr f); + // Unrolls all the iterations of the given loop. + // Requires that the loop bounds are constant. + static void fullUnroll(ForPtr f, StmtPtr* unrolled); + static void fullUnroll(ForPtr f); + + // Unrolls the given loop for the specified factor. + // This does not require constant bounds for the loop being unrolled. + static void unroll(ForPtr f, int factor, ForPtr* tail); + static void unroll(ForPtr f, int factor); static bool normalize(ForPtr f); static bool isNormalized(ForPtr f); diff --git a/torch/csrc/jit/tensorexpr/loopnest_randomization.cpp b/torch/csrc/jit/tensorexpr/loopnest_randomization.cpp index c152a1da64cc..b7745eac0682 100644 --- a/torch/csrc/jit/tensorexpr/loopnest_randomization.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest_randomization.cpp @@ -158,7 +158,7 @@ void loopnestRandomization(int64_t seed, LoopNest& l) { // static void reorderAxis(ForPtr a, ForPtr b); // static std::vector reorder(const std::vector& loops, const std::vector& permutation); // ForPtr tile(ForPtr x, ForPtr y, int x_factor, int y_factor); - // static void unroll(ForPtr f); + // static void fullUnroll(ForPtr f); // static bool normalize(ForPtr f); // static bool flatten(const std::vector& f, ForPtr* flattened); // static void compressBuffer(BufPtr buf, StmtPtr stmt); @@ -191,7 +191,7 @@ void loopnestRandomization(int64_t seed, LoopNest& l) { REORDER_AXIS, REORDER, TILE, - UNROLL, + FULL_UNROLL, NORMALIZE, FLATTEN, COMPRESS_BUFFER, @@ -512,7 +512,7 @@ void loopnestRandomization(int64_t seed, LoopNest& l) { break; } - case UNROLL: { + case FULL_UNROLL: { auto loops = NodeFinder::find(l.root_stmt()); if (loops.size() == 0) { break; @@ -520,9 +520,9 @@ void loopnestRandomization(int64_t seed, LoopNest& l) { int loop_n = std::rand() % (int)loops.size(); auto loop = loops[loop_n]; - message = "unroll(loops[" + std::to_string(loop_n) + "]);\n"; + message = "fullUnroll(loops[" + std::to_string(loop_n) + "]);\n"; randomization_helper::printHistory(n_transform, message); - l.unroll(loop); + LoopNest::fullUnroll(loop); break; } diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp index a22df768a1c9..d1ad1e27b8e9 100644 --- a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp +++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp @@ -607,13 +607,20 @@ void initTensorExprBindings(PyObject* module) { }, py::return_value_policy::reference) .def( - "unroll", - [](const LoopNest& self, ForPtr f) { + "fullUnroll", + [](ForPtr f) { StmtPtr unrolled = nullptr; - self.unroll(f, &unrolled); + LoopNest::fullUnroll(f, &unrolled); return unrolled; }, py::return_value_policy::reference) + .def( + "unroll", + [](ForPtr f, int factor) { + LoopNest::unroll(f, factor); + return f; + }, + py::return_value_policy::reference) .def( "vectorize", [](ForPtr f) { LoopNest::vectorize(f); },