[inductor] dont reuse buffers if it affects peak (#145883) (#159530)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159530
Approved by: https://github.com/eellison
This commit is contained in:
Markus Hoehnerbach
2025-08-15 12:00:45 -07:00
committed by PyTorch MergeBot
parent 62db8ec391
commit 65d21dae18
5 changed files with 608 additions and 2 deletions

View File

@ -0,0 +1,254 @@
# Owner(s): ["module: inductor"]
from hypothesis import given, strategies as st
from torch._inductor.codegen.segmented_tree import SegmentedTree
from torch._inductor.test_case import run_tests, TestCase
# Helper functions for operations
def max_op(a, b):
return max(a, b)
def add_op(a, b):
return a + b
# Naive implementations for reference
def naive_range_max(arr, start, end):
return max(arr[start : end + 1])
def naive_range_update(arr, start, end, value):
for i in range(start, end + 1):
arr[i] += value
# Strategies for hypothesis testing
positive_integers = st.lists(
st.integers(min_value=1, max_value=100), min_size=1, max_size=50
)
def valid_range_indices(array_length):
return st.tuples(
st.integers(min_value=0, max_value=array_length - 1),
st.integers(min_value=0, max_value=array_length - 1),
).map(lambda x: (min(x), max(x)))
update_values = st.integers(min_value=1, max_value=50)
class TestSegmentedTree(TestCase):
# Basic construction and initialization tests
def test_basic_construction(self):
values = [1, 3, 5, 7, 9]
tree = SegmentedTree(values, add_op, max_op, 0)
assert tree.summarize_range(0, 4) == 9
def test_empty_array(self):
with self.assertRaises(ValueError):
SegmentedTree([], add_op, max_op, 0)
# Property-based tests
@given(values=positive_integers)
def test_max_query_matches_naive(self, values):
tree = SegmentedTree(values, add_op, max_op, 0)
for start in range(len(values)):
for end in range(start, len(values)):
expected = naive_range_max(values, start, end)
actual = tree.summarize_range(start, end)
assert actual == expected, (
f"Range [{start}:{end}] expected {expected}, got {actual}"
)
@given(
values=positive_integers, range_indices=st.data(), update_value=update_values
)
def test_range_update(self, values, range_indices, update_value):
# Create a copy for naive implementation
naive_values = values.copy()
# Create segment tree
tree = SegmentedTree(values, add_op, max_op, 0)
# Get valid range indices
start, end = range_indices.draw(valid_range_indices(len(values)))
# Apply updates
tree.update_range(start, end, update_value)
naive_range_update(naive_values, start, end, update_value)
# Verify all possible ranges
for i in range(len(values)):
for j in range(i, len(values)):
expected = naive_range_max(naive_values, i, j)
actual = tree.summarize_range(i, j)
assert actual == expected, (
f"After update, range [{i}:{j}] expected {expected}, got {actual}"
)
@given(values=positive_integers, range_data=st.data())
def test_multiple_operations(self, values, range_data):
# Create a copy for naive implementation
naive_values = values.copy()
tree = SegmentedTree(values, add_op, max_op, 0)
# Perform multiple operations
num_operations = 5
for _ in range(num_operations):
# Randomly choose between query and update
operation_type = range_data.draw(st.sampled_from(["query", "update"]))
start, end = range_data.draw(valid_range_indices(len(values)))
if operation_type == "query":
expected = naive_range_max(naive_values, start, end)
actual = tree.summarize_range(start, end)
assert actual == expected, (
f"Range query [{start}:{end}] expected {expected}, got {actual}"
)
else: # update
update_value = range_data.draw(update_values)
tree.update_range(start, end, update_value)
naive_range_update(naive_values, start, end, update_value)
def test_single_element_ranges(self):
values = [1, 3, 5, 7, 9]
tree = SegmentedTree(values, add_op, max_op, 0)
for i in range(len(values)):
assert tree.summarize_range(i, i) == values[i], (
f"Single element range at index {i} failed"
)
def test_full_array_range(self):
values = [1, 3, 5, 7, 9]
tree = SegmentedTree(values, add_op, max_op, 0)
# Test querying the entire array
assert tree.summarize_range(0, len(values) - 1) == max(values)
# Update the entire array and test again
update_value = 10
tree.update_range(0, len(values) - 1, update_value)
expected = max([v + update_value for v in values])
assert tree.summarize_range(0, len(values) - 1) == expected
def test_boundary_conditions(self):
values = [1, 3, 5, 7, 9]
tree = SegmentedTree(values, add_op, max_op, 0)
# Test first element
assert tree.summarize_range(0, 0) == values[0]
# Test last element
assert tree.summarize_range(len(values) - 1, len(values) - 1) == values[-1]
# Test first two elements
assert tree.summarize_range(0, 1) == max(values[0:2])
# Test last two elements
assert tree.summarize_range(len(values) - 2, len(values) - 1) == max(
values[-2:]
)
def test_invalid_ranges(self):
values = [1, 3, 5, 7, 9]
tree = SegmentedTree(values, add_op, max_op, 0)
# Test start > end
with self.assertRaises(ValueError):
tree.summarize_range(3, 2)
with self.assertRaises(ValueError):
tree.update_range(4, 2, 10)
def test_out_of_bounds(self):
values = [1, 3, 5, 7, 9]
tree = SegmentedTree(values, add_op, max_op, 0)
# Test negative indices
with self.assertRaises(ValueError):
tree.summarize_range(-1, 3)
with self.assertRaises(ValueError):
tree.summarize_range(0, -1)
# Test indices >= n
with self.assertRaises(ValueError):
tree.summarize_range(0, len(values))
with self.assertRaises(ValueError):
tree.summarize_range(len(values), len(values) + 1)
# Test update with out of bounds indices
with self.assertRaises(ValueError):
tree.update_range(-1, 3, 10)
with self.assertRaises(ValueError):
tree.update_range(0, len(values), 10)
def test_overlapping_updates(self):
values = [1, 3, 5, 7, 9]
naive_values = values.copy()
tree = SegmentedTree(values, add_op, max_op, 0)
# Apply overlapping updates
tree.update_range(0, 2, 5) # Update [0, 1, 2]
naive_range_update(naive_values, 0, 2, 5)
tree.update_range(1, 3, 3) # Update [1, 2, 3]
naive_range_update(naive_values, 1, 3, 3)
# Verify all possible ranges
for i in range(len(values)):
for j in range(i, len(values)):
expected = naive_range_max(naive_values, i, j)
actual = tree.summarize_range(i, j)
assert actual == expected, (
f"After overlapping updates, range [{i}:{j}] expected {expected}, got {actual}"
)
def test_sequential_updates_and_queries(self):
values = [2, 4, 6, 8, 10, 12, 14]
naive_values = values.copy()
tree = SegmentedTree(values, add_op, max_op, 0)
# Sequence of operations
operations = [
("update", 1, 3, 5), # Update range [1, 2, 3] with +5
("query", 0, 4), # Query range [0, 1, 2, 3, 4]
("update", 2, 5, 3), # Update range [2, 3, 4, 5] with +3
("query", 1, 3), # Query range [1, 2, 3]
("update", 0, 6, 2), # Update entire array with +2
("query", 0, 6), # Query entire array
("query", 3, 5), # Query range [3, 4, 5]
]
for op in operations:
if op[0] == "update":
_, start, end, value = op
tree.update_range(start, end, value)
naive_range_update(naive_values, start, end, value)
# Verify tree state after update
for i in range(len(values)):
for j in range(i, len(values)):
expected = naive_range_max(naive_values, i, j)
actual = tree.summarize_range(i, j)
assert actual == expected, (
f"After update ({start}, {end}, {value}), query [{i}:{j}] expected {expected}, got {actual}"
)
else: # query
_, start, end = op
expected = naive_range_max(naive_values, start, end)
assert tree.summarize_range(start, end) == expected, (
f"Query [{start}:{end}] expected {expected}, got {tree.summarize_range(start, end)}"
)
if __name__ == "__main__":
run_tests()

View File

@ -13754,6 +13754,45 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
has_lowered = not re.search(r"repeat_interleave.Tensor", code)
self.assertEqual(has_lowered, can_lower)
@staticmethod
def _is_triggering_buffer_reuse(fn, *inputs):
with config.patch(allow_buffer_reuse=True):
_, (code_allowed,) = run_and_get_code(fn, *inputs)
with config.patch(allow_buffer_reuse=False):
_, (code_disallowed,) = run_and_get_code(fn, *inputs)
code_allowed = re.sub(r"AOT ID: .*", "AOT ID: ['test']", code_allowed)
code_disallowed = re.sub(r"AOT ID: .*", "AOT ID: ['test']", code_disallowed)
return code_allowed != code_disallowed
def test_allow_reuse_disable_if_exceed_peak(self):
@torch.compile
def fn(inp): # 1*N^2
a = inp.mean(-1) # 1*N^2 + N
b = (inp - a) ** 2 # 2*N^2 + N
c = b @ b # 3*N^2 (!!) since this is the peak, can not reuse across
d = c.mean(-1) # 2*N^2 + N
return d # 1*N^2 + N
inp = torch.randn(100, 100, device=self.device)
self.assertFalse(CommonTemplate._is_triggering_buffer_reuse(fn, inp))
def test_allow_reuse_active_if_under_peak(self):
def g(inp):
return (inp - torch.logsumexp(inp, -1)) ** 2
@torch.compile
def fn(m, inp):
inp = m @ g(inp)
inp = m @ g(inp)
inp = m @ g(inp)
inp = m @ g(inp)
inp = m @ g(inp)
return inp
m = torch.randn(100, 100, device=self.device)
inp = torch.randn(100, 100, device=self.device)
self.assertTrue(CommonTemplate._is_triggering_buffer_reuse(fn, m, inp))
# end of class CommonTemplate - add new tests here

View File

@ -0,0 +1,241 @@
from typing import Callable, Generic, Optional, TypeVar
T = TypeVar("T")
def _value_or(opt: Optional[T], default: T) -> T:
return opt if opt is not None else default
class SegmentedTree(Generic[T]):
def __init__(
self,
values: list[T],
update_op: Callable[[T, T], T],
summary_op: Callable[[T, T], T],
identity_element: T,
):
"""
Initialize a segment tree with the given values and operations.
Args:
values: list of initial values
update_op: Function to apply when updating a value (e.g., addition)
summary_op: Function to summarize two values (e.g., min, max, sum)
identity_element: Identity element for the summary_op (e.g., 0 for sum, float('inf') for min)
Raises:
ValueError: If the input values list is empty
"""
if not values:
raise ValueError("Cannot create a segment tree with empty values list")
self.n = len(values)
self.update_op = update_op
self.summary_op = summary_op
self.identity = identity_element
# Size of segment tree array (next power of 2 * 2)
# The tree follows a standard heap layout where
# node `n`'s children are at `2*n` and `2*n+1`.
# Index 0 is unused.
self.size = 1
while self.size < self.n:
self.size *= 2
self.size *= 2
# Initialize tree and lazy arrays
self.tree = [identity_element] * self.size
# The lazy array contains updates to the given node
# Upon update, we only push updates to the top-most
# nodes that fully receive the update. We then
# propagate the update down as required (i.e., when
# we receive an interval query that neither fully
# contains the node nor fully doesn't contain the
# node
self.lazy: list[Optional[T]] = [None] * self.size
# Build the tree
self._build(values, 1, 0, self.n - 1)
def _build(self, values: list[T], node: int, start: int, end: int) -> None:
"""
Build the segment tree recursively.
Args:
values: Original array of values
node: Current node index in the segment tree
start: Start index of the segment
end: End index of the segment
"""
if start == end:
# Leaf node
if start < len(values):
self.tree[node] = values[start]
return
mid = (start + end) // 2
left_child = 2 * node
right_child = 2 * node + 1
# Recursively build left and right subtrees
self._build(values, left_child, start, mid)
self._build(values, right_child, mid + 1, end)
# Update current node with summary of children
self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
def _children(self, node: int) -> list[int]:
return [2 * node, 2 * node + 1]
def _push_lazy(self, node: int, start: int, end: int) -> None:
"""
Push lazy updates down to children.
Args:
node: Current node index
start: Start index of the segment
end: End index of the segment
"""
lazy_node = self.lazy[node]
if lazy_node is None:
return
# Apply lazy update to current node
self.tree[node] = self.update_op(self.tree[node], lazy_node)
if start != end: # Not a leaf node
# Propagate to children
for child in self._children(node):
self.lazy[child] = self.update_op(
_value_or(self.lazy[child], self.identity), lazy_node
)
# Clear the lazy value
self.lazy[node] = None
def _update_range_helper(
self, node: int, start: int, end: int, left: int, right: int, value: T
) -> None:
"""
Helper method to update a range of values in the segment tree.
Args:
node: Current node index
start: Start index of the current segment
end: End index of the current segment
left: Start index of the range to update
right: End index of the range to update
value: Value to apply to the range
"""
# Push lazy updates before processing this node
self._push_lazy(node, start, end)
# No overlap
if start > right or end < left:
return
# Complete overlap
if start >= left and end <= right:
# Apply update to current node
self.lazy[node] = value
self._push_lazy(node, start, end)
return
# Partial overlap, recurse to children
mid = (start + end) // 2
left_child = 2 * node
right_child = 2 * node + 1
self._update_range_helper(left_child, start, mid, left, right, value)
self._update_range_helper(right_child, mid + 1, end, left, right, value)
# Update current node based on children
self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
def _query_range_helper(
self, node: int, start: int, end: int, left: int, right: int
) -> T:
"""
Helper method to query a range of values in the segment tree.
Args:
node: Current node index
start: Start index of the current segment
end: End index of the current segment
left: Start index of the range to query
right: End index of the range to query
Returns:
Summary value for the range
"""
# No overlap
if start > right or end < left:
return self.identity
# Push lazy updates before processing this node
self._push_lazy(node, start, end)
# Complete overlap
if start >= left and end <= right:
return self.tree[node]
# Partial overlap, recurse to children
mid = (start + end) // 2
left_child = 2 * node
right_child = 2 * node + 1
left_result = self._query_range_helper(left_child, start, mid, left, right)
right_result = self._query_range_helper(right_child, mid + 1, end, left, right)
# Combine results from children
return self.summary_op(left_result, right_result)
def update_range(self, start: int, end: int, value: T) -> None:
"""
Update a range of values in the segment tree.
Args:
start: Start index of the range to update (inclusive)
end: End index of the range to update (inclusive)
value: Value to apply to the range
Raises:
ValueError: If start > end or indices are out of bounds
"""
if start > end:
raise ValueError("Start index must be less than or equal to end index")
if start < 0 or start >= self.n:
raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]")
if end < 0 or end >= self.n:
raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]")
self._update_range_helper(1, 0, self.n - 1, start, end, value)
def summarize_range(self, start: int, end: int) -> T:
"""
Query a range of values in the segment tree.
Args:
start: Start index of the range to query (inclusive)
end: End index of the range to query (inclusive)
Returns:
Summary value for the range according to the summary operation
Raises:
ValueError: If start > end or indices are out of bounds
"""
if start > end:
raise ValueError("Start index must be less than or equal to end index")
if start < 0 or start >= self.n:
raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]")
if end < 0 or end >= self.n:
raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]")
return self._query_range_helper(1, 0, self.n - 1, start, end)

View File

@ -48,6 +48,7 @@ from ..utils import (
cache_on_self,
DelayReplaceLine,
get_benchmark_name,
get_dtype_size,
IndentedBuffer,
is_codegen_graph_partition_subgraph,
is_using_cudagraph_partition,
@ -587,10 +588,64 @@ class MemoryPlanningLine(WrapperLine):
return f"{type(self).__name__}({', '.join(args)})"
class EfficientPeakEstimate:
def __init__(self):
from ..memory import estimate_peak_memory, get_freeable_input_buf
scheduler_nodes = V.graph.scheduler.nodes
graph_inputs = OrderedSet(V.graph.graph_inputs.keys())
graph_outputs = OrderedSet(V.graph.get_output_names())
names_to_freeable_bufs = get_freeable_input_buf(scheduler_nodes, graph_inputs)
self.overall_peak_memory, peak_by_scheduler_node = estimate_peak_memory(
scheduler_nodes,
names_to_freeable_bufs,
graph_outputs,
)
from .segmented_tree import SegmentedTree
self.segmented_tree = SegmentedTree(
peak_by_scheduler_node, operator.add, max, 0
)
def _get_size(self, node: BufferLike) -> int:
return V.graph.sizevars.size_hint(
V.graph.get_allocation_storage_size(node), fallback=0
) * get_dtype_size(node.get_dtype())
def peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine):
return self.segmented_tree.summarize_range(
line_a.scheduler_node_index + 1, line_b.scheduler_node_index - 1
)
def update_peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine):
if line_a.scheduler_node_index + 1 == line_b.scheduler_node_index:
return
self.segmented_tree.update_range(
line_a.scheduler_node_index + 1,
line_b.scheduler_node_index - 1,
self._get_size(line_b.node),
)
@dataclasses.dataclass
class AllocateLine(MemoryPlanningLine):
node: BufferLike
def __post_init__(self):
assert V.graph.scheduler.current_node is not None
self.scheduler_node_index = V.graph.scheduler.nodes.index(
V.graph.scheduler.current_node
)
def should_reuse_buffer(self, free_line: FreeIfNotReusedLine, size: int) -> bool:
if free_line.scheduler_node_index + 1 == self.scheduler_node_index:
return True
overall_peak_memory = self.wrapper.estimate_peak.overall_peak_memory
peak_memory_in_range = self.wrapper.estimate_peak.peak_between(free_line, self)
new_peak_memory = size + peak_memory_in_range
return new_peak_memory <= overall_peak_memory
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
if self.node.get_name() in V.graph.removed_buffers:
return NullLine(self.wrapper)
@ -599,8 +654,16 @@ class AllocateLine(MemoryPlanningLine):
key = buffer_reuse_key(self.node)
if config.allow_buffer_reuse and key in state:
free_line = state.pop(key)
free_line.is_reused = True
return ReuseLine(self.wrapper, free_line.node, self.node)
size = V.graph.sizevars.size_hint(
V.graph.get_allocation_storage_size(self.node), fallback=0
) * get_dtype_size(self.node.get_dtype())
if self.should_reuse_buffer(free_line, size):
free_line.is_reused = True
self.wrapper.estimate_peak.update_peak_between(free_line, self)
return ReuseLine(self.wrapper, free_line.node, self.node)
else:
state.push(key, free_line)
return self
if self.node.get_device_or_error().type == "cpu":
static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node)
@ -625,6 +688,12 @@ class FreeIfNotReusedLine(MemoryPlanningLine):
node: BufferLike
is_reused: bool = False
def __post_init__(self):
assert V.graph.scheduler.current_node is not None
self.scheduler_node_index = V.graph.scheduler.nodes.index(
V.graph.scheduler.current_node
)
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
if len(self.node.get_inputs_that_alias_output()) > 0:
return self
@ -1645,6 +1714,7 @@ class PythonWrapperCodegen(CodeGen):
if is_inference and config.memory_planning:
self.memory_plan()
else:
self.estimate_peak = EfficientPeakEstimate()
self.memory_plan_reuse()
def codegen_input_symbol_assignment(

View File

@ -2073,6 +2073,7 @@ class Scheduler:
)
self.nodes = [self.create_scheduler_node(n) for n in nodes]
self.current_node: Optional[BaseSchedulerNode] = None
self.update_zero_dim_cpu_tensor()
# some new constants could have been created above
self.available_buffer_names.update(V.graph.constants.keys())
@ -4989,6 +4990,7 @@ class Scheduler:
assert device.index is not None, "device should have an index"
V.graph.wrapper_code.codegen_device_guard_enter(device.index)
self.current_node = node
self.buffer_names_to_free.update(node.last_usage)
if node.is_template():