Skip to content

Commit 6ec2c8b

Browse files
committed
Merge pull request #2674 from GaelVaroquaux/bug_ovo_string_y
BUG: OneVsOneClassifier was broken with string labels
2 parents 81336ae + fedff32 commit 6ec2c8b

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ Changelog
8787
- Fixed bug in :class:`linear_model.stochastic_gradient` :
8888
``l1_ratio`` was used as ``(1.0 - l1_ratio)`` .
8989

90+
- Fixed bug in :class:`multiclass.OneVsOneClassifier` with string
91+
labels
92+
9093
API changes summary
9194
-------------------
9295

sklearn/multiclass.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,11 @@ def _fit_ovo_binary(estimator, X, y, i, j):
304304
"""Fit a single binary estimator (one-vs-one)."""
305305
cond = np.logical_or(y == i, y == j)
306306
y = y[cond]
307-
y[y == i] = 0
308-
y[y == j] = 1
307+
y_binary = np.empty(y.shape, np.int)
308+
y_binary[y == i] = 0
309+
y_binary[y == j] = 1
309310
ind = np.arange(X.shape[0])
310-
return _fit_binary(estimator, X[ind[cond]], y, classes=[i, j])
311+
return _fit_binary(estimator, X[ind[cond]], y_binary, classes=[i, j])
311312

312313

313314
def fit_ovo(estimator, X, y, n_jobs=1):

sklearn/tests/test_multiclass.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,17 @@ def test_ovo_ties2():
312312
assert_equal(ovo_prediction[0], (1 + i) % 3)
313313

314314

315+
def test_ovo_string_y():
316+
"Test that the OvO doesn't screw the encoding of string labels"
317+
X = np.eye(4)
318+
y = np.array(['a', 'b', 'c', 'd'])
319+
320+
svc = LinearSVC()
321+
ovo = OneVsOneClassifier(svc)
322+
ovo.fit(X, y)
323+
assert_array_equal(y, ovo.predict(X))
324+
325+
315326
def test_ecoc_exceptions():
316327
ecoc = OutputCodeClassifier(LinearSVC(random_state=0))
317328
assert_raises(ValueError, ecoc.predict, [])

0 commit comments

Comments
 (0)