Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 52 additions & 30 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ..utils.random import sample_without_replacement
from ..utils._tags import _safe_tags
from ..utils.validation import indexable, check_is_fitted, _check_fit_params
from ..utils.metaestimators import if_delegate_has_method
from ..utils.metaestimators import available_if
from ..utils.fixes import delayed
from ..metrics._scorer import _check_multimetric_scoring
from ..metrics import check_scoring
Expand Down Expand Up @@ -345,6 +345,40 @@ def _check_param_grid(param_grid):
)


def _check_refit(search_cv, attr):
if not search_cv.refit:
raise AttributeError(
f"This {type(search_cv).__name__} instance was initialized with "
f"`refit=False`. {attr} is available only after refitting on the best "
"parameters. You can refit an estimator manually using the "
"`best_params_` attribute"
)


def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.

Calling a prediction method will only be available if `refit=True`. In
such case, we check first the fitted best estimator. If it is not
fitted, we check the unfitted estimator.

Checking the unfitted estimator allows to use `hasattr` on the `SearchCV`
instance even before calling `fit`.
"""

def check(self):
_check_refit(self, attr)
if hasattr(self, "best_estimator_"):
# raise an AttributeError if `attr` does not exist
getattr(self.best_estimator_, attr)
return True
# raise an AttributeError if `attr` does not exist
getattr(self.estimator, attr)
return True

return check


class BaseSearchCV(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta):
"""Abstract base class for hyper parameter search with cross-validation."""

Expand Down Expand Up @@ -418,7 +452,8 @@ def score(self, X, y=None):
-------
score : float
"""
self._check_is_fitted("score")
_check_refit(self, "score")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is certainly a backward incompatible code (I think). I agree with the new logic, but we need to discuss this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On main:

  • If refit=False, then NotFittedError is raised stating that refit=False
  • If refit=True, then check_is_fitted is called.

With this PR, I think the behavior is still the same.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just replaced the NotFittedError with an AttributeError since this is indeed the issue.
In terms of behaviour, this is the exact current behaviour.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you pushed the changes @glemaitre ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I accepted the suggestions of @thomasjpfan but as I mentioned, I don't think that there is a change of behavior here. Which backward incompatibility do you observe (I might have overlooked as well since I was using the tests to ensure that).

check_is_fitted(self)
if self.scorer_ is None:
raise ValueError(
"No score function explicitly defined, "
Expand All @@ -438,7 +473,7 @@ def score(self, X, y=None):
score = score[self.refit]
return score

@if_delegate_has_method(delegate=("best_estimator_", "estimator"))
@available_if(_estimator_has("score_samples"))
def score_samples(self, X):
"""Call score_samples on the estimator with the best found parameters.

Expand All @@ -457,23 +492,10 @@ def score_samples(self, X):
-------
y_score : ndarray of shape (n_samples,)
"""
self._check_is_fitted("score_samples")
check_is_fitted(self)
return self.best_estimator_.score_samples(X)

def _check_is_fitted(self, method_name):
if not self.refit:
raise NotFittedError(
"This %s instance was initialized "
"with refit=False. %s is "
"available only after refitting on the best "
"parameters. You can refit an estimator "
"manually using the ``best_params_`` "
"attribute" % (type(self).__name__, method_name)
)
else:
check_is_fitted(self)

@if_delegate_has_method(delegate=("best_estimator_", "estimator"))
@available_if(_estimator_has("predict"))
def predict(self, X):
"""Call predict on the estimator with the best found parameters.

Expand All @@ -487,10 +509,10 @@ def predict(self, X):
underlying estimator.

"""
self._check_is_fitted("predict")
check_is_fitted(self)
return self.best_estimator_.predict(X)

@if_delegate_has_method(delegate=("best_estimator_", "estimator"))
@available_if(_estimator_has("predict_proba"))
def predict_proba(self, X):
"""Call predict_proba on the estimator with the best found parameters.

Expand All @@ -504,10 +526,10 @@ def predict_proba(self, X):
underlying estimator.

"""
self._check_is_fitted("predict_proba")
check_is_fitted(self)
return self.best_estimator_.predict_proba(X)

@if_delegate_has_method(delegate=("best_estimator_", "estimator"))
@available_if(_estimator_has("predict_log_proba"))
def predict_log_proba(self, X):
"""Call predict_log_proba on the estimator with the best found parameters.

Expand All @@ -521,10 +543,10 @@ def predict_log_proba(self, X):
underlying estimator.

"""
self._check_is_fitted("predict_log_proba")
check_is_fitted(self)
return self.best_estimator_.predict_log_proba(X)

@if_delegate_has_method(delegate=("best_estimator_", "estimator"))
@available_if(_estimator_has("decision_function"))
def decision_function(self, X):
"""Call decision_function on the estimator with the best found parameters.

Expand All @@ -538,10 +560,10 @@ def decision_function(self, X):
underlying estimator.

"""
self._check_is_fitted("decision_function")
check_is_fitted(self)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should either use _check_refit in all of these functions or none, don't we?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_estimator_has will call _check_refit

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the beauty of the decorator, the inner code look-alike any other estimator because the check is done by the decorator. So only the fitting should be checked here.

return self.best_estimator_.decision_function(X)

@if_delegate_has_method(delegate=("best_estimator_", "estimator"))
@available_if(_estimator_has("transform"))
def transform(self, X):
"""Call transform on the estimator with the best found parameters.

Expand All @@ -555,10 +577,10 @@ def transform(self, X):
underlying estimator.

"""
self._check_is_fitted("transform")
check_is_fitted(self)
return self.best_estimator_.transform(X)

@if_delegate_has_method(delegate=("best_estimator_", "estimator"))
@available_if(_estimator_has("inverse_transform"))
def inverse_transform(self, Xt):
"""Call inverse_transform on the estimator with the best found params.

Expand All @@ -572,7 +594,7 @@ def inverse_transform(self, Xt):
underlying estimator.

"""
self._check_is_fitted("inverse_transform")
check_is_fitted(self)
return self.best_estimator_.inverse_transform(Xt)

@property
Expand All @@ -592,7 +614,7 @@ def n_features_in_(self):

@property
def classes_(self):
self._check_is_fitted("classes_")
_estimator_has("classes_")(self)
return self.best_estimator_.classes_

def _run_search(self, evaluate_candidates):
Expand Down
4 changes: 2 additions & 2 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,10 @@ def test_no_refit():
"inverse_transform",
):
error_msg = (
f"refit=False. {fn_name} is available only after "
f"`refit=False`. {fn_name} is available only after "
"refitting on the best parameters"
)
with pytest.raises(NotFittedError, match=error_msg):
with pytest.raises(AttributeError, match=error_msg):
getattr(grid_search, fn_name)(X)

# Test that an invalid refit param raises appropriate error messages
Expand Down