Skip to content

Commit 10f23b1

Browse files
author
Hamzeh Alsalhi
committed
Implemented construction of csc_matrix by column indicies in predict_ovr
1 parent 897a3ba commit 10f23b1

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

sklearn/multiclass.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#
3333
# License: BSD 3 clause
3434

35+
import array
3536
import numpy as np
3637
import warnings
3738
import scipy.sparse as sp
@@ -91,11 +92,11 @@ def fit_ovr(estimator, X, y, n_jobs=1):
9192

9293
if sp.issparse(Y):
9394
estimators = Parallel(n_jobs=n_jobs)(
94-
delayed(_fit_binary)(estimator, X, Y.getcol(i).toarray(),
95+
delayed(_fit_binary)(estimator, X, Y.getcol(i).toarray(),
9596
classes=["not %s" % i, i]) for i in range(Y.shape[1]))
9697
else:
9798
estimators = Parallel(n_jobs=n_jobs)(
98-
delayed(_fit_binary)(estimator, X, Y[:, i],
99+
delayed(_fit_binary)(estimator, X, Y[:, i],
99100
classes=["not %s" % i, i]) for i in range(Y.shape[1]))
100101

101102
return estimators, lb
@@ -114,17 +115,19 @@ def predict_ovr(estimators, label_binarizer, X):
114115
for e in estimators:
115116
x_scores = np.append(x_scores, _predict_binary(e, x))
116117
c = label_binarizer.classes_[x_scores.argmax()]
117-
Y = np.append(Y,c)
118+
Y = np.append(Y, c)
118119
return np.array(Y.T, dtype=label_binarizer.classes_.dtype)
119120

120121
else:
121-
Y = sp.coo_matrix(np.array(_predict_binary(e, X) > thresh,
122-
dtype=np.int))
123-
for e in estimators[1:]:
124-
r = sp.coo_matrix(np.array(_predict_binary(e, X) > thresh,
125-
dtype=np.int))
126-
Y = sp.vstack([Y, r])
127-
return label_binarizer.inverse_transform(Y.T, threshold=0.5)
122+
indices = array.array('i')
123+
indptr = array.array('i', [0])
124+
for e in estimators:
125+
indices.extend(np.where(_predict_binary(e, X) > thresh)[0])
126+
indptr.append(len(indices))
127+
data = np.ones(len(indices), dtype=int)
128+
indicator = sp.csc_matrix((data, indices, indptr),
129+
shape=(len(X), len(estimators)))
130+
return label_binarizer.inverse_transform(indicator)
128131

129132

130133
def predict_proba_ovr(estimators, X, is_multilabel):

0 commit comments

Comments
 (0)