Skip to content

Commit 9257674

Browse files
committed
Add Segment Tree implementation for range sum queries with type hints and doctests
1 parent a187c99 commit 9257674

File tree

1 file changed

+82
-75
lines changed

1 file changed

+82
-75
lines changed
Lines changed: 82 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,100 +1,107 @@
11
"""
2-
A Segment Tree is a binary tree data structure used for efficiently answering
3-
range queries and updates on an array, such as sum, minimum, or maximum over
4-
a subrange. It offers O(log n) time complexity for both queries and updates,
5-
making it very efficient compared to a naive O(n) approach.
6-
7-
While building the tree takes O(n) time and the tree requires O(n) space,
8-
this preprocessing enables fast range queries that would otherwise be slow.
9-
Segment Trees are especially useful when the array is mutable and queries
10-
and updates are intermixed.
2+
A Segment Tree is a binary tree used for storing intervals or segments.
3+
It allows querying the sum, minimum, or maximum of elements in a range efficiently.
114
125
Time Complexity:
136
- Build: O(n)
147
- Query: O(log n)
158
- Update: O(log n)
16-
17-
Example usage and doctests:
18-
19-
>>> data = [1, 2, 3, 4, 5]
20-
>>> st = SegmentTree(data)
21-
>>> st.query(1, 4)
22-
9
23-
>>> st.update(2, 10)
24-
>>> st.query(1, 4)
25-
16
269
"""
2710

2811

2912
class SegmentTree:
30-
"""Segment Tree for efficient range sum queries."""
31-
32-
def __init__(self, data: list[int]):
33-
"""Initialize the segment tree with the input data.
34-
35-
Args:
36-
data (list[int]): List of integers to build the segment tree.
13+
def __init__(self, data: list[int]) -> None:
14+
"""
15+
Initializes the Segment Tree from a list of integers.
16+
:param data: List of integer elements.
17+
:return: None
3718
"""
3819
self.n = len(data)
39-
self.tree = [0] * (2 * self.n)
40-
# Build the tree
41-
for i in range(self.n):
42-
self.tree[self.n + i] = data[i]
43-
for i in range(self.n - 1, 0, -1):
44-
self.tree[i] = self.tree[i << 1] + self.tree[i << 1 | 1]
20+
self.tree = [0] * (4 * self.n)
21+
self._build(data, 0, 0, self.n - 1)
22+
23+
def _build(self, data: list[int], node: int, start: int, end: int) -> None:
24+
if start == end:
25+
self.tree[node] = data[start]
26+
else:
27+
mid = (start + end) // 2
28+
self._build(data, 2 * node + 1, start, mid)
29+
self._build(data, 2 * node + 2, mid + 1, end)
30+
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
4531

4632
def update(self, index: int, value: int) -> None:
47-
"""Update element at index with a new value.
48-
49-
Args:
50-
index (int): Index of the element to update.
51-
value (int): New value to set at the given index.
5233
"""
53-
if index < 0 or index >= self.n:
54-
raise ValueError("Index out of bounds")
55-
index += self.n
56-
self.tree[index] = value
57-
while index > 1:
58-
index >>= 1
59-
self.tree[index] = self.tree[index << 1] + self.tree[index << 1 | 1]
34+
Updates the value at the given index and updates the tree.
35+
:param index: Index to update.
36+
:param value: New value.
37+
:return: None
38+
"""
39+
self._update(0, 0, self.n - 1, index, value)
40+
41+
def _update(self, node: int, start: int, end: int, index: int, value: int) -> None:
42+
if start == end:
43+
self.tree[node] = value
44+
else:
45+
mid = (start + end) // 2
46+
if index <= mid:
47+
self._update(2 * node + 1, start, mid, index, value)
48+
else:
49+
self._update(2 * node + 2, mid + 1, end, index, value)
50+
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
6051

6152
def query(self, left: int, right: int) -> int:
62-
"""Compute the sum of elements in the interval [left, right).
63-
64-
Args:
65-
left (int): Left index (inclusive).
66-
right (int): Right index (exclusive).
67-
68-
Returns:
69-
int: Sum of elements from left to right-1.
70-
71-
Raises:
72-
ValueError: If indices are out of bounds or left >= right.
7353
"""
74-
if left < 0 or right > self.n or left >= right:
75-
raise ValueError("Invalid query range")
76-
res = 0
77-
left += self.n
78-
right += self.n
79-
while left < right:
80-
if left & 1:
81-
res += self.tree[left]
82-
left += 1
83-
if right & 1:
84-
right -= 1
85-
res += self.tree[right]
86-
left >>= 1
87-
right >>= 1
88-
return res
54+
Returns the sum of elements in the range [left, right].
55+
:param left: Left index (inclusive).
56+
:param right: Right index (inclusive).
57+
:return: Sum of the range.
58+
59+
>>> data = [1, 2, 3, 4, 5]
60+
>>> st = SegmentTree(data)
61+
>>> st.query(1, 3)
62+
9
63+
>>> st.update(2, 10)
64+
>>> st.query(1, 3)
65+
16
66+
"""
67+
return self._query(0, 0, self.n - 1, left, right)
8968

69+
def _query(self, node: int, start: int, end: int, left: int, right: int) -> int:
70+
if right < start or left > end:
71+
return 0
72+
if left <= start and end <= right:
73+
return self.tree[node]
74+
mid = (start + end) // 2
75+
return self._query(2 * node + 1, start, mid, left, right) + self._query(
76+
2 * node + 2, mid + 1, end, left, right
77+
)
9078

91-
if __name__ == "__main__":
92-
import doctest
9379

80+
def test_segment_tree() -> bool:
9481
data = [1, 2, 3, 4, 5]
9582
st = SegmentTree(data)
96-
print("Initial sum 1-4:", st.query(1, 4))
83+
assert st.query(0, 2) == 6
84+
assert st.query(1, 4) == 14
9785
st.update(2, 10)
98-
print("Updated sum 1-4:", st.query(1, 4))
86+
assert st.query(1, 3) == 16
87+
assert st.query(0, 4) == 22
88+
return True
89+
90+
91+
def print_results(msg: str, passes: bool) -> None:
92+
print(str(msg), "works!" if passes else "doesn't work :(")
93+
9994

100-
doctest.testmod()
95+
def pytests() -> None:
96+
assert test_segment_tree()
97+
98+
99+
def main() -> None:
100+
"""
101+
>>> pytests()
102+
"""
103+
print_results("Testing Segment Tree functionality", test_segment_tree())
104+
105+
106+
if __name__ == "__main__":
107+
main()

0 commit comments

Comments
 (0)