Skip to content

Commit dfe35fb

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d0d0b40 commit dfe35fb

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

machine_learning/gradient_descent.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,19 @@
1616
((1, 1, 1), 8),
1717
((11, 12, 13), 41),
1818
)
19-
test_data: tuple[tuple[tuple[int, ...], int], ...] = (((515, 22, 13), 555), ((61, 35, 49), 150))
19+
test_data: tuple[tuple[tuple[int, ...], int], ...] = (
20+
((515, 22, 13), 555),
21+
((61, 35, 49), 150),
22+
)
2023
parameter_vector: list[float] = [2.0, 4.0, 1.0, 5.0]
2124
m: int = len(train_data)
2225
LEARNING_RATE: float = 0.009
2326

2427

25-
def _error(example_no: int, data_set: Literal["train", "test"]="train") -> float:
28+
def _error(example_no: int, data_set: Literal["train", "test"] = "train") -> float:
2629
"""
2730
Compute prediction error for a given example.
28-
31+
2932
:param data_set: train data or test data
3033
:param example_no: example number whose error has to be checked
3134
:return: error in example pointed by example number.
@@ -38,7 +41,7 @@ def _error(example_no: int, data_set: Literal["train", "test"]="train") -> float
3841
def _hypothesis_value(data_input_tuple: Sequence[int]) -> float:
3942
"""
4043
Calculates hypothesis value for a given input tuple.
41-
44+
4245
:param data_input_tuple: Input tuple of a particular example
4346
:return: Value of hypothesis function at that point.
4447
Note that there is an 'biased input' whose value is fixed as 1.
@@ -55,7 +58,7 @@ def _hypothesis_value(data_input_tuple: Sequence[int]) -> float:
5558
def output(example_no: int, data_set: Literal["train", "test"]) -> float:
5659
"""
5760
Get the true output value of an example.
58-
61+
5962
:param data_set: test data or train data
6063
:param example_no: example whose output is to be fetched
6164
:return: output for that example
@@ -67,10 +70,12 @@ def output(example_no: int, data_set: Literal["train", "test"]) -> float:
6770
return None
6871

6972

70-
def calculate_hypothesis_value(example_no: int, data_set: Literal["train", "test"]) -> float:
73+
def calculate_hypothesis_value(
74+
example_no: int, data_set: Literal["train", "test"]
75+
) -> float:
7176
"""
7277
Calculates hypothesis value for a given example.
73-
78+
7479
:param data_set: test data or train_data
7580
:param example_no: example whose hypothesis value is to be calculated
7681
:return: hypothesis value for that example
@@ -85,7 +90,7 @@ def calculate_hypothesis_value(example_no: int, data_set: Literal["train", "test
8590
def summation_of_cost_derivative(index: int, end: int = m) -> float:
8691
"""
8792
Calculates the summation term of the cost derivative.
88-
93+
8994
:param index: index wrt derivative is being calculated
9095
:param end: value where summation ends, default is m, number of examples
9196
:return: Returns the summation of cost derivative
@@ -104,7 +109,7 @@ def summation_of_cost_derivative(index: int, end: int = m) -> float:
104109
def get_cost_derivative(index: int) -> float:
105110
"""
106111
Compute ∂J/∂θᵢ for a given parameter index.
107-
112+
108113
:param index: index of the parameter vector wrt to derivative is to be calculated
109114
:return: derivative wrt to that index
110115
Note: If index is -1, this means we are calculating summation wrt to biased

0 commit comments

Comments
 (0)