mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159530 Approved by: https://github.com/eellison
255 lines
8.8 KiB
Python
255 lines
8.8 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
from hypothesis import given, strategies as st
|
|
|
|
from torch._inductor.codegen.segmented_tree import SegmentedTree
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
|
|
|
|
# Helper functions for operations
|
|
def max_op(a, b):
|
|
return max(a, b)
|
|
|
|
|
|
def add_op(a, b):
|
|
return a + b
|
|
|
|
|
|
# Naive implementations for reference
|
|
def naive_range_max(arr, start, end):
|
|
return max(arr[start : end + 1])
|
|
|
|
|
|
def naive_range_update(arr, start, end, value):
|
|
for i in range(start, end + 1):
|
|
arr[i] += value
|
|
|
|
|
|
# Strategies for hypothesis testing
|
|
positive_integers = st.lists(
|
|
st.integers(min_value=1, max_value=100), min_size=1, max_size=50
|
|
)
|
|
|
|
|
|
def valid_range_indices(array_length):
|
|
return st.tuples(
|
|
st.integers(min_value=0, max_value=array_length - 1),
|
|
st.integers(min_value=0, max_value=array_length - 1),
|
|
).map(lambda x: (min(x), max(x)))
|
|
|
|
|
|
update_values = st.integers(min_value=1, max_value=50)
|
|
|
|
|
|
class TestSegmentedTree(TestCase):
|
|
# Basic construction and initialization tests
|
|
def test_basic_construction(self):
|
|
values = [1, 3, 5, 7, 9]
|
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
|
assert tree.summarize_range(0, 4) == 9
|
|
|
|
def test_empty_array(self):
|
|
with self.assertRaises(ValueError):
|
|
SegmentedTree([], add_op, max_op, 0)
|
|
|
|
# Property-based tests
|
|
@given(values=positive_integers)
|
|
def test_max_query_matches_naive(self, values):
|
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
|
|
|
for start in range(len(values)):
|
|
for end in range(start, len(values)):
|
|
expected = naive_range_max(values, start, end)
|
|
actual = tree.summarize_range(start, end)
|
|
assert actual == expected, (
|
|
f"Range [{start}:{end}] expected {expected}, got {actual}"
|
|
)
|
|
|
|
@given(
|
|
values=positive_integers, range_indices=st.data(), update_value=update_values
|
|
)
|
|
def test_range_update(self, values, range_indices, update_value):
|
|
# Create a copy for naive implementation
|
|
naive_values = values.copy()
|
|
|
|
# Create segment tree
|
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
|
|
|
# Get valid range indices
|
|
start, end = range_indices.draw(valid_range_indices(len(values)))
|
|
|
|
# Apply updates
|
|
tree.update_range(start, end, update_value)
|
|
naive_range_update(naive_values, start, end, update_value)
|
|
|
|
# Verify all possible ranges
|
|
for i in range(len(values)):
|
|
for j in range(i, len(values)):
|
|
expected = naive_range_max(naive_values, i, j)
|
|
actual = tree.summarize_range(i, j)
|
|
assert actual == expected, (
|
|
f"After update, range [{i}:{j}] expected {expected}, got {actual}"
|
|
)
|
|
|
|
@given(values=positive_integers, range_data=st.data())
|
|
def test_multiple_operations(self, values, range_data):
|
|
# Create a copy for naive implementation
|
|
naive_values = values.copy()
|
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
|
|
|
# Perform multiple operations
|
|
num_operations = 5
|
|
for _ in range(num_operations):
|
|
# Randomly choose between query and update
|
|
operation_type = range_data.draw(st.sampled_from(["query", "update"]))
|
|
start, end = range_data.draw(valid_range_indices(len(values)))
|
|
|
|
if operation_type == "query":
|
|
expected = naive_range_max(naive_values, start, end)
|
|
actual = tree.summarize_range(start, end)
|
|
assert actual == expected, (
|
|
f"Range query [{start}:{end}] expected {expected}, got {actual}"
|
|
)
|
|
else: # update
|
|
update_value = range_data.draw(update_values)
|
|
tree.update_range(start, end, update_value)
|
|
naive_range_update(naive_values, start, end, update_value)
|
|
|
|
def test_single_element_ranges(self):
|
|
values = [1, 3, 5, 7, 9]
|
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
|
|
|
for i in range(len(values)):
|
|
assert tree.summarize_range(i, i) == values[i], (
|
|
f"Single element range at index {i} failed"
|
|
)
|
|
|
|
def test_full_array_range(self):
|
|
values = [1, 3, 5, 7, 9]
|
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
|
|
|
# Test querying the entire array
|
|
assert tree.summarize_range(0, len(values) - 1) == max(values)
|
|
|
|
# Update the entire array and test again
|
|
update_value = 10
|
|
tree.update_range(0, len(values) - 1, update_value)
|
|
expected = max([v + update_value for v in values])
|
|
assert tree.summarize_range(0, len(values) - 1) == expected
|
|
|
|
def test_boundary_conditions(self):
|
|
values = [1, 3, 5, 7, 9]
|
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
|
|
|
# Test first element
|
|
assert tree.summarize_range(0, 0) == values[0]
|
|
|
|
# Test last element
|
|
assert tree.summarize_range(len(values) - 1, len(values) - 1) == values[-1]
|
|
|
|
# Test first two elements
|
|
assert tree.summarize_range(0, 1) == max(values[0:2])
|
|
|
|
# Test last two elements
|
|
assert tree.summarize_range(len(values) - 2, len(values) - 1) == max(
|
|
values[-2:]
|
|
)
|
|
|
|
def test_invalid_ranges(self):
|
|
values = [1, 3, 5, 7, 9]
|
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
|
|
|
# Test start > end
|
|
with self.assertRaises(ValueError):
|
|
tree.summarize_range(3, 2)
|
|
|
|
with self.assertRaises(ValueError):
|
|
tree.update_range(4, 2, 10)
|
|
|
|
def test_out_of_bounds(self):
|
|
values = [1, 3, 5, 7, 9]
|
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
|
|
|
# Test negative indices
|
|
with self.assertRaises(ValueError):
|
|
tree.summarize_range(-1, 3)
|
|
|
|
with self.assertRaises(ValueError):
|
|
tree.summarize_range(0, -1)
|
|
|
|
# Test indices >= n
|
|
with self.assertRaises(ValueError):
|
|
tree.summarize_range(0, len(values))
|
|
|
|
with self.assertRaises(ValueError):
|
|
tree.summarize_range(len(values), len(values) + 1)
|
|
|
|
# Test update with out of bounds indices
|
|
with self.assertRaises(ValueError):
|
|
tree.update_range(-1, 3, 10)
|
|
|
|
with self.assertRaises(ValueError):
|
|
tree.update_range(0, len(values), 10)
|
|
|
|
def test_overlapping_updates(self):
|
|
values = [1, 3, 5, 7, 9]
|
|
naive_values = values.copy()
|
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
|
|
|
# Apply overlapping updates
|
|
tree.update_range(0, 2, 5) # Update [0, 1, 2]
|
|
naive_range_update(naive_values, 0, 2, 5)
|
|
|
|
tree.update_range(1, 3, 3) # Update [1, 2, 3]
|
|
naive_range_update(naive_values, 1, 3, 3)
|
|
|
|
# Verify all possible ranges
|
|
for i in range(len(values)):
|
|
for j in range(i, len(values)):
|
|
expected = naive_range_max(naive_values, i, j)
|
|
actual = tree.summarize_range(i, j)
|
|
assert actual == expected, (
|
|
f"After overlapping updates, range [{i}:{j}] expected {expected}, got {actual}"
|
|
)
|
|
|
|
def test_sequential_updates_and_queries(self):
|
|
values = [2, 4, 6, 8, 10, 12, 14]
|
|
naive_values = values.copy()
|
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
|
|
|
# Sequence of operations
|
|
operations = [
|
|
("update", 1, 3, 5), # Update range [1, 2, 3] with +5
|
|
("query", 0, 4), # Query range [0, 1, 2, 3, 4]
|
|
("update", 2, 5, 3), # Update range [2, 3, 4, 5] with +3
|
|
("query", 1, 3), # Query range [1, 2, 3]
|
|
("update", 0, 6, 2), # Update entire array with +2
|
|
("query", 0, 6), # Query entire array
|
|
("query", 3, 5), # Query range [3, 4, 5]
|
|
]
|
|
|
|
for op in operations:
|
|
if op[0] == "update":
|
|
_, start, end, value = op
|
|
tree.update_range(start, end, value)
|
|
naive_range_update(naive_values, start, end, value)
|
|
|
|
# Verify tree state after update
|
|
for i in range(len(values)):
|
|
for j in range(i, len(values)):
|
|
expected = naive_range_max(naive_values, i, j)
|
|
actual = tree.summarize_range(i, j)
|
|
assert actual == expected, (
|
|
f"After update ({start}, {end}, {value}), query [{i}:{j}] expected {expected}, got {actual}"
|
|
)
|
|
else: # query
|
|
_, start, end = op
|
|
expected = naive_range_max(naive_values, start, end)
|
|
assert tree.summarize_range(start, end) == expected, (
|
|
f"Query [{start}:{end}] expected {expected}, got {tree.summarize_range(start, end)}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|