-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
Closed
Description
Describe the bug
After using sklearn.set_config(transform_output="pandas") to set output globally, IterativeImputer fails with an InvalidIndexError error.
Steps/Code to Reproduce
From IterativeImputer examples
import numpy as np
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
imp_mean = IterativeImputer(random_state=0)
imp_mean.fit([[7, 2, 3], [4, np.nan, 6], [10, 5, 9]])
X = [[np.nan, 2, 3], [4, np.nan, 6], [10, np.nan, 9]]
# works
imp_mean.transform(X)
# fails with InvalidIndexError: (slice(None, None, None), array([1, 2]))
from sklearn import set_config
set_config(transform_output="pandas")
imp_mean.transform(X)
# fails
set_config(transform_output=None)
imp_mean.transform(X)
# fails
set_config(transform_output="default")
imp_mean.transform(X)
# still fails
imp_mean.transform(X).set_output(transform= None)Expected Results
no error
Actual Results
In [9]: imp_mean.transform(X)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File /path/to/modules/packages/conda/4.6.14/envs/myenv/lib/python3.10/site-packages/pandas/core/indexes/base.py:3803, in Index.get_loc(self, key, method, tolerance)
3802 try:
-> 3803 return self._engine.get_loc(casted_key)
3804 except KeyError as err:
File /path/to/modules/packages/conda/4.6.14/envs/myenv/lib/python3.10/site-packages/pandas/_libs/index.pyx:138, in pandas._libs.index.IndexEngine.get_loc()
File /path/to/modules/packages/conda/4.6.14/envs/myenv/lib/python3.10/site-packages/pandas/_libs/index.pyx:144, in pandas._libs.index.IndexEngine.get_loc()
TypeError: '(slice(None, None, None), array([1, 2]))' is an invalid key
During handling of the above exception, another exception occurred:
InvalidIndexError Traceback (most recent call last)
Cell In [9], line 1
----> 1 imp_mean.transform(X).set_output(transform= None)
File ~/scikit-learn/sklearn/utils/_set_output.py:137, in _wrap_method_output.<locals>.wrapped(self, X, *args, **kwargs)
135 @wraps(f)
136 def wrapped(self, X, *args, **kwargs):
--> 137 data_to_wrap = f(self, X, *args, **kwargs)
138 if isinstance(data_to_wrap, tuple):
139 # only wrap the first output for cross decomposition
140 return (
141 _wrap_data_with_container(method, data_to_wrap[0], X, self),
142 *data_to_wrap[1:],
143 )
File ~/scikit-learn/sklearn/impute/_iterative.py:755, in IterativeImputer.transform(self, X)
753 start_t = time()
754 for it, estimator_triplet in enumerate(self.imputation_sequence_):
--> 755 Xt, _ = self._impute_one_feature(
756 Xt,
757 mask_missing_values,
758 estimator_triplet.feat_idx,
759 estimator_triplet.neighbor_feat_idx,
760 estimator=estimator_triplet.estimator,
761 fit_mode=False,
762 )
763 if not (it + 1) % imputations_per_round:
764 if self.verbose > 1:
File ~/scikit-learn/sklearn/impute/_iterative.py:360, in IterativeImputer._impute_one_feature(self, X_filled, mask_missing_values, feat_idx, neighbor_feat_idx, estimator, fit_mode)
357 return X_filled, estimator
359 # get posterior samples if there is at least one missing value
--> 360 X_test = _safe_indexing(X_filled[:, neighbor_feat_idx], missing_row_mask)
361 if self.sample_posterior:
362 mus, sigmas = estimator.predict(X_test, return_std=True)
File /path/to/modules/packages/conda/4.6.14/envs/myenv/lib/python3.10/site-packages/pandas/core/frame.py:3804, in DataFrame.__getitem__(self, key)
3802 if self.columns.nlevels > 1:
3803 return self._getitem_multilevel(key)
-> 3804 indexer = self.columns.get_loc(key)
3805 if is_integer(indexer):
3806 indexer = [indexer]
File /path/to/modules/packages/conda/4.6.14/envs/myenv/lib/python3.10/site-packages/pandas/core/indexes/base.py:3810, in Index.get_loc(self, key, method, tolerance)
3805 raise KeyError(key) from err
3806 except TypeError:
3807 # If we have a listlike key, _check_indexing_error will raise
3808 # InvalidIndexError. Otherwise we fall through and re-raise
3809 # the TypeError.
-> 3810 self._check_indexing_error(key)
3811 raise
3813 # GH#42269
File /path/to/modules/packages/conda/4.6.14/envs/myenv/lib/python3.10/site-packages/pandas/core/indexes/base.py:5966, in Index._check_indexing_error(self, key)
5962 def _check_indexing_error(self, key):
5963 if not is_scalar(key):
5964 # if key is not a scalar, directly raise an error (the code below
5965 # would convert to numpy arrays and raise later any way) - GH29926
-> 5966 raise InvalidIndexError(key)
InvalidIndexError: (slice(None, None, None), array([1, 2]))
Versions
System:
python: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:35:26) [GCC 10.4.0]
executable: /path/to/modules/packages/conda/4.6.14/envs/myenv/bin/python3.10
machine: Linux-3.10.0-1160.76.1.el7.x86_64-x86_64-with-glibc2.17
Python dependencies:
sklearn: 1.2.dev0
pip: 22.3.1
setuptools: 65.5.1
numpy: 1.23.4
scipy: 1.9.3
Cython: 0.29.32
pandas: 1.5.1
matplotlib: 3.6.2
joblib: 1.2.0
threadpoolctl: 3.1.0
Built with OpenMP: True
threadpoolctl info:
user_api: openmp
internal_api: openmp
prefix: libgomp
filepath: /path/to/modules/packages/conda/4.6.14/envs/myenv/lib/libgomp.so.1.0.0
version: None
num_threads: 1
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: /path/to/modules/packages/conda/4.6.14/envs/myenv/lib/libopenblasp-r0.3.21.so
version: 0.3.21
threading_layer: pthreads
architecture: Zen
num_threads: 1