Skip to content
Merged
6 changes: 5 additions & 1 deletion doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -618,14 +618,18 @@ Changelog
:class:`linear_model.ARDRegression` now preserve float32 dtype. :pr:`9087` by
:user:`Arthur Imbert <Henley13>` and :pr:`22525` by :user:`Meekail Zain <micky774>`.

- |Feature| :class:`ElasticNet`, :class:`ElasticNetCV`, :class:`Lasso` and
:class:`LassoCV` support `sample_weight` for sparse input `X`.
:pr:`22808` by :user:`Christian Lorentzen <lorentzenchr>`.

- |Fix| The `coef_` and `intercept_` attributes of :class:`LinearRegression` are now
correctly computed in the presence of sample weights when the input is sparse.
:pr:`22891` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

- |Fix| The `coef_` and `intercept_` attributes of :class:`Ridge` with
`solver="sparse_cg"` and `solver="lbfgs"` are now correctly computed in the presence
of sample weights when the input is sparse.
:pr:`22899` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
:pr:`22899` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

:mod:`sklearn.manifold`
.......................
Expand Down
28 changes: 19 additions & 9 deletions sklearn/linear_model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def _preprocess_data(
normalize=False,
copy=True,
sample_weight=None,
return_mean=False,
check_input=True,
):
"""Center and scale data.
Expand All @@ -231,7 +230,7 @@ def _preprocess_data(

X_scale is the L2 norm of X - X_offset. If sample_weight is not None,
then the weighted mean of X and y is zero, and not the mean itself. If
return_mean=True, the mean, eventually weighted, is returned, independently
fit_intercept=True, the mean, eventually weighted, is returned, independently
of whether X was centered (option used for optimization with sparse data in
coordinate_descend).

Expand Down Expand Up @@ -271,8 +270,6 @@ def _preprocess_data(
if fit_intercept:
if sp.issparse(X):
X_offset, X_var = mean_variance_axis(X, axis=0, weights=sample_weight)
if not return_mean:
X_offset[:] = X.dtype.type(0)
else:
if normalize:
X_offset, X_var, _ = _incremental_mean_and_var(
Expand Down Expand Up @@ -328,7 +325,18 @@ def _preprocess_data(
def _rescale_data(X, y, sample_weight):
"""Rescale data sample-wise by square root of sample_weight.

For many linear models, this enables easy support for sample_weight.
For many linear models, this enables easy support for sample_weight because

(y - X w)' S (y - X w)

with S = diag(sample_weight) becomes

||y_rescaled - X_rescaled w||_2^2

when setting

y_rescaled = sqrt(S) y
X_rescaled = sqrt(S) X

Returns
-------
Expand Down Expand Up @@ -687,7 +695,6 @@ def fit(self, X, y, sample_weight=None):
normalize=_normalize,
copy=self.copy_X,
sample_weight=sample_weight,
return_mean=True,
)

# Sample weight can be implemented via a simple rescaling.
Expand Down Expand Up @@ -824,8 +831,8 @@ def _pre_fit(
fit_intercept=fit_intercept,
normalize=normalize,
copy=False,
return_mean=True,
check_input=check_input,
sample_weight=sample_weight,
)
else:
# copy was done in fit if necessary
Expand All @@ -838,8 +845,11 @@ def _pre_fit(
check_input=check_input,
sample_weight=sample_weight,
)
if sample_weight is not None:
X, y, _ = _rescale_data(X, y, sample_weight=sample_weight)
# Rescale only in dense case. Sparse cd solver directly deals with
# sample_weight.
if sample_weight is not None:
# This triggers copies anyway.
X, y, _ = _rescale_data(X, y, sample_weight=sample_weight)

# FIXME: 'normalize' to be removed in 1.2
if hasattr(precompute, "__array__"):
Expand Down
154 changes: 109 additions & 45 deletions sklearn/linear_model/_cd_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -260,22 +260,45 @@ def enet_coordinate_descent(floating[::1] w,
return w, gap, tol, n_iter + 1


def sparse_enet_coordinate_descent(floating [::1] w,
floating alpha, floating beta,
np.ndarray[floating, ndim=1, mode='c'] X_data,
np.ndarray[int, ndim=1, mode='c'] X_indices,
np.ndarray[int, ndim=1, mode='c'] X_indptr,
np.ndarray[floating, ndim=1] y,
floating[:] X_mean, int max_iter,
floating tol, object rng, bint random=0,
bint positive=0):
def sparse_enet_coordinate_descent(
floating [::1] w,
floating alpha,
floating beta,
np.ndarray[floating, ndim=1, mode='c'] X_data,
np.ndarray[int, ndim=1, mode='c'] X_indices,
np.ndarray[int, ndim=1, mode='c'] X_indptr,
floating[::1] y,
floating[::1] sample_weight,
floating[::1] X_mean,
int max_iter,
floating tol,
object rng,
bint random=0,
bint positive=0,
):
"""Cython version of the coordinate descent algorithm for Elastic-Net

We minimize:

(1/2) * norm(y - X w, 2)^2 + alpha norm(w, 1) + (beta/2) * norm(w, 2)^2
1/2 * norm(y - Z w, 2)^2 + alpha * norm(w, 1) + (beta/2) * norm(w, 2)^2

where Z = X - X_mean.
With sample weights sw, this becomes

1/2 * sum(sw * (y - Z w)^2, axis=0) + alpha * norm(w, 1)
+ (beta/2) * norm(w, 2)^2

and X_mean is the weighted average of X (per column).
"""
# Notes for sample_weight:
# For dense X, one centers X and y and then rescales them by sqrt(sample_weight).
# Here, for sparse X, we get the sample_weight averaged center X_mean. We take care
# that every calculation results as if we had rescaled y and X (and therefore also
# X_mean) by sqrt(sample_weight) without actually calculating the square root.
# We work with:
# yw = sample_weight
# R = sample_weight * residual
# norm_cols_X = np.sum(sample_weight * (X - X_mean)**2, axis=0)

# get the data information into easy vars
cdef unsigned int n_samples = y.shape[0]
Expand All @@ -289,18 +312,17 @@ def sparse_enet_coordinate_descent(floating [::1] w,
cdef unsigned int endptr

# initial value of the residuals
cdef floating[:] R = y.copy()

cdef floating[:] X_T_R
cdef floating[:] XtA
# R = y - Zw, weighted version R = sample_weight * (y - Zw)
cdef floating[::1] R
cdef floating[::1] XtA
cdef floating[::1] yw

if floating is float:
dtype = np.float32
else:
dtype = np.float64

norm_cols_X = np.zeros(n_features, dtype=dtype)
X_T_R = np.zeros(n_features, dtype=dtype)
XtA = np.zeros(n_features, dtype=dtype)

cdef floating tmp
Expand All @@ -324,6 +346,14 @@ def sparse_enet_coordinate_descent(floating [::1] w,
cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX)
cdef UINT32_t* rand_r_state = &rand_r_state_seed
cdef bint center = False
cdef bint no_sample_weights = sample_weight is None

if no_sample_weights:
yw = y
R = y.copy()
else:
yw = np.multiply(sample_weight, y)
R = yw.copy()

with nogil:
# center = (X_mean != 0).any()
Expand All @@ -338,19 +368,32 @@ def sparse_enet_coordinate_descent(floating [::1] w,
normalize_sum = 0.0
w_ii = w[ii]

for jj in range(startptr, endptr):
normalize_sum += (X_data[jj] - X_mean_ii) ** 2
R[X_indices[jj]] -= X_data[jj] * w_ii
norm_cols_X[ii] = normalize_sum + \
(n_samples - endptr + startptr) * X_mean_ii ** 2

if center:
for jj in range(n_samples):
R[jj] += X_mean_ii * w_ii
if no_sample_weights:
for jj in range(startptr, endptr):
normalize_sum += (X_data[jj] - X_mean_ii) ** 2
R[X_indices[jj]] -= X_data[jj] * w_ii
norm_cols_X[ii] = normalize_sum + \
(n_samples - endptr + startptr) * X_mean_ii ** 2
if center:
for jj in range(n_samples):
R[jj] += X_mean_ii * w_ii
else:
for jj in range(startptr, endptr):
tmp = sample_weight[X_indices[jj]]
# second term will be subtracted by loop over range(n_samples)
normalize_sum += (tmp * (X_data[jj] - X_mean_ii) ** 2
- tmp * X_mean_ii ** 2)
R[X_indices[jj]] -= tmp * X_data[jj] * w_ii
if center:
for jj in range(n_samples):
normalize_sum += sample_weight[jj] * X_mean_ii ** 2
R[jj] += sample_weight[jj] * X_mean_ii * w_ii
norm_cols_X[ii] = normalize_sum
startptr = endptr

# tol *= np.dot(y, y)
tol *= _dot(n_samples, &y[0], 1, &y[0], 1)
# with sample weights: tol *= y @ (sw * y)
tol *= _dot(n_samples, &y[0], 1, &yw[0], 1)

for n_iter in range(max_iter):

Expand All @@ -373,11 +416,19 @@ def sparse_enet_coordinate_descent(floating [::1] w,

if w_ii != 0.0:
# R += w_ii * X[:,ii]
for jj in range(startptr, endptr):
R[X_indices[jj]] += X_data[jj] * w_ii
if center:
for jj in range(n_samples):
R[jj] -= X_mean_ii * w_ii
if no_sample_weights:
for jj in range(startptr, endptr):
R[X_indices[jj]] += X_data[jj] * w_ii
if center:
for jj in range(n_samples):
R[jj] -= X_mean_ii * w_ii
else:
for jj in range(startptr, endptr):
tmp = sample_weight[X_indices[jj]]
R[X_indices[jj]] += tmp * X_data[jj] * w_ii
if center:
for jj in range(n_samples):
R[jj] -= sample_weight[jj] * X_mean_ii * w_ii

# tmp = (X[:,ii] * R).sum()
tmp = 0.0
Expand All @@ -398,20 +449,25 @@ def sparse_enet_coordinate_descent(floating [::1] w,

if w[ii] != 0.0:
# R -= w[ii] * X[:,ii] # Update residual
for jj in range(startptr, endptr):
R[X_indices[jj]] -= X_data[jj] * w[ii]

if center:
for jj in range(n_samples):
R[jj] += X_mean_ii * w[ii]
if no_sample_weights:
for jj in range(startptr, endptr):
R[X_indices[jj]] -= X_data[jj] * w[ii]
if center:
for jj in range(n_samples):
R[jj] += X_mean_ii * w[ii]
else:
for jj in range(startptr, endptr):
tmp = sample_weight[X_indices[jj]]
R[X_indices[jj]] -= tmp * X_data[jj] * w[ii]
if center:
for jj in range(n_samples):
R[jj] += sample_weight[jj] * X_mean_ii * w[ii]

# update the maximum absolute coefficient update
d_w_ii = fabs(w[ii] - w_ii)
if d_w_ii > d_w_max:
d_w_max = d_w_ii
d_w_max = fmax(d_w_max, d_w_ii)

if fabs(w[ii]) > w_max:
w_max = fabs(w[ii])
w_max = fmax(w_max, fabs(w[ii]))

if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1:
# the biggest coordinate update of this iteration was smaller than
Expand All @@ -424,22 +480,30 @@ def sparse_enet_coordinate_descent(floating [::1] w,
for jj in range(n_samples):
R_sum += R[jj]

# XtA = X.T @ R - beta * w
for ii in range(n_features):
X_T_R[ii] = 0.0
XtA[ii] = 0.0
for jj in range(X_indptr[ii], X_indptr[ii + 1]):
X_T_R[ii] += X_data[jj] * R[X_indices[jj]]
XtA[ii] += X_data[jj] * R[X_indices[jj]]

if center:
X_T_R[ii] -= X_mean[ii] * R_sum
XtA[ii] = X_T_R[ii] - beta * w[ii]
XtA[ii] -= X_mean[ii] * R_sum
XtA[ii] -= beta * w[ii]

if positive:
dual_norm_XtA = max(n_features, &XtA[0])
else:
dual_norm_XtA = abs_max(n_features, &XtA[0])

# R_norm2 = np.dot(R, R)
R_norm2 = _dot(n_samples, &R[0], 1, &R[0], 1)
if no_sample_weights:
R_norm2 = _dot(n_samples, &R[0], 1, &R[0], 1)
else:
R_norm2 = 0.0
for jj in range(n_samples):
# R is already multiplied by sample_weight
if sample_weight[jj] != 0:
R_norm2 += (R[jj] ** 2) / sample_weight[jj]

# w_norm2 = np.dot(w, w)
w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1)
Expand Down
Loading