|
1 | 1 | """ |
2 | | -A Segment Tree is a binary tree data structure used for efficiently answering |
3 | | -range queries and updates on an array, such as sum, minimum, or maximum over |
4 | | -a subrange. It offers O(log n) time complexity for both queries and updates, |
5 | | -making it very efficient compared to a naive O(n) approach. |
6 | | -
|
7 | | -While building the tree takes O(n) time and the tree requires O(n) space, |
8 | | -this preprocessing enables fast range queries that would otherwise be slow. |
9 | | -Segment Trees are especially useful when the array is mutable and queries |
10 | | -and updates are intermixed. |
| 2 | +A Segment Tree is a binary tree used for storing intervals or segments. |
| 3 | +It allows querying the sum, minimum, or maximum of elements in a range efficiently. |
11 | 4 |
|
12 | 5 | Time Complexity: |
13 | 6 | - Build: O(n) |
14 | 7 | - Query: O(log n) |
15 | 8 | - Update: O(log n) |
16 | | -
|
17 | | -Example usage and doctests: |
18 | | -
|
19 | | ->>> data = [1, 2, 3, 4, 5] |
20 | | ->>> st = SegmentTree(data) |
21 | | ->>> st.query(1, 4) |
22 | | -9 |
23 | | ->>> st.update(2, 10) |
24 | | ->>> st.query(1, 4) |
25 | | -16 |
26 | 9 | """ |
27 | 10 |
|
28 | 11 |
|
29 | 12 | class SegmentTree: |
30 | | - """Segment Tree for efficient range sum queries.""" |
31 | | - |
32 | | - def __init__(self, data: list[int]): |
33 | | - """Initialize the segment tree with the input data. |
34 | | -
|
35 | | - Args: |
36 | | - data (list[int]): List of integers to build the segment tree. |
| 13 | + def __init__(self, data: list[int]) -> None: |
| 14 | + """ |
| 15 | + Initializes the Segment Tree from a list of integers. |
| 16 | + :param data: List of integer elements. |
| 17 | + :return: None |
37 | 18 | """ |
38 | 19 | self.n = len(data) |
39 | | - self.tree = [0] * (2 * self.n) |
40 | | - # Build the tree |
41 | | - for i in range(self.n): |
42 | | - self.tree[self.n + i] = data[i] |
43 | | - for i in range(self.n - 1, 0, -1): |
44 | | - self.tree[i] = self.tree[i << 1] + self.tree[i << 1 | 1] |
| 20 | + self.tree = [0] * (4 * self.n) |
| 21 | + self._build(data, 0, 0, self.n - 1) |
| 22 | + |
| 23 | + def _build(self, data: list[int], node: int, start: int, end: int) -> None: |
| 24 | + if start == end: |
| 25 | + self.tree[node] = data[start] |
| 26 | + else: |
| 27 | + mid = (start + end) // 2 |
| 28 | + self._build(data, 2 * node + 1, start, mid) |
| 29 | + self._build(data, 2 * node + 2, mid + 1, end) |
| 30 | + self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2] |
45 | 31 |
|
46 | 32 | def update(self, index: int, value: int) -> None: |
47 | | - """Update element at index with a new value. |
48 | | -
|
49 | | - Args: |
50 | | - index (int): Index of the element to update. |
51 | | - value (int): New value to set at the given index. |
52 | 33 | """ |
53 | | - if index < 0 or index >= self.n: |
54 | | - raise ValueError("Index out of bounds") |
55 | | - index += self.n |
56 | | - self.tree[index] = value |
57 | | - while index > 1: |
58 | | - index >>= 1 |
59 | | - self.tree[index] = self.tree[index << 1] + self.tree[index << 1 | 1] |
| 34 | + Updates the value at the given index and updates the tree. |
| 35 | + :param index: Index to update. |
| 36 | + :param value: New value. |
| 37 | + :return: None |
| 38 | + """ |
| 39 | + self._update(0, 0, self.n - 1, index, value) |
| 40 | + |
| 41 | + def _update(self, node: int, start: int, end: int, index: int, value: int) -> None: |
| 42 | + if start == end: |
| 43 | + self.tree[node] = value |
| 44 | + else: |
| 45 | + mid = (start + end) // 2 |
| 46 | + if index <= mid: |
| 47 | + self._update(2 * node + 1, start, mid, index, value) |
| 48 | + else: |
| 49 | + self._update(2 * node + 2, mid + 1, end, index, value) |
| 50 | + self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2] |
60 | 51 |
|
61 | 52 | def query(self, left: int, right: int) -> int: |
62 | | - """Compute the sum of elements in the interval [left, right). |
63 | | -
|
64 | | - Args: |
65 | | - left (int): Left index (inclusive). |
66 | | - right (int): Right index (exclusive). |
67 | | -
|
68 | | - Returns: |
69 | | - int: Sum of elements from left to right-1. |
70 | | -
|
71 | | - Raises: |
72 | | - ValueError: If indices are out of bounds or left >= right. |
73 | 53 | """ |
74 | | - if left < 0 or right > self.n or left >= right: |
75 | | - raise ValueError("Invalid query range") |
76 | | - res = 0 |
77 | | - left += self.n |
78 | | - right += self.n |
79 | | - while left < right: |
80 | | - if left & 1: |
81 | | - res += self.tree[left] |
82 | | - left += 1 |
83 | | - if right & 1: |
84 | | - right -= 1 |
85 | | - res += self.tree[right] |
86 | | - left >>= 1 |
87 | | - right >>= 1 |
88 | | - return res |
| 54 | + Returns the sum of elements in the range [left, right]. |
| 55 | + :param left: Left index (inclusive). |
| 56 | + :param right: Right index (inclusive). |
| 57 | + :return: Sum of the range. |
| 58 | +
|
| 59 | + >>> data = [1, 2, 3, 4, 5] |
| 60 | + >>> st = SegmentTree(data) |
| 61 | + >>> st.query(1, 3) |
| 62 | + 9 |
| 63 | + >>> st.update(2, 10) |
| 64 | + >>> st.query(1, 3) |
| 65 | + 16 |
| 66 | + """ |
| 67 | + return self._query(0, 0, self.n - 1, left, right) |
89 | 68 |
|
| 69 | + def _query(self, node: int, start: int, end: int, left: int, right: int) -> int: |
| 70 | + if right < start or left > end: |
| 71 | + return 0 |
| 72 | + if left <= start and end <= right: |
| 73 | + return self.tree[node] |
| 74 | + mid = (start + end) // 2 |
| 75 | + return self._query(2 * node + 1, start, mid, left, right) + self._query( |
| 76 | + 2 * node + 2, mid + 1, end, left, right |
| 77 | + ) |
90 | 78 |
|
91 | | -if __name__ == "__main__": |
92 | | - import doctest |
93 | 79 |
|
| 80 | +def test_segment_tree() -> bool: |
94 | 81 | data = [1, 2, 3, 4, 5] |
95 | 82 | st = SegmentTree(data) |
96 | | - print("Initial sum 1-4:", st.query(1, 4)) |
| 83 | + assert st.query(0, 2) == 6 |
| 84 | + assert st.query(1, 4) == 14 |
97 | 85 | st.update(2, 10) |
98 | | - print("Updated sum 1-4:", st.query(1, 4)) |
| 86 | + assert st.query(1, 3) == 16 |
| 87 | + assert st.query(0, 4) == 22 |
| 88 | + return True |
| 89 | + |
| 90 | + |
| 91 | +def print_results(msg: str, passes: bool) -> None: |
| 92 | + print(str(msg), "works!" if passes else "doesn't work :(") |
| 93 | + |
99 | 94 |
|
100 | | - doctest.testmod() |
| 95 | +def pytests() -> None: |
| 96 | + assert test_segment_tree() |
| 97 | + |
| 98 | + |
| 99 | +def main() -> None: |
| 100 | + """ |
| 101 | + >>> pytests() |
| 102 | + """ |
| 103 | + print_results("Testing Segment Tree functionality", test_segment_tree()) |
| 104 | + |
| 105 | + |
| 106 | +if __name__ == "__main__": |
| 107 | + main() |
0 commit comments