Skip to content

ENH implement CalibrationCurveDisplay.from_cv_results#21211

Draft
glemaitre wants to merge 4 commits intoscikit-learn:mainfrom
glemaitre:is/calibrated_display_cv_results
Draft

ENH implement CalibrationCurveDisplay.from_cv_results#21211
glemaitre wants to merge 4 commits intoscikit-learn:mainfrom
glemaitre:is/calibrated_display_cv_results

Conversation

@glemaitre
Copy link
Copy Markdown
Member

@glemaitre glemaitre commented Oct 1, 2021

This PR intends to add the capability of plotting uncertainty of the different curves (calibration, precision-recall, roc, etc.) by using the results of cross-validation (i.e. the output of cross_validate).

TODO:

  • add a parameter return_indices in cross_validate to store the train-test indices. It is the safest way to keep track of the train-test splits in the case of stochastic splitting strategies.
  • add a method from_cv_results in the plotting display to take advantage of the CV computation.
  • add unit test for from_cv_results
  • add unit test for the new keyword parameters in CalibrationDisplay
  • add unit test for the new strategy of binning in calibration_curve

Usage example

# %%
import numpy as np
from sklearn.datasets import make_classification

X, y = make_classification(
    n_samples=10_000, weights=[0.1, 0.9], random_state=42, class_sep=1
)
sample_weight = np.zeros_like(y, dtype=np.float64)
sample_weight[y == 0] = 0.1
sample_weight[y == 1] = 0.9

# %%
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test, sw_train, sw_test = train_test_split(
    X, y, sample_weight, random_state=42
)

# %%
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV

calibration_method = "isotonic"
models = {
    "LR no weights": LogisticRegression(),
    "LR class weights": LogisticRegression(class_weight="balanced"),
    "Calibrated LR no weights": CalibratedClassifierCV(
        LogisticRegression(),
        method=calibration_method,
    ),
    "Calibrated LR class weights": CalibratedClassifierCV(
        LogisticRegression(class_weight="balanced"),
        method=calibration_method,
    ),
    "Calibrated LR sample weights": CalibratedClassifierCV(
        LogisticRegression(),
        method=calibration_method,
    ),
    "Calibrated LR class and sample weights": CalibratedClassifierCV(
        LogisticRegression(class_weight="balanced"),
        method=calibration_method,
    ),
}

# %%
import matplotlib.pyplot as plt
from sklearn.calibration import CalibrationDisplay
from sklearn.metrics import balanced_accuracy_score

fig, ax = plt.subplots()

calibration_display_params = {
    "n_bins": 20,
    "strategy": "quantile",
}
for name, model in models.items():
    if "sample weights" in name:
        model.fit(X_train, y_train, sample_weight=sw_train)
    else:
        model.fit(X_train, y_train)

    score = balanced_accuracy_score(y_test, model.predict(X_test))
    CalibrationDisplay.from_estimator(
        model,
        X_test,
        y_test,
        name=name + f" - {score:.3f}",
        ax=ax,
        **calibration_display_params,
    )
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), title="Model - Balanced Accuracy")
_ = fig.suptitle(f"Using {calibration_method} calibration")

# %%
from sklearn.model_selection import cross_validate
from sklearn.model_selection import KFold

cv_results = {}
cv = KFold(n_splits=5)
for name, model in models.items():
    if "sample weights" in name:
        fit_params = {"sample_weight": sample_weight}
    else:
        fit_params = {}
    cv_results[name] = cross_validate(
        model,
        X,
        y,
        cv=cv,
        fit_params=fit_params,
        scoring="balanced_accuracy",
        return_estimator=True,
        return_indices=True,
    )

# %%
fig, ax = plt.subplots()
for model_idx, (name, results) in enumerate(cv_results.items()):
    CalibrationDisplay.from_cv_results(
        results, X, y, ax=ax, name=name, **calibration_display_params
    )
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

# %%
fig, ax = plt.subplots()
for model_idx, (name, results) in enumerate(cv_results.items()):
    CalibrationDisplay.from_cv_results(
        results,
        X,
        y,
        ax=ax,
        name=name,
        plot_uncertainty_style="fill_between",
        **calibration_display_params,
    )
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

# %%
fig, ax = plt.subplots()
for model_idx, (name, results) in enumerate(cv_results.items()):
    CalibrationDisplay.from_cv_results(
        results,
        X,
        y,
        ax=ax,
        name=name,
        plot_uncertainty_style="lines",
        **calibration_display_params,
    )
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

# %%

001
002
003
004

@glemaitre glemaitre changed the title ENH use uncertainty estimate ENH use cv_results in the different curve display to add confidence intervals Oct 1, 2021
@glemaitre glemaitre marked this pull request as draft October 1, 2021 12:45
@ogrisel ogrisel self-requested a review October 19, 2021 09:18
calibrated classifier.

plot_uncertainty_style : {"errorbar", "fill_between", "lines"}, \
default="errorbar"
Copy link
Copy Markdown
Member

@ogrisel ogrisel Oct 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the default should plot_uncertainty_style="lines" as it's the easier to understand without being mislead. For plot_uncertainty_style="errorbar" and plot_uncertainty_style="fill_between" we need to know that it's based on the raw standard deviation (as opposed to a pseudo confidence interval based on the standard error of the mean for instance).

Copy link
Copy Markdown
Member

@ogrisel ogrisel Oct 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also accept plot_uncertainty_style=None to only plot the mean CV calibration curve without any uncertainty markers on the plot.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also plot_uncertainty_style="shade" or plot_uncertainty_style="shaded_area" might be easier to understand than plot_uncertainty_style="fill_between".

Copy link
Copy Markdown
Member

@ogrisel ogrisel Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that for the first iteration, I would rather only implement the "lines" strategy and not the others and remove this parameter from the public API.

For the record, this is the strategy followed when adding the from_cv_results method to RocCurveDisplay.

This way, we don't have to anticipate how the from_cv_results feature will interact or not with the orthogonal feature request to add fixed model uncertainty that results from the finite size sampling of the validation/calibration sets.

default="errorbar"
Style to plot the uncertainty information. Possibilities are:

- "errorbar": error bars representing one standard deviation;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two standard deviations: 1 above and 1 below.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume (I did not check ;)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked and I think I am right:

import numpy as np
import matplotlib.pyplot as plt


plt.errorbar(np.arange(5), np.ones(5), np.ones(5))

image

Comment on lines +173 to +174
return_indices : bool, default=False
Whether to return the train-test indices selected for each split.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coming from #21664, I agree return_indices is useful. (I wanted to do something like this recently).

@adrinjalali
Copy link
Copy Markdown
Member

@glemaitre this seems cool to be continued!

@glemaitre
Copy link
Copy Markdown
Member Author

Yep this also pet of the CZI proposal on inspection. This would be my next effort after the tuning threshold classifier.

@ogrisel ogrisel changed the title ENH use cv_results in the different curve display to add confidence intervals ENH implement CalibrationCurveDisplay.from_cv_results Dec 11, 2025
@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Dec 11, 2025

@AnneBeyer @lucyleeow: I think @glemaitre won't have time to revive this PR for the foreseeable future, but I think it's a very interesting feature.

Feel free to takeover this work in a new PR synced with the current state of the main branch.

@lucyleeow
Copy link
Copy Markdown
Member

lucyleeow commented Dec 12, 2025

I'm slowly adding from_cv_results to various displays. #30508 is still waiting for review and #32235 is sort of waiting on #30508 as they are a bit intertwined, with a fair bit if merge conflicts.
Not sure if it is worth starting this until at least #30508 is merged, due to merge conflicts, waiting for reviews etc

@ogrisel ogrisel added this to Labs Jan 26, 2026
@ogrisel ogrisel moved this to In progress in Labs Jan 26, 2026
@lucyleeow lucyleeow moved this from Discussion to Todo in Visualization and displays Feb 27, 2026
@adrinjalali
Copy link
Copy Markdown
Member

Since #30508 is merged, can this continue? Maybe @antoinebaker @AnneBaker @StefanieSenger ?

@adrinjalali adrinjalali moved this from In progress to Todo in Labs Mar 9, 2026
@StefanieSenger
Copy link
Copy Markdown
Member

I think we meant to tag @AnneBeyer, instead of @AnneBaker. :)

@AnneBeyer
Copy link
Copy Markdown
Contributor

Yes, I have this on my to-do list.

@lucyleeow
Copy link
Copy Markdown
Member

I do think DetCurveDisplay would be an easier place to start/better introduction though. Partly because I have a draft PR #32235 and partly because CalibrationCurveDisplay is a little bit different from the other binary displays.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Development

Successfully merging this pull request may close these issues.

7 participants