Skip to content

Commit 8fa5abe

Browse files
committed
Fixes #10
1 parent 788d95b commit 8fa5abe

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

machine_learning/loss_functions.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,46 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
663663
return np.sum(kl_loss)
664664

665665

666+
def symmetric_mean_absolute_percentage_error(
667+
y_true: np.ndarray, y_pred: np.ndarray, epsilon: float = 1e-15
668+
) -> float:
669+
"""
670+
Calculate the Symmetric Mean Absolute Percentage Error (SMAPE) between y_true and
671+
y_pred.
672+
673+
SMAPE is an accuracy measure based on percentage (or relative) errors. It is
674+
symmetric and treats over- and under- predictions equally.
675+
676+
SMAPE = (1/n) * Σ( |y_true - y_pred| / ((|y_true| + |y_pred|) / 2) )
677+
678+
Reference: https://en.wikipedia.org/wiki/Symmetric_mean_absolute_percentage_error
679+
680+
Parameters:
681+
- y_true: The true values (ground truth)
682+
- y_pred: The predicted values
683+
- epsilon: Small constant to avoid division by zero
684+
685+
>>> true_values = np.array([100, 200, 300, 400])
686+
>>> predicted_values = np.array([110, 190, 310, 420])
687+
>>> float(symmetric_mean_absolute_percentage_error(true_values, predicted_values))
688+
0.05702187989273155
689+
>>> true_labels = np.array([100, 200, 300])
690+
>>> predicted_probs = np.array([110, 190, 310, 420])
691+
>>> symmetric_mean_absolute_percentage_error(true_labels, predicted_probs)
692+
Traceback (most recent call last):
693+
...
694+
ValueError: Input arrays must have the same length.
695+
"""
696+
if len(y_true) != len(y_pred):
697+
raise ValueError("Input arrays must have the same length.")
698+
699+
denominator = (np.abs(y_true) + np.abs(y_pred)) / 2.0
700+
denominator = np.where(denominator == 0, epsilon, denominator)
701+
702+
smape_loss = np.abs(y_true - y_pred) / denominator
703+
return np.mean(smape_loss)
704+
705+
666706
if __name__ == "__main__":
667707
import doctest
668708

0 commit comments

Comments
 (0)