Files
pytorch/test/cpp/tensorexpr/test_memdependency.cpp
PyTorch MergeBot e288c258f7 Revert "Remove tensorexpr tests (#158928)"
This reverts commit d742a2896c571a535003d5928fe80397325575a5.

Reverted https://github.com/pytorch/pytorch/pull/158928 on behalf of https://github.com/yangw-dev due to this breaks bunch of internal dependency since some tests are still using the deleted test files from this pr, the internal reviewer please help fix this using codev ([comment](https://github.com/pytorch/pytorch/pull/158928#issuecomment-3134378616))
2025-07-29 23:32:07 +00:00

3253 lines
101 KiB
C++

#include <gtest/gtest.h>
#include <test/cpp/tensorexpr/test_base.h>
#include <torch/csrc/jit/tensorexpr/bounds_overlap.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/mem_dependency_checker.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
namespace torch {
namespace jit {
using namespace torch::jit::tensorexpr;
// Test helper function used to determine if two regions of a buffer have an
// overlap. No Overlap & partial overlap is obvious. Contains means A is
// larger and fully encloses B, while ContainedOrEqual is the reverse. Equal
// ranges are ContainedOrEqual.
TEST(MemDependency, BoundOverlap) {
using namespace analysis;
auto CB = [](int s, int e) {
return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
};
// Sanity check 3 overlap cases.
ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(0, 0), CB(0, 0)));
ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 3), CB(2, 5)));
ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 0), CB(1, 1)));
// Partial overlap works in either order.
ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 10), CB(7, 14)));
ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(7, 14), CB(0, 10)));
// Total Overlap works when one bound encloses the other, and returns which.
ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(7, 9)));
ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(0, 16)));
// Total overlap works when the bounds are an identical range, returns
// ContainedOrEqual.
ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(2, 15)));
// Total overlap when only one end of the bound matches.
ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 10)));
ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(3, 15)));
ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(0, 10), CB(0, 9)));
ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 10), CB(2, 15)));
ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(3, 15), CB(2, 15)));
// No overlap when a < b.
ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 2), CB(5, 10)));
ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 2), CB(3, 3)));
ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(100, 120), CB(130, 130)));
// No overlap when a > b.
ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(5, 10), CB(0, 2)));
ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(3, 3), CB(2, 2)));
ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(130, 130), CB(100, 120)));
// No overlap when adjacent.
ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 100), CB(101, 120)));
ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 3), CB(0, 1)));
// Partial overlap when middle bounds match.
ASSERT_EQ(
OverlapKind::PartialOverlap, boundOverlap(CB(0, 100), CB(100, 120)));
ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 2), CB(2, 4)));
ASSERT_EQ(
OverlapKind::PartialOverlap, boundOverlap(CB(100, 120), CB(0, 100)));
ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(2, 3), CB(1, 2)));
// Total overlap when one bound is single length over one end of the other.
ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(15, 15)));
ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 2)));
ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 2), CB(2, 15)));
ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(15, 15), CB(2, 15)));
}
TEST(MemDependency, BoundComparison) {
using namespace analysis;
auto CB = [](int s, int e) {
return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
};
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kEQ));
ASSERT_EQ(
CmpEvalResult::True,
compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kEQ));
ASSERT_EQ(
CmpEvalResult::False,
compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kEQ));
ASSERT_EQ(
CmpEvalResult::False,
compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kEQ));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kEQ));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kEQ));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kNE));
ASSERT_EQ(
CmpEvalResult::False,
compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kNE));
ASSERT_EQ(
CmpEvalResult::True,
compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kNE));
ASSERT_EQ(
CmpEvalResult::True,
compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kNE));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kNE));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kNE));
ASSERT_EQ(
CmpEvalResult::True,
compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLT));
ASSERT_EQ(
CmpEvalResult::False,
compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLT));
ASSERT_EQ(
CmpEvalResult::False,
compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLT));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLT));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLT));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLT));
ASSERT_EQ(
CmpEvalResult::False,
compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGE));
ASSERT_EQ(
CmpEvalResult::True,
compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGE));
ASSERT_EQ(
CmpEvalResult::True,
compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGE));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGE));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGE));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGE));
ASSERT_EQ(
CmpEvalResult::False,
compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGT));
ASSERT_EQ(
CmpEvalResult::False,
compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGT));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGT));
ASSERT_EQ(
CmpEvalResult::True,
compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGT));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGT));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGT));
ASSERT_EQ(
CmpEvalResult::True,
compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLE));
ASSERT_EQ(
CmpEvalResult::True,
compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLE));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLE));
ASSERT_EQ(
CmpEvalResult::False,
compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLE));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLE));
ASSERT_EQ(
CmpEvalResult::NotDetermined,
compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLE));
}
TEST(MemDependency, BoundOverlapSymbolic) {
VarHandle x("x", kInt);
VarHandle y("y", kInt);
VarHandle z("z", kInt);
VarHandle w("w", kInt);
using namespace analysis;
auto CB = [](ExprHandle s, ExprHandle e) {
return Bound(s.node(), e.node());
};
// Sanity check cases where the start and end is symbolic but the diff is
// constant.
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(x, x), CB(x, x)));
ASSERT_EQ(
OverlapKind::PartialOverlap,
boundOverlap(CB(x, x + 3), CB(x + 2, x + 5)));
ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(x, x), CB(x + 1, x + 1)));
// We can't infer the sign of y, so cannot tell whether adding y is larger or
// smaller than y/2.
ASSERT_EQ(
OverlapKind::PartialOverlap,
boundOverlap(CB(x, x + y), CB(x, x + y / 2)));
// No information about this bound, have to take the most conservative option:
// there may be an overlap.
ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(x, y), CB(z, w)));
// Math on opaque terms works.
ASSERT_EQ(
OverlapKind::ContainedOrEqual,
boundOverlap(CB(x + w, y - z), CB(x + w, y - z)));
// Even requiring simplification.
ASSERT_EQ(
OverlapKind::ContainedOrEqual,
boundOverlap(CB(x - w - w, y), CB(x - w * 2, y)));
}
// Tests the helper function for overlap of multi dimensional indices bounds.
// This uses boundOverlap on each dimension and return the "lowest" kind of
// overlap.
TEST(MemDependency, BoundOverlapMultiDim) {
using namespace analysis;
auto CB = [](int s, int e) {
return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
};
// Sanity check one dimensional cases.
ASSERT_EQ(OverlapKind::ContainedOrEqual, overlaps({CB(0, 0)}, {CB(0, 0)}));
ASSERT_EQ(OverlapKind::NoOverlap, overlaps({CB(0, 2)}, {CB(5, 10)}));
ASSERT_EQ(
OverlapKind::PartialOverlap, overlaps({CB(0, 100)}, {CB(100, 120)}));
// Total overlap in 3 dims.
ASSERT_EQ(
OverlapKind::ContainedOrEqual,
overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 4)}));
ASSERT_EQ(
OverlapKind::ContainedOrEqual,
overlaps(
{CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 10)}));
// Total overlap in 2 dims, no overlap in another.
ASSERT_EQ(
OverlapKind::NoOverlap,
overlaps(
{CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(5, 10)}));
// Total overlap in 2 dims, partial overlap in another.
ASSERT_EQ(
OverlapKind::PartialOverlap,
overlaps(
{CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(0, 5), CB(5, 10)}));
// This case is most important, so verify the overlap in any dim. (dim 2)
ASSERT_EQ(
OverlapKind::PartialOverlap,
overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(2, 6), CB(0, 5)}));
// Dim 1.
ASSERT_EQ(
OverlapKind::PartialOverlap,
overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(1, 3), CB(0, 5), CB(0, 5)}));
// Total overlap in 1 dim, partial in 2.
ASSERT_EQ(
OverlapKind::PartialOverlap,
overlaps(
{CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(0, 5), CB(5, 10)}));
// Total overlap, partial overlap, no overlap.
ASSERT_EQ(
OverlapKind::NoOverlap,
overlaps(
{CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(11, 15), CB(0, 5)}));
// Total overlap (B) in 2 dims, total overlap (A) in another.
ASSERT_EQ(
OverlapKind::Contains,
overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 4)}));
// Total overlap (A) in 2 dims, total overlap (B) in another.
ASSERT_EQ(
OverlapKind::Contains,
overlaps(
{CB(0, 12), CB(0, 15), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 14)}));
// Total (B), No Overlap, Total (A).
ASSERT_EQ(
OverlapKind::NoOverlap,
overlaps(
{CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 6), CB(11, 15), CB(1, 2)}));
}
// Test the helper we use to subtract bounds: returns the regions(s) of A which
// remain after removing the region of B.
TEST(MemDependency, BoundSubtract) {
using namespace analysis;
auto CB = [](int s, int e) {
return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
};
auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
return indexBoundsEquals(x, y);
};
// One element subtract.
ASSERT_EQ(subtractBound(CB(0, 0), CB(0, 0)).size(), 0);
ASSERT_EQ(subtractBound(CB(5, 5), CB(5, 5)).size(), 0);
// No Overlap.
ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(2, 2)), {CB(5, 5)}));
ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(0, 4)), {CB(5, 5)}));
// one side overlap.
ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(4, 7)), {CB(1, 3)}));
ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(5, 7)), {CB(0, 4)}));
ASSERT_TRUE(EQ(subtractBound(CB(4, 5), CB(1, 4)), {CB(5, 5)}));
ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 4)), {CB(5, 5)}));
// both sides overlap.
ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 7)), {}));
ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(5, 7)), {}));
// internal overlap.
ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(2, 3)), {CB(1, 1), CB(4, 5)}));
ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(2, 4)), {CB(0, 1), CB(5, 5)}));
}
TEST(MemDependency, BoundSubtractSymbolic) {
VarHandle x("x", kInt);
VarHandle y("y", kInt);
VarHandle z("z", kInt);
VarHandle w("w", kInt);
using namespace analysis;
auto CB = [](ExprHandle s, ExprHandle e) {
return Bound(s.node(), e.node());
};
auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
return indexBoundsEquals(x, y);
};
// One element subtract.
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(x, x)), {}));
ASSERT_TRUE(EQ(subtractBound(CB(x + 1, x + 1), CB(x + 1, x + 1)), {}));
ASSERT_TRUE(EQ(subtractBound(CB(x * 2, x * 2), CB(x * 2, x * 2)), {}));
// Subtract constant range low.
ASSERT_TRUE(
EQ(subtractBound(CB(x, x + 10), CB(x, x + 4)), {CB(x + 5, x + 10)}));
// Subtract constant range high.
ASSERT_TRUE(
EQ(subtractBound(CB(x, x + 10), CB(x + 6, x + 12)), {CB(x, x + 5)}));
// Subtract constant range total overlap.
ASSERT_TRUE(EQ(subtractBound(CB(x, x + 10), CB(x, x + 10)), {}));
ASSERT_TRUE(EQ(subtractBound(CB(x + 2, x + 10), CB(x, x + 12)), {}));
// Subtract constant range internal.
ASSERT_TRUE(
EQ(subtractBound(CB(x, x + 10), CB(x + 3, x + 7)),
{CB(x, x + 2), CB(x + 8, x + 10)}));
// Size is inferable but not constant, only works with a single var.
ASSERT_TRUE(EQ(subtractBound(CB(0, x), CB(0, x * 2)), {}));
ASSERT_TRUE(EQ(subtractBound(CB(0, x * 2), CB(0, x - 1)), {CB(x, x * 2)}));
// Size is not inferable.
ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(z, w)), {CB(x, y)}));
ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(x, z)), {CB(x, y)}));
ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(0, x)), {CB(x, y)}));
ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(0, 0)), {CB(x, x)}));
}
// Tests the helper function that does subtraction, but for multi dimensional
// indices bounds.
TEST(MemDependency, BoundSubtractMultiDim) {
using namespace analysis;
auto CB = [](int s, int e) {
return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
};
auto EQ = [](std::vector<IndexBounds> x, std::vector<IndexBounds> y) {
if (x.size() != y.size()) {
return false;
}
for (auto i = 0U; i < x.size(); ++i) {
if (!indexBoundsEquals(x[i], y[i])) {
return false;
}
}
return true;
};
// sanity check one dimension.
ASSERT_TRUE(EQ(subtractIndicesBounds({CB(0, 9)}, {CB(0, 9)}), {}));
ASSERT_TRUE(EQ(subtractIndicesBounds({CB(3, 9)}, {CB(0, 12)}), {}));
ASSERT_TRUE(
EQ(subtractIndicesBounds({CB(0, 12)}, {CB(0, 9)}), {{CB(10, 12)}}));
ASSERT_TRUE(
EQ(subtractIndicesBounds({CB(0, 12)}, {CB(3, 12)}), {{CB(0, 2)}}));
ASSERT_TRUE(EQ(
subtractIndicesBounds({CB(0, 9)}, {CB(1, 8)}), {{CB(0, 0)}, {CB(9, 9)}}));
// Multi dim total overlap.
ASSERT_TRUE(EQ(
subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 9), CB(0, 2)}), {}));
ASSERT_TRUE(EQ(
subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 10), CB(0, 20)}), {}));
// Multi dim one way partial in dim 1.
ASSERT_TRUE(
EQ(subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 3), CB(0, 2)}),
{{CB(4, 9), CB(0, 2)}}));
// Multi dim one way partial in dim 2.
ASSERT_TRUE(
EQ(subtractIndicesBounds({CB(0, 9), CB(0, 20)}, {CB(0, 9), CB(0, 10)}),
{{CB(0, 9), CB(11, 20)}}));
// Partial overlap in 2 dims.
ASSERT_TRUE(
EQ(subtractIndicesBounds({CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8)}),
{{CB(0, 1), CB(0, 5)}, {CB(2, 5), CB(0, 1)}}));
// Partial overlap in 3 dims.
ASSERT_TRUE(
EQ(subtractIndicesBounds(
{CB(0, 5), CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8), CB(2, 8)}),
{{CB(0, 1), CB(0, 5), CB(0, 5)},
{CB(2, 5), CB(0, 1), CB(0, 5)},
{CB(2, 5), CB(2, 5), CB(0, 1)}}));
}
// Tests the multi dimensional subtraction code for bounds that cannot be fully
// materialized.
TEST(MemDependency, BoundSubtractMultiDimSymbolic) {
VarHandle x("x", kInt);
VarHandle y("y", kInt);
using namespace analysis;
auto CB = [](ExprHandle s, ExprHandle e) {
return Bound(s.node(), e.node());
};
auto EQ = [](std::vector<IndexBounds> x, std::vector<IndexBounds> y) {
if (x.size() != y.size()) {
return false;
}
for (auto i = 0U; i < x.size(); ++i) {
if (!indexBoundsEquals(x[i], y[i])) {
return false;
}
}
return true;
};
// Cannot determine overlaps.
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
ASSERT_TRUE(EQ(subtractIndicesBounds({CB(x, x)}, {CB(0, 0)}), {{CB(x, x)}}));
// Various total Overlaps.
ASSERT_TRUE(EQ(
subtractIndicesBounds({CB(x, x), CB(x, x)}, {CB(x, x), CB(x, x)}), {}));
ASSERT_TRUE(EQ(
subtractIndicesBounds({CB(x, y), CB(x, y)}, {CB(x, y), CB(x, y)}), {}));
ASSERT_TRUE(EQ(
subtractIndicesBounds({CB(x, x), CB(y, y)}, {CB(x, x), CB(y, y)}), {}));
ASSERT_TRUE(EQ(
subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(0, y)}), {}));
// one-way overlap in first dim.
ASSERT_TRUE(
EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x - 5), CB(0, y)}),
{{CB(x - 4, x), CB(0, y)}}));
// second dim.
ASSERT_TRUE(
EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(5, y)}),
{{CB(0, x), CB(0, 4)}}));
// Internal overlap in first dim.
ASSERT_TRUE(
EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(2, x - 5), CB(0, y)}),
{{CB(0, 1), CB(0, y)}, {CB(x - 4, x), CB(0, y)}}));
// second dim.
ASSERT_TRUE(EQ(
subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(10, y - 10)}),
{{CB(0, x), CB(0, 9)}, {CB(0, x), CB(y - 9, y)}}));
// Overlap in both dimensions.
ASSERT_TRUE(
EQ(subtractIndicesBounds(
{CB(0, x), CB(0, y)}, {CB(5, x - 5), CB(10, y - 10)}),
{
{CB(0, 4), CB(0, y)},
{CB(x - 4, x), CB(0, y)},
{CB(0, x), CB(0, 9)},
{CB(0, x), CB(y - 9, y)},
}));
}
// Simple check that the analyzer does anything at all...
TEST(MemDependency, MemDependencyCheckerSimple) {
BufHandle a("A", {1}, kInt);
BufHandle b("B", {1}, kInt);
analysis::MemDependencyChecker analyzer;
/*
* A[0] = 3;
* B[0] = A[0] + 1;
*/
StorePtr aStore = Store::make(a, {0}, 3);
StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1));
StmtPtr stmt = Block::make({aStore, bStore});
stmt->accept(&analyzer);
ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore));
ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore));
// sanity check, but anything that depends directly must depend indirectly.
ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aStore));
}
// Check that there is a difference between direct and indirect dependence.
TEST(MemDependency, MemDependencyCheckerMultiStmt) {
BufHandle a("A", {1}, kInt);
BufHandle b("B", {1}, kInt);
BufHandle c("C", {1}, kInt);
analysis::MemDependencyChecker analyzer;
/*
* A[0] = 3;
* B[0] = A[0];
* C[0] = B[0] + 1;
*/
StorePtr aStore = Store::make(a, {0}, 3);
StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
StorePtr cStore = Store::make(c, {0}, Add::make(Load::make(b, {0}), 1));
StmtPtr stmt = Block::make({aStore, bStore, cStore});
stmt->accept(&analyzer);
// C depends on A indirectly.
ASSERT_FALSE(analyzer.dependsDirectly(cStore, aStore));
ASSERT_TRUE(analyzer.dependsIndirectly(cStore, aStore));
// C depends on B directly, which depends on A directly.
ASSERT_TRUE(analyzer.dependsDirectly(cStore, bStore));
ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore));
// Dependency goes top to bottom only.
ASSERT_FALSE(analyzer.dependsIndirectly(bStore, cStore));
ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore));
ASSERT_FALSE(analyzer.dependsIndirectly(aStore, cStore));
}
// Verify that we do filter writes that are totally overlapped by later writes.
TEST(MemDependency, MemDependencyCheckerOverlap) {
BufHandle a("A", {1}, kInt);
BufHandle b("B", {1}, kInt);
analysis::MemDependencyChecker analyzer;
/*
* A[0] = 3;
* A[0] = 6;
* B[0] = A[0] + 1;
*/
StorePtr aStore = Store::make(a, {0}, 3);
StorePtr a2Store = Store::make(a, {0}, 6);
StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1));
StmtPtr stmt = Block::make({aStore, a2Store, bStore});
stmt->accept(&analyzer);
// B store depends on second A store but not first since it is completely
// overlapped.
ASSERT_TRUE(analyzer.dependsIndirectly(bStore, a2Store));
ASSERT_FALSE(analyzer.dependsIndirectly(bStore, aStore));
// No dependency between either A store.
ASSERT_FALSE(analyzer.dependsIndirectly(aStore, a2Store));
ASSERT_FALSE(analyzer.dependsIndirectly(a2Store, aStore));
}
// Verify that bounds match loop iterations, and that dependencies progress
// across loop scopes.
TEST(MemDependency, MemDependencyCheckerLoop) {
BufHandle a("A", {1}, kInt);
BufHandle b("B", {1}, kInt);
VarHandle x("x", kInt);
using namespace analysis;
MemDependencyChecker analyzer;
/*
* for (int x = 0; x < 10; ++x) {
* A[x] = x;
* }
* B[0] = A[0] + 1;
*/
StorePtr aStore = Store::make(a, {x}, x);
StmtPtr loop = For::make(x, 0, 10, aStore);
StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {4}), 1));
StmtPtr stmt = Block::make({loop, bStore});
stmt->accept(&analyzer);
// Same A->B dependency.
ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore));
// B depends on the loop.
ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop));
// A is in the loop but does not depend on any loop iteration.
ASSERT_FALSE(analyzer.dependsIndirectly(aStore, loop));
auto aStoreAccess = analyzer.accessFor(aStore);
ASSERT_NE(aStoreAccess, nullptr);
// It should have bounds covering the range of x: 0 <= x < 10.
ASSERT_TRUE(indexBoundsEquals(
aStoreAccess->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
}
// Reductions should promote dependencies as well.
TEST(MemDependency, MemDependencyCheckerLoopReduce) {
BufHandle a("A", {10}, kInt);
BufHandle b("B", {10}, kInt);
VarHandle x("x", kInt);
using namespace analysis;
MemDependencyChecker analyzer;
/*
* A[0] = 0;
* for (int x = 0; x < 10; ++x) {
* A[0] = A[x] + 1;
* }
* B[0] = A[0];
*/
StorePtr aInit = Store::make(a, {0}, 0);
ExprHandle reduce = Sum()(a, 1, {x}, {x});
StorePtr aReduce = Store::make(a, {0}, reduce);
StmtPtr loop = For::make(x, 0, 10, aReduce);
StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
StmtPtr stmt = Block::make({aInit, loop, bStore});
stmt->accept(&analyzer);
// B -> A.
ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce));
// B depends indirectly on the initializer of A, since the reduction depends
// on it.
ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit));
ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit));
ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit));
// B depends on the loop.
ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop));
// A is in the loop and depends on other iterations.
ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop));
// The loop contents depend on the initializer too.
ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit));
// Find loads within the reduction:
auto reduceLoads = NodeFinder<Load>::find(reduce.node());
// Pull out the access for the load inside the loop.
for (auto load : reduceLoads) {
auto loopLoad = analyzer.accessFor(load);
// It should have 10 element long bounds.
ASSERT_TRUE(indexBoundsEquals(
loopLoad->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
}
}
// Lowering a reduction doesn't affect dependency analysis.
TEST(MemDependency, MemDependencyCheckerLoopReduceExpanded) {
BufHandle a("A", {10}, kInt);
BufHandle b("B", {10}, kInt);
VarHandle x("x", kInt);
using namespace analysis;
MemDependencyChecker analyzer;
/*
* A[0] = 0;
* for (int x = 0; x < 10; ++x) {
* A[0] = A[x] + 1;
* }
* B[0] = A[0];
*/
StorePtr aInit = Store::make(a, {0}, 0);
ExprHandle aLoad = Load::make(a, {x});
StorePtr aReduce = Store::make(a, {0}, Add::make(aLoad, 1));
StmtPtr loop = For::make(x, 0, 10, aReduce);
StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
StmtPtr stmt = Block::make({aInit, loop, bStore});
stmt->accept(&analyzer);
// B -> A.
ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce));
// B depends indirectly on the initializer of A, since the reduction depends
// on it.
ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit));
ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit));
ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit));
// B depends on the loop.
ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop));
// A is in the loop and depends on other iterations.
ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop));
// The loop contents depend on the initializer too.
ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit));
// Pull out the access for the store inside the loop.
auto loopLoad = analyzer.accessFor(aLoad.node());
// It should have 10 element long bounds.
ASSERT_TRUE(indexBoundsEquals(
loopLoad->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
}
// Can determine dependencies of outputs, through to inputs.
TEST(MemDependency, MemDependencyCheckerInputsOutputs) {
BufHandle a("A", {10}, kInt);
BufHandle b("B", {10}, kInt);
VarHandle x("x", kInt);
// initialize analyzer with inputs and outputs.
analysis::MemDependencyChecker analyzer({a}, {b});
// Here's a Relu.
/*
* for (int x = 0; x < 10; ++x) {
* B[x] = Max(A[x], 0);
* }
*/
ExprHandle aLoad = Load::make(a, {x});
StorePtr bStore = Store::make(b, {x}, Max::make(aLoad, 0, true));
StmtPtr loop = For::make(x, 0, 10, bStore);
StmtPtr stmt = Block::make({loop});
stmt->accept(&analyzer);
// Output depends indirectly on input.
ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
// aLoad depends directly on the input A.
ASSERT_TRUE(analyzer.dependsDirectly(aLoad.node(), a.node()));
// bStore therefore depends directly on the input A.
ASSERT_TRUE(analyzer.dependsDirectly(bStore, a.node()));
// The output depends directly on the store.
ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore));
// Check AccessInfo based overloads.
auto input = analyzer.input(a.node());
auto output = analyzer.output(b.node());
// Output depends indirectly on input.
ASSERT_TRUE(analyzer.dependsIndirectly(output, input));
// Not directly.
ASSERT_FALSE(analyzer.dependsDirectly(output, input));
// Not in reverse order.
ASSERT_FALSE(analyzer.dependsIndirectly(input, output));
// output -> bStore -> bLoad -> input.
auto storeAccess = analyzer.accessFor(bStore);
auto loadAccess = analyzer.accessFor(aLoad.node());
ASSERT_TRUE(analyzer.dependsDirectly(output, storeAccess));
ASSERT_TRUE(analyzer.dependsDirectly(loadAccess, input));
}
// Can tell if an output does not depend on an input.
TEST(MemDependency, MemDependencyCheckerOutputDoesntDepend) {
BufHandle a("A", {10}, kInt);
BufHandle b("B", {10}, kInt);
VarHandle x("x", kInt);
// initialize analyzer with inputs and outputs.
analysis::MemDependencyChecker analyzer({a}, {b});
// Here's a dumb Relu.
/*
* for (int x = 0; x < 10; ++x) {
* B[x] = Max(x, 0);
* }
*/
StorePtr bStore = Store::make(b, {x}, Max::make(x, 0, true));
StmtPtr loop = For::make(x, 0, 10, bStore);
StmtPtr stmt = Block::make({loop});
stmt->accept(&analyzer);
// Output does not depend indirectly on input.
ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), a.node()));
// The output still depends directly on the store.
ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore));
// Check AccessInfo based overloads.
auto input = analyzer.input(a.node());
auto output = analyzer.output(b.node());
// Output does not depend indirectly on input.
ASSERT_FALSE(analyzer.dependsIndirectly(output, input));
}
// Verify different loop extents produce accesses with different bounds, and
// that later accesses find dependencies that overlap their entire bound range.
TEST(MemDependency, MemDependencyCheckerLoopBounds) {
BufHandle a("A", {10}, kInt);
BufHandle b("B", {10}, kInt);
BufHandle c("C", {10}, kInt);
VarHandle x("x", kInt);
using namespace analysis;
MemDependencyChecker analyzer({a}, {c});
// This enables using the execution order of the loops to determine if some
// loops are self dependent or not.
analyzer.allowLoopExecutionOrderAnalysis();
/*
* for (int x = 1; x < 10; ++x) {
* B[x] = A[x];
* }
* for (int x = 1; x < 9; ++x) {
* B[x] = B[x] * 2;
* }
* for (int x = 3; x < 4; ++x) {
* C[x] = A[x];
* }
* for (int x = 0; x < 10; ++x) {
* C[x] = B[x];
* }
*/
std::vector<StmtPtr> stmts(
{For::make(x, 1, 10, Store::make(b, {x}, Load::make(a, {x}))),
For::make(
x, 1, 9, Store::make(b, {x}, Mul::make(Load::make(b, {x}), 2))),
For::make(x, 3, 4, Store::make(c, {x}, Load::make(a, {x}))),
For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))});
StmtPtr stmt = Block::make(stmts);
stmt->accept(&analyzer);
auto input = analyzer.input(a.node());
auto output = analyzer.output(c.node());
// sanity check Output -> Input.
ASSERT_TRUE(analyzer.dependsIndirectly(output, input));
// Check the For loop dependencies:
// Last write to C depends on both writes to B since they contain the last
// write to at least one element.
ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[1]));
ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[0]));
// The last write to C does not depend on the other write to C.
ASSERT_FALSE(analyzer.dependsIndirectly(stmts[3], stmts[2]));
auto CB = [](int s, int e) {
return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
};
auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
return indexBoundsEquals(x, y);
};
/* 0. Input: A[(0, 9)] - dependents: 1 5
* 1. Load: A[(1, 9)] - depends on: 0 - dependents: 2
* 2. Store: B[(1, 9)] - depends on: 1 - dependents: 3 7
* 3. Load: B[(1, 8)] - depends on: 2 - dependents: 4
* 4. Store: B[(1, 8)] - depends on: 3 - dependents: 7
* 5. Load: A[(3, 3)] - depends on: 0 - dependents: 6
* 6. Store: C[(3, 3)] - depends on: 5
* 7. Load: B[(0, 9)] - depends on: 2 4 - dependents: 8
* 8. Store: C[(0, 9)] - depends on: 7 - dependents: 9
* 9. Output: C[(0, 9)] - depends on: 8
*/
// Now let's look at the bounds of each access.
// There are 9 accesses in this Stmt, so this is exhaustive, we won't do this
// much.
auto history = analyzer.getHistory();
ASSERT_EQ(history.size(), 10);
VarPtr aVar = a.node()->base_handle();
VarPtr bVar = b.node()->base_handle();
VarPtr cVar = c.node()->base_handle();
// The first access is the input A.
ASSERT_EQ(history[0]->type(), AccessType::Input);
ASSERT_EQ(history[0]->var(), aVar);
// It has the bounds of the producing Input.
ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)}));
// sanity check the input we retrieved earlier matches.
ASSERT_EQ(history[0], input);
// The second access is the load of A in the first loop.
ASSERT_EQ(history[1]->type(), AccessType::Load);
ASSERT_EQ(history[1]->var(), aVar);
// It has the bounds of the loop, i.e. start == 1.
ASSERT_TRUE(EQ(history[1]->bounds(), {CB(1, 9)}));
// It reads from A, so it should have a dependency on the last write to this
// range - with is the input.
ASSERT_EQ(history[1]->dependencies().size(), 1);
ASSERT_TRUE(history[1]->hasDependency(history[0]));
// The third access is the store into B in the first loop.
ASSERT_EQ(history[2]->type(), AccessType::Store);
ASSERT_EQ(history[2]->var(), bVar);
// It also has the bounds of the loop, i.e. start == 1.
ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)}));
// The previous load is in its RHS, so it depends on it.
ASSERT_EQ(history[2]->dependencies().size(), 1);
ASSERT_TRUE(history[2]->hasDependency(history[1]));
// The third access is the load from B in the second loop.
ASSERT_EQ(history[3]->type(), AccessType::Load);
ASSERT_EQ(history[3]->var(), bVar);
// It has the bounds of the second loop, i.e. >= 1 < 9.
ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 8)}));
// It reads from B in a smaller range, so should depend on the previous
// store.
ASSERT_EQ(history[3]->dependencies().size(), 1);
ASSERT_TRUE(history[3]->hasDependency(history[2]));
// The fourth: the store to B in the second loop.
ASSERT_EQ(history[4]->type(), AccessType::Store);
ASSERT_EQ(history[4]->var(), bVar);
// It also has the bounds of the second loop.
ASSERT_TRUE(EQ(history[4]->bounds(), {CB(1, 8)}));
// The previous load is in its RHS, so it depends on it as before.
ASSERT_EQ(history[4]->dependencies().size(), 1);
ASSERT_TRUE(history[4]->hasDependency(history[3]));
// The fifth access is the load is from the 3rd loop, and skips previous B
// accesses.
ASSERT_EQ(history[5]->type(), AccessType::Load);
ASSERT_EQ(history[5]->var(), aVar);
// It has the bounds of the third loop: >= 3 < 4.
ASSERT_TRUE(EQ(history[5]->bounds(), {CB(3, 3)}));
// It depends on the last thing to write to A, which is the A input.
ASSERT_EQ(history[5]->dependencies().size(), 1);
ASSERT_TRUE(history[5]->hasDependency(history[0]));
// Sixth: the store into the output C.
ASSERT_EQ(history[6]->type(), AccessType::Store);
ASSERT_EQ(history[6]->var(), cVar);
// It also has the bounds of the third loop.
ASSERT_TRUE(EQ(history[6]->bounds(), {CB(3, 3)}));
// The previous load is in its RHS, so it depends on it as always.
ASSERT_EQ(history[6]->dependencies().size(), 1);
ASSERT_TRUE(history[6]->hasDependency(history[5]));
// The seventh access is the load of B in the fourth loop.
ASSERT_EQ(history[7]->type(), AccessType::Load);
ASSERT_EQ(history[7]->var(), bVar);
// It has the bounds of the final loop, >= 0 < 10
ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)}));
// The bounds of this read are larger than the bounds of the previous write,
// so it depends on both previous Stores to B.
ASSERT_EQ(history[7]->dependencies().size(), 2);
ASSERT_TRUE(history[7]->hasDependency(history[2]));
ASSERT_TRUE(history[7]->hasDependency(history[4]));
// Eight: the final store into the output C.
ASSERT_EQ(history[8]->type(), AccessType::Store);
ASSERT_EQ(history[8]->var(), cVar);
// It also has the bounds of the final loop.
ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)}));
// The previous load is in its RHS, so it depends on it as always.
ASSERT_EQ(history[8]->dependencies().size(), 1);
ASSERT_TRUE(history[8]->hasDependency(history[7]));
// The last access represents the output Buf.
ASSERT_EQ(history[9]->type(), AccessType::Output);
ASSERT_EQ(history[9]->var(), cVar);
// It has the bounds of the output Buf.
ASSERT_TRUE(EQ(history[9]->bounds(), {CB(0, 9)}));
// sanity check the input we retrieved earlier matches.
ASSERT_EQ(history[9], output);
// It depends on the last write to C only.
ASSERT_EQ(history[9]->dependencies().size(), 1);
ASSERT_TRUE(history[9]->hasDependency(history[8]));
}
// Verify that we can still infer bounds when the loop var is offset.
TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) {
BufHandle a("A", {10}, kInt);
BufHandle b("B", {10}, kInt);
VarHandle x("x", kInt);
using namespace analysis;
MemDependencyChecker analyzer({a}, {b});
// This enables using the execution order of the loops to determine if some
// loops are self dependent or not.
analyzer.allowLoopExecutionOrderAnalysis();
/*
* for (int x = 1; x < 10; x++) {
* A[x] = A[x - 1];
* }
* for (int x = 0; x < 9; x++) {
* A[x] = A[x + 1];
* }
* for (int x = 0; x < 9; x++) {
* A[9 - x] = A[8 - x];
* }
* for (int x = 0; x < 10; x++) {
* A[x] = A[9 - x];
* }
* for (int x = 0; x < 10; x++) {
* B[x] = A[x];
* }
*/
StmtPtr stmt = Block::make(
{For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))),
For::make(x, 0, 9, Store::make(a, {x}, Load::make(a, {x + 1}))),
For::make(
x,
0,
9,
Store::make(
a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))),
For::make(
x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))),
For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})))});
stmt->accept(&analyzer);
// Sanity check output depends on Input.
ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
auto CB = [](int s, int e) {
return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
};
auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
return indexBoundsEquals(x, y);
};
/* 0. Input: A[(0, 9)] - dependents: 1
* 1. Load: A[(0, 8)] - depends on: 0 2 - dependents: 2
* 2. Store: A[(1, 9)] - depends on: 1 - dependents: 1 3
* 3. Load: A[(1, 9)] - depends on: 2 - dependents: 4
* 4. Store: A[(0, 8)] - depends on: 3 - dependents: 5 7
* 5. Load: A[(0, 8)] - depends on: 4 - dependents: 6
* 6. Store: A[(1, 9)] - depends on: 5 - dependents: 7
* 7. Load: A[(0, 9)] - depends on: 4 6 8 - dependents: 8
* 8. Store: A[(0, 9)] - depends on: 7 - dependents: 7 9
* 9. Load: A[(0, 9)] - depends on: 8 - dependents: 10
* 10. Store: B[(0, 9)] - depends on: 9 - dependents: 11
* 11. Output: B[(0, 9)] - depends on: 10
*/
// Now let's look at the bounds of each access.
auto history = analyzer.getHistory();
ASSERT_EQ(history.size(), 12);
VarPtr aVar = a.node()->base_handle();
VarPtr bVar = b.node()->base_handle();
// The first access is the input A.
ASSERT_EQ(history[0]->type(), AccessType::Input);
ASSERT_EQ(history[0]->var(), aVar);
// It has the bounds of the producing Input.
ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)}));
// The second access is the load A[x-1].
ASSERT_EQ(history[1]->type(), AccessType::Load);
ASSERT_EQ(history[1]->var(), aVar);
// It has the bounds of the loop modified by the offset of each index, in
// this case -1.
ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 8)}));
// It depends on the input, but also the store in the same loop, since
// different iterations of the loop depend on each other.
ASSERT_EQ(history[1]->dependencies().size(), 2);
ASSERT_TRUE(history[1]->hasDependency(history[0]));
ASSERT_TRUE(history[1]->hasDependency(history[2]));
// The third access is the Store to A[x] in the first loop.
ASSERT_EQ(history[2]->type(), AccessType::Store);
ASSERT_EQ(history[2]->var(), aVar);
// It has no offset on x, so should have the same bounds as the loop.
ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)}));
// The fourth access is the load A[x+1] in the second loop.
ASSERT_EQ(history[3]->type(), AccessType::Load);
ASSERT_EQ(history[3]->var(), aVar);
// It has the bounds of the loop (0 <= x < 9) modified by the offset of each
// index, in this case 1.
ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 9)}));
// This load totally overlaps the previous write to A, so it depends only on
// it and not the input.
ASSERT_EQ(history[3]->dependencies().size(), 1);
ASSERT_TRUE(history[3]->hasDependency(history[2]));
// The fifth access is the store to A[x] in the second loop.
ASSERT_EQ(history[4]->type(), AccessType::Store);
ASSERT_EQ(history[4]->var(), aVar);
// It has no offset on x, so should have the same bounds as the loop.
ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, 8)}));
// The sixth access is the load to A[8 - x] in the third loop.
ASSERT_EQ(history[5]->type(), AccessType::Load);
ASSERT_EQ(history[5]->var(), aVar);
// It has the bounds of the loop (0 <= x < 9) modified by the offset of each
// index, in this case 8 - x.
// This access has a negative stride, which will be normalized.
ASSERT_TRUE(EQ(history[5]->bounds(), {CB(0, 8)}));
// This load totally overlaps the most recent write to A, so it depends only
// on it and not the input or the first write to A.
ASSERT_EQ(history[5]->dependencies().size(), 1);
ASSERT_TRUE(history[5]->hasDependency(history[4]));
// The seventh access is the store to A[9 - x] in the third loop.
ASSERT_EQ(history[6]->type(), AccessType::Store);
ASSERT_EQ(history[6]->var(), aVar);
// This store has a negative stride on it's indices, but is normalized
// internally.
ASSERT_TRUE(EQ(history[6]->bounds(), {CB(1, 9)}));
// The eighth access is the load A[9-x] in the second loop.
ASSERT_EQ(history[7]->type(), AccessType::Load);
ASSERT_EQ(history[7]->var(), aVar);
// It has the bounds of the loop (0 <= x < 9), modified by the offset 9 - x,
// which essentially traverses the loop backwards.
ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)}));
// This Load has three write dependencies:
ASSERT_EQ(history[7]->dependencies().size(), 3);
// * The previous store (#6) for elements 1-9
ASSERT_TRUE(history[7]->hasDependency(history[6]));
// * An earlier store (#4) covering element 0
ASSERT_TRUE(history[7]->hasDependency(history[4]));
// * A future store inside this loop, since this loop modifies the buffer
// in a non distinct way (due to the load and store having different access
// strides).
ASSERT_TRUE(history[7]->hasDependency(history[8]));
// The ninth access is the store to A[x] in the fourth loop.
ASSERT_EQ(history[8]->type(), AccessType::Store);
ASSERT_EQ(history[8]->var(), aVar);
// This store has a negative stride on it's indices, but is normalized
// internally.
ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)}));
// The tenth and 11th accesses are the copy from A[x] to B[x].
ASSERT_EQ(history[9]->type(), AccessType::Load);
ASSERT_EQ(history[9]->var(), aVar);
ASSERT_EQ(history[10]->type(), AccessType::Store);
ASSERT_EQ(history[10]->var(), bVar);
// The last access represents the output Buf.
ASSERT_EQ(history[11]->type(), AccessType::Output);
ASSERT_EQ(history[11]->var(), bVar);
// It has the bounds of the output Buf.
ASSERT_TRUE(EQ(history[11]->bounds(), {CB(0, 9)}));
// It depends on the last write to B only.
ASSERT_EQ(history[11]->dependencies().size(), 1);
ASSERT_TRUE(history[11]->hasDependency(history[10]));
// ok that's enough of that.
}
// Check many different cases of loop self dependency - when a load within a
// loop is dependent on a Store later in the same loop but in different
// iteration. This is affected by whether or not we can trust the execution
// order of the loop.
TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) {
BufHandle a("A", {5}, kInt);
BufHandle b("B", {5}, kInt);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
VarHandle z("z", kInt);
using namespace analysis;
// This check assumes that the Stmt has a single Store with a single Load on
// the RHS.
auto isSelfDependent =
[](const std::vector<std::shared_ptr<AccessInfo>>& history) -> bool {
return history.front()->hasDependency(history.back());
};
{
/* for (int y = 0; y < 10; y++) {
* A[y] = (A[y]) + 1;
* } */
// Not self dependent since all loop iterations use a different y.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
y,
0,
10,
Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), 1))}));
stmt->accept(&analyzer);
ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int y = 0; y < 10; y++) {
* A[y + 1] = (A[y + 1]) + 1;
* }
*/
// Not self dependent due to different y (with offset).
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
y,
0,
10,
Block::make(
{Store::make(a, {y + 1}, Add::make(Load::make(a, {y + 1}), 1))}));
stmt->accept(&analyzer);
ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[0] = (A[0]) + x;
* }
*/
// Is self dependent since all loops use a common constant element of A.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x,
0,
10,
Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[0] = (B[0]) + x;
* }
*/
// Is not self dependent because there is no store to the buffer that is
// read.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x,
0,
10,
Block::make({Store::make(a, {0}, Add::make(Load::make(b, {0}), x))}));
stmt->accept(&analyzer);
ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[y] = (A[y]) + x;
* }
*/
// Is self dependent since all loops use a common symbolic element of A.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x,
0,
10,
Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), x))}));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x] = A[x + 1];
* }
*/
// In this case it depends if we are considering execution order.
MemDependencyChecker analyzer;
StmtPtr stmt =
For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1})));
stmt->accept(&analyzer);
// With analysis of order disabled, this is self dependent since the read
// from X+1 and the write to X+1 could be in reverse order.
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x] = A[x + 1];
* }
*/
MemDependencyChecker analyzer;
analyzer.allowLoopExecutionOrderAnalysis();
StmtPtr stmt =
For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1})));
stmt->accept(&analyzer);
// If order analysis is enabled, this is not dependent since the read for
// each element occurs before the write to that element.
ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 1; x < 10; x++) {
* A[x] = A[x - 1];
* }
*/
MemDependencyChecker analyzer;
StmtPtr stmt =
For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 1; x < 10; x++) {
* A[x] = A[x - 1];
* }
*/
MemDependencyChecker analyzer;
analyzer.allowLoopExecutionOrderAnalysis();
StmtPtr stmt =
For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})));
stmt->accept(&analyzer);
// In this case, even with order analysis the Load is dependent on the
// Store, since the write to X occurs before the read from X.
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 9; x++) {
* A[9 - x] = A[8 - x];
* }
*/
// Still works if the execution order is reversed, so long as the read
// comes before the write.
MemDependencyChecker analyzer;
analyzer.allowLoopExecutionOrderAnalysis();
StmtPtr stmt = For::make(
x,
3,
10,
Store::make(
a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x})));
stmt->accept(&analyzer);
// However here was can determine the A store is earlier in the order than
// the load.
ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 9; x++) {
* A[8 - x] = A[9 - x];
* }
*/
// But not if it doesn't.
MemDependencyChecker analyzer;
analyzer.allowLoopExecutionOrderAnalysis();
StmtPtr stmt = For::make(
x,
3,
10,
Store::make(
a, {ExprHandle(8) - x}, Load::make(a, {ExprHandle(9) - x})));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 9; x++) {
* A[9 - x] = A[8 - x];
* }
*/
// And not if we're not relying on execution order.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x,
3,
10,
Store::make(
a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x})));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 3; x < 10; x++) {
* A[x - 2] = A[x - 1];
* }
*/
// Forward order but negative indices.
MemDependencyChecker analyzer;
analyzer.allowLoopExecutionOrderAnalysis();
StmtPtr stmt =
For::make(x, 3, 10, Store::make(a, {x - 2}, Load::make(a, {x - 1})));
stmt->accept(&analyzer);
// However here was can determine the A store is earlier in the order than
// the load.
ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x * 2] = A[x * 2];
* }
*/
// With an access stride.
MemDependencyChecker analyzer;
// Execution order doesn't matter since the read and the write are totally
// distinct.
StmtPtr stmt =
For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2})));
stmt->accept(&analyzer);
ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x * 2] = A[x * 2 + 1];
* }
*/
// Here we can use the common stride of the accesses to determine they are
// distinct.
// Note, this is the only place (loop self dependency) we use this stride
// to avoid unnecessary dependence.
MemDependencyChecker analyzer;
// Execution order doesn't matter since the read and the write are totally
// distinct.
StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 1})));
stmt->accept(&analyzer);
ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x * 2] = A[x * 2 - 1];
* }
*/
// same if the read is behind the write so long as they are distinct.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 1})));
stmt->accept(&analyzer);
ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x * 2] = A[x * 2 + 2];
* }
*/
// But not if the offset is in the stride.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 2})));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x * 2] = A[x * 2 - 2];
* }
*/
// Works with negative offsets too.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 2})));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x * 2] = A[x * 2 + 7];
* }
*/
// Detects accesses are distinct when offset is large but not a multiple
// of stride.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 7})));
stmt->accept(&analyzer);
ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x * 2] = A[x * 2 + 4];
* }
*/
// Works with offsets which are multiples of the stride.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 4})));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x * 6] = A[x * 6 + 5];
* }
*/
// detects accesses are distinct with large strides when the offset is
// within.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 6}, Load::make(a, {x * 6 + 5})));
stmt->accept(&analyzer);
ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x * 2] = A[x * 6];
* }
*/
// detects accesses are overlapping when stride is different but a
// multiple.
MemDependencyChecker analyzer;
StmtPtr stmt =
For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6})));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x * 4] = A[x * 2];
* }
*/
// still works when the read axis is the smaller stride.
MemDependencyChecker analyzer;
StmtPtr stmt =
For::make(x, 0, 10, Store::make(a, {x * 4}, Load::make(a, {x * 2})));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x * 2] = A[x * 6 + 1];
* }
*/
// detects accesses are distinct when stride is different but a multiple
// and there is an offset.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 1})));
stmt->accept(&analyzer);
ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x * 2] = A[x * 6 + 4];
* }
*/
// The smaller stride determines whether there is overlap.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 4})));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x * 2 + 3] = A[x * 6];
* }
*/
// The smaller stride determines whether there is overlap, not the larger.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 2 + 3}, Load::make(a, {x * 6})));
stmt->accept(&analyzer);
ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x * 2] = A[x * 3 + 1];
* }
*/
// If they have strides with no common multiple > 1, they overlap.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 3 + 1})));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x] = A[x + 10];
* }
*/
// If the offset is greater than the size of the loop, they can't overlap.
MemDependencyChecker analyzer;
StmtPtr stmt =
For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 10})));
stmt->accept(&analyzer);
ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x] = A[9 - x];
* }
*/
// If they have different execution orders they may overlap.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x})));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x * 2] = A[19 - x * 2];
* }
*/
// Or they may not, depending on their start offset and strides.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x,
0,
10,
Store::make(a, {x * 2}, Load::make(a, {ExprHandle(19) - x * 2})));
stmt->accept(&analyzer);
ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x / 2] = A[x / 2];
* }
*/
// If the stride is not monotonic, they overlap.
MemDependencyChecker analyzer;
StmtPtr stmt =
For::make(x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2})));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x / 2] = A[x / 2] + 1;
* }
*/
// If the stride is not monotonic, they overlap - even with an offset.
MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2 + 1})));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = 0; x < 10; x++) {
* A[x % 2] = A[x % 2];
* }
*/
// Mod too...
analysis::MemDependencyChecker analyzer;
StmtPtr stmt = For::make(
x,
0,
10,
Store::make(a, {Mod::make(x, 2)}, Load::make(a, {Mod::make(x, 2)})));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
/* for (int x = y; x < z; x++) {
* A[x] = A[x + 1];
* }
*/
// Still works with symbolic loop extents.
{
MemDependencyChecker analyzer;
StmtPtr stmt =
For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1})));
stmt->accept(&analyzer);
ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
}
{
MemDependencyChecker analyzer;
analyzer.allowLoopExecutionOrderAnalysis();
StmtPtr stmt =
For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1})));
stmt->accept(&analyzer);
ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
}
}
}
// Verify that a strided access still works.
// TODO: actually this only works because of the size of the ranges, revisit
// this test after strided overlap is implemented.
TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) {
BufHandle a("A", {20}, kInt);
BufHandle b("B", {20}, kInt);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
using namespace analysis;
MemDependencyChecker analyzer({a.node()}, {b.node()});
StmtPtr stmt = Block::make(
{For::make(
x, 0, 10, Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))),
For::make(x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2})))
});
stmt->accept(&analyzer);
// Sanity check output depends on input.
ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
// Output has 2 dependencies... the store in each loop.
auto outputAccess = analyzer.output(b.node());
ASSERT_EQ(outputAccess->dependencies().size(), 2);
}
/* TODO(nickg) - this test will fail due to the lack of stride math in Bound
TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) {
BufHandle a("A", {20}, kInt);
BufHandle b("B", {20}, kInt);
BufHandle c("C", {10}, kInt);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
{
analysis::MemDependencyChecker analyzer({a.node()}, {c.node()});
StmtPtr stmt = Block::make(
{For::make(
x,
0,
10,
Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))),
For::make(
x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))),
For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))
});
stmt->accept(&analyzer);
std::cout << *stmt << "\n";
for (auto& wi : analyzer.getHistory()) {
wi->print();
}
}
}*/
// analysis on Stmts using Cond.
TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) {
BufHandle a("A", {10}, kInt);
BufHandle b("B", {10}, kInt);
BufHandle c("C", {10}, kInt);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
using namespace analysis;
{
/* for (int x = 0; x < 10; x++) {
* C[x] = A[x];
* }
* if (y<5 ? 1 : 0) {
* C[0] = (B[0]) + 1;
* } else {
* C[0] = (B[1]) + 1;
* }
*/
// Future usages may depend on accesses in both branches of a condition.
MemDependencyChecker analyzer({a, b}, {c});
StmtPtr stmt = Block::make(
{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
Cond::make(
CompareSelect::make(y, 5, CompareSelectOperation::kLT),
Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)),
Store::make(c, {0}, Add::make(Load::make(b, {1}), 1)))});
stmt->accept(&analyzer);
// Output C should have 3 dependencies, each of the three stores.
auto outputAccess = analyzer.output(c.node());
ASSERT_NE(outputAccess, nullptr);
ASSERT_EQ(outputAccess->dependencies().size(), 3);
// C depends indirectly on A and B.
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
}
{
/* for (int x = 0; x < 10; x++) {
* C[x] = A[x];
* }
* if (y<5 ? 1 : 0) {
* for (int x = 0; x < 10; x++) {
* C[x] = B[x];
* }
* } else {
* for (int x = 0; x < 10; x++) {
* C[x] = (B[x]) + 1;
* }
* }
*/
// Future usages may depend on accesses in both branches of a condition.
MemDependencyChecker analyzer({a, b}, {c});
StmtPtr stmt = Block::make(
{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
Cond::make(
CompareSelect::make(y, 5, CompareSelectOperation::kLT),
For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))),
For::make(
x,
0,
10,
Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))});
stmt->accept(&analyzer);
// Output C should have 3 dependencies, each of the three stores.
auto outputAccess = analyzer.output(c.node());
ASSERT_NE(outputAccess, nullptr);
ASSERT_EQ(outputAccess->dependencies().size(), 3);
// TODO(nickg): actually since the true and false branch cover the total
// range of the first store this should have 2 dependencies, but we don't
// do that yet.
// C depends indirectly on A and B.
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
}
{
/* for (int x = 0; x < 10; x++) {
* C[x] = A[x];
* }
* if (y<5 ? 1 : 0) {
* for (int x = 0; x < 10; x++) {
* C[x] = (B[x]) + 1;
* }
* }
*/
// Only has true branch.
MemDependencyChecker analyzer({a, b}, {c});
StmtPtr stmt = Block::make(
{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
Cond::make(
CompareSelect::make(y, 5, CompareSelectOperation::kLT),
For::make(
x,
0,
10,
Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))),
nullptr)});
stmt->accept(&analyzer);
// Output C should have 3 dependencies, each of the three stores.
auto outputAccess = analyzer.output(c.node());
ASSERT_NE(outputAccess, nullptr);
ASSERT_EQ(outputAccess->dependencies().size(), 2);
// C depends indirectly on A and B.
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
}
{
/* for (int x = 0; x < 10; x++) {
* C[x] = A[x];
* }
* if (y<5 ? 1 : 0) {
* } else {
* for (int x = 0; x < 10; x++) {
* C[x] = (B[x]) + 1;
* }
* }
*/
// Only has false branch.
MemDependencyChecker analyzer({a, b}, {c});
StmtPtr stmt = Block::make(
{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
Cond::make(
CompareSelect::make(y, 5, CompareSelectOperation::kLT),
nullptr,
For::make(
x,
0,
10,
Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))});
stmt->accept(&analyzer);
// Output C should have 3 dependencies, each of the three stores.
auto outputAccess = analyzer.output(c.node());
ASSERT_NE(outputAccess, nullptr);
ASSERT_EQ(outputAccess->dependencies().size(), 2);
// C depends indirectly on A and B.
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
}
{
/* for (int x = 0; x < 10; x++) {
* C[x] = A[x];
* }
* if (C[0]<5 ? 1 : 0) {
* C[0] = 5;
* }
*/
// Cond's Condition depends on a previous access.
MemDependencyChecker analyzer({a}, {c});
StorePtr initStore = Store::make(c, {x}, Load::make(a, {x}));
ExprHandle conditionalLoad = Load::make(c, {0});
StmtPtr stmt = Block::make(
{For::make(x, 0, 10, initStore),
Cond::make(
CompareSelect::make(
conditionalLoad, 5, CompareSelectOperation::kLT),
Store::make(c, {0}, 5),
nullptr)});
stmt->accept(&analyzer);
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
ASSERT_TRUE(analyzer.dependsDirectly(conditionalLoad.node(), initStore));
ASSERT_FALSE(analyzer.dependsDirectly(conditionalLoad.node(), a.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(conditionalLoad.node(), a.node()));
}
}
// Stmts using IfThenElse.
TEST(MemDependency, MemDependencyCheckerIfThenElse) {
BufHandle a("A", {10}, kInt);
BufHandle b("B", {10}, kInt);
BufHandle c("C", {10}, kInt);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
using namespace analysis;
{
/* for (int x = 0; x < 10; x++) {
* C[x] = A[x];
* }
* C[0] = (y < 5 ? (B[0]) + 1 : (B[1]) + 1;
*/
// Future usages may depend on accesses in both branches of a condition.
MemDependencyChecker analyzer({a, b}, {c});
StorePtr ifStore = Store::make(
c,
{0},
IfThenElse::make(
CompareSelect::make(y, 5, CompareSelectOperation::kLT),
Add::make(Load::make(b, {0}), 1),
Add::make(Load::make(b, {1}), 1)));
StmtPtr stmt = Block::make(
{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
ifStore});
stmt->accept(&analyzer);
// Output C should have 2 dependencies, each of the two stores.
auto outputAccess = analyzer.output(c.node());
ASSERT_NE(outputAccess, nullptr);
ASSERT_EQ(outputAccess->dependencies().size(), 2);
// Now we need to check the Store containing the IfThenElse.
auto ifStoreAccess = analyzer.accessFor(ifStore);
// It should have 2 dependencies.
ASSERT_EQ(ifStoreAccess->dependencies().size(), 2);
// C depends indirectly on A and B.
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
}
{
/* for (int x = 0; x < 10; x++) {
* C[x] = A[x];
* }
* C[0] = (y < 5 ? (B[0]) + 1 : 42;
*/
// If the load appears in only one side of an IfThenElse the output may be
// dependent on it.
MemDependencyChecker analyzer({a, b}, {c});
StorePtr ifStore = Store::make(
c,
{0},
IfThenElse::make(
CompareSelect::make(y, 5, CompareSelectOperation::kLT),
Add::make(Load::make(b, {0}), 1),
42));
StmtPtr stmt = Block::make(
{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
ifStore});
stmt->accept(&analyzer);
// C depends indirectly on A and B.
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
}
{
/* for (int x = 0; x < 10; x++) {
* C[x] = (x < 5 ? B[x] : A[x];
* }
*/
// In this case C is dependent on both A and B.
// TODO: in cases like this it would be possible to split the range of B
// into two bounds, one dependent on A and one dependent on B. We'd need to
// examine conditions relative to previously encountered loop variables. I'm
// uncertain if this would be helpful.
MemDependencyChecker analyzer({a, b}, {c});
StorePtr ifStore = Store::make(
c,
{0},
IfThenElse::make(
CompareSelect::make(y, 5, CompareSelectOperation::kLT),
Load::make(b, {x}),
Load::make(a, {x})));
StmtPtr stmt = Block::make({For::make(x, 0, 10, ifStore)});
stmt->accept(&analyzer);
// C depends indirectly on A and B.
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
}
}
// Cutting a loop with single elem writes
TEST(MemDependency, MemDependencyCheckerCutLoop) {
BufHandle a("A", {10}, kInt);
BufHandle b("B", {10}, kInt);
VarHandle x("x", kInt);
using namespace analysis;
{
/* for (int x = 0; x < 10; x++) {
* B[x] = A[x];
* }
* B[5] = 100;
*/
// Cutting a loop with single element writes.
MemDependencyChecker analyzer({a}, {b});
StmtPtr stmt = Block::make(
{For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))),
Store::make(b, {5}, 100)});
stmt->accept(&analyzer);
// Output depends on input.
ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
// Output has 2 dependencies.
auto outputAccess = analyzer.output(b.node());
ASSERT_NE(outputAccess, nullptr);
ASSERT_EQ(outputAccess->dependencies().size(), 2);
}
{
/* for (int x = 0; x < 10; x++) {
* B[x] = A[x];
* }
* for (int x = 4; x < 7; x++) {
* B[x] = B[x] + 3;
* }
* B[5] = 100;
* B[6] = 101;
* B[7] = 102;
*/
// Cutting a loop with a smaller loop but then totally overlap that second
// loop with one element writes.
MemDependencyChecker analyzer({a}, {b});
ForPtr firstLoop =
For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})));
StorePtr secondStore =
Store::make(b, {x}, Add::make(Load::make(b, {x}), 1));
ForPtr secondLoop = For::make(x, 4, 7, secondStore);
StmtPtr stmt = Block::make(
{firstLoop,
secondLoop,
Store::make(b, {4}, 100),
Store::make(b, {5}, 101),
Store::make(b, {6}, 102)});
stmt->accept(&analyzer);
// Output depends on input.
ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
// Output has 4 dependencies.
auto outputAccess = analyzer.output(b.node());
ASSERT_NE(outputAccess, nullptr);
ASSERT_EQ(outputAccess->dependencies().size(), 4);
// Second loop depends on first loop.
ASSERT_TRUE(analyzer.dependsDirectly(secondLoop, firstLoop));
// Output does not depend on second loop or store.
ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondLoop));
ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondStore));
}
}
// Dynamic shapes (load in indices).
TEST(MemDependency, MemDependencyCheckerDynamicShapes) {
BufHandle a("A", {100}, kInt);
BufHandle b("B", {100}, kInt);
BufHandle c("C", {100}, kInt);
VarHandle x("x", kInt);
using namespace analysis;
auto CB = [](ExprHandle s, ExprHandle e) {
return Bound(s.node(), e.node());
};
auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
return indexBoundsEquals(x, y);
};
{
/* for (int x = 0; x < B[0]; x++) {
* C[x] = A[x];
* }
*/
MemDependencyChecker analyzer({a, b}, {c});
StmtPtr stmt = Block::make({For::make(
x, 0, Load::make(b, {0}), Store::make(c, {x}, Load::make(a, {x})))});
stmt->accept(&analyzer);
/* 0. Input: B[(0, 99)] - dependents: 2
* 1. Input: A[(0, 99)] - dependents: 3
* 2. Load: B[(0, 0)] - depends on: 0 - dependents: 3 4
* 3. Load: A[(0, (B[0]) - 1)] - depends on: 1 2 - dependents: 4
* 4. Store: C[(0, (B[0]) - 1)] - depends on: 2 3 - dependents: 5
* 5. Output: C[(0, 99)] - depends on: 4
*/
// Output dependent on A input.
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
// Also dependent on B input to determine the size of the region written.
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
auto history = analyzer.getHistory();
ASSERT_EQ(history.size(), 6);
// The accesses in the loop depend on the load in the stop condition.
ASSERT_TRUE(history[4]->hasDependency(history[2]));
ASSERT_TRUE(history[3]->hasDependency(history[2]));
// Make a load from B to compare against.
ExprHandle loadFromB = Load::make(b, {0});
ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, loadFromB - 1)}));
ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, loadFromB - 1)}));
}
{
/* for (int x = B[0]; x < B[1]; x++) {
* C[x] = A[x];
* }
*/
MemDependencyChecker analyzer({a, b}, {c});
StmtPtr stmt = Block::make({For::make(
x,
Load::make(b, {0}),
Load::make(b, {1}),
Store::make(c, {x}, Load::make(a, {x})))});
stmt->accept(&analyzer);
/* 0. Input: B[(0, 99)] - dependents: 2 3
* 1. Input: A[(0, 99)] - dependents: 4
* 2. Load: B[(0, 0)] - depends on: 0 - dependents: 4 5
* 3. Load: B[(1, 1)] - depends on: 0 - dependents: 4 5
* 4. Load: A[(B[0], (B[1]) - 1)] - depends on: 1 2 3 - dependents: 5
* 5. Store: C[(B[0], (B[1]) - 1)] - depends on: 2 3 4 - dependents: 6
* 6. Output: C[(0, 99)] - depends on: 5
*/
// Sanity check output depends on input.
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
auto history = analyzer.getHistory();
ASSERT_EQ(history.size(), 7);
// The accesses in the loop depend on the load in the start condition.
ASSERT_TRUE(history[5]->hasDependency(history[2]));
ASSERT_TRUE(history[4]->hasDependency(history[2]));
// also the stop condition.
ASSERT_TRUE(history[5]->hasDependency(history[3]));
ASSERT_TRUE(history[4]->hasDependency(history[3]));
// Make loads from B to compare against.
ExprHandle loadFromB0 = Load::make(b, {0});
ExprHandle loadFromB1 = Load::make(b, {1});
ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromB0, loadFromB1 - 1)}));
ASSERT_TRUE(EQ(history[5]->bounds(), {CB(loadFromB0, loadFromB1 - 1)}));
}
{
/* for (int x = 0; x < 10; x++) {
* C[x] = A[B[x]];
* }
*/
MemDependencyChecker analyzer({a, b}, {c});
StmtPtr stmt = Block::make({For::make(
x, 0, 10, Store::make(c, {x}, Load::make(a, {Load::make(b, {x})})))});
stmt->accept(&analyzer);
/* 0. Input: B[(0, 99)] - dependents: 2
* 1. Input: A[(0, 99)] - dependents: 3
* 2. Load: B[(0, 9)] - depends on: 0 - dependents: 3 4
* 3. Load: A[(B[0], B[9])] - depends on: 1 2 - dependents: 4
* 4. Store: C[(0, 9)] - depends on: 2 3 - dependents: 5
* 5. Output: C[(0, 99)] - depends on: 4
*/
// Sanity check output depends on input.
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
auto history = analyzer.getHistory();
ASSERT_EQ(history.size(), 6);
// The store depends on both loads, the load of A depends on the load of B.
ASSERT_TRUE(history[4]->hasDependency(history[2]));
ASSERT_TRUE(history[4]->hasDependency(history[3]));
ASSERT_TRUE(history[3]->hasDependency(history[2]));
// The loads in the indices depend on the relevant input buffer.
ASSERT_TRUE(history[3]->hasDependency(history[1]));
ASSERT_TRUE(history[2]->hasDependency(history[0]));
// The load from B has the loop bounds.
ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)}));
// The load from A has bounds B[0] to B[9].
ExprHandle loadFromB0 = Load::make(b, {0});
ExprHandle loadFromB9 = Load::make(b, {9});
ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromB0, loadFromB9)}));
}
{
/* for (int x = 0; x < 10; x++) {
* C[B[x]] = A[x];
* }
*/
MemDependencyChecker analyzer({a, b}, {c});
StmtPtr stmt = Block::make({For::make(
x, 0, 10, Store::make(c, {Load::make(b, {x})}, Load::make(a, {x})))});
stmt->accept(&analyzer);
/* 0. Input: B[(0, 99)] - dependents: 3
* 1. Input: A[(0, 99)] - dependents: 2
* 2. Load: A[(0, 9)] - depends on: 1 - dependents: 4
* 3. Load: B[(0, 9)] - depends on: 0 - dependents: 4
* 4. Store: C[(B[0], B[9])] - depends on: 2 3 - dependents: 5
* 5. Output: C[(0, 99)] - depends on: 4
*/
// Sanity check output depends on input.
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
auto history = analyzer.getHistory();
ASSERT_EQ(history.size(), 6);
// The store depends on both loads, neither load is dependent.
ASSERT_TRUE(history[4]->hasDependency(history[2]));
ASSERT_TRUE(history[4]->hasDependency(history[3]));
ASSERT_FALSE(history[3]->hasDependency(history[2]));
ASSERT_FALSE(history[2]->hasDependency(history[3]));
// The loads each depend on their relevant input. (but accesses are in a
// different order than the last case).
ASSERT_TRUE(history[3]->hasDependency(history[0]));
ASSERT_TRUE(history[2]->hasDependency(history[1]));
// The load from B has the loop bounds.
ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, 9)}));
// And so does the load from A.
ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)}));
}
{
/* for (int x = 0; x < 10; x++) {
* C[B[A[x]]] = x;
* }
*/
MemDependencyChecker analyzer({a, b}, {c});
StmtPtr stmt = Block::make({For::make(
x, 0, 10, Store::make(c, {Load::make(b, {Load::make(a, {x})})}, x))});
stmt->accept(&analyzer);
/* 0. Input: B[(0, 99)] - dependents: 3
* 1. Input: A[(0, 99)] - dependents: 2
* 2. Load: A[(0, 9)] - depends on: 1 - dependents: 3 4
* 3. Load: B[(A[0], A[9])] - depends on: 0 2 - dependents: 4
* 4. Store: C[(B[A[0]], B[A[9]])] - depends on: 2 3 - dependents: 5
* 5. Output: C[(0, 99)] - depends on: 4
*/
// Sanity check output depends on input.
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
auto history = analyzer.getHistory();
ASSERT_EQ(history.size(), 6);
// The store depends on both loads.
ASSERT_TRUE(history[4]->hasDependency(history[2]));
ASSERT_TRUE(history[4]->hasDependency(history[3]));
// The outer load depends on the inner.
ASSERT_TRUE(history[3]->hasDependency(history[2]));
// The loads each depend on their relevant input. (but accesses are in a
// different order than the last case).
ASSERT_TRUE(history[3]->hasDependency(history[0]));
ASSERT_TRUE(history[2]->hasDependency(history[1]));
// The load from A has the loop bounds.
ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)}));
// The load from B as bounds A[0] to A[9].
ExprHandle loadFromA0 = Load::make(a, {0});
ExprHandle loadFromA9 = Load::make(a, {9});
ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromA0, loadFromA9)}));
// The store has bounds of B[A[0]] to B[A[9]].
ExprHandle loadFromBA0 = Load::make(b, {loadFromA0});
ExprHandle loadFromBA9 = Load::make(b, {loadFromA9});
ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromBA0, loadFromBA9)}));
}
}
// Verify multi dimensional bounds work.
TEST(MemDependency, MemDependencyCheckerMultiDim) {
int M = 10, N = 9, K = 12;
BufHandle a("A", {M, N, K}, kInt);
BufHandle b("B", {M, N, K}, kInt);
BufHandle c("C", {M, K}, kInt);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
VarHandle z("z", kInt);
using namespace analysis;
auto CB = [](ExprHandle s, ExprHandle e) {
return Bound(s.node(), e.node());
};
auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
return indexBoundsEquals(x, y);
};
{
/* for (int x = 0; x < 10; x++) {
* for (int y = 0; y < 9; y++) {
* for (int z = 0; z < 12; z++) {
* B[x, y, z] = A[x, y, z];
* }
* }
* }
*/
// Full range.
MemDependencyChecker analyzer({a}, {b});
StmtPtr stmt = Block::make({For::make(
x,
0,
M,
For::make(
y,
0,
N,
For::make(
z,
0,
K,
Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))});
stmt->accept(&analyzer);
// Sanity test: Output depends on input.
ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
// 4 accesses: input, load, store, output.
auto history = analyzer.getHistory();
ASSERT_EQ(history.size(), 4);
// Simple chain from input to output.
ASSERT_TRUE(history[3]->hasDependency(history[2]));
ASSERT_TRUE(history[2]->hasDependency(history[1]));
ASSERT_TRUE(history[1]->hasDependency(history[0]));
ASSERT_TRUE(
EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
ASSERT_TRUE(
EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
}
{
/* for (int x = 0; x < 5; x++) {
* for (int y = 0; y < 5; y++) {
* for (int z = 0; z < 5; z++) {
* B[x, y, z] = A[x, y, z];
* }
* }
* }
*/
// Partial range.
MemDependencyChecker analyzer({a}, {b});
StmtPtr stmt = Block::make({For::make(
x,
0,
5,
For::make(
y,
0,
5,
For::make(
z,
0,
5,
Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))});
stmt->accept(&analyzer);
// Sanity test: Output depends on input.
ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
// 4 accesses: input, load, store, output.
auto history = analyzer.getHistory();
ASSERT_EQ(history.size(), 4);
// Simple chain from input to output.
ASSERT_TRUE(history[3]->hasDependency(history[2]));
ASSERT_TRUE(history[2]->hasDependency(history[1]));
ASSERT_TRUE(history[1]->hasDependency(history[0]));
ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)}));
ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)}));
}
{
/* for (int x = 0; x < 10; x++) {
* for (int y = 0; y < 12; y++) {
* B[x, 0, y] = A[x, 0, y];
* }
* }
*/
// Partial loops.
MemDependencyChecker analyzer({a}, {b});
StmtPtr stmt = Block::make({For::make(
x,
0,
N,
For::make(
y, 0, K, Store::make(b, {x, 0, y}, Load::make(a, {x, 0, y}))))});
stmt->accept(&analyzer);
// Sanity test: Output depends on input.
ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
// 4 accesses: input, load, store, output.
auto history = analyzer.getHistory();
ASSERT_EQ(history.size(), 4);
// Simple chain from input to output.
ASSERT_TRUE(history[3]->hasDependency(history[2]));
ASSERT_TRUE(history[2]->hasDependency(history[1]));
ASSERT_TRUE(history[1]->hasDependency(history[0]));
ASSERT_TRUE(
EQ(history[1]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)}));
ASSERT_TRUE(
EQ(history[2]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)}));
}
{
/* for (int x = 0; x < 10; x++) {
* for (int y = 0; y < 100; y++) {
* for (int z = 0; z < 12; z++) {
* B[x, 0, z] = (A[x, 0, z]) + (C[x, z]);
* }
* }
* }
*/
// Loops that don't correspond to an index, bufs with different
// dimensionality.
MemDependencyChecker analyzer({a, c}, {b});
StmtPtr stmt = Block::make({For::make(
x,
0,
M,
For::make(
y,
0,
100,
For::make(
z,
0,
K,
Store::make(
b,
{x, 0, z},
Add::make(
Load::make(a, {x, 0, z}), Load::make(c, {x, z}))))))});
stmt->accept(&analyzer);
// Sanity test: Output depends on both inputs.
ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), c.node()));
// 6 accesses: 2 inputs, 2 loads, store, output.
auto history = analyzer.getHistory();
ASSERT_EQ(history.size(), 6);
// Simple chain from input to output over the A buf.
// history[0] is the C input, history[3] is the load from C.
ASSERT_TRUE(history[5]->hasDependency(history[4]));
ASSERT_TRUE(history[4]->hasDependency(history[2]));
ASSERT_TRUE(history[2]->hasDependency(history[1]));
// The store also depends on the load from the C input.
ASSERT_TRUE(history[4]->hasDependency(history[3]));
ASSERT_TRUE(history[3]->hasDependency(history[0]));
// A Buf accesses.
ASSERT_TRUE(
EQ(history[4]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)}));
ASSERT_TRUE(
EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)}));
// C buf access.
ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, K - 1)}));
}
{
/* for (int x = 0; x < 9; x++) {
* for (int y = 0; y < 10; y++) {
* for (int z = 0; z < 12; z++) {
* B[x, 0, 0] = (B[x, y, z]) + (A[x, y, z]);
* }
* }
* }
*/
// Multi-dim reductions.
MemDependencyChecker analyzer({a}, {b});
StmtPtr stmt = Block::make({For::make(
x,
0,
M,
For::make(
y,
0,
N,
For::make(
z,
0,
K,
Store::make(
b,
{x, 0, 0},
Add::make(
Load::make(b, {x, y, z}),
Load::make(a, {x, y, z}))))))});
stmt->accept(&analyzer);
// Sanity test: Output depends on input.
ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
// 4 accesses: input, 2 loads, store, output.
auto history = analyzer.getHistory();
ASSERT_EQ(history.size(), 5);
// Simple chain from input to output.
ASSERT_TRUE(history[4]->hasDependency(history[3]));
ASSERT_TRUE(history[3]->hasDependency(history[2]));
ASSERT_TRUE(history[3]->hasDependency(history[1]));
ASSERT_TRUE(history[2]->hasDependency(history[0]));
// The load from B depends on the store to B.
ASSERT_TRUE(history[1]->hasDependency(history[3]));
ASSERT_TRUE(
EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
ASSERT_TRUE(
EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, 0)}));
}
}
// Various tests using the external Compute/Reduce API.
TEST(MemDependency, MemDependencyCheckerComputeAPI) {
using namespace analysis;
/* for (int m = 0; m < 4; m++) {
* for (int n = 0; n < 5; n++) {
* for (int k = 0; k < 6; k++) {
* broadcast_add[m, n, k] = (a[m, n]) + (b[n, k]);
* }
* }
* }
* for (int m_1 = 0; m_1 < 4; m_1++) {
* for (int n_1 = 0; n_1 < 5; n_1++) {
* for (int k_1 = 0; k_1 < 6; k_1++) {
* d[m_1, n_1, k_1] = (broadcast_add(m_1, n_1, k_1)) + float(1);
* }
* }
* }
*/
// Can determine if 2 loops created by Compute are dependent.
BufHandle a_buf("a", {4, 5}, kFloat);
BufHandle b_buf("b", {5, 6}, kFloat);
Tensor c = Compute(
"broadcast_add",
{4, 5, 6},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return a_buf.load(m, n) + b_buf.load(n, k);
});
Tensor d = Compute(
"d",
{4, 5, 6},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return c.load(m, n, k) + 1;
});
LoopNest l({d}, {c, d});
MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()});
l.root_stmt()->accept(&analyzer);
// Sanity test: Output depends on input.
ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node()));
// Second loop depends on first loop.
auto c_loop = l.getLoopStmtsFor(c)[0];
auto d_loop = l.getLoopStmtsFor(d)[0];
ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop));
}
TEST(MemDependency, MemDependencyCheckerComputeInline) {
using namespace analysis;
/* for (int m = 0; m < 4; m++) {
* for (int n = 0; n < 5; n++) {
* for (int k = 0; k < 6; k++) {
* d[m, n, k] = ((a[m, n]) + (b[n, k])) + float(1);
* }
* }
* }
*/
// Check inlining affects the number of accesses returned.
BufHandle a_buf("a", {4, 5}, kFloat);
BufHandle b_buf("b", {5, 6}, kFloat);
Tensor c = Compute(
"broadcast_add",
{4, 5, 6},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return a_buf.load(m, n) + b_buf.load(n, k);
});
Tensor d = Compute(
"d",
{4, 5, 6},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return c.load(m, n, k) + 1;
});
LoopNest l({d}, {c, d});
l.computeInline(c.buf());
MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()});
l.root_stmt()->accept(&analyzer);
// Sanity test: Output depends on input.
ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node()));
// broadcast_add tensor should not appear in trace at all.
for (auto& wi : analyzer.getHistory()) {
ASSERT_NE(wi->var(), c.buf()->base_handle());
}
}
TEST(MemDependency, MemDependencyCheckerComputeSplit) {
using namespace analysis;
// Split an axis, so the number of loops != the number of dimensions.
BufHandle a_buf("a", {4, 5}, kFloat);
BufHandle b_buf("b", {5, 6}, kFloat);
Tensor c = Compute(
"broadcast_add",
{4, 5, 6},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return a_buf.load(m, n) + b_buf.load(n, k);
});
LoopNest l({c});
MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()});
l.root_stmt()->accept(&analyzer_before);
l.splitWithTail(l.getLoopStmtsFor(c)[0], 2);
MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()});
StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
stmt->accept(&analyzer_after);
// Splitting should not change accesses at all.
auto history_before = analyzer_before.getHistory();
auto history_after = analyzer_after.getHistory();
ASSERT_EQ(history_before.size(), history_after.size());
for (size_t i = 0; i < history_before.size(); ++i) {
ASSERT_EQ(history_before[i]->type(), history_after[i]->type());
ASSERT_EQ(history_before[i]->var(), history_after[i]->var());
ASSERT_EQ(
history_before[i]->bounds().size(), history_after[i]->bounds().size());
ASSERT_TRUE(indexBoundsEquals(
history_before[i]->bounds(), history_after[i]->bounds()));
ASSERT_EQ(
history_before[i]->dependencies().size(),
history_after[i]->dependencies().size());
ASSERT_EQ(
history_before[i]->dependents().size(),
history_after[i]->dependents().size());
}
}
TEST(MemDependency, MemDependencyCheckerComputeReorder) {
using namespace analysis;
// Reorder an axis, so the loop order doesn't match the indexing order.
BufHandle a_buf("a", {4, 5}, kFloat);
BufHandle b_buf("b", {5, 6}, kFloat);
Tensor c = Compute(
"broadcast_add",
{4, 5, 6},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return a_buf.load(m, n) + b_buf.load(n, k);
});
LoopNest l({c});
MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()});
l.root_stmt()->accept(&analyzer_before);
auto loops = l.getLoopStmtsFor(c);
l.reorderAxis(loops[0], loops[1]);
MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()});
StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
stmt->accept(&analyzer_after);
// Reordering should not change accesses at all.
auto history_before = analyzer_before.getHistory();
auto history_after = analyzer_after.getHistory();
ASSERT_EQ(history_before.size(), history_after.size());
for (size_t i = 0; i < history_before.size(); ++i) {
ASSERT_EQ(history_before[i]->type(), history_after[i]->type());
ASSERT_EQ(history_before[i]->var(), history_after[i]->var());
ASSERT_EQ(
history_before[i]->bounds().size(), history_after[i]->bounds().size());
ASSERT_TRUE(indexBoundsEquals(
history_before[i]->bounds(), history_after[i]->bounds()));
ASSERT_EQ(
history_before[i]->dependencies().size(),
history_after[i]->dependencies().size());
ASSERT_EQ(
history_before[i]->dependents().size(),
history_after[i]->dependents().size());
}
}
TEST(MemDependency, MemDependencyCheckerComputeReduce) {
using namespace analysis;
/* for (int l2 = 0; l2 < 2; l2++) {
* for (int n1 = 0; n1 < 3; n1++) {
* for (int m1 = 0; m1 < 6; m1++) {
* scale[l2, n1, m1] = (b[l2, n1, m1]) * (a[l2, n1, m1]);
* }
* }
* }
* for (int l1 = 0; l1 < 2; l1++) {
* sum[l1] = float(0);
* for (int n1_1 = 0; n1_1 < 3; n1_1++) {
* for (int m1_1 = 0; m1_1 < 6; m1_1++) {
* sum[l1] = ReduceOp(sum, (sum[l1]) + (scale(l1, n1_1, m1_1)),
* out_args={l1}, reduce_args={n1, m1});
* }
* }
* }
*/
// Can determine dependencies of a Reduction.
BufHandle a("a", {2, 3, 6}, kFloat);
BufHandle b("b", {2, 3, 6}, kFloat);
Tensor c = Compute(
"scale",
{2, 3, 6},
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
return b.load(l, n, m) * a.load(l, n, m);
});
Tensor d = Reduce("sum", {2}, Sum(), c, {3, 6});
LoopNest l({d}, {c, d});
MemDependencyChecker analyzer({a.node(), b.node()}, {d.buf()});
l.root_stmt()->accept(&analyzer);
// Sanity test: Output depends on input.
ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b.node()));
// Second loop depends on first loop.
auto c_loop = l.getLoopStmtsFor(c)[0];
auto d_loop = l.getLoopStmtsFor(d)[0];
ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop));
// Reduction depends on both inputs.
auto reduces = NodeFinder<ReduceOp>::find(l.root_stmt());
ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], a.node()));
ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], b.node()));
}
TEST(MemDependency, MemDependencyCheckerComputeGEMM) {
int M = 1024;
int N = 1024;
int K = 2048;
using namespace analysis;
BufHandle AP("A", {M, K}, kFloat);
BufHandle BP("B", {K, N}, kFloat);
Tensor CT = Reduce(
"gemm",
{M, N},
Sum(),
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
return AP.load(m, k) * BP.load(k, n);
},
{K});
LoopNest loop({CT});
{
auto const& loops = loop.getLoopStmtsFor(CT);
ForPtr m = loops[0];
loop.splitWithMask(m, 4);
}
{
auto const& loops = loop.getLoopStmtsFor(CT);
ForPtr n = loops[2];
loop.splitWithMask(n, 16);
}
// mo, mi, no, ni, k ->
// mo, no, mi, ni, k
{
auto const& loops = loop.getLoopStmtsFor(CT);
ForPtr mi = loops[1];
ForPtr no = loops[2];
loop.reorderAxis(mi, no);
}
// mo, no, mi, ni, k ->
// mo, no, mi, k, ni
{
auto const& loops = loop.getLoopStmtsFor(CT);
ForPtr ni = loops[3];
ForPtr k = loops[4];
loop.reorderAxis(ni, k);
}
// mo, no, mi, k, ni ->
// mo, no, k, mi, ni
{
auto const& loops = loop.getLoopStmtsFor(CT);
ForPtr mi = loops[2];
ForPtr k = loops[3];
loop.reorderAxis(mi, k);
}
{
auto const& loops = loop.getLoopStmtsFor(CT);
loop.cacheAccesses(CT.buf(), "C_regs", loops[2]);
}
MemDependencyChecker analyzer_unlowered(
loop.getInputBufs(), loop.getOutputBufs());
MemDependencyChecker analyzer_lowered(
loop.getInputBufs(), loop.getOutputBufs());
// Test both unlowered and lowered form.
{
StmtPtr stmt = IRSimplifier::simplify(loop.root_stmt());
stmt->accept(&analyzer_unlowered);
// Outputs depend on inputs.
ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), AP.node()));
ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), BP.node()));
// The last write to gemm should cover the total bound of the output.
std::shared_ptr<AccessInfo> outputAccess =
analyzer_unlowered.output(CT.buf());
// A single dependency.
ASSERT_EQ(outputAccess->dependencies().size(), 1);
// dependencies is a set with 1 element, so can just deref begin().
std::shared_ptr<AccessInfo> gemmStore =
outputAccess->dependencies().begin()->second;
// Check its a store.
ASSERT_EQ(gemmStore->type(), AccessType::Store);
ASSERT_TRUE(indexBoundsEquals(outputAccess->bounds(), gemmStore->bounds()));
// Likewise the first read from each input cover the entire range of the
// input.
auto aInput = analyzer_unlowered.input(AP.node());
auto bInput = analyzer_unlowered.input(BP.node());
// A single dependent each.
ASSERT_EQ(aInput->dependents().size(), 1);
ASSERT_EQ(bInput->dependents().size(), 1);
// They're both loads.
std::shared_ptr<AccessInfo> aLoad = aInput->dependents().begin()->second;
std::shared_ptr<AccessInfo> bLoad = bInput->dependents().begin()->second;
ASSERT_EQ(aLoad->type(), AccessType::Load);
ASSERT_EQ(bLoad->type(), AccessType::Load);
ASSERT_TRUE(indexBoundsEquals(aInput->bounds(), aLoad->bounds()));
ASSERT_TRUE(indexBoundsEquals(bInput->bounds(), bLoad->bounds()));
}
loop.prepareForCodegen();
SimpleIREvaluator cg(loop.root_stmt(), {AP, BP, CT});
// now check lowered dependency graph.
{
StmtPtr stmt = IRSimplifier::simplify(cg.stmt());
stmt->accept(&analyzer_lowered);
// Lowering will change the dimensionality of all bounds due to index
// flattening and will insert Allocates and Frees.
auto history_before = analyzer_unlowered.getHistory();
auto history_after = analyzer_lowered.getHistory();
ASSERT_EQ(history_before.size() + 2, history_after.size());
// Filter out the alloc/free;
auto isAllocFree = [](const auto& info) {
return info->type() == AccessType::Alloc ||
info->type() == AccessType::Free;
};
history_after.erase(
std::remove_if(history_after.begin(), history_after.end(), isAllocFree),
history_after.end());
ASSERT_EQ(history_before.size(), history_after.size());
for (size_t i = 0; i < history_before.size(); ++i) {
ASSERT_EQ(history_before[i]->type(), history_after[i]->type());
ASSERT_EQ(history_before[i]->var(), history_after[i]->var());
if (history_before[i]->dependencies().size() !=
history_after[i]->dependencies().size()) {
// Must depend on an Alloc.
ASSERT_TRUE(std::any_of(
history_after[i]->dependencies().begin(),
history_after[i]->dependencies().end(),
[](const auto& pair) {
return pair.second->type() == AccessType::Alloc;
}));
ASSERT_EQ(
history_before[i]->dependencies().size() + 1,
history_after[i]->dependencies().size());
}
if (history_before[i]->dependents().size() !=
history_after[i]->dependents().size()) {
// Must depend on an Free.
ASSERT_TRUE(std::any_of(
history_after[i]->dependents().begin(),
history_after[i]->dependents().end(),
[](const auto& pair) {
return pair.second->type() == AccessType::Free;
}));
ASSERT_EQ(
history_before[i]->dependents().size() + 1,
history_after[i]->dependents().size());
}
// Inputs and outputs are not flattened, only accesses.
if (history_before[i]->type() == AccessType::Input ||
history_before[i]->type() == AccessType::Output) {
ASSERT_EQ(
history_before[i]->bounds().size(),
history_after[i]->bounds().size());
ASSERT_TRUE(indexBoundsEquals(
history_before[i]->bounds(), history_after[i]->bounds()));
} else {
ASSERT_EQ(history_after[i]->bounds().size(), 1);
ExprPtr flat_bounds = alloc<IntImm>(1);
for (auto& b : history_before[i]->bounds()) {
flat_bounds =
alloc<Mul>(flat_bounds, alloc<Add>(b.end, alloc<IntImm>(1)));
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
ASSERT_TRUE(exprEquals(b.start, history_after[i]->bounds()[0].start));
}
flat_bounds = IRSimplifier::simplify(flat_bounds);
ExprPtr after_bounds = IRSimplifier::simplify(
alloc<Add>(history_after[i]->bounds()[0].end, alloc<IntImm>(1)));
ASSERT_TRUE(exprEquals(flat_bounds, after_bounds));
}
}
}
}
} // namespace jit
} // namespace torch