-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
MNT replace if_delegate_has_method by available_if in _search.py #20685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
943affd
ad8b5dd
8765aa5
da77194
dc4f4a6
f588908
a91f43d
66bede3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,7 +38,7 @@ | |
| from ..utils.random import sample_without_replacement | ||
| from ..utils._tags import _safe_tags | ||
| from ..utils.validation import indexable, check_is_fitted, _check_fit_params | ||
| from ..utils.metaestimators import if_delegate_has_method | ||
| from ..utils.metaestimators import available_if | ||
| from ..utils.fixes import delayed | ||
| from ..metrics._scorer import _check_multimetric_scoring | ||
| from ..metrics import check_scoring | ||
|
|
@@ -345,6 +345,40 @@ def _check_param_grid(param_grid): | |
| ) | ||
|
|
||
|
|
||
| def _check_refit(search_cv, attr): | ||
| if not search_cv.refit: | ||
| raise AttributeError( | ||
| f"This {type(search_cv).__name__} instance was initialized with " | ||
| f"`refit=False`. {attr} is available only after refitting on the best " | ||
| "parameters. You can refit an estimator manually using the " | ||
| "`best_params_` attribute" | ||
| ) | ||
|
|
||
|
|
||
| def _estimator_has(attr): | ||
| """Check if we can delegate a method to the underlying estimator. | ||
|
|
||
| Calling a prediction method will only be available if `refit=True`. In | ||
| such case, we check first the fitted best estimator. If it is not | ||
| fitted, we check the unfitted estimator. | ||
|
|
||
| Checking the unfitted estimator allows to use `hasattr` on the `SearchCV` | ||
| instance even before calling `fit`. | ||
| """ | ||
|
|
||
| def check(self): | ||
| _check_refit(self, attr) | ||
| if hasattr(self, "best_estimator_"): | ||
| # raise an AttributeError if `attr` does not exist | ||
| getattr(self.best_estimator_, attr) | ||
| return True | ||
| # raise an AttributeError if `attr` does not exist | ||
| getattr(self.estimator, attr) | ||
| return True | ||
|
|
||
| return check | ||
|
|
||
|
|
||
| class BaseSearchCV(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta): | ||
| """Abstract base class for hyper parameter search with cross-validation.""" | ||
|
|
||
|
|
@@ -418,7 +452,8 @@ def score(self, X, y=None): | |
| ------- | ||
| score : float | ||
| """ | ||
| self._check_is_fitted("score") | ||
| _check_refit(self, "score") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is certainly a backward incompatible code (I think). I agree with the new logic, but we need to discuss this.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On
With this PR, I think the behavior is still the same.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just replaced the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have you pushed the changes @glemaitre ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I accepted the suggestions of @thomasjpfan but as I mentioned, I don't think that there is a change of behavior here. Which backward incompatibility do you observe (I might have overlooked as well since I was using the tests to ensure that). |
||
| check_is_fitted(self) | ||
| if self.scorer_ is None: | ||
| raise ValueError( | ||
| "No score function explicitly defined, " | ||
|
|
@@ -438,7 +473,7 @@ def score(self, X, y=None): | |
| score = score[self.refit] | ||
| return score | ||
|
|
||
| @if_delegate_has_method(delegate=("best_estimator_", "estimator")) | ||
| @available_if(_estimator_has("score_samples")) | ||
| def score_samples(self, X): | ||
| """Call score_samples on the estimator with the best found parameters. | ||
|
|
||
|
|
@@ -457,23 +492,10 @@ def score_samples(self, X): | |
| ------- | ||
| y_score : ndarray of shape (n_samples,) | ||
| """ | ||
| self._check_is_fitted("score_samples") | ||
| check_is_fitted(self) | ||
| return self.best_estimator_.score_samples(X) | ||
|
|
||
| def _check_is_fitted(self, method_name): | ||
| if not self.refit: | ||
| raise NotFittedError( | ||
| "This %s instance was initialized " | ||
| "with refit=False. %s is " | ||
| "available only after refitting on the best " | ||
| "parameters. You can refit an estimator " | ||
| "manually using the ``best_params_`` " | ||
| "attribute" % (type(self).__name__, method_name) | ||
| ) | ||
| else: | ||
| check_is_fitted(self) | ||
|
|
||
| @if_delegate_has_method(delegate=("best_estimator_", "estimator")) | ||
| @available_if(_estimator_has("predict")) | ||
| def predict(self, X): | ||
| """Call predict on the estimator with the best found parameters. | ||
|
|
||
|
|
@@ -487,10 +509,10 @@ def predict(self, X): | |
| underlying estimator. | ||
|
|
||
| """ | ||
| self._check_is_fitted("predict") | ||
| check_is_fitted(self) | ||
| return self.best_estimator_.predict(X) | ||
|
|
||
| @if_delegate_has_method(delegate=("best_estimator_", "estimator")) | ||
| @available_if(_estimator_has("predict_proba")) | ||
| def predict_proba(self, X): | ||
| """Call predict_proba on the estimator with the best found parameters. | ||
|
|
||
|
|
@@ -504,10 +526,10 @@ def predict_proba(self, X): | |
| underlying estimator. | ||
|
|
||
| """ | ||
| self._check_is_fitted("predict_proba") | ||
| check_is_fitted(self) | ||
| return self.best_estimator_.predict_proba(X) | ||
|
|
||
| @if_delegate_has_method(delegate=("best_estimator_", "estimator")) | ||
| @available_if(_estimator_has("predict_log_proba")) | ||
| def predict_log_proba(self, X): | ||
| """Call predict_log_proba on the estimator with the best found parameters. | ||
|
|
||
|
|
@@ -521,10 +543,10 @@ def predict_log_proba(self, X): | |
| underlying estimator. | ||
|
|
||
| """ | ||
| self._check_is_fitted("predict_log_proba") | ||
| check_is_fitted(self) | ||
| return self.best_estimator_.predict_log_proba(X) | ||
|
|
||
| @if_delegate_has_method(delegate=("best_estimator_", "estimator")) | ||
| @available_if(_estimator_has("decision_function")) | ||
| def decision_function(self, X): | ||
| """Call decision_function on the estimator with the best found parameters. | ||
|
|
||
|
|
@@ -538,10 +560,10 @@ def decision_function(self, X): | |
| underlying estimator. | ||
|
|
||
| """ | ||
| self._check_is_fitted("decision_function") | ||
| check_is_fitted(self) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should either use
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the beauty of the decorator, the inner code look-alike any other estimator because the check is done by the decorator. So only the fitting should be checked here. |
||
| return self.best_estimator_.decision_function(X) | ||
|
|
||
| @if_delegate_has_method(delegate=("best_estimator_", "estimator")) | ||
| @available_if(_estimator_has("transform")) | ||
| def transform(self, X): | ||
| """Call transform on the estimator with the best found parameters. | ||
|
|
||
|
|
@@ -555,10 +577,10 @@ def transform(self, X): | |
| underlying estimator. | ||
|
|
||
| """ | ||
| self._check_is_fitted("transform") | ||
| check_is_fitted(self) | ||
| return self.best_estimator_.transform(X) | ||
|
|
||
| @if_delegate_has_method(delegate=("best_estimator_", "estimator")) | ||
| @available_if(_estimator_has("inverse_transform")) | ||
| def inverse_transform(self, Xt): | ||
| """Call inverse_transform on the estimator with the best found params. | ||
|
|
||
|
|
@@ -572,7 +594,7 @@ def inverse_transform(self, Xt): | |
| underlying estimator. | ||
|
|
||
| """ | ||
| self._check_is_fitted("inverse_transform") | ||
| check_is_fitted(self) | ||
| return self.best_estimator_.inverse_transform(Xt) | ||
|
|
||
| @property | ||
|
|
@@ -592,7 +614,7 @@ def n_features_in_(self): | |
|
|
||
| @property | ||
| def classes_(self): | ||
| self._check_is_fitted("classes_") | ||
| _estimator_has("classes_")(self) | ||
| return self.best_estimator_.classes_ | ||
|
|
||
| def _run_search(self, evaluate_candidates): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.