Skip to content

Commit c5dd85b

Browse files
committed
feat: add confusion matrix with precision, recall, and F1 score
Add classification evaluation metrics: - confusion_matrix: binary and multiclass support - precision: TP / (TP + FP) - recall (sensitivity): TP / (TP + FN) - f1_score: harmonic mean of precision and recall All functions include doctests.
1 parent 678dedb commit c5dd85b

File tree

1 file changed

+136
-0
lines changed

1 file changed

+136
-0
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""
2+
Confusion Matrix implementation for evaluating classification models.
3+
4+
A confusion matrix is a table used to evaluate the performance of a
5+
classification algorithm by comparing predicted labels against actual labels.
6+
7+
Reference: https://en.wikipedia.org/wiki/Confusion_matrix
8+
"""
9+
10+
import numpy as np
11+
12+
13+
def confusion_matrix(actual: list, predicted: list) -> np.ndarray:
14+
"""
15+
Calculate the confusion matrix for binary or multiclass classification.
16+
17+
Args:
18+
actual: List of actual class labels.
19+
predicted: List of predicted class labels.
20+
21+
Returns:
22+
A 2D numpy array representing the confusion matrix.
23+
24+
Examples:
25+
>>> actual = [1, 0, 1, 1, 0, 1]
26+
>>> predicted = [1, 0, 0, 1, 0, 0]
27+
>>> confusion_matrix(actual, predicted)
28+
array([[2, 0],
29+
[2, 2]])
30+
31+
>>> actual = [0, 0, 1, 1, 2, 2]
32+
>>> predicted = [0, 1, 1, 2, 2, 0]
33+
>>> confusion_matrix(actual, predicted)
34+
array([[1, 1, 0],
35+
[0, 1, 1],
36+
[1, 0, 1]])
37+
"""
38+
classes = sorted(set(actual) | set(predicted))
39+
n = len(classes)
40+
class_to_index = {c: i for i, c in enumerate(classes)}
41+
42+
matrix = np.zeros((n, n), dtype=int)
43+
for a, p in zip(actual, predicted):
44+
matrix[class_to_index[a]][class_to_index[p]] += 1
45+
46+
return matrix
47+
48+
49+
def precision(actual: list, predicted: list, positive_label: int = 1) -> float:
50+
"""
51+
Calculate precision: TP / (TP + FP).
52+
53+
Args:
54+
actual: List of actual class labels.
55+
predicted: List of predicted class labels.
56+
positive_label: The label considered as positive class.
57+
58+
Returns:
59+
Precision score as a float.
60+
61+
Examples:
62+
>>> actual = [1, 0, 1, 1, 0, 1]
63+
>>> predicted = [1, 0, 0, 1, 0, 0]
64+
>>> precision(actual, predicted)
65+
1.0
66+
67+
>>> actual = [1, 0, 1, 1, 0, 1]
68+
>>> predicted = [1, 1, 0, 1, 0, 0]
69+
>>> precision(actual, predicted)
70+
0.6666666666666666
71+
"""
72+
tp = sum(1 for a, p in zip(actual, predicted) if a == positive_label and p == positive_label)
73+
fp = sum(1 for a, p in zip(actual, predicted) if a != positive_label and p == positive_label)
74+
return tp / (tp + fp) if (tp + fp) > 0 else 0.0
75+
76+
77+
def recall(actual: list, predicted: list, positive_label: int = 1) -> float:
78+
"""
79+
Calculate recall (sensitivity): TP / (TP + FN).
80+
81+
Args:
82+
actual: List of actual class labels.
83+
predicted: List of predicted class labels.
84+
positive_label: The label considered as positive class.
85+
86+
Returns:
87+
Recall score as a float.
88+
89+
Examples:
90+
>>> actual = [1, 0, 1, 1, 0, 1]
91+
>>> predicted = [1, 0, 0, 1, 0, 0]
92+
>>> recall(actual, predicted)
93+
0.5
94+
95+
>>> actual = [1, 0, 1, 1, 0, 1]
96+
>>> predicted = [1, 1, 1, 1, 0, 1]
97+
>>> recall(actual, predicted)
98+
1.0
99+
"""
100+
tp = sum(1 for a, p in zip(actual, predicted) if a == positive_label and p == positive_label)
101+
fn = sum(1 for a, p in zip(actual, predicted) if a == positive_label and p != positive_label)
102+
return tp / (tp + fn) if (tp + fn) > 0 else 0.0
103+
104+
105+
def f1_score(actual: list, predicted: list, positive_label: int = 1) -> float:
106+
"""
107+
Calculate F1 score: harmonic mean of precision and recall.
108+
109+
Args:
110+
actual: List of actual class labels.
111+
predicted: List of predicted class labels.
112+
positive_label: The label considered as positive class.
113+
114+
Returns:
115+
F1 score as a float.
116+
117+
Examples:
118+
>>> actual = [1, 0, 1, 1, 0, 1]
119+
>>> predicted = [1, 0, 0, 1, 0, 0]
120+
>>> round(f1_score(actual, predicted), 4)
121+
0.6667
122+
123+
>>> actual = [1, 0, 1, 1, 0, 1]
124+
>>> predicted = [1, 0, 1, 1, 0, 1]
125+
>>> f1_score(actual, predicted)
126+
1.0
127+
"""
128+
p = precision(actual, predicted, positive_label)
129+
r = recall(actual, predicted, positive_label)
130+
return 2 * p * r / (p + r) if (p + r) > 0 else 0.0
131+
132+
133+
if __name__ == "__main__":
134+
import doctest
135+
136+
doctest.testmod()

0 commit comments

Comments
 (0)