|
25 | 25 |
|
26 | 26 | from sklearn.metrics import accuracy_score |
27 | 27 | from sklearn.metrics import average_precision_score |
| 28 | +from sklearn.metrics import brier_score_loss |
28 | 29 | from sklearn.metrics import confusion_matrix |
29 | 30 | from sklearn.metrics import coverage_error |
30 | 31 | from sklearn.metrics import explained_variance_score |
|
148 | 149 |
|
149 | 150 | "hinge_loss": hinge_loss, |
150 | 151 |
|
| 152 | + "brier_score_loss": brier_score_loss, |
| 153 | + |
151 | 154 | "roc_auc_score": roc_auc_score, |
152 | 155 | "weighted_roc_auc": partial(roc_auc_score, average="weighted"), |
153 | 156 | "samples_roc_auc": partial(roc_auc_score, average="samples"), |
|
197 | 200 | "macro_roc_auc", "samples_roc_auc", |
198 | 201 |
|
199 | 202 | "coverage_error", |
| 203 | + "brier_score_loss" |
200 | 204 | ] |
201 | 205 |
|
202 | 206 | # Metrics with an "average" argument |
|
211 | 215 |
|
212 | 216 | # Metrics with a "pos_label" argument |
213 | 217 | METRICS_WITH_POS_LABEL = [ |
214 | | - "roc_curve", "hinge_loss", |
| 218 | + "roc_curve", |
| 219 | + |
| 220 | + "brier_score_loss", |
215 | 221 |
|
216 | 222 | "precision_score", "recall_score", "f1_score", "f2_score", "f0.5_score", |
217 | 223 |
|
@@ -554,9 +560,15 @@ def test_invariance_string_vs_numbers_labels(): |
554 | 560 | "invariance test".format(name)) |
555 | 561 |
|
556 | 562 | for name, metric in THRESHOLDED_METRICS.items(): |
557 | | - if name in ("log_loss", "hinge_loss", "unnormalized_log_loss"): |
| 563 | + if name in ("log_loss", "hinge_loss", "unnormalized_log_loss", |
| 564 | + "brier_score_loss"): |
| 565 | + # Ugly, but handle case with a pos_label and label |
| 566 | + metric_str = metric |
| 567 | + if name in METRICS_WITH_POS_LABEL: |
| 568 | + metric_str = partial(metric_str, pos_label=pos_label_str) |
| 569 | + |
558 | 570 | measure_with_number = metric(y1, y2) |
559 | | - measure_with_str = metric(y1_str, y2) |
| 571 | + measure_with_str = metric_str(y1_str, y2) |
560 | 572 | assert_array_equal(measure_with_number, measure_with_str, |
561 | 573 | err_msg="{0} failed string vs number " |
562 | 574 | "invariance test".format(name)) |
|
0 commit comments