mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit d742a2896c571a535003d5928fe80397325575a5. Reverted https://github.com/pytorch/pytorch/pull/158928 on behalf of https://github.com/yangw-dev due to this breaks bunch of internal dependency since some tests are still using the deleted test files from this pr, the internal reviewer please help fix this using codev ([comment](https://github.com/pytorch/pytorch/pull/158928#issuecomment-3134378616))
6895 lines
204 KiB
C++
6895 lines
204 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <test/cpp/tensorexpr/test_base.h>
|
|
#include <memory>
|
|
#include <sstream>
|
|
#include <stdexcept>
|
|
#include <unordered_map>
|
|
|
|
#include <test/cpp/tensorexpr/padded_buffer.h>
|
|
#include <test/cpp/tensorexpr/test_utils.h>
|
|
#include <torch/csrc/jit/tensorexpr/analysis.h>
|
|
#include <torch/csrc/jit/tensorexpr/bounds_inference.h>
|
|
#include <torch/csrc/jit/tensorexpr/eval.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
|
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
|
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
void checkIR(StmtPtr s, const std::string& pattern) {
|
|
std::ostringstream oss;
|
|
oss << *s;
|
|
torch::jit::testing::FileCheck().run(pattern, oss.str());
|
|
}
|
|
|
|
void checkExprIR(ExprPtr e, const std::string& pattern) {
|
|
std::string prefixed_pattern = "# CHECK: " + pattern + "\n";
|
|
std::ostringstream oss;
|
|
oss << *e << "\n";
|
|
torch::jit::testing::FileCheck().run(prefixed_pattern, oss.str());
|
|
}
|
|
|
|
void checkExprIR(const ExprHandle& e, const std::string& pattern) {
|
|
checkExprIR(e.node(), pattern);
|
|
}
|
|
|
|
TEST(LoopNest, ExprSimple01) {
|
|
Tensor tensor =
|
|
Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) {
|
|
return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
|
|
});
|
|
LoopNest l({tensor});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
|
|
LoopNest::splitWithTail(loops[0], 2);
|
|
LoopNest::splitWithTail(loops[0], 2);
|
|
}
|
|
|
|
TEST(LoopNest, ExprLower01) {
|
|
Tensor tensor =
|
|
Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) {
|
|
return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
|
|
});
|
|
LoopNest l({tensor});
|
|
StmtPtr stmt = l.root_stmt();
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
ASSERT_GT(oss.str().size(), 20);
|
|
ASSERT_LT(oss.str().size(), 200);
|
|
}
|
|
|
|
TEST(LoopNest, ExprSimple02) {
|
|
auto func = [](const ExprHandle& x, const ExprHandle& y) {
|
|
return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
|
|
};
|
|
Tensor tensor = Compute("f", {26, 5}, func);
|
|
LoopNest l({tensor});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
|
|
LoopNest::splitWithTail(loops[0], 4);
|
|
|
|
StmtPtr stmt = l.root_stmt();
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
ASSERT_GT(oss.str().size(), 200);
|
|
ASSERT_LT(oss.str().size(), 600);
|
|
|
|
{
|
|
// Compare to a reference loop structure structure.
|
|
VarHandle x_outer("i_outer", kInt);
|
|
VarHandle x_inner("i_inner", kInt);
|
|
VarHandle y("i", kInt);
|
|
VarHandle x_tail("i_tail", kInt);
|
|
BufHandle f("f", {26, 5}, kFloat);
|
|
ExprHandle x_1 = x_outer * 4 + x_inner;
|
|
ExprHandle x_outer_end = (ExprHandle(26) - 0) / 4;
|
|
ForPtr stmt1 = For::make(
|
|
x_outer,
|
|
0,
|
|
x_outer_end,
|
|
For::make(
|
|
x_inner,
|
|
0,
|
|
4,
|
|
For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))));
|
|
ExprHandle x_2 = x_tail + x_outer_end * 4;
|
|
ForPtr stmt2 = For::make(
|
|
x_tail,
|
|
0,
|
|
(ExprHandle(26) - 0) % 4,
|
|
For::make(y, 0, 5, Store::make(f, {x_2, y}, func(x_2, y))));
|
|
StmtPtr stmt = Block::make({stmt1, stmt2});
|
|
|
|
std::ostringstream oss_ref;
|
|
oss_ref << *stmt;
|
|
ASSERT_EQ(oss.str(), oss_ref.str());
|
|
}
|
|
|
|
{
|
|
PaddedBuffer<float> f_v(26, 5, "f_v");
|
|
PaddedBuffer<float> f_ref(26, 5, "f_res");
|
|
|
|
stmt = FlattenIndexes(stmt);
|
|
SimpleIREvaluator ir_eval(stmt, {tensor});
|
|
ir_eval(f_v);
|
|
|
|
for (int x = 0; x < 26; x++) {
|
|
for (int y = 0; y < 5; y++) {
|
|
f_ref(x, y) = 1 + x * x + y * y;
|
|
}
|
|
}
|
|
|
|
ExpectAllNear(f_v, f_ref, 1e-5);
|
|
}
|
|
}
|
|
|
|
BlockPtr getSimplifiedBody(const LoopNest& l) {
|
|
StmtPtr stmt = l.root_stmt();
|
|
StmtPtr simplified = IRSimplifier::simplify(stmt);
|
|
return to<Block>(simplified);
|
|
}
|
|
|
|
void assertForRange(ForPtr f, int expected_start, int expected_stop) {
|
|
ASSERT_NE(f, nullptr);
|
|
IntImmPtr start = to<IntImm>(f->start());
|
|
ASSERT_NE(start, nullptr);
|
|
ASSERT_EQ(start->value(), expected_start);
|
|
IntImmPtr stop = to<IntImm>(f->stop());
|
|
ASSERT_NE(stop, nullptr);
|
|
ASSERT_EQ(stop->value(), expected_stop);
|
|
}
|
|
|
|
void assertForRanges(
|
|
BlockPtr body,
|
|
const std::vector<std::pair<int, int>>& start_stops) {
|
|
ASSERT_EQ(body->nstmts(), start_stops.size());
|
|
|
|
auto it = body->begin();
|
|
for (size_t i = 0; i < start_stops.size(); i++, it++) {
|
|
ForPtr loop = to<For>(*it);
|
|
assertForRange(loop, start_stops[i].first, start_stops[i].second);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, ExprSliceHeadWithLoopOptions) {
|
|
auto func = [](const ExprHandle& x) {
|
|
return ExprHandle(1.0f) + cast<float>(x);
|
|
};
|
|
Tensor tensor = Compute("f", {10}, func);
|
|
LoopNest l({tensor});
|
|
ForPtr head;
|
|
ForPtr tail;
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
loops[0]->set_gpu_block_index(LoopOptions::IDX_Y);
|
|
LoopNest::sliceHead(loops[0], 2, &head, &tail);
|
|
|
|
BlockPtr body = getSimplifiedBody(l);
|
|
assertForRanges(body, {{0, 2}, {0, 8}});
|
|
|
|
ASSERT_TRUE(tail->loop_options().is_gpu_block_index());
|
|
ASSERT_EQ(tail->loop_options().gpu_block_index(), LoopOptions::IDX_Y);
|
|
|
|
ASSERT_TRUE(head->loop_options().isDefault());
|
|
}
|
|
|
|
TEST(LoopNest, ExprSliceTailWithLoopOptions) {
|
|
auto func = [](const ExprHandle& x) {
|
|
return ExprHandle(1.0f) + cast<float>(x);
|
|
};
|
|
Tensor tensor = Compute("f", {10}, func);
|
|
LoopNest l({tensor});
|
|
ForPtr head;
|
|
ForPtr tail;
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::sliceTail(loops[0], 4, &head, &tail);
|
|
|
|
ForPtr tail_head;
|
|
ForPtr tail_tail;
|
|
tail->set_gpu_block_index(LoopOptions::IDX_Y);
|
|
LoopNest::sliceTail(tail, 2, &tail_head, &tail_tail);
|
|
|
|
BlockPtr body = getSimplifiedBody(l);
|
|
assertForRanges(body, {{0, 6}, {0, 2}, {8, 10}});
|
|
|
|
ASSERT_TRUE(tail_head->loop_options().is_gpu_block_index());
|
|
ASSERT_EQ(tail_head->loop_options().gpu_block_index(), LoopOptions::IDX_Y);
|
|
|
|
ASSERT_TRUE(head->loop_options().isDefault());
|
|
ASSERT_TRUE(tail_tail->loop_options().isDefault());
|
|
}
|
|
|
|
TEST(LoopNest, ExprSliceHeadWhenFactorEqualsSize) {
|
|
// When factor equals the For loop's original size, keep using the original
|
|
// For loop.
|
|
auto func = [](const ExprHandle& x) {
|
|
return ExprHandle(1.0f) + cast<float>(x);
|
|
};
|
|
Tensor tensor = Compute("f", {10}, func);
|
|
LoopNest l({tensor});
|
|
ForPtr head;
|
|
ForPtr tail;
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::sliceHead(loops[0], 10, &head, &tail);
|
|
|
|
ASSERT_EQ(head, loops[0]);
|
|
ASSERT_EQ(tail, nullptr);
|
|
|
|
BlockPtr body = getSimplifiedBody(l);
|
|
assertForRanges(body, {{0, 10}});
|
|
}
|
|
|
|
TEST(LoopNest, ExprSliceHeadWhenFactorLargerThanSize) {
|
|
auto func = [](const ExprHandle& x) {
|
|
return ExprHandle(1.0f) + cast<float>(x);
|
|
};
|
|
Tensor tensor = Compute("f", {10}, func);
|
|
LoopNest l({tensor});
|
|
ForPtr head;
|
|
ForPtr tail;
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::sliceHead(loops[0], 100, &head, &tail);
|
|
|
|
ASSERT_EQ(head, loops[0]);
|
|
ASSERT_EQ(tail, nullptr);
|
|
|
|
BlockPtr body = getSimplifiedBody(l);
|
|
assertForRanges(body, {{0, 10}});
|
|
}
|
|
|
|
TEST(LoopNest, ExprSliceHead) {
|
|
auto func = [](const ExprHandle& x) {
|
|
return ExprHandle(1.0f) + cast<float>(x);
|
|
};
|
|
Tensor tensor = Compute("f", {10}, func);
|
|
LoopNest l({tensor});
|
|
ForPtr head;
|
|
ForPtr tail;
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::sliceHead(loops[0], 4, &head, &tail);
|
|
|
|
ASSERT_NE(head, nullptr);
|
|
ASSERT_NE(head, loops[0]);
|
|
ASSERT_NE(tail, nullptr);
|
|
ASSERT_EQ(tail, loops[0]);
|
|
|
|
BlockPtr body = getSimplifiedBody(l);
|
|
assertForRanges(body, {{0, 4}, {4, 10}});
|
|
}
|
|
|
|
TEST(LoopNest, ExprSliceHeadWithNonZeroStart) {
|
|
auto func = [](const ExprHandle& x) {
|
|
return ExprHandle(1.0f) + cast<float>(x);
|
|
};
|
|
Tensor tensor = Compute("f", {10}, func);
|
|
LoopNest l({tensor});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
|
|
ForPtr head;
|
|
ForPtr tail;
|
|
LoopNest::sliceTail(loops[0], 4, &head, &tail);
|
|
// head: [0, 6)
|
|
// tail: [6, 10)
|
|
|
|
LoopNest::sliceHead(tail, 2);
|
|
// tail_head: [6, 8)
|
|
// tail_tail: [8, 10)
|
|
|
|
BlockPtr body = getSimplifiedBody(l);
|
|
assertForRanges(body, {{0, 6}, {6, 8}, {8, 10}});
|
|
}
|
|
|
|
TEST(LoopNest, ExprSliceTailWhenFactorEqualsSize) {
|
|
// When factor equals the For loop's original size, keep using the original
|
|
// For loop.
|
|
auto func = [](const ExprHandle& x) {
|
|
return ExprHandle(1.0f) + cast<float>(x);
|
|
};
|
|
Tensor tensor = Compute("f", {10}, func);
|
|
LoopNest l({tensor});
|
|
ForPtr head;
|
|
ForPtr tail;
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::sliceTail(loops[0], 10, &head, &tail);
|
|
|
|
ASSERT_EQ(head, nullptr);
|
|
ASSERT_EQ(tail, loops[0]);
|
|
|
|
BlockPtr body = getSimplifiedBody(l);
|
|
assertForRanges(body, {{0, 10}});
|
|
}
|
|
|
|
TEST(LoopNest, ExprSliceTailWhenFactorLargerThanSize) {
|
|
// When factor equals the For loop's original size, keep using the original
|
|
// For loop.
|
|
auto func = [](const ExprHandle& x) {
|
|
return ExprHandle(1.0f) + cast<float>(x);
|
|
};
|
|
Tensor tensor = Compute("f", {10}, func);
|
|
LoopNest l({tensor});
|
|
ForPtr head;
|
|
ForPtr tail;
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::sliceTail(loops[0], 100, &head, &tail);
|
|
|
|
ASSERT_EQ(head, nullptr);
|
|
ASSERT_EQ(tail, loops[0]);
|
|
|
|
BlockPtr body = getSimplifiedBody(l);
|
|
assertForRanges(body, {{0, 10}});
|
|
}
|
|
|
|
TEST(LoopNest, ExprSliceTail) {
|
|
auto func = [](const ExprHandle& x) {
|
|
return ExprHandle(1.0f) + cast<float>(x);
|
|
};
|
|
Tensor tensor = Compute("f", {10}, func);
|
|
LoopNest l({tensor});
|
|
ForPtr head;
|
|
ForPtr tail;
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::sliceTail(loops[0], 4, &head, &tail);
|
|
|
|
ASSERT_NE(head, nullptr);
|
|
ASSERT_EQ(head, loops[0]);
|
|
ASSERT_NE(tail, nullptr);
|
|
ASSERT_NE(tail, loops[0]);
|
|
|
|
BlockPtr body = getSimplifiedBody(l);
|
|
assertForRanges(body, {{0, 6}, {6, 10}});
|
|
}
|
|
|
|
TEST(LoopNest, ExprSplitAndSlice) {
|
|
// 0: splitWithTail
|
|
// 1: sliceTail on inner loop
|
|
// 2: sliceHead on outer loop
|
|
auto func = [](const ExprHandle& x) {
|
|
return ExprHandle(1.0f) + cast<float>(x);
|
|
};
|
|
Tensor tensor = Compute("f", {100}, func);
|
|
LoopNest l({tensor});
|
|
|
|
ForPtr inner;
|
|
ForPtr tail;
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
// outer: [0, 4)
|
|
// inner: [0, 21)
|
|
// tail: [84, 100)
|
|
LoopNest::splitWithTail(loops[0], 21, &inner, &tail);
|
|
LoopNest::sliceTail(inner, 2);
|
|
LoopNest::sliceHead(loops[0], 2);
|
|
|
|
// for (int x_outer = 0; x_outer < 2; x_outer++) {
|
|
// for (int x_inner = 0; x_inner < 19; x_inner++) {
|
|
// f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner);
|
|
// }
|
|
// for (int x_inner = 19; x_inner < 21; x_inner++) {
|
|
// f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner);
|
|
// }
|
|
// }
|
|
// for (int x_outer = 2; x_outer < 4; x_outer++) {
|
|
// for (int x_inner = 0; x_inner < 19; x_inner++) {
|
|
// f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner);
|
|
// }
|
|
// for (int x_inner = 19; x_inner < 21; x_inner++) {
|
|
// f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner);
|
|
// }
|
|
// }
|
|
// for (int x_tail = 0; x_tail < 16; x_tail++) {
|
|
// f[x_tail + 84] = 1.f + float(x_tail + 84);
|
|
// }
|
|
BlockPtr body = getSimplifiedBody(l);
|
|
assertForRanges(body, {{0, 2}, {2, 4}, {0, 16}});
|
|
|
|
auto biter = body->begin();
|
|
|
|
ForPtr loop = to<For>(*biter++);
|
|
assertForRanges(loop->body(), {{0, 19}, {19, 21}});
|
|
|
|
loop = to<For>(*biter);
|
|
assertForRanges(loop->body(), {{0, 19}, {19, 21}});
|
|
}
|
|
|
|
TEST(LoopNest, ExprSliceAndNormalize) {
|
|
// 0: sliceHead
|
|
// 1: normalize tail
|
|
auto func = [](const ExprHandle& x) {
|
|
return ExprHandle(1.0f) + cast<float>(x);
|
|
};
|
|
Tensor tensor = Compute("f", {10}, func);
|
|
LoopNest l({tensor});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
|
|
ForPtr head;
|
|
ForPtr tail;
|
|
LoopNest::sliceHead(loops[0], 2, &head, &tail);
|
|
// head: [0, 2)
|
|
// tail: [2, 10)
|
|
|
|
LoopNest::normalize(tail);
|
|
// normalized_tail: [0, 8)
|
|
|
|
BlockPtr body = getSimplifiedBody(l);
|
|
assertForRanges(body, {{0, 2}, {0, 8}});
|
|
}
|
|
|
|
template <typename T>
|
|
T evalExpr(const ExprHandle& expr, const VarHandle& var, T value) {
|
|
ExprEval<SimpleIREvaluator> eval(expr, {var});
|
|
return eval.value<T>(value);
|
|
}
|
|
|
|
TEST(LoopNest, ExprSliceWithVariableDimension) {
|
|
auto testWithDimension =
|
|
[](int dimension,
|
|
const std::vector<std::pair<int, int>>& expected_for_ranges) {
|
|
VarHandle dim("dim", kInt);
|
|
Tensor tensor =
|
|
Compute("f", {dim}, [](const ExprHandle& x) { return x; });
|
|
LoopNest l({tensor});
|
|
std::vector<ForPtr> loops =
|
|
l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
|
|
ForPtr head;
|
|
ForPtr tail;
|
|
LoopNest::sliceHead(loops[0], 2, &head, &tail);
|
|
|
|
LoopNest::sliceTail(tail, 2);
|
|
|
|
BlockPtr body = getSimplifiedBody(l);
|
|
ASSERT_EQ(expected_for_ranges.size(), 3);
|
|
auto it = body->begin();
|
|
for (auto& start_stop : expected_for_ranges) {
|
|
ForPtr loop = to<For>(*it++);
|
|
int start = evalExpr<int>(ExprHandle(loop->start()), dim, dimension);
|
|
int stop = evalExpr<int>(ExprHandle(loop->stop()), dim, dimension);
|
|
ASSERT_EQ(start, start_stop.first);
|
|
ASSERT_EQ(stop, start_stop.second);
|
|
}
|
|
};
|
|
|
|
testWithDimension(1, {{0, 1}, {1, 1}, {1, 1}});
|
|
testWithDimension(2, {{0, 2}, {2, 2}, {2, 2}});
|
|
testWithDimension(3, {{0, 2}, {2, 2}, {2, 3}});
|
|
testWithDimension(4, {{0, 2}, {2, 2}, {2, 4}});
|
|
testWithDimension(5, {{0, 2}, {2, 3}, {3, 5}});
|
|
testWithDimension(10, {{0, 2}, {2, 8}, {8, 10}});
|
|
}
|
|
|
|
TEST(LoopNest, ExprSplitWithTail) {
|
|
auto func = [](const ExprHandle& x) {
|
|
return ExprHandle(1.0f) + cast<float>(x);
|
|
};
|
|
Tensor tensor = Compute("f", {199}, func);
|
|
LoopNest l({tensor});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
LoopNest::splitWithTail(loops[0], 17);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
LoopNest::splitWithTail(loops[0], 7);
|
|
|
|
StmtPtr stmt = l.root_stmt();
|
|
StmtPtr simplified = IRSimplifier::simplify(stmt);
|
|
BlockPtr body = to<Block>(simplified);
|
|
ASSERT_EQ(body->nstmts(), 3);
|
|
auto biter = body->begin();
|
|
|
|
// Verify that the split loops are ordered correctly.
|
|
ForPtr loop = to<For>(*biter++);
|
|
assertForRange(loop, 0, 7);
|
|
|
|
loop = to<For>(*biter++);
|
|
assertForRange(loop, 0, 4);
|
|
|
|
loop = to<For>(*biter);
|
|
assertForRange(loop, 0, 12);
|
|
}
|
|
|
|
TEST(LoopNest, ExprSplitWithTailNone) {
|
|
auto func = [](const ExprHandle& x, const ExprHandle& y) {
|
|
return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
|
|
};
|
|
Tensor tensor = Compute("f", {24, 5}, func);
|
|
LoopNest l({tensor});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::splitWithTail(loops[0], 4);
|
|
|
|
StmtPtr stmt = l.root_stmt();
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
ASSERT_GT(oss.str().size(), 200);
|
|
ASSERT_LT(oss.str().size(), 600);
|
|
|
|
{
|
|
// Compare to a reference loop structure structure.
|
|
VarHandle x_outer("i_outer", kInt);
|
|
VarHandle x_inner("i_inner", kInt);
|
|
VarHandle y("i", kInt);
|
|
VarHandle x_tail("i_tail", kInt);
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle f("f", {24, 5}, kFloat);
|
|
ExprHandle x_1 = x_outer * 4 + x_inner;
|
|
ExprHandle x_outer_end = (ExprHandle(24) - 0) / 4;
|
|
StmtPtr stmt = alloc<Block>(std::vector<StmtPtr>({For::make(
|
|
x_outer,
|
|
0,
|
|
x_outer_end,
|
|
For::make(
|
|
x_inner,
|
|
0,
|
|
4,
|
|
For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))))}));
|
|
|
|
std::ostringstream oss_ref;
|
|
oss_ref << *stmt;
|
|
ASSERT_EQ(oss.str(), oss_ref.str());
|
|
}
|
|
|
|
{
|
|
PaddedBuffer<float> f_v(24, 5, "f_v");
|
|
PaddedBuffer<float> f_ref(24, 5, "f_res");
|
|
|
|
SimpleIREvaluator ir_eval(stmt, {tensor});
|
|
ir_eval(f_v);
|
|
|
|
for (int x = 0; x < 24; x++) {
|
|
for (int y = 0; y < 5; y++) {
|
|
f_ref(x, y) = 1 + x * x + y * y;
|
|
}
|
|
}
|
|
|
|
ExpectAllNear(f_v, f_ref, 1e-5);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, ExprSplitWithMask01) {
|
|
const int M = 26;
|
|
const int N = 5;
|
|
BufHandle a_buf("a", {M, N}, kFloat);
|
|
BufHandle b_buf("b", {M, N}, kFloat);
|
|
Tensor tensor =
|
|
Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
|
return a_buf.load(m, n) + b_buf.load(m, n) + 1.0f;
|
|
});
|
|
|
|
LoopNest l({tensor});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::splitWithMask(loops[1], 4);
|
|
|
|
StmtPtr stmt = l.root_stmt();
|
|
|
|
PaddedBuffer<float> a_v(M, N, "a");
|
|
PaddedBuffer<float> b_v(M, N, "b");
|
|
PaddedBuffer<float> c_v(M, N, "c");
|
|
PaddedBuffer<float> c_ref(M, N, "c_ref");
|
|
for (int m = 0; m < M; m++) {
|
|
for (int n = 0; n < N; n++) {
|
|
a_v(m, n) = 2 * m;
|
|
b_v(m, n) = 3 * n;
|
|
c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f;
|
|
}
|
|
}
|
|
|
|
SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v);
|
|
|
|
ExpectAllNear(c_v, c_ref, 1e-5);
|
|
}
|
|
|
|
// Tests the case where we split a loop cleanly multiple times, we should not
|
|
// insert any masks.
|
|
TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) {
|
|
const int M = 64;
|
|
BufHandle a_buf("a", {M}, kFloat);
|
|
BufHandle b_buf("b", {M}, kFloat);
|
|
Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) {
|
|
return a_buf.load(m) + b_buf.load(m) + 1.0f;
|
|
});
|
|
|
|
LoopNest l({tensor});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::splitWithMask(loops[0], 4);
|
|
LoopNest::splitWithMask(loops[0], 4);
|
|
|
|
StmtPtr stmt1 = IRSimplifier::simplify(l.root_stmt());
|
|
|
|
// Two splits mean 3 loops, but should need no masks in this case.
|
|
checkIR(stmt1, R"IR(
|
|
# CHECK: for (
|
|
# CHECK-NOT: if (
|
|
# CHECK: for (
|
|
# CHECK-NOT: if (
|
|
# CHECK: for (
|
|
# CHECK-NOT: if (
|
|
# CHECK: f[)IR");
|
|
}
|
|
|
|
TEST(LoopNest, getLoopAt) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 100; i++) {
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[i, j] = sin(i * j);
|
|
// for (int k1 = 0; k1 < 200; k1++) {
|
|
// B[i, j, k1] = (A[i, j]) / (k1 + 1);
|
|
// }
|
|
// for (int k2 = 0; k2 < 300; k2++) {
|
|
// C[i, j, k2] = (A[i, j]) * (k2 + 1);
|
|
// }
|
|
// }
|
|
// }
|
|
BufPtr A = alloc<Buf>(
|
|
"A",
|
|
std::vector<ExprPtr>({alloc<IntImm>(100), alloc<IntImm>(100)}),
|
|
kInt);
|
|
BufPtr B = alloc<Buf>(
|
|
"B",
|
|
std::vector<ExprPtr>(
|
|
{alloc<IntImm>(100), alloc<IntImm>(100), alloc<IntImm>(200)}),
|
|
kInt);
|
|
BufPtr C = alloc<Buf>(
|
|
"C",
|
|
std::vector<ExprPtr>(
|
|
{alloc<IntImm>(100), alloc<IntImm>(100), alloc<IntImm>(300)}),
|
|
kInt);
|
|
BufHandle a_buf(A);
|
|
BufHandle b_buf(B);
|
|
BufHandle c_buf(C);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k1("k1", kInt);
|
|
VarHandle k2("k2", kInt);
|
|
auto store1 = Store::make(a_buf, {i, j}, sin(i * j));
|
|
auto store2 = Store::make(
|
|
b_buf, {i, j, k1}, Div::make(Load::make(a_buf, {i, j}), (k1 + 1)));
|
|
auto store3 = Store::make(
|
|
c_buf, {i, j, k2}, Mul::make(Load::make(a_buf, {i, j}), (k2 + 1)));
|
|
auto for_k2 = For::make(k2, 0, 300, Block::make({store3}));
|
|
auto for_k1 = For::make(k1, 0, 200, Block::make({store2}));
|
|
auto for_j = For::make(j, 0, 100, Block::make({store1, for_k1, for_k2}));
|
|
auto for_i = For::make(i, 0, 100, for_j);
|
|
LoopNest l(Block::make({for_i}), {B, C});
|
|
auto ret_k2 = l.getLoopAt(for_i, {0, 2});
|
|
TORCH_CHECK(ret_k2 == for_k2);
|
|
|
|
std::ostringstream oss;
|
|
oss << *ret_k2;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int k2
|
|
# CHECK-NEXT: C[i, j, k2] =
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
TEST(LoopNest, TileSimple) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
const int M = 64, N = 64;
|
|
BufHandle a_buf("a", {M, N}, kFloat);
|
|
BufHandle b_buf("b", {M, N}, kFloat);
|
|
Tensor tensor =
|
|
Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
|
return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f;
|
|
});
|
|
|
|
LoopNest l({tensor});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
l.tile(loops[0], loops[1], 4, 8);
|
|
|
|
// IR check
|
|
StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
|
|
checkIR(stmt, R"IR(
|
|
# CHECK: for (int i_outer
|
|
# CHECK: for (int i_outer_1
|
|
# CHECK: for (int i_inner
|
|
# CHECK: for (int i_inner_1
|
|
# CHECK: f[
|
|
# CHECK-NOT: for (int i_tail
|
|
# CHECK-NOT: for (int i_tail)IR");
|
|
|
|
// Correctness check
|
|
PaddedBuffer<float> a_v(M, N, "a");
|
|
PaddedBuffer<float> b_v(M, N, "b");
|
|
PaddedBuffer<float> c_v(M, N, "c");
|
|
PaddedBuffer<float> c_ref(M, N, "c_ref");
|
|
for (int m = 0; m < M; m++) {
|
|
for (int n = 0; n < N; n++) {
|
|
a_v(m, n) = 2 * m;
|
|
b_v(m, n) = 3 * n;
|
|
c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f;
|
|
}
|
|
}
|
|
|
|
SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v);
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
ExpectAllNear(c_v, c_ref, 1e-5);
|
|
}
|
|
|
|
TEST(LoopNest, TileWithTails) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
const int M = 64, N = 64;
|
|
BufHandle a_buf("a", {M, N}, kFloat);
|
|
BufHandle b_buf("b", {M, N}, kFloat);
|
|
Tensor tensor =
|
|
Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
|
return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f;
|
|
});
|
|
|
|
LoopNest l({tensor});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
l.tile(loops[0], loops[1], 5, 9);
|
|
|
|
// IR check
|
|
StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
|
|
checkIR(stmt, R"IR(
|
|
# CHECK: for (int i_outer
|
|
# CHECK: for (int i_outer_1
|
|
# CHECK: for (int i_inner
|
|
# CHECK: for (int i_inner_1
|
|
# CHECK: f[
|
|
# CHECK: for (int i_inner
|
|
# CHECK: f[
|
|
# CHECK: for (int i_tail)IR");
|
|
|
|
// Correctness check
|
|
PaddedBuffer<float> a_v(M, N, "a");
|
|
PaddedBuffer<float> b_v(M, N, "b");
|
|
PaddedBuffer<float> c_v(M, N, "c");
|
|
PaddedBuffer<float> c_ref(M, N, "c_ref");
|
|
for (int m = 0; m < M; m++) {
|
|
for (int n = 0; n < N; n++) {
|
|
a_v(m, n) = 2 * m;
|
|
b_v(m, n) = 3 * n;
|
|
c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f;
|
|
}
|
|
}
|
|
|
|
SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v);
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
ExpectAllNear(c_v, c_ref, 1e-5);
|
|
}
|
|
|
|
TEST(LoopNest, TileInMiddle) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
const int M = 8, N = 8, L = 8, K = 8;
|
|
BufHandle a_buf("a", {M, N, L, K}, kFloat);
|
|
BufHandle b_buf("b", {M, N, L, K}, kFloat);
|
|
Tensor tensor = Compute(
|
|
"f",
|
|
{M, N, L, K},
|
|
[&](const ExprHandle& m,
|
|
const ExprHandle& n,
|
|
const ExprHandle& l,
|
|
const ExprHandle& k) {
|
|
return a_buf.load({m, n, l, k}) + b_buf.load({m, n, l, k}) + 1.0f;
|
|
});
|
|
|
|
LoopNest nest({tensor});
|
|
std::vector<ForPtr> loops =
|
|
nest.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
nest.tile(loops[1], loops[2], 3, 3);
|
|
|
|
// IR check
|
|
StmtPtr stmt = IRSimplifier::simplify(nest.root_stmt());
|
|
checkIR(stmt, R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK: for (int i_outer
|
|
# CHECK: for (int i_outer_1
|
|
# CHECK: for (int i_inner
|
|
# CHECK: for (int i_inner_1
|
|
# CHECK: for (int i_1
|
|
# CHECK: f[
|
|
# CHECK: for (int i_tail_1
|
|
# CHECK: for (int i_inner_1
|
|
# CHECK: for (int i_1
|
|
# CHECK: f[
|
|
# CHECK: for (int i_tail)IR");
|
|
|
|
// Correctness check
|
|
PaddedBuffer<float> a_v(M, N, L, K, "a");
|
|
PaddedBuffer<float> b_v(M, N, L, K, "b");
|
|
PaddedBuffer<float> c_v(M, N, L, K, "c");
|
|
PaddedBuffer<float> c_ref(M, N, L, K, "c_ref");
|
|
for (int m = 0; m < M; m++) {
|
|
for (int n = 0; n < N; n++) {
|
|
for (int l = 0; l < L; l++) {
|
|
for (int k = 0; k < K; k++) {
|
|
a_v(m, n, l, k) = 2 * (m + l);
|
|
b_v(m, n, l, k) = 3 * (n + k);
|
|
c_ref(m, n, l, k) = a_v(m, n, l, k) + b_v(m, n, l, k) + 1.0f;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v);
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
ExpectAllNear(c_v, c_ref, 1e-5);
|
|
}
|
|
|
|
TEST(LoopNest, SplitWithTailWithLoopOptions) {
|
|
const int M = 21;
|
|
BufHandle a_buf("a", {M}, kFloat);
|
|
BufHandle b_buf("b", {M}, kFloat);
|
|
Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) {
|
|
return a_buf.load(m) + b_buf.load(m) + 1.0f;
|
|
});
|
|
ForPtr inner, tail;
|
|
|
|
LoopNest l({tensor});
|
|
auto loops = NodeFinder<For>::find(l.root_stmt());
|
|
ASSERT_GT(loops.size(), 0);
|
|
loops[0]->set_gpu_block_index(LoopOptions::IDX_Y);
|
|
LoopNest::splitWithTail(loops[0], 4, &inner, &tail);
|
|
ASSERT_NE(inner, nullptr);
|
|
ASSERT_NE(tail, nullptr);
|
|
ForPtr outer = loops[0];
|
|
|
|
// Outer loop carries loop axis bindings.
|
|
ASSERT_TRUE(outer->loop_options().is_gpu_block_index());
|
|
ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y);
|
|
|
|
// Inner loop has none.
|
|
ASSERT_TRUE(inner->loop_options().isDefault());
|
|
|
|
// Tail loop has none.
|
|
ASSERT_TRUE(tail->loop_options().isDefault());
|
|
}
|
|
|
|
TEST(LoopNest, SplitWithMaskWithLoopOptions) {
|
|
const int M = 21;
|
|
BufHandle a_buf("a", {M}, kFloat);
|
|
BufHandle b_buf("b", {M}, kFloat);
|
|
Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) {
|
|
return a_buf.load(m) + b_buf.load(m) + 1.0f;
|
|
});
|
|
ForPtr inner;
|
|
|
|
LoopNest l({tensor});
|
|
auto loops = NodeFinder<For>::find(l.root_stmt());
|
|
loops[0]->set_gpu_block_index(LoopOptions::IDX_Y);
|
|
LoopNest::splitWithMask(loops[0], 4, &inner);
|
|
ForPtr outer = loops[0];
|
|
|
|
// Outer loop carries loop axis bindings.
|
|
ASSERT_TRUE(outer->loop_options().is_gpu_block_index());
|
|
ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y);
|
|
|
|
// Inner loop has none.
|
|
ASSERT_TRUE(inner->loop_options().isDefault());
|
|
}
|
|
|
|
TEST(LoopNest, ScheduleBroadcastAddBuffer) {
|
|
const int M = 4;
|
|
const int N = 5;
|
|
const int K = 6;
|
|
BufHandle a_buf("a", {M, N}, kFloat);
|
|
BufHandle b_buf("b", {N, K}, kFloat);
|
|
Tensor c = Compute(
|
|
"broadcast_add",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return a_buf.load(m, n) + b_buf.load(n, k);
|
|
});
|
|
LoopNest l({c});
|
|
StmtPtr stmt = l.root_stmt();
|
|
|
|
PaddedBuffer<float> a_v(M, N, "a_v");
|
|
for (int m = 0; m < M; m++) {
|
|
for (int n = 0; n < N; n++) {
|
|
a_v(m, n) = 7 * m * n;
|
|
}
|
|
}
|
|
a_v.Backup();
|
|
|
|
PaddedBuffer<float> b_v(N, K, "b_v");
|
|
for (int n = 0; n < N; n++) {
|
|
for (int k = 0; k < K; k++) {
|
|
b_v(n, k) = 11 * n * k;
|
|
}
|
|
}
|
|
b_v.Backup();
|
|
|
|
PaddedBuffer<float> c_v(M, N, K, "c_buf");
|
|
SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c});
|
|
ir_eval(a_v, b_v, c_v);
|
|
|
|
a_v.CheckBackup();
|
|
b_v.CheckBackup();
|
|
PaddedBuffer<float> c_ref(M, N, K, "c_ref");
|
|
for (int m = 0; m < M; m++) {
|
|
for (int n = 0; n < N; n++) {
|
|
for (int k = 0; k < K; k++) {
|
|
c_ref(m, n, k) = 7 * m * n + 11 * n * k;
|
|
}
|
|
}
|
|
}
|
|
ExpectAllNear(c_v, c_ref, 1e-5);
|
|
}
|
|
|
|
TEST(LoopNest, ScheduleFunctionCall01) {
|
|
const int M = 4;
|
|
const int N = 5;
|
|
const int K = 6;
|
|
BufHandle a_buf("a", {M, N}, kFloat);
|
|
BufHandle b_buf("b", {N, K}, kFloat);
|
|
Tensor c = Compute(
|
|
"broadcast_add",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return a_buf.load(m, n) + b_buf.load(n, k);
|
|
});
|
|
Tensor d = Compute(
|
|
"d",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return c.load(m, n, k) + 1;
|
|
});
|
|
|
|
LoopNest l({d}, {c, d});
|
|
l.prepareForCodegen();
|
|
StmtPtr stmt = l.root_stmt();
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
ASSERT_GT(oss.str().size(), 100);
|
|
|
|
PaddedBuffer<float> a_v(M, N);
|
|
PaddedBuffer<float> b_v(N, K);
|
|
PaddedBuffer<float> c_v(M, N, K);
|
|
PaddedBuffer<float> d_v(M, N, K);
|
|
PaddedBuffer<float> d_ref(M, N, K);
|
|
|
|
for (int i = 0; i < M; i++) {
|
|
for (int j = 0; j < N; j++) {
|
|
a_v(i, j) = i * i;
|
|
}
|
|
}
|
|
for (int i = 0; i < N; i++) {
|
|
for (int j = 0; j < K; j++) {
|
|
b_v(i, j) = j * j;
|
|
}
|
|
}
|
|
for (int i = 0; i < M; i++) {
|
|
for (int j = 0; j < N; j++) {
|
|
for (int k = 0; k < K; k++) {
|
|
d_ref(i, j, k) = a_v(i, j) + b_v(j, k) + 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
SimpleIREvaluator eval(stmt, {a_buf, b_buf, d});
|
|
eval(a_v, b_v, d_v);
|
|
|
|
ExpectAllNear(d_v, d_ref, 1e-5);
|
|
}
|
|
|
|
TEST(LoopNest, ScheduleInlineSimple) {
|
|
const int M = 4;
|
|
const int N = 5;
|
|
const int K = 6;
|
|
BufHandle a_buf("a", {M, N}, kFloat);
|
|
BufHandle b_buf("b", {N, K}, kFloat);
|
|
BufHandle c_buf("c", {M, N}, kFloat);
|
|
BufHandle d_buf("d", {M, K}, kFloat);
|
|
|
|
Tensor x = Compute(
|
|
"x",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return a_buf.load(m, n) * b_buf.load(n, k);
|
|
});
|
|
Tensor y = Compute(
|
|
"y",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k);
|
|
});
|
|
|
|
LoopNest l1({y}, {x, y});
|
|
LoopNest l2(l1);
|
|
l2.computeInline(x.buf());
|
|
|
|
l1.prepareForCodegen();
|
|
l2.prepareForCodegen();
|
|
|
|
StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
|
|
StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
|
|
|
|
SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, c_buf, d_buf, y});
|
|
SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, c_buf, d_buf, y});
|
|
|
|
PaddedBuffer<float> a_v(M, N);
|
|
PaddedBuffer<float> b_v(N, K);
|
|
PaddedBuffer<float> c_v(M, N);
|
|
PaddedBuffer<float> d_v(M, K);
|
|
|
|
for (int i = 0; i < M; i++) {
|
|
for (int j = 0; j < N; j++) {
|
|
a_v(i, j) = i * i;
|
|
}
|
|
}
|
|
for (int i = 0; i < N; i++) {
|
|
for (int j = 0; j < K; j++) {
|
|
b_v(i, j) = j * j;
|
|
}
|
|
}
|
|
for (int i = 0; i < M; i++) {
|
|
for (int j = 0; j < N; j++) {
|
|
c_v(i, j) = i + j;
|
|
}
|
|
}
|
|
for (int i = 0; i < M; i++) {
|
|
for (int j = 0; j < K; j++) {
|
|
d_v(i, j) = i * j;
|
|
}
|
|
}
|
|
|
|
PaddedBuffer<float> y_1(M, N, K);
|
|
PaddedBuffer<float> y_2(M, N, K);
|
|
|
|
eval1(a_v, b_v, c_v, d_v, y_1);
|
|
eval2(a_v, b_v, c_v, d_v, y_2);
|
|
ExpectAllNear(y_1, y_2, 1e-5);
|
|
std::ostringstream oss1, oss2;
|
|
oss1 << *stmt1;
|
|
oss2 << *stmt2;
|
|
ASSERT_GT(oss1.str().size(), oss2.str().size());
|
|
}
|
|
|
|
static std::string remove_space(const std::string& str) {
|
|
std::string str_new = str;
|
|
str_new.erase(
|
|
remove_if(str_new.begin(), str_new.end(), isspace), str_new.end());
|
|
return str_new;
|
|
}
|
|
|
|
void InlineFunc01Helper(const std::vector<std::string>& inline_order) {
|
|
const int M = 4;
|
|
const int N = 5;
|
|
const int K = 6;
|
|
BufHandle a_buf("a", {M, N}, kFloat);
|
|
BufHandle b_buf("b", {N, K}, kFloat);
|
|
BufHandle c_buf("c", {M, N}, kFloat);
|
|
BufHandle d_buf("d", {M, K}, kFloat);
|
|
|
|
Tensor x = Compute(
|
|
"x",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return a_buf.load(m, n) * b_buf.load(n, k);
|
|
});
|
|
Tensor y = Compute(
|
|
"y",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k);
|
|
});
|
|
Tensor z = Compute(
|
|
"z",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return x.load(m, n, k) + y.load(m, n, k);
|
|
});
|
|
|
|
LoopNest l({z}, {x, y, z});
|
|
for (const std::string& order : inline_order) {
|
|
if (order == "x") {
|
|
l.computeInline(x.buf());
|
|
} else if (order == "y") {
|
|
l.computeInline(y.buf());
|
|
} else {
|
|
throw std::runtime_error("Invalid order: " + order);
|
|
}
|
|
}
|
|
l.prepareForCodegen();
|
|
StmtPtr stmt = l.root_stmt();
|
|
|
|
std::ostringstream oss;
|
|
oss << *stmt;
|
|
std::string str1 = remove_space(oss.str());
|
|
|
|
{
|
|
PaddedBuffer<float> a_v(M, N);
|
|
PaddedBuffer<float> b_v(N, K);
|
|
PaddedBuffer<float> c_v(M, N);
|
|
PaddedBuffer<float> d_v(M, K);
|
|
|
|
for (int i = 0; i < M; i++) {
|
|
for (int j = 0; j < N; j++) {
|
|
a_v(i, j) = i * i;
|
|
}
|
|
}
|
|
for (int i = 0; i < N; i++) {
|
|
for (int j = 0; j < K; j++) {
|
|
b_v(i, j) = j * j;
|
|
}
|
|
}
|
|
for (int i = 0; i < M; i++) {
|
|
for (int j = 0; j < N; j++) {
|
|
c_v(i, j) = i + j;
|
|
}
|
|
}
|
|
for (int i = 0; i < M; i++) {
|
|
for (int j = 0; j < K; j++) {
|
|
d_v(i, j) = i * j;
|
|
}
|
|
}
|
|
|
|
PaddedBuffer<float> z_v(M, N, K);
|
|
PaddedBuffer<float> z_ref(M, N, K);
|
|
for (int m = 0; m < M; m++) {
|
|
for (int n = 0; n < N; n++) {
|
|
for (int k = 0; k < K; k++) {
|
|
z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k);
|
|
}
|
|
}
|
|
}
|
|
|
|
SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z});
|
|
eval(a_v, b_v, c_v, d_v, z_v);
|
|
ExpectAllNear(z_v, z_ref, 1e-5);
|
|
}
|
|
|
|
if (inline_order.size() == 2) {
|
|
Tensor z2 = Compute(
|
|
"z",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return a_buf.load(m, n) * b_buf.load(n, k) +
|
|
(c_buf.load(m, n) * d_buf.load(m, k) +
|
|
a_buf.load(m, n) * b_buf.load(n, k));
|
|
});
|
|
LoopNest l2({z2});
|
|
l2.prepareForCodegen();
|
|
StmtPtr stmt2 = l2.root_stmt();
|
|
|
|
std::ostringstream oss2;
|
|
oss2 << *stmt2;
|
|
std::string str2 = remove_space(oss2.str());
|
|
|
|
ASSERT_EQ(str1, str2);
|
|
ASSERT_GT(str1.size(), 100);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, ScheduleInlineFunc01) {
|
|
InlineFunc01Helper({"x", "y"});
|
|
InlineFunc01Helper({"y", "x"});
|
|
InlineFunc01Helper({"x"});
|
|
InlineFunc01Helper({"y"});
|
|
InlineFunc01Helper({});
|
|
}
|
|
|
|
// Make sure we cache random vars if we should.
|
|
TEST(LoopNest, ScheduleInlineRandom) {
|
|
const int M = 4;
|
|
const int N = 5;
|
|
const int K = 6;
|
|
|
|
Tensor x = Compute(
|
|
"x",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return Mod::make(Intrinsics::make(kRand, kInt), 5);
|
|
});
|
|
Tensor y = Compute(
|
|
"y",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return x.load(m, n, k) + x.load(m, n, k);
|
|
});
|
|
|
|
LoopNest l1({y}, {x, y});
|
|
l1.computeInline(x.buf());
|
|
|
|
// would normally compare results but Rand isn't implemented in the
|
|
// SimpleIREvaluator, even if we could seed it.
|
|
StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
|
|
|
|
// Check the IR we produced
|
|
checkIR(stmt1, R"IR(
|
|
# CHECK: for (int i = 0; i < 4; i++)
|
|
# CHECK: for (int i_1 = 0; i_1 < 5; i_1++)
|
|
# CHECK: for (int i_2 = 0; i_2 < 6; i_2++)
|
|
# CHECK: int x = rand();
|
|
# CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR");
|
|
}
|
|
|
|
// Make sure we don't cache random vars that are not being inlined.
|
|
TEST(LoopNest, ScheduleInlineRandomUnrelated) {
|
|
const int M = 4;
|
|
const int N = 5;
|
|
const int K = 6;
|
|
|
|
Tensor x = Compute(
|
|
"x",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return m * n * k;
|
|
});
|
|
Tensor y = Compute(
|
|
"y",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return x.load(m, n, k) + Intrinsics::make(kRand, kInt) +
|
|
Intrinsics::make(kRand, kInt);
|
|
});
|
|
|
|
LoopNest l1({y}, {x, y});
|
|
l1.computeInline(x.buf());
|
|
|
|
// would normally compare results but Rand isn't implemented in the
|
|
// SimpleIREvaluator, even if we could seed it.
|
|
StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
|
|
|
|
// Check the IR we produced
|
|
checkIR(stmt1, R"IR(
|
|
# CHECK: for (int i = 0; i < 4; i++)
|
|
# CHECK: for (int i_1 = 0; i_1 < 5; i_1++)
|
|
# CHECK: for (int i_2 = 0; i_2 < 6; i_2++)
|
|
# CHECK: y[i, i_1, i_2] = ((i * i_1) * i_2 + (rand())) + (rand());)IR");
|
|
}
|
|
|
|
// Make sure we generate the right number of random values == the dimensionality
|
|
// of the production tensor.
|
|
TEST(LoopNest, ScheduleInlineRandomLowerDimensions) {
|
|
const int M = 4;
|
|
const int N = 5;
|
|
const int K = 6;
|
|
|
|
Tensor x = Compute("x", {M}, [&](const VarHandle& m) {
|
|
return Mod::make(Intrinsics::make(kRand, kInt), 5);
|
|
});
|
|
Tensor y = Compute(
|
|
"y",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return x.load(m) + x.load(m);
|
|
});
|
|
|
|
LoopNest l1({y}, {x, y});
|
|
l1.computeInline(x.buf());
|
|
|
|
// would normally compare results but Rand isn't implemented in the
|
|
// SimpleIREvaluator, even if we could seed it.
|
|
StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
|
|
|
|
// Check the IR we produced
|
|
checkIR(stmt1, R"IR(
|
|
# CHECK: for (int i = 0; i < 4; i++)
|
|
# CHECK: int x = rand();
|
|
# CHECK: for (int i_1 = 0; i_1 < 5; i_1++)
|
|
# CHECK: for (int i_2 = 0; i_2 < 6; i_2++)
|
|
# CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR");
|
|
}
|
|
|
|
// Make sure we don't screw up intrinsics thinking they're rand.
|
|
TEST(LoopNest, ScheduleInlineIntrinsics) {
|
|
const int M = 4;
|
|
const int N = 5;
|
|
const int K = 6;
|
|
BufHandle a_buf("a", {M, N}, kFloat);
|
|
BufHandle b_buf("b", {N, K}, kFloat);
|
|
|
|
Tensor x = Compute(
|
|
"x",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return a_buf.load(m, n) * b_buf.load(n, k);
|
|
});
|
|
Tensor y = Compute(
|
|
"y",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return Intrinsics::make(kSqrt, x.load(m, n, k));
|
|
});
|
|
|
|
PaddedBuffer<float> a_v(M, N);
|
|
PaddedBuffer<float> b_v(N, K);
|
|
|
|
for (int i = 0; i < M; i++) {
|
|
for (int j = 0; j < N; j++) {
|
|
a_v(i, j) = i * i;
|
|
}
|
|
}
|
|
for (int i = 0; i < N; i++) {
|
|
for (int j = 0; j < K; j++) {
|
|
b_v(i, j) = j * j;
|
|
}
|
|
}
|
|
|
|
LoopNest l1({y}, {x, y});
|
|
LoopNest l2(l1);
|
|
l2.computeInline(x.buf());
|
|
|
|
l1.prepareForCodegen();
|
|
l2.prepareForCodegen();
|
|
|
|
StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
|
|
StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
|
|
|
|
SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y});
|
|
SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y});
|
|
|
|
PaddedBuffer<float> y_1(M, N, K);
|
|
PaddedBuffer<float> y_2(M, N, K);
|
|
|
|
eval1(a_v, b_v, y_1);
|
|
eval2(a_v, b_v, y_2);
|
|
ExpectAllNear(y_1, y_2, 1e-5);
|
|
std::ostringstream oss1, oss2;
|
|
oss1 << *stmt1;
|
|
oss2 << *stmt2;
|
|
ASSERT_GT(oss1.str().size(), oss2.str().size());
|
|
}
|
|
|
|
// Make sure we can handle rand and non-rand intrinsics.
|
|
TEST(LoopNest, ScheduleInlineRandWithIntrinsics) {
|
|
const int M = 4;
|
|
const int N = 5;
|
|
const int K = 6;
|
|
|
|
Tensor x = Compute(
|
|
"x",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return Intrinsics::make(kRand, kFloat);
|
|
});
|
|
Tensor y = Compute(
|
|
"y",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return Intrinsics::make(kSqrt, x.load(m, n, k));
|
|
});
|
|
|
|
LoopNest l1({y}, {x, y});
|
|
l1.computeInline(x.buf());
|
|
|
|
StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
|
|
|
|
// Check the IR we produced
|
|
checkIR(stmt1, R"IR(
|
|
# CHECK: for (int i = 0; i < 4; i++)
|
|
# CHECK: for (int i_1 = 0; i_1 < 5; i_1++)
|
|
# CHECK: for (int i_2 = 0; i_2 < 6; i_2++)
|
|
# CHECK: float x = rand();
|
|
# CHECK: y[i, i_1, i_2] = sqrt(x);)IR");
|
|
}
|
|
|
|
// Split a Compute then inline it into another compute.
|
|
TEST(LoopNest, ScheduleSplitAThenInline) {
|
|
Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
|
|
Tensor b = Compute(
|
|
"b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
|
|
|
|
LoopNest l({b}, {a, b});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
|
|
LoopNest::splitWithMask(loops[0], 4);
|
|
ASSERT_FALSE(l.computeInline(a.buf()));
|
|
}
|
|
|
|
// Split a Compute then inline another Compute into it.
|
|
TEST(LoopNest, ScheduleSplitBThenInline) {
|
|
Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
|
|
Tensor b = Compute(
|
|
"b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
|
|
|
|
LoopNest l({b}, {a, b});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0);
|
|
LoopNest::splitWithMask(loops[0], 3);
|
|
l.computeInline(a.buf());
|
|
l.prepareForCodegen();
|
|
StmtPtr s = IRSimplifier::simplify(l.root_stmt());
|
|
|
|
std::vector<int> output(6, 0);
|
|
SimpleIREvaluator eval(s, {b});
|
|
eval(output);
|
|
|
|
for (int i = 0; i < 6; ++i) {
|
|
ASSERT_EQ(output[i], (i + 8) * (i + 8));
|
|
}
|
|
}
|
|
|
|
// Split a Compute twice then inline it.
|
|
TEST(LoopNest, ScheduleSplitTwiceThenInline) {
|
|
Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
|
|
Tensor b = Compute(
|
|
"b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
|
|
ForPtr i_inner;
|
|
|
|
LoopNest l({b}, {a, b});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
|
|
LoopNest::splitWithMask(loops[0], 4, &i_inner);
|
|
LoopNest::splitWithMask(i_inner, 2);
|
|
ASSERT_FALSE(l.computeInline(a.buf()));
|
|
}
|
|
|
|
// Inline a Compute, then split.
|
|
TEST(LoopNest, ScheduleInlineThenSplit) {
|
|
Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
|
|
Tensor b = Compute(
|
|
"b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
|
|
|
|
LoopNest l({b}, {a, b});
|
|
l.computeInline(a.buf());
|
|
|
|
std::vector<ForPtr> loops = NodeFinder<For>::find(l.root_stmt());
|
|
LoopNest::splitWithMask(loops.back(), 3);
|
|
l.prepareForCodegen();
|
|
StmtPtr s = IRSimplifier::simplify(l.root_stmt());
|
|
std::vector<int> output(6, 0);
|
|
SimpleIREvaluator eval(s, {b});
|
|
eval(output);
|
|
|
|
for (int i = 0; i < 6; ++i) {
|
|
ASSERT_EQ(output[i], (i + 8) * (i + 8));
|
|
}
|
|
}
|
|
|
|
// Split a Compute, inline it, then split the result.
|
|
TEST(LoopNest, ScheduleSplitInlineThenSplit) {
|
|
Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
|
|
Tensor b = Compute(
|
|
"b", {16}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
|
|
|
|
LoopNest l({b}, {a, b});
|
|
auto loops = NodeFinder<For>::find(l.root_stmt());
|
|
LoopNest::splitWithMask(loops.back(), 2);
|
|
l.computeInline(a.buf());
|
|
|
|
loops = NodeFinder<For>::find(l.root_stmt());
|
|
LoopNest::splitWithMask(loops.front(), 2);
|
|
l.prepareForCodegen();
|
|
StmtPtr s = IRSimplifier::simplify(l.root_stmt());
|
|
std::vector<int> output(16, 0);
|
|
SimpleIREvaluator eval(s, {b});
|
|
eval(output);
|
|
|
|
for (int i = 0; i < 16; ++i) {
|
|
ASSERT_EQ(output[i], (i + 8) * (i + 8));
|
|
}
|
|
}
|
|
|
|
// Oversplit a loop that is simplified out after inlining.
|
|
TEST(LoopNest, ScheduleSplitInlineSimplify) {
|
|
Tensor a = Compute("a", {18}, [&](const VarHandle& i) {
|
|
return ExprHandle(4) * i - ExprHandle(2) * i;
|
|
});
|
|
Tensor b = Compute(
|
|
"b", {2}, [&](const VarHandle& j) { return a.load(j) - ExprHandle(1); });
|
|
|
|
LoopNest l({b}, {a, b});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
|
|
LoopNest::splitWithMask(loops[0], 4);
|
|
ASSERT_FALSE(l.computeInline(a.buf()));
|
|
}
|
|
|
|
// Inline a Compute with two consumers.
|
|
TEST(LoopNest, ScheduleInlineThreeMixedOnce) {
|
|
Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
|
|
Tensor b = Compute(
|
|
"b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
|
|
Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) {
|
|
return a.load(k) * b.load(l);
|
|
});
|
|
|
|
LoopNest l({c}, {a, b, c});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
|
|
l.computeInline(a.buf());
|
|
l.prepareForCodegen();
|
|
|
|
StmtPtr s = IRSimplifier::simplify(l.root_stmt());
|
|
std::vector<int> output(4 * 3, 0);
|
|
SimpleIREvaluator eval(s, {c});
|
|
eval(output);
|
|
|
|
for (int k = 0; k < 4; ++k) {
|
|
for (int l = 0; l < 3; ++l) {
|
|
ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Inline Compute A into B, then inline B into C.
|
|
TEST(LoopNest, ScheduleInlineThreeMixedTwice) {
|
|
Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
|
|
Tensor b = Compute(
|
|
"b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
|
|
Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) {
|
|
return a.load(k) * b.load(l);
|
|
});
|
|
|
|
LoopNest l({c}, {a, b, c});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
|
|
l.computeInline(a.buf());
|
|
l.computeInline(b.buf());
|
|
l.prepareForCodegen();
|
|
|
|
StmtPtr s = IRSimplifier::simplify(l.root_stmt());
|
|
std::vector<int> output(4 * 3, 0);
|
|
SimpleIREvaluator eval(s, {c});
|
|
eval(output);
|
|
|
|
for (int k = 0; k < 4; ++k) {
|
|
for (int l = 0; l < 3; ++l) {
|
|
ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Inline a Compute that is both a producer and consumer.
|
|
TEST(LoopNest, ScheduleInlineThreeMixedInner) {
|
|
Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
|
|
Tensor b = Compute(
|
|
"b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
|
|
Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) {
|
|
return a.load(k) * b.load(l);
|
|
});
|
|
|
|
LoopNest l({c}, {a, b, c});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
|
|
l.computeInline(b.buf());
|
|
l.prepareForCodegen();
|
|
|
|
StmtPtr s = IRSimplifier::simplify(l.root_stmt());
|
|
std::vector<int> output(4 * 3, 0);
|
|
SimpleIREvaluator eval(s, {c});
|
|
eval(output);
|
|
|
|
for (int k = 0; k < 4; ++k) {
|
|
for (int l = 0; l < 3; ++l) {
|
|
ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Split 3 Computes, then inline the first two into the last.
|
|
TEST(LoopNest, ScheduleInlineThreeMixedSplit) {
|
|
Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
|
|
Tensor b = Compute(
|
|
"b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
|
|
Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) {
|
|
return a.load(k) * b.load(l);
|
|
});
|
|
|
|
LoopNest l({c}, {a, b, c});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
|
|
LoopNest::splitWithMask(loops[0], 4);
|
|
loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0);
|
|
LoopNest::splitWithMask(loops[0], 3);
|
|
loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
|
|
LoopNest::splitWithMask(loops[0], 2);
|
|
|
|
ASSERT_FALSE(l.computeInline(a.buf()));
|
|
}
|
|
|
|
// Check that inlining works for output tensors too
|
|
TEST(LoopNest, ScheduleInlineOutputTensors) {
|
|
const int M = 4;
|
|
const int N = 5;
|
|
const int K = 6;
|
|
|
|
Tensor x = Compute(
|
|
"x",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return m * n * k;
|
|
});
|
|
Tensor y = Compute(
|
|
"y",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return x.load(m, n, k) + m;
|
|
});
|
|
|
|
LoopNest l1({x, y});
|
|
l1.computeInline(x.buf());
|
|
|
|
// would normally compare results but Rand isn't implemented in the
|
|
// SimpleIREvaluator, even if we could seed it.
|
|
StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
|
|
|
|
// Check the IR we produced
|
|
checkIR(stmt1, R"IR(
|
|
# CHECK: for (int i = 0; i < 4; i++)
|
|
# CHECK: for (int i_1 = 0; i_1 < 5; i_1++)
|
|
# CHECK: for (int i_2 = 0; i_2 < 6; i_2++)
|
|
# CHECK: x[i, i_1, i_2] = (i * i_1) * i_2;
|
|
# CHECK: for (int i_3 = 0; i_3 < 4; i_3++)
|
|
# CHECK: for (int i_4 = 0; i_4 < 5; i_4++)
|
|
# CHECK: for (int i_5 = 0; i_5 < 6; i_5++)
|
|
# CHECK: y[i_3, i_4, i_5] = i_3 + (i_3 * i_4) * i_5;)IR");
|
|
}
|
|
|
|
TEST(LoopNest, ScheduleInlineWithCompoundIndices) {
|
|
// Input IR:
|
|
// for (int64_t i = 0; i < 100; i++) {
|
|
// A[i*2,i] = i * 500ll;
|
|
// }
|
|
// for (int64_t j = 0; j < 100; j++) {
|
|
// B[0ll,j] = A[0, j] + j * 100ll;
|
|
// }
|
|
BufHandle a_buf("A", {20, 100}, kLong);
|
|
BufHandle b_buf("B", {20, 100}, kLong);
|
|
VarHandle i("i", kLong);
|
|
VarHandle j("j", kLong);
|
|
auto forI = For::make(
|
|
i,
|
|
0,
|
|
100,
|
|
Store::make(a_buf, {i * 2, i}, Mul::make(i, static_cast<int64_t>(500))));
|
|
auto forJ = For::make(
|
|
j,
|
|
0,
|
|
100,
|
|
Store::make(
|
|
b_buf,
|
|
{static_cast<int64_t>(0), j},
|
|
Add::make(
|
|
Load::make(a_buf, {static_cast<int64_t>(0), j}),
|
|
Mul::make(j, static_cast<int64_t>(100)))));
|
|
auto par = Block::make({forI, forJ});
|
|
|
|
LoopNest l(par, {b_buf.node()});
|
|
// Inlining should fail since the producer has compound expr as index.
|
|
ASSERT_FALSE(l.computeInline(a_buf.node()));
|
|
|
|
// The input statement must remain as is.
|
|
checkIR(l.root_stmt(), R"IR(
|
|
# CHECK: for (int64_t i = 0;
|
|
# CHECK-NEXT: A[
|
|
# CHECK: for (int64_t j = 0;
|
|
# CHECK-NEXT: B[)IR");
|
|
}
|
|
|
|
TEST(LoopNest, ScheduleInlineConsumerIndicesWithCast) {
|
|
// Input IR:
|
|
// for (int64_t i = 0; i < 100; i++) {
|
|
// A[0ll,i] = i * 500ll;
|
|
// }
|
|
// for (int64_t j = 0; j < 100; j++) {
|
|
// B[0ll,j] = A[(int64_t)0, j] + j * 100ll;
|
|
// }
|
|
BufHandle a_buf("A", {20, 100}, kLong);
|
|
BufHandle b_buf("B", {20, 100}, kLong);
|
|
VarHandle i("i", kLong);
|
|
VarHandle j("j", kLong);
|
|
auto forI = For::make(
|
|
i,
|
|
0,
|
|
100,
|
|
Store::make(
|
|
a_buf,
|
|
{static_cast<int64_t>(0), i},
|
|
Mul::make(i, static_cast<int64_t>(500))));
|
|
auto forJ = For::make(
|
|
j,
|
|
0,
|
|
100,
|
|
Store::make(
|
|
b_buf,
|
|
{static_cast<int64_t>(0), j},
|
|
Add::make(
|
|
Load::make(a_buf, {0, j}),
|
|
Mul::make(j, static_cast<int64_t>(100)))));
|
|
auto par = Block::make({forI, forJ});
|
|
|
|
LoopNest l(par, {b_buf.node()});
|
|
ASSERT_TRUE(l.computeInline(a_buf.node()));
|
|
|
|
checkIR(l.root_stmt(), R"IR(
|
|
# CHECK: for (int64_t j = 0; j < 100; j++) {
|
|
# CHECK: B[0ll, j] = j * 500ll + j * 100ll;
|
|
# CHECK: })IR");
|
|
}
|
|
|
|
TEST(LoopNest, ScheduleInlineProducerIndicesWithCast) {
|
|
// Input IR:
|
|
// for (int64_t i = 0; i < 100; i++) {
|
|
// A[(int64_t)0,i] = i * 500ll;
|
|
// }
|
|
// for (int64_t j = 0; j < 100; j++) {
|
|
// B[0ll,j] = A[0ll, j] + j * 100ll;
|
|
// }
|
|
BufHandle a_buf("A", {20, 100}, kLong);
|
|
BufHandle b_buf("B", {20, 100}, kLong);
|
|
VarHandle i("i", kLong);
|
|
VarHandle j("j", kLong);
|
|
auto forI = For::make(
|
|
i,
|
|
0,
|
|
100,
|
|
Store::make(a_buf, {0, i}, Mul::make(i, static_cast<int64_t>(500))));
|
|
auto forJ = For::make(
|
|
j,
|
|
0,
|
|
100,
|
|
Store::make(
|
|
b_buf,
|
|
{static_cast<int64_t>(0), j},
|
|
Add::make(
|
|
Load::make(a_buf, {static_cast<int64_t>(0), j}),
|
|
Mul::make(j, static_cast<int64_t>(100)))));
|
|
auto par = Block::make({forI, forJ});
|
|
|
|
LoopNest l(par, {b_buf.node()});
|
|
ASSERT_TRUE(l.computeInline(a_buf.node()));
|
|
|
|
checkIR(l.root_stmt(), R"IR(
|
|
# CHECK: for (int64_t j = 0; j < 100; j++) {
|
|
# CHECK: B[0ll, j] = j * 500ll + j * 100ll;
|
|
# CHECK: })IR");
|
|
}
|
|
|
|
TEST(LoopNest, ScheduleFuserStyle) {
|
|
const int kVectorSize = 8;
|
|
const int kVectorCount = 128;
|
|
const int kTotalSize = kVectorSize * kVectorCount;
|
|
|
|
BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
|
|
|
|
Tensor b =
|
|
Compute("f", {kTotalSize}, [&](const std::vector<VarHandle>& axes) {
|
|
return a_buf.load(axes[0]) + 11.0f;
|
|
});
|
|
|
|
Tensor c =
|
|
Compute("g", {kTotalSize}, [&](const std::vector<VarHandle>& axes) {
|
|
return b.load(axes[0]) + 1.0f;
|
|
});
|
|
|
|
LoopNest l({b, c});
|
|
l.prepareForCodegen();
|
|
StmtPtr s = l.root_stmt();
|
|
|
|
std::vector<float> a_data(kTotalSize, 7.0f);
|
|
std::vector<float> b_data(kTotalSize, 0.0f);
|
|
std::vector<float> c_data(kTotalSize, 0.0f);
|
|
SimpleIREvaluator(s, {a_buf, b, c})(a_data, b_data, c_data);
|
|
|
|
for (int i = 0; i < kTotalSize; i++) {
|
|
ASSERT_EQ(b_data[i], 18.0f);
|
|
ASSERT_EQ(c_data[i], 19.0f);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, ScheduleFuserThreeArg) {
|
|
const int kVectorSize = 8;
|
|
const int kVectorCount = 128;
|
|
const int kTotalSize = kVectorSize * kVectorCount;
|
|
|
|
BufHandle a("A", {ExprHandle(kTotalSize)}, kFloat);
|
|
BufHandle b("B", {ExprHandle(kTotalSize)}, kFloat);
|
|
BufHandle c("C", {ExprHandle(kTotalSize)}, kFloat);
|
|
BufHandle d("D", {ExprHandle(kTotalSize)}, kFloat);
|
|
|
|
Tensor e = Compute("e", {kTotalSize}, [&](const VarHandle& i) {
|
|
return a.load(i) + b.load(i);
|
|
});
|
|
Tensor f = Compute("f", {kTotalSize}, [&](const VarHandle& i) {
|
|
return e.load(i) + c.load(i);
|
|
});
|
|
Tensor g = Compute("g", {kTotalSize}, [&](const VarHandle& i) {
|
|
return f.load(i) + d.load(i);
|
|
});
|
|
|
|
LoopNest l({g}, {e, f, g});
|
|
l.computeInline(l.getLoopBodyFor(e));
|
|
l.computeInline(l.getLoopBodyFor(f));
|
|
l.prepareForCodegen();
|
|
StmtPtr s = l.root_stmt();
|
|
|
|
std::vector<float> a_data(kTotalSize, 1.0f);
|
|
std::vector<float> b_data(kTotalSize, 2.0f);
|
|
std::vector<float> c_data(kTotalSize, 3.0f);
|
|
std::vector<float> d_data(kTotalSize, 4.0f);
|
|
std::vector<float> g_data(kTotalSize, 0.0f);
|
|
SimpleIREvaluator(s, {a, b, c, d, g})(a_data, b_data, c_data, d_data, g_data);
|
|
|
|
for (int i = 0; i < kTotalSize; i++) {
|
|
ASSERT_EQ(g_data[i], 10.0f);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, ScheduleDynamicShape2D) {
|
|
auto testWithSize = [](int32_t M, int32_t N) {
|
|
VarHandle m("m", kInt);
|
|
VarHandle n("n", kInt);
|
|
BufHandle a("a", {m, n}, kFloat);
|
|
BufHandle b("b", {m, n}, kFloat);
|
|
Tensor c =
|
|
Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) {
|
|
return a.load(i, j) + b.load(i, j);
|
|
});
|
|
LoopNest l({c});
|
|
StmtPtr s = l.root_stmt();
|
|
SimpleIREvaluator cg(s, {a, b, c, m, n});
|
|
std::vector<float> aData(M * N, 1.0f);
|
|
std::vector<float> bData(M * N, 2.0f);
|
|
std::vector<float> cData(M * N, 0.0f);
|
|
cg.call({aData, bData, cData, M, N});
|
|
ExpectAllNear(cData, std::vector<float>(M * N, 3.0f), 1e-7);
|
|
};
|
|
testWithSize(1, 8);
|
|
testWithSize(16, 32);
|
|
testWithSize(37, 11);
|
|
}
|
|
|
|
TEST(LoopNest, LoopNestComputeAt_1) {
|
|
// Verify that compute_at works on the following example:
|
|
//
|
|
// for (int i_a = 0; i_a < N; i_a++) {
|
|
// A[i_a] = i_a * i_a
|
|
// }
|
|
// for (int i_b = 0; i_b < N; i_b++) {
|
|
// B[i_b] = A[i_b]
|
|
// }
|
|
//
|
|
// After the transformation the i_b loop should have an allocation for a temp
|
|
// buffer and that buffer should be used in computation of B. No use of A
|
|
// should be in that loop after the transformation. Also, computation of A
|
|
// should not be inlined into B. Instead, it should be computed into the temp,
|
|
// and the temp should be used in B.
|
|
VarHandle N("N", kInt);
|
|
Tensor A = Compute("A", {N}, [&](const VarHandle& i_a) { return i_a * i_a; });
|
|
Tensor B =
|
|
Compute("B", {N}, [&](const VarHandle& i_b) { return A.load(i_b); });
|
|
LoopNest l({B}, {A, B});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0);
|
|
LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]);
|
|
l.prepareForCodegen();
|
|
SimpleIREvaluator cg(l.root_stmt(), {B, N});
|
|
StmtPtr s = cg.stmt();
|
|
|
|
checkIR(s, R"IR(
|
|
# CHECK: Allocate(temp); // dtype=int, dims=[1]
|
|
# CHECK: for (int i = 0; i < N; i++)
|
|
# CHECK: temp[
|
|
# CHECK-NOT: A[
|
|
# CHECK: B[i_1] = temp[0]
|
|
# CHECK: Free(temp))IR");
|
|
|
|
// Now check that the loop still produces the correct result.
|
|
std::vector<int> b_data(100, 0);
|
|
cg.call({b_data, 100});
|
|
|
|
std::vector<int> b_ref(100, 0);
|
|
for (int i = 0; i < 100; i++) {
|
|
b_ref[i] = i * i;
|
|
}
|
|
assertAllEqual(b_data, b_ref);
|
|
}
|
|
|
|
TEST(LoopNest, LoopNestComputeAt_2) {
|
|
// Verify that compute_at works on the following example:
|
|
//
|
|
// for (int py = 0; py < H+1; py++) {
|
|
// for (int px = 0; px < W+1; px++) {
|
|
// p[py, px] = py*px
|
|
// }
|
|
// }
|
|
// for (int cy = 0; cy < H; cy++) {
|
|
// for (int cx = 0; cx < W; cx++) {
|
|
// c[py, px] = p[cy,cx] + p[cy+1,cx] +
|
|
// p[cy,cx+1] + p[cy+1,cx+1]
|
|
// }
|
|
// }
|
|
|
|
const int kW = 16, kH = 16;
|
|
VarHandle W("W", kInt);
|
|
VarHandle H("H", kInt);
|
|
Tensor p = Compute(
|
|
"prod", {H + 1, W + 1}, [&](const VarHandle& py, const VarHandle& px) {
|
|
return px * py;
|
|
});
|
|
Tensor c =
|
|
Compute("cons", {H, W}, [&](const VarHandle& y, const VarHandle& x) {
|
|
return p.load(y, x) + p.load(y + 1, x) + p.load(y, x + 1) +
|
|
p.load(y + 1, x + 1);
|
|
});
|
|
|
|
std::vector<int> c_ref(kW * kH, 0);
|
|
for (int y = 0; y < kH; y++) {
|
|
for (int x = 0; x < kW; x++) {
|
|
c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1);
|
|
}
|
|
}
|
|
LoopNest orig_loopnest({c}, {p, c});
|
|
|
|
{
|
|
// First let's try to compute P at axis cy (the outer loop)
|
|
LoopNest l(orig_loopnest);
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
|
|
LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]);
|
|
l.prepareForCodegen();
|
|
SimpleIREvaluator cg(l.root_stmt(), {c, W, H});
|
|
StmtPtr s = cg.stmt();
|
|
|
|
// Check the IR we produced
|
|
checkIR(s, R"IR(
|
|
# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1]
|
|
# CHECK: for (int i_2 = 0; i_2 < H; i_2++)
|
|
# CHECK: for
|
|
# CHECK: for
|
|
# CHECK: for (int i_3 = 0; i_3 < W; i_3++)
|
|
# CHECK-NOT: prod[
|
|
# CHECK: cons[
|
|
# CHECK: Free(temp))IR");
|
|
|
|
// Now check that the loop still produces the correct result.
|
|
std::vector<int> c_data(kW * kH, 0);
|
|
cg.call({c_data, kW, kH});
|
|
|
|
assertAllEqual(c_data, c_ref);
|
|
}
|
|
{
|
|
// Now let's try to compute P at axis cx (the inner loop)
|
|
LoopNest l(orig_loopnest);
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
|
|
LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]);
|
|
l.prepareForCodegen();
|
|
SimpleIREvaluator cg(l.root_stmt(), {c, W, H});
|
|
StmtPtr s = cg.stmt();
|
|
|
|
// Check the IR we produced
|
|
checkIR(s, R"IR(
|
|
# CHECK: Allocate(temp); // dtype=int, dims=[2, 2]
|
|
# CHECK: for (int i_2 = 0; i_2 < H; i_2++)
|
|
# CHECK: for (int i_3 = 0; i_3 < W; i_3++)
|
|
# CHECK: for
|
|
# CHECK: for
|
|
# CHECK-NOT: prod[
|
|
# CHECK: cons[
|
|
# CHECK: Free(temp))IR");
|
|
|
|
// Now check that the loop still produces the correct result.
|
|
std::vector<int> c_data(kW * kH, 0);
|
|
cg.call({c_data, kW, kH});
|
|
|
|
assertAllEqual(c_data, c_ref);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, LoopNestComputeAt_3) {
|
|
// Verify that compute_at works on the following example:
|
|
//
|
|
// A(x,y) = x*y
|
|
// B(x,y) = A(x, y)
|
|
// C(x,y) = B(x+1, y)
|
|
// D(x,y) = A(x, y+1) + C(x, y)
|
|
//
|
|
// i.e. when 'A' comes to 'D' directly and indirectly through 'C'.
|
|
|
|
const int kW = 16, kH = 16;
|
|
VarHandle W("W", kInt);
|
|
VarHandle H("H", kInt);
|
|
Tensor A = Compute(
|
|
"A", {H + 1, W + 1}, [&](const VarHandle& ay, const VarHandle& ax) {
|
|
return ax * ay;
|
|
});
|
|
Tensor B = Compute(
|
|
"B", {H + 1, W + 1}, [&](const VarHandle& by, const VarHandle& bx) {
|
|
return A.load(by, bx);
|
|
});
|
|
Tensor C =
|
|
Compute("C", {H, W}, [&](const VarHandle& cy, const VarHandle& cx) {
|
|
return B.load(cy, cx + 1);
|
|
});
|
|
Tensor D =
|
|
Compute("D", {H, W}, [&](const VarHandle& dy, const VarHandle& dx) {
|
|
return A.load(dy + 1, dx) + C.load(dy, dx);
|
|
});
|
|
|
|
std::vector<int> c_ref(kW * kH, 0);
|
|
for (int y = 0; y < kH; y++) {
|
|
for (int x = 0; x < kW; x++) {
|
|
c_ref[y * kW + x] = (y + 1) * x + y * (x + 1);
|
|
}
|
|
}
|
|
|
|
LoopNest orig_loopnest({D}, {A, B, C, D});
|
|
{
|
|
// First let's try to compute A at axis dy (the outer loop)
|
|
LoopNest l(orig_loopnest);
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0);
|
|
LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]);
|
|
l.prepareForCodegen();
|
|
SimpleIREvaluator cg(l.root_stmt(), {D, W, H});
|
|
StmtPtr s = cg.stmt();
|
|
|
|
// Check the IR we produced
|
|
checkIR(s, R"IR(
|
|
# CHECK: Allocate(temp); // dtype=int, dims=[1, W]
|
|
# CHECK: for (int i = 0; i < H + 1; i++)
|
|
# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++)
|
|
# CHECK: A[
|
|
# CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++)
|
|
# CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++)
|
|
# CHECK: B[
|
|
# CHECK: for (int i_4 = 0; i_4 < H; i_4++)
|
|
# CHECK: for (int i_5 = 0; i_5 < W; i_5++)
|
|
# CHECK: C[
|
|
# CHECK: for (int i_6 = 0; i_6 < H; i_6++)
|
|
# CHECK: for (int i_7 = 0; i_7 < W; i_7++)
|
|
# CHECK-NOT: A[)IR");
|
|
|
|
// Now check that the loop still produces the correct result.
|
|
std::vector<int> c_data(kW * kH, 0);
|
|
cg.call({c_data, kW, kH});
|
|
|
|
assertAllEqual(c_data, c_ref);
|
|
}
|
|
{
|
|
// Now let's try to compute A at axis dx (the inner loop)
|
|
LoopNest l(orig_loopnest);
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0);
|
|
LoopNest::computeAt(l.getLoopBodyFor(A), loops[1]);
|
|
l.prepareForCodegen();
|
|
SimpleIREvaluator cg(l.root_stmt(), {D, W, H});
|
|
StmtPtr s = cg.stmt();
|
|
|
|
// Check the IR we produced
|
|
checkIR(s, R"IR(
|
|
# CHECK: Allocate(temp); // dtype=int, dims=[1, 1]
|
|
# CHECK: for (int i = 0; i < H + 1; i++)
|
|
# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++)
|
|
# CHECK: A[
|
|
# CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++)
|
|
# CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++)
|
|
# CHECK: B[
|
|
# CHECK: for (int i_4 = 0; i_4 < H; i_4++)
|
|
# CHECK: for (int i_5 = 0; i_5 < W; i_5++)
|
|
# CHECK: C[
|
|
# CHECK: for (int i_6 = 0; i_6 < H; i_6++)
|
|
# CHECK: for (int i_7 = 0; i_7 < W; i_7++)
|
|
# CHECK-NOT: A[)IR");
|
|
|
|
// Now check that the loop still produces the correct result.
|
|
std::vector<int> c_data(kW * kH, 0);
|
|
cg.call({c_data, kW, kH});
|
|
|
|
assertAllEqual(c_data, c_ref);
|
|
}
|
|
}
|
|
|
|
using Axis = const VarHandle&;
|
|
|
|
TEST(LoopNest, Reduce2dComputeAt) {
|
|
const int kW = 16, kH = 16;
|
|
VarHandle W("W", kInt);
|
|
VarHandle H("H", kInt);
|
|
|
|
Tensor p = Compute(
|
|
"prod", {H + 1, W + 1}, [&](Axis py, Axis px) { return px * py; });
|
|
Tensor c = Reduce(
|
|
"cons",
|
|
{H, W},
|
|
Sum(),
|
|
[&](Axis y, Axis x, Axis r, Axis s) { return p.load(y + r, x + s); },
|
|
{2, 2});
|
|
|
|
std::vector<int> c_ref(kW * kH, 0);
|
|
for (int y = 0; y < kH; y++) {
|
|
for (int x = 0; x < kW; x++) {
|
|
c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1);
|
|
}
|
|
}
|
|
LoopNest orig_loopnest({c}, {p, c});
|
|
checkIR(orig_loopnest.root_stmt(), R"IR(
|
|
# CHECK: for (int i = 0; i < H + 1; i++) {
|
|
# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) {
|
|
# CHECK: prod[i, i_1] = i_1 * i;
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: for (int i_2 = 0; i_2 < H; i_2++) {
|
|
# CHECK: for (int i_3 = 0; i_3 < W; i_3++) {
|
|
# CHECK: cons[i_2, i_3] = int(0);
|
|
# CHECK: for (int i_4 = 0; i_4 < 2; i_4++) {
|
|
# CHECK: for (int i_5 = 0; i_5 < 2; i_5++) {
|
|
# CHECK: cons[i_2, i_3] = ReduceOp((cons[i_2, i_3]) + (prod[i_2 + i_4, i_3 + i_5]), reduce_args={i_4, i_5});
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
)IR");
|
|
|
|
{
|
|
// First let's try to compute P at axis cy (the outer loop)
|
|
LoopNest l(orig_loopnest);
|
|
auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
|
|
LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]);
|
|
// FIXME: Calling simplify here breaks the IR:
|
|
// MALFORMED INPUT: could not find base node in Load - temp[...]
|
|
// l.simplify();
|
|
l.eliminateDeadStores();
|
|
l.prepareForCodegen();
|
|
SimpleIREvaluator cg(l.root_stmt(), {c, W, H});
|
|
checkIR(cg.stmt(), R"IR(
|
|
# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1]
|
|
# CHECK: for (int i = 0; i < H; i++) {
|
|
# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) {
|
|
# CHECK: for (int idx1 = 0; idx1 < W + 1; idx1++) {
|
|
# CHECK: temp[(0 + idx0 * (1 * (W + 1))) + idx1 * 1] = (idx0 + i) * (idx1 + 0);
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: for (int i_1 = 0; i_1 < W; i_1++) {
|
|
# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = int(0);
|
|
# CHECK: for (int i_2 = 0; i_2 < 2; i_2++) {
|
|
# CHECK: for (int i_3 = 0; i_3 < 2; i_3++) {
|
|
# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * (W + 1))) + (i_1 + i_3) * 1]);
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: Free(temp);
|
|
)IR");
|
|
|
|
// Now check that the loop still produces the correct result.
|
|
std::vector<int> c_data(kW * kH, 0);
|
|
cg.call({c_data, kW, kH});
|
|
assertAllEqual(c_data, c_ref);
|
|
}
|
|
{
|
|
// Now let's try to compute P at axis cx (the inner loop)
|
|
LoopNest l(orig_loopnest);
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
|
|
LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]);
|
|
l.simplify();
|
|
l.eliminateDeadStores();
|
|
l.prepareForCodegen();
|
|
SimpleIREvaluator cg(l.root_stmt(), {c, W, H});
|
|
checkIR(cg.stmt(), R"IR(
|
|
# CHECK: Allocate(temp); // dtype=int, dims=[2, 2]
|
|
# CHECK: for (int i = 0; i < H; i++) {
|
|
# CHECK: for (int i_1 = 0; i_1 < W; i_1++) {
|
|
# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) {
|
|
# CHECK: for (int idx1 = 0; idx1 < 2; idx1++) {
|
|
# CHECK: temp[(0 + idx0 * (1 * 2)) + idx1 * 1] = (i + idx0) * (i_1 + idx1);
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = 0;
|
|
# CHECK: for (int i_2 = 0; i_2 < 2; i_2++) {
|
|
# CHECK: for (int i_3 = 0; i_3 < 2; i_3++) {
|
|
# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * 2)) + i_3 * 1]);
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: Free(temp);
|
|
)IR");
|
|
|
|
// Now check that the loop still produces the correct result.
|
|
std::vector<int> c_data(kW * kH, 0);
|
|
cg.call({c_data, kW, kH});
|
|
assertAllEqual(c_data, c_ref);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, DISABLED_Conv1d_NH) {
|
|
// Lots of stuff is broken here. The computeAt swaps the axes for some odd
|
|
// reason. Even without that, the index flattener fails due to "dimensions
|
|
// mismatch in flatten index".
|
|
|
|
int N = 4;
|
|
int H = 256;
|
|
int R = 3;
|
|
int Pad = 1;
|
|
BufHandle IP("input", {H}, kFloat);
|
|
|
|
Tensor A = Compute("A", {N, H + 2 * Pad}, [&](Axis n, Axis h) {
|
|
auto cond = CompareSelect::make(h, Pad, 1, 0, kLT);
|
|
cond = CompareSelect::make(h, H + Pad, 1, cond, kGE);
|
|
return ifThenElse(cond, 0.f, IP.load(n, h - Pad));
|
|
});
|
|
Tensor B = Reduce(
|
|
"B",
|
|
{N, H},
|
|
Sum(),
|
|
[&](Axis n, Axis h, Axis r) { return A.load(n, h + r); },
|
|
{R});
|
|
LoopNest l({B});
|
|
checkIR(l.root_stmt(), R"IR(
|
|
# CHECK: for (int np = 0; np < 4; np++) {
|
|
# CHECK: for (int hp = 0; hp < 258; hp++) {
|
|
# CHECK: A[np, hp] = IfThenElse(hp>=257 ? 1 : (hp<1 ? 1 : 0), 0.f, input[np, hp - 1]);
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: for (int n = 0; n < 4; n++) {
|
|
# CHECK: for (int h = 0; h < 256; h++) {
|
|
# CHECK: B[n, h] = float(0);
|
|
# CHECK: for (int r = 0; r < 3; r++) {
|
|
# CHECK: B[n, h] = ReduceOp((B[n, h]) + (A(n, h + r)), reduce_args={r});
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
)IR");
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0);
|
|
LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]);
|
|
// FIXME: The current IR is totally broken. The body of the inlined loop is:
|
|
|
|
// temp[idx0, idx1] = IfThenElse(idx0 + n>=257 ? 1 : (idx0 + n<1 ? 1 : 0),
|
|
// 0.f, input[idx1 + 0, (idx0 + n) - 1]);
|
|
|
|
// Which seems to mix up the axes. The CHECK below is my best guess at what
|
|
// the input "should" look like
|
|
|
|
checkIR(l.root_stmt(), R"IR(
|
|
# CHECK: for (int n = 0; n < 4; n++) {
|
|
# CHECK: for (int idx0 = 0; idx0 < 1; idx0++) {
|
|
# CHECK: for (int idx1 = 0; idx1 < 258; idx1++) {
|
|
temp[idx0, idx1] = IfThenElse(idx1>=257 ? 1 : (idx1<1 ? 1 : 0), 0.f, input[n, idx1 - 1]);
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: for (int h = 0; h < 256; h++) {
|
|
# CHECK: B[n, h] = float(0);
|
|
# CHECK: for (int r = 0; r < 3; r++) {
|
|
# CHECK: B[n, h] = ReduceOp((B[n, h]) + (temp[0, r + h]), reduce_args={r});
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
)IR");
|
|
|
|
l.simplify();
|
|
l.prepareForCodegen();
|
|
StmtPtr s = l.root_stmt();
|
|
|
|
SimpleIREvaluator cg(s, {IP, B});
|
|
// auto At = at::ones({N, H}, at::kFloat);
|
|
auto At = at::arange(N * H, at::kFloat).reshape({N, H});
|
|
auto Rt = at::conv1d(
|
|
At, at::ones({1, 1, 3}), at::Tensor(), /*stride=*/1, /*padding=*/3);
|
|
auto Bt = at::empty_like(Rt);
|
|
cg.call({At.data_ptr<float>(), Bt.data_ptr<float>()});
|
|
ASSERT_TRUE(at::allclose(Rt, Bt));
|
|
}
|
|
|
|
class LoopOrderHelper : public IRVisitor {
|
|
std::stringstream ordering;
|
|
|
|
public:
|
|
std::string getOrder(StmtPtr s) {
|
|
ordering.str("");
|
|
s->accept(this);
|
|
return ordering.str();
|
|
}
|
|
|
|
void visit(const ForPtr& v) final {
|
|
ordering << v->var()->name_hint() << ",";
|
|
IRVisitor::visit(v);
|
|
}
|
|
};
|
|
|
|
TEST(LoopNest, LoopNestReorderAxis1) {
|
|
Tensor tensor =
|
|
Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) {
|
|
return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
|
|
});
|
|
LoopNest l({tensor});
|
|
StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
|
|
|
|
std::vector<int> stmt1_output(6, 0);
|
|
SimpleIREvaluator cg(stmt1, {tensor});
|
|
cg.call({stmt1_output});
|
|
|
|
auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::reorderAxis(loops[0], loops[1]);
|
|
StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
|
|
|
|
ASSERT_NE(stmt1, stmt2);
|
|
LoopOrderHelper loopOrderHelper;
|
|
std::string order1 = loopOrderHelper.getOrder(stmt1);
|
|
std::string order2 = loopOrderHelper.getOrder(stmt2);
|
|
|
|
ASSERT_EQ(order1, "j,i,");
|
|
ASSERT_EQ(order2, "i,j,");
|
|
|
|
std::vector<int> stmt2_output(6, 0);
|
|
SimpleIREvaluator cg2(stmt2, {tensor});
|
|
cg.call({stmt2_output});
|
|
|
|
for (int i = 0; i < 6; ++i) {
|
|
ASSERT_EQ(stmt1_output[i], stmt2_output[i]);
|
|
}
|
|
|
|
// Reorder them back.
|
|
loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::reorderAxis(loops[0], loops[1]);
|
|
StmtPtr stmt3 = l.root_stmt();
|
|
|
|
std::string order3 = loopOrderHelper.getOrder(stmt3);
|
|
ASSERT_EQ(order3, order1);
|
|
|
|
std::ostringstream oss1, oss2;
|
|
oss1 << *stmt1;
|
|
oss2 << *stmt3;
|
|
|
|
// Should be identical to the unreordered statement.
|
|
ASSERT_EQ(oss1.str(), oss2.str());
|
|
}
|
|
|
|
TEST(LoopNest, LoopNestReorderPartialAxes) {
|
|
Tensor tensor = Compute(
|
|
"f",
|
|
{2, 3, 4},
|
|
[](const VarHandle& x, const VarHandle& y, const VarHandle& z) {
|
|
return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y +
|
|
cast<float>(z) * z;
|
|
});
|
|
LoopNest l({tensor});
|
|
|
|
LoopOrderHelper loopOrderHelper;
|
|
StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
|
|
ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,");
|
|
|
|
std::vector<int> stmt1_output(24, 0);
|
|
SimpleIREvaluator cg(stmt1, {tensor});
|
|
cg.call({stmt1_output});
|
|
|
|
auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::reorderAxis(loops[0], loops[1]);
|
|
ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,i,k,");
|
|
|
|
StmtPtr stmt2 = Stmt::clone(l.root_stmt());
|
|
|
|
std::vector<int> stmt2_output(24, 0);
|
|
SimpleIREvaluator cg2(stmt2, {tensor});
|
|
cg2.call({stmt2_output});
|
|
|
|
for (int i = 0; i < 24; ++i) {
|
|
ASSERT_EQ(stmt1_output[i], stmt2_output[i]);
|
|
}
|
|
|
|
loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::reorderAxis(loops[1], loops[2]);
|
|
ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,k,i,");
|
|
|
|
StmtPtr stmt3 = Stmt::clone(l.root_stmt());
|
|
|
|
std::vector<int> stmt3_output(24, 0);
|
|
SimpleIREvaluator cg3(stmt3, {tensor});
|
|
cg3.call({stmt3_output});
|
|
|
|
for (int i = 0; i < 24; ++i) {
|
|
ASSERT_EQ(stmt1_output[i], stmt3_output[i]);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, LoopNestReorderInternalAxis) {
|
|
Tensor tensor = Compute(
|
|
"f",
|
|
{1, 2, 3, 4},
|
|
[](const VarHandle& w,
|
|
const VarHandle& x,
|
|
const VarHandle& y,
|
|
const VarHandle& z) {
|
|
return ExprHandle(1.0f) + w + cast<float>(x) * x + cast<float>(y) * y +
|
|
cast<float>(z) * z;
|
|
});
|
|
LoopNest l({tensor});
|
|
|
|
LoopOrderHelper loopOrderHelper;
|
|
StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
|
|
ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,l,");
|
|
|
|
std::vector<int> stmt1_output(24, 0);
|
|
SimpleIREvaluator cg(stmt1, {tensor});
|
|
cg.call({stmt1_output});
|
|
|
|
auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::reorderAxis(loops[2], loops[1]);
|
|
ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "i,k,j,l,");
|
|
|
|
StmtPtr stmt2 = l.root_stmt();
|
|
|
|
std::vector<int> stmt2_output(24, 0);
|
|
SimpleIREvaluator cg2(stmt2, {tensor});
|
|
cg2.call({stmt2_output});
|
|
|
|
for (int i = 0; i < 24; ++i) {
|
|
ASSERT_EQ(stmt1_output[i], stmt2_output[i]);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, LoopNestReorderEnclosingAxis) {
|
|
Tensor tensor = Compute(
|
|
"f",
|
|
{1, 2, 3, 4},
|
|
[](const VarHandle& w,
|
|
const VarHandle& x,
|
|
const VarHandle& y,
|
|
const VarHandle& z) {
|
|
return ExprHandle(1.0f) + w + cast<float>(x) * x + cast<float>(y) * y +
|
|
cast<float>(z) * z;
|
|
});
|
|
LoopNest l({tensor});
|
|
|
|
LoopOrderHelper loopOrderHelper;
|
|
StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
|
|
|
|
std::vector<int> stmt1_output(24, 0);
|
|
SimpleIREvaluator cg(stmt1, {tensor});
|
|
cg.call({stmt1_output});
|
|
|
|
auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::reorderAxis(loops[0], loops[3]);
|
|
ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "l,j,k,i,");
|
|
|
|
StmtPtr stmt2 = l.root_stmt();
|
|
|
|
std::vector<int> stmt2_output(24, 0);
|
|
SimpleIREvaluator cg2(stmt2, {tensor});
|
|
cg2.call({stmt2_output});
|
|
|
|
for (int i = 0; i < 24; ++i) {
|
|
ASSERT_EQ(stmt1_output[i], stmt2_output[i]);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, LoopNestReorderSameAxis) {
|
|
Tensor tensor =
|
|
Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) {
|
|
return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
|
|
});
|
|
LoopNest l({tensor});
|
|
StmtPtr stmt1 = Stmt::clone(l.root_stmt());
|
|
|
|
auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::reorderAxis(loops[1], loops[1]);
|
|
StmtPtr stmt2 = Stmt::clone(l.root_stmt());
|
|
|
|
std::ostringstream oss, oss2;
|
|
oss << *stmt1;
|
|
oss2 << *stmt2;
|
|
ASSERT_EQ(oss.str(), oss2.str());
|
|
}
|
|
|
|
TEST(LoopNest, LoopNestReorderExtraStatements) {
|
|
/* We're going for a structure like this:
|
|
* for i in ...
|
|
* Stmt 1
|
|
* for j in ...
|
|
* Stmt 2
|
|
* for k in ...
|
|
* Stmt 3
|
|
* Stmt 4
|
|
*/
|
|
|
|
Tensor tensor = Compute(
|
|
"f",
|
|
{2, 3, 4},
|
|
[](const VarHandle& x, const VarHandle& y, const VarHandle& z) {
|
|
return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y +
|
|
cast<float>(z) * z;
|
|
});
|
|
LoopNest l({tensor});
|
|
|
|
BufHandle extra("res", {6, 3}, kFloat);
|
|
|
|
auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
|
|
VarHandle i = VarHandle(loops[0]->var());
|
|
|
|
StmtPtr store_1 = Store::make(extra, {i, 0}, 1.f);
|
|
StmtPtr store_2 = Store::make(extra, {i, 1}, 2.f);
|
|
// stmt 3 is the Function body.
|
|
StmtPtr store_3 = Store::make(extra, {i, 2}, 4.f);
|
|
|
|
loops[0]->body()->prepend_stmt(store_1);
|
|
loops[1]->body()->prepend_stmt(store_2);
|
|
loops[1]->body()->append_stmt(store_3);
|
|
StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
|
|
|
|
std::vector<int> extra1(6, 0);
|
|
std::vector<int> res1(24, 0);
|
|
SimpleIREvaluator cg(stmt1, {tensor, extra});
|
|
cg.call({res1, extra1});
|
|
|
|
/* Then we reorder loop y and z, we want it to look like:
|
|
*
|
|
* for i in ...
|
|
* Stmt 1
|
|
* for j in ...
|
|
* Stmt 2
|
|
* for j_1 in ...
|
|
* for k in ...
|
|
* Stmt 3
|
|
* for j_2 in ...
|
|
* Stmt 4
|
|
*
|
|
* We need extra loops because we don't have dependency info about stmt 3
|
|
* and 4.
|
|
*
|
|
*/
|
|
|
|
LoopNest::reorderAxis(loops[1], loops[2]);
|
|
StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
|
|
|
|
// Check the IR we produced
|
|
checkIR(stmt2, R"IR(
|
|
# CHECK: for
|
|
# CHECK: res[i, 0] = 1
|
|
# CHECK: for
|
|
# CHECK: res[i, 1] = 2
|
|
# CHECK: for
|
|
# CHECK: for
|
|
# CHECK: f[
|
|
# CHECK: for
|
|
# CHECK: res[i, 2] = 4
|
|
)IR");
|
|
|
|
std::vector<int> extra2(6, 0);
|
|
std::vector<int> res2(24, 0);
|
|
SimpleIREvaluator cg2(stmt2, {tensor, extra});
|
|
cg2.call({res2, extra2});
|
|
|
|
for (int i = 0; i < 24; ++i) {
|
|
ASSERT_EQ(res1[i], res2[i]);
|
|
}
|
|
for (int i = 0; i < 6; ++i) {
|
|
ASSERT_EQ(extra1[i], extra2[i]);
|
|
}
|
|
|
|
/* Now reorder x and the y above stmt 3:
|
|
*
|
|
*
|
|
* for x in ...
|
|
* Stmt 1
|
|
* for y in ...
|
|
* Stmt 2
|
|
*
|
|
* for y in ...
|
|
* for z in ...
|
|
* for x in ...
|
|
* Stmt 3
|
|
*
|
|
* for x in ...
|
|
* for y in ...
|
|
* Stmt 4
|
|
*
|
|
*
|
|
*/
|
|
loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
|
|
LoopNest::reorderAxis(loops[0], loops[2]);
|
|
StmtPtr stmt3 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
|
|
|
|
// Check the IR we produced
|
|
checkIR(stmt3, R"IR(
|
|
# CHECK: for
|
|
# CHECK: res[i, 0] = 1
|
|
# CHECK: for
|
|
# CHECK: res[i, 1] = 2
|
|
# CHECK: for
|
|
# CHECK: for
|
|
# CHECK: for
|
|
# CHECK: f[
|
|
# CHECK: for
|
|
# CHECK: for
|
|
# CHECK: res[i_2, 2] = 4
|
|
)IR");
|
|
|
|
std::vector<int> extra3(6, 0);
|
|
std::vector<int> res3(24, 0);
|
|
SimpleIREvaluator cg3(stmt3, {tensor, extra});
|
|
cg3.call({res3, extra3});
|
|
|
|
for (int i = 0; i < 24; ++i) {
|
|
ASSERT_EQ(res1[i], res3[i]);
|
|
}
|
|
for (int i = 0; i < 6; ++i) {
|
|
ASSERT_EQ(extra1[i], extra3[i]);
|
|
}
|
|
}
|
|
|
|
void LoopNestReorderTestHelper(
|
|
bool prepend,
|
|
bool append,
|
|
int index1,
|
|
int index2) {
|
|
Tensor c = Compute(
|
|
"5d", {2, 3, 2, 3, 2}, [](const std::vector<VarHandle>&) { return -1; });
|
|
LoopNest l({c});
|
|
|
|
BufHandle extra("extra", {5}, kInt);
|
|
|
|
auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
|
|
int j = 0;
|
|
for (auto l : loops) {
|
|
// Add an increment at each layer of the loop which counts the number of
|
|
// times the loop executes.
|
|
LoadPtr load =
|
|
alloc<Load>(extra.node(), std::vector<ExprPtr>({alloc<IntImm>(j)}));
|
|
AddPtr add = alloc<Add>(load, alloc<IntImm>(1));
|
|
StmtPtr store = alloc<Store>(
|
|
extra.node(), std::vector<ExprPtr>({alloc<IntImm>(j)}), add);
|
|
if (prepend) {
|
|
l->body()->prepend_stmt(store);
|
|
}
|
|
if (append) {
|
|
l->body()->append_stmt(Stmt::clone(store));
|
|
}
|
|
|
|
j++;
|
|
}
|
|
|
|
StmtPtr stmt1 = Stmt::clone(l.root_stmt());
|
|
|
|
std::vector<int> extra1(5, 0);
|
|
std::vector<int> res1(2 * 3 * 2 * 3 * 2, 0);
|
|
SimpleIREvaluator cg(stmt1, {c, extra});
|
|
cg.call({res1, extra1});
|
|
|
|
std::vector<int> loopExtents = {2, 3, 2, 3, 2};
|
|
|
|
int expected_loops = 0;
|
|
if (prepend) {
|
|
expected_loops++;
|
|
}
|
|
if (append) {
|
|
expected_loops++;
|
|
}
|
|
for (int i = 0; i < 5; ++i) {
|
|
expected_loops *= loopExtents[i];
|
|
ASSERT_EQ(extra1[i], expected_loops);
|
|
}
|
|
|
|
loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
|
|
LoopNest::reorderAxis(loops[index1], loops[index2]);
|
|
StmtPtr stmt2 = Stmt::clone(l.root_stmt());
|
|
|
|
std::ostringstream oss, oss2;
|
|
oss << *stmt1;
|
|
oss2 << *stmt2;
|
|
ASSERT_NE(oss.str(), oss2.str());
|
|
|
|
std::vector<int> extra2(5, 0);
|
|
std::vector<int> res2(2 * 3 * 2 * 3 * 2, 0);
|
|
SimpleIREvaluator cg2(stmt2, {c, extra});
|
|
cg2.call({res2, extra2});
|
|
|
|
expected_loops = 0;
|
|
if (prepend) {
|
|
expected_loops++;
|
|
}
|
|
if (append) {
|
|
expected_loops++;
|
|
}
|
|
|
|
for (int i = 0; i < 5; ++i) {
|
|
expected_loops *= loopExtents[i];
|
|
ASSERT_EQ(extra2[i], expected_loops);
|
|
}
|
|
|
|
for (int i = 0; i < 2 * 3 * 2 * 3 * 2; ++i) {
|
|
ASSERT_EQ(res2[i], res1[i]);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, LoopNestReorderLongStringOfPreOrphans) {
|
|
for (int i = 0; i < 5; ++i) {
|
|
for (int j = 0; j < 5; ++j) {
|
|
// skip noops, since we check the loop isn't the same after reordering.
|
|
if (i != j) {
|
|
LoopNestReorderTestHelper(true, false, i, j);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, LoopNestReorderLongStringOfPostOrphans) {
|
|
for (int i = 0; i < 5; ++i) {
|
|
for (int j = 0; j < 5; ++j) {
|
|
// skip noops, since we check the loop isn't the same after reordering.
|
|
if (i != j) {
|
|
LoopNestReorderTestHelper(false, true, i, j);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, LoopNestReorderLongStringFull) {
|
|
for (int i = 0; i < 5; ++i) {
|
|
for (int j = 0; j < 5; ++j) {
|
|
// skip noops, since we check the loop isn't the same after reordering.
|
|
if (i != j) {
|
|
LoopNestReorderTestHelper(true, true, i, j);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, LoopNestReorderInternalLoopNest) {
|
|
const int M = 4;
|
|
const int N = 5;
|
|
const int K = 6;
|
|
BufHandle a_buf("a", {M, N}, kFloat);
|
|
BufHandle b_buf("b", {N, K}, kFloat);
|
|
BufHandle c_buf("c", {M, N}, kFloat);
|
|
BufHandle d_buf("d", {M, K}, kFloat);
|
|
|
|
Tensor x = Compute(
|
|
"x",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return a_buf.load(m, n) * b_buf.load(n, k);
|
|
});
|
|
Tensor y = Compute(
|
|
"y",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k);
|
|
});
|
|
Tensor z = Compute(
|
|
"z",
|
|
{M, N, K},
|
|
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
|
return x.load(m, n, k) + y.load(m, n, k);
|
|
});
|
|
|
|
LoopNest l({z}, {x, y, z});
|
|
ForPtr a = l.getAllLoopNestsWritingToBuf(y.buf())[0][2];
|
|
ForPtr b = l.getAllLoopNestsWritingToBuf(y.buf())[0][0];
|
|
LoopNest::reorderAxis(a, b);
|
|
|
|
l.prepareForCodegen();
|
|
StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
|
|
|
|
// Check the IR we produced has the 3 nests in the right order, but k and m
|
|
// swapped in the middle.
|
|
checkIR(stmt, R"IR(
|
|
# CHECK: < 4
|
|
# CHECK: < 5
|
|
# CHECK: < 6
|
|
# CHECK: < 6
|
|
# CHECK: < 5
|
|
# CHECK: < 4
|
|
# CHECK: < 4
|
|
# CHECK: < 5
|
|
# CHECK: < 6)IR");
|
|
|
|
{
|
|
PaddedBuffer<float> a_v(M, N);
|
|
PaddedBuffer<float> b_v(N, K);
|
|
PaddedBuffer<float> c_v(M, N);
|
|
PaddedBuffer<float> d_v(M, K);
|
|
|
|
for (int i = 0; i < M; i++) {
|
|
for (int j = 0; j < N; j++) {
|
|
a_v(i, j) = i * i;
|
|
}
|
|
}
|
|
for (int i = 0; i < N; i++) {
|
|
for (int j = 0; j < K; j++) {
|
|
b_v(i, j) = j * j;
|
|
}
|
|
}
|
|
for (int i = 0; i < M; i++) {
|
|
for (int j = 0; j < N; j++) {
|
|
c_v(i, j) = i + j;
|
|
}
|
|
}
|
|
for (int i = 0; i < M; i++) {
|
|
for (int j = 0; j < K; j++) {
|
|
d_v(i, j) = i * j;
|
|
}
|
|
}
|
|
|
|
PaddedBuffer<float> z_v(M, N, K);
|
|
PaddedBuffer<float> z_ref(M, N, K);
|
|
for (int m = 0; m < M; m++) {
|
|
for (int n = 0; n < N; n++) {
|
|
for (int k = 0; k < K; k++) {
|
|
z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k);
|
|
}
|
|
}
|
|
}
|
|
|
|
SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z});
|
|
eval(a_v, b_v, c_v, d_v, z_v);
|
|
ExpectAllNear(z_v, z_ref, 1e-5);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, OuterLoopVectorization) {
|
|
Tensor tensor =
|
|
Compute("f", {8, 8}, [](const VarHandle& x, const VarHandle& y) {
|
|
return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
|
|
});
|
|
LoopNest l({tensor});
|
|
|
|
ASSERT_TRUE(
|
|
LoopNest::vectorize(l.getAllLoopNestsWritingToBuf(tensor.buf())[0][0]));
|
|
|
|
StmtPtr root_stmt = l.root_stmt();
|
|
BlockPtr outer_block = to<Block>(root_stmt);
|
|
ASSERT_NE(outer_block, nullptr);
|
|
while (BlockPtr inner_block = to<Block>(outer_block->front())) {
|
|
outer_block = inner_block;
|
|
}
|
|
|
|
// Verify that we have only a single loop level remaining after
|
|
// vectorization.
|
|
ASSERT_EQ(outer_block->nstmts(), 1);
|
|
ForPtr for_loop = to<For>(outer_block->front());
|
|
ASSERT_NE(for_loop, nullptr);
|
|
BlockPtr for_body = for_loop->body();
|
|
ASSERT_EQ(for_body->nstmts(), 1);
|
|
ASSERT_EQ(to<For>(for_body->front()), nullptr);
|
|
}
|
|
|
|
TEST(LoopNest, VectorizeLoopNotNormalized) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 10; i++) {
|
|
// for (int j = 1; j < 5; j++) {
|
|
// A[i,j] = i * j;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {10, 5}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
|
|
auto inner_for = For::make(j, 1, 5, for_body);
|
|
auto outer_for = For::make(i, 0, 10, inner_for);
|
|
auto block = Block::make({outer_for});
|
|
LoopNest l(block, {a_buf.node()});
|
|
|
|
ASSERT_TRUE(LoopNest::vectorize(inner_for));
|
|
ASSERT_EQ(outer_for->body()->nstmts(), 1);
|
|
ASSERT_EQ(to<For>(outer_for->body()->front()), nullptr);
|
|
}
|
|
|
|
namespace {
|
|
|
|
std::string constantUpperBoundLoopIR(int upper_bound_val) {
|
|
ExprHandle upper_bound(upper_bound_val);
|
|
Tensor A =
|
|
Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; });
|
|
LoopNest l({A});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
|
|
StmtPtr unrolled = nullptr;
|
|
LoopNest::fullUnroll(loops[0], &unrolled);
|
|
std::ostringstream oss;
|
|
oss << *unrolled;
|
|
return oss.str();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
TEST(LoopNest, Unroll) {
|
|
const std::string actual = constantUpperBoundLoopIR(3);
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: A[0] = 0;
|
|
# CHECK: A[1] = 2;
|
|
# CHECK: A[2] = 4)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, actual);
|
|
}
|
|
|
|
TEST(LoopNest, UnrollOuter) {
|
|
ExprHandle outer_bound(3);
|
|
ExprHandle inner_bound(4);
|
|
Tensor A = Compute(
|
|
"A",
|
|
{outer_bound, inner_bound},
|
|
[&](const VarHandle& x, const VarHandle& y) { return x + y; });
|
|
LoopNest l({A});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
|
|
StmtPtr unrolled = nullptr;
|
|
LoopNest::fullUnroll(loops[0], &unrolled);
|
|
checkIR(unrolled, R"IR(
|
|
# CHECK: for (int i = 0; i < 4; i++) {
|
|
# CHECK: A[0, i] = i;
|
|
# CHECK: }
|
|
# CHECK: for (int i = 0; i < 4; i++) {
|
|
# CHECK: A[1, i] = i + 1;
|
|
# CHECK: }
|
|
# CHECK: for (int i = 0; i < 4; i++) {
|
|
# CHECK: A[2, i] = i + 2;
|
|
# CHECK: })IR");
|
|
}
|
|
|
|
TEST(LoopNest, UnrollInner) {
|
|
ExprHandle outer_bound(3);
|
|
ExprHandle inner_bound(4);
|
|
Tensor A = Compute(
|
|
"A",
|
|
{outer_bound, inner_bound},
|
|
[&](const VarHandle& x, const VarHandle& y) { return x + y; });
|
|
LoopNest l({A});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
|
|
StmtPtr unrolled = nullptr;
|
|
LoopNest::fullUnroll(
|
|
static_to<For>(loops[0]->body()->stmts().front()), &unrolled);
|
|
checkIR(loops[0], R"IR(
|
|
# CHECK: for (int i = 0; i < 3; i++) {
|
|
# CHECK: A[i, 0] = i;
|
|
# CHECK: A[i, 1] = i + 1;
|
|
# CHECK: A[i, 2] = i + 2;
|
|
# CHECK: A[i, 3] = i + 3;
|
|
# CHECK: })IR");
|
|
}
|
|
|
|
TEST(LoopNest, UnrollMultipleStatements) {
|
|
const int kTotalSize = 3;
|
|
BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
|
|
BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
|
|
|
|
VarHandle x("x", kInt);
|
|
auto f = For::make(
|
|
x,
|
|
0,
|
|
kTotalSize,
|
|
Block::make(
|
|
{Store::make(a_buf, {x}, x * 2),
|
|
Store::make(b_buf, {x}, Load::make(a_buf, {x}))}));
|
|
auto parent_block = Block::make({f});
|
|
StmtPtr unrolled = nullptr;
|
|
LoopNest::fullUnroll(f, &unrolled);
|
|
checkIR(unrolled, R"IR(
|
|
# CHECK: A[0] = 0;
|
|
# CHECK: B[0] = A[0];
|
|
# CHECK: A[1] = 2;
|
|
# CHECK: B[1] = A[1];
|
|
# CHECK: A[2] = 4
|
|
# CHECK: B[2] = A[2];)IR");
|
|
}
|
|
|
|
TEST(LoopNest, UnrollNonLiteralConstantBounds) {
|
|
// Input IR:
|
|
// for (int i = 2 - 1; i < 12 / 3; i++) {
|
|
// for (int j = 0; j < 4; j++) {
|
|
// A[i,j] = i * j;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {3, 4}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
|
|
auto inner_for = For::make(j, 0, 4, for_body);
|
|
auto outer_for = For::make(
|
|
i,
|
|
IntImm::make(2) - IntImm::make(1),
|
|
IntImm::make(12) / IntImm::make(3),
|
|
inner_for);
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
auto b = Block::make({outer_for});
|
|
|
|
std::vector<ForPtr> loops = {outer_for, inner_for};
|
|
StmtPtr unrolled = nullptr;
|
|
LoopNest::fullUnroll(loops[0], &unrolled);
|
|
checkIR(unrolled, R"IR(
|
|
# CHECK: for (int j = 0; j < 4; j++) {
|
|
# CHECK: A[1, j] = j;
|
|
# CHECK: }
|
|
# CHECK: for (int j = 0; j < 4; j++) {
|
|
# CHECK: A[2, j] = 2 * j;
|
|
# CHECK: }
|
|
# CHECK: for (int j = 0; j < 4; j++) {
|
|
# CHECK: A[3, j] = 3 * j;
|
|
# CHECK: })IR");
|
|
}
|
|
|
|
TEST(LoopNest, UnrollNonConstantBounds) {
|
|
// Input IR:
|
|
// for (int i = 0; i < M; i++) {
|
|
// for (int j = 0; j < N; j++) {
|
|
// A[i, j] = i * j;
|
|
// }
|
|
// }
|
|
VarHandle M("M", kInt);
|
|
VarHandle N("N", kInt);
|
|
BufHandle a_buf("A", {M, N}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
|
|
auto inner_for = For::make(j, 0, N, for_body);
|
|
auto outer_for = For::make(i, 0, M, inner_for);
|
|
auto block = Block::make({outer_for});
|
|
LoopNest l(block, {a_buf.node()});
|
|
|
|
LoopNest::unroll(inner_for, 8);
|
|
l.simplify();
|
|
checkIR(l.root_stmt(), R"IR(
|
|
# CHECK: for (int i = 0; i < M; i++) {
|
|
# CHECK: for (int j_outer = 0; j_outer < N / 8; j_outer++) {
|
|
# CHECK: A[i, 8 * j_outer] =
|
|
# CHECK: A[i, 8 * j_outer + 1] =
|
|
# CHECK: A[i, 2 * (4 * j_outer + 1)] =
|
|
# CHECK: A[i, 8 * j_outer + 3] =
|
|
# CHECK: A[i, 4 * (2 * j_outer + 1)] =
|
|
# CHECK: A[i, 8 * j_outer + 5] =
|
|
# CHECK: A[i, 8 * j_outer + 6] =
|
|
# CHECK: A[i, 8 * j_outer + 7] =
|
|
# CHECK: }
|
|
# CHECK: for (int j_tail = 0; j_tail < N % 8; j_tail++) {
|
|
# CHECK: A[i, 8 * (N / 8) + j_tail] =
|
|
# CHECK: }
|
|
# CHECK: }
|
|
)IR");
|
|
}
|
|
|
|
TEST(LoopNest, UnrollByFactorsLessThan2) {
|
|
// Input IR:
|
|
// for (int i = 0; i < M; i++) {
|
|
// for (int j = 0; j < N; j++) {
|
|
// A[i, j] = i * j;
|
|
// }
|
|
// }
|
|
VarHandle M("M", kInt);
|
|
VarHandle N("N", kInt);
|
|
BufHandle a_buf("A", {M, N}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
|
|
auto inner_for = For::make(j, 0, N, for_body);
|
|
auto outer_for = For::make(i, 0, M, inner_for);
|
|
auto block = Block::make({outer_for});
|
|
LoopNest l(block, {a_buf.node()});
|
|
|
|
// Unrolling by factor = 1 should do nothing.
|
|
LoopNest::unroll(inner_for, 1);
|
|
checkIR(l.root_stmt(), R"IR(
|
|
# CHECK: for (int i = 0; i < M; i++) {
|
|
# CHECK: for (int j = 0; j < N; j++) {
|
|
# CHECK: A[i, j] =
|
|
# CHECK: }
|
|
# CHECK: }
|
|
)IR");
|
|
|
|
// Unrolling by factor = 0 should do nothing.
|
|
LoopNest::unroll(inner_for, 0);
|
|
checkIR(l.root_stmt(), R"IR(
|
|
# CHECK: for (int i = 0; i < M; i++) {
|
|
# CHECK: for (int j = 0; j < N; j++) {
|
|
# CHECK: A[i, j] =
|
|
# CHECK: }
|
|
# CHECK: }
|
|
)IR");
|
|
|
|
// Unrolling by negative factor should do nothing.
|
|
LoopNest::unroll(inner_for, -2);
|
|
checkIR(l.root_stmt(), R"IR(
|
|
# CHECK: for (int i = 0; i < M; i++) {
|
|
# CHECK: for (int j = 0; j < N; j++) {
|
|
# CHECK: A[i, j] =
|
|
# CHECK: }
|
|
# CHECK: }
|
|
)IR");
|
|
}
|
|
|
|
TEST(LoopNest, UnrollByFactorEqualToIters) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 5; i++) {
|
|
// A[i] = i * i;
|
|
// }
|
|
BufHandle a_buf("A", {5}, kInt);
|
|
VarHandle i("i", kInt);
|
|
auto for_body = Block::make({Store::make(a_buf, {i}, i * i)});
|
|
auto for_loop = For::make(i, 0, 5, for_body);
|
|
auto block = Block::make({for_loop});
|
|
LoopNest l(block, {a_buf.node()});
|
|
|
|
LoopNest::unroll(for_loop, 5);
|
|
checkIR(l.root_stmt(), R"IR(
|
|
# CHECK: for (int i_outer = 0; i_outer < (5 - 0) / 5; i_outer++)
|
|
# CHECK: A[5 * i_outer]
|
|
# CHECK: A[5 * i_outer + 1]
|
|
# CHECK: A[5 * i_outer + 2]
|
|
# CHECK: A[5 * i_outer + 3]
|
|
# CHECK: A[5 * i_outer + 4]
|
|
)IR");
|
|
}
|
|
|
|
TEST(LoopNest, UnrollEmpty) {
|
|
const std::string actual = constantUpperBoundLoopIR(0);
|
|
const std::string& verification_pattern = R"IR(
|
|
# CHECK-NOT: A[
|
|
)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, actual);
|
|
}
|
|
|
|
TEST(LoopNest, NoUnroll) {
|
|
VarHandle upper_bound("N", kInt);
|
|
Tensor A =
|
|
Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; });
|
|
LoopNest l({A});
|
|
std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
|
|
StmtPtr unrolled = nullptr;
|
|
ASSERT_THROWS_WITH(
|
|
LoopNest::fullUnroll(loops[0], &unrolled), "non-constant loop");
|
|
}
|
|
|
|
TEST(LoopNest, UnrollWithLet) {
|
|
const int kTotalSize = 3;
|
|
BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
|
|
BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
|
|
|
|
VarHandle e("e", kInt);
|
|
VarHandle x("x", kInt);
|
|
auto f = For::make(
|
|
x,
|
|
0,
|
|
kTotalSize,
|
|
Block::make(
|
|
{Let::make(e, 7),
|
|
Store::make(a_buf, {x}, e),
|
|
Store::make(b_buf, {x}, e + 1)}));
|
|
auto parent_block = Block::make({f});
|
|
StmtPtr unrolled = nullptr;
|
|
LoopNest::fullUnroll(f, &unrolled);
|
|
std::ostringstream oss;
|
|
oss << *unrolled;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: int e = 7;
|
|
# CHECK: A[0] = e;
|
|
# CHECK: B[0] = e + 1;
|
|
# CHECK: A[1] = e;
|
|
# CHECK: B[1] = e + 1;
|
|
# CHECK: A[2] = e;
|
|
# CHECK: B[2] = e + 1;)IR";
|
|
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
std::vector<int> a_v(kTotalSize, 0);
|
|
std::vector<int> b_v(kTotalSize, 0);
|
|
SimpleIREvaluator eval(unrolled, {a_buf, b_buf});
|
|
eval(a_v, b_v);
|
|
for (int i = 0; i < kTotalSize; ++i) {
|
|
ASSERT_EQ(a_v[i], 7);
|
|
ASSERT_EQ(b_v[i], 8);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, IsNormalized) {
|
|
// Input IR:
|
|
// for (int i = 50; i < 100; i++) {
|
|
// A[i] = B[i];
|
|
// }
|
|
BufHandle a_buf("A", {ExprHandle(100)}, kInt);
|
|
BufHandle b_buf("B", {ExprHandle(100)}, kInt);
|
|
VarHandle i("i", kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto for_stmt =
|
|
For::make(i, 50, 100, Store::make(a_buf, {i}, Load::make(b_buf, {i})));
|
|
Block::make({for_stmt});
|
|
ASSERT_FALSE(LoopNest::isNormalized(for_stmt));
|
|
|
|
for_stmt->set_start(alloc<IntImm>(0));
|
|
ASSERT_TRUE(LoopNest::isNormalized(for_stmt));
|
|
|
|
VarHandle N("N", kInt);
|
|
for_stmt->set_start(N.node());
|
|
ASSERT_FALSE(LoopNest::isNormalized(for_stmt));
|
|
}
|
|
|
|
TEST(LoopNest, NormalizeStartPositive) {
|
|
// Input IR:
|
|
// for (int x = 50; x < 100; x++) {
|
|
// A[x] = B[x];
|
|
// B[x] = x * 2;
|
|
// }
|
|
const int kTotalSize = 50;
|
|
BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
|
|
BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
|
|
VarHandle x("x", kInt);
|
|
auto for_body = Block::make(
|
|
{Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})),
|
|
Store::make(b_buf, {x}, x * 2)});
|
|
auto for_stmt = For::make(x, 50, 100, for_body);
|
|
Block::make({for_stmt});
|
|
|
|
LoopNest::normalize(for_stmt);
|
|
|
|
auto result = IRSimplifier::simplify(for_stmt);
|
|
std::ostringstream oss;
|
|
oss << *result;
|
|
const std::string& expected_ir =
|
|
R"IR(
|
|
# CHECK: for (int x = 0; x < 50; x++) {
|
|
# CHECK: A[x + 50] = B[x + 50];
|
|
# CHECK: B[x + 50] = 2 * (x + 50);
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
|
}
|
|
|
|
TEST(LoopNest, NormalizeStartNegative) {
|
|
// Input IR:
|
|
// for (int x = -50; x < 100; x++) {
|
|
// A[x + 50] = B[x + 50];
|
|
// B[x + 50] = x * 2;
|
|
// }
|
|
const int kTotalSize = 150;
|
|
BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
|
|
BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
|
|
VarHandle x("x", kInt);
|
|
auto for_body = Block::make(
|
|
{Store::make(a_buf, {x + 50}, Load::make(kInt, b_buf, {x + 50})),
|
|
Store::make(b_buf, {x + 50}, x * 2)});
|
|
auto for_stmt = For::make(x, -50, 100, for_body);
|
|
Block::make({for_stmt});
|
|
|
|
LoopNest::normalize(for_stmt);
|
|
|
|
auto result = IRSimplifier::simplify(for_stmt);
|
|
std::ostringstream oss;
|
|
oss << *result;
|
|
const std::string& expected_ir =
|
|
R"IR(
|
|
# CHECK: for (int x = 0; x < 150; x++) {
|
|
# CHECK: A[x] = B[x];
|
|
# CHECK: B[x] = 2 * (x - 50);
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
|
}
|
|
|
|
TEST(LoopNest, NormalizeStartZero) {
|
|
// Input IR:
|
|
// for (int x = 0; x < 100; x++) {
|
|
// A[x] = B[x];
|
|
// B[x] = x * 2;
|
|
// }
|
|
// Should not be modified.
|
|
|
|
const int kTotalSize = 100;
|
|
BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
|
|
BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
|
|
VarHandle x("x", kInt);
|
|
auto for_body = Block::make(
|
|
{Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})),
|
|
Store::make(b_buf, {x}, x * 2)});
|
|
auto for_stmt = For::make(x, 0, 100, for_body);
|
|
Block::make({for_stmt});
|
|
|
|
LoopNest::normalize(for_stmt);
|
|
|
|
auto result = IRSimplifier::simplify(for_stmt);
|
|
std::ostringstream oss;
|
|
oss << *result;
|
|
const std::string& expected_ir =
|
|
R"IR(
|
|
# CHECK: for (int x = 0; x < 100; x++) {
|
|
# CHECK: A[x] = B[x];
|
|
# CHECK: B[x] = 2 * x;
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
|
}
|
|
|
|
TEST(LoopNest, NormalizeStartVariable) {
|
|
// Input IR:
|
|
// for (int x = y; x < 100; x++) {
|
|
// A[x] = B[x];
|
|
// B[x] = x * 2;
|
|
// }
|
|
|
|
const int kTotalSize = 100;
|
|
BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
|
|
BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
auto for_body = Block::make(
|
|
{Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})),
|
|
Store::make(b_buf, {x}, x * 2)});
|
|
auto for_stmt = For::make(x, y, 100, for_body);
|
|
auto parent_block = Block::make({for_stmt});
|
|
|
|
LoopNest::normalize(for_stmt);
|
|
|
|
auto result = IRSimplifier::simplify(for_stmt);
|
|
std::ostringstream oss;
|
|
oss << *result;
|
|
const std::string& expected_ir =
|
|
R"IR(
|
|
# CHECK: for (int x = 0; x < 100 - y; x++) {
|
|
# CHECK: A[x + y] = B[x + y];
|
|
# CHECK: B[x + y] = 2 * (x + y);
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
|
}
|
|
|
|
TEST(LoopNest, NormalizeOnNestedOuterLoop) {
|
|
// Input IR:
|
|
// for (int x = 50; x < 100; x++) {
|
|
// for (int y = 10; y < 100; y++) {
|
|
// A[x] = A[x] + B[y] + y * 2;
|
|
// }
|
|
// }
|
|
|
|
BufHandle a_buf("A", {ExprHandle(50)}, kInt);
|
|
BufHandle b_buf("B", {ExprHandle(100)}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
auto inner_for_body = Store::make(
|
|
a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2);
|
|
auto inner_for = For::make(y, 10, 100, inner_for_body);
|
|
auto for_stmt = For::make(x, 50, 100, inner_for);
|
|
Block::make({for_stmt});
|
|
|
|
LoopNest::normalize(for_stmt);
|
|
|
|
auto result = IRSimplifier::simplify(for_stmt);
|
|
std::ostringstream oss;
|
|
oss << *result;
|
|
const std::string& expected_ir =
|
|
R"IR(
|
|
# CHECK: for (int x = 0; x < 50; x++) {
|
|
# CHECK: for (int y = 10; y < 100; y++) {
|
|
# CHECK: A[x + 50] = ((A[x + 50]) + (B[y])) + 2 * y;
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
|
}
|
|
|
|
TEST(LoopNest, NormalizeOnNestedInnerLoop) {
|
|
// Input IR:
|
|
// for (int x = 50; x < 100; x++) {
|
|
// for (int y = 10; y < 100; y++) {
|
|
// A[x] = A[x] + B[y] + y * 2;
|
|
// }
|
|
// }
|
|
|
|
BufHandle a_buf("A", {ExprHandle(50)}, kInt);
|
|
BufHandle b_buf("B", {ExprHandle(100)}, kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
auto inner_for_body = Store::make(
|
|
a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2);
|
|
auto inner_for = For::make(y, 10, 100, inner_for_body);
|
|
auto for_stmt = For::make(x, 50, 100, inner_for);
|
|
Block::make({for_stmt});
|
|
|
|
LoopNest::normalize(inner_for);
|
|
|
|
auto result = IRSimplifier::simplify(for_stmt);
|
|
std::ostringstream oss;
|
|
oss << *result;
|
|
const std::string& expected_ir =
|
|
R"IR(
|
|
# CHECK: for (int x = 50; x < 100; x++) {
|
|
# CHECK: for (int y = 0; y < 90; y++) {
|
|
# CHECK: A[x] = (((A[x]) + (B[y + 10])) + 2 * y) + 20;
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
|
}
|
|
|
|
TEST(LoopNest, NormalizeAndSplitWithTail) {
|
|
// Create a dummy tensor to construct LoopNest.
|
|
ExprHandle n(100);
|
|
BufHandle a("a", {n}, kFloat);
|
|
Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
|
|
LoopNest l({b});
|
|
|
|
// Input IR:
|
|
// for (int x = 5; x < 10; x++) {
|
|
// A[x] = x * 2;
|
|
// }
|
|
const int kTotalSize = 5;
|
|
BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
|
|
VarHandle x("x", kInt);
|
|
auto for_stmt = For::make(x, 5, 10, Store::make(a_buf, {x}, x * 2));
|
|
auto parent_block = Block::make({for_stmt});
|
|
|
|
LoopNest::normalize(for_stmt);
|
|
|
|
ForPtr x_inner;
|
|
ForPtr x_tail;
|
|
LoopNest::splitWithTail(for_stmt, 10, &x_inner, &x_tail);
|
|
|
|
auto x_outer_result = IRSimplifier::simplify(for_stmt);
|
|
std::ostringstream oss_outer;
|
|
oss_outer << *x_outer_result;
|
|
const std::string& expected_outer_ir =
|
|
R"IR(
|
|
# CHECK: {
|
|
# CHECK: }
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str());
|
|
|
|
auto x_tail_result = IRSimplifier::simplify(x_tail);
|
|
std::ostringstream oss_tail;
|
|
oss_tail << *x_tail_result;
|
|
const std::string& expected_tail_ir =
|
|
R"IR(
|
|
# CHECK: for (int x_tail = 0; x_tail < 5; x_tail++) {
|
|
# CHECK: A[x_tail + 5] = 2 * (x_tail + 5);
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str());
|
|
}
|
|
|
|
TEST(LoopNest, NotNormalizeAndSplitWithTail) {
|
|
// Create a dummy tensor to construct LoopNest.
|
|
ExprHandle n(100);
|
|
BufHandle a("a", {n}, kFloat);
|
|
Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
|
|
LoopNest l({b});
|
|
|
|
// Input IR:
|
|
// for (int x = 5; x < 15; x++) {
|
|
// A[x] = x * 2;
|
|
// }
|
|
const int kTotalSize = 10;
|
|
BufHandle a_buf("A", {kTotalSize}, kInt);
|
|
VarHandle x("x", kInt);
|
|
auto for_stmt = For::make(x, 5, 15, Store::make(a_buf, {x}, x * 2));
|
|
auto parent_block = Block::make({for_stmt});
|
|
|
|
ForPtr x_inner;
|
|
ForPtr x_tail;
|
|
LoopNest::splitWithTail(for_stmt, 8, &x_inner, &x_tail);
|
|
|
|
auto x_outer_result = IRSimplifier::simplify(for_stmt);
|
|
std::ostringstream oss_outer;
|
|
oss_outer << *x_outer_result;
|
|
const std::string& expected_outer_ir =
|
|
R"IR(
|
|
# CHECK: {
|
|
# CHECK: }
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str());
|
|
|
|
auto x_tail_result = IRSimplifier::simplify(x_tail);
|
|
std::ostringstream oss_tail;
|
|
oss_tail << *x_tail_result;
|
|
const std::string& expected_tail_ir =
|
|
R"IR(
|
|
# CHECK: for (int x_tail = 0; x_tail < 2; x_tail++) {
|
|
# CHECK: A[x_tail + 13] = 2 * (x_tail + 13);
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str());
|
|
}
|
|
|
|
TEST(LoopNest, FlattenSimpleLoopNest2D) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 10; i++) {
|
|
// for (int j = 0; j < 5; j++) {
|
|
// A[i,j] = i * j;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {10, 5}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
|
|
auto inner_for = For::make(j, 0, 5, for_body);
|
|
auto outer_for = For::make(i, 0, 10, inner_for);
|
|
auto parent_block = Block::make({outer_for});
|
|
|
|
std::vector<ForPtr> loops = {outer_for, inner_for};
|
|
ForPtr flattened = nullptr;
|
|
ASSERT_TRUE(LoopNest::flatten(loops, &flattened));
|
|
ASSERT_EQ(flattened, loops.front());
|
|
|
|
auto result = IRSimplifier::simplify(flattened);
|
|
std::ostringstream oss;
|
|
oss << *result;
|
|
const std::string& expected_ir =
|
|
R"IR(
|
|
# CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) {
|
|
# CHECK: A[i_flat / 5, i_flat % 5] =
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
|
|
|
{
|
|
SimpleIREvaluator eval1(loops[0], {a_buf});
|
|
PaddedBuffer<int> inp1(10, 5);
|
|
eval1(inp1);
|
|
SimpleIREvaluator eval2(flattened, {a_buf});
|
|
PaddedBuffer<int> inp2(10, 5);
|
|
eval2(inp2);
|
|
ExpectAllNear(inp1, inp2, 1e-5);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, FlattenSimpleLoopNest3D) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 10; i++) {
|
|
// for (int j = 0; j < 5; j++) {
|
|
// for (int k = 0; k < 7; k++) {
|
|
// A[i,j,k] = i + j * k;
|
|
// }
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {10, 5, 7}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto for_body = Block::make({Store::make(a_buf, {i, j, k}, i + j * k)});
|
|
auto for1 = For::make(k, 0, 7, for_body);
|
|
auto for2 = For::make(j, 0, 5, for1);
|
|
auto for3 = For::make(i, 0, 10, for2);
|
|
auto parent_block = Block::make({for3});
|
|
|
|
std::vector<ForPtr> loops = {for3, for2, for1};
|
|
ForPtr flattened = nullptr;
|
|
ASSERT_TRUE(LoopNest::flatten(loops, &flattened));
|
|
ASSERT_EQ(flattened, loops.front());
|
|
|
|
auto result = IRSimplifier::simplify(flattened);
|
|
std::ostringstream oss;
|
|
oss << *result;
|
|
const std::string& expected_ir =
|
|
R"IR(
|
|
# CHECK: for (int i_flat = 0; i_flat < 350; i_flat++) {
|
|
# CHECK: A[i_flat / 35, (i_flat / 7) % 5, i_flat % 7] =
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
|
|
|
{
|
|
SimpleIREvaluator eval1(loops[0], {a_buf});
|
|
PaddedBuffer<int> inp1(10, 5, 7);
|
|
eval1(inp1);
|
|
SimpleIREvaluator eval2(flattened, {a_buf});
|
|
PaddedBuffer<int> inp2(10, 5, 7);
|
|
eval2(inp2);
|
|
ExpectAllNear(inp1, inp2, 1e-5);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, FlattenLoopNestAfterNormalize) {
|
|
// Input IR:
|
|
// for (int i = 2; i < 10; i++) {
|
|
// for (int j = 3; j < 15; j++) {
|
|
// A[i - 2,j - 3] = i * j;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {8, 12}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto for_body = Block::make({Store::make(a_buf, {i - 2, j - 3}, i * j)});
|
|
auto inner_for = For::make(j, 3, 15, for_body);
|
|
auto outer_for = For::make(i, 2, 10, inner_for);
|
|
auto parent_block = Block::make({outer_for});
|
|
|
|
std::vector<ForPtr> loops = {outer_for, inner_for};
|
|
ForPtr flattened = nullptr;
|
|
ASSERT_TRUE(LoopNest::flatten(loops, &flattened));
|
|
ASSERT_EQ(flattened, loops.front());
|
|
|
|
auto result = IRSimplifier::simplify(flattened);
|
|
std::ostringstream oss;
|
|
oss << *result;
|
|
const std::string& expected_ir =
|
|
R"IR(
|
|
# CHECK: for (int i_flat = 0; i_flat < 96; i_flat++) {
|
|
# CHECK: A[i_flat / 12, i_flat % 12] =
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
|
|
|
{
|
|
SimpleIREvaluator eval1(loops[0], {a_buf});
|
|
PaddedBuffer<int> inp1(8, 12);
|
|
eval1(inp1);
|
|
SimpleIREvaluator eval2(flattened, {a_buf});
|
|
PaddedBuffer<int> inp2(8, 12);
|
|
eval2(inp2);
|
|
ExpectAllNear(inp1, inp2, 1e-5);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, FlattenLoopNestWithNonLiteralConstantBounds) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 15-5; i++) {
|
|
// for (int j = 0; j < 20/4; j++) {
|
|
// A[i,j] = i * j;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {10, 5}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
|
|
auto inner_for =
|
|
For::make(j, 0, IntImm::make(20) / IntImm::make(4), for_body);
|
|
auto outer_for =
|
|
For::make(i, 0, IntImm::make(15) - IntImm::make(5), inner_for);
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
auto b = Block::make({outer_for});
|
|
|
|
std::vector<ForPtr> loops = {outer_for, inner_for};
|
|
ForPtr flattened = nullptr;
|
|
ASSERT_TRUE(LoopNest::flatten(loops, &flattened));
|
|
ASSERT_EQ(flattened, loops.front());
|
|
|
|
auto result = IRSimplifier::simplify(flattened);
|
|
checkIR(result, R"IR(
|
|
# CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) {
|
|
# CHECK: A[i_flat / 5, i_flat % 5] =
|
|
)IR");
|
|
|
|
{
|
|
SimpleIREvaluator eval1(loops[0], {a_buf});
|
|
PaddedBuffer<int> inp1(10, 5);
|
|
eval1(inp1);
|
|
SimpleIREvaluator eval2(flattened, {a_buf});
|
|
PaddedBuffer<int> inp2(10, 5);
|
|
eval2(inp2);
|
|
ExpectAllNear(inp1, inp2, 1e-5);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, FlattenImperfectLoopNest) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 10; i++) {
|
|
// A[i, i] = 0;
|
|
// for (int j = 0; j < 15; j++) {
|
|
// A[i,j] = i * j;
|
|
// }
|
|
// }
|
|
// Do not flatten.
|
|
|
|
BufHandle a_buf("A", {10, 15}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
|
|
auto inner_for = For::make(j, 0, 15, for_body);
|
|
auto outer_for = For::make(
|
|
i, 0, 10, Block::make({Store::make(a_buf, {i, i}, 0), inner_for}));
|
|
auto par = Block::make({outer_for});
|
|
HashProvider hasher;
|
|
auto hash_before = hasher.hash(par);
|
|
|
|
std::vector<ForPtr> loops = {outer_for, inner_for};
|
|
ForPtr flattened = nullptr;
|
|
ASSERT_FALSE(LoopNest::flatten(loops, &flattened));
|
|
ASSERT_EQ(flattened, nullptr);
|
|
auto hash_after = hasher.hash(par);
|
|
ASSERT_EQ(hash_before, hash_after);
|
|
}
|
|
|
|
TEST(LoopNest, FlattenReductionLoopNest) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 10; i++) {
|
|
// S[i] = 0;
|
|
// for (int j = 0; j < 15; j++) {
|
|
// S[i] = S[i] + A[i,j];
|
|
// }
|
|
// }
|
|
// Do not flatten.
|
|
|
|
BufHandle a_buf("A", {10, 15}, kInt);
|
|
BufHandle s_buf("S", {10}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto for_body = Block::make({Store::make(
|
|
s_buf, {i}, Load::make(s_buf, {i}) + Load::make(a_buf, {i, j}))});
|
|
auto inner_for = For::make(j, 0, 15, for_body);
|
|
auto outer_for =
|
|
For::make(i, 0, 10, Block::make({Store::make(s_buf, {i}, 0), inner_for}));
|
|
auto par = Block::make({outer_for});
|
|
HashProvider hasher;
|
|
auto hash_before = hasher.hash(par);
|
|
|
|
std::vector<ForPtr> loops = {outer_for, inner_for};
|
|
ForPtr flattened = nullptr;
|
|
ASSERT_FALSE(LoopNest::flatten(loops, &flattened));
|
|
ASSERT_EQ(flattened, nullptr);
|
|
auto hash_after = hasher.hash(par);
|
|
ASSERT_EQ(hash_before, hash_after);
|
|
}
|
|
|
|
TEST(LoopNest, FlattenReductionLoopNestFromTensor) {
|
|
const int M = 3;
|
|
const int N = 7;
|
|
VarHandle m("m", kInt);
|
|
VarHandle n("n", kInt);
|
|
BufHandle b("b", {m, n}, kFloat);
|
|
Tensor c = Reduce("sum", {M}, Sum(), b, {N});
|
|
LoopNest loop({c});
|
|
HashProvider hasher;
|
|
auto hash_before = hasher.hash(loop.root_stmt());
|
|
|
|
auto loops = loop.getAllLoopNestsWritingToBuf(c.buf())[1];
|
|
ForPtr flattened = nullptr;
|
|
ASSERT_FALSE(LoopNest::flatten(loops, &flattened));
|
|
ASSERT_EQ(flattened, nullptr);
|
|
auto hash_after = hasher.hash(loop.root_stmt());
|
|
ASSERT_EQ(hash_before, hash_after);
|
|
}
|
|
|
|
TEST(LoopNest, FlattenIncorrectLoopsAsInput) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 10; i++) {
|
|
// for (int j = 0; j < 5; j++) {
|
|
// A[i,j] = i * j;
|
|
// }
|
|
// }
|
|
// for (int x = 0; x < 10; x++) {
|
|
// for (int y = 0; y < 5; y++) {
|
|
// A[x,y] = A[x,y] + x + y;
|
|
// }
|
|
// }
|
|
// Flatten({For_i, For_y}) => should not succeed
|
|
|
|
BufHandle a_buf("A", {10, 5}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)});
|
|
auto inner_for1 = For::make(j, 0, 5, for_body1);
|
|
auto outer_for1 = For::make(i, 0, 10, inner_for1);
|
|
auto for_body2 = Block::make(
|
|
{Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)});
|
|
auto inner_for2 = For::make(y, 0, 5, for_body2);
|
|
auto outer_for2 = For::make(x, 0, 10, inner_for2);
|
|
auto par = Block::make({outer_for1, outer_for2});
|
|
HashProvider hasher;
|
|
auto hash_before = hasher.hash(par);
|
|
|
|
std::vector<ForPtr> loops = {outer_for1, inner_for2};
|
|
ForPtr flattened = nullptr;
|
|
ASSERT_FALSE(LoopNest::flatten(loops, &flattened));
|
|
ASSERT_EQ(flattened, nullptr);
|
|
auto hash_after = hasher.hash(par);
|
|
ASSERT_EQ(hash_before, hash_after);
|
|
}
|
|
|
|
TEST(LoopNest, DetectInlineRankMismatch) {
|
|
const int kTotalSize = 8;
|
|
|
|
BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
|
|
Tensor a = Compute(
|
|
"a", {kTotalSize}, [&](const VarHandle& i) { return a_buf.load(i); });
|
|
Tensor reshape = Compute(
|
|
"reshape",
|
|
{kTotalSize / 2, 2},
|
|
[&](const VarHandle& i, const VarHandle& j) { return a.load(i, j); });
|
|
LoopNest l({reshape}, {a, reshape});
|
|
ASSERT_FALSE(l.computeInline(l.getLoopBodyFor(a)));
|
|
}
|
|
|
|
TEST(LoopNest, CacheReadsSimple) {
|
|
Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
|
|
return i * j;
|
|
});
|
|
Tensor B =
|
|
Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
|
|
return A.load(i + 30, j + 3);
|
|
});
|
|
Tensor C =
|
|
Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
|
|
return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
|
|
});
|
|
|
|
LoopNest l({B, C}, {A, B, C});
|
|
StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1];
|
|
LoopNest::cacheAccesses(A.buf(), "A_local", j_loop);
|
|
|
|
l.prepareForCodegen();
|
|
StmtPtr result =
|
|
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
|
|
SimpleIREvaluator cg(result, {B, C});
|
|
result = cg.stmt();
|
|
|
|
// just this once: verify the whole thing.
|
|
checkIR(result, R"IR(
|
|
#CHECK: Allocate(A); // dtype=int, dims=[64, 64]
|
|
#CHECK: Allocate(A_local); // dtype=int, dims=[1, 10]
|
|
#CHECK: for (int i
|
|
#CHECK: for (int j
|
|
#CHECK: A[
|
|
#CHECK: }
|
|
#CHECK: }
|
|
#CHECK: for (int i_1
|
|
#CHECK: for (int j_1
|
|
#CHECK: A_local[j_1] = A[
|
|
#CHECK: }
|
|
#CHECK: for (int j_2
|
|
#CHECK: B[j_2 + 10 * i_1] = A_local[j_2];
|
|
#CHECK: }
|
|
#CHECK: }
|
|
#CHECK: for (int i_2
|
|
#CHECK: for (int j_3
|
|
#CHECK: C[
|
|
#CHECK: }
|
|
#CHECK: }
|
|
#CHECK: Free(A_local);
|
|
#CHECK: Free(A);
|
|
)IR");
|
|
|
|
std::vector<int> b_data(200, 0);
|
|
std::vector<int> c_data(200, 0);
|
|
cg.call({b_data, c_data});
|
|
|
|
std::vector<int> b_ref(200, 0);
|
|
std::vector<int> c_ref(200, 0);
|
|
|
|
for (int i = 0; i < 20; ++i) {
|
|
for (int j = 0; j < 10; ++j) {
|
|
b_ref[i * 10 + j] = (i + 30) * (j + 3);
|
|
c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40);
|
|
}
|
|
}
|
|
|
|
assertAllEqual(b_data, b_ref);
|
|
assertAllEqual(c_data, c_ref);
|
|
}
|
|
|
|
TEST(LoopNest, CacheReadsOuter) {
|
|
Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
|
|
return i * j;
|
|
});
|
|
Tensor B =
|
|
Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
|
|
return A.load(i + 30, j + 40) + A.load(i + 31, j + 41);
|
|
});
|
|
Tensor C =
|
|
Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
|
|
return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
|
|
});
|
|
|
|
LoopNest l({B, C}, {A, B, C});
|
|
StmtPtr i_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][0];
|
|
LoopNest::cacheAccesses(A.buf(), "A_local", i_loop);
|
|
|
|
l.prepareForCodegen();
|
|
StmtPtr result =
|
|
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
|
|
SimpleIREvaluator cg(result, {B, C});
|
|
result = cg.stmt();
|
|
|
|
checkIR(result, R"IR(
|
|
#CHECK: Allocate(A_local); // dtype=int, dims=[21, 11]
|
|
#CHECK: A_local[j_1 + 11 * i_1] =
|
|
#CHECK: B[j_2 + 10 * i_2] = (A_local[j_2 + 11 * i_2]) + (A_local[(j_2 + 11 * i_2) + 12]);
|
|
)IR");
|
|
|
|
std::vector<int> b_data(200, 0);
|
|
std::vector<int> c_data(200, 0);
|
|
cg.call({b_data, c_data});
|
|
|
|
std::vector<int> b_ref(200, 0);
|
|
std::vector<int> c_ref(200, 0);
|
|
|
|
for (int i = 0; i < 20; ++i) {
|
|
for (int j = 0; j < 10; ++j) {
|
|
b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41);
|
|
c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40);
|
|
}
|
|
}
|
|
|
|
assertAllEqual(b_data, b_ref);
|
|
assertAllEqual(c_data, c_ref);
|
|
}
|
|
|
|
TEST(LoopNest, CacheReadsInternal) {
|
|
Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
|
|
return i * j;
|
|
});
|
|
Tensor B =
|
|
Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
|
|
return A.load(i + 30, j + 40) + A.load(i + 31, j + 41);
|
|
});
|
|
Tensor C =
|
|
Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
|
|
return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
|
|
});
|
|
|
|
LoopNest l({B, C}, {A, B, C});
|
|
StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1];
|
|
LoopNest::cacheAccesses(A.buf(), "A_local", j_loop);
|
|
l.prepareForCodegen();
|
|
StmtPtr result =
|
|
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
|
|
SimpleIREvaluator cg(result, {B, C});
|
|
result = cg.stmt();
|
|
|
|
checkIR(result, R"IR(
|
|
#CHECK: Allocate(A_local); // dtype=int, dims=[2, 11]
|
|
#CHECK: A_local[k + 11 * j_1] =
|
|
#CHECK: B[j_2 + 10 * i_1] = (A_local[j_2 + 12]) + (A_local[j_2]);
|
|
)IR");
|
|
|
|
std::vector<int> b_data(200, 0);
|
|
std::vector<int> c_data(200, 0);
|
|
cg.call({b_data, c_data});
|
|
|
|
std::vector<int> b_ref(200, 0);
|
|
std::vector<int> c_ref(200, 0);
|
|
|
|
for (int i = 0; i < 20; ++i) {
|
|
for (int j = 0; j < 10; ++j) {
|
|
b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41);
|
|
c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40);
|
|
}
|
|
}
|
|
|
|
assertAllEqual(b_data, b_ref);
|
|
assertAllEqual(c_data, c_ref);
|
|
}
|
|
|
|
TEST(LoopNest, CacheReadsInner) {
|
|
Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
|
|
return i * j;
|
|
});
|
|
// note im changing the offset of the first arg of the first call to A.
|
|
Tensor B =
|
|
Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
|
|
return A.load(i + 34, j + 40) + A.load(i + 30, j + 41);
|
|
});
|
|
Tensor C =
|
|
Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
|
|
return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
|
|
});
|
|
|
|
LoopNest l({B, C}, {A, B, C});
|
|
StmtPtr body = l.getLoopBodyFor(B);
|
|
LoopNest::cacheAccesses(A.buf(), "A_local", body);
|
|
l.prepareForCodegen();
|
|
StmtPtr result =
|
|
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
|
|
SimpleIREvaluator cg(result, {B, C});
|
|
result = cg.stmt();
|
|
|
|
checkIR(result, R"IR(
|
|
#CHECK: Allocate(A_local); // dtype=int, dims=[5, 2]
|
|
#CHECK: A_local[l + 2 * k] =
|
|
#CHECK: B[j_1 + 10 * i_1] = (A_local[1]) + (A_local[8]);
|
|
)IR");
|
|
|
|
std::vector<int> b_data(200, 0);
|
|
std::vector<int> c_data(200, 0);
|
|
cg.call({b_data, c_data});
|
|
|
|
std::vector<int> b_ref(200, 0);
|
|
std::vector<int> c_ref(200, 0);
|
|
|
|
for (int i = 0; i < 20; ++i) {
|
|
for (int j = 0; j < 10; ++j) {
|
|
b_ref[i * 10 + j] = (i + 34) * (j + 40) + (i + 30) * (j + 41);
|
|
c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40);
|
|
}
|
|
}
|
|
|
|
assertAllEqual(b_data, b_ref);
|
|
assertAllEqual(c_data, c_ref);
|
|
}
|
|
|
|
TEST(LoopNest, CacheWritesSimple) {
|
|
Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
|
|
return i * j;
|
|
});
|
|
Tensor B =
|
|
Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
|
|
return A.load(i + 30, j + 40) + A.load(i + 31, j + 41);
|
|
});
|
|
Tensor C =
|
|
Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
|
|
return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
|
|
});
|
|
|
|
LoopNest l({B, C}, {A, B, C});
|
|
StmtPtr a_loop = l.getAllLoopNestsWritingToBuf(A.buf())[0][1];
|
|
LoopNest::cacheAccesses(A.buf(), "A_local", a_loop);
|
|
|
|
l.prepareForCodegen();
|
|
StmtPtr result =
|
|
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
|
|
SimpleIREvaluator cg(result, {B, C});
|
|
result = cg.stmt();
|
|
|
|
checkIR(result, R"IR(
|
|
#CHECK: Allocate(A_local); // dtype=int, dims=[1, 64]
|
|
#CHECK: for (int j = 0; j < 64
|
|
#CHECK: A_local[j] = i * j;
|
|
#CHECK: for (int j_1 = 0; j_1 < 64
|
|
#CHECK: A[j_1 + 64 * i] = A_local[
|
|
#CHECK: Free(A_local);
|
|
#CHECK-NOT: A_local
|
|
)IR");
|
|
|
|
std::vector<int> b_data(200, 0);
|
|
std::vector<int> c_data(200, 0);
|
|
cg.call({b_data, c_data});
|
|
|
|
std::vector<int> b_ref(200, 0);
|
|
std::vector<int> c_ref(200, 0);
|
|
|
|
for (int i = 0; i < 20; ++i) {
|
|
for (int j = 0; j < 10; ++j) {
|
|
b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41);
|
|
c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40);
|
|
}
|
|
}
|
|
|
|
assertAllEqual(b_data, b_ref);
|
|
assertAllEqual(c_data, c_ref);
|
|
}
|
|
|
|
TEST(LoopNest, DeadStoreElimination) {
|
|
VarHandle y("y", kInt);
|
|
VarHandle x("x_tail", kInt);
|
|
BufHandle f("f", {26, 5}, kInt);
|
|
BufHandle g("g", {26, 5}, kInt);
|
|
ExprHandle x_outer_end = 5;
|
|
ExprHandle x_2 = x + x_outer_end * 4;
|
|
ForPtr stmt1 = For::make(
|
|
x,
|
|
0,
|
|
5,
|
|
For::make(
|
|
y,
|
|
0,
|
|
5,
|
|
Block::make({
|
|
Store::make(f, {x_2, y}, (x_2 + y)),
|
|
Store::make(g, {x_2, y}, (x_2 * y)),
|
|
})));
|
|
StmtPtr stmt = Block::make({stmt1});
|
|
|
|
// Will eliminate if not used by an output.
|
|
LoopNest loop(Stmt::clone(stmt), {f.node()});
|
|
loop.eliminateDeadStores();
|
|
|
|
checkIR(loop.root_stmt(), R"IR(
|
|
#CHECK: f[x_tail + 5 * 4, y]
|
|
#CHECK-NOT: g[x_tail + 5 * 4, y]
|
|
)IR");
|
|
|
|
// But won't eliminate if used by different outputs.
|
|
LoopNest loop2(stmt, {f.node(), g.node()});
|
|
loop2.eliminateDeadStores();
|
|
|
|
checkIR(loop2.root_stmt(), R"IR(
|
|
#CHECK: f[x_tail + 5 * 4, y]
|
|
#CHECK: g[x_tail + 5 * 4, y]
|
|
)IR");
|
|
}
|
|
|
|
TEST(LoopNest, DeadStoreEliminationWithIntermediates) {
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
VarHandle z("z", kInt);
|
|
BufHandle f("f", {26 * 5}, kInt);
|
|
BufHandle g("g", {26 * 5}, kInt);
|
|
BufHandle h("h", {26, 5}, kInt);
|
|
ExprHandle x_outer_end = 5;
|
|
ExprHandle x_2 = x + x_outer_end * 4;
|
|
ForPtr stmt1 = For::make(x, 0, 26 * 5, Store::make(f, {x}, x));
|
|
ForPtr stmt2 = For::make(z, 0, 26 * 5, Store::make(g, {z}, z + 1));
|
|
ForPtr stmt3 = For::make(
|
|
x,
|
|
0,
|
|
5,
|
|
For::make(
|
|
y,
|
|
0,
|
|
5,
|
|
Block::make({
|
|
Store::make(h, {x, y}, Load::make(f, {x * y})),
|
|
})));
|
|
StmtPtr stmt = Block::make({stmt1, stmt2, stmt3});
|
|
|
|
// Will eliminate the write to g, but not f since it used by the producer of
|
|
// h.
|
|
LoopNest loop(Stmt::clone(stmt), {h.node()});
|
|
loop.eliminateDeadStores();
|
|
|
|
checkIR(loop.root_stmt(), R"IR(
|
|
#CHECK: f[x] = x;
|
|
#CHECK-NOT: g[z] =
|
|
#CHECK: h[x, y] = f[x * y];
|
|
)IR");
|
|
|
|
// Sanity check won't eliminate if g is an output.
|
|
LoopNest loop2(stmt, {h.node(), g.node()});
|
|
loop2.eliminateDeadStores();
|
|
|
|
checkIR(loop2.root_stmt(), R"IR(
|
|
#CHECK: f[x] = x;
|
|
#CHECK: g[z] = z + 1;
|
|
#CHECK: h[x, y] = f[x * y];
|
|
)IR");
|
|
}
|
|
|
|
TEST(LoopNest, CompoundTensorSimple) {
|
|
BufHandle a_buf("A", {10, 5}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)});
|
|
auto inner_for1 = For::make(j, 0, 5, for_body1);
|
|
auto outer_for1 = For::make(i, 0, 10, inner_for1);
|
|
auto for_body2 = Block::make(
|
|
{Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)});
|
|
auto inner_for2 = For::make(y, 0, 5, for_body2);
|
|
auto outer_for2 = For::make(x, 0, 10, inner_for2);
|
|
BlockPtr body = Block::make({outer_for1, outer_for2});
|
|
|
|
Tensor A = Tensor(a_buf.node(), body);
|
|
|
|
LoopNest l({A});
|
|
l.prepareForCodegen();
|
|
|
|
std::vector<int> a_data(50, 0);
|
|
|
|
StmtPtr s = IRSimplifier::simplify(l.root_stmt());
|
|
SimpleIREvaluator cg(s, {A});
|
|
|
|
std::vector<int> a_ref(50, 0);
|
|
|
|
for (int i = 0; i < 10; ++i) {
|
|
for (int j = 0; j < 5; ++j) {
|
|
a_ref[i * 5 + j] = (i * j) + i + j;
|
|
}
|
|
}
|
|
cg.call({a_data});
|
|
|
|
assertAllEqual(a_data, a_ref);
|
|
}
|
|
|
|
TEST(LoopNest, InlineConstantIndex) {
|
|
const int N = 10;
|
|
BufHandle x_buf("a", {1, N, 1}, kFloat);
|
|
Tensor y = Compute(
|
|
"f",
|
|
{1, N, 1},
|
|
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) {
|
|
return x_buf.load(m, n, o);
|
|
});
|
|
Tensor z = Compute(
|
|
"f",
|
|
{1, N, 1},
|
|
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) {
|
|
return y.load(m, n, o);
|
|
});
|
|
|
|
LoopNest l({z}, {y, z});
|
|
l.simplify();
|
|
ASSERT_TRUE(l.computeInline(y.buf()));
|
|
}
|
|
|
|
TEST(LoopNest, CompoundTensorUsed) {
|
|
BufHandle a_buf("A", {10, 5}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle x("x", kInt);
|
|
VarHandle y("y", kInt);
|
|
auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)});
|
|
auto inner_for1 = For::make(j, 0, 5, for_body1);
|
|
auto outer_for1 = For::make(i, 0, 10, inner_for1);
|
|
auto for_body2 = Block::make(
|
|
{Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)});
|
|
auto inner_for2 = For::make(y, 0, 5, for_body2);
|
|
auto outer_for2 = For::make(x, 0, 10, inner_for2);
|
|
BlockPtr body = Block::make({outer_for1, outer_for2});
|
|
|
|
Tensor A = Tensor(a_buf.node(), body);
|
|
Tensor B = Compute("B", {10, 3}, [&](const VarHandle& i, const VarHandle& j) {
|
|
return A.load(i, j + 1) + A.load(i, j + 2);
|
|
});
|
|
|
|
LoopNest l({B}, {A, B});
|
|
ASSERT_FALSE(l.computeInline(A.buf()));
|
|
l.prepareForCodegen();
|
|
|
|
std::vector<int> a_data(50, 0);
|
|
std::vector<int> b_data(50, 0);
|
|
|
|
StmtPtr s = IRSimplifier::simplify(l.root_stmt());
|
|
SimpleIREvaluator cg(s, {B});
|
|
|
|
std::vector<int> b_ref(50, 0);
|
|
|
|
auto AT = [](int i, int j) { return i * j + i + j; };
|
|
for (int i = 0; i < 10; ++i) {
|
|
for (int j = 0; j < 3; ++j) {
|
|
b_ref[i * 3 + j] = AT(i, j + 1) + AT(i, j + 2);
|
|
}
|
|
}
|
|
cg.call({b_data});
|
|
|
|
assertAllEqual(b_data, b_ref);
|
|
}
|
|
|
|
TEST(LoopNest, InlineFromLoad) {
|
|
constexpr int N = 1024;
|
|
BufHandle a("A", {N}, kInt);
|
|
BufHandle b("B", {N}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto store_a = For::make(i, 0, N, Store::make(a, {i}, i));
|
|
auto store_b = For::make(j, 0, N, Store::make(b, {j}, Load::make(a, {j})));
|
|
LoopNest l(Block::make({store_a, store_b}), {b.node()});
|
|
|
|
l.computeInline(a.node());
|
|
|
|
// Check that A[j] is replaced with j after inlining
|
|
std::ostringstream oss;
|
|
oss << *l.root_stmt();
|
|
torch::jit::testing::FileCheck().run(
|
|
R"IR(
|
|
# CHECK: for (int j
|
|
# CHECK-NOT: B[j] = A[j]
|
|
# CHECK-NEXT: B[j] = j
|
|
)IR",
|
|
oss.str());
|
|
}
|
|
|
|
TEST(LoopNest, OptimizeConditionalsSimple) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5])
|
|
// }
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle b_buf("B", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle c_buf("C", {15}, kInt);
|
|
VarHandle i("i", kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto store = Store::make(
|
|
a_buf,
|
|
{i},
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, 5, kLT),
|
|
Load::make(b_buf, {i}),
|
|
Load::make(c_buf, {i - 5})));
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto forI = For::make(i, 0, 20, store);
|
|
auto par = Block::make({forI});
|
|
|
|
LoopNest nest(par, {a_buf.node()});
|
|
nest.optimizeConditionals();
|
|
|
|
std::ostringstream oss;
|
|
oss << *nest.root_stmt();
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i = 0; i < 5
|
|
# CHECK-NEXT: A[i] = B[i]
|
|
# CHECK: for (int i = 0; i < 15
|
|
# CHECK-NEXT: A[i + 5] = C[i]
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
TEST(LoopNest, OptimizeConditionalsNestedConditions) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10])
|
|
// }
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle b_buf("B", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle c_buf("C", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle d_buf("D", {10}, kInt);
|
|
VarHandle i("i", kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto store = Store::make(
|
|
a_buf,
|
|
{i},
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, 10, kLT),
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, 5, kLT),
|
|
Load::make(b_buf, {i}),
|
|
Load::make(c_buf, {i - 5})),
|
|
Load::make(d_buf, {i - 10})));
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto forI = For::make(i, 0, 20, store);
|
|
auto par = Block::make({forI});
|
|
|
|
LoopNest nest(par, {a_buf.node()});
|
|
nest.optimizeConditionals();
|
|
|
|
std::ostringstream oss;
|
|
oss << *nest.root_stmt();
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i = 0; i < 5
|
|
# CHECK-NEXT: A[i] = B[i]
|
|
# CHECK: for (int i = 0; i < 5
|
|
# CHECK-NEXT: A[i + 5] = C[i]
|
|
# CHECK: for (int i = 0; i < 10
|
|
# CHECK-NEXT: A[i + 10] = D[i]
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
TEST(LoopNest, OptimizeConditionalsMultipleStores) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5])
|
|
// }
|
|
// for (int j = 0; j < 100; j++) {
|
|
// B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j])
|
|
// }
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle b_buf("B", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle c_buf("C", {100}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle d_buf("D", {100}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto storeA = Store::make(
|
|
a_buf,
|
|
{i},
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, 5, kLT),
|
|
Load::make(b_buf, {i}),
|
|
Load::make(c_buf, {i - 5})));
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto forI = For::make(i, 0, 20, storeA);
|
|
auto storeB = Store::make(
|
|
b_buf,
|
|
{j},
|
|
IfThenElse::make(
|
|
CompareSelect::make(j, 30, kLT),
|
|
Load::make(c_buf, {j}),
|
|
Load::make(d_buf, {j})));
|
|
auto forJ = For::make(j, 0, 100, storeB);
|
|
auto par = Block::make({forI, forJ});
|
|
|
|
LoopNest nest(par, {a_buf.node()});
|
|
nest.optimizeConditionals();
|
|
|
|
std::ostringstream oss;
|
|
oss << *nest.root_stmt();
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i = 0; i < 5
|
|
# CHECK-NEXT: A[i] = B[i]
|
|
# CHECK: for (int i = 0; i < 15
|
|
# CHECK-NEXT: A[i + 5] = C[i]
|
|
# CHECK: for (int j = 0; j < 30
|
|
# CHECK-NEXT: B[j] = C[j]
|
|
# CHECK: for (int j = 0; j < 70
|
|
# CHECK-NEXT: B[j + 30] = D[j + 30]
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
TEST(LoopNest, OptimizeConditionalsMultipleStoresInOneLoop) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 50; i++) {
|
|
// A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5])
|
|
// B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j])
|
|
// }
|
|
// Only the first conditional, in the write to A, will be optimized.
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle a_buf("A", {100}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle b_buf("B", {100}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle c_buf("C", {100}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle d_buf("D", {100}, kInt);
|
|
VarHandle i("i", kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto storeA = Store::make(
|
|
a_buf,
|
|
{i},
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, 5, kLT),
|
|
Load::make(b_buf, {i}),
|
|
Load::make(c_buf, {i - 5})));
|
|
auto storeB = Store::make(
|
|
b_buf,
|
|
{i},
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, 30, kLT),
|
|
Load::make(c_buf, {i}),
|
|
Load::make(d_buf, {i})));
|
|
auto forI = For::make(i, 0, 50, Block::make({storeA, storeB}));
|
|
auto par = Block::make({forI});
|
|
|
|
LoopNest nest(par, {a_buf.node()});
|
|
nest.optimizeConditionals();
|
|
|
|
std::ostringstream oss;
|
|
oss << *nest.root_stmt();
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i = 0; i < 5
|
|
# CHECK-NEXT: A[i] = B[i]
|
|
# CHECK-NEXT: B[i] = C[i]
|
|
# CHECK: for (int i = 0; i < 45
|
|
# CHECK-NEXT: A[i + 5] = C[i]
|
|
# CHECK-NEXT: B[i + 5] = IfThenElse(i + 5<30 ? 1 : 0, C[i + 5], D[i + 5])
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
TEST(LoopNest, OptimizeConditionalsOuterLoopVar) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10])
|
|
// }
|
|
// }
|
|
// Currently, this case where the condition variable `i` is not the
|
|
// inner-most loop variable, is not optimized.
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle b_buf("B", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle c_buf("C", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle d_buf("D", {10}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto store = Store::make(
|
|
a_buf,
|
|
{i},
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, 10, kLT),
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, 5, kLT),
|
|
Load::make(b_buf, {i}),
|
|
Load::make(c_buf, {i - 5})),
|
|
Load::make(d_buf, {i - 10})));
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto forI = For::make(i, 0, 20, For::make(j, 0, 100, store));
|
|
auto par = Block::make({forI});
|
|
LoopNest nest(par, {a_buf.node()});
|
|
|
|
HashProvider hasher;
|
|
auto hash_before = hasher.hash(nest.root_stmt());
|
|
nest.optimizeConditionals();
|
|
auto hash_after = hasher.hash(nest.root_stmt());
|
|
ASSERT_EQ(hash_before, hash_after);
|
|
}
|
|
|
|
TEST(LoopNest, OptimizeConditionalsCompValuesNotOrdered) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// A[i] = IfThenElse(i<5, IfThenElse(i<10, B[i], C[i-5]), D[i-10])
|
|
// }
|
|
// No optimization should be done here because one of the conditions use '>'.
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle b_buf("B", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle c_buf("C", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle d_buf("D", {10}, kInt);
|
|
VarHandle i("i", kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto store = Store::make(
|
|
a_buf,
|
|
{i},
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, 5, kLT),
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, 10, kLT),
|
|
Load::make(b_buf, {i}),
|
|
Load::make(c_buf, {i - 5})),
|
|
Load::make(d_buf, {i - 10})));
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto forI = For::make(i, 0, 20, store);
|
|
auto par = Block::make({forI});
|
|
LoopNest nest(par, {a_buf.node()});
|
|
|
|
HashProvider hasher;
|
|
auto hash_before = hasher.hash(nest.root_stmt());
|
|
nest.optimizeConditionals();
|
|
auto hash_after = hasher.hash(nest.root_stmt());
|
|
ASSERT_EQ(hash_before, hash_after);
|
|
}
|
|
|
|
TEST(LoopNest, OptimizeConditionalsCompValuesNotConstants) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// A[i] = IfThenElse(i<N, IfThenElse(i<5, B[i], C[i-5]), D[i-10])
|
|
// }
|
|
// No optimization should be done here because one of the conditions use '>'.
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle b_buf("B", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle c_buf("C", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle d_buf("D", {10}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle N("N", kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto store = Store::make(
|
|
a_buf,
|
|
{i},
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, N, kLT),
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, 5, kLT),
|
|
Load::make(b_buf, {i}),
|
|
Load::make(c_buf, {i - 5})),
|
|
Load::make(d_buf, {i - 10})));
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto forI = For::make(i, 0, 20, store);
|
|
auto par = Block::make({forI});
|
|
LoopNest nest(par, {a_buf.node()});
|
|
|
|
HashProvider hasher;
|
|
auto hash_before = hasher.hash(nest.root_stmt());
|
|
nest.optimizeConditionals();
|
|
auto hash_after = hasher.hash(nest.root_stmt());
|
|
ASSERT_EQ(hash_before, hash_after);
|
|
}
|
|
|
|
TEST(LoopNest, OptimizeConditionalsInvalidCondition) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// A[i] = IfThenElse(i<10, IfThenElse(i>5, B[i], C[i-5]), D[i-10])
|
|
// }
|
|
// No optimization should be done here because one of the conditions use '>'.
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle b_buf("B", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle c_buf("C", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle d_buf("D", {10}, kInt);
|
|
VarHandle i("i", kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto store = Store::make(
|
|
a_buf,
|
|
{i},
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, 10, kLT),
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, 5, kGT),
|
|
Load::make(b_buf, {i}),
|
|
Load::make(c_buf, {i - 5})),
|
|
Load::make(d_buf, {i - 10})));
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto forI = For::make(i, 0, 20, store);
|
|
auto par = Block::make({forI});
|
|
LoopNest nest(par, {a_buf.node()});
|
|
|
|
HashProvider hasher;
|
|
auto hash_before = hasher.hash(nest.root_stmt());
|
|
nest.optimizeConditionals();
|
|
auto hash_after = hasher.hash(nest.root_stmt());
|
|
ASSERT_EQ(hash_before, hash_after);
|
|
}
|
|
|
|
TEST(LoopNest, OptimizeConditionalsInvalidCondition2) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// A[i] = IfThenElse(10<i, IfThenElse(i<5, B[i], C[i-5]), D[i-10])
|
|
// }
|
|
// No optimization should be done here because of the invalid condition:
|
|
// "10 < i".
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle b_buf("B", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle c_buf("C", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle d_buf("D", {10}, kInt);
|
|
VarHandle i("i", kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto store = Store::make(
|
|
a_buf,
|
|
{i},
|
|
IfThenElse::make(
|
|
CompareSelect::make(10, i, kLT),
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, 5, kLT),
|
|
Load::make(b_buf, {i}),
|
|
Load::make(c_buf, {i - 5})),
|
|
Load::make(d_buf, {i - 10})));
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto forI = For::make(i, 0, 20, store);
|
|
auto par = Block::make({forI});
|
|
LoopNest nest(par, {a_buf.node()});
|
|
|
|
HashProvider hasher;
|
|
auto hash_before = hasher.hash(nest.root_stmt());
|
|
nest.optimizeConditionals();
|
|
auto hash_after = hasher.hash(nest.root_stmt());
|
|
ASSERT_EQ(hash_before, hash_after);
|
|
}
|
|
|
|
TEST(LoopNest, OptimizeConditionalsInvalidCondition3) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// A[i] = IfThenElse(i<10, IfThenElse(k<5, B[i], C[i-5]), D[i-10])
|
|
// }
|
|
// No optimization should be done here because the conditions use different
|
|
// variables: "i < 10" and "k < 5"
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle b_buf("B", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle c_buf("C", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle d_buf("D", {10}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle k("k", kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto store = Store::make(
|
|
a_buf,
|
|
{i},
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, 10, kLT),
|
|
IfThenElse::make(
|
|
CompareSelect::make(k, 5, kLT),
|
|
Load::make(b_buf, {i}),
|
|
Load::make(c_buf, {i - 5})),
|
|
Load::make(d_buf, {i - 10})));
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto forI = For::make(i, 0, 20, store);
|
|
auto par = Block::make({forI});
|
|
LoopNest nest(par, {a_buf.node()});
|
|
|
|
HashProvider hasher;
|
|
auto hash_before = hasher.hash(nest.root_stmt());
|
|
nest.optimizeConditionals();
|
|
auto hash_after = hasher.hash(nest.root_stmt());
|
|
ASSERT_EQ(hash_before, hash_after);
|
|
}
|
|
|
|
TEST(LoopNest, OptimizeConditionalsInvalidCondition4) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// A[i] = IfThenElse(k<10, IfThenElse(k<5, B[i], C[i-5]), D[i-10])
|
|
// }
|
|
// No optimization should be done here because the conditions use the
|
|
// variable 'k' which is not a loop variable.
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle b_buf("B", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle c_buf("C", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle d_buf("D", {10}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle k("k", kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto store = Store::make(
|
|
a_buf,
|
|
{i},
|
|
IfThenElse::make(
|
|
CompareSelect::make(k, 10, kLT),
|
|
IfThenElse::make(
|
|
CompareSelect::make(k, 5, kLT),
|
|
Load::make(b_buf, {i}),
|
|
Load::make(c_buf, {i - 5})),
|
|
Load::make(d_buf, {i - 10})));
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto forI = For::make(i, 0, 20, store);
|
|
auto par = Block::make({forI});
|
|
LoopNest nest(par, {a_buf.node()});
|
|
|
|
HashProvider hasher;
|
|
auto hash_before = hasher.hash(nest.root_stmt());
|
|
nest.optimizeConditionals();
|
|
auto hash_after = hasher.hash(nest.root_stmt());
|
|
ASSERT_EQ(hash_before, hash_after);
|
|
}
|
|
|
|
TEST(LoopNest, OptimizeConditionalsNotNormalized) {
|
|
// Input IR:
|
|
// for (int i = 2; i < 20; i++) {
|
|
// A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5])
|
|
// }
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle b_buf("B", {5}, kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
BufHandle c_buf("C", {15}, kInt);
|
|
VarHandle i("i", kInt);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto store = Store::make(
|
|
a_buf,
|
|
{i},
|
|
IfThenElse::make(
|
|
CompareSelect::make(i, 5, kLT),
|
|
Load::make(b_buf, {i}),
|
|
Load::make(c_buf, {i - 5})));
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto forI = For::make(i, 2, 20, store);
|
|
auto par = Block::make({forI});
|
|
LoopNest nest(par, {a_buf.node()});
|
|
|
|
HashProvider hasher;
|
|
auto hash_before = hasher.hash(nest.root_stmt());
|
|
nest.optimizeConditionals();
|
|
auto hash_after = hasher.hash(nest.root_stmt());
|
|
ASSERT_EQ(hash_before, hash_after);
|
|
}
|
|
|
|
static std::pair<BufHandle, Tensor> colReduce(int M, int N) {
|
|
BufHandle a("a", {M, N}, kFloat);
|
|
Tensor t = Reduce(
|
|
"b",
|
|
{N},
|
|
Sum(),
|
|
[&](const VarHandle& n, const VarHandle& m) { return a.load(m, n); },
|
|
{M});
|
|
return {a, Tensor(t.buf(), LoopNest::sanitizeNames(t.stmt()))};
|
|
}
|
|
|
|
static StmtPtr splitTailReorder(Tensor b) {
|
|
constexpr int kVectorWidth = 8;
|
|
LoopNest nest({b});
|
|
auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0];
|
|
nest.splitWithTail(loops[0], kVectorWidth);
|
|
// Now the loopnests will look like:
|
|
//
|
|
// for (int i_outer = 0; ...
|
|
// for (int i_inner = 0; ...
|
|
// b[i_outer * 8 + i_inner] = float(0);
|
|
// for (int j = 0; ...
|
|
// b[i_outer * 8 + i_inner] = ReduceOp(...);
|
|
//
|
|
// for (int i_tail = 0; ...
|
|
// b[i_tail + ((100 - 0) / 8) * 8] = float(0);
|
|
// for (int j = 0; ...
|
|
// b[i_tail + ((100 - 0) / 8) * 8] = ReduceOp(...);
|
|
//
|
|
// Since there are 4 writes to b, we will get 4 loopnests from the
|
|
// call to `getAllLoopNestsWritingToBuf` below.
|
|
//
|
|
// Write #2: "b[i_outer * 8 + i_inner] = ReduceOp(...)"
|
|
// Loopnest #2: {i_outer, i_inner, j};
|
|
// We will have to reorder i_inner and j.
|
|
auto loopnests = nest.getAllLoopNestsWritingToBuf(b.buf());
|
|
LoopNest::reorderAxis(loopnests[1][1], loopnests[1][2]);
|
|
nest.prepareForCodegen();
|
|
return nest.root_stmt();
|
|
}
|
|
|
|
static StmtPtr splitMaskReorder(Tensor b) {
|
|
constexpr int kVectorWidth = 8;
|
|
LoopNest nest({b});
|
|
auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1];
|
|
nest.splitWithMask(loops[0], kVectorWidth);
|
|
loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1];
|
|
LoopNest::reorderAxis(loops[1], loops[2]);
|
|
nest.prepareForCodegen();
|
|
return nest.root_stmt();
|
|
}
|
|
|
|
static void checkColReduce(StmtPtr s, BufHandle p, Tensor t) {
|
|
int M = immediateAs<int>(p.dim(0));
|
|
int N = immediateAs<int>(p.dim(1));
|
|
PaddedBuffer<float> a(M, N);
|
|
PaddedBuffer<float> b(N);
|
|
PaddedBuffer<float> ref(N);
|
|
for (int i = 0; i < M; i++) {
|
|
for (int j = 0; j < N; j++) {
|
|
a(i, j) = 1.0f;
|
|
}
|
|
}
|
|
for (int i = 0; i < N; i++) {
|
|
b(i) = 0.0f;
|
|
}
|
|
for (int i = 0; i < N; i++) {
|
|
ref(i) = 76.0f;
|
|
}
|
|
SimpleIREvaluator(s, {p, t}).call({a, b});
|
|
ExpectAllNear(b, ref, 1e-5);
|
|
}
|
|
|
|
TEST(LoopNest, ColReduceSplitTailEvenReorder) {
|
|
constexpr int M = 76, N = 128;
|
|
auto p = colReduce(M, N);
|
|
StmtPtr s = splitTailReorder(p.second);
|
|
|
|
std::ostringstream oss;
|
|
oss << *s;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i_outer
|
|
# CHECK-NEXT: for (int i_inner
|
|
# CHECK-NEXT: b[
|
|
# CHECK: for (int j
|
|
# CHECK-NEXT: for (int i_inner
|
|
# CHECK-NEXT: b[
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
checkColReduce(s, p.first, p.second);
|
|
}
|
|
|
|
TEST(LoopNest, ColReduceSplitTailUnevenReorder) {
|
|
constexpr int M = 76, N = 100;
|
|
auto p = colReduce(M, N);
|
|
StmtPtr s = splitTailReorder(p.second);
|
|
|
|
std::ostringstream oss;
|
|
oss << *s;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i_outer
|
|
# CHECK-NEXT: for (int i_inner
|
|
# CHECK-NEXT: b[
|
|
# CHECK: for (int j
|
|
# CHECK-NEXT: for (int i_inner
|
|
# CHECK-NEXT: b[
|
|
# CHECK: for (int i_tail
|
|
# CHECK-NEXT: b[
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: b[
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
checkColReduce(s, p.first, p.second);
|
|
}
|
|
|
|
TEST(LoopNest, ColReduceSplitMaskEvenReorder) {
|
|
constexpr int M = 76, N = 128;
|
|
auto p = colReduce(M, N);
|
|
StmtPtr s = splitMaskReorder(p.second);
|
|
checkColReduce(s, p.first, p.second);
|
|
}
|
|
|
|
TEST(LoopNest, ColReduceSplitMaskUnevenReorder) {
|
|
constexpr int M = 76, N = 100;
|
|
auto p = colReduce(M, N);
|
|
StmtPtr s = splitMaskReorder(p.second);
|
|
checkColReduce(s, p.first, p.second);
|
|
}
|
|
|
|
TEST(LoopNest, ReorderAxisWithMultipleConds) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// if i > 5 {
|
|
// if i < 10 {
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[i] = i * j;
|
|
// }
|
|
// }
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i}, Mul::make(i, j)));
|
|
auto inner_cond = Cond::make(CompareSelect::make(i, 10, kLT), forJ, nullptr);
|
|
auto outer_cond =
|
|
Cond::make(CompareSelect::make(i, 5, kGT), inner_cond, nullptr);
|
|
auto forI = For::make(i, 0, 20, outer_cond);
|
|
StmtPtr par = Block::make({forI});
|
|
LoopNest l(par, {a_buf.node()});
|
|
LoopNest::reorderAxis(forI, forJ);
|
|
ASSERT_EQ(par, l.root_stmt());
|
|
par = IRSimplifier::simplify(par);
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int j
|
|
# CHECK-NEXT: for (int i
|
|
# CHECK-NEXT: if (i>5
|
|
# CHECK-NEXT: if (i<10
|
|
# CHECK-NEXT: A[i] = i * j
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
TEST(LoopNest, VectorizeUse) {
|
|
constexpr int N = 8;
|
|
BufHandle a("a", {N}, kFloat);
|
|
Tensor b =
|
|
Compute("b", {N}, [&](const VarHandle& n) { return a.load(n) + 1.0f; });
|
|
Tensor c =
|
|
Compute("c", {N}, [&](const VarHandle& n) { return b.load(n) + 2.0f; });
|
|
LoopNest nest({c}, {b, c});
|
|
auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0];
|
|
ASSERT_TRUE(LoopNest::vectorize(loops[0]));
|
|
loops = nest.getAllLoopNestsWritingToBuf(c.buf())[0];
|
|
ASSERT_TRUE(LoopNest::vectorize(loops[0]));
|
|
nest.prepareForCodegen();
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
StmtPtr s = nest.root_stmt();
|
|
std::ostringstream oss;
|
|
oss << *nest.root_stmt();
|
|
torch::jit::testing::FileCheck().run(
|
|
R"IR(
|
|
# CHECK: c[Ramp
|
|
)IR",
|
|
oss.str());
|
|
}
|
|
|
|
const char* int64Loop = R"IR(
|
|
# CHECK: for (int64_t i = 0ll; i < 12ll; i++) {
|
|
# CHECK: b[i] = (a[i]) + 1ll;
|
|
# CHECK: }
|
|
)IR";
|
|
|
|
TEST(LoopNest, Int64Direct) {
|
|
constexpr int64_t N = 12;
|
|
BufHandle a("a", {N}, kLong);
|
|
BufHandle b("b", {N}, kLong);
|
|
VarHandle n("i", kLong);
|
|
StmtPtr s = For::make(
|
|
n, LongImm::make(0l), N, b.store({n}, a.load({n}) + LongImm::make(1l)));
|
|
s = IRSimplifier::simplify(s);
|
|
std::ostringstream oss;
|
|
oss << *s;
|
|
torch::jit::testing::FileCheck().run(int64Loop, oss.str());
|
|
}
|
|
|
|
TEST(LoopNest, Int64Compute) {
|
|
constexpr int64_t N = 12;
|
|
BufHandle a("a", {N}, kLong);
|
|
Tensor b = Compute("b", {N}, [&](const VarHandle& n) {
|
|
return a.load(n) + LongImm::make(1l);
|
|
});
|
|
LoopNest nest({b});
|
|
nest.prepareForCodegen();
|
|
nest.simplify();
|
|
std::ostringstream oss;
|
|
oss << *nest.root_stmt();
|
|
torch::jit::testing::FileCheck().run(int64Loop, oss.str());
|
|
}
|
|
|
|
TEST(LoopNest, DistributeLoopWithAllStmtsAsPivots) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// A[i] = 0;
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[i] = A[i] + i * j;
|
|
// }
|
|
// B[i] = A[i];
|
|
// for (int k = 0; k < 50; k++) {
|
|
// B[i] = B[i] + i * k;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
BufHandle b_buf("B", {20}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto initA = Store::make(a_buf, {i}, 0);
|
|
auto forJ = For::make(
|
|
j,
|
|
0,
|
|
100,
|
|
Store::make(
|
|
a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j))));
|
|
auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i}));
|
|
auto forK = For::make(
|
|
k,
|
|
0,
|
|
50,
|
|
Store::make(
|
|
b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k))));
|
|
auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK}));
|
|
auto par = Block::make({forI});
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: A[i] = 0
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[i] =
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: B[i] = A[i]
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int k
|
|
# CHECK-NEXT: B[i] =
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
|
|
LoopNest nest(par, {a_buf.node(), b_buf.node()});
|
|
auto new_loops = LoopNest::distributeLoop(forI, {initA, forJ, initB});
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The first loop after distribution must be same as the original For.
|
|
ASSERT_EQ(new_loops.front(), forI);
|
|
}
|
|
|
|
TEST(LoopNest, DistributeLoopWithOneStmtAsPivot) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// A[i] = 0;
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[i] = A[i] + i * j;
|
|
// }
|
|
// B[i] = A[i];
|
|
// for (int k = 0; k < 50; k++) {
|
|
// B[i] = B[i] + i * k;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
BufHandle b_buf("B", {20}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto initA = Store::make(a_buf, {i}, 0);
|
|
auto forJ = For::make(
|
|
j,
|
|
0,
|
|
100,
|
|
Store::make(
|
|
a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j))));
|
|
auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i}));
|
|
auto forK = For::make(
|
|
k,
|
|
0,
|
|
50,
|
|
Store::make(
|
|
b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k))));
|
|
auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK}));
|
|
auto par = Block::make({forI});
|
|
|
|
LoopNest nest(par, {a_buf.node(), b_buf.node()});
|
|
auto new_loops = LoopNest::distributeLoop(forI, {forJ});
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: A[i] = 0
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[i] =
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: B[i] = A[i]
|
|
# CHECK-NEXT: for (int k
|
|
# CHECK-NEXT: B[i] =
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The first loop after distribution must be same as the original For.
|
|
ASSERT_EQ(new_loops.front(), forI);
|
|
}
|
|
|
|
TEST(LoopNest, DistributeLoopWithoutAnyPivot) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// A[i] = 0;
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[i] = A[i] + i * j;
|
|
// }
|
|
// B[i] = A[i];
|
|
// for (int k = 0; k < 50; k++) {
|
|
// B[i] = B[i] + i * k;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
BufHandle b_buf("B", {20}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto initA = Store::make(a_buf, {i}, 0);
|
|
auto forJ = For::make(
|
|
j,
|
|
0,
|
|
100,
|
|
Store::make(
|
|
a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j))));
|
|
auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i}));
|
|
auto forK = For::make(
|
|
k,
|
|
0,
|
|
50,
|
|
Store::make(
|
|
b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k))));
|
|
auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK}));
|
|
auto par = Block::make({forI});
|
|
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: A[i] = 0
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[i] =
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: B[i] = A[i]
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int k
|
|
# CHECK-NEXT: B[i] =
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
|
|
LoopNest nest(par, {a_buf.node(), b_buf.node()});
|
|
auto new_loops = LoopNest::distributeLoop(forI);
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The first loop after distribution must be same as the original For.
|
|
ASSERT_EQ(new_loops.front(), forI);
|
|
}
|
|
|
|
TEST(LoopNest, DistributeLoopOverInnerLoops) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// A[i] = 0;
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[i] = A[i] + i * j;
|
|
// }
|
|
// B[i] = A[i];
|
|
// for (int k = 0; k < 50; k++) {
|
|
// B[i] = B[i] + i * k;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
BufHandle b_buf("B", {20}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto initA = Store::make(a_buf, {i}, 0);
|
|
auto forJ = For::make(
|
|
j,
|
|
0,
|
|
100,
|
|
Store::make(
|
|
a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j))));
|
|
auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i}));
|
|
auto forK = For::make(
|
|
k,
|
|
0,
|
|
50,
|
|
Store::make(
|
|
b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k))));
|
|
auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK}));
|
|
auto par = Block::make({forI});
|
|
|
|
LoopNest nest(par, {a_buf.node(), b_buf.node()});
|
|
auto new_loops = LoopNest::distributeLoopOverInnerLoops(forI);
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: A[i] = 0
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[i] =
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: B[i] = A[i]
|
|
# CHECK-NEXT: for (int k
|
|
# CHECK-NEXT: B[i] =
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The first loop after distribution must be same as the original For.
|
|
ASSERT_EQ(new_loops.front(), forI);
|
|
}
|
|
|
|
TEST(LoopNest, DistributeLoopAndParentsWithoutAnyPivot) {
|
|
// Input IR:
|
|
// for (int m = 0; m < 50; m++) {
|
|
// for (int i = 0; i < 20; i++) {
|
|
// A[m,i] = 0;
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[m,i] = A[m,i] + i * j;
|
|
// }
|
|
// B[m,i] = A[m,i];
|
|
// for (int k = 0; k < 50; k++) {
|
|
// B[m,i] = B[m,i] + i * k;
|
|
// }
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {100, 100}, kInt);
|
|
BufHandle b_buf("B", {100, 100}, kInt);
|
|
VarHandle m("m", kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto initA = Store::make(a_buf, {m, i}, 0);
|
|
auto forJ = For::make(
|
|
j,
|
|
0,
|
|
100,
|
|
Store::make(
|
|
a_buf,
|
|
{m, i},
|
|
Add::make(Load::make(a_buf, {m, i}), Mul::make(i, j))));
|
|
auto initB = Store::make(b_buf, {m, i}, Load::make(a_buf, {m, i}));
|
|
auto forK = For::make(
|
|
k,
|
|
0,
|
|
50,
|
|
Store::make(
|
|
b_buf,
|
|
{m, i},
|
|
Add::make(Load::make(b_buf, {m, i}), Mul::make(i, k))));
|
|
auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK}));
|
|
|
|
{
|
|
// Check the case of distributing loop and its parents over all the
|
|
// statements in the loop.
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int m
|
|
# CHECK-NEXT: for (int i
|
|
# CHECK-NEXT: A[m, i] = 0
|
|
# CHECK: for (int m
|
|
# CHECK-NEXT: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[m, i] =
|
|
# CHECK: for (int m
|
|
# CHECK-NEXT: for (int i
|
|
# CHECK-NEXT: B[m, i] = A[m, i]
|
|
# CHECK: for (int m
|
|
# CHECK-NEXT: for (int i
|
|
# CHECK-NEXT: for (int k
|
|
# CHECK-NEXT: B[m, i] =
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
|
|
auto newForI = to<For>(Stmt::clone(forI));
|
|
auto forM = For::make(m, 0, 50, newForI);
|
|
auto par = Block::make({forM});
|
|
LoopNest nest(par, {a_buf.node(), b_buf.node()});
|
|
auto newLoops = LoopNest::distributeLoopAndParents(newForI);
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The first loop after distribution must be same as the original For.
|
|
ASSERT_EQ(newLoops.front(), forM);
|
|
}
|
|
|
|
{
|
|
// Check the case of distributing loop and its parents over all the inner
|
|
// loops.
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int m
|
|
# CHECK-NEXT: for (int i
|
|
# CHECK-NEXT: A[m, i] = 0
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[m, i] =
|
|
# CHECK: for (int m
|
|
# CHECK-NEXT: for (int i
|
|
# CHECK-NEXT: B[m, i] = A[m, i]
|
|
# CHECK-NEXT: for (int k
|
|
# CHECK-NEXT: B[m, i] =
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
|
|
auto newForI = to<For>(Stmt::clone(forI));
|
|
auto forM = For::make(m, 0, 50, newForI);
|
|
auto par = Block::make({forM});
|
|
LoopNest nest(par, {a_buf.node(), b_buf.node()});
|
|
auto newLoops = LoopNest::distributeLoopAndParentsOverInnerLoops(newForI);
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The first loop after distribution must be same as the original For.
|
|
ASSERT_EQ(newLoops.front(), forM);
|
|
}
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsSimple) {
|
|
// Input IR:
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[j] = 10 * j;
|
|
// }
|
|
// for (int k = 0; k < 100; k++) {
|
|
// B[k] = 20 * k;
|
|
// }
|
|
BufHandle a_buf("A", {100}, kInt);
|
|
BufHandle b_buf("B", {100}, kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
|
|
auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k)));
|
|
auto par = Block::make({forJ, forK});
|
|
ForPtr fused_loop;
|
|
ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int j
|
|
# CHECK-NEXT: A[j] =
|
|
# CHECK-NEXT: B[j] =
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The fused loop must be the same as the first loop.
|
|
ASSERT_EQ(fused_loop, forJ);
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsMultiple) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 100; i++) {
|
|
// A[i+100] = 20 + i;
|
|
// }
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[j] = 10 * j;
|
|
// }
|
|
// for (int k = 0; k < 100; k++) {
|
|
// B[k] = 20 * k;
|
|
// }
|
|
BufHandle a_buf("A", {200}, kInt);
|
|
BufHandle b_buf("B", {100}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto forI =
|
|
For::make(i, 0, 100, Store::make(a_buf, {i + 100}, Add::make(20, i)));
|
|
auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
|
|
auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k)));
|
|
auto par = Block::make({forI, forJ, forK});
|
|
ForPtr fused_loop;
|
|
ASSERT_TRUE(LoopNest::fuseLoops({forI, forJ, forK}, &fused_loop));
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: A[i + 100] =
|
|
# CHECK-NEXT: A[i] =
|
|
# CHECK-NEXT: B[i] =
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The fused loop must be the same as the first loop.
|
|
ASSERT_EQ(fused_loop, forI);
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsNested) {
|
|
// Input IR:
|
|
// for (int m = 0; m < 20; m++) {
|
|
// A[m] = 0;
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[m] = A[m] + m * j;
|
|
// }
|
|
// }
|
|
// for (int n = 0; n < 20; n++) {
|
|
// B[n] = A[n];
|
|
// for (int k = 0; k < 50; k++) {
|
|
// B[n] = B[n] + n * k;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20, 100}, kInt);
|
|
BufHandle b_buf("B", {20, 100}, kInt);
|
|
VarHandle m("m", kInt);
|
|
VarHandle n("n", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto initA = Store::make(a_buf, {m}, 0);
|
|
auto forJ = For::make(
|
|
j,
|
|
0,
|
|
100,
|
|
Store::make(
|
|
a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j))));
|
|
auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n}));
|
|
auto forK = For::make(
|
|
k,
|
|
0,
|
|
50,
|
|
Store::make(
|
|
b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k))));
|
|
auto forM = For::make(m, 0, 20, Block::make({initA, forJ}));
|
|
auto forN = For::make(n, 0, 20, Block::make({initB, forK}));
|
|
auto par = Block::make({forM, forN});
|
|
ForPtr fused_loop;
|
|
ASSERT_TRUE(LoopNest::fuseLoops({forM, forN}, &fused_loop));
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int m
|
|
# CHECK-NEXT: A[m] = 0
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[m] =
|
|
# CHECK: B[m] = A[m]
|
|
# CHECK-NEXT: for (int k
|
|
# CHECK-NEXT: B[m] =
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The fused loop must be the same as the first loop.
|
|
ASSERT_EQ(fused_loop, forM);
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsNested2D) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[i,j] = i * j * 500;
|
|
// }
|
|
// }
|
|
// for (int m = 0; m < 20; m++) {
|
|
// for (int n = 0; n < 50; n++) {
|
|
// B[m,n] = m + n * 100;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20, 100}, kInt);
|
|
BufHandle b_buf("B", {20, 100}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle m("m", kInt);
|
|
VarHandle n("n", kInt);
|
|
auto forI = For::make(
|
|
i,
|
|
0,
|
|
20,
|
|
For::make(
|
|
j,
|
|
0,
|
|
100,
|
|
Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))));
|
|
auto forM = For::make(
|
|
m,
|
|
0,
|
|
20,
|
|
For::make(
|
|
n,
|
|
0,
|
|
50,
|
|
Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100)))));
|
|
auto par = Block::make({forI, forM});
|
|
ForPtr fused_loop;
|
|
ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[i, j] =
|
|
# CHECK: for (int n
|
|
# CHECK-NEXT: B[i, n] =
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The fused loop must be the same as the first loop.
|
|
ASSERT_EQ(fused_loop, forI);
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsNested2DInner) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[i,j] = i * j * 500;
|
|
// }
|
|
// for (int n = 0; n < 100; n++) {
|
|
// B[i,n] = m + n * 100;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20, 100}, kInt);
|
|
BufHandle b_buf("B", {20, 100}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle n("n", kInt);
|
|
auto forJ = For::make(
|
|
j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)));
|
|
auto forN = For::make(
|
|
n, 0, 100, Store::make(b_buf, {i, n}, Add::make(i, Mul::make(n, 100))));
|
|
auto forI = For::make(i, 0, 20, Block::make({forJ, forN}));
|
|
ForPtr fused_loop;
|
|
ASSERT_TRUE(LoopNest::fuseLoops({forJ, forN}, &fused_loop));
|
|
|
|
std::ostringstream oss;
|
|
oss << *forI;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[i, j] =
|
|
# CHECK-NEXT: B[i, j] =
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The fused loop must be the same as the first loop.
|
|
ASSERT_EQ(fused_loop, forJ);
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsDifferentStopBounds) {
|
|
// Input IR:
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[j] = 10 * j;
|
|
// }
|
|
// for (int k = 0; k < 50; k++) {
|
|
// B[k] = 20 * k;
|
|
// }
|
|
BufHandle a_buf("A", {100}, kInt);
|
|
BufHandle b_buf("B", {100}, kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
|
|
auto forK = For::make(k, 0, 50, Store::make(b_buf, {j}, Mul::make(20, k)));
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
auto par = Block::make({forJ, forK});
|
|
ForPtr fused_loop;
|
|
ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsDifferentStartBounds) {
|
|
// Input IR:
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[j] = 10 * j;
|
|
// }
|
|
// for (int k = 50; k < 100; k++) {
|
|
// B[k] = 20 * k;
|
|
// }
|
|
BufHandle a_buf("A", {100}, kInt);
|
|
BufHandle b_buf("B", {100}, kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
|
|
auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k)));
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
auto par = Block::make({forJ, forK});
|
|
ForPtr fused_loop;
|
|
ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsNotContiguous) {
|
|
// Input IR:
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[j] = 10 * j;
|
|
// }
|
|
// B[0] = 0;
|
|
// for (int k = 0; k < 100; k++) {
|
|
// B[k] = 20 * k;
|
|
// }
|
|
BufHandle a_buf("A", {100}, kInt);
|
|
BufHandle b_buf("B", {100}, kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
|
|
auto initB = Store::make(b_buf, {0}, 0);
|
|
auto forK = For::make(k, 0, 100, Store::make(b_buf, {j}, Mul::make(20, k)));
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
auto par = Block::make({forJ, initB, forK});
|
|
ForPtr fused_loop;
|
|
ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsWithDifferentParents) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 50; i++) {
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[i,j] = i * j;
|
|
// }
|
|
// }
|
|
// B[0] = 0;
|
|
// for (int k = 50; k < 100; k++) {
|
|
// B[k] = 20 * k;
|
|
// }
|
|
BufHandle a_buf("A", {50, 100}, kInt);
|
|
BufHandle b_buf("B", {100}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(i, j)));
|
|
auto forI = For::make(i, 0, 50, forJ);
|
|
auto initB = Store::make(b_buf, {0}, 0);
|
|
auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k)));
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
auto par = Block::make({forI, initB, forK});
|
|
ForPtr fused_loop;
|
|
ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsWithVariableBounds) {
|
|
// Input IR:
|
|
// for (int j = 0; j < N; j++) {
|
|
// A[j] = 10 * j;
|
|
// }
|
|
// for (int k = 0; k < N; k++) {
|
|
// B[k] = 20 * k;
|
|
// }
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
BufHandle b_buf("B", {20}, kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
VarHandle N("N", kInt);
|
|
auto forJ = For::make(j, 0, N, Store::make(a_buf, {j}, Mul::make(10, j)));
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers)
|
|
auto forK = For::make(k, 0, N, Store::make(b_buf, {j}, Mul::make(20, k)));
|
|
auto par = Block::make({forJ, forK});
|
|
ForPtr fused_loop;
|
|
ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int j
|
|
# CHECK-NEXT: A[j] =
|
|
# CHECK-NEXT: B[j] =
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The fused loop must be the same as the first loop.
|
|
ASSERT_EQ(fused_loop, forJ);
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsWithExprBounds) {
|
|
// Input IR:
|
|
// for (int j = 0; j < M + N; j++) {
|
|
// A[j] = 10 * j;
|
|
// }
|
|
// for (int k = 0; k < M + N; k++) {
|
|
// B[k] = 20 * k;
|
|
// }
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
BufHandle b_buf("B", {20}, kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
VarHandle M("M", kInt);
|
|
VarHandle N("N", kInt);
|
|
auto forJ = For::make(j, 0, M + N, Store::make(a_buf, {j}, Mul::make(10, j)));
|
|
auto forK = For::make(k, 0, M + N, Store::make(b_buf, {j}, Mul::make(20, k)));
|
|
auto par = Block::make({forJ, forK});
|
|
ForPtr fused_loop;
|
|
ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int j
|
|
# CHECK-NEXT: A[j] =
|
|
# CHECK-NEXT: B[j] =
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The fused loop must be the same as the first loop.
|
|
ASSERT_EQ(fused_loop, forJ);
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsWithDifferentExprBounds) {
|
|
// Input IR:
|
|
// for (int j = M; j < N * 2; j++) {
|
|
// A[j] = 10 * j;
|
|
// }
|
|
// for (int k = M; k < N + N; k++) {
|
|
// B[k] = 20 * k;
|
|
// }
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
BufHandle b_buf("B", {20}, kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
VarHandle M("M", kInt);
|
|
VarHandle N("N", kInt);
|
|
auto forJ = For::make(j, M, N * 2, Store::make(a_buf, {j}, Mul::make(10, j)));
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers)
|
|
auto forK = For::make(k, M, N + N, Store::make(b_buf, {j}, Mul::make(20, k)));
|
|
auto par = Block::make({forJ, forK});
|
|
ForPtr fused_loop;
|
|
ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int j
|
|
# CHECK-NEXT: A[j] =
|
|
# CHECK-NEXT: B[j] =
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The fused loop must be the same as the first loop.
|
|
ASSERT_EQ(fused_loop, forJ);
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsWithNonOverlappingBufferAccesses) {
|
|
// Input IR:
|
|
// for (int j = 10; j < 100; j++) {
|
|
// A[j] = 10 * j;
|
|
// }
|
|
// for (int k = 10; k < 100; k++) {
|
|
// A[k+100] = 30 * k
|
|
// }
|
|
BufHandle a_buf("A", {200}, kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
|
|
auto forK =
|
|
For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(30, k)));
|
|
auto par = Block::make({forJ, forK});
|
|
|
|
ForPtr fused_loop;
|
|
ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int j
|
|
# CHECK-NEXT: A[j] =
|
|
# CHECK-NEXT: A[j + 100] =
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The fused loop must be the same as the first loop.
|
|
ASSERT_EQ(fused_loop, forJ);
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsWithNonOverlapping2DBufferAccesses) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[i,j] = i * j * 500;
|
|
// }
|
|
// }
|
|
// for (int m = 0; m < 20; m++) {
|
|
// for (int n = 0; n < 50; n++) {
|
|
// A[m+20,n+100] = m + n * 100;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20, 100}, kInt);
|
|
BufHandle b_buf("B", {20, 50}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle m("m", kInt);
|
|
VarHandle n("n", kInt);
|
|
auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500));
|
|
auto forJ = For::make(j, 0, 100, storeA1);
|
|
auto forI = For::make(i, 0, 20, forJ);
|
|
auto storeA2 =
|
|
Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100)));
|
|
auto forN = For::make(n, 0, 50, storeA2);
|
|
auto forM = For::make(m, 0, 20, forN);
|
|
auto par = Block::make({forI, forM});
|
|
|
|
ForPtr fused_loop;
|
|
ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[i, j] =
|
|
# CHECK: for (int n
|
|
# CHECK-NEXT: A[i + 20, n + 100] =
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The fused loop must be the same as the first loop.
|
|
ASSERT_EQ(fused_loop, forI);
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsWithReductions) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// A[i] = 0
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[i] = A[i] + B[i,j];
|
|
// }
|
|
// }
|
|
// for (int m = 0; m < 20; m++) {
|
|
// C[m] = A[m];
|
|
// }
|
|
BufHandle a_buf("A", {20}, kInt);
|
|
BufHandle b_buf("B", {20, 100}, kInt);
|
|
BufHandle c_buf("C", {20}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle m("m", kInt);
|
|
auto initA = Store::make(a_buf, {i}, 0);
|
|
auto sumA = Store::make(
|
|
a_buf, {i}, Add::make(Load::make(a_buf, {i}), Load::make(b_buf, {i, j})));
|
|
auto forJ = For::make(j, 0, 100, sumA);
|
|
auto forI = For::make(i, 0, 20, Block::make({initA, forJ}));
|
|
auto forM =
|
|
For::make(m, 0, 20, Store::make(c_buf, {m}, Load::make(a_buf, {m})));
|
|
auto par = Block::make({forI, forM});
|
|
ForPtr fused_loop;
|
|
ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: A[i] =
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[i] = (A[i]) +
|
|
# CHECK-NOT: for (
|
|
# CHECK: C[i] = A[i]
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The fused loop must be the same as the first loop.
|
|
ASSERT_EQ(fused_loop, forI);
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsWith2DReductions) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 50; j++) {
|
|
// A[i,j] = 0
|
|
// for (int k = 0; k < 100; k++) {
|
|
// A[i,j] = A[i,j] + B[i,j,k];
|
|
// }
|
|
// }
|
|
// }
|
|
// for (int m = 0; m < 20; m++) {
|
|
// for (int n = 0; n < 40; n++) {
|
|
// C[m,n] = A[m,n];
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20, 50}, kInt);
|
|
BufHandle b_buf("B", {20, 50, 100}, kInt);
|
|
BufHandle c_buf("C", {20, 40}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
VarHandle m("m", kInt);
|
|
VarHandle n("n", kInt);
|
|
auto initA = Store::make(a_buf, {i, j}, 0);
|
|
auto sumA = Store::make(
|
|
a_buf,
|
|
{i, j},
|
|
Add::make(Load::make(a_buf, {i, j}), Load::make(b_buf, {i, j, k})));
|
|
auto forK = For::make(k, 0, 100, sumA);
|
|
auto forJ = For::make(j, 0, 50, Block::make({initA, forK}));
|
|
auto forI = For::make(i, 0, 20, forJ);
|
|
auto storeC = Store::make(c_buf, {m, n}, Load::make(a_buf, {m, n}));
|
|
auto forM = For::make(m, 0, 20, For::make(n, 0, 40, storeC));
|
|
auto par = Block::make({forI, forM});
|
|
|
|
ForPtr fused_loop;
|
|
ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[i, j] =
|
|
# CHECK-NEXT: for (int k
|
|
# CHECK-NEXT: A[i, j] = (A[i, j]) +
|
|
# CHECK: for (int n
|
|
# CHECK-NEXT: C[i, n] = A[i, n]
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The fused loop must be the same as the first loop.
|
|
ASSERT_EQ(fused_loop, forI);
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsWithComplexIndices) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 20; j++) {
|
|
// A[i,j*20+j+2] = i + j;
|
|
// }
|
|
// }
|
|
// for (int m = 0; m < 20; m++) {
|
|
// for (int n = 0; n < 20; n++) {
|
|
// B[m,n] = A[m,n*20+n+2];
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20, 400}, kInt);
|
|
BufHandle b_buf("B", {20, 400}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle m("m", kInt);
|
|
VarHandle n("n", kInt);
|
|
auto writeA = Store::make(a_buf, {i, j * 20 + j + 2}, i + j);
|
|
auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA));
|
|
auto storeB =
|
|
Store::make(b_buf, {m, n}, Load::make(a_buf, {m, n * 20 + n + 2}));
|
|
auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB));
|
|
auto par = Block::make({forI, forM});
|
|
|
|
ForPtr fused_loop;
|
|
ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[i, (j * 20 + j) + 2] = i + j
|
|
# CHECK: for (int n
|
|
# CHECK-NEXT: B[i, n] = A[i, (n * 20 + n) + 2]
|
|
# CHECK-NOT: for (
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
// The fused loop must be the same as the first loop.
|
|
ASSERT_EQ(fused_loop, forI);
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsWithMixedLoopVarsAsIndices) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 20; j++) {
|
|
// A[i,i*20+j] = i + j;
|
|
// }
|
|
// }
|
|
// for (int m = 0; m < 20; m++) {
|
|
// for (int n = 0; n < 20; n++) {
|
|
// B[m,n] = A[m,m*20+n]; // Both indices of A use m
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20, 500}, kInt);
|
|
BufHandle b_buf("B", {20, 500}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle m("m", kInt);
|
|
VarHandle n("n", kInt);
|
|
auto writeA = Store::make(a_buf, {i, i * 20 + j}, i + j);
|
|
auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA));
|
|
auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {m, m * 20 + n}));
|
|
auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB));
|
|
auto par = Block::make({forI, forM});
|
|
|
|
ForPtr fused_loop;
|
|
ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsWithTranspose) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 20; j++) {
|
|
// A[i,j] = i + j;
|
|
// }
|
|
// }
|
|
// for (int m = 0; m < 20; m++) {
|
|
// for (int n = 0; n < 20; n++) {
|
|
// B[m,n] = A[n,m]; // Transpose
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20, 20}, kInt);
|
|
BufHandle b_buf("B", {20, 20}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle m("m", kInt);
|
|
VarHandle n("n", kInt);
|
|
auto writeA = Store::make(a_buf, {i, j}, i + j);
|
|
auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA));
|
|
auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {n, m}));
|
|
auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB));
|
|
auto par = Block::make({forI, forM});
|
|
|
|
ForPtr fused_loop;
|
|
ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsThatViolateDependencies1) {
|
|
// Input IR:
|
|
// for (int j = 10; j < 100; j++) {
|
|
// A[j] = 10 * j;
|
|
// }
|
|
// for (int k = 10; k < 100; k++) {
|
|
// A[k-1] = 20 * k;
|
|
// }
|
|
BufHandle a_buf("A", {100}, kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
|
|
auto forK =
|
|
For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k)));
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
auto par = Block::make({forJ, forK});
|
|
ForPtr fused_loop;
|
|
ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsThatViolateDependencies2) {
|
|
// Input IR:
|
|
// for (int j = 10; j < 100; j++) {
|
|
// A[j] = 10 * j;
|
|
// }
|
|
// for (int k = 10; k < 100; k++) {
|
|
// A[k+50] = 20 * k;
|
|
// }
|
|
BufHandle a_buf("A", {150}, kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
|
|
auto forK =
|
|
For::make(k, 10, 100, Store::make(a_buf, {k + 50}, Mul::make(20, k)));
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
auto par = Block::make({forJ, forK});
|
|
ForPtr fused_loop;
|
|
ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsThatViolateDependencies3) {
|
|
// Input IR:
|
|
// for (int m = 0; m < 20; m++) {
|
|
// A[m] = 0;
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[m] = A[m] + m * j;
|
|
// }
|
|
// }
|
|
// for (int n = 0; n < 20; n++) {
|
|
// B[n] = A[n+1];
|
|
// for (int k = 0; k < 50; k++) {
|
|
// B[n] = B[n] + n * k;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {25, 100}, kInt);
|
|
BufHandle b_buf("B", {20, 50}, kInt);
|
|
VarHandle m("m", kInt);
|
|
VarHandle n("n", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto initA = Store::make(a_buf, {m}, 0);
|
|
auto forJ = For::make(
|
|
j,
|
|
0,
|
|
100,
|
|
Store::make(
|
|
a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j))));
|
|
auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n + 1}));
|
|
auto forK = For::make(
|
|
k,
|
|
0,
|
|
50,
|
|
Store::make(
|
|
b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k))));
|
|
auto forM = For::make(m, 0, 20, Block::make({initA, forJ}));
|
|
auto forN = For::make(n, 0, 20, Block::make({initB, forK}));
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
auto par = Block::make({forM, forN});
|
|
ForPtr fused_loop;
|
|
ASSERT_FALSE(LoopNest::fuseLoops({forM, forN}, &fused_loop));
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsThatViolateDependencies4) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[i,j] = i * j * 500;
|
|
// }
|
|
// }
|
|
// for (int m = 0; m < 20; m++) {
|
|
// for (int n = 0; n < 50; n++) {
|
|
// A[m+1,n] = m + n * 100;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {30, 100}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle m("m", kInt);
|
|
VarHandle n("n", kInt);
|
|
auto forI = For::make(
|
|
i,
|
|
0,
|
|
20,
|
|
For::make(
|
|
j,
|
|
0,
|
|
100,
|
|
Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))));
|
|
auto forM = For::make(
|
|
m,
|
|
0,
|
|
20,
|
|
For::make(
|
|
n,
|
|
0,
|
|
50,
|
|
Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100)))));
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
auto par = Block::make({forI, forM});
|
|
ForPtr fused_loop;
|
|
ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsThatViolateDependencies5) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[i,j] = i * j * 500;
|
|
// }
|
|
// for (int n = 0; n < 100; n++) {
|
|
// A[i,n+1] = m + n * 100;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20, 200}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle n("n", kInt);
|
|
auto forJ = For::make(
|
|
j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)));
|
|
auto forN = For::make(
|
|
n,
|
|
0,
|
|
100,
|
|
Store::make(a_buf, {i, n + 1}, Add::make(i, Mul::make(n, 100))));
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,cppcoreguidelines-avoid-magic-numbers)
|
|
auto forI = For::make(i, 0, 20, Block::make({forJ, forN}));
|
|
ForPtr fused_loop;
|
|
ASSERT_FALSE(LoopNest::fuseLoops({forJ, forN}, &fused_loop));
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsThatViolateDependencies6) {
|
|
// Input IR:
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[j] = 10 * j;
|
|
// }
|
|
// for (int k = 0; k < 100; k++) {
|
|
// B[k] = 20 * A[99-k];
|
|
// }
|
|
BufHandle a_buf("A", {100}, kInt);
|
|
BufHandle b_buf("B", {100}, kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
|
|
auto forK = For::make(
|
|
k,
|
|
0,
|
|
100,
|
|
Store::make(
|
|
b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k}))));
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
auto par = Block::make({forJ, forK});
|
|
ForPtr fused_loop;
|
|
ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
|
|
}
|
|
|
|
TEST(LoopNest, fuseLoopsThatViolateDependencies7) {
|
|
// Input IR:
|
|
// for (int k = 0; k < 100; k++) {
|
|
// B[k] = 20 * A[99-k];
|
|
// }
|
|
// for (int j = 0; j < 100; j++) {
|
|
// A[j] = 10 * j;
|
|
// }
|
|
BufHandle a_buf("A", {100}, kInt);
|
|
BufHandle b_buf("B", {100}, kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto forK = For::make(
|
|
k,
|
|
0,
|
|
100,
|
|
Store::make(
|
|
b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k}))));
|
|
auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
auto par = Block::make({forK, forJ});
|
|
ForPtr fused_loop;
|
|
ASSERT_FALSE(LoopNest::fuseLoops({forK, forJ}, &fused_loop));
|
|
}
|
|
|
|
TEST(LoopNest, areLoopsPerfectlyNested) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 30; j++) {
|
|
// for (int k = 0; k < 40; k++) {
|
|
// A[i,j,k] = i * j * k;
|
|
// }
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20, 30, 40}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k));
|
|
auto forK = For::make(k, 0, 40, store);
|
|
auto forJ = For::make(j, 0, 30, forK);
|
|
auto forI = For::make(i, 0, 20, forJ);
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
auto par = Block::make({forI});
|
|
ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK}));
|
|
|
|
// Specifying the loops in any other order fails.
|
|
ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forJ, forI, forK}));
|
|
ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forK, forJ}));
|
|
ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forK, forJ, forI}));
|
|
|
|
// Adding a statement to forK body should be OK.
|
|
auto init = Store::make(a_buf, {i, j}, 0);
|
|
forK->body()->insert_stmt_before(init, store);
|
|
ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK}));
|
|
|
|
// Adding a statement in forJ body should fail this test.
|
|
forK->body()->remove_stmt(init);
|
|
forJ->body()->insert_stmt_before(init, forK);
|
|
ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK}));
|
|
|
|
// Similarly, adding a statement in forI body should fail this test.
|
|
forJ->body()->remove_stmt(init);
|
|
forI->body()->insert_stmt_before(init, forJ);
|
|
ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK}));
|
|
}
|
|
|
|
TEST(LoopNest, reorderNestedLoops2D) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 30; j++) {
|
|
// A[i,j] = i * j;
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20, 30, 40}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto store = Store::make(a_buf, {i, j}, Mul::make(i, j));
|
|
auto forJ = For::make(j, 0, 30, store);
|
|
auto forI = For::make(i, 0, 20, forJ);
|
|
auto par = Block::make({forI});
|
|
|
|
auto reordered = LoopNest::reorder({forI, forJ}, {1, 0});
|
|
|
|
ASSERT_EQ(reordered[0], forJ);
|
|
ASSERT_EQ(reordered[1], forI);
|
|
ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forJ, forI}));
|
|
ASSERT_EQ(forJ->get_parent(), par);
|
|
ASSERT_EQ(store->get_parent(), forI->body());
|
|
}
|
|
|
|
TEST(LoopNest, reorderNestedLoops3D) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 30; j++) {
|
|
// for (int k = 0; k < 40; k++) {
|
|
// A[i,j,k] = i * j * k;
|
|
// }
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20, 30, 40}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k));
|
|
auto forK = For::make(k, 0, 40, store);
|
|
auto forJ = For::make(j, 0, 30, forK);
|
|
auto forI = For::make(i, 0, 20, forJ);
|
|
auto par = Block::make({forI});
|
|
|
|
auto reordered = LoopNest::reorder({forI, forJ, forK}, {2, 0, 1});
|
|
|
|
ASSERT_EQ(reordered[0], forK);
|
|
ASSERT_EQ(reordered[1], forI);
|
|
ASSERT_EQ(reordered[2], forJ);
|
|
ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forJ}));
|
|
ASSERT_EQ(forK->get_parent(), par);
|
|
ASSERT_EQ(store->get_parent(), forJ->body());
|
|
}
|
|
|
|
TEST(LoopNest, reorderNestedLoops4D) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 30; j++) {
|
|
// for (int k = 0; k < 40; k++) {
|
|
// for (int l = 0; l < 50; l++) {
|
|
// A[i,j,k,l] = i * j * k * l * 500;
|
|
// }
|
|
// }
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20, 30, 40, 50}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
VarHandle l("l", kInt);
|
|
auto store = Store::make(
|
|
a_buf,
|
|
{i, j, k, l},
|
|
Mul::make(Mul::make(Mul::make(Mul::make(i, j), k), l), 500));
|
|
auto forL = For::make(l, 0, 50, store);
|
|
auto forK = For::make(k, 0, 40, forL);
|
|
auto forJ = For::make(j, 0, 30, forK);
|
|
auto forI = For::make(i, 0, 20, forJ);
|
|
auto par = Block::make({forI});
|
|
|
|
auto reordered = LoopNest::reorder({forI, forJ, forK, forL}, {2, 0, 3, 1});
|
|
|
|
ASSERT_EQ(reordered[0], forK);
|
|
ASSERT_EQ(reordered[1], forI);
|
|
ASSERT_EQ(reordered[2], forL);
|
|
ASSERT_EQ(reordered[3], forJ);
|
|
ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forL, forJ}));
|
|
ASSERT_EQ(forK->get_parent(), par);
|
|
ASSERT_EQ(store->get_parent(), forJ->body());
|
|
}
|
|
|
|
TEST(LoopNest, reorderTrivialPermutation) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 30; j++) {
|
|
// for (int k = 0; k < 40; k++) {
|
|
// A[i,j,k] = i * j * k;
|
|
// }
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20, 30, 40}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k));
|
|
auto forK = For::make(k, 0, 40, store);
|
|
auto forJ = For::make(j, 0, 30, forK);
|
|
auto forI = For::make(i, 0, 20, forJ);
|
|
auto par = Block::make({forI});
|
|
|
|
auto reordered = LoopNest::reorder({forI, forJ, forK}, {0, 1, 2});
|
|
|
|
ASSERT_EQ(reordered[0], forI);
|
|
ASSERT_EQ(reordered[1], forJ);
|
|
ASSERT_EQ(reordered[2], forK);
|
|
ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK}));
|
|
ASSERT_EQ(forI->get_parent(), par);
|
|
ASSERT_EQ(store->get_parent(), forK->body());
|
|
}
|
|
|
|
TEST(LoopNest, reorderInvalidPermutations) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 30; j++) {
|
|
// for (int k = 0; k < 40; k++) {
|
|
// A[i,j,k] = i * j * k;
|
|
// }
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20, 30, 40}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k));
|
|
auto forK = For::make(k, 0, 40, store);
|
|
auto forJ = For::make(j, 0, 30, forK);
|
|
auto forI = For::make(i, 0, 20, forJ);
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
auto par = Block::make({forI});
|
|
|
|
ASSERT_THROWS_WITH(
|
|
LoopNest::reorder({forI, forJ, forK}, {0, 1, 2, 3}),
|
|
"invalid permutation size");
|
|
ASSERT_THROWS_WITH(
|
|
LoopNest::reorder({forI, forJ, forK}, {1, 2}),
|
|
"invalid permutation size");
|
|
ASSERT_THROWS_WITH(
|
|
LoopNest::reorder({forI, forJ, forK}, {2, 1, 3}),
|
|
"invalid permutation for reorder");
|
|
ASSERT_THROWS_WITH(
|
|
LoopNest::reorder({forI, forJ, forK}, {1, 1, 0}),
|
|
"invalid permutation for reorder");
|
|
ASSERT_THROWS_WITH(
|
|
LoopNest::reorder({forI, forJ, forK}, {0, 0, 0}),
|
|
"invalid permutation for reorder");
|
|
}
|
|
|
|
TEST(LoopNest, reorderInvalidLoopNest) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 20; i++) {
|
|
// for (int j = 0; j < 30; j++) {
|
|
// A[i,j] = 0
|
|
// for (int k = 0; k < 40; k++) {
|
|
// A[i,j,k] = i * j * k;
|
|
// }
|
|
// }
|
|
// }
|
|
BufHandle a_buf("A", {20, 30, 40}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k));
|
|
auto forK = For::make(k, 0, 40, store);
|
|
auto forJ = For::make(j, 0, 30, forK);
|
|
auto forI = For::make(i, 0, 20, forJ);
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
auto par = Block::make({forI});
|
|
|
|
// Specifying the loops in incorrect order fails.
|
|
ASSERT_THROWS_WITH(
|
|
LoopNest::reorder({forK, forI, forJ}, {1, 0, 2}),
|
|
"reorder is only allowed on perfectly nested loops");
|
|
|
|
// Adding a statement to forJ loop fails.
|
|
auto init = Store::make(a_buf, {i}, 0);
|
|
forJ->body()->insert_stmt_before(init, forK);
|
|
ASSERT_THROWS_WITH(
|
|
LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}),
|
|
"reorder is only allowed on perfectly nested loops");
|
|
|
|
// Moving that statement to forI loop also fails.
|
|
forJ->body()->remove_stmt(init);
|
|
forI->body()->insert_stmt_before(init, forJ);
|
|
ASSERT_THROWS_WITH(
|
|
LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}),
|
|
"reorder is only allowed on perfectly nested loops");
|
|
}
|
|
|
|
TEST(LoopNest, compressBufferSimple) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 100; ++i) {
|
|
// for (int j = 0; j < 200; ++j) {
|
|
// A[i,j] = sin(i*j)
|
|
// }
|
|
// for (int j = 0; j < 199; ++j) {
|
|
// B[i,j] = A[i,j] + A[i, j+1]
|
|
// }
|
|
// }
|
|
BufHandle aBuf("A", {100, 200}, kInt);
|
|
BufHandle bBuf("B", {100, 200}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j)));
|
|
auto forJ2 = For::make(
|
|
j,
|
|
0,
|
|
199,
|
|
Store::make(
|
|
bBuf,
|
|
{i, j},
|
|
Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1}))));
|
|
auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2}));
|
|
auto par = Block::make({forI});
|
|
LoopNest::compressBuffer(aBuf.node(), par);
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[0, j] =
|
|
# CHECK: for (int j
|
|
# CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1])
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
ASSERT_EQ(aBuf.node()->ndim(), 2);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200);
|
|
}
|
|
|
|
TEST(LoopNest, compressBufferMultipleDims) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 100; ++i) {
|
|
// for (int j = 0; j < 200; ++j) {
|
|
// A[i,j] = sin(i*j)
|
|
// B[i,j] = A[i,j] + A[i,j]
|
|
// }
|
|
// }
|
|
BufHandle aBuf("A", {100, 200}, kInt);
|
|
BufHandle bBuf("B", {100, 200}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto store1 = Store::make(aBuf, {i, j}, sin(i * j));
|
|
auto store2 = Store::make(
|
|
bBuf,
|
|
{i, j},
|
|
Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j})));
|
|
auto forJ = For::make(j, 0, 200, Block::make({store1, store2}));
|
|
auto forI = For::make(i, 0, 100, forJ);
|
|
auto par = Block::make({forI});
|
|
LoopNest::compressBuffer(aBuf.node(), par);
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[0, 0] =
|
|
# CHECK-NEXT: B[i, j] = (A[0, 0]) + (A[0, 0])
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
ASSERT_EQ(aBuf.node()->ndim(), 2);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1);
|
|
}
|
|
|
|
TEST(LoopNest, compressBufferMultipleDims2) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 100; ++i) {
|
|
// for (int j = 0; j < 200; ++j) {
|
|
// for (int k = 0; k < 300; ++k) {
|
|
// A[i,j,k] = sin(i*j*k)
|
|
// }
|
|
// for (int k = 0; k < 299; ++j) {
|
|
// B[i,j,k] = A[i,j,k] + A[i,j,k+1]
|
|
// }
|
|
// }
|
|
// }
|
|
BufHandle aBuf("A", {100, 200, 300}, kInt);
|
|
BufHandle bBuf("B", {100, 200, 300}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
auto store1 = Store::make(aBuf, {i, j, k}, sin(i * j * k));
|
|
auto forK1 = For::make(k, 0, 300, store1);
|
|
auto store2 = Store::make(
|
|
bBuf,
|
|
{i, j, k},
|
|
Add::make(Load::make(aBuf, {i, j, k}), Load::make(aBuf, {i, j, k + 1})));
|
|
auto forK2 = For::make(k, 0, 299, store2);
|
|
auto forJ = For::make(j, 0, 200, Block::make({forK1, forK2}));
|
|
auto forI = For::make(i, 0, 100, forJ);
|
|
auto par = Block::make({forI});
|
|
LoopNest::compressBuffer(aBuf.node(), par);
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: for (int k
|
|
# CHECK-NEXT: A[0, 0, k] =
|
|
# CHECK: for (int k
|
|
# CHECK-NEXT: B[i, j, k] = (A[0, 0, k]) + (A[0, 0, k + 1])
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
ASSERT_EQ(aBuf.node()->ndim(), 3);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(2), 300);
|
|
}
|
|
|
|
TEST(LoopNest, compressBufferDifferentOrderIndices) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 100; ++i) {
|
|
// for (int j = 0; j < 200; ++j) {
|
|
// A[j, i] = sin(i*j)
|
|
// }
|
|
// for (int j = 0; j < 99; ++j) {
|
|
// B[i, j] = A[j, i] + A[j+1, 0]
|
|
// }
|
|
// }
|
|
BufHandle aBuf("A", {100, 200}, kInt);
|
|
BufHandle bBuf("B", {100, 200}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {j, i}, sin(i * j)));
|
|
auto forJ2 = For::make(
|
|
j,
|
|
0,
|
|
99,
|
|
Store::make(
|
|
bBuf,
|
|
{i, j},
|
|
Add::make(Load::make(aBuf, {j, i}), Load::make(aBuf, {j + 1, i}))));
|
|
auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2}));
|
|
auto par = Block::make({forI});
|
|
LoopNest::compressBuffer(aBuf.node(), par);
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[j, 0] =
|
|
# CHECK: for (int j
|
|
# CHECK-NEXT: B[i, j] = (A[j, 0]) + (A[j + 1, 0])
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
ASSERT_EQ(aBuf.node()->ndim(), 2);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1);
|
|
}
|
|
|
|
TEST(LoopNest, compressBufferVariableBounds) {
|
|
// Input IR:
|
|
// for (int i = 0; i < M; ++i) {
|
|
// for (int j = 0; j < N; ++j) {
|
|
// A[i,j] = sin(i*j)
|
|
// }
|
|
// for (int j = 0; j < N-1; ++j) {
|
|
// B[i,j] = A[i,j] + A[i, j+1]
|
|
// }
|
|
// }
|
|
BufHandle aBuf("A", {100, 200}, kInt);
|
|
BufHandle bBuf("B", {100, 200}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle M("M", kInt);
|
|
VarHandle N("N", kInt);
|
|
auto forJ1 = For::make(j, 0, N, Store::make(aBuf, {i, j}, sin(i * j)));
|
|
auto forJ2 = For::make(
|
|
j,
|
|
0,
|
|
N - 1,
|
|
Store::make(
|
|
bBuf,
|
|
{i, j},
|
|
Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1}))));
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
|
auto forI = For::make(i, 0, M, Block::make({forJ1, forJ2}));
|
|
auto par = Block::make({forI});
|
|
LoopNest::compressBuffer(aBuf.node(), par);
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[0, j] =
|
|
# CHECK: for (int j
|
|
# CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1])
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
ASSERT_EQ(aBuf.node()->ndim(), 2);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200);
|
|
}
|
|
|
|
TEST(LoopNest, compressBufferNoCommonParentLoops) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 100; ++i) {
|
|
// for (int j = 0; j < 200; ++j) {
|
|
// A[i,j] = sin(i*j)
|
|
// }
|
|
// }
|
|
// for (int i = 0; i < 100; ++i) {
|
|
// for (int j = 0; j < 199; ++j) {
|
|
// B[i,j] = A[i,j] + A[i, j+1]
|
|
// }
|
|
// }
|
|
BufHandle aBuf("A", {100, 200}, kInt);
|
|
BufHandle bBuf("B", {100, 200}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j)));
|
|
auto forJ2 = For::make(
|
|
j,
|
|
0,
|
|
199,
|
|
Store::make(
|
|
bBuf,
|
|
{i, j},
|
|
Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1}))));
|
|
auto forI1 = For::make(i, 0, 100, forJ1);
|
|
auto forI2 = For::make(i, 0, 100, forJ2);
|
|
auto par = Block::make({forI1, forI2});
|
|
LoopNest::compressBuffer(aBuf.node(), par);
|
|
|
|
// There should be no change in the buffer or code.
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[i, j] =
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: B[i, j] = (A[i, j]) + (A[i, j + 1])
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
ASSERT_EQ(aBuf.node()->ndim(), 2);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200);
|
|
}
|
|
|
|
TEST(LoopNest, compressBufferIndicesMixed) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 100; ++i) {
|
|
// for (int j = 0; j < 200; ++j) {
|
|
// A[i + j, j] = sin(i*j)
|
|
// }
|
|
// for (int j = 0; j < 199; ++j) {
|
|
// B[i,j] = A[i + j, j] + A[i + j, j+1]
|
|
// }
|
|
// }
|
|
BufHandle aBuf("A", {300, 200}, kInt);
|
|
BufHandle bBuf("B", {100, 200}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i + j, j}, sin(i * j)));
|
|
auto forJ2 = For::make(
|
|
j,
|
|
0,
|
|
199,
|
|
Store::make(
|
|
bBuf,
|
|
{i, j},
|
|
Add::make(
|
|
Load::make(aBuf, {i + j, j}), Load::make(aBuf, {i + j, j + 1}))));
|
|
auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2}));
|
|
auto par = Block::make({forI});
|
|
LoopNest::compressBuffer(aBuf.node(), par);
|
|
|
|
// There should be no change in the buffer or code.
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[i + j, j] =
|
|
# CHECK: for (int j
|
|
# CHECK-NEXT: B[i, j] = (A[i + j, j]) + (A[i + j, j + 1])
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
ASSERT_EQ(aBuf.node()->ndim(), 2);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 300);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200);
|
|
}
|
|
|
|
TEST(LoopNest, compressMultipleBuffers) {
|
|
// Input IR:
|
|
// for (int i = 0; i < 100; ++i) {
|
|
// for (int j = 0; j < 200; ++j) {
|
|
// A[i,j] = sin(i*j)
|
|
// }
|
|
// for (int k = 0; k < 199; ++k) {
|
|
// B[i,k] = A[i,k] + A[i, k+1]
|
|
// }
|
|
// for (int m = 0; m < 50; ++m) {
|
|
// C[i,m] = B[i,m]
|
|
// }
|
|
// }
|
|
BufHandle aBuf("A", {100, 200}, kInt);
|
|
BufHandle bBuf("B", {100, 200}, kInt);
|
|
BufHandle cBuf("C", {100, 200}, kInt);
|
|
VarHandle i("i", kInt);
|
|
VarHandle j("j", kInt);
|
|
VarHandle k("k", kInt);
|
|
VarHandle m("m", kInt);
|
|
auto forJ = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j)));
|
|
auto forK = For::make(
|
|
k,
|
|
0,
|
|
199,
|
|
Store::make(
|
|
bBuf,
|
|
{i, k},
|
|
Add::make(Load::make(aBuf, {i, k}), Load::make(aBuf, {i, k + 1}))));
|
|
auto forM =
|
|
For::make(m, 0, 50, Store::make(cBuf, {i, m}, Load::make(bBuf, {i, m})));
|
|
auto forI = For::make(i, 0, 100, Block::make({forJ, forK, forM}));
|
|
auto par = Block::make({forI});
|
|
|
|
// This should compress all buffers A, B, and C as follows:
|
|
// A[100, 200] -> A[1, 200]
|
|
// B[100, 200] -> B[1, 200]
|
|
// C[100, 200] -> C[1, 1]
|
|
LoopNest::compressAllBuffers(par);
|
|
|
|
std::ostringstream oss;
|
|
oss << *par;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i
|
|
# CHECK-NEXT: for (int j
|
|
# CHECK-NEXT: A[0, j] =
|
|
# CHECK: for (int k
|
|
# CHECK-NEXT: B[0, k] = (A[0, k]) + (A[0, k + 1])
|
|
# CHECK: for (int m
|
|
# CHECK-NEXT: C[0, 0] = B[0, m]
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
|
|
ASSERT_EQ(aBuf.node()->ndim(), 2);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1);
|
|
IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200);
|
|
ASSERT_EQ(bBuf.node()->ndim(), 2);
|
|
IS_IMM_WITH_VAL(Int, bBuf.node()->dim(0), 1);
|
|
IS_IMM_WITH_VAL(Int, bBuf.node()->dim(1), 200);
|
|
ASSERT_EQ(cBuf.node()->ndim(), 2);
|
|
IS_IMM_WITH_VAL(Int, cBuf.node()->dim(0), 1);
|
|
IS_IMM_WITH_VAL(Int, cBuf.node()->dim(1), 1);
|
|
}
|
|
|
|
TEST(LoopNest, sanitizeNames) {
|
|
std::vector<ExprHandle> dim_args;
|
|
// Let's pick names that would overlap with default index names if not
|
|
// sanitized properly:
|
|
dim_args.emplace_back(ExprHandle(alloc<Var>("i", kInt)));
|
|
dim_args.emplace_back(ExprHandle(alloc<Var>("N:2", kInt)));
|
|
// Now let's create a many dimensions so that we had to use the same letter
|
|
// for different loops
|
|
for (int i = 0; i < 10; i++) {
|
|
dim_args.emplace_back(ExprHandle(alloc<Var>("N", kInt)));
|
|
}
|
|
|
|
// Now create two Computes with conflicting after sanitization names:
|
|
Tensor X = Compute("$X:!", dim_args, [&](const std::vector<VarHandle>& v) {
|
|
return v[0] + v[1] + v[9] + 1;
|
|
});
|
|
Tensor Y = Reduce(
|
|
"%X\"+",
|
|
{},
|
|
Sum(),
|
|
[&](const std::vector<VarHandle>& v) { return X.load(v); },
|
|
dim_args);
|
|
|
|
// Finally, let's verify what we got after sanitization:
|
|
LoopNest l({X, Y});
|
|
StmtPtr s = l.root_stmt();
|
|
LoopNest::sanitizeNames(s);
|
|
|
|
std::ostringstream oss;
|
|
oss << *s;
|
|
const std::string& verification_pattern =
|
|
R"IR(
|
|
# CHECK: for (int i = 0; i < i_1; i++) {
|
|
# CHECK-NEXT: for (int j = 0; j < N_2_1; j++) {
|
|
# CHECK-NEXT: for (int k = 0; k < N_9; k++) {
|
|
# CHECK-NEXT: for (int l = 0; l < N_8; l++) {
|
|
# CHECK-NEXT: for (int m = 0; m < N_7; m++) {
|
|
# CHECK-NEXT: for (int n = 0; n < N_6; n++) {
|
|
# CHECK-NEXT: for (int o = 0; o < N_5; o++) {
|
|
# CHECK-NEXT: for (int p = 0; p < N_4; p++) {
|
|
# CHECK-NEXT: for (int i1 = 0; i1 < N_3; i1++) {
|
|
# CHECK-NEXT: for (int j1 = 0; j1 < N_2; j1++) {
|
|
# CHECK-NEXT: for (int k1 = 0; k1 < N_1; k1++) {
|
|
# CHECK-NEXT: for (int l1 = 0; l1 < N; l1++) {
|
|
# CHECK-NEXT: v_X__[i, j, k, l, m, n, o, p, i1, j1, k1, l1] = ((i + j) + j1) + 1;
|
|
# CHECK: v_X___1 = int(0);
|
|
# CHECK-NEXT: for (int i_2 = 0; i_2 < i_1; i_2++) {
|
|
# CHECK-NEXT: for (int j_1 = 0; j_1 < N_2_1; j_1++) {
|
|
# CHECK-NEXT: for (int k_1 = 0; k_1 < N_9; k_1++) {
|
|
# CHECK-NEXT: for (int l_1 = 0; l_1 < N_8; l_1++) {
|
|
# CHECK-NEXT: for (int m_1 = 0; m_1 < N_7; m_1++) {
|
|
# CHECK-NEXT: for (int n_1 = 0; n_1 < N_6; n_1++) {
|
|
# CHECK-NEXT: for (int o_1 = 0; o_1 < N_5; o_1++) {
|
|
# CHECK-NEXT: for (int p_1 = 0; p_1 < N_4; p_1++) {
|
|
# CHECK-NEXT: for (int i1_1 = 0; i1_1 < N_3; i1_1++) {
|
|
# CHECK-NEXT: for (int j1_1 = 0; j1_1 < N_2; j1_1++) {
|
|
# CHECK-NEXT: for (int k1_1 = 0; k1_1 < N_1; k1_1++) {
|
|
# CHECK-NEXT: for (int l1_1 = 0; l1_1 < N; l1_1++) {
|
|
# CHECK-NEXT: v_X___1 = ReduceOp((v_X___1) + (v_X__[i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1]), reduce_args={i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1});
|
|
)IR";
|
|
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|