Fix unbacked symint and memory leak in inductor memory planning (#159839)

Summary:

In memory planning, some allocation sizes involve unbacked symints. These unbacked symints are not known before they are computed in run time, so **allocation pools that involve unbacked symints cannot be allocated until we have the values of the unbacked symints** .

So we add a notion of `earliest_available` to Allocation nodes. If an allocation node has unbacked symint, it is available at only when its live range begin.

Then in AllocationPool, if a pool involves an Allocation node that has an earliest available time, we restrict its life range.

If a block's earliest available time is later than a pool's life range's start time, we cannot allocate it from the pool.

We also fix a memory leak that's caused by allocating tensor without wrapping it with RAIIAtenTensor.

In python wrapper for JIT inductor, `codegen_alloc_from_pool` doesn't actually write the alloc lines to wrapper, it just returns the string to alloc. However, in cpp_wrapper, `codegen_alloc_from_pool`  actually write to the wrapper. Specifically, it writes the following and returns string `RAIIAtenTensorHandle`.

```
AtenTensorHandle handle_name;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(....);
```

This is bug prune. **If you write aoti_torch__alloc_from_pool lines, you must write the RAIIAtenTensorHandle as well**, otherwise you get memory leaks.

We remove the alloc_from_pool call from codegen_create, because this doesn't work for AOTI. In python wrapper, we can generate the same alloc_from_pool variable name for the same block, but cpp_wrapper will generate a different variable name for each call to alloc_from_pool.

Test Plan:
```
 python test/inductor/test_memory_planning.py
```

Rollback Plan:

Differential Revision: D79603119

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159839
Approved by: https://github.com/jansel
This commit is contained in:
Shangdi Yu
2025-08-11 17:16:15 +00:00
committed by PyTorch MergeBot
parent ca7315c171
commit 9ccd0f5e31
4 changed files with 117 additions and 20 deletions

View File

@ -24,6 +24,14 @@ from torch._inductor.utils import run_and_get_cpp_code
from torch.export import Dim
try:
from .test_aot_inductor import AOTIRunnerUtil
except ImportError:
from test_aot_inductor import ( # @manual=fbcode//caffe2/test/inductor:test_aot_inductor-library
AOTIRunnerUtil,
)
@requires_gpu()
@config.patch(memory_planning=True)
class TestMemoryPlanning(TestCase):
@ -76,13 +84,6 @@ class TestMemoryPlanning(TestCase):
@skipIfXpu(msg="aoti doesn't work on XPU")
def test_aoti(self):
try:
from .test_aot_inductor import AOTIRunnerUtil
except ImportError:
from test_aot_inductor import ( # @manual=fbcode//caffe2/test/inductor:test_aot_inductor-library
AOTIRunnerUtil,
)
f, args = self._generate(device=GPU_TYPE)
dim0_x = Dim("dim0_x", min=1, max=2048)
dynamic_shapes = ({0: dim0_x}, None, None)
@ -103,6 +104,54 @@ class TestMemoryPlanning(TestCase):
).check_next("aoti_torch__alloc_from_pool(pool1, 0").run(code)
self.assertTrue(same(f(*args), result))
@config.patch({"triton.autotune_at_compile_time": False})
def test_unbacked_symint(self):
# when allocation's size has unbacked symints
# the unbacked symints are only available after computed
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Repro(torch.nn.Module):
def forward(self, x, y):
x = x + 1
u0 = x.item()
torch._check(u0 >= 1)
s0 = y.size(0)
expr = u0 * s0
sevens = torch.empty_strided(
size=(10, expr, 32), stride=(expr * 32, 32, 1), device=x.device
).fill_(7)
return sevens * 3
example_inputs = (
torch.scalar_tensor(2, dtype=torch.int, device=self.device),
torch.ones(8, device=self.device),
)
model = Repro().to(self.device)
result, code = run_and_get_cpp_code(
lambda: AOTIRunnerUtil.run(model, example_inputs)
)
self.assertTrue(same(model(*example_inputs), result))
# check allocation is done after the unbacked symint is computed
FileCheck().check("auto u0 = u0_raw;").check(
"const int64_t int_array_2[] = {10L, 8L*u0, 32L};"
).check("AtenTensorHandle pool0_handle;").check(
"aoti_torch_empty_strided(3, int_array_2, int_array_3"
).run(code)
# all AtenTensorHandle allocated using aoti_torch__alloc_from_pool are wrapped with RAIIAtenTensorHandle
# otherwise we'll have memory leak
FileCheck().check_count(
"aoti_torch__alloc_from_pool(pool1", 1, exactly=True
).check_count("aoti_torch__alloc_from_pool(pool0", 1, exactly=True).run(code)
FileCheck().check(
"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(pool1, 0, cached_torch_dtype_int32, 0, int_array_1, int_array_1, &tmp_tensor_handle_0));" # noqa: B950
).check("RAIIAtenTensorHandle(tmp_tensor_handle_0);").check(
"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(pool0, 0, cached_torch_dtype_float32, 3, int_array_4, int_array_5, &tmp_tensor_handle_1));" # noqa: B950
).check("RAIIAtenTensorHandle(tmp_tensor_handle_1);").run(code)
if __name__ == "__main__":
if HAS_GPU:

View File

@ -1651,7 +1651,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
return f"RAIIAtenTensorHandle {name}({handle_name});"
def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str:
def codegen_alloc_from_pool(
self, name, offset, dtype, shape, stride
) -> tuple[str, list[str]]:
size = self.codegen_shape_tuple(shape)
stride = self.codegen_shape_tuple(stride)
tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}"
@ -1668,11 +1670,14 @@ class CppWrapperCpu(PythonWrapperCodegen):
),
f"&{tmp_name}",
]
self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};")
self.wrapper_call.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));"
)
return f"RAIIAtenTensorHandle({tmp_name})"
# We return the lines instead of writing here because writing here is bug prune.
# If you write aoti_torch__alloc_from_pool lines, you must write the RAIIAtenTensorHandle
# as well, otherwise you get memory leaks
allocations_to_write = [
f"AtenTensorHandle {tmp_name};",
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));",
]
return f"RAIIAtenTensorHandle({tmp_name})", allocations_to_write
def codegen_reinterpret_view(
self,

View File

@ -10,6 +10,7 @@ from typing import Any, Optional, Protocol, TYPE_CHECKING
import sympy
import torch
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.utils._ordered_set import OrderedSet
from .. import config
@ -142,6 +143,17 @@ class Allocation(AllocationTreeNode):
allocated: bool = False
pool: Optional[AllocationPool] = None
offset: Optional[sympy.Expr] = None
earliest_available: Optional[float] = None
def __post_init__(self) -> None:
has_unbacked_sym = False
for s in self.node.get_layout().size:
if free_unbacked_symbols(s):
has_unbacked_sym = True
break
if has_unbacked_sym:
self.earliest_available = self.get_live_ranges().begin
@property
def device(self):
@ -186,6 +198,9 @@ class Allocation(AllocationTreeNode):
f"offset={self.offset})"
)
def get_earliest_available(self):
return self.earliest_available
@dataclasses.dataclass
class Empty(AllocationTreeNode):
@ -377,14 +392,26 @@ class AllocationPool:
names_to_del: list[str] = dataclasses.field(default_factory=list)
creation_cache: dict[str, str] = dataclasses.field(default_factory=dict)
def __post_init__(self) -> None:
for block in self.root.allocations:
if isinstance(block, Allocation):
self.update_restrict_live_range(block)
def allocate(self, block: Allocation, is_last: bool):
if self.restrict_live_range and not self.restrict_live_range.contains(
block.live_range
if (
self.restrict_live_range is not None
and not self.restrict_live_range.contains(block.live_range)
):
return False
block_earliest_available = block.get_earliest_available()
pool_begin = self.root.get_live_ranges().begin
if block_earliest_available and block_earliest_available > pool_begin:
return False
is_last = self.can_expand and is_last
if self.root.allocate(block, is_last):
self.update_restrict_live_range(block)
return True
if is_last:
@ -392,9 +419,22 @@ class AllocationPool:
return False
def update_restrict_live_range(self, block: Allocation):
if block_earliest_available := block.get_earliest_available():
if self.restrict_live_range is None:
self.restrict_live_range = LiveRange(
block_earliest_available, float("inf")
)
else:
self.restrict_live_range = LiveRange(
min(self.restrict_live_range.begin, block_earliest_available),
self.restrict_live_range.end,
)
def allocate_at_end(self, block):
block.mark_allocated()
self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))])
self.update_restrict_live_range(block)
return True
def finalize(self, name):
@ -408,7 +448,6 @@ class AllocationPool:
nbytes = self.root.get_symbolic_size()
for block in self.root.allocations:
if isinstance(block, Allocation) and nbytes == block.get_symbolic_size():
# optimization: fuse first allocation and pool creation
node = block.node
code.writeline(
wrapper.make_allocation(
@ -419,7 +458,6 @@ class AllocationPool:
stride=tuple(node.get_stride()),
)
)
self.creation_cache[block.codegen_alloc_from_pool(wrapper)] = self.name
return
else:
code.writeline(
@ -577,7 +615,10 @@ class AllocFromPoolLine(PoolMemoryPlanningLine):
pool.codegen_create(self.wrapper, code)
pool.names_to_del.extend(self.group.names)
alloc_from_pool = allocation.codegen_alloc_from_pool(self.wrapper)
alloc_from_pool, allocation_lines_to_write = allocation.codegen_alloc_from_pool(
self.wrapper
)
code.writelines(allocation_lines_to_write)
if alloc_from_pool in pool.creation_cache:
code.writeline(
self.wrapper.make_tensor_alias(

View File

@ -1765,7 +1765,9 @@ class PythonWrapperCodegen(CodeGen):
def codegen_shape_tuple(self, shape: Sequence[Expr]) -> str:
return self.codegen_python_shape_tuple(shape)
def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str:
def codegen_alloc_from_pool(
self, name, offset, dtype, shape, stride
) -> tuple[str, list[str]]:
return "alloc_from_pool({})".format(
", ".join(
[
@ -1776,7 +1778,7 @@ class PythonWrapperCodegen(CodeGen):
self.codegen_python_shape_tuple(stride),
]
)
)
), []
def codegen_reinterpret_view(
self,