Skip to content

Add strata to Stratified*Split CV splitters #26821

@adrinjalali

Description

@adrinjalali

Right now Stratified*Split classes take y as the strata, while that's not always the case.

In train_test_split we allow a stratify arg (which I'm wondering if it should be called strata), which defines the groups samples belong to. And inside, we basically do this:

    cv = StratifiedShuffleSplit(test_size=n_test, train_size=n_train, random_state=random_state)

    train, test = next(cv.split(X=arrays[0], y=stratify))

    return list(
        chain.from_iterable(
            (_safe_indexing(a, train), _safe_indexing(a, test)) for a in arrays
        )
    )

As you can see, we're passing stratify as y to the splitter. I think it would make sense to add a strata arg to the splitters, and if None, we'd take values in y instead, as it is now.

Note that now that we have SLEP6, we won't need a separate class for them, and we'd simply request strata for the splitter:

cv = StratifiedShuffleSplit().set_split_request(strata=True)
...
GridSearchCV(model, param_grid, cv=cv).fit(X, y, strata=strata_values)
cross_validate(model, X, y, cv=cv, props={"strata": strata_values})

cc @marenwestermann

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions