Skip to content

Conversation

@ogrisel
Copy link
Collaborator

@ogrisel ogrisel commented Jun 20, 2022

Fixes scikit-learn#23670.

This is a PR to scikit-learn#23619 (need to be updated to target main instead) to make sure that GLMs with fit_intercept=True and the default LBFGS solver converge to the minimum norm solution when alpha=0..

The main change is to init the intercept at zero (or very close to zero for GLMs with identity link) instead of the previously smart init of the intercept that was set to:

coef[-1] = linear_loss.base_loss.link.link(
    np.average(y, weights=sample_weight)
)

This PR also updates the tolerance in some tests and does a few hacks to keep the model converging in all cases which seems harder with zero-init-intercept instead of the previous "smart init" of the intercept.

Note

I have not yet benchmarked if the use of a zero init for the intercept has any impact on the convergence speed on a non-toy dataset.

Note 2

test_ridge.py has a similar problem:

https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/linear_model/tests/test_ridge.py#L317-L322

and I believe the solution would probably be similar. However the Ridge class has many different solvers so I am afraid to start investigating before getting a good understanding of the impacts in terms of convergence speed and numerical stability in the case of GLMs with the LBFGS solver.

assert_allclose(model.coef_, np.r_[coef, coef], rtol=rtol)
atol = 1e-6
assert model.intercept_ == pytest.approx(intercept, rel=rtol, abs=atol)
assert_allclose(model.coef_, np.r_[coef, coef], rtol=rtol, atol=atol)
Copy link
Collaborator Author

@ogrisel ogrisel Jun 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: that I had to add quite a largish atol for this test to pass with the new init. Otherwise the tests would fail for a few random seeds, either with wide or long datasets for various distribution families.

I am not very happy about this but I am not sure what I can do about it.

# Possible causes: 1 error in function or gradient evaluation;
# 2 rounding error dominate computation.
pytest.xfail()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would no longer fail, so I removed the xfail. Not sure if it's related to the change of this PR or if it is a consequence of the increased MAXLS setting in 79ec862.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's because of this PR.

@ogrisel
Copy link
Collaborator Author

ogrisel commented Jun 20, 2022

This is not ready to merge, I got the following failures on my local all random seeds run:

Details
___________________________________ test_glm_regression_vstacked_X[wide-GammaRegressor()-53-True-lbfgs] ____________________________________

solver = 'lbfgs', fit_intercept = True
glm_dataset = (GammaRegressor(), array([[-0.39615967,  0.09731388,  0.10966576, -0.27396143,  0.25232577,
        -0.27279365, -0.08...-0.50710729,  0.10123481,
       -0.50007489, -0.19223934,  0.69307535,  0.15496709,  0.23387751,
        0.26531359]))

>   ???
E   AssertionError: 
E   Not equal to tolerance rtol=3e-05, atol=0
E   
E   Mismatched elements: 1 / 11 (9.09%)
E   Max absolute difference: 1.25760229e-08
E   Max relative difference: 3.10611783e-05
E    x: array([-0.038494, -0.016638, -0.06312 , -0.056001, -0.003116,  0.000188,
E          -0.002503,  0.173183,  0.055399, -0.039526,  0.068564])
E    y: array([-0.038494, -0.016638, -0.06312 , -0.056001, -0.003116,  0.000188,
E          -0.002503,  0.173183,  0.055399, -0.039526,  0.068564])

sklearn/linear_model/_glm/tests/test_glm.py:329: AssertionError
____________________________ test_glm_regression_unpenalized_hstacked_X[wide-PoissonRegressor()-27-False-lbfgs] ____________________________

solver = 'lbfgs', fit_intercept = False
glm_dataset = (PoissonRegressor(), array([[ 0.00963101,  0.16101125, -0.1674972 , -0.30527104, -0.38960182,
         0.35346062, -0.... 0.54125044, -0.13846211,
       -0.03145258, -0.06858661,  0.23625749, -0.16402337,  0.03745139,
       -0.14795716]))

>   ???
E   AssertionError: 
E   Not equal to tolerance rtol=5e-05, atol=0
E   
E   Mismatched elements: 2 / 24 (8.33%)
E   Max absolute difference: 5.65458174e-08
E   Max relative difference: 5.05820463e-05
E    x: array([-3.048566e-02, -8.716149e-01, -2.109112e+00,  2.561165e+00,
E           9.263738e-04, -1.270580e+00,  1.649646e-03,  1.151003e+00,
E          -1.374081e+00, -4.612122e-01, -2.020498e+00,  9.284053e-02,...
E    y: array([-3.048565e-02, -8.716149e-01, -2.109111e+00,  2.561165e+00,
E           9.263269e-04, -1.270580e+00,  1.649629e-03,  1.151003e+00,
E          -1.374081e+00, -4.612122e-01, -2.020498e+00,  9.284052e-02,...

on top of this tolerance adjustments, I also observed weirder failures in other tests:

Details
______________________________________________________ test_tweedie_score[identity-1] ______________________________________________________

regression_data = (array([[-0.20198874, -0.29717478,  0.38599808, ...,  1.19451016,
         0.33945977,  1.00147133],
       [-0.875896...       307.93198786,  -91.18215801,   10.408502  , -154.57681977,
        105.70151776,   22.8229874 , -264.93461075]))
power = 1, link = 'identity'

    @pytest.mark.parametrize("link", ["log", "identity"])
    def test_tweedie_score(regression_data, power, link):
        """Test that GLM score equals d2_tweedie_score for Tweedie losses."""
        X, y = regression_data
        # make y positive
        y = np.abs(y) + 1.0
        glm = TweedieRegressor(power=power, link=link).fit(X, y)
>       assert glm.score(X, y) == pytest.approx(
            d2_tweedie_score(y, glm.predict(X), power=power)
        )

sklearn/linear_model/_glm/tests/test_glm.py:940: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
sklearn/linear_model/_glm/glm.py:278: in fit
    func = linear_loss.loss_gradient
../../mambaforge/envs/dev/lib/python3.10/site-packages/scipy/optimize/_minimize.py:681: in minimize
    res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
../../mambaforge/envs/dev/lib/python3.10/site-packages/scipy/optimize/_lbfgsb_py.py:308: in _minimize_lbfgsb
    sf = _prepare_scalar_function(fun, x0, jac=jac, args=args, epsilon=eps,
../../mambaforge/envs/dev/lib/python3.10/site-packages/scipy/optimize/_optimize.py:263: in _prepare_scalar_function
    sf = ScalarFunction(fun, x0, args, grad, hess,
../../mambaforge/envs/dev/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:158: in __init__
    self._update_fun()
../../mambaforge/envs/dev/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:251: in _update_fun
    self._update_fun_impl()
../../mambaforge/envs/dev/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:155: in update_fun
    self.f = fun_wrapped(self.x)
../../mambaforge/envs/dev/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:137: in fun_wrapped
    fx = fun(np.copy(x), *args)
../../mambaforge/envs/dev/lib/python3.10/site-packages/scipy/optimize/_optimize.py:76: in __call__
    self._compute_if_needed(x, *args)
../../mambaforge/envs/dev/lib/python3.10/site-packages/scipy/optimize/_optimize.py:70: in _compute_if_needed
    fg = self.fun(x, *args)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <sklearn.linear_model._linear_loss.LinearModelLoss object at 0x7fb845181cf0>
coef = array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
X = array([[-0.20198874, -0.29717478,  0.38599808, ...,  1.19451016,
         0.33945977,  1.00147133],
       [-0.8758963...548, -0.18325657],
       [-0.05626683, -1.05795222, -0.41675785, ..., -0.84174737,
         0.50288142, -1.79343559]])
y = array([114.25188979, 102.29594843,  22.94623126,  47.15897115,
       274.64132622,  24.76862888,  19.15994658, 223.09...25504,
       308.93198786,  92.18215801,  11.408502  , 155.57681977,
       106.70151776,  23.8229874 , 265.93461075])
sample_weight = array([0.00934579, 0.00934579, 0.00934579, 0.00934579, 0.00934579,
       0.00934579, 0.00934579, 0.00934579, 0.009345...0934579, 0.00934579,
       0.00934579, 0.00934579, 0.00934579, 0.00934579, 0.00934579,
       0.00934579, 0.00934579])
l2_reg_strength = 1.0, n_threads = 8

    def loss_gradient(
        self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1
    ):
        """Computes the sum of loss and gradient w.r.t. coef.
    
        Parameters
        ----------
        coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,)
            Coefficients of a linear model.
            If shape (n_classes * n_dof,), the classes of one feature are contiguous,
            i.e. one reconstructs the 2d-array via
            coef.reshape((n_classes, -1), order="F").
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            Training data.
        y : contiguous array of shape (n_samples,)
            Observed, true target values.
        sample_weight : None or contiguous array of shape (n_samples,), default=None
            Sample weights.
        l2_reg_strength : float, default=0.0
            L2 regularization strength
        n_threads : int, default=1
            Number of OpenMP threads to use.
    
        Returns
        -------
        loss : float
            Sum of losses per sample plus penalty.
    
        gradient : ndarray of shape coef.shape
             The gradient of the loss.
        """
        n_features, n_classes = X.shape[1], self.base_loss.n_classes
        n_dof = n_features + int(self.fit_intercept)
        weights, intercept, raw_prediction = self._w_intercept_raw(coef, X)
    
        loss, grad_per_sample = self.base_loss.loss_gradient(
            y_true=y,
            raw_prediction=raw_prediction,
            sample_weight=sample_weight,
            n_threads=n_threads,
        )
        loss = loss.sum()
    
        if not self.base_loss.is_multiclass:
            loss += 0.5 * l2_reg_strength * (weights @ weights)
            grad = np.empty_like(coef, dtype=weights.dtype)
>           grad[:n_features] = X.T @ grad_per_sample + l2_reg_strength * weights
E           RuntimeWarning: invalid value encountered in matmul

@ogrisel
Copy link
Collaborator Author

ogrisel commented Jun 20, 2022

Zero init intercept causes test_tweedie_score to fail with grad_per_sample being equal to an array of all -inf values...

@lorentzenchr
Copy link
Owner

Zero init intercept causes test_tweedie_score to fail with grad_per_sample being equal to an array of all -inf values...

For Tweedie with power>=1 (Poisson) and identity link, we must initialize away from zero, otherwise we have y_predict=identity(raw_prediction)=0 which is not allowed, i.e. produces inf.

@ogrisel
Copy link
Collaborator Author

ogrisel commented Jun 21, 2022

Ok I pushed a hack to make all tests pass (I tried with the default 42 random seed and currently waiting for 'all' to complete).

I had to do several hacks:

  • use gradient clipping to get test_tweedie_score with power=0 and log link to not get infinite gradients;
  • increase max_iter to get test_tweedie_score with power>=2 and identity link to pass without convergence warning.

All in all, I am not sure it's worth it. The smart intercept init in main seems to make the optimization problem much more numerically stable for the first few iterations, even if it changes the inductive bias and prevents to converge to the minimum norm solution for unpenalized models.

@ogrisel
Copy link
Collaborator Author

ogrisel commented Jun 21, 2022

It took 27 min but:

SKLEARN_TESTS_GLOBAL_RANDOM_SEED='all' pytest -Werror -vs sklearn/linear_model/_glm/tests/test_glm.py

completed successfully on this PR.

@ogrisel
Copy link
Collaborator Author

ogrisel commented Jun 21, 2022

TODO next: quantify impact on convergence speed on realistic datasets.

# infinite gradients
CLIP = 1e100
np.clip(grad_per_sample, -CLIP, CLIP, out=grad_per_sample)
np.clip(loss, None, CLIP, out=loss)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not be needed. We might just want to silence the RuntimeWarning if we know that the code is robust to inf values as suggested in scikit-learn#23314 (comment)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment means that line search handles inf correctly. It does not mean that the result will get finite again if all gradient elements are inf.

If possible, we should get rid of these clips.

Copy link
Owner

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ogrisel Thanks for investigating this.

# link function (in particular for Tweedie with p>=1
# with identity link) to be able to compute the first
# gradient.
coef[-1] = 64 * np.finfo(loss_dtype).eps
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The magic 64 again 😉

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it works for a variety of small eps but getting too close to the minimum eps would make convergence less numerically stable for some combinations of family / link.

# Possible causes: 1 error in function or gradient evaluation;
# 2 rounding error dominate computation.
pytest.xfail()

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's because of this PR.

Comment on lines +378 to +379
# XXX: Do we have any theoretical guarantees why this should be the
# case?
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are papers about gradient descent and minimum norm.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do they cover LBFGS, line search and the combination of stopping criteria we use?

Copy link
Collaborator Author

@ogrisel ogrisel Jun 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a direct answer to the question above, but the following quite recent paper seems relevant:

https://opt-ml.org/papers/2020/paper_33.pdf

In particular this paper is interesting because it also investigate the case of under-parameterized linear classifiers trained on linearly separable data.

Maybe @genji would like to comment on this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should conduct some experiments to see the impact of the intercept init on the test accuracy.

We could also maybe explore how to do the smart intercept init and re-project on the data-span as done in the paper above (not sure if it makes sense for all GLMs).

# infinite gradients
CLIP = 1e100
np.clip(grad_per_sample, -CLIP, CLIP, out=grad_per_sample)
np.clip(loss, None, CLIP, out=loss)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment means that line search handles inf correctly. It does not mean that the result will get finite again if all gradient elements are inf.

If possible, we should get rid of these clips.

@lorentzenchr
Copy link
Owner

Giving it a lot of thought: If it's too hard to reach the minimum norm solution for GLMs, then it's maybe not worth it.

@ogrisel
Copy link
Collaborator Author

ogrisel commented Jun 23, 2022

Giving it a lot of thought: If it's too hard to reach the minimum norm solution for GLMs, then it's maybe not worth it.

I agree. However I think it's still interesting to understand the need or not to gradient clipping because the same convergence problem can happen with fit_intercept=False. At the moment, the tests pass in main only because test_tweedie_score does not test the case fit_intercept=False.

@ogrisel
Copy link
Collaborator Author

ogrisel commented Jun 30, 2022

Next steps:

  • rework this PR to target main;
  • conduct some evaluation of the convergence speed on various realistic datasets;
  • evaluate the impact on generalization performance (e.g. test deviance).

@lorentzenchr
Copy link
Owner

I close as we can now open PRs directly to scikit-learn main.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants