Skip to content

Commit 6d27c79

Browse files
committed
generalize default values, create GroupConsuer
1 parent 29e6b07 commit 6d27c79

File tree

5 files changed

+203
-70
lines changed

5 files changed

+203
-70
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@ requires = [
1313
"numpy==1.17.3; python_version>='3.8' and platform_system=='AIX'",
1414
"scipy>=0.19.1",
1515
]
16+
17+
[tool.black]
18+
line-length = 79

sklearn/base.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,22 @@ def _pprint(params, offset=0, printer=repr):
157157

158158

159159
class _MetadataRequest:
160+
def _add_defaults(self, defaults):
161+
metadata_request = _standardize_metadata_request(
162+
self._metadata_request)
163+
defaults = _standardize_metadata_request(defaults)
164+
for method in defaults:
165+
if method not in metadata_request:
166+
metadata_request[method] = copy.deepcopy(defaults[method])
167+
else:
168+
for key in defaults[method]:
169+
for value in metadata_request[method]:
170+
if key in value:
171+
break
172+
else:
173+
metadata_request[method][key] = defaults[method][key]
174+
self._metadata_request = metadata_request
175+
160176
def get_metadata_request(self):
161177
"""Get requested data properties.
162178
@@ -167,18 +183,24 @@ def get_metadata_request(self):
167183
used. Under each key, there is a dict of the form
168184
``{input_param_name: required_param_name}``.
169185
"""
170-
try:
186+
if hasattr(self, '_metadata_request'):
171187
return _standardize_metadata_request(self._metadata_request)
172-
except AttributeError:
173-
pass
174188

175-
return _empty_metadata_request()
189+
self._metadata_request = _empty_metadata_request()
190+
191+
defaults = [x for x in dir(self)
192+
if x.startswith('_metadata_request__')]
193+
for attr in defaults:
194+
self._add_defaults(getattr(self, attr))
195+
return _standardize_metadata_request(self._metadata_request)
176196

177197

178198
class MetadataConsumer:
179199
def _request_key_for_method(self, *, method, param, user_provides):
200+
if user_provides is None:
201+
return
180202
if not hasattr(self, '_metadata_request'):
181-
self._metadata_request = _empty_metadata_request()
203+
self._metadata_request = self.get_metadata_request()
182204
self._metadata_request = _standardize_metadata_request(
183205
self._metadata_request)
184206

@@ -205,26 +227,24 @@ def _set_metadata_request(self, props):
205227

206228

207229
class SampleWeightConsumer(MetadataConsumer):
208-
def request_sample_weight(self, *, fit=True, score=False):
230+
def request_sample_weight(self, *, fit=None, score=None):
209231
"""Define how to receive sample_weight from a parent meta-estimator
210232
211-
When called with default arguments, fitting will be weighted,
212-
and the meta-estimator should be passed a fit parameter by the name
213-
'sample_weight'.
214-
215233
Parameters
216234
----------
217-
fit : string or bool, default=True
235+
fit : string or bool, default=None
218236
The fit parameter name that a meta-estimator should pass to this
219237
estimator as sample_weight. If true, the name will be
220238
'sample_weight'.
221239
If False, no fit parameter will be passed.
240+
If None, no change in routing.
222241
223-
score : string or bool, default=True
242+
score : string or bool, default=None
224243
The parameter name that a meta-estimator should pass to this
225244
estimator as sample_weight when calling its `score` method.
226245
If true, the name will be 'sample_weight'.
227-
If False, no fit parameter will be passed.
246+
If False, no score parameter will be passed.
247+
If None, no change in routing.
228248
229249
Returns
230250
-------

sklearn/metrics/_scorer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def make_scorer(score_func, *, greater_is_better=True, needs_proba=False,
585585
586586
request_props : list of strings, or dict of {str: str}, default=None
587587
A list of required properties, or a mapping of the form
588-
``{required_parameter: provided_parameter}``, or None.
588+
``{provided_metadata: required_metadata}``, or None.
589589
590590
**kwargs : additional arguments
591591
Additional parameters to be passed to score_func.

sklearn/model_selection/_split.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..utils.validation import check_array
2727
from ..utils.validation import _deprecate_positional_args
2828
from ..utils.multiclass import type_of_target
29-
from ..base import _pprint, _MetadataRequest
29+
from ..base import _pprint, _MetadataRequest, MetadataConsumer
3030

3131
__all__ = ['BaseCrossValidator',
3232
'KFold',
@@ -46,6 +46,34 @@
4646
'check_cv']
4747

4848

49+
class GroupsConsumer(MetadataConsumer):
50+
51+
_metadata_request__groups = {'split': ['groups']}
52+
53+
def request_groups(self, *, split=None):
54+
"""Define how to receive groups from the caller in `split` and
55+
`get_n_splits`.
56+
57+
Parameters
58+
----------
59+
split : string or bool, default=None
60+
The parameter name that a meta-estimator should pass to this
61+
estimator as groups when calling its `split` or `get_n_splits`
62+
method.
63+
If true, the name will be 'groups'.
64+
If False, no score parameter will be passed.
65+
If None, no change in routing.
66+
67+
Returns
68+
-------
69+
self
70+
"""
71+
self._request_key_for_method(method='split',
72+
param='groups',
73+
user_provides=split)
74+
return self
75+
76+
4977
class BaseCrossValidator(_MetadataRequest, metaclass=ABCMeta):
5078
"""Base class for all cross-validators
5179
@@ -444,7 +472,7 @@ def _iter_test_indices(self, X, y=None, groups=None):
444472
current = stop
445473

446474

447-
class GroupKFold(_BaseKFold):
475+
class GroupKFold(GroupsConsumer, _BaseKFold):
448476
"""K-fold iterator variant with non-overlapping groups.
449477
450478
The same group will not appear in two different folds (the number of
@@ -495,9 +523,6 @@ class GroupKFold(_BaseKFold):
495523
LeaveOneGroupOut : For splitting the data according to explicit
496524
domain-specific stratification of the dataset.
497525
"""
498-
499-
_metadata_request = {'split': ['groups']}
500-
501526
def __init__(self, n_splits=5):
502527
super().__init__(n_splits, shuffle=False, random_state=None)
503528

@@ -891,7 +916,7 @@ def split(self, X, y=None, groups=None):
891916
indices[test_start:test_start + test_size])
892917

893918

894-
class LeaveOneGroupOut(BaseCrossValidator):
919+
class LeaveOneGroupOut(GroupsConsumer, BaseCrossValidator):
895920
"""Leave One Group Out cross-validator
896921
897922
Provides train/test indices to split data according to a third-party
@@ -932,9 +957,6 @@ class LeaveOneGroupOut(BaseCrossValidator):
932957
[7 8]] [1 2] [1 2]
933958
934959
"""
935-
936-
_metadata_request = {'split': ['groups']}
937-
938960
def _iter_test_masks(self, X, y, groups):
939961
if groups is None:
940962
raise ValueError("The 'groups' parameter should not be None.")
@@ -1002,7 +1024,7 @@ def split(self, X, y=None, groups=None):
10021024
return super().split(X, y, groups)
10031025

10041026

1005-
class LeavePGroupsOut(BaseCrossValidator):
1027+
class LeavePGroupsOut(GroupsConsumer, BaseCrossValidator):
10061028
"""Leave P Group(s) Out cross-validator
10071029
10081030
Provides train/test indices to split data according to a third-party
@@ -1057,9 +1079,6 @@ class LeavePGroupsOut(BaseCrossValidator):
10571079
--------
10581080
GroupKFold : K-fold iterator variant with non-overlapping groups.
10591081
"""
1060-
1061-
_metadata_request = {'split': ['groups']}
1062-
10631082
def __init__(self, n_groups):
10641083
self.n_groups = n_groups
10651084

@@ -1135,7 +1154,7 @@ def split(self, X, y=None, groups=None):
11351154
return super().split(X, y, groups)
11361155

11371156

1138-
class _RepeatedSplits(metaclass=ABCMeta):
1157+
class _RepeatedSplits(_MetadataRequest, metaclass=ABCMeta):
11391158
"""Repeated splits for an arbitrary randomized CV splitter.
11401159
11411160
Repeats splits for cross-validators n times with different randomization
@@ -1349,7 +1368,7 @@ def __init__(self, *, n_splits=5, n_repeats=10, random_state=None):
13491368
n_splits=n_splits)
13501369

13511370

1352-
class BaseShuffleSplit(metaclass=ABCMeta):
1371+
class BaseShuffleSplit(_MetadataRequest, metaclass=ABCMeta):
13531372
"""Base class for ShuffleSplit and StratifiedShuffleSplit"""
13541373
@_deprecate_positional_args
13551374
def __init__(self, n_splits=10, *, test_size=None, train_size=None,
@@ -1510,7 +1529,7 @@ def _iter_indices(self, X, y=None, groups=None):
15101529
yield ind_train, ind_test
15111530

15121531

1513-
class GroupShuffleSplit(ShuffleSplit):
1532+
class GroupShuffleSplit(GroupsConsumer, ShuffleSplit):
15141533
'''Shuffle-Group(s)-Out cross-validation iterator
15151534
15161535
Provides randomized train/test indices to split data according to a
@@ -1576,8 +1595,6 @@ class GroupShuffleSplit(ShuffleSplit):
15761595
TRAIN: [2 3 4 5 6 7] TEST: [0 1]
15771596
TRAIN: [0 1 5 6 7] TEST: [2 3 4]
15781597
'''
1579-
_metadata_request = {'split': ['groups']}
1580-
15811598
@_deprecate_positional_args
15821599
def __init__(self, n_splits=5, *, test_size=None, train_size=None,
15831600
random_state=None):

0 commit comments

Comments
 (0)