Revert "[inductor] dont reuse buffers if it affects peak (#145883) (#159530)"

This reverts commit 3be70dc30e893b552fc0f23ca06cd8f7949b6d08.

Reverted https://github.com/pytorch/pytorch/pull/159530 on behalf of https://github.com/clee2000 due to newly added test fail internally D80316528, probably just a targets change, but also imo the tests should probably go into a testcase class from common or inductor utils.  While I'm pretty sure CI can run the globally defined ones, theres some CI related functionality that on the testcase class that CI benefits from ([comment](https://github.com/pytorch/pytorch/pull/159530#issuecomment-3191947506))
This commit is contained in:
PyTorch MergeBot
2025-08-15 15:49:04 +00:00
parent 846963fa9b
commit 9df07ecfbe
5 changed files with 2 additions and 615 deletions

View File

@ -1,261 +0,0 @@
# Owner(s): ["module: inductor"]
import pytest
from hypothesis import given, strategies as st
from torch._inductor.codegen.segmented_tree import SegmentedTree
from torch.testing._internal.common_utils import run_tests
# Helper functions for operations
def max_op(a, b):
return max(a, b)
def add_op(a, b):
return a + b
# Naive implementations for reference
def naive_range_max(arr, start, end):
return max(arr[start : end + 1])
def naive_range_update(arr, start, end, value):
for i in range(start, end + 1):
arr[i] += value
# Strategies for hypothesis testing
positive_integers = st.lists(
st.integers(min_value=1, max_value=100), min_size=1, max_size=50
)
def valid_range_indices(array_length):
return st.tuples(
st.integers(min_value=0, max_value=array_length - 1),
st.integers(min_value=0, max_value=array_length - 1),
).map(lambda x: (min(x), max(x)))
update_values = st.integers(min_value=1, max_value=50)
# Basic construction and initialization tests
def test_basic_construction():
values = [1, 3, 5, 7, 9]
tree = SegmentedTree(values, add_op, max_op, 0)
assert tree.summarize_range(0, 4) == 9
def test_empty_array():
with pytest.raises(ValueError):
SegmentedTree([], add_op, max_op, 0)
# Property-based tests
@given(values=positive_integers)
def test_max_query_matches_naive(values):
tree = SegmentedTree(values, add_op, max_op, 0)
for start in range(len(values)):
for end in range(start, len(values)):
expected = naive_range_max(values, start, end)
actual = tree.summarize_range(start, end)
assert actual == expected, (
f"Range [{start}:{end}] expected {expected}, got {actual}"
)
@given(values=positive_integers, range_indices=st.data(), update_value=update_values)
def test_range_update(values, range_indices, update_value):
# Create a copy for naive implementation
naive_values = values.copy()
# Create segment tree
tree = SegmentedTree(values, add_op, max_op, 0)
# Get valid range indices
start, end = range_indices.draw(valid_range_indices(len(values)))
# Apply updates
tree.update_range(start, end, update_value)
naive_range_update(naive_values, start, end, update_value)
# Verify all possible ranges
for i in range(len(values)):
for j in range(i, len(values)):
expected = naive_range_max(naive_values, i, j)
actual = tree.summarize_range(i, j)
assert actual == expected, (
f"After update, range [{i}:{j}] expected {expected}, got {actual}"
)
@given(values=positive_integers, range_data=st.data())
def test_multiple_operations(values, range_data):
# Create a copy for naive implementation
naive_values = values.copy()
tree = SegmentedTree(values, add_op, max_op, 0)
# Perform multiple operations
num_operations = 5
for _ in range(num_operations):
# Randomly choose between query and update
operation_type = range_data.draw(st.sampled_from(["query", "update"]))
start, end = range_data.draw(valid_range_indices(len(values)))
if operation_type == "query":
expected = naive_range_max(naive_values, start, end)
actual = tree.summarize_range(start, end)
assert actual == expected, (
f"Range query [{start}:{end}] expected {expected}, got {actual}"
)
else: # update
update_value = range_data.draw(update_values)
tree.update_range(start, end, update_value)
naive_range_update(naive_values, start, end, update_value)
def test_single_element_ranges():
values = [1, 3, 5, 7, 9]
tree = SegmentedTree(values, add_op, max_op, 0)
for i in range(len(values)):
assert tree.summarize_range(i, i) == values[i], (
f"Single element range at index {i} failed"
)
def test_full_array_range():
values = [1, 3, 5, 7, 9]
tree = SegmentedTree(values, add_op, max_op, 0)
# Test querying the entire array
assert tree.summarize_range(0, len(values) - 1) == max(values)
# Update the entire array and test again
update_value = 10
tree.update_range(0, len(values) - 1, update_value)
expected = max([v + update_value for v in values])
assert tree.summarize_range(0, len(values) - 1) == expected
def test_boundary_conditions():
values = [1, 3, 5, 7, 9]
tree = SegmentedTree(values, add_op, max_op, 0)
# Test first element
assert tree.summarize_range(0, 0) == values[0]
# Test last element
assert tree.summarize_range(len(values) - 1, len(values) - 1) == values[-1]
# Test first two elements
assert tree.summarize_range(0, 1) == max(values[0:2])
# Test last two elements
assert tree.summarize_range(len(values) - 2, len(values) - 1) == max(values[-2:])
def test_invalid_ranges():
values = [1, 3, 5, 7, 9]
tree = SegmentedTree(values, add_op, max_op, 0)
# Test start > end
with pytest.raises(ValueError):
tree.summarize_range(3, 2)
with pytest.raises(ValueError):
tree.update_range(4, 2, 10)
def test_out_of_bounds():
values = [1, 3, 5, 7, 9]
tree = SegmentedTree(values, add_op, max_op, 0)
# Test negative indices
with pytest.raises(ValueError):
tree.summarize_range(-1, 3)
with pytest.raises(ValueError):
tree.summarize_range(0, -1)
# Test indices >= n
with pytest.raises(ValueError):
tree.summarize_range(0, len(values))
with pytest.raises(ValueError):
tree.summarize_range(len(values), len(values) + 1)
# Test update with out of bounds indices
with pytest.raises(ValueError):
tree.update_range(-1, 3, 10)
with pytest.raises(ValueError):
tree.update_range(0, len(values), 10)
def test_overlapping_updates():
values = [1, 3, 5, 7, 9]
naive_values = values.copy()
tree = SegmentedTree(values, add_op, max_op, 0)
# Apply overlapping updates
tree.update_range(0, 2, 5) # Update [0, 1, 2]
naive_range_update(naive_values, 0, 2, 5)
tree.update_range(1, 3, 3) # Update [1, 2, 3]
naive_range_update(naive_values, 1, 3, 3)
# Verify all possible ranges
for i in range(len(values)):
for j in range(i, len(values)):
expected = naive_range_max(naive_values, i, j)
actual = tree.summarize_range(i, j)
assert actual == expected, (
f"After overlapping updates, range [{i}:{j}] expected {expected}, got {actual}"
)
def test_sequential_updates_and_queries():
values = [2, 4, 6, 8, 10, 12, 14]
naive_values = values.copy()
tree = SegmentedTree(values, add_op, max_op, 0)
# Sequence of operations
operations = [
("update", 1, 3, 5), # Update range [1, 2, 3] with +5
("query", 0, 4), # Query range [0, 1, 2, 3, 4]
("update", 2, 5, 3), # Update range [2, 3, 4, 5] with +3
("query", 1, 3), # Query range [1, 2, 3]
("update", 0, 6, 2), # Update entire array with +2
("query", 0, 6), # Query entire array
("query", 3, 5), # Query range [3, 4, 5]
]
for op in operations:
if op[0] == "update":
_, start, end, value = op
tree.update_range(start, end, value)
naive_range_update(naive_values, start, end, value)
# Verify tree state after update
for i in range(len(values)):
for j in range(i, len(values)):
expected = naive_range_max(naive_values, i, j)
actual = tree.summarize_range(i, j)
assert actual == expected, (
f"After update ({start}, {end}, {value}), query [{i}:{j}] expected {expected}, got {actual}"
)
else: # query
_, start, end = op
expected = naive_range_max(naive_values, start, end)
assert tree.summarize_range(start, end) == expected, (
f"Query [{start}:{end}] expected {expected}, got {tree.summarize_range(start, end)}"
)
if __name__ == "__main__":
run_tests()

View File

@ -13754,45 +13754,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
has_lowered = not re.search(r"repeat_interleave.Tensor", code)
self.assertEqual(has_lowered, can_lower)
@staticmethod
def _is_triggering_buffer_reuse(fn, *inputs):
with config.patch(allow_buffer_reuse=True):
_, (code_allowed,) = run_and_get_code(fn, *inputs)
with config.patch(allow_buffer_reuse=False):
_, (code_disallowed,) = run_and_get_code(fn, *inputs)
code_allowed = re.sub(r"AOT ID: .*", "AOT ID: ['test']", code_allowed)
code_disallowed = re.sub(r"AOT ID: .*", "AOT ID: ['test']", code_disallowed)
return code_allowed != code_disallowed
def test_allow_reuse_disable_if_exceed_peak(self):
@torch.compile
def fn(inp): # 1*N^2
a = inp.mean(-1) # 1*N^2 + N
b = (inp - a) ** 2 # 2*N^2 + N
c = b @ b # 3*N^2 (!!) since this is the peak, can not reuse across
d = c.mean(-1) # 2*N^2 + N
return d # 1*N^2 + N
inp = torch.randn(100, 100, device=self.device)
self.assertFalse(CommonTemplate._is_triggering_buffer_reuse(fn, inp))
def test_allow_reuse_active_if_under_peak(self):
def g(inp):
return (inp - torch.logsumexp(inp, -1)) ** 2
@torch.compile
def fn(m, inp):
inp = m @ g(inp)
inp = m @ g(inp)
inp = m @ g(inp)
inp = m @ g(inp)
inp = m @ g(inp)
return inp
m = torch.randn(100, 100, device=self.device)
inp = torch.randn(100, 100, device=self.device)
self.assertTrue(CommonTemplate._is_triggering_buffer_reuse(fn, m, inp))
# end of class CommonTemplate - add new tests here

View File

@ -1,241 +0,0 @@
from typing import Callable, Generic, Optional, TypeVar
T = TypeVar("T")
def _value_or(opt: Optional[T], default: T) -> T:
return opt if opt is not None else default
class SegmentedTree(Generic[T]):
def __init__(
self,
values: list[T],
update_op: Callable[[T, T], T],
summary_op: Callable[[T, T], T],
identity_element: T,
):
"""
Initialize a segment tree with the given values and operations.
Args:
values: list of initial values
update_op: Function to apply when updating a value (e.g., addition)
summary_op: Function to summarize two values (e.g., min, max, sum)
identity_element: Identity element for the summary_op (e.g., 0 for sum, float('inf') for min)
Raises:
ValueError: If the input values list is empty
"""
if not values:
raise ValueError("Cannot create a segment tree with empty values list")
self.n = len(values)
self.update_op = update_op
self.summary_op = summary_op
self.identity = identity_element
# Size of segment tree array (next power of 2 * 2)
# The tree follows a standard heap layout where
# node `n`'s children are at `2*n` and `2*n+1`.
# Index 0 is unused.
self.size = 1
while self.size < self.n:
self.size *= 2
self.size *= 2
# Initialize tree and lazy arrays
self.tree = [identity_element] * self.size
# The lazy array contains updates to the given node
# Upon update, we only push updates to the top-most
# nodes that fully receive the update. We then
# propagate the update down as required (i.e., when
# we receive an interval query that neither fully
# contains the node nor fully doesn't contain the
# node
self.lazy: list[Optional[T]] = [None] * self.size
# Build the tree
self._build(values, 1, 0, self.n - 1)
def _build(self, values: list[T], node: int, start: int, end: int) -> None:
"""
Build the segment tree recursively.
Args:
values: Original array of values
node: Current node index in the segment tree
start: Start index of the segment
end: End index of the segment
"""
if start == end:
# Leaf node
if start < len(values):
self.tree[node] = values[start]
return
mid = (start + end) // 2
left_child = 2 * node
right_child = 2 * node + 1
# Recursively build left and right subtrees
self._build(values, left_child, start, mid)
self._build(values, right_child, mid + 1, end)
# Update current node with summary of children
self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
def _children(self, node: int) -> list[int]:
return [2 * node, 2 * node + 1]
def _push_lazy(self, node: int, start: int, end: int) -> None:
"""
Push lazy updates down to children.
Args:
node: Current node index
start: Start index of the segment
end: End index of the segment
"""
lazy_node = self.lazy[node]
if lazy_node is None:
return
# Apply lazy update to current node
self.tree[node] = self.update_op(self.tree[node], lazy_node)
if start != end: # Not a leaf node
# Propagate to children
for child in self._children(node):
self.lazy[child] = self.update_op(
_value_or(self.lazy[child], self.identity), lazy_node
)
# Clear the lazy value
self.lazy[node] = None
def _update_range_helper(
self, node: int, start: int, end: int, left: int, right: int, value: T
) -> None:
"""
Helper method to update a range of values in the segment tree.
Args:
node: Current node index
start: Start index of the current segment
end: End index of the current segment
left: Start index of the range to update
right: End index of the range to update
value: Value to apply to the range
"""
# Push lazy updates before processing this node
self._push_lazy(node, start, end)
# No overlap
if start > right or end < left:
return
# Complete overlap
if start >= left and end <= right:
# Apply update to current node
self.lazy[node] = value
self._push_lazy(node, start, end)
return
# Partial overlap, recurse to children
mid = (start + end) // 2
left_child = 2 * node
right_child = 2 * node + 1
self._update_range_helper(left_child, start, mid, left, right, value)
self._update_range_helper(right_child, mid + 1, end, left, right, value)
# Update current node based on children
self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
def _query_range_helper(
self, node: int, start: int, end: int, left: int, right: int
) -> T:
"""
Helper method to query a range of values in the segment tree.
Args:
node: Current node index
start: Start index of the current segment
end: End index of the current segment
left: Start index of the range to query
right: End index of the range to query
Returns:
Summary value for the range
"""
# No overlap
if start > right or end < left:
return self.identity
# Push lazy updates before processing this node
self._push_lazy(node, start, end)
# Complete overlap
if start >= left and end <= right:
return self.tree[node]
# Partial overlap, recurse to children
mid = (start + end) // 2
left_child = 2 * node
right_child = 2 * node + 1
left_result = self._query_range_helper(left_child, start, mid, left, right)
right_result = self._query_range_helper(right_child, mid + 1, end, left, right)
# Combine results from children
return self.summary_op(left_result, right_result)
def update_range(self, start: int, end: int, value: T) -> None:
"""
Update a range of values in the segment tree.
Args:
start: Start index of the range to update (inclusive)
end: End index of the range to update (inclusive)
value: Value to apply to the range
Raises:
ValueError: If start > end or indices are out of bounds
"""
if start > end:
raise ValueError("Start index must be less than or equal to end index")
if start < 0 or start >= self.n:
raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]")
if end < 0 or end >= self.n:
raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]")
self._update_range_helper(1, 0, self.n - 1, start, end, value)
def summarize_range(self, start: int, end: int) -> T:
"""
Query a range of values in the segment tree.
Args:
start: Start index of the range to query (inclusive)
end: End index of the range to query (inclusive)
Returns:
Summary value for the range according to the summary operation
Raises:
ValueError: If start > end or indices are out of bounds
"""
if start > end:
raise ValueError("Start index must be less than or equal to end index")
if start < 0 or start >= self.n:
raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]")
if end < 0 or end >= self.n:
raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]")
return self._query_range_helper(1, 0, self.n - 1, start, end)

View File

@ -48,7 +48,6 @@ from ..utils import (
cache_on_self,
DelayReplaceLine,
get_benchmark_name,
get_dtype_size,
IndentedBuffer,
is_codegen_graph_partition_subgraph,
is_using_cudagraph_partition,
@ -588,64 +587,10 @@ class MemoryPlanningLine(WrapperLine):
return f"{type(self).__name__}({', '.join(args)})"
class EfficientPeakEstimate:
def __init__(self):
from ..memory import estimate_peak_memory, get_freeable_input_buf
scheduler_nodes = V.graph.scheduler.nodes
graph_inputs = OrderedSet(V.graph.graph_inputs.keys())
graph_outputs = OrderedSet(V.graph.get_output_names())
names_to_freeable_bufs = get_freeable_input_buf(scheduler_nodes, graph_inputs)
self.overall_peak_memory, peak_by_scheduler_node = estimate_peak_memory(
scheduler_nodes,
names_to_freeable_bufs,
graph_outputs,
)
from .segmented_tree import SegmentedTree
self.segmented_tree = SegmentedTree(
peak_by_scheduler_node, operator.add, max, 0
)
def _get_size(self, node: BufferLike) -> int:
return V.graph.sizevars.size_hint(
V.graph.get_allocation_storage_size(node), fallback=0
) * get_dtype_size(node.get_dtype())
def peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine):
return self.segmented_tree.summarize_range(
line_a.scheduler_node_index + 1, line_b.scheduler_node_index - 1
)
def update_peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine):
if line_a.scheduler_node_index + 1 == line_b.scheduler_node_index:
return
self.segmented_tree.update_range(
line_a.scheduler_node_index + 1,
line_b.scheduler_node_index - 1,
self._get_size(line_b.node),
)
@dataclasses.dataclass
class AllocateLine(MemoryPlanningLine):
node: BufferLike
def __post_init__(self):
assert V.graph.scheduler.current_node is not None
self.scheduler_node_index = V.graph.scheduler.nodes.index(
V.graph.scheduler.current_node
)
def should_reuse_buffer(self, free_line: FreeIfNotReusedLine, size: int) -> bool:
if free_line.scheduler_node_index + 1 == self.scheduler_node_index:
return True
overall_peak_memory = self.wrapper.estimate_peak.overall_peak_memory
peak_memory_in_range = self.wrapper.estimate_peak.peak_between(free_line, self)
new_peak_memory = size + peak_memory_in_range
return new_peak_memory <= overall_peak_memory
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
if self.node.get_name() in V.graph.removed_buffers:
return NullLine(self.wrapper)
@ -654,16 +599,8 @@ class AllocateLine(MemoryPlanningLine):
key = buffer_reuse_key(self.node)
if config.allow_buffer_reuse and key in state:
free_line = state.pop(key)
size = V.graph.sizevars.size_hint(
V.graph.get_allocation_storage_size(self.node), fallback=0
) * get_dtype_size(self.node.get_dtype())
if self.should_reuse_buffer(free_line, size):
free_line.is_reused = True
self.wrapper.estimate_peak.update_peak_between(free_line, self)
return ReuseLine(self.wrapper, free_line.node, self.node)
else:
state.push(key, free_line)
return self
free_line.is_reused = True
return ReuseLine(self.wrapper, free_line.node, self.node)
if self.node.get_device_or_error().type == "cpu":
static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node)
@ -688,12 +625,6 @@ class FreeIfNotReusedLine(MemoryPlanningLine):
node: BufferLike
is_reused: bool = False
def __post_init__(self):
assert V.graph.scheduler.current_node is not None
self.scheduler_node_index = V.graph.scheduler.nodes.index(
V.graph.scheduler.current_node
)
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
if len(self.node.get_inputs_that_alias_output()) > 0:
return self
@ -1714,7 +1645,6 @@ class PythonWrapperCodegen(CodeGen):
if is_inference and config.memory_planning:
self.memory_plan()
else:
self.estimate_peak = EfficientPeakEstimate()
self.memory_plan_reuse()
def codegen_input_symbol_assignment(

View File

@ -2073,7 +2073,6 @@ class Scheduler:
)
self.nodes = [self.create_scheduler_node(n) for n in nodes]
self.current_node: Optional[BaseSchedulerNode] = None
self.update_zero_dim_cpu_tensor()
# some new constants could have been created above
self.available_buffer_names.update(V.graph.constants.keys())
@ -4990,7 +4989,6 @@ class Scheduler:
assert device.index is not None, "device should have an index"
V.graph.wrapper_code.codegen_device_guard_enter(device.index)
self.current_node = node
self.buffer_names_to_free.update(node.last_usage)
if node.is_template():