|
26 | 26 | from ..utils.validation import check_array |
27 | 27 | from ..utils.validation import _deprecate_positional_args |
28 | 28 | from ..utils.multiclass import type_of_target |
29 | | -from ..base import _pprint, _MetadataRequest |
| 29 | +from ..base import _pprint, _MetadataRequest, MetadataConsumer |
30 | 30 |
|
31 | 31 | __all__ = ['BaseCrossValidator', |
32 | 32 | 'KFold', |
|
46 | 46 | 'check_cv'] |
47 | 47 |
|
48 | 48 |
|
| 49 | +class GroupsConsumer(MetadataConsumer): |
| 50 | + |
| 51 | + _metadata_request__groups = {'split': ['groups']} |
| 52 | + |
| 53 | + def request_groups(self, *, split=None): |
| 54 | + """Define how to receive groups from the caller in `split` and |
| 55 | + `get_n_splits`. |
| 56 | +
|
| 57 | + Parameters |
| 58 | + ---------- |
| 59 | + split : string or bool, default=None |
| 60 | + The parameter name that a meta-estimator should pass to this |
| 61 | + estimator as groups when calling its `split` or `get_n_splits` |
| 62 | + method. |
| 63 | + If true, the name will be 'groups'. |
| 64 | + If False, no score parameter will be passed. |
| 65 | + If None, no change in routing. |
| 66 | +
|
| 67 | + Returns |
| 68 | + ------- |
| 69 | + self |
| 70 | + """ |
| 71 | + self._request_key_for_method(method='split', |
| 72 | + param='groups', |
| 73 | + user_provides=split) |
| 74 | + return self |
| 75 | + |
| 76 | + |
49 | 77 | class BaseCrossValidator(_MetadataRequest, metaclass=ABCMeta): |
50 | 78 | """Base class for all cross-validators |
51 | 79 |
|
@@ -444,7 +472,7 @@ def _iter_test_indices(self, X, y=None, groups=None): |
444 | 472 | current = stop |
445 | 473 |
|
446 | 474 |
|
447 | | -class GroupKFold(_BaseKFold): |
| 475 | +class GroupKFold(GroupsConsumer, _BaseKFold): |
448 | 476 | """K-fold iterator variant with non-overlapping groups. |
449 | 477 |
|
450 | 478 | The same group will not appear in two different folds (the number of |
@@ -495,9 +523,6 @@ class GroupKFold(_BaseKFold): |
495 | 523 | LeaveOneGroupOut : For splitting the data according to explicit |
496 | 524 | domain-specific stratification of the dataset. |
497 | 525 | """ |
498 | | - |
499 | | - _metadata_request = {'split': ['groups']} |
500 | | - |
501 | 526 | def __init__(self, n_splits=5): |
502 | 527 | super().__init__(n_splits, shuffle=False, random_state=None) |
503 | 528 |
|
@@ -891,7 +916,7 @@ def split(self, X, y=None, groups=None): |
891 | 916 | indices[test_start:test_start + test_size]) |
892 | 917 |
|
893 | 918 |
|
894 | | -class LeaveOneGroupOut(BaseCrossValidator): |
| 919 | +class LeaveOneGroupOut(GroupsConsumer, BaseCrossValidator): |
895 | 920 | """Leave One Group Out cross-validator |
896 | 921 |
|
897 | 922 | Provides train/test indices to split data according to a third-party |
@@ -932,9 +957,6 @@ class LeaveOneGroupOut(BaseCrossValidator): |
932 | 957 | [7 8]] [1 2] [1 2] |
933 | 958 |
|
934 | 959 | """ |
935 | | - |
936 | | - _metadata_request = {'split': ['groups']} |
937 | | - |
938 | 960 | def _iter_test_masks(self, X, y, groups): |
939 | 961 | if groups is None: |
940 | 962 | raise ValueError("The 'groups' parameter should not be None.") |
@@ -1002,7 +1024,7 @@ def split(self, X, y=None, groups=None): |
1002 | 1024 | return super().split(X, y, groups) |
1003 | 1025 |
|
1004 | 1026 |
|
1005 | | -class LeavePGroupsOut(BaseCrossValidator): |
| 1027 | +class LeavePGroupsOut(GroupsConsumer, BaseCrossValidator): |
1006 | 1028 | """Leave P Group(s) Out cross-validator |
1007 | 1029 |
|
1008 | 1030 | Provides train/test indices to split data according to a third-party |
@@ -1057,9 +1079,6 @@ class LeavePGroupsOut(BaseCrossValidator): |
1057 | 1079 | -------- |
1058 | 1080 | GroupKFold : K-fold iterator variant with non-overlapping groups. |
1059 | 1081 | """ |
1060 | | - |
1061 | | - _metadata_request = {'split': ['groups']} |
1062 | | - |
1063 | 1082 | def __init__(self, n_groups): |
1064 | 1083 | self.n_groups = n_groups |
1065 | 1084 |
|
@@ -1135,7 +1154,7 @@ def split(self, X, y=None, groups=None): |
1135 | 1154 | return super().split(X, y, groups) |
1136 | 1155 |
|
1137 | 1156 |
|
1138 | | -class _RepeatedSplits(metaclass=ABCMeta): |
| 1157 | +class _RepeatedSplits(_MetadataRequest, metaclass=ABCMeta): |
1139 | 1158 | """Repeated splits for an arbitrary randomized CV splitter. |
1140 | 1159 |
|
1141 | 1160 | Repeats splits for cross-validators n times with different randomization |
@@ -1349,7 +1368,7 @@ def __init__(self, *, n_splits=5, n_repeats=10, random_state=None): |
1349 | 1368 | n_splits=n_splits) |
1350 | 1369 |
|
1351 | 1370 |
|
1352 | | -class BaseShuffleSplit(metaclass=ABCMeta): |
| 1371 | +class BaseShuffleSplit(_MetadataRequest, metaclass=ABCMeta): |
1353 | 1372 | """Base class for ShuffleSplit and StratifiedShuffleSplit""" |
1354 | 1373 | @_deprecate_positional_args |
1355 | 1374 | def __init__(self, n_splits=10, *, test_size=None, train_size=None, |
@@ -1510,7 +1529,7 @@ def _iter_indices(self, X, y=None, groups=None): |
1510 | 1529 | yield ind_train, ind_test |
1511 | 1530 |
|
1512 | 1531 |
|
1513 | | -class GroupShuffleSplit(ShuffleSplit): |
| 1532 | +class GroupShuffleSplit(GroupsConsumer, ShuffleSplit): |
1514 | 1533 | '''Shuffle-Group(s)-Out cross-validation iterator |
1515 | 1534 |
|
1516 | 1535 | Provides randomized train/test indices to split data according to a |
@@ -1576,8 +1595,6 @@ class GroupShuffleSplit(ShuffleSplit): |
1576 | 1595 | TRAIN: [2 3 4 5 6 7] TEST: [0 1] |
1577 | 1596 | TRAIN: [0 1 5 6 7] TEST: [2 3 4] |
1578 | 1597 | ''' |
1579 | | - _metadata_request = {'split': ['groups']} |
1580 | | - |
1581 | 1598 | @_deprecate_positional_args |
1582 | 1599 | def __init__(self, n_splits=5, *, test_size=None, train_size=None, |
1583 | 1600 | random_state=None): |
|
0 commit comments