33function.
44"""
55
6+ from __future__ import annotations
7+
68import numpy as np
9+ from typing import Literal , Sequence
710
811# List of input, output pairs
9- train_data = (
12+ train_data : tuple [ tuple [ tuple [ int , ...], int ], ...] = (
1013 ((5 , 2 , 3 ), 15 ),
1114 ((6 , 5 , 9 ), 25 ),
1215 ((11 , 12 , 13 ), 41 ),
1316 ((1 , 1 , 1 ), 8 ),
1417 ((11 , 12 , 13 ), 41 ),
1518)
16- test_data = (((515 , 22 , 13 ), 555 ), ((61 , 35 , 49 ), 150 ))
17- parameter_vector = [2 , 4 , 1 , 5 ]
18- m = len (train_data )
19- LEARNING_RATE = 0.009
19+ test_data : tuple [ tuple [ tuple [ int , ...], int ], ...] = (((515 , 22 , 13 ), 555 ), ((61 , 35 , 49 ), 150 ))
20+ parameter_vector : list [ float ] = [2.0 , 4.0 , 1.0 , 5.0 ]
21+ m : int = len (train_data )
22+ LEARNING_RATE : float = 0.009
2023
2124
22- def _error (example_no , data_set = "train" ):
25+ def _error (example_no : int , data_set : Literal [ "train" , "test" ] = "train" ) -> float :
2326 """
27+ Compute prediction error for a given example.
28+
2429 :param data_set: train data or test data
2530 :param example_no: example number whose error has to be checked
2631 :return: error in example pointed by example number.
@@ -30,9 +35,10 @@ def _error(example_no, data_set="train"):
3035 )
3136
3237
33- def _hypothesis_value (data_input_tuple ) :
38+ def _hypothesis_value (data_input_tuple : Sequence [ int ]) -> float :
3439 """
35- Calculates hypothesis function value for a given input
40+ Calculates hypothesis value for a given input tuple.
41+
3642 :param data_input_tuple: Input tuple of a particular example
3743 :return: Value of hypothesis function at that point.
3844 Note that there is an 'biased input' whose value is fixed as 1.
@@ -46,8 +52,10 @@ def _hypothesis_value(data_input_tuple):
4652 return hyp_val
4753
4854
49- def output (example_no , data_set ) :
55+ def output (example_no : int , data_set : Literal [ "train" , "test" ]) -> float :
5056 """
57+ Get the true output value of an example.
58+
5159 :param data_set: test data or train data
5260 :param example_no: example whose output is to be fetched
5361 :return: output for that example
@@ -59,9 +67,10 @@ def output(example_no, data_set):
5967 return None
6068
6169
62- def calculate_hypothesis_value (example_no , data_set ) :
70+ def calculate_hypothesis_value (example_no : int , data_set : Literal [ "train" , "test" ]) -> float :
6371 """
64- Calculates hypothesis value for a given example
72+ Calculates hypothesis value for a given example.
73+
6574 :param data_set: test data or train_data
6675 :param example_no: example whose hypothesis value is to be calculated
6776 :return: hypothesis value for that example
@@ -73,9 +82,10 @@ def calculate_hypothesis_value(example_no, data_set):
7382 return None
7483
7584
76- def summation_of_cost_derivative (index , end = m ) :
85+ def summation_of_cost_derivative (index : int , end : int = m ) -> float :
7786 """
78- Calculates the sum of cost function derivative
87+ Calculates the summation term of the cost derivative.
88+
7989 :param index: index wrt derivative is being calculated
8090 :param end: value where summation ends, default is m, number of examples
8191 :return: Returns the summation of cost derivative
@@ -91,8 +101,10 @@ def summation_of_cost_derivative(index, end=m):
91101 return summation_value
92102
93103
94- def get_cost_derivative (index ) :
104+ def get_cost_derivative (index : int ) -> float :
95105 """
106+ Compute ∂J/∂θᵢ for a given parameter index.
107+
96108 :param index: index of the parameter vector wrt to derivative is to be calculated
97109 :return: derivative wrt to that index
98110 Note: If index is -1, this means we are calculating summation wrt to biased
@@ -102,7 +114,10 @@ def get_cost_derivative(index):
102114 return cost_derivative_value
103115
104116
105- def run_gradient_descent ():
117+ def run_gradient_descent () -> None :
118+ """
119+ Perform gradient descent to optimize the parameter vector.
120+ """
106121 global parameter_vector
107122 # Tune these values to set a tolerance value for predicted output
108123 absolute_error_limit = 0.000002
@@ -127,7 +142,7 @@ def run_gradient_descent():
127142 print (("Number of iterations:" , j ))
128143
129144
130- def test_gradient_descent ():
145+ def test_gradient_descent () -> None :
131146 for i in range (len (test_data )):
132147 print (("Actual output value:" , output (i , "test" )))
133148 print (("Hypothesis output:" , calculate_hypothesis_value (i , "test" )))
0 commit comments