Skip to content

Commit d0d0b40

Browse files
Refactor gradient_descent.py with full type hints and documentation
This PR modernizes machine_learning/gradient_descent.py while preserving the original algorithm behavior. Changes include: - Added full type hints - Improved docstrings and readability No functional or algorithmic behavior was modified.
1 parent 678dedb commit d0d0b40

File tree

1 file changed

+31
-16
lines changed

1 file changed

+31
-16
lines changed

machine_learning/gradient_descent.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,29 @@
33
function.
44
"""
55

6+
from __future__ import annotations
7+
68
import numpy as np
9+
from typing import Literal, Sequence
710

811
# List of input, output pairs
9-
train_data = (
12+
train_data: tuple[tuple[tuple[int, ...], int], ...] = (
1013
((5, 2, 3), 15),
1114
((6, 5, 9), 25),
1215
((11, 12, 13), 41),
1316
((1, 1, 1), 8),
1417
((11, 12, 13), 41),
1518
)
16-
test_data = (((515, 22, 13), 555), ((61, 35, 49), 150))
17-
parameter_vector = [2, 4, 1, 5]
18-
m = len(train_data)
19-
LEARNING_RATE = 0.009
19+
test_data: tuple[tuple[tuple[int, ...], int], ...] = (((515, 22, 13), 555), ((61, 35, 49), 150))
20+
parameter_vector: list[float] = [2.0, 4.0, 1.0, 5.0]
21+
m: int = len(train_data)
22+
LEARNING_RATE: float = 0.009
2023

2124

22-
def _error(example_no, data_set="train"):
25+
def _error(example_no: int, data_set: Literal["train", "test"]="train") -> float:
2326
"""
27+
Compute prediction error for a given example.
28+
2429
:param data_set: train data or test data
2530
:param example_no: example number whose error has to be checked
2631
:return: error in example pointed by example number.
@@ -30,9 +35,10 @@ def _error(example_no, data_set="train"):
3035
)
3136

3237

33-
def _hypothesis_value(data_input_tuple):
38+
def _hypothesis_value(data_input_tuple: Sequence[int]) -> float:
3439
"""
35-
Calculates hypothesis function value for a given input
40+
Calculates hypothesis value for a given input tuple.
41+
3642
:param data_input_tuple: Input tuple of a particular example
3743
:return: Value of hypothesis function at that point.
3844
Note that there is an 'biased input' whose value is fixed as 1.
@@ -46,8 +52,10 @@ def _hypothesis_value(data_input_tuple):
4652
return hyp_val
4753

4854

49-
def output(example_no, data_set):
55+
def output(example_no: int, data_set: Literal["train", "test"]) -> float:
5056
"""
57+
Get the true output value of an example.
58+
5159
:param data_set: test data or train data
5260
:param example_no: example whose output is to be fetched
5361
:return: output for that example
@@ -59,9 +67,10 @@ def output(example_no, data_set):
5967
return None
6068

6169

62-
def calculate_hypothesis_value(example_no, data_set):
70+
def calculate_hypothesis_value(example_no: int, data_set: Literal["train", "test"]) -> float:
6371
"""
64-
Calculates hypothesis value for a given example
72+
Calculates hypothesis value for a given example.
73+
6574
:param data_set: test data or train_data
6675
:param example_no: example whose hypothesis value is to be calculated
6776
:return: hypothesis value for that example
@@ -73,9 +82,10 @@ def calculate_hypothesis_value(example_no, data_set):
7382
return None
7483

7584

76-
def summation_of_cost_derivative(index, end=m):
85+
def summation_of_cost_derivative(index: int, end: int = m) -> float:
7786
"""
78-
Calculates the sum of cost function derivative
87+
Calculates the summation term of the cost derivative.
88+
7989
:param index: index wrt derivative is being calculated
8090
:param end: value where summation ends, default is m, number of examples
8191
:return: Returns the summation of cost derivative
@@ -91,8 +101,10 @@ def summation_of_cost_derivative(index, end=m):
91101
return summation_value
92102

93103

94-
def get_cost_derivative(index):
104+
def get_cost_derivative(index: int) -> float:
95105
"""
106+
Compute ∂J/∂θᵢ for a given parameter index.
107+
96108
:param index: index of the parameter vector wrt to derivative is to be calculated
97109
:return: derivative wrt to that index
98110
Note: If index is -1, this means we are calculating summation wrt to biased
@@ -102,7 +114,10 @@ def get_cost_derivative(index):
102114
return cost_derivative_value
103115

104116

105-
def run_gradient_descent():
117+
def run_gradient_descent() -> None:
118+
"""
119+
Perform gradient descent to optimize the parameter vector.
120+
"""
106121
global parameter_vector
107122
# Tune these values to set a tolerance value for predicted output
108123
absolute_error_limit = 0.000002
@@ -127,7 +142,7 @@ def run_gradient_descent():
127142
print(("Number of iterations:", j))
128143

129144

130-
def test_gradient_descent():
145+
def test_gradient_descent() -> None:
131146
for i in range(len(test_data)):
132147
print(("Actual output value:", output(i, "test")))
133148
print(("Hypothesis output:", calculate_hypothesis_value(i, "test")))

0 commit comments

Comments
 (0)