mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nnc] Insert alloc/free at global scope (#61725)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61725 Alloc/free inside a loop isn't really an optimization, and furthermore it breaks some attempted optimization in the llvm backend: we use alloca for small allocations, which is efficient since alloca is on the stack, but there's no corresponding free, so we leak tons of stack. I hit this while building an rfactor buffer inside a very deeply nested loop. ghstack-source-id: 133627310 Test Plan: Unit test which simulates use of a temp buffer in a deeply nested loop. Reviewed By: navahgar Differential Revision: D29533364 fbshipit-source-id: c321f4cb05304cfb9146afe32edc4567b623412e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4c3d9cfe03
commit
b963607d50
@ -1863,8 +1863,8 @@ TEST(LoopNest, LoopNestComputeAt_1) {
|
||||
Stmt* s = l.root_stmt();
|
||||
|
||||
checkIR(s, R"IR(
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[1]
|
||||
# CHECK: for (int i_b = 0; i_b < N; i_b++)
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[1]
|
||||
# CHECK: temp[
|
||||
# CHECK-NOT: A[
|
||||
# CHECK: B[i_b] = temp[0]
|
||||
@ -1931,15 +1931,16 @@ TEST(LoopNest, LoopNestComputeAt_2) {
|
||||
Stmt* s = l.root_stmt();
|
||||
|
||||
// Check the IR we produced
|
||||
std::clog << *s << "\n";
|
||||
checkIR(s, R"IR(
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1]
|
||||
# CHECK: for (int cy = 0; cy < H; cy++)
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1]
|
||||
# CHECK: for
|
||||
# CHECK: for
|
||||
# CHECK: for (int cx = 0; cx < W; cx++)
|
||||
# CHECK-NOT: prod[
|
||||
# CHECK: cons[
|
||||
# CHECK: Free(temp))IR");
|
||||
# CHECK: Free(temp))IR");
|
||||
|
||||
// Now check that the loop still produces the correct result.
|
||||
std::vector<int> c_data(kW * kH, 0);
|
||||
@ -1958,14 +1959,14 @@ TEST(LoopNest, LoopNestComputeAt_2) {
|
||||
|
||||
// Check the IR we produced
|
||||
checkIR(s, R"IR(
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[2, 2]
|
||||
# CHECK: for (int cy = 0; cy < H; cy++)
|
||||
# CHECK: for (int cx = 0; cx < W; cx++)
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[2, 2]
|
||||
# CHECK: for
|
||||
# CHECK: for
|
||||
# CHECK-NOT: prod[
|
||||
# CHECK: cons[
|
||||
# CHECK: Free(temp))IR");
|
||||
# CHECK: Free(temp))IR");
|
||||
|
||||
// Now check that the loop still produces the correct result.
|
||||
std::vector<int> c_data(kW * kH, 0);
|
||||
@ -2032,6 +2033,7 @@ TEST(LoopNest, LoopNestComputeAt_3) {
|
||||
|
||||
// Check the IR we produced
|
||||
checkIR(s, R"IR(
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[1, W]
|
||||
# CHECK: for (int ay = 0; ay < H + 1; ay++)
|
||||
# CHECK: for (int ax = 0; ax < W + 1; ax++)
|
||||
# CHECK: A[
|
||||
@ -2042,7 +2044,6 @@ TEST(LoopNest, LoopNestComputeAt_3) {
|
||||
# CHECK: for (int cx = 0; cx < W; cx++)
|
||||
# CHECK: C[
|
||||
# CHECK: for (int dy = 0; dy < H; dy++)
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[1, W]
|
||||
# CHECK: for (int dx = 0; dx < W; dx++)
|
||||
# CHECK-NOT: A[)IR");
|
||||
|
||||
@ -2063,6 +2064,7 @@ TEST(LoopNest, LoopNestComputeAt_3) {
|
||||
|
||||
// Check the IR we produced
|
||||
checkIR(s, R"IR(
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[1, 1]
|
||||
# CHECK: for (int ay = 0; ay < H + 1; ay++)
|
||||
# CHECK: for (int ax = 0; ax < W + 1; ax++)
|
||||
# CHECK: A[
|
||||
@ -2074,7 +2076,6 @@ TEST(LoopNest, LoopNestComputeAt_3) {
|
||||
# CHECK: C[
|
||||
# CHECK: for (int dy = 0; dy < H; dy++)
|
||||
# CHECK: for (int dx = 0; dx < W; dx++)
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[1, 1]
|
||||
# CHECK-NOT: A[)IR");
|
||||
|
||||
// Now check that the loop still produces the correct result.
|
||||
@ -2143,8 +2144,8 @@ TEST(LoopNest, Reduce2dComputeAt) {
|
||||
l.eliminateDeadStores();
|
||||
l.prepareForCodegen();
|
||||
checkIR(l.root_stmt(), R"IR(
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1]
|
||||
# CHECK: for (int cy = 0; cy < H; cy++) {
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1]
|
||||
# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) {
|
||||
# CHECK: for (int idx1 = 0; idx1 < W + 1; idx1++) {
|
||||
# CHECK: temp[(0 + idx0 * (1 * (W + 1))) + idx1 * 1] = (idx0 + cy) * (idx1 + 0);
|
||||
@ -2158,8 +2159,8 @@ TEST(LoopNest, Reduce2dComputeAt) {
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
# CHECK: Free(temp);
|
||||
# CHECK: }
|
||||
# CHECK: Free(temp);
|
||||
)IR");
|
||||
Stmt* s = l.root_stmt();
|
||||
|
||||
@ -2178,9 +2179,9 @@ TEST(LoopNest, Reduce2dComputeAt) {
|
||||
l.eliminateDeadStores();
|
||||
l.prepareForCodegen();
|
||||
checkIR(l.root_stmt(), R"IR(
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[2, 2]
|
||||
# CHECK: for (int cy = 0; cy < H; cy++) {
|
||||
# CHECK: for (int cx = 0; cx < W; cx++) {
|
||||
# CHECK: Allocate(temp); // dtype=int, dims=[2, 2]
|
||||
# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) {
|
||||
# CHECK: for (int idx1 = 0; idx1 < 2; idx1++) {
|
||||
# CHECK: temp[(0 + idx0 * (1 * 2)) + idx1 * 1] = (cy + idx0) * (cx + idx1);
|
||||
@ -2192,9 +2193,9 @@ TEST(LoopNest, Reduce2dComputeAt) {
|
||||
# CHECK: cons[(0 + cy * (1 * W)) + cx * 1] = (cons[(0 + cy * (1 * W)) + cx * 1]) + (temp[(0 + r * (1 * 2)) + s * 1]);
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
# CHECK: Free(temp);
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
# CHECK: Free(temp);
|
||||
)IR");
|
||||
Stmt* s = l.root_stmt();
|
||||
|
||||
@ -3758,26 +3759,26 @@ TEST(LoopNest, CacheReadsSimple) {
|
||||
// just this once: verify the whole thing.
|
||||
checkIR(result, R"IR(
|
||||
#CHECK: Allocate(A); // dtype=int, dims=[64, 64]
|
||||
#CHECK: Allocate(A_local); // dtype=int, dims=[1, 10]
|
||||
#CHECK: for (int i
|
||||
#CHECK: for (int j
|
||||
#CHECK: A[
|
||||
#CHECK: }
|
||||
#CHECK: }
|
||||
#CHECK: for (int i_1
|
||||
#CHECK: Allocate(A_local); // dtype=int, dims=[1, 10]
|
||||
#CHECK: for (int j_1
|
||||
#CHECK: A_local[j_1] = A[
|
||||
#CHECK: }
|
||||
#CHECK: for (int j_2
|
||||
#CHECK: B[10 * i_1 + j_2] = A_local[j_2];
|
||||
#CHECK: }
|
||||
#CHECK: Free(A_local);
|
||||
#CHECK: }
|
||||
#CHECK: for (int i_2
|
||||
#CHECK: for (int j_3
|
||||
#CHECK: C[
|
||||
#CHECK: }
|
||||
#CHECK: }
|
||||
#CHECK: Free(A_local);
|
||||
#CHECK: Free(A);
|
||||
)IR");
|
||||
|
||||
|
@ -1548,16 +1548,16 @@ TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) {
|
||||
oss << *result;
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: Allocate(d_local); // dtype=float, dims=[1]
|
||||
#CHECK: sum[l1] = 0
|
||||
#CHECK: for (int n1
|
||||
#CHECK: Allocate(d_local); // dtype=float, dims=[1]
|
||||
#CHECK: d_local[0] = 0
|
||||
#CHECK: for (int m1
|
||||
#CHECK: d_local[0] = (d_local[0]) + (scale[
|
||||
#CHECK: }
|
||||
#CHECK: sum[l1] = (sum[l1]) + (d_local[0])
|
||||
#CHECK: Free(d_local);
|
||||
#CHECK: }
|
||||
#CHECK: Free(d_local);
|
||||
#CHECK-NOT: d_local
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
@ -1621,8 +1621,8 @@ TEST(Reductions, ReductionCacheBodyAccess) {
|
||||
#CHECK: for (int k = 0; k < 12; k++) {
|
||||
#CHECK: scale_local[k + 12 * j] = scale[(k + 384 * l1) + 12 * j];
|
||||
#CHECK: sum[l1] = (sum[l1]) + (scale_local[12 * n1_1 + m1_1]);
|
||||
#CHECK: Free(scale_local);
|
||||
#CHECK: scale_1[l] = (b[l]) * (sum[l]);
|
||||
#CHECK: Free(scale_local);
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
}
|
||||
@ -1660,8 +1660,8 @@ TEST(Reductions, ReductionCacheConsumerAccess) {
|
||||
oss << *result;
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: sum[l1] = (sum[l1]) + (scale[
|
||||
#CHECK: Allocate(sum_local); // dtype=float, dims=[4]
|
||||
#CHECK: sum[l1] = (sum[l1]) + (scale[
|
||||
#CHECK: for (int i = 0; i < 4
|
||||
#CHECK: sum_local[i] = sum[i + 4 * l_outer];
|
||||
#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]);
|
||||
@ -1709,8 +1709,8 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) {
|
||||
oss << *result;
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: sum[l1_inner + 4 * l1_outer] = (sum[l1_inner + 4 * l1_outer]) + (scale[((12 * n1_1 + 384 * l1_inner) + m1_1) + 1536 * l1_outer]);
|
||||
#CHECK: Allocate(sum_local); // dtype=float, dims=[4]
|
||||
#CHECK: sum[l1_inner + 4 * l1_outer] = (sum[l1_inner + 4 * l1_outer]) + (scale[((12 * n1_1 + 384 * l1_inner) + m1_1) + 1536 * l1_outer]);
|
||||
#CHECK: for (int i = 0; i < 4
|
||||
#CHECK: sum_local[i] = sum[i + 4 * l_outer];
|
||||
#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]);
|
||||
@ -1759,8 +1759,8 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) {
|
||||
oss << *result;
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: sum[l1] = (sum[l1]) + (scale[(12 * n1_1 + m1_1) + 384 * l1]);
|
||||
#CHECK: Allocate(sum_local); // dtype=float, dims=[4]
|
||||
#CHECK: sum[l1] = (sum[l1]) + (scale[(12 * n1_1 + m1_1) + 384 * l1]);
|
||||
#CHECK: for (int i = 0; i < 4
|
||||
#CHECK: sum_local[i] = sum[i + 4 * l_outer];
|
||||
#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]);
|
||||
@ -1815,8 +1815,8 @@ TEST(Reductions, ReductionRfactorCacheTempOuter) {
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n]
|
||||
#CHECK: Allocate(tmp); // dtype=float, dims=[n]
|
||||
#CHECK: for (int a = 0; a < m
|
||||
#CHECK: Allocate(tmp); // dtype=float, dims=[n]
|
||||
#CHECK: for (int i = 0; i < n
|
||||
#CHECK: tmp[i] = 0
|
||||
#CHECK: }
|
||||
@ -1886,9 +1886,9 @@ TEST(Reductions, ReductionRfactorCacheTempInner) {
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n]
|
||||
#CHECK: Allocate(tmp); // dtype=float, dims=[1]
|
||||
#CHECK: for (int a = 0; a < m
|
||||
#CHECK: for (int b = 0; b < n
|
||||
#CHECK: Allocate(tmp); // dtype=float, dims=[1]
|
||||
#CHECK: tmp[0] = 0
|
||||
#CHECK: for (int c
|
||||
#CHECK: tmp[0] = (tmp[0]) + (B[
|
||||
|
@ -397,6 +397,25 @@ graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
|
||||
def test_forgot_kernel_arena(self):
|
||||
self.assertRaises(RuntimeError, lambda: torch._C._te.VarHandle("n", torch._C._te.Dtype.Int))
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
||||
def test_alloc_in_loop(self):
|
||||
with kernel_arena_scope():
|
||||
a, tmp, b = [
|
||||
te.Placeholder(name, te.Dtype.Float, [te.ExprHandle.int(1)])
|
||||
for name in ["a", "tmp", "b"]]
|
||||
t0, t100 = [te.ExprHandle.int(n) for n in [0, 100]]
|
||||
body = te.Block([
|
||||
tmp.store([t0], a.load([t0])),
|
||||
b.store([t0], tmp.load([t0]))
|
||||
])
|
||||
for _ in range(4):
|
||||
i = te.VarHandle("i", te.Dtype.Int)
|
||||
body = te.For.make(i, t0, t100, body)
|
||||
nest = te.LoopNest(body, [b.data()])
|
||||
nest.prepare_for_codegen()
|
||||
f = te.construct_codegen("llvm", nest.simplify(), [a, b])
|
||||
ta, tb = [torch.ones(1) for _ in range(2)]
|
||||
f.call([ta.data_ptr(), tb.data_ptr()])
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -897,14 +897,10 @@ Stmt* LoopNest::insertAllocFree(Stmt* stmt) {
|
||||
|
||||
std::unordered_map<const Buf*, std::vector<BufLoadOrStoreUse>> uses =
|
||||
findLoadOrStoreUses(stmt);
|
||||
// Insert allocations and frees for temporary buffers in the innermost
|
||||
// possible scope.
|
||||
// Insert allocations and frees for temporary buffers at global scope.
|
||||
for (const Buf* buf : intermediate_bufs) {
|
||||
Stmt* alloc = new Allocate(buf);
|
||||
Stmt* free = new Free(buf);
|
||||
Block* alloc_block = findLowestContainingBlock(uses.at(buf));
|
||||
alloc_block->prepend_stmt(alloc);
|
||||
alloc_block->append_stmt(free);
|
||||
b->prepend_stmt(new Allocate(buf));
|
||||
b->append_stmt(new Free(buf));
|
||||
}
|
||||
|
||||
return b;
|
||||
|
Reference in New Issue
Block a user