Skip to content

Commit fee0797

Browse files
authored
Update loss_functions.py
1 parent 6fbb895 commit fee0797

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

machine_learning/loss_functions.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -676,16 +676,21 @@ def root_mean_squared_error(y_true: np.array, y_pred: np.array) -> float:
676676
Reference: https://en.wikipedia.org/wiki/Root_mean_square_deviation
677677
678678
Parameters:
679-
- y_pred: Predicted Value
680-
- y_true: Actual Value
679+
y_true: Actual Value
680+
y_pred: Predicted Value
681681
682682
Returns:
683683
float: The RMSE Loss function between y_pred and y_true
684684
685-
>>> true_labels = np.array([2, 4, 6, 8])
686-
>>> predicted_probs = np.array([3, 5, 7, 10])
687-
>>> root_mean_squared_error(true_labels, predicted_probs)
688-
1.3228
685+
>>> true_labels = np.array([100, 200, 300])
686+
>>> predicted_probs = np.array([110, 190, 310])
687+
>>> round(root_mean_squared_error(true_labels, predicted_probs), 4)
688+
10.0
689+
690+
>>> true_labels = [2, 4, 6, 8]
691+
>>> predicted_probs = [3, 5, 7, 10]
692+
>>> round(root_mean_squared_error(true_labels, predicted_probs), 4)
693+
1.3229
689694
690695
>>> true_labels = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
691696
>>> predicted_probs = np.array([0.3, 0.8, 0.9, 0.2])
@@ -698,7 +703,7 @@ def root_mean_squared_error(y_true: np.array, y_pred: np.array) -> float:
698703
raise ValueError("Input arrays must have the same length.")
699704
y_true, y_pred = np.array(y_true), np.array(y_pred)
700705

701-
mse = np.mean((y_true - y_pred) ** 2)
706+
mse = np.mean((y_pred - y_true) ** 2)
702707
return np.sqrt(mse)
703708

704709

0 commit comments

Comments
 (0)