Skip to content

Commit 7555469

Browse files
Add RMSE and Log-Cosh loss functions for the problem #13379
1 parent 788d95b commit 7555469

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed

machine_learning/loss_functions.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,94 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
663663
return np.sum(kl_loss)
664664

665665

666+
667+
def root_mean_squared_error(y_true: np.ndarray, y_pred: np.ndarray) -> float:
668+
"""
669+
Calculate the Root Mean Squared Error (RMSE) between ground truth and predicted values.
670+
671+
RMSE is the square root of the mean squared error. It is commonly used in regression
672+
tasks to measure the magnitude of prediction errors in the same units as the target.
673+
674+
RMSE = sqrt( (1/n) * Σ(y_true - y_pred)^2 )
675+
676+
Reference: https://en.wikipedia.org/wiki/Mean_squared_error#Root-mean-square_error
677+
678+
Parameters:
679+
- y_true: The true values (ground truth)
680+
- y_pred: The predicted values
681+
682+
>>> true_values = np.array([3, -0.5, 2, 7])
683+
>>> predicted_values = np.array([2.5, 0.0, 2, 8])
684+
>>> float(root_mean_squared_error(true_values, predicted_values))
685+
0.6123724356957945
686+
>>> true_values = np.array([1, 2, 3])
687+
>>> predicted_values = np.array([1, 2, 3])
688+
>>> float(root_mean_squared_error(true_values, predicted_values))
689+
0.0
690+
>>> true_values = np.array([0, 0, 0])
691+
>>> predicted_values = np.array([1, 1, 1])
692+
>>> float(root_mean_squared_error(true_values, predicted_values))
693+
1.0
694+
>>> true_values = np.array([1, 2])
695+
>>> predicted_values = np.array([1, 2, 3])
696+
>>> root_mean_squared_error(true_values, predicted_values)
697+
Traceback (most recent call last):
698+
...
699+
ValueError: Input arrays must have the same length.
700+
"""
701+
if len(y_true) != len(y_pred):
702+
raise ValueError("Input arrays must have the same length.")
703+
704+
mse = np.mean((y_true - y_pred) ** 2)
705+
return np.sqrt(mse)
706+
707+
708+
709+
def log_cosh_loss(y_true: np.ndarray, y_pred: np.ndarray) -> float:
710+
"""
711+
Calculate the Log-Cosh Loss between ground truth and predicted values.
712+
713+
Log-Cosh is the logarithm of the hyperbolic cosine of the prediction error.
714+
It behaves like mean squared error for small errors and like mean absolute error
715+
for large errors, making it less sensitive to outliers.
716+
717+
Log-Cosh = (1/n) * Σ log(cosh(y_pred - y_true))
718+
719+
Reference: https://en.wikipedia.org/wiki/Huber_loss#Pseudo-Huber_loss
720+
721+
Parameters:
722+
- y_true: The true values (ground truth)
723+
- y_pred: The predicted values
724+
725+
>>> true_values = np.array([0.0, 0.0, 0.0])
726+
>>> predicted_values = np.array([0.0, 0.0, 0.0])
727+
>>> float(log_cosh_loss(true_values, predicted_values))
728+
0.0
729+
>>> true_values = np.array([1.0, 2.0])
730+
>>> predicted_values = np.array([1.1, 2.1])
731+
>>> float(round(log_cosh_loss(true_values, predicted_values), 10))
732+
0.0049916889
733+
>>> true_values = np.array([0.0])
734+
>>> predicted_values = np.array([1.0])
735+
>>> float(round(log_cosh_loss(true_values, predicted_values), 10))
736+
0.4337808305
737+
>>> true_values = np.array([1, 2])
738+
>>> predicted_values = np.array([1, 2, 3])
739+
>>> log_cosh_loss(true_values, predicted_values)
740+
Traceback (most recent call last):
741+
...
742+
ValueError: Input arrays must have the same length.
743+
"""
744+
if len(y_true) != len(y_pred):
745+
raise ValueError("Input arrays must have the same length.")
746+
747+
errors = y_pred - y_true
748+
# Use np.logaddexp for numerical stability: log(cosh(x)) = x + log(1 + exp(-2x)) - log(2)
749+
# But for simplicity and readability, we use np.cosh with clipping for large values
750+
# Alternatively, use stable version:
751+
loss = np.logaddexp(errors, -errors) - np.log(2)
752+
return np.mean(loss)
753+
666754
if __name__ == "__main__":
667755
import doctest
668756

0 commit comments

Comments
 (0)