-
-
Notifications
You must be signed in to change notification settings - Fork 26.9k
How observations with sample_weight of zero influence the fit of HistGradientBoostingRegressor #24728
Description
Describe the bug
Hello,
I am trying to exclude some training observations by giving them a weight of zero through the "sample_weights" argument.
As I understand it, observations with a weight of 0 do not influence the training, so changes in their values should not affect the resulting model. However, this is not what I see if I train two models, each one with different values in the samples with weight of zero.
Does anyone know how exactly the algorithm implements sample weight inside the code?
Thanks a lot!
Steps/Code to Reproduce
import numpy as np
from sklearn.ensemble import HistGradientBoostingRegressor
rng = np.random.default_rng(12345)
X_train = rng.normal(size=(100, 3))
y_train = rng.normal(loc=10, size=(100, 1)).ravel()
X_test = rng.normal(size=(5, 3))
weights = np.repeat([0, 1], repeats=[10, 90])
regressor = HistGradientBoostingRegressor(random_state=123)
regressor.fit(X=X_train, y=y_train, sample_weight=weights)
print(regressor.predict(X=X_test))
X_train_2 = X_train.copy()
X_train_2[:10, :] = 50000
regressor = HistGradientBoostingRegressor(random_state=123)
regressor.fit(X=X_train_2, y=y_train, sample_weight=weights)
print(regressor.predict(X=X_test))Expected Results
array([10.02028706, 10.36184929, 9.45997232, 9.18761327, 9.93495853])
array([10.02028706, 10.36184929, 9.45997232, 9.18761327, 9.93495853])
Actual Results
array([10.02028706, 10.36184929, 9.45997232, 9.18761327, 9.93495853])
array([10.06163207, 10.17065387, 9.04865117, 9.15543545, 9.92940014])
Versions
System:
python: 3.8.13 (default, Mar 28 2022, 11:38:47) [GCC 7.5.0]
executable: /home/ubuntu/anaconda3/envs/skforecast/bin/python
machine: Linux-5.15.0-1022-aws-x86_64-with-glibc2.17
Python dependencies:
sklearn: 1.1.0
pip: 22.1.2
setuptools: 63.4.1
numpy: 1.23.0
scipy: 1.9.1
Cython: None
pandas: 1.4.0
matplotlib: 3.5.0
joblib: 1.1.0
threadpoolctl: 3.1.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: /home/ubuntu/anaconda3/envs/skforecast/lib/python3.8/site-packages/numpy.libs/libopenblas64_p-r0-742d56dc.3.20.so
version: 0.3.20
threading_layer: pthreads
architecture: SkylakeX
num_threads: 8
user_api: openmp
internal_api: openmp
prefix: libgomp
filepath: /home/ubuntu/anaconda3/envs/skforecast/lib/python3.8/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
version: None
num_threads: 8
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: /home/ubuntu/anaconda3/envs/skforecast/lib/python3.8/site-packages/scipy.libs/libopenblasp-r0-9f9f5dbc.3.18.so
version: 0.3.18
threading_layer: pthreads
architecture: SkylakeX
num_threads: 8