Files
pytorch/torch/_inductor/codegen/segmented_tree.py

242 lines
8.0 KiB
Python

from typing import Callable, Generic, Optional, TypeVar
T = TypeVar("T")
def _value_or(opt: Optional[T], default: T) -> T:
return opt if opt is not None else default
class SegmentedTree(Generic[T]):
def __init__(
self,
values: list[T],
update_op: Callable[[T, T], T],
summary_op: Callable[[T, T], T],
identity_element: T,
):
"""
Initialize a segment tree with the given values and operations.
Args:
values: list of initial values
update_op: Function to apply when updating a value (e.g., addition)
summary_op: Function to summarize two values (e.g., min, max, sum)
identity_element: Identity element for the summary_op (e.g., 0 for sum, float('inf') for min)
Raises:
ValueError: If the input values list is empty
"""
if not values:
raise ValueError("Cannot create a segment tree with empty values list")
self.n = len(values)
self.update_op = update_op
self.summary_op = summary_op
self.identity = identity_element
# Size of segment tree array (next power of 2 * 2)
# The tree follows a standard heap layout where
# node `n`'s children are at `2*n` and `2*n+1`.
# Index 0 is unused.
self.size = 1
while self.size < self.n:
self.size *= 2
self.size *= 2
# Initialize tree and lazy arrays
self.tree = [identity_element] * self.size
# The lazy array contains updates to the given node
# Upon update, we only push updates to the top-most
# nodes that fully receive the update. We then
# propagate the update down as required (i.e., when
# we receive an interval query that neither fully
# contains the node nor fully doesn't contain the
# node
self.lazy: list[Optional[T]] = [None] * self.size
# Build the tree
self._build(values, 1, 0, self.n - 1)
def _build(self, values: list[T], node: int, start: int, end: int) -> None:
"""
Build the segment tree recursively.
Args:
values: Original array of values
node: Current node index in the segment tree
start: Start index of the segment
end: End index of the segment
"""
if start == end:
# Leaf node
if start < len(values):
self.tree[node] = values[start]
return
mid = (start + end) // 2
left_child = 2 * node
right_child = 2 * node + 1
# Recursively build left and right subtrees
self._build(values, left_child, start, mid)
self._build(values, right_child, mid + 1, end)
# Update current node with summary of children
self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
def _children(self, node: int) -> list[int]:
return [2 * node, 2 * node + 1]
def _push_lazy(self, node: int, start: int, end: int) -> None:
"""
Push lazy updates down to children.
Args:
node: Current node index
start: Start index of the segment
end: End index of the segment
"""
lazy_node = self.lazy[node]
if lazy_node is None:
return
# Apply lazy update to current node
self.tree[node] = self.update_op(self.tree[node], lazy_node)
if start != end: # Not a leaf node
# Propagate to children
for child in self._children(node):
self.lazy[child] = self.update_op(
_value_or(self.lazy[child], self.identity), lazy_node
)
# Clear the lazy value
self.lazy[node] = None
def _update_range_helper(
self, node: int, start: int, end: int, left: int, right: int, value: T
) -> None:
"""
Helper method to update a range of values in the segment tree.
Args:
node: Current node index
start: Start index of the current segment
end: End index of the current segment
left: Start index of the range to update
right: End index of the range to update
value: Value to apply to the range
"""
# Push lazy updates before processing this node
self._push_lazy(node, start, end)
# No overlap
if start > right or end < left:
return
# Complete overlap
if start >= left and end <= right:
# Apply update to current node
self.lazy[node] = value
self._push_lazy(node, start, end)
return
# Partial overlap, recurse to children
mid = (start + end) // 2
left_child = 2 * node
right_child = 2 * node + 1
self._update_range_helper(left_child, start, mid, left, right, value)
self._update_range_helper(right_child, mid + 1, end, left, right, value)
# Update current node based on children
self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
def _query_range_helper(
self, node: int, start: int, end: int, left: int, right: int
) -> T:
"""
Helper method to query a range of values in the segment tree.
Args:
node: Current node index
start: Start index of the current segment
end: End index of the current segment
left: Start index of the range to query
right: End index of the range to query
Returns:
Summary value for the range
"""
# No overlap
if start > right or end < left:
return self.identity
# Push lazy updates before processing this node
self._push_lazy(node, start, end)
# Complete overlap
if start >= left and end <= right:
return self.tree[node]
# Partial overlap, recurse to children
mid = (start + end) // 2
left_child = 2 * node
right_child = 2 * node + 1
left_result = self._query_range_helper(left_child, start, mid, left, right)
right_result = self._query_range_helper(right_child, mid + 1, end, left, right)
# Combine results from children
return self.summary_op(left_result, right_result)
def update_range(self, start: int, end: int, value: T) -> None:
"""
Update a range of values in the segment tree.
Args:
start: Start index of the range to update (inclusive)
end: End index of the range to update (inclusive)
value: Value to apply to the range
Raises:
ValueError: If start > end or indices are out of bounds
"""
if start > end:
raise ValueError("Start index must be less than or equal to end index")
if start < 0 or start >= self.n:
raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]")
if end < 0 or end >= self.n:
raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]")
self._update_range_helper(1, 0, self.n - 1, start, end, value)
def summarize_range(self, start: int, end: int) -> T:
"""
Query a range of values in the segment tree.
Args:
start: Start index of the range to query (inclusive)
end: End index of the range to query (inclusive)
Returns:
Summary value for the range according to the summary operation
Raises:
ValueError: If start > end or indices are out of bounds
"""
if start > end:
raise ValueError("Start index must be less than or equal to end index")
if start < 0 or start >= self.n:
raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]")
if end < 0 or end >= self.n:
raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]")
return self._query_range_helper(1, 0, self.n - 1, start, end)