mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159530 Approved by: https://github.com/eellison
242 lines
8.0 KiB
Python
242 lines
8.0 KiB
Python
from typing import Callable, Generic, Optional, TypeVar
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def _value_or(opt: Optional[T], default: T) -> T:
|
|
return opt if opt is not None else default
|
|
|
|
|
|
class SegmentedTree(Generic[T]):
|
|
def __init__(
|
|
self,
|
|
values: list[T],
|
|
update_op: Callable[[T, T], T],
|
|
summary_op: Callable[[T, T], T],
|
|
identity_element: T,
|
|
):
|
|
"""
|
|
Initialize a segment tree with the given values and operations.
|
|
|
|
Args:
|
|
values: list of initial values
|
|
update_op: Function to apply when updating a value (e.g., addition)
|
|
summary_op: Function to summarize two values (e.g., min, max, sum)
|
|
identity_element: Identity element for the summary_op (e.g., 0 for sum, float('inf') for min)
|
|
|
|
Raises:
|
|
ValueError: If the input values list is empty
|
|
"""
|
|
if not values:
|
|
raise ValueError("Cannot create a segment tree with empty values list")
|
|
|
|
self.n = len(values)
|
|
self.update_op = update_op
|
|
self.summary_op = summary_op
|
|
self.identity = identity_element
|
|
|
|
# Size of segment tree array (next power of 2 * 2)
|
|
# The tree follows a standard heap layout where
|
|
# node `n`'s children are at `2*n` and `2*n+1`.
|
|
# Index 0 is unused.
|
|
self.size = 1
|
|
while self.size < self.n:
|
|
self.size *= 2
|
|
self.size *= 2
|
|
|
|
# Initialize tree and lazy arrays
|
|
self.tree = [identity_element] * self.size
|
|
# The lazy array contains updates to the given node
|
|
# Upon update, we only push updates to the top-most
|
|
# nodes that fully receive the update. We then
|
|
# propagate the update down as required (i.e., when
|
|
# we receive an interval query that neither fully
|
|
# contains the node nor fully doesn't contain the
|
|
# node
|
|
self.lazy: list[Optional[T]] = [None] * self.size
|
|
|
|
# Build the tree
|
|
self._build(values, 1, 0, self.n - 1)
|
|
|
|
def _build(self, values: list[T], node: int, start: int, end: int) -> None:
|
|
"""
|
|
Build the segment tree recursively.
|
|
|
|
Args:
|
|
values: Original array of values
|
|
node: Current node index in the segment tree
|
|
start: Start index of the segment
|
|
end: End index of the segment
|
|
"""
|
|
if start == end:
|
|
# Leaf node
|
|
if start < len(values):
|
|
self.tree[node] = values[start]
|
|
return
|
|
|
|
mid = (start + end) // 2
|
|
left_child = 2 * node
|
|
right_child = 2 * node + 1
|
|
|
|
# Recursively build left and right subtrees
|
|
self._build(values, left_child, start, mid)
|
|
self._build(values, right_child, mid + 1, end)
|
|
|
|
# Update current node with summary of children
|
|
self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
|
|
|
|
def _children(self, node: int) -> list[int]:
|
|
return [2 * node, 2 * node + 1]
|
|
|
|
def _push_lazy(self, node: int, start: int, end: int) -> None:
|
|
"""
|
|
Push lazy updates down to children.
|
|
|
|
Args:
|
|
node: Current node index
|
|
start: Start index of the segment
|
|
end: End index of the segment
|
|
"""
|
|
lazy_node = self.lazy[node]
|
|
if lazy_node is None:
|
|
return
|
|
|
|
# Apply lazy update to current node
|
|
self.tree[node] = self.update_op(self.tree[node], lazy_node)
|
|
|
|
if start != end: # Not a leaf node
|
|
# Propagate to children
|
|
for child in self._children(node):
|
|
self.lazy[child] = self.update_op(
|
|
_value_or(self.lazy[child], self.identity), lazy_node
|
|
)
|
|
|
|
# Clear the lazy value
|
|
self.lazy[node] = None
|
|
|
|
def _update_range_helper(
|
|
self, node: int, start: int, end: int, left: int, right: int, value: T
|
|
) -> None:
|
|
"""
|
|
Helper method to update a range of values in the segment tree.
|
|
|
|
Args:
|
|
node: Current node index
|
|
start: Start index of the current segment
|
|
end: End index of the current segment
|
|
left: Start index of the range to update
|
|
right: End index of the range to update
|
|
value: Value to apply to the range
|
|
"""
|
|
# Push lazy updates before processing this node
|
|
self._push_lazy(node, start, end)
|
|
|
|
# No overlap
|
|
if start > right or end < left:
|
|
return
|
|
|
|
# Complete overlap
|
|
if start >= left and end <= right:
|
|
# Apply update to current node
|
|
self.lazy[node] = value
|
|
self._push_lazy(node, start, end)
|
|
return
|
|
|
|
# Partial overlap, recurse to children
|
|
mid = (start + end) // 2
|
|
left_child = 2 * node
|
|
right_child = 2 * node + 1
|
|
|
|
self._update_range_helper(left_child, start, mid, left, right, value)
|
|
self._update_range_helper(right_child, mid + 1, end, left, right, value)
|
|
|
|
# Update current node based on children
|
|
self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
|
|
|
|
def _query_range_helper(
|
|
self, node: int, start: int, end: int, left: int, right: int
|
|
) -> T:
|
|
"""
|
|
Helper method to query a range of values in the segment tree.
|
|
|
|
Args:
|
|
node: Current node index
|
|
start: Start index of the current segment
|
|
end: End index of the current segment
|
|
left: Start index of the range to query
|
|
right: End index of the range to query
|
|
|
|
Returns:
|
|
Summary value for the range
|
|
"""
|
|
# No overlap
|
|
if start > right or end < left:
|
|
return self.identity
|
|
|
|
# Push lazy updates before processing this node
|
|
self._push_lazy(node, start, end)
|
|
|
|
# Complete overlap
|
|
if start >= left and end <= right:
|
|
return self.tree[node]
|
|
|
|
# Partial overlap, recurse to children
|
|
mid = (start + end) // 2
|
|
left_child = 2 * node
|
|
right_child = 2 * node + 1
|
|
|
|
left_result = self._query_range_helper(left_child, start, mid, left, right)
|
|
right_result = self._query_range_helper(right_child, mid + 1, end, left, right)
|
|
|
|
# Combine results from children
|
|
return self.summary_op(left_result, right_result)
|
|
|
|
def update_range(self, start: int, end: int, value: T) -> None:
|
|
"""
|
|
Update a range of values in the segment tree.
|
|
|
|
Args:
|
|
start: Start index of the range to update (inclusive)
|
|
end: End index of the range to update (inclusive)
|
|
value: Value to apply to the range
|
|
|
|
Raises:
|
|
ValueError: If start > end or indices are out of bounds
|
|
"""
|
|
if start > end:
|
|
raise ValueError("Start index must be less than or equal to end index")
|
|
|
|
if start < 0 or start >= self.n:
|
|
raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]")
|
|
|
|
if end < 0 or end >= self.n:
|
|
raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]")
|
|
|
|
self._update_range_helper(1, 0, self.n - 1, start, end, value)
|
|
|
|
def summarize_range(self, start: int, end: int) -> T:
|
|
"""
|
|
Query a range of values in the segment tree.
|
|
|
|
Args:
|
|
start: Start index of the range to query (inclusive)
|
|
end: End index of the range to query (inclusive)
|
|
|
|
Returns:
|
|
Summary value for the range according to the summary operation
|
|
|
|
Raises:
|
|
ValueError: If start > end or indices are out of bounds
|
|
"""
|
|
if start > end:
|
|
raise ValueError("Start index must be less than or equal to end index")
|
|
|
|
if start < 0 or start >= self.n:
|
|
raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]")
|
|
|
|
if end < 0 or end >= self.n:
|
|
raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]")
|
|
|
|
return self._query_range_helper(1, 0, self.n - 1, start, end)
|