mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix unbacked symint and memory leak in inductor memory planning (#159839)
Summary: In memory planning, some allocation sizes involve unbacked symints. These unbacked symints are not known before they are computed in run time, so **allocation pools that involve unbacked symints cannot be allocated until we have the values of the unbacked symints** . So we add a notion of `earliest_available` to Allocation nodes. If an allocation node has unbacked symint, it is available at only when its live range begin. Then in AllocationPool, if a pool involves an Allocation node that has an earliest available time, we restrict its life range. If a block's earliest available time is later than a pool's life range's start time, we cannot allocate it from the pool. We also fix a memory leak that's caused by allocating tensor without wrapping it with RAIIAtenTensor. In python wrapper for JIT inductor, `codegen_alloc_from_pool` doesn't actually write the alloc lines to wrapper, it just returns the string to alloc. However, in cpp_wrapper, `codegen_alloc_from_pool` actually write to the wrapper. Specifically, it writes the following and returns string `RAIIAtenTensorHandle`. ``` AtenTensorHandle handle_name; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(....); ``` This is bug prune. **If you write aoti_torch__alloc_from_pool lines, you must write the RAIIAtenTensorHandle as well**, otherwise you get memory leaks. We remove the alloc_from_pool call from codegen_create, because this doesn't work for AOTI. In python wrapper, we can generate the same alloc_from_pool variable name for the same block, but cpp_wrapper will generate a different variable name for each call to alloc_from_pool. Test Plan: ``` python test/inductor/test_memory_planning.py ``` Rollback Plan: Differential Revision: D79603119 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159839 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
ca7315c171
commit
9ccd0f5e31
@ -24,6 +24,14 @@ from torch._inductor.utils import run_and_get_cpp_code
|
||||
from torch.export import Dim
|
||||
|
||||
|
||||
try:
|
||||
from .test_aot_inductor import AOTIRunnerUtil
|
||||
except ImportError:
|
||||
from test_aot_inductor import ( # @manual=fbcode//caffe2/test/inductor:test_aot_inductor-library
|
||||
AOTIRunnerUtil,
|
||||
)
|
||||
|
||||
|
||||
@requires_gpu()
|
||||
@config.patch(memory_planning=True)
|
||||
class TestMemoryPlanning(TestCase):
|
||||
@ -76,13 +84,6 @@ class TestMemoryPlanning(TestCase):
|
||||
|
||||
@skipIfXpu(msg="aoti doesn't work on XPU")
|
||||
def test_aoti(self):
|
||||
try:
|
||||
from .test_aot_inductor import AOTIRunnerUtil
|
||||
except ImportError:
|
||||
from test_aot_inductor import ( # @manual=fbcode//caffe2/test/inductor:test_aot_inductor-library
|
||||
AOTIRunnerUtil,
|
||||
)
|
||||
|
||||
f, args = self._generate(device=GPU_TYPE)
|
||||
dim0_x = Dim("dim0_x", min=1, max=2048)
|
||||
dynamic_shapes = ({0: dim0_x}, None, None)
|
||||
@ -103,6 +104,54 @@ class TestMemoryPlanning(TestCase):
|
||||
).check_next("aoti_torch__alloc_from_pool(pool1, 0").run(code)
|
||||
self.assertTrue(same(f(*args), result))
|
||||
|
||||
@config.patch({"triton.autotune_at_compile_time": False})
|
||||
def test_unbacked_symint(self):
|
||||
# when allocation's size has unbacked symints
|
||||
# the unbacked symints are only available after computed
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Repro(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
x = x + 1
|
||||
u0 = x.item()
|
||||
torch._check(u0 >= 1)
|
||||
s0 = y.size(0)
|
||||
expr = u0 * s0
|
||||
sevens = torch.empty_strided(
|
||||
size=(10, expr, 32), stride=(expr * 32, 32, 1), device=x.device
|
||||
).fill_(7)
|
||||
return sevens * 3
|
||||
|
||||
example_inputs = (
|
||||
torch.scalar_tensor(2, dtype=torch.int, device=self.device),
|
||||
torch.ones(8, device=self.device),
|
||||
)
|
||||
model = Repro().to(self.device)
|
||||
result, code = run_and_get_cpp_code(
|
||||
lambda: AOTIRunnerUtil.run(model, example_inputs)
|
||||
)
|
||||
self.assertTrue(same(model(*example_inputs), result))
|
||||
|
||||
# check allocation is done after the unbacked symint is computed
|
||||
FileCheck().check("auto u0 = u0_raw;").check(
|
||||
"const int64_t int_array_2[] = {10L, 8L*u0, 32L};"
|
||||
).check("AtenTensorHandle pool0_handle;").check(
|
||||
"aoti_torch_empty_strided(3, int_array_2, int_array_3"
|
||||
).run(code)
|
||||
|
||||
# all AtenTensorHandle allocated using aoti_torch__alloc_from_pool are wrapped with RAIIAtenTensorHandle
|
||||
# otherwise we'll have memory leak
|
||||
FileCheck().check_count(
|
||||
"aoti_torch__alloc_from_pool(pool1", 1, exactly=True
|
||||
).check_count("aoti_torch__alloc_from_pool(pool0", 1, exactly=True).run(code)
|
||||
|
||||
FileCheck().check(
|
||||
"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(pool1, 0, cached_torch_dtype_int32, 0, int_array_1, int_array_1, &tmp_tensor_handle_0));" # noqa: B950
|
||||
).check("RAIIAtenTensorHandle(tmp_tensor_handle_0);").check(
|
||||
"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(pool0, 0, cached_torch_dtype_float32, 3, int_array_4, int_array_5, &tmp_tensor_handle_1));" # noqa: B950
|
||||
).check("RAIIAtenTensorHandle(tmp_tensor_handle_1);").run(code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if HAS_GPU:
|
||||
|
@ -1651,7 +1651,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
|
||||
return f"RAIIAtenTensorHandle {name}({handle_name});"
|
||||
|
||||
def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str:
|
||||
def codegen_alloc_from_pool(
|
||||
self, name, offset, dtype, shape, stride
|
||||
) -> tuple[str, list[str]]:
|
||||
size = self.codegen_shape_tuple(shape)
|
||||
stride = self.codegen_shape_tuple(stride)
|
||||
tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}"
|
||||
@ -1668,11 +1670,14 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
),
|
||||
f"&{tmp_name}",
|
||||
]
|
||||
self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};")
|
||||
self.wrapper_call.writeline(
|
||||
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));"
|
||||
)
|
||||
return f"RAIIAtenTensorHandle({tmp_name})"
|
||||
# We return the lines instead of writing here because writing here is bug prune.
|
||||
# If you write aoti_torch__alloc_from_pool lines, you must write the RAIIAtenTensorHandle
|
||||
# as well, otherwise you get memory leaks
|
||||
allocations_to_write = [
|
||||
f"AtenTensorHandle {tmp_name};",
|
||||
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));",
|
||||
]
|
||||
return f"RAIIAtenTensorHandle({tmp_name})", allocations_to_write
|
||||
|
||||
def codegen_reinterpret_view(
|
||||
self,
|
||||
|
@ -10,6 +10,7 @@ from typing import Any, Optional, Protocol, TYPE_CHECKING
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .. import config
|
||||
@ -142,6 +143,17 @@ class Allocation(AllocationTreeNode):
|
||||
allocated: bool = False
|
||||
pool: Optional[AllocationPool] = None
|
||||
offset: Optional[sympy.Expr] = None
|
||||
earliest_available: Optional[float] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
has_unbacked_sym = False
|
||||
for s in self.node.get_layout().size:
|
||||
if free_unbacked_symbols(s):
|
||||
has_unbacked_sym = True
|
||||
break
|
||||
|
||||
if has_unbacked_sym:
|
||||
self.earliest_available = self.get_live_ranges().begin
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
@ -186,6 +198,9 @@ class Allocation(AllocationTreeNode):
|
||||
f"offset={self.offset})"
|
||||
)
|
||||
|
||||
def get_earliest_available(self):
|
||||
return self.earliest_available
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Empty(AllocationTreeNode):
|
||||
@ -377,14 +392,26 @@ class AllocationPool:
|
||||
names_to_del: list[str] = dataclasses.field(default_factory=list)
|
||||
creation_cache: dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
for block in self.root.allocations:
|
||||
if isinstance(block, Allocation):
|
||||
self.update_restrict_live_range(block)
|
||||
|
||||
def allocate(self, block: Allocation, is_last: bool):
|
||||
if self.restrict_live_range and not self.restrict_live_range.contains(
|
||||
block.live_range
|
||||
if (
|
||||
self.restrict_live_range is not None
|
||||
and not self.restrict_live_range.contains(block.live_range)
|
||||
):
|
||||
return False
|
||||
|
||||
block_earliest_available = block.get_earliest_available()
|
||||
pool_begin = self.root.get_live_ranges().begin
|
||||
if block_earliest_available and block_earliest_available > pool_begin:
|
||||
return False
|
||||
|
||||
is_last = self.can_expand and is_last
|
||||
if self.root.allocate(block, is_last):
|
||||
self.update_restrict_live_range(block)
|
||||
return True
|
||||
|
||||
if is_last:
|
||||
@ -392,9 +419,22 @@ class AllocationPool:
|
||||
|
||||
return False
|
||||
|
||||
def update_restrict_live_range(self, block: Allocation):
|
||||
if block_earliest_available := block.get_earliest_available():
|
||||
if self.restrict_live_range is None:
|
||||
self.restrict_live_range = LiveRange(
|
||||
block_earliest_available, float("inf")
|
||||
)
|
||||
else:
|
||||
self.restrict_live_range = LiveRange(
|
||||
min(self.restrict_live_range.begin, block_earliest_available),
|
||||
self.restrict_live_range.end,
|
||||
)
|
||||
|
||||
def allocate_at_end(self, block):
|
||||
block.mark_allocated()
|
||||
self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))])
|
||||
self.update_restrict_live_range(block)
|
||||
return True
|
||||
|
||||
def finalize(self, name):
|
||||
@ -408,7 +448,6 @@ class AllocationPool:
|
||||
nbytes = self.root.get_symbolic_size()
|
||||
for block in self.root.allocations:
|
||||
if isinstance(block, Allocation) and nbytes == block.get_symbolic_size():
|
||||
# optimization: fuse first allocation and pool creation
|
||||
node = block.node
|
||||
code.writeline(
|
||||
wrapper.make_allocation(
|
||||
@ -419,7 +458,6 @@ class AllocationPool:
|
||||
stride=tuple(node.get_stride()),
|
||||
)
|
||||
)
|
||||
self.creation_cache[block.codegen_alloc_from_pool(wrapper)] = self.name
|
||||
return
|
||||
else:
|
||||
code.writeline(
|
||||
@ -577,7 +615,10 @@ class AllocFromPoolLine(PoolMemoryPlanningLine):
|
||||
pool.codegen_create(self.wrapper, code)
|
||||
|
||||
pool.names_to_del.extend(self.group.names)
|
||||
alloc_from_pool = allocation.codegen_alloc_from_pool(self.wrapper)
|
||||
alloc_from_pool, allocation_lines_to_write = allocation.codegen_alloc_from_pool(
|
||||
self.wrapper
|
||||
)
|
||||
code.writelines(allocation_lines_to_write)
|
||||
if alloc_from_pool in pool.creation_cache:
|
||||
code.writeline(
|
||||
self.wrapper.make_tensor_alias(
|
||||
|
@ -1765,7 +1765,9 @@ class PythonWrapperCodegen(CodeGen):
|
||||
def codegen_shape_tuple(self, shape: Sequence[Expr]) -> str:
|
||||
return self.codegen_python_shape_tuple(shape)
|
||||
|
||||
def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str:
|
||||
def codegen_alloc_from_pool(
|
||||
self, name, offset, dtype, shape, stride
|
||||
) -> tuple[str, list[str]]:
|
||||
return "alloc_from_pool({})".format(
|
||||
", ".join(
|
||||
[
|
||||
@ -1776,7 +1778,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
self.codegen_python_shape_tuple(stride),
|
||||
]
|
||||
)
|
||||
)
|
||||
), []
|
||||
|
||||
def codegen_reinterpret_view(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user