Skip to content

IterativeImputer InvalidIndexError on dev branch after setting transform_output #24923

@t-silvers

Description

@t-silvers

Describe the bug

After using sklearn.set_config(transform_output="pandas") to set output globally, IterativeImputer fails with an InvalidIndexError error.

Steps/Code to Reproduce

From IterativeImputer examples

import numpy as np
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer

imp_mean = IterativeImputer(random_state=0)
imp_mean.fit([[7, 2, 3], [4, np.nan, 6], [10, 5, 9]])
X = [[np.nan, 2, 3], [4, np.nan, 6], [10, np.nan, 9]]

# works
imp_mean.transform(X)

# fails with InvalidIndexError: (slice(None, None, None), array([1, 2]))
from sklearn import set_config
set_config(transform_output="pandas")
imp_mean.transform(X)

# fails
set_config(transform_output=None)
imp_mean.transform(X)

# fails
set_config(transform_output="default")
imp_mean.transform(X)

# still fails
imp_mean.transform(X).set_output(transform= None)

Expected Results

no error

Actual Results

In [9]: imp_mean.transform(X)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File /path/to/modules/packages/conda/4.6.14/envs/myenv/lib/python3.10/site-packages/pandas/core/indexes/base.py:3803, in Index.get_loc(self, key, method, tolerance)
   3802 try:
-> 3803     return self._engine.get_loc(casted_key)
   3804 except KeyError as err:

File /path/to/modules/packages/conda/4.6.14/envs/myenv/lib/python3.10/site-packages/pandas/_libs/index.pyx:138, in pandas._libs.index.IndexEngine.get_loc()

File /path/to/modules/packages/conda/4.6.14/envs/myenv/lib/python3.10/site-packages/pandas/_libs/index.pyx:144, in pandas._libs.index.IndexEngine.get_loc()

TypeError: '(slice(None, None, None), array([1, 2]))' is an invalid key

During handling of the above exception, another exception occurred:

InvalidIndexError                         Traceback (most recent call last)
Cell In [9], line 1
----> 1 imp_mean.transform(X).set_output(transform= None)

File ~/scikit-learn/sklearn/utils/_set_output.py:137, in _wrap_method_output.<locals>.wrapped(self, X, *args, **kwargs)
    135 @wraps(f)
    136 def wrapped(self, X, *args, **kwargs):
--> 137     data_to_wrap = f(self, X, *args, **kwargs)
    138     if isinstance(data_to_wrap, tuple):
    139         # only wrap the first output for cross decomposition
    140         return (
    141             _wrap_data_with_container(method, data_to_wrap[0], X, self),
    142             *data_to_wrap[1:],
    143         )

File ~/scikit-learn/sklearn/impute/_iterative.py:755, in IterativeImputer.transform(self, X)
    753 start_t = time()
    754 for it, estimator_triplet in enumerate(self.imputation_sequence_):
--> 755     Xt, _ = self._impute_one_feature(
    756         Xt,
    757         mask_missing_values,
    758         estimator_triplet.feat_idx,
    759         estimator_triplet.neighbor_feat_idx,
    760         estimator=estimator_triplet.estimator,
    761         fit_mode=False,
    762     )
    763     if not (it + 1) % imputations_per_round:
    764         if self.verbose > 1:

File ~/scikit-learn/sklearn/impute/_iterative.py:360, in IterativeImputer._impute_one_feature(self, X_filled, mask_missing_values, feat_idx, neighbor_feat_idx, estimator, fit_mode)
    357     return X_filled, estimator
    359 # get posterior samples if there is at least one missing value
--> 360 X_test = _safe_indexing(X_filled[:, neighbor_feat_idx], missing_row_mask)
    361 if self.sample_posterior:
    362     mus, sigmas = estimator.predict(X_test, return_std=True)

File /path/to/modules/packages/conda/4.6.14/envs/myenv/lib/python3.10/site-packages/pandas/core/frame.py:3804, in DataFrame.__getitem__(self, key)
   3802 if self.columns.nlevels > 1:
   3803     return self._getitem_multilevel(key)
-> 3804 indexer = self.columns.get_loc(key)
   3805 if is_integer(indexer):
   3806     indexer = [indexer]

File /path/to/modules/packages/conda/4.6.14/envs/myenv/lib/python3.10/site-packages/pandas/core/indexes/base.py:3810, in Index.get_loc(self, key, method, tolerance)
   3805         raise KeyError(key) from err
   3806     except TypeError:
   3807         # If we have a listlike key, _check_indexing_error will raise
   3808         #  InvalidIndexError. Otherwise we fall through and re-raise
   3809         #  the TypeError.
-> 3810         self._check_indexing_error(key)
   3811         raise
   3813 # GH#42269

File /path/to/modules/packages/conda/4.6.14/envs/myenv/lib/python3.10/site-packages/pandas/core/indexes/base.py:5966, in Index._check_indexing_error(self, key)
   5962 def _check_indexing_error(self, key):
   5963     if not is_scalar(key):
   5964         # if key is not a scalar, directly raise an error (the code below
   5965         # would convert to numpy arrays and raise later any way) - GH29926
-> 5966         raise InvalidIndexError(key)

InvalidIndexError: (slice(None, None, None), array([1, 2]))

Versions

System:
    python: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:35:26) [GCC 10.4.0]
executable: /path/to/modules/packages/conda/4.6.14/envs/myenv/bin/python3.10
   machine: Linux-3.10.0-1160.76.1.el7.x86_64-x86_64-with-glibc2.17

Python dependencies:
      sklearn: 1.2.dev0
          pip: 22.3.1
   setuptools: 65.5.1
        numpy: 1.23.4
        scipy: 1.9.3
       Cython: 0.29.32
       pandas: 1.5.1
   matplotlib: 3.6.2
       joblib: 1.2.0
threadpoolctl: 3.1.0

Built with OpenMP: True

threadpoolctl info:
       user_api: openmp
   internal_api: openmp
         prefix: libgomp
       filepath: /path/to/modules/packages/conda/4.6.14/envs/myenv/lib/libgomp.so.1.0.0
        version: None
    num_threads: 1

       user_api: blas
   internal_api: openblas
         prefix: libopenblas
       filepath: /path/to/modules/packages/conda/4.6.14/envs/myenv/lib/libopenblasp-r0.3.21.so
        version: 0.3.21
threading_layer: pthreads
   architecture: Zen
    num_threads: 1

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions