mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit 3be70dc30e893b552fc0f23ca06cd8f7949b6d08. Reverted https://github.com/pytorch/pytorch/pull/159530 on behalf of https://github.com/clee2000 due to newly added test fail internally D80316528, probably just a targets change, but also imo the tests should probably go into a testcase class from common or inductor utils. While I'm pretty sure CI can run the globally defined ones, theres some CI related functionality that on the testcase class that CI benefits from ([comment](https://github.com/pytorch/pytorch/pull/159530#issuecomment-3191947506))
This commit is contained in:
@ -1,261 +0,0 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, strategies as st
|
||||
|
||||
from torch._inductor.codegen.segmented_tree import SegmentedTree
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
||||
# Helper functions for operations
|
||||
def max_op(a, b):
|
||||
return max(a, b)
|
||||
|
||||
|
||||
def add_op(a, b):
|
||||
return a + b
|
||||
|
||||
|
||||
# Naive implementations for reference
|
||||
def naive_range_max(arr, start, end):
|
||||
return max(arr[start : end + 1])
|
||||
|
||||
|
||||
def naive_range_update(arr, start, end, value):
|
||||
for i in range(start, end + 1):
|
||||
arr[i] += value
|
||||
|
||||
|
||||
# Strategies for hypothesis testing
|
||||
positive_integers = st.lists(
|
||||
st.integers(min_value=1, max_value=100), min_size=1, max_size=50
|
||||
)
|
||||
|
||||
|
||||
def valid_range_indices(array_length):
|
||||
return st.tuples(
|
||||
st.integers(min_value=0, max_value=array_length - 1),
|
||||
st.integers(min_value=0, max_value=array_length - 1),
|
||||
).map(lambda x: (min(x), max(x)))
|
||||
|
||||
|
||||
update_values = st.integers(min_value=1, max_value=50)
|
||||
|
||||
|
||||
# Basic construction and initialization tests
|
||||
def test_basic_construction():
|
||||
values = [1, 3, 5, 7, 9]
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
assert tree.summarize_range(0, 4) == 9
|
||||
|
||||
|
||||
def test_empty_array():
|
||||
with pytest.raises(ValueError):
|
||||
SegmentedTree([], add_op, max_op, 0)
|
||||
|
||||
|
||||
# Property-based tests
|
||||
@given(values=positive_integers)
|
||||
def test_max_query_matches_naive(values):
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
for start in range(len(values)):
|
||||
for end in range(start, len(values)):
|
||||
expected = naive_range_max(values, start, end)
|
||||
actual = tree.summarize_range(start, end)
|
||||
assert actual == expected, (
|
||||
f"Range [{start}:{end}] expected {expected}, got {actual}"
|
||||
)
|
||||
|
||||
|
||||
@given(values=positive_integers, range_indices=st.data(), update_value=update_values)
|
||||
def test_range_update(values, range_indices, update_value):
|
||||
# Create a copy for naive implementation
|
||||
naive_values = values.copy()
|
||||
|
||||
# Create segment tree
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
# Get valid range indices
|
||||
start, end = range_indices.draw(valid_range_indices(len(values)))
|
||||
|
||||
# Apply updates
|
||||
tree.update_range(start, end, update_value)
|
||||
naive_range_update(naive_values, start, end, update_value)
|
||||
|
||||
# Verify all possible ranges
|
||||
for i in range(len(values)):
|
||||
for j in range(i, len(values)):
|
||||
expected = naive_range_max(naive_values, i, j)
|
||||
actual = tree.summarize_range(i, j)
|
||||
assert actual == expected, (
|
||||
f"After update, range [{i}:{j}] expected {expected}, got {actual}"
|
||||
)
|
||||
|
||||
|
||||
@given(values=positive_integers, range_data=st.data())
|
||||
def test_multiple_operations(values, range_data):
|
||||
# Create a copy for naive implementation
|
||||
naive_values = values.copy()
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
# Perform multiple operations
|
||||
num_operations = 5
|
||||
for _ in range(num_operations):
|
||||
# Randomly choose between query and update
|
||||
operation_type = range_data.draw(st.sampled_from(["query", "update"]))
|
||||
start, end = range_data.draw(valid_range_indices(len(values)))
|
||||
|
||||
if operation_type == "query":
|
||||
expected = naive_range_max(naive_values, start, end)
|
||||
actual = tree.summarize_range(start, end)
|
||||
assert actual == expected, (
|
||||
f"Range query [{start}:{end}] expected {expected}, got {actual}"
|
||||
)
|
||||
else: # update
|
||||
update_value = range_data.draw(update_values)
|
||||
tree.update_range(start, end, update_value)
|
||||
naive_range_update(naive_values, start, end, update_value)
|
||||
|
||||
|
||||
def test_single_element_ranges():
|
||||
values = [1, 3, 5, 7, 9]
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
for i in range(len(values)):
|
||||
assert tree.summarize_range(i, i) == values[i], (
|
||||
f"Single element range at index {i} failed"
|
||||
)
|
||||
|
||||
|
||||
def test_full_array_range():
|
||||
values = [1, 3, 5, 7, 9]
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
# Test querying the entire array
|
||||
assert tree.summarize_range(0, len(values) - 1) == max(values)
|
||||
|
||||
# Update the entire array and test again
|
||||
update_value = 10
|
||||
tree.update_range(0, len(values) - 1, update_value)
|
||||
expected = max([v + update_value for v in values])
|
||||
assert tree.summarize_range(0, len(values) - 1) == expected
|
||||
|
||||
|
||||
def test_boundary_conditions():
|
||||
values = [1, 3, 5, 7, 9]
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
# Test first element
|
||||
assert tree.summarize_range(0, 0) == values[0]
|
||||
|
||||
# Test last element
|
||||
assert tree.summarize_range(len(values) - 1, len(values) - 1) == values[-1]
|
||||
|
||||
# Test first two elements
|
||||
assert tree.summarize_range(0, 1) == max(values[0:2])
|
||||
|
||||
# Test last two elements
|
||||
assert tree.summarize_range(len(values) - 2, len(values) - 1) == max(values[-2:])
|
||||
|
||||
|
||||
def test_invalid_ranges():
|
||||
values = [1, 3, 5, 7, 9]
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
# Test start > end
|
||||
with pytest.raises(ValueError):
|
||||
tree.summarize_range(3, 2)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tree.update_range(4, 2, 10)
|
||||
|
||||
|
||||
def test_out_of_bounds():
|
||||
values = [1, 3, 5, 7, 9]
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
# Test negative indices
|
||||
with pytest.raises(ValueError):
|
||||
tree.summarize_range(-1, 3)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tree.summarize_range(0, -1)
|
||||
|
||||
# Test indices >= n
|
||||
with pytest.raises(ValueError):
|
||||
tree.summarize_range(0, len(values))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tree.summarize_range(len(values), len(values) + 1)
|
||||
|
||||
# Test update with out of bounds indices
|
||||
with pytest.raises(ValueError):
|
||||
tree.update_range(-1, 3, 10)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tree.update_range(0, len(values), 10)
|
||||
|
||||
|
||||
def test_overlapping_updates():
|
||||
values = [1, 3, 5, 7, 9]
|
||||
naive_values = values.copy()
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
# Apply overlapping updates
|
||||
tree.update_range(0, 2, 5) # Update [0, 1, 2]
|
||||
naive_range_update(naive_values, 0, 2, 5)
|
||||
|
||||
tree.update_range(1, 3, 3) # Update [1, 2, 3]
|
||||
naive_range_update(naive_values, 1, 3, 3)
|
||||
|
||||
# Verify all possible ranges
|
||||
for i in range(len(values)):
|
||||
for j in range(i, len(values)):
|
||||
expected = naive_range_max(naive_values, i, j)
|
||||
actual = tree.summarize_range(i, j)
|
||||
assert actual == expected, (
|
||||
f"After overlapping updates, range [{i}:{j}] expected {expected}, got {actual}"
|
||||
)
|
||||
|
||||
|
||||
def test_sequential_updates_and_queries():
|
||||
values = [2, 4, 6, 8, 10, 12, 14]
|
||||
naive_values = values.copy()
|
||||
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||
|
||||
# Sequence of operations
|
||||
operations = [
|
||||
("update", 1, 3, 5), # Update range [1, 2, 3] with +5
|
||||
("query", 0, 4), # Query range [0, 1, 2, 3, 4]
|
||||
("update", 2, 5, 3), # Update range [2, 3, 4, 5] with +3
|
||||
("query", 1, 3), # Query range [1, 2, 3]
|
||||
("update", 0, 6, 2), # Update entire array with +2
|
||||
("query", 0, 6), # Query entire array
|
||||
("query", 3, 5), # Query range [3, 4, 5]
|
||||
]
|
||||
|
||||
for op in operations:
|
||||
if op[0] == "update":
|
||||
_, start, end, value = op
|
||||
tree.update_range(start, end, value)
|
||||
naive_range_update(naive_values, start, end, value)
|
||||
|
||||
# Verify tree state after update
|
||||
for i in range(len(values)):
|
||||
for j in range(i, len(values)):
|
||||
expected = naive_range_max(naive_values, i, j)
|
||||
actual = tree.summarize_range(i, j)
|
||||
assert actual == expected, (
|
||||
f"After update ({start}, {end}, {value}), query [{i}:{j}] expected {expected}, got {actual}"
|
||||
)
|
||||
else: # query
|
||||
_, start, end = op
|
||||
expected = naive_range_max(naive_values, start, end)
|
||||
assert tree.summarize_range(start, end) == expected, (
|
||||
f"Query [{start}:{end}] expected {expected}, got {tree.summarize_range(start, end)}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -13754,45 +13754,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
has_lowered = not re.search(r"repeat_interleave.Tensor", code)
|
||||
self.assertEqual(has_lowered, can_lower)
|
||||
|
||||
@staticmethod
|
||||
def _is_triggering_buffer_reuse(fn, *inputs):
|
||||
with config.patch(allow_buffer_reuse=True):
|
||||
_, (code_allowed,) = run_and_get_code(fn, *inputs)
|
||||
with config.patch(allow_buffer_reuse=False):
|
||||
_, (code_disallowed,) = run_and_get_code(fn, *inputs)
|
||||
code_allowed = re.sub(r"AOT ID: .*", "AOT ID: ['test']", code_allowed)
|
||||
code_disallowed = re.sub(r"AOT ID: .*", "AOT ID: ['test']", code_disallowed)
|
||||
return code_allowed != code_disallowed
|
||||
|
||||
def test_allow_reuse_disable_if_exceed_peak(self):
|
||||
@torch.compile
|
||||
def fn(inp): # 1*N^2
|
||||
a = inp.mean(-1) # 1*N^2 + N
|
||||
b = (inp - a) ** 2 # 2*N^2 + N
|
||||
c = b @ b # 3*N^2 (!!) since this is the peak, can not reuse across
|
||||
d = c.mean(-1) # 2*N^2 + N
|
||||
return d # 1*N^2 + N
|
||||
|
||||
inp = torch.randn(100, 100, device=self.device)
|
||||
self.assertFalse(CommonTemplate._is_triggering_buffer_reuse(fn, inp))
|
||||
|
||||
def test_allow_reuse_active_if_under_peak(self):
|
||||
def g(inp):
|
||||
return (inp - torch.logsumexp(inp, -1)) ** 2
|
||||
|
||||
@torch.compile
|
||||
def fn(m, inp):
|
||||
inp = m @ g(inp)
|
||||
inp = m @ g(inp)
|
||||
inp = m @ g(inp)
|
||||
inp = m @ g(inp)
|
||||
inp = m @ g(inp)
|
||||
return inp
|
||||
|
||||
m = torch.randn(100, 100, device=self.device)
|
||||
inp = torch.randn(100, 100, device=self.device)
|
||||
self.assertTrue(CommonTemplate._is_triggering_buffer_reuse(fn, m, inp))
|
||||
|
||||
# end of class CommonTemplate - add new tests here
|
||||
|
||||
|
||||
|
@ -1,241 +0,0 @@
|
||||
from typing import Callable, Generic, Optional, TypeVar
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _value_or(opt: Optional[T], default: T) -> T:
|
||||
return opt if opt is not None else default
|
||||
|
||||
|
||||
class SegmentedTree(Generic[T]):
|
||||
def __init__(
|
||||
self,
|
||||
values: list[T],
|
||||
update_op: Callable[[T, T], T],
|
||||
summary_op: Callable[[T, T], T],
|
||||
identity_element: T,
|
||||
):
|
||||
"""
|
||||
Initialize a segment tree with the given values and operations.
|
||||
|
||||
Args:
|
||||
values: list of initial values
|
||||
update_op: Function to apply when updating a value (e.g., addition)
|
||||
summary_op: Function to summarize two values (e.g., min, max, sum)
|
||||
identity_element: Identity element for the summary_op (e.g., 0 for sum, float('inf') for min)
|
||||
|
||||
Raises:
|
||||
ValueError: If the input values list is empty
|
||||
"""
|
||||
if not values:
|
||||
raise ValueError("Cannot create a segment tree with empty values list")
|
||||
|
||||
self.n = len(values)
|
||||
self.update_op = update_op
|
||||
self.summary_op = summary_op
|
||||
self.identity = identity_element
|
||||
|
||||
# Size of segment tree array (next power of 2 * 2)
|
||||
# The tree follows a standard heap layout where
|
||||
# node `n`'s children are at `2*n` and `2*n+1`.
|
||||
# Index 0 is unused.
|
||||
self.size = 1
|
||||
while self.size < self.n:
|
||||
self.size *= 2
|
||||
self.size *= 2
|
||||
|
||||
# Initialize tree and lazy arrays
|
||||
self.tree = [identity_element] * self.size
|
||||
# The lazy array contains updates to the given node
|
||||
# Upon update, we only push updates to the top-most
|
||||
# nodes that fully receive the update. We then
|
||||
# propagate the update down as required (i.e., when
|
||||
# we receive an interval query that neither fully
|
||||
# contains the node nor fully doesn't contain the
|
||||
# node
|
||||
self.lazy: list[Optional[T]] = [None] * self.size
|
||||
|
||||
# Build the tree
|
||||
self._build(values, 1, 0, self.n - 1)
|
||||
|
||||
def _build(self, values: list[T], node: int, start: int, end: int) -> None:
|
||||
"""
|
||||
Build the segment tree recursively.
|
||||
|
||||
Args:
|
||||
values: Original array of values
|
||||
node: Current node index in the segment tree
|
||||
start: Start index of the segment
|
||||
end: End index of the segment
|
||||
"""
|
||||
if start == end:
|
||||
# Leaf node
|
||||
if start < len(values):
|
||||
self.tree[node] = values[start]
|
||||
return
|
||||
|
||||
mid = (start + end) // 2
|
||||
left_child = 2 * node
|
||||
right_child = 2 * node + 1
|
||||
|
||||
# Recursively build left and right subtrees
|
||||
self._build(values, left_child, start, mid)
|
||||
self._build(values, right_child, mid + 1, end)
|
||||
|
||||
# Update current node with summary of children
|
||||
self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
|
||||
|
||||
def _children(self, node: int) -> list[int]:
|
||||
return [2 * node, 2 * node + 1]
|
||||
|
||||
def _push_lazy(self, node: int, start: int, end: int) -> None:
|
||||
"""
|
||||
Push lazy updates down to children.
|
||||
|
||||
Args:
|
||||
node: Current node index
|
||||
start: Start index of the segment
|
||||
end: End index of the segment
|
||||
"""
|
||||
lazy_node = self.lazy[node]
|
||||
if lazy_node is None:
|
||||
return
|
||||
|
||||
# Apply lazy update to current node
|
||||
self.tree[node] = self.update_op(self.tree[node], lazy_node)
|
||||
|
||||
if start != end: # Not a leaf node
|
||||
# Propagate to children
|
||||
for child in self._children(node):
|
||||
self.lazy[child] = self.update_op(
|
||||
_value_or(self.lazy[child], self.identity), lazy_node
|
||||
)
|
||||
|
||||
# Clear the lazy value
|
||||
self.lazy[node] = None
|
||||
|
||||
def _update_range_helper(
|
||||
self, node: int, start: int, end: int, left: int, right: int, value: T
|
||||
) -> None:
|
||||
"""
|
||||
Helper method to update a range of values in the segment tree.
|
||||
|
||||
Args:
|
||||
node: Current node index
|
||||
start: Start index of the current segment
|
||||
end: End index of the current segment
|
||||
left: Start index of the range to update
|
||||
right: End index of the range to update
|
||||
value: Value to apply to the range
|
||||
"""
|
||||
# Push lazy updates before processing this node
|
||||
self._push_lazy(node, start, end)
|
||||
|
||||
# No overlap
|
||||
if start > right or end < left:
|
||||
return
|
||||
|
||||
# Complete overlap
|
||||
if start >= left and end <= right:
|
||||
# Apply update to current node
|
||||
self.lazy[node] = value
|
||||
self._push_lazy(node, start, end)
|
||||
return
|
||||
|
||||
# Partial overlap, recurse to children
|
||||
mid = (start + end) // 2
|
||||
left_child = 2 * node
|
||||
right_child = 2 * node + 1
|
||||
|
||||
self._update_range_helper(left_child, start, mid, left, right, value)
|
||||
self._update_range_helper(right_child, mid + 1, end, left, right, value)
|
||||
|
||||
# Update current node based on children
|
||||
self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
|
||||
|
||||
def _query_range_helper(
|
||||
self, node: int, start: int, end: int, left: int, right: int
|
||||
) -> T:
|
||||
"""
|
||||
Helper method to query a range of values in the segment tree.
|
||||
|
||||
Args:
|
||||
node: Current node index
|
||||
start: Start index of the current segment
|
||||
end: End index of the current segment
|
||||
left: Start index of the range to query
|
||||
right: End index of the range to query
|
||||
|
||||
Returns:
|
||||
Summary value for the range
|
||||
"""
|
||||
# No overlap
|
||||
if start > right or end < left:
|
||||
return self.identity
|
||||
|
||||
# Push lazy updates before processing this node
|
||||
self._push_lazy(node, start, end)
|
||||
|
||||
# Complete overlap
|
||||
if start >= left and end <= right:
|
||||
return self.tree[node]
|
||||
|
||||
# Partial overlap, recurse to children
|
||||
mid = (start + end) // 2
|
||||
left_child = 2 * node
|
||||
right_child = 2 * node + 1
|
||||
|
||||
left_result = self._query_range_helper(left_child, start, mid, left, right)
|
||||
right_result = self._query_range_helper(right_child, mid + 1, end, left, right)
|
||||
|
||||
# Combine results from children
|
||||
return self.summary_op(left_result, right_result)
|
||||
|
||||
def update_range(self, start: int, end: int, value: T) -> None:
|
||||
"""
|
||||
Update a range of values in the segment tree.
|
||||
|
||||
Args:
|
||||
start: Start index of the range to update (inclusive)
|
||||
end: End index of the range to update (inclusive)
|
||||
value: Value to apply to the range
|
||||
|
||||
Raises:
|
||||
ValueError: If start > end or indices are out of bounds
|
||||
"""
|
||||
if start > end:
|
||||
raise ValueError("Start index must be less than or equal to end index")
|
||||
|
||||
if start < 0 or start >= self.n:
|
||||
raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]")
|
||||
|
||||
if end < 0 or end >= self.n:
|
||||
raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]")
|
||||
|
||||
self._update_range_helper(1, 0, self.n - 1, start, end, value)
|
||||
|
||||
def summarize_range(self, start: int, end: int) -> T:
|
||||
"""
|
||||
Query a range of values in the segment tree.
|
||||
|
||||
Args:
|
||||
start: Start index of the range to query (inclusive)
|
||||
end: End index of the range to query (inclusive)
|
||||
|
||||
Returns:
|
||||
Summary value for the range according to the summary operation
|
||||
|
||||
Raises:
|
||||
ValueError: If start > end or indices are out of bounds
|
||||
"""
|
||||
if start > end:
|
||||
raise ValueError("Start index must be less than or equal to end index")
|
||||
|
||||
if start < 0 or start >= self.n:
|
||||
raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]")
|
||||
|
||||
if end < 0 or end >= self.n:
|
||||
raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]")
|
||||
|
||||
return self._query_range_helper(1, 0, self.n - 1, start, end)
|
@ -48,7 +48,6 @@ from ..utils import (
|
||||
cache_on_self,
|
||||
DelayReplaceLine,
|
||||
get_benchmark_name,
|
||||
get_dtype_size,
|
||||
IndentedBuffer,
|
||||
is_codegen_graph_partition_subgraph,
|
||||
is_using_cudagraph_partition,
|
||||
@ -588,64 +587,10 @@ class MemoryPlanningLine(WrapperLine):
|
||||
return f"{type(self).__name__}({', '.join(args)})"
|
||||
|
||||
|
||||
class EfficientPeakEstimate:
|
||||
def __init__(self):
|
||||
from ..memory import estimate_peak_memory, get_freeable_input_buf
|
||||
|
||||
scheduler_nodes = V.graph.scheduler.nodes
|
||||
graph_inputs = OrderedSet(V.graph.graph_inputs.keys())
|
||||
graph_outputs = OrderedSet(V.graph.get_output_names())
|
||||
names_to_freeable_bufs = get_freeable_input_buf(scheduler_nodes, graph_inputs)
|
||||
self.overall_peak_memory, peak_by_scheduler_node = estimate_peak_memory(
|
||||
scheduler_nodes,
|
||||
names_to_freeable_bufs,
|
||||
graph_outputs,
|
||||
)
|
||||
|
||||
from .segmented_tree import SegmentedTree
|
||||
|
||||
self.segmented_tree = SegmentedTree(
|
||||
peak_by_scheduler_node, operator.add, max, 0
|
||||
)
|
||||
|
||||
def _get_size(self, node: BufferLike) -> int:
|
||||
return V.graph.sizevars.size_hint(
|
||||
V.graph.get_allocation_storage_size(node), fallback=0
|
||||
) * get_dtype_size(node.get_dtype())
|
||||
|
||||
def peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine):
|
||||
return self.segmented_tree.summarize_range(
|
||||
line_a.scheduler_node_index + 1, line_b.scheduler_node_index - 1
|
||||
)
|
||||
|
||||
def update_peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine):
|
||||
if line_a.scheduler_node_index + 1 == line_b.scheduler_node_index:
|
||||
return
|
||||
self.segmented_tree.update_range(
|
||||
line_a.scheduler_node_index + 1,
|
||||
line_b.scheduler_node_index - 1,
|
||||
self._get_size(line_b.node),
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AllocateLine(MemoryPlanningLine):
|
||||
node: BufferLike
|
||||
|
||||
def __post_init__(self):
|
||||
assert V.graph.scheduler.current_node is not None
|
||||
self.scheduler_node_index = V.graph.scheduler.nodes.index(
|
||||
V.graph.scheduler.current_node
|
||||
)
|
||||
|
||||
def should_reuse_buffer(self, free_line: FreeIfNotReusedLine, size: int) -> bool:
|
||||
if free_line.scheduler_node_index + 1 == self.scheduler_node_index:
|
||||
return True
|
||||
overall_peak_memory = self.wrapper.estimate_peak.overall_peak_memory
|
||||
peak_memory_in_range = self.wrapper.estimate_peak.peak_between(free_line, self)
|
||||
new_peak_memory = size + peak_memory_in_range
|
||||
return new_peak_memory <= overall_peak_memory
|
||||
|
||||
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
|
||||
if self.node.get_name() in V.graph.removed_buffers:
|
||||
return NullLine(self.wrapper)
|
||||
@ -654,16 +599,8 @@ class AllocateLine(MemoryPlanningLine):
|
||||
key = buffer_reuse_key(self.node)
|
||||
if config.allow_buffer_reuse and key in state:
|
||||
free_line = state.pop(key)
|
||||
size = V.graph.sizevars.size_hint(
|
||||
V.graph.get_allocation_storage_size(self.node), fallback=0
|
||||
) * get_dtype_size(self.node.get_dtype())
|
||||
if self.should_reuse_buffer(free_line, size):
|
||||
free_line.is_reused = True
|
||||
self.wrapper.estimate_peak.update_peak_between(free_line, self)
|
||||
return ReuseLine(self.wrapper, free_line.node, self.node)
|
||||
else:
|
||||
state.push(key, free_line)
|
||||
return self
|
||||
free_line.is_reused = True
|
||||
return ReuseLine(self.wrapper, free_line.node, self.node)
|
||||
|
||||
if self.node.get_device_or_error().type == "cpu":
|
||||
static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node)
|
||||
@ -688,12 +625,6 @@ class FreeIfNotReusedLine(MemoryPlanningLine):
|
||||
node: BufferLike
|
||||
is_reused: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
assert V.graph.scheduler.current_node is not None
|
||||
self.scheduler_node_index = V.graph.scheduler.nodes.index(
|
||||
V.graph.scheduler.current_node
|
||||
)
|
||||
|
||||
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
|
||||
if len(self.node.get_inputs_that_alias_output()) > 0:
|
||||
return self
|
||||
@ -1714,7 +1645,6 @@ class PythonWrapperCodegen(CodeGen):
|
||||
if is_inference and config.memory_planning:
|
||||
self.memory_plan()
|
||||
else:
|
||||
self.estimate_peak = EfficientPeakEstimate()
|
||||
self.memory_plan_reuse()
|
||||
|
||||
def codegen_input_symbol_assignment(
|
||||
|
@ -2073,7 +2073,6 @@ class Scheduler:
|
||||
)
|
||||
|
||||
self.nodes = [self.create_scheduler_node(n) for n in nodes]
|
||||
self.current_node: Optional[BaseSchedulerNode] = None
|
||||
self.update_zero_dim_cpu_tensor()
|
||||
# some new constants could have been created above
|
||||
self.available_buffer_names.update(V.graph.constants.keys())
|
||||
@ -4990,7 +4989,6 @@ class Scheduler:
|
||||
assert device.index is not None, "device should have an index"
|
||||
V.graph.wrapper_code.codegen_device_guard_enter(device.index)
|
||||
|
||||
self.current_node = node
|
||||
self.buffer_names_to_free.update(node.last_usage)
|
||||
|
||||
if node.is_template():
|
||||
|
Reference in New Issue
Block a user