From 65d21dae18a34e8bd1b2f0e5aec7144b9dd33611 Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Fri, 15 Aug 2025 12:00:45 -0700 Subject: [PATCH] [inductor] dont reuse buffers if it affects peak (#145883) (#159530) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159530 Approved by: https://github.com/eellison --- test/inductor/test_segmented_tree.py | 254 ++++++++++++++++++++++ test/inductor/test_torchinductor.py | 39 ++++ torch/_inductor/codegen/segmented_tree.py | 241 ++++++++++++++++++++ torch/_inductor/codegen/wrapper.py | 74 ++++++- torch/_inductor/scheduler.py | 2 + 5 files changed, 608 insertions(+), 2 deletions(-) create mode 100644 test/inductor/test_segmented_tree.py create mode 100644 torch/_inductor/codegen/segmented_tree.py diff --git a/test/inductor/test_segmented_tree.py b/test/inductor/test_segmented_tree.py new file mode 100644 index 000000000000..22d1c85027dc --- /dev/null +++ b/test/inductor/test_segmented_tree.py @@ -0,0 +1,254 @@ +# Owner(s): ["module: inductor"] + +from hypothesis import given, strategies as st + +from torch._inductor.codegen.segmented_tree import SegmentedTree +from torch._inductor.test_case import run_tests, TestCase + + +# Helper functions for operations +def max_op(a, b): + return max(a, b) + + +def add_op(a, b): + return a + b + + +# Naive implementations for reference +def naive_range_max(arr, start, end): + return max(arr[start : end + 1]) + + +def naive_range_update(arr, start, end, value): + for i in range(start, end + 1): + arr[i] += value + + +# Strategies for hypothesis testing +positive_integers = st.lists( + st.integers(min_value=1, max_value=100), min_size=1, max_size=50 +) + + +def valid_range_indices(array_length): + return st.tuples( + st.integers(min_value=0, max_value=array_length - 1), + st.integers(min_value=0, max_value=array_length - 1), + ).map(lambda x: (min(x), max(x))) + + +update_values = st.integers(min_value=1, max_value=50) + + +class TestSegmentedTree(TestCase): + # Basic construction and initialization tests + def test_basic_construction(self): + values = [1, 3, 5, 7, 9] + tree = SegmentedTree(values, add_op, max_op, 0) + assert tree.summarize_range(0, 4) == 9 + + def test_empty_array(self): + with self.assertRaises(ValueError): + SegmentedTree([], add_op, max_op, 0) + + # Property-based tests + @given(values=positive_integers) + def test_max_query_matches_naive(self, values): + tree = SegmentedTree(values, add_op, max_op, 0) + + for start in range(len(values)): + for end in range(start, len(values)): + expected = naive_range_max(values, start, end) + actual = tree.summarize_range(start, end) + assert actual == expected, ( + f"Range [{start}:{end}] expected {expected}, got {actual}" + ) + + @given( + values=positive_integers, range_indices=st.data(), update_value=update_values + ) + def test_range_update(self, values, range_indices, update_value): + # Create a copy for naive implementation + naive_values = values.copy() + + # Create segment tree + tree = SegmentedTree(values, add_op, max_op, 0) + + # Get valid range indices + start, end = range_indices.draw(valid_range_indices(len(values))) + + # Apply updates + tree.update_range(start, end, update_value) + naive_range_update(naive_values, start, end, update_value) + + # Verify all possible ranges + for i in range(len(values)): + for j in range(i, len(values)): + expected = naive_range_max(naive_values, i, j) + actual = tree.summarize_range(i, j) + assert actual == expected, ( + f"After update, range [{i}:{j}] expected {expected}, got {actual}" + ) + + @given(values=positive_integers, range_data=st.data()) + def test_multiple_operations(self, values, range_data): + # Create a copy for naive implementation + naive_values = values.copy() + tree = SegmentedTree(values, add_op, max_op, 0) + + # Perform multiple operations + num_operations = 5 + for _ in range(num_operations): + # Randomly choose between query and update + operation_type = range_data.draw(st.sampled_from(["query", "update"])) + start, end = range_data.draw(valid_range_indices(len(values))) + + if operation_type == "query": + expected = naive_range_max(naive_values, start, end) + actual = tree.summarize_range(start, end) + assert actual == expected, ( + f"Range query [{start}:{end}] expected {expected}, got {actual}" + ) + else: # update + update_value = range_data.draw(update_values) + tree.update_range(start, end, update_value) + naive_range_update(naive_values, start, end, update_value) + + def test_single_element_ranges(self): + values = [1, 3, 5, 7, 9] + tree = SegmentedTree(values, add_op, max_op, 0) + + for i in range(len(values)): + assert tree.summarize_range(i, i) == values[i], ( + f"Single element range at index {i} failed" + ) + + def test_full_array_range(self): + values = [1, 3, 5, 7, 9] + tree = SegmentedTree(values, add_op, max_op, 0) + + # Test querying the entire array + assert tree.summarize_range(0, len(values) - 1) == max(values) + + # Update the entire array and test again + update_value = 10 + tree.update_range(0, len(values) - 1, update_value) + expected = max([v + update_value for v in values]) + assert tree.summarize_range(0, len(values) - 1) == expected + + def test_boundary_conditions(self): + values = [1, 3, 5, 7, 9] + tree = SegmentedTree(values, add_op, max_op, 0) + + # Test first element + assert tree.summarize_range(0, 0) == values[0] + + # Test last element + assert tree.summarize_range(len(values) - 1, len(values) - 1) == values[-1] + + # Test first two elements + assert tree.summarize_range(0, 1) == max(values[0:2]) + + # Test last two elements + assert tree.summarize_range(len(values) - 2, len(values) - 1) == max( + values[-2:] + ) + + def test_invalid_ranges(self): + values = [1, 3, 5, 7, 9] + tree = SegmentedTree(values, add_op, max_op, 0) + + # Test start > end + with self.assertRaises(ValueError): + tree.summarize_range(3, 2) + + with self.assertRaises(ValueError): + tree.update_range(4, 2, 10) + + def test_out_of_bounds(self): + values = [1, 3, 5, 7, 9] + tree = SegmentedTree(values, add_op, max_op, 0) + + # Test negative indices + with self.assertRaises(ValueError): + tree.summarize_range(-1, 3) + + with self.assertRaises(ValueError): + tree.summarize_range(0, -1) + + # Test indices >= n + with self.assertRaises(ValueError): + tree.summarize_range(0, len(values)) + + with self.assertRaises(ValueError): + tree.summarize_range(len(values), len(values) + 1) + + # Test update with out of bounds indices + with self.assertRaises(ValueError): + tree.update_range(-1, 3, 10) + + with self.assertRaises(ValueError): + tree.update_range(0, len(values), 10) + + def test_overlapping_updates(self): + values = [1, 3, 5, 7, 9] + naive_values = values.copy() + tree = SegmentedTree(values, add_op, max_op, 0) + + # Apply overlapping updates + tree.update_range(0, 2, 5) # Update [0, 1, 2] + naive_range_update(naive_values, 0, 2, 5) + + tree.update_range(1, 3, 3) # Update [1, 2, 3] + naive_range_update(naive_values, 1, 3, 3) + + # Verify all possible ranges + for i in range(len(values)): + for j in range(i, len(values)): + expected = naive_range_max(naive_values, i, j) + actual = tree.summarize_range(i, j) + assert actual == expected, ( + f"After overlapping updates, range [{i}:{j}] expected {expected}, got {actual}" + ) + + def test_sequential_updates_and_queries(self): + values = [2, 4, 6, 8, 10, 12, 14] + naive_values = values.copy() + tree = SegmentedTree(values, add_op, max_op, 0) + + # Sequence of operations + operations = [ + ("update", 1, 3, 5), # Update range [1, 2, 3] with +5 + ("query", 0, 4), # Query range [0, 1, 2, 3, 4] + ("update", 2, 5, 3), # Update range [2, 3, 4, 5] with +3 + ("query", 1, 3), # Query range [1, 2, 3] + ("update", 0, 6, 2), # Update entire array with +2 + ("query", 0, 6), # Query entire array + ("query", 3, 5), # Query range [3, 4, 5] + ] + + for op in operations: + if op[0] == "update": + _, start, end, value = op + tree.update_range(start, end, value) + naive_range_update(naive_values, start, end, value) + + # Verify tree state after update + for i in range(len(values)): + for j in range(i, len(values)): + expected = naive_range_max(naive_values, i, j) + actual = tree.summarize_range(i, j) + assert actual == expected, ( + f"After update ({start}, {end}, {value}), query [{i}:{j}] expected {expected}, got {actual}" + ) + else: # query + _, start, end = op + expected = naive_range_max(naive_values, start, end) + assert tree.summarize_range(start, end) == expected, ( + f"Query [{start}:{end}] expected {expected}, got {tree.summarize_range(start, end)}" + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ff4c31821678..4cd847e81285 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -13754,6 +13754,45 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar has_lowered = not re.search(r"repeat_interleave.Tensor", code) self.assertEqual(has_lowered, can_lower) + @staticmethod + def _is_triggering_buffer_reuse(fn, *inputs): + with config.patch(allow_buffer_reuse=True): + _, (code_allowed,) = run_and_get_code(fn, *inputs) + with config.patch(allow_buffer_reuse=False): + _, (code_disallowed,) = run_and_get_code(fn, *inputs) + code_allowed = re.sub(r"AOT ID: .*", "AOT ID: ['test']", code_allowed) + code_disallowed = re.sub(r"AOT ID: .*", "AOT ID: ['test']", code_disallowed) + return code_allowed != code_disallowed + + def test_allow_reuse_disable_if_exceed_peak(self): + @torch.compile + def fn(inp): # 1*N^2 + a = inp.mean(-1) # 1*N^2 + N + b = (inp - a) ** 2 # 2*N^2 + N + c = b @ b # 3*N^2 (!!) since this is the peak, can not reuse across + d = c.mean(-1) # 2*N^2 + N + return d # 1*N^2 + N + + inp = torch.randn(100, 100, device=self.device) + self.assertFalse(CommonTemplate._is_triggering_buffer_reuse(fn, inp)) + + def test_allow_reuse_active_if_under_peak(self): + def g(inp): + return (inp - torch.logsumexp(inp, -1)) ** 2 + + @torch.compile + def fn(m, inp): + inp = m @ g(inp) + inp = m @ g(inp) + inp = m @ g(inp) + inp = m @ g(inp) + inp = m @ g(inp) + return inp + + m = torch.randn(100, 100, device=self.device) + inp = torch.randn(100, 100, device=self.device) + self.assertTrue(CommonTemplate._is_triggering_buffer_reuse(fn, m, inp)) + # end of class CommonTemplate - add new tests here diff --git a/torch/_inductor/codegen/segmented_tree.py b/torch/_inductor/codegen/segmented_tree.py new file mode 100644 index 000000000000..0c59dc65f950 --- /dev/null +++ b/torch/_inductor/codegen/segmented_tree.py @@ -0,0 +1,241 @@ +from typing import Callable, Generic, Optional, TypeVar + + +T = TypeVar("T") + + +def _value_or(opt: Optional[T], default: T) -> T: + return opt if opt is not None else default + + +class SegmentedTree(Generic[T]): + def __init__( + self, + values: list[T], + update_op: Callable[[T, T], T], + summary_op: Callable[[T, T], T], + identity_element: T, + ): + """ + Initialize a segment tree with the given values and operations. + + Args: + values: list of initial values + update_op: Function to apply when updating a value (e.g., addition) + summary_op: Function to summarize two values (e.g., min, max, sum) + identity_element: Identity element for the summary_op (e.g., 0 for sum, float('inf') for min) + + Raises: + ValueError: If the input values list is empty + """ + if not values: + raise ValueError("Cannot create a segment tree with empty values list") + + self.n = len(values) + self.update_op = update_op + self.summary_op = summary_op + self.identity = identity_element + + # Size of segment tree array (next power of 2 * 2) + # The tree follows a standard heap layout where + # node `n`'s children are at `2*n` and `2*n+1`. + # Index 0 is unused. + self.size = 1 + while self.size < self.n: + self.size *= 2 + self.size *= 2 + + # Initialize tree and lazy arrays + self.tree = [identity_element] * self.size + # The lazy array contains updates to the given node + # Upon update, we only push updates to the top-most + # nodes that fully receive the update. We then + # propagate the update down as required (i.e., when + # we receive an interval query that neither fully + # contains the node nor fully doesn't contain the + # node + self.lazy: list[Optional[T]] = [None] * self.size + + # Build the tree + self._build(values, 1, 0, self.n - 1) + + def _build(self, values: list[T], node: int, start: int, end: int) -> None: + """ + Build the segment tree recursively. + + Args: + values: Original array of values + node: Current node index in the segment tree + start: Start index of the segment + end: End index of the segment + """ + if start == end: + # Leaf node + if start < len(values): + self.tree[node] = values[start] + return + + mid = (start + end) // 2 + left_child = 2 * node + right_child = 2 * node + 1 + + # Recursively build left and right subtrees + self._build(values, left_child, start, mid) + self._build(values, right_child, mid + 1, end) + + # Update current node with summary of children + self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child]) + + def _children(self, node: int) -> list[int]: + return [2 * node, 2 * node + 1] + + def _push_lazy(self, node: int, start: int, end: int) -> None: + """ + Push lazy updates down to children. + + Args: + node: Current node index + start: Start index of the segment + end: End index of the segment + """ + lazy_node = self.lazy[node] + if lazy_node is None: + return + + # Apply lazy update to current node + self.tree[node] = self.update_op(self.tree[node], lazy_node) + + if start != end: # Not a leaf node + # Propagate to children + for child in self._children(node): + self.lazy[child] = self.update_op( + _value_or(self.lazy[child], self.identity), lazy_node + ) + + # Clear the lazy value + self.lazy[node] = None + + def _update_range_helper( + self, node: int, start: int, end: int, left: int, right: int, value: T + ) -> None: + """ + Helper method to update a range of values in the segment tree. + + Args: + node: Current node index + start: Start index of the current segment + end: End index of the current segment + left: Start index of the range to update + right: End index of the range to update + value: Value to apply to the range + """ + # Push lazy updates before processing this node + self._push_lazy(node, start, end) + + # No overlap + if start > right or end < left: + return + + # Complete overlap + if start >= left and end <= right: + # Apply update to current node + self.lazy[node] = value + self._push_lazy(node, start, end) + return + + # Partial overlap, recurse to children + mid = (start + end) // 2 + left_child = 2 * node + right_child = 2 * node + 1 + + self._update_range_helper(left_child, start, mid, left, right, value) + self._update_range_helper(right_child, mid + 1, end, left, right, value) + + # Update current node based on children + self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child]) + + def _query_range_helper( + self, node: int, start: int, end: int, left: int, right: int + ) -> T: + """ + Helper method to query a range of values in the segment tree. + + Args: + node: Current node index + start: Start index of the current segment + end: End index of the current segment + left: Start index of the range to query + right: End index of the range to query + + Returns: + Summary value for the range + """ + # No overlap + if start > right or end < left: + return self.identity + + # Push lazy updates before processing this node + self._push_lazy(node, start, end) + + # Complete overlap + if start >= left and end <= right: + return self.tree[node] + + # Partial overlap, recurse to children + mid = (start + end) // 2 + left_child = 2 * node + right_child = 2 * node + 1 + + left_result = self._query_range_helper(left_child, start, mid, left, right) + right_result = self._query_range_helper(right_child, mid + 1, end, left, right) + + # Combine results from children + return self.summary_op(left_result, right_result) + + def update_range(self, start: int, end: int, value: T) -> None: + """ + Update a range of values in the segment tree. + + Args: + start: Start index of the range to update (inclusive) + end: End index of the range to update (inclusive) + value: Value to apply to the range + + Raises: + ValueError: If start > end or indices are out of bounds + """ + if start > end: + raise ValueError("Start index must be less than or equal to end index") + + if start < 0 or start >= self.n: + raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]") + + if end < 0 or end >= self.n: + raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]") + + self._update_range_helper(1, 0, self.n - 1, start, end, value) + + def summarize_range(self, start: int, end: int) -> T: + """ + Query a range of values in the segment tree. + + Args: + start: Start index of the range to query (inclusive) + end: End index of the range to query (inclusive) + + Returns: + Summary value for the range according to the summary operation + + Raises: + ValueError: If start > end or indices are out of bounds + """ + if start > end: + raise ValueError("Start index must be less than or equal to end index") + + if start < 0 or start >= self.n: + raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]") + + if end < 0 or end >= self.n: + raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]") + + return self._query_range_helper(1, 0, self.n - 1, start, end) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 09f8050a0350..b6b8075e9284 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -48,6 +48,7 @@ from ..utils import ( cache_on_self, DelayReplaceLine, get_benchmark_name, + get_dtype_size, IndentedBuffer, is_codegen_graph_partition_subgraph, is_using_cudagraph_partition, @@ -587,10 +588,64 @@ class MemoryPlanningLine(WrapperLine): return f"{type(self).__name__}({', '.join(args)})" +class EfficientPeakEstimate: + def __init__(self): + from ..memory import estimate_peak_memory, get_freeable_input_buf + + scheduler_nodes = V.graph.scheduler.nodes + graph_inputs = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs = OrderedSet(V.graph.get_output_names()) + names_to_freeable_bufs = get_freeable_input_buf(scheduler_nodes, graph_inputs) + self.overall_peak_memory, peak_by_scheduler_node = estimate_peak_memory( + scheduler_nodes, + names_to_freeable_bufs, + graph_outputs, + ) + + from .segmented_tree import SegmentedTree + + self.segmented_tree = SegmentedTree( + peak_by_scheduler_node, operator.add, max, 0 + ) + + def _get_size(self, node: BufferLike) -> int: + return V.graph.sizevars.size_hint( + V.graph.get_allocation_storage_size(node), fallback=0 + ) * get_dtype_size(node.get_dtype()) + + def peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine): + return self.segmented_tree.summarize_range( + line_a.scheduler_node_index + 1, line_b.scheduler_node_index - 1 + ) + + def update_peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine): + if line_a.scheduler_node_index + 1 == line_b.scheduler_node_index: + return + self.segmented_tree.update_range( + line_a.scheduler_node_index + 1, + line_b.scheduler_node_index - 1, + self._get_size(line_b.node), + ) + + @dataclasses.dataclass class AllocateLine(MemoryPlanningLine): node: BufferLike + def __post_init__(self): + assert V.graph.scheduler.current_node is not None + self.scheduler_node_index = V.graph.scheduler.nodes.index( + V.graph.scheduler.current_node + ) + + def should_reuse_buffer(self, free_line: FreeIfNotReusedLine, size: int) -> bool: + if free_line.scheduler_node_index + 1 == self.scheduler_node_index: + return True + overall_peak_memory = self.wrapper.estimate_peak.overall_peak_memory + peak_memory_in_range = self.wrapper.estimate_peak.peak_between(free_line, self) + new_peak_memory = size + peak_memory_in_range + return new_peak_memory <= overall_peak_memory + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: if self.node.get_name() in V.graph.removed_buffers: return NullLine(self.wrapper) @@ -599,8 +654,16 @@ class AllocateLine(MemoryPlanningLine): key = buffer_reuse_key(self.node) if config.allow_buffer_reuse and key in state: free_line = state.pop(key) - free_line.is_reused = True - return ReuseLine(self.wrapper, free_line.node, self.node) + size = V.graph.sizevars.size_hint( + V.graph.get_allocation_storage_size(self.node), fallback=0 + ) * get_dtype_size(self.node.get_dtype()) + if self.should_reuse_buffer(free_line, size): + free_line.is_reused = True + self.wrapper.estimate_peak.update_peak_between(free_line, self) + return ReuseLine(self.wrapper, free_line.node, self.node) + else: + state.push(key, free_line) + return self if self.node.get_device_or_error().type == "cpu": static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node) @@ -625,6 +688,12 @@ class FreeIfNotReusedLine(MemoryPlanningLine): node: BufferLike is_reused: bool = False + def __post_init__(self): + assert V.graph.scheduler.current_node is not None + self.scheduler_node_index = V.graph.scheduler.nodes.index( + V.graph.scheduler.current_node + ) + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: if len(self.node.get_inputs_that_alias_output()) > 0: return self @@ -1645,6 +1714,7 @@ class PythonWrapperCodegen(CodeGen): if is_inference and config.memory_planning: self.memory_plan() else: + self.estimate_peak = EfficientPeakEstimate() self.memory_plan_reuse() def codegen_input_symbol_assignment( diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index c16d4478145c..71f7f9c8b503 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2073,6 +2073,7 @@ class Scheduler: ) self.nodes = [self.create_scheduler_node(n) for n in nodes] + self.current_node: Optional[BaseSchedulerNode] = None self.update_zero_dim_cpu_tensor() # some new constants could have been created above self.available_buffer_names.update(V.graph.constants.keys()) @@ -4989,6 +4990,7 @@ class Scheduler: assert device.index is not None, "device should have an index" V.graph.wrapper_code.codegen_device_guard_enter(device.index) + self.current_node = node self.buffer_names_to_free.update(node.last_usage) if node.is_template():