Skip to content

Commit 0979dd0

Browse files
hamzehhamzeh
authored andcommitted
Enforced correct dtype in predict_ovr
1 parent 643b9b5 commit 0979dd0

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

sklearn/multiclass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def predict_ovr(estimators, label_binarizer, X):
115115
x_scores = np.append(x_scores, _predict_binary(e, x))
116116
c = label_binarizer.classes_[x_scores.argmax()]
117117
Y = np.append(Y,c)
118-
return Y.T
118+
return np.array(Y.T, dtype=label_binarizer.classes_.dtype)
119119

120120
else:
121121
Y = sp.coo_matrix(np.array(_predict_binary(e, X) > thresh,

0 commit comments

Comments
 (0)