Skip to content

Commit a798d9e

Browse files
Jan Hendrik Metzenogrisel
authored andcommitted
TST Adding brier_score_loss to test_common.py
1 parent c181308 commit a798d9e

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

sklearn/metrics/tests/test_common.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from sklearn.metrics import accuracy_score
2727
from sklearn.metrics import average_precision_score
28+
from sklearn.metrics import brier_score_loss
2829
from sklearn.metrics import confusion_matrix
2930
from sklearn.metrics import coverage_error
3031
from sklearn.metrics import explained_variance_score
@@ -148,6 +149,8 @@
148149

149150
"hinge_loss": hinge_loss,
150151

152+
"brier_score_loss": brier_score_loss,
153+
151154
"roc_auc_score": roc_auc_score,
152155
"weighted_roc_auc": partial(roc_auc_score, average="weighted"),
153156
"samples_roc_auc": partial(roc_auc_score, average="samples"),
@@ -197,6 +200,7 @@
197200
"macro_roc_auc", "samples_roc_auc",
198201

199202
"coverage_error",
203+
"brier_score_loss"
200204
]
201205

202206
# Metrics with an "average" argument
@@ -211,7 +215,9 @@
211215

212216
# Metrics with a "pos_label" argument
213217
METRICS_WITH_POS_LABEL = [
214-
"roc_curve", "hinge_loss",
218+
"roc_curve",
219+
220+
"brier_score_loss",
215221

216222
"precision_score", "recall_score", "f1_score", "f2_score", "f0.5_score",
217223

@@ -554,9 +560,15 @@ def test_invariance_string_vs_numbers_labels():
554560
"invariance test".format(name))
555561

556562
for name, metric in THRESHOLDED_METRICS.items():
557-
if name in ("log_loss", "hinge_loss", "unnormalized_log_loss"):
563+
if name in ("log_loss", "hinge_loss", "unnormalized_log_loss",
564+
"brier_score_loss"):
565+
# Ugly, but handle case with a pos_label and label
566+
metric_str = metric
567+
if name in METRICS_WITH_POS_LABEL:
568+
metric_str = partial(metric_str, pos_label=pos_label_str)
569+
558570
measure_with_number = metric(y1, y2)
559-
measure_with_str = metric(y1_str, y2)
571+
measure_with_str = metric_str(y1_str, y2)
560572
assert_array_equal(measure_with_number, measure_with_str,
561573
err_msg="{0} failed string vs number "
562574
"invariance test".format(name))

0 commit comments

Comments
 (0)