Skip to content

How observations with sample_weight of zero influence the fit of HistGradientBoostingRegressor #24728

@JoaquinAmatRodrigo

Description

@JoaquinAmatRodrigo

Describe the bug

Hello,

I am trying to exclude some training observations by giving them a weight of zero through the "sample_weights" argument.

As I understand it, observations with a weight of 0 do not influence the training, so changes in their values should not affect the resulting model. However, this is not what I see if I train two models, each one with different values in the samples with weight of zero.

Does anyone know how exactly the algorithm implements sample weight inside the code?

Thanks a lot!

Steps/Code to Reproduce

import numpy as np
from sklearn.ensemble import HistGradientBoostingRegressor

rng = np.random.default_rng(12345)
X_train = rng.normal(size=(100, 3))
y_train = rng.normal(loc=10, size=(100, 1)).ravel()
X_test  = rng.normal(size=(5, 3))
weights = np.repeat([0, 1], repeats=[10, 90])

regressor = HistGradientBoostingRegressor(random_state=123)
regressor.fit(X=X_train, y=y_train, sample_weight=weights)
print(regressor.predict(X=X_test))

X_train_2 = X_train.copy()
X_train_2[:10, :] = 50000
regressor = HistGradientBoostingRegressor(random_state=123)
regressor.fit(X=X_train_2, y=y_train, sample_weight=weights)
print(regressor.predict(X=X_test))

Expected Results

array([10.02028706, 10.36184929, 9.45997232, 9.18761327, 9.93495853])
array([10.02028706, 10.36184929, 9.45997232, 9.18761327, 9.93495853])

Actual Results

array([10.02028706, 10.36184929, 9.45997232, 9.18761327, 9.93495853])
array([10.06163207, 10.17065387, 9.04865117, 9.15543545, 9.92940014])

Versions

System:
    python: 3.8.13 (default, Mar 28 2022, 11:38:47)  [GCC 7.5.0]
executable: /home/ubuntu/anaconda3/envs/skforecast/bin/python
   machine: Linux-5.15.0-1022-aws-x86_64-with-glibc2.17

Python dependencies:
      sklearn: 1.1.0
          pip: 22.1.2
   setuptools: 63.4.1
        numpy: 1.23.0
        scipy: 1.9.1
       Cython: None
       pandas: 1.4.0
   matplotlib: 3.5.0
       joblib: 1.1.0
threadpoolctl: 3.1.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
         prefix: libopenblas
       filepath: /home/ubuntu/anaconda3/envs/skforecast/lib/python3.8/site-packages/numpy.libs/libopenblas64_p-r0-742d56dc.3.20.so
        version: 0.3.20
threading_layer: pthreads
   architecture: SkylakeX
    num_threads: 8

       user_api: openmp
   internal_api: openmp
         prefix: libgomp
       filepath: /home/ubuntu/anaconda3/envs/skforecast/lib/python3.8/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
        version: None
    num_threads: 8

       user_api: blas
   internal_api: openblas
         prefix: libopenblas
       filepath: /home/ubuntu/anaconda3/envs/skforecast/lib/python3.8/site-packages/scipy.libs/libopenblasp-r0-9f9f5dbc.3.18.so
        version: 0.3.18
threading_layer: pthreads
   architecture: SkylakeX
    num_threads: 8

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions