Skip to content

Commit 26a98f7

Browse files
committed
remove routing from non fit and score methods
1 parent e527d95 commit 26a98f7

File tree

3 files changed

+43
-113
lines changed

3 files changed

+43
-113
lines changed

sklearn/model_selection/_search.py

Lines changed: 12 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def score_samples(self, X):
555555
return self.best_estimator_.score_samples(X)
556556

557557
@available_if(_estimator_has("predict"))
558-
def predict(self, X, **params):
558+
def predict(self, X):
559559
"""Call predict on the estimator with the best found parameters.
560560
561561
Only available if ``refit=True`` and the underlying estimator supports
@@ -567,27 +567,17 @@ def predict(self, X, **params):
567567
Must fulfill the input assumptions of the
568568
underlying estimator.
569569
570-
**params : dict
571-
Parameters to be passed the underlying estimator's ``predict``.
572-
573-
Only available if `enable_metadata_routing=True`. See the
574-
:ref:`User Guide <metadata_routing>`.
575-
576-
..versionadded:: 1.4
577-
578570
Returns
579571
-------
580572
y_pred : ndarray of shape (n_samples,)
581573
The predicted labels or values for `X` based on the estimator with
582574
the best found parameters.
583575
"""
584576
check_is_fitted(self)
585-
return self.best_estimator_.predict(
586-
X, **_get_params_for_method(self, "predict", params)
587-
)
577+
return self.best_estimator_.predict(X)
588578

589579
@available_if(_estimator_has("predict_proba"))
590-
def predict_proba(self, X, **params):
580+
def predict_proba(self, X):
591581
"""Call predict_proba on the estimator with the best found parameters.
592582
593583
Only available if ``refit=True`` and the underlying estimator supports
@@ -599,14 +589,6 @@ def predict_proba(self, X, **params):
599589
Must fulfill the input assumptions of the
600590
underlying estimator.
601591
602-
**params : dict
603-
Parameters to be passed the underlying estimator's ``predict_proba``.
604-
605-
Only available if `enable_metadata_routing=True`. See the
606-
:ref:`User Guide <metadata_routing>`.
607-
608-
..versionadded:: 1.4
609-
610592
Returns
611593
-------
612594
y_pred : ndarray of shape (n_samples,) or (n_samples, n_classes)
@@ -615,12 +597,10 @@ def predict_proba(self, X, **params):
615597
to that in the fitted attribute :term:`classes_`.
616598
"""
617599
check_is_fitted(self)
618-
return self.best_estimator_.predict_proba(
619-
X, **_get_params_for_method(self, "predict_proba", params)
620-
)
600+
return self.best_estimator_.predict_proba(X)
621601

622602
@available_if(_estimator_has("predict_log_proba"))
623-
def predict_log_proba(self, X, **params):
603+
def predict_log_proba(self, X):
624604
"""Call predict_log_proba on the estimator with the best found parameters.
625605
626606
Only available if ``refit=True`` and the underlying estimator supports
@@ -632,14 +612,6 @@ def predict_log_proba(self, X, **params):
632612
Must fulfill the input assumptions of the
633613
underlying estimator.
634614
635-
**params : dict
636-
Parameters to be passed the underlying estimator's ``predict_log_proba``.
637-
638-
Only available if `enable_metadata_routing=True`. See the
639-
:ref:`User Guide <metadata_routing>`.
640-
641-
..versionadded:: 1.4
642-
643615
Returns
644616
-------
645617
y_pred : ndarray of shape (n_samples,) or (n_samples, n_classes)
@@ -648,12 +620,10 @@ def predict_log_proba(self, X, **params):
648620
corresponds to that in the fitted attribute :term:`classes_`.
649621
"""
650622
check_is_fitted(self)
651-
return self.best_estimator_.predict_log_proba(
652-
X, **_get_params_for_method(self, "predict_log_proba", params)
653-
)
623+
return self.best_estimator_.predict_log_proba(X)
654624

655625
@available_if(_estimator_has("decision_function"))
656-
def decision_function(self, X, **params):
626+
def decision_function(self, X):
657627
"""Call decision_function on the estimator with the best found parameters.
658628
659629
Only available if ``refit=True`` and the underlying estimator supports
@@ -665,14 +635,6 @@ def decision_function(self, X, **params):
665635
Must fulfill the input assumptions of the
666636
underlying estimator.
667637
668-
**params : dict
669-
Parameters to be passed the underlying estimator's ``decision_function``.
670-
671-
Only available if `enable_metadata_routing=True`. See the
672-
:ref:`User Guide <metadata_routing>`.
673-
674-
..versionadded:: 1.4
675-
676638
Returns
677639
-------
678640
y_score : ndarray of shape (n_samples,) or (n_samples, n_classes) \
@@ -681,12 +643,10 @@ def decision_function(self, X, **params):
681643
the best found parameters.
682644
"""
683645
check_is_fitted(self)
684-
return self.best_estimator_.decision_function(
685-
X, **_get_params_for_method(self, "decision_function", params)
686-
)
646+
return self.best_estimator_.decision_function(X)
687647

688648
@available_if(_estimator_has("transform"))
689-
def transform(self, X, **params):
649+
def transform(self, X):
690650
"""Call transform on the estimator with the best found parameters.
691651
692652
Only available if the underlying estimator supports ``transform`` and
@@ -698,27 +658,17 @@ def transform(self, X, **params):
698658
Must fulfill the input assumptions of the
699659
underlying estimator.
700660
701-
**params : dict
702-
Parameters to be passed the underlying estimator's ``transform``.
703-
704-
Only available if `enable_metadata_routing=True`. See the
705-
:ref:`User Guide <metadata_routing>`.
706-
707-
..versionadded:: 1.4
708-
709661
Returns
710662
-------
711663
Xt : {ndarray, sparse matrix} of shape (n_samples, n_features)
712664
`X` transformed in the new space based on the estimator with
713665
the best found parameters.
714666
"""
715667
check_is_fitted(self)
716-
return self.best_estimator_.transform(
717-
X, **_get_params_for_method(self, "transform", params)
718-
)
668+
return self.best_estimator_.transform(X)
719669

720670
@available_if(_estimator_has("inverse_transform"))
721-
def inverse_transform(self, Xt, **params):
671+
def inverse_transform(self, Xt):
722672
"""Call inverse_transform on the estimator with the best found params.
723673
724674
Only available if the underlying estimator implements
@@ -730,24 +680,14 @@ def inverse_transform(self, Xt, **params):
730680
Must fulfill the input assumptions of the
731681
underlying estimator.
732682
733-
**params : dict
734-
Parameters to be passed the underlying estimator's ``inverse_transform``.
735-
736-
Only available if `enable_metadata_routing=True`. See the
737-
:ref:`User Guide <metadata_routing>`.
738-
739-
..versionadded:: 1.4
740-
741683
Returns
742684
-------
743685
X : {ndarray, sparse matrix} of shape (n_samples, n_features)
744686
Result of the `inverse_transform` function for `Xt` based on the
745687
estimator with the best found parameters.
746688
"""
747689
check_is_fitted(self)
748-
return self.best_estimator_.inverse_transform(
749-
Xt, **_get_params_for_method(self, "inverse_transform", params)
750-
)
690+
return self.best_estimator_.inverse_transform(Xt)
751691

752692
@property
753693
def n_features_in_(self):

sklearn/tests/metadata_routing_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def check_recorded_metadata(obj, method, split_params=tuple(), **kwargs):
4949
split_params : tuple, default=empty
5050
specifies any parameters which are to be checked as being a subset
5151
of the original values.
52-
5352
"""
5453
records = getattr(obj, "_records", dict()).get(method, dict())
5554
assert set(kwargs.keys()) == set(records.keys())

sklearn/tests/test_metaestimators_metadata_routing.py

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,8 @@ def enable_slep006():
107107
"init_args": {"param_grid": {"alpha": [0.1, 0.2]}},
108108
"X": X,
109109
"y": y,
110-
"estimator_routing_methods": [
111-
"fit",
112-
"predict",
113-
"predict_proba",
114-
"predict_log_proba",
115-
"decision_function",
116-
],
117-
"preserves_metadata": False,
110+
"estimator_routing_methods": ["fit"],
111+
"preserves_metadata": "subset",
118112
"scorer_name": "scoring",
119113
"scorer_routing_methods": ["fit", "score"],
120114
"cv_name": "cv",
@@ -127,14 +121,8 @@ def enable_slep006():
127121
"init_args": {"param_distributions": {"alpha": [0.1, 0.2]}},
128122
"X": X,
129123
"y": y,
130-
"estimator_routing_methods": [
131-
"fit",
132-
"predict",
133-
"predict_proba",
134-
"predict_log_proba",
135-
"decision_function",
136-
],
137-
"preserves_metadata": False,
124+
"estimator_routing_methods": ["fit"],
125+
"preserves_metadata": "subset",
138126
"scorer_name": "scoring",
139127
"scorer_routing_methods": ["fit", "score"],
140128
"cv_name": "cv",
@@ -147,14 +135,8 @@ def enable_slep006():
147135
"init_args": {"param_grid": {"alpha": [0.1, 0.2]}},
148136
"X": X,
149137
"y": y,
150-
"estimator_routing_methods": [
151-
"fit",
152-
"predict",
153-
"predict_proba",
154-
"predict_log_proba",
155-
"decision_function",
156-
],
157-
"preserves_metadata": False,
138+
"estimator_routing_methods": ["fit"],
139+
"preserves_metadata": "subset",
158140
"scorer_name": "scoring",
159141
"scorer_routing_methods": ["fit", "score"],
160142
"cv_name": "cv",
@@ -167,14 +149,8 @@ def enable_slep006():
167149
"init_args": {"param_distributions": {"alpha": [0.1, 0.2]}},
168150
"X": X,
169151
"y": y,
170-
"estimator_routing_methods": [
171-
"fit",
172-
"predict",
173-
"predict_proba",
174-
"predict_log_proba",
175-
"decision_function",
176-
],
177-
"preserves_metadata": False,
152+
"estimator_routing_methods": ["fit"],
153+
"preserves_metadata": "subset",
178154
"scorer_name": "scoring",
179155
"scorer_routing_methods": ["fit", "score"],
180156
"cv_name": "cv",
@@ -193,10 +169,15 @@ def enable_slep006():
193169
- y: y-data to fit
194170
- estimator_routing_methods: list of all methods to check for routing metadata
195171
to the sub-estimator
196-
- preserves_metadata: Whether the metaestimator passes the metadata to the
197-
sub-estimator without modification or not. If it does, we check that the
198-
values are identical. If it doesn't, no check is performed. TODO Maybe
199-
something smarter could be done if the data is modified.
172+
- preserves_metadata:
173+
- True (default): the metaestimator passes the metadata to the
174+
sub-estimator without modification. We check that the values recorded by
175+
the sub-estimator are identical to what we've passed to the
176+
metaestimator.
177+
- False: no check is performed regarding values, we only check that a
178+
metadata with the expected names/keys are passed.
179+
- "subset": we check that the recorded metadata by the sub-estimator is a
180+
subset of what is passed to the metaestimator.
200181
- scorer_name: The name of the argument for the scorer
201182
- scorer_routing_methods: list of all methods to check for routing metadata
202183
to the scorer
@@ -341,6 +322,8 @@ def set_request(estimator, method_name):
341322

342323
for method_name in routing_methods:
343324
for key in ["sample_weight", "metadata"]:
325+
if method_name == "predict":
326+
pass
344327
val = {"sample_weight": sample_weight, "metadata": metadata}[key]
345328
method_kwargs = {key: val}
346329

@@ -360,12 +343,20 @@ def set_request(estimator, method_name):
360343
instance.fit(X, y)
361344
method(X, **method_kwargs)
362345

363-
if preserves_metadata:
364-
# sanity check that registry is not empty, or else the test
365-
# passes trivially
366-
assert registry
346+
# sanity check that registry is not empty, or else the test passes
347+
# trivially
348+
assert registry
349+
if preserves_metadata is True:
367350
for estimator in registry:
368351
check_recorded_metadata(estimator, method_name, **method_kwargs)
352+
elif preserves_metadata == "subset":
353+
for estimator in registry:
354+
check_recorded_metadata(
355+
estimator,
356+
method_name,
357+
split_params=method_kwargs.keys(),
358+
**method_kwargs,
359+
)
369360

370361

371362
@pytest.mark.parametrize("metaestimator", METAESTIMATORS, ids=METAESTIMATOR_IDS)

0 commit comments

Comments
 (0)