Skip to content
91 changes: 26 additions & 65 deletions sklearn/metrics/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
# Arnaud Joly <[email protected]>
# License: Simplified BSD

from functools import partial
from collections import Counter
from inspect import signature
from functools import partial
from traceback import format_exc

import numpy as np
Expand Down Expand Up @@ -64,20 +65,23 @@

from ..utils.multiclass import type_of_target
from ..base import is_regressor
from ..utils._response import _get_response_values
from ..utils._param_validation import HasMethods, StrOptions, validate_params


def _cached_call(cache, estimator, method, *args, **kwargs):
def _cached_call(cache, estimator, response_method, *args, **kwargs):
"""Call estimator with method and args and kwargs."""
if cache is None:
return getattr(estimator, method)(*args, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this not simply replacing getattr with _get_response_values?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a change requested by @jeremiedbb: #26037 (comment)

It makes the code more readable as @jeremiedbb argued (with a larger diff).

if cache is not None and response_method in cache:
return cache[response_method]

result, _ = _get_response_values(
estimator, *args, response_method=response_method, **kwargs
)

if cache is not None:
cache[response_method] = result

try:
return cache[method]
except KeyError:
result = getattr(estimator, method)(*args, **kwargs)
cache[method] = result
return result
return result


class _MultimetricScorer:
Expand Down Expand Up @@ -162,40 +166,13 @@ def __init__(self, score_func, sign, kwargs):
self._score_func = score_func
self._sign = sign

@staticmethod
def _check_pos_label(pos_label, classes):
if pos_label not in list(classes):
raise ValueError(f"pos_label={pos_label} is not a valid label: {classes}")

def _select_proba_binary(self, y_pred, classes):
"""Select the column of the positive label in `y_pred` when
probabilities are provided.

Parameters
----------
y_pred : ndarray of shape (n_samples, n_classes)
The prediction given by `predict_proba`.

classes : ndarray of shape (n_classes,)
The class labels for the estimator.

Returns
-------
y_pred : ndarray of shape (n_samples,)
Probability predictions of the positive class.
"""
if y_pred.shape[1] == 2:
pos_label = self._kwargs.get("pos_label", classes[1])
self._check_pos_label(pos_label, classes)
col_idx = np.flatnonzero(classes == pos_label)[0]
return y_pred[:, col_idx]

err_msg = (
f"Got predict_proba of shape {y_pred.shape}, but need "
f"classifier with two classes for {self._score_func.__name__} "
"scoring"
)
raise ValueError(err_msg)
def _get_pos_label(self):
if "pos_label" in self._kwargs:
return self._kwargs["pos_label"]
score_func_params = signature(self._score_func).parameters
if "pos_label" in score_func_params:
return score_func_params["pos_label"].default
return None

def __repr__(self):
kwargs_string = "".join(
Expand Down Expand Up @@ -311,14 +288,7 @@ def _score(self, method_caller, clf, X, y, sample_weight=None):
score : float
Score function applied to prediction of estimator on X.
"""

y_type = type_of_target(y)
y_pred = method_caller(clf, "predict_proba", X)
if y_type == "binary" and y_pred.shape[1] <= 2:
# `y_type` could be equal to "binary" even in a multi-class
# problem: (when only 2 class are given to `y_true` during scoring)
# Thus, we need to check for the shape of `y_pred`.
y_pred = self._select_proba_binary(y_pred, clf.classes_)
y_pred = method_caller(clf, "predict_proba", X, pos_label=self._get_pos_label())
if sample_weight is not None:
return self._sign * self._score_func(
y, y_pred, sample_weight=sample_weight, **self._kwargs
Expand Down Expand Up @@ -369,26 +339,17 @@ def _score(self, method_caller, clf, X, y, sample_weight=None):
if is_regressor(clf):
y_pred = method_caller(clf, "predict", X)
else:
pos_label = self._get_pos_label()
try:
y_pred = method_caller(clf, "decision_function", X)
y_pred = method_caller(clf, "decision_function", X, pos_label=pos_label)

if isinstance(y_pred, list):
# For multi-output multi-class estimator
y_pred = np.vstack([p for p in y_pred]).T
elif y_type == "binary" and "pos_label" in self._kwargs:
self._check_pos_label(self._kwargs["pos_label"], clf.classes_)
if self._kwargs["pos_label"] == clf.classes_[0]:
# The implicit positive class of the binary classifier
# does not match `pos_label`: we need to invert the
# predictions
y_pred *= -1

except (NotImplementedError, AttributeError):
y_pred = method_caller(clf, "predict_proba", X)

if y_type == "binary":
y_pred = self._select_proba_binary(y_pred, clf.classes_)
elif isinstance(y_pred, list):
y_pred = method_caller(clf, "predict_proba", X, pos_label=pos_label)
if isinstance(y_pred, list):
y_pred = np.vstack([p[:, -1] for p in y_pred]).T

if sample_weight is not None:
Expand Down
15 changes: 10 additions & 5 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,13 +759,18 @@ def test_multimetric_scorer_calls_method_once(
X, y = np.array([[1], [1], [0], [0], [0]]), np.array([0, 1, 1, 1, 0])

mock_est = Mock()
fit_func = Mock(return_value=mock_est)
predict_func = Mock(return_value=y)
mock_est._estimator_type = "classifier"
fit_func = Mock(return_value=mock_est, name="fit")
fit_func.__name__ = "fit"
predict_func = Mock(return_value=y, name="predict")
predict_func.__name__ = "predict"

pos_proba = np.random.rand(X.shape[0])
proba = np.c_[1 - pos_proba, pos_proba]
predict_proba_func = Mock(return_value=proba)
decision_function_func = Mock(return_value=pos_proba)
predict_proba_func = Mock(return_value=proba, name="predict_proba")
predict_proba_func.__name__ = "predict_proba"
decision_function_func = Mock(return_value=pos_proba, name="decision_function")
decision_function_func.__name__ = "decision_function"

mock_est.fit = fit_func
mock_est.predict = predict_func
Expand Down Expand Up @@ -961,7 +966,7 @@ def test_multiclass_roc_no_proba_scorer_errors(scorer_name):
n_classes=3, n_informative=3, n_samples=20, random_state=0
)
lr = Perceptron().fit(X, y)
msg = "'Perceptron' object has no attribute 'predict_proba'"
msg = "Perceptron has none of the following attributes: predict_proba."
with pytest.raises(AttributeError, match=msg):
scorer(lr, X, y)

Expand Down
10 changes: 0 additions & 10 deletions sklearn/utils/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ def _get_response_values(
The response values are predictions, one scalar value for each sample in X
that depends on the specific choice of `response_method`.

This helper only accepts multiclass classifiers with the `predict` response
method.

If `estimator` is a binary classifier, also return the label for the
effective positive class.

Expand Down Expand Up @@ -75,15 +72,8 @@ def _get_response_values(
if is_classifier(estimator):
prediction_method = _check_response_method(estimator, response_method)
classes = estimator.classes_

target_type = "binary" if len(classes) <= 2 else "multiclass"

if target_type == "multiclass" and prediction_method.__name__ != "predict":
raise ValueError(
"With a multiclass estimator, the response method should be "
f"predict, got {prediction_method.__name__} instead."
)

if pos_label is not None and pos_label not in classes.tolist():
raise ValueError(
f"pos_label={pos_label} is not a valid label: It should be "
Expand Down
45 changes: 25 additions & 20 deletions sklearn/utils/tests/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
LinearRegression,
LogisticRegression,
)
from sklearn.svm import SVC
from sklearn.preprocessing import scale
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.utils._mocking import _MockEstimatorOnOffPrediction
from sklearn.utils._testing import assert_allclose, assert_array_equal
Expand All @@ -15,6 +15,8 @@


X, y = load_iris(return_X_y=True)
# scale the data to avoid ConvergenceWarning with LogisticRegression
X = scale(X, copy=False)
X_binary, y_binary = X[:100], y[:100]


Expand All @@ -29,25 +31,6 @@ def test_get_response_values_regressor_error(response_method):
_get_response_values(my_estimator, X, response_method=response_method)


@pytest.mark.parametrize(
"estimator, response_method",
[
(DecisionTreeClassifier(), "predict_proba"),
(SVC(), "decision_function"),
],
)
def test_get_response_values_error_multiclass_classifier(estimator, response_method):
"""Check that we raise an error with multiclass classifier and requesting
response values different from `predict`."""
X, y = make_classification(
n_samples=10, n_clusters_per_class=1, n_classes=3, random_state=0
)
classifier = estimator.fit(X, y)
err_msg = "With a multiclass estimator, the response method should be predict"
with pytest.raises(ValueError, match=err_msg):
_get_response_values(classifier, X, response_method=response_method)


def test_get_response_values_regressor():
"""Check the behaviour of `_get_response_values` with regressor."""
X, y = make_regression(n_samples=10, random_state=0)
Expand Down Expand Up @@ -227,3 +210,25 @@ def test_get_response_decision_function():
)
np.testing.assert_allclose(y_score, classifier.decision_function(X_binary) * -1)
assert pos_label == 0


@pytest.mark.parametrize(
"estimator, response_method",
[
(DecisionTreeClassifier(max_depth=2, random_state=0), "predict_proba"),
(LogisticRegression(), "decision_function"),
],
)
def test_get_response_values_multiclass(estimator, response_method):
"""Check that we can call `_get_response_values` with a multiclass estimator.
It should return the predictions untouched.
"""
estimator.fit(X, y)
predictions, pos_label = _get_response_values(
estimator, X, response_method=response_method
)

assert pos_label is None
assert predictions.shape == (X.shape[0], len(estimator.classes_))
if response_method == "predict_proba":
assert np.logical_and(predictions >= 0, predictions <= 1).all()