-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
Description
Describe the bug
I wanted to learn more about how cross_validation is implemented in sklearn and came across what I think is a faulty or confusing test.
The test test_cross_validate in model_selection/tests/test_validation.py creates two datasets, one regression model and one classification model. The test loops over the dataset x model combinations and fits and scores the evaluation metrics. In the loop est is used to represent the model under consideration (Lasso or SVC). However, in that process, clone(reg).fit(...) is called, so that the Lasso regression is fit and evaluated in both iterations, also for the classification data. If I'm not mistaken, this should be clone(est).fit(...), to train and evaluate the SVC classifier.
However, using clone on est actually fails the test, which I did not expect given the comment "It's okay to evaluate regression metrics on classification too".
To me this shows that the test is not testing what it is supposed to be testing.
Steps/Code to Reproduce
Run the test test_cross_validate()
Expected Results
The test should use the regressor on the regression data and the classier on the classification data, while passing the test.
Actual Results
The test uses the Lasso regressor on both the regression data and the classification data, while passing the test.
Versions
System:
python: 3.9.15 | packaged by conda-forge | (main, Nov 22 2022, 08:41:22) [MSC v.1929 64 bit (AMD64)]
executable: C:\Users\s.elgewily\Miniconda3\envs\sklearn-env2\python.exe
machine: Windows-10-10.0.19044-SP0
Python dependencies:
sklearn: 1.3.dev0
pip: 22.3.1
setuptools: 66.0.0
numpy: 1.24.1
scipy: 1.10.0
Cython: 0.29.33
pandas: None
matplotlib: None
joblib: 1.2.0
threadpoolctl: 3.1.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: mkl
prefix: libblas
filepath: C:\Users\s.elgewily\Miniconda3\envs\sklearn-env2\Library\bin\libblas.dll
version: 2022.1-Product
threading_layer: intel
num_threads: 4
user_api: openmp
internal_api: openmp
prefix: vcomp
filepath: C:\Users\s.elgewily\Miniconda3\envs\sklearn-env2\vcomp140.dll
version: None
num_threads: 8