Skip to content

Commit d06bad1

Browse files
committed
only warn for multi_class default if the problem is not binary
1 parent 85f0ecc commit d06bad1

File tree

1 file changed

+25
-19
lines changed

1 file changed

+25
-19
lines changed

sklearn/linear_model/logistic.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def hessp(v):
424424
return grad, hessp
425425

426426

427-
def _check_solver_option(solver, multi_class, penalty, dual):
427+
def _check_solver_option(solver, multi_class, penalty, dual, classes):
428428

429429
# Default values raises a future warning
430430
if solver == 'warn':
@@ -435,9 +435,11 @@ def _check_solver_option(solver, multi_class, penalty, dual):
435435

436436
if multi_class == 'warn':
437437
multi_class = 'ovr'
438-
warnings.warn("Default multi_class will be changed to 'multinomial' in"
439-
" 0.22. Use a specific option to silence this warning.",
440-
FutureWarning)
438+
if len(classes) > 2: # only warn if the problem is not binary
439+
warnings.warn(
440+
"Default multi_class will be changed to 'multinomial' in 0.22."
441+
" Use a specific option to silence this warning.",
442+
FutureWarning)
441443

442444
# Check the string parameters
443445
if multi_class not in ['multinomial', 'ovr']:
@@ -613,20 +615,21 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
613615
if isinstance(Cs, numbers.Integral):
614616
Cs = np.logspace(-4, 4, Cs)
615617

616-
solver, multi_class = _check_solver_option(
617-
solver, multi_class, penalty, dual)
618-
619618
# Preprocessing.
620619
if check_input:
620+
accept_large_sparse = solver not in ['liblinear', 'warn']
621621
X = check_array(X, accept_sparse='csr', dtype=np.float64,
622-
accept_large_sparse=solver != 'liblinear')
622+
accept_large_sparse=accept_large_sparse)
623623
y = check_array(y, ensure_2d=False, dtype=None)
624624
check_consistent_length(X, y)
625625
_, n_features = X.shape
626626

627627
classes = np.unique(y)
628628
random_state = check_random_state(random_state)
629629

630+
solver, multi_class = _check_solver_option(
631+
solver, multi_class, penalty, dual, classes=classes)
632+
630633
if pos_class is None and multi_class != 'multinomial':
631634
if (classes.size > 2):
632635
raise ValueError('To fit OvR, use the pos_class argument')
@@ -927,7 +930,7 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
927930
Actual number of iteration for each Cs.
928931
"""
929932
solver, multi_class = _check_solver_option(
930-
solver, multi_class, penalty, dual)
933+
solver, multi_class, penalty, dual, classes=np.unique(y))
931934

932935
X_train = X[train]
933936
X_test = X[test]
@@ -1247,20 +1250,22 @@ def fit(self, X, y, sample_weight=None):
12471250
raise ValueError("Tolerance for stopping criteria must be "
12481251
"positive; got (tol=%r)" % self.tol)
12491252

1250-
solver, multi_class = _check_solver_option(
1251-
self.solver, self.multi_class, self.penalty, self.dual)
1252-
1253-
if solver in ['newton-cg']:
1253+
if self.solver in ['newton-cg']:
12541254
_dtype = [np.float64, np.float32]
12551255
else:
12561256
_dtype = np.float64
12571257

1258+
accept_large_sparse = self.solver not in ['liblinear', 'warn']
12581259
X, y = check_X_y(X, y, accept_sparse='csr', dtype=_dtype, order="C",
1259-
accept_large_sparse=solver != 'liblinear')
1260+
accept_large_sparse=accept_large_sparse)
12601261
check_classification_targets(y)
12611262
self.classes_ = np.unique(y)
12621263
n_samples, n_features = X.shape
12631264

1265+
solver, multi_class = _check_solver_option(
1266+
self.solver, self.multi_class, self.penalty, self.dual,
1267+
classes=self.classes_)
1268+
12641269
if solver == 'liblinear':
12651270
if self.n_jobs != 1:
12661271
warnings.warn("'n_jobs' > 1 does not have any effect when"
@@ -1649,21 +1654,22 @@ def fit(self, X, y, sample_weight=None):
16491654
-------
16501655
self : object
16511656
"""
1652-
solver, multi_class = _check_solver_option(
1653-
self.solver, self.multi_class, self.penalty, self.dual)
1654-
16551657
if not isinstance(self.max_iter, numbers.Number) or self.max_iter < 0:
16561658
raise ValueError("Maximum number of iteration must be positive;"
16571659
" got (max_iter=%r)" % self.max_iter)
16581660
if not isinstance(self.tol, numbers.Number) or self.tol < 0:
16591661
raise ValueError("Tolerance for stopping criteria must be "
16601662
"positive; got (tol=%r)" % self.tol)
16611663

1664+
accept_large_sparse = self.solver not in ['liblinear', 'warn']
16621665
X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64,
1663-
order="C",
1664-
accept_large_sparse=solver != 'liblinear')
1666+
order="C", accept_large_sparse=accept_large_sparse)
16651667
check_classification_targets(y)
16661668

1669+
solver, multi_class = _check_solver_option(
1670+
self.solver, self.multi_class, self.penalty, self.dual,
1671+
classes=np.unique(y))
1672+
16671673
class_weight = self.class_weight
16681674

16691675
# Encode for string labels

0 commit comments

Comments
 (0)