-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Wrong multi-class label requirement for the compute_roc_auc function #3095
Copy link
Copy link
Closed
Labels
enhancementNew feature or requestNew feature or request
Milestone
Description
Describe the bug
monai.metrics.compute_roc_auc doesn't accept labels not in the one-hot format for multi-class classification.
To Reproduce
from monai.metrics import compute_roc_auc
import torch
y = torch.randint(4, size=(10,1))
y_pred = torch.randint(4, size=(10,4))
y_pred = y_pred/y_pred.sum(dim=1).view(-1,1)
compute_roc_auc(y_pred, y)
Expected behavior
In the documentation, it says that if the label has a shape: [16, 1], that label will be converted into [16, 2] (where 2 is inferred from y_pred). The function should internally convert the label into the one-hot format or the documentation should explicitly ask for that.
Additional context
Potential fix for converting the label internally:
Adding from monai.networks.utils import one_hot ,
and
if y_pred_ndim == 2 and y.ndimension()==1:
y = one_hot(y, num_classes=y_pred.shape[1])
after
Lines 137 to 141 in 1ecf5b6
| if y_pred_ndim == 2 and y_pred.shape[1] == 1: | |
| y_pred = y_pred.squeeze(dim=-1) | |
| y_pred_ndim = 1 | |
| if y_ndim == 2 and y.shape[1] == 1: | |
| y = y.squeeze(dim=-1) |
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request