Skip to content

Commit bd3072b

Browse files
Added Mean Squared Logarithmic Error (MSLE) Loss Function (#10637)
* Added Mean Squared Logarithmic Error (MSLE) * Added Mean Squared Logarithmic Error (MSLE) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 289a4dd commit bd3072b

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""
2+
Mean Squared Logarithmic Error (MSLE) Loss Function
3+
4+
Description:
5+
MSLE measures the mean squared logarithmic difference between
6+
true values and predicted values, particularly useful when
7+
dealing with regression problems involving skewed or large-value
8+
targets. It is often used when the relative differences between
9+
predicted and true values are more important than absolute
10+
differences.
11+
12+
Formula:
13+
MSLE = (1/n) * Σ(log(1 + y_true) - log(1 + y_pred))^2
14+
15+
Source:
16+
(https://insideaiml.com/blog/MeanSquared-Logarithmic-Error-Loss-1035)
17+
"""
18+
19+
import numpy as np
20+
21+
22+
def mean_squared_logarithmic_error(y_true: np.ndarray, y_pred: np.ndarray) -> float:
23+
"""
24+
Calculate the Mean Squared Logarithmic Error (MSLE) between two arrays.
25+
26+
Parameters:
27+
- y_true: The true values (ground truth).
28+
- y_pred: The predicted values.
29+
30+
Returns:
31+
- msle: The Mean Squared Logarithmic Error between y_true and y_pred.
32+
33+
Example usage:
34+
>>> true_values = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
35+
>>> predicted_values = np.array([0.8, 2.1, 2.9, 4.2, 5.2])
36+
>>> mean_squared_logarithmic_error(true_values, predicted_values)
37+
0.0030860877925181344
38+
>>> true_labels = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
39+
>>> predicted_probs = np.array([0.3, 0.8, 0.9, 0.2])
40+
>>> mean_squared_logarithmic_error(true_labels, predicted_probs)
41+
Traceback (most recent call last):
42+
...
43+
ValueError: Input arrays must have the same length.
44+
"""
45+
if len(y_true) != len(y_pred):
46+
raise ValueError("Input arrays must have the same length.")
47+
48+
squared_logarithmic_errors = (np.log1p(y_true) - np.log1p(y_pred)) ** 2
49+
return np.mean(squared_logarithmic_errors)
50+
51+
52+
if __name__ == "__main__":
53+
import doctest
54+
55+
doctest.testmod()

0 commit comments

Comments
 (0)