-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
Open
Labels
Description
Right now Stratified*Split classes take y as the strata, while that's not always the case.
In train_test_split we allow a stratify arg (which I'm wondering if it should be called strata), which defines the groups samples belong to. And inside, we basically do this:
cv = StratifiedShuffleSplit(test_size=n_test, train_size=n_train, random_state=random_state)
train, test = next(cv.split(X=arrays[0], y=stratify))
return list(
chain.from_iterable(
(_safe_indexing(a, train), _safe_indexing(a, test)) for a in arrays
)
)As you can see, we're passing stratify as y to the splitter. I think it would make sense to add a strata arg to the splitters, and if None, we'd take values in y instead, as it is now.
Note that now that we have SLEP6, we won't need a separate class for them, and we'd simply request strata for the splitter:
cv = StratifiedShuffleSplit().set_split_request(strata=True)
...
GridSearchCV(model, param_grid, cv=cv).fit(X, y, strata=strata_values)
cross_validate(model, X, y, cv=cv, props={"strata": strata_values})lorentzenchr