Skip to content

KNeighborsRegressor with metric="nan_euclidean" does not actually support NaN values #25319

@jameshfisher

Description

@jameshfisher

Describe the bug

KNeighborsRegressor claims to support these distance metrics, which includes one called "nan_euclidean", which presumably refers to this metric, which calculates "the euclidean distances in the presence of missing values." So, using this metric in KNeighborsRegressor should allow it to find nearest neighbors even in the presence of missing (NaN) values.

However, after setting metric="nan_euclidean", the fit() method raises an error complaining that the "Input contains NaN". But this should not be an error, because I've chosen a distance metric that supports NaN values.

Steps/Code to Reproduce

X = [[0, 1], [1, np.nan], [2, 3,], [3, 5]]
y = [0, 0, 1, 1]
from sklearn.neighbors import KNeighborsRegressor
neigh = KNeighborsRegressor(n_neighbors=2, metric="nan_euclidean")
neigh.fit(X, y)

Expected Results

No error is thrown.

Actual Results

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-227-07ade1a53cbc>](https://localhost:8080/#) in <module>
      3 from sklearn.neighbors import KNeighborsRegressor
      4 neigh = KNeighborsRegressor(n_neighbors=2, metric="nan_euclidean")
----> 5 neigh.fit(X, y)
      6 neigh.predict([[1.5, 3]])

5 frames
[/usr/local/lib/python3.8/dist-packages/sklearn/neighbors/_regression.py](https://localhost:8080/#) in fit(self, X, y)
    211         self.weights = _check_weights(self.weights)
    212 
--> 213         return self._fit(X, y)
    214 
    215     def predict(self, X):

[/usr/local/lib/python3.8/dist-packages/sklearn/neighbors/_base.py](https://localhost:8080/#) in _fit(self, X, y)
    398         if self._get_tags()["requires_y"]:
    399             if not isinstance(X, (KDTree, BallTree, NeighborsBase)):
--> 400                 X, y = self._validate_data(X, y, accept_sparse="csr", multi_output=True)
    401 
    402             if is_classifier(self):

[/usr/local/lib/python3.8/dist-packages/sklearn/base.py](https://localhost:8080/#) in _validate_data(self, X, y, reset, validate_separately, **check_params)
    579                 y = check_array(y, **check_y_params)
    580             else:
--> 581                 X, y = check_X_y(X, y, **check_params)
    582             out = X, y
    583 

[/usr/local/lib/python3.8/dist-packages/sklearn/utils/validation.py](https://localhost:8080/#) in check_X_y(X, y, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, estimator)
    962         raise ValueError("y cannot be None")
    963 
--> 964     X = check_array(
    965         X,
    966         accept_sparse=accept_sparse,

[/usr/local/lib/python3.8/dist-packages/sklearn/utils/validation.py](https://localhost:8080/#) in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator)
    798 
    799         if force_all_finite:
--> 800             _assert_all_finite(array, allow_nan=force_all_finite == "allow-nan")
    801 
    802     if ensure_min_samples > 0:

[/usr/local/lib/python3.8/dist-packages/sklearn/utils/validation.py](https://localhost:8080/#) in _assert_all_finite(X, allow_nan, msg_dtype)
    112         ):
    113             type_err = "infinity" if allow_nan else "NaN, infinity"
--> 114             raise ValueError(
    115                 msg_err.format(
    116                     type_err, msg_dtype if msg_dtype is not None else X.dtype

ValueError: Input contains NaN, infinity or a value too large for dtype('float64').

Versions

System:
    python: 3.8.16 (default, Dec  7 2022, 01:12:13)  [GCC 7.5.0]
executable: /usr/bin/python3
   machine: Linux-5.10.147+-x86_64-with-glibc2.27

Python dependencies:
          pip: 22.0.4
   setuptools: 57.4.0
      sklearn: 1.0.2
        numpy: 1.21.6
        scipy: 1.7.3
       Cython: 0.29.32
       pandas: 1.3.5
   matplotlib: 3.2.2
       joblib: 1.2.0
threadpoolctl: 3.1.0

Built with OpenMP: True

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions