Skip to content

Commit eae87f7

Browse files
authored
Create brent_method.py
1 parent 7530a41 commit eae87f7

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""
2+
Brent's Method for Root Finding
3+
-------------------------------
4+
5+
Brent's method is a robust and efficient algorithm for finding a zero of a
6+
function in a given interval [left, right]. It combines bisection,
7+
secant, and inverse quadratic interpolation methods.
8+
9+
References:
10+
- https://en.wikipedia.org/wiki/Brent%27s_method
11+
- https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.brentq.html
12+
13+
Example usage:
14+
>>> def cubic(x):
15+
... return x**3 - x - 2
16+
>>> round(brent_root(cubic, 1, 2), 5)
17+
1.52138
18+
"""
19+
20+
from typing import Callable
21+
22+
def brent_root(
23+
function: Callable[[float], float],
24+
left: float,
25+
right: float,
26+
tolerance: float = 1e-5,
27+
max_iterations: int = 100,
28+
) -> float:
29+
value_left, value_right = function(left), function(right)
30+
if value_left * value_right >= 0:
31+
raise ValueError("Function must have opposite signs at endpoints left and right.")
32+
33+
previous_point = current_point = left
34+
value_previous = value_current = value_left
35+
distance = interval_length = right - left
36+
37+
for iteration in range(max_iterations):
38+
if value_current * value_previous > 0:
39+
previous_point, value_previous = left, value_left
40+
distance = interval_length = right - left
41+
42+
if abs(value_previous) < abs(value_current):
43+
left, current_point, previous_point = current_point, previous_point, current_point
44+
value_left, value_current, value_previous = value_current, value_previous, value_current
45+
46+
tolerance1 = 2.0 * 1e-16 * abs(current_point) + 0.5 * tolerance
47+
midpoint = 0.5 * (previous_point - current_point)
48+
49+
if abs(midpoint) <= tolerance1 or value_current == 0.0:
50+
return current_point
51+
52+
if abs(interval_length) >= tolerance1 and abs(value_left) > abs(value_current):
53+
ratio = value_current / value_left
54+
if left == previous_point:
55+
numerator = 2 * midpoint * ratio
56+
denominator = 1 - ratio
57+
else:
58+
q = value_left / value_previous
59+
r = value_current / value_previous
60+
numerator = ratio * (2 * midpoint * q * (q - r) - (current_point - left) * (r - 1))
61+
denominator = (q - 1) * (r - 1) * (ratio - 1)
62+
63+
if numerator > 0:
64+
denominator = -denominator
65+
numerator = abs(numerator)
66+
67+
if 2 * numerator < min(3 * midpoint * denominator - abs(tolerance1 * denominator), abs(interval_length * denominator)):
68+
interval_length = distance
69+
distance = numerator / denominator
70+
else:
71+
distance = midpoint
72+
interval_length = midpoint
73+
else:
74+
distance = midpoint
75+
interval_length = midpoint
76+
77+
left, value_left = current_point, value_current
78+
if abs(distance) > tolerance1:
79+
current_point += distance
80+
else:
81+
current_point += tolerance1 if midpoint > 0 else -tolerance1
82+
value_current = function(current_point)
83+
84+
raise RuntimeError("Maximum iterations exceeded in Brent's method.")
85+
86+
if __name__ == "__main__":
87+
import doctest
88+
doctest.testmod()

0 commit comments

Comments
 (0)