Skip to content

Commit e7160a5

Browse files
committed
Refine balanced_accuracy_score on 2d array
1 parent de92e87 commit e7160a5

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

sklearn/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .ranking import roc_curve
1212

1313
from .classification import accuracy_score
14+
from .classification import balanced_accuracy_score
1415
from .classification import classification_report
1516
from .classification import confusion_matrix
1617
from .classification import f1_score

sklearn/metrics/classification.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -170,16 +170,16 @@ def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None):
170170

171171

172172
def _1d_balanced_accuracy_score(y_true, y_pred, sample_weight=None):
173-
# Compute accuracy for each possible representation
173+
# Only support binary classification for now
174174
y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred)
175175
if y_type == 'binary':
176176
score = y_true == y_pred
177177
else:
178178
raise ValueError("%s is not yet implemented" % y_type)
179179

180180
# Positive and negative index in y_true
181-
p_idx = np.where(y_true == 1)
182-
n_idx = np.where(y_true == 0)
181+
p_idx = np.where(y_true == 1)[0]
182+
n_idx = np.where(y_true == 0)[0]
183183

184184
if sample_weight is None:
185185
sample_weight = np.ones(y_true.shape[0])
@@ -188,12 +188,12 @@ def _1d_balanced_accuracy_score(y_true, y_pred, sample_weight=None):
188188
if len(p_idx) == 0:
189189
sensitive = 1
190190
else:
191-
sensitive = np.average(score[p_idx], weights=sample_weight)
191+
sensitive = np.average(score[p_idx], weights=sample_weight[p_idx])
192192

193193
if len(n_idx) == 0:
194194
specificity = 1
195195
else:
196-
specificity = np.average(score[n_idx], weights=sample_weight)
196+
specificity = np.average(score[n_idx], weights=sample_weight[p_idx])
197197

198198
score = (sensitive + specificity) / 2
199199

@@ -239,15 +239,23 @@ def balanced_accuracy_score(y_true, y_pred, sample_weight=None):
239239
>>> y_pred = [0, 1, 1, 1]
240240
>>> y_true = [0, 1, 0, 1]
241241
>>> balanced_accuracy_score(y_true, y_pred)
242-
0.8333333333333334
242+
0.75
243243
244244
In the multilabel case with binary label indicators:
245-
>>> accuracy_score(np.array([[0, 1], [1, 1]]), np.ones((2, 2)))
246-
0.875
245+
>>> balanced_accuracy_score(np.array([[0, 1], [1, 1]]), np.ones((2, 2)))
246+
0.75
247247
"""
248248

249-
vecfunc = np.vectorize(_1d_balanced_accuracy_score)
250-
scores = vecfunc(y_true, y_pred, sample_weight=sample_weight)
249+
if len(np.array(y_true).shape) == 1:
250+
y_true = np.reshape(np.array(y_true), (1, -1))
251+
y_pred = np.reshape(np.array(y_pred), (1, -1))
252+
253+
scores = map(
254+
lambda ys: _1d_balanced_accuracy_score(
255+
ys[0],
256+
ys[1],
257+
sample_weight=sample_weight),
258+
zip(y_true, y_pred))
251259

252260
return np.mean(scores)
253261

sklearn/metrics/metrics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .ranking import roc_curve
1313

1414
from .classification import accuracy_score
15+
from .classification import balanced_accuracy_score
1516
from .classification import classification_report
1617
from .classification import confusion_matrix
1718
from .classification import f1_score

0 commit comments

Comments
 (0)