Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions sklearn/gaussian_process/gpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,21 @@
#
# License: BSD 3 clause

import warnings
from operator import itemgetter

import numpy as np
from scipy.linalg import cholesky, cho_solve, solve
from scipy.optimize import fmin_l_bfgs_b
import scipy.optimize
from scipy.special import erf, expit

from ..base import BaseEstimator, ClassifierMixin, clone
from .kernels \
import RBF, CompoundKernel, ConstantKernel as C
from ..utils.validation import check_X_y, check_is_fitted, check_array
from ..utils import check_random_state
from ..utils.optimize import _check_optimize_result
from ..preprocessing import LabelEncoder
from ..multiclass import OneVsRestClassifier, OneVsOneClassifier
from ..exceptions import ConvergenceWarning


# Values required for approximating the logistic sigmoid by
Expand Down Expand Up @@ -74,7 +73,7 @@ def optimizer(obj_func, initial_theta, bounds):
# the corresponding value of the target function.
return theta_opt, func_min

Per default, the 'fmin_l_bfgs_b' algorithm from scipy.optimize
Per default, the 'L-BFGS-B' algorithm from scipy.optimize.minimize
is used. If None is passed, the kernel's parameters are kept fixed.
Available internal optimizers are::

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we then deprecate fmin_l_bfgs_b as the input value in favor of L-BFGS-B, or lbfgs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that would be in order, and similarly I wanted to see if allowing other scipy optimizers in Gaussian processes would be interesting. Though I would rather do that in a follow-up PR, and keep this as a minimal refactoring not affecting backward compatibility.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree allowing other optimizers would be a different PR, but since you're touching the docstring here, it makes sense for the accepted value to be the same or similar to what you mention in the docstring, I think.

But if you wanna do the deprecation in a different PR, I'm happy with that as well, and then this LGTM.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't changed the docstring value here only changed that the algorithm is called L-BFGS-B, not fmin_l_bfgs_b. That fix would apply even before this PR as fmin_l_bfgs_b is not an algorithm name it's the scipy function name for that optimizer.

Will do the deprecated in a follow up PR :) Thanks for the review!

Expand Down Expand Up @@ -426,12 +425,11 @@ def _posterior_mode(self, K, return_temporaries=False):

def _constrained_optimization(self, obj_func, initial_theta, bounds):
if self.optimizer == "fmin_l_bfgs_b":
theta_opt, func_min, convergence_dict = \
fmin_l_bfgs_b(obj_func, initial_theta, bounds=bounds)
if convergence_dict["warnflag"] != 0:
warnings.warn("fmin_l_bfgs_b terminated abnormally with the "
" state: %s" % convergence_dict,
ConvergenceWarning)
opt_res = scipy.optimize.minimize(
obj_func, initial_theta, method="L-BFGS-B", jac=True,
bounds=bounds)
_check_optimize_result("lbfgs", opt_res)
theta_opt, func_min = opt_res.x, opt_res.fun
elif callable(self.optimizer):
theta_opt, func_min = \
self.optimizer(obj_func, initial_theta, bounds=bounds)
Expand Down Expand Up @@ -482,7 +480,7 @@ def optimizer(obj_func, initial_theta, bounds):
# the corresponding value of the target function.
return theta_opt, func_min

Per default, the 'fmin_l_bfgs_b' algorithm from scipy.optimize
Per default, the 'L-BFGS-B' algorithm from scipy.optimize.minimize
is used. If None is passed, the kernel's parameters are kept fixed.
Available internal optimizers are::

Expand Down
17 changes: 8 additions & 9 deletions sklearn/gaussian_process/gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@

import numpy as np
from scipy.linalg import cholesky, cho_solve, solve_triangular
from scipy.optimize import fmin_l_bfgs_b
import scipy.optimize

from ..base import BaseEstimator, RegressorMixin, clone
from ..base import MultiOutputMixin
from .kernels import RBF, ConstantKernel as C
from ..utils import check_random_state
from ..utils.validation import check_X_y, check_array
from ..exceptions import ConvergenceWarning
from ..utils.optimize import _check_optimize_result


class GaussianProcessRegressor(BaseEstimator, RegressorMixin,
Expand Down Expand Up @@ -77,7 +77,7 @@ def optimizer(obj_func, initial_theta, bounds):
# the corresponding value of the target function.
return theta_opt, func_min

Per default, the 'fmin_l_bfgs_b' algorithm from scipy.optimize
Per default, the 'L-BGFS-B' algorithm from scipy.optimize.minimize
is used. If None is passed, the kernel's parameters are kept fixed.
Available internal optimizers are::

Expand Down Expand Up @@ -461,12 +461,11 @@ def log_marginal_likelihood(self, theta=None, eval_gradient=False):

def _constrained_optimization(self, obj_func, initial_theta, bounds):
if self.optimizer == "fmin_l_bfgs_b":
theta_opt, func_min, convergence_dict = \
fmin_l_bfgs_b(obj_func, initial_theta, bounds=bounds)
if convergence_dict["warnflag"] != 0:
warnings.warn("fmin_l_bfgs_b terminated abnormally with the "
" state: %s" % convergence_dict,
ConvergenceWarning)
opt_res = scipy.optimize.minimize(
obj_func, initial_theta, method="L-BFGS-B", jac=True,
bounds=bounds)
_check_optimize_result("lbfgs", opt_res)
theta_opt, func_min = opt_res.x, opt_res.fun
elif callable(self.optimizer):
theta_opt, func_min = \
self.optimizer(obj_func, initial_theta, bounds=bounds)
Expand Down
27 changes: 15 additions & 12 deletions sklearn/linear_model/huber.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..utils import check_consistent_length
from ..utils import axis0_safe_slice
from ..utils.extmath import safe_sparse_dot
from ..utils.optimize import _check_optimize_result


def _huber_loss_and_gradient(w, X, y, epsilon, alpha, sample_weight=None):
Expand Down Expand Up @@ -147,8 +148,8 @@ class HuberRegressor(LinearModel, RegressorMixin, BaseEstimator):
to outliers.

max_iter : int, default 100
Maximum number of iterations that scipy.optimize.fmin_l_bfgs_b
should run for.
Maximum number of iterations that
``scipy.optimize.minimize(method="L-BFGS-B")`` should run for.

alpha : float, default 0.0001
Regularization parameter.
Expand Down Expand Up @@ -180,7 +181,8 @@ class HuberRegressor(LinearModel, RegressorMixin, BaseEstimator):
The value by which ``|y - X'w - c|`` is scaled down.

n_iter_ : int
Number of iterations that fmin_l_bfgs_b has run for.
Number of iterations that
``scipy.optimize.minimize(method="L-BFGS-B")`` has run for.

.. versionchanged:: 0.20

Expand Down Expand Up @@ -282,18 +284,19 @@ def fit(self, X, y, sample_weight=None):
bounds = np.tile([-np.inf, np.inf], (parameters.shape[0], 1))
bounds[-1][0] = np.finfo(np.float64).eps * 10

parameters, f, dict_ = optimize.fmin_l_bfgs_b(
_huber_loss_and_gradient, parameters,
opt_res = optimize.minimize(
_huber_loss_and_gradient, parameters, method="L-BFGS-B", jac=True,
args=(X, y, self.epsilon, self.alpha, sample_weight),
maxiter=self.max_iter, pgtol=self.tol, bounds=bounds,
iprint=0)
if dict_['warnflag'] == 2:
options={"maxiter": self.max_iter, "gtol": self.tol, "iprint": -1},
bounds=bounds)

parameters = opt_res.x

if opt_res.status == 2:
raise ValueError("HuberRegressor convergence failed:"
" l-BFGS-b solver terminated with %s"
% dict_['task'].decode('ascii'))
# In scipy <= 1.0.0, nit may exceed maxiter.
# See https://github.com/scipy/scipy/issues/7854.
self.n_iter_ = min(dict_['nit'], self.max_iter)
% opt_res.message)
self.n_iter_ = _check_optimize_result("lbfgs", opt_res, self.max_iter)
self.scale_ = parameters[-1]
if self.fit_intercept:
self.intercept_ = parameters[-2]
Expand Down
22 changes: 10 additions & 12 deletions sklearn/linear_model/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
squared_norm)
from ..utils.extmath import row_norms
from ..utils.fixes import logsumexp
from ..utils.optimize import newton_cg
from ..utils.optimize import newton_cg, _check_optimize_result
from ..utils.validation import check_X_y
from ..utils.validation import check_is_fitted
from ..utils import deprecated
from ..exceptions import (ConvergenceWarning, ChangedBehaviorWarning)
from ..exceptions import ChangedBehaviorWarning
from ..utils.multiclass import check_classification_targets
from ..utils.fixes import _joblib_parallel_args
from ..model_selection import check_cv
Expand Down Expand Up @@ -899,7 +899,8 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
w0[:, :coef.shape[1]] = coef

if multi_class == 'multinomial':
# fmin_l_bfgs_b and newton-cg accepts only ravelled parameters.
# scipy.optimize.minimize and newton-cg accepts only
# ravelled parameters.
if solver in ['lbfgs', 'newton-cg']:
w0 = w0.ravel()
target = Y_multi
Expand All @@ -926,16 +927,13 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
if solver == 'lbfgs':
iprint = [-1, 50, 1, 100, 101][
np.searchsorted(np.array([0, 1, 2, 3]), verbose)]
w0, loss, info = optimize.fmin_l_bfgs_b(
func, w0, fprime=None,
opt_res = optimize.minimize(
func, w0, method="L-BFGS-B", jac=True,
args=(X, target, 1. / C, sample_weight),
iprint=iprint, pgtol=tol, maxiter=max_iter)
if info["warnflag"] == 1:
warnings.warn("lbfgs failed to converge. Increase the number "
"of iterations.", ConvergenceWarning)
# In scipy <= 1.0.0, nit may exceed maxiter.
# See https://github.com/scipy/scipy/issues/7854.
n_iter_i = min(info['nit'], max_iter)
options={"iprint": iprint, "gtol": tol, "maxiter": max_iter}
)
n_iter_i = _check_optimize_result(solver, opt_res, max_iter)
w0, loss = opt_res.x, opt_res.fun
elif solver == 'newton-cg':
args = (X, target, 1. / C, sample_weight)
w0, n_iter_i = newton_cg(hess, func, grad, w0, args=args,
Expand Down
45 changes: 16 additions & 29 deletions sklearn/neural_network/multilayer_perceptron.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import numpy as np

from abc import ABCMeta, abstractmethod
from scipy.optimize import fmin_l_bfgs_b
import warnings

import scipy.optimize

from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
from ..base import is_classifier
from ._base import ACTIVATIONS, DERIVATIVES, LOSS_FUNCTIONS
Expand All @@ -26,6 +27,7 @@
from ..utils.validation import check_is_fitted
from ..utils.multiclass import _check_partial_fit_first_call, unique_labels
from ..utils.multiclass import type_of_target
from ..utils.optimize import _check_optimize_result


_STOCHASTIC_SOLVERS = ['sgd', 'adam']
Expand Down Expand Up @@ -458,34 +460,19 @@ def _fit_lbfgs(self, X, y, activations, deltas, coef_grads,
else:
iprint = -1

optimal_parameters, self.loss_, d = fmin_l_bfgs_b(
x0=packed_coef_inter,
func=self._loss_grad_lbfgs,
maxfun=self.max_fun,
maxiter=self.max_iter,
iprint=iprint,
pgtol=self.tol,
args=(X, y, activations, deltas, coef_grads, intercept_grads))
self.n_iter_ = d['nit']
if d['warnflag'] == 1:
if d['nit'] >= self.max_iter:
warnings.warn(
"LBFGS Optimizer: Maximum iterations (%d) "
"reached and the optimization hasn't converged yet."
% self.max_iter, ConvergenceWarning)
if d['funcalls'] >= self.max_fun:
warnings.warn(
"LBFGS Optimizer: Maximum function evaluations (%d) "
"reached and the optimization hasn't converged yet."
% self.max_fun, ConvergenceWarning)
elif d['warnflag'] == 2:
warnings.warn(
"LBFGS Optimizer: Optimization hasn't converged yet, "
"cause of LBFGS stopping: %s."
% d['task'], ConvergenceWarning)


self._unpack(optimal_parameters)
opt_res = scipy.optimize.minimize(
self._loss_grad_lbfgs, packed_coef_inter,
method="L-BFGS-B", jac=True,
options={
"maxfun": self.max_fun,
"maxiter": self.max_iter,
"iprint": iprint,
"gtol": self.tol
},
args=(X, y, activations, deltas, coef_grads, intercept_grads))
self.n_iter_ = _check_optimize_result("lbfgs", opt_res, self.max_iter)
self.loss_ = opt_res.fun
self._unpack(opt_res.x)

def _fit_stochastic(self, X, y, activations, deltas, coef_grads,
intercept_grads, layer_units, incremental):
Expand Down
36 changes: 36 additions & 0 deletions sklearn/utils/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,39 @@ def newton_cg(grad_hess, func, grad, x0, args=(), tol=1e-4,
warnings.warn("newton-cg failed to converge. Increase the "
"number of iterations.", ConvergenceWarning)
return xk, k


def _check_optimize_result(solver, result, max_iter=None):
"""Check the OptimizeResult for successful convergence

Parameters
----------
solver: str
solver name. Currently only `lbfgs` is supported.
result: OptimizeResult
result of the scipy.optimize.minimize function
max_iter: {int, None}
expected maximum number of iterations

Returns
-------
n_iter: int
number of iterations
"""
# handle both scipy and scikit-learn solver names
if solver == "lbfgs":
if result.status != 0:
warnings.warn("{} failed to converge (status={}): {}. "
"Increase the number of iterations."
.format(solver, result.status, result.message),
ConvergenceWarning)
if max_iter is not None:
# In scipy <= 1.0.0, nit may exceed maxiter for lbfgs.
# See https://github.com/scipy/scipy/issues/7854
n_iter_i = min(result.nit, max_iter)
else:
n_iter_i = result.nit
else:
raise NotImplementedError

return n_iter_i