Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
3ed49d5
FIX Removes double validation in fit
thomasjpfan Nov 6, 2021
7021d0c
DOC Adds whats new
thomasjpfan Nov 6, 2021
4b3d2f5
DOC Adds pr number
thomasjpfan Nov 6, 2021
7fd4f65
Merge remote-tracking branch 'upstream/main' into remove_warning_rand…
thomasjpfan Nov 8, 2021
d9ca88a
CLN Address comments
thomasjpfan Nov 8, 2021
a449a64
Merge remote-tracking branch 'upstream/main' into remove_warning_rand…
thomasjpfan Nov 10, 2021
45ed3ad
DOC Adds comment about CSR format for prediction
thomasjpfan Nov 10, 2021
27027a0
TST Updates common test to check every estimator that predicts during…
thomasjpfan Nov 10, 2021
a13f7ee
DOC Update whats new
thomasjpfan Nov 10, 2021
cde9373
Merge remote-tracking branch 'upstream/main' into remove_warning_rand…
thomasjpfan Nov 18, 2021
29aa492
ENH Use tocsr
thomasjpfan Nov 18, 2021
5c22dec
BUG Fixes bug
thomasjpfan Nov 18, 2021
fbf7006
Trigger CI
ogrisel Nov 19, 2021
c74fec0
Merge remote-tracking branch 'upstream/main' into remove_warning_rand…
thomasjpfan Nov 24, 2021
d0a5c4c
CLN Address comments
thomasjpfan Nov 27, 2021
4e3e8e8
Merge remote-tracking branch 'upstream/main' into remove_warning_rand…
thomasjpfan Nov 27, 2021
3c1d57f
TST Fixes test (Will fail now on CI)
thomasjpfan Nov 27, 2021
5a22b70
FIX Fixes issue same issue with MLP
thomasjpfan Nov 27, 2021
8cea8d8
XFAIL the MLP case
thomasjpfan Nov 27, 2021
5d323ac
REV Enable the other tests
thomasjpfan Nov 27, 2021
0849124
Merge remote-tracking branch 'upstream/main' into remove_warning_rand…
thomasjpfan Nov 29, 2021
fab1268
DOC Fix syntax error in whats_new
thomasjpfan Nov 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ Version 1.0.2

**In Development**

- |Fix| :class:`cluster.Birch`,
:class:`feature_selection.RFECV`, :class:`ensemble.RandomForestRegressor`,
:class:`ensemble.RandomForestClassifier`,
:class:`ensemble.GradientBoostingRegressor`, and
:class:`ensemble.GradientBoostingClassifier` do not raise warning when fitted
on a pandas DataFrame anymore. :pr:`21578` by `Thomas Fan`_.

Changelog
---------

Expand Down
6 changes: 5 additions & 1 deletion sklearn/cluster/_birch.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,10 @@ def predict(self, X):
"""
check_is_fitted(self)
X = self._validate_data(X, accept_sparse="csr", reset=False)
return self._predict(X)

def _predict(self, X):
"""Predict data using the ``centroids_`` of subclusters."""
kwargs = {"Y_norm_squared": self._subcluster_norms}

with config_context(assume_finite=True):
Expand Down Expand Up @@ -745,4 +749,4 @@ def _global_clustering(self, X=None):
self.subcluster_labels_ = clusterer.fit_predict(self.subcluster_centers_)

if compute_labels:
self.labels_ = self.predict(X)
self.labels_ = self._predict(X)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't a cleaner API be like self.predict(X, validate_input=False), or self.validate_input(predict=False).predict(X)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.predict(X, validate_input=False)

I don't think introducing a public arg for that is cleaner. I find it clean that we use a private function internally and expose a public function that does extra validation.

self.validate_input(predict=False).predict(X)

I don't get the predict=False arg, could explain more what you have in mind ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are other reasons why a public, or a developer API, would be nice to have when it comes to [skipping] validation: #16653 (comment)

The predict=False would kinda set a flag in the estimator to skip the validation in a certain method.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The predict=False would kinda set a flag in the estimator to skip the validation in a certain method.

I think adding more state to the estimator after __init__ is outside the scope of this PR, but we can use this PR as a motivation to do it. It would kind of be like "inference mode".

self.predict(X, validate_input=False)

I think it would be very nice to have this type of kwarg everywhere. It would be similar to the check_finite flag in SciPy. (Every year I see the "Scikit-learn is slow during prediction" and it comes down to the validation we do.)

In both cases, I do not think we should change public API with a bug fix PR.

6 changes: 4 additions & 2 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,10 @@ def _compute_oob_predictions(self, X, y):
oob_pred : ndarray of shape (n_samples, n_classes, n_outputs) or \
(n_samples, 1, n_outputs)
The OOB predictions.
"""
X = self._validate_data(X, dtype=DTYPE, accept_sparse="csr", reset=False)
"""
# Prediction requires X to be in CSR format
if issparse(X):
X = X.tocsr()

n_samples = y.shape[0]
n_outputs = self.n_outputs_
Expand Down
14 changes: 9 additions & 5 deletions sklearn/ensemble/_gb.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def _fit_stages(
loss_history = np.full(self.n_iter_no_change, np.inf)
# We create a generator to get the predictions for X_val after
# the addition of each successive stage
y_val_pred_iter = self._staged_raw_predict(X_val)
y_val_pred_iter = self._staged_raw_predict(X_val, check_input=False)

# perform boosting iterations
i = begin_at_stage
Expand Down Expand Up @@ -736,7 +736,7 @@ def _raw_predict(self, X):
predict_stages(self.estimators_, X, self.learning_rate, raw_predictions)
return raw_predictions

def _staged_raw_predict(self, X):
def _staged_raw_predict(self, X, check_input=True):
"""Compute raw predictions of ``X`` for each iteration.

This method allows monitoring (i.e. determine error on testing set)
Expand All @@ -749,6 +749,9 @@ def _staged_raw_predict(self, X):
``dtype=np.float32`` and if a sparse matrix is provided
to a sparse ``csr_matrix``.

check_input : bool, default=True
If False, the input arrays X will not be checked.

Returns
-------
raw_predictions : generator of ndarray of shape (n_samples, k)
Expand All @@ -757,9 +760,10 @@ def _staged_raw_predict(self, X):
Regression and binary classification are special cases with
``k == 1``, otherwise ``k==n_classes``.
"""
X = self._validate_data(
X, dtype=DTYPE, order="C", accept_sparse="csr", reset=False
)
if check_input:
X = self._validate_data(
X, dtype=DTYPE, order="C", accept_sparse="csr", reset=False
)
raw_predictions = self._raw_predict_init(X)
for i in range(self.estimators_.shape[0]):
predict_stage(self.estimators_, i, X, self.learning_rate, raw_predictions)
Expand Down
4 changes: 4 additions & 0 deletions sklearn/feature_selection/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def transform(self, X):
force_all_finite=not _safe_tags(self, key="allow_nan"),
reset=False,
)
return self._transform(X)

def _transform(self, X):
"""Reduce X to the selected features."""
mask = self.get_support()
if not mask.any():
warn(
Expand Down
2 changes: 1 addition & 1 deletion sklearn/feature_selection/_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ def fit(self, X, y, groups=None):
self.n_features_ = rfe.n_features_
self.ranking_ = rfe.ranking_
self.estimator_ = clone(self.estimator)
self.estimator_.fit(self.transform(X), y)
self.estimator_.fit(self._transform(X), y)

# reverse to stay consistent with before
scores_rev = scores[:, ::-1]
Expand Down
19 changes: 19 additions & 0 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,24 @@ def test_check_n_features_in_after_fitting(estimator):
check_n_features_in_after_fitting(estimator.__class__.__name__, estimator)


def _estimators_that_predict_in_fit():
for estimator in _tested_estimators():
est_params = set(estimator.get_params())
if "oob_score" in est_params:
yield estimator.set_params(oob_score=True, bootstrap=True)
elif "early_stopping" in est_params:
est = estimator.set_params(early_stopping=True, n_iter_no_change=1)
if est.__class__.__name__ in {"MLPClassifier", "MLPRegressor"}:
# TODO: FIX MLP to not check validation set during MLP
yield pytest.param(
est, marks=pytest.mark.xfail(msg="MLP still validates in fit")
)
else:
yield est
elif "n_iter_no_change" in est_params:
yield estimator.set_params(n_iter_no_change=1)


# NOTE: When running `check_dataframe_column_names_consistency` on a meta-estimator that
# delegates validation to a base estimator, the check is testing that the base estimator
# is checking for column name consistency.
Expand All @@ -340,6 +358,7 @@ def test_check_n_features_in_after_fitting(estimator):
_tested_estimators(),
[make_pipeline(LogisticRegression(C=1))],
list(_generate_search_cv_instances()),
_estimators_that_predict_in_fit(),
)
)

Expand Down
20 changes: 18 additions & 2 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3791,7 +3791,16 @@ def check_dataframe_column_names_consistency(name, estimator_orig):
else:
y = rng.randint(low=0, high=2, size=n_samples)
y = _enforce_estimator_tags_y(estimator, y)
estimator.fit(X, y)

# Check that calling `fit` does not raise any warnings about feature names.
with warnings.catch_warnings():
warnings.filterwarnings(
"error",
message="X does not have valid feature names",
category=UserWarning,
module="sklearn",
)
estimator.fit(X, y)

if not hasattr(estimator, "feature_names_in_"):
raise ValueError(
Expand Down Expand Up @@ -3853,6 +3862,12 @@ def check_dataframe_column_names_consistency(name, estimator_orig):
f"Feature names seen at fit time, yet now missing:\n- {min(names[3:])}\n",
),
]
params = {
key: value
for key, value in estimator.get_params().items()
if "early_stopping" in key
}
early_stopping_enabled = any(value is True for value in params.values())

for invalid_name, additional_message in invalid_names:
X_bad = pd.DataFrame(X, columns=invalid_name)
Expand All @@ -3876,7 +3891,8 @@ def check_dataframe_column_names_consistency(name, estimator_orig):
method(X_bad)

# partial_fit checks on second call
if not hasattr(estimator, "partial_fit"):
# Do not call partial fit if early_stopping is on
if not hasattr(estimator, "partial_fit") or early_stopping_enabled:
continue

estimator = clone(estimator_orig)
Expand Down