3232#
3333# License: BSD 3 clause
3434
35+ import array
3536import numpy as np
3637import warnings
3738import scipy .sparse as sp
@@ -91,11 +92,11 @@ def fit_ovr(estimator, X, y, n_jobs=1):
9192
9293 if sp .issparse (Y ):
9394 estimators = Parallel (n_jobs = n_jobs )(
94- delayed (_fit_binary )(estimator , X , Y .getcol (i ).toarray (),
95+ delayed (_fit_binary )(estimator , X , Y .getcol (i ).toarray (),
9596 classes = ["not %s" % i , i ]) for i in range (Y .shape [1 ]))
9697 else :
9798 estimators = Parallel (n_jobs = n_jobs )(
98- delayed (_fit_binary )(estimator , X , Y [:, i ],
99+ delayed (_fit_binary )(estimator , X , Y [:, i ],
99100 classes = ["not %s" % i , i ]) for i in range (Y .shape [1 ]))
100101
101102 return estimators , lb
@@ -114,17 +115,19 @@ def predict_ovr(estimators, label_binarizer, X):
114115 for e in estimators :
115116 x_scores = np .append (x_scores , _predict_binary (e , x ))
116117 c = label_binarizer .classes_ [x_scores .argmax ()]
117- Y = np .append (Y ,c )
118+ Y = np .append (Y , c )
118119 return np .array (Y .T , dtype = label_binarizer .classes_ .dtype )
119120
120121 else :
121- Y = sp .coo_matrix (np .array (_predict_binary (e , X ) > thresh ,
122- dtype = np .int ))
123- for e in estimators [1 :]:
124- r = sp .coo_matrix (np .array (_predict_binary (e , X ) > thresh ,
125- dtype = np .int ))
126- Y = sp .vstack ([Y , r ])
127- return label_binarizer .inverse_transform (Y .T , threshold = 0.5 )
122+ indices = array .array ('i' )
123+ indptr = array .array ('i' , [0 ])
124+ for e in estimators :
125+ indices .extend (np .where (_predict_binary (e , X ) > thresh )[0 ])
126+ indptr .append (len (indices ))
127+ data = np .ones (len (indices ), dtype = int )
128+ indicator = sp .csc_matrix ((data , indices , indptr ),
129+ shape = (len (X ), len (estimators )))
130+ return label_binarizer .inverse_transform (indicator )
128131
129132
130133def predict_proba_ovr (estimators , X , is_multilabel ):
0 commit comments