Skip to content

Commit f86ed77

Browse files
committed
Merge pull request #2656 from arjoly/np1.8-warnings
[MRG] FIX silence numpy 1.8 warning for using non integer
2 parents 833a1f3 + 20694a8 commit f86ed77

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

sklearn/metrics/metrics.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -437,29 +437,28 @@ def _average_binary_score(binary_metric, y_true, y_score, average):
437437
y_true = y_true.ravel()
438438
y_score = y_score.ravel()
439439

440+
if average == 'weighted':
441+
weights = np.sum(y_true, axis=0)
442+
if weights.sum() == 0:
443+
return 0
444+
else:
445+
weights = None
446+
440447
if y_true.ndim == 1:
441448
y_true = y_true.reshape((-1, 1))
442449

443450
if y_score.ndim == 1:
444451
y_score = y_score.reshape((-1, 1))
445452

446-
average_axis = 1 if average == 'samples' else 0
447-
n_classes = y_score.shape[not average_axis]
453+
not_average_axis = 0 if average == 'samples' else 1
454+
n_classes = y_score.shape[not_average_axis]
448455
score = np.zeros((n_classes,))
449-
450456
for c in range(n_classes):
451-
y_true_c = y_true.take([c], axis=not average_axis).ravel()
452-
y_score_c = y_score.take([c], axis=not average_axis).ravel()
457+
y_true_c = y_true.take([c], axis=not_average_axis).ravel()
458+
y_score_c = y_score.take([c], axis=not_average_axis).ravel()
453459
score[c] = binary_metric(y_true_c, y_score_c)
454460

455461
# Average the results
456-
if average == 'weighted':
457-
weights = np.sum(y_true, axis=average_axis)
458-
if weights.sum() == 0:
459-
return 0
460-
else:
461-
weights = None
462-
463462
if average is not None:
464463
return np.average(score, weights=weights)
465464
else:

0 commit comments

Comments
 (0)