mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159530 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
62db8ec391
commit
65d21dae18
254
test/inductor/test_segmented_tree.py
Normal file
254
test/inductor/test_segmented_tree.py
Normal file
@ -0,0 +1,254 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
from hypothesis import given, strategies as st
|
||||
|
||||
from torch._inductor.codegen.segmented_tree import SegmentedTree
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
|
||||
|
||||
# Helper functions for operations
|
||||
def max_op(a, b):
|
||||
return max(a, b)
|
||||
|
||||
|
||||
def add_op(a, b):
|
||||
return a + b
|
||||
|
||||
|
||||
# Naive implementations for reference
|
||||
def naive_range_max(arr, start, end):
|
||||
return max(arr[start : end + 1])
|
||||
|
||||
|
||||
def naive_range_update(arr, start, end, value):
|
||||
for i in range(start, end + 1):
|
||||
arr[i] += value
|
||||
|
||||
|
||||
# Strategies for hypothesis testing
|
||||
positive_integers = st.lists(
|
||||
st.integers(min_value=1, max_value=100), min_size=1, max_size=50
|
||||
)
|
||||
|
||||
|
||||
def valid_range_indices(array_length):
|
||||
return st.tuples(
|
||||
st.integers(min_value=0, max_value=array_length - 1),
|
||||
st.integers(min_value=0, max_value=array_length - 1),
|
||||
).map(lambda x: (min(x), max(x)))
|
||||
|
||||
|
||||
update_values = st.integers(min_value=1, max_value=50)
|
||||
|
||||
|
||||
class TestSegmentedTree(TestCase):
|
||||
# Basic construction and initialization tests
|
||||
def test_basic_construction(self):
|
||||
values = [1, 3, 5, 7, 9]
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
assert tree.summarize_range(0, 4) == 9
|
||||
|
||||
def test_empty_array(self):
|
||||
with self.assertRaises(ValueError):
|
||||
SegmentedTree([], add_op, max_op, 0)
|
||||
|
||||
# Property-based tests
|
||||
@given(values=positive_integers)
|
||||
def test_max_query_matches_naive(self, values):
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
for start in range(len(values)):
|
||||
for end in range(start, len(values)):
|
||||
expected = naive_range_max(values, start, end)
|
||||
actual = tree.summarize_range(start, end)
|
||||
assert actual == expected, (
|
||||
f"Range [{start}:{end}] expected {expected}, got {actual}"
|
||||
)
|
||||
|
||||
@given(
|
||||
values=positive_integers, range_indices=st.data(), update_value=update_values
|
||||
)
|
||||
def test_range_update(self, values, range_indices, update_value):
|
||||
# Create a copy for naive implementation
|
||||
naive_values = values.copy()
|
||||
|
||||
# Create segment tree
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
# Get valid range indices
|
||||
start, end = range_indices.draw(valid_range_indices(len(values)))
|
||||
|
||||
# Apply updates
|
||||
tree.update_range(start, end, update_value)
|
||||
naive_range_update(naive_values, start, end, update_value)
|
||||
|
||||
# Verify all possible ranges
|
||||
for i in range(len(values)):
|
||||
for j in range(i, len(values)):
|
||||
expected = naive_range_max(naive_values, i, j)
|
||||
actual = tree.summarize_range(i, j)
|
||||
assert actual == expected, (
|
||||
f"After update, range [{i}:{j}] expected {expected}, got {actual}"
|
||||
)
|
||||
|
||||
@given(values=positive_integers, range_data=st.data())
|
||||
def test_multiple_operations(self, values, range_data):
|
||||
# Create a copy for naive implementation
|
||||
naive_values = values.copy()
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
# Perform multiple operations
|
||||
num_operations = 5
|
||||
for _ in range(num_operations):
|
||||
# Randomly choose between query and update
|
||||
operation_type = range_data.draw(st.sampled_from(["query", "update"]))
|
||||
start, end = range_data.draw(valid_range_indices(len(values)))
|
||||
|
||||
if operation_type == "query":
|
||||
expected = naive_range_max(naive_values, start, end)
|
||||
actual = tree.summarize_range(start, end)
|
||||
assert actual == expected, (
|
||||
f"Range query [{start}:{end}] expected {expected}, got {actual}"
|
||||
)
|
||||
else: # update
|
||||
update_value = range_data.draw(update_values)
|
||||
tree.update_range(start, end, update_value)
|
||||
naive_range_update(naive_values, start, end, update_value)
|
||||
|
||||
def test_single_element_ranges(self):
|
||||
values = [1, 3, 5, 7, 9]
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
for i in range(len(values)):
|
||||
assert tree.summarize_range(i, i) == values[i], (
|
||||
f"Single element range at index {i} failed"
|
||||
)
|
||||
|
||||
def test_full_array_range(self):
|
||||
values = [1, 3, 5, 7, 9]
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
# Test querying the entire array
|
||||
assert tree.summarize_range(0, len(values) - 1) == max(values)
|
||||
|
||||
# Update the entire array and test again
|
||||
update_value = 10
|
||||
tree.update_range(0, len(values) - 1, update_value)
|
||||
expected = max([v + update_value for v in values])
|
||||
assert tree.summarize_range(0, len(values) - 1) == expected
|
||||
|
||||
def test_boundary_conditions(self):
|
||||
values = [1, 3, 5, 7, 9]
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
# Test first element
|
||||
assert tree.summarize_range(0, 0) == values[0]
|
||||
|
||||
# Test last element
|
||||
assert tree.summarize_range(len(values) - 1, len(values) - 1) == values[-1]
|
||||
|
||||
# Test first two elements
|
||||
assert tree.summarize_range(0, 1) == max(values[0:2])
|
||||
|
||||
# Test last two elements
|
||||
assert tree.summarize_range(len(values) - 2, len(values) - 1) == max(
|
||||
values[-2:]
|
||||
)
|
||||
|
||||
def test_invalid_ranges(self):
|
||||
values = [1, 3, 5, 7, 9]
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
# Test start > end
|
||||
with self.assertRaises(ValueError):
|
||||
tree.summarize_range(3, 2)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tree.update_range(4, 2, 10)
|
||||
|
||||
def test_out_of_bounds(self):
|
||||
values = [1, 3, 5, 7, 9]
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
# Test negative indices
|
||||
with self.assertRaises(ValueError):
|
||||
tree.summarize_range(-1, 3)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tree.summarize_range(0, -1)
|
||||
|
||||
# Test indices >= n
|
||||
with self.assertRaises(ValueError):
|
||||
tree.summarize_range(0, len(values))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tree.summarize_range(len(values), len(values) + 1)
|
||||
|
||||
# Test update with out of bounds indices
|
||||
with self.assertRaises(ValueError):
|
||||
tree.update_range(-1, 3, 10)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tree.update_range(0, len(values), 10)
|
||||
|
||||
def test_overlapping_updates(self):
|
||||
values = [1, 3, 5, 7, 9]
|
||||
naive_values = values.copy()
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
# Apply overlapping updates
|
||||
tree.update_range(0, 2, 5) # Update [0, 1, 2]
|
||||
naive_range_update(naive_values, 0, 2, 5)
|
||||
|
||||
tree.update_range(1, 3, 3) # Update [1, 2, 3]
|
||||
naive_range_update(naive_values, 1, 3, 3)
|
||||
|
||||
# Verify all possible ranges
|
||||
for i in range(len(values)):
|
||||
for j in range(i, len(values)):
|
||||
expected = naive_range_max(naive_values, i, j)
|
||||
actual = tree.summarize_range(i, j)
|
||||
assert actual == expected, (
|
||||
f"After overlapping updates, range [{i}:{j}] expected {expected}, got {actual}"
|
||||
)
|
||||
|
||||
def test_sequential_updates_and_queries(self):
|
||||
values = [2, 4, 6, 8, 10, 12, 14]
|
||||
naive_values = values.copy()
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
# Sequence of operations
|
||||
operations = [
|
||||
("update", 1, 3, 5), # Update range [1, 2, 3] with +5
|
||||
("query", 0, 4), # Query range [0, 1, 2, 3, 4]
|
||||
("update", 2, 5, 3), # Update range [2, 3, 4, 5] with +3
|
||||
("query", 1, 3), # Query range [1, 2, 3]
|
||||
("update", 0, 6, 2), # Update entire array with +2
|
||||
("query", 0, 6), # Query entire array
|
||||
("query", 3, 5), # Query range [3, 4, 5]
|
||||
]
|
||||
|
||||
for op in operations:
|
||||
if op[0] == "update":
|
||||
_, start, end, value = op
|
||||
tree.update_range(start, end, value)
|
||||
naive_range_update(naive_values, start, end, value)
|
||||
|
||||
# Verify tree state after update
|
||||
for i in range(len(values)):
|
||||
for j in range(i, len(values)):
|
||||
expected = naive_range_max(naive_values, i, j)
|
||||
actual = tree.summarize_range(i, j)
|
||||
assert actual == expected, (
|
||||
f"After update ({start}, {end}, {value}), query [{i}:{j}] expected {expected}, got {actual}"
|
||||
)
|
||||
else: # query
|
||||
_, start, end = op
|
||||
expected = naive_range_max(naive_values, start, end)
|
||||
assert tree.summarize_range(start, end) == expected, (
|
||||
f"Query [{start}:{end}] expected {expected}, got {tree.summarize_range(start, end)}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -13754,6 +13754,45 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
has_lowered = not re.search(r"repeat_interleave.Tensor", code)
|
||||
self.assertEqual(has_lowered, can_lower)
|
||||
|
||||
@staticmethod
|
||||
def _is_triggering_buffer_reuse(fn, *inputs):
|
||||
with config.patch(allow_buffer_reuse=True):
|
||||
_, (code_allowed,) = run_and_get_code(fn, *inputs)
|
||||
with config.patch(allow_buffer_reuse=False):
|
||||
_, (code_disallowed,) = run_and_get_code(fn, *inputs)
|
||||
code_allowed = re.sub(r"AOT ID: .*", "AOT ID: ['test']", code_allowed)
|
||||
code_disallowed = re.sub(r"AOT ID: .*", "AOT ID: ['test']", code_disallowed)
|
||||
return code_allowed != code_disallowed
|
||||
|
||||
def test_allow_reuse_disable_if_exceed_peak(self):
|
||||
@torch.compile
|
||||
def fn(inp): # 1*N^2
|
||||
a = inp.mean(-1) # 1*N^2 + N
|
||||
b = (inp - a) ** 2 # 2*N^2 + N
|
||||
c = b @ b # 3*N^2 (!!) since this is the peak, can not reuse across
|
||||
d = c.mean(-1) # 2*N^2 + N
|
||||
return d # 1*N^2 + N
|
||||
|
||||
inp = torch.randn(100, 100, device=self.device)
|
||||
self.assertFalse(CommonTemplate._is_triggering_buffer_reuse(fn, inp))
|
||||
|
||||
def test_allow_reuse_active_if_under_peak(self):
|
||||
def g(inp):
|
||||
return (inp - torch.logsumexp(inp, -1)) ** 2
|
||||
|
||||
@torch.compile
|
||||
def fn(m, inp):
|
||||
inp = m @ g(inp)
|
||||
inp = m @ g(inp)
|
||||
inp = m @ g(inp)
|
||||
inp = m @ g(inp)
|
||||
inp = m @ g(inp)
|
||||
return inp
|
||||
|
||||
m = torch.randn(100, 100, device=self.device)
|
||||
inp = torch.randn(100, 100, device=self.device)
|
||||
self.assertTrue(CommonTemplate._is_triggering_buffer_reuse(fn, m, inp))
|
||||
|
||||
# end of class CommonTemplate - add new tests here
|
||||
|
||||
|
||||
|
241
torch/_inductor/codegen/segmented_tree.py
Normal file
241
torch/_inductor/codegen/segmented_tree.py
Normal file
@ -0,0 +1,241 @@
|
||||
from typing import Callable, Generic, Optional, TypeVar
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _value_or(opt: Optional[T], default: T) -> T:
|
||||
return opt if opt is not None else default
|
||||
|
||||
|
||||
class SegmentedTree(Generic[T]):
|
||||
def __init__(
|
||||
self,
|
||||
values: list[T],
|
||||
update_op: Callable[[T, T], T],
|
||||
summary_op: Callable[[T, T], T],
|
||||
identity_element: T,
|
||||
):
|
||||
"""
|
||||
Initialize a segment tree with the given values and operations.
|
||||
|
||||
Args:
|
||||
values: list of initial values
|
||||
update_op: Function to apply when updating a value (e.g., addition)
|
||||
summary_op: Function to summarize two values (e.g., min, max, sum)
|
||||
identity_element: Identity element for the summary_op (e.g., 0 for sum, float('inf') for min)
|
||||
|
||||
Raises:
|
||||
ValueError: If the input values list is empty
|
||||
"""
|
||||
if not values:
|
||||
raise ValueError("Cannot create a segment tree with empty values list")
|
||||
|
||||
self.n = len(values)
|
||||
self.update_op = update_op
|
||||
self.summary_op = summary_op
|
||||
self.identity = identity_element
|
||||
|
||||
# Size of segment tree array (next power of 2 * 2)
|
||||
# The tree follows a standard heap layout where
|
||||
# node `n`'s children are at `2*n` and `2*n+1`.
|
||||
# Index 0 is unused.
|
||||
self.size = 1
|
||||
while self.size < self.n:
|
||||
self.size *= 2
|
||||
self.size *= 2
|
||||
|
||||
# Initialize tree and lazy arrays
|
||||
self.tree = [identity_element] * self.size
|
||||
# The lazy array contains updates to the given node
|
||||
# Upon update, we only push updates to the top-most
|
||||
# nodes that fully receive the update. We then
|
||||
# propagate the update down as required (i.e., when
|
||||
# we receive an interval query that neither fully
|
||||
# contains the node nor fully doesn't contain the
|
||||
# node
|
||||
self.lazy: list[Optional[T]] = [None] * self.size
|
||||
|
||||
# Build the tree
|
||||
self._build(values, 1, 0, self.n - 1)
|
||||
|
||||
def _build(self, values: list[T], node: int, start: int, end: int) -> None:
|
||||
"""
|
||||
Build the segment tree recursively.
|
||||
|
||||
Args:
|
||||
values: Original array of values
|
||||
node: Current node index in the segment tree
|
||||
start: Start index of the segment
|
||||
end: End index of the segment
|
||||
"""
|
||||
if start == end:
|
||||
# Leaf node
|
||||
if start < len(values):
|
||||
self.tree[node] = values[start]
|
||||
return
|
||||
|
||||
mid = (start + end) // 2
|
||||
left_child = 2 * node
|
||||
right_child = 2 * node + 1
|
||||
|
||||
# Recursively build left and right subtrees
|
||||
self._build(values, left_child, start, mid)
|
||||
self._build(values, right_child, mid + 1, end)
|
||||
|
||||
# Update current node with summary of children
|
||||
self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
|
||||
|
||||
def _children(self, node: int) -> list[int]:
|
||||
return [2 * node, 2 * node + 1]
|
||||
|
||||
def _push_lazy(self, node: int, start: int, end: int) -> None:
|
||||
"""
|
||||
Push lazy updates down to children.
|
||||
|
||||
Args:
|
||||
node: Current node index
|
||||
start: Start index of the segment
|
||||
end: End index of the segment
|
||||
"""
|
||||
lazy_node = self.lazy[node]
|
||||
if lazy_node is None:
|
||||
return
|
||||
|
||||
# Apply lazy update to current node
|
||||
self.tree[node] = self.update_op(self.tree[node], lazy_node)
|
||||
|
||||
if start != end: # Not a leaf node
|
||||
# Propagate to children
|
||||
for child in self._children(node):
|
||||
self.lazy[child] = self.update_op(
|
||||
_value_or(self.lazy[child], self.identity), lazy_node
|
||||
)
|
||||
|
||||
# Clear the lazy value
|
||||
self.lazy[node] = None
|
||||
|
||||
def _update_range_helper(
|
||||
self, node: int, start: int, end: int, left: int, right: int, value: T
|
||||
) -> None:
|
||||
"""
|
||||
Helper method to update a range of values in the segment tree.
|
||||
|
||||
Args:
|
||||
node: Current node index
|
||||
start: Start index of the current segment
|
||||
end: End index of the current segment
|
||||
left: Start index of the range to update
|
||||
right: End index of the range to update
|
||||
value: Value to apply to the range
|
||||
"""
|
||||
# Push lazy updates before processing this node
|
||||
self._push_lazy(node, start, end)
|
||||
|
||||
# No overlap
|
||||
if start > right or end < left:
|
||||
return
|
||||
|
||||
# Complete overlap
|
||||
if start >= left and end <= right:
|
||||
# Apply update to current node
|
||||
self.lazy[node] = value
|
||||
self._push_lazy(node, start, end)
|
||||
return
|
||||
|
||||
# Partial overlap, recurse to children
|
||||
mid = (start + end) // 2
|
||||
left_child = 2 * node
|
||||
right_child = 2 * node + 1
|
||||
|
||||
self._update_range_helper(left_child, start, mid, left, right, value)
|
||||
self._update_range_helper(right_child, mid + 1, end, left, right, value)
|
||||
|
||||
# Update current node based on children
|
||||
self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
|
||||
|
||||
def _query_range_helper(
|
||||
self, node: int, start: int, end: int, left: int, right: int
|
||||
) -> T:
|
||||
"""
|
||||
Helper method to query a range of values in the segment tree.
|
||||
|
||||
Args:
|
||||
node: Current node index
|
||||
start: Start index of the current segment
|
||||
end: End index of the current segment
|
||||
left: Start index of the range to query
|
||||
right: End index of the range to query
|
||||
|
||||
Returns:
|
||||
Summary value for the range
|
||||
"""
|
||||
# No overlap
|
||||
if start > right or end < left:
|
||||
return self.identity
|
||||
|
||||
# Push lazy updates before processing this node
|
||||
self._push_lazy(node, start, end)
|
||||
|
||||
# Complete overlap
|
||||
if start >= left and end <= right:
|
||||
return self.tree[node]
|
||||
|
||||
# Partial overlap, recurse to children
|
||||
mid = (start + end) // 2
|
||||
left_child = 2 * node
|
||||
right_child = 2 * node + 1
|
||||
|
||||
left_result = self._query_range_helper(left_child, start, mid, left, right)
|
||||
right_result = self._query_range_helper(right_child, mid + 1, end, left, right)
|
||||
|
||||
# Combine results from children
|
||||
return self.summary_op(left_result, right_result)
|
||||
|
||||
def update_range(self, start: int, end: int, value: T) -> None:
|
||||
"""
|
||||
Update a range of values in the segment tree.
|
||||
|
||||
Args:
|
||||
start: Start index of the range to update (inclusive)
|
||||
end: End index of the range to update (inclusive)
|
||||
value: Value to apply to the range
|
||||
|
||||
Raises:
|
||||
ValueError: If start > end or indices are out of bounds
|
||||
"""
|
||||
if start > end:
|
||||
raise ValueError("Start index must be less than or equal to end index")
|
||||
|
||||
if start < 0 or start >= self.n:
|
||||
raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]")
|
||||
|
||||
if end < 0 or end >= self.n:
|
||||
raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]")
|
||||
|
||||
self._update_range_helper(1, 0, self.n - 1, start, end, value)
|
||||
|
||||
def summarize_range(self, start: int, end: int) -> T:
|
||||
"""
|
||||
Query a range of values in the segment tree.
|
||||
|
||||
Args:
|
||||
start: Start index of the range to query (inclusive)
|
||||
end: End index of the range to query (inclusive)
|
||||
|
||||
Returns:
|
||||
Summary value for the range according to the summary operation
|
||||
|
||||
Raises:
|
||||
ValueError: If start > end or indices are out of bounds
|
||||
"""
|
||||
if start > end:
|
||||
raise ValueError("Start index must be less than or equal to end index")
|
||||
|
||||
if start < 0 or start >= self.n:
|
||||
raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]")
|
||||
|
||||
if end < 0 or end >= self.n:
|
||||
raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]")
|
||||
|
||||
return self._query_range_helper(1, 0, self.n - 1, start, end)
|
@ -48,6 +48,7 @@ from ..utils import (
|
||||
cache_on_self,
|
||||
DelayReplaceLine,
|
||||
get_benchmark_name,
|
||||
get_dtype_size,
|
||||
IndentedBuffer,
|
||||
is_codegen_graph_partition_subgraph,
|
||||
is_using_cudagraph_partition,
|
||||
@ -587,10 +588,64 @@ class MemoryPlanningLine(WrapperLine):
|
||||
return f"{type(self).__name__}({', '.join(args)})"
|
||||
|
||||
|
||||
class EfficientPeakEstimate:
|
||||
def __init__(self):
|
||||
from ..memory import estimate_peak_memory, get_freeable_input_buf
|
||||
|
||||
scheduler_nodes = V.graph.scheduler.nodes
|
||||
graph_inputs = OrderedSet(V.graph.graph_inputs.keys())
|
||||
graph_outputs = OrderedSet(V.graph.get_output_names())
|
||||
names_to_freeable_bufs = get_freeable_input_buf(scheduler_nodes, graph_inputs)
|
||||
self.overall_peak_memory, peak_by_scheduler_node = estimate_peak_memory(
|
||||
scheduler_nodes,
|
||||
names_to_freeable_bufs,
|
||||
graph_outputs,
|
||||
)
|
||||
|
||||
from .segmented_tree import SegmentedTree
|
||||
|
||||
self.segmented_tree = SegmentedTree(
|
||||
peak_by_scheduler_node, operator.add, max, 0
|
||||
)
|
||||
|
||||
def _get_size(self, node: BufferLike) -> int:
|
||||
return V.graph.sizevars.size_hint(
|
||||
V.graph.get_allocation_storage_size(node), fallback=0
|
||||
) * get_dtype_size(node.get_dtype())
|
||||
|
||||
def peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine):
|
||||
return self.segmented_tree.summarize_range(
|
||||
line_a.scheduler_node_index + 1, line_b.scheduler_node_index - 1
|
||||
)
|
||||
|
||||
def update_peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine):
|
||||
if line_a.scheduler_node_index + 1 == line_b.scheduler_node_index:
|
||||
return
|
||||
self.segmented_tree.update_range(
|
||||
line_a.scheduler_node_index + 1,
|
||||
line_b.scheduler_node_index - 1,
|
||||
self._get_size(line_b.node),
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AllocateLine(MemoryPlanningLine):
|
||||
node: BufferLike
|
||||
|
||||
def __post_init__(self):
|
||||
assert V.graph.scheduler.current_node is not None
|
||||
self.scheduler_node_index = V.graph.scheduler.nodes.index(
|
||||
V.graph.scheduler.current_node
|
||||
)
|
||||
|
||||
def should_reuse_buffer(self, free_line: FreeIfNotReusedLine, size: int) -> bool:
|
||||
if free_line.scheduler_node_index + 1 == self.scheduler_node_index:
|
||||
return True
|
||||
overall_peak_memory = self.wrapper.estimate_peak.overall_peak_memory
|
||||
peak_memory_in_range = self.wrapper.estimate_peak.peak_between(free_line, self)
|
||||
new_peak_memory = size + peak_memory_in_range
|
||||
return new_peak_memory <= overall_peak_memory
|
||||
|
||||
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
|
||||
if self.node.get_name() in V.graph.removed_buffers:
|
||||
return NullLine(self.wrapper)
|
||||
@ -599,8 +654,16 @@ class AllocateLine(MemoryPlanningLine):
|
||||
key = buffer_reuse_key(self.node)
|
||||
if config.allow_buffer_reuse and key in state:
|
||||
free_line = state.pop(key)
|
||||
free_line.is_reused = True
|
||||
return ReuseLine(self.wrapper, free_line.node, self.node)
|
||||
size = V.graph.sizevars.size_hint(
|
||||
V.graph.get_allocation_storage_size(self.node), fallback=0
|
||||
) * get_dtype_size(self.node.get_dtype())
|
||||
if self.should_reuse_buffer(free_line, size):
|
||||
free_line.is_reused = True
|
||||
self.wrapper.estimate_peak.update_peak_between(free_line, self)
|
||||
return ReuseLine(self.wrapper, free_line.node, self.node)
|
||||
else:
|
||||
state.push(key, free_line)
|
||||
return self
|
||||
|
||||
if self.node.get_device_or_error().type == "cpu":
|
||||
static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node)
|
||||
@ -625,6 +688,12 @@ class FreeIfNotReusedLine(MemoryPlanningLine):
|
||||
node: BufferLike
|
||||
is_reused: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
assert V.graph.scheduler.current_node is not None
|
||||
self.scheduler_node_index = V.graph.scheduler.nodes.index(
|
||||
V.graph.scheduler.current_node
|
||||
)
|
||||
|
||||
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
|
||||
if len(self.node.get_inputs_that_alias_output()) > 0:
|
||||
return self
|
||||
@ -1645,6 +1714,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
if is_inference and config.memory_planning:
|
||||
self.memory_plan()
|
||||
else:
|
||||
self.estimate_peak = EfficientPeakEstimate()
|
||||
self.memory_plan_reuse()
|
||||
|
||||
def codegen_input_symbol_assignment(
|
||||
|
@ -2073,6 +2073,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
self.nodes = [self.create_scheduler_node(n) for n in nodes]
|
||||
self.current_node: Optional[BaseSchedulerNode] = None
|
||||
self.update_zero_dim_cpu_tensor()
|
||||
# some new constants could have been created above
|
||||
self.available_buffer_names.update(V.graph.constants.keys())
|
||||
@ -4989,6 +4990,7 @@ class Scheduler:
|
||||
assert device.index is not None, "device should have an index"
|
||||
V.graph.wrapper_code.codegen_device_guard_enter(device.index)
|
||||
|
||||
self.current_node = node
|
||||
self.buffer_names_to_free.update(node.last_usage)
|
||||
|
||||
if node.is_template():
|
||||
|
Reference in New Issue
Block a user