-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
Closed
Description
Currently, there is no support for **fit_params in the fit method of _BaseStacking:
scikit-learn/sklearn/ensemble/_stacking.py
Line 110 in fd23727
| def fit(self, X, y, sample_weight=None): |
As introduced in issue #15953 for _MultiOutputEstimator, it seems natural to extend the utility to stacking. A proposed implementation in the base stacking class is as follows:
from ..utils.validation import _check_fit_params
def fit(self, X, y, sample_weight=None, **fit_params):
# Right before predictions = Parallel...
if fit_params:
fit_params = _check_fit_params(X, fit_params)
else:
fit_params = (dict(sample_weight=sample_weight)
if sample_weight is not None
else None)
# Then, utilize fit_params in the parallelized cross_val_predictSubsequently, alter the fit methods for StackingClassifier and StackingRegressor such that they support **fit_params. If this is favorable, then I can write an implementation and start the pull request.