-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
Closed
Labels
Description
Describe the bug
Bug happens only when early_stopping option is True.
Looks like y_val and y inside of _fit() training loop are ndarrays, but model detects column names of input pandas.DataFrame on self._validate_input step.
Steps/Code to Reproduce
from sklearn.neural_network import MLPClassifier, MLPRegressor
import pandas as pd
X = pd.DataFrame(data=[[i, i] for i in range(10)], columns=['colname_a', 'colname_b'])
y = pd.DataFrame(data=[[1] for i in range(10)], columns=['colname_y'])
print('training classifier')
model = MLPClassifier(
early_stopping=True,
validation_fraction=0.2
)
model.fit(X, y['colname_y'])
print('training regressor')
model = MLPRegressor(
early_stopping=True,
validation_fraction=0.2
)
model.fit(X, y['colname_y'])Expected Results
No messages like "UserWarning: X does not have valid feature names, but MLPClassifier was fitted with feature names"
Actual Results
multiple messages like
...venv/lib/python3.10/site-packages/sklearn/base.py:450: UserWarning: X does not have valid feature names, but MLPClassifier was fitted with feature names
warnings.warn(
Versions
System:
python: 3.10.6 (main, Nov 2 2022, 18:53:38) [GCC 11.3.0]
executable: /home/usopp/sinp/mineralogy-2022/main/venv/bin/python
machine: Linux-5.15.0-52-generic-x86_64-with-glibc2.35
Python dependencies:
sklearn: 1.1.3
pip: 22.2.2
setuptools: 62.3.4
numpy: 1.23.3
scipy: 1.9.1
Cython: None
pandas: 1.5.0
matplotlib: 3.6.0
joblib: 1.2.0
threadpoolctl: 3.1.0
Built with OpenMP: True
threadpoolctl info:
user_api: openmp
internal_api: openmp
prefix: libgomp
filepath: /home/usopp/sinp/mineralogy-2022/main/venv/lib/python3.10/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
version: None
num_threads: 12
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: /home/usopp/sinp/mineralogy-2022/main/venv/lib/python3.10/site-packages/numpy.libs/libopenblas64_p-r0-742d56dc.3.20.so
version: 0.3.20
threading_layer: pthreads
architecture: Haswell
num_threads: 12
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: /home/usopp/sinp/mineralogy-2022/main/venv/lib/python3.10/site-packages/scipy.libs/libopenblasp-r0-9f9f5dbc.3.18.so
version: 0.3.18
threading_layer: pthreads
architecture: Haswell
num_threads: 12