-
-
Notifications
You must be signed in to change notification settings - Fork 26.9k
LinearSVC does not correctly handle sample_weight under class_weight strategy 'balanced' #30056
Copy link
Copy link
Closed
Labels
Description
Describe the bug
LinearSVC does not pass sample weights through when computing class weights under the "balanced" strategy leading to sample weight invariance issues cross-linked to meta-issue #16298
Steps/Code to Reproduce
from sklearn.svm import LinearSVC
from sklearn.base import clone
from sklearn.datasets import make_classification
import numpy as np
rng = np.random.RandomState()
X, y = make_classification(
n_samples=100,
n_features=5,
n_informative=3,
n_classes=4,
random_state=0,
)
# Create dataset with repetitions and corresponding sample weights
sample_weight = rng.randint(0, 10, size=X.shape[0])
X_resampled_by_weights = np.repeat(X, sample_weight, axis=0)
y_resampled_by_weights = np.repeat(y, sample_weight)
est_sw = LinearSVC(dual=False,class_weight="balanced").fit(X, y, sample_weight=sample_weight)
est_dup = LinearSVC(dual=False,class_weight="balanced").fit(
X_resampled_by_weights, y_resampled_by_weights, sample_weight=None
)
np.testing.assert_allclose(est_sw.coef_, est_dup.coef_,rtol=1e-10,atol=1e-10)
np.testing.assert_allclose(
est_sw.decision_function(X_resampled_by_weights),
est_dup.decision_function(X_resampled_by_weights),
rtol=1e-10,
atol=1e-10
)Expected Results
No error thrown
Actual Results
AssertionError:
Not equal to tolerance rtol=1e-10, atol=1e-10
Mismatched elements: 20 / 20 (100%)
Max absolute difference among violations: 0.00818953
Max relative difference among violations: 0.10657042
ACTUAL: array([[ 0.157045, -0.399979, -0.050654, 0.236997, -0.313416],
[-0.038369, -0.169516, -0.239528, -0.164231, 0.29698 ],
[ 0.069654, 0.250218, 0.268922, -0.065565, -0.195888],
[-0.117921, 0.185563, 0.005148, 0.006144, 0.130577]])
DESIRED: array([[ 0.157595, -0.401087, -0.051018, 0.23653 , -0.313528],
[-0.041687, -0.169006, -0.243102, -0.16373 , 0.302628],
[ 0.065096, 0.245549, 0.260732, -0.061577, -0.188419],
[-0.117224, 0.184116, 0.004652, 0.005555, 0.130453]])
Versions
System:
python: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ]
executable: /Users/shrutinath/micromamba/envs/scikit-learn/bin/python
machine: macOS-14.3-arm64-arm-64bit
Python dependencies:
sklearn: 1.6.dev0
pip: 24.0
setuptools: 70.1.1
numpy: 2.0.0
scipy: 1.14.0
Cython: 3.0.10
pandas: 2.2.2
matplotlib: 3.9.0
joblib: 1.4.2
threadpoolctl: 3.5.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 8
prefix: libopenblas
...
num_threads: 8
prefix: libomp
filepath: /Users/shrutinath/micromamba/envs/scikit-learn/lib/libomp.dylib
version: None
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...Reactions are currently unavailable