[nnc] Insert alloc/free at global scope (#61725)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61725

Alloc/free inside a loop isn't really an optimization, and furthermore
it breaks some attempted optimization in the llvm backend: we use alloca for
small allocations, which is efficient since alloca is on the stack, but there's
no corresponding free, so we leak tons of stack.  I hit this while building an
rfactor buffer inside a very deeply nested loop.
ghstack-source-id: 133627310

Test Plan:
Unit test which simulates use of a temp buffer in a deeply nested
loop.

Reviewed By: navahgar

Differential Revision: D29533364

fbshipit-source-id: c321f4cb05304cfb9146afe32edc4567b623412e
This commit is contained in:
Bert Maher
2021-07-16 08:35:22 -07:00
committed by Facebook GitHub Bot
parent 4c3d9cfe03
commit b963607d50
4 changed files with 44 additions and 28 deletions

View File

@ -1863,8 +1863,8 @@ TEST(LoopNest, LoopNestComputeAt_1) {
Stmt* s = l.root_stmt();
checkIR(s, R"IR(
# CHECK: Allocate(temp); // dtype=int, dims=[1]
# CHECK: for (int i_b = 0; i_b < N; i_b++)
# CHECK: Allocate(temp); // dtype=int, dims=[1]
# CHECK: temp[
# CHECK-NOT: A[
# CHECK: B[i_b] = temp[0]
@ -1931,15 +1931,16 @@ TEST(LoopNest, LoopNestComputeAt_2) {
Stmt* s = l.root_stmt();
// Check the IR we produced
std::clog << *s << "\n";
checkIR(s, R"IR(
# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1]
# CHECK: for (int cy = 0; cy < H; cy++)
# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1]
# CHECK: for
# CHECK: for
# CHECK: for (int cx = 0; cx < W; cx++)
# CHECK-NOT: prod[
# CHECK: cons[
# CHECK: Free(temp))IR");
# CHECK: Free(temp))IR");
// Now check that the loop still produces the correct result.
std::vector<int> c_data(kW * kH, 0);
@ -1958,14 +1959,14 @@ TEST(LoopNest, LoopNestComputeAt_2) {
// Check the IR we produced
checkIR(s, R"IR(
# CHECK: Allocate(temp); // dtype=int, dims=[2, 2]
# CHECK: for (int cy = 0; cy < H; cy++)
# CHECK: for (int cx = 0; cx < W; cx++)
# CHECK: Allocate(temp); // dtype=int, dims=[2, 2]
# CHECK: for
# CHECK: for
# CHECK-NOT: prod[
# CHECK: cons[
# CHECK: Free(temp))IR");
# CHECK: Free(temp))IR");
// Now check that the loop still produces the correct result.
std::vector<int> c_data(kW * kH, 0);
@ -2032,6 +2033,7 @@ TEST(LoopNest, LoopNestComputeAt_3) {
// Check the IR we produced
checkIR(s, R"IR(
# CHECK: Allocate(temp); // dtype=int, dims=[1, W]
# CHECK: for (int ay = 0; ay < H + 1; ay++)
# CHECK: for (int ax = 0; ax < W + 1; ax++)
# CHECK: A[
@ -2042,7 +2044,6 @@ TEST(LoopNest, LoopNestComputeAt_3) {
# CHECK: for (int cx = 0; cx < W; cx++)
# CHECK: C[
# CHECK: for (int dy = 0; dy < H; dy++)
# CHECK: Allocate(temp); // dtype=int, dims=[1, W]
# CHECK: for (int dx = 0; dx < W; dx++)
# CHECK-NOT: A[)IR");
@ -2063,6 +2064,7 @@ TEST(LoopNest, LoopNestComputeAt_3) {
// Check the IR we produced
checkIR(s, R"IR(
# CHECK: Allocate(temp); // dtype=int, dims=[1, 1]
# CHECK: for (int ay = 0; ay < H + 1; ay++)
# CHECK: for (int ax = 0; ax < W + 1; ax++)
# CHECK: A[
@ -2074,7 +2076,6 @@ TEST(LoopNest, LoopNestComputeAt_3) {
# CHECK: C[
# CHECK: for (int dy = 0; dy < H; dy++)
# CHECK: for (int dx = 0; dx < W; dx++)
# CHECK: Allocate(temp); // dtype=int, dims=[1, 1]
# CHECK-NOT: A[)IR");
// Now check that the loop still produces the correct result.
@ -2143,8 +2144,8 @@ TEST(LoopNest, Reduce2dComputeAt) {
l.eliminateDeadStores();
l.prepareForCodegen();
checkIR(l.root_stmt(), R"IR(
# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1]
# CHECK: for (int cy = 0; cy < H; cy++) {
# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1]
# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) {
# CHECK: for (int idx1 = 0; idx1 < W + 1; idx1++) {
# CHECK: temp[(0 + idx0 * (1 * (W + 1))) + idx1 * 1] = (idx0 + cy) * (idx1 + 0);
@ -2158,8 +2159,8 @@ TEST(LoopNest, Reduce2dComputeAt) {
# CHECK: }
# CHECK: }
# CHECK: }
# CHECK: Free(temp);
# CHECK: }
# CHECK: Free(temp);
)IR");
Stmt* s = l.root_stmt();
@ -2178,9 +2179,9 @@ TEST(LoopNest, Reduce2dComputeAt) {
l.eliminateDeadStores();
l.prepareForCodegen();
checkIR(l.root_stmt(), R"IR(
# CHECK: Allocate(temp); // dtype=int, dims=[2, 2]
# CHECK: for (int cy = 0; cy < H; cy++) {
# CHECK: for (int cx = 0; cx < W; cx++) {
# CHECK: Allocate(temp); // dtype=int, dims=[2, 2]
# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) {
# CHECK: for (int idx1 = 0; idx1 < 2; idx1++) {
# CHECK: temp[(0 + idx0 * (1 * 2)) + idx1 * 1] = (cy + idx0) * (cx + idx1);
@ -2192,9 +2193,9 @@ TEST(LoopNest, Reduce2dComputeAt) {
# CHECK: cons[(0 + cy * (1 * W)) + cx * 1] = (cons[(0 + cy * (1 * W)) + cx * 1]) + (temp[(0 + r * (1 * 2)) + s * 1]);
# CHECK: }
# CHECK: }
# CHECK: Free(temp);
# CHECK: }
# CHECK: }
# CHECK: Free(temp);
)IR");
Stmt* s = l.root_stmt();
@ -3758,26 +3759,26 @@ TEST(LoopNest, CacheReadsSimple) {
// just this once: verify the whole thing.
checkIR(result, R"IR(
#CHECK: Allocate(A); // dtype=int, dims=[64, 64]
#CHECK: Allocate(A_local); // dtype=int, dims=[1, 10]
#CHECK: for (int i
#CHECK: for (int j
#CHECK: A[
#CHECK: }
#CHECK: }
#CHECK: for (int i_1
#CHECK: Allocate(A_local); // dtype=int, dims=[1, 10]
#CHECK: for (int j_1
#CHECK: A_local[j_1] = A[
#CHECK: }
#CHECK: for (int j_2
#CHECK: B[10 * i_1 + j_2] = A_local[j_2];
#CHECK: }
#CHECK: Free(A_local);
#CHECK: }
#CHECK: for (int i_2
#CHECK: for (int j_3
#CHECK: C[
#CHECK: }
#CHECK: }
#CHECK: Free(A_local);
#CHECK: Free(A);
)IR");

View File

@ -1548,16 +1548,16 @@ TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) {
oss << *result;
const std::string& expected_ir =
R"IR(
#CHECK: Allocate(d_local); // dtype=float, dims=[1]
#CHECK: sum[l1] = 0
#CHECK: for (int n1
#CHECK: Allocate(d_local); // dtype=float, dims=[1]
#CHECK: d_local[0] = 0
#CHECK: for (int m1
#CHECK: d_local[0] = (d_local[0]) + (scale[
#CHECK: }
#CHECK: sum[l1] = (sum[l1]) + (d_local[0])
#CHECK: Free(d_local);
#CHECK: }
#CHECK: Free(d_local);
#CHECK-NOT: d_local
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
@ -1621,8 +1621,8 @@ TEST(Reductions, ReductionCacheBodyAccess) {
#CHECK: for (int k = 0; k < 12; k++) {
#CHECK: scale_local[k + 12 * j] = scale[(k + 384 * l1) + 12 * j];
#CHECK: sum[l1] = (sum[l1]) + (scale_local[12 * n1_1 + m1_1]);
#CHECK: Free(scale_local);
#CHECK: scale_1[l] = (b[l]) * (sum[l]);
#CHECK: Free(scale_local);
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
}
@ -1660,8 +1660,8 @@ TEST(Reductions, ReductionCacheConsumerAccess) {
oss << *result;
const std::string& expected_ir =
R"IR(
#CHECK: sum[l1] = (sum[l1]) + (scale[
#CHECK: Allocate(sum_local); // dtype=float, dims=[4]
#CHECK: sum[l1] = (sum[l1]) + (scale[
#CHECK: for (int i = 0; i < 4
#CHECK: sum_local[i] = sum[i + 4 * l_outer];
#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]);
@ -1709,8 +1709,8 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) {
oss << *result;
const std::string& expected_ir =
R"IR(
#CHECK: sum[l1_inner + 4 * l1_outer] = (sum[l1_inner + 4 * l1_outer]) + (scale[((12 * n1_1 + 384 * l1_inner) + m1_1) + 1536 * l1_outer]);
#CHECK: Allocate(sum_local); // dtype=float, dims=[4]
#CHECK: sum[l1_inner + 4 * l1_outer] = (sum[l1_inner + 4 * l1_outer]) + (scale[((12 * n1_1 + 384 * l1_inner) + m1_1) + 1536 * l1_outer]);
#CHECK: for (int i = 0; i < 4
#CHECK: sum_local[i] = sum[i + 4 * l_outer];
#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]);
@ -1759,8 +1759,8 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) {
oss << *result;
const std::string& expected_ir =
R"IR(
#CHECK: sum[l1] = (sum[l1]) + (scale[(12 * n1_1 + m1_1) + 384 * l1]);
#CHECK: Allocate(sum_local); // dtype=float, dims=[4]
#CHECK: sum[l1] = (sum[l1]) + (scale[(12 * n1_1 + m1_1) + 384 * l1]);
#CHECK: for (int i = 0; i < 4
#CHECK: sum_local[i] = sum[i + 4 * l_outer];
#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]);
@ -1815,8 +1815,8 @@ TEST(Reductions, ReductionRfactorCacheTempOuter) {
const std::string& expected_ir =
R"IR(
#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n]
#CHECK: Allocate(tmp); // dtype=float, dims=[n]
#CHECK: for (int a = 0; a < m
#CHECK: Allocate(tmp); // dtype=float, dims=[n]
#CHECK: for (int i = 0; i < n
#CHECK: tmp[i] = 0
#CHECK: }
@ -1886,9 +1886,9 @@ TEST(Reductions, ReductionRfactorCacheTempInner) {
const std::string& expected_ir =
R"IR(
#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n]
#CHECK: Allocate(tmp); // dtype=float, dims=[1]
#CHECK: for (int a = 0; a < m
#CHECK: for (int b = 0; b < n
#CHECK: Allocate(tmp); // dtype=float, dims=[1]
#CHECK: tmp[0] = 0
#CHECK: for (int c
#CHECK: tmp[0] = (tmp[0]) + (B[

View File

@ -397,6 +397,25 @@ graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
def test_forgot_kernel_arena(self):
self.assertRaises(RuntimeError, lambda: torch._C._te.VarHandle("n", torch._C._te.Dtype.Int))
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_alloc_in_loop(self):
with kernel_arena_scope():
a, tmp, b = [
te.Placeholder(name, te.Dtype.Float, [te.ExprHandle.int(1)])
for name in ["a", "tmp", "b"]]
t0, t100 = [te.ExprHandle.int(n) for n in [0, 100]]
body = te.Block([
tmp.store([t0], a.load([t0])),
b.store([t0], tmp.load([t0]))
])
for _ in range(4):
i = te.VarHandle("i", te.Dtype.Int)
body = te.For.make(i, t0, t100, body)
nest = te.LoopNest(body, [b.data()])
nest.prepare_for_codegen()
f = te.construct_codegen("llvm", nest.simplify(), [a, b])
ta, tb = [torch.ones(1) for _ in range(2)]
f.call([ta.data_ptr(), tb.data_ptr()])
if __name__ == '__main__':
run_tests()

View File

@ -897,14 +897,10 @@ Stmt* LoopNest::insertAllocFree(Stmt* stmt) {
std::unordered_map<const Buf*, std::vector<BufLoadOrStoreUse>> uses =
findLoadOrStoreUses(stmt);
// Insert allocations and frees for temporary buffers in the innermost
// possible scope.
// Insert allocations and frees for temporary buffers at global scope.
for (const Buf* buf : intermediate_bufs) {
Stmt* alloc = new Allocate(buf);
Stmt* free = new Free(buf);
Block* alloc_block = findLowestContainingBlock(uses.at(buf));
alloc_block->prepend_stmt(alloc);
alloc_block->append_stmt(free);
b->prepend_stmt(new Allocate(buf));
b->append_stmt(new Free(buf));
}
return b;